├── E_align_cropping_s1.py
├── E_align_s2.py
├── E_mis_align_cropping_s1.py
├── ablation_utils
├── 1.E_align_z.py
├── 2.E_align_w_2.py
├── 3.E_align_w.py
├── 4.E_align_w_zn.py
├── 5.E_align_w_zn_zc.py
├── 6.E_align_x.py
├── 7.E_align_x_AT1.py
├── 8.E_align_x_AT1_AT2.py
└── Cat256
│ ├── E_align_case_1.py
│ ├── E_align_case_2.py
│ ├── E_mis_align_case_1.py
│ └── E_mis_align_case_2.py
├── baseline_utils
├── image2stylegan_w2z_opW.py
├── test-baseline-IndomainG.py
├── test_baseline_alae.py
└── test_baseline_psp.py
├── comparing-baseline.py
├── embedding_img.py
├── embedding_v2_BigGAN.py
├── embedding_v2_styleGAN1.py
├── embedding_v2_styleGAN2.py
├── embeded_img_edit.py
├── image_results
├── baseline
│ ├── ALAE-rec.jpg
│ ├── Real.jpg
│ ├── indomain_images_inv.jpg
│ ├── indomain_images_rec.jpg
│ ├── mtv-tsa-(1500E).jpg
│ └── psp-rec.jpg
├── cxx1.gif
├── cxx2.gif
├── dy.gif
├── msk.gif
└── zy.gif
├── inferE.py
├── latent_code
├── directions
│ ├── stylegan_ffhq_age_w_boundary.npy
│ ├── stylegan_ffhq_eyeglasses_w_boundary.npy
│ ├── stylegan_ffhq_gender_w_boundary.npy
│ ├── stylegan_ffhq_pose_w_boundary.npy
│ └── stylegan_ffhq_smile_w_boundary.npy
└── real_face_code
│ ├── i0_cxx1.pt
│ ├── i1_dy.pt
│ ├── i2_zy.pt
│ ├── i3_cxx2.pt
│ ├── i4_msk.pt
│ └── i5_ty.pt
├── metric
├── grad_cam.py
└── pytorch_ssim.py
├── model
├── E
│ ├── Ablation_Study
│ │ ├── E_Blur_W.py
│ │ ├── E_Blur_W_2.py
│ │ ├── E_Blur_Z.py
│ │ ├── E_v1.py
│ │ └── E_v2_std.py
│ ├── E.py
│ ├── E_BIG.py
│ ├── E_Blur.py
│ └── E_PG.py
├── biggan_generator.py
├── pggan
│ ├── pggan_d2e.py
│ ├── pggan_discriminator.py
│ ├── pggan_generator.py
│ └── utils
│ │ ├── CustomLayers.py
│ │ ├── Encoder.py
│ │ └── Networks.py
├── stylegan1
│ ├── alae.py
│ ├── custom_adam.py
│ ├── lod_driver.py
│ ├── losses.py
│ ├── lreq.py
│ ├── model.py
│ ├── net.py
│ └── text_alae.py
├── stylegan2_generator.py
└── utils
│ ├── CustomLayers.py
│ ├── biggan_config.py
│ ├── biggan_file_utils.py
│ ├── custom_adam.py
│ ├── lreq.py
│ └── net.py
├── readme.md
├── readme_cn.md
├── rec_real_img.py
├── requirements.txt
├── synthesized_IMG.py
├── synthesized_textBigGAN.py
└── training_utils.py
/ablation_utils/1.E_align_z.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 | import os
4 | import math
5 | import torch
6 | import torchvision
7 | import model.E.Ablation_Study.E_Blur_Z as BE
8 | from model.utils.custom_adam import LREQAdam
9 | import metric.pytorch_ssim as pytorch_ssim
10 | import lpips
11 | import numpy as np
12 | import tensorboardX
13 | import argparse
14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
15 | from training_utils import *
16 |
17 | def train(tensor_writer = None, args = None):
18 | type = args.mtype
19 |
20 | model_path = args.checkpoint_dir_GAN
21 | if type == 1: # StyleGAN1
22 | #model_path = './checkpoint/stylegan_v1/ffhq1024/'
23 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
24 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth'))
25 |
26 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
27 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth'))
28 |
29 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt')
30 | const_ = Gs.const
31 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda()
32 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024
33 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
34 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
35 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1]
36 | coefs = coefs.cuda()
37 |
38 | Gs.cuda()
39 | Gm.cuda()
40 |
41 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
42 |
43 | else:
44 | print('error')
45 | return
46 |
47 | if args.checkpoint_dir_E != None:
48 | E.load_state_dict(torch.load(args.checkpoint_dir_E))
49 | E.cuda()
50 | writer = tensor_writer
51 |
52 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0)
53 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda')
54 |
55 | batch_size = args.batch_size
56 | it_d = 0
57 | for iteration in range(0,args.iterations):
58 | set_seed(iteration%30000)
59 | z_c1 = torch.randn(batch_size, args.z_dim).cuda() #[n, 512]
60 |
61 | if type == 1:
62 | w1 = Gm(z_c1,coefs_m=coefs) #[batch_size,18,512]
63 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256
64 | z_c2, _ = E(imgs1)
65 | z_c2 = z_c2.squeeze(-1).squeeze(-1)
66 | w2 = Gm(z_c2,coefs_m=coefs)
67 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2))
68 | else:
69 | print('model type error')
70 | return
71 |
72 | E_optimizer.zero_grad()
73 |
74 | #loss Images
75 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips)
76 |
77 | loss_msiv = loss_imgs
78 | E_optimizer.zero_grad()
79 | loss_msiv.backward(retain_graph=True)
80 | E_optimizer.step()
81 |
82 | #Latent-Vectors
83 |
84 | ## w
85 | #loss_w, loss_w_info = space_loss(w1,w2,image_space = False)
86 | ## c
87 | loss_c, loss_c_info = space_loss(z_c1,z_c2,image_space = False)
88 |
89 | loss_mslv = loss_c*0.01
90 | E_optimizer.zero_grad()
91 | loss_mslv.backward()
92 | E_optimizer.step()
93 |
94 |
95 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000))
96 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]')
97 | print('---------ImageSpace--------')
98 | print('loss_imgs_info: %s'%loss_imgs_info)
99 | print('---------LatentSpace--------')
100 | print('loss_c_info: %s'%loss_c_info)
101 |
102 | it_d += 1
103 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d)
104 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d)
105 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d)
106 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d)
107 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d)
108 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d)
109 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d)
110 |
111 | writer.add_scalar('loss_c_mse', loss_c_info[0][0], global_step=it_d)
112 | writer.add_scalar('loss_c_mse_mean', loss_c_info[0][1], global_step=it_d)
113 | writer.add_scalar('loss_c_mse_std', loss_c_info[0][2], global_step=it_d)
114 | writer.add_scalar('loss_c_kl', loss_c_info[1], global_step=it_d)
115 | writer.add_scalar('loss_c_cosine', loss_c_info[2], global_step=it_d)
116 | writer.add_scalar('loss_c_ssim', loss_c_info[3], global_step=it_d)
117 | writer.add_scalar('loss_c_lpips', loss_c_info[4], global_step=it_d)
118 |
119 | writer.add_scalars('Latent Space C', {'loss_c_mse':loss_c_info[0][0],'loss_c_mse_mean':loss_c_info[0][1],'loss_c_mse_std':loss_c_info[0][2],'loss_c_kl':loss_c_info[1],'loss_c_cosine':loss_c_info[2]}, global_step=it_d)
120 |
121 |
122 | if iteration % 100 == 0:
123 | n_row = batch_size
124 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5
125 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3
126 | with open(resultPath+'/Loss.txt', 'a+') as f:
127 | print('i_'+str(iteration),file=f)
128 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f)
129 | print('---------ImageSpace--------',file=f)
130 | print('loss_imgs_info: %s'%loss_imgs_info,file=f)
131 | print('---------LatentSpace--------',file=f)
132 | print('loss_c_info: %s'%loss_c_info,file=f)
133 | if iteration % 5000 == 0:
134 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000))
135 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration)
136 |
137 | if __name__ == "__main__":
138 |
139 | parser = argparse.ArgumentParser(description='the training args')
140 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000
141 | parser.add_argument('--lr', type=float, default=0.0015)
142 | parser.add_argument('--beta_1', type=float, default=0.0)
143 | parser.add_argument('--batch_size', type=int, default=2)
144 | parser.add_argument('--experiment_dir', default=None) #None
145 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt
146 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it
147 | parser.add_argument('--checkpoint_dir_E', default=None)
148 | parser.add_argument('--img_size',type=int, default=1024)
149 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
150 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128
151 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
152 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256
153 | args = parser.parse_args()
154 |
155 | if not os.path.exists('./result'): os.mkdir('./result')
156 | resultPath = args.experiment_dir
157 | if resultPath == None:
158 | resultPath = "./result/StyleGANv1-AlationStudy-Z"
159 | if not os.path.exists(resultPath): os.mkdir(resultPath)
160 |
161 | resultPath1_1 = resultPath+"/imgs"
162 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1)
163 |
164 | resultPath1_2 = resultPath+"/models"
165 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2)
166 |
167 | writer_path = os.path.join(resultPath, './summaries')
168 | if not os.path.exists(writer_path): os.mkdir(writer_path)
169 | writer = tensorboardX.SummaryWriter(writer_path)
170 |
171 | use_gpu = True
172 | device = torch.device("cuda" if use_gpu else "cpu")
173 |
174 | train(tensor_writer=writer, args = args)
175 |
--------------------------------------------------------------------------------
/ablation_utils/2.E_align_w_2.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 | import os
4 | import math
5 | import torch
6 | import torchvision
7 | import model.E.Ablation_Study.E_Blur_W_2 as BE
8 | from model.utils.custom_adam import LREQAdam
9 | import metric.pytorch_ssim as pytorch_ssim
10 | import lpips
11 | import numpy as np
12 | import tensorboardX
13 | import argparse
14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
15 | from training_utils import *
16 |
17 |
18 | def train(tensor_writer = None, args = None):
19 | type = args.mtype
20 |
21 | model_path = args.checkpoint_dir_GAN
22 | if type == 1: # StyleGAN1
23 | #model_path = './checkpoint/stylegan_v1/ffhq1024/'
24 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
25 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth'))
26 |
27 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
28 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth'))
29 |
30 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt')
31 | const_ = Gs.const
32 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda()
33 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024
34 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
35 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
36 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1]
37 |
38 | Gs.cuda()
39 | Gm.eval()
40 |
41 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
42 | else:
43 | print('error')
44 | return
45 |
46 | if args.checkpoint_dir_E != None:
47 | E.load_state_dict(torch.load(args.checkpoint_dir_E))
48 | E.cuda()
49 | writer = tensor_writer
50 |
51 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0)
52 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda')
53 |
54 | batch_size = args.batch_size
55 | it_d = 0
56 | for iteration in range(0,args.iterations):
57 | set_seed(iteration%30000)
58 | z = torch.randn(batch_size, args.z_dim) #[32, 512]
59 |
60 | if type == 1:
61 | with torch.no_grad(): #这里需要生成图片和变量
62 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512]
63 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256
64 | const2,w2 = E(imgs1)
65 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2))
66 | else:
67 | print('model type error')
68 | return
69 |
70 | E_optimizer.zero_grad()
71 |
72 | #loss Images
73 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips)
74 |
75 | loss_msiv = loss_imgs
76 | E_optimizer.zero_grad()
77 | loss_msiv.backward(retain_graph=True)
78 | E_optimizer.step()
79 |
80 | #Latent-Vectors
81 |
82 | ## w
83 | loss_w, loss_w_info = space_loss(w1,w2,image_space = False)
84 |
85 | ## c
86 | #loss_c, loss_c_info = space_loss(const1,const2,image_space = False)
87 |
88 | loss_mslv = loss_w * 0.01
89 | E_optimizer.zero_grad()
90 | loss_mslv.backward()
91 | E_optimizer.step()
92 |
93 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000))
94 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]')
95 | print('---------ImageSpace--------')
96 | print('loss_imgs_info: %s'%loss_imgs_info)
97 | print('---------LatentSpace--------')
98 | print('loss_w_info: %s'%loss_w_info)
99 |
100 | it_d += 1
101 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d)
102 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d)
103 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d)
104 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d)
105 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d)
106 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d)
107 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d)
108 |
109 | writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d)
110 | writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d)
111 | writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d)
112 | writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d)
113 | writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d)
114 | writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d)
115 | writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d)
116 |
117 | writer.add_scalars('Latent Space W', {'loss_w_mse':loss_w_info[0][0],'loss_w_mse_mean':loss_w_info[0][1],'loss_w_mse_std':loss_w_info[0][2],'loss_w_kl':loss_w_info[1],'loss_w_cosine':loss_w_info[2]}, global_step=it_d)
118 |
119 | if iteration % 100 == 0:
120 | n_row = batch_size
121 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5
122 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3
123 | with open(resultPath+'/Loss.txt', 'a+') as f:
124 | print('i_'+str(iteration),file=f)
125 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f)
126 | print('---------ImageSpace--------',file=f)
127 | print('loss_imgs_info: %s'%loss_imgs_info,file=f)
128 | print('---------LatentSpace--------',file=f)
129 | print('loss_w_info: %s'%loss_w_info,file=f)
130 | if iteration % 5000 == 0:
131 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000))
132 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration)
133 |
134 | if __name__ == "__main__":
135 |
136 | parser = argparse.ArgumentParser(description='the training args')
137 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000
138 | parser.add_argument('--lr', type=float, default=0.0015)
139 | parser.add_argument('--beta_1', type=float, default=0.0)
140 | parser.add_argument('--batch_size', type=int, default=2)
141 | parser.add_argument('--experiment_dir', default=None) #None
142 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt
143 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it
144 | parser.add_argument('--checkpoint_dir_E', default=None)
145 | parser.add_argument('--img_size',type=int, default=1024)
146 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
147 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128
148 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
149 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256
150 | args = parser.parse_args()
151 |
152 | if not os.path.exists('./result'): os.mkdir('./result')
153 | resultPath = args.experiment_dir
154 | if resultPath == None:
155 | resultPath = "./result/StyleGANv1-AlationStudy-w_2"
156 | if not os.path.exists(resultPath): os.mkdir(resultPath)
157 |
158 | resultPath1_1 = resultPath+"/imgs"
159 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1)
160 |
161 | resultPath1_2 = resultPath+"/models"
162 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2)
163 |
164 | writer_path = os.path.join(resultPath, './summaries')
165 | if not os.path.exists(writer_path): os.mkdir(writer_path)
166 | writer = tensorboardX.SummaryWriter(writer_path)
167 |
168 | use_gpu = True
169 | device = torch.device("cuda" if use_gpu else "cpu")
170 |
171 | train(tensor_writer=writer, args = args)
172 |
--------------------------------------------------------------------------------
/ablation_utils/3.E_align_w.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 | import os
4 | import math
5 | import torch
6 | import torchvision
7 | import model.E.Ablation_Study.E_Blur_W as BE
8 | from model.utils.custom_adam import LREQAdam
9 | import metric.pytorch_ssim as pytorch_ssim
10 | import lpips
11 | import numpy as np
12 | import tensorboardX
13 | import argparse
14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
15 | from training_utils import *
16 |
17 | def train(tensor_writer = None, args = None):
18 | type = args.mtype
19 |
20 | model_path = args.checkpoint_dir_GAN
21 | if type == 1: # StyleGAN1
22 | #model_path = './checkpoint/stylegan_v1/ffhq1024/'
23 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
24 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth'))
25 |
26 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
27 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth'))
28 |
29 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt')
30 | const_ = Gs.const
31 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda()
32 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024
33 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
34 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
35 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1]
36 |
37 | Gs.cuda()
38 | Gm.eval()
39 |
40 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
41 | else:
42 | print('error')
43 | return
44 |
45 | if args.checkpoint_dir_E != None:
46 | E.load_state_dict(torch.load(args.checkpoint_dir_E))
47 | E.cuda()
48 | writer = tensor_writer
49 |
50 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0)
51 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda')
52 |
53 | batch_size = args.batch_size
54 | it_d = 0
55 | for iteration in range(0,args.iterations):
56 | set_seed(iteration%30000)
57 | z = torch.randn(batch_size, args.z_dim) #[32, 512]
58 |
59 | if type == 1:
60 | with torch.no_grad(): #这里需要生成图片和变量
61 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512]
62 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256
63 | const2,w2 = E(imgs1)
64 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2))
65 | else:
66 | print('model type error')
67 | return
68 |
69 | E_optimizer.zero_grad()
70 |
71 | #loss Images
72 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips)
73 |
74 | loss_msiv = loss_imgs
75 | E_optimizer.zero_grad()
76 | loss_msiv.backward(retain_graph=True)
77 | E_optimizer.step()
78 |
79 | #Latent-Vectors
80 |
81 | ## w
82 | loss_w, loss_w_info = space_loss(w1,w2,image_space = False)
83 |
84 | ## c
85 | #loss_c, loss_c_info = space_loss(const1,const2,image_space = False)
86 |
87 | loss_mslv = loss_w * 0.01
88 | E_optimizer.zero_grad()
89 | loss_mslv.backward()
90 | E_optimizer.step()
91 |
92 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000))
93 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]')
94 | print('---------ImageSpace--------')
95 | print('loss_imgs_info: %s'%loss_imgs_info)
96 | print('---------LatentSpace--------')
97 | print('loss_w_info: %s'%loss_w_info)
98 |
99 | it_d += 1
100 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d)
101 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d)
102 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d)
103 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d)
104 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d)
105 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d)
106 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d)
107 |
108 | writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d)
109 | writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d)
110 | writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d)
111 | writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d)
112 | writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d)
113 | writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d)
114 | writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d)
115 |
116 | writer.add_scalars('Latent Space W', {'loss_w_mse':loss_w_info[0][0],'loss_w_mse_mean':loss_w_info[0][1],'loss_w_mse_std':loss_w_info[0][2],'loss_w_kl':loss_w_info[1],'loss_w_cosine':loss_w_info[2]}, global_step=it_d)
117 |
118 | if iteration % 100 == 0:
119 | n_row = batch_size
120 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5
121 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3
122 | with open(resultPath+'/Loss.txt', 'a+') as f:
123 | print('i_'+str(iteration),file=f)
124 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f)
125 | print('---------ImageSpace--------',file=f)
126 | print('loss_imgs_info: %s'%loss_imgs_info,file=f)
127 | print('---------LatentSpace--------',file=f)
128 | print('loss_w_info: %s'%loss_w_info,file=f)
129 | if iteration % 5000 == 0:
130 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000))
131 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration)
132 |
133 | if __name__ == "__main__":
134 |
135 | parser = argparse.ArgumentParser(description='the training args')
136 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000
137 | parser.add_argument('--lr', type=float, default=0.0015)
138 | parser.add_argument('--beta_1', type=float, default=0.0)
139 | parser.add_argument('--batch_size', type=int, default=2)
140 | parser.add_argument('--experiment_dir', default=None) #None
141 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt
142 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it
143 | parser.add_argument('--checkpoint_dir_E', default=None)
144 | parser.add_argument('--img_size',type=int, default=1024)
145 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
146 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128
147 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
148 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256
149 | args = parser.parse_args()
150 |
151 | if not os.path.exists('./result'): os.mkdir('./result')
152 | resultPath = args.experiment_dir
153 | if resultPath == None:
154 | resultPath = "./result/StyleGANv1-AlationStudy-w"
155 | if not os.path.exists(resultPath): os.mkdir(resultPath)
156 |
157 | resultPath1_1 = resultPath+"/imgs"
158 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1)
159 |
160 | resultPath1_2 = resultPath+"/models"
161 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2)
162 |
163 | writer_path = os.path.join(resultPath, './summaries')
164 | if not os.path.exists(writer_path): os.mkdir(writer_path)
165 | writer = tensorboardX.SummaryWriter(writer_path)
166 |
167 | use_gpu = True
168 | device = torch.device("cuda" if use_gpu else "cpu")
169 |
170 | train(tensor_writer=writer, args = args)
171 |
--------------------------------------------------------------------------------
/ablation_utils/4.E_align_w_zn.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 | import os
4 | import math
5 | import torch
6 | import torchvision
7 | import model.E.E_Blur as BE
8 | from model.utils.custom_adam import LREQAdam
9 | import metric.pytorch_ssim as pytorch_ssim
10 | import lpips
11 | import numpy as np
12 | import tensorboardX
13 | import argparse
14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
15 | from training_utils import *
16 |
17 | def train(tensor_writer = None, args = None):
18 | type = args.mtype
19 |
20 | model_path = args.checkpoint_dir_GAN
21 | if type == 1: # StyleGAN1
22 | #model_path = './checkpoint/stylegan_v1/ffhq1024/'
23 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
24 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth'))
25 |
26 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
27 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth'))
28 |
29 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt')
30 | const_ = Gs.const
31 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda()
32 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024
33 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
34 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
35 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1]
36 |
37 | Gs.cuda()
38 | Gm.eval()
39 |
40 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
41 | else:
42 | print('error')
43 | return
44 |
45 | if args.checkpoint_dir_E != None:
46 | E.load_state_dict(torch.load(args.checkpoint_dir_E))
47 | E.cuda()
48 | writer = tensor_writer
49 |
50 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0)
51 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda')
52 |
53 | batch_size = args.batch_size
54 | it_d = 0
55 | for iteration in range(0,args.iterations):
56 | set_seed(iteration%30000)
57 | z = torch.randn(batch_size, args.z_dim) #[32, 512]
58 |
59 | if type == 1:
60 | with torch.no_grad(): #这里需要生成图片和变量
61 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512]
62 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256
63 | const2,w2 = E(imgs1)
64 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2))
65 | else:
66 | print('model type error')
67 | return
68 |
69 | E_optimizer.zero_grad()
70 |
71 | #loss Images
72 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips)
73 |
74 | loss_msiv = loss_imgs
75 | E_optimizer.zero_grad()
76 | loss_msiv.backward(retain_graph=True)
77 | E_optimizer.step()
78 |
79 | #Latent-Vectors
80 |
81 | ## w
82 | loss_w, loss_w_info = space_loss(w1,w2,image_space = False)
83 |
84 | ## c
85 | #loss_c, loss_c_info = space_loss(const1,const2,image_space = False)
86 |
87 | loss_mslv = loss_w * 0.01
88 | E_optimizer.zero_grad()
89 | loss_mslv.backward()
90 | E_optimizer.step()
91 |
92 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000))
93 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]')
94 | print('---------ImageSpace--------')
95 | print('loss_imgs_info: %s'%loss_imgs_info)
96 | print('---------LatentSpace--------')
97 | print('loss_w_info: %s'%loss_w_info)
98 |
99 | it_d += 1
100 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d)
101 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d)
102 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d)
103 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d)
104 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d)
105 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d)
106 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d)
107 |
108 | writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d)
109 | writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d)
110 | writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d)
111 | writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d)
112 | writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d)
113 | writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d)
114 | writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d)
115 |
116 | writer.add_scalars('Latent Space W', {'loss_w_mse':loss_w_info[0][0],'loss_w_mse_mean':loss_w_info[0][1],'loss_w_mse_std':loss_w_info[0][2],'loss_w_kl':loss_w_info[1],'loss_w_cosine':loss_w_info[2]}, global_step=it_d)
117 |
118 | if iteration % 100 == 0:
119 | n_row = batch_size
120 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5
121 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3
122 | with open(resultPath+'/Loss.txt', 'a+') as f:
123 | print('i_'+str(iteration),file=f)
124 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f)
125 | print('---------ImageSpace--------',file=f)
126 | print('loss_imgs_info: %s'%loss_imgs_info,file=f)
127 | print('---------LatentSpace--------',file=f)
128 | print('loss_w_info: %s'%loss_w_info,file=f)
129 | if iteration % 5000 == 0:
130 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000))
131 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration)
132 |
133 | if __name__ == "__main__":
134 |
135 | parser = argparse.ArgumentParser(description='the training args')
136 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000
137 | parser.add_argument('--lr', type=float, default=0.0015)
138 | parser.add_argument('--beta_1', type=float, default=0.0)
139 | parser.add_argument('--batch_size', type=int, default=2)
140 | parser.add_argument('--experiment_dir', default=None) #None
141 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt
142 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it
143 | parser.add_argument('--checkpoint_dir_E', default=None)
144 | parser.add_argument('--img_size',type=int, default=1024)
145 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
146 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128
147 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
148 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256
149 | args = parser.parse_args()
150 |
151 | if not os.path.exists('./result'): os.mkdir('./result')
152 | resultPath = args.experiment_dir
153 | if resultPath == None:
154 | resultPath = "./result/StyleGANv1-AlationStudy-w_zn"
155 | if not os.path.exists(resultPath): os.mkdir(resultPath)
156 |
157 | resultPath1_1 = resultPath+"/imgs"
158 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1)
159 |
160 | resultPath1_2 = resultPath+"/models"
161 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2)
162 |
163 | writer_path = os.path.join(resultPath, './summaries')
164 | if not os.path.exists(writer_path): os.mkdir(writer_path)
165 | writer = tensorboardX.SummaryWriter(writer_path)
166 |
167 | use_gpu = True
168 | device = torch.device("cuda" if use_gpu else "cpu")
169 |
170 | train(tensor_writer=writer, args = args)
171 |
--------------------------------------------------------------------------------
/ablation_utils/5.E_align_w_zn_zc.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 | import os
4 | import math
5 | import torch
6 | import torchvision
7 | import model.E.E_Blur as BE
8 | from model.utils.custom_adam import LREQAdam
9 | import metric.pytorch_ssim as pytorch_ssim
10 | import lpips
11 | import numpy as np
12 | import tensorboardX
13 | import argparse
14 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
15 | from training_utils import *
16 |
17 | def train(tensor_writer = None, args = None):
18 | type = args.mtype
19 |
20 | model_path = args.checkpoint_dir_GAN
21 | if type == 1: # StyleGAN1
22 | #model_path = './checkpoint/stylegan_v1/ffhq1024/'
23 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
24 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth'))
25 |
26 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
27 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth'))
28 |
29 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt')
30 | const_ = Gs.const
31 | const1 = const_.repeat(args.batch_size,1,1,1).detach().clone().cuda()
32 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024
33 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
34 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
35 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1]
36 |
37 | Gs.cuda()
38 | Gm.eval()
39 |
40 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
41 |
42 | else:
43 | print('error')
44 | return
45 |
46 | if args.checkpoint_dir_E != None:
47 | E.load_state_dict(torch.load(args.checkpoint_dir_E))
48 | E.cuda()
49 | writer = tensor_writer
50 |
51 | E_optimizer = LREQAdam([{'params': E.parameters()},], lr=args.lr, betas=(args.beta_1, 0.99), weight_decay=0)
52 | loss_lpips = lpips.LPIPS(net='vgg').to('cuda')
53 |
54 | batch_size = args.batch_size
55 | it_d = 0
56 | for iteration in range(0,args.iterations):
57 | set_seed(iteration%30000)
58 | z = torch.randn(batch_size, args.z_dim) #[32, 512]
59 |
60 | if type == 1:
61 | with torch.no_grad(): #这里需要生成图片和变量
62 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512]
63 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256
64 | const2,w2 = E(imgs1)
65 | imgs2 = Gs.forward(w2,int(math.log(args.img_size,2)-2))
66 | else:
67 | print('model type error')
68 | return
69 |
70 | E_optimizer.zero_grad()
71 |
72 | #loss Images
73 | loss_imgs, loss_imgs_info = space_loss(imgs1,imgs2,lpips_model=loss_lpips)
74 |
75 | loss_msiv = loss_imgs
76 | E_optimizer.zero_grad()
77 | loss_msiv.backward(retain_graph=True)
78 | E_optimizer.step()
79 |
80 | #Latent-Vectors
81 |
82 | ## w
83 | loss_w, loss_w_info = space_loss(w1,w2,image_space = False)
84 |
85 | ## c
86 | loss_c, loss_c_info = space_loss(const1,const2,image_space = False)
87 |
88 | loss_mslv = (loss_w + loss_c)*0.01
89 | E_optimizer.zero_grad()
90 | loss_mslv.backward()
91 | E_optimizer.step()
92 |
93 | print('ep_%d_iter_%d'%(iteration//30000,iteration%30000))
94 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]')
95 | print('---------ImageSpace--------')
96 | print('loss_imgs_info: %s'%loss_imgs_info)
97 | print('---------LatentSpace--------')
98 | print('loss_w_info: %s'%loss_w_info)
99 | print('loss_c_info: %s'%loss_c_info)
100 |
101 |
102 | it_d += 1
103 |
104 | writer.add_scalar('loss_imgs_mse', loss_imgs_info[0][0], global_step=it_d)
105 | writer.add_scalar('loss_imgs_mse_mean', loss_imgs_info[0][1], global_step=it_d)
106 | writer.add_scalar('loss_imgs_mse_std', loss_imgs_info[0][2], global_step=it_d)
107 | writer.add_scalar('loss_imgs_kl', loss_imgs_info[1], global_step=it_d)
108 | writer.add_scalar('loss_imgs_cosine', loss_imgs_info[2], global_step=it_d)
109 | writer.add_scalar('loss_imgs_ssim', loss_imgs_info[3], global_step=it_d)
110 | writer.add_scalar('loss_imgs_lpips', loss_imgs_info[4], global_step=it_d)
111 |
112 | writer.add_scalar('loss_w_mse', loss_w_info[0][0], global_step=it_d)
113 | writer.add_scalar('loss_w_mse_mean', loss_w_info[0][1], global_step=it_d)
114 | writer.add_scalar('loss_w_mse_std', loss_w_info[0][2], global_step=it_d)
115 | writer.add_scalar('loss_w_kl', loss_w_info[1], global_step=it_d)
116 | writer.add_scalar('loss_w_cosine', loss_w_info[2], global_step=it_d)
117 | writer.add_scalar('loss_w_ssim', loss_w_info[3], global_step=it_d)
118 | writer.add_scalar('loss_w_lpips', loss_w_info[4], global_step=it_d)
119 |
120 | writer.add_scalar('loss_c_mse', loss_c_info[0][0], global_step=it_d)
121 | writer.add_scalar('loss_c_mse_mean', loss_c_info[0][1], global_step=it_d)
122 | writer.add_scalar('loss_c_mse_std', loss_c_info[0][2], global_step=it_d)
123 | writer.add_scalar('loss_c_kl', loss_c_info[1], global_step=it_d)
124 | writer.add_scalar('loss_c_cosine', loss_c_info[2], global_step=it_d)
125 | writer.add_scalar('loss_c_ssim', loss_c_info[3], global_step=it_d)
126 | writer.add_scalar('loss_c_lpips', loss_c_info[4], global_step=it_d)
127 |
128 | writer.add_scalars('Latent Space W', {'loss_w_mse':loss_w_info[0][0],'loss_w_mse_mean':loss_w_info[0][1],'loss_w_mse_std':loss_w_info[0][2],'loss_w_kl':loss_w_info[1],'loss_w_cosine':loss_w_info[2]}, global_step=it_d)
129 | writer.add_scalars('Latent Space C', {'loss_c_mse':loss_c_info[0][0],'loss_c_mse_mean':loss_c_info[0][1],'loss_c_mse_std':loss_c_info[0][2],'loss_c_kl':loss_c_info[1],'loss_c_cosine':loss_c_info[2]}, global_step=it_d)
130 |
131 |
132 | if iteration % 100 == 0:
133 | n_row = batch_size
134 | test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5
135 | torchvision.utils.save_image(test_img, resultPath1_1+'/ep%d_iter%d.jpg'%(iteration//30000,iteration%30000),nrow=n_row) # nrow=3
136 | with open(resultPath+'/Loss.txt', 'a+') as f:
137 | print('i_'+str(iteration),file=f)
138 | print('[loss_imgs_mse[img,img_mean,img_std], loss_imgs_kl, loss_imgs_cosine, loss_imgs_ssim, loss_imgs_lpips]',file=f)
139 | print('---------ImageSpace--------',file=f)
140 | print('loss_imgs_info: %s'%loss_imgs_info,file=f)
141 | print('---------LatentSpace--------',file=f)
142 | print('loss_w_info: %s'%loss_w_info,file=f)
143 | print('loss_c_info: %s'%loss_c_info,file=f)
144 | if iteration % 5000 == 0:
145 | torch.save(E.state_dict(), resultPath1_2+'/E_model_ep%d_iter%d.pth'%(iteration//30000,iteration%30000))
146 | #torch.save(Gm.buffer1,resultPath1_2+'/center_tensor_iter%d.pt'%iteration)
147 |
148 | if __name__ == "__main__":
149 |
150 | parser = argparse.ArgumentParser(description='the training args')
151 | parser.add_argument('--iterations', type=int, default=60001) # epoch = iterations//30000
152 | parser.add_argument('--lr', type=float, default=0.0015)
153 | parser.add_argument('--beta_1', type=float, default=0.0)
154 | parser.add_argument('--batch_size', type=int, default=2)
155 | parser.add_argument('--experiment_dir', default=None) #None
156 | parser.add_argument('--checkpoint_dir_GAN', default='../checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt
157 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it
158 | parser.add_argument('--checkpoint_dir_E', default=None)
159 | parser.add_argument('--img_size',type=int, default=1024)
160 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
161 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128
162 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
163 | parser.add_argument('--start_features', type=int, default=16) # 16->1024 32->512 64->256
164 | args = parser.parse_args()
165 |
166 | if not os.path.exists('./result'): os.mkdir('./result')
167 | resultPath = args.experiment_dir
168 | if resultPath == None:
169 | resultPath = "./result/StyleGANv1-AlationStudy-w_zn_zc"
170 | if not os.path.exists(resultPath): os.mkdir(resultPath)
171 |
172 | resultPath1_1 = resultPath+"/imgs"
173 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1)
174 |
175 | resultPath1_2 = resultPath+"/models"
176 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2)
177 |
178 | writer_path = os.path.join(resultPath, './summaries')
179 | if not os.path.exists(writer_path): os.mkdir(writer_path)
180 | writer = tensorboardX.SummaryWriter(writer_path)
181 |
182 | use_gpu = True
183 | device = torch.device("cuda" if use_gpu else "cpu")
184 |
185 | train(tensor_writer=writer, args = args)
186 |
--------------------------------------------------------------------------------
/baseline_utils/test-baseline-IndomainG.py:
--------------------------------------------------------------------------------
1 | # python 3.6
2 | """Inverts given images to latent codes with In-Domain GAN Inversion.
3 |
4 | Basically, for a particular image (real or synthesized), this script first
5 | employs the domain-guided encoder to produce a initial point in the latent
6 | space and then performs domain-regularized optimization to refine the latent
7 | code.
8 | """
9 |
10 | # invert.py
11 | # python test-baseline-indomainG.py 'styleganinv_ffhq256' './styleganv1-generations-512/'
12 | import os
13 | import argparse
14 | from tqdm import tqdm
15 | import numpy as np
16 |
17 | from utils.inverter import StyleGANInverter
18 | from utils.logger import setup_logger
19 | from utils.visualizer import HtmlPageVisualizer
20 | from utils.visualizer import save_image, load_image, resize_image
21 |
22 | import torch
23 | import torchvision
24 |
25 | def parse_args():
26 | """Parses arguments."""
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('model_name', type=str, help='Name of the GAN model.')
29 | parser.add_argument('image_list', type=str,
30 | help='List of images to invert.')
31 | parser.add_argument('-o', '--output_dir', type=str, default='',
32 | help='Directory to save the results. If not specified, '
33 | '`./results/inversion/${IMAGE_LIST}` '
34 | 'will be used by default.')
35 | parser.add_argument('--learning_rate', type=float, default=0.01,
36 | help='Learning rate for optimization. (default: 0.01)')
37 | parser.add_argument('--num_iterations', type=int, default=100,
38 | help='Number of optimization iterations. (default: 100)')
39 | parser.add_argument('--num_results', type=int, default=5,
40 | help='Number of intermediate optimization results to '
41 | 'save for each sample. (default: 5)')
42 | parser.add_argument('--loss_weight_feat', type=float, default=5e-5,
43 | help='The perceptual loss scale for optimization. '
44 | '(default: 5e-5)')
45 | parser.add_argument('--loss_weight_enc', type=float, default=2.0,
46 | help='The encoder loss scale for optimization.'
47 | '(default: 2.0)')
48 | parser.add_argument('--viz_size', type=int, default=256,
49 | help='Image size for visualization. (default: 256)')
50 | parser.add_argument('--gpu_id', type=str, default='0',
51 | help='Which GPU(s) to use. (default: `0`)')
52 | return parser.parse_args()
53 |
54 |
55 | def main():
56 | """Main function."""
57 | args = parse_args()
58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
59 | assert os.path.exists(args.image_list)
60 | image_list_name = os.path.splitext(os.path.basename(args.image_list))[0]
61 | output_dir = args.output_dir or f'results/inversion/{image_list_name}'
62 | logger = setup_logger(output_dir, 'inversion.log', 'inversion_logger')
63 |
64 | logger.info(f'Loading model.')
65 | inverter = StyleGANInverter(
66 | args.model_name,
67 | learning_rate=args.learning_rate,
68 | iteration=args.num_iterations,
69 | reconstruction_loss_weight=1.0,
70 | perceptual_loss_weight=args.loss_weight_feat,
71 | regularization_loss_weight=args.loss_weight_enc,
72 | logger=logger)
73 | image_size = inverter.G.resolution
74 |
75 | # Load image list.
76 | # logger.info(f'Loading image list.')
77 | # image_list = []
78 | # with open(args.image_list, 'r') as f:
79 | # for line in f:
80 | # image_list.append(line.strip())
81 |
82 | path = args.image_list # 文件路径
83 | paths = list(os.listdir(path))
84 | image_list = []
85 | for i in paths:
86 | image_list.append(path + '/' + i)
87 |
88 |
89 | # Initialize visualizer.
90 | save_interval = args.num_iterations // args.num_results
91 | headers = ['Name', 'Original Image', 'Encoder Output']
92 | for step in range(1, args.num_iterations + 1):
93 | if step == args.num_iterations or step % save_interval == 0:
94 | headers.append(f'Step {step:06d}')
95 | viz_size = None if args.viz_size == 0 else args.viz_size
96 | visualizer = HtmlPageVisualizer(
97 | num_rows=len(image_list), num_cols=len(headers), viz_size=viz_size)
98 | visualizer.set_headers(headers)
99 |
100 | # Invert images.
101 | logger.info(f'Start inversion.')
102 | latent_codes = []
103 | codes_rec = []
104 | codes_inv = []
105 | for img_idx in tqdm(range(len(image_list)), leave=False):
106 | image_path = image_list[img_idx]
107 | image_name = os.path.splitext(os.path.basename(image_path))[0]
108 | image = resize_image(load_image(image_path), (image_size, image_size))
109 | code, viz_results = inverter.easy_invert(image, num_viz=args.num_results)
110 | latent_codes.append(code)
111 | #save_image(f'{output_dir}/{image_name}_ori.png', image)
112 | save_image(f'{output_dir}/r/{image_name}_enc.png', viz_results[1])
113 | save_image(f'{output_dir}/i/{image_name}_inv.png', viz_results[-1])
114 | visualizer.set_cell(img_idx, 0, text=image_name)
115 | visualizer.set_cell(img_idx, 1, image=image)
116 | codes_rec.append(torch.tensor(viz_results[1],dtype=float)/255.0) # cv2->pytorch
117 | codes_inv.append(torch.tensor(viz_results[-1],dtype=float)/255.0)
118 | for viz_idx, viz_img in enumerate(viz_results[1:]):
119 | visualizer.set_cell(img_idx, viz_idx + 2, image=viz_img)
120 |
121 | # Save results.
122 | #os.system(f'cp {args.image_list} {output_dir}/image_list.txt')
123 | np.save(f'{output_dir}/inverted_codes.npy', np.concatenate(latent_codes, axis=0))
124 | visualizer.save(f'{output_dir}/inversion.html')
125 |
126 | # codes_rec_tensor_ = torch.stack(codes_rec, dim=0)
127 | # codes_rec_tensor = codes_rec_tensor_.permute(0,3,1,2) # cv2->pytorch
128 | # codes_inv_tensor_ = torch.stack(codes_inv, dim=0)
129 | # codes_inv_tensor = codes_inv_tensor_.permute(0,3,1,2)
130 | # torch.save(codes_rec_tensor,'./indomain-rec32.pt')
131 | # torch.save(codes_inv_tensor,'./indomain-inv32.pt')
132 | # torchvision.utils.save_image(codes_rec_tensor,'./indomain_images_rec.png',nrow=5)
133 | # torchvision.utils.save_image(codes_inv_tensor,'./indomain_images_inv.png',nrow=5)
134 |
135 | if __name__ == '__main__':
136 | main()
137 |
--------------------------------------------------------------------------------
/baseline_utils/test_baseline_alae.py:
--------------------------------------------------------------------------------
1 | #test_baseline_alae.py
2 |
3 |
4 | import sys
5 | sys.path.append(".")
6 | sys.path.append("..")
7 |
8 | import torch.utils.data
9 | import torchvision
10 | import random
11 | from net import *
12 | from model import Model
13 | from launcher import run
14 | from checkpointer import Checkpointer
15 | from dlutils.pytorch import count_parameters
16 | from defaults import get_cfg_defaults
17 | import lreq
18 | from skimage.transform import resize
19 | from PIL import Image
20 | import logging
21 |
22 | config_file='./configs/ffhq.yaml'
23 | cfg=get_cfg_defaults()
24 | cfg.merge_from_file(config_file)
25 | cfg.freeze()
26 |
27 | torch.cuda.set_device(0)
28 | model = Model(
29 | startf=cfg.MODEL.START_CHANNEL_COUNT,
30 | layer_count=cfg.MODEL.LAYER_COUNT,
31 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
32 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
33 | truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
34 | truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
35 | mapping_layers=cfg.MODEL.MAPPING_LAYERS,
36 | channels=cfg.MODEL.CHANNELS,
37 | generator=cfg.MODEL.GENERATOR,
38 | encoder=cfg.MODEL.ENCODER)
39 | model.cuda(0)
40 | model.eval()
41 | model.requires_grad_(False)
42 |
43 | decoder = model.decoder
44 | encoder = model.encoder
45 | mapping_tl = model.mapping_d
46 | mapping_fl = model.mapping_f
47 | dlatent_avg = model.dlatent_avg
48 |
49 | logger = logging.getLogger("logger")
50 | logger.setLevel(logging.DEBUG)
51 |
52 | logger.info("Trainable parameters generator:")
53 | count_parameters(decoder)
54 |
55 | logger.info("Trainable parameters discriminator:")
56 | count_parameters(encoder)
57 |
58 | arguments = dict()
59 | arguments["iteration"] = 0
60 |
61 | model_dict = {
62 | 'discriminator_s': encoder,
63 | 'generator_s': decoder,
64 | 'mapping_tl_s': mapping_tl,
65 | 'mapping_fl_s': mapping_fl,
66 | 'dlatent_avg': dlatent_avg
67 | }
68 |
69 | checkpointer = Checkpointer(cfg,
70 | model_dict,
71 | {},
72 | logger=logger,
73 | save=False)
74 |
75 | extra_checkpoint_data = checkpointer.load()
76 |
77 | model.eval()
78 |
79 | layer_count = cfg.MODEL.LAYER_COUNT
80 | im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)
81 | print(im_size)
82 |
83 | def encode(x):
84 | Z, _ = model.encode(x, layer_count - 1, 1)
85 | Z = Z.repeat(1, model.mapping_f.num_layers, 1)
86 | return Z
87 |
88 | def decode(x):
89 | layer_idx = torch.arange(2 * cfg.MODEL.LAYER_COUNT)[np.newaxis, :, np.newaxis]
90 | ones = torch.ones(layer_idx.shape, dtype=torch.float32)
91 | coefs = torch.where(layer_idx < model.truncation_cutoff, 1.0 * ones, ones)
92 | # x = torch.lerp(model.dlatent_avg.buff.data, x, coefs)
93 | return model.decoder(x, layer_count - 1, 1, noise=True)
94 |
95 | #传入路径,输出图片及其重构
96 | def make(paths):
97 | src = []
98 | for i,filename in enumerate(paths):
99 | img = Image.open(path + '/' + filename)
100 | img = img.resize((im_size,im_size))
101 | img = np.asarray(img)
102 | print(i,img.shape)
103 | if img.shape[2] == 4:
104 | img = img[:, :, :3]
105 | im = img.transpose((2, 0, 1))
106 | x = torch.tensor(np.asarray(im, dtype=np.float32), requires_grad=True).cuda() / 127.5 - 1.
107 | if x.shape[0] == 4:
108 | x = x[:3]
109 | factor = x.shape[2] // im_size
110 | if factor != 1:
111 | x = torch.nn.functional.avg_pool2d(x[None, ...], factor, factor)[0]
112 | assert x.shape[2] == im_size
113 | src.append(x)
114 |
115 | with torch.no_grad():
116 | reconstructions = []
117 | for s in src:
118 | latents = encode(s[None, ...])
119 | reconstructions.append(decode(latents).cpu().detach().numpy())
120 | # print(len(src))
121 | # print(src[0].shape)
122 | # print(reconstructions[0].shape)
123 | return src, reconstructions
124 |
125 |
126 | path = './styleganv1-generations'
127 | paths = list(os.listdir(path))
128 |
129 | # paths = paths[:256]
130 | # chuncker_id = 0 # 0, 1, 2, 3
131 |
132 | paths = paths[256:]
133 | chuncker_id = 1 # 0, 1, 2, 3
134 |
135 | src0, rec0 = make(paths)
136 | src0 = [torch.tensor(array_np) for array_np in src0]
137 | rec0 = [torch.tensor(array_np[0]) for array_np in rec0]
138 | batched_images1 = torch.stack(src0, dim=0)
139 | batched_images2 = torch.stack(rec0, dim=0)
140 | print(batched_images1.shape)
141 | print(batched_images2.shape)
142 |
143 | for i,j in enumerate(batched_images2):
144 | j = j.unsqueeze(0)
145 | torchvision.utils.save_image(j*0.5+0.5,'./%s_rec.png'%str(i+256*chuncker_id).rjust(3,'0'))
146 |
147 | #torch.save(batched_images1*0.5+0.5,'./ALAE-real-30.pt')
148 | #torch.save(batched_images2*0.5+0.5,'./ALAE-rec-30.pt')
--------------------------------------------------------------------------------
/baseline_utils/test_baseline_psp.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 | import time
3 | import sys
4 | import os
5 | import pprint
6 | import numpy as np
7 | from PIL import Image
8 | import torch
9 | import torchvision
10 | import torchvision.transforms as transforms
11 |
12 | sys.path.append(".")
13 | sys.path.append("..")
14 |
15 |
16 | from datasets import augmentations
17 | from utils.common import tensor2im, log_input_image
18 | from models.psp import pSp
19 |
20 |
21 | ## 处理编码图像
22 | save_path = "./real-128imgs"
23 |
24 | ## 配置
25 | experiment_type = 'ffhq_encode'
26 | #@param ['ffhq_encode', 'ffhq_frontalize', 'celebs_sketch_to_face', 'celebs_seg_to_face', 'celebs_super_resolution', 'toonify']
27 |
28 | EXPERIMENT_DATA_ARGS = {
29 | "ffhq_encode": {
30 | "model_path": "pretrained_models/psp_ffhq_encode.pt",
31 | "image_path": "notebooks/images/input_img.jpg",
32 | "transform": transforms.Compose([
33 | transforms.Resize((256, 256)),
34 | transforms.ToTensor(),
35 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
36 | },
37 | "ffhq_frontalize": {
38 | "model_path": "pretrained_models/psp_ffhq_frontalization.pt",
39 | "image_path": "notebooks/images/input_img.jpg",
40 | "transform": transforms.Compose([
41 | transforms.Resize((256, 256)),
42 | transforms.ToTensor(),
43 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
44 | },
45 | "celebs_sketch_to_face": {
46 | "model_path": "pretrained_models/psp_celebs_sketch_to_face.pt",
47 | "image_path": "notebooks/images/input_sketch.jpg",
48 | "transform": transforms.Compose([
49 | transforms.Resize((256, 256)),
50 | transforms.ToTensor()])
51 | },
52 | "celebs_seg_to_face": {
53 | "model_path": "pretrained_models/psp_celebs_seg_to_face.pt",
54 | "image_path": "notebooks/images/input_mask.png",
55 | "transform": transforms.Compose([
56 | transforms.Resize((256, 256)),
57 | augmentations.ToOneHot(n_classes=19),
58 | transforms.ToTensor()])
59 | },
60 | "celebs_super_resolution": {
61 | "model_path": "pretrained_models/psp_celebs_super_resolution.pt",
62 | "image_path": "notebooks/images/input_img.jpg",
63 | "transform": transforms.Compose([
64 | transforms.Resize((256, 256)),
65 | augmentations.BilinearResize(factors=[16]),
66 | transforms.Resize((256, 256)),
67 | transforms.ToTensor(),
68 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
69 | },
70 | "toonify": {
71 | "model_path": "pretrained_models/psp_ffhq_toonify.pt",
72 | "image_path": "notebooks/images/input_img.jpg",
73 | "transform": transforms.Compose([
74 | transforms.Resize((256, 256)),
75 | transforms.ToTensor(),
76 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
77 | },
78 | }
79 |
80 | EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]
81 |
82 | model_path = EXPERIMENT_ARGS['model_path']
83 | ckpt = torch.load(model_path, map_location='cpu')
84 |
85 | opts = ckpt['opts']
86 | # pprint.pprint(opts)
87 |
88 | # update the training options
89 | opts['checkpoint_path'] = model_path
90 | if 'learn_in_w' not in opts:
91 | opts['learn_in_w'] = False
92 | if 'output_size' not in opts:
93 | opts['output_size'] = 1024
94 |
95 |
96 | opts = Namespace(**opts)
97 | net = pSp(opts)
98 | net.eval()
99 | device = 'cuda' # 'cuda'
100 | net.cuda()
101 | print('Model successfully loaded!')
102 |
103 | import matplotlib; matplotlib.use('TkAgg')
104 | import matplotlib.pyplot as plt
105 |
106 | def run_on_batch(inputs, net, latent_mask=None):
107 | if latent_mask is None:
108 | result_batch = net(inputs.to(device).float(), randomize_noise=False)
109 | else:
110 | result_batch = []
111 | for image_idx, input_image in enumerate(inputs):
112 | # get latent vector to inject into our input image
113 | vec_to_inject = np.random.randn(1, 512).astype('float32')
114 | _, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"),
115 | input_code=True,
116 | return_latents=True)
117 | # get output image with injected style vector
118 | res = net(input_image.unsqueeze(0).to(device).float(),
119 | latent_mask=latent_mask,
120 | inject_latent=latent_to_inject)
121 | result_batch.append(res)
122 | result_batch = torch.cat(result_batch, dim=0)
123 | return result_batch
124 |
125 | image_paths = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith(".png") or f.endswith(".jpg")]
126 | n_images = len(image_paths)
127 |
128 | images = []
129 | n_cols = np.ceil(n_images / 2)
130 | #fig = plt.figure(figsize=(20, 4))
131 | for idx, image_path in enumerate(image_paths):
132 | #ax = fig.add_subplot(2, n_cols, idx + 1)
133 | img = Image.open(image_path).convert("RGB")
134 | images.append(img)
135 | # ax.grid(False)
136 | # ax.set_xticks([])
137 | # ax.set_yticks([])
138 | # ax.imshow(img)
139 | #plt.show()
140 |
141 | img_transforms = EXPERIMENT_ARGS['transform']
142 | transformed_images = [img_transforms(image) for image in images]
143 |
144 | batched_images = torch.stack(transformed_images, dim=0)
145 |
146 | #batched_images = torch.load('real-30.pt')
147 |
148 | with torch.no_grad():
149 | tic = time.time()
150 | for i,j in enumerate(batched_images):
151 | j = j.unsqueeze(0)
152 | result_images = run_on_batch(j, net, latent_mask=None)
153 | toc = time.time()
154 | print('Inference took {:.4f} seconds.'.format(toc - tic))
155 | torchvision.utils.save_image(result_images*0.5+0.5,'./r/%s_pSp_Rec.jpg'%str(i).rjust(5,'0'))
156 |
157 | # from IPython.display import display
158 |
159 | # couple_results = []
160 | # for original_image, result_image in zip(images, result_images):
161 | # result_image = tensor2im(result_image)
162 | # res = np.concatenate([np.array(original_image.resize((256, 256))),
163 | # np.array(result_image.resize((256, 256)))], axis=1)
164 | # res_im = Image.fromarray(res)
165 | # couple_results.append(res_im)
166 | # display(res_im)
167 | # #import matplotlib.pyplot as plt
168 | # #img = plt.imread('1.jpg')#读取图片
169 | # plt.imshow(res_im)#展示图片
170 | # plt.show()
171 | # result_images = result_images*0.5+0.5
172 | # torch.save(result_images,'./psp_w30.pt')
173 | # torchvision.utils.save_image(result_images,'./batched_images2.png',nrow=5)
174 |
175 |
--------------------------------------------------------------------------------
/comparing-baseline.py:
--------------------------------------------------------------------------------
1 | # testing
2 |
3 | import os
4 | import skimage
5 | import lpips
6 | import torch
7 | import torchvision
8 | from PIL import Image
9 |
10 |
11 | save_path1 = './styleganv1-generations/'
12 | save_path2 = './MTV-rec/'
13 |
14 | loss_mse = torch.nn.MSELoss()
15 | loss_lpips = lpips.LPIPS(net='vgg')
16 |
17 | def cosineSimilarty(imgs1_cos,imgs2_cos):
18 | values = imgs1_cos.dot(imgs2_cos)/(torch.sqrt(imgs1_cos.dot(imgs1_cos))*torch.sqrt(imgs2_cos.dot(imgs2_cos))) # [0,1]
19 | return values
20 |
21 | def metrics(img_tensor1,img_tensor2):
22 |
23 | psnr = skimage.measure.compare_psnr(img_tensor1.float().numpy().transpose(1,2,0), img_tensor2.float().numpy().transpose(1,2,0), 255) #range:[0,255]
24 |
25 | ssim = skimage.measure.compare_ssim(img_tensor1.float().numpy().transpose(1,2,0), img_tensor2.float().numpy().transpose(1,2,0), data_range=255, multichannel=True)#[h,w,c] and range:[0,255]
26 |
27 | mse_value = loss_mse(img_tensor1,img_tensor2).numpy() #range:[0,255]
28 |
29 | lpips_value = loss_lpips(img_tensor1.unsqueeze(0)/255.0*2-1,img_tensor2.unsqueeze(0)/255.0*2-1).mean().detach().numpy() #range:[-1,1]
30 |
31 | cosine_value = cosineSimilarty(img_tensor1.view(-1)/255.0*2-1,img_tensor2.view(-1)/255.0*2-1).numpy() #range:[-1,1]
32 |
33 | print('-------------')
34 | print('psnr:',psnr)
35 | print('-------------')
36 | print('ssim:',ssim)
37 | print('-------------')
38 | print('lpips:',lpips_value)
39 | print('-------------')
40 | print('mse:',mse_value)
41 | print('-------------')
42 | print('cosine:',cosine_value)
43 |
44 | return psnr, ssim, lpips_value, mse_value, cosine_value
45 |
46 | #--------文件夹内的图片转换为tensor:[n,c,h,w]------------------
47 | img_size = 512
48 | #PIL 2 Tensor
49 | transform = torchvision.transforms.Compose([
50 | #torchvision.transforms.CenterCrop(160),
51 | torchvision.transforms.Resize((img_size,img_size)),
52 | torchvision.transforms.ToTensor(),
53 | #torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
54 | ])
55 |
56 | #-----------------------------------------metric imgs_tensor-------------------------------
57 | n = 0
58 | psnr_values = 0
59 | ssim_values = 0
60 | lpips_values = 0
61 | mse_values = 0
62 | cosine_values = 0
63 |
64 | imgs_path1 = [os.path.join(save_path1, f) for f in os.listdir(save_path1) if f.endswith(".png") or f.endswith(".jpg")]
65 | imgs_path2 = [os.path.join(save_path2, f) for f in os.listdir(save_path2) if f.endswith(".png") or f.endswith(".jpg")]
66 |
67 | for i,j in zip(imgs_path1,imgs_path2):
68 | print(i, 'vs.', j)
69 | img1 = Image.open(i).convert("RGB")
70 | img1 = transform(img1)
71 | img2 = Image.open(j).convert("RGB")
72 | img2 = transform(img2)
73 |
74 | img1 = img1*255.0
75 | img2 = img2*255.0
76 |
77 | n = n + 1
78 |
79 | psnr, ssim, lpips_value, mse_value, cosine_value = metrics(img1,img2)
80 | psnr_values +=psnr
81 | ssim_values +=ssim
82 | lpips_values +=lpips_value
83 | mse_values +=mse_value
84 | cosine_values +=cosine_value
85 |
86 | print('img_num:%d--psnr:%f--ssim:%f--mse_value:%f--lpips_value:%f--cosine_value:%f'\
87 | %(n,psnr_values/n, ssim_values/n, mse_values/n, lpips_values/n, cosine_values/n))
88 | # if imgs_tensor1 = imgs_tensor2: -psnr: inf or 88.132626(with 1e-3) --ssim:1.000000--lpips_value:0.000000--mse_value:0.000000--cosine_value:1.000001
89 |
90 | # imgs_path1 = [os.path.join(save_path1, f) for f in os.listdir(save_path1) if f.endswith(".png")]
91 | # images1 = []
92 | # for idx, image_path in enumerate(imgs_path1):
93 | # print(image_path)
94 | # img = Image.open(image_path).convert("RGB")
95 | # img = transform(img)
96 | # images1.append(img)
97 | # imgs_tensor1 = torch.stack(images1, dim=0)
98 |
99 |
100 | # imgs_path2 = [os.path.join(save_path2, f) for f in os.listdir(save_path2) if f.endswith(".png")]
101 | # images2 = []
102 | # for idx, image_path in enumerate(imgs_path2):
103 | # print(image_path)
104 | # img = Image.open(image_path).convert("RGB")
105 | # img = transform(img)
106 | # images2.append(img)
107 | # imgs_tensor2 = torch.stack(images2, dim=0)
108 |
109 | # if len(imgs_tensor1) != len(imgs_tensor2):
110 | # print('error: 2 comparing numbers are not equal!')
111 |
112 |
--------------------------------------------------------------------------------
/embeded_img_edit.py:
--------------------------------------------------------------------------------
1 | #Embedded_ImageProcessing, Just for StyleGAN_v1 FFHQ
2 |
3 | import numpy as np
4 | import math
5 | import torch
6 | import torchvision
7 | import model.E.E_Blur as BE
8 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
9 |
10 | #Params
11 | use_gpu = False
12 | device = torch.device("cuda" if use_gpu else "cpu")
13 | img_size = 1024
14 | GAN_path = './checkpoint/stylegan_v1/ffhq1024/'
15 | direction = 'eyeglasses' #smile, eyeglasses, pose, age, gender
16 | direction_path = './latentvectors/directions/stylegan_ffhq_%s_w_boundary.npy'%direction
17 | w_path = './latentvectors/faces/i3_cxx2.pt'
18 |
19 | #Loading Pre-trained Model, Directions
20 | Gs = Generator(startf=16, maxf=512, layer_count=int(math.log(img_size,2)-1), latent_size=512, channels=3)
21 | Gs.load_state_dict(torch.load(GAN_path+'Gs_dict.pth', map_location=device))
22 |
23 | # E = BE.BE()
24 | # E.load_state_dict(torch.load('./checkpoint/E/styleganv1.pth',map_location=torch.device('cpu')))
25 |
26 | direction = np.load(direction_path) #[[1, 512] interfaceGAN
27 | direction = torch.tensor(direction).float()
28 | direction = direction.expand(18,512)
29 | print(direction.shape)
30 |
31 | w = torch.load(w_path, map_location=device).clone().squeeze(0)
32 | print(w.shape)
33 |
34 | # discovering face semantic attribute dirrections
35 | bonus= 70 #bonus (-10) <- (-5) <- 0 ->5 ->10
36 | start= 0 # default 0, if not 0, will be bed performance
37 | end= 3 # default 3 or 4. if 3, it will keep face features (glasses). if 4, it will keep dirrection features (Smile).
38 | w[start:start+end] = (w+bonus*direction)[start:start+end]
39 | #w = w + bonus*direction
40 | w = w.reshape(1,18,512)
41 | with torch.no_grad():
42 | img = Gs.forward(w,8) # 8->1024
43 | torchvision.utils.save_image(img*0.5+0.5, './img_bonus%d_start%d_end%d.png'%(bonus,start,end))
44 |
45 | ## end=3 人物ID的特征明显,end=4 direction的特征明显, end>4 空间纠缠严重
46 | #smile: bonue*100, start=0, end=4(end不到4作用不大,end或bonus越大越猖狂)
47 | #glass: bonue*200, start=0, end=4(end超过6开始崩,bonus也不宜过大)
48 | #pose: bonue*5-10, start=0, end=3
--------------------------------------------------------------------------------
/image_results/baseline/ALAE-rec.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/ALAE-rec.jpg
--------------------------------------------------------------------------------
/image_results/baseline/Real.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/Real.jpg
--------------------------------------------------------------------------------
/image_results/baseline/indomain_images_inv.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/indomain_images_inv.jpg
--------------------------------------------------------------------------------
/image_results/baseline/indomain_images_rec.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/indomain_images_rec.jpg
--------------------------------------------------------------------------------
/image_results/baseline/mtv-tsa-(1500E).jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/mtv-tsa-(1500E).jpg
--------------------------------------------------------------------------------
/image_results/baseline/psp-rec.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/baseline/psp-rec.jpg
--------------------------------------------------------------------------------
/image_results/cxx1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/cxx1.gif
--------------------------------------------------------------------------------
/image_results/cxx2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/cxx2.gif
--------------------------------------------------------------------------------
/image_results/dy.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/dy.gif
--------------------------------------------------------------------------------
/image_results/msk.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/msk.gif
--------------------------------------------------------------------------------
/image_results/zy.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/image_results/zy.gif
--------------------------------------------------------------------------------
/latent_code/directions/stylegan_ffhq_age_w_boundary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_age_w_boundary.npy
--------------------------------------------------------------------------------
/latent_code/directions/stylegan_ffhq_eyeglasses_w_boundary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_eyeglasses_w_boundary.npy
--------------------------------------------------------------------------------
/latent_code/directions/stylegan_ffhq_gender_w_boundary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_gender_w_boundary.npy
--------------------------------------------------------------------------------
/latent_code/directions/stylegan_ffhq_pose_w_boundary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_pose_w_boundary.npy
--------------------------------------------------------------------------------
/latent_code/directions/stylegan_ffhq_smile_w_boundary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/directions/stylegan_ffhq_smile_w_boundary.npy
--------------------------------------------------------------------------------
/latent_code/real_face_code/i0_cxx1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i0_cxx1.pt
--------------------------------------------------------------------------------
/latent_code/real_face_code/i1_dy.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i1_dy.pt
--------------------------------------------------------------------------------
/latent_code/real_face_code/i2_zy.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i2_zy.pt
--------------------------------------------------------------------------------
/latent_code/real_face_code/i3_cxx2.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i3_cxx2.pt
--------------------------------------------------------------------------------
/latent_code/real_face_code/i4_msk.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i4_msk.pt
--------------------------------------------------------------------------------
/latent_code/real_face_code/i5_ty.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/disanda/Deep-GAN-Encoders/0c2655f3fa6e1a2d57fac7568f7bbd5c1bc4dbbb/latent_code/real_face_code/i5_ty.pt
--------------------------------------------------------------------------------
/metric/pytorch_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10 | return gauss/gauss.sum()
11 |
12 | def create_window(window_size, channel):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
16 | return window
17 |
18 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
21 |
22 | mu1_sq = mu1.pow(2)
23 | mu2_sq = mu2.pow(2)
24 | mu1_mu2 = mu1*mu2
25 |
26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
29 |
30 | C1 = 0.01**2
31 | C2 = 0.03**2
32 |
33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
34 |
35 | if size_average:
36 | return ssim_map.mean()
37 | else:
38 | return ssim_map.mean(1).mean(1).mean(1)
39 |
40 | class SSIM(torch.nn.Module):
41 | def __init__(self, window_size = 11, size_average = True):
42 | super(SSIM, self).__init__()
43 | self.window_size = window_size
44 | self.size_average = size_average
45 | self.channel = 1
46 | self.window = create_window(window_size, self.channel)
47 |
48 | def forward(self, img1, img2):
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def ssim(img1, img2, window_size = 11, size_average = True):
67 | (_, channel, _, _) = img1.size()
68 | window = create_window(window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | return _ssim(img1, img2, window, window_size, channel, size_average)
75 |
--------------------------------------------------------------------------------
/model/E/Ablation_Study/E_Blur_W.py:
--------------------------------------------------------------------------------
1 | # For ablation study: w
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | from torch.nn.parameter import Parameter
7 | import sys
8 | #sys.path.append('../')
9 | import model.utils.lreq as ln
10 | from model.utils.net import Blur,FromRGB,downscale2d
11 | from torch.nn import functional as F
12 |
13 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
14 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
15 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
16 |
17 | class BEBlock(nn.Module):
18 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样
19 | super().__init__()
20 | self.has_last_conv = has_last_conv
21 | # self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
22 | # self.noise_weight_1.data.zero_()
23 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
24 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
25 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512]
26 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
27 |
28 | # self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
29 | # self.noise_weight_2.data.zero_()
30 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
31 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
32 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1)
33 | self.blur = Blur(inputs)
34 | if has_last_conv:
35 | if fused_scale:
36 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True)
37 | else:
38 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
39 | self.fused_scale = fused_scale
40 |
41 | self.inputs = inputs
42 | self.outputs = outputs
43 |
44 | if self.inputs != self.outputs:
45 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0)
46 |
47 | with torch.no_grad():
48 | self.bias_1.zero_()
49 | self.bias_2.zero_()
50 |
51 | def forward(self, x):
52 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
53 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
54 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1]
55 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512]
56 |
57 | residual = x
58 |
59 | x = self.instance_norm_1(x)
60 | x = self.conv_1(x)
61 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
62 | x = x + self.bias_1
63 | x = F.leaky_relu(x, 0.2)
64 |
65 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
66 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
67 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1]
68 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view
69 |
70 | x = self.instance_norm_2(x)
71 | if self.has_last_conv:
72 | x = self.blur(x)
73 | x = self.conv_2(x)
74 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
75 | x = x + self.bias_2
76 | x = F.leaky_relu(x, 0.2)
77 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样
78 | x = downscale2d(x)
79 | residual = downscale2d(residual)
80 |
81 |
82 | if self.inputs != self.outputs:
83 | residual = self.conv_3(residual)
84 |
85 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好
86 | return x, w1, w2
87 |
88 |
89 | class BE(nn.Module):
90 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3):
91 | super().__init__()
92 | self.maxf = maxf
93 | self.startf = startf
94 | self.latent_size = latent_size
95 | self.layer_to_resolution = [0 for _ in range(layer_count)]
96 | self.decode_block = nn.ModuleList()
97 | self.layer_count = layer_count
98 | inputs = startf # 16
99 | outputs = startf*2
100 | resolution = 1024
101 | self.FromRGB = FromRGB(channels, inputs)
102 | #from_RGB = nn.ModuleList()
103 | for i in range(layer_count):
104 |
105 | has_last_conv = i+1 != layer_count
106 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样
107 |
108 | #from_RGB.append(FromRGB(channels, inputs))
109 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale)
110 |
111 | inputs = inputs*2
112 | outputs = outputs*2
113 | inputs = min(maxf, inputs)
114 | outputs = min(maxf, outputs)
115 | self.layer_to_resolution[i] = resolution
116 | resolution /=2
117 | self.decode_block.append(block)
118 |
119 | #self.FromRGB = from_RGB
120 |
121 | #将w逆序,以保证和G的w顺序, block_num控制progressive
122 | def forward(self, x, block_num=9):
123 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB
124 | x = self.FromRGB(x)
125 | #print(x.shape)
126 | w = torch.tensor(0)
127 | for i in range(9-block_num,self.layer_count):
128 | x,w1,w2 = self.decode_block[i](x)
129 | #print(x.shape)
130 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512]
131 | if i == (9-block_num):
132 | w = w_ # [b,n,512]
133 | else:
134 | w = torch.cat((w_,w),dim=1)
135 | #print(w.shape)
136 | return x, w
--------------------------------------------------------------------------------
/model/E/Ablation_Study/E_Blur_W_2.py:
--------------------------------------------------------------------------------
1 | # For ablation study: w/2
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | from torch.nn.parameter import Parameter
7 | import sys
8 | #sys.path.append('../')
9 | import model.utils.lreq as ln
10 | from model.utils.net import Blur,FromRGB,downscale2d
11 | from torch.nn import functional as F
12 |
13 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
14 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
15 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
16 |
17 | class BEBlock(nn.Module):
18 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样
19 | super().__init__()
20 | self.has_last_conv = has_last_conv
21 | # self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
22 | # self.noise_weight_1.data.zero_()
23 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
24 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
25 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512]
26 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
27 |
28 | # self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
29 | # self.noise_weight_2.data.zero_()
30 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
31 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
32 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1)
33 | self.blur = Blur(inputs)
34 | if has_last_conv:
35 | if fused_scale:
36 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True)
37 | else:
38 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
39 | self.fused_scale = fused_scale
40 |
41 | self.inputs = inputs
42 | self.outputs = outputs
43 |
44 | if self.inputs != self.outputs:
45 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0)
46 |
47 | with torch.no_grad():
48 | self.bias_1.zero_()
49 | self.bias_2.zero_()
50 |
51 | def forward(self, x):
52 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
53 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
54 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1]
55 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512]
56 |
57 | residual = x
58 |
59 | x = self.instance_norm_1(x)
60 | x = self.conv_1(x)
61 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
62 | x = x + self.bias_1
63 | x = F.leaky_relu(x, 0.2)
64 |
65 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
66 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
67 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1]
68 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view
69 |
70 | x = self.instance_norm_2(x)
71 | if self.has_last_conv:
72 | x = self.blur(x)
73 | x = self.conv_2(x)
74 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
75 | x = x + self.bias_2
76 | x = F.leaky_relu(x, 0.2)
77 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样
78 | x = downscale2d(x)
79 | residual = downscale2d(residual)
80 |
81 |
82 | if self.inputs != self.outputs:
83 | residual = self.conv_3(residual)
84 |
85 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好
86 | return x, w1, w2
87 |
88 |
89 | class BE(nn.Module):
90 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3):
91 | super().__init__()
92 | self.maxf = maxf
93 | self.startf = startf
94 | self.latent_size = latent_size
95 | self.layer_to_resolution = [0 for _ in range(layer_count)]
96 | self.decode_block = nn.ModuleList()
97 | self.layer_count = layer_count
98 | inputs = startf # 16
99 | outputs = startf*2
100 | resolution = 1024
101 | self.FromRGB = FromRGB(channels, inputs)
102 | #from_RGB = nn.ModuleList()
103 | for i in range(layer_count):
104 |
105 | has_last_conv = i+1 != layer_count
106 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样
107 |
108 | #from_RGB.append(FromRGB(channels, inputs))
109 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale)
110 |
111 | inputs = inputs*2
112 | outputs = outputs*2
113 | inputs = min(maxf, inputs)
114 | outputs = min(maxf, outputs)
115 | self.layer_to_resolution[i] = resolution
116 | resolution /=2
117 | self.decode_block.append(block)
118 |
119 | #self.FromRGB = from_RGB
120 |
121 | #将w逆序,以保证和G的w顺序, block_num控制progressive
122 | def forward(self, x, block_num=9):
123 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB
124 | x = self.FromRGB(x)
125 | #print(x.shape)
126 | w = torch.tensor(0)
127 | for i in range(9-block_num,self.layer_count):
128 | x, _ , w2 = self.decode_block[i](x)
129 | #print(x.shape)
130 | w_ = torch.cat((w2.view(x.shape[0],1,512),w2.view(x.shape[0],1,512)),dim=1) # [b,2,512]
131 | if i == (9-block_num):
132 | w = w_ # [b,n,512]
133 | else:
134 | w = torch.cat((w_,w),dim=1)
135 | #print(w.shape)
136 | return x, w
--------------------------------------------------------------------------------
/model/E/Ablation_Study/E_Blur_Z.py:
--------------------------------------------------------------------------------
1 | # For ablation study: w
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | from torch.nn.parameter import Parameter
7 | import sys
8 | #sys.path.append('../')
9 | import model.utils.lreq as ln
10 | from model.utils.net import Blur,FromRGB,downscale2d
11 | from torch.nn import functional as F
12 |
13 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
14 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
15 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
16 |
17 | class BEBlock(nn.Module):
18 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样
19 | super().__init__()
20 | self.has_last_conv = has_last_conv
21 | # self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
22 | # self.noise_weight_1.data.zero_()
23 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
24 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
25 | #self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512]
26 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
27 |
28 | # self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
29 | # self.noise_weight_2.data.zero_()
30 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
31 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
32 | #self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1)
33 | self.blur = Blur(inputs)
34 | if has_last_conv:
35 | if fused_scale:
36 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True)
37 | else:
38 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
39 | self.fused_scale = fused_scale
40 |
41 | self.inputs = inputs
42 | self.outputs = outputs
43 |
44 | if self.inputs != self.outputs:
45 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0)
46 |
47 | with torch.no_grad():
48 | self.bias_1.zero_()
49 | self.bias_2.zero_()
50 |
51 | def forward(self, x):
52 | # mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
53 | # std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
54 | # style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1]
55 | # w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512]
56 |
57 | residual = x
58 |
59 | x = self.instance_norm_1(x)
60 | x = self.conv_1(x)
61 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
62 | x = x + self.bias_1
63 | x = F.leaky_relu(x, 0.2)
64 |
65 | # mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
66 | # std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
67 | # style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1]
68 | # w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view
69 |
70 | x = self.instance_norm_2(x)
71 | if self.has_last_conv:
72 | x = self.blur(x)
73 | x = self.conv_2(x)
74 | #x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
75 | x = x + self.bias_2
76 | x = F.leaky_relu(x, 0.2)
77 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样
78 | x = downscale2d(x)
79 | residual = downscale2d(residual)
80 |
81 |
82 | if self.inputs != self.outputs:
83 | residual = self.conv_3(residual)
84 |
85 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好
86 | return x
87 |
88 |
89 | class BE(nn.Module):
90 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3):
91 | super().__init__()
92 | self.maxf = maxf
93 | self.startf = startf
94 | self.latent_size = latent_size
95 | self.layer_to_resolution = [0 for _ in range(layer_count)]
96 | self.decode_block = nn.ModuleList()
97 | self.layer_count = layer_count
98 | inputs = startf # 16
99 | outputs = startf*2
100 | resolution = 1024
101 | self.FromRGB = FromRGB(channels, inputs)
102 | self.out_z = ln.Conv2d(512,512,3,2)
103 | #from_RGB = nn.ModuleList()
104 | for i in range(layer_count):
105 |
106 | has_last_conv = i+1 != layer_count
107 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样
108 |
109 | #from_RGB.append(FromRGB(channels, inputs))
110 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale)
111 |
112 | inputs = inputs*2
113 | outputs = outputs*2
114 | inputs = min(maxf, inputs)
115 | outputs = min(maxf, outputs)
116 | self.layer_to_resolution[i] = resolution
117 | resolution /=2
118 | self.decode_block.append(block)
119 |
120 | #self.FromRGB = from_RGB
121 |
122 | #将w逆序,以保证和G的w顺序, block_num控制progressive
123 | def forward(self, x, block_num=9):
124 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB
125 | x = self.FromRGB(x)
126 | #print(x.shape)
127 | w = torch.tensor(0)
128 | for i in range(9-block_num,self.layer_count):
129 | x = self.decode_block[i](x)
130 | #print(x.shape)
131 | # w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512]
132 | # if i == (9-block_num):
133 | # w = w_ # [b,n,512]
134 | # else:
135 | # w = torch.cat((w_,w),dim=1)
136 | # #print(w.shape)
137 | z = self.out_z(x)
138 | return z, w
--------------------------------------------------------------------------------
/model/E/Ablation_Study/E_v1.py:
--------------------------------------------------------------------------------
1 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
2 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
3 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
4 | # 比第2版,即可以使用到styleganv1,styleganv2, 不再使用带Equalize learning rate的Conv (这条已经废除). 以及Block第二层的blur操作
5 | # 改变了上采样,不在conv中完成
6 | # 改变了In,带参数的学习
7 | # 改变了了residual,和残差网络一致,另外旁路多了conv1处理通道和In学习参数
8 | # 经测试,不带Eq(Equalize Learning Rate)的参数层学习效果不好
9 |
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.nn import init
14 | from torch.nn.parameter import Parameter
15 | import sys
16 | sys.path.append('../')
17 | from torch.nn import functional as F
18 | import model.utils.lreq as ln
19 |
20 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
21 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
22 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
23 |
24 | class FromRGB(nn.Module):
25 | def __init__(self, channels, outputs):
26 | super(FromRGB, self).__init__()
27 | self.from_rgb = ln.Conv2d(channels, outputs, 1, 1, 0)
28 | def forward(self, x):
29 | x = self.from_rgb(x)
30 | x = F.leaky_relu(x, 0.2)
31 | return x
32 |
33 | class BEBlock(nn.Module):
34 | def __init__(self, inputs, outputs, latent_size, has_second_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样
35 | super().__init__()
36 | self.has_second_conv = has_second_conv
37 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
38 | self.noise_weight_1.data.zero_()
39 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
40 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
41 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size) # [n, 2c] -> [n,512]
42 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
43 |
44 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
45 | self.noise_weight_2.data.zero_()
46 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
47 | self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8)
48 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size)
49 | if has_second_conv:
50 | if fused_scale:
51 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False)
52 | else:
53 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
54 | self.fused_scale = fused_scale
55 |
56 | self.inputs = inputs
57 | self.outputs = outputs
58 |
59 | if self.inputs != self.outputs:
60 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0)
61 | self.instance_norm_3 = nn.InstanceNorm2d(outputs, affine=True, eps=1e-8)
62 |
63 | with torch.no_grad():
64 | self.bias_1.zero_()
65 | self.bias_2.zero_()
66 |
67 | def forward(self, x):
68 | residual = x
69 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
70 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
71 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1]
72 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512]
73 |
74 | x = self.conv_1(x)
75 | x = self.instance_norm_1(x)
76 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
77 | x = x + self.bias_1
78 | x = F.leaky_relu(x, 0.2)
79 |
80 |
81 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
82 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
83 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1]
84 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view
85 |
86 | if self.has_second_conv:
87 | x = self.conv_2(x)
88 | x = self.instance_norm_2(x)
89 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
90 | x = x + self.bias_2
91 | if self.inputs != self.outputs:
92 | residual = self.conv_3(residual)
93 | residual = self.instance_norm_3(residual)
94 | x = x + residual
95 | x = F.leaky_relu(x, 0.2)
96 | if not self.fused_scale: #上采样
97 | x = F.avg_pool2d(x, 2, 2)
98 |
99 | #x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好
100 | return x, w1, w2
101 |
102 |
103 | class BE(nn.Module):
104 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3, pggan=False):
105 | super().__init__()
106 | self.maxf = maxf
107 | self.startf = startf
108 | self.latent_size = latent_size
109 | #self.layer_to_resolution = [0 for _ in range(layer_count)]
110 | self.decode_block = nn.ModuleList()
111 | self.layer_count = layer_count
112 | inputs = startf # 16
113 | outputs = startf*2
114 | #resolution = 1024
115 | # from_RGB = nn.ModuleList()
116 | fused_scale = False
117 | self.FromRGB = FromRGB(channels, inputs)
118 |
119 | for i in range(layer_count):
120 |
121 | has_second_conv = i+1 != layer_count #普通的D最后一个块的第二层是 mini_batch_std
122 | #fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样
123 | #from_RGB.append(FromRGB(channels, inputs))
124 |
125 | block = BEBlock(inputs, outputs, latent_size, has_second_conv, fused_scale=fused_scale)
126 |
127 | inputs = inputs*2
128 | outputs = outputs*2
129 | inputs = min(maxf, inputs)
130 | outputs = min(maxf, outputs)
131 | #self.layer_to_resolution[i] = resolution
132 | #resolution /=2
133 | self.decode_block.append(block)
134 | #self.FromRGB = from_RGB
135 |
136 | self.pggan = pggan
137 | if pggan:
138 | self.new_final = nn.Conv2d(512, 512, 4, 1, 0, bias=True)
139 |
140 | #将w逆序,以保证和G的w顺序, block_num控制progressive
141 | def forward(self, x, block_num=9):
142 | #x = self.FromRGB[9-block_num](x) #每个block一个
143 | x = self.FromRGB(x)
144 | w = torch.tensor(0)
145 | for i in range(9-block_num,self.layer_count):
146 | x,w1,w2 = self.decode_block[i](x)
147 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512]
148 | if i == (9-block_num):
149 | w = w_ # [b,n,512]
150 | else:
151 | w = torch.cat((w_,w),dim=1)
152 | if self.pggan:
153 | x = self.new_final(x)
154 | return x, w
155 |
156 | #test
157 | # E = BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3)
158 | # imgs1 = torch.randn(2,3,256,256)
159 | # const2,w2 = E(imgs1)
160 | # print(const2.shape)
161 | # print(w2.shape)
162 | # print(E)
163 |
--------------------------------------------------------------------------------
/model/E/Ablation_Study/E_v2_std.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | from torch.nn.parameter import Parameter
5 | import sys
6 | #sys.path.append('../')
7 | import model.utils.lreq as ln
8 | from model.utils.net import Blur,FromRGB,downscale2d
9 | from torch.nn import functional as F
10 |
11 |
12 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
13 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
14 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
15 |
16 | # 这一版的w只有std,没有mean, 比马要差一些,参照3090-styleganv2马的训练
17 |
18 | class BEBlock(nn.Module):
19 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样
20 | super().__init__()
21 | self.has_last_conv = has_last_conv
22 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
23 | self.noise_weight_1.data.zero_()
24 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
25 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
26 | self.inver_mod1 = ln.Linear(inputs, latent_size, gain=1) # [n, 2c] -> [n,512]
27 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
28 |
29 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
30 | self.noise_weight_2.data.zero_()
31 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
32 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
33 | self.inver_mod2 = ln.Linear(inputs, latent_size, gain=1)
34 | self.blur = Blur(inputs)
35 | if has_last_conv:
36 | if fused_scale:
37 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True)
38 | else:
39 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
40 | self.fused_scale = fused_scale
41 |
42 | self.inputs = inputs
43 | self.outputs = outputs
44 |
45 | if self.inputs != self.outputs:
46 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0)
47 |
48 | with torch.no_grad():
49 | self.bias_1.zero_()
50 | self.bias_2.zero_()
51 |
52 | def forward(self, x):
53 | #mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
54 | #std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
55 | #style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1]
56 | std1 = x.std((2,3))
57 | w1 = self.inver_mod1(std1.squeeze()) # [b,2c]->[b,512]
58 |
59 | residual = x
60 |
61 | x = self.instance_norm_1(x)
62 | x = self.conv_1(x)
63 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
64 | x = x + self.bias_1
65 | x = F.leaky_relu(x, 0.2)
66 |
67 | #mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
68 | #std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
69 | #style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1]
70 | std2 = x.std((2,3))
71 | w2 = self.inver_mod2(std2.squeeze()) # [b,512] , 这里style2.view一直写错成style1.view
72 |
73 | x = self.instance_norm_2(x)
74 | if self.has_last_conv:
75 | x = self.blur(x)
76 | x = self.conv_2(x)
77 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
78 | x = x + self.bias_2
79 | x = F.leaky_relu(x, 0.2)
80 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样
81 | x = downscale2d(x)
82 | residual = downscale2d(residual)
83 |
84 |
85 | if self.inputs != self.outputs:
86 | residual = self.conv_3(residual)
87 |
88 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好
89 | return x, w1, w2
90 |
91 |
92 | class BE(nn.Module):
93 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3, pggan=False):
94 | super().__init__()
95 | self.maxf = maxf
96 | self.startf = startf
97 | self.latent_size = latent_size
98 | self.layer_to_resolution = [0 for _ in range(layer_count)]
99 | self.decode_block = nn.ModuleList()
100 | self.layer_count = layer_count
101 | inputs = startf # 16
102 | outputs = startf*2
103 | resolution = 1024
104 | self.FromRGB = FromRGB(channels, inputs)
105 | #from_RGB = nn.ModuleList()
106 | for i in range(layer_count):
107 |
108 | has_last_conv = i+1 != layer_count
109 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样
110 |
111 | #from_RGB.append(FromRGB(channels, inputs))
112 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale)
113 |
114 | inputs = inputs*2
115 | outputs = outputs*2
116 | inputs = min(maxf, inputs)
117 | outputs = min(maxf, outputs)
118 | self.layer_to_resolution[i] = resolution
119 | resolution /=2
120 | self.decode_block.append(block)
121 |
122 | self.pggan = pggan
123 | if self.pggan:
124 | #self.new_final = ln.Conv2d(512, 512, 4, 1, 0, bias=True)
125 | self.new_final = ln.Linear(512 * 16, latent_size, gain=1)
126 | #self.FromRGB = from_RGB
127 |
128 | #将w逆序,以保证和G的w顺序, block_num控制progressive
129 | def forward(self, x, block_num=9):
130 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB
131 | x = self.FromRGB(x)
132 | w = torch.tensor(0)
133 | for i in range(9-block_num,self.layer_count):
134 | x,w1,w2 = self.decode_block[i](x)
135 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512]
136 | if i == (9-block_num):
137 | w = w_ # [b,n,512]
138 | else:
139 | w = torch.cat((w_,w),dim=1)
140 | if self.pggan:
141 | #x = self.new_final(x)
142 | x = self.new_final(x.view(x.shape[0],-1))
143 | return x, w
--------------------------------------------------------------------------------
/model/E/E.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | from torch.nn.parameter import Parameter
5 | import sys
6 | #sys.path.append('../')
7 | import model.utils.lreq as ln
8 | from model.utils.net import Blur,FromRGB,downscale2d
9 | from torch.nn import functional as F
10 |
11 |
12 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
13 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
14 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
15 |
16 | class BEBlock(nn.Module):
17 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样
18 | super().__init__()
19 | self.has_last_conv = has_last_conv
20 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
21 | self.noise_weight_1.data.zero_()
22 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
23 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
24 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512]
25 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
26 |
27 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
28 | self.noise_weight_2.data.zero_()
29 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
30 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
31 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1)
32 | #self.blur = Blur(inputs)
33 | if has_last_conv:
34 | if fused_scale:
35 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True)
36 | else:
37 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
38 | self.fused_scale = fused_scale
39 |
40 | self.inputs = inputs
41 | self.outputs = outputs
42 |
43 | if self.inputs != self.outputs:
44 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0)
45 |
46 | with torch.no_grad():
47 | self.bias_1.zero_()
48 | self.bias_2.zero_()
49 |
50 | def forward(self, x):
51 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
52 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
53 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1]
54 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512]
55 |
56 | residual = x
57 |
58 | x = self.instance_norm_1(x)
59 | x = self.conv_1(x)
60 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
61 | x = x + self.bias_1
62 | x = F.leaky_relu(x, 0.2)
63 |
64 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
65 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
66 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1]
67 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view
68 |
69 | x = self.instance_norm_2(x)
70 | if self.has_last_conv:
71 | #x = self.blur(x)
72 | x = self.conv_2(x)
73 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
74 | x = x + self.bias_2
75 | x = F.leaky_relu(x, 0.2)
76 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样
77 | x = downscale2d(x)
78 | residual = downscale2d(residual)
79 |
80 |
81 | if self.inputs != self.outputs:
82 | residual = self.conv_3(residual)
83 |
84 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好
85 | return x, w1, w2
86 |
87 |
88 | class BE(nn.Module):
89 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3):
90 | super().__init__()
91 | self.maxf = maxf
92 | self.startf = startf
93 | self.latent_size = latent_size
94 | self.layer_to_resolution = [0 for _ in range(layer_count)]
95 | self.decode_block = nn.ModuleList()
96 | self.layer_count = layer_count
97 | inputs = startf # 16
98 | outputs = startf*2
99 | resolution = 1024
100 | self.FromRGB = FromRGB(channels, inputs)
101 | #from_RGB = nn.ModuleList()
102 | for i in range(layer_count):
103 |
104 | has_last_conv = i+1 != layer_count
105 | #fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样
106 | fused_scale = False
107 |
108 | #from_RGB.append(FromRGB(channels, inputs))
109 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale)
110 |
111 | inputs = inputs*2
112 | outputs = outputs*2
113 | inputs = min(maxf, inputs)
114 | outputs = min(maxf, outputs)
115 | self.layer_to_resolution[i] = resolution
116 | resolution /=2
117 | self.decode_block.append(block)
118 |
119 | #self.FromRGB = from_RGB
120 |
121 | #将w逆序,以保证和G的w顺序, block_num控制progressive
122 | def forward(self, x, block_num=9):
123 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB
124 | x = self.FromRGB(x)
125 | #print(x.shape)
126 | w = torch.tensor(0)
127 | for i in range(9-block_num,self.layer_count):
128 | x,w1,w2 = self.decode_block[i](x)
129 | #print(x.shape)
130 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512]
131 | if i == (9-block_num):
132 | w = w_ # [b,n,512]
133 | else:
134 | w = torch.cat((w_,w),dim=1)
135 | #print(w.shape)
136 | return x, w
--------------------------------------------------------------------------------
/model/E/E_Blur.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | from torch.nn.parameter import Parameter
5 | import sys
6 | #sys.path.append('../')
7 | import model.utils.lreq as ln
8 | from model.utils.net import Blur,FromRGB,downscale2d
9 | from torch.nn import functional as F
10 |
11 |
12 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
13 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
14 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
15 |
16 | class BEBlock(nn.Module):
17 | def __init__(self, inputs, outputs, latent_size, has_last_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样
18 | super().__init__()
19 | self.has_last_conv = has_last_conv
20 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
21 | self.noise_weight_1.data.zero_()
22 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
23 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
24 | self.inver_mod1 = ln.Linear(2 * inputs, latent_size, gain=1) # [n, 2c] -> [n,512]
25 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
26 |
27 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
28 | self.noise_weight_2.data.zero_()
29 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
30 | self.instance_norm_2 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
31 | self.inver_mod2 = ln.Linear(2 * inputs, latent_size, gain=1)
32 | self.blur = Blur(inputs)
33 | if has_last_conv:
34 | if fused_scale:
35 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True)
36 | else:
37 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
38 | self.fused_scale = fused_scale
39 |
40 | self.inputs = inputs
41 | self.outputs = outputs
42 |
43 | if self.inputs != self.outputs:
44 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0)
45 |
46 | with torch.no_grad():
47 | self.bias_1.zero_()
48 | self.bias_2.zero_()
49 |
50 | def forward(self, x):
51 | mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
52 | std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
53 | style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1]
54 | w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512]
55 |
56 | residual = x
57 |
58 | x = self.instance_norm_1(x)
59 | x = self.conv_1(x)
60 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
61 | x = x + self.bias_1
62 | x = F.leaky_relu(x, 0.2)
63 |
64 | mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
65 | std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
66 | style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1]
67 | w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view
68 |
69 | x = self.instance_norm_2(x)
70 | if self.has_last_conv:
71 | x = self.blur(x)
72 | x = self.conv_2(x)
73 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
74 | x = x + self.bias_2
75 | x = F.leaky_relu(x, 0.2)
76 | if not self.fused_scale: #在新的一层起初 fused_scale = flase, 完成上采样
77 | x = downscale2d(x)
78 | residual = downscale2d(residual)
79 |
80 |
81 | if self.inputs != self.outputs:
82 | residual = self.conv_3(residual)
83 |
84 | x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好
85 | return x, w1, w2
86 |
87 |
88 | class BE(nn.Module):
89 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3):
90 | super().__init__()
91 | self.maxf = maxf
92 | self.startf = startf
93 | self.latent_size = latent_size
94 | self.layer_to_resolution = [0 for _ in range(layer_count)]
95 | self.decode_block = nn.ModuleList()
96 | self.layer_count = layer_count
97 | inputs = startf # 16
98 | outputs = startf*2
99 | resolution = 1024
100 | self.FromRGB = FromRGB(channels, inputs)
101 | #from_RGB = nn.ModuleList()
102 | for i in range(layer_count):
103 |
104 | has_last_conv = i+1 != layer_count
105 | fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样
106 |
107 | #from_RGB.append(FromRGB(channels, inputs))
108 | block = BEBlock(inputs, outputs, latent_size, has_last_conv, fused_scale=fused_scale)
109 |
110 | inputs = inputs*2
111 | outputs = outputs*2
112 | inputs = min(maxf, inputs)
113 | outputs = min(maxf, outputs)
114 | self.layer_to_resolution[i] = resolution
115 | resolution /=2
116 | self.decode_block.append(block)
117 |
118 | #self.FromRGB = from_RGB
119 |
120 | #将w逆序,以保证和G的w顺序, block_num控制progressive
121 | def forward(self, x, block_num=9):
122 | #x = self.FromRGB[9-block_num](x) #不是progressive,去除多余的FromRGB
123 | x = self.FromRGB(x)
124 | #print(x.shape)
125 | w = torch.tensor(0)
126 | for i in range(9-block_num,self.layer_count):
127 | x,w1,w2 = self.decode_block[i](x)
128 | #print(x.shape)
129 | w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512]
130 | if i == (9-block_num):
131 | w = w_ # [b,n,512]
132 | else:
133 | w = torch.cat((w_,w),dim=1)
134 | #print(w.shape)
135 | return x, w
--------------------------------------------------------------------------------
/model/E/E_PG.py:
--------------------------------------------------------------------------------
1 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
2 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
3 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
4 | # 比第2版,即可以使用到styleganv1,styleganv2, 不再使用带Equalize learning rate的Conv (这条已经废除). 以及Block第二层的blur操作
5 | # 改变了上采样,不在conv中完成
6 | # 改变了In,带参数的学习
7 | # 改变了了residual,和残差网络一致,另外旁路多了conv1处理通道和In学习参数
8 | # 经测试,不带Eq(Equalize Learning Rate)的参数层学习效果不好
9 |
10 |
11 | #这一版兼容PGGAN和BIGGAN: 主要改变最后一层,增加FC
12 | #PGGAN: 加一个fc, 和原D类似
13 | #BIGGAN,加两个fc,各128channel,其中一个是标签,完成128->1000的映射
14 | #BIGGAN 改进思路1: IN替换CBN
15 | #BIGGAN 改进思路2: G加w,和E的w对称
16 |
17 | import torch
18 | import torch.nn as nn
19 | from torch.nn import init
20 | from torch.nn.parameter import Parameter
21 | import sys
22 | sys.path.append('../')
23 | from torch.nn import functional as F
24 | import model.utils.lreq as ln
25 |
26 | # G 改 E, 实际上需要用G Block改出E block, 完成逆序对称,在同样位置还原style潜码
27 | # 比第0版多了残差, 每一层的两个(conv/line)输出的w1和w2合并为1个w
28 | # 比第1版加了要学习的bias_1和bias_2,网络顺序和第1版有所不同(更对称)
29 |
30 | class FromRGB(nn.Module):
31 | def __init__(self, channels, outputs):
32 | super(FromRGB, self).__init__()
33 | self.from_rgb = ln.Conv2d(channels, outputs, 1, 1, 0)
34 | def forward(self, x):
35 | x = self.from_rgb(x)
36 | x = F.leaky_relu(x, 0.2)
37 | return x
38 |
39 | class BEBlock(nn.Module):
40 | def __init__(self, inputs, outputs, latent_size, has_second_conv=True, fused_scale=True): #分辨率大于128用fused_scale,即conv完成上采样
41 | super().__init__()
42 | self.has_second_conv = has_second_conv
43 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
44 | self.noise_weight_1.data.zero_()
45 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
46 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False, eps=1e-8)
47 | #self.inver_mod1 = ln.Linear(2 * inputs, latent_size) # [n, 2c] -> [n,512]
48 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
49 |
50 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
51 | self.noise_weight_2.data.zero_()
52 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
53 | self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8)
54 | #self.inver_mod2 = ln.Linear(2 * inputs, latent_size)
55 | if has_second_conv:
56 | if fused_scale:
57 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False)
58 | else:
59 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
60 | self.fused_scale = fused_scale
61 |
62 | self.inputs = inputs
63 | self.outputs = outputs
64 |
65 | if self.inputs != self.outputs:
66 | self.conv_3 = ln.Conv2d(inputs, outputs, 1, 1, 0)
67 | self.instance_norm_3 = nn.InstanceNorm2d(outputs, affine=True, eps=1e-8)
68 |
69 | with torch.no_grad():
70 | self.bias_1.zero_()
71 | self.bias_2.zero_()
72 |
73 | def forward(self, x):
74 | residual = x
75 | # mean1 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
76 | # std1 = torch.sqrt(torch.mean((x - mean1) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
77 | # style1 = torch.cat((mean1, std1), dim=1) # [b,2c,1,1]
78 | # w1 = self.inver_mod1(style1.view(style1.shape[0],style1.shape[1])) # [b,2c]->[b,512]
79 |
80 | x = self.instance_norm_1(x)
81 | x = self.conv_1(x)
82 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_1, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
83 | x = x + self.bias_1
84 | x = F.leaky_relu(x, 0.2)
85 |
86 |
87 | # mean2 = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
88 | # std2 = torch.sqrt(torch.mean((x - mean2) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
89 | # style2 = torch.cat((mean2, std2), dim=1) # [b,2c,1,1]
90 | # w2 = self.inver_mod2(style2.view(style2.shape[0],style2.shape[1])) # [b,512] , 这里style2.view一直写错成style1.view
91 |
92 | if self.has_second_conv:
93 | x = self.instance_norm_2(x)
94 | x = self.conv_2(x)
95 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]).to(x.device))
96 | x = x + self.bias_2
97 | if self.inputs != self.outputs:
98 | residual = self.conv_3(residual)
99 | residual = self.instance_norm_3(residual)
100 | x = x + residual
101 | x = F.leaky_relu(x, 0.2)
102 | if not self.fused_scale: #上采样
103 | x = F.avg_pool2d(x, 2, 2)
104 |
105 | #x = 0.111*x+0.889*residual #降低x的比例,可以将const的loss缩小!!0.7*residual: 10-11 >> 7 同时 c_s的loss扩大至3, ws的抖动提前, 效果更好
106 | w1 = 0
107 | w2 = 0
108 | return x , w1, w2
109 |
110 |
111 | class BE(nn.Module):
112 | def __init__(self, startf=16, maxf=512, layer_count=9, latent_size=512, channels=3, pggan=False):
113 | super().__init__()
114 | self.maxf = maxf
115 | self.startf = startf
116 | self.latent_size = latent_size
117 | #self.layer_to_resolution = [0 for _ in range(layer_count)]
118 | self.decode_block = nn.ModuleList()
119 | self.layer_count = layer_count
120 | inputs = startf # 16
121 | outputs = startf*2
122 | #resolution = 1024
123 | # from_RGB = nn.ModuleList()
124 | fused_scale = False
125 | self.FromRGB = FromRGB(channels, inputs)
126 |
127 | for i in range(layer_count):
128 |
129 | has_second_conv = i+1 != layer_count #普通的D最后一个块的第二层是 mini_batch_std
130 | #fused_scale = resolution >= 128 # 在新的一层起初 fused_scale = flase, 完成上采样
131 | #from_RGB.append(FromRGB(channels, inputs))
132 |
133 | block = BEBlock(inputs, outputs, latent_size, has_second_conv, fused_scale=fused_scale)
134 |
135 | inputs = inputs*2
136 | outputs = outputs*2
137 | inputs = min(maxf, inputs)
138 | outputs = min(maxf, outputs)
139 | #self.layer_to_resolution[i] = resolution
140 | #resolution /=2
141 | self.decode_block.append(block)
142 | #self.FromRGB = from_RGB
143 |
144 | self.pggan = pggan
145 | if pggan:
146 | #self.new_final = nn.Conv2d(512, 512, 4, 1, 0, bias=True)
147 | self.new_final = ln.Linear(512 * 16, latent_size, gain=1)
148 |
149 | #将w逆序,以保证和G的w顺序, block_num控制progressive
150 | def forward(self, x, block_num=9):
151 | #x = self.FromRGB[9-block_num](x) #每个block一个
152 | x = self.FromRGB(x)
153 | w = torch.tensor(0)
154 | for i in range(9-block_num,self.layer_count):
155 | x,w1,w2 = self.decode_block[i](x)
156 | # w_ = torch.cat((w2.view(x.shape[0],1,512),w1.view(x.shape[0],1,512)),dim=1) # [b,2,512]
157 | # if i == (9-block_num):
158 | # w = w_ # [b,n,512]
159 | # else:
160 | # w = torch.cat((w_,w),dim=1)
161 | if self.pggan:
162 | #x = self.new_final(x)
163 | x = self.new_final(x.view(x.shape[0],-1))
164 | return torch.tensor(0), w
165 |
166 | #test
167 | # E = BE(startf=64, maxf=512, layer_count=7, latent_size=512, channels=3)
168 | # imgs1 = torch.randn(2,3,256,256)
169 | # const2,w2 = E(imgs1)
170 | # print(const2.shape)
171 | # print(w2.shape)
172 | # print(E)
173 |
--------------------------------------------------------------------------------
/model/pggan/utils/Encoder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import ModuleList, AvgPool2d
5 | from networks.PGGAN.CustomLayers import DisGeneralConvBlock, DisFinalBlock, _equalized_conv2d
6 |
7 | #Discriminator to Encoder, Just change the last layer
8 | #和原网络中D对应的Encoder, 训练时G不变, v1只改了最后一层,v2是一个规模较小的网络
9 |
10 | # in: [-1,512] , out: [-1,3,1024,1024]
11 | class encoder(torch.nn.Module):
12 | """ Discriminator of the GAN """
13 | def __init__(self, height=7, feature_size=512, use_eql=True):
14 | """
15 | constructor for the class
16 | :param height: total height of the discriminator (Must be equal to the Generator depth)
17 | :param feature_size: size of the deepest features extracted
18 | (Must be equal to Generator latent_size)
19 | :param use_eql: whether to use equalized learning rate
20 | """
21 | super().__init__()
22 | assert feature_size != 0 and ((feature_size & (feature_size - 1)) == 0), \
23 | "latent size not a power of 2"
24 | if height >= 4:
25 | assert feature_size >= np.power(2, height - 4), "feature size cannot be produced"
26 | # create state of the object
27 | self.use_eql = use_eql
28 | self.height = height
29 | self.feature_size = feature_size
30 | #self.final_block = DisFinalBlock(self.feature_size, use_eql=self.use_eql)
31 | # create a module list of the other required general convolution blocks
32 | self.layers = ModuleList([]) # initialize to empty list
33 | # create the fromRGB layers for various inputs:
34 | if self.use_eql:
35 | self.fromRGB = lambda out_channels: \
36 | _equalized_conv2d(3, out_channels, (1, 1), bias=True)
37 | else:
38 | from torch.nn import Conv2d
39 | self.fromRGB = lambda out_channels: Conv2d(3, out_channels, (1, 1), bias=True)
40 | self.rgb_to_features = ModuleList([self.fromRGB(self.feature_size)])
41 | # create the remaining layers
42 | for i in range(self.height - 1):
43 | if i > 2:
44 | layer = DisGeneralConvBlock(
45 | int(self.feature_size // np.power(2, i - 2)),
46 | int(self.feature_size // np.power(2, i - 3)),
47 | use_eql=self.use_eql
48 | )
49 | rgb = self.fromRGB(int(self.feature_size // np.power(2, i - 2)))
50 | else:
51 | layer = DisGeneralConvBlock(self.feature_size,
52 | self.feature_size, use_eql=self.use_eql)
53 | rgb = self.fromRGB(self.feature_size)
54 | self.layers.append(layer)
55 | self.rgb_to_features.append(rgb)
56 | # register the temporary downSampler
57 | self.temporaryDownsampler = AvgPool2d(2)
58 | #new
59 | self.new_final = nn.Conv2d(512, 512, 4, 1, 0, bias=True)
60 | def forward(self, x, depth, alpha):
61 | """
62 | forward pass of the discriminator
63 | :param x: input to the network
64 | :param depth: current depth of operation (Progressive GAN)
65 | :param alpha: current value of alpha for fade-in
66 | :return: out => raw prediction values (WGAN-GP)
67 | """
68 | assert depth < self.height, "Requested output depth cannot be produced"
69 | if depth > 0:
70 | residual = self.rgb_to_features[depth - 1](self.temporaryDownsampler(x))
71 |
72 | straight = self.layers[depth - 1](
73 | self.rgb_to_features[depth](x)
74 | )
75 |
76 | y = (alpha * straight) + ((1 - alpha) * residual)
77 |
78 | for block in reversed(self.layers[:depth - 1]):
79 | y = block(y)
80 | else:
81 | y = self.rgb_to_features[0](x)
82 |
83 | #out = self.final_block(y)
84 | out = self.new_final(y)
85 | return out
86 |
87 | #in: [-1,3,1024,1024], out: [-1,512]
88 | class encoder_small(torch.nn.Module):
89 | def __init__(self):
90 | super().__init__()
91 | self.main = nn.Sequential(
92 | nn.Conv2d(3,12,4,2,1,bias=False), # 1024->512
93 | nn.LeakyReLU(0.2, inplace=True),
94 | nn.Conv2d(12,12,4,2,1,bias=False),# 512->256
95 | nn.BatchNorm2d(12),
96 | nn.LeakyReLU(0.2, inplace=True),
97 | nn.Conv2d(12,3,4,2,1,bias=False),# 256->128
98 | nn.BatchNorm2d(3),
99 | nn.LeakyReLU(0.2, inplace=True),
100 | nn.Conv2d(3,1,4,2,1,bias=False),# 128->64*64=4096
101 | )
102 | self.fc = nn.Linear(4096,512)
103 | def forward(self, x):
104 | y1 = self.main(x)
105 | y2 = y1.view(-1,4096)
106 | y3 = self.fc(y2)
107 | return y3
--------------------------------------------------------------------------------
/model/stylegan1/alae.py:
--------------------------------------------------------------------------------
1 | #ALAE
2 | class EncodeBlock(nn.Module):
3 | def __init__(self, inputs, outputs, latent_size, last=False, fused_scale=True):
4 | super(EncodeBlock, self).__init__()
5 | self.conv_1 = ln.Conv2d(inputs, inputs, 3, 1, 1, bias=False)
6 | self.bias_1 = nn.Parameter(torch.Tensor(1, inputs, 1, 1))
7 | self.instance_norm_1 = nn.InstanceNorm2d(inputs, affine=False)
8 | self.blur = Blur(inputs)
9 | self.last = last
10 | self.fused_scale = fused_scale
11 | if last:
12 | self.dense = ln.Linear(inputs * 4 * 4, outputs)
13 | else:
14 | if fused_scale:
15 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True)
16 | else:
17 | self.conv_2 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False)
18 |
19 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1))
20 | self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False)
21 | self.style_1 = ln.Linear(2 * inputs, latent_size)
22 | if last:
23 | self.style_2 = ln.Linear(outputs, latent_size)
24 | else:
25 | self.style_2 = ln.Linear(2 * outputs, latent_size)
26 |
27 | with torch.no_grad():
28 | self.bias_1.zero_()
29 | self.bias_2.zero_()
30 |
31 | def forward(self, x):
32 | x = self.conv_1(x) + self.bias_1
33 | x = F.leaky_relu(x, 0.2)
34 |
35 | m = torch.mean(x, dim=[2, 3], keepdim=True) # [b, c, 1, 1]
36 | std = torch.sqrt(torch.mean((x - m) ** 2, dim=[2, 3], keepdim=True)) # [b, c, 1, 1]
37 | style_1 = torch.cat((m, std), dim=1) # [b,2c,1,1]
38 |
39 | x = self.instance_norm_1(x)
40 |
41 | if self.last:
42 | x = self.dense(x.view(x.shape[0], -1))
43 |
44 | x = F.leaky_relu(x, 0.2)
45 | w1 = self.style_1(style_1.view(style_1.shape[0], style_1.shape[1]))
46 | w2 = self.style_2(x.view(x.shape[0], x.shape[1]))
47 | else:
48 | x = self.conv_2(self.blur(x))
49 | if not self.fused_scale:
50 | x = downscale2d(x)
51 | x = x + self.bias_2
52 |
53 | x = F.leaky_relu(x, 0.2)
54 |
55 | m = torch.mean(x, dim=[2, 3], keepdim=True)
56 | std = torch.sqrt(torch.mean((x - m) ** 2, dim=[2, 3], keepdim=True))
57 | style_2 = torch.cat((m, std), dim=1)
58 |
59 | x = self.instance_norm_2(x)
60 |
61 | w1 = self.style_1(style_1.view(style_1.shape[0], style_1.shape[1]))
62 | w2 = self.style_2(style_2.view(style_2.shape[0], style_2.shape[1]))
63 |
64 | return x, w1, w2 #降采样一次的结果 , [b,512] , [b,512]
65 |
66 | class EncoderDefault(nn.Module):
67 | def __init__(self, startf, maxf, layer_count, latent_size, channels=3):
68 | super(EncoderDefault, self).__init__()
69 | self.maxf = maxf
70 | self.startf = startf
71 | self.layer_count = layer_count
72 | self.from_rgb: nn.ModuleList[FromRGB] = nn.ModuleList()
73 | self.channels = channels
74 | self.latent_size = latent_size
75 |
76 | mul = 2
77 | inputs = startf
78 | self.encode_block: nn.ModuleList[EncodeBlock] = nn.ModuleList()
79 |
80 | resolution = 2 ** (self.layer_count + 1)
81 |
82 | for i in range(self.layer_count):
83 | outputs = min(self.maxf, startf * mul)
84 |
85 | self.from_rgb.append(FromRGB(channels, inputs))
86 |
87 | fused_scale = resolution >= 128
88 |
89 | block = EncodeBlock(inputs, outputs, latent_size, False, fused_scale=fused_scale)
90 |
91 | resolution //= 2
92 |
93 | self.encode_block.append(block)
94 | inputs = outputs
95 | mul *= 2
96 |
97 | def encode(self, x, lod):
98 | styles = torch.zeros(x.shape[0], 1, self.latent_size) #[b, 1 , 512]
99 |
100 | x = self.from_rgb[self.layer_count - lod - 1](x)
101 | x = F.leaky_relu(x, 0.2)
102 |
103 | for i in range(self.layer_count - lod - 1, self.layer_count):
104 | x, s1, s2 = self.encode_block[i](x)
105 | styles[:, 0] += s1 + s2
106 |
107 | return styles
108 |
109 | def encode2(self, x, lod, blend):
110 | x_orig = x
111 | styles = torch.zeros(x.shape[0], 1, self.latent_size)
112 |
113 | x = self.from_rgb[self.layer_count - lod - 1](x)
114 | x = F.leaky_relu(x, 0.2)
115 |
116 | x, s1, s2 = self.encode_block[self.layer_count - lod - 1](x)
117 | styles[:, 0] += s1 * blend + s2 * blend
118 |
119 | x_prev = F.avg_pool2d(x_orig, 2, 2)
120 |
121 | x_prev = self.from_rgb[self.layer_count - (lod - 1) - 1](x_prev)
122 | x_prev = F.leaky_relu(x_prev, 0.2)
123 |
124 | x = torch.lerp(x_prev, x, blend)
125 |
126 | for i in range(self.layer_count - (lod - 1) - 1, self.layer_count):
127 | x, s1, s2 = self.encode_block[i](x)
128 | styles[:, 0] += s1 + s2
129 |
130 | return styles
131 |
132 | def forward(self, x, lod, blend):
133 | if blend == 1:
134 | return self.encode(x, lod)
135 | else:
136 | return self.encode2(x, lod, blend)
137 |
138 | def get_statistics(self, lod):
139 | rgb_std = self.from_rgb[self.layer_count - lod - 1].from_rgb.weight.std().item()
140 | rgb_std_c = self.from_rgb[self.layer_count - lod - 1].from_rgb.std
141 |
142 | layers = []
143 | for i in range(self.layer_count - lod - 1, self.layer_count):
144 | conv_1 = self.encode_block[i].conv_1.weight.std().item()
145 | conv_1_c = self.encode_block[i].conv_1.std
146 | conv_2 = self.encode_block[i].conv_2.weight.std().item()
147 | conv_2_c = self.encode_block[i].conv_2.std
148 | layers.append(((conv_1 / conv_1_c), (conv_2 / conv_2_c)))
149 | return rgb_std / rgb_std_c, layers
150 |
--------------------------------------------------------------------------------
/model/stylegan1/custom_adam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class LREQAdam(Optimizer):
7 | def __init__(self, params, lr=1e-3, betas=(0.0, 0.99), eps=1e-8, weight_decay=0):
8 | beta_2 = betas[1]
9 | if not 0.0 <= lr:
10 | raise ValueError("Invalid learning rate: {}".format(lr))
11 | if not 0.0 <= eps:
12 | raise ValueError("Invalid epsilon value: {}".format(eps))
13 | if not 0.0 == betas[0]:
14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
15 | if not 0.0 <= beta_2 < 1.0:
16 | raise ValueError("Invalid beta parameter at index 1: {}".format(beta_2))
17 | defaults = dict(lr=lr, beta_2=beta_2, eps=eps, weight_decay=weight_decay)
18 | super(LREQAdam, self).__init__(params, defaults)
19 |
20 | def __setstate__(self, state):
21 | super(LREQAdam, self).__setstate__(state)
22 |
23 | def step(self, closure=None):
24 | """Performs a single optimization step.
25 | Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.
26 | """
27 | loss = None
28 | if closure is not None:
29 | loss = closure()
30 |
31 | for group in self.param_groups:
32 | for p in group['params']:
33 | if p.grad is None:
34 | continue
35 | grad = p.grad.data
36 | if grad.is_sparse:
37 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
38 |
39 | state = self.state[p]
40 |
41 | # State initialization
42 | if len(state) == 0:
43 | state['step'] = 0
44 | # Exponential moving average of gradient values
45 | # state['exp_avg'] = torch.zeros_like(p.data)
46 | # Exponential moving average of squared gradient values
47 | state['exp_avg_sq'] = torch.zeros_like(p.data)
48 |
49 | # exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
50 | exp_avg_sq = state['exp_avg_sq']
51 | beta_2 = group['beta_2']
52 |
53 | state['step'] += 1
54 |
55 | if group['weight_decay'] != 0:
56 | grad.add_(group['weight_decay'], p.data / p.coef)
57 |
58 | # Decay the first and second moment running average coefficient
59 | # exp_avg.mul_(beta1).add_(1 - beta1, grad)
60 | exp_avg_sq.mul_(beta_2).addcmul_(1 - beta_2, grad, grad)
61 | denom = exp_avg_sq.sqrt().add_(group['eps'])
62 |
63 | # bias_correction1 = 1 - beta1 ** state['step'] # 1
64 | bias_correction2 = 1 - beta_2 ** state['step']
65 | # step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
66 | step_size = group['lr'] * math.sqrt(bias_correction2)
67 |
68 | # p.data.addcdiv_(-step_size, exp_avg, denom)
69 | if hasattr(p, 'lr_equalization_coef'):
70 | step_size *= p.lr_equalization_coef
71 |
72 | p.data.addcdiv_(-step_size, grad, denom) # p = p + (-steo_size)/grad * demon
73 |
74 | return loss
75 |
--------------------------------------------------------------------------------
/model/stylegan1/lod_driver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import time
4 | from collections import defaultdict
5 |
6 | #一个可以返回各类训练参数(param)的对象
7 | class LODDriver:
8 | def __init__(self, cfg, logger, dataset_size):
9 | self.lod_2_batch = cfg.TRAIN.LOD_2_BATCH_1GPU #[128, 128, 128, 64, 32, 16]
10 | self.minibatch_base = 16
11 | self.cfg = cfg
12 | self.dataset_size = dataset_size
13 | self.current_epoch = 0
14 | self.lod = -1
15 | self.in_transition = False
16 | self.logger = logger
17 | self.iteration = 0
18 | self.epoch_end_time = 0
19 | self.epoch_start_time = 0
20 | self.per_epoch_ptime = 0
21 | self.reports = cfg.TRAIN.REPORT_FREQ
22 | self.snapshots = cfg.TRAIN.SNAPSHOT_FREQ
23 | self.tick_start_nimg_report = 0
24 | self.tick_start_nimg_snapshot = 0
25 |
26 | def get_lod_power2(self):
27 | return self.lod + 2
28 |
29 | def get_batch_size(self):
30 | return self.lod_2_batch[min(self.lod, len(self.lod_2_batch) - 1)]
31 |
32 | def get_dataset_size(self):
33 | return self.dataset_size
34 |
35 | def get_blend_factor(self):
36 | blend_factor = float((self.current_epoch % self.cfg.TRAIN.EPOCHS_PER_LOD) * self.dataset_size + self.iteration)
37 | blend_factor /= float(self.cfg.TRAIN.EPOCHS_PER_LOD // 2 * self.dataset_size)
38 | blend_factor = math.sin(blend_factor * math.pi - 0.5 * math.pi) * 0.5 + 0.5
39 |
40 | if not self.in_transition:
41 | blend_factor = 1
42 |
43 | return blend_factor
44 |
45 | def is_time_to_report(self):
46 | if self.iteration >= self.tick_start_nimg_report + self.reports[min(self.lod, len(self.reports) - 1)] * 1000:
47 | self.tick_start_nimg_report = self.iteration
48 | return True
49 | return False
50 |
51 | def is_time_to_save(self):
52 | if self.iteration >= self.tick_start_nimg_snapshot + self.snapshots[min(self.lod, len(self.snapshots) - 1)] * 1000:
53 | self.tick_start_nimg_snapshot = self.iteration
54 | return True
55 | return False
56 |
57 | def step(self):
58 | self.iteration += self.get_batch_size()
59 | self.epoch_end_time = time.time()
60 | self.per_epoch_ptime = self.epoch_end_time - self.epoch_start_time
61 |
62 | def set_epoch(self, epoch, optimizers):
63 | self.current_epoch = epoch
64 | self.iteration = 0
65 | self.tick_start_nimg_report = 0
66 | self.tick_start_nimg_snapshot = 0
67 | self.epoch_start_time = time.time()
68 |
69 | new_lod = min(self.cfg.MODEL.LAYER_COUNT - 1, epoch // self.cfg.TRAIN.EPOCHS_PER_LOD) #第一个值是固定的。第二个值最小是0,提升lod即分辨率
70 | if new_lod != self.lod:
71 | self.lod = new_lod
72 | self.logger.info("#" * 80)
73 | self.logger.info("# Switching LOD to %d" % self.lod)
74 | self.logger.info("#" * 80)
75 | self.logger.info("Start transition")
76 | self.in_transition = True
77 | for opt in optimizers:
78 | opt.state = defaultdict(dict)
79 |
80 | is_in_first_half_of_cycle = (epoch % self.cfg.TRAIN.EPOCHS_PER_LOD) < (self.cfg.TRAIN.EPOCHS_PER_LOD // 2)
81 | is_growing = epoch // self.cfg.TRAIN.EPOCHS_PER_LOD == self.lod > 0
82 | new_in_transition = is_in_first_half_of_cycle and is_growing
83 |
84 | if new_in_transition != self.in_transition:
85 | self.in_transition = new_in_transition
86 | self.logger.info("#" * 80)
87 | self.logger.info("# Transition ended")
88 | self.logger.info("#" * 80)
89 |
--------------------------------------------------------------------------------
/model/stylegan1/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | __all__ = ['kl', 'reconstruction', 'discriminator_logistic_simple_gp',
6 | 'discriminator_gradient_penalty', 'generator_logistic_non_saturating']
7 |
8 |
9 | def kl(mu, log_var):
10 | return -0.5 * torch.mean(torch.mean(1 + log_var - mu.pow(2) - log_var.exp(), 1))
11 |
12 |
13 | def reconstruction(recon_x, x, lod=None):
14 | return torch.mean((recon_x - x)**2)
15 |
16 |
17 | def discriminator_logistic_simple_gp(d_result_fake, d_result_real, reals, r1_gamma=10.0):
18 | loss = (F.softplus(d_result_fake) + F.softplus(-d_result_real))
19 |
20 | if r1_gamma != 0.0:
21 | real_loss = d_result_real.sum()
22 | real_grads = torch.autograd.grad(real_loss, reals, create_graph=True, retain_graph=True)[0]
23 | r1_penalty = torch.sum(real_grads.pow(2.0), dim=[1, 2, 3])
24 | loss = loss + r1_penalty * (r1_gamma * 0.5)
25 | return loss.mean()
26 |
27 |
28 | def discriminator_gradient_penalty(d_result_real, reals, r1_gamma=10.0):
29 | real_loss = d_result_real.sum()
30 | real_grads = torch.autograd.grad(real_loss, reals, create_graph=True, retain_graph=True)[0]
31 | r1_penalty = torch.sum(real_grads.pow(2.0), dim=[1, 2, 3])
32 | loss = r1_penalty * (r1_gamma * 0.5)
33 | return loss.mean()
34 |
35 |
36 | def generator_logistic_non_saturating(d_result_fake):
37 | return F.softplus(-d_result_fake).mean()
38 |
--------------------------------------------------------------------------------
/model/stylegan1/lreq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.nn import init
5 | from torch.nn.parameter import Parameter
6 | import numpy as np
7 |
8 |
9 | class Bool:
10 | def __init__(self):
11 | self.value = False
12 |
13 | def __bool__(self):
14 | return self.value
15 | __nonzero__ = __bool__
16 |
17 | def set(self, value):
18 | self.value = value
19 |
20 |
21 | use_implicit_lreq = Bool()
22 | use_implicit_lreq.set(True)
23 |
24 |
25 | def is_sequence(arg):
26 | return (not hasattr(arg, "strip") and
27 | hasattr(arg, "__getitem__") or
28 | hasattr(arg, "__iter__"))
29 |
30 |
31 | def make_tuple(x, n):
32 | if is_sequence(x):
33 | return x
34 | return tuple([x for _ in range(n)])
35 |
36 |
37 | class Linear(nn.Module):
38 | def __init__(self, in_features, out_features, bias=True, gain=np.sqrt(2.0), lrmul=1.0, implicit_lreq=use_implicit_lreq):
39 | super(Linear, self).__init__()
40 | self.in_features = in_features
41 | self.weight = Parameter(torch.Tensor(out_features, in_features))
42 | if bias:
43 | self.bias = Parameter(torch.Tensor(out_features))
44 | else:
45 | self.register_parameter('bias', None)
46 | self.std = 0
47 | self.gain = gain
48 | self.lrmul = lrmul
49 | self.implicit_lreq = implicit_lreq
50 | self.reset_parameters()
51 |
52 | def reset_parameters(self):
53 | self.std = self.gain / np.sqrt(self.in_features) * self.lrmul
54 | if not self.implicit_lreq:
55 | init.normal_(self.weight, mean=0, std=1.0 / self.lrmul)
56 | else:
57 | init.normal_(self.weight, mean=0, std=self.std / self.lrmul)
58 | setattr(self.weight, 'lr_equalization_coef', self.std)
59 | if self.bias is not None:
60 | setattr(self.bias, 'lr_equalization_coef', self.lrmul)
61 |
62 | if self.bias is not None:
63 | with torch.no_grad():
64 | self.bias.zero_()
65 |
66 | def forward(self, input):
67 | if not self.implicit_lreq:
68 | bias = self.bias
69 | if bias is not None:
70 | bias = bias * self.lrmul
71 | return F.linear(input, self.weight * self.std, bias)
72 | else:
73 | return F.linear(input, self.weight, self.bias)
74 |
75 |
76 | class Conv2d(nn.Module):
77 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1,
78 | groups=1, bias=True, gain=np.sqrt(2.0), transpose=False, transform_kernel=False, lrmul=1.0,
79 | implicit_lreq=use_implicit_lreq):
80 | super(Conv2d, self).__init__()
81 | if in_channels % groups != 0:
82 | raise ValueError('in_channels must be divisible by groups')
83 | if out_channels % groups != 0:
84 | raise ValueError('out_channels must be divisible by groups')
85 | self.in_channels = in_channels
86 | self.out_channels = out_channels
87 | self.kernel_size = make_tuple(kernel_size, 2)
88 | self.stride = make_tuple(stride, 2)
89 | self.padding = make_tuple(padding, 2)
90 | self.output_padding = make_tuple(output_padding, 2)
91 | self.dilation = make_tuple(dilation, 2)
92 | self.groups = groups
93 | self.gain = gain
94 | self.lrmul = lrmul
95 | self.transpose = transpose
96 | self.fan_in = np.prod(self.kernel_size) * in_channels // groups
97 | self.transform_kernel = transform_kernel
98 | if transpose:
99 | self.weight = Parameter(torch.Tensor(in_channels, out_channels // groups, *self.kernel_size))
100 | else:
101 | self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
102 | if bias:
103 | self.bias = Parameter(torch.Tensor(out_channels))
104 | else:
105 | self.register_parameter('bias', None)
106 | self.std = 0
107 | self.implicit_lreq = implicit_lreq
108 | self.reset_parameters()
109 |
110 | def reset_parameters(self):
111 | self.std = self.gain / np.sqrt(self.fan_in)
112 | if not self.implicit_lreq:
113 | init.normal_(self.weight, mean=0, std=1.0 / self.lrmul)
114 | else:
115 | init.normal_(self.weight, mean=0, std=self.std / self.lrmul)
116 | setattr(self.weight, 'lr_equalization_coef', self.std)
117 | if self.bias is not None:
118 | setattr(self.bias, 'lr_equalization_coef', self.lrmul)
119 |
120 | if self.bias is not None:
121 | with torch.no_grad():
122 | self.bias.zero_()
123 |
124 | def forward(self, x):
125 | if self.transpose:
126 | w = self.weight
127 | if self.transform_kernel:
128 | w = F.pad(w, (1, 1, 1, 1), mode='constant')
129 | w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
130 | if not self.implicit_lreq:
131 | bias = self.bias
132 | if bias is not None:
133 | bias = bias * self.lrmul
134 | return F.conv_transpose2d(x, w * self.std, bias, stride=self.stride,
135 | padding=self.padding, output_padding=self.output_padding,
136 | dilation=self.dilation, groups=self.groups)
137 | else:
138 | return F.conv_transpose2d(x, w, self.bias, stride=self.stride, padding=self.padding,
139 | output_padding=self.output_padding, dilation=self.dilation,
140 | groups=self.groups)
141 | else:
142 | w = self.weight
143 | if self.transform_kernel:
144 | w = F.pad(w, (1, 1, 1, 1), mode='constant')
145 | w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25
146 | if not self.implicit_lreq:
147 | bias = self.bias
148 | if bias is not None:
149 | bias = bias * self.lrmul
150 | return F.conv2d(x, w * self.std, bias, stride=self.stride, padding=self.padding,
151 | dilation=self.dilation, groups=self.groups)
152 | else:
153 | return F.conv2d(x, w, self.bias, stride=self.stride, padding=self.padding,
154 | dilation=self.dilation, groups=self.groups)
155 |
156 |
157 | class ConvTranspose2d(Conv2d):
158 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1,
159 | groups=1, bias=True, gain=np.sqrt(2.0), transform_kernel=False, lrmul=1.0, implicit_lreq=use_implicit_lreq):
160 | super(ConvTranspose2d, self).__init__(in_channels=in_channels,
161 | out_channels=out_channels,
162 | kernel_size=kernel_size,
163 | stride=stride,
164 | padding=padding,
165 | output_padding=output_padding,
166 | dilation=dilation,
167 | groups=groups,
168 | bias=bias,
169 | gain=gain,
170 | transpose=True,
171 | transform_kernel=transform_kernel,
172 | lrmul=lrmul,
173 | implicit_lreq=implicit_lreq)
174 |
175 |
176 | class SeparableConv2d(nn.Module):
177 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1,
178 | bias=True, gain=np.sqrt(2.0), transpose=False):
179 | super(SeparableConv2d, self).__init__()
180 | self.spatial_conv = Conv2d(in_channels, in_channels, kernel_size, stride, padding, output_padding, dilation,
181 | in_channels, False, 1, transpose)
182 | self.channel_conv = Conv2d(in_channels, out_channels, 1, bias, 1, gain=gain)
183 |
184 | def forward(self, x):
185 | return self.channel_conv(self.spatial_conv(x))
186 |
187 |
188 | class SeparableConvTranspose2d(Conv2d):
189 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1,
190 | bias=True, gain=np.sqrt(2.0)):
191 | super(SeparableConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
192 | output_padding, dilation, bias, gain, True)
193 |
194 |
--------------------------------------------------------------------------------
/model/stylegan1/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import random
4 | import module.losses as losses
5 | from module.net import Generator, Mapping, Discriminator
6 | import numpy as np
7 |
8 |
9 | class DLatent(nn.Module):
10 | def __init__(self, dlatent_size, layer_count):
11 | super(DLatent, self).__init__()
12 | buffer = torch.zeros(layer_count, dlatent_size, dtype=torch.float32)
13 | self.register_buffer('buff', buffer)
14 |
15 | class Model(nn.Module): #三个网络 Gp Gs D 的封装
16 | def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_layers=5, dlatent_avg_beta=None,
17 | truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3):
18 | super(Model, self).__init__()
19 |
20 | self.mapping = Mapping(
21 | num_layers=2 * layer_count, # 2*9 = 18
22 | latent_size=latent_size,
23 | dlatent_size=latent_size,
24 | mapping_fmaps=latent_size,
25 | mapping_layers=mapping_layers)
26 |
27 | self.generator = Generator(
28 | startf=startf,
29 | layer_count=layer_count,
30 | maxf=maxf,
31 | latent_size=latent_size,
32 | channels=channels)
33 |
34 | self.discriminator = Discriminator(
35 | startf=startf,
36 | layer_count=layer_count,
37 | maxf=maxf,
38 | channels=channels)
39 |
40 | self.dlatent_avg = DLatent(latent_size, self.mapping.num_layers)
41 | self.latent_size = latent_size
42 | self.dlatent_avg_beta = dlatent_avg_beta
43 | self.truncation_psi = truncation_psi # 0.7
44 | self.style_mixing_prob = style_mixing_prob
45 | self.truncation_cutoff = truncation_cutoff # 前8层
46 |
47 | def generate(self, lod, blend_factor, z=None, count=32, remove_blob=False):
48 | if z is None:
49 | z = torch.randn(count, self.latent_size)
50 | styles = self.mapping(z)
51 | if self.dlatent_avg_beta is not None: #让原向量以中心向量 dlatent_avg.buff.data 为中心,按比例self.dlatent_avg_beta=0.995围绕中心向量拉近,
52 | with torch.no_grad():
53 | batch_avg = styles.mean(dim=0)
54 | self.dlatent_avg.buff.data.lerp_(batch_avg.data, 1.0 - self.dlatent_avg_beta) # y.lerp(x,a) = y - (y-x)*a
55 |
56 | if self.style_mixing_prob is not None:
57 | if random.random() < self.style_mixing_prob:
58 | z2 = torch.randn(count, self.latent_size) # z2 : [32, 512]
59 | styles2 = self.mapping(z2)
60 |
61 | layer_idx = torch.arange(self.mapping.num_layers)[np.newaxis, :, np.newaxis]
62 | cur_layers = (lod + 1) * 2
63 | mixing_cutoff = random.randint(1, cur_layers)
64 | styles = torch.where(layer_idx < mixing_cutoff, styles, styles2)
65 |
66 | if self.truncation_psi is not None: #让原向量以中心向量 dlatent_avg.buff.data 为中心,按比例truncation_psi围绕中心向量拉近,
67 | layer_idx = torch.arange(self.mapping.num_layers)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
68 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
69 | coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) # 18个变量前8个裁剪比例truncation_psi
70 | styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) # avg + (styles-avg) * 0.7
71 |
72 | rec = self.generator.forward(styles, lod, blend_factor, remove_blob) # styles:[-1 , 18, 512]
73 | return rec
74 |
75 | def forward(self, x, lod, blend_factor, d_train):
76 | if d_train:
77 | with torch.no_grad():
78 | rec = self.generate(lod, blend_factor, count=x.shape[0])
79 | self.discriminator.requires_grad_(True)
80 | d_result_real = self.discriminator(x, lod, blend_factor).squeeze()
81 | d_result_fake = self.discriminator(rec.detach(), lod, blend_factor).squeeze()
82 |
83 | loss_d = losses.discriminator_logistic_simple_gp(d_result_fake, d_result_real, x)
84 | return loss_d
85 | else:
86 | rec = self.generate(lod, blend_factor, count=x.shape[0])
87 | self.discriminator.requires_grad_(False)
88 | d_result_fake = self.discriminator(rec, lod, blend_factor).squeeze()
89 | loss_g = losses.generator_logistic_non_saturating(d_result_fake)
90 | return loss_g
91 |
92 | def lerp(self, other, betta):
93 | if hasattr(other, 'module'):
94 | other = other.module
95 | with torch.no_grad():
96 | params = list(self.mapping.parameters()) + list(self.generator.parameters()) + list(self.dlatent_avg.parameters())
97 | other_param = list(other.mapping.parameters()) + list(other.generator.parameters()) + list(other.dlatent_avg.parameters())
98 | for p, p_other in zip(params, other_param):
99 | p.data.lerp_(p_other.data, 1.0 - betta)
100 |
--------------------------------------------------------------------------------
/model/stylegan1/text_alae.py:
--------------------------------------------------------------------------------
1 | from net import *
2 | from model import Model
3 | from launcher import run
4 | from dataloader import *
5 | from checkpointer import Checkpointer
6 | from dlutils.pytorch import count_parameters
7 | from defaults import get_cfg_defaults
8 | from PIL import Image
9 | import PIL
10 | from torchvision.utils import save_image
11 |
12 | def draw_uncurated_result_figure(cfg, png, model, cx, cy, cw, ch, rows, lods, seed):
13 | print(png)
14 | N = sum(rows * 2**lod for lod in lods)
15 | images = []
16 |
17 | rnd = np.random.RandomState(5)
18 | for i in range(N):
19 | latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
20 | samplez = torch.tensor(latents).float().cuda()
21 | image = model.generate(cfg.DATASET.MAX_RESOLUTION_LEVEL-2, 1, samplez, 1, mixing=True)
22 | images.append(image[0])
23 |
24 | canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white')
25 | image_iter = iter(list(images))
26 | for col, lod in enumerate(lods):
27 | for row in range(rows * 2**lod):
28 | im = next(image_iter).cpu().numpy()
29 | im = im.transpose(1, 2, 0)
30 | im = im * 0.5 + 0.5
31 | image = PIL.Image.fromarray(np.clip(im * 255, 0, 255).astype(np.uint8), 'RGB')
32 | image = image.crop((cx, cy, cx + cw, cy + ch))
33 | image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS)
34 | canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod))
35 | canvas.save(png)
36 |
37 |
38 | def sample(cfg, logger):
39 | torch.cuda.set_device(0)
40 | model = Model(
41 | startf=cfg.MODEL.START_CHANNEL_COUNT,
42 | layer_count=cfg.MODEL.LAYER_COUNT,
43 | maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
44 | latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
45 | truncation_psi=cfg.MODEL.TRUNCATIOM_PSI,
46 | truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
47 | style_mixing_prob=cfg.MODEL.STYLE_MIXING_PROB,
48 | mapping_layers=cfg.MODEL.MAPPING_LAYERS,
49 | channels=cfg.MODEL.CHANNELS,
50 | generator=cfg.MODEL.GENERATOR,
51 | encoder=cfg.MODEL.ENCODER)
52 |
53 | model.cuda(0)
54 | model.eval()
55 | model.requires_grad_(False)
56 |
57 | decoder = model.decoder
58 | encoder = model.encoder
59 | mapping_tl = model.mapping_d
60 | mapping_fl = model.mapping_f
61 |
62 | dlatent_avg = model.dlatent_avg
63 |
64 | logger.info("Trainable parameters generator:")
65 | count_parameters(decoder)
66 |
67 | logger.info("Trainable parameters discriminator:")
68 | count_parameters(encoder)
69 |
70 | arguments = dict()
71 | arguments["iteration"] = 0
72 |
73 | model_dict = {
74 | 'discriminator_s': encoder,
75 | 'generator_s': decoder,
76 | 'mapping_tl_s': mapping_tl,
77 | 'mapping_fl_s': mapping_fl,
78 | 'dlatent_avg': dlatent_avg
79 | }
80 |
81 | checkpointer = Checkpointer(cfg,
82 | model_dict,
83 | {},
84 | logger=logger,
85 | save=False)
86 |
87 | checkpointer.load()
88 |
89 | model.eval()
90 |
91 | layer_count = cfg.MODEL.LAYER_COUNT
92 |
93 | decoder = nn.DataParallel(decoder)
94 |
95 | im_size = 2 ** (cfg.MODEL.LAYER_COUNT + 1)
96 | with torch.no_grad():
97 | rnd = np.random.RandomState(0)
98 | latents = rnd.randn(1, 512)
99 | samplez = torch.tensor(latents).float().cuda()
100 | image = model.generate(8, 1, samplez, 1, mixing=True)
101 | ls = model.encode(image,8,1)
102 | x = model.decoder(x, 8, 1, noise=True)
103 | save_image(samplez,'1.png')
104 | save_image(x,'2.png')
105 |
106 | if __name__ == "__main__":
107 | gpu_count = 1
108 | run(sample, get_cfg_defaults(), description='ALAE-generations', default_config='configs/ffhq.yaml',
109 | world_size=gpu_count, write_log=False)
110 |
--------------------------------------------------------------------------------
/model/utils/biggan_config.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | """
3 | BigGAN config.
4 | """
5 | from __future__ import (absolute_import, division, print_function, unicode_literals)
6 |
7 | import copy
8 | import json
9 | import os
10 |
11 | class BigGANConfig(object):
12 | """ Configuration class to store the configuration of a `BigGAN`.
13 | Defaults are for the 128x128 model.
14 | layers tuple are (up-sample in the layer ?, input channels, output channels)
15 | """
16 | def __init__(self,
17 | output_dim=128,
18 | z_dim=128,
19 | class_embed_dim=128,
20 | channel_width=128,
21 | num_classes=1000,
22 | layers=[(False, 16, 16),
23 | (True, 16, 16),
24 | (False, 16, 16),
25 | (True, 16, 8),
26 | (False, 8, 8),
27 | (True, 8, 4),
28 | (False, 4, 4),
29 | (True, 4, 2),
30 | (False, 2, 2),
31 | (True, 2, 1)],
32 | attention_layer_position=8,
33 | eps=1e-4,
34 | n_stats=51):
35 | """Constructs BigGANConfig. """
36 | self.output_dim = output_dim
37 | self.z_dim = z_dim
38 | self.class_embed_dim = class_embed_dim
39 | self.channel_width = channel_width
40 | self.num_classes = num_classes
41 | self.layers = layers
42 | self.attention_layer_position = attention_layer_position
43 | self.eps = eps
44 | self.n_stats = n_stats
45 |
46 | @classmethod
47 | def from_dict(cls, json_object):
48 | """Constructs a `BigGANConfig` from a Python dictionary of parameters."""
49 | config = BigGANConfig()
50 | for key, value in json_object.items():
51 | config.__dict__[key] = value
52 | return config
53 |
54 | @classmethod
55 | def from_json_file(cls, json_file):
56 | """Constructs a `BigGANConfig` from a json file of parameters."""
57 | with open(json_file, "r", encoding='utf-8') as reader:
58 | text = reader.read()
59 | return cls.from_dict(json.loads(text))
60 |
61 | def __repr__(self):
62 | return str(self.to_json_string())
63 |
64 | def to_dict(self):
65 | """Serializes this instance to a Python dictionary."""
66 | output = copy.deepcopy(self.__dict__)
67 | return output
68 |
69 | def to_json_string(self):
70 | """Serializes this instance to a JSON string."""
71 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
72 |
--------------------------------------------------------------------------------
/model/utils/biggan_file_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for working with the local dataset cache.
3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4 | Copyright by the AllenNLP authors.
5 | """
6 | from __future__ import (absolute_import, division, print_function, unicode_literals)
7 |
8 | import json
9 | import logging
10 | import os
11 | import shutil
12 | import tempfile
13 | from functools import wraps
14 | from hashlib import sha256
15 | import sys
16 | from io import open
17 |
18 | import boto3
19 | import requests
20 | from botocore.exceptions import ClientError
21 | from tqdm import tqdm
22 |
23 | try:
24 | from urllib.parse import urlparse
25 | except ImportError:
26 | from urlparse import urlparse
27 |
28 | try:
29 | from pathlib import Path
30 | PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
31 | Path.home() / '.pytorch_pretrained_biggan'))
32 | except (AttributeError, ImportError):
33 | PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan'))
35 |
36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
37 |
38 |
39 | def url_to_filename(url, etag=None):
40 | """
41 | Convert `url` into a hashed filename in a repeatable way.
42 | If `etag` is specified, append its hash to the url's, delimited
43 | by a period.
44 | """
45 | url_bytes = url.encode('utf-8')
46 | url_hash = sha256(url_bytes)
47 | filename = url_hash.hexdigest()
48 |
49 | if etag:
50 | etag_bytes = etag.encode('utf-8')
51 | etag_hash = sha256(etag_bytes)
52 | filename += '.' + etag_hash.hexdigest()
53 |
54 | return filename
55 |
56 |
57 | def filename_to_url(filename, cache_dir=None):
58 | """
59 | Return the url and etag (which may be ``None``) stored for `filename`.
60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
61 | """
62 | if cache_dir is None:
63 | cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
65 | cache_dir = str(cache_dir)
66 |
67 | cache_path = os.path.join(cache_dir, filename)
68 | if not os.path.exists(cache_path):
69 | raise EnvironmentError("file {} not found".format(cache_path))
70 |
71 | meta_path = cache_path + '.json'
72 | if not os.path.exists(meta_path):
73 | raise EnvironmentError("file {} not found".format(meta_path))
74 |
75 | with open(meta_path, encoding="utf-8") as meta_file:
76 | metadata = json.load(meta_file)
77 | url = metadata['url']
78 | etag = metadata['etag']
79 |
80 | return url, etag
81 |
82 |
83 | def cached_path(url_or_filename, cache_dir=None):
84 | """
85 | Given something that might be a URL (or might be a local path),
86 | determine which. If it's a URL, download the file and cache it, and
87 | return the path to the cached file. If it's already a local path,
88 | make sure the file exists and then return the path.
89 | """
90 | if cache_dir is None:
91 | cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
93 | url_or_filename = str(url_or_filename)
94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
95 | cache_dir = str(cache_dir)
96 |
97 | parsed = urlparse(url_or_filename)
98 |
99 | if parsed.scheme in ('http', 'https', 's3'):
100 | # URL, so get it from the cache (downloading if necessary)
101 | return get_from_cache(url_or_filename, cache_dir)
102 | elif os.path.exists(url_or_filename):
103 | # File, and it exists.
104 | return url_or_filename
105 | elif parsed.scheme == '':
106 | # File, but it doesn't exist.
107 | raise EnvironmentError("file {} not found".format(url_or_filename))
108 | else:
109 | # Something unknown
110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
111 |
112 |
113 | def split_s3_path(url):
114 | """Split a full s3 path into the bucket name and path."""
115 | parsed = urlparse(url)
116 | if not parsed.netloc or not parsed.path:
117 | raise ValueError("bad s3 path {}".format(url))
118 | bucket_name = parsed.netloc
119 | s3_path = parsed.path
120 | # Remove '/' at beginning of path.
121 | if s3_path.startswith("/"):
122 | s3_path = s3_path[1:]
123 | return bucket_name, s3_path
124 |
125 |
126 | def s3_request(func):
127 | """
128 | Wrapper function for s3 requests in order to create more helpful error
129 | messages.
130 | """
131 |
132 | @wraps(func)
133 | def wrapper(url, *args, **kwargs):
134 | try:
135 | return func(url, *args, **kwargs)
136 | except ClientError as exc:
137 | if int(exc.response["Error"]["Code"]) == 404:
138 | raise EnvironmentError("file {} not found".format(url))
139 | else:
140 | raise
141 |
142 | return wrapper
143 |
144 |
145 | @s3_request
146 | def s3_etag(url):
147 | """Check ETag on S3 object."""
148 | s3_resource = boto3.resource("s3")
149 | bucket_name, s3_path = split_s3_path(url)
150 | s3_object = s3_resource.Object(bucket_name, s3_path)
151 | return s3_object.e_tag
152 |
153 |
154 | @s3_request
155 | def s3_get(url, temp_file):
156 | """Pull a file directly from S3."""
157 | s3_resource = boto3.resource("s3")
158 | bucket_name, s3_path = split_s3_path(url)
159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
160 |
161 |
162 | def http_get(url, temp_file):
163 | req = requests.get(url, stream=True)
164 | content_length = req.headers.get('Content-Length')
165 | total = int(content_length) if content_length is not None else None
166 | progress = tqdm(unit="B", total=total)
167 | for chunk in req.iter_content(chunk_size=1024):
168 | if chunk: # filter out keep-alive new chunks
169 | progress.update(len(chunk))
170 | temp_file.write(chunk)
171 | progress.close()
172 |
173 |
174 | def get_from_cache(url, cache_dir=None):
175 | """
176 | Given a URL, look for the corresponding dataset in the local cache.
177 | If it's not there, download it. Then return the path to the cached file.
178 | """
179 | if cache_dir is None:
180 | cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
182 | cache_dir = str(cache_dir)
183 |
184 | if not os.path.exists(cache_dir):
185 | os.makedirs(cache_dir)
186 |
187 | # Get eTag to add to filename, if it exists.
188 | if url.startswith("s3://"):
189 | etag = s3_etag(url)
190 | else:
191 | response = requests.head(url, allow_redirects=True)
192 | if response.status_code != 200:
193 | raise IOError("HEAD request failed for url {} with status code {}"
194 | .format(url, response.status_code))
195 | etag = response.headers.get("ETag")
196 |
197 | filename = url_to_filename(url, etag)
198 |
199 | # get cache path to put the file
200 | cache_path = os.path.join(cache_dir, filename)
201 |
202 | if not os.path.exists(cache_path):
203 | # Download to temporary file, then copy to cache dir once finished.
204 | # Otherwise you get corrupt cache entries if the download gets interrupted.
205 | with tempfile.NamedTemporaryFile() as temp_file:
206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
207 |
208 | # GET file object
209 | if url.startswith("s3://"):
210 | s3_get(url, temp_file)
211 | else:
212 | http_get(url, temp_file)
213 |
214 | # we are copying the file before closing it, so flush to avoid truncation
215 | temp_file.flush()
216 | # shutil.copyfileobj() starts at the current position, so go to the start
217 | temp_file.seek(0)
218 |
219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path)
220 | with open(cache_path, 'wb') as cache_file:
221 | shutil.copyfileobj(temp_file, cache_file)
222 |
223 | logger.info("creating metadata file for %s", cache_path)
224 | meta = {'url': url, 'etag': etag}
225 | meta_path = cache_path + '.json'
226 | with open(meta_path, 'w', encoding="utf-8") as meta_file:
227 | json.dump(meta, meta_file)
228 |
229 | logger.info("removing temp file %s", temp_file.name)
230 |
231 | return cache_path
232 |
233 |
234 | def read_set_from_file(filename):
235 | '''
236 | Extract a de-duped collection (set) of text from a file.
237 | Expected file format is one item per line.
238 | '''
239 | collection = set()
240 | with open(filename, 'r', encoding='utf-8') as file_:
241 | for line in file_:
242 | collection.add(line.rstrip())
243 | return collection
244 |
245 |
246 | def get_file_extension(path, dot=True, lower=True):
247 | ext = os.path.splitext(path)[1]
248 | ext = ext if dot else ext[1:]
249 | return ext.lower() if lower else ext
250 |
--------------------------------------------------------------------------------
/model/utils/custom_adam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class LREQAdam(Optimizer):
7 | def __init__(self, params, lr=1e-3, betas=(0.0, 0.99), eps=1e-8,
8 | weight_decay=0):
9 | beta_2 = betas[1]
10 | if not 0.0 <= lr:
11 | raise ValueError("Invalid learning rate: {}".format(lr))
12 | if not 0.0 <= eps:
13 | raise ValueError("Invalid epsilon value: {}".format(eps))
14 | if not 0.0 == betas[0]:
15 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
16 | if not 0.0 <= beta_2 < 1.0:
17 | raise ValueError("Invalid beta parameter at index 1: {}".format(beta_2))
18 | defaults = dict(lr=lr, beta_2=beta_2, eps=eps, weight_decay=weight_decay)
19 | super(LREQAdam, self).__init__(params, defaults)
20 |
21 | def __setstate__(self, state):
22 | super(LREQAdam, self).__setstate__(state)
23 |
24 | def step(self, closure=None):
25 | """Performs a single optimization step.
26 | Arguments:
27 | closure (callable, optional): A closure that reevaluates the model and returns the loss.
28 | """
29 | loss = None
30 | if closure is not None:
31 | loss = closure()
32 |
33 | for group in self.param_groups:
34 | for p in group['params']:
35 | if p.grad is None:
36 | continue
37 | grad = p.grad.data
38 | if grad.is_sparse:
39 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
40 |
41 | state = self.state[p]
42 |
43 | # State initialization
44 | if len(state) == 0:
45 | state['step'] = 0
46 | # Exponential moving average of gradient values
47 | # state['exp_avg'] = torch.zeros_like(p.data)
48 | # Exponential moving average of squared gradient values
49 | state['exp_avg_sq'] = torch.zeros_like(p.data)
50 |
51 | # exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
52 | exp_avg_sq = state['exp_avg_sq']
53 | beta_2 = group['beta_2']
54 |
55 | state['step'] += 1
56 |
57 | if group['weight_decay'] != 0:
58 | grad.add_(group['weight_decay'], p.data / p.coef)
59 |
60 | # Decay the first and second moment running average coefficient
61 | # exp_avg.mul_(beta1).add_(1 - beta1, grad)
62 | exp_avg_sq.mul_(beta_2).addcmul_(1 - beta_2, grad, grad)
63 | denom = exp_avg_sq.sqrt().add_(group['eps'])
64 |
65 | # bias_correction1 = 1 - beta1 ** state['step'] # 1
66 | bias_correction2 = 1 - beta_2 ** state['step']
67 | # step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
68 | step_size = group['lr'] * math.sqrt(bias_correction2)
69 |
70 | # p.data.addcdiv_(-step_size, exp_avg, denom)
71 | if hasattr(p, 'lr_equalization_coef'):
72 | step_size *= p.lr_equalization_coef
73 |
74 | p.data.addcdiv_(-step_size, grad, denom) # p = p + (-steo_size)/grad * demon
75 |
76 | return loss
77 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Improving generative adversarial network inversion via fine-tuning GAN encoders
2 |
3 | 
4 | 
5 | 
6 |
7 |
8 |
9 |
10 | >This is the official code for "Improving generative adversarial network inversion via
11 | fine-tuning GAN encoders".
12 |
13 | >The code contains a set of encoders that match pre-trained GANs (PGGAN, StyleGAN1, StyleGAN2, BigGAN). BTW, The project can match other GANs in the same way.
14 |
15 |
16 | ## Usage
17 |
18 |
19 | - training encoder with center attentions (scale for align images)
20 |
21 | > python E_align.py
22 |
23 | - training encoder with Grad-CAM-based attentions (scale for misalign images)
24 |
25 | > python E_mis_align.py
26 |
27 |
28 | - embedding real images to latent space (using StyleGANv1 and w).
29 |
30 | ```
31 | a. You can put real images at './checkpoint/realimg_file/' (default file as args.img_dir)
32 |
33 | b. You should load pre-trained Encoder at './checkpoint/E/E_blur(case2)_styleganv1_FFHQ_state_dict.pth'
34 |
35 | c. Then run:
36 | ```
37 |
38 | > python embedding_img.py
39 |
40 | - discovering attribute directions with latent space : embedded_img_processing.py
41 |
42 | Note: Pre-trained Model should be download first , and default save to './chechpoint/'
43 |
44 | ## Metric
45 |
46 | - validate performance (Pre-trained GANs and baseline)
47 |
48 | 1. using generations.py to generate reconstructed images (generate GANs images if needed)
49 | 2. Files in the directory "./baseline/" could help you to quickly format images and latent vectors (w).
50 | 3. Put comparing images to different files, and run comparing-baseline.py
51 |
52 |
53 | - ablation study : look at ''./ablations-study/''
54 |
55 |
56 | ## Setup
57 |
58 | ### Encoders
59 |
60 | - Case 1: Training most pre-trained GANs with encoders.
61 | at './model/E/E.py' (quickly converge for reconstructed GANs' image)
62 | - Case 2: Training StyleGANv1 on FFHQ for ablation study and real face image process
63 | at './model/E/E_Blur.py' (margin blur and more GPU memory for pixels gradients)
64 |
65 | ### Pre-Trained GANs
66 | > note: put pre-trained GANs weight file at ''./checkpoint/' directory
67 | - StyleGAN_V1 (should contain 3 files: Gm, Gs, center-tensor):
68 | - Cat 256:
69 | - ./checkpoint/stylegan_V1/cat/cat256_Gs_dict.pth
70 | - ./checkpoint/stylegan_V1/cat/cat256_Gm_dict.pth
71 | - ./checkpoint/stylegan_V1/cat/cat256_tensor.pt
72 | - Car 256: same above
73 | - Bedroom 256:
74 | - StyleGAN_V2 (Only one files : pth):
75 | - FFHQ 1024:
76 | - ./checkpoint/stylegan_V2/stylegan2_ffhq1024.pth
77 | - PGGAN ((Only one files : pth)):
78 | - Horse 256:
79 | - ./checkpoint/PGGAN/
80 | - BigGAN (Two files : model as .pt and config as .json ):
81 | - Image-Net 256:
82 | - ./checkpoint/biggan/256/G-256.pt
83 | - ./checkpoint/biggan/256/biggan-deep-256-config.json
84 |
85 | ### Options and Setting
86 | > note: different GANs should set different parameters carefully.
87 |
88 | - choose --mtype for StyleGANv1=1, StyleGANv2=2, PGGAN=3, BIGGAN=4
89 | - choose Encoder start_features (--z_dim) carefully, the value are: 16->1024x1024, 32->512x512, 64->256x256
90 | - if go on training, set --checkpoint_dir_E which path save pre-trained Encoder model
91 | - --checkpoint_dir_GAN is needed, StyleGANv1 is a directory(contains 3 filers: Gm, Gs, center-tensor) , others are file path (.pth or .pt)
92 | ```python
93 | parser = argparse.ArgumentParser(description='the training args')
94 | parser.add_argument('--iterations', type=int, default=210000) # epoch = iterations//30000
95 | parser.add_argument('--lr', type=float, default=0.0015)
96 | parser.add_argument('--beta_1', type=float, default=0.0)
97 | parser.add_argument('--batch_size', type=int, default=2)
98 | parser.add_argument('--experiment_dir', default=None) #None
99 | parser.add_argument('--checkpoint_dir_GAN', default='./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth or ./checkpoint/biggan/256/G-256.pt
100 | parser.add_argument('--config_dir', default='./checkpoint/biggan/256/biggan-deep-256-config.json') # BigGAN needs it
101 | parser.add_argument('--checkpoint_dir_E', default=None)
102 | parser.add_argument('--img_size',type=int, default=1024)
103 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
104 | parser.add_argument('--z_dim', type=int, default=512) # PGGAN , StyleGANs are 512. BIGGAN is 128
105 | parser.add_argument('--mtype', type=int, default=2) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
106 | parser.add_argument('--start_features', type=int, default=16) # 16->1024x1024, 32->512x512, 64->256x256, 128->128x128
107 | ```
108 |
109 | ## Pre-trained Model
110 |
111 | We offered pre-trainned GANs and their corresponding encoders here: [models](https://drive.google.com/drive/folders/1vqx5Sol04MAbeNLk9h0ouo8MiR3rJI4f?usp=sharing) (default setting is the case1 ).
112 |
113 | GANs:
114 |
115 | - StyleGANv1-(FFHQ1024, Car512, Cat256) models which contain 3 files Gm, Gs and center-tensor.
116 | - PGGAN and StyleGANv2. A single .pth file gets Gm, Gs and center-tensor together.
117 | - BigGAN 128x128 ,256x256, and 512x512: each type contain a config file and model (.pt)
118 |
119 | Encoders:
120 |
121 | - StyleGANv1 FFHQ (case 2) for real-image embedding and process.
122 | - StyleGANv2 LSUN Cat 256, they are one models from case 1 (Grad-CAM based attentions) and both models from case 2 (Grad-Cam based and Center-aligned Attentions for ablation study):
123 | - StyleGANv2 FFHQ (case 1)
124 | - Biggan-256 (case 1)
125 |
126 | If you want to try more GANs, cite more pre-trained GANs below:
127 |
128 |
129 | ## Acknowledgements
130 |
131 | Pre-trained GANs:
132 |
133 | > StyleGANv1: https://github.com/podgorskiy/StyleGan.git,
134 | > ( Converting code for official pre-trained model is here: https://github.com/podgorskiy/StyleGAN_Blobless.git)
135 | > StyleGANv2 and PGGAN: https://github.com/genforce/genforce.git
136 | > BigGAN: https://github.com/huggingface/pytorch-pretrained-BigGAN
137 |
138 | Comparing Works:
139 |
140 | > E2Style: https://github.com/wty-ustc/e2style
141 | > In-Domain GAN: https://github.com/genforce/idinvert_pytorch
142 | > pSp: https://github.com/eladrich/pixel2style2pixel
143 | > ALAE: https://github.com/podgorskiy/ALAE.git
144 |
145 |
146 | Related Works:
147 |
148 | > Grad-CAM & Grad-CAM++: https://github.com/yizt/Grad-CAM.pytorch
149 | > SSIM Index: https://github.com/Po-Hsun-Su/pytorch-ssim
150 |
151 | Our method implementation partly borrow from the above works (ALAE and Related Works). We would like to thank those authors.
152 |
153 |
154 | ## License
155 |
156 | The code of this repository is released under the [Apache 2.0](LICENSE) license.
The directories `models/biggan` and `models/stylegan2` are provided under the MIT license.
157 |
158 | ## 简体中文:
159 |
160 | 如何应用于[编辑人脸](./readme_cn.md)
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/readme_cn.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 用主流的GAN制作编码器,并编辑高清人脸图片
4 |
5 |
6 |
7 |
8 | 看到这几幅人脸动画图了吧,这些都是通过一张图片生成。本篇博客介绍一种简练的方法,即用Stylegan(V1)编辑图片并改变属性 (如让人脸微笑,带眼镜,转方向等).
9 |
10 | ## 1.加载StyleGAN (v1)
11 |
12 | 首先需要下载预训练模型,保重StyleGAN 1024x1024的人脸图像能生成 (即复现StyleGAN).
13 |
14 | 这里需要注意的是,预训练模型保存为三个部分:
15 |
16 | a.Gm, 将z [n, 512]映射为w [n, 18, 512]. 其中n是batch_size
17 |
18 | b.Gs, 输入w 输出对应图片, 注意这个w是每层输入一个[n,512]. 有18层, 即 [n,18,512]
19 |
20 | c.avg_tensor, 这个是一个训练好的常向量,用于模型的首次输入,[n, 512,4 , 4]. (而上面的w是从各个层单独输入)
21 |
22 | 这里提供StyleGANv1预训练模型 [Model](https://pan.baidu.com/s/1_JewahCd_UK5wIMCQzFAPA
23 | ) , 提取码: kwsk
24 |
25 | ## 2.将真实图片编码
26 |
27 | 通过一个我们提供的编码器将一张1024X1024的人脸图片编码到潜变量W (1,18,512) , 也可以同时处理多张人脸 (1->n),这个根据自己显卡内存大小决定。另外需要注意最好是五官对齐的人脸. 并把编好的W保存到文件夹(默认:./latentvectors/faces/). 里面已经有多张人脸了,可以用于测试上面的StyleGAN.
28 |
29 | 可以运行以下文件:
30 |
31 | > python embedding_img.py
32 |
33 | a.这个文件加载Encoder预训练模型
34 | > 默认路径为./checkpoint/E/E_blur(case2)_styleganv1_FFHQ_state_dict.pth
35 |
36 | b.并编码真实图像
37 | >默认路径为 ./checkpoint/realimg_file/)编码到w (默认路径为 ./result/models/
38 |
39 | c.这里提供Encoder预训练模型 [Model](https://pan.baidu.com/s/1F9Tv5ph9Rejp5JTQK2HSYQ
40 | ) , 提取码: swtl
41 |
42 | ## 3. 编辑表情
43 |
44 | 用'./latentvectors/directions/'中的属性向量乘以一个系数(5或10的倍数),加上人脸的向量,探索并保存人脸变化,参见:
45 |
46 | > embedded_img_processing.py
47 |
48 | 如果想要探索其他人脸属性,也很多方法,最简单有监督方法可以参考: https://github.com/Puzer/stylegan-encoder (里面也附带人脸对齐)
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/rec_real_img.py:
--------------------------------------------------------------------------------
1 | #reconstructing real_image by dirrectlly MTV E (no optimize E via StyleGANv1)
2 | #set arg.realimg_dir for embedding real images to latent vectors
3 |
4 | import os
5 | import math
6 | import torch
7 | import argparse
8 | import numpy as np
9 | import torchvision
10 | import model.E.E_Blur as BE
11 | from collections import OrderedDict
12 | from training_utils import *
13 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
14 | import model.stylegan2_generator as model_v2 #StyleGANv2
15 | import model.pggan.pggan_generator as model_pggan #PGGAN
16 | from model.biggan_generator import BigGAN #BigGAN
17 |
18 | if __name__ == "__main__":
19 |
20 | if not os.path.exists('./images'): os.mkdir('./images')
21 |
22 | parser = argparse.ArgumentParser(description='the training args')
23 | parser.add_argument('--batch_size', type=int, default=2)
24 | parser.add_argument('--experiment_dir', default=None) #None
25 | parser.add_argument('--checkpoint_dir_gan', default='./checkpoint/stylegan_v1/ffhq1024/') # stylegan_v2/stylegan2_ffhq1024.pth
26 | parser.add_argument('--checkpoint_dir_e', default='./checkpoint/E/E_styleganv1_state_dict.pth') #None or E_ffhq_styleganv2_modelv2_ep110000.pth E_ffhq_styleganv2_modelv1_ep85000.pth
27 | parser.add_argument('--config_dir', default=None)
28 | parser.add_argument('--realimg_dir', default='./images/real_images128/')
29 | parser.add_argument('--img_size',type=int, default=1024)
30 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
31 | parser.add_argument('--z_dim', type=int, default=512)
32 | parser.add_argument('--mtype', type=int, default=1) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
33 | parser.add_argument('--start_features', type=int, default=16)
34 | args = parser.parse_args()
35 |
36 | use_gpu = True
37 | device = torch.device("cuda" if use_gpu else "cpu")
38 |
39 | resultPath1_1 = "./images/imgs"
40 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1)
41 |
42 | resultPath1_2 = "./images/rec"
43 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2)
44 |
45 | #Load GANs
46 | type = args.mtype
47 | model_path = args.checkpoint_dir_gan
48 | config_path = args.config_dir
49 |
50 | if type == 1: # StyleGAN1, 1 diretory contains 3files(Gm, Gs, center-tensor)
51 | #model_path = './checkpoint/stylegan_v1/ffhq1024/'
52 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
53 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth'))
54 |
55 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
56 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth'))
57 |
58 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt')
59 | const_ = Gs.const
60 | const1 = const_.repeat(args.batch_size,1,1,1).cuda()
61 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024
62 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
63 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
64 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1]
65 |
66 | Gs.cuda()
67 | Gm.eval()
68 |
69 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
70 |
71 | else:
72 | print('error')
73 |
74 | #Load E
75 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) .to(device)
76 | if args.checkpoint_dir_e is not None:
77 | E.load_state_dict(torch.load(args.checkpoint_dir_e, map_location=torch.device(device)))
78 |
79 | type = args.mtype
80 | save_path = args.realimg_dir
81 | imgs_path = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith(".png") or f.endswith(".jpg")]
82 | img_size = args.img_size
83 |
84 | #PIL 2 Tensor
85 | transform = torchvision.transforms.Compose([
86 | #torchvision.transforms.CenterCrop(160),
87 | torchvision.transforms.Resize((img_size,img_size)),
88 | torchvision.transforms.ToTensor(),
89 | #torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
90 | ])
91 |
92 | images = []
93 | for idx, image_path in enumerate(imgs_path):
94 | img = Image.open(image_path).convert("RGB")
95 | img = transform(img)
96 | images.append(img)
97 |
98 | imgs_tensor = torch.stack(images, dim=0)
99 |
100 | for i, j in enumerate(imgs_tensor):
101 | imgs1 = j.unsqueeze(0).cuda() * 2 -1
102 | if type != 4:
103 | const2,w2 = E(imgs1)
104 | else:
105 | const2,w2 = E(imgs1, cond_vector)
106 |
107 | if type == 1:
108 | imgs2=Gs.forward(w2,int(math.log(args.img_size,2)-2))
109 | elif type == 2 or type == 3:
110 | imgs2=generator.synthesis(w2)['image']
111 | elif type == 4:
112 | imgs2, _=G(w2, conditions, truncation)
113 | else:
114 | print('model type error')
115 |
116 | # n_row = args.batch_size
117 | # test_img = torch.cat((imgs1[:n_row],imgs2[:n_row]))*0.5+0.5
118 | # torchvision.utils.save_image(test_img, './v2ep%d.jpg'%(args.seed),nrow=n_row) # nrow=3
119 | torchvision.utils.save_image(imgs1*0.5+0.5, resultPath1_1+'/%s_realimg.png'%str(i).rjust(5,'0'))
120 | torchvision.utils.save_image(imgs2*0.5+0.5, resultPath1_2+'/%s_mtv_rec.png'%str(i).rjust(5,'0'))
121 | print('doing:'+str(i).rjust(5,'0'))
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #pip install belows:
2 |
3 |
4 | python==3.7.3
5 |
6 | torch>=1.8
7 |
8 | torchvision
9 |
10 | tensorboardx
11 |
12 | lpips
13 |
14 | boto3
15 |
16 | skimage
17 |
18 | matplotlib
19 |
20 | scipy
21 |
22 | tqdm
23 |
24 | requests
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/synthesized_IMG.py:
--------------------------------------------------------------------------------
1 | # make styleganv1 and BigGAN synthesized images for validation. seeds should be different(0,30000)
2 |
3 | import os
4 | import math
5 | import torch
6 | import torchvision
7 | import model.E as BE #default styleganv1, if you will reconstruct other GANs, change corresponding E in here.
8 | import lpips
9 | import numpy as np
10 | import argparse
11 | import tensorboardX
12 | from collections import OrderedDict
13 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
14 | import model.stylegan2_generator as model_v2 #StyleGANv2
15 | import model.pggan.pggan_generator as model_pggan #PGGAN
16 | from model.biggan_generator import BigGAN #BigGAN
17 | from training_utils import *
18 | from model.utils.biggan_config import BigGANConfig
19 | import model.E.E_BIG as BE_BIG
20 |
21 | def train(tensor_writer = None, args = None):
22 | type = args.mtype
23 |
24 | model_path = args.checkpoint_dir_GAN
25 | config_path = args.config_dir
26 | if type == 1: # StyleGAN1
27 | #model_path = './checkpoint/stylegan_v1/ffhq1024/'
28 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
29 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth'))
30 |
31 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
32 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth'))
33 |
34 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt')
35 | const_ = Gs.const
36 | const1 = const_.repeat(args.batch_size,1,1,1).cuda()
37 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024
38 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
39 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
40 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1]
41 |
42 | Gs.cuda()
43 | Gm.eval()
44 |
45 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
46 |
47 | elif type == 2: # StyleGAN2
48 | #model_path = './checkpoint/stylegan_v2/stylegan2_ffhq1024.pth'
49 | generator = model_v2.StyleGAN2Generator(resolution=args.img_size).to(device)
50 | checkpoint = torch.load(model_path) #map_location='cpu'
51 | if 'generator_smooth' in checkpoint: #default
52 | generator.load_state_dict(checkpoint['generator_smooth'])
53 | else:
54 | generator.load_state_dict(checkpoint['generator'])
55 | synthesis_kwargs = dict(trunc_psi=0.7,trunc_layers=8,randomize_noise=False)
56 | #Gs = generator.synthesis
57 | #Gs.cuda()
58 | #Gm = generator.mapping
59 | #truncation = generator.truncation
60 | const_r = torch.randn(args.batch_size)
61 | const1 = generator.synthesis.early_layer(const_r) #[n,512,4,4]
62 |
63 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) # layer_count: 7->256 8->512 9->1024
64 |
65 | elif type == 3: # PGGAN
66 | #model_path = './checkpoint/PGGAN/pggan_horse256.pth'
67 | generator = model_pggan.PGGANGenerator(resolution=args.img_size).to(device)
68 | checkpoint = torch.load(model_path) #map_location='cpu'
69 | if 'generator_smooth' in checkpoint: #默认是这个
70 | generator.load_state_dict(checkpoint['generator_smooth'])
71 | else:
72 | generator.load_state_dict(checkpoint['generator'])
73 | const1 = torch.tensor(0)
74 |
75 | E = BE_PG.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3, pggan=True)
76 |
77 | elif type == 4:
78 | model_path = './checkpoint/biggan/256/G-256.pt'
79 | config_file = './checkpoint/biggan/256/biggan-deep-256-config.json'
80 | config = BigGANConfig.from_json_file(config_file)
81 | generator = BigGAN(config).to(device)
82 | generator.load_state_dict(torch.load(model_path))
83 |
84 | E = BE_BIG.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3, biggan=True).to(device)
85 |
86 | else:
87 | print('error')
88 | return
89 |
90 | #Load E
91 | # E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) .to(device)
92 | if args.checkpoint_dir_E is not None:
93 | E.load_state_dict(torch.load(args.checkpoint_dir_E, map_location=torch.device(device)))
94 |
95 | batch_size = args.batch_size
96 | it_d = 0
97 | for iteration in range(30000,30000+args.iterations):
98 | set_seed(iteration)
99 | z = torch.randn(batch_size, args.z_dim) #[32, 512]
100 | if type == 1:
101 | with torch.no_grad(): #这里需要生成图片和变量
102 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512]
103 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256
104 | elif type == 2:
105 | with torch.no_grad():
106 | #use generator
107 | result_all = generator(z.cuda(), **synthesis_kwargs)
108 | imgs1 = result_all['image']
109 | w1 = result_all['wp']
110 |
111 | elif type == 3:
112 | with torch.no_grad(): #这里需要生成图片和变量
113 | w1 = z.cuda()
114 | result_all = generator(w1)
115 | imgs1 = result_all['image']
116 | elif type == 4:
117 | z = truncated_noise_sample(truncation=0.4, batch_size=batch_size, seed=iteration%30000)
118 | #label = np.random.randint(1000,size=batch_size) # 生成标签
119 | flag = np.array(30)
120 | print(flag)
121 | label = np.ones(batch_size)
122 | label = flag * label
123 | label = one_hot(label)
124 | w1 = torch.tensor(z, dtype=torch.float).cuda()
125 | conditions = torch.tensor(label, dtype=torch.float).cuda() # as label
126 | truncation = torch.tensor(0.4, dtype=torch.float).cuda()
127 | with torch.no_grad(): #这里需要生成图片和变量
128 | imgs1, const1 = generator(w1, conditions, truncation) # const1 are conditional vectors in BigGAN
129 |
130 | if type != 4:
131 | const2,w2 = E(imgs1)
132 | else:
133 | const2,w2 = E(imgs1, const1)
134 |
135 | if type == 1:
136 | imgs2=Gs.forward(w2,int(math.log(args.img_size,2)-2))
137 | elif type == 2 or type == 3:
138 | imgs2=generator.synthesis(w2)['image']
139 | elif type == 4:
140 | imgs2, _=generator(w2, conditions, truncation)
141 | else:
142 | print('model type error')
143 | return
144 |
145 | imgs = torch.cat((imgs1,imgs2))
146 | torchvision.utils.save_image(imgs*0.5+0.5, resultPath1_1+'/id%s_%s.png'%(flag, str(iteration-30000).rjust(5,'0')), nrow=10) # nrow=3
147 | #torchvision.utils.save_image(imgs2*0.5+0.5, resultPath1_2+'/%s_styleganv1_rec.png'%str(iteration-30000).rjust(5,'0')) # nrow=3
148 |
149 | if __name__ == "__main__":
150 | parser = argparse.ArgumentParser(description='the training args')
151 | parser.add_argument('--iterations', type=int, default=5)
152 | parser.add_argument('--seed', type=int, default=30001) # training seeds: 0-30000; validated seeds > 30000
153 | parser.add_argument('--lr', type=float, default=0.0015)
154 | parser.add_argument('--beta_1', type=float, default=0.0)
155 | parser.add_argument('--batch_size', type=int, default=10)
156 | parser.add_argument('--experiment_dir', default=None) #None
157 | parser.add_argument('--checkpoint_dir_GAN', default='./checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth
158 | parser.add_argument('--config_dir', default=None) # BigGAN needs it
159 | parser.add_argument('--checkpoint_dir_E', default='./result/BigGAN-256/models/E_model_ep30000.pth')
160 | parser.add_argument('--img_size',type=int, default=256)
161 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
162 | parser.add_argument('--z_dim', type=int, default=128)
163 | parser.add_argument('--mtype', type=int, default=4) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
164 | parser.add_argument('--start_features', type=int, default=64) # 16->1024 32->512 64->256
165 | args = parser.parse_args()
166 |
167 | if not os.path.exists('./result'): os.mkdir('./result')
168 | resultPath = args.experiment_dir
169 | if resultPath == None:
170 | resultPath = "./result/BigGAN256_inversion_v4"
171 | if not os.path.exists(resultPath): os.mkdir(resultPath)
172 |
173 | resultPath1_1 = resultPath+"/generations"
174 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1)
175 |
176 | resultPath1_2 = resultPath+"/reconstructions"
177 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2)
178 |
179 | writer_path = os.path.join(resultPath, './summaries')
180 | if not os.path.exists(writer_path): os.mkdir(writer_path)
181 | writer = tensorboardX.SummaryWriter(writer_path)
182 |
183 | use_gpu = True
184 | device = torch.device("cuda" if use_gpu else "cpu")
185 |
186 | train(tensor_writer=writer, args = args)
187 |
--------------------------------------------------------------------------------
/synthesized_textBigGAN.py:
--------------------------------------------------------------------------------
1 | # make styleganv1 and BigGAN synthesized images for validation. seeds should be different seeds (0,30000)
2 |
3 | import os
4 | import math
5 | import torch
6 | import torchvision
7 | import model.E as BE #default styleganv1, if you will reconstruct other GANs, change corresponding E in here.
8 | import lpips
9 | import numpy as np
10 | import argparse
11 | import tensorboardX
12 | from collections import OrderedDict
13 | from model.stylegan1.net import Generator, Mapping #StyleGANv1
14 | import model.stylegan2_generator as model_v2 #StyleGANv2
15 | import model.pggan.pggan_generator as model_pggan #PGGAN
16 | from model.biggan_generator import BigGAN #BigGAN
17 | from training_utils import *
18 | from model.utils.biggan_config import BigGANConfig
19 | import model.E.E_BIG as BE_BIG
20 |
21 | def train(tensor_writer = None, args = None):
22 | type = args.mtype
23 |
24 | model_path = args.checkpoint_dir_GAN
25 | config_path = args.config_dir
26 | if type == 1: # StyleGAN1
27 | #model_path = './checkpoint/stylegan_v1/ffhq1024/'
28 | Gs = Generator(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
29 | Gs.load_state_dict(torch.load(model_path+'Gs_dict.pth'))
30 |
31 | Gm = Mapping(num_layers=int(math.log(args.img_size,2)-1)*2, mapping_layers=8, latent_size=512, dlatent_size=512, mapping_fmaps=512) #num_layers: 14->256 / 16->512 / 18->1024
32 | Gm.load_state_dict(torch.load(model_path+'Gm_dict.pth'))
33 |
34 | Gm.buffer1 = torch.load(model_path+'./center_tensor.pt')
35 | const_ = Gs.const
36 | const1 = const_.repeat(args.batch_size,1,1,1).cuda()
37 | layer_num = int(math.log(args.img_size,2)-1)*2 # 14->256 / 16 -> 512 / 18->1024
38 | layer_idx = torch.arange(layer_num)[np.newaxis, :, np.newaxis] # shape:[1,18,1], layer_idx = [0,1,2,3,4,5,6。。。,17]
39 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) # shape:[1,18,1], ones = [1,1,1,1,1,1,1,1]
40 | coefs = torch.where(layer_idx < layer_num//2, 0.7 * ones, ones) # 18个变量前8个裁剪比例truncation_psi [0.7,0.7,...,1,1,1]
41 |
42 | Gs.cuda()
43 | Gm.eval()
44 |
45 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3)
46 |
47 | elif type == 2: # StyleGAN2
48 | #model_path = './checkpoint/stylegan_v2/stylegan2_ffhq1024.pth'
49 | generator = model_v2.StyleGAN2Generator(resolution=args.img_size).to(device)
50 | checkpoint = torch.load(model_path) #map_location='cpu'
51 | if 'generator_smooth' in checkpoint: #default
52 | generator.load_state_dict(checkpoint['generator_smooth'])
53 | else:
54 | generator.load_state_dict(checkpoint['generator'])
55 | synthesis_kwargs = dict(trunc_psi=0.7,trunc_layers=8,randomize_noise=False)
56 | #Gs = generator.synthesis
57 | #Gs.cuda()
58 | #Gm = generator.mapping
59 | #truncation = generator.truncation
60 | const_r = torch.randn(args.batch_size)
61 | const1 = generator.synthesis.early_layer(const_r) #[n,512,4,4]
62 |
63 | E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) # layer_count: 7->256 8->512 9->1024
64 |
65 | elif type == 3: # PGGAN
66 | #model_path = './checkpoint/PGGAN/pggan_horse256.pth'
67 | generator = model_pggan.PGGANGenerator(resolution=args.img_size).to(device)
68 | checkpoint = torch.load(model_path) #map_location='cpu'
69 | if 'generator_smooth' in checkpoint: #默认是这个
70 | generator.load_state_dict(checkpoint['generator_smooth'])
71 | else:
72 | generator.load_state_dict(checkpoint['generator'])
73 | const1 = torch.tensor(0)
74 |
75 | E = BE_PG.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3, pggan=True)
76 |
77 | elif type == 4:
78 | model_path = './checkpoint/biggan/256/G-256.pt'
79 | config_file = './checkpoint/biggan/256/biggan-deep-256-config.json'
80 | config = BigGANConfig.from_json_file(config_file)
81 | generator = BigGAN(config).to(device)
82 | generator.load_state_dict(torch.load(model_path))
83 |
84 | E = BE_BIG.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3, biggan=True).to(device)
85 |
86 | else:
87 | print('error')
88 | return
89 |
90 | #Load E
91 | # E = BE.BE(startf=args.start_features, maxf=512, layer_count=int(math.log(args.img_size,2)-1), latent_size=512, channels=3) .to(device)
92 | if args.checkpoint_dir_E is not None:
93 | E.load_state_dict(torch.load(args.checkpoint_dir_E, map_location=torch.device(device)))
94 |
95 | batch_size = args.batch_size
96 | it_d = 0
97 | for iteration in range(30000,30000+args.iterations):
98 | set_seed(iteration)
99 | z = torch.randn(batch_size, args.z_dim) #[32, 512]
100 | if type == 1:
101 | with torch.no_grad(): #这里需要生成图片和变量
102 | w1 = Gm(z,coefs_m=coefs).cuda() #[batch_size,18,512]
103 | imgs1 = Gs.forward(w1,int(math.log(args.img_size,2)-2)) # 7->512 / 6->256
104 | elif type == 2:
105 | with torch.no_grad():
106 | #use generator
107 | result_all = generator(z.cuda(), **synthesis_kwargs)
108 | imgs1 = result_all['image']
109 | w1 = result_all['wp']
110 |
111 | elif type == 3:
112 | with torch.no_grad(): #这里需要生成图片和变量
113 | w1 = z.cuda()
114 | result_all = generator(w1)
115 | imgs1 = result_all['image']
116 | elif type == 4:
117 | z = truncated_noise_sample(truncation=0.4, batch_size=batch_size, seed=iteration%30000)
118 | #label = np.random.randint(1000,size=batch_size) # 生成标签
119 | flag = np.array(726)
120 | print(flag)
121 | label = np.ones(batch_size)
122 | label = flag * label
123 | label = one_hot(label)
124 | w1 = torch.tensor(z, dtype=torch.float).cuda()
125 | conditions = torch.tensor(label, dtype=torch.float).cuda() # as label
126 | truncation = torch.tensor(0.4, dtype=torch.float).cuda()
127 | with torch.no_grad(): #这里需要生成图片和变量
128 | imgs1, const1 = generator(w1, conditions, truncation) # const1 are conditional vectors in BigGAN
129 |
130 | if type != 4:
131 | const2,w2 = E(imgs1)
132 | else:
133 | const2,w2 = E(imgs1, const1)
134 |
135 | if type == 1:
136 | imgs2=Gs.forward(w2,int(math.log(args.img_size,2)-2))
137 | elif type == 2 or type == 3:
138 | imgs2=generator.synthesis(w2)['image']
139 | elif type == 4:
140 | imgs2, _=generator(w2, conditions, truncation)
141 | else:
142 | print('model type error')
143 | return
144 |
145 | imgs = torch.cat((imgs1,imgs2))
146 | torchvision.utils.save_image(imgs*0.5+0.5, resultPath1_1+'/id%s_%s.png'%(flag, str(iteration-30000).rjust(5,'0')), nrow=10) # nrow=3
147 | #torchvision.utils.save_image(imgs2*0.5+0.5, resultPath1_2+'/%s_styleganv1_rec.png'%str(iteration-30000).rjust(5,'0')) # nrow=3
148 |
149 | if __name__ == "__main__":
150 | parser = argparse.ArgumentParser(description='the training args')
151 | parser.add_argument('--iterations', type=int, default=3)
152 | parser.add_argument('--seed', type=int, default=30001) # training seeds: 0-30000; validated seeds > 30000
153 | parser.add_argument('--lr', type=float, default=0.0015)
154 | parser.add_argument('--beta_1', type=float, default=0.0)
155 | parser.add_argument('--batch_size', type=int, default=10)
156 | parser.add_argument('--experiment_dir', default=None) #None
157 | parser.add_argument('--checkpoint_dir_GAN', default='./checkpoint/stylegan_v1/ffhq1024/') #None ./checkpoint/stylegan_v1/ffhq1024/ or ./checkpoint/stylegan_v2/stylegan2_ffhq1024.pth
158 | parser.add_argument('--config_dir', default=None) # BigGAN needs it
159 | parser.add_argument('--checkpoint_dir_E', default='./result/BigGAN-256/models/E_model_ep30000.pth')
160 | parser.add_argument('--img_size',type=int, default=256)
161 | parser.add_argument('--img_channels', type=int, default=3)# RGB:3 ,L:1
162 | parser.add_argument('--z_dim', type=int, default=128)
163 | parser.add_argument('--mtype', type=int, default=4) # StyleGANv1=1, StyleGANv2=2, PGGAN=3, BigGAN=4
164 | parser.add_argument('--start_features', type=int, default=64) # 16->1024 32->512 64->256
165 | args = parser.parse_args()
166 |
167 | if not os.path.exists('./result'): os.mkdir('./result')
168 | resultPath = args.experiment_dir
169 | if resultPath == None:
170 | resultPath = "./result/BigGAN256_inversion_v4"
171 | if not os.path.exists(resultPath): os.mkdir(resultPath)
172 |
173 | resultPath1_1 = resultPath+"/generations"
174 | if not os.path.exists(resultPath1_1): os.mkdir(resultPath1_1)
175 |
176 | resultPath1_2 = resultPath+"/reconstructions"
177 | if not os.path.exists(resultPath1_2): os.mkdir(resultPath1_2)
178 |
179 | writer_path = os.path.join(resultPath, './summaries')
180 | if not os.path.exists(writer_path): os.mkdir(writer_path)
181 | writer = tensorboardX.SummaryWriter(writer_path)
182 |
183 | use_gpu = True
184 | device = torch.device("cuda" if use_gpu else "cpu")
185 |
186 | train(tensor_writer=writer, args = args)
187 |
--------------------------------------------------------------------------------
/training_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.stats import truncnorm
4 | import metric.pytorch_ssim as pytorch_ssim
5 | from torch.nn import functional as F
6 | from PIL import Image
7 | import torchvision
8 |
9 | #img_path2tensor
10 | loader = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
11 | def imgPath2loader(image_name,size):
12 | image = Image.open(image_name).convert('RGB')
13 | image = image.resize((size,size))
14 | image = loader(image)#.unsqueeze(0)
15 | return image.to(torch.float)
16 |
17 | def get_parameter_number(net):
18 | total_num = sum(p.numel() for p in net.parameters())
19 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
20 | return {'Total': total_num, 'Trainable': trainable_num}
21 |
22 | def get_para_GByte(parameter_number):
23 | x=parameter_number['Total']*8/1024/1024/1024
24 | y=parameter_number['Total']*8/1024/1024/1024
25 | return {'Total_GB': x, 'Trainable_BG': y}
26 |
27 | def one_hot(x, class_count=1000):
28 | # 第一构造一个[class_count, class_count]的对角线为1的向量
29 | # 第二保留label对应的行并返回
30 | return torch.eye(class_count)[x,:]
31 |
32 | def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None):
33 | """ Create a truncated noise vector.
34 | Params:
35 | batch_size: batch size.
36 | dim_z: dimension of z
37 | truncation: truncation value to use
38 | seed: seed for the random generator
39 | Output:
40 | array of shape (batch_size, dim_z)
41 | """
42 | state = None if seed is None else np.random.RandomState(seed)
43 | values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
44 | return truncation * values
45 |
46 | def set_seed(seed): #随机数设置
47 | np.random.seed(seed)
48 | #random.seed(seed)
49 | torch.manual_seed(seed) # cpu
50 | torch.cuda.manual_seed_all(seed) # gpu
51 | torch.backends.cudnn.deterministic = True
52 | #torch.backends.cudnn.benchmark = False
53 |
54 | def space_loss(imgs1,imgs2,image_space=True,lpips_model=None):
55 | loss_mse = torch.nn.MSELoss()
56 | loss_kl = torch.nn.KLDivLoss()
57 | ssim_loss = pytorch_ssim.SSIM()
58 | loss_lpips = lpips_model
59 |
60 | imgs1 = imgs1.contiguous()
61 | imgs2 = imgs2.contiguous()
62 |
63 | loss_imgs_mse_1 = loss_mse(imgs1,imgs2)
64 | loss_imgs_mse_2 = loss_mse(imgs1.mean(),imgs2.mean())
65 | loss_imgs_mse_3 = loss_mse(imgs1.std(),imgs2.std())
66 | loss_imgs_mse = loss_imgs_mse_1 #+ loss_imgs_mse_2 + loss_imgs_mse_3
67 |
68 | imgs1_kl, imgs2_kl = torch.nn.functional.softmax(imgs1),torch.nn.functional.softmax(imgs2)
69 | loss_imgs_kl = loss_kl(torch.log(imgs2_kl),imgs1_kl) #D_kl(True=y1_imgs||Fake=y2_imgs)
70 | loss_imgs_kl = torch.where(torch.isnan(loss_imgs_kl),torch.full_like(loss_imgs_kl,0), loss_imgs_kl)
71 | loss_imgs_kl = torch.where(torch.isinf(loss_imgs_kl),torch.full_like(loss_imgs_kl,1), loss_imgs_kl)
72 |
73 | imgs1_cos = imgs1.view(-1)
74 | imgs2_cos = imgs2.view(-1)
75 | loss_imgs_cosine = 1 - imgs1_cos.dot(imgs2_cos)/(torch.sqrt(imgs1_cos.dot(imgs1_cos))*torch.sqrt(imgs2_cos.dot(imgs2_cos))) #[-1,1],-1:反向相反,1:方向相同
76 |
77 | if imgs1.view(-1).shape[0] != imgs2.view(-1).shape[0]:
78 | print('error: vector1 dimentions are not equal to vector2 dimentions')
79 | return
80 |
81 | if image_space:
82 | while imgs1.shape[2] > 256:
83 | imgs1 = F.avg_pool2d(imgs1,2,2)
84 | imgs2 = F.avg_pool2d(imgs2,2,2)
85 |
86 | if image_space:
87 | ssim_value = pytorch_ssim.ssim(imgs1, imgs2) # while ssim_value<0.999:
88 | loss_imgs_ssim = 1-ssim_loss(imgs1, imgs2)
89 | else:
90 | loss_imgs_ssim = torch.tensor(0)
91 |
92 | if image_space:
93 | loss_imgs_lpips = loss_lpips(imgs1,imgs2).mean()
94 | else:
95 | loss_imgs_lpips = torch.tensor(0)
96 |
97 | loss_imgs = 5*loss_imgs_mse + 3*loss_imgs_cosine + loss_imgs_ssim + 2*loss_imgs_lpips # loss_imgs_kl
98 | loss_info = [[loss_imgs_mse_1.item(),loss_imgs_mse_2.item(),loss_imgs_mse_3.item()], loss_imgs_kl.item(), loss_imgs_cosine.item(), loss_imgs_ssim.item(), loss_imgs_lpips.item()]
99 | return loss_imgs, loss_info
--------------------------------------------------------------------------------