├── E_align_cropping_s1.py ├── E_align_s2.py ├── E_mis_align_cropping_s1.py ├── ablation_utils ├── 1.E_align_z.py ├── 2.E_align_w_2.py ├── 3.E_align_w.py ├── 4.E_align_w_zn.py ├── 5.E_align_w_zn_zc.py ├── 6.E_align_x.py ├── 7.E_align_x_AT1.py ├── 8.E_align_x_AT1_AT2.py └── Cat256 │ ├── E_align_case_1.py │ ├── E_align_case_2.py │ ├── E_mis_align_case_1.py │ └── E_mis_align_case_2.py ├── baseline_utils ├── image2stylegan_w2z_opW.py ├── test-baseline-IndomainG.py ├── test_baseline_alae.py └── test_baseline_psp.py ├── comparing-baseline.py ├── embedding_img.py ├── embedding_v2_BigGAN.py ├── embedding_v2_styleGAN1.py ├── embedding_v2_styleGAN2.py ├── embeded_img_edit.py ├── image_results ├── baseline │ ├── ALAE-rec.jpg │ ├── Real.jpg │ ├── indomain_images_inv.jpg │ ├── indomain_images_rec.jpg │ ├── mtv-tsa-(1500E).jpg │ └── psp-rec.jpg ├── cxx1.gif ├── cxx2.gif ├── dy.gif ├── msk.gif └── zy.gif ├── inferE.py ├── latent_code ├── directions │ ├── stylegan_ffhq_age_w_boundary.npy │ ├── stylegan_ffhq_eyeglasses_w_boundary.npy │ ├── stylegan_ffhq_gender_w_boundary.npy │ ├── stylegan_ffhq_pose_w_boundary.npy │ └── stylegan_ffhq_smile_w_boundary.npy └── real_face_code │ ├── i0_cxx1.pt │ ├── i1_dy.pt │ ├── i2_zy.pt │ ├── i3_cxx2.pt │ ├── i4_msk.pt │ └── i5_ty.pt ├── metric ├── grad_cam.py └── pytorch_ssim.py ├── model ├── E │ ├── Ablation_Study │ │ ├── E_Blur_W.py │ │ ├── E_Blur_W_2.py │ │ ├── E_Blur_Z.py │ │ ├── E_v1.py │ │ └── E_v2_std.py │ ├── E.py │ ├── E_BIG.py │ ├── E_Blur.py │ └── E_PG.py ├── biggan_generator.py ├── pggan │ ├── pggan_d2e.py │ ├── pggan_discriminator.py │ ├── pggan_generator.py │ └── utils │ │ ├── CustomLayers.py │ │ ├── Encoder.py │ │ └── Networks.py ├── stylegan1 │ ├── alae.py │ ├── custom_adam.py │ ├── lod_driver.py │ ├── losses.py │ ├── lreq.py │ ├── model.py │ ├── net.py │ └── text_alae.py ├── stylegan2_generator.py └── utils │ ├── CustomLayers.py │ ├── biggan_config.py │ ├── biggan_file_utils.py │ ├── custom_adam.py │ ├── lreq.py │ └── net.py ├── readme.md ├── readme_cn.md ├── rec_real_img.py ├── requirements.txt ├── synthesized_IMG.py ├── synthesized_textBigGAN.py └── training_utils.py /ablation_utils/1.E_align_z.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import os 4 | import math 5 | import torch 6 | import torchvision 7 | import model.E.Ablation_Study.E_Blur_Z as BE 8 | from model.utils.custom_adam import LREQAdam 9 | import metric.pytorch_ssim as pytorch_ssim 10 | import lpips 11 | import numpy as np 12 | import tensorboardX 13 | import argparse 14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 15 | from training_utils import * 16 | 17 | def train(tensor_writer = None, args = None): 18 | type = args.mtype 19 | 20 | model_path = args.checkpoint_dir_GAN 21 | if type == 1: # StyleGAN1 22 | #model_path = './checkpoint/stylegan_v1/ffhq1024/' 23 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 24 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth')) 25 | 26 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 27 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth')) 28 | 29 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt') 30 | const_ = Gs.const 31 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda() 32 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024 33 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 34 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 35 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1] 36 | coefs = coefs.cuda() 37 | 38 | Gs.cuda() 39 | Gm.cuda() 40 | 41 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 42 | 43 | else: 44 | print('error') 45 | return 46 | 47 | if args.checkpoint_dir_E != None: 48 | E.load_state_dict(torch.load(args.checkpoint_dir_E)) 49 | E.cuda() 50 | writer = tensor_writer 51 | 52 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0) 53 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda') 54 | 55 | batch_size = args.batch_size 56 | it_d = 0 57 | for iteration in range(0,args.iterations): 58 | set_seed(iteration%30000) 59 | z_c1 = torch.randn(batch_size, args.z_dim).cuda() #[n, 512] 60 | 61 | if type == 1: 62 | w1 = Gm(z_c1,coefs_m=coefs) #[batch_size,18,512] 63 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256 64 | z_c2, _ = E(imgs1) 65 | z_c2 = z_c2.squeeze(-1).squeeze(-1) 66 | w2 = Gm(z_c2,coefs_m=coefs) 67 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2)) 68 | else: 69 | print('model type error') 70 | return 71 | 72 | E_optimizer.zero_grad() 73 | 74 | #loss Images 75 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips) 76 | 77 | loss_msiv = loss_imgs 78 | E_optimizer.zero_grad() 79 | loss_msiv.backward(retain_graph=True) 80 | E_optimizer.step() 81 | 82 | #Latent-Vectors 83 | 84 | ## w 85 | #loss_w, loss_w_info = space_loss(w1,w2,image_space = False) 86 | ## c 87 | loss_c, loss_c_info = space_loss(z_c1,z_c2,image_space = False) 88 | 89 | loss_mslv = loss_c*0.01 90 | E_optimizer.zero_grad() 91 | loss_mslv.backward() 92 | E_optimizer.step() 93 | 94 | 95 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000)) 96 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]') 97 | print('---------ImageSpace--------') 98 | print('loss_imgs_info: %s'%loss_imgs_info) 99 | print('---------LatentSpace--------') 100 | print('loss_c_info: %s'%loss_c_info) 101 | 102 | it_d += 1 103 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d) 104 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d) 105 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d) 106 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d) 107 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d) 108 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d) 109 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d) 110 | 111 | writer.add_scalar('loss_c_mse', loss_c_info[0][0], global_step=it_d) 112 | writer.add_scalar('loss_c_mse_mean', loss_c_info[0][1], global_step=it_d) 113 | writer.add_scalar('loss_c_mse_std', loss_c_info[0][2], global_step=it_d) 114 | writer.add_scalar('loss_c_kl', loss_c_info[1], global_step=it_d) 115 | writer.add_scalar('loss_c_cosine', loss_c_info[2], global_step=it_d) 116 | writer.add_scalar('loss_c_ssim', loss_c_info[3], global_step=it_d) 117 | writer.add_scalar('loss_c_lpips', loss_c_info[4], global_step=it_d) 118 | 119 | writer.add_scalars('Latent Space C', {'loss_c_mse':loss_c_info[0][0],'loss_c_mse_mean':loss_c_info[0][1],'loss_c_mse_std':loss_c_info[0][2],'loss_c_kl':loss_c_info[1],'loss_c_cosine':loss_c_info[2]}, global_step=it_d) 120 | 121 | 122 | if iteration % 100 == 0: 123 | n_row = batch_size 124 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5 125 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3 126 | with open(resultPath+'/Loss.txt', 'a+') as f: 127 | print('i_'+str(iteration),file=f) 128 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f) 129 | print('---------ImageSpace--------',file=f) 130 | print('loss_imgs_info: %s'%loss_imgs_info,file=f) 131 | print('---------LatentSpace--------',file=f) 132 | print('loss_c_info: %s'%loss_c_info,file=f) 133 | if iteration % 5000 == 0: 134 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000)) 135 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration) 136 | 137 | if __name__ == "__main__": 138 | 139 | parser = argparse.ArgumentParser(description='the training args') 140 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000 141 | parser.add_argument('--lr', type=float, default=0.0015) 142 | parser.add_argument('--beta_1', type=float, default=0.0) 143 | parser.add_argument('--batch_size', type=int, default=2) 144 | parser.add_argument('--experiment_dir', default=None) #None 145 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt 146 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it 147 | parser.add_argument('--checkpoint_dir_E', default=None) 148 | parser.add_argument('--img_size',type=int, default=1024) 149 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 150 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128 151 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 152 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256 153 | args = parser.parse_args() 154 | 155 | if not os.path.exists('./result'): os.mkdir('./result') 156 | resultPath = args.experiment_dir 157 | if resultPath == None: 158 | resultPath = "./result/StyleGANv1-AlationStudy-Z" 159 | if not os.path.exists(resultPath): os.mkdir(resultPath) 160 | 161 | resultPath1_1 = resultPath+"/imgs" 162 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1) 163 | 164 | resultPath1_2 = resultPath+"/models" 165 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2) 166 | 167 | writer_path = os.path.join(resultPath, './summaries') 168 | if not os.path.exists(writer_path): os.mkdir(writer_path) 169 | writer = tensorboardX.SummaryWriter(writer_path) 170 | 171 | use_gpu = True 172 | device = torch.device("cuda" if use_gpu else "cpu") 173 | 174 | train(tensor_writer=writer, args = args) 175 | -------------------------------------------------------------------------------- /ablation_utils/2.E_align_w_2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import os 4 | import math 5 | import torch 6 | import torchvision 7 | import model.E.Ablation_Study.E_Blur_W_2 as BE 8 | from model.utils.custom_adam import LREQAdam 9 | import metric.pytorch_ssim as pytorch_ssim 10 | import lpips 11 | import numpy as np 12 | import tensorboardX 13 | import argparse 14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 15 | from training_utils import * 16 | 17 | 18 | def train(tensor_writer = None, args = None): 19 | type = args.mtype 20 | 21 | model_path = args.checkpoint_dir_GAN 22 | if type == 1: # StyleGAN1 23 | #model_path = './checkpoint/stylegan_v1/ffhq1024/' 24 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 25 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth')) 26 | 27 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 28 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth')) 29 | 30 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt') 31 | const_ = Gs.const 32 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda() 33 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024 34 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 35 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 36 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1] 37 | 38 | Gs.cuda() 39 | Gm.eval() 40 | 41 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 42 | else: 43 | print('error') 44 | return 45 | 46 | if args.checkpoint_dir_E != None: 47 | E.load_state_dict(torch.load(args.checkpoint_dir_E)) 48 | E.cuda() 49 | writer = tensor_writer 50 | 51 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0) 52 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda') 53 | 54 | batch_size = args.batch_size 55 | it_d = 0 56 | for iteration in range(0,args.iterations): 57 | set_seed(iteration%30000) 58 | z = torch.randn(batch_size, args.z_dim) #[32, 512] 59 | 60 | if type == 1: 61 | with torch.no_grad(): #这里需要生成图片和变量 62 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512] 63 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256 64 | const2,w2 = E(imgs1) 65 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2)) 66 | else: 67 | print('model type error') 68 | return 69 | 70 | E_optimizer.zero_grad() 71 | 72 | #loss Images 73 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips) 74 | 75 | loss_msiv = loss_imgs 76 | E_optimizer.zero_grad() 77 | loss_msiv.backward(retain_graph=True) 78 | E_optimizer.step() 79 | 80 | #Latent-Vectors 81 | 82 | ## w 83 | loss_w, loss_w_info = space_loss(w1,w2,image_space = False) 84 | 85 | ## c 86 | #loss_c, loss_c_info = space_loss(const1,const2,image_space = False) 87 | 88 | loss_mslv = loss_w * 0.01 89 | E_optimizer.zero_grad() 90 | loss_mslv.backward() 91 | E_optimizer.step() 92 | 93 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000)) 94 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]') 95 | print('---------ImageSpace--------') 96 | print('loss_imgs_info: %s'%loss_imgs_info) 97 | print('---------LatentSpace--------') 98 | print('loss_w_info: %s'%loss_w_info) 99 | 100 | it_d += 1 101 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d) 102 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d) 103 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d) 104 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d) 105 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d) 106 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d) 107 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d) 108 | 109 | writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d) 110 | writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d) 111 | writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d) 112 | writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d) 113 | writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d) 114 | writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d) 115 | writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d) 116 | 117 | writer.add_scalars('Latent Space W', {'loss_w_mse':loss_w_info[0][0],'loss_w_mse_mean':loss_w_info[0][1],'loss_w_mse_std':loss_w_info[0][2],'loss_w_kl':loss_w_info[1],'loss_w_cosine':loss_w_info[2]}, global_step=it_d) 118 | 119 | if iteration % 100 == 0: 120 | n_row = batch_size 121 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5 122 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3 123 | with open(resultPath+'/Loss.txt', 'a+') as f: 124 | print('i_'+str(iteration),file=f) 125 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f) 126 | print('---------ImageSpace--------',file=f) 127 | print('loss_imgs_info: %s'%loss_imgs_info,file=f) 128 | print('---------LatentSpace--------',file=f) 129 | print('loss_w_info: %s'%loss_w_info,file=f) 130 | if iteration % 5000 == 0: 131 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000)) 132 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration) 133 | 134 | if __name__ == "__main__": 135 | 136 | parser = argparse.ArgumentParser(description='the training args') 137 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000 138 | parser.add_argument('--lr', type=float, default=0.0015) 139 | parser.add_argument('--beta_1', type=float, default=0.0) 140 | parser.add_argument('--batch_size', type=int, default=2) 141 | parser.add_argument('--experiment_dir', default=None) #None 142 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt 143 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it 144 | parser.add_argument('--checkpoint_dir_E', default=None) 145 | parser.add_argument('--img_size',type=int, default=1024) 146 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 147 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128 148 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 149 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256 150 | args = parser.parse_args() 151 | 152 | if not os.path.exists('./result'): os.mkdir('./result') 153 | resultPath = args.experiment_dir 154 | if resultPath == None: 155 | resultPath = "./result/StyleGANv1-AlationStudy-w_2" 156 | if not os.path.exists(resultPath): os.mkdir(resultPath) 157 | 158 | resultPath1_1 = resultPath+"/imgs" 159 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1) 160 | 161 | resultPath1_2 = resultPath+"/models" 162 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2) 163 | 164 | writer_path = os.path.join(resultPath, './summaries') 165 | if not os.path.exists(writer_path): os.mkdir(writer_path) 166 | writer = tensorboardX.SummaryWriter(writer_path) 167 | 168 | use_gpu = True 169 | device = torch.device("cuda" if use_gpu else "cpu") 170 | 171 | train(tensor_writer=writer, args = args) 172 | -------------------------------------------------------------------------------- /ablation_utils/3.E_align_w.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import os 4 | import math 5 | import torch 6 | import torchvision 7 | import model.E.Ablation_Study.E_Blur_W as BE 8 | from model.utils.custom_adam import LREQAdam 9 | import metric.pytorch_ssim as pytorch_ssim 10 | import lpips 11 | import numpy as np 12 | import tensorboardX 13 | import argparse 14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 15 | from training_utils import * 16 | 17 | def train(tensor_writer = None, args = None): 18 | type = args.mtype 19 | 20 | model_path = args.checkpoint_dir_GAN 21 | if type == 1: # StyleGAN1 22 | #model_path = './checkpoint/stylegan_v1/ffhq1024/' 23 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 24 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth')) 25 | 26 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 27 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth')) 28 | 29 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt') 30 | const_ = Gs.const 31 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda() 32 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024 33 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 34 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 35 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1] 36 | 37 | Gs.cuda() 38 | Gm.eval() 39 | 40 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 41 | else: 42 | print('error') 43 | return 44 | 45 | if args.checkpoint_dir_E != None: 46 | E.load_state_dict(torch.load(args.checkpoint_dir_E)) 47 | E.cuda() 48 | writer = tensor_writer 49 | 50 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0) 51 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda') 52 | 53 | batch_size = args.batch_size 54 | it_d = 0 55 | for iteration in range(0,args.iterations): 56 | set_seed(iteration%30000) 57 | z = torch.randn(batch_size, args.z_dim) #[32, 512] 58 | 59 | if type == 1: 60 | with torch.no_grad(): #这里需要生成图片和变量 61 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512] 62 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256 63 | const2,w2 = E(imgs1) 64 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2)) 65 | else: 66 | print('model type error') 67 | return 68 | 69 | E_optimizer.zero_grad() 70 | 71 | #loss Images 72 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips) 73 | 74 | loss_msiv = loss_imgs 75 | E_optimizer.zero_grad() 76 | loss_msiv.backward(retain_graph=True) 77 | E_optimizer.step() 78 | 79 | #Latent-Vectors 80 | 81 | ## w 82 | loss_w, loss_w_info = space_loss(w1,w2,image_space = False) 83 | 84 | ## c 85 | #loss_c, loss_c_info = space_loss(const1,const2,image_space = False) 86 | 87 | loss_mslv = loss_w * 0.01 88 | E_optimizer.zero_grad() 89 | loss_mslv.backward() 90 | E_optimizer.step() 91 | 92 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000)) 93 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]') 94 | print('---------ImageSpace--------') 95 | print('loss_imgs_info: %s'%loss_imgs_info) 96 | print('---------LatentSpace--------') 97 | print('loss_w_info: %s'%loss_w_info) 98 | 99 | it_d += 1 100 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d) 101 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d) 102 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d) 103 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d) 104 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d) 105 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d) 106 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d) 107 | 108 | writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d) 109 | writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d) 110 | writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d) 111 | writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d) 112 | writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d) 113 | writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d) 114 | writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d) 115 | 116 | writer.add_scalars('Latent Space W', {'loss_w_mse':loss_w_info[0][0],'loss_w_mse_mean':loss_w_info[0][1],'loss_w_mse_std':loss_w_info[0][2],'loss_w_kl':loss_w_info[1],'loss_w_cosine':loss_w_info[2]}, global_step=it_d) 117 | 118 | if iteration % 100 == 0: 119 | n_row = batch_size 120 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5 121 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3 122 | with open(resultPath+'/Loss.txt', 'a+') as f: 123 | print('i_'+str(iteration),file=f) 124 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f) 125 | print('---------ImageSpace--------',file=f) 126 | print('loss_imgs_info: %s'%loss_imgs_info,file=f) 127 | print('---------LatentSpace--------',file=f) 128 | print('loss_w_info: %s'%loss_w_info,file=f) 129 | if iteration % 5000 == 0: 130 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000)) 131 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration) 132 | 133 | if __name__ == "__main__": 134 | 135 | parser = argparse.ArgumentParser(description='the training args') 136 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000 137 | parser.add_argument('--lr', type=float, default=0.0015) 138 | parser.add_argument('--beta_1', type=float, default=0.0) 139 | parser.add_argument('--batch_size', type=int, default=2) 140 | parser.add_argument('--experiment_dir', default=None) #None 141 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt 142 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it 143 | parser.add_argument('--checkpoint_dir_E', default=None) 144 | parser.add_argument('--img_size',type=int, default=1024) 145 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 146 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128 147 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 148 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256 149 | args = parser.parse_args() 150 | 151 | if not os.path.exists('./result'): os.mkdir('./result') 152 | resultPath = args.experiment_dir 153 | if resultPath == None: 154 | resultPath = "./result/StyleGANv1-AlationStudy-w" 155 | if not os.path.exists(resultPath): os.mkdir(resultPath) 156 | 157 | resultPath1_1 = resultPath+"/imgs" 158 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1) 159 | 160 | resultPath1_2 = resultPath+"/models" 161 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2) 162 | 163 | writer_path = os.path.join(resultPath, './summaries') 164 | if not os.path.exists(writer_path): os.mkdir(writer_path) 165 | writer = tensorboardX.SummaryWriter(writer_path) 166 | 167 | use_gpu = True 168 | device = torch.device("cuda" if use_gpu else "cpu") 169 | 170 | train(tensor_writer=writer, args = args) 171 | -------------------------------------------------------------------------------- /ablation_utils/4.E_align_w_zn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import os 4 | import math 5 | import torch 6 | import torchvision 7 | import model.E.E_Blur as BE 8 | from model.utils.custom_adam import LREQAdam 9 | import metric.pytorch_ssim as pytorch_ssim 10 | import lpips 11 | import numpy as np 12 | import tensorboardX 13 | import argparse 14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 15 | from training_utils import * 16 | 17 | def train(tensor_writer = None, args = None): 18 | type = args.mtype 19 | 20 | model_path = args.checkpoint_dir_GAN 21 | if type == 1: # StyleGAN1 22 | #model_path = './checkpoint/stylegan_v1/ffhq1024/' 23 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 24 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth')) 25 | 26 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 27 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth')) 28 | 29 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt') 30 | const_ = Gs.const 31 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda() 32 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024 33 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 34 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 35 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1] 36 | 37 | Gs.cuda() 38 | Gm.eval() 39 | 40 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 41 | else: 42 | print('error') 43 | return 44 | 45 | if args.checkpoint_dir_E != None: 46 | E.load_state_dict(torch.load(args.checkpoint_dir_E)) 47 | E.cuda() 48 | writer = tensor_writer 49 | 50 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0) 51 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda') 52 | 53 | batch_size = args.batch_size 54 | it_d = 0 55 | for iteration in range(0,args.iterations): 56 | set_seed(iteration%30000) 57 | z = torch.randn(batch_size, args.z_dim) #[32, 512] 58 | 59 | if type == 1: 60 | with torch.no_grad(): #这里需要生成图片和变量 61 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512] 62 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256 63 | const2,w2 = E(imgs1) 64 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2)) 65 | else: 66 | print('model type error') 67 | return 68 | 69 | E_optimizer.zero_grad() 70 | 71 | #loss Images 72 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips) 73 | 74 | loss_msiv = loss_imgs 75 | E_optimizer.zero_grad() 76 | loss_msiv.backward(retain_graph=True) 77 | E_optimizer.step() 78 | 79 | #Latent-Vectors 80 | 81 | ## w 82 | loss_w, loss_w_info = space_loss(w1,w2,image_space = False) 83 | 84 | ## c 85 | #loss_c, loss_c_info = space_loss(const1,const2,image_space = False) 86 | 87 | loss_mslv = loss_w * 0.01 88 | E_optimizer.zero_grad() 89 | loss_mslv.backward() 90 | E_optimizer.step() 91 | 92 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000)) 93 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]') 94 | print('---------ImageSpace--------') 95 | print('loss_imgs_info: %s'%loss_imgs_info) 96 | print('---------LatentSpace--------') 97 | print('loss_w_info: %s'%loss_w_info) 98 | 99 | it_d += 1 100 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d) 101 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d) 102 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d) 103 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d) 104 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d) 105 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d) 106 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d) 107 | 108 | writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d) 109 | writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d) 110 | writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d) 111 | writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d) 112 | writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d) 113 | writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d) 114 | writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d) 115 | 116 | writer.add_scalars('Latent Space W', {'loss_w_mse':loss_w_info[0][0],'loss_w_mse_mean':loss_w_info[0][1],'loss_w_mse_std':loss_w_info[0][2],'loss_w_kl':loss_w_info[1],'loss_w_cosine':loss_w_info[2]}, global_step=it_d) 117 | 118 | if iteration % 100 == 0: 119 | n_row = batch_size 120 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5 121 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3 122 | with open(resultPath+'/Loss.txt', 'a+') as f: 123 | print('i_'+str(iteration),file=f) 124 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f) 125 | print('---------ImageSpace--------',file=f) 126 | print('loss_imgs_info: %s'%loss_imgs_info,file=f) 127 | print('---------LatentSpace--------',file=f) 128 | print('loss_w_info: %s'%loss_w_info,file=f) 129 | if iteration % 5000 == 0: 130 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000)) 131 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration) 132 | 133 | if __name__ == "__main__": 134 | 135 | parser = argparse.ArgumentParser(description='the training args') 136 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000 137 | parser.add_argument('--lr', type=float, default=0.0015) 138 | parser.add_argument('--beta_1', type=float, default=0.0) 139 | parser.add_argument('--batch_size', type=int, default=2) 140 | parser.add_argument('--experiment_dir', default=None) #None 141 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt 142 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it 143 | parser.add_argument('--checkpoint_dir_E', default=None) 144 | parser.add_argument('--img_size',type=int, default=1024) 145 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 146 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128 147 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 148 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256 149 | args = parser.parse_args() 150 | 151 | if not os.path.exists('./result'): os.mkdir('./result') 152 | resultPath = args.experiment_dir 153 | if resultPath == None: 154 | resultPath = "./result/StyleGANv1-AlationStudy-w_zn" 155 | if not os.path.exists(resultPath): os.mkdir(resultPath) 156 | 157 | resultPath1_1 = resultPath+"/imgs" 158 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1) 159 | 160 | resultPath1_2 = resultPath+"/models" 161 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2) 162 | 163 | writer_path = os.path.join(resultPath, './summaries') 164 | if not os.path.exists(writer_path): os.mkdir(writer_path) 165 | writer = tensorboardX.SummaryWriter(writer_path) 166 | 167 | use_gpu = True 168 | device = torch.device("cuda" if use_gpu else "cpu") 169 | 170 | train(tensor_writer=writer, args = args) 171 | -------------------------------------------------------------------------------- /ablation_utils/5.E_align_w_zn_zc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import os 4 | import math 5 | import torch 6 | import torchvision 7 | import model.E.E_Blur as BE 8 | from model.utils.custom_adam import LREQAdam 9 | import metric.pytorch_ssim as pytorch_ssim 10 | import lpips 11 | import numpy as np 12 | import tensorboardX 13 | import argparse 14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 15 | from training_utils import * 16 | 17 | def train(tensor_writer = None, args = None): 18 | type = args.mtype 19 | 20 | model_path = args.checkpoint_dir_GAN 21 | if type == 1: # StyleGAN1 22 | #model_path = './checkpoint/stylegan_v1/ffhq1024/' 23 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 24 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth')) 25 | 26 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 27 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth')) 28 | 29 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt') 30 | const_ = Gs.const 31 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda() 32 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024 33 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 34 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 35 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1] 36 | 37 | Gs.cuda() 38 | Gm.eval() 39 | 40 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 41 | 42 | else: 43 | print('error') 44 | return 45 | 46 | if args.checkpoint_dir_E != None: 47 | E.load_state_dict(torch.load(args.checkpoint_dir_E)) 48 | E.cuda() 49 | writer = tensor_writer 50 | 51 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0) 52 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda') 53 | 54 | batch_size = args.batch_size 55 | it_d = 0 56 | for iteration in range(0,args.iterations): 57 | set_seed(iteration%30000) 58 | z = torch.randn(batch_size, args.z_dim) #[32, 512] 59 | 60 | if type == 1: 61 | with torch.no_grad(): #这里需要生成图片和变量 62 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512] 63 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256 64 | const2,w2 = E(imgs1) 65 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2)) 66 | else: 67 | print('model type error') 68 | return 69 | 70 | E_optimizer.zero_grad() 71 | 72 | #loss Images 73 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips) 74 | 75 | loss_msiv = loss_imgs 76 | E_optimizer.zero_grad() 77 | loss_msiv.backward(retain_graph=True) 78 | E_optimizer.step() 79 | 80 | #Latent-Vectors 81 | 82 | ## w 83 | loss_w, loss_w_info = space_loss(w1,w2,image_space = False) 84 | 85 | ## c 86 | loss_c, loss_c_info = space_loss(const1,const2,image_space = False) 87 | 88 | loss_mslv = (loss_w + loss_c)*0.01 89 | E_optimizer.zero_grad() 90 | loss_mslv.backward() 91 | E_optimizer.step() 92 | 93 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000)) 94 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]') 95 | print('---------ImageSpace--------') 96 | print('loss_imgs_info: %s'%loss_imgs_info) 97 | print('---------LatentSpace--------') 98 | print('loss_w_info: %s'%loss_w_info) 99 | print('loss_c_info: %s'%loss_c_info) 100 | 101 | 102 | it_d += 1 103 | 104 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d) 105 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d) 106 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d) 107 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d) 108 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d) 109 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d) 110 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d) 111 | 112 | writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d) 113 | writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d) 114 | writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d) 115 | writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d) 116 | writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d) 117 | writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d) 118 | writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d) 119 | 120 | writer.add_scalar('loss_c_mse', loss_c_info[0][0], global_step=it_d) 121 | writer.add_scalar('loss_c_mse_mean', loss_c_info[0][1], global_step=it_d) 122 | writer.add_scalar('loss_c_mse_std', loss_c_info[0][2], global_step=it_d) 123 | writer.add_scalar('loss_c_kl', loss_c_info[1], global_step=it_d) 124 | writer.add_scalar('loss_c_cosine', loss_c_info[2], global_step=it_d) 125 | writer.add_scalar('loss_c_ssim', loss_c_info[3], global_step=it_d) 126 | writer.add_scalar('loss_c_lpips', loss_c_info[4], global_step=it_d) 127 | 128 | writer.add_scalars('Latent Space W', {'loss_w_mse':loss_w_info[0][0],'loss_w_mse_mean':loss_w_info[0][1],'loss_w_mse_std':loss_w_info[0][2],'loss_w_kl':loss_w_info[1],'loss_w_cosine':loss_w_info[2]}, global_step=it_d) 129 | writer.add_scalars('Latent Space C', {'loss_c_mse':loss_c_info[0][0],'loss_c_mse_mean':loss_c_info[0][1],'loss_c_mse_std':loss_c_info[0][2],'loss_c_kl':loss_c_info[1],'loss_c_cosine':loss_c_info[2]}, global_step=it_d) 130 | 131 | 132 | if iteration % 100 == 0: 133 | n_row = batch_size 134 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5 135 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3 136 | with open(resultPath+'/Loss.txt', 'a+') as f: 137 | print('i_'+str(iteration),file=f) 138 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f) 139 | print('---------ImageSpace--------',file=f) 140 | print('loss_imgs_info: %s'%loss_imgs_info,file=f) 141 | print('---------LatentSpace--------',file=f) 142 | print('loss_w_info: %s'%loss_w_info,file=f) 143 | print('loss_c_info: %s'%loss_c_info,file=f) 144 | if iteration % 5000 == 0: 145 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000)) 146 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration) 147 | 148 | if __name__ == "__main__": 149 | 150 | parser = argparse.ArgumentParser(description='the training args') 151 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000 152 | parser.add_argument('--lr', type=float, default=0.0015) 153 | parser.add_argument('--beta_1', type=float, default=0.0) 154 | parser.add_argument('--batch_size', type=int, default=2) 155 | parser.add_argument('--experiment_dir', default=None) #None 156 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt 157 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it 158 | parser.add_argument('--checkpoint_dir_E', default=None) 159 | parser.add_argument('--img_size',type=int, default=1024) 160 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 161 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128 162 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 163 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256 164 | args = parser.parse_args() 165 | 166 | if not os.path.exists('./result'): os.mkdir('./result') 167 | resultPath = args.experiment_dir 168 | if resultPath == None: 169 | resultPath = "./result/StyleGANv1-AlationStudy-w_zn_zc" 170 | if not os.path.exists(resultPath): os.mkdir(resultPath) 171 | 172 | resultPath1_1 = resultPath+"/imgs" 173 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1) 174 | 175 | resultPath1_2 = resultPath+"/models" 176 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2) 177 | 178 | writer_path = os.path.join(resultPath, './summaries') 179 | if not os.path.exists(writer_path): os.mkdir(writer_path) 180 | writer = tensorboardX.SummaryWriter(writer_path) 181 | 182 | use_gpu = True 183 | device = torch.device("cuda" if use_gpu else "cpu") 184 | 185 | train(tensor_writer=writer, args = args) 186 | -------------------------------------------------------------------------------- /baseline_utils/test-baseline-IndomainG.py: -------------------------------------------------------------------------------- 1 | # python 3.6 2 | """Inverts given images to latent codes with In-Domain GAN Inversion. 3 | 4 | Basically, for a particular image (real or synthesized), this script first 5 | employs the domain-guided encoder to produce a initial point in the latent 6 | space and then performs domain-regularized optimization to refine the latent 7 | code. 8 | """ 9 | 10 | # invert.py 11 | # python test-baseline-indomainG.py 'styleganinv_ffhq256' './styleganv1-generations-512/' 12 | import os 13 | import argparse 14 | from tqdm import tqdm 15 | import numpy as np 16 | 17 | from utils.inverter import StyleGANInverter 18 | from utils.logger import setup_logger 19 | from utils.visualizer import HtmlPageVisualizer 20 | from utils.visualizer import save_image, load_image, resize_image 21 | 22 | import torch 23 | import torchvision 24 | 25 | def parse_args(): 26 | """Parses arguments.""" 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('model_name', type=str, help='Name of the GAN model.') 29 | parser.add_argument('image_list', type=str, 30 | help='List of images to invert.') 31 | parser.add_argument('-o', '--output_dir', type=str, default='', 32 | help='Directory to save the results. If not specified, ' 33 | '`./results/inversion/${IMAGE_LIST}` ' 34 | 'will be used by default.') 35 | parser.add_argument('--learning_rate', type=float, default=0.01, 36 | help='Learning rate for optimization. (default: 0.01)') 37 | parser.add_argument('--num_iterations', type=int, default=100, 38 | help='Number of optimization iterations. (default: 100)') 39 | parser.add_argument('--num_results', type=int, default=5, 40 | help='Number of intermediate optimization results to ' 41 | 'save for each sample. (default: 5)') 42 | parser.add_argument('--loss_weight_feat', type=float, default=5e-5, 43 | help='The perceptual loss scale for optimization. ' 44 | '(default: 5e-5)') 45 | parser.add_argument('--loss_weight_enc', type=float, default=2.0, 46 | help='The encoder loss scale for optimization.' 47 | '(default: 2.0)') 48 | parser.add_argument('--viz_size', type=int, default=256, 49 | help='Image size for visualization. (default: 256)') 50 | parser.add_argument('--gpu_id', type=str, default='0', 51 | help='Which GPU(s) to use. (default: `0`)') 52 | return parser.parse_args() 53 | 54 | 55 | def main(): 56 | """Main function.""" 57 | args = parse_args() 58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 59 | assert os.path.exists(args.image_list) 60 | image_list_name = os.path.splitext(os.path.basename(args.image_list))[0] 61 | output_dir = args.output_dir or f'results/inversion/{image_list_name}' 62 | logger = setup_logger(output_dir, 'inversion.log', 'inversion_logger') 63 | 64 | logger.info(f'Loading model.') 65 | inverter = StyleGANInverter( 66 | args.model_name, 67 | learning_rate=args.learning_rate, 68 | iteration=args.num_iterations, 69 | reconstruction_loss_weight=1.0, 70 | perceptual_loss_weight=args.loss_weight_feat, 71 | regularization_loss_weight=args.loss_weight_enc, 72 | logger=logger) 73 | image_size = inverter.G.resolution 74 | 75 | # Load image list. 76 | # logger.info(f'Loading image list.') 77 | # image_list = [] 78 | # with open(args.image_list, 'r') as f: 79 | # for line in f: 80 | # image_list.append(line.strip()) 81 | 82 | path = args.image_list # 文件路径 83 | paths = list(os.listdir(path)) 84 | image_list = [] 85 | for i in paths: 86 | image_list.append(path + '/' + i) 87 | 88 | 89 | # Initialize visualizer. 90 | save_interval = args.num_iterations // args.num_results 91 | headers = ['Name', 'Original Image', 'Encoder Output'] 92 | for step in range(1, args.num_iterations + 1): 93 | if step == args.num_iterations or step % save_interval == 0: 94 | headers.append(f'Step {step:06d}') 95 | viz_size = None if args.viz_size == 0 else args.viz_size 96 | visualizer = HtmlPageVisualizer( 97 | num_rows=len(image_list), num_cols=len(headers), viz_size=viz_size) 98 | visualizer.set_headers(headers) 99 | 100 | # Invert images. 101 | logger.info(f'Start inversion.') 102 | latent_codes = [] 103 | codes_rec = [] 104 | codes_inv = [] 105 | for img_idx in tqdm(range(len(image_list)), leave=False): 106 | image_path = image_list[img_idx] 107 | image_name = os.path.splitext(os.path.basename(image_path))[0] 108 | image = resize_image(load_image(image_path), (image_size, image_size)) 109 | code, viz_results = inverter.easy_invert(image, num_viz=args.num_results) 110 | latent_codes.append(code) 111 | #save_image(f'{output_dir}/{image_name}_ori.png', image) 112 | save_image(f'{output_dir}/r/{image_name}_enc.png', viz_results[1]) 113 | save_image(f'{output_dir}/i/{image_name}_inv.png', viz_results[-1]) 114 | visualizer.set_cell(img_idx, 0, text=image_name) 115 | visualizer.set_cell(img_idx, 1, image=image) 116 | codes_rec.append(torch.tensor(viz_results[1],dtype=float)/255.0) # cv2->pytorch 117 | codes_inv.append(torch.tensor(viz_results[-1],dtype=float)/255.0) 118 | for viz_idx, viz_img in enumerate(viz_results[1:]): 119 | visualizer.set_cell(img_idx, viz_idx + 2, image=viz_img) 120 | 121 | # Save results. 122 | #os.system(f'cp {args.image_list} {output_dir}/image_list.txt') 123 | np.save(f'{output_dir}/inverted_codes.npy', np.concatenate(latent_codes, axis=0)) 124 | visualizer.save(f'{output_dir}/inversion.html') 125 | 126 | # codes_rec_tensor_ = torch.stack(codes_rec, dim=0) 127 | # codes_rec_tensor = codes_rec_tensor_.permute(0,3,1,2) # cv2->pytorch 128 | # codes_inv_tensor_ = torch.stack(codes_inv, dim=0) 129 | # codes_inv_tensor = codes_inv_tensor_.permute(0,3,1,2) 130 | # torch.save(codes_rec_tensor,'./indomain-rec32.pt') 131 | # torch.save(codes_inv_tensor,'./indomain-inv32.pt') 132 | # torchvision.utils.save_image(codes_rec_tensor,'./indomain_images_rec.png',nrow=5) 133 | # torchvision.utils.save_image(codes_inv_tensor,'./indomain_images_inv.png',nrow=5) 134 | 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /baseline_utils/test_baseline_alae.py: -------------------------------------------------------------------------------- 1 | #test_baseline_alae.py 2 | 3 | 4 | import sys 5 | sys.path.append(".") 6 | sys.path.append("..") 7 | 8 | import torch.utils.data 9 | import torchvision 10 | import random 11 | from net import * 12 | from model import Model 13 | from launcher import run 14 | from checkpointer import Checkpointer 15 | from dlutils.pytorch import count_parameters 16 | from defaults import get_cfg_defaults 17 | import lreq 18 | from skimage.transform import resize 19 | from PIL import Image 20 | import logging 21 | 22 | config_file='./configs/ffhq.yaml' 23 | cfg=get_cfg_defaults() 24 | cfg.merge_from_file(config_file) 25 | cfg.freeze() 26 | 27 | torch.cuda.set_device(0) 28 | model = Model( 29 | startf=cfg.MODEL.START_CHANNEL_COUNT, 30 | layer_count=cfg.MODEL.LAYER_COUNT, 31 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT, 32 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE, 33 | truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, 34 | truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, 35 | mapping_layers=cfg.MODEL.MAPPING_LAYERS, 36 | channels=cfg.MODEL.CHANNELS, 37 | generator=cfg.MODEL.GENERATOR, 38 | encoder=cfg.MODEL.ENCODER) 39 | model.cuda(0) 40 | model.eval() 41 | model.requires_grad_(False) 42 | 43 | decoder = model.decoder 44 | encoder = model.encoder 45 | mapping_tl = model.mapping_d 46 | mapping_fl = model.mapping_f 47 | dlatent_avg = model.dlatent_avg 48 | 49 | logger = logging.getLogger("logger") 50 | logger.setLevel(logging.DEBUG) 51 | 52 | logger.info("Trainable parameters generator:") 53 | count_parameters(decoder) 54 | 55 | logger.info("Trainable parameters discriminator:") 56 | count_parameters(encoder) 57 | 58 | arguments = dict() 59 | arguments["iteration"] = 0 60 | 61 | model_dict = { 62 | 'discriminator_s': encoder, 63 | 'generator_s': decoder, 64 | 'mapping_tl_s': mapping_tl, 65 | 'mapping_fl_s': mapping_fl, 66 | 'dlatent_avg': dlatent_avg 67 | } 68 | 69 | checkpointer = Checkpointer(cfg, 70 | model_dict, 71 | {}, 72 | logger=logger, 73 | save=False) 74 | 75 | extra_checkpoint_data = checkpointer.load() 76 | 77 | model.eval() 78 | 79 | layer_count = cfg.MODEL.LAYER_COUNT 80 | im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1) 81 | print(im_size) 82 | 83 | def encode(x): 84 | Z, _ = model.encode(x, layer_count - 1, 1) 85 | Z = Z.repeat(1, model.mapping_f.num_layers, 1) 86 | return Z 87 | 88 | def decode(x): 89 | layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis] 90 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) 91 | coefs = torch.where(layer_idx < model.truncation_cutoff, 1.0 * ones, ones) 92 | # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs) 93 | return model.decoder(x, layer_count - 1, 1, noise=True) 94 | 95 | #传入路径,输出图片及其重构 96 | def make(paths): 97 | src = [] 98 | for i,filename in enumerate(paths): 99 | img = Image.open(path + '/' + filename) 100 | img = img.resize((im_size,im_size)) 101 | img = np.asarray(img) 102 | print(i,img.shape) 103 | if img.shape[2] == 4: 104 | img = img[:, :, :3] 105 | im = img.transpose((2, 0, 1)) 106 | x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1. 107 | if x.shape[0] == 4: 108 | x = x[:3] 109 | factor = x.shape[2] // im_size 110 | if factor != 1: 111 | x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0] 112 | assert x.shape[2] == im_size 113 | src.append(x) 114 | 115 | with torch.no_grad(): 116 | reconstructions = [] 117 | for s in src: 118 | latents = encode(s[None, ...]) 119 | reconstructions.append(decode(latents).cpu().detach().numpy()) 120 | # print(len(src)) 121 | # print(src[0].shape) 122 | # print(reconstructions[0].shape) 123 | return src, reconstructions 124 | 125 | 126 | path = './styleganv1-generations' 127 | paths = list(os.listdir(path)) 128 | 129 | # paths = paths[:256] 130 | # chuncker_id = 0 # 0, 1, 2, 3 131 | 132 | paths = paths[256:] 133 | chuncker_id = 1 # 0, 1, 2, 3 134 | 135 | src0, rec0 = make(paths) 136 | src0 = [torch.tensor(array_np) for array_np in src0] 137 | rec0 = [torch.tensor(array_np[0]) for array_np in rec0] 138 | batched_images1 = torch.stack(src0, dim=0) 139 | batched_images2 = torch.stack(rec0, dim=0) 140 | print(batched_images1.shape) 141 | print(batched_images2.shape) 142 | 143 | for i,j in enumerate(batched_images2): 144 | j = j.unsqueeze(0) 145 | torchvision.utils.save_image(j*0.5+0.5,'./%s_rec.png'%str(i+256*chuncker_id).rjust(3,'0')) 146 | 147 | #torch.save(batched_images1*0.5+0.5,'./ALAE-real-30.pt') 148 | #torch.save(batched_images2*0.5+0.5,'./ALAE-rec-30.pt') -------------------------------------------------------------------------------- /baseline_utils/test_baseline_psp.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import time 3 | import sys 4 | import os 5 | import pprint 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | 15 | 16 | from datasets import augmentations 17 | from utils.common import tensor2im, log_input_image 18 | from models.psp import pSp 19 | 20 | 21 | ## 处理编码图像 22 | save_path = "./real-128imgs" 23 | 24 | ## 配置 25 | experiment_type = 'ffhq_encode' 26 | #@param ['ffhq_encode', 'ffhq_frontalize', 'celebs_sketch_to_face', 'celebs_seg_to_face', 'celebs_super_resolution', 'toonify'] 27 | 28 | EXPERIMENT_DATA_ARGS = { 29 | "ffhq_encode": { 30 | "model_path": "pretrained_models/psp_ffhq_encode.pt", 31 | "image_path": "notebooks/images/input_img.jpg", 32 | "transform": transforms.Compose([ 33 | transforms.Resize((256, 256)), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 36 | }, 37 | "ffhq_frontalize": { 38 | "model_path": "pretrained_models/psp_ffhq_frontalization.pt", 39 | "image_path": "notebooks/images/input_img.jpg", 40 | "transform": transforms.Compose([ 41 | transforms.Resize((256, 256)), 42 | transforms.ToTensor(), 43 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 44 | }, 45 | "celebs_sketch_to_face": { 46 | "model_path": "pretrained_models/psp_celebs_sketch_to_face.pt", 47 | "image_path": "notebooks/images/input_sketch.jpg", 48 | "transform": transforms.Compose([ 49 | transforms.Resize((256, 256)), 50 | transforms.ToTensor()]) 51 | }, 52 | "celebs_seg_to_face": { 53 | "model_path": "pretrained_models/psp_celebs_seg_to_face.pt", 54 | "image_path": "notebooks/images/input_mask.png", 55 | "transform": transforms.Compose([ 56 | transforms.Resize((256, 256)), 57 | augmentations.ToOneHot(n_classes=19), 58 | transforms.ToTensor()]) 59 | }, 60 | "celebs_super_resolution": { 61 | "model_path": "pretrained_models/psp_celebs_super_resolution.pt", 62 | "image_path": "notebooks/images/input_img.jpg", 63 | "transform": transforms.Compose([ 64 | transforms.Resize((256, 256)), 65 | augmentations.BilinearResize(factors=[16]), 66 | transforms.Resize((256, 256)), 67 | transforms.ToTensor(), 68 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 69 | }, 70 | "toonify": { 71 | "model_path": "pretrained_models/psp_ffhq_toonify.pt", 72 | "image_path": "notebooks/images/input_img.jpg", 73 | "transform": transforms.Compose([ 74 | transforms.Resize((256, 256)), 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 77 | }, 78 | } 79 | 80 | EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type] 81 | 82 | model_path = EXPERIMENT_ARGS['model_path'] 83 | ckpt = torch.load(model_path, map_location='cpu') 84 | 85 | opts = ckpt['opts'] 86 | # pprint.pprint(opts) 87 | 88 | # update the training options 89 | opts['checkpoint_path'] = model_path 90 | if 'learn_in_w' not in opts: 91 | opts['learn_in_w'] = False 92 | if 'output_size' not in opts: 93 | opts['output_size'] = 1024 94 | 95 | 96 | opts = Namespace(**opts) 97 | net = pSp(opts) 98 | net.eval() 99 | device = 'cuda' # 'cuda' 100 | net.cuda() 101 | print('Model successfully loaded!') 102 | 103 | import matplotlib; matplotlib.use('TkAgg') 104 | import matplotlib.pyplot as plt 105 | 106 | def run_on_batch(inputs, net, latent_mask=None): 107 | if latent_mask is None: 108 | result_batch = net(inputs.to(device).float(), randomize_noise=False) 109 | else: 110 | result_batch = [] 111 | for image_idx, input_image in enumerate(inputs): 112 | # get latent vector to inject into our input image 113 | vec_to_inject = np.random.randn(1, 512).astype('float32') 114 | _, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"), 115 | input_code=True, 116 | return_latents=True) 117 | # get output image with injected style vector 118 | res = net(input_image.unsqueeze(0).to(device).float(), 119 | latent_mask=latent_mask, 120 | inject_latent=latent_to_inject) 121 | result_batch.append(res) 122 | result_batch = torch.cat(result_batch, dim=0) 123 | return result_batch 124 | 125 | image_paths = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith(".png") or f.endswith(".jpg")] 126 | n_images = len(image_paths) 127 | 128 | images = [] 129 | n_cols = np.ceil(n_images / 2) 130 | #fig = plt.figure(figsize=(20, 4)) 131 | for idx, image_path in enumerate(image_paths): 132 | #ax = fig.add_subplot(2, n_cols, idx + 1) 133 | img = Image.open(image_path).convert("RGB") 134 | images.append(img) 135 | # ax.grid(False) 136 | # ax.set_xticks([]) 137 | # ax.set_yticks([]) 138 | # ax.imshow(img) 139 | #plt.show() 140 | 141 | img_transforms = EXPERIMENT_ARGS['transform'] 142 | transformed_images = [img_transforms(image) for image in images] 143 | 144 | batched_images = torch.stack(transformed_images, dim=0) 145 | 146 | #batched_images = torch.load('real-30.pt') 147 | 148 | with torch.no_grad(): 149 | tic = time.time() 150 | for i,j in enumerate(batched_images): 151 | j = j.unsqueeze(0) 152 | result_images = run_on_batch(j, net, latent_mask=None) 153 | toc = time.time() 154 | print('Inference took {:.4f} seconds.'.format(toc - tic)) 155 | torchvision.utils.save_image(result_images*0.5+0.5,'./r/%s_pSp_Rec.jpg'%str(i).rjust(5,'0')) 156 | 157 | # from IPython.display import display 158 | 159 | # couple_results = [] 160 | # for original_image, result_image in zip(images, result_images): 161 | # result_image = tensor2im(result_image) 162 | # res = np.concatenate([np.array(original_image.resize((256, 256))), 163 | # np.array(result_image.resize((256, 256)))], axis=1) 164 | # res_im = Image.fromarray(res) 165 | # couple_results.append(res_im) 166 | # display(res_im) 167 | # #import matplotlib.pyplot as plt 168 | # #img = plt.imread('1.jpg')#读取图片 169 | # plt.imshow(res_im)#展示图片 170 | # plt.show() 171 | # result_images = result_images*0.5+0.5 172 | # torch.save(result_images,'./psp_w30.pt') 173 | # torchvision.utils.save_image(result_images,'./batched_images2.png',nrow=5) 174 | 175 | -------------------------------------------------------------------------------- /comparing-baseline.py: -------------------------------------------------------------------------------- 1 | # testing 2 | 3 | import os 4 | import skimage 5 | import lpips 6 | import torch 7 | import torchvision 8 | from PIL import Image 9 | 10 | 11 | save_path1 = './styleganv1-generations/' 12 | save_path2 = './MTV-rec/' 13 | 14 | loss_mse = torch.nn.MSELoss() 15 | loss_lpips = lpips.LPIPS(net='vgg') 16 | 17 | def cosineSimilarty(imgs1_cos,imgs2_cos): 18 | values = imgs1_cos.dot(imgs2_cos)/(torch.sqrt(imgs1_cos.dot(imgs1_cos))*torch.sqrt(imgs2_cos.dot(imgs2_cos))) # [0,1] 19 | return values 20 | 21 | def metrics(img_tensor1,img_tensor2): 22 | 23 | psnr = skimage.measure.compare_psnr(img_tensor1.float().numpy().transpose(1,2,0), img_tensor2.float().numpy().transpose(1,2,0), 255) #range:[0,255] 24 | 25 | ssim = skimage.measure.compare_ssim(img_tensor1.float().numpy().transpose(1,2,0), img_tensor2.float().numpy().transpose(1,2,0), data_range=255, multichannel=True)#[h,w,c] and range:[0,255] 26 | 27 | mse_value = loss_mse(img_tensor1,img_tensor2).numpy() #range:[0,255] 28 | 29 | lpips_value = loss_lpips(img_tensor1.unsqueeze(0)/255.0*2-1,img_tensor2.unsqueeze(0)/255.0*2-1).mean().detach().numpy() #range:[-1,1] 30 | 31 | cosine_value = cosineSimilarty(img_tensor1.view(-1)/255.0*2-1,img_tensor2.view(-1)/255.0*2-1).numpy() #range:[-1,1] 32 | 33 | print('-------------') 34 | print('psnr:',psnr) 35 | print('-------------') 36 | print('ssim:',ssim) 37 | print('-------------') 38 | print('lpips:',lpips_value) 39 | print('-------------') 40 | print('mse:',mse_value) 41 | print('-------------') 42 | print('cosine:',cosine_value) 43 | 44 | return psnr, ssim, lpips_value, mse_value, cosine_value 45 | 46 | #--------文件夹内的图片转换为tensor:[n,c,h,w]------------------ 47 | img_size = 512 48 | #PIL 2 Tensor 49 | transform = torchvision.transforms.Compose([ 50 | #torchvision.transforms.CenterCrop(160), 51 | torchvision.transforms.Resize((img_size,img_size)), 52 | torchvision.transforms.ToTensor(), 53 | #torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 54 | ]) 55 | 56 | #-----------------------------------------metric imgs_tensor------------------------------- 57 | n = 0 58 | psnr_values = 0 59 | ssim_values = 0 60 | lpips_values = 0 61 | mse_values = 0 62 | cosine_values = 0 63 | 64 | imgs_path1 = [os.path.join(save_path1, f) for f in os.listdir(save_path1) if f.endswith(".png") or f.endswith(".jpg")] 65 | imgs_path2 = [os.path.join(save_path2, f) for f in os.listdir(save_path2) if f.endswith(".png") or f.endswith(".jpg")] 66 | 67 | for i,j in zip(imgs_path1,imgs_path2): 68 | print(i, 'vs.', j) 69 | img1 = Image.open(i).convert("RGB") 70 | img1 = transform(img1) 71 | img2 = Image.open(j).convert("RGB") 72 | img2 = transform(img2) 73 | 74 | img1 = img1*255.0 75 | img2 = img2*255.0 76 | 77 | n = n + 1 78 | 79 | psnr, ssim, lpips_value, mse_value, cosine_value = metrics(img1,img2) 80 | psnr_values +=psnr 81 | ssim_values +=ssim 82 | lpips_values +=lpips_value 83 | mse_values +=mse_value 84 | cosine_values +=cosine_value 85 | 86 | print('img_num:%d--psnr:%f--ssim:%f--mse_value:%f--lpips_value:%f--cosine_value:%f'\ 87 | %(n,psnr_values/n, ssim_values/n, mse_values/n, lpips_values/n, cosine_values/n)) 88 | # if imgs_tensor1 = imgs_tensor2: -psnr: inf or 88.132626(with 1e-3) --ssim:1.000000--lpips_value:0.000000--mse_value:0.000000--cosine_value:1.000001 89 | 90 | # imgs_path1 = [os.path.join(save_path1, f) for f in os.listdir(save_path1) if f.endswith(".png")] 91 | # images1 = [] 92 | # for idx, image_path in enumerate(imgs_path1): 93 | # print(image_path) 94 | # img = Image.open(image_path).convert("RGB") 95 | # img = transform(img) 96 | # images1.append(img) 97 | # imgs_tensor1 = torch.stack(images1, dim=0) 98 | 99 | 100 | # imgs_path2 = [os.path.join(save_path2, f) for f in os.listdir(save_path2) if f.endswith(".png")] 101 | # images2 = [] 102 | # for idx, image_path in enumerate(imgs_path2): 103 | # print(image_path) 104 | # img = Image.open(image_path).convert("RGB") 105 | # img = transform(img) 106 | # images2.append(img) 107 | # imgs_tensor2 = torch.stack(images2, dim=0) 108 | 109 | # if len(imgs_tensor1) != len(imgs_tensor2): 110 | # print('error: 2 comparing numbers are not equal!') 111 | 112 | -------------------------------------------------------------------------------- /embeded_img_edit.py: -------------------------------------------------------------------------------- 1 | #Embedded_ImageProcessing, Just for StyleGAN_v1 FFHQ 2 | 3 | import numpy as np 4 | import math 5 | import torch 6 | import torchvision 7 | import model.E.E_Blur as BE 8 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 9 | 10 | #Params 11 | use_gpu = False 12 | device = torch.device("cuda" if use_gpu else "cpu") 13 | img_size = 1024 14 | GAN_path = './checkpoint/stylegan_v1/ffhq1024/' 15 | direction = 'eyeglasses' #smile, eyeglasses, pose, age, gender 16 | direction_path = './latentvectors/directions/stylegan_ffhq_%s_w_boundary.npy'%direction 17 | w_path = './latentvectors/faces/i3_cxx2.pt' 18 | 19 | #Loading Pre-trained Model, Directions 20 | Gs = Generator(startf=16, maxf=512, layer_count=int(math.log(img_size,2)-1), latent_size=512, channels=3) 21 | Gs.load_state_dict(torch.load(GAN_path+'Gs_dict.pth', map_location=device)) 22 | 23 | # E = BE.BE() 24 | # E.load_state_dict(torch.load('./checkpoint/E/styleganv1.pth',map_location=torch.device('cpu'))) 25 | 26 | direction = np.load(direction_path) #[[1, 512] interfaceGAN 27 | direction = torch.tensor(direction).float() 28 | direction = direction.expand(18,512) 29 | print(direction.shape) 30 | 31 | w = torch.load(w_path, map_location=device).clone().squeeze(0) 32 | print(w.shape) 33 | 34 | # discovering face semantic attribute dirrections 35 | bonus= 70 #bonus (-10) <- (-5) <- 0 ->5 ->10 36 | start= 0 # default 0, if not 0, will be bed performance 37 | end= 3 # default 3 or 4. if 3, it will keep face features (glasses). if 4, it will keep dirrection features (Smile). 38 | w[start:start+end] = (w+bonus*direction)[start:start+end] 39 | #w = w + bonus*direction 40 | w = w.reshape(1,18,512) 41 | with torch.no_grad(): 42 | img = Gs.forward(w,8) # 8->1024 43 | torchvision.utils.save_image(img*0.5+0.5, './img_bonus%d_start%d_end%d.png'%(bonus,start,end)) 44 | 45 | ## end=3 人物ID的特征明显,end=4 direction的特征明显, end>4 空间纠缠严重 46 | #smile: bonue*100, start=0, end=4(end不到4作用不大,end或bonus越大越猖狂) 47 | #glass: bonue*200, start=0, end=4(end超过6开始崩,bonus也不宜过大) 48 | #pose: bonue*5-10, start=0, end=3 -------------------------------------------------------------------------------- /image_results/baseline/ALAE-rec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/ALAE-rec.jpg -------------------------------------------------------------------------------- /image_results/baseline/Real.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/Real.jpg -------------------------------------------------------------------------------- /image_results/baseline/indomain_images_inv.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/indomain_images_inv.jpg -------------------------------------------------------------------------------- /image_results/baseline/indomain_images_rec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/indomain_images_rec.jpg -------------------------------------------------------------------------------- /image_results/baseline/mtv-tsa-(1500E).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/mtv-tsa-(1500E).jpg -------------------------------------------------------------------------------- /image_results/baseline/psp-rec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/psp-rec.jpg -------------------------------------------------------------------------------- /image_results/cxx1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/cxx1.gif -------------------------------------------------------------------------------- /image_results/cxx2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/cxx2.gif -------------------------------------------------------------------------------- /image_results/dy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/dy.gif -------------------------------------------------------------------------------- /image_results/msk.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/msk.gif -------------------------------------------------------------------------------- /image_results/zy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/zy.gif -------------------------------------------------------------------------------- /latent_code/directions/stylegan_ffhq_age_w_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_age_w_boundary.npy -------------------------------------------------------------------------------- /latent_code/directions/stylegan_ffhq_eyeglasses_w_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_eyeglasses_w_boundary.npy -------------------------------------------------------------------------------- /latent_code/directions/stylegan_ffhq_gender_w_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_gender_w_boundary.npy -------------------------------------------------------------------------------- /latent_code/directions/stylegan_ffhq_pose_w_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_pose_w_boundary.npy -------------------------------------------------------------------------------- /latent_code/directions/stylegan_ffhq_smile_w_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_smile_w_boundary.npy -------------------------------------------------------------------------------- /latent_code/real_face_code/i0_cxx1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i0_cxx1.pt -------------------------------------------------------------------------------- /latent_code/real_face_code/i1_dy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i1_dy.pt -------------------------------------------------------------------------------- /latent_code/real_face_code/i2_zy.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i2_zy.pt -------------------------------------------------------------------------------- /latent_code/real_face_code/i3_cxx2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i3_cxx2.pt -------------------------------------------------------------------------------- /latent_code/real_face_code/i4_msk.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i4_msk.pt -------------------------------------------------------------------------------- /latent_code/real_face_code/i5_ty.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i5_ty.pt -------------------------------------------------------------------------------- /metric/pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /model/E/Ablation_Study/E_Blur_W.py: -------------------------------------------------------------------------------- 1 | # For ablation study: w 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.nn.parameter import Parameter 7 | import sys 8 | #sys.path.append('../') 9 | import model.utils.lreq as ln 10 | from model.utils.net import Blur,FromRGB,downscale2d 11 | from torch.nn import functional as F 12 | 13 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 14 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 15 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 16 | 17 | class BEBlock(nn.Module): 18 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样 19 | super().__init__() 20 | self.has_last_conv = has_last_conv 21 | # self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 22 | # self.noise_weight_1.data.zero_() 23 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 24 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 25 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512] 26 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 27 | 28 | # self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 29 | # self.noise_weight_2.data.zero_() 30 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 31 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 32 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1) 33 | self.blur = Blur(inputs) 34 | if has_last_conv: 35 | if fused_scale: 36 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) 37 | else: 38 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 39 | self.fused_scale = fused_scale 40 | 41 | self.inputs = inputs 42 | self.outputs = outputs 43 | 44 | if self.inputs != self.outputs: 45 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0) 46 | 47 | with torch.no_grad(): 48 | self.bias_1.zero_() 49 | self.bias_2.zero_() 50 | 51 | def forward(self, x): 52 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 53 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 54 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1] 55 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512] 56 | 57 | residual = x 58 | 59 | x = self.instance_norm_1(x) 60 | x = self.conv_1(x) 61 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 62 | x = x + self.bias_1 63 | x = F.leaky_relu(x, 0.2) 64 | 65 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 66 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 67 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1] 68 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view 69 | 70 | x = self.instance_norm_2(x) 71 | if self.has_last_conv: 72 | x = self.blur(x) 73 | x = self.conv_2(x) 74 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 75 | x = x + self.bias_2 76 | x = F.leaky_relu(x, 0.2) 77 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样 78 | x = downscale2d(x) 79 | residual = downscale2d(residual) 80 | 81 | 82 | if self.inputs != self.outputs: 83 | residual = self.conv_3(residual) 84 | 85 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好 86 | return x, w1, w2 87 | 88 | 89 | class BE(nn.Module): 90 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3): 91 | super().__init__() 92 | self.maxf = maxf 93 | self.startf = startf 94 | self.latent_size = latent_size 95 | self.layer_to_resolution = [0 for _ in range(layer_count)] 96 | self.decode_block = nn.ModuleList() 97 | self.layer_count = layer_count 98 | inputs = startf # 16 99 | outputs = startf*2 100 | resolution = 1024 101 | self.FromRGB = FromRGB(channels, inputs) 102 | #from_RGB = nn.ModuleList() 103 | for i in range(layer_count): 104 | 105 | has_last_conv = i+1 != layer_count 106 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样 107 | 108 | #from_RGB.append(FromRGB(channels, inputs)) 109 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale) 110 | 111 | inputs = inputs*2 112 | outputs = outputs*2 113 | inputs = min(maxf, inputs) 114 | outputs = min(maxf, outputs) 115 | self.layer_to_resolution[i] = resolution 116 | resolution /=2 117 | self.decode_block.append(block) 118 | 119 | #self.FromRGB = from_RGB 120 | 121 | #将w逆序,以保证和G的w顺序, block_num控制progressive 122 | def forward(self, x, block_num=9): 123 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB 124 | x = self.FromRGB(x) 125 | #print(x.shape) 126 | w = torch.tensor(0) 127 | for i in range(9-block_num,self.layer_count): 128 | x,w1,w2 = self.decode_block[i](x) 129 | #print(x.shape) 130 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512] 131 | if i == (9-block_num): 132 | w = w_ # [b,n,512] 133 | else: 134 | w = torch.cat((w_,w),dim=1) 135 | #print(w.shape) 136 | return x, w -------------------------------------------------------------------------------- /model/E/Ablation_Study/E_Blur_W_2.py: -------------------------------------------------------------------------------- 1 | # For ablation study: w/2 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.nn.parameter import Parameter 7 | import sys 8 | #sys.path.append('../') 9 | import model.utils.lreq as ln 10 | from model.utils.net import Blur,FromRGB,downscale2d 11 | from torch.nn import functional as F 12 | 13 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 14 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 15 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 16 | 17 | class BEBlock(nn.Module): 18 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样 19 | super().__init__() 20 | self.has_last_conv = has_last_conv 21 | # self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 22 | # self.noise_weight_1.data.zero_() 23 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 24 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 25 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512] 26 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 27 | 28 | # self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 29 | # self.noise_weight_2.data.zero_() 30 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 31 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 32 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1) 33 | self.blur = Blur(inputs) 34 | if has_last_conv: 35 | if fused_scale: 36 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) 37 | else: 38 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 39 | self.fused_scale = fused_scale 40 | 41 | self.inputs = inputs 42 | self.outputs = outputs 43 | 44 | if self.inputs != self.outputs: 45 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0) 46 | 47 | with torch.no_grad(): 48 | self.bias_1.zero_() 49 | self.bias_2.zero_() 50 | 51 | def forward(self, x): 52 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 53 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 54 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1] 55 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512] 56 | 57 | residual = x 58 | 59 | x = self.instance_norm_1(x) 60 | x = self.conv_1(x) 61 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 62 | x = x + self.bias_1 63 | x = F.leaky_relu(x, 0.2) 64 | 65 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 66 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 67 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1] 68 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view 69 | 70 | x = self.instance_norm_2(x) 71 | if self.has_last_conv: 72 | x = self.blur(x) 73 | x = self.conv_2(x) 74 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 75 | x = x + self.bias_2 76 | x = F.leaky_relu(x, 0.2) 77 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样 78 | x = downscale2d(x) 79 | residual = downscale2d(residual) 80 | 81 | 82 | if self.inputs != self.outputs: 83 | residual = self.conv_3(residual) 84 | 85 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好 86 | return x, w1, w2 87 | 88 | 89 | class BE(nn.Module): 90 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3): 91 | super().__init__() 92 | self.maxf = maxf 93 | self.startf = startf 94 | self.latent_size = latent_size 95 | self.layer_to_resolution = [0 for _ in range(layer_count)] 96 | self.decode_block = nn.ModuleList() 97 | self.layer_count = layer_count 98 | inputs = startf # 16 99 | outputs = startf*2 100 | resolution = 1024 101 | self.FromRGB = FromRGB(channels, inputs) 102 | #from_RGB = nn.ModuleList() 103 | for i in range(layer_count): 104 | 105 | has_last_conv = i+1 != layer_count 106 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样 107 | 108 | #from_RGB.append(FromRGB(channels, inputs)) 109 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale) 110 | 111 | inputs = inputs*2 112 | outputs = outputs*2 113 | inputs = min(maxf, inputs) 114 | outputs = min(maxf, outputs) 115 | self.layer_to_resolution[i] = resolution 116 | resolution /=2 117 | self.decode_block.append(block) 118 | 119 | #self.FromRGB = from_RGB 120 | 121 | #将w逆序,以保证和G的w顺序, block_num控制progressive 122 | def forward(self, x, block_num=9): 123 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB 124 | x = self.FromRGB(x) 125 | #print(x.shape) 126 | w = torch.tensor(0) 127 | for i in range(9-block_num,self.layer_count): 128 | x, _ , w2 = self.decode_block[i](x) 129 | #print(x.shape) 130 | w_ = torch.cat((w2.view(x.shape[0],1,512),w2.view(x.shape[0],1,512)),dim=1) # [b,2,512] 131 | if i == (9-block_num): 132 | w = w_ # [b,n,512] 133 | else: 134 | w = torch.cat((w_,w),dim=1) 135 | #print(w.shape) 136 | return x, w -------------------------------------------------------------------------------- /model/E/Ablation_Study/E_Blur_Z.py: -------------------------------------------------------------------------------- 1 | # For ablation study: w 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.nn.parameter import Parameter 7 | import sys 8 | #sys.path.append('../') 9 | import model.utils.lreq as ln 10 | from model.utils.net import Blur,FromRGB,downscale2d 11 | from torch.nn import functional as F 12 | 13 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 14 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 15 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 16 | 17 | class BEBlock(nn.Module): 18 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样 19 | super().__init__() 20 | self.has_last_conv = has_last_conv 21 | # self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 22 | # self.noise_weight_1.data.zero_() 23 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 24 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 25 | #self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512] 26 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 27 | 28 | # self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 29 | # self.noise_weight_2.data.zero_() 30 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 31 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 32 | #self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1) 33 | self.blur = Blur(inputs) 34 | if has_last_conv: 35 | if fused_scale: 36 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) 37 | else: 38 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 39 | self.fused_scale = fused_scale 40 | 41 | self.inputs = inputs 42 | self.outputs = outputs 43 | 44 | if self.inputs != self.outputs: 45 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0) 46 | 47 | with torch.no_grad(): 48 | self.bias_1.zero_() 49 | self.bias_2.zero_() 50 | 51 | def forward(self, x): 52 | # mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 53 | # std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 54 | # style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1] 55 | # w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512] 56 | 57 | residual = x 58 | 59 | x = self.instance_norm_1(x) 60 | x = self.conv_1(x) 61 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 62 | x = x + self.bias_1 63 | x = F.leaky_relu(x, 0.2) 64 | 65 | # mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 66 | # std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 67 | # style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1] 68 | # w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view 69 | 70 | x = self.instance_norm_2(x) 71 | if self.has_last_conv: 72 | x = self.blur(x) 73 | x = self.conv_2(x) 74 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 75 | x = x + self.bias_2 76 | x = F.leaky_relu(x, 0.2) 77 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样 78 | x = downscale2d(x) 79 | residual = downscale2d(residual) 80 | 81 | 82 | if self.inputs != self.outputs: 83 | residual = self.conv_3(residual) 84 | 85 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好 86 | return x 87 | 88 | 89 | class BE(nn.Module): 90 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3): 91 | super().__init__() 92 | self.maxf = maxf 93 | self.startf = startf 94 | self.latent_size = latent_size 95 | self.layer_to_resolution = [0 for _ in range(layer_count)] 96 | self.decode_block = nn.ModuleList() 97 | self.layer_count = layer_count 98 | inputs = startf # 16 99 | outputs = startf*2 100 | resolution = 1024 101 | self.FromRGB = FromRGB(channels, inputs) 102 | self.out_z = ln.Conv2d(512,512,3,2) 103 | #from_RGB = nn.ModuleList() 104 | for i in range(layer_count): 105 | 106 | has_last_conv = i+1 != layer_count 107 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样 108 | 109 | #from_RGB.append(FromRGB(channels, inputs)) 110 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale) 111 | 112 | inputs = inputs*2 113 | outputs = outputs*2 114 | inputs = min(maxf, inputs) 115 | outputs = min(maxf, outputs) 116 | self.layer_to_resolution[i] = resolution 117 | resolution /=2 118 | self.decode_block.append(block) 119 | 120 | #self.FromRGB = from_RGB 121 | 122 | #将w逆序,以保证和G的w顺序, block_num控制progressive 123 | def forward(self, x, block_num=9): 124 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB 125 | x = self.FromRGB(x) 126 | #print(x.shape) 127 | w = torch.tensor(0) 128 | for i in range(9-block_num,self.layer_count): 129 | x = self.decode_block[i](x) 130 | #print(x.shape) 131 | # w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512] 132 | # if i == (9-block_num): 133 | # w = w_ # [b,n,512] 134 | # else: 135 | # w = torch.cat((w_,w),dim=1) 136 | # #print(w.shape) 137 | z = self.out_z(x) 138 | return z, w -------------------------------------------------------------------------------- /model/E/Ablation_Study/E_v1.py: -------------------------------------------------------------------------------- 1 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 2 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 3 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 4 | # 比第2版,即可以使用到styleganv1,styleganv2, 不再使用带Equalize learning rate的Conv (这条已经废除). 以及Block第二层的blur操作 5 | # 改变了上采样,不在conv中完成 6 | # 改变了In,带参数的学习 7 | # 改变了了residual,和残差网络一致,另外旁路多了conv1处理通道和In学习参数 8 | # 经测试,不带Eq(Equalize Learning Rate)的参数层学习效果不好 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import init 14 | from torch.nn.parameter import Parameter 15 | import sys 16 | sys.path.append('../') 17 | from torch.nn import functional as F 18 | import model.utils.lreq as ln 19 | 20 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 21 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 22 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 23 | 24 | class FromRGB(nn.Module): 25 | def __init__(self, channels, outputs): 26 | super(FromRGB, self).__init__() 27 | self.from_rgb = ln.Conv2d(channels, outputs, 1, 1, 0) 28 | def forward(self, x): 29 | x = self.from_rgb(x) 30 | x = F.leaky_relu(x, 0.2) 31 | return x 32 | 33 | class BEBlock(nn.Module): 34 | def __init__(self, inputs, outputs, latent_size, has_second_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样 35 | super().__init__() 36 | self.has_second_conv = has_second_conv 37 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 38 | self.noise_weight_1.data.zero_() 39 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 40 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 41 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size) # [n, 2c] -> [n,512] 42 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 43 | 44 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 45 | self.noise_weight_2.data.zero_() 46 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 47 | self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8) 48 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size) 49 | if has_second_conv: 50 | if fused_scale: 51 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False) 52 | else: 53 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 54 | self.fused_scale = fused_scale 55 | 56 | self.inputs = inputs 57 | self.outputs = outputs 58 | 59 | if self.inputs != self.outputs: 60 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0) 61 | self.instance_norm_3 = nn.InstanceNorm2d(outputs, affine=True, eps=1e-8) 62 | 63 | with torch.no_grad(): 64 | self.bias_1.zero_() 65 | self.bias_2.zero_() 66 | 67 | def forward(self, x): 68 | residual = x 69 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 70 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 71 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1] 72 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512] 73 | 74 | x = self.conv_1(x) 75 | x = self.instance_norm_1(x) 76 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 77 | x = x + self.bias_1 78 | x = F.leaky_relu(x, 0.2) 79 | 80 | 81 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 82 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 83 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1] 84 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view 85 | 86 | if self.has_second_conv: 87 | x = self.conv_2(x) 88 | x = self.instance_norm_2(x) 89 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 90 | x = x + self.bias_2 91 | if self.inputs != self.outputs: 92 | residual = self.conv_3(residual) 93 | residual = self.instance_norm_3(residual) 94 | x = x + residual 95 | x = F.leaky_relu(x, 0.2) 96 | if not self.fused_scale: #上采样 97 | x = F.avg_pool2d(x, 2, 2) 98 | 99 | #x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好 100 | return x, w1, w2 101 | 102 | 103 | class BE(nn.Module): 104 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3, pggan=False): 105 | super().__init__() 106 | self.maxf = maxf 107 | self.startf = startf 108 | self.latent_size = latent_size 109 | #self.layer_to_resolution = [0 for _ in range(layer_count)] 110 | self.decode_block = nn.ModuleList() 111 | self.layer_count = layer_count 112 | inputs = startf # 16 113 | outputs = startf*2 114 | #resolution = 1024 115 | # from_RGB = nn.ModuleList() 116 | fused_scale = False 117 | self.FromRGB = FromRGB(channels, inputs) 118 | 119 | for i in range(layer_count): 120 | 121 | has_second_conv = i+1 != layer_count #普通的D最后一个块的第二层是 mini_batch_std 122 | #fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样 123 | #from_RGB.append(FromRGB(channels, inputs)) 124 | 125 | block = BEBlock(inputs, outputs, latent_size, has_second_conv, fused_scale=fused_scale) 126 | 127 | inputs = inputs*2 128 | outputs = outputs*2 129 | inputs = min(maxf, inputs) 130 | outputs = min(maxf, outputs) 131 | #self.layer_to_resolution[i] = resolution 132 | #resolution /=2 133 | self.decode_block.append(block) 134 | #self.FromRGB = from_RGB 135 | 136 | self.pggan = pggan 137 | if pggan: 138 | self.new_final = nn.Conv2d(512, 512, 4, 1, 0, bias=True) 139 | 140 | #将w逆序,以保证和G的w顺序, block_num控制progressive 141 | def forward(self, x, block_num=9): 142 | #x = self.FromRGB[9-block_num](x) #每个block一个 143 | x = self.FromRGB(x) 144 | w = torch.tensor(0) 145 | for i in range(9-block_num,self.layer_count): 146 | x,w1,w2 = self.decode_block[i](x) 147 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512] 148 | if i == (9-block_num): 149 | w = w_ # [b,n,512] 150 | else: 151 | w = torch.cat((w_,w),dim=1) 152 | if self.pggan: 153 | x = self.new_final(x) 154 | return x, w 155 | 156 | #test 157 | # E = BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) 158 | # imgs1 = torch.randn(2,3,256,256) 159 | # const2,w2 = E(imgs1) 160 | # print(const2.shape) 161 | # print(w2.shape) 162 | # print(E) 163 | -------------------------------------------------------------------------------- /model/E/Ablation_Study/E_v2_std.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.nn.parameter import Parameter 5 | import sys 6 | #sys.path.append('../') 7 | import model.utils.lreq as ln 8 | from model.utils.net import Blur,FromRGB,downscale2d 9 | from torch.nn import functional as F 10 | 11 | 12 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 13 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 14 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 15 | 16 | # 这一版的w只有std,没有mean, 比马要差一些,参照3090-styleganv2马的训练 17 | 18 | class BEBlock(nn.Module): 19 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样 20 | super().__init__() 21 | self.has_last_conv = has_last_conv 22 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 23 | self.noise_weight_1.data.zero_() 24 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 25 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 26 | self.inver_mod1 = ln.Linear(inputs, latent_size, gain=1) # [n, 2c] -> [n,512] 27 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 28 | 29 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 30 | self.noise_weight_2.data.zero_() 31 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 32 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 33 | self.inver_mod2 = ln.Linear(inputs, latent_size, gain=1) 34 | self.blur = Blur(inputs) 35 | if has_last_conv: 36 | if fused_scale: 37 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) 38 | else: 39 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 40 | self.fused_scale = fused_scale 41 | 42 | self.inputs = inputs 43 | self.outputs = outputs 44 | 45 | if self.inputs != self.outputs: 46 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0) 47 | 48 | with torch.no_grad(): 49 | self.bias_1.zero_() 50 | self.bias_2.zero_() 51 | 52 | def forward(self, x): 53 | #mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 54 | #std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 55 | #style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1] 56 | std1 = x.std((2,3)) 57 | w1 = self.inver_mod1(std1.squeeze()) # [b,2c]->[b,512] 58 | 59 | residual = x 60 | 61 | x = self.instance_norm_1(x) 62 | x = self.conv_1(x) 63 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 64 | x = x + self.bias_1 65 | x = F.leaky_relu(x, 0.2) 66 | 67 | #mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 68 | #std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 69 | #style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1] 70 | std2 = x.std((2,3)) 71 | w2 = self.inver_mod2(std2.squeeze()) # [b,512] , 这里style2.view一直写错成style1.view 72 | 73 | x = self.instance_norm_2(x) 74 | if self.has_last_conv: 75 | x = self.blur(x) 76 | x = self.conv_2(x) 77 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 78 | x = x + self.bias_2 79 | x = F.leaky_relu(x, 0.2) 80 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样 81 | x = downscale2d(x) 82 | residual = downscale2d(residual) 83 | 84 | 85 | if self.inputs != self.outputs: 86 | residual = self.conv_3(residual) 87 | 88 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好 89 | return x, w1, w2 90 | 91 | 92 | class BE(nn.Module): 93 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3, pggan=False): 94 | super().__init__() 95 | self.maxf = maxf 96 | self.startf = startf 97 | self.latent_size = latent_size 98 | self.layer_to_resolution = [0 for _ in range(layer_count)] 99 | self.decode_block = nn.ModuleList() 100 | self.layer_count = layer_count 101 | inputs = startf # 16 102 | outputs = startf*2 103 | resolution = 1024 104 | self.FromRGB = FromRGB(channels, inputs) 105 | #from_RGB = nn.ModuleList() 106 | for i in range(layer_count): 107 | 108 | has_last_conv = i+1 != layer_count 109 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样 110 | 111 | #from_RGB.append(FromRGB(channels, inputs)) 112 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale) 113 | 114 | inputs = inputs*2 115 | outputs = outputs*2 116 | inputs = min(maxf, inputs) 117 | outputs = min(maxf, outputs) 118 | self.layer_to_resolution[i] = resolution 119 | resolution /=2 120 | self.decode_block.append(block) 121 | 122 | self.pggan = pggan 123 | if self.pggan: 124 | #self.new_final = ln.Conv2d(512, 512, 4, 1, 0, bias=True) 125 | self.new_final = ln.Linear(512 * 16, latent_size, gain=1) 126 | #self.FromRGB = from_RGB 127 | 128 | #将w逆序,以保证和G的w顺序, block_num控制progressive 129 | def forward(self, x, block_num=9): 130 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB 131 | x = self.FromRGB(x) 132 | w = torch.tensor(0) 133 | for i in range(9-block_num,self.layer_count): 134 | x,w1,w2 = self.decode_block[i](x) 135 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512] 136 | if i == (9-block_num): 137 | w = w_ # [b,n,512] 138 | else: 139 | w = torch.cat((w_,w),dim=1) 140 | if self.pggan: 141 | #x = self.new_final(x) 142 | x = self.new_final(x.view(x.shape[0],-1)) 143 | return x, w -------------------------------------------------------------------------------- /model/E/E.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.nn.parameter import Parameter 5 | import sys 6 | #sys.path.append('../') 7 | import model.utils.lreq as ln 8 | from model.utils.net import Blur,FromRGB,downscale2d 9 | from torch.nn import functional as F 10 | 11 | 12 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 13 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 14 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 15 | 16 | class BEBlock(nn.Module): 17 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样 18 | super().__init__() 19 | self.has_last_conv = has_last_conv 20 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 21 | self.noise_weight_1.data.zero_() 22 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 23 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 24 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512] 25 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 26 | 27 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 28 | self.noise_weight_2.data.zero_() 29 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 30 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 31 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1) 32 | #self.blur = Blur(inputs) 33 | if has_last_conv: 34 | if fused_scale: 35 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) 36 | else: 37 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 38 | self.fused_scale = fused_scale 39 | 40 | self.inputs = inputs 41 | self.outputs = outputs 42 | 43 | if self.inputs != self.outputs: 44 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0) 45 | 46 | with torch.no_grad(): 47 | self.bias_1.zero_() 48 | self.bias_2.zero_() 49 | 50 | def forward(self, x): 51 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 52 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 53 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1] 54 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512] 55 | 56 | residual = x 57 | 58 | x = self.instance_norm_1(x) 59 | x = self.conv_1(x) 60 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 61 | x = x + self.bias_1 62 | x = F.leaky_relu(x, 0.2) 63 | 64 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 65 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 66 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1] 67 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view 68 | 69 | x = self.instance_norm_2(x) 70 | if self.has_last_conv: 71 | #x = self.blur(x) 72 | x = self.conv_2(x) 73 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 74 | x = x + self.bias_2 75 | x = F.leaky_relu(x, 0.2) 76 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样 77 | x = downscale2d(x) 78 | residual = downscale2d(residual) 79 | 80 | 81 | if self.inputs != self.outputs: 82 | residual = self.conv_3(residual) 83 | 84 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好 85 | return x, w1, w2 86 | 87 | 88 | class BE(nn.Module): 89 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3): 90 | super().__init__() 91 | self.maxf = maxf 92 | self.startf = startf 93 | self.latent_size = latent_size 94 | self.layer_to_resolution = [0 for _ in range(layer_count)] 95 | self.decode_block = nn.ModuleList() 96 | self.layer_count = layer_count 97 | inputs = startf # 16 98 | outputs = startf*2 99 | resolution = 1024 100 | self.FromRGB = FromRGB(channels, inputs) 101 | #from_RGB = nn.ModuleList() 102 | for i in range(layer_count): 103 | 104 | has_last_conv = i+1 != layer_count 105 | #fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样 106 | fused_scale = False 107 | 108 | #from_RGB.append(FromRGB(channels, inputs)) 109 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale) 110 | 111 | inputs = inputs*2 112 | outputs = outputs*2 113 | inputs = min(maxf, inputs) 114 | outputs = min(maxf, outputs) 115 | self.layer_to_resolution[i] = resolution 116 | resolution /=2 117 | self.decode_block.append(block) 118 | 119 | #self.FromRGB = from_RGB 120 | 121 | #将w逆序,以保证和G的w顺序, block_num控制progressive 122 | def forward(self, x, block_num=9): 123 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB 124 | x = self.FromRGB(x) 125 | #print(x.shape) 126 | w = torch.tensor(0) 127 | for i in range(9-block_num,self.layer_count): 128 | x,w1,w2 = self.decode_block[i](x) 129 | #print(x.shape) 130 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512] 131 | if i == (9-block_num): 132 | w = w_ # [b,n,512] 133 | else: 134 | w = torch.cat((w_,w),dim=1) 135 | #print(w.shape) 136 | return x, w -------------------------------------------------------------------------------- /model/E/E_Blur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.nn.parameter import Parameter 5 | import sys 6 | #sys.path.append('../') 7 | import model.utils.lreq as ln 8 | from model.utils.net import Blur,FromRGB,downscale2d 9 | from torch.nn import functional as F 10 | 11 | 12 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 13 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 14 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 15 | 16 | class BEBlock(nn.Module): 17 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样 18 | super().__init__() 19 | self.has_last_conv = has_last_conv 20 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 21 | self.noise_weight_1.data.zero_() 22 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 23 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 24 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512] 25 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 26 | 27 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 28 | self.noise_weight_2.data.zero_() 29 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 30 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 31 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1) 32 | self.blur = Blur(inputs) 33 | if has_last_conv: 34 | if fused_scale: 35 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) 36 | else: 37 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 38 | self.fused_scale = fused_scale 39 | 40 | self.inputs = inputs 41 | self.outputs = outputs 42 | 43 | if self.inputs != self.outputs: 44 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0) 45 | 46 | with torch.no_grad(): 47 | self.bias_1.zero_() 48 | self.bias_2.zero_() 49 | 50 | def forward(self, x): 51 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 52 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 53 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1] 54 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512] 55 | 56 | residual = x 57 | 58 | x = self.instance_norm_1(x) 59 | x = self.conv_1(x) 60 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 61 | x = x + self.bias_1 62 | x = F.leaky_relu(x, 0.2) 63 | 64 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 65 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 66 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1] 67 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view 68 | 69 | x = self.instance_norm_2(x) 70 | if self.has_last_conv: 71 | x = self.blur(x) 72 | x = self.conv_2(x) 73 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 74 | x = x + self.bias_2 75 | x = F.leaky_relu(x, 0.2) 76 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样 77 | x = downscale2d(x) 78 | residual = downscale2d(residual) 79 | 80 | 81 | if self.inputs != self.outputs: 82 | residual = self.conv_3(residual) 83 | 84 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好 85 | return x, w1, w2 86 | 87 | 88 | class BE(nn.Module): 89 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3): 90 | super().__init__() 91 | self.maxf = maxf 92 | self.startf = startf 93 | self.latent_size = latent_size 94 | self.layer_to_resolution = [0 for _ in range(layer_count)] 95 | self.decode_block = nn.ModuleList() 96 | self.layer_count = layer_count 97 | inputs = startf # 16 98 | outputs = startf*2 99 | resolution = 1024 100 | self.FromRGB = FromRGB(channels, inputs) 101 | #from_RGB = nn.ModuleList() 102 | for i in range(layer_count): 103 | 104 | has_last_conv = i+1 != layer_count 105 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样 106 | 107 | #from_RGB.append(FromRGB(channels, inputs)) 108 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale) 109 | 110 | inputs = inputs*2 111 | outputs = outputs*2 112 | inputs = min(maxf, inputs) 113 | outputs = min(maxf, outputs) 114 | self.layer_to_resolution[i] = resolution 115 | resolution /=2 116 | self.decode_block.append(block) 117 | 118 | #self.FromRGB = from_RGB 119 | 120 | #将w逆序,以保证和G的w顺序, block_num控制progressive 121 | def forward(self, x, block_num=9): 122 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB 123 | x = self.FromRGB(x) 124 | #print(x.shape) 125 | w = torch.tensor(0) 126 | for i in range(9-block_num,self.layer_count): 127 | x,w1,w2 = self.decode_block[i](x) 128 | #print(x.shape) 129 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512] 130 | if i == (9-block_num): 131 | w = w_ # [b,n,512] 132 | else: 133 | w = torch.cat((w_,w),dim=1) 134 | #print(w.shape) 135 | return x, w -------------------------------------------------------------------------------- /model/E/E_PG.py: -------------------------------------------------------------------------------- 1 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 2 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 3 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 4 | # 比第2版,即可以使用到styleganv1,styleganv2, 不再使用带Equalize learning rate的Conv (这条已经废除). 以及Block第二层的blur操作 5 | # 改变了上采样,不在conv中完成 6 | # 改变了In,带参数的学习 7 | # 改变了了residual,和残差网络一致,另外旁路多了conv1处理通道和In学习参数 8 | # 经测试,不带Eq(Equalize Learning Rate)的参数层学习效果不好 9 | 10 | 11 | #这一版兼容PGGAN和BIGGAN: 主要改变最后一层,增加FC 12 | #PGGAN: 加一个fc, 和原D类似 13 | #BIGGAN,加两个fc,各128channel,其中一个是标签,完成128->1000的映射 14 | #BIGGAN 改进思路1: IN替换CBN 15 | #BIGGAN 改进思路2: G加w,和E的w对称 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import init 20 | from torch.nn.parameter import Parameter 21 | import sys 22 | sys.path.append('../') 23 | from torch.nn import functional as F 24 | import model.utils.lreq as ln 25 | 26 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码 27 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w 28 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称) 29 | 30 | class FromRGB(nn.Module): 31 | def __init__(self, channels, outputs): 32 | super(FromRGB, self).__init__() 33 | self.from_rgb = ln.Conv2d(channels, outputs, 1, 1, 0) 34 | def forward(self, x): 35 | x = self.from_rgb(x) 36 | x = F.leaky_relu(x, 0.2) 37 | return x 38 | 39 | class BEBlock(nn.Module): 40 | def __init__(self, inputs, outputs, latent_size, has_second_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样 41 | super().__init__() 42 | self.has_second_conv = has_second_conv 43 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 44 | self.noise_weight_1.data.zero_() 45 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 46 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8) 47 | #self.inver_mod1 = ln.Linear(2 * inputs, latent_size) # [n, 2c] -> [n,512] 48 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 49 | 50 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 51 | self.noise_weight_2.data.zero_() 52 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 53 | self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8) 54 | #self.inver_mod2 = ln.Linear(2 * inputs, latent_size) 55 | if has_second_conv: 56 | if fused_scale: 57 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False) 58 | else: 59 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 60 | self.fused_scale = fused_scale 61 | 62 | self.inputs = inputs 63 | self.outputs = outputs 64 | 65 | if self.inputs != self.outputs: 66 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0) 67 | self.instance_norm_3 = nn.InstanceNorm2d(outputs, affine=True, eps=1e-8) 68 | 69 | with torch.no_grad(): 70 | self.bias_1.zero_() 71 | self.bias_2.zero_() 72 | 73 | def forward(self, x): 74 | residual = x 75 | # mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 76 | # std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 77 | # style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1] 78 | # w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512] 79 | 80 | x = self.instance_norm_1(x) 81 | x = self.conv_1(x) 82 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 83 | x = x + self.bias_1 84 | x = F.leaky_relu(x, 0.2) 85 | 86 | 87 | # mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 88 | # std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 89 | # style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1] 90 | # w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view 91 | 92 | if self.has_second_conv: 93 | x = self.instance_norm_2(x) 94 | x = self.conv_2(x) 95 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device)) 96 | x = x + self.bias_2 97 | if self.inputs != self.outputs: 98 | residual = self.conv_3(residual) 99 | residual = self.instance_norm_3(residual) 100 | x = x + residual 101 | x = F.leaky_relu(x, 0.2) 102 | if not self.fused_scale: #上采样 103 | x = F.avg_pool2d(x, 2, 2) 104 | 105 | #x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好 106 | w1 = 0 107 | w2 = 0 108 | return x , w1, w2 109 | 110 | 111 | class BE(nn.Module): 112 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3, pggan=False): 113 | super().__init__() 114 | self.maxf = maxf 115 | self.startf = startf 116 | self.latent_size = latent_size 117 | #self.layer_to_resolution = [0 for _ in range(layer_count)] 118 | self.decode_block = nn.ModuleList() 119 | self.layer_count = layer_count 120 | inputs = startf # 16 121 | outputs = startf*2 122 | #resolution = 1024 123 | # from_RGB = nn.ModuleList() 124 | fused_scale = False 125 | self.FromRGB = FromRGB(channels, inputs) 126 | 127 | for i in range(layer_count): 128 | 129 | has_second_conv = i+1 != layer_count #普通的D最后一个块的第二层是 mini_batch_std 130 | #fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样 131 | #from_RGB.append(FromRGB(channels, inputs)) 132 | 133 | block = BEBlock(inputs, outputs, latent_size, has_second_conv, fused_scale=fused_scale) 134 | 135 | inputs = inputs*2 136 | outputs = outputs*2 137 | inputs = min(maxf, inputs) 138 | outputs = min(maxf, outputs) 139 | #self.layer_to_resolution[i] = resolution 140 | #resolution /=2 141 | self.decode_block.append(block) 142 | #self.FromRGB = from_RGB 143 | 144 | self.pggan = pggan 145 | if pggan: 146 | #self.new_final = nn.Conv2d(512, 512, 4, 1, 0, bias=True) 147 | self.new_final = ln.Linear(512 * 16, latent_size, gain=1) 148 | 149 | #将w逆序,以保证和G的w顺序, block_num控制progressive 150 | def forward(self, x, block_num=9): 151 | #x = self.FromRGB[9-block_num](x) #每个block一个 152 | x = self.FromRGB(x) 153 | w = torch.tensor(0) 154 | for i in range(9-block_num,self.layer_count): 155 | x,w1,w2 = self.decode_block[i](x) 156 | # w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512] 157 | # if i == (9-block_num): 158 | # w = w_ # [b,n,512] 159 | # else: 160 | # w = torch.cat((w_,w),dim=1) 161 | if self.pggan: 162 | #x = self.new_final(x) 163 | x = self.new_final(x.view(x.shape[0],-1)) 164 | return torch.tensor(0), w 165 | 166 | #test 167 | # E = BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3) 168 | # imgs1 = torch.randn(2,3,256,256) 169 | # const2,w2 = E(imgs1) 170 | # print(const2.shape) 171 | # print(w2.shape) 172 | # print(E) 173 | -------------------------------------------------------------------------------- /model/pggan/utils/Encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import ModuleList, AvgPool2d 5 | from networks.PGGAN.CustomLayers import DisGeneralConvBlock, DisFinalBlock, _equalized_conv2d 6 | 7 | #Discriminator to Encoder, Just change the last layer 8 | #和原网络中D对应的Encoder, 训练时G不变, v1只改了最后一层,v2是一个规模较小的网络 9 | 10 | # in: [-1,512] , out: [-1,3,1024,1024] 11 | class encoder(torch.nn.Module): 12 | """ Discriminator of the GAN """ 13 | def __init__(self, height=7, feature_size=512, use_eql=True): 14 | """ 15 | constructor for the class 16 | :param height: total height of the discriminator (Must be equal to the Generator depth) 17 | :param feature_size: size of the deepest features extracted 18 | (Must be equal to Generator latent_size) 19 | :param use_eql: whether to use equalized learning rate 20 | """ 21 | super().__init__() 22 | assert feature_size != 0 and ((feature_size & (feature_size - 1)) == 0), \ 23 | "latent size not a power of 2" 24 | if height >= 4: 25 | assert feature_size >= np.power(2, height - 4), "feature size cannot be produced" 26 | # create state of the object 27 | self.use_eql = use_eql 28 | self.height = height 29 | self.feature_size = feature_size 30 | #self.final_block = DisFinalBlock(self.feature_size, use_eql=self.use_eql) 31 | # create a module list of the other required general convolution blocks 32 | self.layers = ModuleList([]) # initialize to empty list 33 | # create the fromRGB layers for various inputs: 34 | if self.use_eql: 35 | self.fromRGB = lambda out_channels: \ 36 | _equalized_conv2d(3, out_channels, (1, 1), bias=True) 37 | else: 38 | from torch.nn import Conv2d 39 | self.fromRGB = lambda out_channels: Conv2d(3, out_channels, (1, 1), bias=True) 40 | self.rgb_to_features = ModuleList([self.fromRGB(self.feature_size)]) 41 | # create the remaining layers 42 | for i in range(self.height - 1): 43 | if i > 2: 44 | layer = DisGeneralConvBlock( 45 | int(self.feature_size // np.power(2, i - 2)), 46 | int(self.feature_size // np.power(2, i - 3)), 47 | use_eql=self.use_eql 48 | ) 49 | rgb = self.fromRGB(int(self.feature_size // np.power(2, i - 2))) 50 | else: 51 | layer = DisGeneralConvBlock(self.feature_size, 52 | self.feature_size, use_eql=self.use_eql) 53 | rgb = self.fromRGB(self.feature_size) 54 | self.layers.append(layer) 55 | self.rgb_to_features.append(rgb) 56 | # register the temporary downSampler 57 | self.temporaryDownsampler = AvgPool2d(2) 58 | #new 59 | self.new_final = nn.Conv2d(512, 512, 4, 1, 0, bias=True) 60 | def forward(self, x, depth, alpha): 61 | """ 62 | forward pass of the discriminator 63 | :param x: input to the network 64 | :param depth: current depth of operation (Progressive GAN) 65 | :param alpha: current value of alpha for fade-in 66 | :return: out => raw prediction values (WGAN-GP) 67 | """ 68 | assert depth < self.height, "Requested output depth cannot be produced" 69 | if depth > 0: 70 | residual = self.rgb_to_features[depth - 1](self.temporaryDownsampler(x)) 71 | 72 | straight = self.layers[depth - 1]( 73 | self.rgb_to_features[depth](x) 74 | ) 75 | 76 | y = (alpha * straight) + ((1 - alpha) * residual) 77 | 78 | for block in reversed(self.layers[:depth - 1]): 79 | y = block(y) 80 | else: 81 | y = self.rgb_to_features[0](x) 82 | 83 | #out = self.final_block(y) 84 | out = self.new_final(y) 85 | return out 86 | 87 | #in: [-1,3,1024,1024], out: [-1,512] 88 | class encoder_small(torch.nn.Module): 89 | def __init__(self): 90 | super().__init__() 91 | self.main = nn.Sequential( 92 | nn.Conv2d(3,12,4,2,1,bias=False), # 1024->512 93 | nn.LeakyReLU(0.2, inplace=True), 94 | nn.Conv2d(12,12,4,2,1,bias=False),# 512->256 95 | nn.BatchNorm2d(12), 96 | nn.LeakyReLU(0.2, inplace=True), 97 | nn.Conv2d(12,3,4,2,1,bias=False),# 256->128 98 | nn.BatchNorm2d(3), 99 | nn.LeakyReLU(0.2, inplace=True), 100 | nn.Conv2d(3,1,4,2,1,bias=False),# 128->64*64=4096 101 | ) 102 | self.fc = nn.Linear(4096,512) 103 | def forward(self, x): 104 | y1 = self.main(x) 105 | y2 = y1.view(-1,4096) 106 | y3 = self.fc(y2) 107 | return y3 -------------------------------------------------------------------------------- /model/stylegan1/alae.py: -------------------------------------------------------------------------------- 1 | #ALAE 2 | class EncodeBlock(nn.Module): 3 | def __init__(self, inputs, outputs, latent_size, last=False, fused_scale=True): 4 | super(EncodeBlock, self).__init__() 5 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False) 6 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1)) 7 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False) 8 | self.blur = Blur(inputs) 9 | self.last = last 10 | self.fused_scale = fused_scale 11 | if last: 12 | self.dense = ln.Linear(inputs * 4 * 4, outputs) 13 | else: 14 | if fused_scale: 15 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) 16 | else: 17 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 18 | 19 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 20 | self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False) 21 | self.style_1 = ln.Linear(2 * inputs, latent_size) 22 | if last: 23 | self.style_2 = ln.Linear(outputs, latent_size) 24 | else: 25 | self.style_2 = ln.Linear(2 * outputs, latent_size) 26 | 27 | with torch.no_grad(): 28 | self.bias_1.zero_() 29 | self.bias_2.zero_() 30 | 31 | def forward(self, x): 32 | x = self.conv_1(x) + self.bias_1 33 | x = F.leaky_relu(x, 0.2) 34 | 35 | m = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1] 36 | std = torch.sqrt(torch.mean((x - m) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1] 37 | style_1 = torch.cat((m, std), dim=1) # [b,2c,1,1] 38 | 39 | x = self.instance_norm_1(x) 40 | 41 | if self.last: 42 | x = self.dense(x.view(x.shape[0], -1)) 43 | 44 | x = F.leaky_relu(x, 0.2) 45 | w1 = self.style_1(style_1.view(style_1.shape[0], style_1.shape[1])) 46 | w2 = self.style_2(x.view(x.shape[0], x.shape[1])) 47 | else: 48 | x = self.conv_2(self.blur(x)) 49 | if not self.fused_scale: 50 | x = downscale2d(x) 51 | x = x + self.bias_2 52 | 53 | x = F.leaky_relu(x, 0.2) 54 | 55 | m = torch.mean(x, dim=[2, 3], keepdim=True) 56 | std = torch.sqrt(torch.mean((x - m) ** 2, dim=[2, 3], keepdim=True)) 57 | style_2 = torch.cat((m, std), dim=1) 58 | 59 | x = self.instance_norm_2(x) 60 | 61 | w1 = self.style_1(style_1.view(style_1.shape[0], style_1.shape[1])) 62 | w2 = self.style_2(style_2.view(style_2.shape[0], style_2.shape[1])) 63 | 64 | return x, w1, w2 #降采样一次的结果 , [b,512] , [b,512] 65 | 66 | class EncoderDefault(nn.Module): 67 | def __init__(self, startf, maxf, layer_count, latent_size, channels=3): 68 | super(EncoderDefault, self).__init__() 69 | self.maxf = maxf 70 | self.startf = startf 71 | self.layer_count = layer_count 72 | self.from_rgb: nn.ModuleList[FromRGB] = nn.ModuleList() 73 | self.channels = channels 74 | self.latent_size = latent_size 75 | 76 | mul = 2 77 | inputs = startf 78 | self.encode_block: nn.ModuleList[EncodeBlock] = nn.ModuleList() 79 | 80 | resolution = 2 ** (self.layer_count + 1) 81 | 82 | for i in range(self.layer_count): 83 | outputs = min(self.maxf, startf * mul) 84 | 85 | self.from_rgb.append(FromRGB(channels, inputs)) 86 | 87 | fused_scale = resolution >= 128 88 | 89 | block = EncodeBlock(inputs, outputs, latent_size, False, fused_scale=fused_scale) 90 | 91 | resolution //= 2 92 | 93 | self.encode_block.append(block) 94 | inputs = outputs 95 | mul *= 2 96 | 97 | def encode(self, x, lod): 98 | styles = torch.zeros(x.shape[0], 1, self.latent_size) #[b, 1 , 512] 99 | 100 | x = self.from_rgb[self.layer_count - lod - 1](x) 101 | x = F.leaky_relu(x, 0.2) 102 | 103 | for i in range(self.layer_count - lod - 1, self.layer_count): 104 | x, s1, s2 = self.encode_block[i](x) 105 | styles[:, 0] += s1 + s2 106 | 107 | return styles 108 | 109 | def encode2(self, x, lod, blend): 110 | x_orig = x 111 | styles = torch.zeros(x.shape[0], 1, self.latent_size) 112 | 113 | x = self.from_rgb[self.layer_count - lod - 1](x) 114 | x = F.leaky_relu(x, 0.2) 115 | 116 | x, s1, s2 = self.encode_block[self.layer_count - lod - 1](x) 117 | styles[:, 0] += s1 * blend + s2 * blend 118 | 119 | x_prev = F.avg_pool2d(x_orig, 2, 2) 120 | 121 | x_prev = self.from_rgb[self.layer_count - (lod - 1) - 1](x_prev) 122 | x_prev = F.leaky_relu(x_prev, 0.2) 123 | 124 | x = torch.lerp(x_prev, x, blend) 125 | 126 | for i in range(self.layer_count - (lod - 1) - 1, self.layer_count): 127 | x, s1, s2 = self.encode_block[i](x) 128 | styles[:, 0] += s1 + s2 129 | 130 | return styles 131 | 132 | def forward(self, x, lod, blend): 133 | if blend == 1: 134 | return self.encode(x, lod) 135 | else: 136 | return self.encode2(x, lod, blend) 137 | 138 | def get_statistics(self, lod): 139 | rgb_std = self.from_rgb[self.layer_count - lod - 1].from_rgb.weight.std().item() 140 | rgb_std_c = self.from_rgb[self.layer_count - lod - 1].from_rgb.std 141 | 142 | layers = [] 143 | for i in range(self.layer_count - lod - 1, self.layer_count): 144 | conv_1 = self.encode_block[i].conv_1.weight.std().item() 145 | conv_1_c = self.encode_block[i].conv_1.std 146 | conv_2 = self.encode_block[i].conv_2.weight.std().item() 147 | conv_2_c = self.encode_block[i].conv_2.std 148 | layers.append(((conv_1 / conv_1_c), (conv_2 / conv_2_c))) 149 | return rgb_std / rgb_std_c, layers 150 | -------------------------------------------------------------------------------- /model/stylegan1/custom_adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class LREQAdam(Optimizer): 7 | def __init__(self, params, lr=1e-3, betas=(0.0, 0.99), eps=1e-8, weight_decay=0): 8 | beta_2 = betas[1] 9 | if not 0.0 <= lr: 10 | raise ValueError("Invalid learning rate: {}".format(lr)) 11 | if not 0.0 <= eps: 12 | raise ValueError("Invalid epsilon value: {}".format(eps)) 13 | if not 0.0 == betas[0]: 14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 15 | if not 0.0 <= beta_2 < 1.0: 16 | raise ValueError("Invalid beta parameter at index 1: {}".format(beta_2)) 17 | defaults = dict(lr=lr, beta_2=beta_2, eps=eps, weight_decay=weight_decay) 18 | super(LREQAdam, self).__init__(params, defaults) 19 | 20 | def __setstate__(self, state): 21 | super(LREQAdam, self).__setstate__(state) 22 | 23 | def step(self, closure=None): 24 | """Performs a single optimization step. 25 | Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. 26 | """ 27 | loss = None 28 | if closure is not None: 29 | loss = closure() 30 | 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | grad = p.grad.data 36 | if grad.is_sparse: 37 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 38 | 39 | state = self.state[p] 40 | 41 | # State initialization 42 | if len(state) == 0: 43 | state['step'] = 0 44 | # Exponential moving average of gradient values 45 | # state['exp_avg'] = torch.zeros_like(p.data) 46 | # Exponential moving average of squared gradient values 47 | state['exp_avg_sq'] = torch.zeros_like(p.data) 48 | 49 | # exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 50 | exp_avg_sq = state['exp_avg_sq'] 51 | beta_2 = group['beta_2'] 52 | 53 | state['step'] += 1 54 | 55 | if group['weight_decay'] != 0: 56 | grad.add_(group['weight_decay'], p.data / p.coef) 57 | 58 | # Decay the first and second moment running average coefficient 59 | # exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | exp_avg_sq.mul_(beta_2).addcmul_(1 - beta_2, grad, grad) 61 | denom = exp_avg_sq.sqrt().add_(group['eps']) 62 | 63 | # bias_correction1 = 1 - beta1 ** state['step'] # 1 64 | bias_correction2 = 1 - beta_2 ** state['step'] 65 | # step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 66 | step_size = group['lr'] * math.sqrt(bias_correction2) 67 | 68 | # p.data.addcdiv_(-step_size, exp_avg, denom) 69 | if hasattr(p, 'lr_equalization_coef'): 70 | step_size *= p.lr_equalization_coef 71 | 72 | p.data.addcdiv_(-step_size, grad, denom) # p = p + (-steo_size)/grad * demon 73 | 74 | return loss 75 | -------------------------------------------------------------------------------- /model/stylegan1/lod_driver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import time 4 | from collections import defaultdict 5 | 6 | #一个可以返回各类训练参数(param)的对象 7 | class LODDriver: 8 | def __init__(self, cfg, logger, dataset_size): 9 | self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_1GPU #[128, 128, 128, 64, 32, 16] 10 | self.minibatch_base = 16 11 | self.cfg = cfg 12 | self.dataset_size = dataset_size 13 | self.current_epoch = 0 14 | self.lod = -1 15 | self.in_transition = False 16 | self.logger = logger 17 | self.iteration = 0 18 | self.epoch_end_time = 0 19 | self.epoch_start_time = 0 20 | self.per_epoch_ptime = 0 21 | self.reports = cfg.TRAIN.REPORT_FREQ 22 | self.snapshots = cfg.TRAIN.SNAPSHOT_FREQ 23 | self.tick_start_nimg_report = 0 24 | self.tick_start_nimg_snapshot = 0 25 | 26 | def get_lod_power2(self): 27 | return self.lod + 2 28 | 29 | def get_batch_size(self): 30 | return self.lod_2_batch[min(self.lod, len(self.lod_2_batch) - 1)] 31 | 32 | def get_dataset_size(self): 33 | return self.dataset_size 34 | 35 | def get_blend_factor(self): 36 | blend_factor = float((self.current_epoch % self.cfg.TRAIN.EPOCHS_PER_LOD) * self.dataset_size + self.iteration) 37 | blend_factor /= float(self.cfg.TRAIN.EPOCHS_PER_LOD // 2 * self.dataset_size) 38 | blend_factor = math.sin(blend_factor * math.pi - 0.5 * math.pi) * 0.5 + 0.5 39 | 40 | if not self.in_transition: 41 | blend_factor = 1 42 | 43 | return blend_factor 44 | 45 | def is_time_to_report(self): 46 | if self.iteration >= self.tick_start_nimg_report + self.reports[min(self.lod, len(self.reports) - 1)] * 1000: 47 | self.tick_start_nimg_report = self.iteration 48 | return True 49 | return False 50 | 51 | def is_time_to_save(self): 52 | if self.iteration >= self.tick_start_nimg_snapshot + self.snapshots[min(self.lod, len(self.snapshots) - 1)] * 1000: 53 | self.tick_start_nimg_snapshot = self.iteration 54 | return True 55 | return False 56 | 57 | def step(self): 58 | self.iteration += self.get_batch_size() 59 | self.epoch_end_time = time.time() 60 | self.per_epoch_ptime = self.epoch_end_time - self.epoch_start_time 61 | 62 | def set_epoch(self, epoch, optimizers): 63 | self.current_epoch = epoch 64 | self.iteration = 0 65 | self.tick_start_nimg_report = 0 66 | self.tick_start_nimg_snapshot = 0 67 | self.epoch_start_time = time.time() 68 | 69 | new_lod = min(self.cfg.MODEL.LAYER_COUNT - 1, epoch // self.cfg.TRAIN.EPOCHS_PER_LOD) #第一个值是固定的。第二个值最小是0,提升lod即分辨率 70 | if new_lod != self.lod: 71 | self.lod = new_lod 72 | self.logger.info("#" * 80) 73 | self.logger.info("# Switching LOD to %d" % self.lod) 74 | self.logger.info("#" * 80) 75 | self.logger.info("Start transition") 76 | self.in_transition = True 77 | for opt in optimizers: 78 | opt.state = defaultdict(dict) 79 | 80 | is_in_first_half_of_cycle = (epoch % self.cfg.TRAIN.EPOCHS_PER_LOD) < (self.cfg.TRAIN.EPOCHS_PER_LOD // 2) 81 | is_growing = epoch // self.cfg.TRAIN.EPOCHS_PER_LOD == self.lod > 0 82 | new_in_transition = is_in_first_half_of_cycle and is_growing 83 | 84 | if new_in_transition != self.in_transition: 85 | self.in_transition = new_in_transition 86 | self.logger.info("#" * 80) 87 | self.logger.info("# Transition ended") 88 | self.logger.info("#" * 80) 89 | -------------------------------------------------------------------------------- /model/stylegan1/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | __all__ = ['kl', 'reconstruction', 'discriminator_logistic_simple_gp', 6 | 'discriminator_gradient_penalty', 'generator_logistic_non_saturating'] 7 | 8 | 9 | def kl(mu, log_var): 10 | return -0.5 * torch.mean(torch.mean(1 + log_var - mu.pow(2) - log_var.exp(), 1)) 11 | 12 | 13 | def reconstruction(recon_x, x, lod=None): 14 | return torch.mean((recon_x - x)**2) 15 | 16 | 17 | def discriminator_logistic_simple_gp(d_result_fake, d_result_real, reals, r1_gamma=10.0): 18 | loss = (F.softplus(d_result_fake) + F.softplus(-d_result_real)) 19 | 20 | if r1_gamma != 0.0: 21 | real_loss = d_result_real.sum() 22 | real_grads = torch.autograd.grad(real_loss, reals, create_graph=True, retain_graph=True)[0] 23 | r1_penalty = torch.sum(real_grads.pow(2.0), dim=[1, 2, 3]) 24 | loss = loss + r1_penalty * (r1_gamma * 0.5) 25 | return loss.mean() 26 | 27 | 28 | def discriminator_gradient_penalty(d_result_real, reals, r1_gamma=10.0): 29 | real_loss = d_result_real.sum() 30 | real_grads = torch.autograd.grad(real_loss, reals, create_graph=True, retain_graph=True)[0] 31 | r1_penalty = torch.sum(real_grads.pow(2.0), dim=[1, 2, 3]) 32 | loss = r1_penalty * (r1_gamma * 0.5) 33 | return loss.mean() 34 | 35 | 36 | def generator_logistic_non_saturating(d_result_fake): 37 | return F.softplus(-d_result_fake).mean() 38 | -------------------------------------------------------------------------------- /model/stylegan1/lreq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn import init 5 | from torch.nn.parameter import Parameter 6 | import numpy as np 7 | 8 | 9 | class Bool: 10 | def __init__(self): 11 | self.value = False 12 | 13 | def __bool__(self): 14 | return self.value 15 | __nonzero__ = __bool__ 16 | 17 | def set(self, value): 18 | self.value = value 19 | 20 | 21 | use_implicit_lreq = Bool() 22 | use_implicit_lreq.set(True) 23 | 24 | 25 | def is_sequence(arg): 26 | return (not hasattr(arg, "strip") and 27 | hasattr(arg, "__getitem__") or 28 | hasattr(arg, "__iter__")) 29 | 30 | 31 | def make_tuple(x, n): 32 | if is_sequence(x): 33 | return x 34 | return tuple([x for _ in range(n)]) 35 | 36 | 37 | class Linear(nn.Module): 38 | def __init__(self, in_features, out_features, bias=True, gain=np.sqrt(2.0), lrmul=1.0, implicit_lreq=use_implicit_lreq): 39 | super(Linear, self).__init__() 40 | self.in_features = in_features 41 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 42 | if bias: 43 | self.bias = Parameter(torch.Tensor(out_features)) 44 | else: 45 | self.register_parameter('bias', None) 46 | self.std = 0 47 | self.gain = gain 48 | self.lrmul = lrmul 49 | self.implicit_lreq = implicit_lreq 50 | self.reset_parameters() 51 | 52 | def reset_parameters(self): 53 | self.std = self.gain / np.sqrt(self.in_features) * self.lrmul 54 | if not self.implicit_lreq: 55 | init.normal_(self.weight, mean=0, std=1.0 / self.lrmul) 56 | else: 57 | init.normal_(self.weight, mean=0, std=self.std / self.lrmul) 58 | setattr(self.weight, 'lr_equalization_coef', self.std) 59 | if self.bias is not None: 60 | setattr(self.bias, 'lr_equalization_coef', self.lrmul) 61 | 62 | if self.bias is not None: 63 | with torch.no_grad(): 64 | self.bias.zero_() 65 | 66 | def forward(self, input): 67 | if not self.implicit_lreq: 68 | bias = self.bias 69 | if bias is not None: 70 | bias = bias * self.lrmul 71 | return F.linear(input, self.weight * self.std, bias) 72 | else: 73 | return F.linear(input, self.weight, self.bias) 74 | 75 | 76 | class Conv2d(nn.Module): 77 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 78 | groups=1, bias=True, gain=np.sqrt(2.0), transpose=False, transform_kernel=False, lrmul=1.0, 79 | implicit_lreq=use_implicit_lreq): 80 | super(Conv2d, self).__init__() 81 | if in_channels % groups != 0: 82 | raise ValueError('in_channels must be divisible by groups') 83 | if out_channels % groups != 0: 84 | raise ValueError('out_channels must be divisible by groups') 85 | self.in_channels = in_channels 86 | self.out_channels = out_channels 87 | self.kernel_size = make_tuple(kernel_size, 2) 88 | self.stride = make_tuple(stride, 2) 89 | self.padding = make_tuple(padding, 2) 90 | self.output_padding = make_tuple(output_padding, 2) 91 | self.dilation = make_tuple(dilation, 2) 92 | self.groups = groups 93 | self.gain = gain 94 | self.lrmul = lrmul 95 | self.transpose = transpose 96 | self.fan_in = np.prod(self.kernel_size) * in_channels // groups 97 | self.transform_kernel = transform_kernel 98 | if transpose: 99 | self.weight = Parameter(torch.Tensor(in_channels, out_channels // groups, *self.kernel_size)) 100 | else: 101 | self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 102 | if bias: 103 | self.bias = Parameter(torch.Tensor(out_channels)) 104 | else: 105 | self.register_parameter('bias', None) 106 | self.std = 0 107 | self.implicit_lreq = implicit_lreq 108 | self.reset_parameters() 109 | 110 | def reset_parameters(self): 111 | self.std = self.gain / np.sqrt(self.fan_in) 112 | if not self.implicit_lreq: 113 | init.normal_(self.weight, mean=0, std=1.0 / self.lrmul) 114 | else: 115 | init.normal_(self.weight, mean=0, std=self.std / self.lrmul) 116 | setattr(self.weight, 'lr_equalization_coef', self.std) 117 | if self.bias is not None: 118 | setattr(self.bias, 'lr_equalization_coef', self.lrmul) 119 | 120 | if self.bias is not None: 121 | with torch.no_grad(): 122 | self.bias.zero_() 123 | 124 | def forward(self, x): 125 | if self.transpose: 126 | w = self.weight 127 | if self.transform_kernel: 128 | w = F.pad(w, (1, 1, 1, 1), mode='constant') 129 | w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] 130 | if not self.implicit_lreq: 131 | bias = self.bias 132 | if bias is not None: 133 | bias = bias * self.lrmul 134 | return F.conv_transpose2d(x, w * self.std, bias, stride=self.stride, 135 | padding=self.padding, output_padding=self.output_padding, 136 | dilation=self.dilation, groups=self.groups) 137 | else: 138 | return F.conv_transpose2d(x, w, self.bias, stride=self.stride, padding=self.padding, 139 | output_padding=self.output_padding, dilation=self.dilation, 140 | groups=self.groups) 141 | else: 142 | w = self.weight 143 | if self.transform_kernel: 144 | w = F.pad(w, (1, 1, 1, 1), mode='constant') 145 | w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 146 | if not self.implicit_lreq: 147 | bias = self.bias 148 | if bias is not None: 149 | bias = bias * self.lrmul 150 | return F.conv2d(x, w * self.std, bias, stride=self.stride, padding=self.padding, 151 | dilation=self.dilation, groups=self.groups) 152 | else: 153 | return F.conv2d(x, w, self.bias, stride=self.stride, padding=self.padding, 154 | dilation=self.dilation, groups=self.groups) 155 | 156 | 157 | class ConvTranspose2d(Conv2d): 158 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 159 | groups=1, bias=True, gain=np.sqrt(2.0), transform_kernel=False, lrmul=1.0, implicit_lreq=use_implicit_lreq): 160 | super(ConvTranspose2d, self).__init__(in_channels=in_channels, 161 | out_channels=out_channels, 162 | kernel_size=kernel_size, 163 | stride=stride, 164 | padding=padding, 165 | output_padding=output_padding, 166 | dilation=dilation, 167 | groups=groups, 168 | bias=bias, 169 | gain=gain, 170 | transpose=True, 171 | transform_kernel=transform_kernel, 172 | lrmul=lrmul, 173 | implicit_lreq=implicit_lreq) 174 | 175 | 176 | class SeparableConv2d(nn.Module): 177 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 178 | bias=True, gain=np.sqrt(2.0), transpose=False): 179 | super(SeparableConv2d, self).__init__() 180 | self.spatial_conv = Conv2d(in_channels, in_channels, kernel_size, stride, padding, output_padding, dilation, 181 | in_channels, False, 1, transpose) 182 | self.channel_conv = Conv2d(in_channels, out_channels, 1, bias, 1, gain=gain) 183 | 184 | def forward(self, x): 185 | return self.channel_conv(self.spatial_conv(x)) 186 | 187 | 188 | class SeparableConvTranspose2d(Conv2d): 189 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 190 | bias=True, gain=np.sqrt(2.0)): 191 | super(SeparableConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, 192 | output_padding, dilation, bias, gain, True) 193 | 194 | -------------------------------------------------------------------------------- /model/stylegan1/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import random 4 | import module.losses as losses 5 | from module.net import Generator, Mapping, Discriminator 6 | import numpy as np 7 | 8 | 9 | class DLatent(nn.Module): 10 | def __init__(self, dlatent_size, layer_count): 11 | super(DLatent, self).__init__() 12 | buffer = torch.zeros(layer_count, dlatent_size, dtype=torch.float32) 13 | self.register_buffer('buff', buffer) 14 | 15 | class Model(nn.Module): #三个网络 Gp Gs D 的封装 16 | def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_layers=5, dlatent_avg_beta=None, 17 | truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3): 18 | super(Model, self).__init__() 19 | 20 | self.mapping = Mapping( 21 | num_layers=2 * layer_count, # 2*9 = 18 22 | latent_size=latent_size, 23 | dlatent_size=latent_size, 24 | mapping_fmaps=latent_size, 25 | mapping_layers=mapping_layers) 26 | 27 | self.generator = Generator( 28 | startf=startf, 29 | layer_count=layer_count, 30 | maxf=maxf, 31 | latent_size=latent_size, 32 | channels=channels) 33 | 34 | self.discriminator = Discriminator( 35 | startf=startf, 36 | layer_count=layer_count, 37 | maxf=maxf, 38 | channels=channels) 39 | 40 | self.dlatent_avg = DLatent(latent_size, self.mapping.num_layers) 41 | self.latent_size = latent_size 42 | self.dlatent_avg_beta = dlatent_avg_beta 43 | self.truncation_psi = truncation_psi # 0.7 44 | self.style_mixing_prob = style_mixing_prob 45 | self.truncation_cutoff = truncation_cutoff # 前8层 46 | 47 | def generate(self, lod, blend_factor, z=None, count=32, remove_blob=False): 48 | if z is None: 49 | z = torch.randn(count, self.latent_size) 50 | styles = self.mapping(z) 51 | if self.dlatent_avg_beta is not None: #让原向量以中心向量 dlatent_avg.buff.data 为中心,按比例self.dlatent_avg_beta=0.995围绕中心向量拉近, 52 | with torch.no_grad(): 53 | batch_avg = styles.mean(dim=0) 54 | self.dlatent_avg.buff.data.lerp_(batch_avg.data, 1.0 - self.dlatent_avg_beta) # y.lerp(x,a) = y - (y-x)*a 55 | 56 | if self.style_mixing_prob is not None: 57 | if random.random() < self.style_mixing_prob: 58 | z2 = torch.randn(count, self.latent_size) # z2 : [32, 512] 59 | styles2 = self.mapping(z2) 60 | 61 | layer_idx = torch.arange(self.mapping.num_layers)[np.newaxis, :, np.newaxis] 62 | cur_layers = (lod + 1) * 2 63 | mixing_cutoff = random.randint(1, cur_layers) 64 | styles = torch.where(layer_idx < mixing_cutoff, styles, styles2) 65 | 66 | if self.truncation_psi is not None: #让原向量以中心向量 dlatent_avg.buff.data 为中心,按比例truncation_psi围绕中心向量拉近, 67 | layer_idx = torch.arange(self.mapping.num_layers)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 68 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 69 | coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) # 18个变量前8个裁剪比例truncation_psi 70 | styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) # avg + (styles-avg) * 0.7 71 | 72 | rec = self.generator.forward(styles, lod, blend_factor, remove_blob) # styles:[-1 , 18, 512] 73 | return rec 74 | 75 | def forward(self, x, lod, blend_factor, d_train): 76 | if d_train: 77 | with torch.no_grad(): 78 | rec = self.generate(lod, blend_factor, count=x.shape[0]) 79 | self.discriminator.requires_grad_(True) 80 | d_result_real = self.discriminator(x, lod, blend_factor).squeeze() 81 | d_result_fake = self.discriminator(rec.detach(), lod, blend_factor).squeeze() 82 | 83 | loss_d = losses.discriminator_logistic_simple_gp(d_result_fake, d_result_real, x) 84 | return loss_d 85 | else: 86 | rec = self.generate(lod, blend_factor, count=x.shape[0]) 87 | self.discriminator.requires_grad_(False) 88 | d_result_fake = self.discriminator(rec, lod, blend_factor).squeeze() 89 | loss_g = losses.generator_logistic_non_saturating(d_result_fake) 90 | return loss_g 91 | 92 | def lerp(self, other, betta): 93 | if hasattr(other, 'module'): 94 | other = other.module 95 | with torch.no_grad(): 96 | params = list(self.mapping.parameters()) + list(self.generator.parameters()) + list(self.dlatent_avg.parameters()) 97 | other_param = list(other.mapping.parameters()) + list(other.generator.parameters()) + list(other.dlatent_avg.parameters()) 98 | for p, p_other in zip(params, other_param): 99 | p.data.lerp_(p_other.data, 1.0 - betta) 100 | -------------------------------------------------------------------------------- /model/stylegan1/text_alae.py: -------------------------------------------------------------------------------- 1 | from net import * 2 | from model import Model 3 | from launcher import run 4 | from dataloader import * 5 | from checkpointer import Checkpointer 6 | from dlutils.pytorch import count_parameters 7 | from defaults import get_cfg_defaults 8 | from PIL import Image 9 | import PIL 10 | from torchvision.utils import save_image 11 | 12 | def draw_uncurated_result_figure(cfg, png, model, cx, cy, cw, ch, rows, lods, seed): 13 | print(png) 14 | N = sum(rows * 2**lod for lod in lods) 15 | images = [] 16 | 17 | rnd = np.random.RandomState(5) 18 | for i in range(N): 19 | latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE) 20 | samplez = torch.tensor(latents).float().cuda() 21 | image = model.generate(cfg.DATASET.MAX_RESOLUTION_LEVEL-2, 1, samplez, 1, mixing=True) 22 | images.append(image[0]) 23 | 24 | canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white') 25 | image_iter = iter(list(images)) 26 | for col, lod in enumerate(lods): 27 | for row in range(rows * 2**lod): 28 | im = next(image_iter).cpu().numpy() 29 | im = im.transpose(1, 2, 0) 30 | im = im * 0.5 + 0.5 31 | image = PIL.Image.fromarray(np.clip(im * 255, 0, 255).astype(np.uint8), 'RGB') 32 | image = image.crop((cx, cy, cx + cw, cy + ch)) 33 | image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS) 34 | canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod)) 35 | canvas.save(png) 36 | 37 | 38 | def sample(cfg, logger): 39 | torch.cuda.set_device(0) 40 | model = Model( 41 | startf=cfg.MODEL.START_CHANNEL_COUNT, 42 | layer_count=cfg.MODEL.LAYER_COUNT, 43 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT, 44 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE, 45 | truncation_psi=cfg.MODEL.TRUNCATIOM_PSI, 46 | truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, 47 | style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB, 48 | mapping_layers=cfg.MODEL.MAPPING_LAYERS, 49 | channels=cfg.MODEL.CHANNELS, 50 | generator=cfg.MODEL.GENERATOR, 51 | encoder=cfg.MODEL.ENCODER) 52 | 53 | model.cuda(0) 54 | model.eval() 55 | model.requires_grad_(False) 56 | 57 | decoder = model.decoder 58 | encoder = model.encoder 59 | mapping_tl = model.mapping_d 60 | mapping_fl = model.mapping_f 61 | 62 | dlatent_avg = model.dlatent_avg 63 | 64 | logger.info("Trainable parameters generator:") 65 | count_parameters(decoder) 66 | 67 | logger.info("Trainable parameters discriminator:") 68 | count_parameters(encoder) 69 | 70 | arguments = dict() 71 | arguments["iteration"] = 0 72 | 73 | model_dict = { 74 | 'discriminator_s': encoder, 75 | 'generator_s': decoder, 76 | 'mapping_tl_s': mapping_tl, 77 | 'mapping_fl_s': mapping_fl, 78 | 'dlatent_avg': dlatent_avg 79 | } 80 | 81 | checkpointer = Checkpointer(cfg, 82 | model_dict, 83 | {}, 84 | logger=logger, 85 | save=False) 86 | 87 | checkpointer.load() 88 | 89 | model.eval() 90 | 91 | layer_count = cfg.MODEL.LAYER_COUNT 92 | 93 | decoder = nn.DataParallel(decoder) 94 | 95 | im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1) 96 | with torch.no_grad(): 97 | rnd = np.random.RandomState(0) 98 | latents = rnd.randn(1, 512) 99 | samplez = torch.tensor(latents).float().cuda() 100 | image = model.generate(8, 1, samplez, 1, mixing=True) 101 | ls = model.encode(image,8,1) 102 | x = model.decoder(x, 8, 1, noise=True) 103 | save_image(samplez,'1.png') 104 | save_image(x,'2.png') 105 | 106 | if __name__ == "__main__": 107 | gpu_count = 1 108 | run(sample, get_cfg_defaults(), description='ALAE-generations', default_config='configs/ffhq.yaml', 109 | world_size=gpu_count, write_log=False) 110 | -------------------------------------------------------------------------------- /model/utils/biggan_config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | BigGAN config. 4 | """ 5 | from __future__ import (absolute_import, division, print_function, unicode_literals) 6 | 7 | import copy 8 | import json 9 | import os 10 | 11 | class BigGANConfig(object): 12 | """ Configuration class to store the configuration of a `BigGAN`. 13 | Defaults are for the 128x128 model. 14 | layers tuple are (up-sample in the layer ?, input channels, output channels) 15 | """ 16 | def __init__(self, 17 | output_dim=128, 18 | z_dim=128, 19 | class_embed_dim=128, 20 | channel_width=128, 21 | num_classes=1000, 22 | layers=[(False, 16, 16), 23 | (True, 16, 16), 24 | (False, 16, 16), 25 | (True, 16, 8), 26 | (False, 8, 8), 27 | (True, 8, 4), 28 | (False, 4, 4), 29 | (True, 4, 2), 30 | (False, 2, 2), 31 | (True, 2, 1)], 32 | attention_layer_position=8, 33 | eps=1e-4, 34 | n_stats=51): 35 | """Constructs BigGANConfig. """ 36 | self.output_dim = output_dim 37 | self.z_dim = z_dim 38 | self.class_embed_dim = class_embed_dim 39 | self.channel_width = channel_width 40 | self.num_classes = num_classes 41 | self.layers = layers 42 | self.attention_layer_position = attention_layer_position 43 | self.eps = eps 44 | self.n_stats = n_stats 45 | 46 | @classmethod 47 | def from_dict(cls, json_object): 48 | """Constructs a `BigGANConfig` from a Python dictionary of parameters.""" 49 | config = BigGANConfig() 50 | for key, value in json_object.items(): 51 | config.__dict__[key] = value 52 | return config 53 | 54 | @classmethod 55 | def from_json_file(cls, json_file): 56 | """Constructs a `BigGANConfig` from a json file of parameters.""" 57 | with open(json_file, "r", encoding='utf-8') as reader: 58 | text = reader.read() 59 | return cls.from_dict(json.loads(text)) 60 | 61 | def __repr__(self): 62 | return str(self.to_json_string()) 63 | 64 | def to_dict(self): 65 | """Serializes this instance to a Python dictionary.""" 66 | output = copy.deepcopy(self.__dict__) 67 | return output 68 | 69 | def to_json_string(self): 70 | """Serializes this instance to a JSON string.""" 71 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 72 | -------------------------------------------------------------------------------- /model/utils/biggan_file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE', 31 | Path.home() / '.pytorch_pretrained_biggan')) 32 | except (AttributeError, ImportError): 33 | PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | response = requests.head(url, allow_redirects=True) 192 | if response.status_code != 200: 193 | raise IOError("HEAD request failed for url {} with status code {}" 194 | .format(url, response.status_code)) 195 | etag = response.headers.get("ETag") 196 | 197 | filename = url_to_filename(url, etag) 198 | 199 | # get cache path to put the file 200 | cache_path = os.path.join(cache_dir, filename) 201 | 202 | if not os.path.exists(cache_path): 203 | # Download to temporary file, then copy to cache dir once finished. 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. 205 | with tempfile.NamedTemporaryFile() as temp_file: 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 207 | 208 | # GET file object 209 | if url.startswith("s3://"): 210 | s3_get(url, temp_file) 211 | else: 212 | http_get(url, temp_file) 213 | 214 | # we are copying the file before closing it, so flush to avoid truncation 215 | temp_file.flush() 216 | # shutil.copyfileobj() starts at the current position, so go to the start 217 | temp_file.seek(0) 218 | 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 220 | with open(cache_path, 'wb') as cache_file: 221 | shutil.copyfileobj(temp_file, cache_file) 222 | 223 | logger.info("creating metadata file for %s", cache_path) 224 | meta = {'url': url, 'etag': etag} 225 | meta_path = cache_path + '.json' 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 227 | json.dump(meta, meta_file) 228 | 229 | logger.info("removing temp file %s", temp_file.name) 230 | 231 | return cache_path 232 | 233 | 234 | def read_set_from_file(filename): 235 | ''' 236 | Extract a de-duped collection (set) of text from a file. 237 | Expected file format is one item per line. 238 | ''' 239 | collection = set() 240 | with open(filename, 'r', encoding='utf-8') as file_: 241 | for line in file_: 242 | collection.add(line.rstrip()) 243 | return collection 244 | 245 | 246 | def get_file_extension(path, dot=True, lower=True): 247 | ext = os.path.splitext(path)[1] 248 | ext = ext if dot else ext[1:] 249 | return ext.lower() if lower else ext 250 | -------------------------------------------------------------------------------- /model/utils/custom_adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class LREQAdam(Optimizer): 7 | def __init__(self, params, lr=1e-3, betas=(0.0, 0.99), eps=1e-8, 8 | weight_decay=0): 9 | beta_2 = betas[1] 10 | if not 0.0 <= lr: 11 | raise ValueError("Invalid learning rate: {}".format(lr)) 12 | if not 0.0 <= eps: 13 | raise ValueError("Invalid epsilon value: {}".format(eps)) 14 | if not 0.0 == betas[0]: 15 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 16 | if not 0.0 <= beta_2 < 1.0: 17 | raise ValueError("Invalid beta parameter at index 1: {}".format(beta_2)) 18 | defaults = dict(lr=lr, beta_2=beta_2, eps=eps, weight_decay=weight_decay) 19 | super(LREQAdam, self).__init__(params, defaults) 20 | 21 | def __setstate__(self, state): 22 | super(LREQAdam, self).__setstate__(state) 23 | 24 | def step(self, closure=None): 25 | """Performs a single optimization step. 26 | Arguments: 27 | closure (callable, optional): A closure that reevaluates the model and returns the loss. 28 | """ 29 | loss = None 30 | if closure is not None: 31 | loss = closure() 32 | 33 | for group in self.param_groups: 34 | for p in group['params']: 35 | if p.grad is None: 36 | continue 37 | grad = p.grad.data 38 | if grad.is_sparse: 39 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 40 | 41 | state = self.state[p] 42 | 43 | # State initialization 44 | if len(state) == 0: 45 | state['step'] = 0 46 | # Exponential moving average of gradient values 47 | # state['exp_avg'] = torch.zeros_like(p.data) 48 | # Exponential moving average of squared gradient values 49 | state['exp_avg_sq'] = torch.zeros_like(p.data) 50 | 51 | # exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 52 | exp_avg_sq = state['exp_avg_sq'] 53 | beta_2 = group['beta_2'] 54 | 55 | state['step'] += 1 56 | 57 | if group['weight_decay'] != 0: 58 | grad.add_(group['weight_decay'], p.data / p.coef) 59 | 60 | # Decay the first and second moment running average coefficient 61 | # exp_avg.mul_(beta1).add_(1 - beta1, grad) 62 | exp_avg_sq.mul_(beta_2).addcmul_(1 - beta_2, grad, grad) 63 | denom = exp_avg_sq.sqrt().add_(group['eps']) 64 | 65 | # bias_correction1 = 1 - beta1 ** state['step'] # 1 66 | bias_correction2 = 1 - beta_2 ** state['step'] 67 | # step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 68 | step_size = group['lr'] * math.sqrt(bias_correction2) 69 | 70 | # p.data.addcdiv_(-step_size, exp_avg, denom) 71 | if hasattr(p, 'lr_equalization_coef'): 72 | step_size *= p.lr_equalization_coef 73 | 74 | p.data.addcdiv_(-step_size, grad, denom) # p = p + (-steo_size)/grad * demon 75 | 76 | return loss 77 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Improving generative adversarial network inversion via fine-tuning GAN encoders 2 | 3 | ![Python 3.7.3](https://img.shields.io/badge/python-3.7.3-blue.svg?style=plastic) 4 | ![PyTorch 1.8.1](https://img.shields.io/badge/pytorch-1.8.1-blue.svg?style=plastic) 5 | ![Apache-2.0](https://img.shields.io/badge/License-Apache%202.0-green.svg?style=plastic) 6 | 7 | cxx1 cxx2 msk dy zy 8 | 9 | 10 | >This is the official code for "Improving generative adversarial network inversion via 11 | fine-tuning GAN encoders". 12 | 13 | >The code contains a set of encoders that match pre-trained GANs (PGGAN, StyleGAN1, StyleGAN2, BigGAN). BTW, The project can match other GANs in the same way. 14 | 15 | 16 | ## Usage 17 | 18 | 19 | - training encoder with center attentions (scale for align images) 20 | 21 | > python E_align.py 22 | 23 | - training encoder with Grad-CAM-based attentions (scale for misalign images) 24 | 25 | > python E_mis_align.py 26 | 27 | 28 | - embedding real images to latent space (using StyleGANv1 and w). 29 | 30 | ``` 31 | a. You can put real images at './checkpoint/realimg_file/' (default file as args.img_dir) 32 | 33 | b. You should load pre-trained Encoder at './checkpoint/E/E_blur(case2)_styleganv1_FFHQ_state_dict.pth' 34 | 35 | c. Then run: 36 | ``` 37 | 38 | > python embedding_img.py 39 | 40 | - discovering attribute directions with latent space : embedded_img_processing.py 41 | 42 | Note: Pre-trained Model should be download first , and default save to './chechpoint/' 43 | 44 | ## Metric 45 | 46 | - validate performance (Pre-trained GANs and baseline) 47 | 48 | 1. using generations.py to generate reconstructed images (generate GANs images if needed) 49 | 2. Files in the directory "./baseline/" could help you to quickly format images and latent vectors (w). 50 | 3. Put comparing images to different files, and run comparing-baseline.py 51 | 52 | 53 | - ablation study : look at ''./ablations-study/'' 54 | 55 | 56 | ## Setup 57 | 58 | ### Encoders 59 | 60 | - Case 1: Training most pre-trained GANs with encoders. 61 | at './model/E/E.py' (quickly converge for reconstructed GANs' image) 62 | - Case 2: Training StyleGANv1 on FFHQ for ablation study and real face image process 63 | at './model/E/E_Blur.py' (margin blur and more GPU memory for pixels gradients) 64 | 65 | ### Pre-Trained GANs 66 | > note: put pre-trained GANs weight file at ''./checkpoint/' directory 67 | - StyleGAN_V1 (should contain 3 files: Gm, Gs, center-tensor): 68 | - Cat 256: 69 | - ./checkpoint/stylegan_V1/cat/cat256_Gs_dict.pth 70 | - ./checkpoint/stylegan_V1/cat/cat256_Gm_dict.pth 71 | - ./checkpoint/stylegan_V1/cat/cat256_tensor.pt 72 | - Car 256: same above 73 | - Bedroom 256: 74 | - StyleGAN_V2 (Only one files : pth): 75 | - FFHQ 1024: 76 | - ./checkpoint/stylegan_V2/stylegan2_ffhq1024.pth 77 | - PGGAN ((Only one files : pth)): 78 | - Horse 256: 79 | - ./checkpoint/PGGAN/ 80 | - BigGAN (Two files : model as .pt and config as .json ): 81 | - Image-Net 256: 82 | - ./checkpoint/biggan/256/G-256.pt 83 | - ./checkpoint/biggan/256/biggan-deep-256-config.json 84 | 85 | ### Options and Setting 86 | > note: different GANs should set different parameters carefully. 87 | 88 | - choose --mtype for StyleGANv1=1, StyleGANv2=2, PGGAN=3, BIGGAN=4 89 | - choose Encoder start_features (--z_dim) carefully, the value are: 16->1024x1024, 32->512x512, 64->256x256 90 | - if go on training, set --checkpoint_dir_E which path save pre-trained Encoder model 91 | - --checkpoint_dir_GAN is needed, StyleGANv1 is a directory(contains 3 filers: Gm, Gs, center-tensor) , others are file path (.pth or .pt) 92 | ```python 93 | parser = argparse.ArgumentParser(description='the training args') 94 | parser.add_argument('--iterations', type=int, default=210000) # epoch = iterations//30000 95 | parser.add_argument('--lr', type=float, default=0.0015) 96 | parser.add_argument('--beta_1', type=float, default=0.0) 97 | parser.add_argument('--batch_size', type=int, default=2) 98 | parser.add_argument('--experiment_dir', default=None) #None 99 | parser.add_argument('--checkpoint_dir_GAN', default='./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt 100 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it 101 | parser.add_argument('--checkpoint_dir_E', default=None) 102 | parser.add_argument('--img_size',type=int, default=1024) 103 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 104 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128 105 | parser.add_argument('--mtype', type=int, default=2) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 106 | parser.add_argument('--start_features', type=int, default=16) # 16->1024x1024, 32->512x512, 64->256x256, 128->128x128 107 | ``` 108 | 109 | ## Pre-trained Model 110 | 111 | We offered pre-trainned GANs and their corresponding encoders here: [models](https://drive.google.com/drive/folders/1vqx5Sol04MAbeNLk9h0ouo8MiR3rJI4f?usp=sharing) (default setting is the case1 ). 112 | 113 | GANs: 114 | 115 | - StyleGANv1-(FFHQ1024, Car512, Cat256) models which contain 3 files Gm, Gs and center-tensor. 116 | - PGGAN and StyleGANv2. A single .pth file gets Gm, Gs and center-tensor together. 117 | - BigGAN 128x128 ,256x256, and 512x512: each type contain a config file and model (.pt) 118 | 119 | Encoders: 120 | 121 | - StyleGANv1 FFHQ (case 2) for real-image embedding and process. 122 | - StyleGANv2 LSUN Cat 256, they are one models from case 1 (Grad-CAM based attentions) and both models from case 2 (Grad-Cam based and Center-aligned Attentions for ablation study): 123 | - StyleGANv2 FFHQ (case 1) 124 | - Biggan-256 (case 1) 125 | 126 | If you want to try more GANs, cite more pre-trained GANs below: 127 | 128 | 129 | ## Acknowledgements 130 | 131 | Pre-trained GANs: 132 | 133 | > StyleGANv1: https://github.com/podgorskiy/StyleGan.git, 134 | > ( Converting code for official pre-trained model is here: https://github.com/podgorskiy/StyleGAN_Blobless.git) 135 | > StyleGANv2 and PGGAN: https://github.com/genforce/genforce.git 136 | > BigGAN: https://github.com/huggingface/pytorch-pretrained-BigGAN 137 | 138 | Comparing Works: 139 | 140 | > E2Style: https://github.com/wty-ustc/e2style 141 | > In-Domain GAN: https://github.com/genforce/idinvert_pytorch 142 | > pSp: https://github.com/eladrich/pixel2style2pixel 143 | > ALAE: https://github.com/podgorskiy/ALAE.git 144 | 145 | 146 | Related Works: 147 | 148 | > Grad-CAM & Grad-CAM++: https://github.com/yizt/Grad-CAM.pytorch 149 | > SSIM Index: https://github.com/Po-Hsun-Su/pytorch-ssim 150 | 151 | Our method implementation partly borrow from the above works (ALAE and Related Works). We would like to thank those authors. 152 | 153 | 154 | ## License 155 | 156 | The code of this repository is released under the [Apache 2.0](LICENSE) license.
The directories `models/biggan` and `models/stylegan2` are provided under the MIT license.
157 | 158 | ## 简体中文: 159 | 160 | 如何应用于[编辑人脸](./readme_cn.md) 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /readme_cn.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 用主流的GAN制作编码器,并编辑高清人脸图片 4 | 5 | 6 | cxx1 cxx2 msk dy zy 7 | 8 | 看到这几幅人脸动画图了吧,这些都是通过一张图片生成。本篇博客介绍一种简练的方法,即用Stylegan(V1)编辑图片并改变属性 (如让人脸微笑,带眼镜,转方向等). 9 | 10 | ## 1.加载StyleGAN (v1) 11 | 12 | 首先需要下载预训练模型,保重StyleGAN 1024x1024的人脸图像能生成 (即复现StyleGAN). 13 | 14 | 这里需要注意的是,预训练模型保存为三个部分: 15 | 16 | a.Gm, 将z [n, 512]映射为w [n, 18, 512]. 其中n是batch_size 17 | 18 | b.Gs, 输入w 输出对应图片, 注意这个w是每层输入一个[n,512]. 有18层, 即 [n,18,512] 19 | 20 | c.avg_tensor, 这个是一个训练好的常向量,用于模型的首次输入,[n, 512,4 , 4]. (而上面的w是从各个层单独输入) 21 | 22 | 这里提供StyleGANv1预训练模型 [Model](https://pan.baidu.com/s/1_JewahCd_UK5wIMCQzFAPA 23 | ) , 提取码: kwsk 24 | 25 | ## 2.将真实图片编码 26 | 27 | 通过一个我们提供的编码器将一张1024X1024的人脸图片编码到潜变量W (1,18,512) , 也可以同时处理多张人脸 (1->n),这个根据自己显卡内存大小决定。另外需要注意最好是五官对齐的人脸. 并把编好的W保存到文件夹(默认:./latentvectors/faces/). 里面已经有多张人脸了,可以用于测试上面的StyleGAN. 28 | 29 | 可以运行以下文件: 30 | 31 | > python embedding_img.py 32 | 33 | a.这个文件加载Encoder预训练模型 34 | > 默认路径为./checkpoint/E/E_blur(case2)_styleganv1_FFHQ_state_dict.pth 35 | 36 | b.并编码真实图像 37 | >默认路径为 ./checkpoint/realimg_file/)编码到w (默认路径为 ./result/models/ 38 | 39 | c.这里提供Encoder预训练模型 [Model](https://pan.baidu.com/s/1F9Tv5ph9Rejp5JTQK2HSYQ 40 | ) , 提取码: swtl 41 | 42 | ## 3. 编辑表情 43 | 44 | 用'./latentvectors/directions/'中的属性向量乘以一个系数(5或10的倍数),加上人脸的向量,探索并保存人脸变化,参见: 45 | 46 | > embedded_img_processing.py 47 | 48 | 如果想要探索其他人脸属性,也很多方法,最简单有监督方法可以参考: https://github.com/Puzer/stylegan-encoder (里面也附带人脸对齐) 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /rec_real_img.py: -------------------------------------------------------------------------------- 1 | #reconstructing real_image by dirrectlly MTV E (no optimize E via StyleGANv1) 2 | #set arg.realimg_dir for embedding real images to latent vectors 3 | 4 | import os 5 | import math 6 | import torch 7 | import argparse 8 | import numpy as np 9 | import torchvision 10 | import model.E.E_Blur as BE 11 | from collections import OrderedDict 12 | from training_utils import * 13 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 14 | import model.stylegan2_generator as model_v2 #StyleGANv2 15 | import model.pggan.pggan_generator as model_pggan #PGGAN 16 | from model.biggan_generator import BigGAN #BigGAN 17 | 18 | if __name__ == "__main__": 19 | 20 | if not os.path.exists('./images'): os.mkdir('./images') 21 | 22 | parser = argparse.ArgumentParser(description='the training args') 23 | parser.add_argument('--batch_size', type=int, default=2) 24 | parser.add_argument('--experiment_dir', default=None) #None 25 | parser.add_argument('--checkpoint_dir_gan', default='./checkpoint/stylegan_v1/ffhq1024/') # stylegan_v2/stylegan2_ffhq1024.pth 26 | parser.add_argument('--checkpoint_dir_e', default='./checkpoint/E/E_styleganv1_state_dict.pth') #None or E_ffhq_styleganv2_modelv2_ep110000.pth E_ffhq_styleganv2_modelv1_ep85000.pth 27 | parser.add_argument('--config_dir', default=None) 28 | parser.add_argument('--realimg_dir', default='./images/real_images128/') 29 | parser.add_argument('--img_size',type=int, default=1024) 30 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 31 | parser.add_argument('--z_dim', type=int, default=512) 32 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 33 | parser.add_argument('--start_features', type=int, default=16) 34 | args = parser.parse_args() 35 | 36 | use_gpu = True 37 | device = torch.device("cuda" if use_gpu else "cpu") 38 | 39 | resultPath1_1 = "./images/imgs" 40 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1) 41 | 42 | resultPath1_2 = "./images/rec" 43 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2) 44 | 45 | #Load GANs 46 | type = args.mtype 47 | model_path = args.checkpoint_dir_gan 48 | config_path = args.config_dir 49 | 50 | if type == 1: # StyleGAN1, 1 diretory contains 3files(Gm, Gs, center-tensor) 51 | #model_path = './checkpoint/stylegan_v1/ffhq1024/' 52 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 53 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth')) 54 | 55 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 56 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth')) 57 | 58 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt') 59 | const_ = Gs.const 60 | const1 = const_.repeat(args.batch_size,1,1,1).cuda() 61 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024 62 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 63 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 64 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1] 65 | 66 | Gs.cuda() 67 | Gm.eval() 68 | 69 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 70 | 71 | else: 72 | print('error') 73 | 74 | #Load E 75 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) .to(device) 76 | if args.checkpoint_dir_e is not None: 77 | E.load_state_dict(torch.load(args.checkpoint_dir_e, map_location=torch.device(device))) 78 | 79 | type = args.mtype 80 | save_path = args.realimg_dir 81 | imgs_path = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith(".png") or f.endswith(".jpg")] 82 | img_size = args.img_size 83 | 84 | #PIL 2 Tensor 85 | transform = torchvision.transforms.Compose([ 86 | #torchvision.transforms.CenterCrop(160), 87 | torchvision.transforms.Resize((img_size,img_size)), 88 | torchvision.transforms.ToTensor(), 89 | #torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 90 | ]) 91 | 92 | images = [] 93 | for idx, image_path in enumerate(imgs_path): 94 | img = Image.open(image_path).convert("RGB") 95 | img = transform(img) 96 | images.append(img) 97 | 98 | imgs_tensor = torch.stack(images, dim=0) 99 | 100 | for i, j in enumerate(imgs_tensor): 101 | imgs1 = j.unsqueeze(0).cuda() * 2 -1 102 | if type != 4: 103 | const2,w2 = E(imgs1) 104 | else: 105 | const2,w2 = E(imgs1, cond_vector) 106 | 107 | if type == 1: 108 | imgs2=Gs.forward(w2,int(math.log(args.img_size,2)-2)) 109 | elif type == 2 or type == 3: 110 | imgs2=generator.synthesis(w2)['image'] 111 | elif type == 4: 112 | imgs2, _=G(w2, conditions, truncation) 113 | else: 114 | print('model type error') 115 | 116 | # n_row = args.batch_size 117 | # test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5 118 | # torchvision.utils.save_image(test_img, './v2ep%d.jpg'%(args.seed),nrow=n_row) # nrow=3 119 | torchvision.utils.save_image(imgs1*0.5+0.5, resultPath1_1+'/%s_realimg.png'%str(i).rjust(5,'0')) 120 | torchvision.utils.save_image(imgs2*0.5+0.5, resultPath1_2+'/%s_mtv_rec.png'%str(i).rjust(5,'0')) 121 | print('doing:'+str(i).rjust(5,'0')) 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #pip install belows: 2 | 3 | 4 | python==3.7.3 5 | 6 | torch>=1.8 7 | 8 | torchvision 9 | 10 | tensorboardx 11 | 12 | lpips 13 | 14 | boto3 15 | 16 | skimage 17 | 18 | matplotlib 19 | 20 | scipy 21 | 22 | tqdm 23 | 24 | requests 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /synthesized_IMG.py: -------------------------------------------------------------------------------- 1 | # make styleganv1 and BigGAN synthesized images for validation. seeds should be different(0,30000) 2 | 3 | import os 4 | import math 5 | import torch 6 | import torchvision 7 | import model.E as BE #default styleganv1, if you will reconstruct other GANs, change corresponding E in here. 8 | import lpips 9 | import numpy as np 10 | import argparse 11 | import tensorboardX 12 | from collections import OrderedDict 13 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 14 | import model.stylegan2_generator as model_v2 #StyleGANv2 15 | import model.pggan.pggan_generator as model_pggan #PGGAN 16 | from model.biggan_generator import BigGAN #BigGAN 17 | from training_utils import * 18 | from model.utils.biggan_config import BigGANConfig 19 | import model.E.E_BIG as BE_BIG 20 | 21 | def train(tensor_writer = None, args = None): 22 | type = args.mtype 23 | 24 | model_path = args.checkpoint_dir_GAN 25 | config_path = args.config_dir 26 | if type == 1: # StyleGAN1 27 | #model_path = './checkpoint/stylegan_v1/ffhq1024/' 28 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 29 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth')) 30 | 31 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 32 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth')) 33 | 34 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt') 35 | const_ = Gs.const 36 | const1 = const_.repeat(args.batch_size,1,1,1).cuda() 37 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024 38 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 39 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 40 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1] 41 | 42 | Gs.cuda() 43 | Gm.eval() 44 | 45 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 46 | 47 | elif type == 2: # StyleGAN2 48 | #model_path = './checkpoint/stylegan_v2/stylegan2_ffhq1024.pth' 49 | generator = model_v2.StyleGAN2Generator(resolution=args.img_size).to(device) 50 | checkpoint = torch.load(model_path) #map_location='cpu' 51 | if 'generator_smooth' in checkpoint: #default 52 | generator.load_state_dict(checkpoint['generator_smooth']) 53 | else: 54 | generator.load_state_dict(checkpoint['generator']) 55 | synthesis_kwargs = dict(trunc_psi=0.7,trunc_layers=8,randomize_noise=False) 56 | #Gs = generator.synthesis 57 | #Gs.cuda() 58 | #Gm = generator.mapping 59 | #truncation = generator.truncation 60 | const_r = torch.randn(args.batch_size) 61 | const1 = generator.synthesis.early_layer(const_r) #[n,512,4,4] 62 | 63 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) # layer_count: 7->256 8->512 9->1024 64 | 65 | elif type == 3: # PGGAN 66 | #model_path = './checkpoint/PGGAN/pggan_horse256.pth' 67 | generator = model_pggan.PGGANGenerator(resolution=args.img_size).to(device) 68 | checkpoint = torch.load(model_path) #map_location='cpu' 69 | if 'generator_smooth' in checkpoint: #默认是这个 70 | generator.load_state_dict(checkpoint['generator_smooth']) 71 | else: 72 | generator.load_state_dict(checkpoint['generator']) 73 | const1 = torch.tensor(0) 74 | 75 | E = BE_PG.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3, pggan=True) 76 | 77 | elif type == 4: 78 | model_path = './checkpoint/biggan/256/G-256.pt' 79 | config_file = './checkpoint/biggan/256/biggan-deep-256-config.json' 80 | config = BigGANConfig.from_json_file(config_file) 81 | generator = BigGAN(config).to(device) 82 | generator.load_state_dict(torch.load(model_path)) 83 | 84 | E = BE_BIG.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3, biggan=True).to(device) 85 | 86 | else: 87 | print('error') 88 | return 89 | 90 | #Load E 91 | # E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) .to(device) 92 | if args.checkpoint_dir_E is not None: 93 | E.load_state_dict(torch.load(args.checkpoint_dir_E, map_location=torch.device(device))) 94 | 95 | batch_size = args.batch_size 96 | it_d = 0 97 | for iteration in range(30000,30000+args.iterations): 98 | set_seed(iteration) 99 | z = torch.randn(batch_size, args.z_dim) #[32, 512] 100 | if type == 1: 101 | with torch.no_grad(): #这里需要生成图片和变量 102 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512] 103 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256 104 | elif type == 2: 105 | with torch.no_grad(): 106 | #use generator 107 | result_all = generator(z.cuda(), **synthesis_kwargs) 108 | imgs1 = result_all['image'] 109 | w1 = result_all['wp'] 110 | 111 | elif type == 3: 112 | with torch.no_grad(): #这里需要生成图片和变量 113 | w1 = z.cuda() 114 | result_all = generator(w1) 115 | imgs1 = result_all['image'] 116 | elif type == 4: 117 | z = truncated_noise_sample(truncation=0.4, batch_size=batch_size, seed=iteration%30000) 118 | #label = np.random.randint(1000,size=batch_size) # 生成标签 119 | flag = np.array(30) 120 | print(flag) 121 | label = np.ones(batch_size) 122 | label = flag * label 123 | label = one_hot(label) 124 | w1 = torch.tensor(z, dtype=torch.float).cuda() 125 | conditions = torch.tensor(label, dtype=torch.float).cuda() # as label 126 | truncation = torch.tensor(0.4, dtype=torch.float).cuda() 127 | with torch.no_grad(): #这里需要生成图片和变量 128 | imgs1, const1 = generator(w1, conditions, truncation) # const1 are conditional vectors in BigGAN 129 | 130 | if type != 4: 131 | const2,w2 = E(imgs1) 132 | else: 133 | const2,w2 = E(imgs1, const1) 134 | 135 | if type == 1: 136 | imgs2=Gs.forward(w2,int(math.log(args.img_size,2)-2)) 137 | elif type == 2 or type == 3: 138 | imgs2=generator.synthesis(w2)['image'] 139 | elif type == 4: 140 | imgs2, _=generator(w2, conditions, truncation) 141 | else: 142 | print('model type error') 143 | return 144 | 145 | imgs = torch.cat((imgs1,imgs2)) 146 | torchvision.utils.save_image(imgs*0.5+0.5, resultPath1_1+'/id%s_%s.png'%(flag, str(iteration-30000).rjust(5,'0')), nrow=10) # nrow=3 147 | #torchvision.utils.save_image(imgs2*0.5+0.5, resultPath1_2+'/%s_styleganv1_rec.png'%str(iteration-30000).rjust(5,'0')) # nrow=3 148 | 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser(description='the training args') 151 | parser.add_argument('--iterations', type=int, default=5) 152 | parser.add_argument('--seed', type=int, default=30001) # training seeds: 0-30000; validated seeds > 30000 153 | parser.add_argument('--lr', type=float, default=0.0015) 154 | parser.add_argument('--beta_1', type=float, default=0.0) 155 | parser.add_argument('--batch_size', type=int, default=10) 156 | parser.add_argument('--experiment_dir', default=None) #None 157 | parser.add_argument('--checkpoint_dir_GAN', default='./checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth 158 | parser.add_argument('--config_dir', default=None) # BigGAN needs it 159 | parser.add_argument('--checkpoint_dir_E', default='./result/BigGAN-256/models/E_model_ep30000.pth') 160 | parser.add_argument('--img_size',type=int, default=256) 161 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 162 | parser.add_argument('--z_dim', type=int, default=128) 163 | parser.add_argument('--mtype', type=int, default=4) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 164 | parser.add_argument('--start_features', type=int, default=64) # 16->1024 32->512 64->256 165 | args = parser.parse_args() 166 | 167 | if not os.path.exists('./result'): os.mkdir('./result') 168 | resultPath = args.experiment_dir 169 | if resultPath == None: 170 | resultPath = "./result/BigGAN256_inversion_v4" 171 | if not os.path.exists(resultPath): os.mkdir(resultPath) 172 | 173 | resultPath1_1 = resultPath+"/generations" 174 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1) 175 | 176 | resultPath1_2 = resultPath+"/reconstructions" 177 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2) 178 | 179 | writer_path = os.path.join(resultPath, './summaries') 180 | if not os.path.exists(writer_path): os.mkdir(writer_path) 181 | writer = tensorboardX.SummaryWriter(writer_path) 182 | 183 | use_gpu = True 184 | device = torch.device("cuda" if use_gpu else "cpu") 185 | 186 | train(tensor_writer=writer, args = args) 187 | -------------------------------------------------------------------------------- /synthesized_textBigGAN.py: -------------------------------------------------------------------------------- 1 | # make styleganv1 and BigGAN synthesized images for validation. seeds should be different seeds (0,30000) 2 | 3 | import os 4 | import math 5 | import torch 6 | import torchvision 7 | import model.E as BE #default styleganv1, if you will reconstruct other GANs, change corresponding E in here. 8 | import lpips 9 | import numpy as np 10 | import argparse 11 | import tensorboardX 12 | from collections import OrderedDict 13 | from model.stylegan1.net import Generator, Mapping #StyleGANv1 14 | import model.stylegan2_generator as model_v2 #StyleGANv2 15 | import model.pggan.pggan_generator as model_pggan #PGGAN 16 | from model.biggan_generator import BigGAN #BigGAN 17 | from training_utils import * 18 | from model.utils.biggan_config import BigGANConfig 19 | import model.E.E_BIG as BE_BIG 20 | 21 | def train(tensor_writer = None, args = None): 22 | type = args.mtype 23 | 24 | model_path = args.checkpoint_dir_GAN 25 | config_path = args.config_dir 26 | if type == 1: # StyleGAN1 27 | #model_path = './checkpoint/stylegan_v1/ffhq1024/' 28 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 29 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth')) 30 | 31 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024 32 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth')) 33 | 34 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt') 35 | const_ = Gs.const 36 | const1 = const_.repeat(args.batch_size,1,1,1).cuda() 37 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024 38 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17] 39 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1] 40 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1] 41 | 42 | Gs.cuda() 43 | Gm.eval() 44 | 45 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) 46 | 47 | elif type == 2: # StyleGAN2 48 | #model_path = './checkpoint/stylegan_v2/stylegan2_ffhq1024.pth' 49 | generator = model_v2.StyleGAN2Generator(resolution=args.img_size).to(device) 50 | checkpoint = torch.load(model_path) #map_location='cpu' 51 | if 'generator_smooth' in checkpoint: #default 52 | generator.load_state_dict(checkpoint['generator_smooth']) 53 | else: 54 | generator.load_state_dict(checkpoint['generator']) 55 | synthesis_kwargs = dict(trunc_psi=0.7,trunc_layers=8,randomize_noise=False) 56 | #Gs = generator.synthesis 57 | #Gs.cuda() 58 | #Gm = generator.mapping 59 | #truncation = generator.truncation 60 | const_r = torch.randn(args.batch_size) 61 | const1 = generator.synthesis.early_layer(const_r) #[n,512,4,4] 62 | 63 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) # layer_count: 7->256 8->512 9->1024 64 | 65 | elif type == 3: # PGGAN 66 | #model_path = './checkpoint/PGGAN/pggan_horse256.pth' 67 | generator = model_pggan.PGGANGenerator(resolution=args.img_size).to(device) 68 | checkpoint = torch.load(model_path) #map_location='cpu' 69 | if 'generator_smooth' in checkpoint: #默认是这个 70 | generator.load_state_dict(checkpoint['generator_smooth']) 71 | else: 72 | generator.load_state_dict(checkpoint['generator']) 73 | const1 = torch.tensor(0) 74 | 75 | E = BE_PG.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3, pggan=True) 76 | 77 | elif type == 4: 78 | model_path = './checkpoint/biggan/256/G-256.pt' 79 | config_file = './checkpoint/biggan/256/biggan-deep-256-config.json' 80 | config = BigGANConfig.from_json_file(config_file) 81 | generator = BigGAN(config).to(device) 82 | generator.load_state_dict(torch.load(model_path)) 83 | 84 | E = BE_BIG.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3, biggan=True).to(device) 85 | 86 | else: 87 | print('error') 88 | return 89 | 90 | #Load E 91 | # E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) .to(device) 92 | if args.checkpoint_dir_E is not None: 93 | E.load_state_dict(torch.load(args.checkpoint_dir_E, map_location=torch.device(device))) 94 | 95 | batch_size = args.batch_size 96 | it_d = 0 97 | for iteration in range(30000,30000+args.iterations): 98 | set_seed(iteration) 99 | z = torch.randn(batch_size, args.z_dim) #[32, 512] 100 | if type == 1: 101 | with torch.no_grad(): #这里需要生成图片和变量 102 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512] 103 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256 104 | elif type == 2: 105 | with torch.no_grad(): 106 | #use generator 107 | result_all = generator(z.cuda(), **synthesis_kwargs) 108 | imgs1 = result_all['image'] 109 | w1 = result_all['wp'] 110 | 111 | elif type == 3: 112 | with torch.no_grad(): #这里需要生成图片和变量 113 | w1 = z.cuda() 114 | result_all = generator(w1) 115 | imgs1 = result_all['image'] 116 | elif type == 4: 117 | z = truncated_noise_sample(truncation=0.4, batch_size=batch_size, seed=iteration%30000) 118 | #label = np.random.randint(1000,size=batch_size) # 生成标签 119 | flag = np.array(726) 120 | print(flag) 121 | label = np.ones(batch_size) 122 | label = flag * label 123 | label = one_hot(label) 124 | w1 = torch.tensor(z, dtype=torch.float).cuda() 125 | conditions = torch.tensor(label, dtype=torch.float).cuda() # as label 126 | truncation = torch.tensor(0.4, dtype=torch.float).cuda() 127 | with torch.no_grad(): #这里需要生成图片和变量 128 | imgs1, const1 = generator(w1, conditions, truncation) # const1 are conditional vectors in BigGAN 129 | 130 | if type != 4: 131 | const2,w2 = E(imgs1) 132 | else: 133 | const2,w2 = E(imgs1, const1) 134 | 135 | if type == 1: 136 | imgs2=Gs.forward(w2,int(math.log(args.img_size,2)-2)) 137 | elif type == 2 or type == 3: 138 | imgs2=generator.synthesis(w2)['image'] 139 | elif type == 4: 140 | imgs2, _=generator(w2, conditions, truncation) 141 | else: 142 | print('model type error') 143 | return 144 | 145 | imgs = torch.cat((imgs1,imgs2)) 146 | torchvision.utils.save_image(imgs*0.5+0.5, resultPath1_1+'/id%s_%s.png'%(flag, str(iteration-30000).rjust(5,'0')), nrow=10) # nrow=3 147 | #torchvision.utils.save_image(imgs2*0.5+0.5, resultPath1_2+'/%s_styleganv1_rec.png'%str(iteration-30000).rjust(5,'0')) # nrow=3 148 | 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser(description='the training args') 151 | parser.add_argument('--iterations', type=int, default=3) 152 | parser.add_argument('--seed', type=int, default=30001) # training seeds: 0-30000; validated seeds > 30000 153 | parser.add_argument('--lr', type=float, default=0.0015) 154 | parser.add_argument('--beta_1', type=float, default=0.0) 155 | parser.add_argument('--batch_size', type=int, default=10) 156 | parser.add_argument('--experiment_dir', default=None) #None 157 | parser.add_argument('--checkpoint_dir_GAN', default='./checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth 158 | parser.add_argument('--config_dir', default=None) # BigGAN needs it 159 | parser.add_argument('--checkpoint_dir_E', default='./result/BigGAN-256/models/E_model_ep30000.pth') 160 | parser.add_argument('--img_size',type=int, default=256) 161 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1 162 | parser.add_argument('--z_dim', type=int, default=128) 163 | parser.add_argument('--mtype', type=int, default=4) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4 164 | parser.add_argument('--start_features', type=int, default=64) # 16->1024 32->512 64->256 165 | args = parser.parse_args() 166 | 167 | if not os.path.exists('./result'): os.mkdir('./result') 168 | resultPath = args.experiment_dir 169 | if resultPath == None: 170 | resultPath = "./result/BigGAN256_inversion_v4" 171 | if not os.path.exists(resultPath): os.mkdir(resultPath) 172 | 173 | resultPath1_1 = resultPath+"/generations" 174 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1) 175 | 176 | resultPath1_2 = resultPath+"/reconstructions" 177 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2) 178 | 179 | writer_path = os.path.join(resultPath, './summaries') 180 | if not os.path.exists(writer_path): os.mkdir(writer_path) 181 | writer = tensorboardX.SummaryWriter(writer_path) 182 | 183 | use_gpu = True 184 | device = torch.device("cuda" if use_gpu else "cpu") 185 | 186 | train(tensor_writer=writer, args = args) 187 | -------------------------------------------------------------------------------- /training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.stats import truncnorm 4 | import metric.pytorch_ssim as pytorch_ssim 5 | from torch.nn import functional as F 6 | from PIL import Image 7 | import torchvision 8 | 9 | #img_path2tensor 10 | loader = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 11 | def imgPath2loader(image_name,size): 12 | image = Image.open(image_name).convert('RGB') 13 | image = image.resize((size,size)) 14 | image = loader(image)#.unsqueeze(0) 15 | return image.to(torch.float) 16 | 17 | def get_parameter_number(net): 18 | total_num = sum(p.numel() for p in net.parameters()) 19 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 20 | return {'Total': total_num, 'Trainable': trainable_num} 21 | 22 | def get_para_GByte(parameter_number): 23 | x=parameter_number['Total']*8/1024/1024/1024 24 | y=parameter_number['Total']*8/1024/1024/1024 25 | return {'Total_GB': x, 'Trainable_BG': y} 26 | 27 | def one_hot(x, class_count=1000): 28 | # 第一构造一个[class_count, class_count]的对角线为1的向量 29 | # 第二保留label对应的行并返回 30 | return torch.eye(class_count)[x,:] 31 | 32 | def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None): 33 | """ Create a truncated noise vector. 34 | Params: 35 | batch_size: batch size. 36 | dim_z: dimension of z 37 | truncation: truncation value to use 38 | seed: seed for the random generator 39 | Output: 40 | array of shape (batch_size, dim_z) 41 | """ 42 | state = None if seed is None else np.random.RandomState(seed) 43 | values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) 44 | return truncation * values 45 | 46 | def set_seed(seed): #随机数设置 47 | np.random.seed(seed) 48 | #random.seed(seed) 49 | torch.manual_seed(seed) # cpu 50 | torch.cuda.manual_seed_all(seed) # gpu 51 | torch.backends.cudnn.deterministic = True 52 | #torch.backends.cudnn.benchmark = False 53 | 54 | def space_loss(imgs1,imgs2,image_space=True,lpips_model=None): 55 | loss_mse = torch.nn.MSELoss() 56 | loss_kl = torch.nn.KLDivLoss() 57 | ssim_loss = pytorch_ssim.SSIM() 58 | loss_lpips = lpips_model 59 | 60 | imgs1 = imgs1.contiguous() 61 | imgs2 = imgs2.contiguous() 62 | 63 | loss_imgs_mse_1 = loss_mse(imgs1,imgs2) 64 | loss_imgs_mse_2 = loss_mse(imgs1.mean(),imgs2.mean()) 65 | loss_imgs_mse_3 = loss_mse(imgs1.std(),imgs2.std()) 66 | loss_imgs_mse = loss_imgs_mse_1 #+ loss_imgs_mse_2 + loss_imgs_mse_3 67 | 68 | imgs1_kl, imgs2_kl = torch.nn.functional.softmax(imgs1),torch.nn.functional.softmax(imgs2) 69 | loss_imgs_kl = loss_kl(torch.log(imgs2_kl),imgs1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs) 70 | loss_imgs_kl = torch.where(torch.isnan(loss_imgs_kl),torch.full_like(loss_imgs_kl,0), loss_imgs_kl) 71 | loss_imgs_kl = torch.where(torch.isinf(loss_imgs_kl),torch.full_like(loss_imgs_kl,1), loss_imgs_kl) 72 | 73 | imgs1_cos = imgs1.view(-1) 74 | imgs2_cos = imgs2.view(-1) 75 | loss_imgs_cosine = 1 - imgs1_cos.dot(imgs2_cos)/(torch.sqrt(imgs1_cos.dot(imgs1_cos))*torch.sqrt(imgs2_cos.dot(imgs2_cos))) #[-1,1],-1:反向相反,1:方向相同 76 | 77 | if imgs1.view(-1).shape[0] != imgs2.view(-1).shape[0]: 78 | print('error: vector1 dimentions are not equal to vector2 dimentions') 79 | return 80 | 81 | if image_space: 82 | while imgs1.shape[2] > 256: 83 | imgs1 = F.avg_pool2d(imgs1,2,2) 84 | imgs2 = F.avg_pool2d(imgs2,2,2) 85 | 86 | if image_space: 87 | ssim_value = pytorch_ssim.ssim(imgs1, imgs2) # while ssim_value<0.999: 88 | loss_imgs_ssim = 1-ssim_loss(imgs1, imgs2) 89 | else: 90 | loss_imgs_ssim = torch.tensor(0) 91 | 92 | if image_space: 93 | loss_imgs_lpips = loss_lpips(imgs1,imgs2).mean() 94 | else: 95 | loss_imgs_lpips = torch.tensor(0) 96 | 97 | loss_imgs = 5*loss_imgs_mse + 3*loss_imgs_cosine + loss_imgs_ssim + 2*loss_imgs_lpips # loss_imgs_kl 98 | loss_info = [[loss_imgs_mse_1.item(),loss_imgs_mse_2.item(),loss_imgs_mse_3.item()], loss_imgs_kl.item(), loss_imgs_cosine.item(), loss_imgs_ssim.item(), loss_imgs_lpips.item()] 99 | return loss_imgs, loss_info --------------------------------------------------------------------------------