├── README.md ├── datasets.py ├── DCDGANc.py └── models.py /README.md: -------------------------------------------------------------------------------- 1 | # DCDGANc 2 | Diversified and Multi-Class Controllable Industrial Defect Synthesis for Data Augmentation and Transfer 3 | ![image](https://github.com/cg-light/DCDGANc/assets/63499394/7930b883-f30a-454d-97ae-737f53547868) 4 | 5 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @ Project Name: styleVAEGAN 4 | @ Author: Jing 5 | @ TIME: 13:17/15/11/2021 6 | """ 7 | import glob 8 | import random 9 | import os 10 | 11 | import cv2 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import Dataset 15 | from PIL import Image 16 | import torchvision.transforms as transforms 17 | import torchvision.transforms.functional as ttf 18 | 19 | 20 | def to_rgb(image): 21 | rgb_image = Image.new("RGB", image.size) 22 | rgb_image.paste(image) 23 | return rgb_image 24 | 25 | 26 | class ImageDataset(Dataset): 27 | def __init__(self, root, c_dim, transforms_=None, unaligned=False, mode="train"): 28 | self.transform = transforms.Compose(transforms_) 29 | self.unaligned = unaligned 30 | 31 | self.files = {} 32 | path = os.path.join(root, "%s" % mode) 33 | self.class_count = 0 34 | for parent, dirnames, _ in os.walk(path): 35 | for dirname in dirnames: 36 | # print(dirname) 37 | files = sorted(glob.glob(os.path.join(parent, dirname) + "/*.*")) 38 | self.files[self.class_count] = files 39 | self.class_count += 1 40 | if self.class_count == c_dim: 41 | break 42 | break 43 | print(self.class_count) 44 | 45 | def __getitem__(self, index): 46 | y = random.randint(0, (self.class_count - 1)) 47 | if self.unaligned: 48 | image = Image.open(self.files[y][random.randint(0, (len(self.files[y]) - 1))]) 49 | else: 50 | image = Image.open(self.files[y][index % len(self.files[y])]) 51 | img = self.transform(image) 52 | return img, y 53 | 54 | def __len__(self): 55 | l = 0 56 | for key in self.files.keys(): 57 | l += len(self.files[key]) # if l < len(self.files[i]) else l 58 | return l 59 | -------------------------------------------------------------------------------- /DCDGANc.py: -------------------------------------------------------------------------------- 1 | # -*- ecoding: utf-8 -*- 2 | # @ModuleName: BicycleGAN 3 | # @ProjectName: BicycleGAN 4 | # @Author: Kilo 5 | # @Time: 2021/8/20-21:30 6 | import argparse 7 | import datetime 8 | import time 9 | import sys 10 | 11 | from torch import autograd 12 | from torchvision.utils import save_image 13 | 14 | from torch.utils.data import DataLoader 15 | 16 | from models import * 17 | from datasets import * 18 | 19 | import torch 20 | 21 | import os 22 | 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--filename", type=str, default="../../../paras/DCDGANc", help="storage path") 26 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from") 27 | parser.add_argument("--n_class", type=int, default=5, help="number of classes") 28 | parser.add_argument("--dataset_name", type=str, default="carpet", help="name of the dataset") 29 | parser.add_argument("--noise", type=str, default=False, help="insert noise or not") 30 | parser.add_argument("--sample_interval", type=int, default=250, help="interval between saving generator samples") 31 | parser.add_argument("--latent_dim", type=int, default=8, help="number of latent codes") 32 | parser.add_argument("--res_num", type=int, default=4, help="number of classes") 33 | parser.add_argument("--checkpoint_interval", type=int, default=2, help="interval between model checkpoints") 34 | parser.add_argument("--norm_type", type=str, default="spade", help="name of the dataset") 35 | parser.add_argument("--n_epochs", type=int, default=400, help="number of epochs of training") 36 | parser.add_argument("--batch_size", type=int, default=20, help="size of the batches") 37 | parser.add_argument("--lr", type=float, default=0.0005, help="adam: learning rate") 38 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 39 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 40 | parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") 41 | parser.add_argument("--img_height", type=int, default=256, help="size of image height") 42 | parser.add_argument("--img_width", type=int, default=256, help="size of image width") 43 | parser.add_argument("--channels", type=int, default=3, help="number of image channels") 44 | parser.add_argument("--lambda_pixel", type=float, default=10, help="pixelwise loss weight") 45 | parser.add_argument("--lambda_latent", type=float, default=1, help="latent loss weight") 46 | parser.add_argument("--lambda_gp", type=float, default=2.5, help="latent loss weight") 47 | parser.add_argument("--lambda_kl", type=float, default=0.01, help="kullback-leibler loss weight") 48 | parser.add_argument("--lambda_cls", type=float, default=5, help="kullback-leibler loss weight") 49 | opt = parser.parse_args() 50 | print(opt) 51 | os.makedirs("%s/images/%s" % (opt.filename, opt.dataset_name), exist_ok=True) 52 | os.makedirs("%s/saved_models/%s" % (opt.filename, opt.dataset_name), exist_ok=True) 53 | 54 | cuda = True if torch.cuda.is_available() else False 55 | 56 | input_shape = (opt.channels, opt.img_height, opt.img_width) 57 | 58 | # Loss functions 59 | mae_loss = torch.nn.L1Loss() 60 | 61 | # Initialize generator, encoder and discriminators 62 | generator = Res_Generator(opt.latent_dim + opt.n_class, input_shape, opt.norm_type, opt.res_num, noise=opt.noise) 63 | encoder = Encoder(opt.latent_dim, input_shape) 64 | D_VAE = Multiclass(input_shape, opt.n_class) 65 | 66 | if cuda: 67 | generator = generator.cuda() 68 | encoder.cuda() 69 | D_VAE = D_VAE.cuda() 70 | mae_loss.cuda() 71 | 72 | if opt.epoch != 0: 73 | # Load pretrained models 74 | generator.load_state_dict( 75 | torch.load("%s/saved_models/%s/generator_%d.pth" % (opt.filename, opt.dataset_name, opt.epoch))) 76 | encoder.load_state_dict( 77 | torch.load("%s/saved_models/%s/encoder_%d.pth" % (opt.filename, opt.dataset_name, opt.epoch))) 78 | D_VAE.load_state_dict(torch.load("%s/saved_models/%s/D_VAE_%d.pth" % (opt.filename, opt.dataset_name, opt.epoch))) 79 | else: 80 | # Initialize weights 81 | generator.apply(weights_init_normal) 82 | D_VAE.apply(weights_init_normal) 83 | 84 | # Optimizers 85 | optimizer_E = torch.optim.Adam(encoder.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 86 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 87 | optimizer_D_VAE = torch.optim.Adam(D_VAE.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 88 | 89 | Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor 90 | 91 | # Image transformations 92 | transforms_ = [ 93 | transforms.Resize((int(1.12 * opt.img_height), int(1.12 * opt.img_width)), Image.BICUBIC), 94 | transforms.RandomCrop((opt.img_height, opt.img_width)), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.ToTensor(), 97 | transforms.Normalize((0.5,), (0.5,)) 98 | ] 99 | 100 | dataloader = DataLoader( 101 | ImageDataset("../../../datasets/%s" % opt.dataset_name, opt.n_class, transforms_=transforms_, mode='train', 102 | unaligned=True), 103 | batch_size=opt.batch_size, 104 | shuffle=True, 105 | num_workers=opt.n_cpu, 106 | ) 107 | print(len(dataloader)) 108 | 109 | 110 | def sample_images(batches_done): 111 | """Saves a generated sample from the validation set""" 112 | generator.eval() 113 | img_samples = None 114 | for label in range(opt.n_class): 115 | real_label = torch.zeros((1, opt.n_class)) 116 | real_label[0, label] = 1 117 | real_label = real_label.repeat(8, 1) 118 | real_label = real_label.type(Tensor) 119 | clabels = 2 * label / (opt.n_class - 1) - 1 120 | clabels = np.array(clabels) 121 | clabels = Tensor(clabels).repeat(8, 1) # -1, 1 122 | # Sample latent representations 123 | sampled_z = Tensor(np.random.normal(0, 1, (8, opt.latent_dim))) # * torch.exp(logvar) + mu) 124 | # Generate samples 125 | fake_B = generator(clabels, real_label, sampled_z) 126 | # Concatenate samples horisontally 127 | fake_B = torch.cat([x for x in fake_B.data.cpu()], 2) 128 | img_sample = fake_B 129 | img_sample = img_sample.view(1, *img_sample.shape) 130 | # Concatenate with previous samples vertically 131 | img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2) 132 | save_image(img_samples, "%s/images/%s/%s.png" % (opt.filename, opt.dataset_name, batches_done), nrow=8, 133 | normalize=True) 134 | generator.train() 135 | 136 | 137 | def compute_gradient_penalty(D, real_samples, fake_samples, l): 138 | """Calculates the gradient penalty loss for WGAN GP""" 139 | # Random weight term for interpolation between real and fake samples 140 | alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) 141 | # Get random interpolation between real and fake samples 142 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 143 | d_interpolates = D.compute_loss(interpolates, l, valid) 144 | # Get gradient w.r.t. interpolates 145 | gradients = autograd.grad( 146 | outputs=d_interpolates, 147 | inputs=interpolates, 148 | # grad_outputs=fake, 149 | create_graph=True, 150 | retain_graph=True, 151 | only_inputs=True, 152 | )[0] 153 | gradients = gradients.view(gradients.size(0), -1) 154 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 155 | return gradient_penalty 156 | 157 | 158 | # 重参数 159 | def reparameterization(mu, lovar): 160 | std = torch.exp(lovar / 2) 161 | sample_z = Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim))) 162 | z = sample_z * std + mu 163 | return z 164 | 165 | 166 | def embedding(condition, n=opt.n_class): 167 | label = torch.zeros(n) 168 | label[condition] = 1 169 | return label 170 | 171 | 172 | # ---------- 173 | # Training 174 | # ---------- 175 | # Adversarial loss 176 | valid = 1 177 | fake = 0 178 | 179 | prev_time = time.time() 180 | print("Training...\n") 181 | if __name__ == "__main__": 182 | for epoch in range(opt.epoch, opt.n_epochs): 183 | for i, (imgs, ys) in enumerate(dataloader): 184 | 185 | # set model input 186 | real_B = imgs.type(Tensor) 187 | y = ys 188 | labels = torch.zeros((real_B.size(0), opt.n_class)) 189 | j = 0 190 | for yc in ys: 191 | labels[j, yc] = 1 192 | j += 1 193 | labels = labels.type(Tensor) 194 | # glabels = glabels.type(Tensor) 195 | clabels = ((2 * ys / (opt.n_class - 1) - 1)).type(Tensor) # -1, 1 196 | # ys = ys.type(torch.LongTensor).cuda() 197 | 198 | # ------------------------------- 199 | # Train Generator and Encoder 200 | # ------------------------------- 201 | optimizer_G.zero_grad() 202 | optimizer_E.zero_grad() 203 | 204 | # ---------- 205 | # cVAE-GAN 编码生成的图像 206 | # ---------- 207 | 208 | # Produce output using encoding of B (cVAE-GAN) 209 | mu, logvar = encoder(real_B) 210 | # Kullback-Leibler divergence of encoded B 211 | loss_kl = 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - logvar - 1) 212 | encoder_z = reparameterization(mu, logvar) 213 | fake_B = generator(clabels, labels, encoder_z) 214 | # Pixelwise loss of translated image by VAE 215 | loss_pixel = mae_loss(fake_B, real_B) 216 | # Adversarial loss 217 | loss_VAE_GAN = D_VAE.compute_loss(fake_B, labels, valid) 218 | 219 | # --------- 220 | # cLR-GAN 采样z生成的图像骗过鉴别器 221 | # --------- 222 | # Produce output using sampled z (cLR-GAN) 223 | sampled_z = Tensor(np.random.normal(0, 1, (real_B.size(0), opt.latent_dim))) 224 | _fake_B = generator(clabels, labels, sampled_z) 225 | loss_LR_GAN = D_VAE.compute_loss(_fake_B, labels, valid) 226 | 227 | # ---------------------------------- 228 | # Total Loss (Generator + Encoder) 229 | # ---------------------------------- 230 | loss_GAN = loss_VAE_GAN + loss_LR_GAN 231 | loss_GE = loss_GAN + opt.lambda_pixel * loss_pixel + opt.lambda_kl * loss_kl 232 | 233 | loss_GE.backward(retain_graph=True) 234 | optimizer_E.step() 235 | 236 | # --------------------- 237 | # Generator Only Loss 生成图像和潜在空间采样的向量距离最小 238 | # --------------------- 239 | 240 | # Latent L1 loss 241 | _mu, _ = encoder(_fake_B) 242 | loss_latent = opt.lambda_latent * mae_loss(_mu, sampled_z) 243 | loss_latent.backward() 244 | optimizer_G.step() 245 | 246 | # ---------------------------------- 247 | # Train Discriminator (cVAE-GAN)真实图像和编码向量生成的图像损失 248 | # ---------------------------------- 249 | 250 | optimizer_D_VAE.zero_grad() 251 | 252 | loss_adv_r = D_VAE.compute_loss(real_B, labels, valid) 253 | loss_adv_f = D_VAE.compute_loss(fake_B.detach(), labels, fake) 254 | loss_adv_f_ = D_VAE.compute_loss(_fake_B.detach(), labels, fake) 255 | loss_VAE_gp = opt.lambda_gp * (compute_gradient_penalty(D_VAE, real_B.detach(), _fake_B.detach(), labels) + 256 | compute_gradient_penalty(D_VAE, real_B.detach(), fake_B.detach(), labels)) 257 | loss_D_VAE = 2 * loss_adv_r + loss_adv_f + loss_adv_f_ + loss_VAE_gp 258 | loss_D_VAE.backward() 259 | optimizer_D_VAE.step() 260 | 261 | # -------------- 262 | # Log Progress 263 | # -------------- 264 | # Determine approximate time left 265 | batches_done = epoch * len(dataloader) + i 266 | batches_left = opt.n_epochs * len(dataloader) - batches_done 267 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 268 | prev_time = time.time() 269 | 270 | if batches_done % opt.sample_interval == 0: 271 | sample_images(batches_done) 272 | 273 | if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0: # and epoch > 100: 274 | # Save model checkpoints 275 | torch.save(generator.state_dict(), 276 | "%s/saved_models/%s/generator_%d.pth" % (opt.filename, opt.dataset_name, epoch)) 277 | torch.save(encoder.state_dict(), 278 | "%s/saved_models/%s/encoder_%d.pth" % (opt.filename, opt.dataset_name, epoch)) 279 | torch.save(D_VAE.state_dict(), "%s/saved_models/%s/D_VAE_%d.pth" % (opt.filename, opt.dataset_name, epoch)) 280 | 281 | # Print log 282 | sys.stdout.write( 283 | "\r[Epoch %d/%d] [Batch %d/%d] [D adv_loss: %f] [G loss: %f, pixel: %f, kl: %f, latent: %f] ETA: %s\n" 284 | % ( 285 | epoch, 286 | opt.n_epochs, 287 | i, 288 | len(dataloader), 289 | loss_D_VAE.item(), 290 | (loss_VAE_GAN + loss_LR_GAN).item(), 291 | (opt.lambda_pixel * loss_pixel).item(), 292 | (opt.lambda_kl * loss_kl).item(), 293 | (opt.lambda_latent * loss_latent).item(), 294 | time_left, 295 | ) 296 | ) 297 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @ Project Name: styleVAEGAN 4 | @ Author: Jing 5 | @ TIME: 13:17/15/11/2021 6 | """ 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | from torchvision.models import resnet18 11 | import torch.nn.functional as F 12 | 13 | 14 | def weights_init_normal(m): 15 | classname = m.__class__.__name__ 16 | if classname.find("Conv") != -1: 17 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 18 | elif classname.find("BatchNorm2d") != -1: 19 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 20 | torch.nn.init.constant_(m.bias.data, 0.0) 21 | 22 | 23 | class AdaIN(nn.Module): 24 | def __init__(self, weight_dim, channels): 25 | super(AdaIN, self).__init__() 26 | # print(channels) 27 | self.weight = nn.Linear(weight_dim, 2 * channels) 28 | 29 | def forward(self, x, w): 30 | eps = 0.0001 31 | mu = torch.mean(torch.mean(x, dim=2, keepdim=True), dim=3, keepdim=True) 32 | mu = mu.repeat(1, 1, x.shape[2], x.shape[3]) 33 | seta = torch.mean(torch.mean((x - mu) ** 2, dim=2, keepdim=True), dim=3, keepdim=True).sqrt() + eps 34 | # print(x.shape) 35 | y = self.weight(w) 36 | # print(y.shape) 37 | y = y.view(2, x.shape[0], x.shape[1], 1, 1) 38 | out = y[0] * (x - mu) / seta + y[1] 39 | # print(out.shape) 40 | return out 41 | 42 | 43 | def norm(x, gama, beta): 44 | mu = torch.mean(x, dim=1, keepdim=True) 45 | seta = torch.mean((x ** 2 - mu ** 2), dim=1, keepdim=True).sqrt() 46 | out = gama * (x - mu) / seta + beta 47 | return out 48 | 49 | 50 | class SPADE(nn.Module): 51 | def __init__(self, in_size, out_channels): 52 | super(SPADE, self).__init__() 53 | 54 | self.fc = nn.Linear(in_size[0], in_size[2] * in_size[3]) 55 | self.in_size = in_size 56 | self.spade1 = nn.Sequential( 57 | nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, bias=False), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | ) 60 | self.spade_gama = nn.Conv2d(3, out_channels, 3, 1, 1, bias=False) 61 | self.spade_beta = nn.Conv2d(3, out_channels, 3, 1, 1, bias=False) 62 | 63 | def forward(self, x, z): 64 | z = self.fc(z).view(z.size(0), 1, self.in_size[2], self.in_size[3]) 65 | z = self.spade1(z) 66 | gama = self.spade_gama(z) 67 | beta = self.spade_beta(z) 68 | out = norm(x, gama, beta) 69 | return out 70 | 71 | 72 | class ResidualBlock(nn.Module): 73 | def __init__(self, in_size, out_c, norm_type, noise=True): 74 | super(ResidualBlock, self).__init__() 75 | 76 | in_features = in_size[1] 77 | if in_features == out_c: 78 | self.use_1x1 = False 79 | else: 80 | self.use_1x1 = True 81 | self.norm = norm_type 82 | self.noise = noise 83 | 84 | self.conv1 = nn.Conv2d(in_features, out_c, 3, stride=1, padding=1, groups=1, bias=False) 85 | if norm_type == 'spade': 86 | norm_layer = SPADE(in_size, out_c) 87 | elif norm_type == 'adain': 88 | norm_layer = AdaIN(in_size[0], out_c) 89 | else: 90 | norm_layer = nn.InstanceNorm2d(out_c, 0.8, track_running_stats=True) 91 | self.act1 = nn.LeakyReLU(0.2, inplace=True) 92 | self.act2 = nn.LeakyReLU(0.2, inplace=True) 93 | self.conv2 = nn.Conv2d(out_c, out_c, 3, stride=1, padding=1, groups=1, bias=False) 94 | if self.use_1x1: 95 | self.conv3 = nn.Conv2d(in_features, out_c, 1) 96 | self.norm_layer1 = norm_layer 97 | self.norm_layer2 = norm_layer 98 | if noise: 99 | self.insert_noise = nn.Linear(1, out_c) 100 | 101 | def forward(self, x, z): 102 | if self.noise: 103 | noise = torch.rand(x.shape[0], x.shape[2], x.shape[3], 1).cuda() 104 | noise = self.insert_noise(noise).permute((0, 3, 2, 1)) 105 | x = x + noise 106 | # print(noise.shape, x.shape) 107 | x_out = self.conv1(x) 108 | # print(x_out.shape, noise.shape) 109 | if self.norm == 'spade' or self.norm == 'adain': 110 | x_out = self.norm_layer1(x_out, z) 111 | else: 112 | x_out = self.norm_layer1(x_out) 113 | x_out = self.act1(x_out) 114 | x_out = self.conv2(x_out) 115 | if self.norm == 'spade' or self.norm == 'adain': 116 | x_out = self.norm_layer2(x_out, z) 117 | else: 118 | x_out = self.norm_layer2(x) 119 | if self.use_1x1: 120 | return self.conv3(x) + x_out 121 | else: 122 | return x + x_out 123 | 124 | 125 | class down_block(nn.Module): 126 | def __init__(self, in_channels, out_channels, weight_dim, h, w, norm_type, noise=True): 127 | super(down_block, self).__init__() 128 | self.norm_type = norm_type 129 | self.noise = noise 130 | group_num = 1 131 | 132 | model = nn.ModuleList([nn.Conv2d(in_channels, out_channels, 4, 2, 1, groups=group_num)]) 133 | if norm_type == 'adain': 134 | model.extend([AdaIN(weight_dim, out_channels), nn.LeakyReLU(0.2, inplace=True)]) 135 | elif norm_type == 'spade': 136 | model.extend([SPADE([weight_dim, out_channels, h, w], out_channels), nn.LeakyReLU(0.2, inplace=True)]) 137 | elif norm_type == 'instance': 138 | model.extend([nn.InstanceNorm2d(out_channels, 0.8), nn.LeakyReLU(0.2, inplace=True)]) 139 | elif not norm_type: 140 | model.append(nn.LeakyReLU(0.2, inplace=True)) 141 | else: 142 | print('choose appropriate normalization') 143 | 144 | self.models = model 145 | if noise: 146 | self.insert_noise = nn.Linear(1, in_channels) 147 | 148 | def forward(self, x, z): 149 | if self.noise: 150 | noise = torch.rand(x.shape[0], x.shape[2], x.shape[3], 1).cuda() 151 | noise = self.insert_noise(noise).permute((0, 3, 2, 1)) 152 | x = x + noise 153 | i = 0 154 | for model in self.models: 155 | if self.norm_type == 'spade' or self.norm_type == 'adain': 156 | if i == 1: 157 | x = model(x, z) 158 | else: 159 | x = model(x) 160 | else: 161 | x = model(x) 162 | i += 1 163 | return x 164 | 165 | 166 | class up_block(nn.Module): 167 | def __init__(self, in_channels, out_channels, weight_dim, h, w, norm_type, noise=True): 168 | super(up_block, self).__init__() 169 | self.norm_type = norm_type 170 | self.noise = noise 171 | group_num = 1 172 | 173 | model = nn.ModuleList([nn.Upsample(scale_factor=2), 174 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group_num)]) 175 | if norm_type == 'adain': 176 | model.extend([AdaIN(weight_dim, out_channels), nn.LeakyReLU(0.2, inplace=True)]) 177 | elif norm_type == 'spade': 178 | model.extend([SPADE([weight_dim, out_channels, h, w], out_channels), nn.LeakyReLU(0.2, inplace=True)]) 179 | elif norm_type == 'instance': 180 | model.extend([nn.InstanceNorm2d(out_channels, 0.8), nn.LeakyReLU(0.2, inplace=True)]) 181 | elif not norm_type: 182 | model.append(nn.LeakyReLU(0.2, inplace=True)) 183 | else: 184 | print('choose appropriate normalization') 185 | 186 | self.models = model 187 | if noise: 188 | self.insert_noise = nn.Linear(1, in_channels) 189 | 190 | def forward(self, x, z): 191 | if self.noise: 192 | noise = torch.rand(x.shape[0], x.shape[2], x.shape[3], 1).cuda() 193 | noise = self.insert_noise(noise).permute((0, 3, 2, 1)) 194 | x = x + noise 195 | i = 0 196 | for model in self.models: 197 | if self.norm_type == 'spade' or self.norm_type == 'adain': 198 | if i == 2: 199 | x = model(x, z) 200 | else: 201 | 202 | x = model(x) 203 | else: 204 | x = model(x) 205 | i += 1 206 | return x 207 | 208 | 209 | class Res_Generator(nn.Module): 210 | def __init__(self, weight_dim, in_size, norm_type, res_num=6, noise=False): 211 | super(Res_Generator, self).__init__() 212 | channel, h, w = in_size 213 | self.c, self.h, self.w = channel, h, w 214 | self.norm = norm_type 215 | 216 | # down sample 217 | h = h // 2 218 | w = w // 2 219 | curr_channel = 32 220 | if norm_type == 'instance': 221 | self.in_layer = nn.Sequential(nn.Linear(weight_dim, h * w * 4, bias=True)) 222 | model = nn.ModuleList([down_block(channel + 1, curr_channel, weight_dim, h, w, False, noise)]) 223 | else: 224 | model = nn.ModuleList([down_block(channel, curr_channel, weight_dim, h, w, False, noise)]) 225 | for _ in range(3): 226 | h = h // 2 227 | w = w // 2 228 | model.append(down_block(curr_channel, curr_channel * 2, weight_dim, h, w, norm_type, noise)) 229 | curr_channel *= 2 230 | 231 | h = h // 2 232 | w = w // 2 233 | model.extend([down_block(curr_channel, curr_channel, weight_dim, h, w, norm_type, noise)]) # , 234 | 235 | for i in range(res_num): 236 | model.append(ResidualBlock([weight_dim, curr_channel, h, w], curr_channel, norm_type, noise)) # 256 237 | 238 | # # up sample and output 239 | for _ in range(4): 240 | h *= 2 241 | w *= 2 242 | model.append(up_block(curr_channel, curr_channel // 2, weight_dim, h, w, norm_type, noise)) 243 | curr_channel = curr_channel // 2 244 | 245 | self.models = model 246 | self.out_layer = nn.Sequential( 247 | nn.Upsample(scale_factor=2), 248 | nn.Conv2d(curr_channel, channel, 3, 1, 1), # 16 249 | nn.Tanh()) 250 | 251 | def forward(self, clabel, label, z): 252 | x = clabel.repeat(1, self.c * self.h * self.w).view(z.size(0), self.c, self.h, self.w) 253 | z = torch.cat((z, label), 1) 254 | if self.norm == 'instance': 255 | z = self.in_layer(z).view(x.size(0), 1, self.h, self.w) 256 | x = torch.cat((x, z), 1) 257 | for model in self.models: 258 | x = model(x, z) 259 | return self.out_layer(x) 260 | 261 | 262 | class Encoder(nn.Module): 263 | def __init__(self, latent_dim, input_shape): 264 | super(Encoder, self).__init__() 265 | self.feature_extractor = nn.Sequential( 266 | nn.Conv2d(input_shape[0], 64, 7, 2, 3, bias=False), 267 | nn.InstanceNorm2d(64, 0.8), 268 | nn.ReLU(inplace=True)) 269 | self.pooling = nn.AvgPool2d(kernel_size=8, stride=8, padding=0) 270 | # Output is mu and log(var) for reparameterization trick used in VAEs 271 | self.fc_mu = nn.Linear(input_shape[1] * input_shape[2] // 4, latent_dim) 272 | self.fc_logvar = nn.Linear(input_shape[1] * input_shape[2] // 4, latent_dim) 273 | 274 | def forward(self, img): 275 | out = self.feature_extractor(img) # torch.Size([2, 64, 128, 128]) 276 | out = self.pooling(out) # torch.Size([2, 64, 16, 16]) 277 | out = out.view(out.size(0), -1) 278 | mu = self.fc_mu(out) # torch.Size([2, 6]) 279 | logvar = self.fc_logvar(out) 280 | return mu, logvar 281 | 282 | 283 | class MulticlassDis(nn.Module): 284 | def __init__(self, img_shape, c_dim): 285 | super(MulticlassDis, self).__init__() 286 | channels, h, w = img_shape 287 | self.h = h 288 | self.w = w 289 | 290 | def discriminator_block(in_filters, out_filters, norm_type=True): 291 | """Returns downsampling layers of each discriminator block""" 292 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 293 | if norm_type: 294 | layers.append(nn.InstanceNorm2d(out_filters)) 295 | # layers.append(nn.BatchNorm2d(out_filters, 0.8)) 296 | layers.append(nn.LeakyReLU(0.2)) 297 | return layers 298 | 299 | self.in_layers = nn.ModuleList([nn.Linear(c_dim, h * w), 300 | nn.Linear(c_dim, h * w // 4), 301 | nn.Linear(c_dim, h * w // 16)]) 302 | self.models = nn.ModuleList() 303 | self.out_layers = nn.ModuleList() 304 | for i in range(3): 305 | self.models.add_module( 306 | 'class %s' % i, 307 | nn.Sequential( 308 | *discriminator_block(channels, 64, False), 309 | *discriminator_block(64, 128), 310 | *discriminator_block(128, 256), 311 | *discriminator_block(256, 256) 312 | ) 313 | ) 314 | self.out_layers.append(nn.Conv2d(256, 1, 3, padding=1)) 315 | self.down_sample = nn.AvgPool2d(channels + 1, stride=2, padding=[1, 1], count_include_pad=False) 316 | 317 | def compute_loss(self, img, label, gt): 318 | loss_c = sum([torch.mean(out - gt) ** 2 for out in self.forward(img, label)]) 319 | return loss_c 320 | 321 | def forward(self, img, label): 322 | # input = label.repeat(1, self.img_size * self.img_size).view(img.size(0), 1, self.img_size, self.img_size) 323 | outputs = [] 324 | i = 0 325 | for model in self.models: 326 | h, w = self.h // 2 ** (i + 4), self.w // 2 ** (i + 4) 327 | label_in = self.in_layers[i](label).view(img.size(0), -1, h, w) 328 | feature = model(img) 329 | outputs.append(self.out_layers[i](feature * label_in)) 330 | img = self.down_sample(img) 331 | i += 1 332 | return outputs 333 | 334 | 335 | class Multiclass(nn.Module): 336 | def __init__(self, img_shape, c_dim): 337 | super(Multiclass, self).__init__() 338 | channels, h, w = img_shape 339 | self.h = h 340 | self.w = w 341 | 342 | def discriminator_block(in_filters, out_filters, norm_type=True): 343 | """Returns downsampling layers of each discriminator block""" 344 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 345 | if norm_type: 346 | layers.append(nn.InstanceNorm2d(out_filters)) 347 | # layers.append(nn.BatchNorm2d(out_filters, 0.8)) 348 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 349 | return layers 350 | 351 | self.in_layer = nn.Linear(c_dim, h * w) 352 | self.models = nn.Sequential( 353 | *discriminator_block(channels + 1, 64), 354 | *discriminator_block(64, 128), 355 | *discriminator_block(128, 256), 356 | *discriminator_block(256, 512), 357 | nn.Conv2d(512, 1, 3, padding=1) 358 | ) 359 | 360 | def compute_loss(self, img, label, gt): 361 | loss_c = sum([torch.mean(out - gt) ** 2 for out in self.forward(img, label)]) 362 | return loss_c 363 | 364 | def forward(self, img, label): 365 | in_c = self.in_layer(label).reshape(img.size(0), 1, img.size(2), img.size(3)) 366 | img = torch.cat((img, in_c), dim=1) 367 | return self.models(img) 368 | --------------------------------------------------------------------------------