├── DeepI2I_BigGAN ├── BigGAN.py ├── BigGANdeep.py ├── README.md ├── TFHub │ ├── README.md │ ├── biggan_v1.py │ └── converter.py ├── __pycache__ │ ├── BigGAN.cpython-36.pyc │ ├── BigGAN.cpython-37.pyc │ ├── animal_hash.cpython-36.pyc │ ├── animal_hash.cpython-37.pyc │ ├── datasets.cpython-36.pyc │ ├── datasets.cpython-37.pyc │ ├── inception_utils.cpython-36.pyc │ ├── inception_utils.cpython-37.pyc │ ├── layers.cpython-36.pyc │ ├── layers.cpython-37.pyc │ ├── losses.cpython-36.pyc │ ├── losses.cpython-37.pyc │ ├── train_fns.cpython-36.pyc │ ├── train_fns.cpython-37.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-37.pyc ├── animal_hash.py ├── calculate_inception_moments.py ├── class_to_index │ ├── DeepI2I_NABirds │ │ └── I128_imgs.pickle │ ├── DeepI2I_UECFOOD256 │ │ └── I128_imgs.pickle │ └── DeepI2I_animals │ │ └── I128_imgs.pickle ├── datasets.py ├── inception_tf13.py ├── inception_utils.py ├── layers.py ├── losses.py ├── make_hdf5.py ├── merge_image.py ├── sample.py ├── scripts │ ├── .DeepI2I.sh.swp │ ├── DeepI2I.sh │ ├── DeepI2I_test.sh │ ├── launch_BigGAN_bs256x8.sh~ │ └── utils │ │ ├── duplicate.sh │ │ └── prepare_data.sh ├── sync_batchnorm │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── batchnorm.cpython-36.pyc │ │ ├── batchnorm.cpython-37.pyc │ │ ├── comm.cpython-36.pyc │ │ ├── comm.cpython-37.pyc │ │ ├── replicate.cpython-36.pyc │ │ └── replicate.cpython-37.pyc │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── test.py ├── train.py ├── train_fns.py └── utils.py ├── DeepI2I_StyleGAN ├── README.md ├── __pycache__ │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-37.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ └── train.cpython-36.pyc ├── dataset.py ├── generate.py ├── model.py ├── prepare_data.py └── train.py ├── README.md └── figures ├── framework.png ├── interpolation.png └── sample.png /DeepI2I_BigGAN/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | - [Dependences](#dependences) 3 | - [Installation](#installtion) 4 | - [Instructions](#instructions) 5 | # Dependences 6 | - Python2.7, NumPy, SciPy, NVIDIA GPU 7 | - **Pytorch:** pytorch is more 1.0 8 | - **Dataset:** [animals](https://github.com/NVlabs/FUNIT), [NABirds](https://dl.allaboutbirds.org/nabirds) and [UECFOOD-256](http://foodcam.mobi/dataset256.html) 9 | 10 | # Installation 11 | - Install pytorch 12 | # Instructions 13 | 14 | - `git clone https://github.com/yaxingwang/DeepI2I.git` to get `DeepI2I`, and `cd DeepI2I/DeepI2I_BigGAN` 15 | 16 | 17 | - Pretrained model: downloading the pretrained model from [Biggan](https://github.com/ajbrock/BigGAN-PyTorch). Note using `G_ema.pth` to replace `G.pth`, since we dones't use `ema`. The pretrained model is moved into `BigGAN_weights/` 18 | 19 | 20 | - Preparing data: leveraging `sh scripts/utils/prepare_data.py`, and put it into `data/your_data/data`. Please check [Biggan](https://github.com/ajbrock/BigGAN-PyTorch) to learn how to generate the data 21 | 22 | I have already created an [example](https://drive.google.com/drive/folders/1Wvmz_SHlJekHjuC4UJCncxdcJsYlwcCb?usp=sharing). Downloading the three items and put them into `data/animals`. Also I upload the [compressed NABirds and UECFOOD-256](https://drive.google.com/drive/folders/1mftJ5RpTs2zPkf3c19suIGMXkswrgO5f?usp=sharing), which is only be used for our project. 23 | 24 | Note when you processe data, try to save the label of each categroy, since it is easy for you to further generate images by leveraging label, and compare to the corresponding GT. Here I save them in the folder `class_to_index/*`, which is used in test time. 25 | 26 | - Traing: ```sh scripts/DeepI2I.sh``` 27 | 28 | The corresponding model and generated images are saved in `result/animals` where four items are automatically generated: 'logs', 'samples', 'test' and 'weights'. 29 | 30 | 31 | - Testing: ```sh scripts/DeepI2I_test.sh``` 32 | 33 | Note if you use new name (e.g., '--experiment DeepI2I_animalv2' in 'scripts/DeepI2I.sh'), you should also use it in 'scripts/DeepI2I_test.sh', and rename the fold ( 'class_to_index/DeepI2I_animals') to the new one ( 'class_to_index/DeepI2I_animalv2') 34 | 35 | Downloading our [pre-trained model](https://drive.google.com/drive/folders/19pSSiNDmebtm17ymw3tYe5V5G9wI6RHR?usp=sharing) on animals, and put it into 'result/animals/weights/DeepI2I_animals/0'. Also the pre-trained model for [birds](https://drive.google.com/drive/folders/1gZpkFzLp9w8X1PsTiqPrPJWll5DgX2XP?usp=sharing) and [foods](https://drive.google.com/drive/folders/1RgdpYmOoWnX0gqQzETgcpZQTPJFAw5Pp?usp=sharing) 36 | 37 | 38 | If you use the provided data and code, please cite the following papers: 39 | 40 | ``` 41 | @article{wang2020deepi2i, 42 | title={DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs}, 43 | author={Wang, Yaxing and Yu, Lu and van de Weijer, Joost}, 44 | journal={arXiv preprint arXiv:2011.05867}, 45 | year={2020} 46 | } 47 | 48 | ``` 49 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/TFHub/README.md: -------------------------------------------------------------------------------- 1 | # BigGAN-PyTorch TFHub converter 2 | This dir contains scripts for taking the [pre-trained generator weights from TFHub](https://tfhub.dev/s?q=biggan) and porting them to BigGAN-Pytorch. 3 | 4 | In addition to the base libraries for BigGAN-PyTorch, to run this code you will need: 5 | 6 | TensorFlow 7 | TFHub 8 | parse 9 | 10 | Note that this code is only presently set up to run the ported models without truncation--you'll need to accumulate standing stats at each truncation level yourself if you wish to employ it. 11 | 12 | To port the 128x128 model from tfhub, produce a pretrained weights .pth file, and generate samples using all your GPUs, run 13 | 14 | `python converter.py -r 128 --generate_samples --parallel` -------------------------------------------------------------------------------- /DeepI2I_BigGAN/TFHub/biggan_v1.py: -------------------------------------------------------------------------------- 1 | # BigGAN V1: 2 | # This is now deprecated code used for porting the TFHub modules to pytorch, 3 | # included here for reference only. 4 | import numpy as np 5 | import torch 6 | from scipy.stats import truncnorm 7 | from torch import nn 8 | from torch.nn import Parameter 9 | from torch.nn import functional as F 10 | 11 | 12 | def l2normalize(v, eps=1e-4): 13 | return v / (v.norm() + eps) 14 | 15 | 16 | def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None): 17 | state = None if seed is None else np.random.RandomState(seed) 18 | values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state) 19 | return truncation * values 20 | 21 | 22 | def denorm(x): 23 | out = (x + 1) / 2 24 | return out.clamp_(0, 1) 25 | 26 | 27 | class SpectralNorm(nn.Module): 28 | def __init__(self, module, name='weight', power_iterations=1): 29 | super(SpectralNorm, self).__init__() 30 | self.module = module 31 | self.name = name 32 | self.power_iterations = power_iterations 33 | if not self._made_params(): 34 | self._make_params() 35 | 36 | def _update_u_v(self): 37 | u = getattr(self.module, self.name + "_u") 38 | v = getattr(self.module, self.name + "_v") 39 | w = getattr(self.module, self.name + "_bar") 40 | 41 | height = w.data.shape[0] 42 | _w = w.view(height, -1) 43 | for _ in range(self.power_iterations): 44 | v = l2normalize(torch.matmul(_w.t(), u)) 45 | u = l2normalize(torch.matmul(_w, v)) 46 | 47 | sigma = u.dot((_w).mv(v)) 48 | setattr(self.module, self.name, w / sigma.expand_as(w)) 49 | 50 | def _made_params(self): 51 | try: 52 | getattr(self.module, self.name + "_u") 53 | getattr(self.module, self.name + "_v") 54 | getattr(self.module, self.name + "_bar") 55 | return True 56 | except AttributeError: 57 | return False 58 | 59 | def _make_params(self): 60 | w = getattr(self.module, self.name) 61 | 62 | height = w.data.shape[0] 63 | width = w.view(height, -1).data.shape[1] 64 | 65 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 66 | v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 67 | u.data = l2normalize(u.data) 68 | v.data = l2normalize(v.data) 69 | w_bar = Parameter(w.data) 70 | 71 | del self.module._parameters[self.name] 72 | self.module.register_parameter(self.name + "_u", u) 73 | self.module.register_parameter(self.name + "_v", v) 74 | self.module.register_parameter(self.name + "_bar", w_bar) 75 | 76 | def forward(self, *args): 77 | self._update_u_v() 78 | return self.module.forward(*args) 79 | 80 | 81 | class SelfAttention(nn.Module): 82 | """ Self Attention Layer""" 83 | 84 | def __init__(self, in_dim, activation=F.relu): 85 | super().__init__() 86 | self.chanel_in = in_dim 87 | self.activation = activation 88 | 89 | self.theta = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False)) 90 | self.phi = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False)) 91 | self.pool = nn.MaxPool2d(2, 2) 92 | self.g = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1, bias=False)) 93 | self.o_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim // 2, out_channels=in_dim, kernel_size=1, bias=False)) 94 | self.gamma = nn.Parameter(torch.zeros(1)) 95 | 96 | self.softmax = nn.Softmax(dim=-1) 97 | 98 | def forward(self, x): 99 | m_batchsize, C, width, height = x.size() 100 | N = height * width 101 | 102 | theta = self.theta(x) 103 | phi = self.phi(x) 104 | phi = self.pool(phi) 105 | phi = phi.view(m_batchsize, -1, N // 4) 106 | theta = theta.view(m_batchsize, -1, N) 107 | theta = theta.permute(0, 2, 1) 108 | attention = self.softmax(torch.bmm(theta, phi)) 109 | g = self.pool(self.g(x)).view(m_batchsize, -1, N // 4) 110 | attn_g = torch.bmm(g, attention.permute(0, 2, 1)).view(m_batchsize, -1, width, height) 111 | out = self.o_conv(attn_g) 112 | return self.gamma * out + x 113 | 114 | 115 | class ConditionalBatchNorm2d(nn.Module): 116 | def __init__(self, num_features, num_classes, eps=1e-4, momentum=0.1): 117 | super().__init__() 118 | self.num_features = num_features 119 | self.bn = nn.BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum) 120 | self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) 121 | self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) 122 | 123 | def forward(self, x, y): 124 | out = self.bn(x) 125 | gamma = self.gamma_embed(y) + 1 126 | beta = self.beta_embed(y) 127 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 128 | return out 129 | 130 | 131 | class GBlock(nn.Module): 132 | def __init__( 133 | self, 134 | in_channel, 135 | out_channel, 136 | kernel_size=[3, 3], 137 | padding=1, 138 | stride=1, 139 | n_class=None, 140 | bn=True, 141 | activation=F.relu, 142 | upsample=True, 143 | downsample=False, 144 | z_dim=148, 145 | ): 146 | super().__init__() 147 | 148 | self.conv0 = SpectralNorm( 149 | nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True) 150 | ) 151 | self.conv1 = SpectralNorm( 152 | nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True) 153 | ) 154 | 155 | self.skip_proj = False 156 | if in_channel != out_channel or upsample or downsample: 157 | self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0)) 158 | self.skip_proj = True 159 | 160 | self.upsample = upsample 161 | self.downsample = downsample 162 | self.activation = activation 163 | self.bn = bn 164 | if bn: 165 | self.HyperBN = ConditionalBatchNorm2d(in_channel, z_dim) 166 | self.HyperBN_1 = ConditionalBatchNorm2d(out_channel, z_dim) 167 | 168 | def forward(self, input, condition=None): 169 | out = input 170 | 171 | if self.bn: 172 | out = self.HyperBN(out, condition) 173 | out = self.activation(out) 174 | if self.upsample: 175 | out = F.interpolate(out, scale_factor=2) 176 | out = self.conv0(out) 177 | if self.bn: 178 | out = self.HyperBN_1(out, condition) 179 | out = self.activation(out) 180 | out = self.conv1(out) 181 | 182 | if self.downsample: 183 | out = F.avg_pool2d(out, 2) 184 | 185 | if self.skip_proj: 186 | skip = input 187 | if self.upsample: 188 | skip = F.interpolate(skip, scale_factor=2) 189 | skip = self.conv_sc(skip) 190 | if self.downsample: 191 | skip = F.avg_pool2d(skip, 2) 192 | else: 193 | skip = input 194 | return out + skip 195 | 196 | 197 | class Generator128(nn.Module): 198 | def __init__(self, code_dim=120, n_class=1000, chn=96, debug=False): 199 | super().__init__() 200 | 201 | self.linear = nn.Linear(n_class, 128, bias=False) 202 | 203 | if debug: 204 | chn = 8 205 | 206 | self.first_view = 16 * chn 207 | 208 | self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) 209 | 210 | z_dim = code_dim + 28 211 | 212 | self.GBlock = nn.ModuleList([ 213 | GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), 214 | GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), 215 | GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), 216 | GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), 217 | GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), 218 | ]) 219 | 220 | self.sa_id = 4 221 | self.num_split = len(self.GBlock) + 1 222 | self.attention = SelfAttention(2 * chn) 223 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) 224 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) 225 | 226 | def forward(self, input, class_id): 227 | codes = torch.chunk(input, self.num_split, 1) 228 | class_emb = self.linear(class_id) # 128 229 | 230 | out = self.G_linear(codes[0]) 231 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) 232 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): 233 | if i == self.sa_id: 234 | out = self.attention(out) 235 | condition = torch.cat([code, class_emb], 1) 236 | out = GBlock(out, condition) 237 | 238 | out = self.ScaledCrossReplicaBN(out) 239 | out = F.relu(out) 240 | out = self.colorize(out) 241 | return torch.tanh(out) 242 | 243 | 244 | class Generator256(nn.Module): 245 | def __init__(self, code_dim=140, n_class=1000, chn=96, debug=False): 246 | super().__init__() 247 | 248 | self.linear = nn.Linear(n_class, 128, bias=False) 249 | 250 | if debug: 251 | chn = 8 252 | 253 | self.first_view = 16 * chn 254 | 255 | self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) 256 | 257 | self.GBlock = nn.ModuleList([ 258 | GBlock(16 * chn, 16 * chn, n_class=n_class), 259 | GBlock(16 * chn, 8 * chn, n_class=n_class), 260 | GBlock(8 * chn, 8 * chn, n_class=n_class), 261 | GBlock(8 * chn, 4 * chn, n_class=n_class), 262 | GBlock(4 * chn, 2 * chn, n_class=n_class), 263 | GBlock(2 * chn, 1 * chn, n_class=n_class), 264 | ]) 265 | 266 | self.sa_id = 5 267 | self.num_split = len(self.GBlock) + 1 268 | self.attention = SelfAttention(2 * chn) 269 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) 270 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) 271 | 272 | def forward(self, input, class_id): 273 | codes = torch.chunk(input, self.num_split, 1) 274 | class_emb = self.linear(class_id) # 128 275 | 276 | out = self.G_linear(codes[0]) 277 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) 278 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): 279 | if i == self.sa_id: 280 | out = self.attention(out) 281 | condition = torch.cat([code, class_emb], 1) 282 | out = GBlock(out, condition) 283 | 284 | out = self.ScaledCrossReplicaBN(out) 285 | out = F.relu(out) 286 | out = self.colorize(out) 287 | return torch.tanh(out) 288 | 289 | 290 | class Generator512(nn.Module): 291 | def __init__(self, code_dim=128, n_class=1000, chn=96, debug=False): 292 | super().__init__() 293 | 294 | self.linear = nn.Linear(n_class, 128, bias=False) 295 | 296 | if debug: 297 | chn = 8 298 | 299 | self.first_view = 16 * chn 300 | 301 | self.G_linear = SpectralNorm(nn.Linear(16, 4 * 4 * 16 * chn)) 302 | 303 | z_dim = code_dim + 16 304 | 305 | self.GBlock = nn.ModuleList([ 306 | GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), 307 | GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), 308 | GBlock(8 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), 309 | GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), 310 | GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), 311 | GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), 312 | GBlock(1 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), 313 | ]) 314 | 315 | self.sa_id = 4 316 | self.num_split = len(self.GBlock) + 1 317 | self.attention = SelfAttention(4 * chn) 318 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn) 319 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) 320 | 321 | def forward(self, input, class_id): 322 | codes = torch.chunk(input, self.num_split, 1) 323 | class_emb = self.linear(class_id) # 128 324 | 325 | out = self.G_linear(codes[0]) 326 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) 327 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): 328 | if i == self.sa_id: 329 | out = self.attention(out) 330 | condition = torch.cat([code, class_emb], 1) 331 | out = GBlock(out, condition) 332 | 333 | out = self.ScaledCrossReplicaBN(out) 334 | out = F.relu(out) 335 | out = self.colorize(out) 336 | return torch.tanh(out) 337 | 338 | 339 | class Discriminator(nn.Module): 340 | def __init__(self, n_class=1000, chn=96, debug=False): 341 | super().__init__() 342 | 343 | def conv(in_channel, out_channel, downsample=True): 344 | return GBlock(in_channel, out_channel, bn=False, upsample=False, downsample=downsample) 345 | 346 | if debug: 347 | chn = 8 348 | self.debug = debug 349 | 350 | self.pre_conv = nn.Sequential( 351 | SpectralNorm(nn.Conv2d(3, 1 * chn, 3, padding=1)), 352 | nn.ReLU(), 353 | SpectralNorm(nn.Conv2d(1 * chn, 1 * chn, 3, padding=1)), 354 | nn.AvgPool2d(2), 355 | ) 356 | self.pre_skip = SpectralNorm(nn.Conv2d(3, 1 * chn, 1)) 357 | 358 | self.conv = nn.Sequential( 359 | conv(1 * chn, 1 * chn, downsample=True), 360 | conv(1 * chn, 2 * chn, downsample=True), 361 | SelfAttention(2 * chn), 362 | conv(2 * chn, 2 * chn, downsample=True), 363 | conv(2 * chn, 4 * chn, downsample=True), 364 | conv(4 * chn, 8 * chn, downsample=True), 365 | conv(8 * chn, 8 * chn, downsample=True), 366 | conv(8 * chn, 16 * chn, downsample=True), 367 | conv(16 * chn, 16 * chn, downsample=False), 368 | ) 369 | 370 | self.linear = SpectralNorm(nn.Linear(16 * chn, 1)) 371 | 372 | self.embed = nn.Embedding(n_class, 16 * chn) 373 | self.embed.weight.data.uniform_(-0.1, 0.1) 374 | self.embed = SpectralNorm(self.embed) 375 | 376 | def forward(self, input, class_id): 377 | 378 | out = self.pre_conv(input) 379 | out += self.pre_skip(F.avg_pool2d(input, 2)) 380 | out = self.conv(out) 381 | out = F.relu(out) 382 | out = out.view(out.size(0), out.size(1), -1) 383 | out = out.sum(2) 384 | out_linear = self.linear(out).squeeze(1) 385 | embed = self.embed(class_id) 386 | 387 | prod = (out * embed).sum(1) 388 | 389 | return out_linear + prod -------------------------------------------------------------------------------- /DeepI2I_BigGAN/TFHub/converter.py: -------------------------------------------------------------------------------- 1 | """Utilities for converting TFHub BigGAN generator weights to PyTorch. 2 | 3 | Recommended usage: 4 | 5 | To convert all BigGAN variants and generate test samples, use: 6 | 7 | ```bash 8 | CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples 9 | ``` 10 | 11 | See `parse_args` for additional options. 12 | """ 13 | 14 | import argparse 15 | import os 16 | import sys 17 | 18 | import h5py 19 | import torch 20 | import torch.nn as nn 21 | from torchvision.utils import save_image 22 | import tensorflow as tf 23 | import tensorflow_hub as hub 24 | import parse 25 | 26 | # import reference biggan from this folder 27 | import biggan_v1 as biggan_for_conversion 28 | 29 | # Import model from main folder 30 | sys.path.append('..') 31 | import BigGAN 32 | 33 | 34 | 35 | 36 | DEVICE = 'cuda' 37 | HDF5_TMPL = 'biggan-{}.h5' 38 | PTH_TMPL = 'biggan-{}.pth' 39 | MODULE_PATH_TMPL = 'https://tfhub.dev/deepmind/biggan-{}/2' 40 | Z_DIMS = { 41 | 128: 120, 42 | 256: 140, 43 | 512: 128} 44 | RESOLUTIONS = list(Z_DIMS) 45 | 46 | 47 | def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False): 48 | """Loads TFHub weights and saves them to intermediate HDF5 file. 49 | 50 | Args: 51 | module_path ([Path-like]): Path to TFHub module. 52 | hdf5_path ([Path-like]): Path to output HDF5 file. 53 | 54 | Returns: 55 | [h5py.File]: Loaded hdf5 file containing module weights. 56 | """ 57 | if os.path.exists(hdf5_path) and (not redownload): 58 | print('Loading BigGAN hdf5 file from:', hdf5_path) 59 | return h5py.File(hdf5_path, 'r') 60 | 61 | print('Loading BigGAN module from:', module_path) 62 | tf.reset_default_graph() 63 | hub.Module(module_path) 64 | print('Loaded BigGAN module from:', module_path) 65 | 66 | initializer = tf.global_variables_initializer() 67 | sess = tf.Session() 68 | sess.run(initializer) 69 | 70 | print('Saving BigGAN weights to :', hdf5_path) 71 | h5f = h5py.File(hdf5_path, 'w') 72 | for var in tf.global_variables(): 73 | val = sess.run(var) 74 | h5f.create_dataset(var.name, data=val) 75 | print(f'Saving {var.name} with shape {val.shape}') 76 | h5f.close() 77 | return h5py.File(hdf5_path, 'r') 78 | 79 | 80 | class TFHub2Pytorch(object): 81 | 82 | TF_ROOT = 'module' 83 | 84 | NUM_GBLOCK = { 85 | 128: 5, 86 | 256: 6, 87 | 512: 7 88 | } 89 | 90 | w = 'w' 91 | b = 'b' 92 | u = 'u0' 93 | v = 'u1' 94 | gamma = 'gamma' 95 | beta = 'beta' 96 | 97 | def __init__(self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False): 98 | self.state_dict = state_dict 99 | self.tf_weights = tf_weights 100 | self.resolution = resolution 101 | self.verbose = verbose 102 | if load_ema: 103 | for name in ['w', 'b', 'gamma', 'beta']: 104 | setattr(self, name, getattr(self, name) + '/ema_b999900') 105 | 106 | def load(self): 107 | self.load_generator() 108 | return self.state_dict 109 | 110 | def load_generator(self): 111 | GENERATOR_ROOT = os.path.join(self.TF_ROOT, 'Generator') 112 | 113 | for i in range(self.NUM_GBLOCK[self.resolution]): 114 | name_tf = os.path.join(GENERATOR_ROOT, 'GBlock') 115 | name_tf += f'_{i}' if i != 0 else '' 116 | self.load_GBlock(f'GBlock.{i}.', name_tf) 117 | 118 | self.load_attention('attention.', os.path.join(GENERATOR_ROOT, 'attention')) 119 | self.load_linear('linear', os.path.join(self.TF_ROOT, 'linear'), bias=False) 120 | self.load_snlinear('G_linear', os.path.join(GENERATOR_ROOT, 'G_Z', 'G_linear')) 121 | self.load_colorize('colorize', os.path.join(GENERATOR_ROOT, 'conv_2d')) 122 | self.load_ScaledCrossReplicaBNs('ScaledCrossReplicaBN', 123 | os.path.join(GENERATOR_ROOT, 'ScaledCrossReplicaBN')) 124 | 125 | def load_linear(self, name_pth, name_tf, bias=True): 126 | self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0) 127 | if bias: 128 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b) 129 | 130 | def load_snlinear(self, name_pth, name_tf, bias=True): 131 | self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() 132 | self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() 133 | self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0) 134 | if bias: 135 | self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b) 136 | 137 | def load_colorize(self, name_pth, name_tf): 138 | self.load_snconv(name_pth, name_tf) 139 | 140 | def load_GBlock(self, name_pth, name_tf): 141 | self.load_convs(name_pth, name_tf) 142 | self.load_HyperBNs(name_pth, name_tf) 143 | 144 | def load_convs(self, name_pth, name_tf): 145 | self.load_snconv(name_pth + 'conv0', os.path.join(name_tf, 'conv0')) 146 | self.load_snconv(name_pth + 'conv1', os.path.join(name_tf, 'conv1')) 147 | self.load_snconv(name_pth + 'conv_sc', os.path.join(name_tf, 'conv_sc')) 148 | 149 | def load_snconv(self, name_pth, name_tf, bias=True): 150 | if self.verbose: 151 | print(f'loading: {name_pth} from {name_tf}') 152 | self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() 153 | self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() 154 | self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1) 155 | if bias: 156 | self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b).squeeze() 157 | 158 | def load_conv(self, name_pth, name_tf, bias=True): 159 | 160 | self.state_dict[name_pth + '.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() 161 | self.state_dict[name_pth + '.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() 162 | self.state_dict[name_pth + '.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1) 163 | if bias: 164 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b) 165 | 166 | def load_HyperBNs(self, name_pth, name_tf): 167 | self.load_HyperBN(name_pth + 'HyperBN', os.path.join(name_tf, 'HyperBN')) 168 | self.load_HyperBN(name_pth + 'HyperBN_1', os.path.join(name_tf, 'HyperBN_1')) 169 | 170 | def load_ScaledCrossReplicaBNs(self, name_pth, name_tf): 171 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.beta).squeeze() 172 | self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.gamma).squeeze() 173 | self.state_dict[name_pth + '.running_mean'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_mean') 174 | self.state_dict[name_pth + '.running_var'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_var') 175 | self.state_dict[name_pth + '.num_batches_tracked'] = torch.tensor( 176 | self.tf_weights[os.path.join(name_tf + 'bn', 'accumulation_counter:0')][()], dtype=torch.float32) 177 | 178 | def load_HyperBN(self, name_pth, name_tf): 179 | if self.verbose: 180 | print(f'loading: {name_pth} from {name_tf}') 181 | beta = name_pth + '.beta_embed.module' 182 | gamma = name_pth + '.gamma_embed.module' 183 | self.state_dict[beta + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.u).squeeze() 184 | self.state_dict[gamma + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.u).squeeze() 185 | self.state_dict[beta + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.v).squeeze() 186 | self.state_dict[gamma + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.v).squeeze() 187 | self.state_dict[beta + '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.w).permute(1, 0) 188 | self.state_dict[gamma + 189 | '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.w).permute(1, 0) 190 | 191 | cr_bn_name = name_tf.replace('HyperBN', 'CrossReplicaBN') 192 | self.state_dict[name_pth + '.bn.running_mean'] = self.load_tf_tensor(cr_bn_name, 'accumulated_mean') 193 | self.state_dict[name_pth + '.bn.running_var'] = self.load_tf_tensor(cr_bn_name, 'accumulated_var') 194 | self.state_dict[name_pth + '.bn.num_batches_tracked'] = torch.tensor( 195 | self.tf_weights[os.path.join(cr_bn_name, 'accumulation_counter:0')][()], dtype=torch.float32) 196 | 197 | def load_attention(self, name_pth, name_tf): 198 | 199 | self.load_snconv(name_pth + 'theta', os.path.join(name_tf, 'theta'), bias=False) 200 | self.load_snconv(name_pth + 'phi', os.path.join(name_tf, 'phi'), bias=False) 201 | self.load_snconv(name_pth + 'g', os.path.join(name_tf, 'g'), bias=False) 202 | self.load_snconv(name_pth + 'o_conv', os.path.join(name_tf, 'o_conv'), bias=False) 203 | self.state_dict[name_pth + 'gamma'] = self.load_tf_tensor(name_tf, self.gamma) 204 | 205 | def load_tf_tensor(self, prefix, var, device='0'): 206 | name = os.path.join(prefix, var) + f':{device}' 207 | return torch.from_numpy(self.tf_weights[name][:]) 208 | 209 | # Convert from v1: This function maps 210 | def convert_from_v1(hub_dict, resolution=128): 211 | weightname_dict = {'weight_u': 'u0', 'weight_bar': 'weight', 'bias': 'bias'} 212 | convnum_dict = {'conv0': 'conv1', 'conv1': 'conv2', 'conv_sc': 'conv_sc'} 213 | attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution] 214 | hub2me = {'linear.weight': 'shared.weight', # This is actually the shared weight 215 | # Linear stuff 216 | 'G_linear.module.weight_bar': 'linear.weight', 217 | 'G_linear.module.bias': 'linear.bias', 218 | 'G_linear.module.weight_u': 'linear.u0', 219 | # output layer stuff 220 | 'ScaledCrossReplicaBN.weight': 'output_layer.0.gain', 221 | 'ScaledCrossReplicaBN.bias': 'output_layer.0.bias', 222 | 'ScaledCrossReplicaBN.running_mean': 'output_layer.0.stored_mean', 223 | 'ScaledCrossReplicaBN.running_var': 'output_layer.0.stored_var', 224 | 'colorize.module.weight_bar': 'output_layer.2.weight', 225 | 'colorize.module.bias': 'output_layer.2.bias', 226 | 'colorize.module.weight_u': 'output_layer.2.u0', 227 | # Attention stuff 228 | 'attention.gamma': 'blocks.%d.1.gamma' % attention_blocknum, 229 | 'attention.theta.module.weight_u': 'blocks.%d.1.theta.u0' % attention_blocknum, 230 | 'attention.theta.module.weight_bar': 'blocks.%d.1.theta.weight' % attention_blocknum, 231 | 'attention.phi.module.weight_u': 'blocks.%d.1.phi.u0' % attention_blocknum, 232 | 'attention.phi.module.weight_bar': 'blocks.%d.1.phi.weight' % attention_blocknum, 233 | 'attention.g.module.weight_u': 'blocks.%d.1.g.u0' % attention_blocknum, 234 | 'attention.g.module.weight_bar': 'blocks.%d.1.g.weight' % attention_blocknum, 235 | 'attention.o_conv.module.weight_u': 'blocks.%d.1.o.u0' % attention_blocknum, 236 | 'attention.o_conv.module.weight_bar':'blocks.%d.1.o.weight' % attention_blocknum, 237 | } 238 | 239 | # Loop over the hub dict and build the hub2me map 240 | for name in hub_dict.keys(): 241 | if 'GBlock' in name: 242 | if 'HyperBN' not in name: # it's a conv 243 | out = parse.parse('GBlock.{:d}.{}.module.{}',name) 244 | blocknum, convnum, weightname = out 245 | if weightname not in weightname_dict: 246 | continue # else hyperBN in 247 | out_name = 'blocks.%d.0.%s.%s' % (blocknum, convnum_dict[convnum], weightname_dict[weightname]) # Increment conv number by 1 248 | else: # hyperbn not conv 249 | BNnum = 2 if 'HyperBN_1' in name else 1 250 | if 'embed' in name: 251 | out = parse.parse('GBlock.{:d}.{}.module.{}',name) 252 | blocknum, gamma_or_beta, weightname = out 253 | if weightname not in weightname_dict: # Ignore weight_v 254 | continue 255 | out_name = 'blocks.%d.0.bn%d.%s.%s' % (blocknum, BNnum, 'gain' if 'gamma' in gamma_or_beta else 'bias', weightname_dict[weightname]) 256 | else: 257 | out = parse.parse('GBlock.{:d}.{}.bn.{}',name) 258 | blocknum, dummy, mean_or_var = out 259 | if 'num_batches_tracked' in mean_or_var: 260 | continue 261 | out_name = 'blocks.%d.0.bn%d.%s' % (blocknum, BNnum, 'stored_mean' if 'mean' in mean_or_var else 'stored_var') 262 | hub2me[name] = out_name 263 | 264 | 265 | # Invert the hub2me map 266 | me2hub = {hub2me[item]: item for item in hub2me} 267 | new_dict = {} 268 | dimz_dict = {128: 20, 256: 20, 512:16} 269 | for item in me2hub: 270 | # Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs 271 | if ('bn' in item and 'weight' in item) and ('gain' in item or 'bias' in item) and ('output_layer' not in item): 272 | new_dict[item] = torch.cat([hub_dict[me2hub[item]][:, -128:], hub_dict[me2hub[item]][:, :dimz_dict[resolution]]], 1) 273 | # Reshape the first linear weight, bias, and u0 274 | elif item == 'linear.weight': 275 | new_dict[item] = hub_dict[me2hub[item]].contiguous().view(4, 4, 96 * 16, -1).permute(2,0,1,3).contiguous().view(-1,dimz_dict[resolution]) 276 | elif item == 'linear.bias': 277 | new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(-1) 278 | elif item == 'linear.u0': 279 | new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(1, -1) 280 | elif me2hub[item] == 'linear.weight': # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER 281 | # Transpose shared weight so that it's an embedding 282 | new_dict[item] = hub_dict[me2hub[item]].t() 283 | elif 'weight_u' in me2hub[item]: # Unsqueeze u0s 284 | new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0) 285 | else: 286 | new_dict[item] = hub_dict[me2hub[item]] 287 | return new_dict 288 | 289 | def get_config(resolution): 290 | attn_dict = {128: '64', 256: '128', 512: '64'} 291 | dim_z_dict = {128: 120, 256: 140, 512: 128} 292 | config = {'G_param': 'SN', 'D_param': 'SN', 293 | 'G_ch': 96, 'D_ch': 96, 294 | 'D_wide': True, 'G_shared': True, 295 | 'shared_dim': 128, 'dim_z': dim_z_dict[resolution], 296 | 'hier': True, 'cross_replica': False, 297 | 'mybn': False, 'G_activation': nn.ReLU(inplace=True), 298 | 'G_attn': attn_dict[resolution], 299 | 'norm_style': 'bn', 300 | 'G_init': 'ortho', 'skip_init': True, 'no_optim': True, 301 | 'G_fp16': False, 'G_mixed_precision': False, 302 | 'accumulate_stats': False, 'num_standing_accumulations': 16, 303 | 'G_eval_mode': True, 304 | 'BN_eps': 1e-04, 'SN_eps': 1e-04, 305 | 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, 306 | 'n_classes': 1000} 307 | return config 308 | 309 | 310 | def convert_biggan(resolution, weight_dir, redownload=False, no_ema=False, verbose=False): 311 | module_path = MODULE_PATH_TMPL.format(resolution) 312 | hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution)) 313 | pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution)) 314 | 315 | tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload) 316 | G_temp = getattr(biggan_for_conversion, f'Generator{resolution}')() 317 | state_dict_temp = G_temp.state_dict() 318 | 319 | converter = TFHub2Pytorch(state_dict_temp, tf_weights, resolution=resolution, 320 | load_ema=(not no_ema), verbose=verbose) 321 | state_dict_v1 = converter.load() 322 | state_dict = convert_from_v1(state_dict_v1, resolution) 323 | # Get the config, build the model 324 | config = get_config(resolution) 325 | G = BigGAN.Generator(**config) 326 | G.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries 327 | torch.save(state_dict, pth_path) 328 | 329 | # output_location ='pretrained_weights/TFHub-PyTorch-128.pth' 330 | 331 | return G 332 | 333 | 334 | def generate_sample(G, z_dim, batch_size, filename, parallel=False): 335 | 336 | G.eval() 337 | G.to(DEVICE) 338 | with torch.no_grad(): 339 | z = torch.randn(batch_size, G.dim_z).to(DEVICE) 340 | y = torch.randint(low=0, high=1000, size=(batch_size,), 341 | device=DEVICE, dtype=torch.int64, requires_grad=False) 342 | if parallel: 343 | images = nn.parallel.data_parallel(G, (z, G.shared(y))) 344 | else: 345 | images = G(z, G.shared(y)) 346 | save_image(images, filename, scale_each=True, normalize=True) 347 | 348 | def parse_args(): 349 | usage = 'Parser for conversion script.' 350 | parser = argparse.ArgumentParser(description=usage) 351 | parser.add_argument( 352 | '--resolution', '-r', type=int, default=None, choices=[128, 256, 512], 353 | help='Resolution of TFHub module to convert. Converts all resolutions if None.') 354 | parser.add_argument( 355 | '--redownload', action='store_true', default=False, 356 | help='Redownload weights and overwrite current hdf5 file, if present.') 357 | parser.add_argument( 358 | '--weights_dir', type=str, default='pretrained_weights') 359 | parser.add_argument( 360 | '--samples_dir', type=str, default='pretrained_samples') 361 | parser.add_argument( 362 | '--no_ema', action='store_true', default=False, 363 | help='Do not load ema weights.') 364 | parser.add_argument( 365 | '--verbose', action='store_true', default=False, 366 | help='Additionally logging.') 367 | parser.add_argument( 368 | '--generate_samples', action='store_true', default=False, 369 | help='Generate test sample with pretrained model.') 370 | parser.add_argument( 371 | '--batch_size', type=int, default=64, 372 | help='Batch size used for test sample.') 373 | parser.add_argument( 374 | '--parallel', action='store_true', default=False, 375 | help='Parallelize G?') 376 | args = parser.parse_args() 377 | return args 378 | 379 | 380 | if __name__ == '__main__': 381 | 382 | args = parse_args() 383 | os.makedirs(args.weights_dir, exist_ok=True) 384 | os.makedirs(args.samples_dir, exist_ok=True) 385 | 386 | if args.resolution is not None: 387 | G = convert_biggan(args.resolution, args.weights_dir, 388 | redownload=args.redownload, 389 | no_ema=args.no_ema, verbose=args.verbose) 390 | if args.generate_samples: 391 | filename = os.path.join(args.samples_dir, f'biggan{args.resolution}_samples.jpg') 392 | print('Generating samples...') 393 | generate_sample(G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel) 394 | else: 395 | for res in RESOLUTIONS: 396 | G = convert_biggan(res, args.weights_dir, 397 | redownload=args.redownload, 398 | no_ema=args.no_ema, verbose=args.verbose) 399 | if args.generate_samples: 400 | filename = os.path.join(args.samples_dir, f'biggan{res}_samples.jpg') 401 | print('Generating samples...') 402 | generate_sample(G, Z_DIMS[res], args.batch_size, filename, args.parallel) -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/BigGAN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/BigGAN.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/BigGAN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/BigGAN.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/animal_hash.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/animal_hash.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/animal_hash.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/animal_hash.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/inception_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/inception_utils.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/inception_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/inception_utils.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/train_fns.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/train_fns.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/train_fns.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/train_fns.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/calculate_inception_moments.py: -------------------------------------------------------------------------------- 1 | ''' Calculate Inception Moments 2 | This script iterates over the dataset and calculates the moments of the 3 | activations of the Inception net (needed for FID), and also returns 4 | the Inception Score of the training data. 5 | 6 | Note that if you don't shuffle the data, the IS of true data will be under- 7 | estimated as it is label-ordered. By default, the data is not shuffled 8 | so as to reduce non-determinism. ''' 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | import utils 15 | import inception_utils 16 | from tqdm import tqdm, trange 17 | from argparse import ArgumentParser 18 | import pdb 19 | 20 | def prepare_parser(): 21 | usage = 'Calculate and store inception metrics.' 22 | parser = ArgumentParser(description=usage) 23 | parser.add_argument( 24 | '--dataset', type=str, default='I128_hdf5', 25 | help='Which Dataset to train on, out of I128, I256, C10, C100...' 26 | 'Append _hdf5 to use the hdf5 version of the dataset. (default: %(default)s)') 27 | parser.add_argument( 28 | '--data_root', type=str, default='data', 29 | help='Default location where data is stored (default: %(default)s)') 30 | parser.add_argument( 31 | '--batch_size', type=int, default=64, 32 | help='Default overall batchsize (default: %(default)s)') 33 | parser.add_argument( 34 | '--parallel', action='store_true', default=False, 35 | help='Train with multiple GPUs (default: %(default)s)') 36 | parser.add_argument( 37 | '--augment', action='store_true', default=False, 38 | help='Augment with random crops and flips (default: %(default)s)') 39 | parser.add_argument( 40 | '--num_workers', type=int, default=8, 41 | help='Number of dataloader workers (default: %(default)s)') 42 | parser.add_argument( 43 | '--shuffle', action='store_true', default=False, 44 | help='Shuffle the data? (default: %(default)s)') 45 | parser.add_argument( 46 | '--seed', type=int, default=0, 47 | help='Random seed to use.') 48 | return parser 49 | 50 | def run(config): 51 | # Get loader 52 | config['drop_last'] = False 53 | loaders = utils.get_data_loaders(**config) 54 | 55 | # Load inception net 56 | net = inception_utils.load_inception_net(parallel=config['parallel']) 57 | pool, logits, labels = [], [], [] 58 | device = 'cuda' 59 | for i, (x, y) in enumerate(tqdm(loaders[0])): 60 | try: 61 | x = x.to(device) 62 | with torch.no_grad(): 63 | pool_val, logits_val = net(x) 64 | pool += [np.asarray(pool_val.cpu())] 65 | logits += [np.asarray(F.softmax(logits_val, 1).cpu())] 66 | labels += [np.asarray(y.cpu())] 67 | except Exception as e: 68 | x = x.to(device) 69 | with torch.no_grad(): 70 | pool_val, logits_val = net(x) 71 | pool += [np.asarray(pool_val.cpu())] 72 | logits += [np.asarray(F.softmax(logits_val, 1).cpu())] 73 | labels += [np.asarray(y.cpu())] 74 | 75 | pool, logits, labels = [np.concatenate(item, 0) for item in [pool, logits, labels]] 76 | # uncomment to save pool, logits, and labels to disk 77 | print('Saving pool, logits, and labels to disk...') 78 | np.savez(config['dataset']+'_inception_activations.npz', 79 | {'pool': pool, 'logits': logits, 'labels': labels}) 80 | # Calculate inception metrics and report them 81 | print('Calculating inception metrics...') 82 | IS_mean, IS_std = inception_utils.calculate_inception_score(logits) 83 | print('Training data from dataset %s has IS of %5.5f +/- %5.5f' % (config['dataset'], IS_mean, IS_std)) 84 | # Prepare mu and sigma, save to disk. Remove "hdf5" by default 85 | # (the FID code also knows to strip "hdf5") 86 | print('Calculating means and covariances...') 87 | mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False) 88 | print('Saving calculated means and covariances to disk...') 89 | np.savez(config['dataset'].strip('_hdf5')+'_inception_moments.npz', **{'mu' : mu, 'sigma' : sigma}) 90 | 91 | def main(): 92 | # parse command line 93 | parser = prepare_parser() 94 | config = vars(parser.parse_args()) 95 | print(config) 96 | run(config) 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/class_to_index/DeepI2I_NABirds/I128_imgs.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/class_to_index/DeepI2I_NABirds/I128_imgs.pickle -------------------------------------------------------------------------------- /DeepI2I_BigGAN/class_to_index/DeepI2I_UECFOOD256/I128_imgs.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/class_to_index/DeepI2I_UECFOOD256/I128_imgs.pickle -------------------------------------------------------------------------------- /DeepI2I_BigGAN/class_to_index/DeepI2I_animals/I128_imgs.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/class_to_index/DeepI2I_animals/I128_imgs.pickle -------------------------------------------------------------------------------- /DeepI2I_BigGAN/datasets.py: -------------------------------------------------------------------------------- 1 | ''' Datasets 2 | This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets 3 | ''' 4 | import os 5 | import os.path 6 | import sys 7 | from PIL import Image 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | import numpy as np 11 | from tqdm import tqdm, trange 12 | 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as transforms 15 | from torchvision.datasets.utils import download_url, check_integrity 16 | import torch.utils.data as data 17 | from torch.utils.data import DataLoader 18 | 19 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 20 | 21 | 22 | def is_image_file(filename): 23 | """Checks if a file is an image. 24 | 25 | Args: 26 | filename (string): path to a file 27 | 28 | Returns: 29 | bool: True if the filename ends with a known image extension 30 | """ 31 | filename_lower = filename.lower() 32 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 33 | 34 | 35 | def find_classes(dir): 36 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 37 | classes.sort() 38 | class_to_idx = {classes[i]: i for i in range(len(classes))} 39 | return classes, class_to_idx 40 | 41 | 42 | def make_dataset(dir, class_to_idx): 43 | images = [] 44 | dir = os.path.expanduser(dir) 45 | for target in tqdm(sorted(os.listdir(dir))): 46 | d = os.path.join(dir, target) 47 | if not os.path.isdir(d): 48 | continue 49 | 50 | for root, _, fnames in sorted(os.walk(d)): 51 | for fname in sorted(fnames): 52 | if is_image_file(fname): 53 | path = os.path.join(root, fname) 54 | item = (path, class_to_idx[target]) 55 | images.append(item) 56 | 57 | return images 58 | 59 | 60 | def pil_loader(path): 61 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 62 | with open(path, 'rb') as f: 63 | img = Image.open(f) 64 | return img.convert('RGB') 65 | 66 | 67 | def accimage_loader(path): 68 | import accimage 69 | try: 70 | return accimage.Image(path) 71 | except IOError: 72 | # Potentially a decoding problem, fall back to PIL.Image 73 | return pil_loader(path) 74 | 75 | 76 | def default_loader(path): 77 | from torchvision import get_image_backend 78 | if get_image_backend() == 'accimage': 79 | return accimage_loader(path) 80 | else: 81 | return pil_loader(path) 82 | 83 | 84 | class ImageFolder(data.Dataset): 85 | """A generic data loader where the images are arranged in this way: :: 86 | 87 | root/dogball/xxx.png 88 | root/dogball/xxy.png 89 | root/dogball/xxz.png 90 | 91 | root/cat/123.png 92 | root/cat/nsdf3.png 93 | root/cat/asd932_.png 94 | 95 | Args: 96 | root (string): Root directory path. 97 | transform (callable, optional): A function/transform that takes in an PIL image 98 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 99 | target_transform (callable, optional): A function/transform that takes in the 100 | target and transforms it. 101 | loader (callable, optional): A function to load an image given its path. 102 | 103 | Attributes: 104 | classes (list): List of the class names. 105 | class_to_idx (dict): Dict with items (class_name, class_index). 106 | imgs (list): List of (image path, class_index) tuples 107 | """ 108 | 109 | def __init__(self, root, transform=None, target_transform=None, 110 | loader=default_loader, load_in_mem=False, 111 | index_filename='imagenet_imgs.npz', **kwargs): 112 | classes, class_to_idx = find_classes(root) 113 | # Load pre-computed image directory walk 114 | if os.path.exists(index_filename): 115 | print('Loading pre-saved Index file %s...' % index_filename) 116 | imgs = np.load(index_filename)['imgs'] 117 | # If first time, walk the folder directory and save the 118 | # results to a pre-computed file. 119 | else: 120 | print('Generating Index file %s...' % index_filename) 121 | imgs = make_dataset(root, class_to_idx) 122 | np.savez_compressed(index_filename, **{'imgs' : imgs}) 123 | if len(imgs) == 0: 124 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 125 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 126 | 127 | self.root = root 128 | self.imgs = imgs 129 | self.classes = classes 130 | self.class_to_idx = class_to_idx 131 | self.transform = transform 132 | self.target_transform = target_transform 133 | self.loader = loader 134 | self.load_in_mem = load_in_mem 135 | 136 | if self.load_in_mem: 137 | print('Loading all images into memory...') 138 | self.data, self.labels = [], [] 139 | for index in tqdm(range(len(self.imgs))): 140 | path, target = imgs[index][0], imgs[index][1] 141 | self.data.append(self.transform(self.loader(path))) 142 | self.labels.append(target) 143 | 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | 150 | Returns: 151 | tuple: (image, target) where target is class_index of the target class. 152 | """ 153 | if self.load_in_mem: 154 | img = self.data[index] 155 | target = self.labels[index] 156 | else: 157 | path, target = self.imgs[index] 158 | img = self.loader(str(path)) 159 | if self.transform is not None: 160 | img = self.transform(img) 161 | 162 | if self.target_transform is not None: 163 | target = self.target_transform(target) 164 | 165 | # print(img.size(), target) 166 | return img, int(target) 167 | 168 | def __len__(self): 169 | return len(self.imgs) 170 | 171 | def __repr__(self): 172 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 173 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 174 | fmt_str += ' Root Location: {}\n'.format(self.root) 175 | tmp = ' Transforms (if any): ' 176 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 177 | tmp = ' Target Transforms (if any): ' 178 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 179 | return fmt_str 180 | 181 | 182 | ''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid 183 | having to load individual images all the time. ''' 184 | import h5py as h5 185 | import torch 186 | class ILSVRC_HDF5(data.Dataset): 187 | def __init__(self, root, transform=None, target_transform=None, 188 | load_in_mem=False, train=True,download=False, validate_seed=0, 189 | val_split=0, **kwargs): # last four are dummies 190 | 191 | self.root = root 192 | self.num_imgs = len(h5.File(root, 'r')['labels']) 193 | 194 | # self.transform = transform 195 | self.target_transform = target_transform 196 | 197 | # Set the transform here 198 | self.transform = transform 199 | 200 | # load the entire dataset into memory? 201 | self.load_in_mem = load_in_mem 202 | 203 | # If loading into memory, do so now 204 | if self.load_in_mem: 205 | print('Loading %s into memory...' % root) 206 | with h5.File(root,'r') as f: 207 | self.data = f['imgs'][:] 208 | self.labels = f['labels'][:] 209 | 210 | def __getitem__(self, index): 211 | """ 212 | Args: 213 | index (int): Index 214 | 215 | Returns: 216 | tuple: (image, target) where target is class_index of the target class. 217 | """ 218 | # If loaded the entire dataset in RAM, get image from memory 219 | if self.load_in_mem: 220 | img = self.data[index] 221 | target = self.labels[index] 222 | 223 | # Else load it from disk 224 | else: 225 | with h5.File(self.root,'r') as f: 226 | img = f['imgs'][index] 227 | target = f['labels'][index] 228 | 229 | 230 | # if self.transform is not None: 231 | # img = self.transform(img) 232 | # Apply my own transform 233 | img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2 234 | 235 | if self.target_transform is not None: 236 | target = self.target_transform(target) 237 | 238 | return img, int(target) 239 | 240 | def __len__(self): 241 | return self.num_imgs 242 | # return len(self.f['imgs']) 243 | 244 | import pickle 245 | class CIFAR10(dset.CIFAR10): 246 | 247 | def __init__(self, root, train=True, 248 | transform=None, target_transform=None, 249 | download=True, validate_seed=0, 250 | val_split=0, load_in_mem=True, **kwargs): 251 | self.root = os.path.expanduser(root) 252 | self.transform = transform 253 | self.target_transform = target_transform 254 | self.train = train # training set or test set 255 | self.val_split = val_split 256 | 257 | if download: 258 | self.download() 259 | 260 | if not self._check_integrity(): 261 | raise RuntimeError('Dataset not found or corrupted.' + 262 | ' You can use download=True to download it') 263 | 264 | # now load the picked numpy arrays 265 | self.data = [] 266 | self.labels= [] 267 | for fentry in self.train_list: 268 | f = fentry[0] 269 | file = os.path.join(self.root, self.base_folder, f) 270 | fo = open(file, 'rb') 271 | if sys.version_info[0] == 2: 272 | entry = pickle.load(fo) 273 | else: 274 | entry = pickle.load(fo, encoding='latin1') 275 | self.data.append(entry['data']) 276 | if 'labels' in entry: 277 | self.labels += entry['labels'] 278 | else: 279 | self.labels += entry['fine_labels'] 280 | fo.close() 281 | 282 | self.data = np.concatenate(self.data) 283 | # Randomly select indices for validation 284 | if self.val_split > 0: 285 | label_indices = [[] for _ in range(max(self.labels)+1)] 286 | for i,l in enumerate(self.labels): 287 | label_indices[l] += [i] 288 | label_indices = np.asarray(label_indices) 289 | 290 | # randomly grab 500 elements of each class 291 | np.random.seed(validate_seed) 292 | self.val_indices = [] 293 | for l_i in label_indices: 294 | self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)]) 295 | 296 | if self.train=='validate': 297 | self.data = self.data[self.val_indices] 298 | self.labels = list(np.asarray(self.labels)[self.val_indices]) 299 | 300 | self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32)) 301 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 302 | 303 | elif self.train: 304 | print(np.shape(self.data)) 305 | if self.val_split > 0: 306 | self.data = np.delete(self.data,self.val_indices,axis=0) 307 | self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0)) 308 | 309 | self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32)) 310 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 311 | else: 312 | f = self.test_list[0][0] 313 | file = os.path.join(self.root, self.base_folder, f) 314 | fo = open(file, 'rb') 315 | if sys.version_info[0] == 2: 316 | entry = pickle.load(fo) 317 | else: 318 | entry = pickle.load(fo, encoding='latin1') 319 | self.data = entry['data'] 320 | if 'labels' in entry: 321 | self.labels = entry['labels'] 322 | else: 323 | self.labels = entry['fine_labels'] 324 | fo.close() 325 | self.data = self.data.reshape((10000, 3, 32, 32)) 326 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 327 | 328 | def __getitem__(self, index): 329 | """ 330 | Args: 331 | index (int): Index 332 | Returns: 333 | tuple: (image, target) where target is index of the target class. 334 | """ 335 | img, target = self.data[index], self.labels[index] 336 | 337 | # doing this so that it is consistent with all other datasets 338 | # to return a PIL Image 339 | img = Image.fromarray(img) 340 | 341 | if self.transform is not None: 342 | img = self.transform(img) 343 | 344 | if self.target_transform is not None: 345 | target = self.target_transform(target) 346 | 347 | return img, target 348 | 349 | def __len__(self): 350 | return len(self.data) 351 | 352 | 353 | class CIFAR100(CIFAR10): 354 | base_folder = 'cifar-100-python' 355 | url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 356 | filename = "cifar-100-python.tar.gz" 357 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 358 | train_list = [ 359 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 360 | ] 361 | 362 | test_list = [ 363 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 364 | ] 365 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/inception_tf13.py: -------------------------------------------------------------------------------- 1 | ''' Tensorflow inception score code 2 | Derived from https://github.com/openai/improved-gan 3 | Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 4 | THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE 5 | 6 | To use this code, run sample.py on your model with --sample_npz, and then 7 | pass the experiment name in the --experiment_name. 8 | This code also saves pool3 stats to an npz file for FID calculation 9 | ''' 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os.path 15 | import sys 16 | import tarfile 17 | import math 18 | from tqdm import tqdm, trange 19 | from argparse import ArgumentParser 20 | 21 | import numpy as np 22 | from six.moves import urllib 23 | import tensorflow as tf 24 | 25 | MODEL_DIR = '' 26 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 27 | softmax = None 28 | 29 | def prepare_parser(): 30 | usage = 'Parser for TF1.3- Inception Score scripts.' 31 | parser = ArgumentParser(description=usage) 32 | parser.add_argument( 33 | '--experiment_name', type=str, default='', 34 | help='Which experiment''s samples.npz file to pull and evaluate') 35 | parser.add_argument( 36 | '--experiment_root', type=str, default='samples', 37 | help='Default location where samples are stored (default: %(default)s)') 38 | parser.add_argument( 39 | '--batch_size', type=int, default=500, 40 | help='Default overall batchsize (default: %(default)s)') 41 | return parser 42 | 43 | 44 | def run(config): 45 | # Inception with TF1.3 or earlier. 46 | # Call this function with list of images. Each of elements should be a 47 | # numpy array with values ranging from 0 to 255. 48 | def get_inception_score(images, splits=10): 49 | assert(type(images) == list) 50 | assert(type(images[0]) == np.ndarray) 51 | assert(len(images[0].shape) == 3) 52 | assert(np.max(images[0]) > 10) 53 | assert(np.min(images[0]) >= 0.0) 54 | inps = [] 55 | for img in images: 56 | img = img.astype(np.float32) 57 | inps.append(np.expand_dims(img, 0)) 58 | bs = config['batch_size'] 59 | with tf.Session() as sess: 60 | preds, pools = [], [] 61 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 62 | for i in trange(n_batches): 63 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 64 | inp = np.concatenate(inp, 0) 65 | pred, pool = sess.run([softmax, pool3], {'ExpandDims:0': inp}) 66 | preds.append(pred) 67 | pools.append(pool) 68 | preds = np.concatenate(preds, 0) 69 | scores = [] 70 | for i in range(splits): 71 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 72 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 73 | kl = np.mean(np.sum(kl, 1)) 74 | scores.append(np.exp(kl)) 75 | return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0)) 76 | # Init inception 77 | def _init_inception(): 78 | global softmax, pool3 79 | if not os.path.exists(MODEL_DIR): 80 | os.makedirs(MODEL_DIR) 81 | filename = DATA_URL.split('/')[-1] 82 | filepath = os.path.join(MODEL_DIR, filename) 83 | if not os.path.exists(filepath): 84 | def _progress(count, block_size, total_size): 85 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 86 | filename, float(count * block_size) / float(total_size) * 100.0)) 87 | sys.stdout.flush() 88 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 89 | print() 90 | statinfo = os.stat(filepath) 91 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 92 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 93 | with tf.gfile.FastGFile(os.path.join( 94 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 95 | graph_def = tf.GraphDef() 96 | graph_def.ParseFromString(f.read()) 97 | _ = tf.import_graph_def(graph_def, name='') 98 | # Works with an arbitrary minibatch size. 99 | with tf.Session() as sess: 100 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 101 | ops = pool3.graph.get_operations() 102 | for op_idx, op in enumerate(ops): 103 | for o in op.outputs: 104 | shape = o.get_shape() 105 | shape = [s.value for s in shape] 106 | new_shape = [] 107 | for j, s in enumerate(shape): 108 | if s == 1 and j == 0: 109 | new_shape.append(None) 110 | else: 111 | new_shape.append(s) 112 | o._shape = tf.TensorShape(new_shape) 113 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 114 | logits = tf.matmul(tf.squeeze(pool3), w) 115 | softmax = tf.nn.softmax(logits) 116 | 117 | # if softmax is None: # No need to functionalize like this. 118 | _init_inception() 119 | 120 | fname = '%s/%s/samples.npz' % (config['experiment_root'], config['experiment_name']) 121 | print('loading %s ...'%fname) 122 | ims = np.load(fname)['x'] 123 | import time 124 | t0 = time.time() 125 | inc_mean, inc_std, pool_activations = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=10) 126 | t1 = time.time() 127 | print('Saving pool to numpy file for FID calculations...') 128 | np.savez('%s/%s/TF_pool.npz' % (config['experiment_root'], config['experiment_name']), **{'pool_mean': np.mean(pool_activations,axis=0), 'pool_var': np.cov(pool_activations, rowvar=False)}) 129 | print('Inception took %3f seconds, score of %3f +/- %3f.'%(t1-t0, inc_mean, inc_std)) 130 | def main(): 131 | # parse command line and run 132 | parser = prepare_parser() 133 | config = vars(parser.parse_args()) 134 | print(config) 135 | run(config) 136 | 137 | if __name__ == '__main__': 138 | main() -------------------------------------------------------------------------------- /DeepI2I_BigGAN/inception_utils.py: -------------------------------------------------------------------------------- 1 | ''' Inception utilities 2 | This file contains methods for calculating IS and FID, using either 3 | the original numpy code or an accelerated fully-pytorch version that 4 | uses a fast newton-schulz approximation for the matrix sqrt. There are also 5 | methods for acquiring a desired number of samples from the Generator, 6 | and parallelizing the inbuilt PyTorch inception network. 7 | 8 | NOTE that Inception Scores and FIDs calculated using these methods will 9 | *not* be directly comparable to values calculated using the original TF 10 | IS/FID code. You *must* use the TF model if you wish to report and compare 11 | numbers. This code tends to produce IS values that are 5-10% lower than 12 | those obtained through TF. 13 | ''' 14 | import numpy as np 15 | from scipy import linalg # For numpy FID 16 | import time 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.nn import Parameter as P 22 | from torchvision.models.inception import inception_v3 23 | 24 | 25 | # Module that wraps the inception network to enable use with dataparallel and 26 | # returning pool features and logits. 27 | class WrapInception(nn.Module): 28 | def __init__(self, net): 29 | super(WrapInception,self).__init__() 30 | self.net = net 31 | self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1), 32 | requires_grad=False) 33 | self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1), 34 | requires_grad=False) 35 | def forward(self, x): 36 | # Normalize x 37 | x = (x + 1.) / 2.0 38 | x = (x - self.mean) / self.std 39 | # Upsample if necessary 40 | if x.shape[2] != 299 or x.shape[3] != 299: 41 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 42 | # 299 x 299 x 3 43 | x = self.net.Conv2d_1a_3x3(x) 44 | # 149 x 149 x 32 45 | x = self.net.Conv2d_2a_3x3(x) 46 | # 147 x 147 x 32 47 | x = self.net.Conv2d_2b_3x3(x) 48 | # 147 x 147 x 64 49 | x = F.max_pool2d(x, kernel_size=3, stride=2) 50 | # 73 x 73 x 64 51 | x = self.net.Conv2d_3b_1x1(x) 52 | # 73 x 73 x 80 53 | x = self.net.Conv2d_4a_3x3(x) 54 | # 71 x 71 x 192 55 | x = F.max_pool2d(x, kernel_size=3, stride=2) 56 | # 35 x 35 x 192 57 | x = self.net.Mixed_5b(x) 58 | # 35 x 35 x 256 59 | x = self.net.Mixed_5c(x) 60 | # 35 x 35 x 288 61 | x = self.net.Mixed_5d(x) 62 | # 35 x 35 x 288 63 | x = self.net.Mixed_6a(x) 64 | # 17 x 17 x 768 65 | x = self.net.Mixed_6b(x) 66 | # 17 x 17 x 768 67 | x = self.net.Mixed_6c(x) 68 | # 17 x 17 x 768 69 | x = self.net.Mixed_6d(x) 70 | # 17 x 17 x 768 71 | x = self.net.Mixed_6e(x) 72 | # 17 x 17 x 768 73 | # 17 x 17 x 768 74 | x = self.net.Mixed_7a(x) 75 | # 8 x 8 x 1280 76 | x = self.net.Mixed_7b(x) 77 | # 8 x 8 x 2048 78 | x = self.net.Mixed_7c(x) 79 | # 8 x 8 x 2048 80 | pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2) 81 | # 1 x 1 x 2048 82 | logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1)) 83 | # 1000 (num_classes) 84 | return pool, logits 85 | 86 | 87 | # A pytorch implementation of cov, from Modar M. Alfadly 88 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 89 | def torch_cov(m, rowvar=False): 90 | '''Estimate a covariance matrix given data. 91 | 92 | Covariance indicates the level to which two variables vary together. 93 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 94 | then the covariance matrix element `C_{ij}` is the covariance of 95 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 96 | 97 | Args: 98 | m: A 1-D or 2-D array containing multiple variables and observations. 99 | Each row of `m` represents a variable, and each column a single 100 | observation of all those variables. 101 | rowvar: If `rowvar` is True, then each row represents a 102 | variable, with observations in the columns. Otherwise, the 103 | relationship is transposed: each column represents a variable, 104 | while the rows contain observations. 105 | 106 | Returns: 107 | The covariance matrix of the variables. 108 | ''' 109 | if m.dim() > 2: 110 | raise ValueError('m has more than 2 dimensions') 111 | if m.dim() < 2: 112 | m = m.view(1, -1) 113 | if not rowvar and m.size(0) != 1: 114 | m = m.t() 115 | # m = m.type(torch.double) # uncomment this line if desired 116 | fact = 1.0 / (m.size(1) - 1) 117 | m -= torch.mean(m, dim=1, keepdim=True) 118 | mt = m.t() # if complex: mt = m.t().conj() 119 | return fact * m.matmul(mt).squeeze() 120 | 121 | 122 | # Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji 123 | # https://github.com/msubhransu/matrix-sqrt 124 | def sqrt_newton_schulz(A, numIters, dtype=None): 125 | with torch.no_grad(): 126 | if dtype is None: 127 | dtype = A.type() 128 | batchSize = A.shape[0] 129 | dim = A.shape[1] 130 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() 131 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)); 132 | I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 133 | Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 134 | for i in range(numIters): 135 | T = 0.5*(3.0*I - Z.bmm(Y)) 136 | Y = Y.bmm(T) 137 | Z = T.bmm(Z) 138 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 139 | return sA 140 | 141 | 142 | # FID calculator from TTUR--consider replacing this with GPU-accelerated cov 143 | # calculations using torch? 144 | def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 145 | """Numpy implementation of the Frechet Distance. 146 | Taken from https://github.com/bioinf-jku/TTUR 147 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 148 | and X_2 ~ N(mu_2, C_2) is 149 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 150 | Stable version by Dougal J. Sutherland. 151 | Params: 152 | -- mu1 : Numpy array containing the activations of a layer of the 153 | inception net (like returned by the function 'get_predictions') 154 | for generated samples. 155 | -- mu2 : The sample mean over activations, precalculated on an 156 | representive data set. 157 | -- sigma1: The covariance matrix over activations for generated samples. 158 | -- sigma2: The covariance matrix over activations, precalculated on an 159 | representive data set. 160 | Returns: 161 | -- : The Frechet Distance. 162 | """ 163 | 164 | mu1 = np.atleast_1d(mu1) 165 | mu2 = np.atleast_1d(mu2) 166 | 167 | sigma1 = np.atleast_2d(sigma1) 168 | sigma2 = np.atleast_2d(sigma2) 169 | 170 | assert mu1.shape == mu2.shape, \ 171 | 'Training and test mean vectors have different lengths' 172 | assert sigma1.shape == sigma2.shape, \ 173 | 'Training and test covariances have different dimensions' 174 | 175 | diff = mu1 - mu2 176 | 177 | # Product might be almost singular 178 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 179 | if not np.isfinite(covmean).all(): 180 | msg = ('fid calculation produces singular product; ' 181 | 'adding %s to diagonal of cov estimates') % eps 182 | print(msg) 183 | offset = np.eye(sigma1.shape[0]) * eps 184 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 185 | 186 | # Numerical error might give slight imaginary component 187 | if np.iscomplexobj(covmean): 188 | print('wat') 189 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 190 | m = np.max(np.abs(covmean.imag)) 191 | raise ValueError('Imaginary component {}'.format(m)) 192 | covmean = covmean.real 193 | 194 | tr_covmean = np.trace(covmean) 195 | 196 | out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 197 | return out 198 | 199 | 200 | def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 201 | """Pytorch implementation of the Frechet Distance. 202 | Taken from https://github.com/bioinf-jku/TTUR 203 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 204 | and X_2 ~ N(mu_2, C_2) is 205 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 206 | Stable version by Dougal J. Sutherland. 207 | Params: 208 | -- mu1 : Numpy array containing the activations of a layer of the 209 | inception net (like returned by the function 'get_predictions') 210 | for generated samples. 211 | -- mu2 : The sample mean over activations, precalculated on an 212 | representive data set. 213 | -- sigma1: The covariance matrix over activations for generated samples. 214 | -- sigma2: The covariance matrix over activations, precalculated on an 215 | representive data set. 216 | Returns: 217 | -- : The Frechet Distance. 218 | """ 219 | 220 | 221 | assert mu1.shape == mu2.shape, \ 222 | 'Training and test mean vectors have different lengths' 223 | assert sigma1.shape == sigma2.shape, \ 224 | 'Training and test covariances have different dimensions' 225 | 226 | diff = mu1 - mu2 227 | # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2 228 | covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze() 229 | out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) 230 | - 2 * torch.trace(covmean)) 231 | return out 232 | 233 | 234 | # Calculate Inception Score mean + std given softmax'd logits and number of splits 235 | def calculate_inception_score(pred, num_splits=10): 236 | scores = [] 237 | for index in range(num_splits): 238 | pred_chunk = pred[index * (pred.shape[0] // num_splits): (index + 1) * (pred.shape[0] // num_splits), :] 239 | kl_inception = pred_chunk * (np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0))) 240 | kl_inception = np.mean(np.sum(kl_inception, 1)) 241 | scores.append(np.exp(kl_inception)) 242 | return np.mean(scores), np.std(scores) 243 | 244 | 245 | # Loop and run the sampler and the net until it accumulates num_inception_images 246 | # activations. Return the pool, the logits, and the labels (if one wants 247 | # Inception Accuracy the labels of the generated class will be needed) 248 | def accumulate_inception_activations(sample, net, num_inception_images=50000): 249 | pool, logits, labels = [], [], [] 250 | while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images: 251 | with torch.no_grad(): 252 | images, labels_val = sample() 253 | pool_val, logits_val = net(images.float()) 254 | pool += [pool_val] 255 | logits += [F.softmax(logits_val, 1)] 256 | labels += [labels_val] 257 | return torch.cat(pool, 0), torch.cat(logits, 0), torch.cat(labels, 0) 258 | 259 | 260 | # Load and wrap the Inception model 261 | def load_inception_net(parallel=False): 262 | inception_model = inception_v3(pretrained=True, transform_input=False) 263 | inception_model = WrapInception(inception_model.eval()).cuda() 264 | if parallel: 265 | print('Parallelizing Inception module...') 266 | inception_model = nn.DataParallel(inception_model) 267 | return inception_model 268 | 269 | 270 | # This produces a function which takes in an iterator which returns a set number of samples 271 | # and iterates until it accumulates config['num_inception_images'] images. 272 | # The iterator can return samples with a different batch size than used in 273 | # training, using the setting confg['inception_batchsize'] 274 | def prepare_inception_metrics(dataset, parallel, no_fid=False): 275 | # Load metrics; this is intentionally not in a try-except loop so that 276 | # the script will crash here if it cannot find the Inception moments. 277 | # By default, remove the "hdf5" from dataset 278 | dataset = dataset.strip('_hdf5') 279 | data_mu = np.load(dataset+'_inception_moments.npz')['mu'] 280 | data_sigma = np.load(dataset+'_inception_moments.npz')['sigma'] 281 | # Load network 282 | net = load_inception_net(parallel) 283 | def get_inception_metrics(sample, num_inception_images, num_splits=10, 284 | prints=True, use_torch=True): 285 | if prints: 286 | print('Gathering activations...') 287 | pool, logits, labels = accumulate_inception_activations(sample, net, num_inception_images) 288 | if prints: 289 | print('Calculating Inception Score...') 290 | IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits) 291 | if no_fid: 292 | FID = 9999.0 293 | else: 294 | if prints: 295 | print('Calculating means and covariances...') 296 | if use_torch: 297 | mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False) 298 | else: 299 | mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False) 300 | if prints: 301 | print('Covariances calculated, getting FID...') 302 | if use_torch: 303 | FID = torch_calculate_frechet_distance(mu, sigma, torch.tensor(data_mu).float().cuda(), torch.tensor(data_sigma).float().cuda()) 304 | FID = float(FID.cpu().numpy()) 305 | else: 306 | FID = numpy_calculate_frechet_distance(mu.cpu().numpy(), sigma.cpu().numpy(), data_mu, data_sigma) 307 | # Delete mu, sigma, pool, logits, and labels, just in case 308 | del mu, sigma, pool, logits, labels 309 | return IS_mean, IS_std, FID 310 | return get_inception_metrics 311 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/layers.py: -------------------------------------------------------------------------------- 1 | ''' Layers 2 | This file contains various layers for the BigGAN models. 3 | ''' 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.nn import Parameter as P 11 | 12 | from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d 13 | 14 | 15 | # Projection of x onto y 16 | def proj(x, y): 17 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) 18 | 19 | 20 | # Orthogonalize x wrt list of vectors ys 21 | def gram_schmidt(x, ys): 22 | for y in ys: 23 | x = x - proj(x, y) 24 | return x 25 | 26 | 27 | # Apply num_itrs steps of the power method to estimate top N singular values. 28 | def power_iteration(W, u_, update=True, eps=1e-12): 29 | # Lists holding singular vectors and values 30 | us, vs, svs = [], [], [] 31 | for i, u in enumerate(u_): 32 | # Run one step of the power iteration 33 | with torch.no_grad(): 34 | v = torch.matmul(u, W) 35 | # Run Gram-Schmidt to subtract components of all other singular vectors 36 | v = F.normalize(gram_schmidt(v, vs), eps=eps) 37 | # Add to the list 38 | vs += [v] 39 | # Update the other singular vector 40 | u = torch.matmul(v, W.t()) 41 | # Run Gram-Schmidt to subtract components of all other singular vectors 42 | u = F.normalize(gram_schmidt(u, us), eps=eps) 43 | # Add to the list 44 | us += [u] 45 | if update: 46 | u_[i][:] = u 47 | # Compute this singular value and add it to the list 48 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] 49 | #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] 50 | return svs, us, vs 51 | 52 | 53 | # Convenience passthrough function 54 | class identity(nn.Module): 55 | def forward(self, input): 56 | return input 57 | 58 | 59 | # Spectral normalization base class 60 | class SN(object): 61 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): 62 | # Number of power iterations per step 63 | self.num_itrs = num_itrs 64 | # Number of singular values 65 | self.num_svs = num_svs 66 | # Transposed? 67 | self.transpose = transpose 68 | # Epsilon value for avoiding divide-by-0 69 | self.eps = eps 70 | # Register a singular vector for each sv 71 | for i in range(self.num_svs): 72 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) 73 | self.register_buffer('sv%d' % i, torch.ones(1)) 74 | 75 | # Singular vectors (u side) 76 | @property 77 | def u(self): 78 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] 79 | 80 | # Singular values; 81 | # note that these buffers are just for logging and are not used in training. 82 | @property 83 | def sv(self): 84 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] 85 | 86 | # Compute the spectrally-normalized weight 87 | def W_(self): 88 | W_mat = self.weight.view(self.weight.size(0), -1) 89 | if self.transpose: 90 | W_mat = W_mat.t() 91 | # Apply num_itrs power iterations 92 | for _ in range(self.num_itrs): 93 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) 94 | # Update the svs 95 | if self.training: 96 | with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks! 97 | for i, sv in enumerate(svs): 98 | self.sv[i][:] = sv 99 | return self.weight / svs[0] 100 | 101 | 102 | # 2D Conv layer with spectral norm 103 | class SNConv2d(nn.Conv2d, SN): 104 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 105 | padding=0, dilation=1, groups=1, bias=True, 106 | num_svs=1, num_itrs=1, eps=1e-12): 107 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, 108 | padding, dilation, groups, bias) 109 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) 110 | def forward(self, x): 111 | return F.conv2d(x, self.W_(), self.bias, self.stride, 112 | self.padding, self.dilation, self.groups) 113 | 114 | 115 | # Linear layer with spectral norm 116 | class SNLinear(nn.Linear, SN): 117 | def __init__(self, in_features, out_features, bias=True, 118 | num_svs=1, num_itrs=1, eps=1e-12): 119 | nn.Linear.__init__(self, in_features, out_features, bias) 120 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) 121 | def forward(self, x): 122 | return F.linear(x, self.W_(), self.bias) 123 | 124 | 125 | # Embedding layer with spectral norm 126 | # We use num_embeddings as the dim instead of embedding_dim here 127 | # for convenience sake 128 | class SNEmbedding(nn.Embedding, SN): 129 | def __init__(self, num_embeddings, embedding_dim, padding_idx=None, 130 | max_norm=None, norm_type=2, scale_grad_by_freq=False, 131 | sparse=False, _weight=None, 132 | num_svs=1, num_itrs=1, eps=1e-12): 133 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, 134 | max_norm, norm_type, scale_grad_by_freq, 135 | sparse, _weight) 136 | SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) 137 | def forward(self, x): 138 | return F.embedding(x, self.W_()) 139 | 140 | 141 | # A non-local block as used in SA-GAN 142 | # Note that the implementation as described in the paper is largely incorrect; 143 | # refer to the released code for the actual implementation. 144 | class Attention(nn.Module): 145 | def __init__(self, ch, which_conv=SNConv2d, name='attention'): 146 | super(Attention, self).__init__() 147 | # Channel multiplier 148 | self.ch = ch 149 | self.which_conv = which_conv 150 | self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) 151 | self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) 152 | self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) 153 | self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) 154 | # Learnable gain parameter 155 | self.gamma = P(torch.tensor(0.), requires_grad=True) 156 | def forward(self, x, y=None): 157 | # Apply convs 158 | theta = self.theta(x) 159 | phi = F.max_pool2d(self.phi(x), [2,2]) 160 | g = F.max_pool2d(self.g(x), [2,2]) 161 | # Perform reshapes 162 | theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3]) 163 | phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4) 164 | g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4) 165 | # Matmul and softmax to get attention maps 166 | beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) 167 | # Attention map times g path 168 | o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) 169 | return self.gamma * o + x 170 | 171 | 172 | # Fused batchnorm op 173 | def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): 174 | # Apply scale and shift--if gain and bias are provided, fuse them here 175 | # Prepare scale 176 | scale = torch.rsqrt(var + eps) 177 | # If a gain is provided, use it 178 | if gain is not None: 179 | scale = scale * gain 180 | # Prepare shift 181 | shift = mean * scale 182 | # If bias is provided, use it 183 | if bias is not None: 184 | shift = shift - bias 185 | return x * scale - shift 186 | #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. 187 | 188 | 189 | # Manual BN 190 | # Calculate means and variances using mean-of-squares minus mean-squared 191 | def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): 192 | # Cast x to float32 if necessary 193 | float_x = x.float() 194 | # Calculate expected value of x (m) and expected value of x**2 (m2) 195 | # Mean of x 196 | m = torch.mean(float_x, [0, 2, 3], keepdim=True) 197 | # Mean of x squared 198 | m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) 199 | # Calculate variance as mean of squared minus mean squared. 200 | var = (m2 - m **2) 201 | # Cast back to float 16 if necessary 202 | var = var.type(x.type()) 203 | m = m.type(x.type()) 204 | # Return mean and variance for updating stored mean/var if requested 205 | if return_mean_var: 206 | return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() 207 | else: 208 | return fused_bn(x, m, var, gain, bias, eps) 209 | 210 | 211 | # My batchnorm, supports standing stats 212 | class myBN(nn.Module): 213 | def __init__(self, num_channels, eps=1e-5, momentum=0.1): 214 | super(myBN, self).__init__() 215 | # momentum for updating running stats 216 | self.momentum = momentum 217 | # epsilon to avoid dividing by 0 218 | self.eps = eps 219 | # Momentum 220 | self.momentum = momentum 221 | # Register buffers 222 | self.register_buffer('stored_mean', torch.zeros(num_channels)) 223 | self.register_buffer('stored_var', torch.ones(num_channels)) 224 | self.register_buffer('accumulation_counter', torch.zeros(1)) 225 | # Accumulate running means and vars 226 | self.accumulate_standing = False 227 | 228 | # reset standing stats 229 | def reset_stats(self): 230 | self.stored_mean[:] = 0 231 | self.stored_var[:] = 0 232 | self.accumulation_counter[:] = 0 233 | 234 | def forward(self, x, gain, bias): 235 | if self.training: 236 | out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) 237 | # If accumulating standing stats, increment them 238 | if self.accumulate_standing: 239 | self.stored_mean[:] = self.stored_mean + mean.data 240 | self.stored_var[:] = self.stored_var + var.data 241 | self.accumulation_counter += 1.0 242 | # If not accumulating standing stats, take running averages 243 | else: 244 | self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum 245 | self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum 246 | return out 247 | # If not in training mode, use the stored statistics 248 | else: 249 | mean = self.stored_mean.view(1, -1, 1, 1) 250 | var = self.stored_var.view(1, -1, 1, 1) 251 | # If using standing stats, divide them by the accumulation counter 252 | if self.accumulate_standing: 253 | mean = mean / self.accumulation_counter 254 | var = var / self.accumulation_counter 255 | return fused_bn(x, mean, var, gain, bias, self.eps) 256 | 257 | 258 | # Simple function to handle groupnorm norm stylization 259 | def groupnorm(x, norm_style): 260 | # If number of channels specified in norm_style: 261 | if 'ch' in norm_style: 262 | ch = int(norm_style.split('_')[-1]) 263 | groups = max(int(x.shape[1]) // ch, 1) 264 | # If number of groups specified in norm style 265 | elif 'grp' in norm_style: 266 | groups = int(norm_style.split('_')[-1]) 267 | # If neither, default to groups = 16 268 | else: 269 | groups = 16 270 | return F.group_norm(x, groups) 271 | 272 | 273 | # Class-conditional bn 274 | # output size is the number of channels, input size is for the linear layers 275 | # Andy's Note: this class feels messy but I'm not really sure how to clean it up 276 | # Suggestions welcome! (By which I mean, refactor this and make a pull request 277 | # if you want to make this more readable/usable). 278 | class ccbn(nn.Module): 279 | def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1, 280 | cross_replica=False, mybn=False, norm_style='bn',): 281 | super(ccbn, self).__init__() 282 | self.output_size, self.input_size = output_size, input_size 283 | # Prepare gain and bias layers 284 | self.gain = which_linear(input_size, output_size) 285 | self.bias = which_linear(input_size, output_size) 286 | # epsilon to avoid dividing by 0 287 | self.eps = eps 288 | # Momentum 289 | self.momentum = momentum 290 | # Use cross-replica batchnorm? 291 | self.cross_replica = cross_replica 292 | # Use my batchnorm? 293 | self.mybn = mybn 294 | # Norm style? 295 | self.norm_style = norm_style 296 | 297 | if self.cross_replica: 298 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) 299 | elif self.mybn: 300 | self.bn = myBN(output_size, self.eps, self.momentum) 301 | elif self.norm_style in ['bn', 'in']: 302 | self.register_buffer('stored_mean', torch.zeros(output_size)) 303 | self.register_buffer('stored_var', torch.ones(output_size)) 304 | 305 | 306 | def forward(self, x, y): 307 | # Calculate class-conditional gains and biases 308 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) 309 | bias = self.bias(y).view(y.size(0), -1, 1, 1) 310 | # If using my batchnorm 311 | if self.mybn or self.cross_replica: 312 | return self.bn(x, gain=gain, bias=bias) 313 | # else: 314 | else: 315 | if self.norm_style == 'bn': 316 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, 317 | self.training, 0.1, self.eps) 318 | elif self.norm_style == 'in': 319 | out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, 320 | self.training, 0.1, self.eps) 321 | elif self.norm_style == 'gn': 322 | out = groupnorm(x, self.normstyle) 323 | elif self.norm_style == 'nonorm': 324 | out = x 325 | return out * gain + bias 326 | def extra_repr(self): 327 | s = 'out: {output_size}, in: {input_size},' 328 | s +=' cross_replica={cross_replica}' 329 | return s.format(**self.__dict__) 330 | 331 | 332 | # Normal, non-class-conditional BN 333 | class bn(nn.Module): 334 | def __init__(self, output_size, eps=1e-5, momentum=0.1, 335 | cross_replica=False, mybn=False): 336 | super(bn, self).__init__() 337 | self.output_size= output_size 338 | # Prepare gain and bias layers 339 | self.gain = P(torch.ones(output_size), requires_grad=True) 340 | self.bias = P(torch.zeros(output_size), requires_grad=True) 341 | # epsilon to avoid dividing by 0 342 | self.eps = eps 343 | # Momentum 344 | self.momentum = momentum 345 | # Use cross-replica batchnorm? 346 | self.cross_replica = cross_replica 347 | # Use my batchnorm? 348 | self.mybn = mybn 349 | 350 | if self.cross_replica: 351 | self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False) 352 | elif mybn: 353 | self.bn = myBN(output_size, self.eps, self.momentum) 354 | # Register buffers if neither of the above 355 | else: 356 | self.register_buffer('stored_mean', torch.zeros(output_size)) 357 | self.register_buffer('stored_var', torch.ones(output_size)) 358 | 359 | def forward(self, x, y=None): 360 | if self.cross_replica or self.mybn: 361 | gain = self.gain.view(1,-1,1,1) 362 | bias = self.bias.view(1,-1,1,1) 363 | return self.bn(x, gain=gain, bias=bias) 364 | else: 365 | return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, 366 | self.bias, self.training, self.momentum, self.eps) 367 | 368 | 369 | # Generator blocks 370 | # Note that this class assumes the kernel size and padding (and any other 371 | # settings) have been selected in the main generator module and passed in 372 | # through the which_conv arg. Similar rules apply with which_bn (the input 373 | # size [which is actually the number of channels of the conditional info] must 374 | # be preselected) 375 | class GBlock(nn.Module): 376 | def __init__(self, in_channels, out_channels, 377 | which_conv=nn.Conv2d, which_bn=bn, activation=None, 378 | upsample=None): 379 | super(GBlock, self).__init__() 380 | 381 | self.in_channels, self.out_channels = in_channels, out_channels 382 | self.which_conv, self.which_bn = which_conv, which_bn 383 | self.activation = activation 384 | self.upsample = upsample 385 | # Conv layers 386 | self.conv1 = self.which_conv(self.in_channels, self.out_channels) 387 | self.conv2 = self.which_conv(self.out_channels, self.out_channels) 388 | self.learnable_sc = in_channels != out_channels or upsample 389 | if self.learnable_sc: 390 | self.conv_sc = self.which_conv(in_channels, out_channels, 391 | kernel_size=1, padding=0) 392 | # Batchnorm layers 393 | self.bn1 = self.which_bn(in_channels) 394 | self.bn2 = self.which_bn(out_channels) 395 | # upsample layers 396 | self.upsample = upsample 397 | 398 | def forward(self, x, y): 399 | h = self.activation(self.bn1(x, y)) 400 | if self.upsample: 401 | h = self.upsample(h) 402 | x = self.upsample(x) 403 | h = self.conv1(h) 404 | h = self.activation(self.bn2(h, y)) 405 | h = self.conv2(h) 406 | if self.learnable_sc: 407 | x = self.conv_sc(x) 408 | return h + x 409 | 410 | # dogball 411 | 412 | # Residual block for the discriminator 413 | class DBlock(nn.Module): 414 | def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, 415 | preactivation=False, activation=None, downsample=None,): 416 | super(DBlock, self).__init__() 417 | self.in_channels, self.out_channels = in_channels, out_channels 418 | # If using wide D (as in SA-GAN and BigGAN), change the channel pattern 419 | self.hidden_channels = self.out_channels if wide else self.in_channels 420 | self.which_conv = which_conv 421 | self.preactivation = preactivation 422 | self.activation = activation 423 | self.downsample = downsample 424 | 425 | # Conv layers 426 | self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) 427 | self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) 428 | self.learnable_sc = True if (in_channels != out_channels) or downsample else False 429 | if self.learnable_sc: 430 | self.conv_sc = self.which_conv(in_channels, out_channels, 431 | kernel_size=1, padding=0) 432 | def shortcut(self, x): 433 | if self.preactivation: 434 | if self.learnable_sc: 435 | x = self.conv_sc(x) 436 | if self.downsample: 437 | x = self.downsample(x) 438 | else: 439 | if self.downsample: 440 | x = self.downsample(x) 441 | if self.learnable_sc: 442 | x = self.conv_sc(x) 443 | return x 444 | 445 | def forward(self, x): 446 | if self.preactivation: 447 | # h = self.activation(x) # NOT TODAY SATAN 448 | # Andy's note: This line *must* be an out-of-place ReLU or it 449 | # will negatively affect the shortcut connection. 450 | h = F.relu(x) 451 | else: 452 | h = x 453 | h = self.conv1(h) 454 | h = self.conv2(self.activation(h)) 455 | if self.downsample: 456 | h = self.downsample(h) 457 | 458 | return h + self.shortcut(x) 459 | 460 | # dogball 461 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pdb 4 | 5 | # DCGAN loss 6 | def loss_dcgan_dis(dis_fake, dis_real): 7 | L1 = torch.mean(F.softplus(-dis_real)) 8 | L2 = torch.mean(F.softplus(dis_fake)) 9 | return L1, L2 10 | 11 | 12 | def loss_dcgan_gen(dis_fake, M_regu=None, D_fea_w={4:10, 8:5, 16:2, 32:0.1}): 13 | loss = torch.mean(F.softplus(-dis_fake)) 14 | loss_M = 0. 15 | if M_regu is not None: 16 | for keys in M_regu[-1].keys(): 17 | loss_M += D_fea_w[keys] * torch.mean(F.mse_loss(M_regu[-1][keys][0], M_regu[-1][keys][1])) 18 | loss += loss_M 19 | return loss 20 | 21 | 22 | # Hinge Loss 23 | def loss_hinge_dis(dis_fake, dis_real): 24 | loss_real = torch.mean(F.relu(1. - dis_real)) 25 | loss_fake = torch.mean(F.relu(1. + dis_fake)) 26 | return loss_real, loss_fake 27 | # def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss 28 | # loss = torch.mean(F.relu(1. - dis_real)) 29 | # loss += torch.mean(F.relu(1. + dis_fake)) 30 | # return loss 31 | 32 | 33 | def loss_hinge_gen(dis_fake, M_regu=None, D_fea_w={4:10, 8:5, 16:2, 32:0.1}): 34 | loss = -torch.mean(dis_fake) 35 | loss_M = 0. 36 | if M_regu is not None: 37 | for keys in M_regu[-1].keys(): 38 | loss_M += D_fea_w[keys] * torch.mean(F.mse_loss(M_regu[-1][keys][0], M_regu[-1][keys][1])) 39 | loss += loss_M 40 | return loss 41 | 42 | # Default to hinge loss 43 | generator_loss = loss_hinge_gen 44 | discriminator_loss = loss_hinge_dis 45 | #generator_loss = loss_dcgan_gen 46 | #discriminator_loss = loss_dcgan_dis 47 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/make_hdf5.py: -------------------------------------------------------------------------------- 1 | """ Convert dataset to HDF5 2 | This script preprocesses a dataset and saves it (images and labels) to 3 | an HDF5 file for improved I/O. """ 4 | import os 5 | import sys 6 | from argparse import ArgumentParser 7 | from tqdm import tqdm, trange 8 | import h5py as h5 9 | 10 | import numpy as np 11 | import torch 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | from torchvision.utils import save_image 15 | import torchvision.transforms as transforms 16 | from torch.utils.data import DataLoader 17 | 18 | import utils 19 | import pdb 20 | 21 | def prepare_parser(): 22 | usage = 'Parser for ImageNet HDF5 scripts.' 23 | parser = ArgumentParser(description=usage) 24 | parser.add_argument( 25 | '--dataset', type=str, default='I128', 26 | help='Which Dataset to train on, out of I128, I256, C10, C100;' 27 | 'Append "_hdf5" to use the hdf5 version for ISLVRC (default: %(default)s)') 28 | parser.add_argument( 29 | '--data_root', type=str, default='data', 30 | help='Default location where data is stored (default: %(default)s)') 31 | parser.add_argument( 32 | '--batch_size', type=int, default=256, 33 | help='Default overall batchsize (default: %(default)s)') 34 | parser.add_argument( 35 | '--num_workers', type=int, default=16, 36 | help='Number of dataloader workers (default: %(default)s)') 37 | parser.add_argument( 38 | '--chunk_size', type=int, default=500, 39 | help='Default overall batchsize (default: %(default)s)') 40 | parser.add_argument( 41 | '--compression', action='store_true', default=False, 42 | help='Use LZF compression? (default: %(default)s)') 43 | return parser 44 | 45 | 46 | def run(config): 47 | if 'hdf5' in config['dataset']: 48 | raise ValueError('Reading from an HDF5 file which you will probably be ' 49 | 'about to overwrite! Override this error only if you know ' 50 | 'what you''re doing!') 51 | # Get image size 52 | config['image_size'] = utils.imsize_dict[config['dataset']] 53 | 54 | # Update compression entry 55 | config['compression'] = 'lzf' if config['compression'] else None #No compression; can also use 'lzf' 56 | 57 | # Get dataset 58 | kwargs = {'num_workers': config['num_workers'], 'pin_memory': False, 'drop_last': False} 59 | train_loader = utils.get_data_loaders(dataset=config['dataset'], 60 | batch_size=config['batch_size'], 61 | shuffle=False, 62 | data_root=config['data_root'], 63 | use_multiepoch_sampler=False, 64 | **kwargs)[0] 65 | 66 | # HDF5 supports chunking and compression. You may want to experiment 67 | # with different chunk sizes to see how it runs on your machines. 68 | # Chunk Size/compression Read speed @ 256x256 Read speed @ 128x128 Filesize @ 128x128 Time to write @128x128 69 | # 1 / None 20/s 70 | # 500 / None ramps up to 77/s 102/s 61GB 23min 71 | # 500 / LZF 8/s 56GB 23min 72 | # 1000 / None 78/s 73 | # 5000 / None 81/s 74 | # auto:(125,1,16,32) / None 11/s 61GB 75 | 76 | print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (config['dataset'], config['chunk_size'], config['compression'])) 77 | # Loop over train loader 78 | for i,(x,y) in enumerate(tqdm(train_loader)): 79 | # Stick X into the range [0, 255] since it's coming from the train loader 80 | x = (255 * ((x + 1) / 2.0)).byte().numpy() 81 | # Numpyify y 82 | y = y.numpy() 83 | # If we're on the first batch, prepare the hdf5 84 | if i==0: 85 | with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f: 86 | print('Producing dataset of len %d' % len(train_loader.dataset)) 87 | imgs_dset = f.create_dataset('imgs', x.shape,dtype='uint8', maxshape=(len(train_loader.dataset), 3, config['image_size'], config['image_size']), 88 | chunks=(config['chunk_size'], 3, config['image_size'], config['image_size']), compression=config['compression']) 89 | print('Image chunks chosen as ' + str(imgs_dset.chunks)) 90 | imgs_dset[...] = x 91 | labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(train_loader.dataset),), chunks=(config['chunk_size'],), compression=config['compression']) 92 | print('Label chunks chosen as ' + str(labels_dset.chunks)) 93 | labels_dset[...] = y 94 | # Else append to the hdf5 95 | else: 96 | with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f: 97 | f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0) 98 | f['imgs'][-x.shape[0]:] = x 99 | f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0) 100 | f['labels'][-y.shape[0]:] = y 101 | 102 | 103 | def main(): 104 | # parse command line and run 105 | parser = prepare_parser() 106 | config = vars(parser.parse_args()) 107 | print(config) 108 | run(config) 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/merge_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import torchvision 5 | 6 | from scipy.misc import imread 7 | import argparse 8 | parser = argparse.ArgumentParser() 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--experiment_name', 12 | type=str, 13 | default='MineGAN_I2I_conditional_add_align_layers') 14 | parser.add_argument('--guid_path', 15 | type=str, 16 | default='/DATA/data/Imagenet/crop_animal_faces_hdf/MineGAN_I2I_conditional_add_align_layers/test') 17 | parser.add_argument('--save_path', 18 | type=str, 19 | default='/DATA/data/Imagenet/crop_animal_faces_hdf/MineGAN_I2I_conditional_add_align_layers/test_merge_image') 20 | opts = parser.parse_args() 21 | 22 | if not os.path.exists(os.path.join(opts.save_path, opts.experiment_name + '_output')): 23 | os.makedirs(os.path.join(opts.save_path, opts.experiment_name + '_output')) 24 | 25 | cate = os.listdir(os.path.join(opts.guid_path, opts.experiment_name)) 26 | img_dirs = os.listdir(os.path.join(opts.guid_path, opts.experiment_name, cate[0])) 27 | 28 | for img_dir in img_dirs: 29 | for img_index, cate_name in enumerate(cate): 30 | img = imread(os.path.join(opts.guid_path, opts.experiment_name, cate_name, img_dir)) 31 | img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2 32 | img = img.unsqueeze(0); img = img.permute(0, 3, 1, 2) 33 | if img_index==0: 34 | img_input = imread(os.path.join(opts.guid_path, opts.experiment_name + '_input_GT', img_dir)) 35 | img_input = ((torch.from_numpy(img_input).float() / 255) - 0.5) * 2 36 | img_input = img_input.unsqueeze(0); img_input = img_input.permute(0, 3, 1, 2) 37 | fixed_t_x = img_input.detach().clone() 38 | fixed_t_x = torch.cat((img_input, fixed_t_x.detach().clone()), 0) 39 | img_dirs = [img_dir] 40 | else: 41 | fixed_t_x = torch.cat((fixed_t_x, img.detach().clone()), 0) 42 | img_dirs.append(img_dir) 43 | #if not os.path.exists(os.path.join(opts.save_path, img_dir)): 44 | # os.makedirs(os.path.join(opts.save_path, img_dir)) 45 | torchvision.utils.save_image(fixed_t_x, os.path.join(opts.save_path, opts.experiment_name + '_output', img_dir), 46 | nrow=int(fixed_t_x.shape[0] **0.5), normalize=True) 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sample.py: -------------------------------------------------------------------------------- 1 | ''' Sample 2 | This script loads a pretrained net and a weightsfile and sample ''' 3 | import functools 4 | import math 5 | import numpy as np 6 | from tqdm import tqdm, trange 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import init 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torch.nn import Parameter as P 15 | import torchvision 16 | 17 | # Import my stuff 18 | import inception_utils 19 | import utils 20 | import losses 21 | import pdb 22 | 23 | 24 | 25 | def run(config): 26 | # Prepare state dict, which holds things like epoch # and itr # 27 | state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 28 | 'best_IS': 0, 'best_FID': 999999, 'config': config} 29 | 30 | # Optionally, get the configuration from the state dict. This allows for 31 | # recovery of the config provided only a state dict and experiment name, 32 | # and can be convenient for writing less verbose sample shell scripts. 33 | if config['config_from_name']: 34 | utils.load_weights(None, None, state_dict, config['weights_root'], 35 | config['experiment_name'], config['load_weights'], None, 36 | strict=False, load_optim=False) 37 | # Ignore items which we might want to overwrite from the command line 38 | for item in state_dict['config']: 39 | if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']: 40 | config[item] = state_dict['config'][item] 41 | 42 | # update config (see train.py for explanation) 43 | config['resolution'] = utils.imsize_dict[config['dataset']] 44 | config['n_classes'] = utils.nclass_dict[config['dataset']] 45 | config['G_activation'] = utils.activation_dict[config['G_nl']] 46 | config['D_activation'] = utils.activation_dict[config['D_nl']] 47 | config = utils.update_config_roots(config) 48 | config['skip_init'] = True 49 | config['no_optim'] = True 50 | device = 'cuda' 51 | 52 | # Seed RNG 53 | utils.seed_rng(config['seed']) 54 | 55 | # Setup cudnn.benchmark for free speed 56 | torch.backends.cudnn.benchmark = True 57 | 58 | # Import the model--this line allows us to dynamically select different files. 59 | model = __import__(config['model']) 60 | experiment_name = (config['experiment_name'] if config['experiment_name'] 61 | else utils.name_from_config(config)) 62 | print('Experiment name is %s' % experiment_name) 63 | 64 | G = model.Generator(**config).cuda() 65 | utils.count_parameters(G) 66 | 67 | # Load weights 68 | print('Loading weights...') 69 | # Here is where we deal with the ema--load ema weights or load normal weights 70 | utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, 71 | config['weights_root'], experiment_name, config['load_weights'], 72 | G if config['ema'] and config['use_ema'] else None, 73 | strict=False, load_optim=False) 74 | # Update batch size setting used for G 75 | G_batch_size = max(config['G_batch_size'], config['batch_size']) 76 | z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], 77 | device=device, fp16=config['G_fp16'], 78 | z_var=config['z_var']) 79 | 80 | if config['G_eval_mode']: 81 | print('Putting G in eval mode..') 82 | G.eval() 83 | else: 84 | print('G is in %s mode...' % ('training' if G.training else 'eval')) 85 | 86 | #Sample function 87 | sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) 88 | if config['accumulate_stats']: 89 | print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) 90 | utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], 91 | config['num_standing_accumulations']) 92 | 93 | 94 | # Sample a number of images and save them to an NPZ, for use with TF-Inception 95 | if config['sample_npz']: 96 | # Lists to hold images and labels for images 97 | x, y = [], [] 98 | print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) 99 | for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): 100 | with torch.no_grad(): 101 | images, labels = sample() 102 | x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] 103 | y += [labels.cpu().numpy()] 104 | x = np.concatenate(x, 0)[:config['sample_num_npz']] 105 | y = np.concatenate(y, 0)[:config['sample_num_npz']] 106 | print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) 107 | npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) 108 | pdb.set_trace() 109 | print('Saving npz to %s...' % npz_filename) 110 | np.savez(npz_filename, **{'x' : x, 'y' : y}) 111 | 112 | # Prepare sample sheets 113 | if config['sample_sheets']: 114 | print('Preparing conditional sample sheets...') 115 | utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], 116 | num_classes=config['n_classes'], 117 | samples_per_class=10, parallel=config['parallel'], 118 | samples_root=config['samples_root'], 119 | experiment_name=experiment_name, 120 | folder_number=config['sample_sheet_folder_num'], 121 | z_=z_,) 122 | # Sample interp sheets 123 | if config['sample_interps']: 124 | print('Preparing interp sheets...') 125 | for fix_z, fix_y in zip([False, False, True], [False, True, False]): 126 | utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, 127 | num_classes=config['n_classes'], 128 | parallel=config['parallel'], 129 | samples_root=config['samples_root'], 130 | experiment_name=experiment_name, 131 | folder_number=config['sample_sheet_folder_num'], 132 | sheet_number=0, 133 | fix_z=fix_z, fix_y=fix_y, device='cuda') 134 | # Sample random sheet 135 | if config['sample_random']: 136 | print('Preparing random sample sheet...') 137 | images, labels = sample() 138 | torchvision.utils.save_image(images.float(), 139 | '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name), 140 | nrow=int(G_batch_size**0.5), 141 | normalize=True) 142 | 143 | # Get Inception Score and FID 144 | get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) 145 | # Prepare a simple function get metrics that we use for trunc curves 146 | def get_metrics(): 147 | sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) 148 | IS_mean, IS_std, FID = get_inception_metrics(sample, config['num_inception_images'], num_splits=10, prints=False) 149 | # Prepare output string 150 | outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema') 151 | outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training') 152 | outstring += 'with noise variance %3.3f, ' % z_.var 153 | outstring += 'over %d images, ' % config['num_inception_images'] 154 | if config['accumulate_stats'] or not config['G_eval_mode']: 155 | outstring += 'with batch size %d, ' % G_batch_size 156 | if config['accumulate_stats']: 157 | outstring += 'using %d standing stat accumulations, ' % config['num_standing_accumulations'] 158 | outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID) 159 | print(outstring) 160 | if config['sample_inception_metrics']: 161 | print('Calculating Inception metrics...') 162 | get_metrics() 163 | 164 | # Sample truncation curve stuff. This is basically the same as the inception metrics code 165 | if config['sample_trunc_curves']: 166 | start, step, end = [float(item) for item in config['sample_trunc_curves'].split('_')] 167 | print('Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end)) 168 | for var in np.arange(start, end + step, step): 169 | z_.var = var 170 | # Optionally comment this out if you want to run with standing stats 171 | # accumulated at one z variance setting 172 | if config['accumulate_stats']: 173 | utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], 174 | config['num_standing_accumulations']) 175 | get_metrics() 176 | def main(): 177 | # parse command line and run 178 | parser = utils.prepare_parser() 179 | parser = utils.add_sample_parser(parser) 180 | config = vars(parser.parse_args()) 181 | print(config) 182 | run(config) 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/scripts/.DeepI2I.sh.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/scripts/.DeepI2I.sh.swp -------------------------------------------------------------------------------- /DeepI2I_BigGAN/scripts/DeepI2I.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 1 --batch_size 4 --resume --resume_BigGAN \ 4 | --num_G_accumulations 8 --num_D_accumulations 8 \ 5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 6 | --G_attn 64 --D_attn 64 \ 7 | --G_nl inplace_relu --D_nl inplace_relu \ 8 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 9 | --G_ortho 0.0 \ 10 | --G_shared \ 11 | --G_init ortho --D_init ortho \ 12 | --hier --dim_z 120 --shared_dim 128 \ 13 | --G_eval_mode \ 14 | --N_target_cate 149 \ 15 | --G_ch 96 --D_ch 96 \ 16 | --test_every 2000000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 17 | --use_multiepoch_sampler \ 18 | --base_root result/animals \ 19 | --experiment_name DeepI2I_animals \ 20 | --data_root ./data/animals 21 | 22 | ##!/bin/bash 23 | #python train.py \ 24 | #--dataset I128_hdf5 --parallel --shuffle --num_workers 1 --batch_size 4 --training_scratch \ 25 | #--num_G_accumulations 8 --num_D_accumulations 8 \ 26 | #--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 27 | #--G_attn 64 --D_attn 64 \ 28 | #--G_nl inplace_relu --D_nl inplace_relu \ 29 | #--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 30 | #--G_ortho 0.0 \ 31 | #--G_shared \ 32 | #--G_init ortho --D_init ortho \ 33 | #--hier --dim_z 120 --shared_dim 128 \ 34 | #--G_eval_mode \ 35 | #--N_target_cate 555 \ 36 | #--G_ch 96 --D_ch 96 \ 37 | #--test_every 2000000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 38 | #--use_multiepoch_sampler \ 39 | #--base_root result/NABirds \ 40 | #--experiment_name DeepI2I_NABirds \ 41 | #--data_root ./data/NABirds 42 | 43 | 44 | 45 | 46 | ##!/bin/bash 47 | #python train.py \ 48 | #--dataset I128_hdf5 --parallel --shuffle --num_workers 1 --batch_size 4 --training_scratch \ 49 | #--num_G_accumulations 8 --num_D_accumulations 8 \ 50 | #--num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 51 | #--G_attn 64 --D_attn 64 \ 52 | #--G_nl inplace_relu --D_nl inplace_relu \ 53 | #--SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 54 | #--G_ortho 0.0 \ 55 | #--G_shared \ 56 | #--G_init ortho --D_init ortho \ 57 | #--hier --dim_z 120 --shared_dim 128 \ 58 | #--G_eval_mode \ 59 | #--N_target_cate 256 \ 60 | #--G_ch 96 --D_ch 96 \ 61 | #--test_every 2000000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 62 | #--use_multiepoch_sampler \ 63 | #--base_root result/foods \ 64 | #--experiment_name DeepI2I_UECFOOD256 \ 65 | #--data_root ./data/foods 66 | 67 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/scripts/DeepI2I_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for model_ite in $(seq 0 10000 1) 3 | do 4 | python test.py \ 5 | --dataset I128_hdf5 --parallel --shuffle --num_workers 1 --batch_size 32 --resume \ 6 | --num_G_accumulations 8 --num_D_accumulations 8 \ 7 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 8 | --G_attn 64 --D_attn 64 \ 9 | --G_nl inplace_relu --D_nl inplace_relu \ 10 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 11 | --G_ortho 0.0 \ 12 | --G_shared \ 13 | --G_init ortho --D_init ortho \ 14 | --hier --dim_z 120 --shared_dim 128 \ 15 | --G_eval_mode \ 16 | --model_ite $model_ite \ 17 | --N_target_cate 149 \ 18 | --G_ch 96 --D_ch 96 \ 19 | --test_every 2000000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 2020 \ 20 | --use_multiepoch_sampler \ 21 | --base_root result/animals \ 22 | --experiment_name DeepI2I_animals \ 23 | --data_root ./data/animals 24 | 25 | experiment_name=DeepI2I_animals 26 | python merge_image.py --experiment_name $experiment_name \ 27 | --guid_path result/animals/test/$model_ite \ 28 | --save_path result/animals/test/$model_ite 29 | 30 | done 31 | 32 | 33 | ##!/bin/bash 34 | #for model_ite in $(seq 30000 1000 30001) 35 | #do 36 | # python test.py \ 37 | # --dataset I128_hdf5 --parallel --shuffle --num_workers 1 --batch_size 64 --resume \ 38 | # --num_G_accumulations 8 --num_D_accumulations 8 \ 39 | # --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 40 | # --G_attn 64 --D_attn 64 \ 41 | # --G_nl inplace_relu --D_nl inplace_relu \ 42 | # --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 43 | # --G_ortho 0.0 \ 44 | # --G_shared \ 45 | # --G_init ortho --D_init ortho \ 46 | # --hier --dim_z 120 --shared_dim 128 \ 47 | # --G_eval_mode \ 48 | # --model_ite $model_ite \ 49 | # --N_target_cate 256 \ 50 | # --G_ch 96 --D_ch 96 \ 51 | # --test_every 2000000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 2020 \ 52 | # --use_multiepoch_sampler \ 53 | # --base_root result/foods \ 54 | # --experiment_name DeepI2I_UECFOOD256 \ 55 | # --data_root ./data/foods 56 | # 57 | # experiment_name=DeepI2I_UECFOOD256 58 | # python merge_image.py --experiment_name $experiment_name \ 59 | # --guid_path result/foods/test/$model_ite \ 60 | # --save_path result/foods/test/$model_ite 61 | # 62 | #done 63 | 64 | 65 | 66 | 67 | ##!/bin/bash 68 | #for model_ite in $(seq 30000 1000 30001) 69 | #do 70 | # python test.py \ 71 | # --dataset I128_hdf5 --parallel --shuffle --num_workers 1 --batch_size 64 --resume \ 72 | # --num_G_accumulations 8 --num_D_accumulations 8 \ 73 | # --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 74 | # --G_attn 64 --D_attn 64 \ 75 | # --G_nl inplace_relu --D_nl inplace_relu \ 76 | # --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 77 | # --G_ortho 0.0 \ 78 | # --G_shared \ 79 | # --G_init ortho --D_init ortho \ 80 | # --hier --dim_z 120 --shared_dim 128 \ 81 | # --G_eval_mode \ 82 | # --model_ite $model_ite \ 83 | # --N_target_cate 555 \ 84 | # --G_ch 96 --D_ch 96 \ 85 | # --test_every 2000000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 2020 \ 86 | # --use_multiepoch_sampler \ 87 | # --base_root result/NABirds \ 88 | # --experiment_name DeepI2I_NABirds \ 89 | # --data_root ./data/NABirds 90 | # 91 | # experiment_name=DeepI2I_NABirds 92 | # python merge_image.py --experiment_name $experiment_name \ 93 | # --guid_path result/NABirds/test/$model_ite \ 94 | # --save_path result/NABirds/test/$model_ite 95 | # 96 | #done 97 | # 98 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/scripts/launch_BigGAN_bs256x8.sh~: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 4 --batch_size 32 --resume \ 4 | --num_G_accumulations 8 --num_D_accumulations 8 \ 5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 6 | --G_attn 64 --D_attn 64 \ 7 | --G_nl inplace_relu --D_nl inplace_relu \ 8 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 9 | --G_ortho 0.0 \ 10 | --G_shared \ 11 | --G_init ortho --D_init ortho \ 12 | --hier --dim_z 120 --shared_dim 128 \ 13 | --G_eval_mode \ 14 | --G_ch 96 --D_ch 96 \ 15 | --test_every 2000000 --save_every 100 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 16 | --use_multiepoch_sampler \ 17 | --base_root /data2/users/yaxing/MineGAN_I2I/data/cat2dog \ 18 | --target_domain /data2/users/yaxing/MineGAN_I2I/data/cat2dog/dog \ 19 | --ema --use_ema --ema_start 20000 \ 20 | --experiment_name pretrained_model 21 | ##--E1_fea_w {4:0.01, 8:0.1, 16:1, 32:1} this one should be updated in util.py 22 | ##--D_fea_w {4:20, 8:10, 16:1, 32:0.5} this one should be updated in util.py, which is refered form CRN paper 23 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/scripts/utils/duplicate.sh: -------------------------------------------------------------------------------- 1 | #duplicate.sh 2 | source=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0 3 | target=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0A 4 | logs_root=logs 5 | weights_root=weights 6 | echo "copying ${source} to ${target}" 7 | cp -r ${logs_root}/${source} ${logs_root}/${target} 8 | cp ${logs_root}/${source}_log.jsonl ${logs_root}/${target}_log.jsonl 9 | cp ${weights_root}/${source}_G.pth ${weights_root}/${target}_G.pth 10 | cp ${weights_root}/${source}_G_ema.pth ${weights_root}/${target}_G_ema.pth 11 | cp ${weights_root}/${source}_D.pth ${weights_root}/${target}_D.pth 12 | cp ${weights_root}/${source}_G_optim.pth ${weights_root}/${target}_G_optim.pth 13 | cp ${weights_root}/${source}_D_optim.pth ${weights_root}/${target}_D_optim.pth 14 | cp ${weights_root}/${source}_state_dict.pth ${weights_root}/${target}_state_dict.pth -------------------------------------------------------------------------------- /DeepI2I_BigGAN/scripts/utils/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #python make_hdf5.py --dataset I128 --batch_size 256 --data_root /DATA/data/catdog_drit/cat 3 | python make_hdf5.py --dataset I128 --batch_size 256 --data_root /DATA/data/catdog_drit/dog 4 | #python calculate_inception_moments.py --dataset I128_hdf5 --data_root /media/yaxing/Elements/IIAI_raid/Imagenet/single_cate 5 | 6 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/DeepI2I_BigGAN/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input, gain=None, bias=None): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | out = F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | if gain is not None: 55 | out = out + gain 56 | if bias is not None: 57 | out = out + bias 58 | return out 59 | 60 | # Resize the input to (B, C, -1). 61 | input_shape = input.size() 62 | # print(input_shape) 63 | input = input.view(input.size(0), input.size(1), -1) 64 | 65 | # Compute the sum and square-sum. 66 | sum_size = input.size(0) * input.size(2) 67 | input_sum = _sum_ft(input) 68 | input_ssum = _sum_ft(input ** 2) 69 | # Reduce-and-broadcast the statistics. 70 | # print('it begins') 71 | if self._parallel_id == 0: 72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 73 | else: 74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | # if self._parallel_id == 0: 76 | # # print('here') 77 | # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 78 | # else: 79 | # # print('there') 80 | # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 81 | 82 | # print('how2') 83 | # num = sum_size 84 | # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) 85 | # Fix the graph 86 | # sum = (sum.detach() - input_sum.detach()) + input_sum 87 | # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum 88 | 89 | # mean = sum / num 90 | # var = ssum / num - mean ** 2 91 | # # var = (ssum - mean * sum) / num 92 | # inv_std = torch.rsqrt(var + self.eps) 93 | 94 | # Compute the output. 95 | if gain is not None: 96 | # print('gaining') 97 | # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) 98 | # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) 99 | # output = input * scale - shift 100 | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) 101 | elif self.affine: 102 | # MJY:: Fuse the multiplication for speed. 103 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 104 | else: 105 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 106 | 107 | # Reshape it. 108 | return output.view(input_shape) 109 | 110 | def __data_parallel_replicate__(self, ctx, copy_id): 111 | self._is_parallel = True 112 | self._parallel_id = copy_id 113 | 114 | # parallel_id == 0 means master device. 115 | if self._parallel_id == 0: 116 | ctx.sync_master = self._sync_master 117 | else: 118 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 119 | 120 | def _data_parallel_master(self, intermediates): 121 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 122 | 123 | # Always using same "device order" makes the ReduceAdd operation faster. 124 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 125 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 126 | 127 | to_reduce = [i[1][:2] for i in intermediates] 128 | to_reduce = [j for i in to_reduce for j in i] # flatten 129 | target_gpus = [i[1].sum.get_device() for i in intermediates] 130 | 131 | sum_size = sum([i[1].sum_size for i in intermediates]) 132 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 133 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 134 | 135 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 136 | # print('a') 137 | # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) 138 | # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) 139 | # print('b') 140 | outputs = [] 141 | for i, rec in enumerate(intermediates): 142 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 143 | # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) 144 | 145 | return outputs 146 | 147 | def _compute_mean_std(self, sum_, ssum, size): 148 | """Compute the mean and standard-deviation with sum and square-sum. This method 149 | also maintains the moving average on the master device.""" 150 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 151 | mean = sum_ / size 152 | sumvar = ssum - sum_ * mean 153 | unbias_var = sumvar / (size - 1) 154 | bias_var = sumvar / size 155 | 156 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 157 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 158 | return mean, torch.rsqrt(bias_var + self.eps) 159 | # return mean, bias_var.clamp(self.eps) ** -0.5 160 | 161 | 162 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 163 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 164 | mini-batch. 165 | 166 | .. math:: 167 | 168 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 169 | 170 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 171 | standard-deviation are reduced across all devices during training. 172 | 173 | For example, when one uses `nn.DataParallel` to wrap the network during 174 | training, PyTorch's implementation normalize the tensor on each device using 175 | the statistics only on that device, which accelerated the computation and 176 | is also easy to implement, but the statistics might be inaccurate. 177 | Instead, in this synchronized version, the statistics will be computed 178 | over all training samples distributed on multiple devices. 179 | 180 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 181 | as the built-in PyTorch implementation. 182 | 183 | The mean and standard-deviation are calculated per-dimension over 184 | the mini-batches and gamma and beta are learnable parameter vectors 185 | of size C (where C is the input size). 186 | 187 | During training, this layer keeps a running estimate of its computed mean 188 | and variance. The running sum is kept with a default momentum of 0.1. 189 | 190 | During evaluation, this running mean/variance is used for normalization. 191 | 192 | Because the BatchNorm is done over the `C` dimension, computing statistics 193 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 194 | 195 | Args: 196 | num_features: num_features from an expected input of size 197 | `batch_size x num_features [x width]` 198 | eps: a value added to the denominator for numerical stability. 199 | Default: 1e-5 200 | momentum: the value used for the running_mean and running_var 201 | computation. Default: 0.1 202 | affine: a boolean value that when set to ``True``, gives the layer learnable 203 | affine parameters. Default: ``True`` 204 | 205 | Shape: 206 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 207 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 208 | 209 | Examples: 210 | >>> # With Learnable Parameters 211 | >>> m = SynchronizedBatchNorm1d(100) 212 | >>> # Without Learnable Parameters 213 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 214 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 215 | >>> output = m(input) 216 | """ 217 | 218 | def _check_input_dim(self, input): 219 | if input.dim() != 2 and input.dim() != 3: 220 | raise ValueError('expected 2D or 3D input (got {}D input)' 221 | .format(input.dim())) 222 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 223 | 224 | 225 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 226 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 227 | of 3d inputs 228 | 229 | .. math:: 230 | 231 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 232 | 233 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 234 | standard-deviation are reduced across all devices during training. 235 | 236 | For example, when one uses `nn.DataParallel` to wrap the network during 237 | training, PyTorch's implementation normalize the tensor on each device using 238 | the statistics only on that device, which accelerated the computation and 239 | is also easy to implement, but the statistics might be inaccurate. 240 | Instead, in this synchronized version, the statistics will be computed 241 | over all training samples distributed on multiple devices. 242 | 243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 244 | as the built-in PyTorch implementation. 245 | 246 | The mean and standard-deviation are calculated per-dimension over 247 | the mini-batches and gamma and beta are learnable parameter vectors 248 | of size C (where C is the input size). 249 | 250 | During training, this layer keeps a running estimate of its computed mean 251 | and variance. The running sum is kept with a default momentum of 0.1. 252 | 253 | During evaluation, this running mean/variance is used for normalization. 254 | 255 | Because the BatchNorm is done over the `C` dimension, computing statistics 256 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 257 | 258 | Args: 259 | num_features: num_features from an expected input of 260 | size batch_size x num_features x height x width 261 | eps: a value added to the denominator for numerical stability. 262 | Default: 1e-5 263 | momentum: the value used for the running_mean and running_var 264 | computation. Default: 0.1 265 | affine: a boolean value that when set to ``True``, gives the layer learnable 266 | affine parameters. Default: ``True`` 267 | 268 | Shape: 269 | - Input: :math:`(N, C, H, W)` 270 | - Output: :math:`(N, C, H, W)` (same shape as input) 271 | 272 | Examples: 273 | >>> # With Learnable Parameters 274 | >>> m = SynchronizedBatchNorm2d(100) 275 | >>> # Without Learnable Parameters 276 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 277 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 278 | >>> output = m(input) 279 | """ 280 | 281 | def _check_input_dim(self, input): 282 | if input.dim() != 4: 283 | raise ValueError('expected 4D input (got {}D input)' 284 | .format(input.dim())) 285 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 286 | 287 | 288 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 289 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 290 | of 4d inputs 291 | 292 | .. math:: 293 | 294 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 295 | 296 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 297 | standard-deviation are reduced across all devices during training. 298 | 299 | For example, when one uses `nn.DataParallel` to wrap the network during 300 | training, PyTorch's implementation normalize the tensor on each device using 301 | the statistics only on that device, which accelerated the computation and 302 | is also easy to implement, but the statistics might be inaccurate. 303 | Instead, in this synchronized version, the statistics will be computed 304 | over all training samples distributed on multiple devices. 305 | 306 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 307 | as the built-in PyTorch implementation. 308 | 309 | The mean and standard-deviation are calculated per-dimension over 310 | the mini-batches and gamma and beta are learnable parameter vectors 311 | of size C (where C is the input size). 312 | 313 | During training, this layer keeps a running estimate of its computed mean 314 | and variance. The running sum is kept with a default momentum of 0.1. 315 | 316 | During evaluation, this running mean/variance is used for normalization. 317 | 318 | Because the BatchNorm is done over the `C` dimension, computing statistics 319 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 320 | or Spatio-temporal BatchNorm 321 | 322 | Args: 323 | num_features: num_features from an expected input of 324 | size batch_size x num_features x depth x height x width 325 | eps: a value added to the denominator for numerical stability. 326 | Default: 1e-5 327 | momentum: the value used for the running_mean and running_var 328 | computation. Default: 0.1 329 | affine: a boolean value that when set to ``True``, gives the layer learnable 330 | affine parameters. Default: ``True`` 331 | 332 | Shape: 333 | - Input: :math:`(N, C, D, H, W)` 334 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 335 | 336 | Examples: 337 | >>> # With Learnable Parameters 338 | >>> m = SynchronizedBatchNorm3d(100) 339 | >>> # Without Learnable Parameters 340 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 341 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 342 | >>> output = m(input) 343 | """ 344 | 345 | def _check_input_dim(self, input): 346 | if input.dim() != 5: 347 | raise ValueError('expected 5D input (got {}D input)' 348 | .format(input.dim())) 349 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/test.py: -------------------------------------------------------------------------------- 1 | """ BigGAN: The Authorized Unofficial PyTorch release 2 | Code by A. Brock and A. Andonian 3 | This code is an unofficial reimplementation of 4 | "Large-Scale GAN Training for High Fidelity Natural Image Synthesis," 5 | by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096). 6 | 7 | Let's go. 8 | """ 9 | 10 | import os 11 | import functools 12 | import math 13 | import numpy as np 14 | from tqdm import tqdm, trange 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import init 20 | import torch.optim as optim 21 | import torch.nn.functional as F 22 | from torch.nn import Parameter as P 23 | 24 | # Import my stuff 25 | import inception_utils 26 | import utils 27 | import losses 28 | import train_fns 29 | from sync_batchnorm import patch_replication_callback 30 | import pickle 31 | import pdb 32 | 33 | # The main training file. Config is a dictionary specifying the configuration 34 | # of this training run. 35 | def run(config): 36 | 37 | # Update the config dict as necessary 38 | # This is for convenience, to add settings derived from the user-specified 39 | # configuration into the config-dict (e.g. inferring the number of classes 40 | # and size of the images from the dataset, passing in a pytorch object 41 | # for the activation specified as a string) 42 | config['resolution'] = utils.imsize_dict[config['dataset']] 43 | config['n_classes'] = utils.nclass_dict[config['dataset']] 44 | config['G_activation'] = utils.activation_dict[config['G_nl']] 45 | config['D_activation'] = utils.activation_dict[config['D_nl']] 46 | # By default, skip init if resuming training. 47 | if config['resume']: 48 | print('Skipping initialization for training resumption...') 49 | config['skip_init'] = True 50 | config = utils.update_config_roots(config) 51 | device = 'cuda' 52 | 53 | # chaning the parameter for model from scratch 54 | if config['training_scratch']: 55 | config['E1_fea_w'] = {4:1, 8:1, 16:1, 32:.1} 56 | config['D_fea_w'] = {4:0.1, 8:0.1, 16:0.1, 32:.01} 57 | 58 | # Seed RNG 59 | utils.seed_rng(config['seed']) 60 | 61 | # Prepare root folders if necessary 62 | utils.prepare_root(config) 63 | 64 | # Setup cudnn.benchmark for free speed 65 | torch.backends.cudnn.benchmark = True 66 | 67 | # Import the model--this line allows us to dynamically select different files. 68 | model = __import__(config['model']) 69 | experiment_name = (config['experiment_name'] if config['experiment_name'] 70 | else utils.name_from_config(config)) 71 | print('Experiment name is %s' % experiment_name) 72 | 73 | # Next, build the model 74 | # Minor 75 | G = model.Generator(**config).to(device) 76 | D = model.Discriminator(**config).to(device) 77 | E1 = model.Encoder(**config).to(device) 78 | A1 = model.Alignment(**config).to(device) 79 | 80 | # If using EMA, prepare it 81 | if config['ema']:# here it is True 82 | print('Preparing EMA for G with decay of {}'.format(config['ema_decay'])) 83 | G_ema = model.Generator(**{**config, 'skip_init':True, 84 | 'no_optim': True}).to(device) 85 | ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) 86 | else: 87 | G_ema, ema = None, None 88 | 89 | # FP16? 90 | if config['G_fp16']:# here it is False 91 | print('Casting G to float16...') 92 | G = G.half() 93 | if config['ema']: 94 | G_ema = G_ema.half() 95 | if config['D_fp16']:# here it is False 96 | print('Casting D to fp16...') 97 | D = D.half() 98 | # Consider automatically reducing SN_eps? 99 | GD = model.G_D(G, D, E1, A1) 100 | print(G) 101 | print(D) 102 | print(E1) 103 | print(A1) 104 | print('Number of params in G: {} D: {} E1: {} A1: {}'.format( 105 | *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D,E1, A1]])) 106 | # Prepare state dict, which holds things like epoch # and itr # 107 | state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 108 | 'best_IS': 0, 'best_FID': 999999, 'config': config} 109 | 110 | # If loading from a pre-trained model, load weights 111 | if config['resume']: 112 | print('Loading weights...') 113 | utils.load_weights(G, D, E1, A1, state_dict, 114 | config['weights_root'], experiment_name, 115 | config['load_weights'] if config['load_weights'] else None, 116 | G_ema if config['ema'] else None, model_ite=config['model_ite'])# :I add load_optim=Fasle 117 | 118 | # If parallel, parallelize the GD module 119 | if config['parallel']: 120 | GD = nn.DataParallel(GD) # BigGAN use it 121 | #GD = nn.DistributedDataParallel(GD)# Yaxing update it 122 | 123 | if config['cross_replica']:# here it is False 124 | patch_replication_callback(GD) 125 | 126 | # Prepare loggers for stats; metrics holds test metrics, 127 | # lmetrics holds any desired training metrics. 128 | train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) 129 | print('Training Metrics will be saved to {}'.format(train_metrics_fname)) 130 | train_log = utils.MyLogger(train_metrics_fname, 131 | reinitialize=(not config['resume']), 132 | logstyle=config['logstyle']) 133 | # Write metadata 134 | utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) 135 | # Prepare data; the Discriminator's batch size is all that needs to be passed 136 | # to the dataloader, as G doesn't require dataloading. 137 | # Note that at every loader iteration we pass in enough data to complete 138 | # a full D iteration (regardless of number of D steps and accumulations) 139 | D_batch_size = (config['batch_size'] * config['num_D_steps'] 140 | * config['num_D_accumulations']) 141 | loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 142 | 'start_itr': state_dict['itr'], 'target_domain':None}) 143 | 144 | loaders_t = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 145 | 'start_itr': state_dict['itr'], 'target_domain':None}) 146 | 147 | if config['model_ite'] > 0: 148 | config['test_root'] = config['test_root'] + '/'+ str(config['model_ite']) 149 | 150 | with open('class_to_index/%s/I128_imgs.pickle'%experiment_name, 'rb') as handle: 151 | class_to_index = pickle.load(handle) 152 | index_to_class={class_to_index[i]:i for i in class_to_index.keys()} 153 | 154 | # Prepare noise and randomly sampled label arrays 155 | # Allow for different batch sizes in G 156 | G_batch_size = max(config['G_batch_size'], config['batch_size']) 157 | z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], 158 | device=device, fp16=config['G_fp16'], N_target_cate = config['N_target_cate']) 159 | # Prepare a fixed z & y to see individual sample evolution throghout training 160 | fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, 161 | config['n_classes'], device=device, 162 | fp16=config['G_fp16']) 163 | fixed_z.sample_() 164 | fixed_y.sample_() 165 | # Loaders are loaded, prepare the training function 166 | if config['which_train_fn'] == 'GAN': # : here it is GAN 167 | train = train_fns.GAN_training_function(G, D, E1, A1, GD, z_, y_, 168 | ema, state_dict, config) 169 | # Else, assume debugging and use the dummy train fn 170 | else: 171 | train = train_fns.dummy_training_function() 172 | 173 | # target data 174 | t_data = iter(loaders_t[0])# : target domain 175 | fixed_t_x, fixed_t_y = None, None 176 | 177 | # Train for specified number of epochs, although we mostly track G iterations. 178 | print('Switchin G to eval mode...') 179 | E1.eval() 180 | A1.eval() 181 | G.eval() 182 | if config['ema']: 183 | G_ema.eval() 184 | 185 | 186 | for index_epoch_num in range(0, 1): 187 | t_batch = next(t_data)#target domain 188 | t_x, t_y = t_batch 189 | if len(t_x) != (config['num_D_accumulations'] * config['batch_size']): 190 | t_data = iter(loaders_t[0])#since it will stop when it read all loop, we need reset 191 | t_batch = next(t_data)#target domain 192 | t_x, t_y = t_batch 193 | 194 | fixed_t_x, fixed_t_y = t_x[:G_batch_size].detach().clone(), t_y[:G_batch_size].detach().clone() 195 | fixed_t_x, fixed_t_y = fixed_t_x.to(device), fixed_t_y.to(device) 196 | 197 | train_fns.sample_all_cate(G, D, E1, A1, G_ema, z_, y_, fixed_z, fixed_y, 198 | state_dict, config, experiment_name, fixed_t_x, fixed_t_y, None, config['N_target_cate'], index_to_class=index_to_class,index_epoch_num=index_epoch_num) 199 | 200 | 201 | 202 | def main(): 203 | # parse command line and run 204 | parser = utils.prepare_parser() 205 | config = vars(parser.parse_args()) 206 | print(config) 207 | run(config) 208 | 209 | if __name__ == '__main__': 210 | main() 211 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/train.py: -------------------------------------------------------------------------------- 1 | """ BigGAN: The Authorized Unofficial PyTorch release 2 | Code by A. Brock and A. Andonian 3 | This code is an unofficial reimplementation of 4 | "Large-Scale GAN Training for High Fidelity Natural Image Synthesis," 5 | by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096). 6 | 7 | Let's go. 8 | """ 9 | 10 | import os 11 | import functools 12 | import math 13 | import numpy as np 14 | from tqdm import tqdm, trange 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import init 20 | import torch.optim as optim 21 | import torch.nn.functional as F 22 | from torch.nn import Parameter as P 23 | 24 | # Import my stuff 25 | import inception_utils 26 | import utils 27 | import losses 28 | import train_fns 29 | from sync_batchnorm import patch_replication_callback 30 | import pdb 31 | 32 | # The main training file. Config is a dictionary specifying the configuration 33 | # of this training run. 34 | def run(config): 35 | 36 | # Update the config dict as necessary 37 | # This is for convenience, to add settings derived from the user-specified 38 | # configuration into the config-dict (e.g. inferring the number of classes 39 | # and size of the images from the dataset, passing in a pytorch object 40 | # for the activation specified as a string) 41 | config['resolution'] = utils.imsize_dict[config['dataset']] 42 | config['n_classes'] = utils.nclass_dict[config['dataset']] 43 | config['G_activation'] = utils.activation_dict[config['G_nl']] 44 | config['D_activation'] = utils.activation_dict[config['D_nl']] 45 | # By default, skip init if resuming training. 46 | if config['resume']: 47 | print('Skipping initialization for training resumption...') 48 | config['skip_init'] = True 49 | config = utils.update_config_roots(config) 50 | device = 'cuda' 51 | 52 | # chaning the parameter for model from scratch 53 | if config['training_scratch']: 54 | config['E1_fea_w'] = {4:1, 8:1, 16:1, 32:.1} 55 | config['D_fea_w'] = {4:0.1, 8:0.1, 16:0.1, 32:.01} 56 | 57 | # Seed RNG 58 | utils.seed_rng(config['seed']) 59 | 60 | # Prepare root folders if necessary 61 | utils.prepare_root(config) 62 | 63 | # Setup cudnn.benchmark for free speed 64 | torch.backends.cudnn.benchmark = True 65 | 66 | # Import the model--this line allows us to dynamically select different files. 67 | model = __import__(config['model']) 68 | experiment_name = (config['experiment_name'] if config['experiment_name'] 69 | else utils.name_from_config(config)) 70 | print('Experiment name is %s' % experiment_name) 71 | 72 | # Next, build the model 73 | # Minor 74 | G = model.Generator(**config).to(device) 75 | D = model.Discriminator(**config).to(device) 76 | E1 = model.Encoder(**config).to(device) 77 | A1 = model.Alignment(**config).to(device) 78 | 79 | # If using EMA, prepare it 80 | if config['ema']:# here it is True 81 | print('Preparing EMA for G with decay of {}'.format(config['ema_decay'])) 82 | G_ema = model.Generator(**{**config, 'skip_init':True, 83 | 'no_optim': True}).to(device) 84 | ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) 85 | else: 86 | G_ema, ema = None, None 87 | 88 | # FP16? 89 | if config['G_fp16']:# here it is False 90 | print('Casting G to float16...') 91 | G = G.half() 92 | if config['ema']: 93 | G_ema = G_ema.half() 94 | if config['D_fp16']:# here it is False 95 | print('Casting D to fp16...') 96 | D = D.half() 97 | # Consider automatically reducing SN_eps? 98 | GD = model.G_D(G, D, E1, A1) 99 | print(G) 100 | print(D) 101 | print(E1) 102 | print(A1) 103 | print('Number of params in G: {} D: {} E1: {} A1: {}'.format( 104 | *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D,E1, A1]])) 105 | # Prepare state dict, which holds things like epoch # and itr # 106 | state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 107 | 'best_IS': 0, 'best_FID': 999999, 'config': config} 108 | 109 | # If loading from a pre-trained model, load weights 110 | if config['resume']: 111 | print('Loading weights...') 112 | utils.load_weights(G, D, E1, A1, state_dict, 113 | config['weights_root'], experiment_name, 114 | config['load_weights'] if config['load_weights'] else None, 115 | G_ema if config['ema'] else None, resume_BigGAN =config['resume_BigGAN'])# :I add load_optim=Fasle 116 | 117 | # If parallel, parallelize the GD module 118 | if config['parallel']: 119 | GD = nn.DataParallel(GD) # BigGAN use it 120 | #GD = nn.DistributedDataParallel(GD)# Yaxing update it 121 | 122 | if config['cross_replica']:# here it is False 123 | patch_replication_callback(GD) 124 | 125 | # Prepare loggers for stats; metrics holds test metrics, 126 | # lmetrics holds any desired training metrics. 127 | train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) 128 | print('Training Metrics will be saved to {}'.format(train_metrics_fname)) 129 | train_log = utils.MyLogger(train_metrics_fname, 130 | reinitialize=(not config['resume']), 131 | logstyle=config['logstyle']) 132 | # Write metadata 133 | utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) 134 | # Prepare data; the Discriminator's batch size is all that needs to be passed 135 | # to the dataloader, as G doesn't require dataloading. 136 | # Note that at every loader iteration we pass in enough data to complete 137 | # a full D iteration (regardless of number of D steps and accumulations) 138 | D_batch_size = (config['batch_size'] * config['num_D_steps'] 139 | * config['num_D_accumulations']) 140 | loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 141 | 'start_itr': state_dict['itr'], 'target_domain':None}) 142 | 143 | loaders_t = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 144 | 'start_itr': state_dict['itr'], 'target_domain':None}) 145 | 146 | # Prepare noise and randomly sampled label arrays 147 | # Allow for different batch sizes in G 148 | G_batch_size = max(config['G_batch_size'], config['batch_size']) 149 | z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], 150 | device=device, fp16=config['G_fp16'], N_target_cate = config['N_target_cate']) 151 | # Prepare a fixed z & y to see individual sample evolution throghout training 152 | fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, 153 | config['n_classes'], device=device, 154 | fp16=config['G_fp16']) 155 | fixed_z.sample_() 156 | fixed_y.sample_() 157 | # Loaders are loaded, prepare the training function 158 | if config['which_train_fn'] == 'GAN': # : here it is GAN 159 | train = train_fns.GAN_training_function(G, D, E1, A1, GD, z_, y_, 160 | ema, state_dict, config) 161 | # Else, assume debugging and use the dummy train fn 162 | else: 163 | train = train_fns.dummy_training_function() 164 | 165 | # : target data 166 | t_data = iter(loaders_t[0])# : target domain 167 | fixed_t_x, fixed_t_y = None, None 168 | 169 | print('Beginning training at epoch %d...' % state_dict['epoch']) 170 | # Train for specified number of epochs, although we mostly track G iterations. 171 | for epoch in range(state_dict['epoch'], config['num_epochs']): 172 | #for epoch in range(state_dict['epoch'], 7): 173 | # Which progressbar to use? TQDM or my own? 174 | if config['pbar'] == 'mine':# here it is 'mine' 175 | pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') 176 | else: 177 | pbar = tqdm(loaders[0]) 178 | # pdb.set_trace() 179 | for i, (x, y) in enumerate(pbar): 180 | # Increment the iteration counter 181 | state_dict['itr'] += 1 182 | 183 | t_batch = next(t_data)# : target domain 184 | t_x, t_y = t_batch 185 | if len(x) != (config['num_D_accumulations'] * config['batch_size']) : 186 | break 187 | if len(t_x) != (config['num_D_accumulations'] * config['batch_size']): 188 | t_data = iter(loaders_t[0])#since it will stop when it read all loop, we need reset 189 | t_batch = next(t_data)# : target domain 190 | t_x, t_y = t_batch 191 | if fixed_t_x is None: 192 | fixed_t_x, fixed_t_y = t_x[:G_batch_size].detach().clone(), t_y[:G_batch_size].detach().clone() 193 | fixed_x_v2 = x[:G_batch_size].detach().clone() 194 | fixed_t_x, fixed_t_y = fixed_t_x.to(device), fixed_t_y.to(device) 195 | 196 | # Make sure G and D are in training mode, just in case they got set to eval 197 | # For D, which typically doesn't have BN, this shouldn't matter much 198 | 199 | G.train() 200 | D.train() 201 | E1.train() 202 | A1.train() 203 | if config['ema']: 204 | G_ema.train() 205 | if config['D_fp16']: # here it is False 206 | x, y = x.to(device).half(), y.to(device) 207 | t_x, t_y = t_x.to(device).half(), t_y.to(device) 208 | else: 209 | x, y = x.to(device), y.to(device) 210 | t_x, t_y = t_x.to(device), t_y.to(device) 211 | 212 | if state_dict['itr'] > 138500 or config['training_scratch']: # the number 138000 is the last interation of the pre-trained BigGAN, here we train the adaptor 500 iterations 213 | stage = 2 214 | else: 215 | stage = 1 216 | metrics = train(x, y, t_x, t_y, stage=stage, training_scratch=config['training_scratch']) 217 | 218 | train_log.log(itr=int(state_dict['itr']), **metrics) 219 | 220 | # Every sv_log_interval, log singular values 221 | if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): 222 | train_log.log(itr=int(state_dict['itr']), 223 | **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) 224 | 225 | # If using my progbar, print metrics. 226 | if config['pbar'] == 'mine': 227 | print(', '.join(['itr: %d' % state_dict['itr']] 228 | + ['%s : %+4.3f' % (key, metrics[key]) 229 | for key in metrics]), end=' ') 230 | 231 | # Save weights and copies as configured at specified interval 232 | if not (state_dict['itr'] % config['save_every']): 233 | if config['G_eval_mode']: 234 | print('Switchin G to eval mode...') 235 | E1.eval() 236 | A1.eval() 237 | G.eval() 238 | if config['ema']: 239 | G_ema.eval() 240 | train_fns.save_and_sample(G, D, E1, A1, G_ema, z_, y_, fixed_z, fixed_y, 241 | state_dict, config, experiment_name, fixed_t_x, fixed_t_y, fixed_x_v2) 242 | 243 | state_dict['epoch'] += 1 244 | 245 | 246 | def main(): 247 | # parse command line and run 248 | parser = utils.prepare_parser() 249 | config = vars(parser.parse_args()) 250 | print(config) 251 | run(config) 252 | 253 | if __name__ == '__main__': 254 | main() 255 | -------------------------------------------------------------------------------- /DeepI2I_BigGAN/train_fns.py: -------------------------------------------------------------------------------- 1 | ''' train_fns.py 2 | Functions for the main loop of training different conditional image models 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | import os 8 | 9 | import utils 10 | import losses 11 | import pdb 12 | 13 | 14 | # Dummy training function for debugging 15 | def dummy_training_function(): 16 | def train(x, y): 17 | return {} 18 | return train 19 | 20 | 21 | def GAN_training_function(G, D, E1, A1, GD, z_, y_, ema, state_dict, config):# 22 | 23 | def train(x, y, t_x, t_y, stage=1, training_scratch=False): 24 | G.optim.zero_grad() 25 | D.optim.zero_grad() 26 | E1.optim.zero_grad()# 27 | A1.optim.zero_grad()# 28 | x = torch.split(x, config['batch_size'])# How many chunks to split x and y into? 29 | y = torch.split(y, config['batch_size'])# How many chunks to split x and y into? 30 | t_x = torch.split(t_x, config['batch_size'])# How many chunks to split x and y into? 31 | t_y = torch.split(t_y, config['batch_size'])# How many chunks to split x and y into? 32 | D_fea_w = config['D_fea_w'] 33 | # add this 34 | G_para_update = ['shared', 'linear', 'bn', 'output_layer', 'blocks.3.1.gamma', 'blocks.3.1.theta', 'blocks.3.1.phi', 'blocks.3.1.g', 'blocks.3.1.o'] 35 | counter = 0 36 | 37 | # Optionally toggle D and G's "require_grad" 38 | if config['toggle_grads']:# here it is True 39 | utils.toggle_grad(D, True) 40 | utils.toggle_grad(G, False) 41 | utils.toggle_grad(E1, False) # 42 | utils.toggle_grad(A1, False) # 43 | 44 | for step_index in range(config['num_D_steps']): 45 | 46 | 47 | # If accumulating gradients, loop multiple times before an optimizer step 48 | D.optim.zero_grad() 49 | for accumulation_index in range(config['num_D_accumulations']): 50 | z_.sample_() 51 | y_.sample_() 52 | D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], 53 | x[counter], y[counter], t_x=t_x[counter], t_y=t_y[counter], train_G=False, 54 | split_D=config['split_D']) 55 | 56 | # Compute components of D's loss, average them, and divide by 57 | # the number of gradient accumulations 58 | D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) 59 | D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations']) 60 | D_loss.backward() 61 | counter += 1 62 | 63 | # Optionally apply ortho reg in D 64 | if config['D_ortho'] > 0.0:# here it is 0.0 65 | # Debug print to indicate we're using ortho reg in D. 66 | print('using modified ortho reg in D') 67 | utils.ortho(D, config['D_ortho']) 68 | 69 | D.optim.step() 70 | 71 | # Optionally toggle "requires_grad" 72 | if config['toggle_grads']: 73 | utils.toggle_grad(D, False) 74 | T_F = True if stage==2 else False 75 | utils.toggle_grad(G, T_F)# 76 | utils.toggle_grad(A1,True) # 77 | if training_scratch: 78 | utils.toggle_grad(E1, True) # 79 | else: 80 | utils.toggle_grad(E1, False) # 81 | 82 | 83 | 84 | # Zero G's gradients by default before training G, for safety 85 | A1.optim.zero_grad()# 86 | if stage==2: 87 | G.optim.zero_grad() 88 | if training_scratch: 89 | E1.optim.zero_grad()# 90 | 91 | # If accumulating gradients, loop multiple times 92 | for accumulation_index in range(config['num_G_accumulations']): # here it is 1 93 | z_.sample_() 94 | y_.sample_() 95 | # : set gy and dy is equal 0, since we donot know label 96 | D_fake, M_regu= GD(z_, y_, t_x=t_x[accumulation_index], t_y=t_y[accumulation_index], train_G=True, split_D=config['split_D'], M_regu=True,train_E1=True, train_A1=True) 97 | 98 | M_E1_loss = losses.generator_loss(D_fake, M_regu, D_fea_w=D_fea_w) / float(config['num_G_accumulations']) 99 | #pdb.set_trace() 100 | M_E1_loss.backward() 101 | 102 | # Optionally apply modified ortho reg in G 103 | if config['G_ortho'] > 0.0:# here it is 0.0 104 | print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G 105 | # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this 106 | utils.ortho(G, config['G_ortho'], 107 | blacklist=[param for param in G.shared.parameters()]) 108 | A1.optim.step() 109 | if stage==2: 110 | G.optim.step() 111 | if training_scratch: 112 | E1.optim.step() 113 | 114 | # If we have an ema, update it, regardless of if we test with it or not 115 | if config['ema']: 116 | ema.update(state_dict['itr']) 117 | 118 | #out = {'G_loss': float(G_loss.item()), 119 | out = {'G_loss': float(M_E1_loss.item()), 120 | 'D_loss_real': float(D_loss_real.item()), 121 | 'D_loss_fake': float(D_loss_fake.item())} 122 | # Return G's loss and the components of D's loss. 123 | return out 124 | return train 125 | 126 | ''' This function takes in the model, saves the weights (multiple copies if 127 | requested), and prepares sample sheets: one consisting of samples given 128 | a fixed noise seed (to show how the model evolves throughout training), 129 | a set of full conditional sample sheets, and a set of interp sheets. ''' 130 | def save_and_sample(G, D, E1, A1, G_ema, z_, y_, fixed_z, fixed_y, 131 | state_dict, config, experiment_name, fixed_t_x, fixed_t_y, fixed_x_v2): 132 | utils.save_weights(G, D, E1, A1, state_dict, config['weights_root'], 133 | experiment_name, None, G_ema if config['ema'] else None) 134 | if not (state_dict['itr'] % config['save_every']): 135 | utils.save_weights(G, D, E1, A1, state_dict, config['weights_root'], 136 | experiment_name, None, G_ema if config['ema'] else None, copy=True) 137 | # # Use EMA G for samples or non-EMA? 138 | which_G = G_ema if config['ema'] and config['use_ema'] else G 139 | 140 | 141 | # Save a random sample sheet with fixed z and y 142 | with torch.no_grad(): 143 | if config['parallel']: 144 | E1_L_feat = nn.parallel.data_parallel(E1, (fixed_t_x, fixed_t_y, True)) 145 | 146 | E1_L_feat = nn.parallel.data_parallel(A1, (E1_L_feat, None, None)) 147 | 148 | fixed_Gz = nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y), E1_L_feat)) 149 | else: 150 | E1_L_feat = E1(fixed_t_x, fixed_t_y, True) 151 | E1_L_feat = A1(E1_L_feat, None, None) 152 | fixed_Gz = which_G(fixed_z, fixed_t_y, E1_L_feat) 153 | if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)): 154 | os.mkdir('%s/%s' % (config['samples_root'], experiment_name)) 155 | image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'], 156 | experiment_name, 157 | state_dict['itr']) 158 | torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename, 159 | nrow=int(fixed_Gz.shape[0] **0.5), normalize=True) 160 | 161 | # source real image 162 | source_real_name = image_filename.split('fixed')[0] + 'soure_real.jpg' 163 | torchvision.utils.save_image(fixed_x_v2.float().cpu(), source_real_name, 164 | nrow=int(fixed_x_v2.shape[0] **0.5), normalize=True) 165 | 166 | # target real image 167 | target_real_name = image_filename.split('fixed')[0] + 'target_real.jpg' 168 | torchvision.utils.save_image(fixed_t_x.float().cpu(), target_real_name, 169 | nrow=int(fixed_t_x.shape[0] **0.5), normalize=True) 170 | 171 | # For now, every time we save, also save sample sheets 172 | 173 | utils.sample_sheet(which_G,E1_L_feat, 174 | classes_per_sheet=E1_L_feat[32].shape[0] if (E1_L_feat[32].shape[0] args.phase * 2: 102 | used_sample = 0 103 | step += 1 104 | 105 | if step > max_step: 106 | step = max_step 107 | final_progress = True 108 | ckpt_step = step + 1 109 | 110 | else: 111 | alpha = 0 112 | ckpt_step = step 113 | 114 | resolution = 4 * 2 ** step 115 | 116 | loader = sample_data( 117 | dataset, args.batch.get(resolution, args.batch_default), resolution 118 | ) 119 | data_loader = iter(loader) 120 | loader_y = sample_data( 121 | dataset_y, args.batch.get(resolution, args.batch_default), resolution 122 | ) 123 | data_loader_y = iter(loader_y) 124 | 125 | torch.save( 126 | { 127 | 'generator': generator.module.state_dict(), 128 | 'discriminator': discriminator.module.state_dict(), 129 | 'encoder': encoder.module.state_dict(), 130 | 'g_optimizer': g_optimizer.state_dict(), 131 | 'd_optimizer': d_optimizer.state_dict(), 132 | 'e_optimizer': e_optimizer.state_dict(), 133 | 'g_running': g_running.state_dict(), 134 | }, 135 | f'checkpoint/train_step-{ckpt_step}.model', 136 | ) 137 | 138 | adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) 139 | adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) 140 | 141 | try: 142 | real_image = next(data_loader) 143 | real_image_y = next(data_loader_y) 144 | if i==0: 145 | fix_real_image=[] 146 | fix_real_image_y=[] 147 | for _ in range(10): # (10, 5) is to visualize the generated images, which is correspoinding the following codes (10, 5) 148 | fix_real_image.append(real_image[:value_column].cuda()) 149 | fix_real_image_y.append(real_image_y[:value_column].cuda()) 150 | 151 | real_image = next(data_loader) 152 | real_image_y = next(data_loader_y) 153 | 154 | 155 | except (OSError, StopIteration): 156 | data_loader = iter(loader) 157 | real_image = next(data_loader) 158 | 159 | data_loader_y = iter(loader_y) 160 | real_image_y = next(data_loader_y) 161 | 162 | used_sample += real_image.shape[0] 163 | 164 | b_size = real_image.size(0) 165 | real_image = real_image.cuda() 166 | real_image_y = real_image_y.cuda() 167 | 168 | if args.loss == 'wgan-gp': 169 | real_predict = discriminator(real_image, step=step, alpha=alpha) 170 | real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean() 171 | (-real_predict).backward() 172 | 173 | elif args.loss == 'r1': 174 | real_image.requires_grad = True 175 | real_scores = discriminator(real_image, step=step, alpha=alpha) 176 | real_predict = F.softplus(-real_scores).mean() 177 | real_predict.backward(retain_graph=True) 178 | 179 | grad_real = grad( 180 | outputs=real_scores.sum(), inputs=real_image, create_graph=True 181 | )[0] 182 | grad_penalty = ( 183 | grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2 184 | ).mean() 185 | grad_penalty = 10 / 2 * grad_penalty 186 | grad_penalty.backward() 187 | if i%10 == 0: 188 | grad_loss_val = grad_penalty.item() 189 | 190 | if args.mixing and random.random() < 0.9: # mixing is True 191 | gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn( 192 | 4, b_size, code_size, device='cuda' 193 | ).chunk(4, 0) 194 | gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)] 195 | gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)] 196 | 197 | else: 198 | gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device='cuda').chunk( 199 | 2, 0 200 | ) 201 | gen_in1 = gen_in1.squeeze(0) 202 | gen_in2 = gen_in2.squeeze(0) 203 | # E1 204 | _, L_feat = encoder(real_image_y, step=step, alpha=alpha, E1_output_feat=True, RESOLUTION=args.RESOLUTION) 205 | # A1 206 | L_feat = align(L_feat) 207 | 208 | fake_image = generator(gen_in1, step=step, alpha=alpha, E1_output_feat=True, L_feat=L_feat, RESOLUTION=args.RESOLUTION, E1_fea_w=args.E1_fea_w) 209 | fake_predict = discriminator(fake_image, step=step, alpha=alpha) 210 | 211 | if args.loss == 'wgan-gp': 212 | fake_predict = fake_predict.mean() 213 | fake_predict.backward() 214 | 215 | eps = torch.rand(b_size, 1, 1, 1).cuda() 216 | x_hat = eps * real_image.data + (1 - eps) * fake_image.data 217 | x_hat.requires_grad = True 218 | hat_predict = discriminator(x_hat, step=step, alpha=alpha) 219 | grad_x_hat = grad( 220 | outputs=hat_predict.sum(), inputs=x_hat, create_graph=True 221 | )[0] 222 | grad_penalty = ( 223 | (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2 224 | ).mean() 225 | grad_penalty = 10 * grad_penalty 226 | grad_penalty.backward() 227 | if i%10 == 0: 228 | grad_loss_val = grad_penalty.item() 229 | disc_loss_val = (-real_predict + fake_predict).item() 230 | 231 | elif args.loss == 'r1': 232 | fake_predict = F.softplus(fake_predict).mean() 233 | fake_predict.backward() 234 | if i%10 == 0: 235 | disc_loss_val = (real_predict + fake_predict).item() 236 | 237 | d_optimizer.step() 238 | 239 | if (i + 1) % n_critic == 0: 240 | align.zero_grad() 241 | requires_grad(align, True) 242 | # E1 243 | if i > stage_change: 244 | #encoder.zero_grad() 245 | #requires_grad(encoder, True) 246 | generator.zero_grad() 247 | requires_grad(generator, True) 248 | else: 249 | generator.zero_grad() 250 | requires_grad(generator, False) 251 | 252 | requires_grad(discriminator, False) 253 | 254 | # E1 255 | _, L_feat = encoder(real_image_y, step=step, alpha=alpha, E1_output_feat=True, RESOLUTION=args.RESOLUTION) 256 | # A1 257 | L_feat = align(L_feat) 258 | fake_image = generator(gen_in2, step=step, alpha=alpha, E1_output_feat=True, L_feat=L_feat, RESOLUTION=args.RESOLUTION, E1_fea_w=args.E1_fea_w) 259 | 260 | predict, L_out_feat = discriminator(fake_image, step=step, alpha=alpha, E1_output_feat=True, RESOLUTION=args.RESOLUTION) 261 | _, L_in_feat = discriminator(real_image_y, step=step, alpha=alpha, E1_output_feat=True, RESOLUTION=args.RESOLUTION) 262 | 263 | if args.loss == 'wgan-gp': 264 | loss = -predict.mean() 265 | # reconstruction loss 266 | loss_M = 0. 267 | for keys in args.RESOLUTION: 268 | loss_M += args.D_fea_w[keys] * torch.mean(F.mse_loss(L_in_feat[keys], L_out_feat[keys])) 269 | loss += loss_M 270 | 271 | elif args.loss == 'r1': 272 | loss = F.softplus(-predict).mean() 273 | # reconstruction loss 274 | loss_M = 0. 275 | for keys in args.RESOLUTION: 276 | loss_M += args.D_fea_w[keys] * torch.mean(F.mse_loss(L_in_feat[keys], L_out_feat[keys])) 277 | loss += loss_M 278 | 279 | if i%10 == 0: 280 | gen_loss_val = loss.item() 281 | 282 | loss.backward() 283 | a_optimizer.step() 284 | requires_grad(align, False) 285 | # E1 286 | if i > stage_change: 287 | #e_optimizer.step() 288 | #requires_grad(encoder, False) 289 | g_optimizer.step() 290 | accumulate(g_running, generator.module, 0) 291 | 292 | requires_grad(generator, False) 293 | requires_grad(encoder, False) 294 | requires_grad(discriminator, True) 295 | 296 | if i % 100 == 0: 297 | images = [] 298 | images_source = [] 299 | images_y = [] 300 | 301 | encoder.eval() 302 | align.eval() 303 | gen_i, gen_j = args.gen_sample.get(resolution, (10, value_column)) 304 | 305 | 306 | with torch.no_grad(): 307 | for i_ in range(gen_i): 308 | 309 | _, L_feat_fixed = encoder(fix_real_image_y[i_], step=step, alpha=alpha, E1_output_feat=True, RESOLUTION=args.RESOLUTION) 310 | L_feat_fixed = align(L_feat_fixed) 311 | images.append( 312 | g_running( 313 | torch.randn(gen_j, code_size).cuda(), step=step, alpha=alpha, E1_output_feat=True, L_feat=L_feat_fixed, RESOLUTION=args.RESOLUTION, E1_fea_w=args.E1_fea_w 314 | ).data.cpu() 315 | ) 316 | 317 | utils.save_image( 318 | torch.cat(images, 0), 319 | f'sample/{str(i + 1).zfill(6)}.png', 320 | nrow=gen_i, 321 | normalize=True, 322 | range=(-1, 1), 323 | ) 324 | if i==0: # source and target real images 325 | utils.save_image( 326 | torch.cat([img.data.cpu() for img in fix_real_image_y], 0), 327 | f'sample/y.png', 328 | nrow=gen_i, 329 | normalize=True, 330 | range=(-1, 1), 331 | ) 332 | utils.save_image( 333 | torch.cat([img.data.cpu() for img in fix_real_image], 0), 334 | f'sample/target.png', 335 | nrow=gen_i, 336 | normalize=True, 337 | range=(-1, 1), 338 | ) 339 | align.train() 340 | if (i + 1) % 5000 == 0: 341 | torch.save( 342 | { 343 | 'generator': generator.module.state_dict(), 344 | 'discriminator': discriminator.module.state_dict(), 345 | 'encoder': encoder.module.state_dict(), 346 | 'align': align.module.state_dict(), 347 | 'g_optimizer': g_optimizer.state_dict(), 348 | 'd_optimizer': d_optimizer.state_dict(), 349 | 'e_optimizer': e_optimizer.state_dict(), 350 | 'a_optimizer': a_optimizer.state_dict(), 351 | 'g_running': g_running.state_dict(), 352 | }, 353 | f'checkpoint/train_step-{str(i + 1).zfill(6)}.model', 354 | ) 355 | 356 | # torch.save( 357 | # g_running.state_dict(), f'checkpoint/G_running_{str(i + 1).zfill(6)}.model' 358 | # ) 359 | # torch.save( 360 | # generator.state_dict(), f'checkpoint/G_{str(i + 1).zfill(6)}.model' 361 | # ) 362 | # torch.save( 363 | # g_optimizer.state_dict(), f'checkpoint/G_optim_{str(i + 1).zfill(6)}.model' 364 | # ) 365 | 366 | # torch.save( 367 | # discriminator.state_dict(), f'checkpoint/D_{str(i + 1).zfill(6)}.model' 368 | # ) 369 | # torch.save( 370 | # d_optimizer.state_dict(), f'checkpoint/D_optim_{str(i + 1).zfill(6)}.model' 371 | # ) 372 | 373 | # torch.save( 374 | # encoder.state_dict(), f'checkpoint/E_{str(i + 1).zfill(6)}.model' 375 | # ) 376 | # torch.save( 377 | # e_optimizer.state_dict(), f'checkpoint/E_optim_{str(i + 1).zfill(6)}.model' 378 | # ) 379 | 380 | state_msg = ( 381 | f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};' 382 | f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}' 383 | ) 384 | 385 | pbar.set_description(state_msg) 386 | 387 | 388 | if __name__ == '__main__': 389 | code_size = 512 390 | batch_size = 16 391 | n_critic = 1 392 | 393 | parser = argparse.ArgumentParser(description='Progressive Growing of GANs') 394 | 395 | parser.add_argument('path', type=str, help='path of specified dataset') 396 | parser.add_argument('--path_y', default='data/dog', type=str, help='path of specified dataset') 397 | parser.add_argument( 398 | '--phase', 399 | type=int, 400 | default=600_000, 401 | help='number of samples used for each training phases', 402 | ) 403 | 404 | 405 | parser.add_argument('--RESOLUTION', default=[64,32,16,8,4], type=list, help='the selected features') 406 | parser.add_argument('--E1_fea_w', default={4:.1, 8:.1, 16:.1, 32:.1, 64:.0}, type=dict, help='weights of each encoder feautures which is used in generator') 407 | parser.add_argument('--D_fea_w', default={4:.1, 8:.1, 16:.1, 32:.1, 64:0.}, type=dict, help='We compute the distance between real image and fake image, to keep the structure information') 408 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 409 | parser.add_argument('--sched', action='store_true', help='use lr scheduling') 410 | parser.add_argument('--init_size', default=8, type=int, help='initial image size') 411 | parser.add_argument('--max_size', default=1024, type=int, help='max image size') 412 | parser.add_argument( 413 | '--ckpt', default=None, type=str, help='load from previous checkpoints' 414 | ) 415 | parser.add_argument( 416 | '--no_from_rgb_activate', 417 | action='store_true', 418 | help='use activate in from_rgb (original implementation)', 419 | ) 420 | parser.add_argument( 421 | '--mixing', action='store_true', help='use mixing regularization' 422 | ) 423 | parser.add_argument( 424 | '--loss', 425 | type=str, 426 | default='wgan-gp', 427 | choices=['wgan-gp', 'r1'], 428 | help='class of gan loss', 429 | ) 430 | 431 | args = parser.parse_args() 432 | 433 | generator = nn.DataParallel(StyledGenerator(code_size)).cuda() 434 | discriminator = nn.DataParallel( 435 | Discriminator(from_rgb_activate=not args.no_from_rgb_activate) 436 | ).cuda() 437 | # E1 438 | encoder = nn.DataParallel( 439 | Discriminator(from_rgb_activate=not args.no_from_rgb_activate) 440 | ).cuda() 441 | 442 | # A1 443 | align = nn.DataParallel( 444 | Alignment(from_rgb_activate=not args.no_from_rgb_activate) 445 | ).cuda() 446 | 447 | g_running = StyledGenerator(code_size).cuda() 448 | g_running.train(False) 449 | 450 | g_optimizer = optim.Adam( 451 | generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99) 452 | ) 453 | g_optimizer.add_param_group( 454 | { 455 | 'params': generator.module.style.parameters(), 456 | 'lr': args.lr * 0.01, 457 | 'mult': 0.01, 458 | } 459 | ) 460 | d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99)) 461 | # E1 462 | e_optimizer = optim.Adam(encoder.parameters(), lr=args.lr, betas=(0.0, 0.99)) 463 | # A1 464 | a_optimizer = optim.Adam(align.parameters(), lr=args.lr, betas=(0.0, 0.99)) 465 | 466 | accumulate(g_running, generator.module, 0)#Yaxing 467 | 468 | # Big probelm 469 | if args.ckpt is not None: 470 | ckpt = torch.load(args.ckpt) 471 | 472 | generator.module.load_state_dict(ckpt['generator']) 473 | discriminator.module.load_state_dict(ckpt['discriminator']) 474 | # E1 475 | encoder.module.load_state_dict(ckpt['discriminator']) 476 | g_running.load_state_dict(ckpt['g_running']) 477 | g_optimizer.load_state_dict(ckpt['g_optimizer']) 478 | d_optimizer.load_state_dict(ckpt['d_optimizer']) 479 | # E1 480 | e_optimizer.load_state_dict(ckpt['d_optimizer']) 481 | # A1 482 | #a_optimizer.load_state_dict(ckpt['d_optimizer']) 483 | 484 | transform = transforms.Compose( 485 | [ 486 | transforms.RandomHorizontalFlip(), 487 | transforms.ToTensor(), 488 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 489 | ] 490 | ) 491 | 492 | dataset = MultiResolutionDataset(args.path, transform) 493 | dataset_y = MultiResolutionDataset(args.path_y, transform) 494 | 495 | if args.sched: 496 | args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} 497 | args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32} 498 | 499 | else: 500 | args.lr = {} 501 | args.batch = {} 502 | 503 | args.gen_sample = {512: (8, 4), 1024: (4, 2)} 504 | 505 | #args.batch_default = 32 506 | args.batch_default = 32#yaxing 507 | 508 | train(args, dataset, generator, discriminator, encoder, align, dataset_y) 509 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs 2 | # Abstract: 3 | Image-to-image translation has recently achieved remarkable results. But despitecurrent success, it suffers from inferior performance when translations betweenclasses require large shape changes. We attribute this to the high-resolution bottle-necks which are used by current state-of-the-art image-to-image methods. There-fore, in this work, we propose a novel deep hierarchical Image-to-Image Translationmethod, calledDeepI2I. We learn a model by leveraging hierarchical features: (a)structural informationcontained in the shallow layers and (b)semantic informationextracted from the deep layers. To enable the training of deep I2I models on smalldatasets, we propose a novel transfer learning method, that transfers knowledgefrom pre-trained GANs. Specifically, we leverage the discriminator of a pre-trainedGANs (i.e. BigGAN or StyleGAN) to initialize both the encoder and the dis-criminator and the pre-trained generator to initialize the generator of our model.Applying knowledge transfer leads to an alignment problem between the encoderand generator. We introduce anadaptor networkto address this. On many-classimage-to-image translation on three datasets (Animal faces, Birds, and Foods) wedecrease mFID by at least 35% when compared to the state-of-the-art. Furthermore,we qualitatively and quantitatively demonstrate that transfer learning significantlyimproves the performance of I2I systems, especially for small datasets. Finally, weare the first to perform I2I translations for domains with over 100 classes. 4 | 5 | 6 | # Framework 7 |
8 |

9 | 10 | # Result 11 |
12 |

13 | 14 | # Interpolation 15 |
16 |

17 | 18 | 19 | 20 | 21 | # References 22 | - \[1\] [BigGAN](https://arxiv.org/abs/1809.11096) 23 | - \[2\] [StyleGAN](https://arxiv.org/pdf/1812.04948.pdf) 24 | 25 | # Contact 26 | If you run into any problems with this code, please submit a bug report on the Github site of the project. For another inquries pleace contact with me: yaxing@cvc.uab.es 27 | -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/figures/framework.png -------------------------------------------------------------------------------- /figures/interpolation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/figures/interpolation.png -------------------------------------------------------------------------------- /figures/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaxingwang/DeepI2I/9eb03b749ef016715194c11447b44053bc009b3a/figures/sample.png --------------------------------------------------------------------------------