├── metrics
├── __init__.py
└── fid
│ ├── __init__.py
│ ├── inception.py
│ └── fid_score.py
├── videos
├── __init__.py
└── test_video.py
├── faceseg
├── __init__.py
├── resnet.py
├── FaceSegmentation.py
└── networks_faceseg.py
├── assets
├── kid.png
├── teaser.png
├── ablation.png
├── generator.png
├── user_study.png
├── bg_parsing_boy.png
├── discriminator.png
└── face_parsing_boy.png
├── dataset
└── YOUR_DATASET_NAME
│ ├── testB
│ └── 3414.png
│ ├── trainB
│ └── 0006.png
│ ├── testA
│ └── female_2321.jpg
│ └── trainA
│ └── female_222.jpg
├── requirements.txt
├── scripts
├── test.sh
└── train.sh
├── LICENSE
├── .gitignore
├── README.md
├── dataset.py
├── utils.py
├── histogram.py
├── main.py
├── networks.py
└── UGATIT.py
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .fid import FIDScore
2 |
--------------------------------------------------------------------------------
/videos/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Description:
4 | """
--------------------------------------------------------------------------------
/faceseg/__init__.py:
--------------------------------------------------------------------------------
1 | # ref: https://github.com/zllrunning/face-parsing.PyTorch
2 |
--------------------------------------------------------------------------------
/assets/kid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/kid.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/assets/ablation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/ablation.png
--------------------------------------------------------------------------------
/assets/generator.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/generator.png
--------------------------------------------------------------------------------
/assets/user_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/user_study.png
--------------------------------------------------------------------------------
/assets/bg_parsing_boy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/bg_parsing_boy.png
--------------------------------------------------------------------------------
/assets/discriminator.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/discriminator.png
--------------------------------------------------------------------------------
/metrics/fid/__init__.py:
--------------------------------------------------------------------------------
1 | # ref: https://github.com/mseitzer/pytorch-fid
2 | from .fid_score import FIDScore
3 |
--------------------------------------------------------------------------------
/assets/face_parsing_boy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/face_parsing_boy.png
--------------------------------------------------------------------------------
/dataset/YOUR_DATASET_NAME/testB/3414.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/dataset/YOUR_DATASET_NAME/testB/3414.png
--------------------------------------------------------------------------------
/dataset/YOUR_DATASET_NAME/trainB/0006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/dataset/YOUR_DATASET_NAME/trainB/0006.png
--------------------------------------------------------------------------------
/dataset/YOUR_DATASET_NAME/testA/female_2321.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/dataset/YOUR_DATASET_NAME/testA/female_2321.jpg
--------------------------------------------------------------------------------
/dataset/YOUR_DATASET_NAME/trainA/female_222.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/dataset/YOUR_DATASET_NAME/trainA/female_222.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # pip install -r requirements.txt
2 | # python=3.8
3 | torch==1.9.0
4 | torchvision>=0.9.0
5 | tensorboard
6 | opencv-python
7 | scipy
8 | tqdm
9 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # 模型测试视频 + attention 模型 + scse
3 | python main_light.py --phase video --generator_model checkpoints/big_normal_100w.pt \
4 | --use_se --attention_gan 3 --attention_input --device cpu --img_size 384 --light 32
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 郑煜伟
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # 原始训练
3 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256
4 |
5 | # 直方图匹配
6 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 \
7 | --match_histograms --match_mode hsl --match_prob 0.5 --match_ratio 1.0
8 |
9 | # 背景不变
10 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 \
11 | --cam_D_weight -1 --cam_D_attention --seg_fix_weight 100 --seg_D_mask --seg_G_detach
12 |
13 | # se + blur
14 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 --use_se --has_blur
15 |
16 | # attention + 原图
17 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 \
18 | --attention_gan 3 --attention_input
19 |
20 | # 损失权重调整
21 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 \
22 | --adv_weight 1.0 --forward_adv_weight 2 --cycle_weight 5 --identity_weight 5
23 |
24 | # cpu调试
25 | python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 --device cpu --num_workers 0
26 |
27 | # 复杂cpu测试:直方图匹配 + 背景不变 + se + blur + attention + 原图
28 | python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 --device cpu --num_workers 0 \
29 | --match_histograms --match_mode hsl --match_prob 0.5 --match_ratio 1.0 \
30 | --cam_D_weight -1 --cam_D_attention --seg_fix_weight 100 --seg_D_mask --seg_G_detach \
31 | --use_se --has_blur --attention_gan 3 --attention_input \
32 | --adv_weight 1.0 --forward_adv_weight 2 --cycle_weight 20 --identity_weight 5
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # my ignore
132 | .idea
133 | .DS_Store
134 |
--------------------------------------------------------------------------------
/faceseg/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum - 1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | # self.init_weight()
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self):
83 | state_dict = modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/videos/test_video.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | 视频/摄像头测试:人脸关键点检测、矫正、变漫画、贴回原图、保存为视频
4 | """
5 | import os
6 | from functools import partial
7 |
8 | import cv2
9 | import numpy as np
10 | from PIL import Image
11 | import torch
12 | from torchvision import transforms
13 |
14 | from .face_align.align_utils import detect_face, align_face
15 |
16 |
17 | class VideoTester:
18 | """ GAN测试 """
19 |
20 | IMAGE_EXT = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
21 |
22 | def __init__(self, args, generator):
23 | self.args = args
24 | self.generator = generator
25 | self.generator.to(args.device)
26 | self.generator.eval()
27 | self.preprocess = transforms.Compose([
28 | partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB), # 将cv读取的图像转为RGB
29 | Image.fromarray,
30 | transforms.Resize((args.img_size, args.img_size)),
31 | transforms.ToTensor(),
32 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
33 | partial(torch.unsqueeze, dim=0),
34 | ])
35 | width = 1280
36 | height = 720 * 2
37 | self.frame_size = (width, height)
38 | self.text = partial(cv2.putText, org=(width-100, 100), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
39 | fontScale=1.2, color=(255, 255, 255), thickness=2)
40 | self.fourcc = cv2.VideoWriter_fourcc(*'XVID')
41 | self.fps = 30
42 |
43 | def generate(self, img):
44 | """ 脸部(0, 255)BGR原图生成(0, 255)BGR动漫图 """
45 | img = self.preprocess(img)
46 | img = img.to(self.args.device)
47 | with torch.no_grad():
48 | img, _, _, _ = self.generator(img)
49 | img = (img.cpu()[0].numpy().transpose(1, 2, 0) + 1) * 255 / 2
50 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
51 | return img
52 |
53 | def image2image(self, image):
54 | """ 整张图像到图像的转换 """
55 | # 脸部检测及矫正
56 | facepoints = detect_face(image, detect_scale=0.25, flag_384=self.args.img_size>256)
57 | if facepoints is None:
58 | return image
59 | align_result = align_face(image, facepoints, flag_384=self.args.img_size>256)
60 | if align_result is None:
61 | return image
62 | img_align, face2mean_matrix, mean2face_matrix, _ = align_result
63 | # 生成
64 | img_gen = self.generate(img_align)
65 | # 变换为原来位置
66 | img_gen = np.clip(img_gen, 0, 255).astype(np.uint8)
67 | output = np.ones([self.args.img_size, self.args.img_size, 4], dtype=np.uint8) * 255
68 | output[:, :, :3] = img_gen
69 | output = cv2.warpAffine(output, mean2face_matrix,
70 | (image.shape[1], image.shape[0]),
71 | flags=cv2.INTER_LINEAR,
72 | borderMode=cv2.BORDER_CONSTANT, borderValue=0)
73 | alpha_img_gen = (output[:, :, 3:4] / 255.0)
74 | image_compose = image.astype(np.float32)
75 | image_compose[:, :, 0:3] = (image_compose[:, :, 0:3] * (1.0 - alpha_img_gen) + output[:, :, 0:3])
76 | image_compose = np.clip(image_compose, 0, 255).astype(np.uint8)
77 | return image_compose
78 |
79 | def record(self, cap: cv2.VideoCapture, writer: cv2.VideoWriter,
80 | frame_num: int):
81 | """ 从cap视频源读取图像,经过图像转换后,写入到writer中 """
82 | ret, frame = cap.read()
83 | if frame is None:
84 | return -1
85 | frame = frame[:, ::-1, :] # 左右flip
86 | org_frame = frame.copy()
87 | new_frame = self.image2image(frame)
88 | new_frame = np.concatenate([org_frame, new_frame], axis=0)
89 | self.text(new_frame, str(frame_num))
90 | writer.write(new_frame)
91 | frame_num += 1
92 | return frame_num
93 |
94 | def video(self, path):
95 | """ 视频测试 """
96 | video_name = os.path.basename(self.args.generator_model)
97 | video_name = os.path.splitext(os.path.basename(path))[0] + '_' + video_name
98 | new_record = cv2.VideoWriter(f'{os.path.splitext(video_name)[0]}_cartoon.avi',
99 | self.fourcc, self.fps, self.frame_size)
100 | cap = cv2.VideoCapture(path) # 打开视频
101 | frame_num = 0
102 | while cap.isOpened():
103 | frame_num = self.record(cap, new_record, frame_num)
104 | if frame_num < 0:
105 | break
106 | cap.release(), new_record.release()
107 |
108 | def camera(self):
109 | """ 摄像头测试 """
110 | video_name = os.path.basename(self.args.generator_model)
111 | cap = cv2.VideoCapture(0) # 打开摄像头
112 | ret = True
113 | while ret:
114 | ret, frame = cap.read()
115 | frame = self.image2image(frame)
116 | cv2.imshow(video_name, frame)
117 | cv2.waitKey(1)
118 | if 0xFF == ord('q'):
119 | break
120 | cap.release()
121 | cv2.destroyAllWindows()
122 |
123 | def image_dir(self, img_dir):
124 | """ 图像文件夹测试 """
125 | img_paths = [f for f in os.listdir(img_dir) if os.path.splitext(f)[-1].lower() in self.IMAGE_EXT]
126 | save_dir = os.path.join(img_dir, '..', 'gan_result')
127 | os.makedirs(save_dir, exist_ok=False)
128 | for img_path in img_paths:
129 | image = cv2.imread(os.path.join(img_dir, img_path))
130 | image_gan = self.image2image(image)
131 | cv2.imwrite(os.path.join(save_dir, img_path), image_gan)
132 |
--------------------------------------------------------------------------------
/faceseg/FaceSegmentation.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | import os
3 | import functools
4 |
5 | import cv2
6 | import torch
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from torchvision import transforms
10 |
11 | from faceseg.networks_faceseg import BiSeNet
12 |
13 |
14 | class FaceSegmentation:
15 | part_map = {0: 'background', 1: 'skin', 2: 'l_brow', 3: 'r_brow', 4: 'l_eye', 5: 'r_eye', 6: 'eye_g', 7: 'l_ear',
16 | 8: 'r_ear', 9: 'ear_r', 10: 'nose', 11: 'mouth', 12: 'u_lip', 13: 'l_lip', 14: 'neck', 15: 'neck_l',
17 | 16: 'cloth', 17: 'hair', 18: 'hat'}
18 | # 不同部分的颜色
19 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 0, 85], [255, 0, 170],
20 | [0, 255, 0], [85, 255, 0], [170, 255, 0], [0, 255, 85], [0, 255, 170],
21 | [0, 0, 255], [85, 0, 255], [170, 0, 255], [0, 85, 255], [0, 170, 255],
22 | [255, 255, 0], [255, 255, 85], [255, 255, 170], [255, 0, 255], [255, 85, 255],
23 | [255, 170, 255], [0, 255, 255], [85, 255, 255], [170, 255, 255]]
24 |
25 | def __init__(self, device):
26 | self.device = device
27 | self.n_classes = len(FaceSegmentation.part_map.keys())
28 | # 膨胀、腐蚀、闭运算、开运算
29 | self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
30 | # 模型加载
31 | net = BiSeNet(n_classes=self.n_classes).to(self.device)
32 | net.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), '79999_iter.pth'),
33 | map_location=torch.device('cpu')))
34 | net.eval()
35 | self.net = net
36 | self.mean = torch.as_tensor((0.485, 0.456, 0.406), dtype=torch.float32, device=self.device).view(-1, 1, 1)
37 | self.std = torch.as_tensor((0.229, 0.224, 0.225), dtype=torch.float32, device=self.device).view(-1, 1, 1)
38 | # 预处理
39 | self.preprocessor = transforms.Compose([
40 | lambda x: x * 0.5 + 0.5,
41 | functools.partial(F.interpolate, size=(512, 512), mode='bilinear', align_corners=True),
42 | lambda x: x.sub_(self.mean).div_(self.std),
43 | ])
44 |
45 | def face_segmentation(self, input_x):
46 | """ 人脸分割
47 | :param input_x: NCHW 标准化的tensor, (image / 255 - 0.5) * 2
48 | :return mask N1HW tensor,每一个像素点位置取值 0~18 int数,表示属于哪一类
49 | """
50 | with torch.no_grad():
51 | img = self.preprocessor(input_x)
52 | img = img.to(self.device)
53 | out = self.net(img)[0]
54 | out = F.interpolate(out, input_x.shape[2:], mode='bicubic', align_corners=True)
55 | mask = out.detach().softmax(axis=1)
56 | return mask
57 |
58 | def gen_mask(self, mask_tensor, is_soft_edge=True,
59 | normal_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17), dilate_parts=(), erode_parts=()):
60 | """ 根据指定的parts生成mask numpy数组
61 | :param mask_tensor: face parsing 分割出来的不同类别的mask
62 | :param is_soft_edge: mask是否有软边界
63 | :param normal_parts: 不做操作的 前景类别列表
64 | :param dilate_parts: 需要做膨胀操作的 前景类别列表
65 | :param erode_parts: 需要做腐蚀操作的 前景类别列表
66 | :return mask: 前景mask
67 | """
68 | mask_tensor = mask_tensor.cpu().numpy()
69 | N, C, H, W = mask_tensor.shape
70 | normal_mask = np.zeros((N, 1, H, W), dtype=np.float32)
71 | dilate_mask = np.zeros((N, 1, H, W), dtype=np.float32)
72 | erode_mask = np.zeros((N, 1, H, W), dtype=np.float32)
73 | for i in normal_parts:
74 | normal_mask += mask_tensor[:, i, :, :]
75 | for i in dilate_parts:
76 | dilate_mask += mask_tensor[:, i, :, :]
77 | for i in erode_parts:
78 | erode_mask += mask_tensor[:, i, :, :]
79 | # 闭运算 + 膨胀
80 | for i in range(len(dilate_mask)):
81 | dilate_mask[i, 0] = cv2.morphologyEx(dilate_mask[i, 0], cv2.MORPH_CLOSE, self.kernel)
82 | dilate_mask[i, 0] = cv2.dilate(dilate_mask[i, 0], self.kernel, iterations=1)
83 | # 闭运算 + 腐蚀
84 | for i in range(len(erode_mask)):
85 | erode_mask[i, 0] = cv2.morphologyEx(erode_mask[i, 0], cv2.MORPH_CLOSE, self.kernel)
86 | erode_mask[i, 0] = cv2.erode(erode_mask[i, 0], self.kernel, iterations=1)
87 | mask = np.maximum(normal_mask, dilate_mask) # 三个区域的交集
88 | mask = np.maximum(mask, erode_mask)
89 | if is_soft_edge:
90 | mask = ((0.7 > mask) & (mask > 0.3)) * mask + (mask > 0.7) # 概率很高/低的区域更加hard,留下过渡区域
91 | else:
92 | mask = (mask >= 0.5).astype(np.float32)
93 | mask = torch.from_numpy(mask).to(self.device)
94 | return mask
95 |
96 | def vis(self, image, mask, is_show=False):
97 | """ 可视化人脸分割结果,如果mask包含不同类别,会将不同类别用不同的颜色mask可视化,如果只有一种类别,会将前景用蓝色mask可视化
98 | :param image: 待可视化图像
99 | :param mask: 待可视化图像分割后的mask
100 | :param is_show:是否用cv2.imshow可视化
101 | :return 可视化的图像
102 | """
103 | mask = mask.cpu().numpy()
104 | vis_im = (image.numpy().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8)
105 | vis_mask_color = np.zeros((mask.shape[0], mask.shape[1], 3)) + 255
106 |
107 | if mask.dtype == np.int64:
108 | for pi in range(1, self.n_classes):
109 | index = np.where(mask == pi)
110 | vis_mask_color[index[0], index[1], :] = self.part_colors[pi]
111 | else:
112 | mask = np.repeat(np.expand_dims(mask, axis=-1), repeats=3, axis=-1)
113 | vis_mask_color = np.array([[self.part_colors[1]]], dtype=np.float32) * mask + (1 - mask) * vis_mask_color
114 |
115 | vis_mask_color = vis_mask_color.astype(np.uint8)
116 | vis_im_hm = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.5, vis_mask_color, 0.5, 0)
117 | if is_show:
118 | cv2.imshow('seg', np.concatenate([vis_im_hm, cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR)], axis=1))
119 | cv2.waitKey(0)
120 | return np.concatenate([vis_im_hm, cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR)], axis=1)
121 |
122 |
123 | if __name__ == '__main__':
124 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
125 | img_dir = '../dataset/cartoon/testA'
126 | face_seg = FaceSegmentation('cpu')
127 | train_transform = transforms.Compose([
128 | cv2.imread,
129 | functools.partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB),
130 | transforms.ToTensor(),
131 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
132 | functools.partial(torch.unsqueeze, dim=0)
133 | ])
134 | for f_name in sorted(os.listdir(img_dir)):
135 | test_img = train_transform(os.path.join(img_dir, f_name))
136 | test_mask = face_seg.face_segmentation(test_img)
137 | # 注释这三句,使用下面一句,可以看到所有类别
138 | test_mask0 = face_seg.gen_mask(test_mask, normal_parts=(), dilate_parts=(),
139 | erode_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17))
140 | test_mask1 = face_seg.gen_mask(test_mask, normal_parts=(1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17),
141 | dilate_parts=(), erode_parts=(6, ))
142 | test_mask = test_mask0 * test_mask1
143 | # test_mask = test_mask.argmax(1, keepdims=True)
144 | vis_img = face_seg.vis(test_img[0], test_mask[0, 0], is_show=True)
145 | cv2.imwrite(f_name, vis_img)
146 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # enhanced-U-GAT-IT
2 |
3 | 知乎上关于论文的解读blog:
4 | - [论文阅读 | 图像转换(五) AttentionGAN](https://zhuanlan.zhihu.com/p/168382844)
5 | - [论文阅读 | 图像转换(六) U-GAT-IT](https://zhuanlan.zhihu.com/p/270958248)
6 |
7 | 专栏内及其他专栏有更多内容,欢迎进行技术讨论~
8 |
9 | ## Usage
10 |
11 | ```shell
12 | CUDA_VISIBLE_DEVICES=2 python main.py --dataset YOUR_DATASET_NAME --result_dir YOUR_RESULT_DIR \
13 | # 适当减小light可降低显存的使用
14 | --light 32 \
15 | --img_size 384 --aug_prob 0.2 --device cuda:0 --num_worker 1 --print_freq 1000 --calc_fid_freq 10000 --save_freq 100000 \
16 | # 适当提高forward_adv_weight系数可适当增强 A2B的风格化程度
17 | --adv_weight 1.0 --forward_adv_weight 1 --cycle_weight 10 --identity_weight 10 \
18 | # ema大幅提高模型稳定性和fid
19 | --ema_start 0.5 --ema_beta 0.9999 \
20 | # 固定背景参数
21 | # 去除判别器中的CAM和logit loss
22 | --cam_D_weight -1 --cam_D_attention \
23 | # 使用faceseg分割出背景做L1损失
24 | --seg_fix_weight 100 \
25 | # 背景的判别器损失丢弃mask区域的损失
26 | --seg_D_mask \
27 | # 背景detach
28 | --seg_G_detach \
29 | # D判别器可以设置较为local
30 | --n_global_dis 7 --n_local_dis 5
31 | --has_blur --use_se
32 | # 如果attention_input想配合背景固定一起用的话,建议加大对抗loss的权重 adv_weight -> 10 甚至更大
33 | --attention_gan 2 --attention_input
34 | ```
35 |
36 | 本地调试的话:`--device cuda:0` -> `--device cpu`,
37 |
38 | ## 基于官方U-GAT-IT[1]做改进:
39 |
40 | - 代码结构优化:
41 | - 将`UGATIT.py`中的`train`函数的部分代码模块化为函数(`get_batch, forward, backward_G, backward_D`);
42 | - 增加训练时损失函数的打印,tensorboard记录;
43 | - 优化了官方代码中一次多余的前向推理,约节省20%的训练时间。
44 |
45 | - 功能增强(可选是否开启):
46 | - `--ema_start`: 开始做模型ema的迭代次数的比例数,也就是`--iteration 100 --ema_start 0.7`表示 `100 * 0.7 = 70` 个迭代后开始做模型ema;
47 | - `--ema_beta`: 模型ema的滑动平均系数;
48 | - `--calc_fid_freq`: 计算fid score的频率,应该设置为`--print_freq`的整数倍;
49 | - `--fid_batch`: 每次计算fid score时,推理时使用的`batch size`;
50 | - `--adv_weight, forward_adv_weight, backward_adv_weight`: 将对抗损失项权重分成3个(原始的对抗损失项权重,A2B对抗损失项的权重,B2A的对抗损失项权重);
51 | - `--aug_prob`: 数据增强`RandomResizedCrop`的概率;
52 | - `--match_histograms`:是否将两个域进行直方图匹配,使两个域的分布一致;[2]
53 | - `--match_mode`:直方图匹配时的匹配模式,(hsv, hsl, rgb)中的一个;
54 | - `--match_prob`:将域B的图像往域A进行直方图匹配的概率,否则将域A的图像往域B直方图匹配;
55 | - `--match_ratio`:直方图匹配的比例,匹配图 * ratio + 原图 * (1 - ratio);
56 | - `--sn`: 判别器中,卷积操作后是否进行谱归一化,默认使用,tf官方代码有,pt官方代码没有;
57 | - `--has_blur`: 判别网络训练/模型更新时,损失项中增加模糊对抗损失项,增强D对模糊图片的判别;[3]
58 | - `--use_se`: 在生成网络中,是否给每个`ResnetBlock, ResnetAdaILNBlock`增加`scse block`;[4]
59 | - `--attention_gan`: 生成网络中,是否使用`attention mask`机制(在上采样前多一个分支,用于生成mask);[5]
60 | - `--use_deconv`: 生成网络中,上采样是否使用反卷积`nn.ConvTranspose2d`,默认使用插值+卷积方式;
61 | - `--seg_fix_weight`: 人脸分割[6]损失权重(本库代码默认的分割区域为下图的二分类图,借鉴的库可实现多个类别的解析),将原图与生成图在背景区域上做L1损失;
62 | - `--seg_D_mask`: 背景区域的对抗损失置0、背景区域的判别损失置0;
63 | - `--seg_G_detach`: 生成图的背景区域detach掉再放入D网络做对抗;
64 | - `--seg_D_cam_inp_mask`: 训练D网络时,整个背景替换为0作为CAM分支的输入;
65 | - `--seg_D_cam_fea_mask`: 训练D网络时,将CAM分支的feature map的背景区域替换为0;
66 | - `--cam_D_weight -1 --cam_D_attention`: 将CAM的attention和logit loss砍掉。
67 |
68 |  
69 |
70 | - 其他功能:
71 | - `--phase`: 模型的训练/验证/测试/视频/视频文件夹/摄像头/图像文件夹/以对齐的人脸图像文件夹 模式(`'train', 'val', 'test', 'video', 'video_dir', 'camera', 'img_dir', 'generate'`);
72 |
73 | ## 文件目录说明
74 |
75 | - `assets`: `README.md`中的图像;
76 | - `dataset`: 图像数据
77 | - `faceseg`: 人脸分割部分代码与明星权重文件夹;
78 | - `metrics`: FID score 计算脚本;
79 | - `scripts`: 运行命令shell脚本文件夹;训练、测试的运行指令可以参考这里~
80 | - `videos`: 视频测试脚本;
81 | - `dataset.py`: 数据集;
82 | - `histogram.py`: 直方图匹配;
83 | - `main.py`: 程序主入口,配置信息;
84 | - `networks.py`: 生成器、判别器网络结构
85 | - `UGATIT.py`: 整个训练、测试过程的执行脚本;
86 | - `utils.py`: 工具脚本;
87 |
88 | ## 补充实验说明
89 |
90 | - 部分代码修改:官方其实主库是tensorflow的库,pytorch是根据tensorflow写的,其中有比较多的差异,本库修正了部分差异;
91 | - 损失项权重:在`CycleGAN`实验中发现,适当增加A2B阶段的损失项权重,有利于提升A2B阶段的图像生成质量;
92 | - 反卷积和谱归一化:在`CycleGAN`实验中发现,使用反卷积和去除谱归一化有利于生成网络对原图做更强的变化;
93 | - 模糊对抗对纹理区分度略有提升;
94 | - attention mask fusion:在本库中实验,发现提升不大,不同任务需进一步实验;同时提供了一个把输入图与生成图一起做区域融合的选项;
95 | - scse模块:在本库中实验,发现提升不大,不同任务需进一步实验;
96 | - 在控制人脸除外的背景不变时,使用`faceseg`具有一定效果,可配合背景随机换色一起使用;
97 | - 直方图匹配在两个域比较接近,但是分布存在差异的时候有效,例如 大人 -> 小孩,其他情况可能只对某个通道进行直方图匹配就可,例如在value上使得光照一致;
98 | - 在两个域的色调等存在分布差异时,还可考虑将输入D的图像先转为灰度图,再进行判别;类似AnimeGAN的做法。
99 |
100 | ## reference
101 |
102 | [1] UGATIT官方pytorch实现,也是本库baseline:https://github.com/znxlwm/UGATIT-pytorch;
103 | [2] 直方图匹配参考 [Bringing-Old-Photos-Back-to-Life](https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life) ;
104 | [3] CartoonGAN: Generative Adversarial Networks for Photo Cartoonization;
105 | [4] Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks;
106 | [5] AttentionGAN: Unpaired Image-to-Image Translation using Attention-Guided Generative Adversarial Networks;
107 | [6] 人脸分割模型库:[zllrunning/face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch) ;
108 |
109 |
110 | ------
111 | (以下为官方原始 readme 部分)
112 |
113 | ## U-GAT-IT — Official PyTorch Implementation
114 | ### : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation
115 |
116 |
117 |

118 |
119 |
120 | ### [Paper](https://arxiv.org/abs/1907.10830) | [Official Tensorflow code](https://github.com/taki0112/UGATIT)
121 | The results of the paper came from the **Tensorflow code**
122 |
123 |
124 | > **U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation**
125 | >
126 | > **Abstract** *We propose a novel method for unsupervised image-to-image translation, which incorporates a new attention module and a new learnable normalization function in an end-to-end manner. The attention module guides our model to focus on more important regions distinguishing between source and target domains based on the attention map obtained by the auxiliary classifier. Unlike previous attention-based methods which cannot handle the geometric changes between domains, our model can translate both images requiring holistic changes and images requiring large shape changes. Moreover, our new AdaLIN (Adaptive Layer-Instance Normalization) function helps our attention-guided model to flexibly control the amount of change in shape and texture by learned parameters depending on datasets. Experimental results show the superiority of the proposed method compared to the existing state-of-the-art models with a fixed network architecture and hyper-parameters.*
127 |
128 | ## Usage
129 | ```
130 | ├── dataset
131 | └── YOUR_DATASET_NAME
132 | ├── trainA
133 | ├── xxx.jpg (name, format doesn't matter)
134 | ├── yyy.png
135 | └── ...
136 | ├── trainB
137 | ├── zzz.jpg
138 | ├── www.png
139 | └── ...
140 | ├── testA
141 | ├── aaa.jpg
142 | ├── bbb.png
143 | └── ...
144 | └── testB
145 | ├── ccc.jpg
146 | ├── ddd.png
147 | └── ...
148 | ```
149 |
150 | ### Train
151 | ```
152 | > python main.py --dataset selfie2anime
153 | ```
154 | * If the memory of gpu is **not sufficient**, set `--light` to True
155 |
156 | ### Test
157 | ```
158 | > python main.py --dataset selfie2anime --phase test
159 | ```
160 |
161 | ## Architecture
162 | 
163 |
164 | ---
165 |
166 | 
167 |
168 | ## Results
169 | ### Ablation study
170 | 
171 |
172 | ### User study
173 | 
174 |
175 | ### Comparison
176 | 
177 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import os.path
4 | from queue import Queue
5 | from threading import Thread
6 |
7 | import cv2
8 | import torch
9 | import torch.utils.data
10 | import numpy as np
11 |
12 | from histogram import match_histograms
13 |
14 |
15 | def get_loader(my_dataset, device, batch_size, num_workers, shuffle):
16 | """ 根据dataset及设置,获取对应的 DataLoader """
17 | my_loader = torch.utils.data.DataLoader(my_dataset, batch_size=batch_size, num_workers=num_workers,
18 | shuffle=shuffle, pin_memory=True, persistent_workers=(num_workers > 0))
19 | # if torch.cuda.is_available():
20 | # my_loader = CudaDataLoader(my_loader, device=device)
21 | return my_loader
22 |
23 |
24 | class MatchHistogramsDataset(torch.utils.data.Dataset):
25 |
26 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
27 |
28 | def __init__(self, root, transform=None, target_transform=None, is_match_histograms=False, match_mode=True,
29 | b2a_prob=0.5, match_ratio=1.0):
30 | """ 获取指定的两个文件夹下,两张图像numpy数组的Dataset """
31 | assert len(root) == 2, f'root of MatchHistogramsDataset must has two dir!'
32 | self.dataset_0 = DatasetFolder(root[0])
33 | self.dataset_1 = DatasetFolder(root[1])
34 |
35 | self.transform = transform
36 | self.target_transform = target_transform
37 | self.len_0 = len(self.dataset_0)
38 | self.len_1 = len(self.dataset_1)
39 | self.len = max(self.len_0, self.len_1)
40 | self.is_match_histograms = is_match_histograms
41 | self.match_mode = match_mode
42 | assert self.match_mode in ('hsv', 'hsl', 'rgb'), f'match mode must in {self.match_mode}'
43 | self.b2a_prob = b2a_prob
44 | self.match_ratio = match_ratio
45 |
46 | def __getitem__(self, index):
47 | sample_0 = self.dataset_0[index] if index < self.len_0 else self.dataset_0[np.random.randint(self.len_0)]
48 | sample_1 = self.dataset_1[index] if index < self.len_1 else self.dataset_1[np.random.randint(self.len_1)]
49 |
50 | if self.is_match_histograms:
51 | if self.match_mode == 'hsv':
52 | sample_0 = cv2.cvtColor(sample_0, cv2.COLOR_RGB2HSV_FULL)
53 | sample_1 = cv2.cvtColor(sample_1, cv2.COLOR_RGB2HSV_FULL)
54 | elif self.match_mode == 'hsl':
55 | sample_0 = cv2.cvtColor(sample_0, cv2.COLOR_RGB2HLS_FULL)
56 | sample_1 = cv2.cvtColor(sample_1, cv2.COLOR_RGB2HLS_FULL)
57 |
58 | if np.random.rand() < self.b2a_prob:
59 | sample_1 = match_histograms(sample_1, sample_0, rate=self.match_ratio)
60 | else:
61 | sample_0 = match_histograms(sample_0, sample_1, rate=self.match_ratio)
62 |
63 | if self.match_mode == 'hsv':
64 | sample_0 = cv2.cvtColor(sample_0, cv2.COLOR_HSV2RGB_FULL)
65 | sample_1 = cv2.cvtColor(sample_1, cv2.COLOR_HSV2RGB_FULL)
66 | elif self.match_mode == 'hsl':
67 | sample_0 = cv2.cvtColor(sample_0, cv2.COLOR_HLS2RGB_FULL)
68 | sample_1 = cv2.cvtColor(sample_1, cv2.COLOR_HLS2RGB_FULL)
69 |
70 | if self.transform is not None:
71 | sample_0 = self.transform(sample_0)
72 | sample_1 = self.transform(sample_1)
73 |
74 | return sample_0, sample_1
75 |
76 | def __len__(self):
77 | return self.len
78 |
79 | def __repr__(self):
80 | fmt_str = f'MatchHistogramsDataset for: \n' \
81 | f'{self.dataset_0.__repr__()} \n ' \
82 | f'{self.dataset_1.__repr__()}'
83 | return fmt_str
84 |
85 |
86 | class DatasetFolder(torch.utils.data.Dataset):
87 |
88 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
89 |
90 | def __init__(self, root, transform=None):
91 | """ 获取指定文件夹下,单张图像numpy数组的Dataset """
92 | samples = []
93 | for sub_root, _, filenames in sorted(os.walk(root)):
94 | for filename in sorted(filenames):
95 | if os.path.splitext(filename)[-1].lower() in self.IMG_EXTENSIONS:
96 | path = os.path.join(sub_root, filename)
97 | samples.append(path)
98 |
99 | if len(samples) == 0:
100 | raise RuntimeError(f"Found 0 files in sub-folders of: {root}\n"
101 | f"Supported extensions are: {','.join(self.IMG_EXTENSIONS)}")
102 |
103 | self.root = root
104 | self.samples = samples
105 | self.transform = transform
106 |
107 | def __getitem__(self, index):
108 | path = self.samples[index]
109 | sample = cv2.imread(path)[..., ::-1]
110 | if self.transform is not None:
111 | sample = self.transform(sample)
112 | return sample
113 |
114 | def __len__(self):
115 | return len(self.samples)
116 |
117 | def __repr__(self):
118 | fmt_str = f'Dataset {self.__class__.__name__}\n'\
119 | f' Number of data points: {self.__len__()}\n'\
120 | f' Root Location: {self.root}\n'
121 | tmp = ' Transforms (if any): '
122 | trans_tmp = self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))
123 | fmt_str += f'{tmp}{trans_tmp}'
124 | return fmt_str
125 |
126 |
127 | class CudaDataLoader:
128 | """ 异步预先将数据从CPU加载到GPU中 """
129 |
130 | def __init__(self, loader, device, queue_size=2):
131 | self.device = device
132 | self.queue_size = queue_size
133 | self.loader = loader
134 |
135 | self.load_stream = torch.cuda.Stream(device=device)
136 | self.queue = Queue(maxsize=self.queue_size)
137 |
138 | self.idx = 0
139 | self.worker = Thread(target=self.load_loop)
140 | self.worker.setDaemon(True)
141 | self.worker.start()
142 |
143 | def load_loop(self):
144 | """ 不断的将cuda数据加载到队列里 """
145 | # The loop that will load into the queue in the background
146 | torch.cuda.set_device(self.device)
147 | while True:
148 | for i, sample in enumerate(self.loader):
149 | self.queue.put(self.load_instance(sample))
150 |
151 | def load_instance(self, sample):
152 | """ 将batch数据从CPU加载到GPU中 """
153 | if torch.is_tensor(sample):
154 | with torch.cuda.stream(self.load_stream):
155 | return sample.to(self.device, non_blocking=True)
156 | elif sample is None or type(sample) in (list, str):
157 | return sample
158 | elif isinstance(sample, dict):
159 | return {k: self.load_instance(v) for k, v in sample.items()}
160 | else:
161 | return [self.load_instance(s) for s in sample]
162 |
163 | def __iter__(self):
164 | self.idx = 0
165 | return self
166 |
167 | def __next__(self):
168 | # 加载线程挂了
169 | if not self.worker.is_alive() and self.queue.empty():
170 | self.idx = 0
171 | self.queue.join()
172 | self.worker.join()
173 | raise StopIteration
174 | # 一个epoch加载完了
175 | elif self.idx >= len(self.loader):
176 | self.idx = 0
177 | raise StopIteration
178 | # 下一个batch
179 | else:
180 | out = self.queue.get()
181 | self.queue.task_done()
182 | self.idx += 1
183 | return out
184 |
185 | def next(self):
186 | return self.__next__()
187 |
188 | def __len__(self):
189 | return len(self.loader)
190 |
191 | @property
192 | def sampler(self):
193 | return self.loader.sampler
194 |
195 | @property
196 | def dataset(self):
197 | return self.loader.dataset
198 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import sys
4 | import random
5 | from typing import Any, Union
6 |
7 | import cv2
8 | import torch
9 | import numpy as np
10 | from scipy import misc
11 | from tqdm import tqdm
12 |
13 |
14 | def calc_tv_loss(inp, mask=None, eps=1e-8):
15 | """ 提供inp平滑性约束 """
16 | x_diff = inp[:, :, 1:, 1:] - inp[:, :, :-1, 1:]
17 | y_diff = inp[:, :, 1:, 1:] - inp[:, :, 1:, :-1]
18 | if mask is not None:
19 | x_diff *= mask[:, :, 1:, 1:]
20 | y_diff *= mask[:, :, 1:, 1:]
21 |
22 | # loss = torch.mean(torch.abs(x_diff) + torch.abs(y_diff))
23 | loss = torch.mean(torch.sqrt(torch.square(inp) + torch.square(inp) + eps))
24 | return loss
25 |
26 |
27 | def load_test_data(image_path, size=256):
28 | img = misc.imread(image_path, mode='RGB')
29 | img = misc.imresize(img, [size, size])
30 | img = np.expand_dims(img, axis=0)
31 | img = preprocessing(img)
32 |
33 | return img
34 |
35 |
36 | def preprocessing(x):
37 | x = x / 127.5 - 1 # -1 ~ 1
38 | return x
39 |
40 |
41 | def save_images(images, size, image_path):
42 | return imsave(inverse_transform(images), size, image_path)
43 |
44 |
45 | def inverse_transform(images):
46 | return (images + 1.) / 2
47 |
48 |
49 | def imsave(images, size, path):
50 | return misc.imsave(path, merge(images, size))
51 |
52 |
53 | def merge(images, size):
54 | h, w = images.shape[1], images.shape[2]
55 | img = np.zeros((h * size[0], w * size[1], 3))
56 | for idx, image in enumerate(images):
57 | i = idx % size[1]
58 | j = idx // size[1]
59 | img[h * j:h * (j + 1), w * i:w * (i + 1), :] = image
60 |
61 | return img
62 |
63 |
64 | def check_folder(log_dir):
65 | if not os.path.exists(log_dir):
66 | os.makedirs(log_dir)
67 | return log_dir
68 |
69 |
70 | def cam(x, size=256):
71 | x = x - np.min(x)
72 | cam_img = x / np.max(x)
73 | cam_img = np.uint8(255 * cam_img)
74 | cam_img = cv2.resize(cam_img, (size, size))
75 | cam_img = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET)
76 | return cam_img / 255.0
77 |
78 |
79 | def attention_mask(x, size=256):
80 | attention_img = cv2.resize(np.uint8(255 * x), (size, size))
81 | attention_img = cv2.applyColorMap(attention_img, cv2.COLORMAP_JET)
82 | return attention_img / 255.0
83 |
84 |
85 | def imagenet_norm(x):
86 | mean = [0.485, 0.456, 0.406]
87 | std = [0.299, 0.224, 0.225]
88 | mean = torch.FloatTensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
89 | std = torch.FloatTensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
90 | return (x - mean) / std
91 |
92 |
93 | def denorm(x):
94 | return x * 0.5 + 0.5
95 |
96 |
97 | def tensor2numpy(x):
98 | return x.detach().cpu().numpy().transpose(1, 2, 0)
99 |
100 |
101 | def RGB2BGR(x):
102 | return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
103 |
104 |
105 | def setup_seed(seed):
106 | torch.manual_seed(seed)
107 | torch.cuda.manual_seed_all(seed)
108 | np.random.seed(seed)
109 | random.seed(seed)
110 | torch.backends.cudnn.deterministic = True
111 |
112 |
113 | """
114 | 评估量:记录,打印
115 | """
116 |
117 |
118 | class AverageMeter:
119 | """ 计算并存储 评估量的均值和当前值 """
120 |
121 | def __init__(self, name, fmt=':f'):
122 | self.name = name # 评估量名称
123 | self.fmt = fmt # 评估量打印格式
124 | self.val = 0 # 评估量当前值
125 | self.avg = 0 # 评估量均值
126 | self.sum = 0 # 历史评估量的和
127 | self.count = 0 # 历史评估量的数量
128 |
129 | def reset(self):
130 | self.val = 0
131 | self.avg = 0
132 | self.sum = 0
133 | self.count = 0
134 |
135 | def update(self, val, n=1):
136 | self.val = val
137 | self.sum += val * n
138 | self.count += n
139 | self.avg = self.sum / self.count
140 |
141 | def __str__(self):
142 | fmtstr = f'{{name}} {{val{self.fmt}}} ({{avg{self.fmt}}})'
143 | return fmtstr.format(**self.__dict__)
144 |
145 |
146 | class ProgressMeter:
147 | """ 评估量的进度条打印 """
148 |
149 | def __init__(self, num_batches, *meters, prefix=""):
150 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
151 | self.meters = meters
152 | self.prefix = prefix
153 |
154 | def print(self, batch):
155 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
156 | entries += [str(meter) for meter in self.meters]
157 | print('\t'.join(entries))
158 |
159 | @staticmethod
160 | def _get_batch_fmtstr(num_batches):
161 | num_digits = len(str(num_batches // 1))
162 | fmt = f'{{:{str(num_digits)}d}}'
163 | return f'[{fmt}/{fmt.format(num_batches)}]'
164 |
165 |
166 | class Logger(object):
167 | """
168 | Redirect stderr to stdout, optionally print stdout to a file,
169 | and optionally force flushing on both stdout and the file.
170 | """
171 |
172 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
173 | self.file = None
174 |
175 | if file_name is not None:
176 | self.file = open(file_name, file_mode)
177 |
178 | self.should_flush = should_flush
179 | self.stdout = sys.stdout
180 | self.stderr = sys.stderr
181 |
182 | sys.stdout = self
183 | sys.stderr = self
184 |
185 | def __enter__(self) -> "Logger":
186 | return self
187 |
188 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
189 | self.close()
190 |
191 | def write(self, text: Union[str, bytes]) -> None:
192 | """Write text to stdout (and a file) and optionally flush."""
193 | if isinstance(text, bytes):
194 | text = text.decode()
195 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
196 | return
197 |
198 | if self.file is not None:
199 | self.file.write(text)
200 |
201 | self.stdout.write(text)
202 |
203 | if self.should_flush:
204 | self.flush()
205 |
206 | def flush(self) -> None:
207 | """Flush written text to both stdout and a file, if open."""
208 | if self.file is not None:
209 | self.file.flush()
210 |
211 | self.stdout.flush()
212 |
213 | def close(self) -> None:
214 | """Flush, close possible files, and remove stdout/stderr mirroring."""
215 | self.flush()
216 |
217 | # if using multiple loggers, prevent closing in wrong order
218 | if sys.stdout is self:
219 | sys.stdout = self.stdout
220 | if sys.stderr is self:
221 | sys.stderr = self.stderr
222 |
223 | if self.file is not None:
224 | self.file.close()
225 | self.file = None
226 |
227 |
228 | """
229 | 制作模糊图像
230 | """
231 |
232 |
233 | def generate_blur_images(root, save):
234 | """ 根据清晰图像制作模糊的图像
235 | :param root: 清晰图像所在的根目录
236 | :param save: 模糊图像存放的根目录
237 | """
238 | print(f'generating blur images: {root} to {save}...')
239 | file_list = os.listdir(root)
240 | if not os.path.isdir(save):
241 | os.makedirs(save)
242 | kernel_size = 5
243 | kernel = np.ones((kernel_size, kernel_size), np.uint8)
244 | gauss = cv2.getGaussianKernel(kernel_size, 0)
245 | gauss = gauss * gauss.transpose(1, 0)
246 | for f in tqdm(file_list):
247 | try:
248 | rgb_img = cv2.imread(os.path.join(root, f))
249 | gray_img = cv2.imread(os.path.join(root, f), 0)
250 | pad_img = np.pad(rgb_img, ((2, 2), (2, 2), (0, 0)), mode='reflect')
251 | edges = cv2.Canny(gray_img, 100, 200)
252 | dilation = cv2.dilate(edges, kernel)
253 |
254 | gauss_img = np.copy(rgb_img)
255 | idx = np.where(dilation != 0)
256 | for i in range(np.sum(dilation != 0)):
257 | gauss_img[idx[0][i], idx[1][i], 0] = np.sum(
258 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 0],
259 | gauss))
260 | gauss_img[idx[0][i], idx[1][i], 1] = np.sum(
261 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 1],
262 | gauss))
263 | gauss_img[idx[0][i], idx[1][i], 2] = np.sum(
264 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 2],
265 | gauss))
266 |
267 | cv2.imwrite(os.path.join(save, f), gauss_img)
268 | except Exception as e:
269 | print(f'{f} failed!\n{e}')
270 |
271 | print(f'finish: blur images over! ')
272 |
--------------------------------------------------------------------------------
/histogram.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | 直方图匹配,将一张图片的直方图匹配到目标图上,使两张图的视觉感觉接近
4 | ref https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/blob/95ba2834a358fa243665c86407b220e4e78854fe/Face_Detection/align_warp_back_multiple_dlib.py
5 | """
6 | import cv2
7 | import numpy as np
8 |
9 |
10 | def match_histograms(src_image, ref_image, rate=1.0, image_type='HWC'):
11 | """
12 | This method matches the source image histogram to the
13 | reference signal
14 | :param image src_image: The original source image
15 | :param image ref_image: The reference image
16 | :param rate: histograms shift ratio
17 | :param image_type: HWC or CHW
18 | :return: image_after_matching
19 | :rtype: image (array)
20 | """
21 | # Split the images into the different color channels
22 | # b means blue, g means green and r means red
23 | if image_type == 'HWC':
24 | src_b, src_g, src_r = cv2.split(src_image)
25 | ref_b, ref_g, ref_r = cv2.split(ref_image)
26 | elif image_type == 'CHW':
27 | src_b, src_g, src_r = src_image[0], src_image[1], src_image[2]
28 | ref_b, ref_g, ref_r = ref_image[0], ref_image[1], ref_image[2]
29 | else:
30 | raise ValueError(f'image_type only HWC or CHW, no: {image_type}')
31 |
32 | # Compute the b, g, and r histograms separately
33 | # The flatten() Numpy method returns a copy of the array c
34 | # collapsed into one dimension.
35 | src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256])
36 | src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256])
37 | src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256])
38 | ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256])
39 | ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256])
40 | ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256])
41 |
42 | # Compute the normalized cdf for the source and reference image
43 | src_cdf_blue = calculate_cdf(src_hist_blue)
44 | src_cdf_green = calculate_cdf(src_hist_green)
45 | src_cdf_red = calculate_cdf(src_hist_red)
46 | ref_cdf_blue = calculate_cdf(ref_hist_blue)
47 | ref_cdf_green = calculate_cdf(ref_hist_green)
48 | ref_cdf_red = calculate_cdf(ref_hist_red)
49 |
50 | if rate < 1.0:
51 | ref_cdf_blue = src_cdf_blue * (1.0 - rate) + ref_cdf_blue * rate
52 | ref_cdf_green = src_cdf_green * (1.0 - rate) + ref_cdf_green * rate
53 | ref_cdf_red = src_cdf_red * (1.0 - rate) + ref_cdf_red * rate
54 |
55 | # Make a separate lookup table for each color
56 | blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
57 | green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
58 | red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
59 |
60 | # Use the lookup function to transform the colors of the original
61 | # source image
62 | blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
63 | green_after_transform = cv2.LUT(src_g, green_lookup_table)
64 | red_after_transform = cv2.LUT(src_r, red_lookup_table)
65 |
66 | # Put the image back together
67 | if image_type == 'HWC':
68 | image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
69 | elif image_type == 'CHW':
70 | image_after_matching = np.array([blue_after_transform, green_after_transform, red_after_transform])
71 | else:
72 | raise ValueError(f'image_type only HWC or CHW, no: {image_type}')
73 |
74 | image_after_matching = cv2.convertScaleAbs(image_after_matching)
75 |
76 | return image_after_matching
77 |
78 |
79 | def calculate_cdf(histogram: np.ndarray) -> np.ndarray:
80 | """
81 | This method calculates the cumulative distribution function
82 | :param array histogram: The values of the histogram
83 | :return: normalized_cdf: The normalized cumulative distribution function
84 | :rtype: array
85 | """
86 | # Get the cumulative sum of the elements
87 | cdf = histogram.cumsum()
88 |
89 | # Normalize the cdf
90 | normalized_cdf = cdf / float(cdf.max())
91 |
92 | return normalized_cdf
93 |
94 |
95 | def calculate_lookup(src_cdf: np.ndarray, ref_cdf: np.ndarray) -> np.ndarray:
96 | """
97 | This method creates the lookup table
98 | :param array src_cdf: The cdf for the source image
99 | :param array ref_cdf: The cdf for the reference image
100 | :return: lookup_table: The lookup table
101 | :rtype: array
102 | """
103 | lookup_table = np.zeros(256)
104 | lookup_val = 0
105 | for src_pixel_val in range(len(src_cdf)):
106 | for ref_pixel_val in range(len(ref_cdf)):
107 | if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
108 | lookup_val = ref_pixel_val
109 | break
110 | lookup_table[src_pixel_val] = lookup_val
111 | return lookup_table
112 |
113 |
114 | if __name__ == '__main__':
115 | import numpy as np
116 |
117 | kid_src = cv2.imread('dataset/kid_src.png')
118 | man_src = cv2.imread('dataset/man_src.png')
119 |
120 | # 均衡 1
121 | kid_match_bgr_10 = match_histograms(kid_src, man_src, rate=1)
122 | kid_match_hsv_10 = match_histograms(cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL),
123 | cv2.cvtColor(man_src, cv2.COLOR_BGR2HSV_FULL), rate=1)
124 | kid_match_hls_10 = match_histograms(cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL),
125 | cv2.cvtColor(man_src, cv2.COLOR_BGR2HLS_FULL), rate=1)
126 | # 均衡0.5
127 | kid_match_bgr_05 = match_histograms(kid_src, man_src, rate=0.5)
128 | kid_match_hsv_05 = match_histograms(cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL),
129 | cv2.cvtColor(man_src, cv2.COLOR_BGR2HSV_FULL), rate=0.5)
130 | kid_match_hls_05 = match_histograms(cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL),
131 | cv2.cvtColor(man_src, cv2.COLOR_BGR2HLS_FULL), rate=0.5)
132 |
133 | result = np.concatenate([np.concatenate([kid_src, man_src, man_src], axis=1),
134 | np.concatenate([kid_match_bgr_10, cv2.cvtColor(kid_match_hsv_10, cv2.COLOR_HSV2BGR_FULL),
135 | cv2.cvtColor(kid_match_hls_10, cv2.COLOR_HLS2BGR_FULL)],
136 | axis=1),
137 | np.concatenate([kid_match_bgr_05, cv2.cvtColor(kid_match_hsv_05, cv2.COLOR_HSV2BGR_FULL),
138 | cv2.cvtColor(kid_match_hls_05, cv2.COLOR_HLS2BGR_FULL)],
139 | axis=1)],
140 | axis=0)
141 | cv2.imwrite('all.png', result)
142 | cv2.imshow('match_hsv', result)
143 | cv2.waitKey(0)
144 |
145 | # 只均衡hsv某一通道
146 | kid_hsv_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL)
147 | kid_hsv_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL)
148 | kid_hsv_src_10[..., 0] = kid_match_hsv_10[..., 0]
149 | kid_hsv_src_05[..., 0] = kid_match_hsv_05[..., 0]
150 | kid_hls_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL)
151 | kid_hls_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL)
152 | kid_hls_src_10[..., 0] = kid_match_hls_10[..., 0]
153 | kid_hls_src_05[..., 0] = kid_match_hls_05[..., 0]
154 | kid_match_h = np.concatenate([cv2.cvtColor(kid_hsv_src_10, cv2.COLOR_HSV2BGR_FULL),
155 | cv2.cvtColor(kid_hsv_src_05, cv2.COLOR_HSV2BGR_FULL),
156 | cv2.cvtColor(kid_hls_src_10, cv2.COLOR_HLS2BGR_FULL),
157 | cv2.cvtColor(kid_hls_src_05, cv2.COLOR_HLS2BGR_FULL)], axis=1)
158 |
159 | kid_hsv_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL)
160 | kid_hsv_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL)
161 | kid_hsv_src_10[..., 1] = kid_match_hsv_10[..., 1]
162 | kid_hsv_src_05[..., 1] = kid_match_hsv_05[..., 1]
163 | kid_hls_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL)
164 | kid_hls_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL)
165 | kid_hls_src_10[..., 2] = kid_match_hls_10[..., 2]
166 | kid_hls_src_05[..., 2] = kid_match_hls_05[..., 2]
167 | kid_match_s = np.concatenate([cv2.cvtColor(kid_hsv_src_10, cv2.COLOR_HSV2BGR_FULL),
168 | cv2.cvtColor(kid_hsv_src_05, cv2.COLOR_HSV2BGR_FULL),
169 | cv2.cvtColor(kid_hls_src_10, cv2.COLOR_HLS2BGR_FULL),
170 | cv2.cvtColor(kid_hls_src_05, cv2.COLOR_HLS2BGR_FULL)], axis=1)
171 |
172 | kid_hsv_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL)
173 | kid_hsv_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL)
174 | kid_hsv_src_10[..., 2] = kid_match_hsv_10[..., 2]
175 | kid_hsv_src_05[..., 2] = kid_match_hsv_05[..., 2]
176 | kid_hls_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL)
177 | kid_hls_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL)
178 | kid_hls_src_10[..., 1] = kid_match_hls_10[..., 1]
179 | kid_hls_src_05[..., 1] = kid_match_hls_05[..., 1]
180 | kid_match_v = np.concatenate([cv2.cvtColor(kid_hsv_src_10, cv2.COLOR_HSV2BGR_FULL),
181 | cv2.cvtColor(kid_hsv_src_05, cv2.COLOR_HSV2BGR_FULL),
182 | cv2.cvtColor(kid_hls_src_10, cv2.COLOR_HLS2BGR_FULL),
183 | cv2.cvtColor(kid_hls_src_05, cv2.COLOR_HLS2BGR_FULL)], axis=1)
184 |
185 | result = np.concatenate([np.concatenate([kid_src, man_src, kid_src, man_src], axis=1),
186 | kid_match_h, kid_match_s, kid_match_v], axis=0)
187 |
188 | cv2.imwrite('one.png', result)
189 | cv2.imshow('match_hsv', result)
190 | cv2.waitKey(0)
191 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | CUDA_VISIBLE_DEVICES=0 python main.py \
4 | --dataset YOUR_DATA_SET --result_dir results --img_size 384 --device cuda:0 --num_workers 1 \
5 | # G网络计算AdaLIN的大小
6 | --light -1 \
7 | # 指标打印
8 | --print_freq 1000 --calc_fid_freq 10000 --save_freq 100000 \
9 | # 模型ema
10 | --ema_start 0.5 --ema_beta 0.9999 \
11 | # 分割损失
12 | --seg_fix_weight 50 --seg_fix_glass_mouth --seg_D_mask --seg_G_detach --seg_D_cam_fea_mask --seg_D_cam_inp_mask \
13 | --cam_D_weight -1 \
14 | # attention gan
15 | --attention_gan 2 --attention_input \
16 | # 不同损失项权重
17 | --adv_weight 1.0 --forward_adv_weight 1 --cycle_weight 10 --identity_weight 10 \
18 | # 直方图匹配
19 | --match_histograms --match_mode hsl --match_prob 0.5 --match_ratio 1.0 \
20 | # 模糊、se
21 | --has_blur --use_se
22 |
23 | e.g.
24 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset transfer/yuwei/styleGAN/dy_cartoon/dy_cartoon.tar \
25 | --light 32 --result_dir transfer/yuwei/styleGAN/oilpaint_result/pai2/dy_only \
26 | --img_size 384 --device cuda:0 --num_workers 4
27 | """
28 | import os
29 | import time
30 | import datetime
31 | import argparse
32 |
33 | import torch
34 | import torch.backends.cudnn
35 |
36 | import utils
37 | from UGATIT import UGATIT
38 |
39 | VIDEO_EXT = ('.ts', '.mp4')
40 |
41 |
42 | def parse_args():
43 | """parsing and configuration"""
44 | desc = "Pytorch implementation of U-GAT-IT"
45 | parser = argparse.ArgumentParser(description=desc)
46 | parser.add_argument('--phase', type=str, default='train',
47 | choices=['train', 'val', 'test', 'video', 'video_dir', 'camera', 'img_dir', 'generate'],
48 | help='训练/验证/测试/视频/视频文件夹/摄像头/图像文件夹/以对齐的人脸图像文件夹 模式')
49 | parser.add_argument('--light', type=int, default=-1,
50 | help='[U-GAT-IT full version / U-GAT-IT light version],求gamma和beta的MLP的输入尺寸')
51 | parser.add_argument('--dataset', type=str, default='YOUR_DATASET_NAME', help='dataset_name')
52 |
53 | parser.add_argument('--ema_start', type=float, default=0.5, help='start ema after ratio of --iteration')
54 | parser.add_argument('--ema_beta', type=float, default=0.9999, help='ema gamma for genA2B/B2A, 0.9999^10000=0.37')
55 | parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations')
56 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
57 | parser.add_argument('--num_workers', type=int, default=1, help='每个 DataLoader 的进程数')
58 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image print freq')
59 | parser.add_argument('--calc_fid_freq', type=int, default=10000, help='The number of fid print freq')
60 | parser.add_argument('--fid_batch', type=int, default=50, help='计算fid score时的batch size')
61 | parser.add_argument('--save_freq', type=int, default=100000, help='The number of model save freq')
62 | parser.add_argument('--no_decay_flag', action='store_false', help='在中间iteration时,使用学习率下降策略,默认使用')
63 |
64 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
65 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='The weight decay')
66 | parser.add_argument('--adv_weight', type=float, default=1, help='Weight for GAN,建议值:0.8')
67 | parser.add_argument('--forward_adv_weight', type=float, default=1, help='前向对抗损失的权重,建议值:2')
68 | parser.add_argument('--backward_adv_weight', type=float, default=1, help='后向对抗损失的权重,建议值:1')
69 | parser.add_argument('--cycle_weight', type=float, default=10, help='Weight for Cycle,建议值:3')
70 | parser.add_argument('--identity_weight', type=float, default=10, help='Weight for Identity,建议值:1.5')
71 | parser.add_argument('--cam_weight', type=float, default=1000, help='Weight for CAM,建议值:1000')
72 |
73 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
74 | parser.add_argument('--n_res', type=int, default=4, help='The number of resblock')
75 | parser.add_argument('--n_global_dis', type=int, default=7, help='The number of global discriminator layer')
76 | parser.add_argument('--n_local_dis', type=int, default=5, help='The number of local discriminator layer')
77 | parser.add_argument('--img_size', type=int, default=384, help='The size of image')
78 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
79 | parser.add_argument('--result_dir', type=str, default='project/results', help='Directory name to save the results')
80 | parser.add_argument('--device', type=str, default='cuda:0', choices=['cpu', 'cuda:0'], help='Set gpu mode; [cpu, cuda:0]') # noqa, E501
81 | parser.add_argument('--resume', action='store_true', help='是否继续最后的一次训练')
82 |
83 | # 增强U-GAT-IT选项
84 | parser.add_argument('--aug_prob', type=float, default=0.2, help='对数据应用 resize & crop 数据增强的概率,建议值,<=0.2')
85 | parser.add_argument('--sn', action='store_false', help='默认D网络使用sn,建议使用,tf版本使用了')
86 | parser.add_argument('--has_blur', action='store_true', help='默认不使用模糊数据增强D网络,建议使用')
87 | parser.add_argument('--tv_loss', action='store_true', help='是否对生成图像使用TVLoss,默认不适用')
88 | parser.add_argument('--tv_weight', type=float, default=1.0, help='Weight for TVLoss,建议值:1.0')
89 | parser.add_argument('--use_se', action='store_true', help='resblock是否使用se-block,可以使用')
90 | parser.add_argument('--attention_gan', type=int, default=0, help='attention_gan,可以尝试')
91 | parser.add_argument('--attention_input', action='store_true', help='attention_gan时,是否把输入加入做attention,可以尝试')
92 | parser.add_argument('--cam_D_weight', type=float, default=1, help='判别器的CAM分类损失项权重,建议值:1')
93 | parser.add_argument('--cam_D_attention', action='store_false', help='是否使用判别器的CAM注意力机制,默认使用')
94 | # 直方图匹配
95 | parser.add_argument('--match_histograms', action='store_true', help='默认不使用直方图匹配,两个域真实域存在部分分布差异可尝试')
96 | parser.add_argument('--match_mode', type=str, default='hsl', help='默认直方图匹配使用hsl')
97 | parser.add_argument('--match_prob', type=float, default=0.5, help='从 B->A 进行直方图匹配的概率,否则 A->B 进行直方图匹配')
98 | parser.add_argument('--match_ratio', type=float, default=1.0, help='直方图匹配的比例')
99 | # 固定背景选项
100 | parser.add_argument('--hard_seg_edge', action='store_true', help='分割边界是否为硬边界,默认为软边界')
101 | parser.add_argument('--seg_fix_weight', type=float, default=-1, help='对生成图像的分割区域与原图做L1损失项的权重,建议值:50')
102 | parser.add_argument('--seg_fix_glass_mouth', action='store_true', help='分割是否固定眼镜边框和嘴巴内部(作为背景),默认不固定')
103 | parser.add_argument('--seg_D_mask', action='store_true', help='只计算分割mask区域的判别损失,默认都计算')
104 | parser.add_argument('--seg_G_detach', action='store_true', help='对生成图像的非分割mask区域做detach,默认不detach')
105 | parser.add_argument('--seg_D_cam_fea_mask', action='store_true', help='将判别器cam的feature map做mask替换,默认不替换')
106 | parser.add_argument('--seg_D_cam_inp_mask', action='store_true', help='将输入给判别器cam的图像做mask替换,默认不替换')
107 |
108 | # 测试
109 | parser.add_argument('--generator_model', type=str, default='', help='测试的A2B生成器路径')
110 | parser.add_argument('--video_path', type=str, default='', help='测试的A2B生成器路径所用的视频路径')
111 | parser.add_argument('--img_dir', type=str, default='', help='测试的A2B生成器路径所用的图像文件夹路径')
112 |
113 | return check_args(parser.parse_args())
114 |
115 |
116 | def check_args(args):
117 | """checking arguments"""
118 | utils.check_folder(args.result_dir)
119 | utils.Logger(file_name=os.path.join(args.result_dir, f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.log"),
120 | file_mode='a', should_flush=True)
121 | if args.cam_D_weight <= 0:
122 | args.cam_D_attention = False
123 | print(f'can not use D cam attention while D cam weight = {args.cam_D_weight} <= 0')
124 |
125 | if args.phase in ('video', 'video_dir', 'camera', 'img_dir', 'generate'):
126 | return args
127 |
128 | # --result_dir
129 | utils.check_folder(os.path.join(args.result_dir, args.dataset, 'model'))
130 | utils.check_folder(os.path.join(args.result_dir, args.dataset, 'img'))
131 | utils.check_folder(os.path.join(args.result_dir, args.dataset, 'test'))
132 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one'
133 | return args
134 |
135 |
136 | def main():
137 | # parse arguments
138 | args = parse_args()
139 | if args is None:
140 | exit()
141 | if not torch.cuda.is_available():
142 | args.device = 'cpu'
143 | args.device = torch.device(args.device)
144 | print(args)
145 |
146 | # 视频/摄像头/图像文件夹 测试
147 | if args.phase in ('video', 'video_dir', 'camera', 'img_dir', 'generate'):
148 | if args.generator_model == '':
149 | raise ValueError('No define A2B G model path!')
150 | from videos import test_video
151 | from networks import ResnetGenerator
152 | torch.set_flush_denormal(True)
153 | # 定义及加载生成模型
154 | generator = ResnetGenerator(input_nc=3, output_nc=3, ngf=args.ch, n_blocks=args.n_res,
155 | img_size=args.img_size, args=args)
156 | params = torch.load(args.generator_model, map_location=torch.device("cpu"))
157 | generator.load_state_dict(params['genA2B_ema'])
158 | # 模型测试
159 | tester = test_video.VideoTester(args, generator)
160 | if args.phase in ('video', 'video_dir'):
161 | assert args.video_path and os.path.exists(args.video_path), f'video path ({args.video_path}) error!'
162 | if args.phase == 'video_dir':
163 | video_paths = [os.path.join(args.video_path, video_name) for video_name in os.listdir(args.video_path)
164 | if video_name.endswith(VIDEO_EXT)]
165 | else:
166 | video_paths = [args.video_path]
167 | for video_path in video_paths:
168 | print(f'generating video: {video_path} ...')
169 | tester.video(video_path)
170 | elif args.phase == 'camera':
171 | tester.camera()
172 | elif args.phase == 'img_dir':
173 | assert args.img_dir and os.path.exists(args.img_dir), f'image directory ({args.img_dir}) error!'
174 | tester.image_dir(args.img_dir)
175 | elif args.phase == 'generate':
176 | assert args.img_dir and os.path.exists(args.img_dir), f'image directory ({args.img_dir}) error!'
177 | tester.generate_images(args.img_dir)
178 | else:
179 | raise Exception(f'unknown phase: {args.phase}')
180 | return
181 |
182 | if torch.backends.cudnn.enabled:
183 | torch.backends.cudnn.benchmark = True
184 | utils.setup_seed(0)
185 |
186 | # open session
187 | gan = UGATIT(args)
188 |
189 | # build graph
190 | gan.build_model()
191 |
192 | if args.phase == 'train':
193 | gan.train()
194 | print(" [*] Training finished!")
195 | args.phase = 'test'
196 |
197 | if args.phase == 'test':
198 | torch.set_flush_denormal(True)
199 | gan.test()
200 | print(" [*] Test finished!")
201 |
202 | return
203 |
204 |
205 | if __name__ == '__main__':
206 | main()
207 |
--------------------------------------------------------------------------------
/faceseg/networks_faceseg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from faceseg.resnet import Resnet18
10 |
11 |
12 | # from modules.bn import InPlaceABNSync as BatchNorm2d
13 |
14 |
15 | class ConvBNReLU(nn.Module):
16 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
17 | super(ConvBNReLU, self).__init__()
18 | self.conv = nn.Conv2d(in_chan,
19 | out_chan,
20 | kernel_size=ks,
21 | stride=stride,
22 | padding=padding,
23 | bias=False)
24 | self.bn = nn.BatchNorm2d(out_chan)
25 | self.init_weight()
26 |
27 | def forward(self, x):
28 | x = self.conv(x)
29 | x = F.relu(self.bn(x))
30 | return x
31 |
32 | def init_weight(self):
33 | for ly in self.children():
34 | if isinstance(ly, nn.Conv2d):
35 | nn.init.kaiming_normal_(ly.weight, a=1)
36 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
37 |
38 |
39 | class BiSeNetOutput(nn.Module):
40 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
41 | super(BiSeNetOutput, self).__init__()
42 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
43 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
44 | self.init_weight()
45 |
46 | def forward(self, x):
47 | x = self.conv(x)
48 | x = self.conv_out(x)
49 | return x
50 |
51 | def init_weight(self):
52 | for ly in self.children():
53 | if isinstance(ly, nn.Conv2d):
54 | nn.init.kaiming_normal_(ly.weight, a=1)
55 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
56 |
57 | def get_params(self):
58 | wd_params, nowd_params = [], []
59 | for name, module in self.named_modules():
60 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
61 | wd_params.append(module.weight)
62 | if not module.bias is None:
63 | nowd_params.append(module.bias)
64 | elif isinstance(module, nn.BatchNorm2d):
65 | nowd_params += list(module.parameters())
66 | return wd_params, nowd_params
67 |
68 |
69 | class AttentionRefinementModule(nn.Module):
70 | def __init__(self, in_chan, out_chan, *args, **kwargs):
71 | super(AttentionRefinementModule, self).__init__()
72 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
73 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
74 | self.bn_atten = nn.BatchNorm2d(out_chan)
75 | self.sigmoid_atten = nn.Sigmoid()
76 | self.init_weight()
77 |
78 | def forward(self, x):
79 | feat = self.conv(x)
80 | atten = F.avg_pool2d(feat, feat.size()[2:])
81 | atten = self.conv_atten(atten)
82 | atten = self.bn_atten(atten)
83 | atten = self.sigmoid_atten(atten)
84 | out = torch.mul(feat, atten)
85 | return out
86 |
87 | def init_weight(self):
88 | for ly in self.children():
89 | if isinstance(ly, nn.Conv2d):
90 | nn.init.kaiming_normal_(ly.weight, a=1)
91 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
92 |
93 |
94 | class ContextPath(nn.Module):
95 | def __init__(self, *args, **kwargs):
96 | super(ContextPath, self).__init__()
97 | self.resnet = Resnet18()
98 | self.arm16 = AttentionRefinementModule(256, 128)
99 | self.arm32 = AttentionRefinementModule(512, 128)
100 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
101 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
102 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
103 |
104 | self.init_weight()
105 |
106 | def forward(self, x):
107 | H0, W0 = x.size()[2:]
108 | feat8, feat16, feat32 = self.resnet(x)
109 | H8, W8 = feat8.size()[2:]
110 | H16, W16 = feat16.size()[2:]
111 | H32, W32 = feat32.size()[2:]
112 |
113 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
114 | avg = self.conv_avg(avg)
115 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
116 |
117 | feat32_arm = self.arm32(feat32)
118 | feat32_sum = feat32_arm + avg_up
119 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
120 | feat32_up = self.conv_head32(feat32_up)
121 |
122 | feat16_arm = self.arm16(feat16)
123 | feat16_sum = feat16_arm + feat32_up
124 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
125 | feat16_up = self.conv_head16(feat16_up)
126 |
127 | return feat8, feat16_up, feat32_up # x8, x8, x16
128 |
129 | def init_weight(self):
130 | for ly in self.children():
131 | if isinstance(ly, nn.Conv2d):
132 | nn.init.kaiming_normal_(ly.weight, a=1)
133 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
134 |
135 | def get_params(self):
136 | wd_params, nowd_params = [], []
137 | for name, module in self.named_modules():
138 | if isinstance(module, (nn.Linear, nn.Conv2d)):
139 | wd_params.append(module.weight)
140 | if not module.bias is None:
141 | nowd_params.append(module.bias)
142 | elif isinstance(module, nn.BatchNorm2d):
143 | nowd_params += list(module.parameters())
144 | return wd_params, nowd_params
145 |
146 |
147 | ### This is not used, since I replace this with the resnet feature with the same size
148 | class SpatialPath(nn.Module):
149 | def __init__(self, *args, **kwargs):
150 | super(SpatialPath, self).__init__()
151 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
152 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
153 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
154 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
155 | self.init_weight()
156 |
157 | def forward(self, x):
158 | feat = self.conv1(x)
159 | feat = self.conv2(feat)
160 | feat = self.conv3(feat)
161 | feat = self.conv_out(feat)
162 | return feat
163 |
164 | def init_weight(self):
165 | for ly in self.children():
166 | if isinstance(ly, nn.Conv2d):
167 | nn.init.kaiming_normal_(ly.weight, a=1)
168 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
169 |
170 | def get_params(self):
171 | wd_params, nowd_params = [], []
172 | for name, module in self.named_modules():
173 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
174 | wd_params.append(module.weight)
175 | if not module.bias is None:
176 | nowd_params.append(module.bias)
177 | elif isinstance(module, nn.BatchNorm2d):
178 | nowd_params += list(module.parameters())
179 | return wd_params, nowd_params
180 |
181 |
182 | class FeatureFusionModule(nn.Module):
183 | def __init__(self, in_chan, out_chan, *args, **kwargs):
184 | super(FeatureFusionModule, self).__init__()
185 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
186 | self.conv1 = nn.Conv2d(out_chan,
187 | out_chan // 4,
188 | kernel_size=1,
189 | stride=1,
190 | padding=0,
191 | bias=False)
192 | self.conv2 = nn.Conv2d(out_chan // 4,
193 | out_chan,
194 | kernel_size=1,
195 | stride=1,
196 | padding=0,
197 | bias=False)
198 | self.relu = nn.ReLU(inplace=True)
199 | self.sigmoid = nn.Sigmoid()
200 | self.init_weight()
201 |
202 | def forward(self, fsp, fcp):
203 | fcat = torch.cat([fsp, fcp], dim=1)
204 | feat = self.convblk(fcat)
205 | atten = F.avg_pool2d(feat, feat.size()[2:])
206 | atten = self.conv1(atten)
207 | atten = self.relu(atten)
208 | atten = self.conv2(atten)
209 | atten = self.sigmoid(atten)
210 | feat_atten = torch.mul(feat, atten)
211 | feat_out = feat_atten + feat
212 | return feat_out
213 |
214 | def init_weight(self):
215 | for ly in self.children():
216 | if isinstance(ly, nn.Conv2d):
217 | nn.init.kaiming_normal_(ly.weight, a=1)
218 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
219 |
220 | def get_params(self):
221 | wd_params, nowd_params = [], []
222 | for name, module in self.named_modules():
223 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
224 | wd_params.append(module.weight)
225 | if not module.bias is None:
226 | nowd_params.append(module.bias)
227 | elif isinstance(module, nn.BatchNorm2d):
228 | nowd_params += list(module.parameters())
229 | return wd_params, nowd_params
230 |
231 |
232 | class BiSeNet(nn.Module):
233 | def __init__(self, n_classes, *args, **kwargs):
234 | super(BiSeNet, self).__init__()
235 | self.cp = ContextPath()
236 | ## here self.sp is deleted
237 | self.ffm = FeatureFusionModule(256, 256)
238 | self.conv_out = BiSeNetOutput(256, 256, n_classes)
239 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
240 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
241 | self.init_weight()
242 |
243 | def forward(self, x):
244 | H, W = x.size()[2:]
245 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
246 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
247 | feat_fuse = self.ffm(feat_sp, feat_cp8)
248 |
249 | feat_out = self.conv_out(feat_fuse)
250 | feat_out16 = self.conv_out16(feat_cp8)
251 | feat_out32 = self.conv_out32(feat_cp16)
252 |
253 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
254 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
255 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
256 | return feat_out, feat_out16, feat_out32
257 |
258 | def init_weight(self):
259 | for ly in self.children():
260 | if isinstance(ly, nn.Conv2d):
261 | nn.init.kaiming_normal_(ly.weight, a=1)
262 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
263 |
264 | def get_params(self):
265 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
266 | for name, child in self.named_children():
267 | child_wd_params, child_nowd_params = child.get_params()
268 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
269 | lr_mul_wd_params += child_wd_params
270 | lr_mul_nowd_params += child_nowd_params
271 | else:
272 | wd_params += child_wd_params
273 | nowd_params += child_nowd_params
274 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
275 |
276 |
277 | if __name__ == "__main__":
278 | net = BiSeNet(19)
279 | net.cuda()
280 | net.eval()
281 | in_ten = torch.randn(16, 3, 640, 480).cuda()
282 | out, out16, out32 = net(in_ten)
283 | print(out.shape)
284 |
285 | net.get_params()
286 |
--------------------------------------------------------------------------------
/metrics/fid/inception.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 |
7 | try:
8 | from torchvision.models.utils import load_state_dict_from_url
9 | except ImportError:
10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
11 |
12 | # Inception weights ported to Pytorch from
13 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
14 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
15 | FID_WEIGHTS_LOCAL = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pt_inception-2015-12-05-6726825d.pth')
16 |
17 |
18 | class InceptionV3(nn.Module):
19 | """Pretrained InceptionV3 network returning feature maps"""
20 |
21 | # Index of default block of inception to return,
22 | # corresponds to output of final average pooling
23 | DEFAULT_BLOCK_INDEX = 3
24 |
25 | # Maps feature dimensionality to their output blocks indices
26 | BLOCK_INDEX_BY_DIM = {
27 | 64: 0, # First max pooling features
28 | 192: 1, # Second max pooling featurs
29 | 768: 2, # Pre-aux classifier features
30 | 2048: 3 # Final average pooling features
31 | }
32 |
33 | def __init__(self,
34 | output_blocks=(DEFAULT_BLOCK_INDEX,),
35 | resize_input=True,
36 | normalize_input=True,
37 | requires_grad=False,
38 | use_fid_inception=True):
39 | """Build pretrained InceptionV3
40 |
41 | Parameters
42 | ----------
43 | output_blocks : list of int
44 | Indices of blocks to return features of. Possible values are:
45 | - 0: corresponds to output of first max pooling
46 | - 1: corresponds to output of second max pooling
47 | - 2: corresponds to output which is fed to aux classifier
48 | - 3: corresponds to output of final average pooling
49 | resize_input : bool
50 | If true, bilinearly resizes input to width and height 299 before
51 | feeding input to model. As the network without fully connected
52 | layers is fully convolutional, it should be able to handle inputs
53 | of arbitrary size, so resizing might not be strictly needed
54 | normalize_input : bool
55 | If true, scales the input from range (0, 1) to the range the
56 | pretrained Inception network expects, namely (-1, 1)
57 | requires_grad : bool
58 | If true, parameters of the model require gradients. Possibly useful
59 | for finetuning the network
60 | use_fid_inception : bool
61 | If true, uses the pretrained Inception model used in Tensorflow's
62 | FID implementation. If false, uses the pretrained Inception model
63 | available in torchvision. The FID Inception model has different
64 | weights and a slightly different structure from torchvision's
65 | Inception model. If you want to compute FID scores, you are
66 | strongly advised to set this parameter to true to get comparable
67 | results.
68 | """
69 | super(InceptionV3, self).__init__()
70 |
71 | self.resize_input = resize_input
72 | self.normalize_input = normalize_input
73 | self.output_blocks = sorted(output_blocks)
74 | self.last_needed_block = max(output_blocks)
75 |
76 | assert self.last_needed_block <= 3, \
77 | 'Last possible output block index is 3'
78 |
79 | self.blocks = nn.ModuleList()
80 |
81 | if use_fid_inception:
82 | inception = fid_inception_v3()
83 | else:
84 | inception = _inception_v3(pretrained=True)
85 |
86 | # Block 0: input to maxpool1
87 | block0 = [
88 | inception.Conv2d_1a_3x3,
89 | inception.Conv2d_2a_3x3,
90 | inception.Conv2d_2b_3x3,
91 | nn.MaxPool2d(kernel_size=3, stride=2)
92 | ]
93 | self.blocks.append(nn.Sequential(*block0))
94 |
95 | # Block 1: maxpool1 to maxpool2
96 | if self.last_needed_block >= 1:
97 | block1 = [
98 | inception.Conv2d_3b_1x1,
99 | inception.Conv2d_4a_3x3,
100 | nn.MaxPool2d(kernel_size=3, stride=2)
101 | ]
102 | self.blocks.append(nn.Sequential(*block1))
103 |
104 | # Block 2: maxpool2 to aux classifier
105 | if self.last_needed_block >= 2:
106 | block2 = [
107 | inception.Mixed_5b,
108 | inception.Mixed_5c,
109 | inception.Mixed_5d,
110 | inception.Mixed_6a,
111 | inception.Mixed_6b,
112 | inception.Mixed_6c,
113 | inception.Mixed_6d,
114 | inception.Mixed_6e,
115 | ]
116 | self.blocks.append(nn.Sequential(*block2))
117 |
118 | # Block 3: aux classifier to final avgpool
119 | if self.last_needed_block >= 3:
120 | block3 = [
121 | inception.Mixed_7a,
122 | inception.Mixed_7b,
123 | inception.Mixed_7c,
124 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
125 | ]
126 | self.blocks.append(nn.Sequential(*block3))
127 |
128 | for param in self.parameters():
129 | param.requires_grad = requires_grad
130 |
131 | def forward(self, inp):
132 | """Get Inception feature maps
133 |
134 | Parameters
135 | ----------
136 | inp : torch.autograd.Variable
137 | Input tensor of shape Bx3xHxW. Values are expected to be in
138 | range (0, 1)
139 |
140 | Returns
141 | -------
142 | List of torch.autograd.Variable, corresponding to the selected output
143 | block, sorted ascending by index
144 | """
145 | outp = []
146 | x = inp
147 |
148 | if self.resize_input:
149 | x = F.interpolate(x,
150 | size=(299, 299),
151 | mode='bilinear',
152 | align_corners=False)
153 |
154 | if self.normalize_input:
155 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
156 |
157 | for idx, block in enumerate(self.blocks):
158 | x = block(x)
159 | if idx in self.output_blocks:
160 | outp.append(x)
161 |
162 | if idx == self.last_needed_block:
163 | break
164 |
165 | return outp
166 |
167 |
168 | def _inception_v3(*args, **kwargs):
169 | """Wraps `torchvision.models.inception_v3`
170 |
171 | Skips default weight inititialization if supported by torchvision version.
172 | See https://github.com/mseitzer/pytorch-fid/issues/28.
173 | """
174 | try:
175 | version = tuple(map(int, torchvision.__version__.split('.')[:2]))
176 | except ValueError:
177 | # Just a caution against weird version strings
178 | version = (0,)
179 |
180 | if version >= (0, 6):
181 | kwargs['init_weights'] = False
182 |
183 | return torchvision.models.inception_v3(*args, **kwargs)
184 |
185 |
186 | def fid_inception_v3():
187 | """Build pretrained Inception model for FID computation
188 |
189 | The Inception model for FID computation uses a different set of weights
190 | and has a slightly different structure than torchvision's Inception.
191 |
192 | This method first constructs torchvision's Inception and then patches the
193 | necessary parts that are different in the FID Inception model.
194 | """
195 | inception = _inception_v3(num_classes=1008,
196 | aux_logits=False,
197 | pretrained=False)
198 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
199 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
200 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
201 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
202 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
203 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
204 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
205 | inception.Mixed_7b = FIDInceptionE_1(1280)
206 | inception.Mixed_7c = FIDInceptionE_2(2048)
207 |
208 | try:
209 | state_dict = torch.load(FID_WEIGHTS_LOCAL)
210 | except:
211 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
212 | inception.load_state_dict(state_dict)
213 | return inception
214 |
215 |
216 | class FIDInceptionA(torchvision.models.inception.InceptionA):
217 | """InceptionA block patched for FID computation"""
218 | def __init__(self, in_channels, pool_features):
219 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
220 |
221 | def forward(self, x):
222 | branch1x1 = self.branch1x1(x)
223 |
224 | branch5x5 = self.branch5x5_1(x)
225 | branch5x5 = self.branch5x5_2(branch5x5)
226 |
227 | branch3x3dbl = self.branch3x3dbl_1(x)
228 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
229 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
230 |
231 | # Patch: Tensorflow's average pool does not use the padded zero's in
232 | # its average calculation
233 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
234 | count_include_pad=False)
235 | branch_pool = self.branch_pool(branch_pool)
236 |
237 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
238 | return torch.cat(outputs, 1)
239 |
240 |
241 | class FIDInceptionC(torchvision.models.inception.InceptionC):
242 | """InceptionC block patched for FID computation"""
243 | def __init__(self, in_channels, channels_7x7):
244 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
245 |
246 | def forward(self, x):
247 | branch1x1 = self.branch1x1(x)
248 |
249 | branch7x7 = self.branch7x7_1(x)
250 | branch7x7 = self.branch7x7_2(branch7x7)
251 | branch7x7 = self.branch7x7_3(branch7x7)
252 |
253 | branch7x7dbl = self.branch7x7dbl_1(x)
254 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
255 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
256 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
257 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
258 |
259 | # Patch: Tensorflow's average pool does not use the padded zero's in
260 | # its average calculation
261 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
262 | count_include_pad=False)
263 | branch_pool = self.branch_pool(branch_pool)
264 |
265 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
266 | return torch.cat(outputs, 1)
267 |
268 |
269 | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
270 | """First InceptionE block patched for FID computation"""
271 | def __init__(self, in_channels):
272 | super(FIDInceptionE_1, self).__init__(in_channels)
273 |
274 | def forward(self, x):
275 | branch1x1 = self.branch1x1(x)
276 |
277 | branch3x3 = self.branch3x3_1(x)
278 | branch3x3 = [
279 | self.branch3x3_2a(branch3x3),
280 | self.branch3x3_2b(branch3x3),
281 | ]
282 | branch3x3 = torch.cat(branch3x3, 1)
283 |
284 | branch3x3dbl = self.branch3x3dbl_1(x)
285 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
286 | branch3x3dbl = [
287 | self.branch3x3dbl_3a(branch3x3dbl),
288 | self.branch3x3dbl_3b(branch3x3dbl),
289 | ]
290 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
291 |
292 | # Patch: Tensorflow's average pool does not use the padded zero's in
293 | # its average calculation
294 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
295 | count_include_pad=False)
296 | branch_pool = self.branch_pool(branch_pool)
297 |
298 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
299 | return torch.cat(outputs, 1)
300 |
301 |
302 | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
303 | """Second InceptionE block patched for FID computation"""
304 | def __init__(self, in_channels):
305 | super(FIDInceptionE_2, self).__init__(in_channels)
306 |
307 | def forward(self, x):
308 | branch1x1 = self.branch1x1(x)
309 |
310 | branch3x3 = self.branch3x3_1(x)
311 | branch3x3 = [
312 | self.branch3x3_2a(branch3x3),
313 | self.branch3x3_2b(branch3x3),
314 | ]
315 | branch3x3 = torch.cat(branch3x3, 1)
316 |
317 | branch3x3dbl = self.branch3x3dbl_1(x)
318 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
319 | branch3x3dbl = [
320 | self.branch3x3dbl_3a(branch3x3dbl),
321 | self.branch3x3dbl_3b(branch3x3dbl),
322 | ]
323 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
324 |
325 | # Patch: The FID Inception model uses max pooling instead of average
326 | # pooling. This is likely an error in this specific Inception
327 | # implementation, as other Inception models use average pooling here
328 | # (which matches the description in the paper).
329 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
330 | branch_pool = self.branch_pool(branch_pool)
331 |
332 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
333 | return torch.cat(outputs, 1)
334 |
--------------------------------------------------------------------------------
/metrics/fid/fid_score.py:
--------------------------------------------------------------------------------
1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2 |
3 | The FID metric calculates the distance between two distributions of images.
4 | Typically, we have summary statistics (mean & covariance matrix) of one
5 | of these distributions, while the 2nd distribution is given by a GAN.
6 |
7 | When run as a stand-alone program, it compares the distribution of
8 | images that are stored as PNG/JPEG at a specified location with a
9 | distribution given by summary statistics (in pickle format).
10 |
11 | The FID is calculated by assuming that X_1 and X_2 are the activations of
12 | the pool_3 layer of the inception net for generated samples and real world
13 | samples respectively.
14 |
15 | See --help to see further details.
16 |
17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18 | of Tensorflow
19 |
20 | Copyright 2018 Institute of Bioinformatics, JKU Linz
21 |
22 | Licensed under the Apache License, Version 2.0 (the "License");
23 | you may not use this file except in compliance with the License.
24 | You may obtain a copy of the License at
25 |
26 | http://www.apache.org/licenses/LICENSE-2.0
27 |
28 | Unless required by applicable law or agreed to in writing, software
29 | distributed under the License is distributed on an "AS IS" BASIS,
30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31 | See the License for the specific language governing permissions and
32 | limitations under the License.
33 | """
34 | import os
35 | import pathlib
36 |
37 | import numpy as np
38 | import torch
39 | import torchvision.transforms as TF
40 | from PIL import Image
41 | from scipy import linalg
42 | from torch.nn.functional import adaptive_avg_pool2d
43 | from tqdm import tqdm
44 |
45 | from metrics.fid.inception import InceptionV3
46 |
47 | # parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
48 | # parser.add_argument('--batch-size', type=int, default=50,
49 | # help='Batch size to use')
50 | # parser.add_argument('--num-workers', type=int, default=8,
51 | # help='Number of processes to use for data loading')
52 | # parser.add_argument('--device', type=str, default=None,
53 | # help='Device to use. Like cuda, cuda:0 or cpu')
54 | # parser.add_argument('--dims', type=int, default=2048,
55 | # choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
56 | # help=('Dimensionality of Inception features to use. '
57 | # 'By default, uses pool3 features'))
58 | # parser.add_argument('path', type=str, nargs=2,
59 | # help=('Paths to the generated images or '
60 | # 'to .npz statistic files'))
61 |
62 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
63 | 'tif', 'tiff', 'webp'}
64 |
65 |
66 | class ImagePathDataset(torch.utils.data.Dataset):
67 | def __init__(self, files, transforms=None):
68 | self.files = files
69 | self.transforms = transforms
70 |
71 | def __len__(self):
72 | return len(self.files)
73 |
74 | def __getitem__(self, i):
75 | path = self.files[i]
76 | img = Image.open(path).convert('RGB')
77 | if self.transforms is not None:
78 | img = self.transforms(img)
79 | return img
80 |
81 |
82 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8):
83 | """Calculates the activations of the pool_3 layer for all images.
84 |
85 | Params:
86 | -- files : List of image files paths
87 | -- model : Instance of inception model
88 | -- batch_size : Batch size of images for the model to process at once.
89 | Make sure that the number of samples is a multiple of
90 | the batch size, otherwise some samples are ignored. This
91 | behavior is retained to match the original FID score
92 | implementation.
93 | -- dims : Dimensionality of features returned by Inception
94 | -- device : Device to run calculations
95 | -- num_workers : Number of parallel dataloader workers
96 |
97 | Returns:
98 | -- A numpy array of dimension (num images, dims) that contains the
99 | activations of the given tensor when feeding inception with the
100 | query tensor.
101 | """
102 | model.eval()
103 |
104 | if batch_size > len(files):
105 | print(('Warning: batch size is bigger than the data size. '
106 | 'Setting batch size to data size'))
107 | batch_size = len(files)
108 |
109 | dataset = ImagePathDataset(files, transforms=TF.ToTensor())
110 | dataloader = torch.utils.data.DataLoader(dataset,
111 | batch_size=batch_size,
112 | shuffle=False,
113 | drop_last=False,
114 | num_workers=num_workers)
115 |
116 | pred_arr = np.empty((len(files), dims))
117 |
118 | start_idx = 0
119 |
120 | for batch in tqdm(dataloader):
121 | batch = batch.to(device)
122 |
123 | with torch.no_grad():
124 | pred = model(batch)[0]
125 |
126 | # If model output is not scalar, apply global spatial average pooling.
127 | # This happens if you choose a dimensionality not equal 2048.
128 | if pred.size(2) != 1 or pred.size(3) != 1:
129 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
130 |
131 | pred = pred.squeeze(3).squeeze(2).cpu().numpy()
132 |
133 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred
134 |
135 | start_idx = start_idx + pred.shape[0]
136 |
137 | return pred_arr
138 |
139 |
140 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
141 | """Numpy implementation of the Frechet Distance.
142 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
143 | and X_2 ~ N(mu_2, C_2) is
144 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
145 |
146 | Stable version by Dougal J. Sutherland.
147 |
148 | Params:
149 | -- mu1 : Numpy array containing the activations of a layer of the
150 | inception net (like returned by the function 'get_predictions')
151 | for generated samples.
152 | -- mu2 : The sample mean over activations, precalculated on an
153 | representative data set.
154 | -- sigma1: The covariance matrix over activations for generated samples.
155 | -- sigma2: The covariance matrix over activations, precalculated on an
156 | representative data set.
157 |
158 | Returns:
159 | -- : The Frechet Distance.
160 | """
161 |
162 | mu1 = np.atleast_1d(mu1)
163 | mu2 = np.atleast_1d(mu2)
164 |
165 | sigma1 = np.atleast_2d(sigma1)
166 | sigma2 = np.atleast_2d(sigma2)
167 |
168 | assert mu1.shape == mu2.shape, \
169 | 'Training and test mean vectors have different lengths'
170 | assert sigma1.shape == sigma2.shape, \
171 | 'Training and test covariances have different dimensions'
172 |
173 | diff = mu1 - mu2
174 |
175 | # Product might be almost singular
176 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
177 | if not np.isfinite(covmean).all():
178 | msg = ('fid calculation produces singular product; '
179 | 'adding %s to diagonal of cov estimates') % eps
180 | print(msg)
181 | offset = np.eye(sigma1.shape[0]) * eps
182 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
183 |
184 | # Numerical error might give slight imaginary component
185 | if np.iscomplexobj(covmean):
186 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
187 | m = np.max(np.abs(covmean.imag))
188 | raise ValueError('Imaginary component {}'.format(m))
189 | covmean = covmean.real
190 |
191 | tr_covmean = np.trace(covmean)
192 |
193 | return (diff.dot(diff) + np.trace(sigma1)
194 | + np.trace(sigma2) - 2 * tr_covmean)
195 |
196 |
197 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
198 | device='cpu', num_workers=8):
199 | """Calculation of the statistics used by the FID.
200 | Params:
201 | -- files : List of image files paths
202 | -- model : Instance of inception model
203 | -- batch_size : The images numpy array is split into batches with
204 | batch size batch_size. A reasonable batch size
205 | depends on the hardware.
206 | -- dims : Dimensionality of features returned by Inception
207 | -- device : Device to run calculations
208 | -- num_workers : Number of parallel dataloader workers
209 |
210 | Returns:
211 | -- mu : The mean over samples of the activations of the pool_3 layer of
212 | the inception model.
213 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
214 | the inception model.
215 | """
216 | act = get_activations(files, model, batch_size, dims, device, num_workers)
217 | mu = np.mean(act, axis=0)
218 | sigma = np.cov(act, rowvar=False)
219 | return mu, sigma
220 |
221 |
222 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8):
223 | if path.endswith('.npz'):
224 | with np.load(path) as f:
225 | m, s = f['mu'][:], f['sigma'][:]
226 | else:
227 | path = pathlib.Path(path)
228 | files = sorted([file for ext in IMAGE_EXTENSIONS
229 | for file in path.glob('*.{}'.format(ext))])
230 | m, s = calculate_activation_statistics(files, model, batch_size,
231 | dims, device, num_workers)
232 |
233 | return m, s
234 |
235 |
236 | class FIDScore:
237 |
238 | def __init__(self, device=None, dims=2048, batch_size=50, num_workers=4):
239 | self.device = device
240 | self.dims = dims
241 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
242 | self.inception_model = InceptionV3([block_idx]).to(device).eval()
243 | self.batch_size = batch_size
244 | self.num_workers = num_workers
245 |
246 | def calc_mean_std(self, p):
247 | """Calculates the mean and covariance matrix of path"""
248 | if not os.path.exists(p):
249 | raise RuntimeError(f'Invalid path: {p}')
250 | mean, std = compute_statistics_of_path(p, self.inception_model, self.batch_size,
251 | self.dims, self.device, self.num_workers)
252 | return mean, std
253 |
254 | def calc_mean_std_with_gen(self, gen_model, data_loader):
255 | self.inception_model.eval()
256 |
257 | pred_arr = np.empty((len(data_loader.dataset), self.dims))
258 | start_idx = 0
259 | for batch in tqdm(data_loader):
260 | batch = batch.to(self.device, non_blocking=True)
261 |
262 | with torch.no_grad():
263 | pred = self.inception_model(gen_model(batch))[0]
264 |
265 | # If model output is not scalar, apply global spatial average pooling.
266 | # This happens if you choose a dimensionality not equal 2048.
267 | if pred.size(2) != 1 or pred.size(3) != 1:
268 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
269 |
270 | pred = pred.squeeze(3).squeeze(2).cpu().numpy()
271 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred
272 | start_idx = start_idx + pred.shape[0]
273 |
274 | mu = np.mean(pred_arr, axis=0)
275 | sigma = np.cov(pred_arr, rowvar=False)
276 | return mu, sigma
277 |
278 | @staticmethod
279 | def calc_fid(mean_std1, mean_std2):
280 | """Calculates the FID of two paths"""
281 | m1, s1 = mean_std1
282 | m2, s2 = mean_std2
283 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
284 | return fid_value
285 |
286 |
287 | # def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=8):
288 | # """Calculates the FID of two paths"""
289 | # for p in paths:
290 | # if not os.path.exists(p):
291 | # raise RuntimeError('Invalid path: %s' % p)
292 | #
293 | # block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
294 | #
295 | # model = InceptionV3([block_idx]).to(device)
296 | #
297 | # m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
298 | # dims, device, num_workers)
299 | # m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
300 | # dims, device, num_workers)
301 | # fid_value = calculate_frechet_distance(m1, s1, m2, s2)
302 | #
303 | # return fid_value
304 | #
305 | #
306 | # def main():
307 | # args = parser.parse_args()
308 | #
309 | # if args.device is None:
310 | # device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
311 | # else:
312 | # device = torch.device(args.device)
313 | #
314 | # fid_value = calculate_fid_given_paths(args.path,
315 | # args.batch_size,
316 | # device,
317 | # args.dims,
318 | # args.num_workers)
319 | # print('FID: ', fid_value)
320 | #
321 | #
322 | if __name__ == '__main__':
323 | file_dir_1 = 'dataset/selfie2anime/testB'
324 | file_dir_2 = 'dataset/selfie2anime/testB'
325 |
326 | fid_score = FIDScore('cpu', batch_size=2, num_workers=1)
327 | mean_std_A_1 = fid_score.calc_mean_std(file_dir_1)
328 | mean_std_A_2 = fid_score.calc_mean_std(file_dir_2)
329 | score = fid_score.calc_fid(mean_std_A_1, mean_std_A_2)
330 | print(f'fid score between {file_dir_1} and {file_dir_2}: {score}')
331 |
332 | mean_std_A_1 = fid_score.calc_mean_std(file_dir_1)
333 | mean_std_A_2 = fid_score.calc_mean_std(file_dir_2)
334 | score = fid_score.calc_fid(mean_std_A_1, mean_std_A_2)
335 | print(f'fid score between {file_dir_1} and {file_dir_2}: {score}')
336 |
337 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional
5 | from torch.nn.parameter import Parameter
6 |
7 |
8 | class ResnetGenerator(nn.Module):
9 |
10 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, args=None):
11 | assert (n_blocks >= 0)
12 | super(ResnetGenerator, self).__init__()
13 | self.input_nc = input_nc
14 | self.output_nc = output_nc
15 | self.ngf = ngf
16 | self.n_blocks = n_blocks
17 | self.img_size = img_size
18 | self.args = args
19 |
20 | self.light = args.light
21 | self.attention_gan = args.attention_gan
22 | self.attention_input = args.attention_input
23 | self.use_se = args.use_se
24 |
25 | # 下采样模块:特征抽取、下采样、Bottleneck(resnet-block)特征编码
26 | DownBlock = [nn.ReflectionPad2d(3),
27 | nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False),
28 | nn.InstanceNorm2d(ngf, affine=True),
29 | nn.ReLU(True)]
30 | # 下采样
31 | n_downsampling = 2
32 | for i in range(n_downsampling):
33 | mult = 2 ** i
34 | DownBlock += [nn.ReflectionPad2d(1),
35 | nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False),
36 | nn.InstanceNorm2d(ngf * mult * 2, affine=True),
37 | nn.ReLU(True)]
38 |
39 | # Down-Sampling Bottleneck
40 | mult = 2 ** n_downsampling
41 | for i in range(n_blocks):
42 | DownBlock += [ResnetBlock(ngf * mult, use_bias=False, use_se=self.use_se)]
43 |
44 | # Class Activation Map
45 | self.gap_fc = nn.Linear(ngf * mult, 1, bias=True)
46 | # self.gmp_fc = self.gap_fc # tf版本
47 | self.gmp_fc = nn.Linear(ngf * mult, 1, bias=True) # pytorch版本
48 | self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True)
49 | self.relu = nn.ReLU(True)
50 |
51 | # 生成 gamma 和 beta,小模型和大模型
52 | if self.light > 0:
53 | FC = [nn.Linear(self.light * self.light * ngf * mult, ngf * mult, bias=True)]
54 | else:
55 | FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=True)]
56 | FC += [nn.ReLU(True),
57 | nn.Linear(ngf * mult, ngf * mult, bias=True),
58 | nn.ReLU(True)]
59 | self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=True)
60 | self.beta = nn.Linear(ngf * mult, ngf * mult, bias=True)
61 |
62 | # Up-Sampling Bottleneck
63 | for i in range(n_blocks):
64 | setattr(self, 'UpBlock1_' + str(i + 1), ResnetAdaLINBlock(ngf * mult, use_bias=True, use_se=self.use_se))
65 | # Up-Sampling
66 | UpBlock2 = []
67 | for i in range(n_downsampling):
68 | mult = 2 ** (n_downsampling - i)
69 | UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'),
70 | nn.ReflectionPad2d(1),
71 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=True),
72 | LIN(int(ngf * mult / 2)),
73 | nn.ReLU(True)
74 | ]
75 |
76 | if self.attention_gan > 0:
77 | UpBlock_attention = []
78 | mult = 2 ** n_downsampling
79 | for i in range(n_blocks):
80 | UpBlock_attention += [ResnetBlock(ngf * mult, use_bias=False, use_se=self.use_se)]
81 | for i in range(n_downsampling):
82 | mult = 2 ** (n_downsampling - i)
83 | UpBlock_attention += [nn.Upsample(scale_factor=2, mode='nearest'),
84 | nn.ReflectionPad2d(1),
85 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3,
86 | stride=1, padding=0, bias=True),
87 | LIN(int(ngf * mult / 2)),
88 | nn.ReLU(True)]
89 | UpBlock_attention += [nn.Conv2d(ngf, self.attention_gan, kernel_size=1, stride=1, padding=0, bias=True),
90 | nn.Softmax(dim=1)]
91 | self.UpBlock_attention = nn.Sequential(*UpBlock_attention)
92 | if self.attention_input:
93 | output_nc *= (self.attention_gan - 1)
94 | else:
95 | output_nc *= self.attention_gan
96 |
97 | UpBlock2 += [nn.ReflectionPad2d(3),
98 | nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=True),
99 | nn.Tanh()]
100 |
101 | self.DownBlock = nn.Sequential(*DownBlock)
102 | self.FC = nn.Sequential(*FC)
103 | self.UpBlock2 = nn.Sequential(*UpBlock2)
104 |
105 | def forward(self, input_x):
106 | attention = None
107 | x = self.DownBlock(input_x)
108 | # cam作为attention加权x
109 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
110 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
111 | gap_weight = sum(list(self.gap_fc.parameters()))
112 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
113 |
114 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
115 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
116 | gmp_weight = sum(list(self.gmp_fc.parameters()))
117 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
118 |
119 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
120 | x = torch.cat([gap, gmp], 1)
121 | x = self.relu(self.conv1x1(x))
122 |
123 | heatmap = torch.sum(x, dim=1, keepdim=True)
124 | # 生成上采样模块的adaILN的gamma和beta
125 | if self.light > 0:
126 | x_ = torch.nn.functional.adaptive_avg_pool2d(x, self.light)
127 | x_ = self.FC(x_.view(x_.shape[0], -1))
128 | else:
129 | x_ = self.FC(x.view(x.shape[0], -1))
130 | gamma, beta = self.gamma(x_), self.beta(x_)
131 |
132 | new_x = x
133 | for i in range(self.n_blocks):
134 | new_x = getattr(self, 'UpBlock1_' + str(i + 1))(new_x, gamma, beta)
135 |
136 | out = self.UpBlock2(new_x)
137 |
138 | if self.attention_gan > 0:
139 | attention = self.UpBlock_attention(x)
140 | batch_size, attention_ch, height, width = attention.shape
141 | if self.attention_input:
142 | out = torch.cat([input_x, out], dim=1)
143 | out = out.view(batch_size, 3, attention_ch, height, width)
144 | out = out * attention.view(batch_size, 1, attention_ch, height, width)
145 | out = out.sum(dim=2)
146 |
147 | return out, cam_logit, heatmap, attention
148 |
149 |
150 | class ChannelSELayer(nn.Module):
151 | def __init__(self, in_size, reduction=4, min_hidden_channel=8):
152 | super(ChannelSELayer, self).__init__()
153 |
154 | hidden_channel = max(in_size // reduction, min_hidden_channel)
155 |
156 | self.se = nn.Sequential(
157 | nn.AdaptiveAvgPool2d(1),
158 | nn.Conv2d(in_size, hidden_channel, kernel_size=1, stride=1, bias=True),
159 | nn.ReLU(inplace=True),
160 | nn.Conv2d(hidden_channel, in_size, kernel_size=1, stride=1, bias=True),
161 | nn.Sigmoid()
162 | )
163 |
164 | def forward(self, x):
165 | return self.se(x) * x
166 |
167 |
168 | class SpatialSELayer(nn.Module):
169 | """
170 | Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in:
171 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
172 | """
173 |
174 | def __init__(self, num_channels):
175 | """
176 | :param num_channels: No of input channels
177 | """
178 | super(SpatialSELayer, self).__init__()
179 | self.conv = nn.Conv2d(num_channels, 1, 1)
180 | self.sigmoid = nn.Sigmoid()
181 |
182 | def forward(self, input_tensor):
183 | """
184 | :param input_tensor: X, shape = (batch_size, num_channels, H, W)
185 | :return: output_tensor
186 | """
187 | out = self.conv(input_tensor)
188 | squeeze_tensor = self.sigmoid(out)
189 | return squeeze_tensor * input_tensor
190 |
191 |
192 | class ChannelSpatialSELayer(nn.Module):
193 | """
194 | Re-implementation of concurrent spatial and channel squeeze & excitation:
195 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks,
196 | MICCAI 2018, arXiv:1803.02579*
197 | """
198 |
199 | def __init__(self, num_channels, reduction_ratio=4):
200 | """
201 | :param num_channels: No of input channels
202 | :param reduction_ratio: By how much should the num_channels should be reduced
203 | """
204 | super(ChannelSpatialSELayer, self).__init__()
205 | self.cSE = ChannelSELayer(num_channels, reduction_ratio)
206 | self.sSE = SpatialSELayer(num_channels)
207 |
208 | def forward(self, input_tensor):
209 | """
210 | :param input_tensor: X, shape = (batch_size, num_channels, H, W)
211 | :return: output_tensor
212 | """
213 | attention = self.cSE(input_tensor) + self.sSE(input_tensor)
214 | return attention
215 |
216 |
217 | class ResnetBlock(nn.Module):
218 | def __init__(self, dim, use_bias, use_se=False):
219 | super(ResnetBlock, self).__init__()
220 | conv_block = []
221 | conv_block += [nn.ReflectionPad2d(1),
222 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
223 | nn.InstanceNorm2d(dim, affine=True),
224 | nn.ReLU(True)]
225 |
226 | conv_block += [nn.ReflectionPad2d(1),
227 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
228 | nn.InstanceNorm2d(dim, affine=True)]
229 | if use_se:
230 | conv_block += [ChannelSpatialSELayer(dim)]
231 | self.conv_block = nn.Sequential(*conv_block)
232 |
233 | def forward(self, x):
234 | out = x + self.conv_block(x)
235 | return out
236 |
237 |
238 | class ResnetAdaLINBlock(nn.Module):
239 | def __init__(self, dim, use_bias, use_se=False):
240 | super(ResnetAdaLINBlock, self).__init__()
241 | self.pad1 = nn.ReflectionPad2d(1)
242 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
243 | self.norm1 = AdaLIN(dim)
244 | self.relu1 = nn.ReLU(True)
245 |
246 | self.pad2 = nn.ReflectionPad2d(1)
247 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
248 | self.norm2 = AdaLIN(dim)
249 | self.use_se = use_se
250 | if use_se:
251 | self.se = ChannelSpatialSELayer(dim)
252 |
253 | def forward(self, x, gamma, beta):
254 | out = self.pad1(x)
255 | out = self.conv1(out)
256 | out = self.norm1(out, gamma, beta)
257 | out = self.relu1(out)
258 | out = self.pad2(out)
259 | out = self.conv2(out)
260 | out = self.norm2(out, gamma, beta)
261 | if self.use_se:
262 | out = self.se(out)
263 | return out + x
264 |
265 |
266 | class AdaLIN(nn.Module):
267 | def __init__(self, num_features, eps=1e-5):
268 | super(AdaLIN, self).__init__()
269 | self.eps = eps
270 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
271 | self.rho.data.fill_(0.9)
272 |
273 | def forward(self, in_x, gamma, beta):
274 | in_mean, in_var = torch.mean(in_x, dim=[2, 3], keepdim=True), torch.var(in_x, dim=[2, 3], keepdim=True)
275 | out_in = (in_x - in_mean) / torch.sqrt(in_var + self.eps)
276 | ln_mean, ln_var = torch.mean(in_x, dim=[1, 2, 3], keepdim=True), torch.var(in_x, dim=[1, 2, 3], keepdim=True)
277 | out_ln = (in_x - ln_mean) / torch.sqrt(ln_var + self.eps)
278 | out = self.rho.expand(in_x.shape[0], -1, -1, -1) * out_in + (
279 | 1 - self.rho.expand(in_x.shape[0], -1, -1, -1)) * out_ln
280 | out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
281 |
282 | return out
283 |
284 |
285 | class LIN(nn.Module):
286 | def __init__(self, num_features, eps=1e-5):
287 | super(LIN, self).__init__()
288 | self.eps = eps
289 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
290 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
291 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
292 | self.rho.data.fill_(0.0)
293 | self.gamma.data.fill_(1.0)
294 | self.beta.data.fill_(0.0)
295 |
296 | def forward(self, in_x):
297 | in_mean, in_var = torch.mean(in_x, dim=[2, 3], keepdim=True), torch.var(in_x, dim=[2, 3], keepdim=True)
298 | out_in = (in_x - in_mean) / torch.sqrt(in_var + self.eps)
299 | ln_mean, ln_var = torch.mean(in_x, dim=[1, 2, 3], keepdim=True), torch.var(in_x, dim=[1, 2, 3], keepdim=True)
300 | out_ln = (in_x - ln_mean) / torch.sqrt(ln_var + self.eps)
301 | out = self.rho.expand(in_x.shape[0], -1, -1, -1) * out_in + (
302 | 1 - self.rho.expand(in_x.shape[0], -1, -1, -1)) * out_ln
303 | out = out * self.gamma.expand(in_x.shape[0], -1, -1, -1) + self.beta.expand(in_x.shape[0], -1, -1, -1)
304 |
305 | return out
306 |
307 |
308 | class Discriminator(nn.Module):
309 | def __init__(self, input_nc, ndf=64, n_layers=5, with_sn=True, use_cam_attention=True):
310 | super(Discriminator, self).__init__()
311 | self.use_cam_attention = use_cam_attention
312 | conv1 = nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)
313 | model = [nn.ReflectionPad2d(1),
314 | nn.utils.spectral_norm(conv1) if with_sn else conv1,
315 | nn.LeakyReLU(0.2, True)]
316 |
317 | for i in range(1, n_layers - 2):
318 | mult = 2 ** (i - 1)
319 | conv = nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)
320 | model += [nn.ReflectionPad2d(1),
321 | nn.utils.spectral_norm(conv) if with_sn else conv,
322 | nn.LeakyReLU(0.2, True)]
323 | if n_layers < 5:
324 | mult = 2 ** (n_layers - 2 - 1)
325 | for i in range(0, 5 - n_layers):
326 | conv = nn.Conv2d(ndf * mult, ndf * mult, kernel_size=3, stride=1, padding=0, bias=True)
327 | model += [nn.ReflectionPad2d(1),
328 | nn.utils.spectral_norm(conv) if with_sn else conv,
329 | nn.LeakyReLU(0.2, True)]
330 |
331 | mult = 2 ** (n_layers - 2 - 1)
332 | conv2 = nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)
333 | model += [nn.ReflectionPad2d(1),
334 | nn.utils.spectral_norm(conv2) if with_sn else conv2,
335 | nn.LeakyReLU(0.2, True)]
336 |
337 | # Class Activation Map
338 | mult = 2 ** (n_layers - 2)
339 | linear_gmp = nn.Linear(ndf * mult, 1, bias=True)
340 | linear_gap = nn.Linear(ndf * mult, 1, bias=True)
341 | self.gmp_fc = nn.utils.spectral_norm(linear_gmp) if with_sn else linear_gmp
342 | self.gap_fc = nn.utils.spectral_norm(linear_gap) if with_sn else linear_gap
343 | self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
344 | self.leaky_relu = nn.LeakyReLU(0.2, True)
345 |
346 | self.pad = nn.ReflectionPad2d(1)
347 | conv3 = nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=True)
348 | self.conv = nn.utils.spectral_norm(conv3) if with_sn else conv3
349 |
350 | self.model = nn.Sequential(*model)
351 |
352 | def forward(self, in_x, cam_input=None, mask=None):
353 | x = self.model(in_x)
354 |
355 | # 如果专门给CAM用的输入(例如将背景抹除的输入,这样可以使背景)不是None,则用这个输入计算logit和对应CAM
356 | cam_x = self.model(cam_input) if cam_input is not None else x
357 |
358 | if mask is not None:
359 | cam_x = torch.nn.functional.interpolate(mask, cam_x.shape[2:], mode='area') * cam_x
360 |
361 | gap = torch.nn.functional.adaptive_avg_pool2d(cam_x, 1)
362 | gap_logit = self.gap_fc(gap.view(cam_x.shape[0], -1))
363 | if self.use_cam_attention:
364 | gap_weight = sum(list(self.gap_fc.parameters()))
365 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
366 | else:
367 | gap = x
368 |
369 | gmp = torch.nn.functional.adaptive_max_pool2d(cam_x, 1)
370 | gmp_logit = self.gmp_fc(gmp.view(cam_x.shape[0], -1))
371 | if self.use_cam_attention:
372 | gmp_weight = sum(list(self.gmp_fc.parameters()))
373 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
374 | else:
375 | gmp = x
376 |
377 | cam_logit = torch.cat([gap_logit, gmp_logit], 1)
378 |
379 | x = torch.cat([gap, gmp], 1)
380 | x = self.leaky_relu(self.conv1x1(x))
381 |
382 | heatmap = torch.sum(x, dim=1, keepdim=True)
383 |
384 | x = self.pad(x)
385 | out = self.conv(x)
386 |
387 | return out, cam_logit, heatmap
388 |
389 |
390 | class RhoClipper(object):
391 |
392 | def __init__(self, min_num, max_num, module_type=None):
393 | self.clip_min = min_num
394 | self.clip_max = max_num
395 | self.module_type = module_type
396 | assert min_num < max_num
397 |
398 | def __call__(self, module):
399 | if (self.module_type is None and hasattr(module, 'rho')) or (type(module) == self.module_type):
400 | w = module.rho.data
401 | w = w.clamp(self.clip_min, self.clip_max)
402 | module.rho.data = w
403 |
--------------------------------------------------------------------------------
/UGATIT.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import time
4 | import copy
5 | import itertools
6 | from glob import glob
7 | from typing import Union
8 |
9 | import cv2
10 | import PIL
11 | import numpy as np
12 | from tqdm import tqdm
13 | import torch
14 | from torch import nn
15 | import torch.nn.functional as F # noqa
16 | from torchvision import transforms
17 | from torch.utils.tensorboard import SummaryWriter
18 |
19 | from faceseg.FaceSegmentation import FaceSegmentation
20 | from utils import (calc_tv_loss, AverageMeter, ProgressMeter, generate_blur_images,
21 | RGB2BGR, tensor2numpy, attention_mask, cam, denorm)
22 | from dataset import MatchHistogramsDataset, DatasetFolder, get_loader
23 | from metrics import FIDScore
24 |
25 |
26 | class UGATIT(object):
27 | def __init__(self, args):
28 | self.args = args
29 |
30 | if self.args.light > 0:
31 | self.model_name = 'UGATIT_light' + str(self.args.light)
32 | else:
33 | self.model_name = 'UGATIT'
34 |
35 | print(f'\n##### Information #####\n'
36 | f'# light : {self.args.light}\n'
37 | f'# dataset : {self.args.dataset}\n'
38 | f'# batch_size : {self.args.batch_size}\n'
39 | f'# num_workers : {self.args.num_workers}\n'
40 | f'# ema_start : {self.args.ema_start}\n'
41 | f'# ema_beta : {self.args.ema_beta}\n'
42 | f'# iteration : {self.args.iteration}\n'
43 | f'# is decay : {self.args.no_decay_flag}\n'
44 | f'##### Data #####\n'
45 | f'# img_size : {self.args.img_size}\n'
46 | f'# aug_prob : {self.args.aug_prob}\n'
47 | f'# match_histograms : {self.args.match_histograms}\n'
48 | f'# match_mode : {self.args.match_mode}\n'
49 | f'# match_prob : {self.args.match_prob}\n'
50 | f'# match_ratio : {self.args.match_ratio}\n'
51 | f'##### Generator #####\n'
52 | f'# residual blocks : {self.args.n_res}\n'
53 | f'# use se or not : {self.args.use_se}\n'
54 | f'# use blur or not : {self.args.has_blur}\n'
55 | f'# tv_loss : {self.args.tv_loss}\n'
56 | f'# tv_weight : {self.args.tv_weight}\n'
57 | f'# use attention gan : {self.args.attention_gan}\n'
58 | f'# use attention input : {self.args.attention_input}\n'
59 | f'##### Discriminator #####\n'
60 | f'# global discriminator layer : {self.args.n_global_dis}\n'
61 | f'# local discriminator layer : {self.args.n_local_dis}\n'
62 | f'##### Weight #####\n'
63 | f'# adv_weight : {self.args.adv_weight}\n'
64 | f'# forward_adv_weight : {self.args.forward_adv_weight}\n'
65 | f'# backward_adv_weight : {self.args.backward_adv_weight}\n'
66 | f'# cycle_weight : {self.args.cycle_weight}\n'
67 | f'# identity_weight : {self.args.identity_weight}\n'
68 | f'# cam_weight : {self.args.cam_weight}\n'
69 | f'##### Enhanced #####\n'
70 | f'# cam_D_weight : {self.args.cam_D_weight}\n'
71 | f'# cam_D_attention : {self.args.cam_D_attention}\n'
72 | f'##### Segment #####\n'
73 | f'# hard_seg_edge : {self.args.hard_seg_edge}\n'
74 | f'# seg_fix_weight : {self.args.seg_fix_weight}\n'
75 | f'# seg_fix_glass_mouth : {self.args.seg_fix_glass_mouth}\n'
76 | f'# seg_D_mask : {self.args.seg_D_mask}\n'
77 | f'# seg_G_detach : {self.args.seg_G_detach}\n'
78 | f'# seg_D_cam_fea_mask : {self.args.seg_D_cam_fea_mask}\n'
79 | f'# seg_D_cam_inp_mask : {self.args.seg_D_cam_inp_mask}\n'
80 | f'# resume : {self.args.resume}\n\n'
81 | )
82 |
83 | self.use_seg = ((self.args.seg_fix_weight > 0) or self.args.seg_D_mask or self.args.seg_G_detach or
84 | self.args.seg_D_cam_fea_mask or self.args.seg_D_cam_inp_mask)
85 |
86 | self.genA2B, self.genB2A = None, None
87 | self.genA2B_ema, self.genB2A_ema = None, None
88 | self.disGA, self.disGB, self.disLA, self.disLB = None, None, None, None
89 | self.FaceSeg = None
90 | self.maskA, self.maskA_erode, self.maskB, self.maskB_erode = None, None, None, None
91 | self.trainA_data_root = os.path.join('dataset', self.args.dataset, 'trainA')
92 | self.trainB_data_root = os.path.join('dataset', self.args.dataset, 'trainB')
93 | self.testA_data_root = os.path.join('dataset', self.args.dataset, 'testA')
94 | self.testB_data_root = os.path.join('dataset', self.args.dataset, 'testB')
95 | self.blurA_data_root = os.path.join('dataset', self.args.dataset, 'blurA')
96 | self.blurB_data_root = os.path.join('dataset', self.args.dataset, 'blurB')
97 | self.train_transform, self.test_transform = None, None
98 | self.trainAB_loader, self.blurAB_loader = None, None
99 | self.trainAB_iter, self.blurAB_iter = None, None
100 | self.testA_loader, self.testB_loader = None, None
101 | self.testA_iter, self.testB_iter = None, None
102 | self.L1_loss, self.MSE_loss, self.BCE_loss = None, None, None
103 | self.G_optim, self.D_optim = None, None
104 | self.Rho_LIN_clipper, self.Rho_AdaLIN_clipper = None, None
105 | self.G_adv_loss, self.G_cyc_loss, self.G_idt_loss, self.G_cam_loss = None, None, None, None
106 | self.Generator_loss, self.G_seg_loss, self.tv_loss = None, None, None
107 | self.discriminator_loss = None
108 | self.fid_score, self.mean_std_A, self.mean_std_B = None, None, None
109 | self.fid_loaderA, self.fid_loaderB = None, None
110 |
111 | ##################################################################################
112 | # Model
113 | ##################################################################################
114 |
115 | def build_data_loader(self):
116 | """ 构造data loader """
117 | self.train_transform = transforms.Compose([
118 | PIL.Image.fromarray,
119 | transforms.RandomHorizontalFlip(),
120 | transforms.RandomApply(
121 | [transforms.RandomResizedCrop(size=self.args.img_size, scale=(0.748, 1.0), ratio=(1.0, 1.0),
122 | interpolation=transforms.InterpolationMode.BICUBIC)],
123 | p=self.args.aug_prob),
124 | transforms.Resize(size=self.args.img_size, interpolation=transforms.InterpolationMode.BICUBIC),
125 | transforms.ToTensor(),
126 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
127 | ])
128 | self.test_transform = transforms.Compose([
129 | PIL.Image.fromarray,
130 | transforms.Resize((self.args.img_size, self.args.img_size),
131 | interpolation=transforms.InterpolationMode.BICUBIC),
132 | transforms.ToTensor(),
133 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
134 | ])
135 |
136 | trainAB = MatchHistogramsDataset((self.trainA_data_root, self.trainB_data_root),
137 | self.train_transform, is_match_histograms=self.args.match_histograms,
138 | match_mode=self.args.match_mode, b2a_prob=self.args.match_prob,
139 | match_ratio=self.args.match_ratio)
140 | self.trainAB_loader = get_loader(trainAB, self.args.device, batch_size=self.args.batch_size,
141 | shuffle=True, num_workers=self.args.num_workers)
142 | testA = DatasetFolder(self.testA_data_root, self.test_transform)
143 | testB = DatasetFolder(self.testB_data_root, self.test_transform)
144 | self.testA_loader = get_loader(testA, self.args.device, batch_size=1, shuffle=False,
145 | num_workers=self.args.num_workers)
146 | self.testB_loader = get_loader(testB, self.args.device, batch_size=1, shuffle=False,
147 | num_workers=self.args.num_workers)
148 |
149 | # 使用模糊图像增强判别器D对模糊的判别,从而增强生成器G生成清晰图像
150 | if self.args.has_blur:
151 | if not os.path.exists(self.blurA_data_root):
152 | generate_blur_images(self.trainA_data_root, self.blurA_data_root)
153 | if not os.path.exists(self.blurB_data_root):
154 | generate_blur_images(self.trainB_data_root, self.blurB_data_root)
155 |
156 | blurAB = MatchHistogramsDataset((self.blurA_data_root, self.blurB_data_root), self.train_transform,
157 | is_match_histograms=self.args.match_histograms,
158 | match_mode=self.args.match_mode, b2a_prob=self.args.match_prob,
159 | match_ratio=self.args.match_ratio)
160 | self.blurAB_loader = get_loader(blurAB, self.args.device, batch_size=self.args.batch_size,
161 | shuffle=True, num_workers=self.args.num_workers)
162 |
163 | def build_model(self):
164 | """ 构造data loader,Generator,Discriminator 模型,损失,优化器 """
165 | from networks import ResnetGenerator, Discriminator, RhoClipper, LIN, AdaLIN
166 | self.build_data_loader()
167 | # Define Generator, Discriminator
168 | self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.args.ch, n_blocks=self.args.n_res,
169 | img_size=self.args.img_size, args=self.args).to(self.args.device)
170 | self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.args.ch, n_blocks=self.args.n_res,
171 | img_size=self.args.img_size, args=self.args).to(self.args.device)
172 | self.genA2B_ema = copy.deepcopy(self.genA2B).eval().requires_grad_(False)
173 | self.genB2A_ema = copy.deepcopy(self.genB2A).eval().requires_grad_(False)
174 | self.disGA = Discriminator(input_nc=3, ndf=self.args.ch, n_layers=self.args.n_global_dis, with_sn=self.args.sn,
175 | use_cam_attention=self.args.cam_D_attention).to(self.args.device)
176 | self.disGB = Discriminator(input_nc=3, ndf=self.args.ch, n_layers=self.args.n_global_dis, with_sn=self.args.sn,
177 | use_cam_attention=self.args.cam_D_attention).to(self.args.device)
178 | self.disLA = Discriminator(input_nc=3, ndf=self.args.ch, n_layers=self.args.n_local_dis, with_sn=self.args.sn,
179 | use_cam_attention=self.args.cam_D_attention).to(self.args.device)
180 | self.disLB = Discriminator(input_nc=3, ndf=self.args.ch, n_layers=self.args.n_local_dis, with_sn=self.args.sn,
181 | use_cam_attention=self.args.cam_D_attention).to(self.args.device)
182 |
183 | # 使用分割区域做L2监督损失,或,分割出来的区域随机填充颜色的填充概率值
184 | if self.use_seg:
185 | self.FaceSeg = FaceSegmentation(self.args.device)
186 |
187 | # Define Loss
188 | self.L1_loss = nn.L1Loss().to(self.args.device)
189 | self.MSE_loss = nn.MSELoss().to(self.args.device)
190 | self.BCE_loss = nn.BCEWithLogitsLoss().to(self.args.device)
191 |
192 | # 优化器
193 | gen_params = itertools.chain(self.genA2B.parameters(), self.genB2A.parameters())
194 | self.G_optim = torch.optim.Adam(gen_params, lr=self.args.lr, betas=(0.5, 0.999),
195 | weight_decay=self.args.weight_decay)
196 | disc_params = itertools.chain(self.disGA.parameters(), self.disGB.parameters(),
197 | self.disLA.parameters(), self.disLB.parameters())
198 | self.D_optim = torch.optim.Adam(disc_params, lr=self.args.lr, betas=(0.5, 0.999),
199 | weight_decay=self.args.weight_decay)
200 |
201 | # Define Rho clipper to constraint the value of rho in AdaLIN and LIN
202 | # self.Rho_clipper = RhoClipper(0, 1)
203 | self.Rho_LIN_clipper = RhoClipper(0, 1, LIN)
204 | self.Rho_AdaLIN_clipper = RhoClipper(0.0, 0.9, AdaLIN)
205 |
206 | ##################################################################################
207 | # 工具函数
208 | ##################################################################################
209 |
210 | def gen_train(self, on=True):
211 | """ 开启生成网络训练模式 """
212 | if on:
213 | self.genA2B.train(), self.genB2A.train()
214 | else:
215 | self.genA2B.eval(), self.genB2A.eval()
216 |
217 | def dis_train(self, on=True):
218 | """ 开启判别网络训练模式 """
219 | if on:
220 | self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()
221 | else:
222 | self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
223 |
224 | def get_batch(self, mode='train'):
225 | """ 获取训练数据 """
226 | if mode == 'train':
227 | try:
228 | real_A, real_B = next(self.trainAB_iter)
229 | except (StopIteration, TypeError):
230 | self.trainAB_iter = iter(self.trainAB_loader)
231 | real_A, real_B = next(self.trainAB_iter)
232 | else:
233 | try:
234 | real_A = next(self.testA_iter)
235 | except (StopIteration, TypeError):
236 | self.testA_iter = iter(self.testA_loader)
237 | real_A = next(self.testA_iter)
238 |
239 | try:
240 | real_B = next(self.testB_iter)
241 | except (StopIteration, TypeError):
242 | self.testB_iter = iter(self.testB_loader)
243 | real_B = next(self.testB_iter)
244 |
245 | real_A, real_B = real_A.to(self.args.device, non_blocking=True), real_B.to(self.args.device, non_blocking=True)
246 |
247 | blur = None
248 | if self.args.has_blur and mode == 'train':
249 | try:
250 | blur_A, blur_B = next(self.blurAB_iter)
251 | except (StopIteration, TypeError):
252 | self.blurAB_iter = iter(self.blurAB_loader)
253 | blur_A, blur_B = next(self.blurAB_iter)
254 |
255 | blur = (blur_A.to(self.args.device, non_blocking=True), blur_B.to(self.args.device, non_blocking=True))
256 |
257 | return real_A, real_B, blur
258 |
259 | ##################################################################################
260 | # 训练
261 | ##################################################################################
262 |
263 | def forward(self, real_A, real_B):
264 | """ 前向推理:A->B->A, B->A->B, A->A, B->B """
265 | # cycle
266 | fake_A2B, fake_A2B_cam_logit, fake_A2B_heatmap, fake_A2B_attention = self.genA2B(real_A)
267 | fake_A2B2A, _, fake_A2B2A_heatmap, fake_A2B2A_attention = self.genB2A(fake_A2B)
268 | fake_B2A, fake_B2A_cam_logit, fake_B2A_heatmap, fake_B2A_attention = self.genB2A(real_B)
269 | fake_B2A2B, _, fake_B2A2B_heatmap, fake_B2A2B_attention = self.genA2B(fake_B2A)
270 | # 单位映射
271 | fake_A2A, fake_A2A_cam_logit, fake_A2A_heatmap, fake_A2A_attention = self.genB2A(real_A)
272 | fake_B2B, fake_B2B_cam_logit, fake_B2B_heatmap, fake_B2B_attention = self.genA2B(real_B)
273 |
274 | # 根据人脸分割,获取分割区域 self.maskA (==1)
275 | if self.use_seg:
276 | maskA = self.FaceSeg.face_segmentation(real_A)
277 | self.maskA = self.FaceSeg.gen_mask(maskA, is_soft_edge=not self.args.hard_seg_edge,
278 | normal_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17),
279 | dilate_parts=(), erode_parts=())
280 | self.maskA_erode = self.FaceSeg.gen_mask(maskA, normal_parts=(), dilate_parts=(),
281 | is_soft_edge=not self.args.hard_seg_edge,
282 | erode_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17))
283 | if self.args.seg_fix_glass_mouth:
284 | maskA_erode = self.FaceSeg.gen_mask(maskA, normal_parts=(1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17),
285 | dilate_parts=(), is_soft_edge=not self.args.hard_seg_edge,
286 | erode_parts=(6, ))
287 | self.maskA_erode *= maskA_erode
288 | maskB = self.FaceSeg.face_segmentation(real_B)
289 | self.maskB = self.FaceSeg.gen_mask(maskB, is_soft_edge=not self.args.hard_seg_edge,
290 | normal_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17),
291 | dilate_parts=(), erode_parts=())
292 | self.maskB_erode = self.FaceSeg.gen_mask(maskB, normal_parts=(), dilate_parts=(),
293 | is_soft_edge=not self.args.hard_seg_edge,
294 | erode_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17))
295 | if self.args.seg_fix_glass_mouth:
296 | maskB_erode = self.FaceSeg.gen_mask(maskB, normal_parts=(1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17),
297 | dilate_parts=(), is_soft_edge=not self.args.hard_seg_edge,
298 | erode_parts=(6, ))
299 | self.maskB_erode *= maskB_erode
300 |
301 | return (fake_A2B, fake_A2B_cam_logit, fake_A2B_heatmap, fake_A2B_attention,
302 | fake_A2B2A, fake_A2B2A_heatmap, fake_A2B2A_attention,
303 | fake_B2A, fake_B2A_cam_logit, fake_B2A_heatmap, fake_B2A_attention,
304 | fake_B2A2B, fake_B2A2B_heatmap, fake_B2A2B_attention,
305 | fake_A2A, fake_A2A_cam_logit, fake_A2A_heatmap, fake_A2A_attention,
306 | fake_B2B, fake_B2B_cam_logit, fake_B2B_heatmap, fake_B2B_attention)
307 |
308 | def backward_D(self, real_A, real_B, fake_A2B, fake_B2A, blur=None): # noqa
309 | """ D网络前向+反向计算 """
310 | fake_A2B, fake_B2A = fake_A2B.detach(), fake_B2A.detach()
311 | # 将生成图像的分割区域(膨胀)随机填充颜色,用于后续D训练的cam
312 | cam_real_A, cam_fake_A2B, cam_real_B, cam_fake_B2A = None, None, None, None
313 | if self.args.seg_D_cam_inp_mask:
314 | cam_real_A = self.maskA * real_A
315 | cam_fake_A2B = self.maskA * fake_A2B
316 | cam_real_B = self.maskB * real_B
317 | cam_fake_B2A = self.maskB * fake_B2A
318 |
319 | maskA, maskB = None, None
320 | if self.args.seg_D_cam_fea_mask:
321 | maskA, maskB = self.maskA, self.maskB
322 |
323 | real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A, cam_real_A, maskA)
324 | real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A, cam_real_A, maskA)
325 | fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A, cam_fake_B2A, maskB)
326 | fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A, cam_fake_B2A, maskB)
327 |
328 | real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B, cam_real_B, maskB)
329 | real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B, cam_real_B, maskB)
330 | fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B, cam_fake_A2B, maskA)
331 | fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B, cam_fake_A2B, maskA)
332 |
333 | # 设置目标常量,GA和LA的D网络输出shape不一致,但cam的分类输出是一致的
334 | flag_GA_1 = torch.ones_like(real_GA_logit, requires_grad=False).to(self.args.device)
335 | flag_GA_0 = torch.zeros_like(fake_GA_logit, requires_grad=False).to(self.args.device)
336 | flag_LA_1 = torch.ones_like(real_LA_logit, requires_grad=False).to(self.args.device)
337 | flag_LA_0 = torch.zeros_like(fake_LA_logit, requires_grad=False).to(self.args.device)
338 | flag_cam_1 = torch.ones_like(real_GA_cam_logit, requires_grad=False).to(self.args.device)
339 | flag_cam_0 = torch.zeros_like(fake_GA_cam_logit, requires_grad=False).to(self.args.device)
340 |
341 | # D网络损失函数:cam和D网络损失
342 | D_loss_GA, D_cam_loss_GA = 0, 0
343 | D_loss_LA, D_cam_loss_LA = 0, 0
344 | D_loss_GB, D_cam_loss_GB = 0, 0
345 | D_loss_LB, D_cam_loss_LB = 0, 0
346 | if blur is not None:
347 | blur_A, blur_B = blur
348 | cam_blur_A, cam_blur_B = None, None
349 | if self.args.seg_D_cam_inp_mask:
350 | cam_blur_A = self.maskA * blur_A
351 | cam_blur_B = self.maskB * blur_B
352 | blur_GA_logit, blur_GA_cam_logit, _ = self.disGA(blur_A, cam_blur_A, maskA)
353 | blur_LA_logit, blur_LA_cam_logit, _ = self.disLA(blur_A, cam_blur_A, maskA)
354 | blur_GB_logit, blur_GB_cam_logit, _ = self.disGB(blur_B, cam_blur_B, maskB)
355 | blur_LB_logit, blur_LB_cam_logit, _ = self.disLB(blur_B, cam_blur_B, maskB)
356 | # 无论使用什么策略,模糊图像都是false
357 | D_loss_GA = self.MSE_loss(blur_GA_logit, flag_GA_0)
358 | D_loss_LA = self.MSE_loss(blur_LA_logit, flag_LA_0)
359 | D_loss_GB = self.MSE_loss(blur_GB_logit, flag_GA_0)
360 | D_loss_LB = self.MSE_loss(blur_LB_logit, flag_LA_0)
361 | if self.args.cam_D_weight > 0:
362 | D_cam_loss_GA = self.MSE_loss(blur_GA_cam_logit, flag_cam_0)
363 | D_cam_loss_LA = self.MSE_loss(blur_LA_cam_logit, flag_cam_0)
364 | D_cam_loss_GB = self.MSE_loss(blur_GB_cam_logit, flag_cam_0)
365 | D_cam_loss_LB = self.MSE_loss(blur_LB_cam_logit, flag_cam_0)
366 |
367 | # 只计算分割区域的损失
368 | if self.args.seg_D_mask:
369 | maskA_G = F.interpolate(self.maskA, fake_GA_logit.shape[2:], mode='area')
370 | maskA_L = F.interpolate(self.maskA, fake_LA_logit.shape[2:], mode='area')
371 | maskB_G = F.interpolate(self.maskB, fake_GB_logit.shape[2:], mode='area')
372 | maskB_L = F.interpolate(self.maskB, fake_LB_logit.shape[2:], mode='area')
373 |
374 | real_GA_logit = real_GA_logit * maskA_G + flag_GA_1 * (1 - maskA_G)
375 | fake_GA_logit = fake_GA_logit * maskB_G
376 | real_LA_logit = real_LA_logit * maskA_L + flag_LA_1 * (1 - maskA_L)
377 | fake_LA_logit = fake_LA_logit * maskB_L
378 | real_GB_logit = real_GB_logit * maskB_G + flag_GA_1 * (1 - maskB_G)
379 | fake_GB_logit = fake_GB_logit * maskA_G
380 | real_LB_logit = real_LB_logit * maskB_L + flag_LA_1 * (1 - maskB_L)
381 | fake_LB_logit = fake_LB_logit * maskA_L
382 |
383 | if self.args.cam_D_weight > 0:
384 | D_cam_loss_GA += (self.MSE_loss(real_GA_cam_logit, flag_cam_1) + self.MSE_loss(fake_GA_cam_logit, flag_cam_0)) # noqa, E501
385 | D_cam_loss_LA += (self.MSE_loss(real_LA_cam_logit, flag_cam_1) + self.MSE_loss(fake_LA_cam_logit, flag_cam_0)) # noqa, E501
386 | D_cam_loss_GB += (self.MSE_loss(real_GB_cam_logit, flag_cam_1) + self.MSE_loss(fake_GB_cam_logit, flag_cam_0)) # noqa, E501
387 | D_cam_loss_LB += (self.MSE_loss(real_LB_cam_logit, flag_cam_1) + self.MSE_loss(fake_LB_cam_logit, flag_cam_0)) # noqa, E501
388 | D_loss_GA += (self.MSE_loss(real_GA_logit, flag_GA_1) + self.MSE_loss(fake_GA_logit, flag_GA_0))
389 | D_loss_LA += (self.MSE_loss(real_LA_logit, flag_LA_1) + self.MSE_loss(fake_LA_logit, flag_LA_0))
390 | D_loss_A = D_loss_GA + D_loss_LA + (D_cam_loss_GA + D_cam_loss_LA) * self.args.cam_D_weight
391 | D_loss_A = self.args.forward_adv_weight * self.args.adv_weight * D_loss_A
392 |
393 | D_loss_GB += (self.MSE_loss(real_GB_logit, flag_GA_1) + self.MSE_loss(fake_GB_logit, flag_GA_0))
394 | D_loss_LB += (self.MSE_loss(real_LB_logit, flag_LA_1) + self.MSE_loss(fake_LB_logit, flag_LA_0))
395 | D_loss_B = D_loss_GB + D_loss_LB + (D_cam_loss_GB + D_cam_loss_LB) * self.args.cam_D_weight
396 | D_loss_B = self.args.backward_adv_weight * self.args.adv_weight * D_loss_B
397 |
398 | self.discriminator_loss = D_loss_A + D_loss_B
399 | # backward
400 | self.discriminator_loss.backward()
401 | return D_loss_A, D_loss_B
402 |
403 | def backward_G(self, real_A, real_B, fake_A2B, fake_B2A, fake_A2B2A, fake_B2A2B, fake_A2A, fake_B2B, # noqa
404 | fake_A2B_cam_logit, fake_B2A_cam_logit, fake_A2A_cam_logit, fake_B2B_cam_logit):
405 | self.Generator_loss: Union[int, torch.tensor] = 0
406 | # 根据人脸分割,获取背景不变性损失项
407 | if self.args.seg_fix_weight > 0:
408 | G_seg_loss_B = self.L1_loss(fake_A2B * (1 - self.maskA_erode), real_A * (1 - self.maskA_erode))
409 | G_seg_loss_A = self.L1_loss(fake_B2A * (1 - self.maskB_erode), real_B * (1 - self.maskB_erode))
410 | self.G_seg_loss = self.args.seg_fix_weight * (G_seg_loss_A + G_seg_loss_B)
411 | self.Generator_loss += self.G_seg_loss
412 |
413 | if self.args.tv_loss:
414 | self.tv_loss = calc_tv_loss(fake_A2B, mask=self.maskA) + calc_tv_loss(fake_B2A, mask=self.maskB)
415 | self.tv_loss *= self.args.tv_weight
416 | self.Generator_loss += self.tv_loss
417 |
418 | # 将生成图像的背景detach掉,使背景上的对抗损失梯度不影响G
419 | if self.args.seg_G_detach:
420 | fake_A2B = fake_A2B * self.maskA + fake_A2B.detach() * (1.0 - self.maskA)
421 | fake_B2A = fake_B2A * self.maskB + fake_B2A.detach() * (1.0 - self.maskB)
422 |
423 | cam_fake_A2B, cam_fake_B2A = None, None
424 | if self.args.seg_D_cam_inp_mask:
425 | cam_fake_A2B = self.maskA * fake_A2B
426 | cam_fake_B2A = self.maskB * fake_B2A
427 |
428 | maskA, maskB = None, None
429 | if self.args.seg_D_cam_fea_mask:
430 | maskA, maskB = self.maskA, self.maskB
431 |
432 | # 判别器输出
433 | fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A, cam_fake_B2A, maskB)
434 | fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A, cam_fake_B2A, maskB)
435 | fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B, cam_fake_A2B, maskA)
436 | fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B, cam_fake_A2B, maskA)
437 | # 设置目标常量,GA和LA的D网络输出shape不一致,但cam的分类输出是一致的
438 | flag_GA_1 = torch.ones_like(fake_GA_logit, requires_grad=False).to(self.args.device)
439 | flag_LA_1 = torch.ones_like(fake_LA_logit, requires_grad=False).to(self.args.device)
440 | flag_GA_cam_1 = torch.ones_like(fake_GA_cam_logit, requires_grad=False).to(self.args.device)
441 |
442 | # 对抗损失
443 | if self.args.seg_D_mask:
444 | # 背景区域不需要对抗损失,置为目标值,等价于把背景区域的对抗损失置0
445 | maskB_G = F.interpolate(self.maskB, fake_GA_logit.shape[2:], mode='area')
446 | maskB_L = F.interpolate(self.maskB, fake_LA_logit.shape[2:], mode='area')
447 | maskA_G = F.interpolate(self.maskA, fake_GB_logit.shape[2:], mode='area')
448 | maskA_L = F.interpolate(self.maskA, fake_LB_logit.shape[2:], mode='area')
449 | fake_GA_logit = fake_GA_logit * maskB_G + flag_GA_1 * (1 - maskB_G)
450 | fake_LA_logit = fake_LA_logit * maskB_L + flag_LA_1 * (1 - maskB_L)
451 | fake_GB_logit = fake_GB_logit * maskA_G + flag_GA_1 * (1 - maskA_G)
452 | fake_LB_logit = fake_LB_logit * maskA_L + flag_LA_1 * (1 - maskA_L)
453 |
454 | if self.args.cam_D_weight > 0:
455 | G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, flag_GA_cam_1)
456 | G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, flag_GA_cam_1)
457 | G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, flag_GA_cam_1)
458 | G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, flag_GA_cam_1)
459 | else:
460 | G_ad_cam_loss_GA, G_ad_cam_loss_LA, G_ad_cam_loss_GB, G_ad_cam_loss_LB = 0, 0, 0, 0
461 | G_ad_loss_GA = self.MSE_loss(fake_GA_logit, flag_GA_1)
462 | G_ad_loss_LA = self.MSE_loss(fake_LA_logit, flag_LA_1)
463 | G_ad_loss_A = (G_ad_loss_GA + G_ad_loss_LA) + (G_ad_cam_loss_GA + G_ad_cam_loss_LA) * self.args.cam_D_weight
464 | G_ad_loss_A = self.args.adv_weight * self.args.forward_adv_weight * G_ad_loss_A
465 |
466 | G_ad_loss_GB = self.MSE_loss(fake_GB_logit, flag_GA_1)
467 | G_ad_loss_LB = self.MSE_loss(fake_LB_logit, flag_LA_1)
468 | G_ad_loss_B = (G_ad_loss_GB + G_ad_loss_LB) + (G_ad_cam_loss_GB + G_ad_cam_loss_LB) * self.args.cam_D_weight
469 | G_ad_loss_B = self.args.adv_weight * self.args.backward_adv_weight * G_ad_loss_B
470 | # 循环一致性损失
471 | G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) * self.args.cycle_weight
472 | G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) * self.args.cycle_weight
473 | # 单位映射损失
474 | G_identity_loss_A = self.L1_loss(fake_A2A, real_A) * self.args.identity_weight
475 | G_identity_loss_B = self.L1_loss(fake_B2B, real_B) * self.args.identity_weight
476 | # G的cam损失
477 | flag_cam_1 = torch.ones_like(fake_B2A_cam_logit, requires_grad=False).to(self.args.device)
478 | flag_cam_0 = torch.zeros_like(fake_A2A_cam_logit, requires_grad=False).to(self.args.device)
479 | G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, flag_cam_1) + self.BCE_loss(fake_A2A_cam_logit, flag_cam_0)
480 | G_cam_loss_A *= self.args.cam_weight
481 | flag_cam_1 = torch.ones_like(fake_A2B_cam_logit, requires_grad=False).to(self.args.device)
482 | flag_cam_0 = torch.zeros_like(fake_B2B_cam_logit, requires_grad=False).to(self.args.device)
483 | G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, flag_cam_1) + self.BCE_loss(fake_B2B_cam_logit, flag_cam_0)
484 | G_cam_loss_B *= self.args.cam_weight
485 |
486 | G_loss_A = G_ad_loss_A + G_recon_loss_A + G_identity_loss_A + G_cam_loss_A
487 | G_loss_B = G_ad_loss_B + G_recon_loss_B + G_identity_loss_B + G_cam_loss_B
488 |
489 | self.G_adv_loss = G_ad_loss_A + G_ad_loss_B
490 | self.G_cyc_loss = G_recon_loss_A + G_recon_loss_B
491 | self.G_idt_loss = G_identity_loss_A + G_identity_loss_B
492 | self.G_cam_loss = G_cam_loss_A + G_cam_loss_B
493 |
494 | self.Generator_loss += (G_loss_A + G_loss_B)
495 |
496 | # backward
497 | self.Generator_loss.backward()
498 | return (G_ad_loss_A, G_recon_loss_A, G_identity_loss_A, G_cam_loss_A,
499 | G_ad_loss_B, G_recon_loss_B, G_identity_loss_B, G_cam_loss_B)
500 |
501 | def train(self):
502 | train_writer = SummaryWriter(os.path.join(self.args.result_dir, 'logs'))
503 | D_losses_A = AverageMeter('D_losses_A', ':.4e')
504 | D_losses_B = AverageMeter('D_losses_B', ':.4e')
505 | Discriminator_losses = AverageMeter('Discriminator_losses', ':.4e')
506 | G_ad_losses_A = AverageMeter('G_ad_losses_A', ':.4e')
507 | G_recon_losses_A = AverageMeter('G_recon_losses_A', ':.4e')
508 | G_identity_losses_A = AverageMeter('G_identity_losses_A', ':.4e')
509 | G_cam_losses_A = AverageMeter('G_cam_losses_A', ':.4e')
510 | G_ad_losses_B = AverageMeter('G_ad_losses_B', ':.4e')
511 | G_recon_losses_B = AverageMeter('G_recon_losses_B', ':.4e')
512 | G_identity_losses_B = AverageMeter('G_identity_losses_B', ':.4e')
513 | G_cam_losses_B = AverageMeter('G_cam_losses_B', ':.4e')
514 | Generator_losses = AverageMeter('Generator_losses', ':.4e')
515 | train_progress = ProgressMeter(self.args.iteration, D_losses_A, D_losses_B, Discriminator_losses,
516 | G_ad_losses_A, G_recon_losses_A, G_identity_losses_A, G_cam_losses_A,
517 | G_ad_losses_B, G_recon_losses_B, G_identity_losses_B, G_cam_losses_B,
518 | Generator_losses, prefix=f"Iteration: ")
519 |
520 | # 用于学习率 decay策略
521 | start_iter = 1
522 | mid_iter = self.args.iteration // 2
523 | lr_rate = self.args.lr / mid_iter
524 | if self.args.resume:
525 | model_list = glob(os.path.join(self.args.result_dir, self.args.dataset, 'model', '*.pt'))
526 | if not len(model_list) == 0:
527 | model_list.sort()
528 | start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
529 | self.load(os.path.join(self.args.result_dir, self.args.dataset, 'model'), start_iter)
530 | print(" [*] Load SUCCESS")
531 | if not self.args.no_decay_flag and start_iter > mid_iter:
532 | self.G_optim.param_groups[0]['lr'] -= lr_rate * (start_iter - mid_iter)
533 | self.D_optim.param_groups[0]['lr'] -= self.G_optim.param_groups[0]['lr']
534 |
535 | # training loop
536 | print('training start !')
537 | start_time = time.time()
538 | for step in range(start_iter, self.args.iteration + 1):
539 | if not self.args.no_decay_flag and step > mid_iter:
540 | self.G_optim.param_groups[0]['lr'] -= lr_rate
541 | self.D_optim.param_groups[0]['lr'] -= lr_rate
542 |
543 | real_A, real_B, blur = self.get_batch(mode='train')
544 |
545 | self.gen_train(True)
546 |
547 | (fake_A2B, fake_A2B_cam_logit, _, _,
548 | fake_A2B2A, _, _,
549 | fake_B2A, fake_B2A_cam_logit, _, _,
550 | fake_B2A2B, _, _,
551 | fake_A2A, fake_A2A_cam_logit, _, _,
552 | fake_B2B, fake_B2B_cam_logit, _, _) = self.forward(real_A, real_B)
553 |
554 | # Update D
555 | self.dis_train(True)
556 | self.D_optim.zero_grad()
557 | D_loss_A, D_loss_B = self.backward_D(real_A, real_B, fake_A2B, fake_B2A, blur)
558 | self.D_optim.step()
559 | # 更新统计量
560 | D_losses_A.update(D_loss_A.detach().cpu().item())
561 | D_losses_B.update(D_loss_B.detach().cpu().item())
562 | Discriminator_losses.update(self.discriminator_loss.detach().cpu().item())
563 |
564 | # Update G
565 | self.dis_train(False)
566 | self.G_optim.zero_grad()
567 | (G_ad_loss_A, G_recon_loss_A, G_identity_loss_A, G_cam_loss_A,
568 | G_ad_loss_B, G_recon_loss_B, G_identity_loss_B, G_cam_loss_B) = \
569 | self.backward_G(real_A, real_B, fake_A2B, fake_B2A, fake_A2B2A, fake_B2A2B, fake_A2A, fake_B2B,
570 | fake_A2B_cam_logit, fake_B2A_cam_logit, fake_A2A_cam_logit, fake_B2B_cam_logit)
571 | self.G_optim.step()
572 | self.gen_train(False)
573 |
574 | # clip parameter of AdaLIN and LIN, applied after optimizer step
575 | self.genA2B.apply(self.Rho_LIN_clipper)
576 | self.genB2A.apply(self.Rho_LIN_clipper)
577 | self.genA2B.apply(self.Rho_AdaLIN_clipper)
578 | self.genB2A.apply(self.Rho_AdaLIN_clipper)
579 | self.model_ema(step, self.genA2B_ema, self.genA2B)
580 | self.model_ema(step, self.genB2A_ema, self.genB2A)
581 |
582 | # 打印每一个step的损失
583 | info = f'[{step:5d}/{self.args.iteration:5d}] time: {(time.time() - start_time):4.4f} ' \
584 | f'd_loss: {self.discriminator_loss:.8f}, g_loss: {self.Generator_loss:.8f}, ' \
585 | f'g_adv: {self.G_adv_loss:.8f}, g_cyc: {self.G_cyc_loss:.8f}, ' \
586 | f'g_idt: {self.G_idt_loss:.8f}, g_cam: {self.G_cam_loss:.8f}'
587 | if self.args.seg_fix_weight > 0:
588 | info += f', g_seg: {self.G_seg_loss:.8f}'
589 | if self.args.tv_loss:
590 | info += f', g_tv: {self.tv_loss:.8f}'
591 | print(info)
592 |
593 | # 更新统计量
594 | G_ad_losses_A.update(G_ad_loss_A.detach().cpu().item(), real_A.size(0))
595 | G_recon_losses_A.update(G_recon_loss_A.detach().cpu().item(), real_A.size(0))
596 | G_identity_losses_A.update(G_identity_loss_A.detach().cpu().item(), real_A.size(0))
597 | G_cam_losses_A.update(G_cam_loss_A.detach().cpu().item(), real_A.size(0))
598 | G_ad_losses_B.update(G_ad_loss_B.detach().cpu().item(), real_B.size(0))
599 | G_recon_losses_B.update(G_recon_loss_B.detach().cpu().item(), real_B.size(0))
600 | G_identity_losses_B.update(G_identity_loss_B.detach().cpu().item(), real_B.size(0))
601 | G_cam_losses_B.update(G_cam_loss_B.detach().cpu().item(), real_B.size(0))
602 | Generator_losses.update(self.Generator_loss.detach().cpu().item(), real_A.size(0))
603 |
604 | # 可视化中间结果,计算fid,tensorboard统计
605 | if step % self.args.print_freq == 0:
606 | # 可视化中间结果
607 | self.vis_inference_result(step, train_sample_num=5, test_sample_num=5)
608 | if step > self.args.ema_start * self.args.iteration:
609 | temp = self.genA2B, self.genB2A
610 | self.genA2B, self.genB2A = self.genA2B_ema, self.genB2A_ema
611 | self.vis_inference_result(step, train_sample_num=5, test_sample_num=5, name='_ema')
612 | self.genA2B, self.genB2A = temp
613 | # 计算fid
614 | if step % self.args.calc_fid_freq == 0:
615 | temp_prefix = train_progress.prefix
616 | fid_score_A2B, fid_score_B2A = self.calc_fid_score()
617 | train_writer.add_scalar('13_fid_score_A2B', fid_score_A2B, step)
618 | train_writer.add_scalar('13_fid_score_B2A', fid_score_B2A, step)
619 | train_progress.prefix = f"Iteration: fid: A2B {fid_score_A2B:.4e}, B2A {fid_score_B2A:.4e}"
620 | if step > self.args.ema_start * self.args.iteration:
621 | temp = self.genA2B, self.genB2A
622 | self.genA2B, self.genB2A = self.genA2B_ema, self.genB2A_ema
623 | fid_score_A2B, fid_score_B2A = self.calc_fid_score()
624 | self.genA2B, self.genB2A = temp
625 | train_writer.add_scalar('14_fid_score_A2B_ema', fid_score_A2B, step)
626 | train_writer.add_scalar('14_fid_score_B2A_ema', fid_score_B2A, step)
627 | train_progress.prefix += f" A2B_ema {fid_score_A2B:.4e}, B2A_ema {fid_score_B2A:.4e}"
628 | train_progress.print(step)
629 | train_progress.prefix = temp_prefix
630 | else:
631 | train_progress.print(step)
632 |
633 | # 打印统计量
634 | train_writer.add_scalar('01_D_losses_A', D_losses_A.avg, step)
635 | train_writer.add_scalar('02_D_losses_B', D_losses_B.avg, step)
636 | train_writer.add_scalar('03_Discriminator_losses', Discriminator_losses.avg, step)
637 | train_writer.add_scalar('04_G_ad_losses_A', G_ad_losses_A.avg, step)
638 | train_writer.add_scalar('05_G_recon_losses_A', G_recon_losses_A.avg, step)
639 | train_writer.add_scalar('06_G_identity_losses_A', G_identity_losses_A.avg, step)
640 | train_writer.add_scalar('07_G_cam_losses_A', G_cam_losses_A.avg, step)
641 | train_writer.add_scalar('08_G_ad_losses_B', G_ad_losses_B.avg, step)
642 | train_writer.add_scalar('09_G_recon_losses_B', G_recon_losses_B.avg, step)
643 | train_writer.add_scalar('10_G_identity_losses_B', G_identity_losses_B.avg, step)
644 | train_writer.add_scalar('11_G_cam_losses_B', G_cam_losses_B.avg, step)
645 | train_writer.add_scalar('12_Generator_losses', Generator_losses.avg, step)
646 | train_writer.add_scalar('Learning rate', self.G_optim.param_groups[0]['lr'], step)
647 | train_writer.flush()
648 | D_losses_A.reset(), D_losses_B.reset(), Discriminator_losses.reset()
649 | G_ad_losses_A.reset(), G_recon_losses_A.reset(), G_identity_losses_A.reset()
650 | G_cam_losses_A.reset(), G_ad_losses_B.reset(), G_recon_losses_B.reset()
651 | G_identity_losses_B.reset(), G_cam_losses_B.reset(), Generator_losses.reset()
652 |
653 | if step % self.args.save_freq == 0 or step == self.args.iteration:
654 | self.save(os.path.join(self.args.result_dir, self.args.dataset, 'model'), step)
655 |
656 | # if step % 1000 == 0:
657 | # self.save(self.args.result_dir, step=None, name='_params_latest.pt')
658 | train_writer.close()
659 |
660 | def calc_fid_score(self):
661 | self.gen_train(False)
662 | if self.fid_score is None:
663 | self.fid_score = FIDScore(self.args.device, batch_size=self.args.fid_batch, num_workers=1)
664 | self.mean_std_A = self.fid_score.calc_mean_std(self.trainA_data_root)
665 | self.mean_std_B = self.fid_score.calc_mean_std(self.trainB_data_root)
666 | self.fid_loaderA = get_loader(DatasetFolder(self.trainA_data_root, self.test_transform), self.args.device,
667 | batch_size=self.args.fid_batch, shuffle=False,
668 | num_workers=self.args.num_workers)
669 | self.fid_loaderB = get_loader(DatasetFolder(self.trainB_data_root, self.test_transform), self.args.device,
670 | batch_size=self.args.fid_batch, shuffle=False,
671 | num_workers=self.args.num_workers)
672 | self.fid_score.inception_model.normalize_input = False
673 | mean_std_A2B = self.fid_score.calc_mean_std_with_gen(lambda batch: self.genA2B(batch)[0].detach(),
674 | self.fid_loaderA)
675 | fid_score_A2B = self.fid_score.calc_fid(self.mean_std_B, mean_std_A2B)
676 | mean_std_B2A = self.fid_score.calc_mean_std_with_gen(lambda batch: self.genB2A(batch)[0].detach(),
677 | self.fid_loaderB)
678 | fid_score_B2A = self.fid_score.calc_fid(self.mean_std_A, mean_std_B2A)
679 | return fid_score_A2B, fid_score_B2A
680 |
681 | def vis_inference_result(self, step, train_sample_num=5, test_sample_num=5, name=''):
682 | A2B = np.zeros((self.args.img_size * (9 + self.args.attention_gan), 0, 3))
683 | B2A = np.zeros((self.args.img_size * (9 + self.args.attention_gan), 0, 3))
684 | self.gen_train(False), self.dis_train(False)
685 | for _ in range(train_sample_num):
686 | real_A, real_B, _ = self.get_batch(mode='train')
687 | with torch.no_grad():
688 | (fake_A2B, _, fake_A2B_heatmap, fake_A2B_attention,
689 | fake_A2B2A, fake_A2B2A_heatmap, fake_A2B2A_attention,
690 | fake_B2A, _, fake_B2A_heatmap, fake_B2A_attention,
691 | fake_B2A2B, fake_B2A2B_heatmap, fake_B2A2B_attention,
692 | fake_A2A, _, fake_A2A_heatmap, fake_A2A_attention,
693 | fake_B2B, _, fake_B2B_heatmap, fake_B2B_attention) = \
694 | self.forward(real_A, real_B)
695 |
696 | cam_fake_A2B, cam_fake_B2A = None, None
697 | if self.args.seg_D_cam_inp_mask:
698 | cam_fake_A2B = self.maskA * fake_A2B
699 | cam_fake_B2A = self.maskB * fake_B2A
700 |
701 | maskA, maskB = None, None
702 | if self.args.seg_D_cam_fea_mask:
703 | maskA = self.maskA
704 | maskB = self.maskB
705 |
706 | _, _, fake_disGB_cam_hm = self.disGB(fake_A2B, cam_fake_A2B, maskA)
707 | _, _, fake_disLB_cam_hm = self.disLB(fake_A2B, cam_fake_A2B, maskA)
708 | _, _, fake_disGA_cam_hm = self.disGA(fake_B2A, cam_fake_B2A, maskB)
709 | _, _, fake_disLA_cam_hm = self.disLA(fake_B2A, cam_fake_B2A, maskB)
710 | A2B_list = [RGB2BGR(tensor2numpy(denorm(real_A[0]))),
711 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.args.img_size),
712 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
713 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.args.img_size),
714 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
715 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.args.img_size),
716 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0]))),
717 | cam(tensor2numpy(fake_disGB_cam_hm[0]), self.args.img_size),
718 | cam(tensor2numpy(fake_disLB_cam_hm[0]), self.args.img_size),
719 | ]
720 | if self.args.attention_gan > 0:
721 | for i in range(self.args.attention_gan):
722 | A2B_list.append(attention_mask(tensor2numpy(fake_A2B_attention[0][i:(i + 1)]),
723 | self.args.img_size))
724 | A2B = np.concatenate((A2B, np.concatenate(A2B_list, 0)), 1)
725 |
726 | B2A_list = [RGB2BGR(tensor2numpy(denorm(real_B[0]))),
727 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.args.img_size),
728 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
729 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.args.img_size),
730 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
731 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.args.img_size),
732 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0]))),
733 | cam(tensor2numpy(fake_disGA_cam_hm[0]), self.args.img_size),
734 | cam(tensor2numpy(fake_disLA_cam_hm[0]), self.args.img_size),
735 | ]
736 | if self.args.attention_gan > 0:
737 | for i in range(self.args.attention_gan):
738 | B2A_list.append(attention_mask(tensor2numpy(fake_B2A_attention[0][i:(i + 1)]),
739 | self.args.img_size))
740 | B2A = np.concatenate((B2A, np.concatenate(B2A_list, 0)), 1)
741 |
742 | for _ in range(test_sample_num):
743 | real_A, real_B, _ = self.get_batch(mode='test')
744 | with torch.no_grad():
745 | (fake_A2B, _, fake_A2B_heatmap, fake_A2B_attention,
746 | fake_A2B2A, fake_A2B2A_heatmap, fake_A2B2A_attention,
747 | fake_B2A, _, fake_B2A_heatmap, fake_B2A_attention,
748 | fake_B2A2B, fake_B2A2B_heatmap, fake_B2A2B_attention,
749 | fake_A2A, _, fake_A2A_heatmap, fake_A2A_attention,
750 | fake_B2B, _, fake_B2B_heatmap, fake_B2B_attention) = \
751 | self.forward(real_A, real_B)
752 |
753 | cam_fake_A2B, cam_fake_B2A = None, None
754 | if self.args.seg_D_cam_inp_mask:
755 | cam_fake_A2B = self.maskA * fake_A2B
756 | cam_fake_B2A = self.maskB * fake_B2A
757 |
758 | maskA, maskB = None, None
759 | if self.args.seg_D_cam_fea_mask:
760 | maskA = self.maskA
761 | maskB = self.maskB
762 |
763 | _, _, fake_disGB_cam_hm = self.disGB(fake_A2B, cam_fake_A2B, maskA)
764 | _, _, fake_disLB_cam_hm = self.disLB(fake_A2B, cam_fake_A2B, maskA)
765 | _, _, fake_disGA_cam_hm = self.disGA(fake_B2A, cam_fake_B2A, maskB)
766 | _, _, fake_disLA_cam_hm = self.disLA(fake_B2A, cam_fake_B2A, maskB)
767 | A2B_list = [RGB2BGR(tensor2numpy(denorm(real_A[0]))),
768 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.args.img_size),
769 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
770 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.args.img_size),
771 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
772 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.args.img_size),
773 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0]))),
774 | cam(tensor2numpy(fake_disGB_cam_hm[0]), self.args.img_size),
775 | cam(tensor2numpy(fake_disLB_cam_hm[0]), self.args.img_size),
776 | ]
777 | if self.args.attention_gan > 0:
778 | for i in range(self.args.attention_gan):
779 | A2B_list.append(attention_mask(tensor2numpy(fake_A2B_attention[0][i:(i + 1)]),
780 | self.args.img_size))
781 | A2B = np.concatenate((A2B, np.concatenate(A2B_list, 0)), 1)
782 |
783 | B2A_list = [RGB2BGR(tensor2numpy(denorm(real_B[0]))),
784 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.args.img_size),
785 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
786 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.args.img_size),
787 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
788 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.args.img_size),
789 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0]))),
790 | cam(tensor2numpy(fake_disGA_cam_hm[0]), self.args.img_size),
791 | cam(tensor2numpy(fake_disLA_cam_hm[0]), self.args.img_size),
792 | ]
793 | if self.args.attention_gan > 0:
794 | for i in range(self.args.attention_gan):
795 | B2A_list.append(attention_mask(tensor2numpy(fake_B2A_attention[0][i:(i + 1)]),
796 | self.args.img_size))
797 | B2A = np.concatenate((B2A, np.concatenate(B2A_list, 0)), 1)
798 |
799 | cv2.imwrite(os.path.join(self.args.result_dir, self.args.dataset, 'img', f'A2B{name}_{step:07d}.png'),
800 | A2B * 255.0)
801 | cv2.imwrite(os.path.join(self.args.result_dir, self.args.dataset, 'img', f'B2A{name}_{step:07d}.png'),
802 | B2A * 255.0)
803 | return
804 |
805 | def model_ema(self, step, G_ema, G):
806 | if step > self.args.ema_start * self.args.iteration:
807 | for p_ema, p in zip(G_ema.parameters(), G.parameters()):
808 | p_ema.copy_(p.lerp(p_ema, self.args.ema_beta))
809 | else:
810 | for p_ema, p in zip(G_ema.parameters(), G.parameters()):
811 | p_ema.copy_(p)
812 | for b_ema, b in zip(G_ema.buffers(), G.buffers()):
813 | b_ema.copy_(b)
814 | return
815 |
816 | def save(self, root, step, name=None):
817 | if name is None:
818 | name = '_params_%07d.pt' % step
819 | params = {'genA2B': self.genA2B.state_dict(), 'genB2A': self.genB2A.state_dict(),
820 | 'genA2B_ema': self.genA2B_ema.state_dict(), 'genB2A_ema': self.genB2A_ema.state_dict(),
821 | 'disGA': self.disGA.state_dict(), 'disGB': self.disGB.state_dict(), 'disLA': self.disLA.state_dict(),
822 | 'disLB': self.disLB.state_dict()}
823 | torch.save(params, os.path.join(root, self.args.dataset + name))
824 | g_params = {'genA2B': self.genA2B.state_dict(), 'genA2B_ema': self.genA2B_ema.state_dict()}
825 | torch.save(g_params, os.path.join(root, self.args.dataset + f'_g{name}'))
826 |
827 | def load(self, root, step):
828 | params = torch.load(os.path.join(root, self.args.dataset + '_params_%07d.pt' % step),
829 | map_location=torch.device("cpu"))
830 | self.genA2B.load_state_dict(params['genA2B'])
831 | self.genB2A.load_state_dict(params['genB2A'])
832 | self.genA2B_ema.load_state_dict(params['genA2B_ema'])
833 | self.genB2A_ema.load_state_dict(params['genB2A_ema'])
834 | self.disGA.load_state_dict(params['disGA'])
835 | self.disGB.load_state_dict(params['disGB'])
836 | self.disLA.load_state_dict(params['disLA'])
837 | self.disLB.load_state_dict(params['disLB'])
838 |
839 | def test(self):
840 | model_list = glob(os.path.join(self.args.result_dir, '*_params_latest.pt'))
841 | if len(model_list) == 0:
842 | model_list = glob(os.path.join(self.args.result_dir, self.args.dataset, 'model', '*.pt'))
843 | if len(model_list) != 0:
844 | model_list.sort()
845 | if not (self.args.generator_model and os.path.isfile(self.args.generator_model)):
846 | self.args.generator_model = model_list[-1]
847 |
848 | if self.args.generator_model and os.path.isfile(self.args.generator_model):
849 | params = torch.load(self.args.generator_model, map_location=torch.device("cpu"))
850 | self.genA2B.load_state_dict(params['genA2B_ema'])
851 | self.genB2A.load_state_dict(params['genB2A_ema'])
852 | print(" [*] Load SUCCESS")
853 | else:
854 | print(" [*] Load FAILURE")
855 | return
856 |
857 | self.genA2B.eval(), self.genB2A.eval()
858 | for n, real_A in tqdm(enumerate(self.testA_loader)):
859 | real_A = real_A.to(self.args.device)
860 | with torch.no_grad():
861 | fake_A2B, _, fake_A2B_heatmap, fake_A2B_attention = self.genA2B(real_A)
862 | fake_A2B2A, _, fake_A2B2A_heatmap, fake_A2B2A_attention = self.genB2A(fake_A2B)
863 | fake_A2A, _, fake_A2A_heatmap, fake_A2A_attention = self.genB2A(real_A)
864 |
865 | A2B_list = [RGB2BGR(tensor2numpy(denorm(real_A[0]))),
866 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.args.img_size),
867 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
868 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.args.img_size),
869 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
870 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.args.img_size),
871 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))
872 | ]
873 | if self.args.attention_gan > 0:
874 | for i in range(self.args.attention_gan):
875 | A2B_list.append(attention_mask(tensor2numpy(fake_A2B_attention[0][i:(i + 1)]), self.args.img_size))
876 | A2B = np.concatenate(A2B_list, 0)
877 | cv2.imwrite(os.path.join(self.args.result_dir, self.args.dataset, 'test', 'A2B_%d.png' % (n + 1)),
878 | A2B * 255.0)
879 |
880 | for n, real_B in tqdm(enumerate(self.testB_loader)):
881 | real_B = real_B.to(self.args.device)
882 | with torch.no_grad():
883 | fake_B2A, _, fake_B2A_heatmap, fake_B2A_attention = self.genB2A(real_B)
884 | fake_B2A2B, _, fake_B2A2B_heatmap, fake_B2A2B_attention = self.genA2B(fake_B2A)
885 | fake_B2B, _, fake_B2B_heatmap, fake_B2B_attention = self.genA2B(real_B)
886 |
887 | B2A_list = [RGB2BGR(tensor2numpy(denorm(real_B[0]))),
888 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.args.img_size),
889 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
890 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.args.img_size),
891 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
892 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.args.img_size),
893 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))]
894 | if self.args.attention_gan > 0:
895 | for i in range(self.args.attention_gan):
896 | B2A_list.append(attention_mask(tensor2numpy(fake_B2A_attention[0][i:(i + 1)]), self.args.img_size))
897 | B2A = np.concatenate(B2A_list, 0)
898 | cv2.imwrite(os.path.join(self.args.result_dir, self.args.dataset, 'test', 'B2A_%d.png' % (n + 1)),
899 | B2A * 255.0)
900 |
--------------------------------------------------------------------------------