├── .gitignore
├── LICENSE
├── README.md
├── attn_gan_pytorch
├── ConfigManagement.py
├── CustomLayers.py
├── Losses.py
├── Networks.py
├── Utils.py
└── __init__.py
├── literature
└── self_attention_gan.pdf
├── samples
├── .gitignore
├── data_processing
│ ├── DataLoader.py
│ └── __init__.py
├── generate_loss_plots.py
└── sample_celeba
│ ├── .gitignore
│ ├── configs
│ ├── 1
│ │ ├── dis.conf
│ │ └── gen.conf
│ ├── 2
│ │ ├── dis.conf
│ │ └── gen.conf
│ └── 3
│ │ ├── dis.conf
│ │ └── gen.conf
│ └── train.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # ignore the pycharm setup
107 | .idea/
108 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Animesh Karnewar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # attn_gan_pytorch
2 | python package for self-attention gan implemented as
3 | extension of PyTorch nn.Module.
4 | paper -> https://arxiv.org/abs/1805.08318
5 |
6 | Also includes generic layers for image based attention mechanism.
7 | Includes a **`Full-Attention`** layer as proposed by in another
8 | project of mine [here](https://github.com/akanimax/fagan)
9 |
10 | ## Installation:
11 | This is a python package availbale at the
12 | [**pypi.org**](https://pypi.org/project/attn-gan-pytorch/#description).
13 | So, installation is fairly straightforward. This package depends on
14 | a suitable GPU version of **`torch`** and **`torch-vision`** for your
15 | architecture. So, please download suitable pytorch prior to installing
16 | this package. Follow the instructions at
17 | [pytorch.org](https://pytorch.org/) to install your version of PyTorch.
18 |
19 | Install with following commands:
20 |
21 | $ workon [your virtual environment]
22 | $ pip install attn-gan-pytorch
23 |
24 | ## Celeba Samples:
25 | some celeba samples generated using this code for the
26 | fagan architecture:
27 |
28 |
29 |
30 |
31 | ### Head over to the [**Fagan project**](https://github.com/akanimax/fagan) repo for more info!
32 | Also, this repo contains the code for using this package
33 | to build the `SAGAN` architecture as mentioned in the paper.
34 | Please refer the `samples/` directory for this.
35 |
36 | ## Thanks
37 | Please feel free to open PRs here if you train on other datasets
38 | using this package. Suggestions / Issues / Contributions are most
39 | welcome.
40 |
41 | Best regards,
42 | @akanimax :)
43 |
--------------------------------------------------------------------------------
/attn_gan_pytorch/ConfigManagement.py:
--------------------------------------------------------------------------------
1 | """ Module for reading and parsing configuration files """
2 |
3 | import yaml
4 |
5 |
6 | def get_config(conf_file):
7 | """
8 | parse and load the provided configuration
9 | :param conf_file: configuration file
10 | :return: conf => parsed configuration
11 | """
12 | from easydict import EasyDict as edict
13 |
14 | with open(conf_file, "r") as file_descriptor:
15 | data = yaml.load(file_descriptor)
16 |
17 | # convert the data into an easyDictionary
18 | return edict(data)
19 |
20 |
21 | def parse2tuple(inp_str):
22 | """
23 | function for parsing a 2 tuple of integers
24 | :param inp_str: string of the form: '(3, 3)'
25 | :return: tuple => parsed tuple
26 | """
27 | inp_str = inp_str[1: -1] # remove the parenthesis
28 | args = inp_str.split(',')
29 | args = tuple(map(int, args))
30 |
31 | return args
--------------------------------------------------------------------------------
/attn_gan_pytorch/CustomLayers.py:
--------------------------------------------------------------------------------
1 | """ Module implements the custom layers """
2 |
3 | import torch as th
4 |
5 |
6 | class SelfAttention(th.nn.Module):
7 | """
8 | Layer implements the self-attention module
9 | which is the main logic behind this architecture.
10 |
11 | args:
12 | channels: number of channels in the image tensor
13 | activation: activation function to be applied (default: lrelu(0.2))
14 | squeeze_factor: squeeze factor for query and keys (default: 8)
15 | bias: whether to apply bias or not (default: True)
16 | """
17 |
18 | def __init__(self, channels, activation=None, squeeze_factor=8, bias=True):
19 | """ constructor for the layer """
20 |
21 | from torch.nn import Conv2d, Parameter, Softmax
22 |
23 | # base constructor call
24 | super().__init__()
25 |
26 | # state of the layer
27 | self.activation = activation
28 | self.gamma = Parameter(th.zeros(1))
29 |
30 | # Modules required for computations
31 | self.query_conv = Conv2d( # query convolution
32 | in_channels=channels,
33 | out_channels=channels // squeeze_factor,
34 | kernel_size=(1, 1),
35 | stride=1,
36 | padding=0,
37 | bias=bias
38 | )
39 |
40 | self.key_conv = Conv2d(
41 | in_channels=channels,
42 | out_channels=channels // squeeze_factor,
43 | kernel_size=(1, 1),
44 | stride=1,
45 | padding=0,
46 | bias=bias
47 | )
48 |
49 | self.value_conv = Conv2d(
50 | in_channels=channels,
51 | out_channels=channels,
52 | kernel_size=(1, 1),
53 | stride=1,
54 | padding=0,
55 | bias=bias
56 | )
57 |
58 | # softmax module for applying attention
59 | self.softmax = Softmax(dim=-1)
60 |
61 | def forward(self, x):
62 | """
63 | forward computations of the layer
64 | :param x: input feature maps (B x C x H x W)
65 | :return:
66 | out: self attention value + input feature (B x O x H x W)
67 | attention: attention map (B x C x H x W)
68 | """
69 |
70 | # extract the shape of the input tensor
71 | m_batchsize, c, height, width = x.size()
72 |
73 | # create the query projection
74 | proj_query = self.query_conv(x).view(
75 | m_batchsize, -1, width * height).permute(0, 2, 1) # B x (N) x C
76 |
77 | # create the key projection
78 | proj_key = self.key_conv(x).view(
79 | m_batchsize, -1, width * height) # B x C x (N)
80 |
81 | # calculate the attention maps
82 | energy = th.bmm(proj_query, proj_key) # energy
83 | attention = self.softmax(energy) # attention (B x (N) x (N))
84 |
85 | # create the value projection
86 | proj_value = self.value_conv(x).view(
87 | m_batchsize, -1, width * height) # B X C X N
88 |
89 | # calculate the output
90 | out = th.bmm(proj_value, attention.permute(0, 2, 1))
91 | out = out.view(m_batchsize, c, height, width)
92 |
93 | attention = attention.view(m_batchsize, -1, height, width)
94 |
95 | if self.activation is not None:
96 | out = self.activation(out)
97 |
98 | out = self.gamma * out + x
99 | return out, attention
100 |
101 |
102 | class FullAttention(th.nn.Module):
103 | """
104 | Layer implements my version of the self-attention module
105 | it is mostly same as self attention, but generalizes to
106 | (k x k) convolutions instead of (1 x 1)
107 | args:
108 | in_channels: number of input channels
109 | out_channels: number of output channels
110 | activation: activation function to be applied (default: lrelu(0.2))
111 | kernel_size: kernel size for convolution (default: (1 x 1))
112 | transpose_conv: boolean denoting whether to use convolutions or transpose
113 | convolutions
114 | squeeze_factor: squeeze factor for query and keys (default: 8)
115 | stride: stride for the convolutions (default: 1)
116 | padding: padding for the applied convolutions (default: 1)
117 | bias: whether to apply bias or not (default: True)
118 | """
119 |
120 | def __init__(self, in_channels, out_channels,
121 | activation=None, kernel_size=(1, 1), transpose_conv=False,
122 | use_spectral_norm=True, use_batch_norm=True,
123 | squeeze_factor=8, stride=1, padding=0, bias=True):
124 | """ constructor for the layer """
125 |
126 | from torch.nn import Conv2d, Parameter, \
127 | Softmax, ConvTranspose2d, BatchNorm2d
128 |
129 | # base constructor call
130 | super().__init__()
131 |
132 | # state of the layer
133 | self.activation = activation
134 | self.gamma = Parameter(th.zeros(1))
135 |
136 | self.in_channels = in_channels
137 | self.out_channels = out_channels
138 | self.squeezed_channels = in_channels // squeeze_factor
139 | self.use_batch_norm = use_batch_norm
140 |
141 | # Modules required for computations
142 | if transpose_conv:
143 | self.query_conv = ConvTranspose2d( # query convolution
144 | in_channels=in_channels,
145 | out_channels=in_channels // squeeze_factor,
146 | kernel_size=kernel_size,
147 | stride=stride,
148 | padding=padding,
149 | bias=bias
150 | )
151 |
152 | self.key_conv = ConvTranspose2d(
153 | in_channels=in_channels,
154 | out_channels=in_channels // squeeze_factor,
155 | kernel_size=kernel_size,
156 | stride=stride,
157 | padding=padding,
158 | bias=bias
159 | )
160 |
161 | self.value_conv = ConvTranspose2d(
162 | in_channels=in_channels,
163 | out_channels=out_channels,
164 | kernel_size=kernel_size,
165 | stride=stride,
166 | padding=padding,
167 | bias=bias
168 | )
169 |
170 | self.residual_conv = ConvTranspose2d(
171 | in_channels=in_channels,
172 | out_channels=out_channels,
173 | kernel_size=kernel_size,
174 | stride=stride,
175 | padding=padding,
176 | bias=bias
177 | ) if not use_spectral_norm else SpectralNorm(
178 | ConvTranspose2d(
179 | in_channels=in_channels,
180 | out_channels=out_channels,
181 | kernel_size=kernel_size,
182 | stride=stride,
183 | padding=padding,
184 | bias=bias
185 | )
186 | )
187 |
188 | else:
189 | self.query_conv = Conv2d( # query convolution
190 | in_channels=in_channels,
191 | out_channels=in_channels // squeeze_factor,
192 | kernel_size=kernel_size,
193 | stride=stride,
194 | padding=padding,
195 | bias=bias
196 | )
197 |
198 | self.key_conv = Conv2d(
199 | in_channels=in_channels,
200 | out_channels=in_channels // squeeze_factor,
201 | kernel_size=kernel_size,
202 | stride=stride,
203 | padding=padding,
204 | bias=bias
205 | )
206 |
207 | self.value_conv = Conv2d(
208 | in_channels=in_channels,
209 | out_channels=out_channels,
210 | kernel_size=kernel_size,
211 | stride=stride,
212 | padding=padding,
213 | bias=bias
214 | )
215 |
216 | self.residual_conv = Conv2d(
217 | in_channels=in_channels,
218 | out_channels=out_channels,
219 | kernel_size=kernel_size,
220 | stride=stride,
221 | padding=padding,
222 | bias=bias
223 | ) if not use_spectral_norm else SpectralNorm(
224 | Conv2d(
225 | in_channels=in_channels,
226 | out_channels=out_channels,
227 | kernel_size=kernel_size,
228 | stride=stride,
229 | padding=padding,
230 | bias=bias
231 | )
232 | )
233 |
234 | # softmax module for applying attention
235 | self.softmax = Softmax(dim=-1)
236 | self.batch_norm = BatchNorm2d(out_channels)
237 |
238 | def forward(self, x):
239 | """
240 | forward computations of the layer
241 | :param x: input feature maps (B x C x H x W)
242 | :return:
243 | out: self attention value + input feature (B x O x H x W)
244 | attention: attention map (B x C x H x W)
245 | """
246 |
247 | # extract the batch size of the input tensor
248 | m_batchsize, _, _, _ = x.size()
249 |
250 | # create the query projection
251 | proj_query = self.query_conv(x).view(
252 | m_batchsize, self.squeezed_channels, -1).permute(0, 2, 1) # B x (N) x C
253 |
254 | # create the key projection
255 | proj_key = self.key_conv(x).view(
256 | m_batchsize, self.squeezed_channels, -1) # B x C x (N)
257 |
258 | # calculate the attention maps
259 | energy = th.bmm(proj_query, proj_key) # energy
260 | attention = self.softmax(energy) # attention (B x (N) x (N))
261 |
262 | # create the value projection
263 | proj_value = self.value_conv(x).view(
264 | m_batchsize, self.out_channels, -1) # B X C X N
265 |
266 | # calculate the output
267 | out = th.bmm(proj_value, attention.permute(0, 2, 1))
268 |
269 | # calculate the residual output
270 | res_out = self.residual_conv(x)
271 |
272 | out = out.view(m_batchsize, self.out_channels,
273 | res_out.shape[-2], res_out.shape[-1])
274 |
275 | attention = attention.view(m_batchsize, -1,
276 | res_out.shape[-2], res_out.shape[-1])
277 |
278 | if self.use_batch_norm:
279 | res_out = self.batch_norm(res_out)
280 |
281 | if self.activation is not None:
282 | out = self.activation(out)
283 | res_out = self.activation(res_out)
284 |
285 | # apply the residual connections
286 | out = (self.gamma * out) + ((1 - self.gamma) * res_out)
287 | return out, attention
288 |
289 |
290 | class SpectralNorm(th.nn.Module):
291 | """
292 | Wrapper around a Torch module which applies spectral Normalization
293 | """
294 |
295 | # TODO complete the documentation for this Layer
296 |
297 | def __init__(self, module, name='weight', power_iterations=1):
298 | super(SpectralNorm, self).__init__()
299 | self.module = module
300 | self.name = name
301 | self.power_iterations = power_iterations
302 | if not self._made_params():
303 | self._make_params()
304 |
305 | @staticmethod
306 | def l2normalize(v, eps=1e-12):
307 | return v / (v.norm() + eps)
308 |
309 | def _update_u_v(self):
310 | u = getattr(self.module, self.name + "_u")
311 | v = getattr(self.module, self.name + "_v")
312 | w = getattr(self.module, self.name + "_bar")
313 |
314 | height = w.data.shape[0]
315 | for _ in range(self.power_iterations):
316 | v.data = self.l2normalize(th.mv(th.t(w.view(height, -1).data), u.data))
317 | u.data = self.l2normalize(th.mv(w.view(height, -1).data, v.data))
318 |
319 | sigma = u.dot(w.view(height, -1).mv(v))
320 | setattr(self.module, self.name, w / sigma.expand_as(w))
321 |
322 | def _made_params(self):
323 | try:
324 | getattr(self.module, self.name + "_u")
325 | getattr(self.module, self.name + "_v")
326 | getattr(self.module, self.name + "_bar")
327 | return True
328 | except AttributeError:
329 | return False
330 |
331 | def _make_params(self):
332 | from torch.nn import Parameter
333 |
334 | w = getattr(self.module, self.name)
335 |
336 | height = w.data.shape[0]
337 | width = w.view(height, -1).data.shape[1]
338 |
339 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
340 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
341 | u.data = self.l2normalize(u.data)
342 | v.data = self.l2normalize(v.data)
343 | w_bar = Parameter(w.data)
344 |
345 | del self.module._parameters[self.name]
346 |
347 | self.module.register_parameter(self.name + "_u", u)
348 | self.module.register_parameter(self.name + "_v", v)
349 | self.module.register_parameter(self.name + "_bar", w_bar)
350 |
351 | def forward(self, *args):
352 | self._update_u_v()
353 | return self.module.forward(*args)
354 |
355 |
356 | class IgnoreAttentionMap(th.nn.Module):
357 | """
358 | A petty module to ignore the attention
359 | map output by the self_attention layer
360 | """
361 |
362 | def __init__(self):
363 | """ has nothing and does nothing apart from super calls """
364 | super().__init__()
365 |
366 | def forward(self, inp):
367 | """
368 | ignores the attention_map the obtained input. and returns the features
369 | :param inp: (features, attention_maps)
370 | :return: output => features
371 | """
372 | return inp[0]
373 |
--------------------------------------------------------------------------------
/attn_gan_pytorch/Losses.py:
--------------------------------------------------------------------------------
1 | """ Module implementing various loss functions """
2 |
3 | import torch as th
4 |
5 |
6 | # TODOcomplete Major rewrite: change the interface to use only predictions
7 | # for real and fake samples
8 | # The interface doesn't need to change to only use predictions for real and fake samples
9 | # because for loss such as WGAN-GP requires the samples to calculate gradient penalty
10 |
11 | class GANLoss:
12 | """
13 | Base class for all losses
14 | Note that the gen_loss also has
15 | """
16 |
17 | def __init__(self, device, dis):
18 | self.device = device
19 | self.dis = dis
20 |
21 | def dis_loss(self, real_samps, fake_samps):
22 | raise NotImplementedError("dis_loss method has not been implemented")
23 |
24 | def gen_loss(self, real_samps, fake_samps):
25 | raise NotImplementedError("gen_loss method has not been implemented")
26 |
27 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors):
28 | raise NotImplementedError("conditional_dis_loss method has not been implemented")
29 |
30 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors):
31 | raise NotImplementedError("conditional_gen_loss method has not been implemented")
32 |
33 |
34 | class StandardGAN(GANLoss):
35 |
36 | def __init__(self, dev, dis):
37 | from torch.nn import BCELoss
38 |
39 | super().__init__(dev, dis)
40 |
41 | # define the criterion object
42 | self.criterion = BCELoss()
43 |
44 | def dis_loss(self, real_samps, fake_samps):
45 | # calculate the real loss:
46 | real_loss = self.criterion(th.squeeze(self.dis(real_samps)),
47 | th.ones(real_samps.shape[0]).to(self.device))
48 | # calculate the fake loss:
49 | fake_loss = self.criterion(th.squeeze(self.dis(fake_samps)),
50 | th.zeros(fake_samps.shape[0]).to(self.device))
51 |
52 | # return final loss as average of the two:
53 | return (real_loss + fake_loss) / 2
54 |
55 | def gen_loss(self, _, fake_samps):
56 | return self.criterion(th.squeeze(self.dis(fake_samps)),
57 | th.ones(fake_samps.shape[0]).to(self.device))
58 |
59 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors):
60 | # calculate the real loss:
61 | real_loss = self.criterion(th.squeeze(self.dis(real_samps, conditional_vectors)),
62 | th.ones(real_samps.shape[0]).to(self.device))
63 | # calculate the fake loss:
64 | fake_loss = self.criterion(th.squeeze(self.dis(fake_samps, conditional_vectors)),
65 | th.zeros(fake_samps.shape[0]).to(self.device))
66 |
67 | # return final loss as average of the two:
68 | return (real_loss + fake_loss) / 2
69 |
70 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors):
71 | return self.criterion(th.squeeze(self.dis(fake_samps, conditional_vectors)),
72 | th.ones(fake_samps.shape[0]).to(self.device))
73 |
74 |
75 | class LSGAN(GANLoss):
76 |
77 | def __init__(self, device, dis):
78 | super().__init__(device, dis)
79 |
80 | def dis_loss(self, real_samps, fake_samps):
81 | return 0.5 * (((th.mean(self.dis(real_samps)) - 1) ** 2)
82 | + (th.mean(self.dis(fake_samps))) ** 2)
83 |
84 | def gen_loss(self, _, fake_samps):
85 | return 0.5 * ((th.mean(self.dis(fake_samps)) - 1) ** 2)
86 |
87 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors):
88 | return 0.5 * (((th.mean(self.dis(real_samps, conditional_vectors)) - 1) ** 2)
89 | + (th.mean(self.dis(fake_samps, conditional_vectors))) ** 2)
90 |
91 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors):
92 | return 0.5 * ((th.mean(self.dis(fake_samps, conditional_vectors)) - 1) ** 2)
93 |
94 |
95 | class HingeGAN(GANLoss):
96 |
97 | def __init__(self, device, dis):
98 | super().__init__(device, dis)
99 |
100 | def dis_loss(self, real_samps, fake_samps):
101 | return (th.mean(th.nn.ReLU()(1 - self.dis(real_samps))) +
102 | th.mean(th.nn.ReLU()(1 + self.dis(fake_samps))))
103 |
104 | def gen_loss(self, real_samps, fake_samps):
105 | return -th.mean(self.dis(fake_samps))
106 |
107 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors):
108 | return (th.mean(th.nn.ReLU()(1 - self.dis(real_samps, conditional_vectors))) +
109 | th.mean(th.nn.ReLU()(1 + self.dis(fake_samps, conditional_vectors))))
110 |
111 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors):
112 | return -th.mean(self.dis(fake_samps, conditional_vectors))
113 |
114 |
115 | class RelativisticAverageHingeGAN(GANLoss):
116 |
117 | def __init__(self, device, dis):
118 | super().__init__(device, dis)
119 |
120 | def dis_loss(self, real_samps, fake_samps):
121 | # difference between real and fake:
122 | r_f_diff = self.dis(real_samps) - th.mean(self.dis(fake_samps))
123 |
124 | # difference between fake and real samples
125 | f_r_diff = self.dis(fake_samps) - th.mean(self.dis(real_samps))
126 |
127 | # return the loss
128 | return (th.mean(th.nn.ReLU()(1 - r_f_diff))
129 | + th.mean(th.nn.ReLU()(1 + f_r_diff)))
130 |
131 | def gen_loss(self, real_samps, fake_samps):
132 | # difference between real and fake:
133 | r_f_diff = self.dis(real_samps) - th.mean(self.dis(fake_samps))
134 |
135 | # difference between fake and real samples
136 | f_r_diff = self.dis(fake_samps) - th.mean(self.dis(real_samps))
137 |
138 | # return the loss
139 | return (th.mean(th.nn.ReLU()(1 + r_f_diff))
140 | + th.mean(th.nn.ReLU()(1 - f_r_diff)))
141 |
142 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors):
143 | # difference between real and fake:
144 | r_f_diff = self.dis(real_samps, conditional_vectors) \
145 | - th.mean(self.dis(fake_samps, conditional_vectors))
146 |
147 | # difference between fake and real samples
148 | f_r_diff = self.dis(fake_samps, conditional_vectors) \
149 | - th.mean(self.dis(real_samps, conditional_vectors))
150 |
151 | # return the loss
152 | return (th.mean(th.nn.ReLU()(1 - r_f_diff))
153 | + th.mean(th.nn.ReLU()(1 + f_r_diff)))
154 |
155 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors):
156 | # difference between real and fake:
157 | r_f_diff = self.dis(real_samps, conditional_vectors) \
158 | - th.mean(self.dis(fake_samps, conditional_vectors))
159 |
160 | # difference between fake and real samples
161 | f_r_diff = self.dis(fake_samps, conditional_vectors) \
162 | - th.mean(self.dis(real_samps, conditional_vectors))
163 |
164 | # return the loss
165 | return (th.mean(th.nn.ReLU()(1 + r_f_diff))
166 | + th.mean(th.nn.ReLU()(1 - f_r_diff)))
167 |
--------------------------------------------------------------------------------
/attn_gan_pytorch/Networks.py:
--------------------------------------------------------------------------------
1 | """ module implements the networks functionality """
2 |
3 | import torch as th
4 | import numpy as np
5 | import timeit
6 | import datetime
7 | import time
8 | import os
9 |
10 |
11 | class Network(th.nn.Module):
12 | """ General module that creates a Network from the configuration provided
13 | Extends a PyTorch Module
14 | args:
15 | modules: list of PyTorch layers (nn.Modules)
16 | """
17 |
18 | def __init__(self, modules):
19 | """ derived constructor """
20 |
21 | # make a call to Module constructor for allowing
22 | # us to attach required modules
23 | super().__init__()
24 |
25 | self.model = th.nn.Sequential(*modules)
26 |
27 | def forward(self, x):
28 | """
29 | forward computations
30 | :param x: input
31 | :return: y => output features volume
32 | """
33 | return self.model(x)
34 |
35 |
36 | class Generator(Network):
37 | """
38 | Generator is an extension of a Generic Network
39 |
40 | args:
41 | modules: same as for Network
42 | latent_size: latent size of the Generator (GAN)
43 | """
44 |
45 | def __init__(self, modules, latent_size):
46 | super().__init__(modules)
47 |
48 | # attach the latent size for the GAN here
49 | self.latent_size = latent_size
50 |
51 |
52 | class ConditionalGenerator(Generator):
53 | """ Conditional Generator is a special case of a generator
54 | well nothing special more than just the name. Nevertheless,
55 | something does lie in name. (Niki is the name
56 | I can't stop thinking about :blush:)
57 |
58 | args:
59 | modules: same as for Network
60 | latent_size: latent size of the Generator (GAN)
61 | Note that latent_size also includes the size of the
62 | conditional labels
63 | """
64 | pass
65 |
66 |
67 | class Discriminator(Network):
68 | pass
69 |
70 |
71 | class ConditionalDiscriminator(Discriminator):
72 | """
73 | The conditional variant of the Discriminator which (discriminator)
74 | is just further down the Network class tree.
75 |
76 | args:
77 | modules: Note that this list of modules must not contain the final prediction
78 | layer. This only reduces the spatial dimension to
79 | (reduced_height x reduced_width) specifically.
80 | embedding_size: size of the conditional embedding
81 | last_module: th.nn.Module which makes the conditional prediction
82 | """
83 |
84 | def __init__(self, modules, last_module):
85 | super().__init__(modules)
86 |
87 | # attach the last module separately here:
88 | self.last_module = last_module
89 |
90 | # adding the last projector conv layer which
91 | # concatenates the text embedding prior to prediction
92 | # calculation.
93 |
94 | def forward(self, x, embedding):
95 | """
96 | The forward pass of the Conditional Discriminator.
97 | :param x: input images tensor
98 | :param embedding: conditional vector
99 | :return: predictions => scores for the inputs
100 | """
101 | # obtain the reduced volume:
102 | reduced_volume = super().forward(x)
103 |
104 | # concatenate the embeddings to reduced_volume here:
105 | cat = th.unsqueeze(th.unsqueeze(embedding, -1), -1)
106 | # spatial replication
107 | cat = cat.expand(cat.shape[0], cat.shape[1],
108 | reduced_volume.shape[2], reduced_volume.shape[3])
109 | final_input = th.cat((reduced_volume, cat), dim=1)
110 |
111 | # apply the last module to obtain the predictions:
112 | prediction_scores = self.last_module(final_input)
113 |
114 | # return the prediction scores:
115 | return prediction_scores
116 |
117 |
118 | class GAN:
119 | """
120 | Unconditional GAN
121 |
122 | args:
123 | gen: Generator object
124 | dis: Discriminator object
125 | device: torch.device() for running on GPU or CPU
126 | default = torch.device("cpu")
127 | """
128 |
129 | def __init__(self, gen, dis,
130 | device=th.device("cpu")):
131 | """ constructor for the class """
132 | assert isinstance(gen, Generator), "gen is not an Unconditional Generator"
133 | assert isinstance(dis, Discriminator), "dis is not an Unconditional Discriminator"
134 |
135 | # define the state of the object
136 | self.generator = gen.to(device)
137 | self.discriminator = dis.to(device)
138 | self.device = device
139 |
140 | # by default the generator and discriminator are in eval mode
141 | self.generator.eval()
142 | self.discriminator.eval()
143 |
144 | def generate_samples(self, num_samples):
145 | """
146 | generate samples using this gan
147 | :param num_samples: number of samples to be generated
148 | :return: generated samples tensor: (B x H x W x C)
149 | """
150 | noise = th.randn(num_samples, self.generator.latent_size).to(self.device)
151 | generated_images = self.generator(noise).detach()
152 |
153 | # reshape the generated images
154 | generated_images = generated_images.permute(0, 2, 3, 1)
155 |
156 | return generated_images
157 |
158 | def optimize_discriminator(self, dis_optim, noise, real_batch, loss_fn):
159 | """
160 | performs one step of weight update on discriminator using the batch of data
161 | :param dis_optim: discriminator optimizer
162 | :param noise: input noise of sample generation
163 | :param real_batch: real samples batch
164 | :param loss_fn: loss function to be used (object of GANLoss)
165 | :return: current loss
166 | """
167 |
168 | # generate a batch of samples
169 | fake_samples = self.generator(noise).detach()
170 |
171 | loss = loss_fn.dis_loss(real_batch, fake_samples)
172 |
173 | # optimize discriminator
174 | dis_optim.zero_grad()
175 | loss.backward()
176 | dis_optim.step()
177 |
178 | return loss.item()
179 |
180 | def optimize_generator(self, gen_optim, noise, real_batch, loss_fn):
181 | """
182 | performs one step of weight update on generator using the batch of data
183 | :param gen_optim: generator optimizer
184 | :param noise: input noise of sample generation
185 | :param real_batch: real samples batch
186 | :param loss_fn: loss function to be used (object of GANLoss)
187 | :return: current loss
188 | """
189 |
190 | # generate a batch of samples
191 | fake_samples = self.generator(noise)
192 |
193 | loss = loss_fn.gen_loss(real_batch, fake_samples)
194 |
195 | # optimize discriminator
196 | gen_optim.zero_grad()
197 | loss.backward()
198 | gen_optim.step()
199 |
200 | return loss.item()
201 |
202 | @staticmethod
203 | def create_grid(samples, img_file):
204 | """
205 | utility function to create a grid of GAN samples
206 | :param samples: generated samples for storing
207 | :param img_file: name of file to write
208 | :return: None (saves a file)
209 | """
210 | from torchvision.utils import save_image
211 | from numpy import sqrt
212 |
213 | samples = th.clamp((samples / 2) + 0.5, min=0, max=1)
214 |
215 | # save the images:
216 | save_image(samples, img_file, nrow=int(sqrt(samples.shape[0])))
217 |
218 | def train(self, data, gen_optim, dis_optim, loss_fn,
219 | start=1, num_epochs=12, feedback_factor=10, checkpoint_factor=1,
220 | data_percentage=100, num_samples=36,
221 | log_dir=None, sample_dir="./samples",
222 | save_dir="./models"):
223 |
224 | # TODO write the documentation for this method
225 |
226 | # turn the generator and discriminator into train mode
227 | self.generator.train()
228 | self.discriminator.train()
229 |
230 | assert isinstance(gen_optim, th.optim.Optimizer), \
231 | "gen_optim is not an Optimizer"
232 | assert isinstance(dis_optim, th.optim.Optimizer), \
233 | "dis_optim is not an Optimizer"
234 |
235 | print("Starting the training process ... ")
236 |
237 | # create fixed_input for debugging
238 | fixed_input = th.randn(num_samples,
239 | self.generator.latent_size, 1, 1).to(self.device)
240 |
241 | # create a global time counter
242 | global_time = time.time()
243 |
244 | for epoch in range(start, num_epochs + 1):
245 | start = timeit.default_timer() # record time at the start of epoch
246 |
247 | print("\nEpoch: %d" % epoch)
248 | total_batches = len(iter(data))
249 |
250 | limit = int((data_percentage / 100) * total_batches)
251 |
252 | for (i, batch) in enumerate(data, 1):
253 |
254 | # extract current batch of data for training
255 | images = batch.to(self.device)
256 |
257 | gan_input = th.randn(images.shape[0],
258 | self.generator.latent_size, 1, 1).to(self.device)
259 |
260 | # optimize the discriminator:
261 | dis_loss = self.optimize_discriminator(dis_optim, gan_input,
262 | images, loss_fn)
263 |
264 | # optimize the generator:
265 | # resample from the latent noise
266 | gan_input = th.randn(images.shape[0],
267 | self.generator.latent_size, 1, 1).to(self.device)
268 | gen_loss = self.optimize_generator(gen_optim, gan_input,
269 | images, loss_fn)
270 |
271 | # provide a loss feedback
272 | if i % int(limit / feedback_factor) == 0 or i == 1:
273 | elapsed = time.time() - global_time
274 | elapsed = str(datetime.timedelta(seconds=elapsed))
275 | print("Elapsed [%s] batch: %d d_loss: %f g_loss: %f"
276 | % (elapsed, i, dis_loss, gen_loss))
277 |
278 | # also write the losses to the log file:
279 | if log_dir is not None:
280 | log_file = os.path.join(log_dir, "loss.log")
281 | os.makedirs(os.path.dirname(log_file), exist_ok=True)
282 | with open(log_file, "a") as log:
283 | log.write(str(dis_loss) + "\t" + str(gen_loss) + "\n")
284 |
285 | # create a grid of samples and save it
286 | os.makedirs(sample_dir, exist_ok=True)
287 | gen_img_file = os.path.join(sample_dir, "gen_" +
288 | str(epoch) + "_" +
289 | str(i) + ".png")
290 | self.create_grid(self.generator(fixed_input).detach(), gen_img_file)
291 |
292 | if i > limit:
293 | break
294 |
295 | # calculate the time required for the epoch
296 | stop = timeit.default_timer()
297 | print("Time taken for epoch: %.3f secs" % (stop - start))
298 |
299 | if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == num_epochs:
300 | os.makedirs(save_dir, exist_ok=True)
301 | gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(epoch) + ".pth")
302 | dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(epoch) + ".pth")
303 |
304 | th.save(self.generator.state_dict(), gen_save_file)
305 | th.save(self.discriminator.state_dict(), dis_save_file)
306 |
307 | print("Training completed ...")
308 |
309 | # return the generator and discriminator back to eval mode
310 | self.generator.eval()
311 | self.discriminator.eval()
312 |
313 |
314 | # TODOcomplete implement conditional gan variant of this
315 | # conditional gan implemented
316 |
317 | class ConditionalGAN(GAN):
318 | """
319 | Conditional GAN. Actually modifies the calls
320 | for optimize discriminator, optimize generator and train
321 |
322 | args:
323 | gen: ConditionalGenerator object
324 | dis: ConditionalDiscriminator object
325 | device: torch.device() for running on GPU or CPU
326 | default = torch.device("cpu")
327 | """
328 |
329 | def __init__(self, gen, dis, device=th.device("cpu")):
330 | """ constructor for this derived class """
331 |
332 | # some more specific checks here
333 | assert isinstance(gen, ConditionalGenerator), \
334 | "gen is not a Conditional Generator"
335 | assert isinstance(dis, ConditionalDiscriminator), \
336 | "dis is not a Conditional Discriminator"
337 |
338 | super().__init__(gen, dis, device)
339 |
340 | @staticmethod
341 | def randomize(correct_labels):
342 | """
343 | static helper for mismatching the given labels
344 | :param correct_labels: input correct labels
345 | :return: shuffled labels
346 | (Note, that this behaviour is not
347 | guaranteed to create a mismatch for every sample)
348 | """
349 | return correct_labels[np.random.permutation(correct_labels[0]), :]
350 |
351 | def optimize_discriminator(self, dis_optim, noise, real_batch, loss_fn,
352 | conditional_vectors, matching_aware=False,
353 | randomizer=None):
354 | """
355 | performs one step of weight update on discriminator using the batch of data
356 | :param dis_optim: discriminator optimizer
357 | :param noise: input noise of sample generation
358 | :param real_batch: real samples batch
359 | :param loss_fn: loss function to be used (object of GANLoss)
360 | :param conditional_vectors: for conditional discrimination
361 | :param matching_aware: boolean for whether to use matching aware discriminator
362 | :param randomizer: function object for randomizing the conditional vectors.
363 | i.e. to mismatch conditional vectors
364 | uses the default randomize function here
365 | :return: current loss
366 | """
367 |
368 | # generate a batch of samples
369 | fake_samples = self.generator(noise).detach()
370 |
371 | loss = loss_fn.conditional_dis_loss(real_batch, fake_samples,
372 | conditional_vectors)
373 |
374 | # if matching aware discrimination is to be used:
375 | if matching_aware:
376 | loss += loss_fn.conditional_dis_loss(
377 | real_batch, real_batch,
378 | randomizer(conditional_vectors)
379 | if randomizer is not None
380 | else self.randomize(conditional_vectors)
381 | )
382 | loss = loss / 2
383 |
384 | # optimize discriminator
385 | dis_optim.zero_grad()
386 | loss.backward()
387 | dis_optim.step()
388 |
389 | return loss.item()
390 |
391 | def optimize_generator(self, gen_optim, noise, real_batch, loss_fn,
392 | conditional_vectors):
393 | """
394 | performs one step of weight update on generator using the batch of data
395 | :param gen_optim: generator optimizer
396 | :param noise: input noise of sample generation
397 | :param real_batch: real samples batch
398 | :param loss_fn: loss function to be used (object of GANLoss)
399 | :param conditional_vectors: for conditional discrimination
400 | :return: current loss
401 | """
402 |
403 | # generate a batch of samples
404 | fake_samples = self.generator(noise)
405 |
406 | loss = loss_fn.conditional_gen_loss(real_batch, fake_samples,
407 | conditional_vectors)
408 |
409 | # optimize discriminator
410 | gen_optim.zero_grad()
411 | # retain graph is true for applying regularization on the
412 | # conditional input
413 | loss.backward(retain_graph=True)
414 | gen_optim.step()
415 |
416 | return loss.item()
417 |
418 | def train(self, data, gen_optim, dis_optim, loss_fn,
419 | start=1, num_epochs=12, feedback_factor=10, checkpoint_factor=1,
420 | data_percentage=100, num_samples=36,
421 | matching_aware=False, mismatcher=None,
422 | log_dir=None, sample_dir="./samples",
423 | save_dir="./models"):
424 |
425 | # TODO write the documentation for this method
426 | # This is the limit of procrastination now :D
427 | # Just note that data here gives image, label (one-hot encoded)
428 | # in every batch
429 |
430 | # turn the generator and discriminator into train mode
431 | self.generator.train()
432 | self.discriminator.train()
433 |
434 | assert isinstance(gen_optim, th.optim.Optimizer), \
435 | "gen_optim is not an Optimizer"
436 | assert isinstance(dis_optim, th.optim.Optimizer), \
437 | "dis_optim is not an Optimizer"
438 |
439 | print("Starting the training process ... ")
440 |
441 | # create fixed_input for debugging
442 | _, debug_labels = iter(data).next()
443 | debug_labels = th.unsqueeze(th.unsqueeze(debug_labels, -1), -1).to(self.device)
444 | fixed_latent_vectors = th.randn(
445 | num_samples,
446 | self.generator.latent_size - debug_labels.shape[1],
447 | 1, 1
448 | ).to(self.device)
449 |
450 | fixed_input = th.cat((fixed_latent_vectors, debug_labels), dim=1)
451 |
452 | # create a global time counter
453 | global_time = time.time()
454 |
455 | for epoch in range(start, num_epochs + 1):
456 | start = timeit.default_timer() # record time at the start of epoch
457 |
458 | print("\nEpoch: %d" % epoch)
459 | total_batches = len(iter(data))
460 |
461 | limit = int((data_percentage / 100) * total_batches)
462 |
463 | for (i, batch) in enumerate(data, 1):
464 |
465 | # extract current batch of data for training
466 | images, labels = batch
467 | images, labels = images.to(self.device), labels.to(self.device)
468 | expanded_labels = th.unsqueeze(th.unsqueeze(labels, -1), -1)
469 |
470 | latent_input = th.randn(
471 | images.shape[0],
472 | self.generator.latent_size - expanded_labels.shape[1],
473 | 1, 1
474 | ).to(self.device)
475 |
476 | gan_input = th.cat((latent_input, expanded_labels), dim=1)
477 |
478 | # optimize the discriminator:
479 | dis_loss = self.optimize_discriminator(dis_optim, gan_input,
480 | images, loss_fn, labels,
481 | matching_aware, mismatcher)
482 |
483 | # optimize the generator:
484 | # resample from the latent noise
485 | latent_input = th.randn(
486 | images.shape[0],
487 | self.generator.latent_size - expanded_labels.shape[1],
488 | 1, 1
489 | ).to(self.device)
490 | gan_input = th.cat((latent_input, expanded_labels), dim=1)
491 | gen_loss = self.optimize_generator(gen_optim, gan_input,
492 | images, loss_fn, labels)
493 |
494 | # provide a loss feedback
495 | if i % int(limit / feedback_factor) == 0 or i == 1:
496 | elapsed = time.time() - global_time
497 | elapsed = str(datetime.timedelta(seconds=elapsed))
498 | print("Elapsed [%s] batch: %d d_loss: %f g_loss: %f"
499 | % (elapsed, i, dis_loss, gen_loss))
500 |
501 | # also write the losses to the log file:
502 | if log_dir is not None:
503 | log_file = os.path.join(log_dir, "loss.log")
504 | os.makedirs(os.path.dirname(log_file), exist_ok=True)
505 | with open(log_file, "a") as log:
506 | log.write(str(dis_loss) + "\t" + str(gen_loss) + "\n")
507 |
508 | # create a grid of samples and save it
509 | os.makedirs(sample_dir, exist_ok=True)
510 | gen_img_file = os.path.join(sample_dir, "gen_" +
511 | str(epoch) + "_" +
512 | str(i) + ".png")
513 | self.create_grid(self.generator(fixed_input).detach(), gen_img_file)
514 |
515 | if i > limit:
516 | break
517 |
518 | # calculate the time required for the epoch
519 | stop = timeit.default_timer()
520 | print("Time taken for epoch: %.3f secs" % (stop - start))
521 |
522 | if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == num_epochs:
523 | os.makedirs(save_dir, exist_ok=True)
524 | gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(epoch) + ".pth")
525 | dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(epoch) + ".pth")
526 |
527 | th.save(self.generator.state_dict(), gen_save_file)
528 | th.save(self.discriminator.state_dict(), dis_save_file)
529 |
530 | print("Training completed ...")
531 |
532 | # return the generator and discriminator back to eval mode
533 | self.generator.eval()
534 | self.discriminator.eval()
535 |
--------------------------------------------------------------------------------
/attn_gan_pytorch/Utils.py:
--------------------------------------------------------------------------------
1 | """ module contains small utils for parsing configurations """
2 |
3 | import torch as th
4 |
5 |
6 | def get_act_fn(fn_name):
7 | """
8 | helper for creating the activation function
9 | :param fn_name: string containing act_fn name
10 | currently supports: [tanh, sigmoid, relu, lrelu]
11 | :return: fn => PyTorch activation function
12 | """
13 | fn_name = fn_name.lower()
14 |
15 | if fn_name == "tanh":
16 | fn = th.nn.Tanh()
17 |
18 | elif fn_name == "sigmoid":
19 | fn = th.nn.Sigmoid()
20 |
21 | elif fn_name == "relu":
22 | fn = th.nn.ReLU()
23 |
24 | elif "lrelu" in fn_name:
25 | negative_slope = float(fn_name.split("(")[-1][:-1])
26 | fn = th.nn.LeakyReLU(negative_slope=negative_slope)
27 |
28 | else:
29 | raise NotImplementedError("requested activation function is not implemented")
30 |
31 | return fn
32 |
33 |
34 | def get_layer(layer):
35 | """
36 | static private helper for creating a layer from the given conf
37 | :param layer: dict containing info
38 | :return: lay => PyTorch layer
39 | """
40 | from attn_gan_pytorch.CustomLayers import SelfAttention, \
41 | SpectralNorm, IgnoreAttentionMap, FullAttention
42 | from torch.nn import Sequential, Conv2d, Dropout2d, ConvTranspose2d, BatchNorm2d
43 | from attn_gan_pytorch.ConfigManagement import parse2tuple
44 |
45 | # lowercase the name
46 | name = layer.name.lower()
47 |
48 | if name == "conv":
49 | in_channels, out_channels = parse2tuple(layer.channels)
50 | kernel_size = parse2tuple(layer.kernel_dims)
51 | stride = parse2tuple(layer.stride)
52 | padding = parse2tuple(layer.padding)
53 | bias = layer.bias
54 | act_fn = get_act_fn(layer.activation)
55 |
56 | if hasattr(layer, "spectral_norm") and layer.spectral_norm:
57 | if layer.batch_norm:
58 | mod_layer = Sequential(
59 | SpectralNorm(Conv2d(in_channels, out_channels, kernel_size,
60 | stride, padding, bias=bias)),
61 | BatchNorm2d(out_channels),
62 | act_fn
63 | )
64 | else:
65 | mod_layer = Sequential(
66 | SpectralNorm(Conv2d(in_channels, out_channels, kernel_size,
67 | stride, padding, bias=bias)),
68 | act_fn
69 | )
70 | else:
71 | if layer.batch_norm:
72 | mod_layer = Sequential(
73 | Conv2d(in_channels, out_channels, kernel_size,
74 | stride, padding, bias=bias),
75 | BatchNorm2d(out_channels),
76 | act_fn
77 | )
78 | else:
79 | mod_layer = Sequential(
80 | Conv2d(in_channels, out_channels, kernel_size,
81 | stride, padding, bias=bias),
82 | act_fn
83 | )
84 |
85 | elif name == "conv_transpose":
86 | in_channels, out_channels = parse2tuple(layer.channels)
87 | kernel_size = parse2tuple(layer.kernel_dims)
88 | stride = parse2tuple(layer.stride)
89 | padding = parse2tuple(layer.padding)
90 | bias = layer.bias
91 | act_fn = get_act_fn(layer.activation)
92 |
93 | if hasattr(layer, "spectral_norm") and layer.spectral_norm:
94 | if layer.batch_norm:
95 | mod_layer = Sequential(
96 | SpectralNorm(ConvTranspose2d(in_channels, out_channels, kernel_size,
97 | stride, padding, bias=bias)),
98 | BatchNorm2d(out_channels),
99 | act_fn
100 | )
101 | else:
102 | mod_layer = Sequential(
103 | SpectralNorm(ConvTranspose2d(in_channels, out_channels, kernel_size,
104 | stride, padding, bias=bias)),
105 | act_fn
106 | )
107 | else:
108 | if layer.batch_norm:
109 | mod_layer = Sequential(
110 | ConvTranspose2d(in_channels, out_channels, kernel_size,
111 | stride, padding, bias=bias),
112 | BatchNorm2d(out_channels),
113 | act_fn
114 | )
115 | else:
116 | mod_layer = Sequential(
117 | ConvTranspose2d(in_channels, out_channels, kernel_size,
118 | stride, padding, bias=bias),
119 | act_fn
120 | )
121 |
122 | elif name == "dropout":
123 | drop_probability = layer.drop_prob
124 | mod_layer = Dropout2d(p=drop_probability, inplace=False)
125 |
126 | elif name == "batch_norm":
127 | channel_num = layer.num_channels
128 | mod_layer = BatchNorm2d(channel_num)
129 |
130 | elif name == "ignore_attn_maps":
131 | mod_layer = IgnoreAttentionMap()
132 |
133 | elif name == "self_attention":
134 | channels = layer.channels
135 | squeeze_factor = layer.squeeze_factor
136 | bias = layer.bias
137 |
138 | if hasattr(layer, "activation"):
139 | act_fn = get_act_fn(layer.activation)
140 | mod_layer = SelfAttention(channels, act_fn, squeeze_factor, bias)
141 | else:
142 | mod_layer = SelfAttention(channels, None, squeeze_factor, bias)
143 |
144 | elif name == "full_attention":
145 | in_channels, out_channels = parse2tuple(layer.channels)
146 | kernel_size = parse2tuple(layer.kernel_dims)
147 | squeeze_factor = layer.squeeze_factor
148 | stride = parse2tuple(layer.stride)
149 | use_batch_norm = layer.use_batch_norm
150 | use_spectral_norm = layer.use_spectral_norm
151 | padding = parse2tuple(layer.padding)
152 | transpose_conv = layer.transpose_conv
153 | bias = layer.bias
154 |
155 | if hasattr(layer, "activation"):
156 | act_fn = get_act_fn(layer.activation)
157 | mod_layer = FullAttention(in_channels, out_channels, act_fn,
158 | kernel_size, transpose_conv,
159 | use_spectral_norm, use_batch_norm,
160 | squeeze_factor, stride, padding, bias)
161 | else:
162 | mod_layer = FullAttention(in_channels, out_channels, None,
163 | kernel_size, transpose_conv,
164 | use_spectral_norm, use_batch_norm,
165 | squeeze_factor, stride, padding, bias)
166 | else:
167 | raise ValueError("unknown layer type requested")
168 |
169 | return mod_layer
170 |
--------------------------------------------------------------------------------
/attn_gan_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | """ package implements Attentional gan as an extension of PyTorch nn.Module """
2 |
3 | # import everything for flat package access also
4 | from attn_gan_pytorch import ConfigManagement
5 | from attn_gan_pytorch import CustomLayers
6 | from attn_gan_pytorch import Losses
7 | from attn_gan_pytorch import Networks
8 | from attn_gan_pytorch import Utils
9 |
--------------------------------------------------------------------------------
/literature/self_attention_gan.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/akanimax/attn_gan_pytorch/2cf3810963eaf00ebc642c9413a0a0ab79a4a7bc/literature/self_attention_gan.pdf
--------------------------------------------------------------------------------
/samples/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore the data folder
2 | data/
--------------------------------------------------------------------------------
/samples/data_processing/DataLoader.py:
--------------------------------------------------------------------------------
1 | """ Module for the data loading pipeline for the model to train """
2 |
3 | import os
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class FlatDirectoryImageDataset(Dataset):
8 | """ pyTorch Dataset wrapper for the generic flat directory images dataset """
9 |
10 | def __setup_files(self):
11 | """
12 | private helper for setting up the files_list
13 | :return: files => list of paths of files
14 | """
15 | file_names = os.listdir(self.data_dir)
16 | files = [] # initialize to empty list
17 |
18 | for file_name in file_names:
19 | possible_file = os.path.join(self.data_dir, file_name)
20 | if os.path.isfile(possible_file):
21 | files.append(possible_file)
22 |
23 | # return the files list
24 | return files
25 |
26 | def __init__(self, data_dir, transform=None):
27 | """
28 | constructor for the class
29 | :param data_dir: path to the directory containing the data
30 | :param transform: transforms to be applied to the images
31 | """
32 | # define the state of the object
33 | self.data_dir = data_dir
34 | self.transform = transform
35 |
36 | # setup the files for reading
37 | self.files = self.__setup_files()
38 |
39 | def __len__(self):
40 | """
41 | compute the length of the dataset
42 | :return: len => length of dataset
43 | """
44 | return len(self.files)
45 |
46 | def __getitem__(self, idx):
47 | """
48 | obtain the image (read and transform)
49 | :param idx: index of the file required
50 | :return: img => image array
51 | """
52 | from PIL import Image
53 |
54 | # read the image:
55 | img = Image.open(self.files[idx])
56 |
57 | # apply the transforms on the image
58 | if self.transform is not None:
59 | img = self.transform(img)
60 |
61 | # return the image:
62 | return img
63 |
64 |
65 | class FoldersDistributedDataset(Dataset):
66 | """ pyTorch Dataset wrapper for the MNIST dataset """
67 |
68 | def __setup_files(self):
69 | """
70 | private helper for setting up the files_list
71 | :return: files => list of paths of files
72 | """
73 |
74 | dir_names = os.listdir(self.data_dir)
75 | files = [] # initialize to empty list
76 |
77 | for dir_name in dir_names:
78 | file_path = os.path.join(self.data_dir, dir_name)
79 | file_names = os.listdir(file_path)
80 | for file_name in file_names:
81 | possible_file = os.path.join(file_path, file_name)
82 | if os.path.isfile(possible_file):
83 | files.append(possible_file)
84 |
85 | # return the files list
86 | return files
87 |
88 | def __init__(self, data_dir, transform=None):
89 | """
90 | constructor for the class
91 | :param data_dir: path to the directory containing the data
92 | :param transform: transforms to be applied to the images
93 | """
94 | # define the state of the object
95 | self.data_dir = data_dir
96 | self.transform = transform
97 |
98 | # setup the files for reading
99 | self.files = self.__setup_files()
100 |
101 | def __len__(self):
102 | """
103 | compute the length of the dataset
104 | :return: len => length of dataset
105 | """
106 | return len(self.files)
107 |
108 | def __getitem__(self, idx):
109 | """
110 | obtain the image (read and transform)
111 | :param idx: index of the file required
112 | :return: img => image array
113 | """
114 | from PIL import Image
115 |
116 | # read the image:
117 | img = Image.open(self.files[idx])
118 |
119 | # apply the transforms on the image
120 | if self.transform is not None:
121 | img = self.transform(img)
122 |
123 | # convert the black and white image to RGB:
124 | img = img.expand(3, -1, -1)
125 |
126 | # return the image:
127 | return img
128 |
129 |
130 | def get_transform(new_size=None):
131 | """
132 | obtain the image transforms required for the input data
133 | :param new_size: size of the resized images
134 | :return: image_transform => transform object from TorchVision
135 | """
136 | from torchvision.transforms import ToTensor, Normalize, Compose, Resize
137 |
138 | if new_size is not None:
139 | image_transform = Compose([
140 | Resize(new_size),
141 | ToTensor(),
142 | Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
143 | ])
144 |
145 | else:
146 | image_transform = Compose([
147 | ToTensor(),
148 | Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
149 | ])
150 | return image_transform
151 |
152 |
153 | def get_data_loader(dataset, batch_size, num_workers):
154 | """
155 | generate the data_loader from the given dataset
156 | :param dataset: F2T dataset
157 | :param batch_size: batch size of the data
158 | :param num_workers: num of parallel readers
159 | :return: dl => dataloader for the dataset
160 | """
161 | from torch.utils.data import DataLoader
162 |
163 | dl = DataLoader(
164 | dataset,
165 | batch_size=batch_size,
166 | shuffle=True,
167 | num_workers=num_workers
168 | )
169 |
170 | return dl
--------------------------------------------------------------------------------
/samples/data_processing/__init__.py:
--------------------------------------------------------------------------------
1 | """ Package for keeping all the data processing utilities """
2 | from data_processing import DataLoader
--------------------------------------------------------------------------------
/samples/generate_loss_plots.py:
--------------------------------------------------------------------------------
1 | """ script for generating the loss plots from the Loss logs """
2 |
3 | import argparse
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def read_loss_log(file_name, delimiter='\t'):
8 | """
9 | read and load the loss values from a loss.log file
10 | :param file_name: path of the loss.log file
11 | :param delimiter: delimiter used to delimit the two columns
12 | :return: loss_val => numpy array [Iterations x 2]
13 | """
14 | from numpy import genfromtxt
15 | losses = genfromtxt(file_name, delimiter=delimiter)
16 | return losses
17 |
18 |
19 | def plot_loss(*loss_vals, plot_name="Loss plot",
20 | fig_size=(17, 7), save_path=None,
21 | legends=("discriminator", "generator")):
22 | """
23 | plot the discriminator loss values and save the plot if required
24 | :param loss_vals: (Variable Arg) numpy array or Sequence like for plotting values
25 | :param plot_name: Name of the plot
26 | :param fig_size: size of the generated figure (column_width, row_width)
27 | :param save_path: path to save the figure
28 | :param legends: list containing labels for loss plots' legends
29 | len(legends) == len(loss_vals)
30 | :return:
31 | """
32 | assert len(loss_vals) == len(legends), "Not enough labels for legends"
33 |
34 | plt.figure(figsize=fig_size).suptitle(plot_name)
35 | plt.grid(True, which="both")
36 | plt.ylabel("loss value")
37 | plt.xlabel("spaced iterations")
38 |
39 | plt.axhline(y=0, color='k')
40 | plt.axvline(x=0, color='k')
41 |
42 | # plot all the provided loss values in a single plot
43 | plts = []
44 | for loss_val in loss_vals:
45 | plts.append(plt.plot(loss_val)[0])
46 |
47 | plt.legend(plts, legends, loc="upper right", fontsize=16)
48 |
49 | if save_path is not None:
50 | plt.savefig(save_path)
51 |
52 |
53 | def parse_arguments():
54 | """
55 | command line arguments parser
56 | :return: args => parsed command line arguments
57 | """
58 | parser = argparse.ArgumentParser()
59 |
60 | parser.add_argument("--loss_file", action="store", type=str, default=None,
61 | help="path to loss log file")
62 |
63 | parser.add_argument("--plot_file", action="store", type=str, default=".",
64 | help="path to the file where plots are to be saved")
65 |
66 | args = parser.parse_args()
67 |
68 | return args
69 |
70 |
71 | def main(args):
72 | """
73 | Main function for the script
74 | :param args: parsed command line arguments
75 | :return: None
76 | """
77 | # Make sure input logs directory is provided
78 | assert args.loss_file is not None, "Loss-Log file not specified"
79 |
80 | # read the loss file
81 | loss_vals = read_loss_log(args.loss_file)
82 |
83 | # plot the loss:
84 | plot_loss(loss_vals[:, 0], loss_vals[:, 1], save_path=args.plot_file)
85 |
86 | print("Loss plots have been successfully generated ...")
87 | print("Please check: ", args.plot_file)
88 |
89 |
90 | if __name__ == '__main__':
91 | main(parse_arguments())
92 |
--------------------------------------------------------------------------------
/samples/sample_celeba/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore the generated samples and trained models
2 | models/
3 | samples/
--------------------------------------------------------------------------------
/samples/sample_celeba/configs/1/dis.conf:
--------------------------------------------------------------------------------
1 | # configuration for the discriminator architecture
2 |
3 | architecture:
4 | -
5 | name: "conv"
6 | channels: (3, 64)
7 | kernel_dims: (4, 4)
8 | stride: (2, 2)
9 | padding: (1, 1)
10 | bias: True
11 | batch_norm: False
12 | spectral_norm: True
13 | activation: "lrelu(0.1)"
14 |
15 | -
16 | name: "conv"
17 | channels: (64, 128)
18 | kernel_dims: (4, 4)
19 | stride: (2, 2)
20 | padding: (1, 1)
21 | bias: True
22 | batch_norm: False
23 | spectral_norm: True
24 | activation: "lrelu(0.1)"
25 |
26 | -
27 | name: "conv"
28 | channels: (128, 256)
29 | kernel_dims: (4, 4)
30 | stride: (2, 2)
31 | padding: (1, 1)
32 | bias: True
33 | batch_norm: False
34 | spectral_norm: True
35 | activation: "lrelu(0.1)"
36 |
37 | -
38 | name: "self_attention"
39 | channels: 256
40 | bias: True
41 | squeeze_factor: 8
42 |
43 | -
44 | name: "ignore_attn_maps"
45 |
46 | -
47 | name: "conv"
48 | channels: (256, 512)
49 | kernel_dims: (4, 4)
50 | stride: (2, 2)
51 | padding: (1, 1)
52 | bias: True
53 | batch_norm: False
54 | spectral_norm: True
55 | activation: "lrelu(0.1)"
56 |
57 | -
58 | name: "self_attention"
59 | channels: 512
60 | bias: True
61 | squeeze_factor: 8
62 |
63 | -
64 | name: "ignore_attn_maps"
65 |
66 | -
67 | name: "conv"
68 | channels: (512, 1)
69 | kernel_dims: (4, 4)
70 | stride: (1, 1)
71 | padding: (0, 0)
72 | bias: True
73 | batch_norm: False
74 | spectral_norm: False
75 | activation: "lrelu(1.0)"
--------------------------------------------------------------------------------
/samples/sample_celeba/configs/1/gen.conf:
--------------------------------------------------------------------------------
1 | # configuration for the Generator architecture
2 |
3 | architecture:
4 | -
5 | name: "conv_transpose"
6 | channels: (128, 512)
7 | kernel_dims: (4, 4)
8 | stride: (1, 1)
9 | padding: (0, 0)
10 | bias: True
11 | batch_norm: True
12 | spectral_norm: True
13 | activation: "relu"
14 |
15 | -
16 | name: "conv_transpose"
17 | channels: (512, 256)
18 | kernel_dims: (4, 4)
19 | stride: (2, 2)
20 | padding: (1, 1)
21 | bias: True
22 | batch_norm: True
23 | spectral_norm: True
24 | activation: "relu"
25 |
26 | -
27 | name: "conv_transpose"
28 | channels: (256, 128)
29 | kernel_dims: (4, 4)
30 | stride: (2, 2)
31 | padding: (1, 1)
32 | bias: True
33 | batch_norm: True
34 | spectral_norm: True
35 | activation: "relu"
36 |
37 | -
38 | name: "self_attention"
39 | channels: 128
40 | bias: True
41 | squeeze_factor: 8
42 |
43 | -
44 | name: "ignore_attn_maps"
45 |
46 | -
47 | name: "conv_transpose"
48 | channels: (128, 64)
49 | kernel_dims: (4, 4)
50 | stride: (2, 2)
51 | padding: (1, 1)
52 | bias: True
53 | batch_norm: True
54 | spectral_norm: True
55 | activation: "relu"
56 |
57 | -
58 | name: "self_attention"
59 | channels: 64
60 | bias: True
61 | squeeze_factor: 8
62 |
63 | -
64 | name: "ignore_attn_maps"
65 |
66 | -
67 | name: "conv_transpose"
68 | channels: (64, 3)
69 | kernel_dims: (4, 4)
70 | stride: (2, 2)
71 | padding: (1, 1)
72 | bias: True
73 | batch_norm: False
74 | spectral_norm: False
75 | activation: "tanh"
--------------------------------------------------------------------------------
/samples/sample_celeba/configs/2/dis.conf:
--------------------------------------------------------------------------------
1 | # configuration for the discriminator architecture
2 |
3 | architecture:
4 | -
5 | name: "conv"
6 | channels: (3, 64)
7 | kernel_dims: (4, 4)
8 | stride: (2, 2)
9 | padding: (1, 1)
10 | bias: True
11 | batch_norm: False
12 | spectral_norm: True
13 | activation: "lrelu(0.1)"
14 |
15 | -
16 | name: "conv"
17 | channels: (64, 128)
18 | kernel_dims: (4, 4)
19 | stride: (2, 2)
20 | padding: (1, 1)
21 | bias: True
22 | batch_norm: False
23 | spectral_norm: True
24 | activation: "lrelu(0.1)"
25 |
26 | -
27 | name: "full_attention"
28 | channels: (128, 256)
29 | kernel_dims: (4, 4)
30 | stride: (2, 2)
31 | padding: (1, 1)
32 | bias: True
33 | use_batch_norm: False
34 | use_spectral_norm: False
35 | squeeze_factor: 8
36 | transpose_conv: False
37 | activation: "lrelu(0.3)"
38 |
39 | -
40 | name: "ignore_attn_maps"
41 |
42 | -
43 | name: "self_attention"
44 | channels: 256
45 | bias: True
46 | squeeze_factor: 8
47 |
48 | -
49 | name: "ignore_attn_maps"
50 |
51 | -
52 | name: "conv"
53 | channels: (256, 512)
54 | kernel_dims: (4, 4)
55 | stride: (2, 2)
56 | padding: (1, 1)
57 | bias: True
58 | batch_norm: False
59 | spectral_norm: True
60 | activation: "lrelu(0.1)"
61 |
62 | -
63 | name: "self_attention"
64 | channels: 512
65 | bias: True
66 | squeeze_factor: 8
67 |
68 | -
69 | name: "ignore_attn_maps"
70 |
71 | -
72 | name: "conv"
73 | channels: (512, 1)
74 | kernel_dims: (4, 4)
75 | stride: (1, 1)
76 | padding: (0, 0)
77 | bias: True
78 | batch_norm: False
79 | spectral_norm: False
80 | activation: "lrelu(1.0)"
--------------------------------------------------------------------------------
/samples/sample_celeba/configs/2/gen.conf:
--------------------------------------------------------------------------------
1 | # configuration for the Generator architecture
2 |
3 | architecture:
4 | -
5 | name: "conv_transpose"
6 | channels: (128, 512)
7 | kernel_dims: (4, 4)
8 | stride: (1, 1)
9 | padding: (0, 0)
10 | bias: True
11 | batch_norm: True
12 | spectral_norm: True
13 | activation: "relu"
14 |
15 | -
16 | name: "full_attention"
17 | channels: (512, 256)
18 | kernel_dims: (4, 4)
19 | stride: (2, 2)
20 | padding: (1, 1)
21 | bias: True
22 | use_spectral_norm: False
23 | use_batch_norm: False
24 | squeeze_factor: 8
25 | transpose_conv: True
26 | activation: "lrelu(0.3)"
27 |
28 | -
29 | name: "ignore_attn_maps"
30 |
31 | -
32 | name: "conv_transpose"
33 | channels: (256, 128)
34 | kernel_dims: (4, 4)
35 | stride: (2, 2)
36 | padding: (1, 1)
37 | bias: True
38 | batch_norm: True
39 | spectral_norm: True
40 | activation: "relu"
41 |
42 | -
43 | name: "self_attention"
44 | channels: 128
45 | bias: True
46 | squeeze_factor: 8
47 |
48 | -
49 | name: "ignore_attn_maps"
50 |
51 | -
52 | name: "conv_transpose"
53 | channels: (128, 64)
54 | kernel_dims: (4, 4)
55 | stride: (2, 2)
56 | padding: (1, 1)
57 | bias: True
58 | batch_norm: True
59 | spectral_norm: True
60 | activation: "relu"
61 |
62 | -
63 | name: "self_attention"
64 | channels: 64
65 | bias: True
66 | squeeze_factor: 8
67 |
68 | -
69 | name: "ignore_attn_maps"
70 |
71 | -
72 | name: "conv_transpose"
73 | channels: (64, 3)
74 | kernel_dims: (4, 4)
75 | stride: (2, 2)
76 | padding: (1, 1)
77 | bias: True
78 | batch_norm: False
79 | spectral_norm: False
80 | activation: "tanh"
--------------------------------------------------------------------------------
/samples/sample_celeba/configs/3/dis.conf:
--------------------------------------------------------------------------------
1 | # configuration for the discriminator architecture
2 |
3 | architecture:
4 | -
5 | name: "conv"
6 | channels: (3, 64)
7 | kernel_dims: (4, 4)
8 | stride: (2, 2)
9 | padding: (1, 1)
10 | bias: True
11 | batch_norm: False
12 | spectral_norm: True
13 | activation: "lrelu(0.1)"
14 |
15 | -
16 | name: "conv"
17 | channels: (64, 128)
18 | kernel_dims: (4, 4)
19 | stride: (2, 2)
20 | padding: (1, 1)
21 | bias: True
22 | batch_norm: False
23 | spectral_norm: True
24 | activation: "lrelu(0.1)"
25 |
26 | -
27 | name: "full_attention"
28 | channels: (128, 256)
29 | kernel_dims: (4, 4)
30 | stride: (2, 2)
31 | padding: (1, 1)
32 | bias: True
33 | use_spectral_norm: True
34 | use_batch_norm: True
35 | squeeze_factor: 8
36 | transpose_conv: False
37 | activation: "lrelu(0.3)"
38 |
39 | -
40 | name: "ignore_attn_maps"
41 |
42 | -
43 | name: "self_attention"
44 | channels: 256
45 | bias: True
46 | squeeze_factor: 8
47 |
48 | -
49 | name: "ignore_attn_maps"
50 |
51 | -
52 | name: "conv"
53 | channels: (256, 512)
54 | kernel_dims: (4, 4)
55 | stride: (2, 2)
56 | padding: (1, 1)
57 | bias: True
58 | batch_norm: False
59 | spectral_norm: True
60 | activation: "lrelu(0.1)"
61 |
62 | -
63 | name: "self_attention"
64 | channels: 512
65 | bias: True
66 | squeeze_factor: 8
67 |
68 | -
69 | name: "ignore_attn_maps"
70 |
71 | -
72 | name: "conv"
73 | channels: (512, 1)
74 | kernel_dims: (4, 4)
75 | stride: (1, 1)
76 | padding: (0, 0)
77 | bias: True
78 | batch_norm: False
79 | spectral_norm: False
80 | activation: "lrelu(1.0)"
--------------------------------------------------------------------------------
/samples/sample_celeba/configs/3/gen.conf:
--------------------------------------------------------------------------------
1 | # configuration for the Generator architecture
2 |
3 | architecture:
4 | -
5 | name: "conv_transpose"
6 | channels: (128, 512)
7 | kernel_dims: (4, 4)
8 | stride: (1, 1)
9 | padding: (0, 0)
10 | bias: True
11 | batch_norm: True
12 | spectral_norm: True
13 | activation: "relu"
14 |
15 | -
16 | name: "full_attention"
17 | channels: (512, 256)
18 | kernel_dims: (4, 4)
19 | stride: (2, 2)
20 | padding: (1, 1)
21 | bias: True
22 | use_spectral_norm: True
23 | use_batch_norm: True
24 | squeeze_factor: 8
25 | transpose_conv: True
26 | activation: "lrelu(0.3)"
27 |
28 | -
29 | name: "ignore_attn_maps"
30 |
31 | -
32 | name: "conv_transpose"
33 | channels: (256, 128)
34 | kernel_dims: (4, 4)
35 | stride: (2, 2)
36 | padding: (1, 1)
37 | bias: True
38 | batch_norm: True
39 | spectral_norm: True
40 | activation: "relu"
41 |
42 | -
43 | name: "self_attention"
44 | channels: 128
45 | bias: True
46 | squeeze_factor: 8
47 |
48 | -
49 | name: "ignore_attn_maps"
50 |
51 | -
52 | name: "conv_transpose"
53 | channels: (128, 64)
54 | kernel_dims: (4, 4)
55 | stride: (2, 2)
56 | padding: (1, 1)
57 | bias: True
58 | batch_norm: True
59 | spectral_norm: True
60 | activation: "relu"
61 |
62 | -
63 | name: "self_attention"
64 | channels: 64
65 | bias: True
66 | squeeze_factor: 8
67 |
68 | -
69 | name: "ignore_attn_maps"
70 |
71 | -
72 | name: "conv_transpose"
73 | channels: (64, 3)
74 | kernel_dims: (4, 4)
75 | stride: (2, 2)
76 | padding: (1, 1)
77 | bias: True
78 | batch_norm: False
79 | spectral_norm: False
80 | activation: "tanh"
--------------------------------------------------------------------------------
/samples/sample_celeba/train.py:
--------------------------------------------------------------------------------
1 | """ script for training a Self Attention GAN on celeba images """
2 |
3 | import torch as th
4 | import argparse
5 |
6 | from torch.backends import cudnn
7 |
8 | # define the device for the training script
9 | device = th.device("cuda" if th.cuda.is_available() else "cpu")
10 |
11 | # enable fast training
12 | cudnn.benchmark = True
13 |
14 |
15 | def parse_arguments():
16 | """
17 | command line arguments parser
18 | :return: args => parsed command line arguments
19 | """
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--generator_config", action="store", type=str,
22 | default="configs/3/gen.conf",
23 | help="default configuration for generator network")
24 |
25 | parser.add_argument("--discriminator_config", action="store", type=str,
26 | default="configs/3/dis.conf",
27 | help="default configuration for discriminator network")
28 |
29 | parser.add_argument("--images_dir", action="store", type=str,
30 | default="../data/celeba",
31 | help="path for the images directory")
32 |
33 | parser.add_argument("--latent_size", action="store", type=int,
34 | default=128,
35 | help="latent size for the generator")
36 |
37 | parser.add_argument("--batch_size", action="store", type=int,
38 | default=64,
39 | help="batch_size for training")
40 |
41 | parser.add_argument("--num_epochs", action="store", type=int,
42 | default=3,
43 | help="number of epochs for training")
44 |
45 | parser.add_argument("--checkpoint_factor", action="store", type=int,
46 | default=1,
47 | help="save model per n epochs")
48 |
49 | parser.add_argument("--g_lr", action="store", type=float,
50 | default=0.0001,
51 | help="learning rate for generator")
52 |
53 | parser.add_argument("--d_lr", action="store", type=float,
54 | default=0.0004,
55 | help="learning rate for discriminator")
56 |
57 | parser.add_argument("--data_percentage", action="store", type=float,
58 | default=100,
59 | help="percentage of data to use")
60 |
61 | parser.add_argument("--num_workers", action="store", type=int,
62 | default=3,
63 | help="number of parallel workers for reading files")
64 |
65 | args = parser.parse_args()
66 |
67 | return args
68 |
69 |
70 | def main(args):
71 | """
72 | Main function for the script
73 | :param args: parsed command line arguments
74 | :return: None
75 | """
76 | from attn_gan_pytorch.Utils import get_layer
77 | from attn_gan_pytorch.ConfigManagement import get_config
78 | from attn_gan_pytorch.Networks import Generator, Discriminator, GAN
79 | from data_processing.DataLoader import FlatDirectoryImageDataset, \
80 | get_transform, get_data_loader
81 | from attn_gan_pytorch.Losses import RelativisticAverageHingeGAN
82 |
83 | # create a data source:
84 | celeba_dataset = FlatDirectoryImageDataset(args.images_dir,
85 | transform=get_transform((64, 64)))
86 | data = get_data_loader(celeba_dataset, args.batch_size, args.num_workers)
87 |
88 | # create generator object:
89 | gen_conf = get_config(args.generator_config)
90 | gen_conf = list(map(get_layer, gen_conf.architecture))
91 | generator = Generator(gen_conf, args.latent_size)
92 |
93 | print("Generator Configuration: ")
94 | print(generator)
95 |
96 | # create discriminator object:
97 | dis_conf = get_config(args.discriminator_config)
98 | dis_conf = list(map(get_layer, dis_conf.architecture))
99 | discriminator = Discriminator(dis_conf)
100 |
101 | print("Discriminator Configuration: ")
102 | print(discriminator)
103 |
104 | # create a gan from these
105 | sagan = GAN(generator, discriminator, device=device)
106 |
107 | # create optimizer for generator:
108 | gen_optim = th.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()),
109 | args.g_lr, [0, 0.9])
110 |
111 | dis_optim = th.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()),
112 | args.d_lr, [0, 0.9])
113 |
114 | # train the GAN
115 | sagan.train(
116 | data,
117 | gen_optim,
118 | dis_optim,
119 | loss_fn=RelativisticAverageHingeGAN(device, discriminator),
120 | num_epochs=args.num_epochs,
121 | checkpoint_factor=args.checkpoint_factor,
122 | data_percentage=args.data_percentage,
123 | feedback_factor=31,
124 | num_samples=64,
125 | save_dir="models/relativistic/",
126 | sample_dir="samples/4/",
127 | log_dir="models/relativistic"
128 | )
129 |
130 |
131 | if __name__ == '__main__':
132 | # invoke the main function of the script
133 | main(parse_arguments())
134 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='attn_gan_pytorch',
5 | version='0.6',
6 | packages=find_packages(exclude=("samples", "literature")),
7 | url='https://github.com/akanimax/attn_gan_pytorch',
8 | license='MIT',
9 | author='animesh karnewar',
10 | author_email='animeshsk3@gmail.com',
11 | description='python package for self-attention gan implemented as extension of ' +
12 | 'PyTorch nn.Module. paper -> https://arxiv.org/abs/1805.08318',
13 | install_requires=['torch', 'torchvision', 'numpy', 'PyYAML']
14 | )
15 |
--------------------------------------------------------------------------------