├── .gitignore ├── LICENSE ├── ReadMe.md ├── assets ├── ldh.png ├── net.png └── xcaq.gif ├── data ├── CCNLoader.py ├── DataLoader.py └── TTNLoader.py ├── inference.py ├── model ├── Pix2PixModule │ ├── config.py │ ├── loss.py │ ├── model.py │ └── module.py └── styleganModule │ ├── arcface.py │ ├── config.py │ ├── loss.py │ ├── model.py │ ├── op │ ├── __init__.py │ ├── conv2d_gradfix.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu │ └── utils.py ├── run.sh ├── train.py ├── trainer ├── CCNTrainer.py ├── ModelTrainer.py └── TTNTrainer.py └── utils ├── download_weight.sh ├── get_face_expression.py ├── get_tcc_input.py ├── utils.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | pretrain_models 3 | *checkpoint* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 LeslieZhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | # DCT-NET.Pytorch 2 | unofficial implementation of DCT-Net: Domain-Calibrated Translation for Portrait Stylization.
3 | you can find official version [here](https://github.com/menyifang/DCT-Net) 4 | ![](assets/net.png) 5 | 6 | ## show 7 | ![img](assets/ldh.png) 8 | ![video](assets/xcaq.gif) 9 | 10 | ## environment 11 | you can build your environment follow [this](https://github.com/rosinality/stylegan2-pytorch)
12 | ```pip install tensorboardX ``` for show 13 | 14 | ## how to run 15 | ### train 16 | download pretrain weights
17 | 18 | ```shell 19 | cd utils 20 | bash download_weight.sh 21 | ``` 22 | follow [rosinality/stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch) and put 550000.pt in pretrain_models 23 | #### CCN 24 | 1. prepare the style pictures and align them
25 | the image path is like this
26 | style-photos/
27 | |-- 000000.png
28 | |-- 000006.png
29 | |-- 000010.png
30 | |-- 000011.png
31 | |-- 000015.png
32 | |-- 000028.png
33 | |-- 000039.png
34 | 2. change your own path in [ccn_config](./model/styleganModule/config.py#L7) 35 | 3. train ccn
36 | 37 | ```shell 38 | # single gpu 39 | python train.py \ 40 | --model ccn \ 41 | --batch_size 16 \ 42 | --checkpoint_path checkpoint \ 43 | --lr 0.002 \ 44 | --print_interval 100 \ 45 | --save_interval 100 --dist 46 | ``` 47 | 48 | ```shell 49 | # multi gpu 50 | python -m torch.distributed.launch train.py \ 51 | --model ccn \ 52 | --batch_size 16 \ 53 | --checkpoint_path checkpoint \ 54 | --lr 0.002 \ 55 | --print_interval 100 \ 56 | --save_interval 100 57 | ``` 58 | almost 1000 steps, you can stop 59 | #### TTN 60 | 1. prepare expression information
61 | you can follow [LVT](https://github.com/LeslieZhoa/LVT) to estimate facial landmark
62 | ```shell 63 | cd utils 64 | python get_face_expression.py \ 65 | --img_base '' # your real image path base,like ffhq \ 66 | --pool_num 2 # multiprocess number \ 67 | --LVT '' # the LVT path you put \ 68 | --train # train data or val data 69 | ``` 70 | 2. prepare your generator image
71 | ```shell 72 | cd utils 73 | python get_tcc_input.py \ 74 | --model_path '' # ccn model path \ 75 | --output_path '' # save path 76 | ``` 77 | __select almost 5k~1w good image manually__ 78 | 3. change your own path in [ttn_config](./model/Pix2PixModule/config.py#21) 79 | ```shell 80 | # like 81 | self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img' 82 | self.train_tgt_root = '/StyleTransform/DATA/select-style-gan' 83 | self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img' 84 | self.val_tgt_root = '/StyleTransform/DATA/select-style-gan' 85 | ``` 86 | 4. train tnn 87 | ```shell 88 | # like ccn single and multi gpus 89 | python train.py \ 90 | --model ttn \ 91 | --batch_size 64 \ 92 | --checkpoint_path checkpoint \ 93 | --lr 2e-4 \ 94 | --print_interval 100 \ 95 | --save_interval 100 \ 96 | --dist 97 | ``` 98 | ## inference 99 | you can follow inference.py to put your own ttn model path and image path
100 | ```python inference.py``` 101 | 102 | ## Credits 103 | SEAN model and implementation:
104 | https://github.com/ZPdesu/SEAN Copyright © 2020, ZPdesu.
105 | License https://github.com/ZPdesu/SEAN/blob/master/LICENSE.md 106 | 107 | stylegan2-pytorch model and implementation:
108 | https://github.com/rosinality/stylegan2-pytorch Copyright © 2019, rosinality.
109 | License https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE 110 | 111 | White-box-Cartoonization model and implementation:
112 | https://github.com/SystemErrorWang/White-box-Cartoonization Copyright © 2020, SystemErrorWang.
113 | 114 | White-box-Cartoonization model pytorch model and implementation:
115 | https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch Copyright © 2022, vinesmsuic.
116 | License https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch/blob/main/LICENSE 117 | 118 | arcface pytorch model pytorch model and implementation:
119 | https://github.com/ronghuaiyang/arcface-pytorch Copyright © 2018, ronghuaiyang.
120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /assets/ldh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeslieZhoa/DCT-NET.Pytorch/bedafe430de2c92d21cea0d587a78d6cb06292e7/assets/ldh.png -------------------------------------------------------------------------------- /assets/net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeslieZhoa/DCT-NET.Pytorch/bedafe430de2c92d21cea0d587a78d6cb06292e7/assets/net.png -------------------------------------------------------------------------------- /assets/xcaq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeslieZhoa/DCT-NET.Pytorch/bedafe430de2c92d21cea0d587a78d6cb06292e7/assets/xcaq.gif -------------------------------------------------------------------------------- /data/CCNLoader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author LeslieZhao 5 | @date 20220721 6 | ''' 7 | 8 | import os 9 | 10 | from torchvision import transforms 11 | import PIL.Image as Image 12 | from data.DataLoader import DatasetBase 13 | import random 14 | 15 | 16 | class CCNData(DatasetBase): 17 | def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs): 18 | super().__init__(slice_id, slice_count,dist, **kwargs) 19 | 20 | 21 | self.transform = transforms.Compose([ 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 25 | ]) 26 | 27 | root = kwargs['root'] 28 | self.paths = [os.path.join(root,f) for f in os.listdir(root)] 29 | self.length = len(self.paths) 30 | random.shuffle(self.paths) 31 | 32 | def __getitem__(self,i): 33 | idx = i % self.length 34 | img_path = self.paths[idx] 35 | 36 | with Image.open(img_path) as img: 37 | Img = self.transform(img) 38 | 39 | return Img 40 | 41 | 42 | def __len__(self): 43 | return max(100000,self.length) 44 | # return 4 45 | 46 | -------------------------------------------------------------------------------- /data/DataLoader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author LeslieZhao 5 | @date 20220721 6 | ''' 7 | 8 | 9 | from torch.utils.data import Dataset 10 | import torch.distributed as dist 11 | 12 | 13 | class DatasetBase(Dataset): 14 | def __init__(self,slice_id=0,slice_count=1,use_dist=False,**kwargs): 15 | 16 | if use_dist: 17 | slice_id = dist.get_rank() 18 | slice_count = dist.get_world_size() 19 | self.id = slice_id 20 | self.count = slice_count 21 | 22 | 23 | def __getitem__(self,i): 24 | pass 25 | 26 | 27 | 28 | 29 | def __len__(self): 30 | return 1000 31 | 32 | -------------------------------------------------------------------------------- /data/TTNLoader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author LeslieZhao 5 | @date 20220721 6 | ''' 7 | import os 8 | from torchvision import transforms 9 | import PIL.Image as Image 10 | from data.DataLoader import DatasetBase 11 | import random 12 | import numpy as np 13 | import torch 14 | 15 | 16 | class TTNData(DatasetBase): 17 | def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs): 18 | super().__init__(slice_id, slice_count,dist, **kwargs) 19 | 20 | 21 | self.transform = transforms.Compose([ 22 | transforms.Resize([256,256]), 23 | transforms.RandomResizedCrop(256,scale=(0.8,1.2)), 24 | transforms.RandomRotation(degrees=(-90,90)), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 28 | ]) 29 | 30 | if kwargs['eval']: 31 | self.transform = transforms.Compose([ 32 | transforms.Resize([256,256]), 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 35 | self.length = 100 36 | 37 | src_root = kwargs['src_root'] 38 | tgt_root = kwargs['tgt_root'] 39 | 40 | self.src_paths = [os.path.join(src_root,f) for f in os.listdir(src_root) if f.endswith('.png')] 41 | self.tgt_paths = [os.path.join(tgt_root,f) for f in os.listdir(tgt_root) if f.endswith('.png')] 42 | self.src_length = len(self.src_paths) 43 | self.tgt_length = len(self.tgt_paths) 44 | random.shuffle(self.src_paths) 45 | random.shuffle(self.tgt_paths) 46 | 47 | self.mx_left_eye_all,\ 48 | self.mn_left_eye_all,\ 49 | self.mx_right_eye_all,\ 50 | self.mn_right_eye_all,\ 51 | self.mx_lip_all,\ 52 | self.mn_lip_all = \ 53 | np.load(kwargs['score_info']) 54 | 55 | def __getitem__(self,i): 56 | src_idx = i % self.src_length 57 | tgt_idx = i % self.tgt_length 58 | 59 | src_path = self.src_paths[src_idx] 60 | tgt_path = self.tgt_paths[tgt_idx] 61 | exp_path = src_path.replace('img','express')[:-3] + 'npy' 62 | 63 | with Image.open(src_path) as img: 64 | srcImg = self.transform(img) 65 | 66 | with Image.open(tgt_path) as img: 67 | tgtImg = self.transform(img) 68 | 69 | score = np.load(exp_path) 70 | score[0] = (score[0] - self.mn_left_eye_all) / (self.mx_left_eye_all - self.mn_left_eye_all) 71 | score[1] = (score[1] - self.mn_right_eye_all) / (self.mx_right_eye_all - self.mn_right_eye_all) 72 | score[2] = (score[2] - self.mn_lip_all) / (self.mx_lip_all - self.mn_lip_all) 73 | score = torch.from_numpy(score.astype(np.float32)) 74 | 75 | return srcImg,tgtImg,score 76 | 77 | 78 | def __len__(self): 79 | # return max(self.src_length,self.tgt_length) 80 | if hasattr(self,'length'): 81 | return self.length 82 | else: 83 | return 10000 84 | 85 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import torch 5 | from model.Pix2PixModule.model import Generator 6 | from utils.utils import convert_img 7 | 8 | class Infer: 9 | def __init__(self,model_path): 10 | self.net = Generator(img_channels=3) 11 | self.load_checkpoint(model_path) 12 | 13 | 14 | def run(self,img): 15 | if isinstance(img,str): 16 | img = cv2.imread(img) 17 | inp = self.preprocess(img) 18 | with torch.no_grad(): 19 | xg = self.net(inp) 20 | oup = self.postprocess(xg[0]) 21 | return oup 22 | 23 | def load_checkpoint(self,path): 24 | ckpt = torch.load(path, map_location=lambda storage, loc: storage) 25 | self.net.load_state_dict(ckpt['netG'],strict=False) 26 | if torch.cuda.is_available(): 27 | self.net.cuda() 28 | self.net.eval() 29 | 30 | def preprocess(self,img): 31 | 32 | img = (img[...,::-1] / 255.0 - 0.5) * 2 33 | img = img.transpose(2,0,1)[np.newaxis,:].astype(np.float32) 34 | img = torch.from_numpy(img) 35 | if torch.cuda.is_available(): 36 | img = img.cuda() 37 | return img 38 | def postprocess(self,img): 39 | img = convert_img(img,unit=True) 40 | return img.permute(1,2,0).cpu().numpy()[...,::-1] 41 | 42 | 43 | 44 | if __name__ == "__main__": 45 | 46 | path = 'pretrain_models/final.pth' 47 | model = Infer(path) 48 | 49 | img = cv2.imread('') 50 | 51 | img_h,img_w,_ = img.shape 52 | n_h,n_w = img_h // 8 * 8,img_w // 8 * 8 53 | img = cv2.resize(img,(n_w,n_h)) 54 | 55 | oup = model.run(img) 56 | cv2.imwrite('output.png',oup) 57 | 58 | 59 | -------------------------------------------------------------------------------- /model/Pix2PixModule/config.py: -------------------------------------------------------------------------------- 1 | class Params: 2 | def __init__(self): 3 | 4 | self.name = 'Pix2Pix' 5 | 6 | self.pretrain_path = None 7 | self.vgg_model = 'pretrain_models/vgg19-dcbb9e9d.pth' 8 | self.lr = 2e-4 9 | self.beta1 = 0.5 10 | self.beta2 = 0.99 11 | 12 | self.use_exp = True 13 | self.lambda_surface = 2.0 14 | self.lambda_texture = 2.0 15 | self.lambda_content = 200 16 | self.lambda_tv = 1e4 17 | 18 | self.lambda_exp = 1.0 19 | 20 | 21 | self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img' 22 | self.train_tgt_root = '/StyleTransform/DATA/select-style-gan' 23 | self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img' 24 | self.val_tgt_root = '/StyleTransform/DATA/select-style-gan' 25 | self.score_info = 'pretrain_models/all_express_mean.npy' 26 | 27 | self.infer_batch_size = 2 28 | 29 | -------------------------------------------------------------------------------- /model/Pix2PixModule/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .module import * 5 | 6 | from torch.autograd import Variable 7 | # Perceptual loss that uses a pretrained VGG network 8 | class VGGLoss(nn.Module): 9 | def __init__(self,model_path): 10 | super(VGGLoss, self).__init__() 11 | self.vgg = VGG19(in_channels=3, 12 | VGGtype='VGG19', 13 | init_weights=model_path, 14 | batch_norm=False, feature_mode=True) 15 | self.vgg.eval() 16 | self.criterion = nn.L1Loss() 17 | 18 | 19 | def forward(self, x, y): 20 | 21 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 22 | _, c, h, w = x_vgg.shape 23 | loss = self.criterion(x_vgg,y_vgg) 24 | return loss *255 / (c*h*w) 25 | 26 | 27 | 28 | 29 | class TVLoss(nn.Module): 30 | def __init__(self, k_size): 31 | super().__init__() 32 | self.k_size = k_size 33 | 34 | def forward(self, image): 35 | b, c, h, w = image.shape 36 | tv_h = torch.mean((image[:, :, self.k_size:, :] - image[:, :, : -self.k_size, :])**2) 37 | tv_w = torch.mean((image[:, :, :, self.k_size:] - image[:, :, :, : -self.k_size])**2) 38 | tv_loss = (tv_h + tv_w) / (3 * h * w) 39 | return tv_loss.mean() 40 | 41 | 42 | -------------------------------------------------------------------------------- /model/Pix2PixModule/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, channels, kernel_size, stride, padding, padding_mode): 7 | super().__init__() 8 | self.block = nn.Sequential( 9 | nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode), 10 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 11 | nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode), 12 | ) 13 | 14 | def forward(self, x): 15 | #Elementwise Sum (ES) 16 | return x + self.block(x) 17 | 18 | class Generator(nn.Module): 19 | def __init__(self, img_channels=3, num_features=32, num_residuals=4, padding_mode="zeros"): 20 | super().__init__() 21 | self.padding_mode = padding_mode 22 | 23 | self.initial_down = nn.Sequential( 24 | #k7n32s1 25 | nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode=self.padding_mode), 26 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 27 | ) 28 | 29 | #Down-convolution 30 | self.down1 = nn.Sequential( 31 | #k3n32s2 32 | nn.Conv2d(num_features, num_features, kernel_size=3, stride=2, padding=1, padding_mode=self.padding_mode), 33 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 34 | 35 | #k3n64s1 36 | nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode), 37 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 38 | ) 39 | 40 | self.down2 = nn.Sequential( 41 | #k3n64s2 42 | nn.Conv2d(num_features*2, num_features*2, kernel_size=3, stride=2, padding=1, padding_mode=self.padding_mode), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | 45 | #k3n128s1 46 | nn.Conv2d(num_features*2, num_features*4, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode), 47 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 48 | ) 49 | 50 | #Bottleneck: 4 residual blocks => 4 times [K3n128s1] 51 | self.res_blocks = nn.Sequential( 52 | *[ResidualBlock(num_features*4, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode) for _ in range(num_residuals)] 53 | ) 54 | 55 | #Up-convolution 56 | self.up1 = nn.Sequential( 57 | #k3n128s1 (should be k3n64s1?) 58 | nn.Conv2d(num_features*4, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode), 59 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 60 | ) 61 | 62 | self.up2 = nn.Sequential( 63 | #k3n64s1 64 | nn.Conv2d(num_features*2, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode), 65 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 66 | #k3n64s1 (should be k3n32s1?) 67 | nn.Conv2d(num_features*2, num_features, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode), 68 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 69 | ) 70 | 71 | self.last = nn.Sequential( 72 | #k3n32s1 73 | nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode), 74 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 75 | #k7n3s1 76 | nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode=self.padding_mode) 77 | ) 78 | 79 | def forward(self, x): 80 | x1 = self.initial_down(x) 81 | x2 = self.down1(x1) 82 | x = self.down2(x2) 83 | x = self.res_blocks(x) 84 | x = self.up1(x) 85 | #Resize Bilinear 86 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners = False) 87 | x = self.up2(x + x2) 88 | #Resize Bilinear 89 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners = False) 90 | x = self.last(x + x1) 91 | #TanH 92 | return torch.tanh(x) 93 | 94 | 95 | 96 | from torch.nn.utils import spectral_norm 97 | 98 | # PyTorch implementation by vinesmsuic 99 | # Referenced from official tensorflow implementation: https://github.com/SystemErrorWang/White-box-Cartoonization/blob/master/train_code/network.py 100 | # slim.convolution2d uses constant padding (zeros). 101 | # Paper used spectral_norm 102 | 103 | class Block(nn.Module): 104 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding,activate=True): 105 | super().__init__() 106 | self.sn_conv = spectral_norm(nn.Conv2d( 107 | in_channels, 108 | out_channels, 109 | kernel_size, 110 | stride, 111 | padding, 112 | padding_mode="zeros" # Author's code used slim.convolution2d, which is using SAME padding (zero padding in pytorch) 113 | )) 114 | self.activate = activate 115 | if self.activate: 116 | self.LReLU = nn.LeakyReLU(negative_slope=0.2, inplace=True) 117 | 118 | def forward(self, x): 119 | x = self.sn_conv(x) 120 | if self.activate: 121 | x = self.LReLU(x) 122 | 123 | return x 124 | 125 | 126 | class Discriminator(nn.Module): 127 | def __init__(self, in_channels=3, out_channels=1, features=[32, 64, 128]): 128 | super().__init__() 129 | 130 | self.model = nn.Sequential( 131 | #k3n32s2 132 | Block(in_channels, features[0], kernel_size=3, stride=2, padding=1), 133 | #k3n32s1 134 | Block(features[0], features[0], kernel_size=3, stride=1, padding=1), 135 | 136 | #k3n64s2 137 | Block(features[0], features[1], kernel_size=3, stride=2, padding=1), 138 | #k3n64s1 139 | Block(features[1], features[1], kernel_size=3, stride=1, padding=1), 140 | 141 | #k3n128s2 142 | Block(features[1], features[2], kernel_size=3, stride=2, padding=1), 143 | #k3n128s1 144 | Block(features[2], features[2], kernel_size=3, stride=1, padding=1), 145 | 146 | #k1n1s1 147 | Block(features[2], out_channels, kernel_size=1, stride=1, padding=0) 148 | ) 149 | 150 | def forward(self, x): 151 | x = self.model(x) 152 | 153 | return x 154 | 155 | 156 | class ExpressDetector(nn.Module): 157 | def __init__(self, in_channels=3, out_channels=3, features=[32, 64, 128,512]): 158 | super().__init__() 159 | 160 | self.model = nn.Sequential( 161 | #k3n32s2 162 | Block(in_channels, features[0], kernel_size=3, stride=2, padding=1), 163 | #k3n32s1 164 | Block(features[0], features[0], kernel_size=3, stride=1, padding=1), 165 | 166 | #k3n64s2 167 | Block(features[0], features[1], kernel_size=3, stride=2, padding=1), 168 | #k3n64s1 169 | Block(features[1], features[1], kernel_size=3, stride=1, padding=1), 170 | 171 | #k3n128s2 172 | Block(features[1], features[2], kernel_size=3, stride=2, padding=1), 173 | #k3n128s1 174 | Block(features[2], features[2], kernel_size=3, stride=1, padding=1), 175 | nn.AdaptiveAvgPool2d(1), 176 | nn.Flatten(), 177 | nn.Linear(features[2],features[3]), 178 | nn.Linear(features[3],out_channels) 179 | 180 | ) 181 | 182 | def forward(self, x): 183 | x = self.model(x) 184 | 185 | return x 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /model/Pix2PixModule/module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | import torch 4 | 5 | # refer to https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch 6 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 7 | VGG_types = { 8 | "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 9 | "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 10 | "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], 11 | "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], 12 | } 13 | class VGG19(torch.nn.Module): 14 | def __init__(self, in_channels=3, 15 | VGGtype="VGG19", 16 | init_weights=None, 17 | batch_norm=False, 18 | num_classes=1000, 19 | feature_mode=False, 20 | requires_grad=False): 21 | super(VGG19, self).__init__() 22 | self.in_channels = in_channels 23 | self.feature_mode = feature_mode 24 | self.batch_norm = batch_norm 25 | 26 | self.features = self.create_conv_layers(VGG_types[VGGtype]) 27 | 28 | self.classifier = nn.Sequential( 29 | nn.Linear(512 * 7 * 7, 4096), 30 | nn.ReLU(), 31 | nn.Dropout(p=0.5), 32 | nn.Linear(4096, 4096), 33 | nn.ReLU(), 34 | nn.Dropout(p=0.5), 35 | nn.Linear(4096, num_classes), 36 | ) 37 | 38 | if init_weights is not None: 39 | self.load_state_dict(torch.load(init_weights)) 40 | 41 | if not requires_grad: 42 | for param in self.parameters(): 43 | param.requires_grad = False 44 | 45 | def forward(self, x): 46 | if not self.feature_mode: 47 | x = self.features(x) 48 | x = x.view(x.size(0), -1) 49 | x = self.classifier(x) 50 | 51 | elif self.feature_mode == True and self.batch_norm == False: 52 | module_list = list(self.features.modules()) 53 | #print(module_list[1:27]) 54 | for layer in module_list[1:27]: # conv4_4 Feature maps 55 | x = layer(x) 56 | else: 57 | raise ValueError('Feature mode does not work with batch norm enabled. Set batch_norm=False and try again.') 58 | 59 | return x 60 | 61 | def create_conv_layers(self, architecture): 62 | layers = [] 63 | in_channels = self.in_channels 64 | batch_norm = self.batch_norm 65 | 66 | for x in architecture: 67 | if type(x) == int: # Number of features 68 | out_channels = x 69 | 70 | layers += [ 71 | nn.Conv2d( 72 | in_channels=in_channels, 73 | out_channels=out_channels, 74 | kernel_size=3, 75 | stride=1, 76 | padding=1, 77 | ), 78 | ] 79 | 80 | if batch_norm == True: 81 | # Back at that time Batch Norm was not invented 82 | layers += [nn.BatchNorm2d(x),nn.ReLU(),] 83 | else: 84 | layers += [nn.ReLU()] 85 | 86 | in_channels = x #update in_channel 87 | 88 | elif x == "M": # Maxpooling 89 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 90 | 91 | return nn.Sequential(*layers) 92 | 93 | 94 | 95 | # refer to https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch 96 | def box_filter( x, r): 97 | channel = x.shape[1] # Batch, Channel, H, W 98 | kernel_size = (2*r+1) 99 | weight = 1.0/(kernel_size**2) 100 | box_kernel = weight*torch.ones((channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device) 101 | output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) #tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME') 102 | 103 | return output 104 | 105 | 106 | def guided_filter(x, y, r, eps=1e-2): 107 | # Batch, Channel, H, W 108 | _, _, H, W = x.shape 109 | 110 | N = box_filter(torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), r) 111 | 112 | mean_x = box_filter(x, r) / N 113 | mean_y = box_filter(y, r) / N 114 | cov_xy = box_filter(x * y, r) / N - mean_x * mean_y 115 | var_x = box_filter(x * x, r) / N - mean_x * mean_x 116 | 117 | A = cov_xy / (var_x + eps) 118 | b = mean_y - A * mean_x 119 | 120 | mean_A = box_filter(A, r) / N 121 | mean_b = box_filter(b, r) / N 122 | 123 | output = mean_A * x + mean_b 124 | return output 125 | 126 | # refer to https://github.com/SystemErrorWang/White-box-Cartoonization 127 | def color_shift(image, mode='uniform'): 128 | device = image.device 129 | r1, b1, g1 = torch.split(image, 1, dim=1) 130 | 131 | if mode == 'normal': 132 | b_weight = torch.normal(mean=0.114, std=0.1,size=[1]).to(device) 133 | g_weight = torch.normal(mean=0.587, std=0.1,size=[1]).to(device) 134 | r_weight = torch.normal(mean=0.299, std=0.1,size=[1]).to(device) 135 | elif mode == 'uniform': 136 | 137 | b_weight = torch.FloatTensor(1).uniform_(0.014,0.214).to(device) 138 | g_weight = torch.FloatTensor(1).uniform_(0.487, 0.687).to(device) 139 | r_weight = torch.FloatTensor(1).uniform_(0.199, 0.399).to(device) 140 | output1 = (b_weight*b1+g_weight*g1+r_weight*r1)/(b_weight+g_weight+r_weight) 141 | 142 | return output1 143 | 144 | 145 | -------------------------------------------------------------------------------- /model/styleganModule/arcface.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 2 | import torch 3 | from collections import namedtuple 4 | 5 | 6 | # ------------------------arcface--------------------------------------- 7 | class Backbone(Module): 8 | def __init__(self, num_layers, drop_ratio, mode='ir'): 9 | super(Backbone, self).__init__() 10 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 11 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 12 | blocks = get_blocks(num_layers) 13 | if mode == 'ir': 14 | unit_module = bottleneck_IR 15 | elif mode == 'ir_se': 16 | unit_module = bottleneck_IR_SE 17 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1 ,bias=False), 18 | BatchNorm2d(64), 19 | PReLU(64)) 20 | self.output_layer = Sequential(BatchNorm2d(512), 21 | Dropout(drop_ratio), 22 | Flatten(), 23 | Linear(512 * 7 * 7, 512), 24 | BatchNorm1d(512)) 25 | # ) 26 | modules = [] 27 | for block in blocks: 28 | for bottleneck in block: 29 | modules.append( 30 | unit_module(bottleneck.in_channel, 31 | bottleneck.depth, 32 | bottleneck.stride)) 33 | self.body = Sequential(*modules) 34 | 35 | def forward(self,x): 36 | 37 | feats = [] 38 | x = self.input_layer(x) 39 | for m in self.body.children(): 40 | x = m(x) 41 | feats.append(x) 42 | # x = self.body(x) 43 | x = self.output_layer(x) 44 | return l2_norm(x), feats 45 | 46 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 47 | '''A named tuple describing a ResNet block.''' 48 | 49 | class bottleneck_IR(Module): 50 | def __init__(self, in_channel, depth, stride): 51 | super(bottleneck_IR, self).__init__() 52 | if in_channel == depth: 53 | self.shortcut_layer = MaxPool2d(1, stride) 54 | else: 55 | self.shortcut_layer = Sequential( 56 | Conv2d(in_channel, depth, (1, 1), stride ,bias=False), BatchNorm2d(depth)) 57 | self.res_layer = Sequential( 58 | BatchNorm2d(in_channel), 59 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1 ,bias=False), PReLU(depth), 60 | Conv2d(depth, depth, (3, 3), stride, 1 ,bias=False), BatchNorm2d(depth)) 61 | 62 | def forward(self, x): 63 | shortcut = self.shortcut_layer(x) 64 | res = self.res_layer(x) 65 | return res + shortcut 66 | 67 | class bottleneck_IR_SE(Module): 68 | def __init__(self, in_channel, depth, stride): 69 | super(bottleneck_IR_SE, self).__init__() 70 | if in_channel == depth: 71 | self.shortcut_layer = MaxPool2d(1, stride) 72 | else: 73 | self.shortcut_layer = Sequential( 74 | Conv2d(in_channel, depth, (1, 1), stride ,bias=False), 75 | BatchNorm2d(depth)) 76 | self.res_layer = Sequential( 77 | BatchNorm2d(in_channel), 78 | Conv2d(in_channel, depth, (3,3), (1,1),1 ,bias=False), 79 | PReLU(depth), 80 | Conv2d(depth, depth, (3,3), stride, 1 ,bias=False), 81 | BatchNorm2d(depth), 82 | SEModule(depth,16) 83 | ) 84 | def forward(self,x): 85 | shortcut = self.shortcut_layer(x) 86 | res = self.res_layer(x) 87 | return res + shortcut 88 | 89 | def l2_norm(input,axis=1): 90 | norm = torch.norm(input,2,axis,True) 91 | output = torch.div(input, norm) 92 | return output 93 | 94 | class SEModule(Module): 95 | def __init__(self, channels, reduction): 96 | super(SEModule, self).__init__() 97 | self.avg_pool = AdaptiveAvgPool2d(1) 98 | self.fc1 = Conv2d( 99 | channels, channels // reduction, kernel_size=1, padding=0 ,bias=False) 100 | self.relu = ReLU(inplace=True) 101 | self.fc2 = Conv2d( 102 | channels // reduction, channels, kernel_size=1, padding=0 ,bias=False) 103 | self.sigmoid = Sigmoid() 104 | 105 | def forward(self, x): 106 | module_input = x 107 | x = self.avg_pool(x) 108 | x = self.fc1(x) 109 | x = self.relu(x) 110 | x = self.fc2(x) 111 | x = self.sigmoid(x) 112 | return module_input * x 113 | 114 | class Flatten(Module): 115 | def forward(self, input): 116 | return input.view(input.size(0), -1) 117 | 118 | def get_block(in_channel, depth, num_units, stride = 2): 119 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units-1)] 120 | 121 | def get_blocks(num_layers): 122 | if num_layers == 50: 123 | blocks = [ 124 | get_block(in_channel=64, depth=64, num_units = 3), 125 | get_block(in_channel=64, depth=128, num_units=4), 126 | get_block(in_channel=128, depth=256, num_units=14), 127 | get_block(in_channel=256, depth=512, num_units=3) 128 | ] 129 | elif num_layers == 100: 130 | blocks = [ 131 | get_block(in_channel=64, depth=64, num_units=3), 132 | get_block(in_channel=64, depth=128, num_units=13), 133 | get_block(in_channel=128, depth=256, num_units=30), 134 | get_block(in_channel=256, depth=512, num_units=3) 135 | ] 136 | elif num_layers == 152: 137 | blocks = [ 138 | get_block(in_channel=64, depth=64, num_units=3), 139 | get_block(in_channel=64, depth=128, num_units=8), 140 | get_block(in_channel=128, depth=256, num_units=36), 141 | get_block(in_channel=256, depth=512, num_units=3) 142 | ] 143 | return blocks 144 | 145 | -------------------------------------------------------------------------------- /model/styleganModule/config.py: -------------------------------------------------------------------------------- 1 | class Params: 2 | def __init__(self): 3 | 4 | self.name = 'StyleGAN' 5 | self.size = 256 6 | self.stylegan_path = 'pretrain_models/550000.pt' 7 | self.root = '' 8 | self.id_model = 'pretrain_models/model_ir_se50.pth' 9 | self.g_reg_every = 4 10 | self.d_reg_every = 16 11 | self.D_steps_pre_G = 1 12 | self.latent = 512 13 | self.n_mlp =8 14 | self.channel_multiplier =2 15 | self.size =256 16 | self.mixing =0.9 17 | self.inject_index =4 18 | self.n_sample =8 19 | 20 | self.lambda_gan =1.0 21 | self.lambda_id =1.0 22 | self.path_regularize =2.0 23 | self.path_batch_shrink = 2 24 | self.r1 =10.0 25 | 26 | self.interval_steps = 100 27 | self.interval_train = False 28 | 29 | self.infer_batch_size = 1 30 | self.mx_gen_iters = 20000 31 | -------------------------------------------------------------------------------- /model/styleganModule/loss.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from torch import autograd 3 | from torch import nn 4 | import torch 5 | from .arcface import Backbone 6 | import math 7 | from .op import conv2d_gradfix 8 | 9 | class IDLoss(nn.Module): 10 | def __init__(self,pretrain_model, requires_grad=False): 11 | super(IDLoss, self).__init__() 12 | self.idModel = Backbone(50,0.6,'ir_se') 13 | self.idModel.load_state_dict(torch.load(pretrain_model),strict=False) 14 | self.idModel.eval() 15 | self.criterion = nn.CosineSimilarity(dim=1,eps=1e-6) 16 | self.id_size = 112 17 | if not requires_grad: 18 | for param in self.parameters(): 19 | param.requires_grad = False 20 | 21 | 22 | 23 | def forward(self, x, y): 24 | x_id, _ = self.idModel(F.interpolate(x[:,:,28:228,28:228],[self.id_size, self.id_size], mode='bilinear')) 25 | y_id,_ = self.idModel(F.interpolate(y[:,:,28:228,28:228], 26 | [self.id_size, self.id_size], mode='bilinear')) 27 | loss = 1 - self.criterion(x_id,y_id) 28 | return loss.mean() 29 | 30 | 31 | def g_nonsaturating_loss(fake_pred): 32 | loss = F.softplus(-fake_pred).mean() 33 | 34 | return loss 35 | 36 | 37 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 38 | noise = torch.randn_like(fake_img) / math.sqrt( 39 | fake_img.shape[2] * fake_img.shape[3] 40 | ) 41 | grad, = autograd.grad( 42 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 43 | ) 44 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 45 | 46 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 47 | 48 | path_penalty = (path_lengths - path_mean).pow(2).mean() 49 | 50 | return path_penalty, path_mean.detach(), path_lengths 51 | 52 | def d_r1_loss(real_pred, real_img): 53 | with conv2d_gradfix.no_weight_gradients(): 54 | grad_real, = autograd.grad( 55 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 56 | ) 57 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 58 | 59 | return grad_penalty 60 | 61 | def d_logistic_loss(real_pred, fake_pred): 62 | real_loss = F.softplus(-real_pred) 63 | fake_loss = F.softplus(fake_pred) 64 | 65 | return real_loss.mean() + fake_loss.mean() -------------------------------------------------------------------------------- /model/styleganModule/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 10 | 11 | 12 | class PixelNorm(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, input): 17 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 18 | 19 | 20 | def make_kernel(k): 21 | k = torch.tensor(k, dtype=torch.float32) 22 | 23 | if k.ndim == 1: 24 | k = k[None, :] * k[:, None] 25 | 26 | k /= k.sum() 27 | 28 | return k 29 | 30 | 31 | class Upsample(nn.Module): 32 | def __init__(self, kernel, factor=2): 33 | super().__init__() 34 | 35 | self.factor = factor 36 | kernel = make_kernel(kernel) * (factor ** 2) 37 | self.register_buffer("kernel", kernel) 38 | 39 | p = kernel.shape[0] - factor 40 | 41 | pad0 = (p + 1) // 2 + factor - 1 42 | pad1 = p // 2 43 | 44 | self.pad = (pad0, pad1) 45 | 46 | def forward(self, input): 47 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 48 | 49 | return out 50 | 51 | 52 | class Downsample(nn.Module): 53 | def __init__(self, kernel, factor=2): 54 | super().__init__() 55 | 56 | self.factor = factor 57 | kernel = make_kernel(kernel) 58 | self.register_buffer("kernel", kernel) 59 | 60 | p = kernel.shape[0] - factor 61 | 62 | pad0 = (p + 1) // 2 63 | pad1 = p // 2 64 | 65 | self.pad = (pad0, pad1) 66 | 67 | def forward(self, input): 68 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 69 | 70 | return out 71 | 72 | 73 | class Blur(nn.Module): 74 | def __init__(self, kernel, pad, upsample_factor=1): 75 | super().__init__() 76 | 77 | kernel = make_kernel(kernel) 78 | 79 | if upsample_factor > 1: 80 | kernel = kernel * (upsample_factor ** 2) 81 | 82 | self.register_buffer("kernel", kernel) 83 | 84 | self.pad = pad 85 | 86 | def forward(self, input): 87 | out = upfirdn2d(input, self.kernel, pad=self.pad) 88 | 89 | return out 90 | 91 | 92 | class EqualConv2d(nn.Module): 93 | def __init__( 94 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 95 | ): 96 | super().__init__() 97 | 98 | self.weight = nn.Parameter( 99 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 100 | ) 101 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 102 | 103 | self.stride = stride 104 | self.padding = padding 105 | 106 | if bias: 107 | self.bias = nn.Parameter(torch.zeros(out_channel)) 108 | 109 | else: 110 | self.bias = None 111 | 112 | def forward(self, input): 113 | out = conv2d_gradfix.conv2d( 114 | input, 115 | self.weight * self.scale, 116 | bias=self.bias, 117 | stride=self.stride, 118 | padding=self.padding, 119 | ) 120 | 121 | return out 122 | 123 | def __repr__(self): 124 | return ( 125 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 126 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 127 | ) 128 | 129 | 130 | class EqualLinear(nn.Module): 131 | def __init__( 132 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 133 | ): 134 | super().__init__() 135 | 136 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 137 | 138 | if bias: 139 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 140 | 141 | else: 142 | self.bias = None 143 | 144 | self.activation = activation 145 | 146 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 147 | self.lr_mul = lr_mul 148 | 149 | def forward(self, input): 150 | if self.activation: 151 | out = F.linear(input, self.weight * self.scale) 152 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 153 | 154 | else: 155 | out = F.linear( 156 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 157 | ) 158 | 159 | return out 160 | 161 | def __repr__(self): 162 | return ( 163 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 164 | ) 165 | 166 | 167 | class ModulatedConv2d(nn.Module): 168 | def __init__( 169 | self, 170 | in_channel, 171 | out_channel, 172 | kernel_size, 173 | style_dim, 174 | demodulate=True, 175 | upsample=False, 176 | downsample=False, 177 | blur_kernel=[1, 3, 3, 1], 178 | fused=True, 179 | ): 180 | super().__init__() 181 | 182 | self.eps = 1e-8 183 | self.kernel_size = kernel_size 184 | self.in_channel = in_channel 185 | self.out_channel = out_channel 186 | self.upsample = upsample 187 | self.downsample = downsample 188 | 189 | if upsample: 190 | factor = 2 191 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 192 | pad0 = (p + 1) // 2 + factor - 1 193 | pad1 = p // 2 + 1 194 | 195 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 196 | 197 | if downsample: 198 | factor = 2 199 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 200 | pad0 = (p + 1) // 2 201 | pad1 = p // 2 202 | 203 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 204 | 205 | fan_in = in_channel * kernel_size ** 2 206 | self.scale = 1 / math.sqrt(fan_in) 207 | self.padding = kernel_size // 2 208 | 209 | self.weight = nn.Parameter( 210 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 211 | ) 212 | 213 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 214 | 215 | self.demodulate = demodulate 216 | self.fused = fused 217 | 218 | def __repr__(self): 219 | return ( 220 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 221 | f"upsample={self.upsample}, downsample={self.downsample})" 222 | ) 223 | 224 | def forward(self, input, style): 225 | batch, in_channel, height, width = input.shape 226 | 227 | if not self.fused: 228 | weight = self.scale * self.weight.squeeze(0) 229 | style = self.modulation(style) 230 | 231 | if self.demodulate: 232 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) 233 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() 234 | 235 | input = input * style.reshape(batch, in_channel, 1, 1) 236 | 237 | if self.upsample: 238 | weight = weight.transpose(0, 1) 239 | out = conv2d_gradfix.conv_transpose2d( 240 | input, weight, padding=0, stride=2 241 | ) 242 | out = self.blur(out) 243 | 244 | elif self.downsample: 245 | input = self.blur(input) 246 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) 247 | 248 | else: 249 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) 250 | 251 | if self.demodulate: 252 | out = out * dcoefs.view(batch, -1, 1, 1) 253 | 254 | return out 255 | 256 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 257 | weight = self.scale * self.weight * style 258 | 259 | if self.demodulate: 260 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 261 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 262 | 263 | weight = weight.view( 264 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 265 | ) 266 | 267 | if self.upsample: 268 | input = input.view(1, batch * in_channel, height, width) 269 | weight = weight.view( 270 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 271 | ) 272 | weight = weight.transpose(1, 2).reshape( 273 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 274 | ) 275 | out = conv2d_gradfix.conv_transpose2d( 276 | input, weight, padding=0, stride=2, groups=batch 277 | ) 278 | _, _, height, width = out.shape 279 | out = out.view(batch, self.out_channel, height, width) 280 | out = self.blur(out) 281 | 282 | elif self.downsample: 283 | input = self.blur(input) 284 | _, _, height, width = input.shape 285 | input = input.view(1, batch * in_channel, height, width) 286 | out = conv2d_gradfix.conv2d( 287 | input, weight, padding=0, stride=2, groups=batch 288 | ) 289 | _, _, height, width = out.shape 290 | out = out.view(batch, self.out_channel, height, width) 291 | 292 | else: 293 | input = input.view(1, batch * in_channel, height, width) 294 | out = conv2d_gradfix.conv2d( 295 | input, weight, padding=self.padding, groups=batch 296 | ) 297 | _, _, height, width = out.shape 298 | out = out.view(batch, self.out_channel, height, width) 299 | 300 | return out 301 | 302 | 303 | class NoiseInjection(nn.Module): 304 | def __init__(self): 305 | super().__init__() 306 | 307 | self.weight = nn.Parameter(torch.zeros(1)) 308 | 309 | def forward(self, image, noise=None): 310 | if noise is None: 311 | batch, _, height, width = image.shape 312 | noise = image.new_empty(batch, 1, height, width).normal_() 313 | 314 | return image + self.weight * noise 315 | 316 | 317 | class ConstantInput(nn.Module): 318 | def __init__(self, channel, size=4): 319 | super().__init__() 320 | 321 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 322 | 323 | def forward(self, input): 324 | batch = input.shape[0] 325 | out = self.input.repeat(batch, 1, 1, 1) 326 | 327 | return out 328 | 329 | 330 | class StyledConv(nn.Module): 331 | def __init__( 332 | self, 333 | in_channel, 334 | out_channel, 335 | kernel_size, 336 | style_dim, 337 | upsample=False, 338 | blur_kernel=[1, 3, 3, 1], 339 | demodulate=True, 340 | ): 341 | super().__init__() 342 | 343 | self.conv = ModulatedConv2d( 344 | in_channel, 345 | out_channel, 346 | kernel_size, 347 | style_dim, 348 | upsample=upsample, 349 | blur_kernel=blur_kernel, 350 | demodulate=demodulate, 351 | ) 352 | 353 | self.noise = NoiseInjection() 354 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 355 | # self.activate = ScaledLeakyReLU(0.2) 356 | self.activate = FusedLeakyReLU(out_channel) 357 | 358 | def forward(self, input, style, noise=None): 359 | out = self.conv(input, style) 360 | out = self.noise(out, noise=noise) 361 | # out = out + self.bias 362 | out = self.activate(out) 363 | 364 | return out 365 | 366 | 367 | class ToRGB(nn.Module): 368 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 369 | super().__init__() 370 | 371 | if upsample: 372 | self.upsample = Upsample(blur_kernel) 373 | 374 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 375 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 376 | 377 | def forward(self, input, style, skip=None): 378 | out = self.conv(input, style) 379 | out = out + self.bias 380 | 381 | if skip is not None: 382 | skip = self.upsample(skip) 383 | 384 | out = out + skip 385 | 386 | return out 387 | 388 | 389 | class Generator(nn.Module): 390 | def __init__( 391 | self, 392 | size, 393 | style_dim, 394 | n_mlp, 395 | channel_multiplier=2, 396 | blur_kernel=[1, 3, 3, 1], 397 | lr_mlp=0.01, 398 | ): 399 | super().__init__() 400 | 401 | self.size = size 402 | 403 | self.style_dim = style_dim 404 | 405 | layers = [PixelNorm()] 406 | 407 | for i in range(n_mlp): 408 | layers.append( 409 | EqualLinear( 410 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 411 | ) 412 | ) 413 | 414 | self.style = nn.Sequential(*layers) 415 | 416 | self.channels = { 417 | 4: 512, 418 | 8: 512, 419 | 16: 512, 420 | 32: 512, 421 | 64: 256 * channel_multiplier, 422 | 128: 128 * channel_multiplier, 423 | 256: 64 * channel_multiplier, 424 | 512: 32 * channel_multiplier, 425 | 1024: 16 * channel_multiplier, 426 | } 427 | 428 | self.input = ConstantInput(self.channels[4]) 429 | self.conv1 = StyledConv( 430 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 431 | ) 432 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 433 | 434 | self.log_size = int(math.log(size, 2)) 435 | self.num_layers = (self.log_size - 2) * 2 + 1 436 | 437 | self.convs = nn.ModuleList() 438 | self.upsamples = nn.ModuleList() 439 | self.to_rgbs = nn.ModuleList() 440 | self.noises = nn.Module() 441 | 442 | in_channel = self.channels[4] 443 | 444 | for layer_idx in range(self.num_layers): 445 | res = (layer_idx + 5) // 2 446 | shape = [1, 1, 2 ** res, 2 ** res] 447 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 448 | 449 | for i in range(3, self.log_size + 1): 450 | out_channel = self.channels[2 ** i] 451 | 452 | self.convs.append( 453 | StyledConv( 454 | in_channel, 455 | out_channel, 456 | 3, 457 | style_dim, 458 | upsample=True, 459 | blur_kernel=blur_kernel, 460 | ) 461 | ) 462 | 463 | self.convs.append( 464 | StyledConv( 465 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 466 | ) 467 | ) 468 | 469 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 470 | 471 | in_channel = out_channel 472 | 473 | self.n_latent = self.log_size * 2 - 2 474 | 475 | def make_noise(self): 476 | device = self.input.input.device 477 | 478 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 479 | 480 | for i in range(3, self.log_size + 1): 481 | for _ in range(2): 482 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 483 | 484 | return noises 485 | 486 | def mean_latent(self, n_latent): 487 | latent_in = torch.randn( 488 | n_latent, self.style_dim, device=self.input.input.device 489 | ) 490 | latent = self.style(latent_in).mean(0, keepdim=True) 491 | 492 | return latent 493 | 494 | def get_latent(self, input): 495 | return self.style(input) 496 | 497 | def forward( 498 | self, 499 | styles, 500 | return_latents=False, 501 | inject_index=None, 502 | truncation=1, 503 | truncation_latent=None, 504 | input_is_latent=False, 505 | noise=None, 506 | randomize_noise=True, 507 | only_latent=False 508 | ): 509 | if not input_is_latent: 510 | styles = [self.style(s) for s in styles] 511 | 512 | if noise is None: 513 | if randomize_noise: 514 | noise = [None] * self.num_layers 515 | else: 516 | noise = [ 517 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 518 | ] 519 | 520 | if truncation < 1: 521 | style_t = [] 522 | 523 | for style in styles: 524 | style_t.append( 525 | truncation_latent + truncation * (style - truncation_latent) 526 | ) 527 | 528 | styles = style_t 529 | 530 | if len(styles) < 2: 531 | inject_index = self.n_latent 532 | 533 | if styles[0].ndim < 3: 534 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 535 | 536 | else: 537 | latent = styles[0] 538 | 539 | else: 540 | if inject_index is None: 541 | inject_index = random.randint(1, self.n_latent - 1) 542 | 543 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 544 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 545 | 546 | latent = torch.cat([latent, latent2], 1) 547 | if only_latent: 548 | return latent 549 | 550 | out = self.input(latent) 551 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 552 | 553 | skip = self.to_rgb1(out, latent[:, 1]) 554 | 555 | i = 1 556 | for conv1, conv2, noise1, noise2, to_rgb in zip( 557 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 558 | ): 559 | out = conv1(out, latent[:, i], noise=noise1) 560 | out = conv2(out, latent[:, i + 1], noise=noise2) 561 | skip = to_rgb(out, latent[:, i + 2], skip) 562 | 563 | i += 2 564 | 565 | image = skip 566 | 567 | if return_latents: 568 | return image, latent 569 | 570 | else: 571 | return image, None 572 | 573 | 574 | class ConvLayer(nn.Sequential): 575 | def __init__( 576 | self, 577 | in_channel, 578 | out_channel, 579 | kernel_size, 580 | downsample=False, 581 | blur_kernel=[1, 3, 3, 1], 582 | bias=True, 583 | activate=True, 584 | ): 585 | layers = [] 586 | 587 | if downsample: 588 | factor = 2 589 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 590 | pad0 = (p + 1) // 2 591 | pad1 = p // 2 592 | 593 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 594 | 595 | stride = 2 596 | self.padding = 0 597 | 598 | else: 599 | stride = 1 600 | self.padding = kernel_size // 2 601 | 602 | layers.append( 603 | EqualConv2d( 604 | in_channel, 605 | out_channel, 606 | kernel_size, 607 | padding=self.padding, 608 | stride=stride, 609 | bias=bias and not activate, 610 | ) 611 | ) 612 | 613 | if activate: 614 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 615 | 616 | super().__init__(*layers) 617 | 618 | 619 | class ResBlock(nn.Module): 620 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 621 | super().__init__() 622 | 623 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 624 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 625 | 626 | self.skip = ConvLayer( 627 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 628 | ) 629 | 630 | def forward(self, input): 631 | out = self.conv1(input) 632 | out = self.conv2(out) 633 | 634 | skip = self.skip(input) 635 | out = (out + skip) / math.sqrt(2) 636 | 637 | return out 638 | 639 | 640 | class Discriminator(nn.Module): 641 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 642 | super().__init__() 643 | 644 | channels = { 645 | 4: 512, 646 | 8: 512, 647 | 16: 512, 648 | 32: 512, 649 | 64: 256 * channel_multiplier, 650 | 128: 128 * channel_multiplier, 651 | 256: 64 * channel_multiplier, 652 | 512: 32 * channel_multiplier, 653 | 1024: 16 * channel_multiplier, 654 | } 655 | 656 | convs = [ConvLayer(3, channels[size], 1)] 657 | 658 | log_size = int(math.log(size, 2)) 659 | 660 | in_channel = channels[size] 661 | 662 | for i in range(log_size, 2, -1): 663 | out_channel = channels[2 ** (i - 1)] 664 | 665 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 666 | 667 | in_channel = out_channel 668 | 669 | self.convs = nn.Sequential(*convs) 670 | 671 | self.stddev_group = 4 672 | self.stddev_feat = 1 673 | 674 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 675 | self.final_linear = nn.Sequential( 676 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 677 | EqualLinear(channels[4], 1), 678 | ) 679 | 680 | def forward(self, input): 681 | 682 | out = self.convs(input) 683 | 684 | batch, channel, height, width = out.shape 685 | group = min(batch, self.stddev_group) 686 | stddev = out.view( 687 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 688 | ) 689 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 690 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 691 | stddev = stddev.repeat(group, 1, height, width) 692 | out = torch.cat([out, stddev], 1) 693 | 694 | out = self.final_conv(out) 695 | 696 | out = out.view(batch, -1) 697 | out = self.final_linear(out) 698 | 699 | return out 700 | 701 | -------------------------------------------------------------------------------- /model/styleganModule/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /model/styleganModule/op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /model/styleganModule/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input.contiguous(), 51 | gradgrad_bias, 52 | out, 53 | 3, 54 | 1, 55 | ctx.negative_slope, 56 | ctx.scale, 57 | ) 58 | 59 | return gradgrad_out, None, None, None, None 60 | 61 | 62 | class FusedLeakyReLUFunction(Function): 63 | @staticmethod 64 | def forward(ctx, input, bias, negative_slope, scale): 65 | empty = input.new_empty(0) 66 | 67 | ctx.bias = bias is not None 68 | 69 | if bias is None: 70 | bias = empty 71 | 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | if not ctx.bias: 88 | grad_bias = None 89 | 90 | return grad_input, grad_bias, None, None 91 | 92 | 93 | class FusedLeakyReLU(nn.Module): 94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 95 | super().__init__() 96 | 97 | if bias: 98 | self.bias = nn.Parameter(torch.zeros(channel)) 99 | 100 | else: 101 | self.bias = None 102 | 103 | self.negative_slope = negative_slope 104 | self.scale = scale 105 | 106 | def forward(self, input): 107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 108 | 109 | 110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 111 | if input.device.type == "cpu": 112 | if bias is not None: 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return F.leaky_relu(input, negative_slope=0.2) * scale 123 | 124 | else: 125 | return FusedLeakyReLUFunction.apply( 126 | input.contiguous(), bias, negative_slope, scale 127 | ) 128 | -------------------------------------------------------------------------------- /model/styleganModule/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /model/styleganModule/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /model/styleganModule/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /model/styleganModule/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /model/styleganModule/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /model/styleganModule/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | def make_noise(batch, latent_dim, n_noise, device): 5 | if n_noise == 1: 6 | return torch.randn(batch, latent_dim, device=device) 7 | 8 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 9 | 10 | return noises 11 | 12 | 13 | def mixing_noise(batch, latent_dim, prob, device): 14 | if prob > 0 and random.random() < prob: 15 | return make_noise(batch, latent_dim, 2, device) 16 | 17 | else: 18 | return [make_noise(batch, latent_dim, 1, device)] -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | pip install tensorboardX -i https://pypi.tuna.tsinghua.edu.cn/simple 2 | export NCCL_SOCKET_IFNAME=eth0 3 | export NCCL_DEBUG=INFO 4 | # python train.py --model ccn --batch_size 16 --checkpoint_path checkpoint-ccn 5 | # python -m torch.distributed.launch train.py --model ccn --batch_size 16 --checkpoint_path checkpoint-ccn 6 | # python train.py --model ttn --batch_size 16 --checkpoint_path checkpoint-ttn-noxgf --lr 2e-4 --dist --print_interval 100 --save_interval 100 7 | # python train.py --model exp --batch_size 2 --checkpoint_path checkpoint-exp --lr 2e-4 --dist --print_interval 1 --save_interval 1 --early_stop --test_interval 1 --stop_interval 1 8 | # python train.py --model ccn --batch_size 16 --checkpoint_path checkpoint --lr 0.002 --print_interval 100 --save_interval 100 --dist 9 | python train.py --model ttn --batch_size 64 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100 --dist -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author LeslieZhao 3 | @date 20220721 4 | ''' 5 | import os 6 | 7 | import argparse 8 | from trainer.CCNTrainer import CCNTrainer 9 | 10 | from trainer.TTNTrainer import TTNTrainer 11 | import torch.distributed as dist 12 | from utils.utils import setup_seed,get_data_loader,merge_args 13 | from model.styleganModule.config import Params as CCNParams 14 | from model.Pix2PixModule.config import Params as TTNParams 15 | 16 | # torch.multiprocessing.set_start_method('spawn') 17 | 18 | parser = argparse.ArgumentParser(description="StyleGAN") 19 | #---------train set------------------------------------- 20 | parser.add_argument('--model',default="ccn",help='') 21 | parser.add_argument('--isTrain',action="store_false",help='') 22 | parser.add_argument('--dist',action="store_false",help='') 23 | parser.add_argument('--batch_size',default=16,type=int) 24 | parser.add_argument('--seed',default=10,type=int) 25 | parser.add_argument('--eval',default=1,type=int,help='whether use eval') 26 | parser.add_argument('--nDataLoaderThread',default=5,type=int,help='Num of loader threads') 27 | parser.add_argument('--print_interval',default=100,type=int) 28 | parser.add_argument('--test_interval',default=100,type=int,help='Test and save every [test_intervaal] epochs') 29 | parser.add_argument('--save_interval',default=100,type=int,help='save model interval') 30 | parser.add_argument('--stop_interval',default=20,type=int) 31 | parser.add_argument('--begin_it',default=0,type=int,help='begin epoch') 32 | parser.add_argument('--mx_data_length',default=100,type=int,help='max data length') 33 | parser.add_argument('--max_epoch',default=10000,type=int) 34 | parser.add_argument('--early_stop',action="store_true",help='') 35 | parser.add_argument('--scratch',action="store_true",help='') 36 | #---------path set-------------------------------------- 37 | parser.add_argument('--checkpoint_path',default='checkpoint-onlybaby',type=str) 38 | parser.add_argument('--pretrain_path',default=None,type=str) 39 | 40 | # ------optimizer set-------------------------------------- 41 | parser.add_argument('--lr',default=0.002,type=float,help="Learning rate") 42 | 43 | parser.add_argument( 44 | '--local_rank', 45 | type=int, 46 | default=0, 47 | help='Local rank passed from distributed launcher' 48 | ) 49 | 50 | args = parser.parse_args() 51 | 52 | def train_net(args): 53 | train_loader,test_loader,mx_length = get_data_loader(args) 54 | 55 | args.mx_data_length = mx_length 56 | if args.model == 'ccn': 57 | trainer = CCNTrainer(args) 58 | if args.model == 'ttn': 59 | trainer = TTNTrainer(args) 60 | 61 | trainer.train_network(train_loader,test_loader) 62 | 63 | if __name__ == "__main__": 64 | 65 | args = parser.parse_args() 66 | 67 | if args.model == 'ccn': 68 | params = CCNParams() 69 | if args.model == 'ttn': 70 | params = TTNParams() 71 | 72 | args = merge_args(args,params) 73 | if args.dist: 74 | dist.init_process_group(backend="nccl") # backbend='nccl' 75 | dist.barrier() # 用于同步训练 76 | args.world_size = dist.get_world_size() # 一共有几个节点 77 | args.rank = dist.get_rank() # 当前节点编号 78 | 79 | else: 80 | args.world_size = 1 81 | args.rank = 0 82 | 83 | setup_seed(args.seed+args.rank) 84 | print(args) 85 | 86 | args.checkpoint_path = os.path.join(args.checkpoint_path,args.name) 87 | 88 | print("local_rank %d | rank %d | world_size: %d"%(int(os.environ.get('LOCAL_RANK','0')),args.rank,args.world_size)) 89 | if args.rank == 0 : 90 | if not os.path.exists(args.checkpoint_path): 91 | os.makedirs(args.checkpoint_path) 92 | print("make dir: ",args.checkpoint_path) 93 | train_net(args) 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /trainer/CCNTrainer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author LeslieZhao 5 | @date 20220721 6 | ''' 7 | import torch 8 | 9 | from trainer.ModelTrainer import ModelTrainer 10 | from model.styleganModule.model import Generator,Discriminator 11 | from model.styleganModule.utils import * 12 | from model.styleganModule.loss import * 13 | from utils.utils import * 14 | 15 | 16 | class CCNTrainer(ModelTrainer): 17 | 18 | def __init__(self, args): 19 | super().__init__(args) 20 | self.device = 'cpu' 21 | if torch.cuda.is_available(): 22 | self.device = 'cuda' 23 | 24 | self.netGs = Generator( 25 | args.size,args.latent, 26 | args.n_mlp, 27 | channel_multiplier=args.channel_multiplier).to(self.device) 28 | 29 | self.netGt = Generator(args.size,args.latent, 30 | args.n_mlp, 31 | channel_multiplier=args.channel_multiplier).to(self.device) 32 | 33 | 34 | self.netD = Discriminator( 35 | args.size, channel_multiplier=args.channel_multiplier).to(self.device) 36 | self.gt_ema = Generator( 37 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(self.device) 38 | self.gt_ema.eval() 39 | accumulate(self.gt_ema, self.netGt, 0) 40 | self.g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) 41 | self.d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 42 | 43 | self.ckpt = None 44 | # self.init_weights(init_type='kaiming') 45 | if not args.scratch: 46 | ckpt = torch.load(args.stylegan_path, map_location=lambda storage, loc: storage) 47 | self.netGs.load_state_dict(ckpt['g_ema'],strict=False) 48 | self.netGt.load_state_dict(ckpt['g'],strict=False) 49 | self.gt_ema.load_state_dict(ckpt['g_ema'],strict=False) 50 | self.netD.load_state_dict(ckpt["d"]) 51 | self.ckpt = ckpt 52 | self.optimG,self.optimD = self.create_optimizer() 53 | 54 | self.sample_z = torch.randn(args.n_sample, args.latent).to(self.device) 55 | 56 | 57 | if args.pretrain_path is not None: 58 | self.loadParameters(args.pretrain_path) 59 | 60 | if args.dist: 61 | self.netGt,self.netGt_module = self.use_ddp(self.netGt) 62 | self.netD,self.netD_module = self.use_ddp(self.netD) 63 | else: 64 | self.netGt_module = self.netGt 65 | self.netD_module = self.netD 66 | self.netGs.eval() 67 | 68 | self.accum = 0.5 ** (32 / (10 * 1000)) 69 | self.criterionID = IDLoss(args.id_model).to(self.device) 70 | 71 | self.mean_path_length = 0 72 | 73 | 74 | def create_optimizer(self): 75 | g_optim = torch.optim.Adam( 76 | self.netGt.parameters(), 77 | lr=self.args.lr * self.g_reg_ratio, 78 | betas=(0 ** self.g_reg_ratio, 0.99 ** self.g_reg_ratio), 79 | ) 80 | d_optim = torch.optim.Adam( 81 | self.netD.parameters(), 82 | lr=self.args.lr * self.d_reg_ratio, 83 | betas=(0 ** self.d_reg_ratio, 0.99 ** self.d_reg_ratio), 84 | ) 85 | if self.ckpt is not None: 86 | g_optim.load_state_dict(self.ckpt["g_optim"]) 87 | d_optim.load_state_dict(self.ckpt["d_optim"]) 88 | return g_optim,d_optim 89 | 90 | 91 | def run_single_step(self, data, steps): 92 | self.netGt.train() 93 | if self.args.interval_train: 94 | data = self.process_input(data) 95 | if not hasattr(self,'interval_flag'): 96 | self.interval_flag = True 97 | if self.interval_flag: 98 | self.run_generator_one_step(data,steps) 99 | self.d_losses = {} 100 | else: 101 | self.run_discriminator_one_step(data,steps) 102 | self.g_losses = {} 103 | if steps % self.args.interval_steps == 0: 104 | self.interval_flag = not self.interval_flag 105 | 106 | 107 | else: 108 | super().run_single_step(data, steps) 109 | 110 | 111 | def run_discriminator_one_step(self, data,step): 112 | 113 | D_losses = {} 114 | requires_grad(self.netGt, False) 115 | requires_grad(self.netD, True) 116 | noise = mixing_noise(self.args.batch_size, 117 | self.args.latent, 118 | self.args.mixing,self.device) 119 | 120 | fake_img, _ = self.netGt(noise) 121 | fake_pred = self.netD(fake_img) 122 | real_pred = self.netD(data) 123 | d_loss = d_logistic_loss(real_pred, fake_pred) 124 | D_losses['d'] = d_loss 125 | 126 | self.netD.zero_grad() 127 | d_loss.backward() 128 | self.optimD.step() 129 | 130 | if step % self.args.d_reg_every == 0: 131 | data.requires_grad = True 132 | 133 | real_pred = self.netD(data) 134 | r1_loss = d_r1_loss(real_pred,data) 135 | self.netD.zero_grad() 136 | r1_loss = self.args.r1 / 2 * \ 137 | r1_loss * self.args.d_reg_every + \ 138 | 0 * real_pred[0] 139 | 140 | r1_loss.mean().backward() 141 | 142 | self.optimD.step() 143 | D_losses['r1'] = r1_loss 144 | 145 | self.d_losses = D_losses 146 | 147 | 148 | def run_generator_one_step(self, data,step): 149 | 150 | G_losses = {} 151 | requires_grad(self.netGt, True) 152 | requires_grad(self.netD, False) 153 | requires_grad(self.netGs, False) 154 | 155 | noise = mixing_noise(self.args.batch_size, 156 | self.args.latent, 157 | self.args.mixing,self.device) 158 | 159 | fake_s,_ = self.netGs(noise) 160 | fake_t,_ = self.netGt(noise) 161 | fake_pred = self.netD(fake_t) 162 | gan_loss = g_nonsaturating_loss(fake_pred) * self.args.lambda_gan 163 | id_loss = self.criterionID(fake_s,fake_t) * self.args.lambda_id 164 | G_losses['g'] = gan_loss 165 | G_losses['id'] = id_loss 166 | losses = gan_loss + id_loss 167 | G_losses['g_losses'] = losses 168 | self.netGt.zero_grad() 169 | losses.mean().backward() 170 | self.optimG.step() 171 | 172 | if step % self.args.g_reg_every == 0: 173 | path_batch_size = max(1, self.args.batch_size // self.args.path_batch_shrink) 174 | noise = mixing_noise(path_batch_size, self.args.latent, self.args.mixing,self.device) 175 | fake_img, latents = self.netGt(noise, return_latents=True) 176 | 177 | path_loss, self.mean_path_length, path_lengths = g_path_regularize( 178 | fake_img, latents, self.mean_path_length 179 | ) 180 | 181 | weighted_path_loss = self.args.path_regularize * self.args.g_reg_every * path_loss 182 | if self.args.path_batch_shrink: 183 | weighted_path_loss += 0 * fake_img[0, 0, 0, 0] 184 | 185 | self.netGt.zero_grad() 186 | weighted_path_loss.mean().backward() 187 | self.optimG.step() 188 | 189 | G_losses['path'] = weighted_path_loss 190 | 191 | accumulate(self.gt_ema,self.netGt_module,self.accum) 192 | self.g_losses = G_losses 193 | self.generator = [data.detach(),fake_s.detach(),fake_t.detach()] 194 | 195 | 196 | def evalution(self,test_loader,steps,epoch): 197 | 198 | loss_dict = {} 199 | with torch.no_grad(): 200 | fake_s,_ = self.netGs([self.sample_z]) 201 | fake_t,_ = self.gt_ema([self.sample_z]) 202 | if self.args.rank == 0 : 203 | self.val_vis.display_current_results(self.select_img([fake_s,fake_t]),steps) 204 | # self.val_vis.display_current_results(self.select_img([fake_t]),steps) 205 | 206 | return loss_dict 207 | 208 | def get_latest_losses(self): 209 | return {**self.g_losses,**self.d_losses} 210 | 211 | def get_latest_generated(self): 212 | return self.generator 213 | 214 | def loadParameters(self,path): 215 | ckpt = torch.load(path, map_location=lambda storage, loc: storage) 216 | self.netGs.load_state_dict(ckpt['Gs'],strict=False) 217 | self.netGt.load_state_dict(ckpt['Gt'],strict=False) 218 | self.gt_ema.load_state_dict(ckpt['gt_ema'],strict=False) 219 | self.optimG.load_state_dict(ckpt['g_optim']) 220 | self.optimD.load_state_dict(ckpt['d_optim']) 221 | 222 | def saveParameters(self,path): 223 | torch.save( 224 | { 225 | "Gs": self.netGs.state_dict(), 226 | "Gt": self.netGt_module.state_dict(), 227 | "gt_ema": self.gt_ema.state_dict(), 228 | "g_optim": self.optimG.state_dict(), 229 | "d_optim": self.optimD.state_dict(), 230 | "args": self.args, 231 | }, 232 | path 233 | ) 234 | 235 | def get_lr(self): 236 | return self.optimG.state_dict()['param_groups'][0]['lr'] 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /trainer/ModelTrainer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author LeslieZhao 5 | @date 20220721 6 | ''' 7 | import torch 8 | import math 9 | import time,os 10 | 11 | from utils.visualizer import Visualizer 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | import torch.distributed as dist 14 | import subprocess 15 | from utils.utils import convert_img 16 | class ModelTrainer: 17 | 18 | def __init__(self,args): 19 | 20 | self.args = args 21 | self.batch_size = args.batch_size 22 | self.old_lr = args.lr 23 | if args.rank == 0 : 24 | self.vis = Visualizer(args) 25 | 26 | if args.eval: 27 | self.val_vis = Visualizer(args,"val") 28 | 29 | # ## ===== ===== ===== ===== ===== 30 | # ## Train network 31 | # ## ===== ===== ===== ===== ===== 32 | 33 | def train_network(self,train_loader,test_loader): 34 | 35 | counter = 0 36 | loss_dict = {} 37 | acc_num = 0 38 | mn_loss = float('inf') 39 | 40 | steps = 0 41 | begin_it = 0 42 | if self.args.pretrain_path: 43 | begin_it = int(self.args.pretrain_path.split('/')[-1].split('-')[0]) 44 | steps = (begin_it+1) * math.ceil(self.args.mx_data_length/self.args.batch_size) 45 | 46 | print("current steps: %d | one epoch steps: %d "%(steps,self.args.mx_data_length)) 47 | 48 | for epoch in range(begin_it+1,self.args.max_epoch): 49 | 50 | for ii,(data) in enumerate(train_loader): 51 | 52 | tstart = time.time() 53 | 54 | self.run_single_step(data,steps) 55 | losses = self.get_latest_losses() 56 | 57 | for key,val in losses.items(): 58 | loss_dict[key] = loss_dict.get(key,0) + val.mean().item() 59 | 60 | counter += 1 61 | steps += 1 62 | 63 | telapsed = time.time() - tstart 64 | 65 | 66 | if ii % self.args.print_interval == 0 and self.args.rank == 0: 67 | for key,val in loss_dict.items(): 68 | loss_dict[key] /= counter 69 | 70 | lr_rate = self.get_lr() 71 | print_dict = {**{"time":telapsed,"lr":lr_rate}, 72 | **loss_dict} 73 | self.vis.print_current_errors(epoch,ii,print_dict,telapsed) 74 | 75 | self.vis.plot_current_errors(print_dict,steps) 76 | 77 | loss_dict = {} 78 | counter = 0 79 | 80 | # torch.cuda.empty_cache() 81 | if self.args.save_interval != 0 and ii % self.args.save_interval == 0 and \ 82 | self.args.rank == 0: 83 | self.saveParameters(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,ii))) 84 | 85 | display_data = self.select_img(self.get_latest_generated()) 86 | 87 | self.vis.display_current_results(display_data,steps) 88 | 89 | 90 | 91 | if self.args.eval and self.args.test_interval > 0 and steps % self.args.test_interval == 0: 92 | val_loss = self.evalution(test_loader,steps,epoch) 93 | 94 | if self.args.early_stop: 95 | 96 | acc_num,mn_loss,stop_flag = self.early_stop_wait(self.get_loss_from_val(val_loss),acc_num,mn_loss,epoch) 97 | if stop_flag: 98 | return 99 | 100 | # print('******************memory:',psutil.virtual_memory()[3]) 101 | 102 | if self.args.rank == 0 : 103 | self.saveParameters(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,0))) 104 | 105 | # 验证,保存最优模型 106 | if test_loader or self.args.eval: 107 | val_loss = self.evalution(test_loader,steps,epoch) 108 | 109 | if self.args.early_stop: 110 | 111 | acc_num,mn_loss,stop_flag = self.early_stop_wait(self.get_loss_from_val(val_loss),acc_num,mn_loss,epoch) 112 | if stop_flag: 113 | return 114 | 115 | 116 | if self.args.rank == 0 : 117 | self.vis.close() 118 | 119 | 120 | 121 | def early_stop_wait(self,loss,acc_num,mn_loss,epoch): 122 | 123 | if self.args.rank == 0: 124 | if loss < mn_loss: 125 | mn_loss = loss 126 | cmd_one = 'cp -r %s %s'%(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,0)), 127 | os.path.join(self.args.checkpoint_path,'final.pth')) 128 | done_one = subprocess.Popen(cmd_one,stdout=subprocess.PIPE,shell=True) 129 | done_one.wait() 130 | acc_num = 0 131 | else: 132 | acc_num += 1 133 | # 多机多卡,某一张卡退出则终止程序,使用all_reduce 134 | if self.args.dist: 135 | 136 | if acc_num > self.args.stop_interval: 137 | signal = torch.tensor([0]).cuda() 138 | else: 139 | signal = torch.tensor([1]).cuda() 140 | else: 141 | if self.args.dist: 142 | signal = torch.tensor([1]).cuda() 143 | 144 | if self.args.dist: 145 | dist.all_reduce(signal) 146 | value = signal.item() 147 | if value >= int(os.environ.get("WORLD_SIZE","1")): 148 | dist.all_reduce(torch.tensor([0]).cuda()) 149 | return acc_num,mn_loss,False 150 | else: 151 | return acc_num,mn_loss,True 152 | 153 | else: 154 | if acc_num > self.args.stop_interval: 155 | return acc_num,mn_loss,True 156 | else: 157 | return acc_num,mn_loss,False 158 | 159 | def run_single_step(self,data,steps): 160 | data = self.process_input(data) 161 | self.run_discriminator_one_step(data,steps) 162 | self.run_generator_one_step(data,steps) 163 | 164 | def select_img(self,data,name='fake'): 165 | if data is None: 166 | return None 167 | cat_img = [] 168 | for v in data: 169 | cat_img.append(v.detach().cpu()) 170 | 171 | cat_img = torch.cat(cat_img,-1) 172 | cat_img = torch.cat(torch.split(cat_img,1,dim=0),2)[0] 173 | 174 | return {name:convert_img(cat_img)} 175 | 176 | 177 | 178 | ################################################################## 179 | # Helper functions 180 | ################################################################## 181 | 182 | def get_loss_from_val(self,loss): 183 | return loss 184 | 185 | def get_show_inp(self,data): 186 | if not isinstance(data,list): 187 | return [data] 188 | return data 189 | 190 | def use_ddp(self,model): 191 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) #用于将BN转换成ddp模式/ 192 | # model = DDP(model,broadcast_buffers=False,find_unused_parameters=True) # find_unused_parameters->训练gan会有判别器或生成器参数不参与训练,需使用该参数 193 | model = DDP(model, 194 | broadcast_buffers=False, 195 | ) 196 | model_on_one_gpu = model.module #若需要调用self.model的函数,在ddp模式要调用self._model_on_one_gpu 197 | return model,model_on_one_gpu 198 | def process_input(self,data): 199 | 200 | if torch.cuda.is_available(): 201 | if isinstance(data,list): 202 | data = [x.cuda() for x in data] 203 | else: 204 | data = data.cuda() 205 | return data -------------------------------------------------------------------------------- /trainer/TTNTrainer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author LeslieZhao 5 | @date 20220721 6 | ''' 7 | import torch 8 | 9 | from trainer.ModelTrainer import ModelTrainer 10 | from model.Pix2PixModule.model import Generator,Discriminator,ExpressDetector 11 | from utils.utils import * 12 | from model.Pix2PixModule.module import * 13 | from model.Pix2PixModule.loss import * 14 | import torch.distributed as dist 15 | import random 16 | import itertools 17 | 18 | class TTNTrainer(ModelTrainer): 19 | 20 | def __init__(self, args): 21 | super().__init__(args) 22 | self.device = 'cpu' 23 | if torch.cuda.is_available(): 24 | self.device = 'cuda' 25 | self.netG = Generator(img_channels=3).to(self.device) 26 | 27 | self.netTxD = Discriminator(in_channels=1).to(self.device) 28 | 29 | self.netSfD = Discriminator(in_channels=3).to(self.device) 30 | 31 | self.ExpG = None 32 | if args.use_exp: 33 | self.ExpG = ExpressDetector().to(self.device) 34 | self.ExpG.apply(init_weights) 35 | 36 | self.netG.apply(init_weights) 37 | 38 | self.netTxD.apply(init_weights) 39 | self.netSfD.apply(init_weights) 40 | 41 | self.optimG,self.optimD,self.optimExp = self.create_optimizer() 42 | 43 | if args.pretrain_path is not None: 44 | self.loadParameters(args.pretrain_path) 45 | 46 | if args.dist: 47 | self.netG,self.netG_module = self.use_ddp(self.netG) 48 | 49 | self.netTxD,self.netTxD_module = self.use_ddp(self.netTxD) 50 | self.netSfD,self.netSfD_module = self.use_ddp(self.netSfD) 51 | 52 | if args.use_exp: 53 | self.ExpG,self.ExpG_module = self.use_ddp(self.ExpG) 54 | else: 55 | self.netG_module = self.netG 56 | self.netTxD_module = self.netTxD 57 | self.netSfD_module = self.netSfD 58 | if args.use_exp: 59 | self.ExpG_module = self.ExpG 60 | 61 | 62 | self.VggLoss = VGGLoss(args.vgg_model).to(self.device).eval() 63 | self.TVLoss = TVLoss(1).to(self.device).eval() 64 | self.L1_Loss = nn.L1Loss() 65 | self.MSE_Loss = nn.MSELoss() 66 | 67 | 68 | def create_optimizer(self): 69 | g_optim = torch.optim.Adam( 70 | self.netG.parameters(), 71 | lr=self.args.lr, 72 | betas=(self.args.beta1,self.args.beta2), 73 | ) 74 | 75 | d_optim = torch.optim.Adam( 76 | 77 | itertools.chain(self.netTxD.parameters(),self.netSfD.parameters()), 78 | lr=self.args.lr, 79 | betas=(self.args.beta1,self.args.beta2), 80 | ) 81 | exp_optim = None 82 | if self.args.use_exp: 83 | exp_optim = torch.optim.Adam( 84 | self.ExpG.parameters(), 85 | lr=self.args.lr, 86 | betas=(self.args.beta1,self.args.beta2), 87 | ) 88 | 89 | return g_optim,d_optim,exp_optim 90 | 91 | 92 | def run_single_step(self, data, steps): 93 | self.netG.train() 94 | super().run_single_step(data, steps) 95 | 96 | 97 | def run_discriminator_one_step(self, data,step): 98 | 99 | D_losses = {} 100 | requires_grad(self.netTxD, True) 101 | requires_grad(self.netSfD, True) 102 | xs,xt,_ = data 103 | with torch.no_grad(): 104 | xg = self.netG(xs) 105 | 106 | # surface 107 | 108 | blur_fake = guided_filter(xg,xg,r=5,eps=2e-1) 109 | blur_style = guided_filter(xt,xt,r=5,eps=2e-1) 110 | 111 | D_blur_real = self.netSfD(blur_style) 112 | D_blur_fake = self.netSfD(blur_fake) 113 | d_loss_surface_real = self.MSE_Loss(D_blur_real,torch.ones_like(D_blur_real)) 114 | d_loss_surface_fake = self.MSE_Loss(D_blur_fake, torch.zeros_like(D_blur_fake)) 115 | d_loss_surface = (d_loss_surface_real + d_loss_surface_fake)/2.0 116 | 117 | D_losses['d_surface_real'] = d_loss_surface_real 118 | D_losses['d_surface_fake'] = d_loss_surface_fake 119 | 120 | # texture 121 | gray_fake = color_shift(xg) 122 | gray_style = color_shift(xt) 123 | 124 | D_gray_real = self.netTxD(gray_style) 125 | D_gray_fake = self.netTxD(gray_fake.detach()) 126 | d_loss_texture_real = self.MSE_Loss(D_gray_real, torch.ones_like(D_gray_real)) 127 | d_loss_texture_fake = self.MSE_Loss(D_gray_fake, torch.zeros_like(D_gray_fake)) 128 | d_loss_texture = (d_loss_texture_real + d_loss_texture_fake)/2.0 129 | 130 | D_losses['d_texture_real'] = d_loss_texture_real 131 | D_losses['d_texture_fake'] = d_loss_texture_fake 132 | 133 | d_loss_total = d_loss_surface + d_loss_texture 134 | 135 | 136 | self.optimD.zero_grad() 137 | d_loss_total.backward() 138 | 139 | self.optimD.step() 140 | self.d_losses = D_losses 141 | 142 | def run_generator_one_step(self, data,step): 143 | 144 | G_losses = {} 145 | requires_grad(self.netG, True) 146 | requires_grad(self.netTxD, False) 147 | requires_grad(self.netSfD, False) 148 | requires_grad(self.ExpG, False) 149 | xs,xt,exp_gt = data 150 | 151 | G_losses,losses,xg = self.compute_g_loss(xs,exp_gt,step) 152 | 153 | self.netG.zero_grad() 154 | losses.backward() 155 | self.optimG.step() 156 | 157 | if self.args.use_exp: 158 | requires_grad(self.ExpG, True) 159 | pred_exp = self.ExpG(xg.detach()) 160 | exp_loss = self.MSE_Loss(pred_exp,exp_gt) 161 | self.optimExp.zero_grad() 162 | exp_loss.backward() 163 | self.optimExp.step() 164 | G_losses['raw_exp_loss'] = exp_loss 165 | 166 | 167 | self.g_losses = G_losses 168 | self.generator = [xs.detach(),xg.detach(),xt.detach()] 169 | 170 | 171 | def evalution(self,test_loader,steps,epoch): 172 | 173 | loss_dict = {} 174 | counter = 0 175 | index = random.randint(0,len(test_loader)-1) 176 | self.netG.eval() 177 | 178 | with torch.no_grad(): 179 | for i,data in enumerate(test_loader): 180 | 181 | data = self.process_input(data) 182 | xs,xt,exp_gt = data 183 | G_losses,losses,xg = self.compute_g_loss(xs,exp_gt,steps) 184 | for k,v in G_losses.items(): 185 | loss_dict[k] = loss_dict.get(k,0) + v.detach() 186 | if i == index and self.args.rank == 0 : 187 | 188 | self.val_vis.display_current_results(self.select_img([xs,xg,xt]),steps) 189 | counter += 1 190 | 191 | 192 | for key,val in loss_dict.items(): 193 | loss_dict[key] /= counter 194 | 195 | if self.args.dist: 196 | # if self.args.rank == 0 : 197 | dist_losses = loss_dict.copy() 198 | for key,val in loss_dict.items(): 199 | 200 | dist.reduce(dist_losses[key],0) 201 | value = dist_losses[key].item() 202 | loss_dict[key] = value / self.args.world_size 203 | 204 | if self.args.rank == 0 : 205 | self.val_vis.plot_current_errors(loss_dict,steps) 206 | self.val_vis.print_current_errors(epoch+1,0,loss_dict,0) 207 | 208 | return loss_dict 209 | 210 | 211 | def compute_g_loss(self,xs,exp_gt,step): 212 | G_losses = {} 213 | 214 | xg = self.netG(xs) 215 | 216 | # warp_up 217 | if step < 100: 218 | lambda_surface = 0 219 | lambda_texture = 0 220 | lambda_exp = 0 221 | 222 | elif step < 500: 223 | lambda_surface = self.args.lambda_surface * 0.01 224 | lambda_texture = self.args.lambda_texture * 0.01 225 | lambda_exp = self.args.lambda_exp * 0.01 226 | 227 | elif step < 1000: 228 | lambda_surface = self.args.lambda_surface * 0.1 229 | lambda_texture = self.args.lambda_texture * 0.1 230 | lambda_exp = self.args.lambda_exp * 0.1 231 | 232 | else: 233 | lambda_surface = self.args.lambda_surface 234 | lambda_texture = self.args.lambda_texture 235 | lambda_exp = self.args.lambda_exp 236 | 237 | 238 | # surface 239 | blur_fake = guided_filter(xg,xg,r=5,eps=2e-1) 240 | D_blur_fake = self.netSfD(blur_fake) 241 | g_loss_surface = lambda_surface * self.MSE_Loss(D_blur_fake, torch.ones_like(D_blur_fake)) 242 | G_losses['g_loss_surface'] = g_loss_surface 243 | 244 | # texture 245 | gray_fake = color_shift(xg) 246 | D_gray_fake = self.netTxD(gray_fake) 247 | g_loss_texture = lambda_texture * self.MSE_Loss(D_gray_fake, torch.ones_like(D_gray_fake)) 248 | G_losses['g_loss_texture'] = g_loss_texture 249 | 250 | # content 251 | content_loss = self.VggLoss(xs,xg) * self.args.lambda_content 252 | G_losses['content_loss'] = content_loss 253 | 254 | # tv loss 255 | tv_loss = self.TVLoss(xg) * self.args.lambda_tv 256 | G_losses['tvloss'] = tv_loss 257 | 258 | # exp loss 259 | exp_loss = 0 260 | if self.args.use_exp: 261 | exp_pred = self.ExpG(xg) 262 | exp_loss = self.MSE_Loss(exp_pred,exp_gt) * lambda_exp 263 | G_losses['exploss'] = exp_loss 264 | 265 | losses = g_loss_surface + g_loss_texture + content_loss + tv_loss + exp_loss * lambda_exp 266 | G_losses['total_loss'] = losses 267 | return G_losses,losses,xg 268 | 269 | def get_latest_losses(self): 270 | return {**self.g_losses,**self.d_losses} 271 | 272 | def get_latest_generated(self): 273 | return self.generator 274 | 275 | def get_loss_from_val(self,loss): 276 | return loss['total_loss'] 277 | 278 | def loadParameters(self,path): 279 | ckpt = torch.load(path, map_location=lambda storage, loc: storage) 280 | self.netG.load_state_dict(ckpt['netG'],strict=False) 281 | self.netTxD.load_state_dict(ckpt['netTxD'],strict=False) 282 | self.netSfD.load_state_dict(ckpt['netSfD'],strict=False) 283 | self.optimG.load_state_dict(ckpt['g_optim']) 284 | self.optimD.load_state_dict(ckpt['d_optim']) 285 | if self.args.use_exp: 286 | self.ExpG.load_state_dict(ckpt['ExpG'],strict=False) 287 | self.optimExp.load_state_dict(ckpt['exp_optim']) 288 | def saveParameters(self,path): 289 | save_dict = { 290 | "netG": self.netG_module.state_dict(), 291 | "netTxD": self.netTxD_module.state_dict(), 292 | "netSfD": self.netSfD_module.state_dict(), 293 | "g_optim": self.optimG.state_dict(), 294 | "d_optim": self.optimD.state_dict(), 295 | "args": self.args, 296 | } 297 | if self.args.use_exp: 298 | save_dict['ExpG'] = self.ExpG_module.state_dict() 299 | save_dict['exp_optim'] = self.optimExp.state_dict() 300 | torch.save( 301 | save_dict, 302 | path 303 | ) 304 | 305 | def get_lr(self): 306 | return self.optimG.state_dict()['param_groups'][0]['lr'] 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | -------------------------------------------------------------------------------- /utils/download_weight.sh: -------------------------------------------------------------------------------- 1 | wget https://github.com/LeslieZhoa/DCT-NET.Pytorch/releases/download/v0.0/final.pth -P ../pretrain_models 2 | wget https://github.com/LeslieZhoa/DCT-NET.Pytorch/releases/download/v0.0/model_ir_se50.pth -P ../pretrain_models 3 | wget https://github.com/LeslieZhoa/DCT-NET.Pytorch/releases/download/v0.0/vgg19-dcbb9e9d.pth -P ../pretrain_models -------------------------------------------------------------------------------- /utils/get_face_expression.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pdb 4 | import cv2 5 | import os 6 | from multiprocessing import Pool 7 | import time 8 | import math 9 | import multiprocessing as mp 10 | import numpy as np 11 | 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser(description="Process") 15 | parser.add_argument('--img_base',default="",type=str,help='') 16 | parser.add_argument('--pool_num',default=2,type=int,help='') 17 | parser.add_argument('--LVT',default='',type=str,help='') 18 | parser.add_argument('--train',action="store_true",help='') 19 | args = parser.parse_args() 20 | 21 | class Process: 22 | def __init__(self): 23 | self.engine = Engine( 24 | face_lmk_path='') 25 | 26 | def run(self,img_paths): 27 | mx_left_eye = -1 28 | mn_left_eye = 100 29 | 30 | mx_right_eye = -1 31 | mn_right_eye = 100 32 | 33 | mx_lip = -1 34 | mn_lip = 100 35 | for i,img_path in enumerate(img_paths): 36 | img = cv2.imread(img_path) 37 | left_eye_score,right_eye_score,lip_score = \ 38 | self.run_single(img) 39 | 40 | base,img_name = os.path.split(img_path) 41 | score_base = base.replace('img','express') 42 | os.makedirs(score_base,exist_ok=True) 43 | np.save(os.path.join(score_base,img_name.split('.')[0]+'.npy'),[left_eye_score,right_eye_score,lip_score]) 44 | 45 | mx_left_eye = max(left_eye_score,mx_left_eye) 46 | mn_left_eye = min(left_eye_score,mn_left_eye) 47 | 48 | mx_right_eye = max(right_eye_score,mx_right_eye) 49 | mn_right_eye = min(right_eye_score,mn_right_eye) 50 | 51 | mx_lip = max(lip_score,mx_lip) 52 | mn_lip = min(lip_score,mn_lip) 53 | print('\rhave done %04d'%i,end='',flush=True) 54 | print() 55 | return mx_left_eye,mn_left_eye,mx_right_eye,mn_right_eye,mx_lip,mn_lip 56 | 57 | 58 | def run_single(self,img): 59 | inp = self.engine.preprocess_lmk(img) 60 | lmk = self.engine.get_lmk(inp) 61 | lmk = self.engine.postprocess_lmk(lmk,256,[0,0]) 62 | scores = self.get_expression(lmk[0]) 63 | return scores 64 | def get_expression(self,lmk): 65 | left_eye_h = abs(lmk[66,1]-lmk[62,1]) 66 | left_eye_w = abs(lmk[60,0]-lmk[64,0]) 67 | left_eye_score = left_eye_h / max(left_eye_w,1e-5) 68 | 69 | right_eye_h = abs(lmk[70,1]-lmk[74,1]) 70 | right_eye_w = abs(lmk[68,0]-lmk[72,0]) 71 | right_eye_score = right_eye_h / max(right_eye_w,1e-5) 72 | 73 | lip_h = abs(lmk[90,1]-lmk[94,1]) 74 | lip_w = abs(lmk[88,0]-lmk[82,0]) 75 | lip_score = lip_h / max(lip_w,1e-5) 76 | 77 | return left_eye_score,right_eye_score,lip_score 78 | 79 | def work(queue,img_paths): 80 | model = Process() 81 | mx_left_eye,mn_left_eye,mx_right_eye,mn_right_eye,mx_lip,mn_lip = \ 82 | model.run(img_paths) 83 | queue.put([mx_left_eye,mn_left_eye,mx_right_eye,mn_right_eye,mx_lip,mn_lip]) 84 | 85 | def print_error(value): 86 | print("error: ", value) 87 | if __name__ == "__main__": 88 | 89 | args = parser.parse_args() 90 | # import LVT base path 91 | sys.path.insert(0,args.LVT) 92 | from LVT import Engine 93 | from utils import utils 94 | 95 | mp.set_start_method('spawn') 96 | m = mp.Manager() 97 | queue = m.Queue() 98 | model = Process() 99 | base = args.img_base 100 | img_paths = [os.path.join(base,f) for f in os.listdir(base)] 101 | pool_num = args.pool_num 102 | length = len(img_paths) 103 | 104 | dis = math.ceil(length/float(pool_num)) 105 | 106 | 107 | t1 = time.time() 108 | print('***************all length: %d ******************'%length) 109 | p = Pool(pool_num) 110 | for i in range(pool_num): 111 | p.apply_async(work, args = (queue,img_paths[i*dis:(i+1)*dis],),error_callback=print_error) 112 | 113 | p.close() 114 | p.join() 115 | print("all the time: %s"%(time.time()-t1)) 116 | 117 | if args.train: 118 | mx_left_eye_all = -1 119 | mn_left_eye_all = 100 120 | 121 | mx_right_eye_all = -1 122 | mn_right_eye_all = 100 123 | 124 | mx_lip_all = -1 125 | mn_lip_all = 100 126 | 127 | while not queue.empty(): 128 | mx_left_eye,mn_left_eye,mx_right_eye,mn_right_eye,mx_lip,mn_lip = \ 129 | queue.get() 130 | 131 | mx_left_eye_all = max(mx_left_eye_all,mx_left_eye) 132 | mn_left_eye_all = min(mn_left_eye_all,mn_left_eye) 133 | 134 | mx_right_eye_all = max(mx_right_eye_all,mx_right_eye) 135 | mn_right_eye_all = min(mn_right_eye_all,mn_right_eye) 136 | 137 | mx_lip_all = max(mx_lip_all,mx_lip) 138 | mn_lip_all = min(mn_lip_all,mn_lip) 139 | os.makedirs('../pretrain_models',exist_ok=True) 140 | np.save('../pretrain_models/all_express_mean.npy',[mx_left_eye_all,mn_left_eye_all, 141 | mx_right_eye_all,mn_right_eye_all, 142 | mx_lip_all,mn_lip_all]) 143 | 144 | -------------------------------------------------------------------------------- /utils/get_tcc_input.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | import pdb 4 | from model.styleganModule.model import Generator 5 | import torch 6 | from model.styleganModule.utils import * 7 | from utils import convert_img 8 | from model.styleganModule.config import Params as CCNParams 9 | import cv2 10 | import os 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description="Process") 14 | parser.add_argument('--model_path',default="",type=str,help='') 15 | parser.add_argument('--output_path',default="",type=str,help='') 16 | args = parser.parse_args() 17 | 18 | class Process: 19 | def __init__(self,model_path): 20 | self.args = CCNParams() 21 | self.device = 'cpu' 22 | if torch.cuda.is_available(): 23 | self.device = 'cuda' 24 | 25 | self.netGt = Generator( 26 | self.args.size,self.args.latent, 27 | self.args.n_mlp, 28 | channel_multiplier=self.args.channel_multiplier).to(self.device) 29 | self.netGs = Generator( 30 | self.args.size,self.args.latent, 31 | self.args.n_mlp, 32 | channel_multiplier=self.args.channel_multiplier).to(self.device) 33 | 34 | self.loadparams(model_path) 35 | 36 | self.netGs.eval() 37 | self.netGt.eval() 38 | 39 | def __call__(self,save_base): 40 | os.makedirs(save_base,exist_ok=True) 41 | steps = 0 42 | for i in range(self.args.mx_gen_iters): 43 | fakes = self.run_sigle() 44 | for f in fakes: 45 | cv2.imwrite(os.path.join(save_base,'%06d.png'%steps),f) 46 | steps += 1 47 | print('\r have done %06d'%steps,end='',flush=True) 48 | 49 | def run_sigle(self): 50 | noise = mixing_noise(self.args.infer_batch_size, 51 | self.args.latent, 52 | self.args.mixing,self.device) 53 | with torch.no_grad(): 54 | latent_s = self.netGs(noise,only_latent=True) 55 | latent_t = self.netGt(noise,only_latent=True) 56 | # fake_s,latent_s = self.netGs(noise,return_latents=True) 57 | # fake_t,latent_t = self.netGt(noise,return_latents=True) 58 | # mix 59 | latent_mix = torch.cat([latent_s[:,:self.args.inject_index],latent_t[:,self.args.inject_index:]],1) 60 | fake,_ = self.netGt([latent_mix],input_is_latent=True) 61 | # fake = torch.cat([fake_s,fake_t,fake],-1) 62 | fake = convert_img(fake,unit=True).permute(0,2,3,1) 63 | return fake.cpu().numpy()[...,::-1] 64 | 65 | 66 | def loadparams(self,path): 67 | ckpt = torch.load(path, map_location=lambda storage, loc: storage) 68 | self.netGs.load_state_dict(ckpt['Gs'],strict=False) 69 | self.netGt.load_state_dict(ckpt['gt_ema'],strict=False) 70 | 71 | if __name__ == "__main__": 72 | args = parser.parse_args() 73 | 74 | 75 | model = Process(args.model_path) 76 | 77 | model(args.output_path) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | @author LeslieZhao 4 | @date 20220721 5 | ''' 6 | import torch 7 | from data.CCNLoader import CCNData 8 | from data.TTNLoader import TTNData 9 | 10 | import os 11 | import torch.distributed as dist 12 | 13 | def reduce_loss_dict(loss_dict): 14 | world_size = int(os.environ.get("WORLD_SIZE","1")) 15 | 16 | if world_size < 2: 17 | return loss_dict 18 | 19 | with torch.no_grad(): 20 | keys = [] 21 | losses = [] 22 | 23 | for k in sorted(loss_dict.keys()): 24 | keys.append(k) 25 | losses.append(loss_dict[k]) 26 | 27 | losses = torch.stack(losses, 0) 28 | dist.reduce(losses, dst=0) 29 | 30 | if dist.get_rank() == 0: 31 | losses /= world_size 32 | 33 | reduced_losses = {k: v for k, v in zip(keys, losses)} 34 | 35 | return reduced_losses 36 | 37 | def requires_grad(model, flag=True): 38 | if model is None: 39 | return 40 | for p in model.parameters(): 41 | p.requires_grad = flag 42 | def need_grad(x): 43 | x = x.detach() 44 | x.requires_grad_() 45 | return x 46 | 47 | def init_weights(model,init_type='normal', gain=0.02): 48 | def init_func(m): 49 | classname = m.__class__.__name__ 50 | if classname.find('BatchNorm2d') != -1: 51 | if hasattr(m, 'weight') and m.weight is not None: 52 | torch.nn.init.normal_(m.weight.data, 1.0, gain) 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | torch.nn.init.constant_(m.bias.data, 0.0) 55 | 56 | elif hasattr(m, 'bias') and m.bias is not None: 57 | torch.nn.init.constant_(m.bias.data, 0.0) 58 | 59 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 60 | if init_type == 'normal': 61 | torch.nn.init.normal_(m.weight.data, 0.0, gain) 62 | elif init_type == 'xavier': 63 | torch.nn.init.xavier_normal_(m.weight.data, gain=gain) 64 | elif init_type == 'xavier_uniform': 65 | torch.nn.init.xavier_uniform_(m.weight.data, gain=1.0) 66 | elif init_type == 'kaiming': 67 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 68 | elif init_type == 'orthogonal': 69 | torch.nn.init.orthogonal_(m.weight.data, gain=gain) 70 | elif init_type == 'none': # uses pytorch's default init method 71 | m.reset_parameters() 72 | else: 73 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 74 | 75 | model.apply(init_func) 76 | 77 | def accumulate(model1, model2, decay=0.999,use_buffer=False): 78 | par1 = dict(model1.named_parameters()) 79 | par2 = dict(model2.named_parameters()) 80 | 81 | for k in par1.keys(): 82 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 83 | 84 | if use_buffer: 85 | for p1,p2 in zip(model1.buffers(),model2.buffers()): 86 | p1.detach().copy_(decay*p1.detach()+(1-decay)*p2.detach()) 87 | 88 | def setup_seed(seed): 89 | torch.manual_seed(seed) 90 | if torch.cuda.is_available(): 91 | torch.cuda.manual_seed_all(seed) 92 | torch.backends.cudnn.deterministic = True 93 | 94 | def get_data_loader(args): 95 | if args.model == 'ccn': 96 | train_data = CCNData(root=args.root,dist=args.dist) 97 | test_data = None 98 | if args.model == 'ttn': 99 | train_data = TTNData(dist=args.dist,eval=False, 100 | src_root=args.train_src_root, 101 | tgt_root=args.train_tgt_root, 102 | score_info=args.score_info) 103 | test_data = TTNData(dist=args.dist,eval=True, 104 | src_root=args.val_src_root, 105 | tgt_root=args.val_tgt_root, 106 | score_info=args.score_info) 107 | 108 | train_loader = torch.utils.data.DataLoader( 109 | train_data, 110 | batch_size=args.batch_size, 111 | num_workers=args.nDataLoaderThread, 112 | pin_memory=False, 113 | drop_last=True 114 | ) 115 | test_loader = None if test_data is None else \ 116 | torch.utils.data.DataLoader( 117 | test_data, 118 | batch_size=args.batch_size, 119 | num_workers=args.nDataLoaderThread, 120 | pin_memory=False, 121 | drop_last=True 122 | ) 123 | return train_loader,test_loader,len(train_data) 124 | 125 | 126 | 127 | def merge_args(args,params): 128 | for k,v in vars(params).items(): 129 | setattr(args,k,v) 130 | return args 131 | 132 | def convert_img(img,unit=False): 133 | 134 | img = (img + 1) * 0.5 135 | if unit: 136 | return torch.clamp(img*255+0.5,0,255) 137 | 138 | return torch.clamp(img,0,1) -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. ALL rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import time 8 | import subprocess 9 | 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | 14 | class Visualizer: 15 | def __init__(self,opt,mode='train'): 16 | self.opt = opt 17 | self.name = opt.name 18 | self.mode = mode 19 | self.train_log_dir = os.path.join(opt.checkpoint_path,"logs/%s"%mode) 20 | self.log_name = os.path.join(opt.checkpoint_path,'loss_log_%s.txt'%mode) 21 | if opt.local_rank == 0: 22 | if not os.path.exists(self.train_log_dir): 23 | os.makedirs(self.train_log_dir) 24 | 25 | self.train_writer = SummaryWriter(self.train_log_dir) 26 | 27 | self.log_file = open(self.log_name,"a") 28 | now = time.strftime("%c") 29 | self.log_file.write('================ Training Loss (%s) =================\n'%now) 30 | self.log_file.flush() 31 | 32 | 33 | # errors:dictionary of error labels and values 34 | def plot_current_errors(self,errors,step): 35 | 36 | for tag,value in errors.items(): 37 | 38 | self.train_writer.add_scalar("%s/"%self.name+tag,value,step) 39 | self.train_writer.flush() 40 | 41 | 42 | # errors: same format as |errors| of CurrentErrors 43 | def print_current_errors(self,epoch,i,errors,t): 44 | message = '(epoch: %d\t iters: %d\t time: %.5f)\t'%(epoch,i,t) 45 | for k,v in errors.items(): 46 | 47 | message += '%s: %.5f\t' %(k,v) 48 | 49 | print(message) 50 | 51 | self.log_file.write('%s\n' % message) 52 | self.log_file.flush() 53 | 54 | def display_current_results(self, visuals, step): 55 | if visuals is None: 56 | return 57 | for label, image in visuals.items(): 58 | # Write the image to a string 59 | 60 | self.train_writer.add_image("%s/"%self.name+label,image,global_step=step) 61 | 62 | def close(self): 63 | 64 | self.train_writer.close() 65 | self.log_file.close() --------------------------------------------------------------------------------