├── .gitignore
├── LICENSE
├── ReadMe.md
├── assets
├── ldh.png
├── net.png
└── xcaq.gif
├── data
├── CCNLoader.py
├── DataLoader.py
└── TTNLoader.py
├── inference.py
├── model
├── Pix2PixModule
│ ├── config.py
│ ├── loss.py
│ ├── model.py
│ └── module.py
└── styleganModule
│ ├── arcface.py
│ ├── config.py
│ ├── loss.py
│ ├── model.py
│ ├── op
│ ├── __init__.py
│ ├── conv2d_gradfix.py
│ ├── fused_act.py
│ ├── fused_bias_act.cpp
│ ├── fused_bias_act_kernel.cu
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
│ └── utils.py
├── run.sh
├── train.py
├── trainer
├── CCNTrainer.py
├── ModelTrainer.py
└── TTNTrainer.py
└── utils
├── download_weight.sh
├── get_face_expression.py
├── get_tcc_input.py
├── utils.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__*
2 | pretrain_models
3 | *checkpoint*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 LeslieZhao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/ReadMe.md:
--------------------------------------------------------------------------------
1 | # DCT-NET.Pytorch
2 | unofficial implementation of DCT-Net: Domain-Calibrated Translation for Portrait Stylization.
3 | you can find official version [here](https://github.com/menyifang/DCT-Net)
4 | 
5 |
6 | ## show
7 | 
8 | 
9 |
10 | ## environment
11 | you can build your environment follow [this](https://github.com/rosinality/stylegan2-pytorch)
12 | ```pip install tensorboardX ``` for show
13 |
14 | ## how to run
15 | ### train
16 | download pretrain weights
17 |
18 | ```shell
19 | cd utils
20 | bash download_weight.sh
21 | ```
22 | follow [rosinality/stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch) and put 550000.pt in pretrain_models
23 | #### CCN
24 | 1. prepare the style pictures and align them
25 | the image path is like this
26 | style-photos/
27 | |-- 000000.png
28 | |-- 000006.png
29 | |-- 000010.png
30 | |-- 000011.png
31 | |-- 000015.png
32 | |-- 000028.png
33 | |-- 000039.png
34 | 2. change your own path in [ccn_config](./model/styleganModule/config.py#L7)
35 | 3. train ccn
36 |
37 | ```shell
38 | # single gpu
39 | python train.py \
40 | --model ccn \
41 | --batch_size 16 \
42 | --checkpoint_path checkpoint \
43 | --lr 0.002 \
44 | --print_interval 100 \
45 | --save_interval 100 --dist
46 | ```
47 |
48 | ```shell
49 | # multi gpu
50 | python -m torch.distributed.launch train.py \
51 | --model ccn \
52 | --batch_size 16 \
53 | --checkpoint_path checkpoint \
54 | --lr 0.002 \
55 | --print_interval 100 \
56 | --save_interval 100
57 | ```
58 | almost 1000 steps, you can stop
59 | #### TTN
60 | 1. prepare expression information
61 | you can follow [LVT](https://github.com/LeslieZhoa/LVT) to estimate facial landmark
62 | ```shell
63 | cd utils
64 | python get_face_expression.py \
65 | --img_base '' # your real image path base,like ffhq \
66 | --pool_num 2 # multiprocess number \
67 | --LVT '' # the LVT path you put \
68 | --train # train data or val data
69 | ```
70 | 2. prepare your generator image
71 | ```shell
72 | cd utils
73 | python get_tcc_input.py \
74 | --model_path '' # ccn model path \
75 | --output_path '' # save path
76 | ```
77 | __select almost 5k~1w good image manually__
78 | 3. change your own path in [ttn_config](./model/Pix2PixModule/config.py#21)
79 | ```shell
80 | # like
81 | self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img'
82 | self.train_tgt_root = '/StyleTransform/DATA/select-style-gan'
83 | self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img'
84 | self.val_tgt_root = '/StyleTransform/DATA/select-style-gan'
85 | ```
86 | 4. train tnn
87 | ```shell
88 | # like ccn single and multi gpus
89 | python train.py \
90 | --model ttn \
91 | --batch_size 64 \
92 | --checkpoint_path checkpoint \
93 | --lr 2e-4 \
94 | --print_interval 100 \
95 | --save_interval 100 \
96 | --dist
97 | ```
98 | ## inference
99 | you can follow inference.py to put your own ttn model path and image path
100 | ```python inference.py```
101 |
102 | ## Credits
103 | SEAN model and implementation:
104 | https://github.com/ZPdesu/SEAN Copyright © 2020, ZPdesu.
105 | License https://github.com/ZPdesu/SEAN/blob/master/LICENSE.md
106 |
107 | stylegan2-pytorch model and implementation:
108 | https://github.com/rosinality/stylegan2-pytorch Copyright © 2019, rosinality.
109 | License https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE
110 |
111 | White-box-Cartoonization model and implementation:
112 | https://github.com/SystemErrorWang/White-box-Cartoonization Copyright © 2020, SystemErrorWang.
113 |
114 | White-box-Cartoonization model pytorch model and implementation:
115 | https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch Copyright © 2022, vinesmsuic.
116 | License https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch/blob/main/LICENSE
117 |
118 | arcface pytorch model pytorch model and implementation:
119 | https://github.com/ronghuaiyang/arcface-pytorch Copyright © 2018, ronghuaiyang.
120 |
121 |
122 |
123 |
--------------------------------------------------------------------------------
/assets/ldh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeslieZhoa/DCT-NET.Pytorch/bedafe430de2c92d21cea0d587a78d6cb06292e7/assets/ldh.png
--------------------------------------------------------------------------------
/assets/net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeslieZhoa/DCT-NET.Pytorch/bedafe430de2c92d21cea0d587a78d6cb06292e7/assets/net.png
--------------------------------------------------------------------------------
/assets/xcaq.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeslieZhoa/DCT-NET.Pytorch/bedafe430de2c92d21cea0d587a78d6cb06292e7/assets/xcaq.gif
--------------------------------------------------------------------------------
/data/CCNLoader.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author LeslieZhao
5 | @date 20220721
6 | '''
7 |
8 | import os
9 |
10 | from torchvision import transforms
11 | import PIL.Image as Image
12 | from data.DataLoader import DatasetBase
13 | import random
14 |
15 |
16 | class CCNData(DatasetBase):
17 | def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
18 | super().__init__(slice_id, slice_count,dist, **kwargs)
19 |
20 |
21 | self.transform = transforms.Compose([
22 | transforms.RandomHorizontalFlip(),
23 | transforms.ToTensor(),
24 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
25 | ])
26 |
27 | root = kwargs['root']
28 | self.paths = [os.path.join(root,f) for f in os.listdir(root)]
29 | self.length = len(self.paths)
30 | random.shuffle(self.paths)
31 |
32 | def __getitem__(self,i):
33 | idx = i % self.length
34 | img_path = self.paths[idx]
35 |
36 | with Image.open(img_path) as img:
37 | Img = self.transform(img)
38 |
39 | return Img
40 |
41 |
42 | def __len__(self):
43 | return max(100000,self.length)
44 | # return 4
45 |
46 |
--------------------------------------------------------------------------------
/data/DataLoader.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author LeslieZhao
5 | @date 20220721
6 | '''
7 |
8 |
9 | from torch.utils.data import Dataset
10 | import torch.distributed as dist
11 |
12 |
13 | class DatasetBase(Dataset):
14 | def __init__(self,slice_id=0,slice_count=1,use_dist=False,**kwargs):
15 |
16 | if use_dist:
17 | slice_id = dist.get_rank()
18 | slice_count = dist.get_world_size()
19 | self.id = slice_id
20 | self.count = slice_count
21 |
22 |
23 | def __getitem__(self,i):
24 | pass
25 |
26 |
27 |
28 |
29 | def __len__(self):
30 | return 1000
31 |
32 |
--------------------------------------------------------------------------------
/data/TTNLoader.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author LeslieZhao
5 | @date 20220721
6 | '''
7 | import os
8 | from torchvision import transforms
9 | import PIL.Image as Image
10 | from data.DataLoader import DatasetBase
11 | import random
12 | import numpy as np
13 | import torch
14 |
15 |
16 | class TTNData(DatasetBase):
17 | def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
18 | super().__init__(slice_id, slice_count,dist, **kwargs)
19 |
20 |
21 | self.transform = transforms.Compose([
22 | transforms.Resize([256,256]),
23 | transforms.RandomResizedCrop(256,scale=(0.8,1.2)),
24 | transforms.RandomRotation(degrees=(-90,90)),
25 | transforms.RandomHorizontalFlip(),
26 | transforms.ToTensor(),
27 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
28 | ])
29 |
30 | if kwargs['eval']:
31 | self.transform = transforms.Compose([
32 | transforms.Resize([256,256]),
33 | transforms.ToTensor(),
34 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
35 | self.length = 100
36 |
37 | src_root = kwargs['src_root']
38 | tgt_root = kwargs['tgt_root']
39 |
40 | self.src_paths = [os.path.join(src_root,f) for f in os.listdir(src_root) if f.endswith('.png')]
41 | self.tgt_paths = [os.path.join(tgt_root,f) for f in os.listdir(tgt_root) if f.endswith('.png')]
42 | self.src_length = len(self.src_paths)
43 | self.tgt_length = len(self.tgt_paths)
44 | random.shuffle(self.src_paths)
45 | random.shuffle(self.tgt_paths)
46 |
47 | self.mx_left_eye_all,\
48 | self.mn_left_eye_all,\
49 | self.mx_right_eye_all,\
50 | self.mn_right_eye_all,\
51 | self.mx_lip_all,\
52 | self.mn_lip_all = \
53 | np.load(kwargs['score_info'])
54 |
55 | def __getitem__(self,i):
56 | src_idx = i % self.src_length
57 | tgt_idx = i % self.tgt_length
58 |
59 | src_path = self.src_paths[src_idx]
60 | tgt_path = self.tgt_paths[tgt_idx]
61 | exp_path = src_path.replace('img','express')[:-3] + 'npy'
62 |
63 | with Image.open(src_path) as img:
64 | srcImg = self.transform(img)
65 |
66 | with Image.open(tgt_path) as img:
67 | tgtImg = self.transform(img)
68 |
69 | score = np.load(exp_path)
70 | score[0] = (score[0] - self.mn_left_eye_all) / (self.mx_left_eye_all - self.mn_left_eye_all)
71 | score[1] = (score[1] - self.mn_right_eye_all) / (self.mx_right_eye_all - self.mn_right_eye_all)
72 | score[2] = (score[2] - self.mn_lip_all) / (self.mx_lip_all - self.mn_lip_all)
73 | score = torch.from_numpy(score.astype(np.float32))
74 |
75 | return srcImg,tgtImg,score
76 |
77 |
78 | def __len__(self):
79 | # return max(self.src_length,self.tgt_length)
80 | if hasattr(self,'length'):
81 | return self.length
82 | else:
83 | return 10000
84 |
85 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | import torch
5 | from model.Pix2PixModule.model import Generator
6 | from utils.utils import convert_img
7 |
8 | class Infer:
9 | def __init__(self,model_path):
10 | self.net = Generator(img_channels=3)
11 | self.load_checkpoint(model_path)
12 |
13 |
14 | def run(self,img):
15 | if isinstance(img,str):
16 | img = cv2.imread(img)
17 | inp = self.preprocess(img)
18 | with torch.no_grad():
19 | xg = self.net(inp)
20 | oup = self.postprocess(xg[0])
21 | return oup
22 |
23 | def load_checkpoint(self,path):
24 | ckpt = torch.load(path, map_location=lambda storage, loc: storage)
25 | self.net.load_state_dict(ckpt['netG'],strict=False)
26 | if torch.cuda.is_available():
27 | self.net.cuda()
28 | self.net.eval()
29 |
30 | def preprocess(self,img):
31 |
32 | img = (img[...,::-1] / 255.0 - 0.5) * 2
33 | img = img.transpose(2,0,1)[np.newaxis,:].astype(np.float32)
34 | img = torch.from_numpy(img)
35 | if torch.cuda.is_available():
36 | img = img.cuda()
37 | return img
38 | def postprocess(self,img):
39 | img = convert_img(img,unit=True)
40 | return img.permute(1,2,0).cpu().numpy()[...,::-1]
41 |
42 |
43 |
44 | if __name__ == "__main__":
45 |
46 | path = 'pretrain_models/final.pth'
47 | model = Infer(path)
48 |
49 | img = cv2.imread('')
50 |
51 | img_h,img_w,_ = img.shape
52 | n_h,n_w = img_h // 8 * 8,img_w // 8 * 8
53 | img = cv2.resize(img,(n_w,n_h))
54 |
55 | oup = model.run(img)
56 | cv2.imwrite('output.png',oup)
57 |
58 |
59 |
--------------------------------------------------------------------------------
/model/Pix2PixModule/config.py:
--------------------------------------------------------------------------------
1 | class Params:
2 | def __init__(self):
3 |
4 | self.name = 'Pix2Pix'
5 |
6 | self.pretrain_path = None
7 | self.vgg_model = 'pretrain_models/vgg19-dcbb9e9d.pth'
8 | self.lr = 2e-4
9 | self.beta1 = 0.5
10 | self.beta2 = 0.99
11 |
12 | self.use_exp = True
13 | self.lambda_surface = 2.0
14 | self.lambda_texture = 2.0
15 | self.lambda_content = 200
16 | self.lambda_tv = 1e4
17 |
18 | self.lambda_exp = 1.0
19 |
20 |
21 | self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img'
22 | self.train_tgt_root = '/StyleTransform/DATA/select-style-gan'
23 | self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img'
24 | self.val_tgt_root = '/StyleTransform/DATA/select-style-gan'
25 | self.score_info = 'pretrain_models/all_express_mean.npy'
26 |
27 | self.infer_batch_size = 2
28 |
29 |
--------------------------------------------------------------------------------
/model/Pix2PixModule/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .module import *
5 |
6 | from torch.autograd import Variable
7 | # Perceptual loss that uses a pretrained VGG network
8 | class VGGLoss(nn.Module):
9 | def __init__(self,model_path):
10 | super(VGGLoss, self).__init__()
11 | self.vgg = VGG19(in_channels=3,
12 | VGGtype='VGG19',
13 | init_weights=model_path,
14 | batch_norm=False, feature_mode=True)
15 | self.vgg.eval()
16 | self.criterion = nn.L1Loss()
17 |
18 |
19 | def forward(self, x, y):
20 |
21 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
22 | _, c, h, w = x_vgg.shape
23 | loss = self.criterion(x_vgg,y_vgg)
24 | return loss *255 / (c*h*w)
25 |
26 |
27 |
28 |
29 | class TVLoss(nn.Module):
30 | def __init__(self, k_size):
31 | super().__init__()
32 | self.k_size = k_size
33 |
34 | def forward(self, image):
35 | b, c, h, w = image.shape
36 | tv_h = torch.mean((image[:, :, self.k_size:, :] - image[:, :, : -self.k_size, :])**2)
37 | tv_w = torch.mean((image[:, :, :, self.k_size:] - image[:, :, :, : -self.k_size])**2)
38 | tv_loss = (tv_h + tv_w) / (3 * h * w)
39 | return tv_loss.mean()
40 |
41 |
42 |
--------------------------------------------------------------------------------
/model/Pix2PixModule/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | class ResidualBlock(nn.Module):
6 | def __init__(self, channels, kernel_size, stride, padding, padding_mode):
7 | super().__init__()
8 | self.block = nn.Sequential(
9 | nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode),
10 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
11 | nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode),
12 | )
13 |
14 | def forward(self, x):
15 | #Elementwise Sum (ES)
16 | return x + self.block(x)
17 |
18 | class Generator(nn.Module):
19 | def __init__(self, img_channels=3, num_features=32, num_residuals=4, padding_mode="zeros"):
20 | super().__init__()
21 | self.padding_mode = padding_mode
22 |
23 | self.initial_down = nn.Sequential(
24 | #k7n32s1
25 | nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode=self.padding_mode),
26 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
27 | )
28 |
29 | #Down-convolution
30 | self.down1 = nn.Sequential(
31 | #k3n32s2
32 | nn.Conv2d(num_features, num_features, kernel_size=3, stride=2, padding=1, padding_mode=self.padding_mode),
33 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
34 |
35 | #k3n64s1
36 | nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode),
37 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
38 | )
39 |
40 | self.down2 = nn.Sequential(
41 | #k3n64s2
42 | nn.Conv2d(num_features*2, num_features*2, kernel_size=3, stride=2, padding=1, padding_mode=self.padding_mode),
43 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
44 |
45 | #k3n128s1
46 | nn.Conv2d(num_features*2, num_features*4, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode),
47 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
48 | )
49 |
50 | #Bottleneck: 4 residual blocks => 4 times [K3n128s1]
51 | self.res_blocks = nn.Sequential(
52 | *[ResidualBlock(num_features*4, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode) for _ in range(num_residuals)]
53 | )
54 |
55 | #Up-convolution
56 | self.up1 = nn.Sequential(
57 | #k3n128s1 (should be k3n64s1?)
58 | nn.Conv2d(num_features*4, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode),
59 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
60 | )
61 |
62 | self.up2 = nn.Sequential(
63 | #k3n64s1
64 | nn.Conv2d(num_features*2, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode),
65 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
66 | #k3n64s1 (should be k3n32s1?)
67 | nn.Conv2d(num_features*2, num_features, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode),
68 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
69 | )
70 |
71 | self.last = nn.Sequential(
72 | #k3n32s1
73 | nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1, padding_mode=self.padding_mode),
74 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
75 | #k7n3s1
76 | nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode=self.padding_mode)
77 | )
78 |
79 | def forward(self, x):
80 | x1 = self.initial_down(x)
81 | x2 = self.down1(x1)
82 | x = self.down2(x2)
83 | x = self.res_blocks(x)
84 | x = self.up1(x)
85 | #Resize Bilinear
86 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners = False)
87 | x = self.up2(x + x2)
88 | #Resize Bilinear
89 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners = False)
90 | x = self.last(x + x1)
91 | #TanH
92 | return torch.tanh(x)
93 |
94 |
95 |
96 | from torch.nn.utils import spectral_norm
97 |
98 | # PyTorch implementation by vinesmsuic
99 | # Referenced from official tensorflow implementation: https://github.com/SystemErrorWang/White-box-Cartoonization/blob/master/train_code/network.py
100 | # slim.convolution2d uses constant padding (zeros).
101 | # Paper used spectral_norm
102 |
103 | class Block(nn.Module):
104 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding,activate=True):
105 | super().__init__()
106 | self.sn_conv = spectral_norm(nn.Conv2d(
107 | in_channels,
108 | out_channels,
109 | kernel_size,
110 | stride,
111 | padding,
112 | padding_mode="zeros" # Author's code used slim.convolution2d, which is using SAME padding (zero padding in pytorch)
113 | ))
114 | self.activate = activate
115 | if self.activate:
116 | self.LReLU = nn.LeakyReLU(negative_slope=0.2, inplace=True)
117 |
118 | def forward(self, x):
119 | x = self.sn_conv(x)
120 | if self.activate:
121 | x = self.LReLU(x)
122 |
123 | return x
124 |
125 |
126 | class Discriminator(nn.Module):
127 | def __init__(self, in_channels=3, out_channels=1, features=[32, 64, 128]):
128 | super().__init__()
129 |
130 | self.model = nn.Sequential(
131 | #k3n32s2
132 | Block(in_channels, features[0], kernel_size=3, stride=2, padding=1),
133 | #k3n32s1
134 | Block(features[0], features[0], kernel_size=3, stride=1, padding=1),
135 |
136 | #k3n64s2
137 | Block(features[0], features[1], kernel_size=3, stride=2, padding=1),
138 | #k3n64s1
139 | Block(features[1], features[1], kernel_size=3, stride=1, padding=1),
140 |
141 | #k3n128s2
142 | Block(features[1], features[2], kernel_size=3, stride=2, padding=1),
143 | #k3n128s1
144 | Block(features[2], features[2], kernel_size=3, stride=1, padding=1),
145 |
146 | #k1n1s1
147 | Block(features[2], out_channels, kernel_size=1, stride=1, padding=0)
148 | )
149 |
150 | def forward(self, x):
151 | x = self.model(x)
152 |
153 | return x
154 |
155 |
156 | class ExpressDetector(nn.Module):
157 | def __init__(self, in_channels=3, out_channels=3, features=[32, 64, 128,512]):
158 | super().__init__()
159 |
160 | self.model = nn.Sequential(
161 | #k3n32s2
162 | Block(in_channels, features[0], kernel_size=3, stride=2, padding=1),
163 | #k3n32s1
164 | Block(features[0], features[0], kernel_size=3, stride=1, padding=1),
165 |
166 | #k3n64s2
167 | Block(features[0], features[1], kernel_size=3, stride=2, padding=1),
168 | #k3n64s1
169 | Block(features[1], features[1], kernel_size=3, stride=1, padding=1),
170 |
171 | #k3n128s2
172 | Block(features[1], features[2], kernel_size=3, stride=2, padding=1),
173 | #k3n128s1
174 | Block(features[2], features[2], kernel_size=3, stride=1, padding=1),
175 | nn.AdaptiveAvgPool2d(1),
176 | nn.Flatten(),
177 | nn.Linear(features[2],features[3]),
178 | nn.Linear(features[3],out_channels)
179 |
180 | )
181 |
182 | def forward(self, x):
183 | x = self.model(x)
184 |
185 | return x
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
--------------------------------------------------------------------------------
/model/Pix2PixModule/module.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn import functional as F
3 | import torch
4 |
5 | # refer to https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch
6 | # VGG architecter, used for the perceptual loss using a pretrained VGG network
7 | VGG_types = {
8 | "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
9 | "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
10 | "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
11 | "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
12 | }
13 | class VGG19(torch.nn.Module):
14 | def __init__(self, in_channels=3,
15 | VGGtype="VGG19",
16 | init_weights=None,
17 | batch_norm=False,
18 | num_classes=1000,
19 | feature_mode=False,
20 | requires_grad=False):
21 | super(VGG19, self).__init__()
22 | self.in_channels = in_channels
23 | self.feature_mode = feature_mode
24 | self.batch_norm = batch_norm
25 |
26 | self.features = self.create_conv_layers(VGG_types[VGGtype])
27 |
28 | self.classifier = nn.Sequential(
29 | nn.Linear(512 * 7 * 7, 4096),
30 | nn.ReLU(),
31 | nn.Dropout(p=0.5),
32 | nn.Linear(4096, 4096),
33 | nn.ReLU(),
34 | nn.Dropout(p=0.5),
35 | nn.Linear(4096, num_classes),
36 | )
37 |
38 | if init_weights is not None:
39 | self.load_state_dict(torch.load(init_weights))
40 |
41 | if not requires_grad:
42 | for param in self.parameters():
43 | param.requires_grad = False
44 |
45 | def forward(self, x):
46 | if not self.feature_mode:
47 | x = self.features(x)
48 | x = x.view(x.size(0), -1)
49 | x = self.classifier(x)
50 |
51 | elif self.feature_mode == True and self.batch_norm == False:
52 | module_list = list(self.features.modules())
53 | #print(module_list[1:27])
54 | for layer in module_list[1:27]: # conv4_4 Feature maps
55 | x = layer(x)
56 | else:
57 | raise ValueError('Feature mode does not work with batch norm enabled. Set batch_norm=False and try again.')
58 |
59 | return x
60 |
61 | def create_conv_layers(self, architecture):
62 | layers = []
63 | in_channels = self.in_channels
64 | batch_norm = self.batch_norm
65 |
66 | for x in architecture:
67 | if type(x) == int: # Number of features
68 | out_channels = x
69 |
70 | layers += [
71 | nn.Conv2d(
72 | in_channels=in_channels,
73 | out_channels=out_channels,
74 | kernel_size=3,
75 | stride=1,
76 | padding=1,
77 | ),
78 | ]
79 |
80 | if batch_norm == True:
81 | # Back at that time Batch Norm was not invented
82 | layers += [nn.BatchNorm2d(x),nn.ReLU(),]
83 | else:
84 | layers += [nn.ReLU()]
85 |
86 | in_channels = x #update in_channel
87 |
88 | elif x == "M": # Maxpooling
89 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
90 |
91 | return nn.Sequential(*layers)
92 |
93 |
94 |
95 | # refer to https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch
96 | def box_filter( x, r):
97 | channel = x.shape[1] # Batch, Channel, H, W
98 | kernel_size = (2*r+1)
99 | weight = 1.0/(kernel_size**2)
100 | box_kernel = weight*torch.ones((channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device)
101 | output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) #tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')
102 |
103 | return output
104 |
105 |
106 | def guided_filter(x, y, r, eps=1e-2):
107 | # Batch, Channel, H, W
108 | _, _, H, W = x.shape
109 |
110 | N = box_filter(torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), r)
111 |
112 | mean_x = box_filter(x, r) / N
113 | mean_y = box_filter(y, r) / N
114 | cov_xy = box_filter(x * y, r) / N - mean_x * mean_y
115 | var_x = box_filter(x * x, r) / N - mean_x * mean_x
116 |
117 | A = cov_xy / (var_x + eps)
118 | b = mean_y - A * mean_x
119 |
120 | mean_A = box_filter(A, r) / N
121 | mean_b = box_filter(b, r) / N
122 |
123 | output = mean_A * x + mean_b
124 | return output
125 |
126 | # refer to https://github.com/SystemErrorWang/White-box-Cartoonization
127 | def color_shift(image, mode='uniform'):
128 | device = image.device
129 | r1, b1, g1 = torch.split(image, 1, dim=1)
130 |
131 | if mode == 'normal':
132 | b_weight = torch.normal(mean=0.114, std=0.1,size=[1]).to(device)
133 | g_weight = torch.normal(mean=0.587, std=0.1,size=[1]).to(device)
134 | r_weight = torch.normal(mean=0.299, std=0.1,size=[1]).to(device)
135 | elif mode == 'uniform':
136 |
137 | b_weight = torch.FloatTensor(1).uniform_(0.014,0.214).to(device)
138 | g_weight = torch.FloatTensor(1).uniform_(0.487, 0.687).to(device)
139 | r_weight = torch.FloatTensor(1).uniform_(0.199, 0.399).to(device)
140 | output1 = (b_weight*b1+g_weight*g1+r_weight*r1)/(b_weight+g_weight+r_weight)
141 |
142 | return output1
143 |
144 |
145 |
--------------------------------------------------------------------------------
/model/styleganModule/arcface.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
2 | import torch
3 | from collections import namedtuple
4 |
5 |
6 | # ------------------------arcface---------------------------------------
7 | class Backbone(Module):
8 | def __init__(self, num_layers, drop_ratio, mode='ir'):
9 | super(Backbone, self).__init__()
10 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
11 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
12 | blocks = get_blocks(num_layers)
13 | if mode == 'ir':
14 | unit_module = bottleneck_IR
15 | elif mode == 'ir_se':
16 | unit_module = bottleneck_IR_SE
17 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1 ,bias=False),
18 | BatchNorm2d(64),
19 | PReLU(64))
20 | self.output_layer = Sequential(BatchNorm2d(512),
21 | Dropout(drop_ratio),
22 | Flatten(),
23 | Linear(512 * 7 * 7, 512),
24 | BatchNorm1d(512))
25 | # )
26 | modules = []
27 | for block in blocks:
28 | for bottleneck in block:
29 | modules.append(
30 | unit_module(bottleneck.in_channel,
31 | bottleneck.depth,
32 | bottleneck.stride))
33 | self.body = Sequential(*modules)
34 |
35 | def forward(self,x):
36 |
37 | feats = []
38 | x = self.input_layer(x)
39 | for m in self.body.children():
40 | x = m(x)
41 | feats.append(x)
42 | # x = self.body(x)
43 | x = self.output_layer(x)
44 | return l2_norm(x), feats
45 |
46 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
47 | '''A named tuple describing a ResNet block.'''
48 |
49 | class bottleneck_IR(Module):
50 | def __init__(self, in_channel, depth, stride):
51 | super(bottleneck_IR, self).__init__()
52 | if in_channel == depth:
53 | self.shortcut_layer = MaxPool2d(1, stride)
54 | else:
55 | self.shortcut_layer = Sequential(
56 | Conv2d(in_channel, depth, (1, 1), stride ,bias=False), BatchNorm2d(depth))
57 | self.res_layer = Sequential(
58 | BatchNorm2d(in_channel),
59 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1 ,bias=False), PReLU(depth),
60 | Conv2d(depth, depth, (3, 3), stride, 1 ,bias=False), BatchNorm2d(depth))
61 |
62 | def forward(self, x):
63 | shortcut = self.shortcut_layer(x)
64 | res = self.res_layer(x)
65 | return res + shortcut
66 |
67 | class bottleneck_IR_SE(Module):
68 | def __init__(self, in_channel, depth, stride):
69 | super(bottleneck_IR_SE, self).__init__()
70 | if in_channel == depth:
71 | self.shortcut_layer = MaxPool2d(1, stride)
72 | else:
73 | self.shortcut_layer = Sequential(
74 | Conv2d(in_channel, depth, (1, 1), stride ,bias=False),
75 | BatchNorm2d(depth))
76 | self.res_layer = Sequential(
77 | BatchNorm2d(in_channel),
78 | Conv2d(in_channel, depth, (3,3), (1,1),1 ,bias=False),
79 | PReLU(depth),
80 | Conv2d(depth, depth, (3,3), stride, 1 ,bias=False),
81 | BatchNorm2d(depth),
82 | SEModule(depth,16)
83 | )
84 | def forward(self,x):
85 | shortcut = self.shortcut_layer(x)
86 | res = self.res_layer(x)
87 | return res + shortcut
88 |
89 | def l2_norm(input,axis=1):
90 | norm = torch.norm(input,2,axis,True)
91 | output = torch.div(input, norm)
92 | return output
93 |
94 | class SEModule(Module):
95 | def __init__(self, channels, reduction):
96 | super(SEModule, self).__init__()
97 | self.avg_pool = AdaptiveAvgPool2d(1)
98 | self.fc1 = Conv2d(
99 | channels, channels // reduction, kernel_size=1, padding=0 ,bias=False)
100 | self.relu = ReLU(inplace=True)
101 | self.fc2 = Conv2d(
102 | channels // reduction, channels, kernel_size=1, padding=0 ,bias=False)
103 | self.sigmoid = Sigmoid()
104 |
105 | def forward(self, x):
106 | module_input = x
107 | x = self.avg_pool(x)
108 | x = self.fc1(x)
109 | x = self.relu(x)
110 | x = self.fc2(x)
111 | x = self.sigmoid(x)
112 | return module_input * x
113 |
114 | class Flatten(Module):
115 | def forward(self, input):
116 | return input.view(input.size(0), -1)
117 |
118 | def get_block(in_channel, depth, num_units, stride = 2):
119 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units-1)]
120 |
121 | def get_blocks(num_layers):
122 | if num_layers == 50:
123 | blocks = [
124 | get_block(in_channel=64, depth=64, num_units = 3),
125 | get_block(in_channel=64, depth=128, num_units=4),
126 | get_block(in_channel=128, depth=256, num_units=14),
127 | get_block(in_channel=256, depth=512, num_units=3)
128 | ]
129 | elif num_layers == 100:
130 | blocks = [
131 | get_block(in_channel=64, depth=64, num_units=3),
132 | get_block(in_channel=64, depth=128, num_units=13),
133 | get_block(in_channel=128, depth=256, num_units=30),
134 | get_block(in_channel=256, depth=512, num_units=3)
135 | ]
136 | elif num_layers == 152:
137 | blocks = [
138 | get_block(in_channel=64, depth=64, num_units=3),
139 | get_block(in_channel=64, depth=128, num_units=8),
140 | get_block(in_channel=128, depth=256, num_units=36),
141 | get_block(in_channel=256, depth=512, num_units=3)
142 | ]
143 | return blocks
144 |
145 |
--------------------------------------------------------------------------------
/model/styleganModule/config.py:
--------------------------------------------------------------------------------
1 | class Params:
2 | def __init__(self):
3 |
4 | self.name = 'StyleGAN'
5 | self.size = 256
6 | self.stylegan_path = 'pretrain_models/550000.pt'
7 | self.root = ''
8 | self.id_model = 'pretrain_models/model_ir_se50.pth'
9 | self.g_reg_every = 4
10 | self.d_reg_every = 16
11 | self.D_steps_pre_G = 1
12 | self.latent = 512
13 | self.n_mlp =8
14 | self.channel_multiplier =2
15 | self.size =256
16 | self.mixing =0.9
17 | self.inject_index =4
18 | self.n_sample =8
19 |
20 | self.lambda_gan =1.0
21 | self.lambda_id =1.0
22 | self.path_regularize =2.0
23 | self.path_batch_shrink = 2
24 | self.r1 =10.0
25 |
26 | self.interval_steps = 100
27 | self.interval_train = False
28 |
29 | self.infer_batch_size = 1
30 | self.mx_gen_iters = 20000
31 |
--------------------------------------------------------------------------------
/model/styleganModule/loss.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from torch import autograd
3 | from torch import nn
4 | import torch
5 | from .arcface import Backbone
6 | import math
7 | from .op import conv2d_gradfix
8 |
9 | class IDLoss(nn.Module):
10 | def __init__(self,pretrain_model, requires_grad=False):
11 | super(IDLoss, self).__init__()
12 | self.idModel = Backbone(50,0.6,'ir_se')
13 | self.idModel.load_state_dict(torch.load(pretrain_model),strict=False)
14 | self.idModel.eval()
15 | self.criterion = nn.CosineSimilarity(dim=1,eps=1e-6)
16 | self.id_size = 112
17 | if not requires_grad:
18 | for param in self.parameters():
19 | param.requires_grad = False
20 |
21 |
22 |
23 | def forward(self, x, y):
24 | x_id, _ = self.idModel(F.interpolate(x[:,:,28:228,28:228],[self.id_size, self.id_size], mode='bilinear'))
25 | y_id,_ = self.idModel(F.interpolate(y[:,:,28:228,28:228],
26 | [self.id_size, self.id_size], mode='bilinear'))
27 | loss = 1 - self.criterion(x_id,y_id)
28 | return loss.mean()
29 |
30 |
31 | def g_nonsaturating_loss(fake_pred):
32 | loss = F.softplus(-fake_pred).mean()
33 |
34 | return loss
35 |
36 |
37 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
38 | noise = torch.randn_like(fake_img) / math.sqrt(
39 | fake_img.shape[2] * fake_img.shape[3]
40 | )
41 | grad, = autograd.grad(
42 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
43 | )
44 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
45 |
46 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
47 |
48 | path_penalty = (path_lengths - path_mean).pow(2).mean()
49 |
50 | return path_penalty, path_mean.detach(), path_lengths
51 |
52 | def d_r1_loss(real_pred, real_img):
53 | with conv2d_gradfix.no_weight_gradients():
54 | grad_real, = autograd.grad(
55 | outputs=real_pred.sum(), inputs=real_img, create_graph=True
56 | )
57 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
58 |
59 | return grad_penalty
60 |
61 | def d_logistic_loss(real_pred, fake_pred):
62 | real_loss = F.softplus(-real_pred)
63 | fake_loss = F.softplus(fake_pred)
64 |
65 | return real_loss.mean() + fake_loss.mean()
--------------------------------------------------------------------------------
/model/styleganModule/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
10 |
11 |
12 | class PixelNorm(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def forward(self, input):
17 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
18 |
19 |
20 | def make_kernel(k):
21 | k = torch.tensor(k, dtype=torch.float32)
22 |
23 | if k.ndim == 1:
24 | k = k[None, :] * k[:, None]
25 |
26 | k /= k.sum()
27 |
28 | return k
29 |
30 |
31 | class Upsample(nn.Module):
32 | def __init__(self, kernel, factor=2):
33 | super().__init__()
34 |
35 | self.factor = factor
36 | kernel = make_kernel(kernel) * (factor ** 2)
37 | self.register_buffer("kernel", kernel)
38 |
39 | p = kernel.shape[0] - factor
40 |
41 | pad0 = (p + 1) // 2 + factor - 1
42 | pad1 = p // 2
43 |
44 | self.pad = (pad0, pad1)
45 |
46 | def forward(self, input):
47 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
48 |
49 | return out
50 |
51 |
52 | class Downsample(nn.Module):
53 | def __init__(self, kernel, factor=2):
54 | super().__init__()
55 |
56 | self.factor = factor
57 | kernel = make_kernel(kernel)
58 | self.register_buffer("kernel", kernel)
59 |
60 | p = kernel.shape[0] - factor
61 |
62 | pad0 = (p + 1) // 2
63 | pad1 = p // 2
64 |
65 | self.pad = (pad0, pad1)
66 |
67 | def forward(self, input):
68 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
69 |
70 | return out
71 |
72 |
73 | class Blur(nn.Module):
74 | def __init__(self, kernel, pad, upsample_factor=1):
75 | super().__init__()
76 |
77 | kernel = make_kernel(kernel)
78 |
79 | if upsample_factor > 1:
80 | kernel = kernel * (upsample_factor ** 2)
81 |
82 | self.register_buffer("kernel", kernel)
83 |
84 | self.pad = pad
85 |
86 | def forward(self, input):
87 | out = upfirdn2d(input, self.kernel, pad=self.pad)
88 |
89 | return out
90 |
91 |
92 | class EqualConv2d(nn.Module):
93 | def __init__(
94 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
95 | ):
96 | super().__init__()
97 |
98 | self.weight = nn.Parameter(
99 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
100 | )
101 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
102 |
103 | self.stride = stride
104 | self.padding = padding
105 |
106 | if bias:
107 | self.bias = nn.Parameter(torch.zeros(out_channel))
108 |
109 | else:
110 | self.bias = None
111 |
112 | def forward(self, input):
113 | out = conv2d_gradfix.conv2d(
114 | input,
115 | self.weight * self.scale,
116 | bias=self.bias,
117 | stride=self.stride,
118 | padding=self.padding,
119 | )
120 |
121 | return out
122 |
123 | def __repr__(self):
124 | return (
125 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
126 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
127 | )
128 |
129 |
130 | class EqualLinear(nn.Module):
131 | def __init__(
132 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
133 | ):
134 | super().__init__()
135 |
136 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
137 |
138 | if bias:
139 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
140 |
141 | else:
142 | self.bias = None
143 |
144 | self.activation = activation
145 |
146 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
147 | self.lr_mul = lr_mul
148 |
149 | def forward(self, input):
150 | if self.activation:
151 | out = F.linear(input, self.weight * self.scale)
152 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
153 |
154 | else:
155 | out = F.linear(
156 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
157 | )
158 |
159 | return out
160 |
161 | def __repr__(self):
162 | return (
163 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
164 | )
165 |
166 |
167 | class ModulatedConv2d(nn.Module):
168 | def __init__(
169 | self,
170 | in_channel,
171 | out_channel,
172 | kernel_size,
173 | style_dim,
174 | demodulate=True,
175 | upsample=False,
176 | downsample=False,
177 | blur_kernel=[1, 3, 3, 1],
178 | fused=True,
179 | ):
180 | super().__init__()
181 |
182 | self.eps = 1e-8
183 | self.kernel_size = kernel_size
184 | self.in_channel = in_channel
185 | self.out_channel = out_channel
186 | self.upsample = upsample
187 | self.downsample = downsample
188 |
189 | if upsample:
190 | factor = 2
191 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
192 | pad0 = (p + 1) // 2 + factor - 1
193 | pad1 = p // 2 + 1
194 |
195 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
196 |
197 | if downsample:
198 | factor = 2
199 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
200 | pad0 = (p + 1) // 2
201 | pad1 = p // 2
202 |
203 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
204 |
205 | fan_in = in_channel * kernel_size ** 2
206 | self.scale = 1 / math.sqrt(fan_in)
207 | self.padding = kernel_size // 2
208 |
209 | self.weight = nn.Parameter(
210 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
211 | )
212 |
213 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
214 |
215 | self.demodulate = demodulate
216 | self.fused = fused
217 |
218 | def __repr__(self):
219 | return (
220 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
221 | f"upsample={self.upsample}, downsample={self.downsample})"
222 | )
223 |
224 | def forward(self, input, style):
225 | batch, in_channel, height, width = input.shape
226 |
227 | if not self.fused:
228 | weight = self.scale * self.weight.squeeze(0)
229 | style = self.modulation(style)
230 |
231 | if self.demodulate:
232 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
233 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
234 |
235 | input = input * style.reshape(batch, in_channel, 1, 1)
236 |
237 | if self.upsample:
238 | weight = weight.transpose(0, 1)
239 | out = conv2d_gradfix.conv_transpose2d(
240 | input, weight, padding=0, stride=2
241 | )
242 | out = self.blur(out)
243 |
244 | elif self.downsample:
245 | input = self.blur(input)
246 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
247 |
248 | else:
249 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
250 |
251 | if self.demodulate:
252 | out = out * dcoefs.view(batch, -1, 1, 1)
253 |
254 | return out
255 |
256 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
257 | weight = self.scale * self.weight * style
258 |
259 | if self.demodulate:
260 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
261 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
262 |
263 | weight = weight.view(
264 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
265 | )
266 |
267 | if self.upsample:
268 | input = input.view(1, batch * in_channel, height, width)
269 | weight = weight.view(
270 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
271 | )
272 | weight = weight.transpose(1, 2).reshape(
273 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
274 | )
275 | out = conv2d_gradfix.conv_transpose2d(
276 | input, weight, padding=0, stride=2, groups=batch
277 | )
278 | _, _, height, width = out.shape
279 | out = out.view(batch, self.out_channel, height, width)
280 | out = self.blur(out)
281 |
282 | elif self.downsample:
283 | input = self.blur(input)
284 | _, _, height, width = input.shape
285 | input = input.view(1, batch * in_channel, height, width)
286 | out = conv2d_gradfix.conv2d(
287 | input, weight, padding=0, stride=2, groups=batch
288 | )
289 | _, _, height, width = out.shape
290 | out = out.view(batch, self.out_channel, height, width)
291 |
292 | else:
293 | input = input.view(1, batch * in_channel, height, width)
294 | out = conv2d_gradfix.conv2d(
295 | input, weight, padding=self.padding, groups=batch
296 | )
297 | _, _, height, width = out.shape
298 | out = out.view(batch, self.out_channel, height, width)
299 |
300 | return out
301 |
302 |
303 | class NoiseInjection(nn.Module):
304 | def __init__(self):
305 | super().__init__()
306 |
307 | self.weight = nn.Parameter(torch.zeros(1))
308 |
309 | def forward(self, image, noise=None):
310 | if noise is None:
311 | batch, _, height, width = image.shape
312 | noise = image.new_empty(batch, 1, height, width).normal_()
313 |
314 | return image + self.weight * noise
315 |
316 |
317 | class ConstantInput(nn.Module):
318 | def __init__(self, channel, size=4):
319 | super().__init__()
320 |
321 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
322 |
323 | def forward(self, input):
324 | batch = input.shape[0]
325 | out = self.input.repeat(batch, 1, 1, 1)
326 |
327 | return out
328 |
329 |
330 | class StyledConv(nn.Module):
331 | def __init__(
332 | self,
333 | in_channel,
334 | out_channel,
335 | kernel_size,
336 | style_dim,
337 | upsample=False,
338 | blur_kernel=[1, 3, 3, 1],
339 | demodulate=True,
340 | ):
341 | super().__init__()
342 |
343 | self.conv = ModulatedConv2d(
344 | in_channel,
345 | out_channel,
346 | kernel_size,
347 | style_dim,
348 | upsample=upsample,
349 | blur_kernel=blur_kernel,
350 | demodulate=demodulate,
351 | )
352 |
353 | self.noise = NoiseInjection()
354 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
355 | # self.activate = ScaledLeakyReLU(0.2)
356 | self.activate = FusedLeakyReLU(out_channel)
357 |
358 | def forward(self, input, style, noise=None):
359 | out = self.conv(input, style)
360 | out = self.noise(out, noise=noise)
361 | # out = out + self.bias
362 | out = self.activate(out)
363 |
364 | return out
365 |
366 |
367 | class ToRGB(nn.Module):
368 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
369 | super().__init__()
370 |
371 | if upsample:
372 | self.upsample = Upsample(blur_kernel)
373 |
374 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
375 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
376 |
377 | def forward(self, input, style, skip=None):
378 | out = self.conv(input, style)
379 | out = out + self.bias
380 |
381 | if skip is not None:
382 | skip = self.upsample(skip)
383 |
384 | out = out + skip
385 |
386 | return out
387 |
388 |
389 | class Generator(nn.Module):
390 | def __init__(
391 | self,
392 | size,
393 | style_dim,
394 | n_mlp,
395 | channel_multiplier=2,
396 | blur_kernel=[1, 3, 3, 1],
397 | lr_mlp=0.01,
398 | ):
399 | super().__init__()
400 |
401 | self.size = size
402 |
403 | self.style_dim = style_dim
404 |
405 | layers = [PixelNorm()]
406 |
407 | for i in range(n_mlp):
408 | layers.append(
409 | EqualLinear(
410 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
411 | )
412 | )
413 |
414 | self.style = nn.Sequential(*layers)
415 |
416 | self.channels = {
417 | 4: 512,
418 | 8: 512,
419 | 16: 512,
420 | 32: 512,
421 | 64: 256 * channel_multiplier,
422 | 128: 128 * channel_multiplier,
423 | 256: 64 * channel_multiplier,
424 | 512: 32 * channel_multiplier,
425 | 1024: 16 * channel_multiplier,
426 | }
427 |
428 | self.input = ConstantInput(self.channels[4])
429 | self.conv1 = StyledConv(
430 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
431 | )
432 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
433 |
434 | self.log_size = int(math.log(size, 2))
435 | self.num_layers = (self.log_size - 2) * 2 + 1
436 |
437 | self.convs = nn.ModuleList()
438 | self.upsamples = nn.ModuleList()
439 | self.to_rgbs = nn.ModuleList()
440 | self.noises = nn.Module()
441 |
442 | in_channel = self.channels[4]
443 |
444 | for layer_idx in range(self.num_layers):
445 | res = (layer_idx + 5) // 2
446 | shape = [1, 1, 2 ** res, 2 ** res]
447 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
448 |
449 | for i in range(3, self.log_size + 1):
450 | out_channel = self.channels[2 ** i]
451 |
452 | self.convs.append(
453 | StyledConv(
454 | in_channel,
455 | out_channel,
456 | 3,
457 | style_dim,
458 | upsample=True,
459 | blur_kernel=blur_kernel,
460 | )
461 | )
462 |
463 | self.convs.append(
464 | StyledConv(
465 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
466 | )
467 | )
468 |
469 | self.to_rgbs.append(ToRGB(out_channel, style_dim))
470 |
471 | in_channel = out_channel
472 |
473 | self.n_latent = self.log_size * 2 - 2
474 |
475 | def make_noise(self):
476 | device = self.input.input.device
477 |
478 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
479 |
480 | for i in range(3, self.log_size + 1):
481 | for _ in range(2):
482 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
483 |
484 | return noises
485 |
486 | def mean_latent(self, n_latent):
487 | latent_in = torch.randn(
488 | n_latent, self.style_dim, device=self.input.input.device
489 | )
490 | latent = self.style(latent_in).mean(0, keepdim=True)
491 |
492 | return latent
493 |
494 | def get_latent(self, input):
495 | return self.style(input)
496 |
497 | def forward(
498 | self,
499 | styles,
500 | return_latents=False,
501 | inject_index=None,
502 | truncation=1,
503 | truncation_latent=None,
504 | input_is_latent=False,
505 | noise=None,
506 | randomize_noise=True,
507 | only_latent=False
508 | ):
509 | if not input_is_latent:
510 | styles = [self.style(s) for s in styles]
511 |
512 | if noise is None:
513 | if randomize_noise:
514 | noise = [None] * self.num_layers
515 | else:
516 | noise = [
517 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
518 | ]
519 |
520 | if truncation < 1:
521 | style_t = []
522 |
523 | for style in styles:
524 | style_t.append(
525 | truncation_latent + truncation * (style - truncation_latent)
526 | )
527 |
528 | styles = style_t
529 |
530 | if len(styles) < 2:
531 | inject_index = self.n_latent
532 |
533 | if styles[0].ndim < 3:
534 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
535 |
536 | else:
537 | latent = styles[0]
538 |
539 | else:
540 | if inject_index is None:
541 | inject_index = random.randint(1, self.n_latent - 1)
542 |
543 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
544 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
545 |
546 | latent = torch.cat([latent, latent2], 1)
547 | if only_latent:
548 | return latent
549 |
550 | out = self.input(latent)
551 | out = self.conv1(out, latent[:, 0], noise=noise[0])
552 |
553 | skip = self.to_rgb1(out, latent[:, 1])
554 |
555 | i = 1
556 | for conv1, conv2, noise1, noise2, to_rgb in zip(
557 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
558 | ):
559 | out = conv1(out, latent[:, i], noise=noise1)
560 | out = conv2(out, latent[:, i + 1], noise=noise2)
561 | skip = to_rgb(out, latent[:, i + 2], skip)
562 |
563 | i += 2
564 |
565 | image = skip
566 |
567 | if return_latents:
568 | return image, latent
569 |
570 | else:
571 | return image, None
572 |
573 |
574 | class ConvLayer(nn.Sequential):
575 | def __init__(
576 | self,
577 | in_channel,
578 | out_channel,
579 | kernel_size,
580 | downsample=False,
581 | blur_kernel=[1, 3, 3, 1],
582 | bias=True,
583 | activate=True,
584 | ):
585 | layers = []
586 |
587 | if downsample:
588 | factor = 2
589 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
590 | pad0 = (p + 1) // 2
591 | pad1 = p // 2
592 |
593 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
594 |
595 | stride = 2
596 | self.padding = 0
597 |
598 | else:
599 | stride = 1
600 | self.padding = kernel_size // 2
601 |
602 | layers.append(
603 | EqualConv2d(
604 | in_channel,
605 | out_channel,
606 | kernel_size,
607 | padding=self.padding,
608 | stride=stride,
609 | bias=bias and not activate,
610 | )
611 | )
612 |
613 | if activate:
614 | layers.append(FusedLeakyReLU(out_channel, bias=bias))
615 |
616 | super().__init__(*layers)
617 |
618 |
619 | class ResBlock(nn.Module):
620 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
621 | super().__init__()
622 |
623 | self.conv1 = ConvLayer(in_channel, in_channel, 3)
624 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
625 |
626 | self.skip = ConvLayer(
627 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False
628 | )
629 |
630 | def forward(self, input):
631 | out = self.conv1(input)
632 | out = self.conv2(out)
633 |
634 | skip = self.skip(input)
635 | out = (out + skip) / math.sqrt(2)
636 |
637 | return out
638 |
639 |
640 | class Discriminator(nn.Module):
641 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
642 | super().__init__()
643 |
644 | channels = {
645 | 4: 512,
646 | 8: 512,
647 | 16: 512,
648 | 32: 512,
649 | 64: 256 * channel_multiplier,
650 | 128: 128 * channel_multiplier,
651 | 256: 64 * channel_multiplier,
652 | 512: 32 * channel_multiplier,
653 | 1024: 16 * channel_multiplier,
654 | }
655 |
656 | convs = [ConvLayer(3, channels[size], 1)]
657 |
658 | log_size = int(math.log(size, 2))
659 |
660 | in_channel = channels[size]
661 |
662 | for i in range(log_size, 2, -1):
663 | out_channel = channels[2 ** (i - 1)]
664 |
665 | convs.append(ResBlock(in_channel, out_channel, blur_kernel))
666 |
667 | in_channel = out_channel
668 |
669 | self.convs = nn.Sequential(*convs)
670 |
671 | self.stddev_group = 4
672 | self.stddev_feat = 1
673 |
674 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
675 | self.final_linear = nn.Sequential(
676 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
677 | EqualLinear(channels[4], 1),
678 | )
679 |
680 | def forward(self, input):
681 |
682 | out = self.convs(input)
683 |
684 | batch, channel, height, width = out.shape
685 | group = min(batch, self.stddev_group)
686 | stddev = out.view(
687 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
688 | )
689 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
690 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
691 | stddev = stddev.repeat(group, 1, height, width)
692 | out = torch.cat([out, stddev], 1)
693 |
694 | out = self.final_conv(out)
695 |
696 | out = out.view(batch, -1)
697 | out = self.final_linear(out)
698 |
699 | return out
700 |
701 |
--------------------------------------------------------------------------------
/model/styleganModule/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/model/styleganModule/op/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import warnings
3 |
4 | import torch
5 | from torch import autograd
6 | from torch.nn import functional as F
7 |
8 | enabled = True
9 | weight_gradients_disabled = False
10 |
11 |
12 | @contextlib.contextmanager
13 | def no_weight_gradients():
14 | global weight_gradients_disabled
15 |
16 | old = weight_gradients_disabled
17 | weight_gradients_disabled = True
18 | yield
19 | weight_gradients_disabled = old
20 |
21 |
22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23 | if could_use_op(input):
24 | return conv2d_gradfix(
25 | transpose=False,
26 | weight_shape=weight.shape,
27 | stride=stride,
28 | padding=padding,
29 | output_padding=0,
30 | dilation=dilation,
31 | groups=groups,
32 | ).apply(input, weight, bias)
33 |
34 | return F.conv2d(
35 | input=input,
36 | weight=weight,
37 | bias=bias,
38 | stride=stride,
39 | padding=padding,
40 | dilation=dilation,
41 | groups=groups,
42 | )
43 |
44 |
45 | def conv_transpose2d(
46 | input,
47 | weight,
48 | bias=None,
49 | stride=1,
50 | padding=0,
51 | output_padding=0,
52 | groups=1,
53 | dilation=1,
54 | ):
55 | if could_use_op(input):
56 | return conv2d_gradfix(
57 | transpose=True,
58 | weight_shape=weight.shape,
59 | stride=stride,
60 | padding=padding,
61 | output_padding=output_padding,
62 | groups=groups,
63 | dilation=dilation,
64 | ).apply(input, weight, bias)
65 |
66 | return F.conv_transpose2d(
67 | input=input,
68 | weight=weight,
69 | bias=bias,
70 | stride=stride,
71 | padding=padding,
72 | output_padding=output_padding,
73 | dilation=dilation,
74 | groups=groups,
75 | )
76 |
77 |
78 | def could_use_op(input):
79 | if (not enabled) or (not torch.backends.cudnn.enabled):
80 | return False
81 |
82 | if input.device.type != "cuda":
83 | return False
84 |
85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86 | return True
87 |
88 | warnings.warn(
89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90 | )
91 |
92 | return False
93 |
94 |
95 | def ensure_tuple(xs, ndim):
96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97 |
98 | return xs
99 |
100 |
101 | conv2d_gradfix_cache = dict()
102 |
103 |
104 | def conv2d_gradfix(
105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups
106 | ):
107 | ndim = 2
108 | weight_shape = tuple(weight_shape)
109 | stride = ensure_tuple(stride, ndim)
110 | padding = ensure_tuple(padding, ndim)
111 | output_padding = ensure_tuple(output_padding, ndim)
112 | dilation = ensure_tuple(dilation, ndim)
113 |
114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115 | if key in conv2d_gradfix_cache:
116 | return conv2d_gradfix_cache[key]
117 |
118 | common_kwargs = dict(
119 | stride=stride, padding=padding, dilation=dilation, groups=groups
120 | )
121 |
122 | def calc_output_padding(input_shape, output_shape):
123 | if transpose:
124 | return [0, 0]
125 |
126 | return [
127 | input_shape[i + 2]
128 | - (output_shape[i + 2] - 1) * stride[i]
129 | - (1 - 2 * padding[i])
130 | - dilation[i] * (weight_shape[i + 2] - 1)
131 | for i in range(ndim)
132 | ]
133 |
134 | class Conv2d(autograd.Function):
135 | @staticmethod
136 | def forward(ctx, input, weight, bias):
137 | if not transpose:
138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139 |
140 | else:
141 | out = F.conv_transpose2d(
142 | input=input,
143 | weight=weight,
144 | bias=bias,
145 | output_padding=output_padding,
146 | **common_kwargs,
147 | )
148 |
149 | ctx.save_for_backward(input, weight)
150 |
151 | return out
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | input, weight = ctx.saved_tensors
156 | grad_input, grad_weight, grad_bias = None, None, None
157 |
158 | if ctx.needs_input_grad[0]:
159 | p = calc_output_padding(
160 | input_shape=input.shape, output_shape=grad_output.shape
161 | )
162 | grad_input = conv2d_gradfix(
163 | transpose=(not transpose),
164 | weight_shape=weight_shape,
165 | output_padding=p,
166 | **common_kwargs,
167 | ).apply(grad_output, weight, None)
168 |
169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
171 |
172 | if ctx.needs_input_grad[2]:
173 | grad_bias = grad_output.sum((0, 2, 3))
174 |
175 | return grad_input, grad_weight, grad_bias
176 |
177 | class Conv2dGradWeight(autograd.Function):
178 | @staticmethod
179 | def forward(ctx, grad_output, input):
180 | op = torch._C._jit_get_operation(
181 | "aten::cudnn_convolution_backward_weight"
182 | if not transpose
183 | else "aten::cudnn_convolution_transpose_backward_weight"
184 | )
185 | flags = [
186 | torch.backends.cudnn.benchmark,
187 | torch.backends.cudnn.deterministic,
188 | torch.backends.cudnn.allow_tf32,
189 | ]
190 | grad_weight = op(
191 | weight_shape,
192 | grad_output,
193 | input,
194 | padding,
195 | stride,
196 | dilation,
197 | groups,
198 | *flags,
199 | )
200 | ctx.save_for_backward(grad_output, input)
201 |
202 | return grad_weight
203 |
204 | @staticmethod
205 | def backward(ctx, grad_grad_weight):
206 | grad_output, input = ctx.saved_tensors
207 | grad_grad_output, grad_grad_input = None, None
208 |
209 | if ctx.needs_input_grad[0]:
210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211 |
212 | if ctx.needs_input_grad[1]:
213 | p = calc_output_padding(
214 | input_shape=input.shape, output_shape=grad_output.shape
215 | )
216 | grad_grad_input = conv2d_gradfix(
217 | transpose=(not transpose),
218 | weight_shape=weight_shape,
219 | output_padding=p,
220 | **common_kwargs,
221 | ).apply(grad_output, grad_grad_weight, None)
222 |
223 | return grad_grad_output, grad_grad_input
224 |
225 | conv2d_gradfix_cache[key] = Conv2d
226 |
227 | return Conv2d
228 |
--------------------------------------------------------------------------------
/model/styleganModule/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | if bias:
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | else:
42 | grad_bias = empty
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | out, = ctx.saved_tensors
49 | gradgrad_out = fused.fused_bias_act(
50 | gradgrad_input.contiguous(),
51 | gradgrad_bias,
52 | out,
53 | 3,
54 | 1,
55 | ctx.negative_slope,
56 | ctx.scale,
57 | )
58 |
59 | return gradgrad_out, None, None, None, None
60 |
61 |
62 | class FusedLeakyReLUFunction(Function):
63 | @staticmethod
64 | def forward(ctx, input, bias, negative_slope, scale):
65 | empty = input.new_empty(0)
66 |
67 | ctx.bias = bias is not None
68 |
69 | if bias is None:
70 | bias = empty
71 |
72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
73 | ctx.save_for_backward(out)
74 | ctx.negative_slope = negative_slope
75 | ctx.scale = scale
76 |
77 | return out
78 |
79 | @staticmethod
80 | def backward(ctx, grad_output):
81 | out, = ctx.saved_tensors
82 |
83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85 | )
86 |
87 | if not ctx.bias:
88 | grad_bias = None
89 |
90 | return grad_input, grad_bias, None, None
91 |
92 |
93 | class FusedLeakyReLU(nn.Module):
94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95 | super().__init__()
96 |
97 | if bias:
98 | self.bias = nn.Parameter(torch.zeros(channel))
99 |
100 | else:
101 | self.bias = None
102 |
103 | self.negative_slope = negative_slope
104 | self.scale = scale
105 |
106 | def forward(self, input):
107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108 |
109 |
110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111 | if input.device.type == "cpu":
112 | if bias is not None:
113 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
114 | return (
115 | F.leaky_relu(
116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117 | )
118 | * scale
119 | )
120 |
121 | else:
122 | return F.leaky_relu(input, negative_slope=0.2) * scale
123 |
124 | else:
125 | return FusedLeakyReLUFunction.apply(
126 | input.contiguous(), bias, negative_slope, scale
127 | )
128 |
--------------------------------------------------------------------------------
/model/styleganModule/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 |
5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6 | const torch::Tensor &bias,
7 | const torch::Tensor &refer, int act, int grad,
8 | float alpha, float scale);
9 |
10 | #define CHECK_CUDA(x) \
11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12 | #define CHECK_CONTIGUOUS(x) \
13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14 | #define CHECK_INPUT(x) \
15 | CHECK_CUDA(x); \
16 | CHECK_CONTIGUOUS(x)
17 |
18 | torch::Tensor fused_bias_act(const torch::Tensor &input,
19 | const torch::Tensor &bias,
20 | const torch::Tensor &refer, int act, int grad,
21 | float alpha, float scale) {
22 | CHECK_INPUT(input);
23 | CHECK_INPUT(bias);
24 |
25 | at::DeviceGuard guard(input.device());
26 |
27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28 | }
29 |
30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32 | }
--------------------------------------------------------------------------------
/model/styleganModule/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 |
15 | #include
16 | #include
17 |
18 | template
19 | static __global__ void
20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22 | scalar_t scale, int loop_x, int size_x, int step_b,
23 | int size_b, int use_bias, int use_ref) {
24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25 |
26 | scalar_t zero = 0.0;
27 |
28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29 | loop_idx++, xi += blockDim.x) {
30 | scalar_t x = p_x[xi];
31 |
32 | if (use_bias) {
33 | x += p_b[(xi / step_b) % size_b];
34 | }
35 |
36 | scalar_t ref = use_ref ? p_ref[xi] : zero;
37 |
38 | scalar_t y;
39 |
40 | switch (act * 10 + grad) {
41 | default:
42 | case 10:
43 | y = x;
44 | break;
45 | case 11:
46 | y = x;
47 | break;
48 | case 12:
49 | y = 0.0;
50 | break;
51 |
52 | case 30:
53 | y = (x > 0.0) ? x : x * alpha;
54 | break;
55 | case 31:
56 | y = (ref > 0.0) ? x : x * alpha;
57 | break;
58 | case 32:
59 | y = 0.0;
60 | break;
61 | }
62 |
63 | out[xi] = y * scale;
64 | }
65 | }
66 |
67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68 | const torch::Tensor &bias,
69 | const torch::Tensor &refer, int act, int grad,
70 | float alpha, float scale) {
71 | int curDevice = -1;
72 | cudaGetDevice(&curDevice);
73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74 |
75 | auto x = input.contiguous();
76 | auto b = bias.contiguous();
77 | auto ref = refer.contiguous();
78 |
79 | int use_bias = b.numel() ? 1 : 0;
80 | int use_ref = ref.numel() ? 1 : 0;
81 |
82 | int size_x = x.numel();
83 | int size_b = b.numel();
84 | int step_b = 1;
85 |
86 | for (int i = 1 + 1; i < x.dim(); i++) {
87 | step_b *= x.size(i);
88 | }
89 |
90 | int loop_x = 4;
91 | int block_size = 4 * 32;
92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93 |
94 | auto y = torch::empty_like(x);
95 |
96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97 | x.scalar_type(), "fused_bias_act_kernel", [&] {
98 | fused_bias_act_kernel<<>>(
99 | y.data_ptr(), x.data_ptr(),
100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha,
101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102 | });
103 |
104 | return y;
105 | }
--------------------------------------------------------------------------------
/model/styleganModule/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5 | const torch::Tensor &kernel, int up_x, int up_y,
6 | int down_x, int down_y, int pad_x0, int pad_x1,
7 | int pad_y0, int pad_y1);
8 |
9 | #define CHECK_CUDA(x) \
10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) \
12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13 | #define CHECK_INPUT(x) \
14 | CHECK_CUDA(x); \
15 | CHECK_CONTIGUOUS(x)
16 |
17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18 | int up_x, int up_y, int down_x, int down_y, int pad_x0,
19 | int pad_x1, int pad_y0, int pad_y1) {
20 | CHECK_INPUT(input);
21 | CHECK_INPUT(kernel);
22 |
23 | at::DeviceGuard guard(input.device());
24 |
25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26 | pad_y0, pad_y1);
27 | }
28 |
29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31 | }
--------------------------------------------------------------------------------
/model/styleganModule/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | from collections import abc
2 | import os
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | upfirdn2d_op = load(
12 | "upfirdn2d",
13 | sources=[
14 | os.path.join(module_path, "upfirdn2d.cpp"),
15 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class UpFirDn2dBackward(Function):
21 | @staticmethod
22 | def forward(
23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24 | ):
25 |
26 | up_x, up_y = up
27 | down_x, down_y = down
28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29 |
30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31 |
32 | grad_input = upfirdn2d_op.upfirdn2d(
33 | grad_output,
34 | grad_kernel,
35 | down_x,
36 | down_y,
37 | up_x,
38 | up_y,
39 | g_pad_x0,
40 | g_pad_x1,
41 | g_pad_y0,
42 | g_pad_y1,
43 | )
44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45 |
46 | ctx.save_for_backward(kernel)
47 |
48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
49 |
50 | ctx.up_x = up_x
51 | ctx.up_y = up_y
52 | ctx.down_x = down_x
53 | ctx.down_y = down_y
54 | ctx.pad_x0 = pad_x0
55 | ctx.pad_x1 = pad_x1
56 | ctx.pad_y0 = pad_y0
57 | ctx.pad_y1 = pad_y1
58 | ctx.in_size = in_size
59 | ctx.out_size = out_size
60 |
61 | return grad_input
62 |
63 | @staticmethod
64 | def backward(ctx, gradgrad_input):
65 | kernel, = ctx.saved_tensors
66 |
67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68 |
69 | gradgrad_out = upfirdn2d_op.upfirdn2d(
70 | gradgrad_input,
71 | kernel,
72 | ctx.up_x,
73 | ctx.up_y,
74 | ctx.down_x,
75 | ctx.down_y,
76 | ctx.pad_x0,
77 | ctx.pad_x1,
78 | ctx.pad_y0,
79 | ctx.pad_y1,
80 | )
81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82 | gradgrad_out = gradgrad_out.view(
83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84 | )
85 |
86 | return gradgrad_out, None, None, None, None, None, None, None, None
87 |
88 |
89 | class UpFirDn2d(Function):
90 | @staticmethod
91 | def forward(ctx, input, kernel, up, down, pad):
92 | up_x, up_y = up
93 | down_x, down_y = down
94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
95 |
96 | kernel_h, kernel_w = kernel.shape
97 | batch, channel, in_h, in_w = input.shape
98 | ctx.in_size = input.shape
99 |
100 | input = input.reshape(-1, in_h, in_w, 1)
101 |
102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103 |
104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106 | ctx.out_size = (out_h, out_w)
107 |
108 | ctx.up = (up_x, up_y)
109 | ctx.down = (down_x, down_y)
110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111 |
112 | g_pad_x0 = kernel_w - pad_x0 - 1
113 | g_pad_y0 = kernel_h - pad_y0 - 1
114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116 |
117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118 |
119 | out = upfirdn2d_op.upfirdn2d(
120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121 | )
122 | # out = out.view(major, out_h, out_w, minor)
123 | out = out.view(-1, channel, out_h, out_w)
124 |
125 | return out
126 |
127 | @staticmethod
128 | def backward(ctx, grad_output):
129 | kernel, grad_kernel = ctx.saved_tensors
130 |
131 | grad_input = None
132 |
133 | if ctx.needs_input_grad[0]:
134 | grad_input = UpFirDn2dBackward.apply(
135 | grad_output,
136 | kernel,
137 | grad_kernel,
138 | ctx.up,
139 | ctx.down,
140 | ctx.pad,
141 | ctx.g_pad,
142 | ctx.in_size,
143 | ctx.out_size,
144 | )
145 |
146 | return grad_input, None, None, None, None
147 |
148 |
149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150 | if not isinstance(up, abc.Iterable):
151 | up = (up, up)
152 |
153 | if not isinstance(down, abc.Iterable):
154 | down = (down, down)
155 |
156 | if len(pad) == 2:
157 | pad = (pad[0], pad[1], pad[0], pad[1])
158 |
159 | if input.device.type == "cpu":
160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161 |
162 | else:
163 | out = UpFirDn2d.apply(input, kernel, up, down, pad)
164 |
165 | return out
166 |
167 |
168 | def upfirdn2d_native(
169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170 | ):
171 | _, channel, in_h, in_w = input.shape
172 | input = input.reshape(-1, in_h, in_w, 1)
173 |
174 | _, in_h, in_w, minor = input.shape
175 | kernel_h, kernel_w = kernel.shape
176 |
177 | out = input.view(-1, in_h, 1, in_w, 1, minor)
178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180 |
181 | out = F.pad(
182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183 | )
184 | out = out[
185 | :,
186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188 | :,
189 | ]
190 |
191 | out = out.permute(0, 3, 1, 2)
192 | out = out.reshape(
193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194 | )
195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196 | out = F.conv2d(out, w)
197 | out = out.reshape(
198 | -1,
199 | minor,
200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202 | )
203 | out = out.permute(0, 2, 3, 1)
204 | out = out[:, ::down_y, ::down_x, :]
205 |
206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208 |
209 | return out.view(-1, channel, out_h, out_w)
210 |
--------------------------------------------------------------------------------
/model/styleganModule/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/model/styleganModule/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 | def make_noise(batch, latent_dim, n_noise, device):
5 | if n_noise == 1:
6 | return torch.randn(batch, latent_dim, device=device)
7 |
8 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
9 |
10 | return noises
11 |
12 |
13 | def mixing_noise(batch, latent_dim, prob, device):
14 | if prob > 0 and random.random() < prob:
15 | return make_noise(batch, latent_dim, 2, device)
16 |
17 | else:
18 | return [make_noise(batch, latent_dim, 1, device)]
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | pip install tensorboardX -i https://pypi.tuna.tsinghua.edu.cn/simple
2 | export NCCL_SOCKET_IFNAME=eth0
3 | export NCCL_DEBUG=INFO
4 | # python train.py --model ccn --batch_size 16 --checkpoint_path checkpoint-ccn
5 | # python -m torch.distributed.launch train.py --model ccn --batch_size 16 --checkpoint_path checkpoint-ccn
6 | # python train.py --model ttn --batch_size 16 --checkpoint_path checkpoint-ttn-noxgf --lr 2e-4 --dist --print_interval 100 --save_interval 100
7 | # python train.py --model exp --batch_size 2 --checkpoint_path checkpoint-exp --lr 2e-4 --dist --print_interval 1 --save_interval 1 --early_stop --test_interval 1 --stop_interval 1
8 | # python train.py --model ccn --batch_size 16 --checkpoint_path checkpoint --lr 0.002 --print_interval 100 --save_interval 100 --dist
9 | python train.py --model ttn --batch_size 64 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100 --dist
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | @author LeslieZhao
3 | @date 20220721
4 | '''
5 | import os
6 |
7 | import argparse
8 | from trainer.CCNTrainer import CCNTrainer
9 |
10 | from trainer.TTNTrainer import TTNTrainer
11 | import torch.distributed as dist
12 | from utils.utils import setup_seed,get_data_loader,merge_args
13 | from model.styleganModule.config import Params as CCNParams
14 | from model.Pix2PixModule.config import Params as TTNParams
15 |
16 | # torch.multiprocessing.set_start_method('spawn')
17 |
18 | parser = argparse.ArgumentParser(description="StyleGAN")
19 | #---------train set-------------------------------------
20 | parser.add_argument('--model',default="ccn",help='')
21 | parser.add_argument('--isTrain',action="store_false",help='')
22 | parser.add_argument('--dist',action="store_false",help='')
23 | parser.add_argument('--batch_size',default=16,type=int)
24 | parser.add_argument('--seed',default=10,type=int)
25 | parser.add_argument('--eval',default=1,type=int,help='whether use eval')
26 | parser.add_argument('--nDataLoaderThread',default=5,type=int,help='Num of loader threads')
27 | parser.add_argument('--print_interval',default=100,type=int)
28 | parser.add_argument('--test_interval',default=100,type=int,help='Test and save every [test_intervaal] epochs')
29 | parser.add_argument('--save_interval',default=100,type=int,help='save model interval')
30 | parser.add_argument('--stop_interval',default=20,type=int)
31 | parser.add_argument('--begin_it',default=0,type=int,help='begin epoch')
32 | parser.add_argument('--mx_data_length',default=100,type=int,help='max data length')
33 | parser.add_argument('--max_epoch',default=10000,type=int)
34 | parser.add_argument('--early_stop',action="store_true",help='')
35 | parser.add_argument('--scratch',action="store_true",help='')
36 | #---------path set--------------------------------------
37 | parser.add_argument('--checkpoint_path',default='checkpoint-onlybaby',type=str)
38 | parser.add_argument('--pretrain_path',default=None,type=str)
39 |
40 | # ------optimizer set--------------------------------------
41 | parser.add_argument('--lr',default=0.002,type=float,help="Learning rate")
42 |
43 | parser.add_argument(
44 | '--local_rank',
45 | type=int,
46 | default=0,
47 | help='Local rank passed from distributed launcher'
48 | )
49 |
50 | args = parser.parse_args()
51 |
52 | def train_net(args):
53 | train_loader,test_loader,mx_length = get_data_loader(args)
54 |
55 | args.mx_data_length = mx_length
56 | if args.model == 'ccn':
57 | trainer = CCNTrainer(args)
58 | if args.model == 'ttn':
59 | trainer = TTNTrainer(args)
60 |
61 | trainer.train_network(train_loader,test_loader)
62 |
63 | if __name__ == "__main__":
64 |
65 | args = parser.parse_args()
66 |
67 | if args.model == 'ccn':
68 | params = CCNParams()
69 | if args.model == 'ttn':
70 | params = TTNParams()
71 |
72 | args = merge_args(args,params)
73 | if args.dist:
74 | dist.init_process_group(backend="nccl") # backbend='nccl'
75 | dist.barrier() # 用于同步训练
76 | args.world_size = dist.get_world_size() # 一共有几个节点
77 | args.rank = dist.get_rank() # 当前节点编号
78 |
79 | else:
80 | args.world_size = 1
81 | args.rank = 0
82 |
83 | setup_seed(args.seed+args.rank)
84 | print(args)
85 |
86 | args.checkpoint_path = os.path.join(args.checkpoint_path,args.name)
87 |
88 | print("local_rank %d | rank %d | world_size: %d"%(int(os.environ.get('LOCAL_RANK','0')),args.rank,args.world_size))
89 | if args.rank == 0 :
90 | if not os.path.exists(args.checkpoint_path):
91 | os.makedirs(args.checkpoint_path)
92 | print("make dir: ",args.checkpoint_path)
93 | train_net(args)
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/trainer/CCNTrainer.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author LeslieZhao
5 | @date 20220721
6 | '''
7 | import torch
8 |
9 | from trainer.ModelTrainer import ModelTrainer
10 | from model.styleganModule.model import Generator,Discriminator
11 | from model.styleganModule.utils import *
12 | from model.styleganModule.loss import *
13 | from utils.utils import *
14 |
15 |
16 | class CCNTrainer(ModelTrainer):
17 |
18 | def __init__(self, args):
19 | super().__init__(args)
20 | self.device = 'cpu'
21 | if torch.cuda.is_available():
22 | self.device = 'cuda'
23 |
24 | self.netGs = Generator(
25 | args.size,args.latent,
26 | args.n_mlp,
27 | channel_multiplier=args.channel_multiplier).to(self.device)
28 |
29 | self.netGt = Generator(args.size,args.latent,
30 | args.n_mlp,
31 | channel_multiplier=args.channel_multiplier).to(self.device)
32 |
33 |
34 | self.netD = Discriminator(
35 | args.size, channel_multiplier=args.channel_multiplier).to(self.device)
36 | self.gt_ema = Generator(
37 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(self.device)
38 | self.gt_ema.eval()
39 | accumulate(self.gt_ema, self.netGt, 0)
40 | self.g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
41 | self.d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
42 |
43 | self.ckpt = None
44 | # self.init_weights(init_type='kaiming')
45 | if not args.scratch:
46 | ckpt = torch.load(args.stylegan_path, map_location=lambda storage, loc: storage)
47 | self.netGs.load_state_dict(ckpt['g_ema'],strict=False)
48 | self.netGt.load_state_dict(ckpt['g'],strict=False)
49 | self.gt_ema.load_state_dict(ckpt['g_ema'],strict=False)
50 | self.netD.load_state_dict(ckpt["d"])
51 | self.ckpt = ckpt
52 | self.optimG,self.optimD = self.create_optimizer()
53 |
54 | self.sample_z = torch.randn(args.n_sample, args.latent).to(self.device)
55 |
56 |
57 | if args.pretrain_path is not None:
58 | self.loadParameters(args.pretrain_path)
59 |
60 | if args.dist:
61 | self.netGt,self.netGt_module = self.use_ddp(self.netGt)
62 | self.netD,self.netD_module = self.use_ddp(self.netD)
63 | else:
64 | self.netGt_module = self.netGt
65 | self.netD_module = self.netD
66 | self.netGs.eval()
67 |
68 | self.accum = 0.5 ** (32 / (10 * 1000))
69 | self.criterionID = IDLoss(args.id_model).to(self.device)
70 |
71 | self.mean_path_length = 0
72 |
73 |
74 | def create_optimizer(self):
75 | g_optim = torch.optim.Adam(
76 | self.netGt.parameters(),
77 | lr=self.args.lr * self.g_reg_ratio,
78 | betas=(0 ** self.g_reg_ratio, 0.99 ** self.g_reg_ratio),
79 | )
80 | d_optim = torch.optim.Adam(
81 | self.netD.parameters(),
82 | lr=self.args.lr * self.d_reg_ratio,
83 | betas=(0 ** self.d_reg_ratio, 0.99 ** self.d_reg_ratio),
84 | )
85 | if self.ckpt is not None:
86 | g_optim.load_state_dict(self.ckpt["g_optim"])
87 | d_optim.load_state_dict(self.ckpt["d_optim"])
88 | return g_optim,d_optim
89 |
90 |
91 | def run_single_step(self, data, steps):
92 | self.netGt.train()
93 | if self.args.interval_train:
94 | data = self.process_input(data)
95 | if not hasattr(self,'interval_flag'):
96 | self.interval_flag = True
97 | if self.interval_flag:
98 | self.run_generator_one_step(data,steps)
99 | self.d_losses = {}
100 | else:
101 | self.run_discriminator_one_step(data,steps)
102 | self.g_losses = {}
103 | if steps % self.args.interval_steps == 0:
104 | self.interval_flag = not self.interval_flag
105 |
106 |
107 | else:
108 | super().run_single_step(data, steps)
109 |
110 |
111 | def run_discriminator_one_step(self, data,step):
112 |
113 | D_losses = {}
114 | requires_grad(self.netGt, False)
115 | requires_grad(self.netD, True)
116 | noise = mixing_noise(self.args.batch_size,
117 | self.args.latent,
118 | self.args.mixing,self.device)
119 |
120 | fake_img, _ = self.netGt(noise)
121 | fake_pred = self.netD(fake_img)
122 | real_pred = self.netD(data)
123 | d_loss = d_logistic_loss(real_pred, fake_pred)
124 | D_losses['d'] = d_loss
125 |
126 | self.netD.zero_grad()
127 | d_loss.backward()
128 | self.optimD.step()
129 |
130 | if step % self.args.d_reg_every == 0:
131 | data.requires_grad = True
132 |
133 | real_pred = self.netD(data)
134 | r1_loss = d_r1_loss(real_pred,data)
135 | self.netD.zero_grad()
136 | r1_loss = self.args.r1 / 2 * \
137 | r1_loss * self.args.d_reg_every + \
138 | 0 * real_pred[0]
139 |
140 | r1_loss.mean().backward()
141 |
142 | self.optimD.step()
143 | D_losses['r1'] = r1_loss
144 |
145 | self.d_losses = D_losses
146 |
147 |
148 | def run_generator_one_step(self, data,step):
149 |
150 | G_losses = {}
151 | requires_grad(self.netGt, True)
152 | requires_grad(self.netD, False)
153 | requires_grad(self.netGs, False)
154 |
155 | noise = mixing_noise(self.args.batch_size,
156 | self.args.latent,
157 | self.args.mixing,self.device)
158 |
159 | fake_s,_ = self.netGs(noise)
160 | fake_t,_ = self.netGt(noise)
161 | fake_pred = self.netD(fake_t)
162 | gan_loss = g_nonsaturating_loss(fake_pred) * self.args.lambda_gan
163 | id_loss = self.criterionID(fake_s,fake_t) * self.args.lambda_id
164 | G_losses['g'] = gan_loss
165 | G_losses['id'] = id_loss
166 | losses = gan_loss + id_loss
167 | G_losses['g_losses'] = losses
168 | self.netGt.zero_grad()
169 | losses.mean().backward()
170 | self.optimG.step()
171 |
172 | if step % self.args.g_reg_every == 0:
173 | path_batch_size = max(1, self.args.batch_size // self.args.path_batch_shrink)
174 | noise = mixing_noise(path_batch_size, self.args.latent, self.args.mixing,self.device)
175 | fake_img, latents = self.netGt(noise, return_latents=True)
176 |
177 | path_loss, self.mean_path_length, path_lengths = g_path_regularize(
178 | fake_img, latents, self.mean_path_length
179 | )
180 |
181 | weighted_path_loss = self.args.path_regularize * self.args.g_reg_every * path_loss
182 | if self.args.path_batch_shrink:
183 | weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
184 |
185 | self.netGt.zero_grad()
186 | weighted_path_loss.mean().backward()
187 | self.optimG.step()
188 |
189 | G_losses['path'] = weighted_path_loss
190 |
191 | accumulate(self.gt_ema,self.netGt_module,self.accum)
192 | self.g_losses = G_losses
193 | self.generator = [data.detach(),fake_s.detach(),fake_t.detach()]
194 |
195 |
196 | def evalution(self,test_loader,steps,epoch):
197 |
198 | loss_dict = {}
199 | with torch.no_grad():
200 | fake_s,_ = self.netGs([self.sample_z])
201 | fake_t,_ = self.gt_ema([self.sample_z])
202 | if self.args.rank == 0 :
203 | self.val_vis.display_current_results(self.select_img([fake_s,fake_t]),steps)
204 | # self.val_vis.display_current_results(self.select_img([fake_t]),steps)
205 |
206 | return loss_dict
207 |
208 | def get_latest_losses(self):
209 | return {**self.g_losses,**self.d_losses}
210 |
211 | def get_latest_generated(self):
212 | return self.generator
213 |
214 | def loadParameters(self,path):
215 | ckpt = torch.load(path, map_location=lambda storage, loc: storage)
216 | self.netGs.load_state_dict(ckpt['Gs'],strict=False)
217 | self.netGt.load_state_dict(ckpt['Gt'],strict=False)
218 | self.gt_ema.load_state_dict(ckpt['gt_ema'],strict=False)
219 | self.optimG.load_state_dict(ckpt['g_optim'])
220 | self.optimD.load_state_dict(ckpt['d_optim'])
221 |
222 | def saveParameters(self,path):
223 | torch.save(
224 | {
225 | "Gs": self.netGs.state_dict(),
226 | "Gt": self.netGt_module.state_dict(),
227 | "gt_ema": self.gt_ema.state_dict(),
228 | "g_optim": self.optimG.state_dict(),
229 | "d_optim": self.optimD.state_dict(),
230 | "args": self.args,
231 | },
232 | path
233 | )
234 |
235 | def get_lr(self):
236 | return self.optimG.state_dict()['param_groups'][0]['lr']
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
--------------------------------------------------------------------------------
/trainer/ModelTrainer.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author LeslieZhao
5 | @date 20220721
6 | '''
7 | import torch
8 | import math
9 | import time,os
10 |
11 | from utils.visualizer import Visualizer
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | import torch.distributed as dist
14 | import subprocess
15 | from utils.utils import convert_img
16 | class ModelTrainer:
17 |
18 | def __init__(self,args):
19 |
20 | self.args = args
21 | self.batch_size = args.batch_size
22 | self.old_lr = args.lr
23 | if args.rank == 0 :
24 | self.vis = Visualizer(args)
25 |
26 | if args.eval:
27 | self.val_vis = Visualizer(args,"val")
28 |
29 | # ## ===== ===== ===== ===== =====
30 | # ## Train network
31 | # ## ===== ===== ===== ===== =====
32 |
33 | def train_network(self,train_loader,test_loader):
34 |
35 | counter = 0
36 | loss_dict = {}
37 | acc_num = 0
38 | mn_loss = float('inf')
39 |
40 | steps = 0
41 | begin_it = 0
42 | if self.args.pretrain_path:
43 | begin_it = int(self.args.pretrain_path.split('/')[-1].split('-')[0])
44 | steps = (begin_it+1) * math.ceil(self.args.mx_data_length/self.args.batch_size)
45 |
46 | print("current steps: %d | one epoch steps: %d "%(steps,self.args.mx_data_length))
47 |
48 | for epoch in range(begin_it+1,self.args.max_epoch):
49 |
50 | for ii,(data) in enumerate(train_loader):
51 |
52 | tstart = time.time()
53 |
54 | self.run_single_step(data,steps)
55 | losses = self.get_latest_losses()
56 |
57 | for key,val in losses.items():
58 | loss_dict[key] = loss_dict.get(key,0) + val.mean().item()
59 |
60 | counter += 1
61 | steps += 1
62 |
63 | telapsed = time.time() - tstart
64 |
65 |
66 | if ii % self.args.print_interval == 0 and self.args.rank == 0:
67 | for key,val in loss_dict.items():
68 | loss_dict[key] /= counter
69 |
70 | lr_rate = self.get_lr()
71 | print_dict = {**{"time":telapsed,"lr":lr_rate},
72 | **loss_dict}
73 | self.vis.print_current_errors(epoch,ii,print_dict,telapsed)
74 |
75 | self.vis.plot_current_errors(print_dict,steps)
76 |
77 | loss_dict = {}
78 | counter = 0
79 |
80 | # torch.cuda.empty_cache()
81 | if self.args.save_interval != 0 and ii % self.args.save_interval == 0 and \
82 | self.args.rank == 0:
83 | self.saveParameters(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,ii)))
84 |
85 | display_data = self.select_img(self.get_latest_generated())
86 |
87 | self.vis.display_current_results(display_data,steps)
88 |
89 |
90 |
91 | if self.args.eval and self.args.test_interval > 0 and steps % self.args.test_interval == 0:
92 | val_loss = self.evalution(test_loader,steps,epoch)
93 |
94 | if self.args.early_stop:
95 |
96 | acc_num,mn_loss,stop_flag = self.early_stop_wait(self.get_loss_from_val(val_loss),acc_num,mn_loss,epoch)
97 | if stop_flag:
98 | return
99 |
100 | # print('******************memory:',psutil.virtual_memory()[3])
101 |
102 | if self.args.rank == 0 :
103 | self.saveParameters(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,0)))
104 |
105 | # 验证,保存最优模型
106 | if test_loader or self.args.eval:
107 | val_loss = self.evalution(test_loader,steps,epoch)
108 |
109 | if self.args.early_stop:
110 |
111 | acc_num,mn_loss,stop_flag = self.early_stop_wait(self.get_loss_from_val(val_loss),acc_num,mn_loss,epoch)
112 | if stop_flag:
113 | return
114 |
115 |
116 | if self.args.rank == 0 :
117 | self.vis.close()
118 |
119 |
120 |
121 | def early_stop_wait(self,loss,acc_num,mn_loss,epoch):
122 |
123 | if self.args.rank == 0:
124 | if loss < mn_loss:
125 | mn_loss = loss
126 | cmd_one = 'cp -r %s %s'%(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,0)),
127 | os.path.join(self.args.checkpoint_path,'final.pth'))
128 | done_one = subprocess.Popen(cmd_one,stdout=subprocess.PIPE,shell=True)
129 | done_one.wait()
130 | acc_num = 0
131 | else:
132 | acc_num += 1
133 | # 多机多卡,某一张卡退出则终止程序,使用all_reduce
134 | if self.args.dist:
135 |
136 | if acc_num > self.args.stop_interval:
137 | signal = torch.tensor([0]).cuda()
138 | else:
139 | signal = torch.tensor([1]).cuda()
140 | else:
141 | if self.args.dist:
142 | signal = torch.tensor([1]).cuda()
143 |
144 | if self.args.dist:
145 | dist.all_reduce(signal)
146 | value = signal.item()
147 | if value >= int(os.environ.get("WORLD_SIZE","1")):
148 | dist.all_reduce(torch.tensor([0]).cuda())
149 | return acc_num,mn_loss,False
150 | else:
151 | return acc_num,mn_loss,True
152 |
153 | else:
154 | if acc_num > self.args.stop_interval:
155 | return acc_num,mn_loss,True
156 | else:
157 | return acc_num,mn_loss,False
158 |
159 | def run_single_step(self,data,steps):
160 | data = self.process_input(data)
161 | self.run_discriminator_one_step(data,steps)
162 | self.run_generator_one_step(data,steps)
163 |
164 | def select_img(self,data,name='fake'):
165 | if data is None:
166 | return None
167 | cat_img = []
168 | for v in data:
169 | cat_img.append(v.detach().cpu())
170 |
171 | cat_img = torch.cat(cat_img,-1)
172 | cat_img = torch.cat(torch.split(cat_img,1,dim=0),2)[0]
173 |
174 | return {name:convert_img(cat_img)}
175 |
176 |
177 |
178 | ##################################################################
179 | # Helper functions
180 | ##################################################################
181 |
182 | def get_loss_from_val(self,loss):
183 | return loss
184 |
185 | def get_show_inp(self,data):
186 | if not isinstance(data,list):
187 | return [data]
188 | return data
189 |
190 | def use_ddp(self,model):
191 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) #用于将BN转换成ddp模式/
192 | # model = DDP(model,broadcast_buffers=False,find_unused_parameters=True) # find_unused_parameters->训练gan会有判别器或生成器参数不参与训练,需使用该参数
193 | model = DDP(model,
194 | broadcast_buffers=False,
195 | )
196 | model_on_one_gpu = model.module #若需要调用self.model的函数,在ddp模式要调用self._model_on_one_gpu
197 | return model,model_on_one_gpu
198 | def process_input(self,data):
199 |
200 | if torch.cuda.is_available():
201 | if isinstance(data,list):
202 | data = [x.cuda() for x in data]
203 | else:
204 | data = data.cuda()
205 | return data
--------------------------------------------------------------------------------
/trainer/TTNTrainer.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author LeslieZhao
5 | @date 20220721
6 | '''
7 | import torch
8 |
9 | from trainer.ModelTrainer import ModelTrainer
10 | from model.Pix2PixModule.model import Generator,Discriminator,ExpressDetector
11 | from utils.utils import *
12 | from model.Pix2PixModule.module import *
13 | from model.Pix2PixModule.loss import *
14 | import torch.distributed as dist
15 | import random
16 | import itertools
17 |
18 | class TTNTrainer(ModelTrainer):
19 |
20 | def __init__(self, args):
21 | super().__init__(args)
22 | self.device = 'cpu'
23 | if torch.cuda.is_available():
24 | self.device = 'cuda'
25 | self.netG = Generator(img_channels=3).to(self.device)
26 |
27 | self.netTxD = Discriminator(in_channels=1).to(self.device)
28 |
29 | self.netSfD = Discriminator(in_channels=3).to(self.device)
30 |
31 | self.ExpG = None
32 | if args.use_exp:
33 | self.ExpG = ExpressDetector().to(self.device)
34 | self.ExpG.apply(init_weights)
35 |
36 | self.netG.apply(init_weights)
37 |
38 | self.netTxD.apply(init_weights)
39 | self.netSfD.apply(init_weights)
40 |
41 | self.optimG,self.optimD,self.optimExp = self.create_optimizer()
42 |
43 | if args.pretrain_path is not None:
44 | self.loadParameters(args.pretrain_path)
45 |
46 | if args.dist:
47 | self.netG,self.netG_module = self.use_ddp(self.netG)
48 |
49 | self.netTxD,self.netTxD_module = self.use_ddp(self.netTxD)
50 | self.netSfD,self.netSfD_module = self.use_ddp(self.netSfD)
51 |
52 | if args.use_exp:
53 | self.ExpG,self.ExpG_module = self.use_ddp(self.ExpG)
54 | else:
55 | self.netG_module = self.netG
56 | self.netTxD_module = self.netTxD
57 | self.netSfD_module = self.netSfD
58 | if args.use_exp:
59 | self.ExpG_module = self.ExpG
60 |
61 |
62 | self.VggLoss = VGGLoss(args.vgg_model).to(self.device).eval()
63 | self.TVLoss = TVLoss(1).to(self.device).eval()
64 | self.L1_Loss = nn.L1Loss()
65 | self.MSE_Loss = nn.MSELoss()
66 |
67 |
68 | def create_optimizer(self):
69 | g_optim = torch.optim.Adam(
70 | self.netG.parameters(),
71 | lr=self.args.lr,
72 | betas=(self.args.beta1,self.args.beta2),
73 | )
74 |
75 | d_optim = torch.optim.Adam(
76 |
77 | itertools.chain(self.netTxD.parameters(),self.netSfD.parameters()),
78 | lr=self.args.lr,
79 | betas=(self.args.beta1,self.args.beta2),
80 | )
81 | exp_optim = None
82 | if self.args.use_exp:
83 | exp_optim = torch.optim.Adam(
84 | self.ExpG.parameters(),
85 | lr=self.args.lr,
86 | betas=(self.args.beta1,self.args.beta2),
87 | )
88 |
89 | return g_optim,d_optim,exp_optim
90 |
91 |
92 | def run_single_step(self, data, steps):
93 | self.netG.train()
94 | super().run_single_step(data, steps)
95 |
96 |
97 | def run_discriminator_one_step(self, data,step):
98 |
99 | D_losses = {}
100 | requires_grad(self.netTxD, True)
101 | requires_grad(self.netSfD, True)
102 | xs,xt,_ = data
103 | with torch.no_grad():
104 | xg = self.netG(xs)
105 |
106 | # surface
107 |
108 | blur_fake = guided_filter(xg,xg,r=5,eps=2e-1)
109 | blur_style = guided_filter(xt,xt,r=5,eps=2e-1)
110 |
111 | D_blur_real = self.netSfD(blur_style)
112 | D_blur_fake = self.netSfD(blur_fake)
113 | d_loss_surface_real = self.MSE_Loss(D_blur_real,torch.ones_like(D_blur_real))
114 | d_loss_surface_fake = self.MSE_Loss(D_blur_fake, torch.zeros_like(D_blur_fake))
115 | d_loss_surface = (d_loss_surface_real + d_loss_surface_fake)/2.0
116 |
117 | D_losses['d_surface_real'] = d_loss_surface_real
118 | D_losses['d_surface_fake'] = d_loss_surface_fake
119 |
120 | # texture
121 | gray_fake = color_shift(xg)
122 | gray_style = color_shift(xt)
123 |
124 | D_gray_real = self.netTxD(gray_style)
125 | D_gray_fake = self.netTxD(gray_fake.detach())
126 | d_loss_texture_real = self.MSE_Loss(D_gray_real, torch.ones_like(D_gray_real))
127 | d_loss_texture_fake = self.MSE_Loss(D_gray_fake, torch.zeros_like(D_gray_fake))
128 | d_loss_texture = (d_loss_texture_real + d_loss_texture_fake)/2.0
129 |
130 | D_losses['d_texture_real'] = d_loss_texture_real
131 | D_losses['d_texture_fake'] = d_loss_texture_fake
132 |
133 | d_loss_total = d_loss_surface + d_loss_texture
134 |
135 |
136 | self.optimD.zero_grad()
137 | d_loss_total.backward()
138 |
139 | self.optimD.step()
140 | self.d_losses = D_losses
141 |
142 | def run_generator_one_step(self, data,step):
143 |
144 | G_losses = {}
145 | requires_grad(self.netG, True)
146 | requires_grad(self.netTxD, False)
147 | requires_grad(self.netSfD, False)
148 | requires_grad(self.ExpG, False)
149 | xs,xt,exp_gt = data
150 |
151 | G_losses,losses,xg = self.compute_g_loss(xs,exp_gt,step)
152 |
153 | self.netG.zero_grad()
154 | losses.backward()
155 | self.optimG.step()
156 |
157 | if self.args.use_exp:
158 | requires_grad(self.ExpG, True)
159 | pred_exp = self.ExpG(xg.detach())
160 | exp_loss = self.MSE_Loss(pred_exp,exp_gt)
161 | self.optimExp.zero_grad()
162 | exp_loss.backward()
163 | self.optimExp.step()
164 | G_losses['raw_exp_loss'] = exp_loss
165 |
166 |
167 | self.g_losses = G_losses
168 | self.generator = [xs.detach(),xg.detach(),xt.detach()]
169 |
170 |
171 | def evalution(self,test_loader,steps,epoch):
172 |
173 | loss_dict = {}
174 | counter = 0
175 | index = random.randint(0,len(test_loader)-1)
176 | self.netG.eval()
177 |
178 | with torch.no_grad():
179 | for i,data in enumerate(test_loader):
180 |
181 | data = self.process_input(data)
182 | xs,xt,exp_gt = data
183 | G_losses,losses,xg = self.compute_g_loss(xs,exp_gt,steps)
184 | for k,v in G_losses.items():
185 | loss_dict[k] = loss_dict.get(k,0) + v.detach()
186 | if i == index and self.args.rank == 0 :
187 |
188 | self.val_vis.display_current_results(self.select_img([xs,xg,xt]),steps)
189 | counter += 1
190 |
191 |
192 | for key,val in loss_dict.items():
193 | loss_dict[key] /= counter
194 |
195 | if self.args.dist:
196 | # if self.args.rank == 0 :
197 | dist_losses = loss_dict.copy()
198 | for key,val in loss_dict.items():
199 |
200 | dist.reduce(dist_losses[key],0)
201 | value = dist_losses[key].item()
202 | loss_dict[key] = value / self.args.world_size
203 |
204 | if self.args.rank == 0 :
205 | self.val_vis.plot_current_errors(loss_dict,steps)
206 | self.val_vis.print_current_errors(epoch+1,0,loss_dict,0)
207 |
208 | return loss_dict
209 |
210 |
211 | def compute_g_loss(self,xs,exp_gt,step):
212 | G_losses = {}
213 |
214 | xg = self.netG(xs)
215 |
216 | # warp_up
217 | if step < 100:
218 | lambda_surface = 0
219 | lambda_texture = 0
220 | lambda_exp = 0
221 |
222 | elif step < 500:
223 | lambda_surface = self.args.lambda_surface * 0.01
224 | lambda_texture = self.args.lambda_texture * 0.01
225 | lambda_exp = self.args.lambda_exp * 0.01
226 |
227 | elif step < 1000:
228 | lambda_surface = self.args.lambda_surface * 0.1
229 | lambda_texture = self.args.lambda_texture * 0.1
230 | lambda_exp = self.args.lambda_exp * 0.1
231 |
232 | else:
233 | lambda_surface = self.args.lambda_surface
234 | lambda_texture = self.args.lambda_texture
235 | lambda_exp = self.args.lambda_exp
236 |
237 |
238 | # surface
239 | blur_fake = guided_filter(xg,xg,r=5,eps=2e-1)
240 | D_blur_fake = self.netSfD(blur_fake)
241 | g_loss_surface = lambda_surface * self.MSE_Loss(D_blur_fake, torch.ones_like(D_blur_fake))
242 | G_losses['g_loss_surface'] = g_loss_surface
243 |
244 | # texture
245 | gray_fake = color_shift(xg)
246 | D_gray_fake = self.netTxD(gray_fake)
247 | g_loss_texture = lambda_texture * self.MSE_Loss(D_gray_fake, torch.ones_like(D_gray_fake))
248 | G_losses['g_loss_texture'] = g_loss_texture
249 |
250 | # content
251 | content_loss = self.VggLoss(xs,xg) * self.args.lambda_content
252 | G_losses['content_loss'] = content_loss
253 |
254 | # tv loss
255 | tv_loss = self.TVLoss(xg) * self.args.lambda_tv
256 | G_losses['tvloss'] = tv_loss
257 |
258 | # exp loss
259 | exp_loss = 0
260 | if self.args.use_exp:
261 | exp_pred = self.ExpG(xg)
262 | exp_loss = self.MSE_Loss(exp_pred,exp_gt) * lambda_exp
263 | G_losses['exploss'] = exp_loss
264 |
265 | losses = g_loss_surface + g_loss_texture + content_loss + tv_loss + exp_loss * lambda_exp
266 | G_losses['total_loss'] = losses
267 | return G_losses,losses,xg
268 |
269 | def get_latest_losses(self):
270 | return {**self.g_losses,**self.d_losses}
271 |
272 | def get_latest_generated(self):
273 | return self.generator
274 |
275 | def get_loss_from_val(self,loss):
276 | return loss['total_loss']
277 |
278 | def loadParameters(self,path):
279 | ckpt = torch.load(path, map_location=lambda storage, loc: storage)
280 | self.netG.load_state_dict(ckpt['netG'],strict=False)
281 | self.netTxD.load_state_dict(ckpt['netTxD'],strict=False)
282 | self.netSfD.load_state_dict(ckpt['netSfD'],strict=False)
283 | self.optimG.load_state_dict(ckpt['g_optim'])
284 | self.optimD.load_state_dict(ckpt['d_optim'])
285 | if self.args.use_exp:
286 | self.ExpG.load_state_dict(ckpt['ExpG'],strict=False)
287 | self.optimExp.load_state_dict(ckpt['exp_optim'])
288 | def saveParameters(self,path):
289 | save_dict = {
290 | "netG": self.netG_module.state_dict(),
291 | "netTxD": self.netTxD_module.state_dict(),
292 | "netSfD": self.netSfD_module.state_dict(),
293 | "g_optim": self.optimG.state_dict(),
294 | "d_optim": self.optimD.state_dict(),
295 | "args": self.args,
296 | }
297 | if self.args.use_exp:
298 | save_dict['ExpG'] = self.ExpG_module.state_dict()
299 | save_dict['exp_optim'] = self.optimExp.state_dict()
300 | torch.save(
301 | save_dict,
302 | path
303 | )
304 |
305 | def get_lr(self):
306 | return self.optimG.state_dict()['param_groups'][0]['lr']
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
--------------------------------------------------------------------------------
/utils/download_weight.sh:
--------------------------------------------------------------------------------
1 | wget https://github.com/LeslieZhoa/DCT-NET.Pytorch/releases/download/v0.0/final.pth -P ../pretrain_models
2 | wget https://github.com/LeslieZhoa/DCT-NET.Pytorch/releases/download/v0.0/model_ir_se50.pth -P ../pretrain_models
3 | wget https://github.com/LeslieZhoa/DCT-NET.Pytorch/releases/download/v0.0/vgg19-dcbb9e9d.pth -P ../pretrain_models
--------------------------------------------------------------------------------
/utils/get_face_expression.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import pdb
4 | import cv2
5 | import os
6 | from multiprocessing import Pool
7 | import time
8 | import math
9 | import multiprocessing as mp
10 | import numpy as np
11 |
12 | import argparse
13 |
14 | parser = argparse.ArgumentParser(description="Process")
15 | parser.add_argument('--img_base',default="",type=str,help='')
16 | parser.add_argument('--pool_num',default=2,type=int,help='')
17 | parser.add_argument('--LVT',default='',type=str,help='')
18 | parser.add_argument('--train',action="store_true",help='')
19 | args = parser.parse_args()
20 |
21 | class Process:
22 | def __init__(self):
23 | self.engine = Engine(
24 | face_lmk_path='')
25 |
26 | def run(self,img_paths):
27 | mx_left_eye = -1
28 | mn_left_eye = 100
29 |
30 | mx_right_eye = -1
31 | mn_right_eye = 100
32 |
33 | mx_lip = -1
34 | mn_lip = 100
35 | for i,img_path in enumerate(img_paths):
36 | img = cv2.imread(img_path)
37 | left_eye_score,right_eye_score,lip_score = \
38 | self.run_single(img)
39 |
40 | base,img_name = os.path.split(img_path)
41 | score_base = base.replace('img','express')
42 | os.makedirs(score_base,exist_ok=True)
43 | np.save(os.path.join(score_base,img_name.split('.')[0]+'.npy'),[left_eye_score,right_eye_score,lip_score])
44 |
45 | mx_left_eye = max(left_eye_score,mx_left_eye)
46 | mn_left_eye = min(left_eye_score,mn_left_eye)
47 |
48 | mx_right_eye = max(right_eye_score,mx_right_eye)
49 | mn_right_eye = min(right_eye_score,mn_right_eye)
50 |
51 | mx_lip = max(lip_score,mx_lip)
52 | mn_lip = min(lip_score,mn_lip)
53 | print('\rhave done %04d'%i,end='',flush=True)
54 | print()
55 | return mx_left_eye,mn_left_eye,mx_right_eye,mn_right_eye,mx_lip,mn_lip
56 |
57 |
58 | def run_single(self,img):
59 | inp = self.engine.preprocess_lmk(img)
60 | lmk = self.engine.get_lmk(inp)
61 | lmk = self.engine.postprocess_lmk(lmk,256,[0,0])
62 | scores = self.get_expression(lmk[0])
63 | return scores
64 | def get_expression(self,lmk):
65 | left_eye_h = abs(lmk[66,1]-lmk[62,1])
66 | left_eye_w = abs(lmk[60,0]-lmk[64,0])
67 | left_eye_score = left_eye_h / max(left_eye_w,1e-5)
68 |
69 | right_eye_h = abs(lmk[70,1]-lmk[74,1])
70 | right_eye_w = abs(lmk[68,0]-lmk[72,0])
71 | right_eye_score = right_eye_h / max(right_eye_w,1e-5)
72 |
73 | lip_h = abs(lmk[90,1]-lmk[94,1])
74 | lip_w = abs(lmk[88,0]-lmk[82,0])
75 | lip_score = lip_h / max(lip_w,1e-5)
76 |
77 | return left_eye_score,right_eye_score,lip_score
78 |
79 | def work(queue,img_paths):
80 | model = Process()
81 | mx_left_eye,mn_left_eye,mx_right_eye,mn_right_eye,mx_lip,mn_lip = \
82 | model.run(img_paths)
83 | queue.put([mx_left_eye,mn_left_eye,mx_right_eye,mn_right_eye,mx_lip,mn_lip])
84 |
85 | def print_error(value):
86 | print("error: ", value)
87 | if __name__ == "__main__":
88 |
89 | args = parser.parse_args()
90 | # import LVT base path
91 | sys.path.insert(0,args.LVT)
92 | from LVT import Engine
93 | from utils import utils
94 |
95 | mp.set_start_method('spawn')
96 | m = mp.Manager()
97 | queue = m.Queue()
98 | model = Process()
99 | base = args.img_base
100 | img_paths = [os.path.join(base,f) for f in os.listdir(base)]
101 | pool_num = args.pool_num
102 | length = len(img_paths)
103 |
104 | dis = math.ceil(length/float(pool_num))
105 |
106 |
107 | t1 = time.time()
108 | print('***************all length: %d ******************'%length)
109 | p = Pool(pool_num)
110 | for i in range(pool_num):
111 | p.apply_async(work, args = (queue,img_paths[i*dis:(i+1)*dis],),error_callback=print_error)
112 |
113 | p.close()
114 | p.join()
115 | print("all the time: %s"%(time.time()-t1))
116 |
117 | if args.train:
118 | mx_left_eye_all = -1
119 | mn_left_eye_all = 100
120 |
121 | mx_right_eye_all = -1
122 | mn_right_eye_all = 100
123 |
124 | mx_lip_all = -1
125 | mn_lip_all = 100
126 |
127 | while not queue.empty():
128 | mx_left_eye,mn_left_eye,mx_right_eye,mn_right_eye,mx_lip,mn_lip = \
129 | queue.get()
130 |
131 | mx_left_eye_all = max(mx_left_eye_all,mx_left_eye)
132 | mn_left_eye_all = min(mn_left_eye_all,mn_left_eye)
133 |
134 | mx_right_eye_all = max(mx_right_eye_all,mx_right_eye)
135 | mn_right_eye_all = min(mn_right_eye_all,mn_right_eye)
136 |
137 | mx_lip_all = max(mx_lip_all,mx_lip)
138 | mn_lip_all = min(mn_lip_all,mn_lip)
139 | os.makedirs('../pretrain_models',exist_ok=True)
140 | np.save('../pretrain_models/all_express_mean.npy',[mx_left_eye_all,mn_left_eye_all,
141 | mx_right_eye_all,mn_right_eye_all,
142 | mx_lip_all,mn_lip_all])
143 |
144 |
--------------------------------------------------------------------------------
/utils/get_tcc_input.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0,'..')
3 | import pdb
4 | from model.styleganModule.model import Generator
5 | import torch
6 | from model.styleganModule.utils import *
7 | from utils import convert_img
8 | from model.styleganModule.config import Params as CCNParams
9 | import cv2
10 | import os
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser(description="Process")
14 | parser.add_argument('--model_path',default="",type=str,help='')
15 | parser.add_argument('--output_path',default="",type=str,help='')
16 | args = parser.parse_args()
17 |
18 | class Process:
19 | def __init__(self,model_path):
20 | self.args = CCNParams()
21 | self.device = 'cpu'
22 | if torch.cuda.is_available():
23 | self.device = 'cuda'
24 |
25 | self.netGt = Generator(
26 | self.args.size,self.args.latent,
27 | self.args.n_mlp,
28 | channel_multiplier=self.args.channel_multiplier).to(self.device)
29 | self.netGs = Generator(
30 | self.args.size,self.args.latent,
31 | self.args.n_mlp,
32 | channel_multiplier=self.args.channel_multiplier).to(self.device)
33 |
34 | self.loadparams(model_path)
35 |
36 | self.netGs.eval()
37 | self.netGt.eval()
38 |
39 | def __call__(self,save_base):
40 | os.makedirs(save_base,exist_ok=True)
41 | steps = 0
42 | for i in range(self.args.mx_gen_iters):
43 | fakes = self.run_sigle()
44 | for f in fakes:
45 | cv2.imwrite(os.path.join(save_base,'%06d.png'%steps),f)
46 | steps += 1
47 | print('\r have done %06d'%steps,end='',flush=True)
48 |
49 | def run_sigle(self):
50 | noise = mixing_noise(self.args.infer_batch_size,
51 | self.args.latent,
52 | self.args.mixing,self.device)
53 | with torch.no_grad():
54 | latent_s = self.netGs(noise,only_latent=True)
55 | latent_t = self.netGt(noise,only_latent=True)
56 | # fake_s,latent_s = self.netGs(noise,return_latents=True)
57 | # fake_t,latent_t = self.netGt(noise,return_latents=True)
58 | # mix
59 | latent_mix = torch.cat([latent_s[:,:self.args.inject_index],latent_t[:,self.args.inject_index:]],1)
60 | fake,_ = self.netGt([latent_mix],input_is_latent=True)
61 | # fake = torch.cat([fake_s,fake_t,fake],-1)
62 | fake = convert_img(fake,unit=True).permute(0,2,3,1)
63 | return fake.cpu().numpy()[...,::-1]
64 |
65 |
66 | def loadparams(self,path):
67 | ckpt = torch.load(path, map_location=lambda storage, loc: storage)
68 | self.netGs.load_state_dict(ckpt['Gs'],strict=False)
69 | self.netGt.load_state_dict(ckpt['gt_ema'],strict=False)
70 |
71 | if __name__ == "__main__":
72 | args = parser.parse_args()
73 |
74 |
75 | model = Process(args.model_path)
76 |
77 | model(args.output_path)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | @author LeslieZhao
4 | @date 20220721
5 | '''
6 | import torch
7 | from data.CCNLoader import CCNData
8 | from data.TTNLoader import TTNData
9 |
10 | import os
11 | import torch.distributed as dist
12 |
13 | def reduce_loss_dict(loss_dict):
14 | world_size = int(os.environ.get("WORLD_SIZE","1"))
15 |
16 | if world_size < 2:
17 | return loss_dict
18 |
19 | with torch.no_grad():
20 | keys = []
21 | losses = []
22 |
23 | for k in sorted(loss_dict.keys()):
24 | keys.append(k)
25 | losses.append(loss_dict[k])
26 |
27 | losses = torch.stack(losses, 0)
28 | dist.reduce(losses, dst=0)
29 |
30 | if dist.get_rank() == 0:
31 | losses /= world_size
32 |
33 | reduced_losses = {k: v for k, v in zip(keys, losses)}
34 |
35 | return reduced_losses
36 |
37 | def requires_grad(model, flag=True):
38 | if model is None:
39 | return
40 | for p in model.parameters():
41 | p.requires_grad = flag
42 | def need_grad(x):
43 | x = x.detach()
44 | x.requires_grad_()
45 | return x
46 |
47 | def init_weights(model,init_type='normal', gain=0.02):
48 | def init_func(m):
49 | classname = m.__class__.__name__
50 | if classname.find('BatchNorm2d') != -1:
51 | if hasattr(m, 'weight') and m.weight is not None:
52 | torch.nn.init.normal_(m.weight.data, 1.0, gain)
53 | if hasattr(m, 'bias') and m.bias is not None:
54 | torch.nn.init.constant_(m.bias.data, 0.0)
55 |
56 | elif hasattr(m, 'bias') and m.bias is not None:
57 | torch.nn.init.constant_(m.bias.data, 0.0)
58 |
59 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
60 | if init_type == 'normal':
61 | torch.nn.init.normal_(m.weight.data, 0.0, gain)
62 | elif init_type == 'xavier':
63 | torch.nn.init.xavier_normal_(m.weight.data, gain=gain)
64 | elif init_type == 'xavier_uniform':
65 | torch.nn.init.xavier_uniform_(m.weight.data, gain=1.0)
66 | elif init_type == 'kaiming':
67 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
68 | elif init_type == 'orthogonal':
69 | torch.nn.init.orthogonal_(m.weight.data, gain=gain)
70 | elif init_type == 'none': # uses pytorch's default init method
71 | m.reset_parameters()
72 | else:
73 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
74 |
75 | model.apply(init_func)
76 |
77 | def accumulate(model1, model2, decay=0.999,use_buffer=False):
78 | par1 = dict(model1.named_parameters())
79 | par2 = dict(model2.named_parameters())
80 |
81 | for k in par1.keys():
82 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
83 |
84 | if use_buffer:
85 | for p1,p2 in zip(model1.buffers(),model2.buffers()):
86 | p1.detach().copy_(decay*p1.detach()+(1-decay)*p2.detach())
87 |
88 | def setup_seed(seed):
89 | torch.manual_seed(seed)
90 | if torch.cuda.is_available():
91 | torch.cuda.manual_seed_all(seed)
92 | torch.backends.cudnn.deterministic = True
93 |
94 | def get_data_loader(args):
95 | if args.model == 'ccn':
96 | train_data = CCNData(root=args.root,dist=args.dist)
97 | test_data = None
98 | if args.model == 'ttn':
99 | train_data = TTNData(dist=args.dist,eval=False,
100 | src_root=args.train_src_root,
101 | tgt_root=args.train_tgt_root,
102 | score_info=args.score_info)
103 | test_data = TTNData(dist=args.dist,eval=True,
104 | src_root=args.val_src_root,
105 | tgt_root=args.val_tgt_root,
106 | score_info=args.score_info)
107 |
108 | train_loader = torch.utils.data.DataLoader(
109 | train_data,
110 | batch_size=args.batch_size,
111 | num_workers=args.nDataLoaderThread,
112 | pin_memory=False,
113 | drop_last=True
114 | )
115 | test_loader = None if test_data is None else \
116 | torch.utils.data.DataLoader(
117 | test_data,
118 | batch_size=args.batch_size,
119 | num_workers=args.nDataLoaderThread,
120 | pin_memory=False,
121 | drop_last=True
122 | )
123 | return train_loader,test_loader,len(train_data)
124 |
125 |
126 |
127 | def merge_args(args,params):
128 | for k,v in vars(params).items():
129 | setattr(args,k,v)
130 | return args
131 |
132 | def convert_img(img,unit=False):
133 |
134 | img = (img + 1) * 0.5
135 | if unit:
136 | return torch.clamp(img*255+0.5,0,255)
137 |
138 | return torch.clamp(img,0,1)
--------------------------------------------------------------------------------
/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. ALL rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import os
7 | import time
8 | import subprocess
9 |
10 | from tensorboardX import SummaryWriter
11 |
12 |
13 |
14 | class Visualizer:
15 | def __init__(self,opt,mode='train'):
16 | self.opt = opt
17 | self.name = opt.name
18 | self.mode = mode
19 | self.train_log_dir = os.path.join(opt.checkpoint_path,"logs/%s"%mode)
20 | self.log_name = os.path.join(opt.checkpoint_path,'loss_log_%s.txt'%mode)
21 | if opt.local_rank == 0:
22 | if not os.path.exists(self.train_log_dir):
23 | os.makedirs(self.train_log_dir)
24 |
25 | self.train_writer = SummaryWriter(self.train_log_dir)
26 |
27 | self.log_file = open(self.log_name,"a")
28 | now = time.strftime("%c")
29 | self.log_file.write('================ Training Loss (%s) =================\n'%now)
30 | self.log_file.flush()
31 |
32 |
33 | # errors:dictionary of error labels and values
34 | def plot_current_errors(self,errors,step):
35 |
36 | for tag,value in errors.items():
37 |
38 | self.train_writer.add_scalar("%s/"%self.name+tag,value,step)
39 | self.train_writer.flush()
40 |
41 |
42 | # errors: same format as |errors| of CurrentErrors
43 | def print_current_errors(self,epoch,i,errors,t):
44 | message = '(epoch: %d\t iters: %d\t time: %.5f)\t'%(epoch,i,t)
45 | for k,v in errors.items():
46 |
47 | message += '%s: %.5f\t' %(k,v)
48 |
49 | print(message)
50 |
51 | self.log_file.write('%s\n' % message)
52 | self.log_file.flush()
53 |
54 | def display_current_results(self, visuals, step):
55 | if visuals is None:
56 | return
57 | for label, image in visuals.items():
58 | # Write the image to a string
59 |
60 | self.train_writer.add_image("%s/"%self.name+label,image,global_step=step)
61 |
62 | def close(self):
63 |
64 | self.train_writer.close()
65 | self.log_file.close()
--------------------------------------------------------------------------------