├── metrics ├── __init__.py └── fid │ ├── __init__.py │ ├── inception.py │ └── fid_score.py ├── videos ├── __init__.py └── test_video.py ├── faceseg ├── __init__.py ├── resnet.py ├── FaceSegmentation.py └── networks_faceseg.py ├── assets ├── kid.png ├── teaser.png ├── ablation.png ├── generator.png ├── user_study.png ├── bg_parsing_boy.png ├── discriminator.png └── face_parsing_boy.png ├── dataset └── YOUR_DATASET_NAME │ ├── testB │ └── 3414.png │ ├── trainB │ └── 0006.png │ ├── testA │ └── female_2321.jpg │ └── trainA │ └── female_222.jpg ├── requirements.txt ├── scripts ├── test.sh └── train.sh ├── LICENSE ├── .gitignore ├── README.md ├── dataset.py ├── utils.py ├── histogram.py ├── main.py ├── networks.py └── UGATIT.py /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .fid import FIDScore 2 | -------------------------------------------------------------------------------- /videos/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Description: 4 | """ -------------------------------------------------------------------------------- /faceseg/__init__.py: -------------------------------------------------------------------------------- 1 | # ref: https://github.com/zllrunning/face-parsing.PyTorch 2 | -------------------------------------------------------------------------------- /assets/kid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/kid.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/ablation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/ablation.png -------------------------------------------------------------------------------- /assets/generator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/generator.png -------------------------------------------------------------------------------- /assets/user_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/user_study.png -------------------------------------------------------------------------------- /assets/bg_parsing_boy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/bg_parsing_boy.png -------------------------------------------------------------------------------- /assets/discriminator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/discriminator.png -------------------------------------------------------------------------------- /metrics/fid/__init__.py: -------------------------------------------------------------------------------- 1 | # ref: https://github.com/mseitzer/pytorch-fid 2 | from .fid_score import FIDScore 3 | -------------------------------------------------------------------------------- /assets/face_parsing_boy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/assets/face_parsing_boy.png -------------------------------------------------------------------------------- /dataset/YOUR_DATASET_NAME/testB/3414.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/dataset/YOUR_DATASET_NAME/testB/3414.png -------------------------------------------------------------------------------- /dataset/YOUR_DATASET_NAME/trainB/0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/dataset/YOUR_DATASET_NAME/trainB/0006.png -------------------------------------------------------------------------------- /dataset/YOUR_DATASET_NAME/testA/female_2321.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/dataset/YOUR_DATASET_NAME/testA/female_2321.jpg -------------------------------------------------------------------------------- /dataset/YOUR_DATASET_NAME/trainA/female_222.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zheng-yuwei/enhanced-UGATIT/HEAD/dataset/YOUR_DATASET_NAME/trainA/female_222.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # pip install -r requirements.txt 2 | # python=3.8 3 | torch==1.9.0 4 | torchvision>=0.9.0 5 | tensorboard 6 | opencv-python 7 | scipy 8 | tqdm 9 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 模型测试视频 + attention 模型 + scse 3 | python main_light.py --phase video --generator_model checkpoints/big_normal_100w.pt \ 4 | --use_se --attention_gan 3 --attention_input --device cpu --img_size 384 --light 32 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 郑煜伟 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 原始训练 3 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 4 | 5 | # 直方图匹配 6 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 \ 7 | --match_histograms --match_mode hsl --match_prob 0.5 --match_ratio 1.0 8 | 9 | # 背景不变 10 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 \ 11 | --cam_D_weight -1 --cam_D_attention --seg_fix_weight 100 --seg_D_mask --seg_G_detach 12 | 13 | # se + blur 14 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 --use_se --has_blur 15 | 16 | # attention + 原图 17 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 \ 18 | --attention_gan 3 --attention_input 19 | 20 | # 损失权重调整 21 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 \ 22 | --adv_weight 1.0 --forward_adv_weight 2 --cycle_weight 5 --identity_weight 5 23 | 24 | # cpu调试 25 | python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 --device cpu --num_workers 0 26 | 27 | # 复杂cpu测试:直方图匹配 + 背景不变 + se + blur + attention + 原图 28 | python main.py --dataset YOUR_DATA_SET --result_dir results --img_size 256 --device cpu --num_workers 0 \ 29 | --match_histograms --match_mode hsl --match_prob 0.5 --match_ratio 1.0 \ 30 | --cam_D_weight -1 --cam_D_attention --seg_fix_weight 100 --seg_D_mask --seg_G_detach \ 31 | --use_se --has_blur --attention_gan 3 --attention_input \ 32 | --adv_weight 1.0 --forward_adv_weight 2 --cycle_weight 20 --identity_weight 5 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # my ignore 132 | .idea 133 | .DS_Store 134 | -------------------------------------------------------------------------------- /faceseg/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum - 1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | # self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /videos/test_video.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 视频/摄像头测试:人脸关键点检测、矫正、变漫画、贴回原图、保存为视频 4 | """ 5 | import os 6 | from functools import partial 7 | 8 | import cv2 9 | import numpy as np 10 | from PIL import Image 11 | import torch 12 | from torchvision import transforms 13 | 14 | from .face_align.align_utils import detect_face, align_face 15 | 16 | 17 | class VideoTester: 18 | """ GAN测试 """ 19 | 20 | IMAGE_EXT = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 21 | 22 | def __init__(self, args, generator): 23 | self.args = args 24 | self.generator = generator 25 | self.generator.to(args.device) 26 | self.generator.eval() 27 | self.preprocess = transforms.Compose([ 28 | partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB), # 将cv读取的图像转为RGB 29 | Image.fromarray, 30 | transforms.Resize((args.img_size, args.img_size)), 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 33 | partial(torch.unsqueeze, dim=0), 34 | ]) 35 | width = 1280 36 | height = 720 * 2 37 | self.frame_size = (width, height) 38 | self.text = partial(cv2.putText, org=(width-100, 100), fontFace=cv2.FONT_HERSHEY_SIMPLEX, 39 | fontScale=1.2, color=(255, 255, 255), thickness=2) 40 | self.fourcc = cv2.VideoWriter_fourcc(*'XVID') 41 | self.fps = 30 42 | 43 | def generate(self, img): 44 | """ 脸部(0, 255)BGR原图生成(0, 255)BGR动漫图 """ 45 | img = self.preprocess(img) 46 | img = img.to(self.args.device) 47 | with torch.no_grad(): 48 | img, _, _, _ = self.generator(img) 49 | img = (img.cpu()[0].numpy().transpose(1, 2, 0) + 1) * 255 / 2 50 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 51 | return img 52 | 53 | def image2image(self, image): 54 | """ 整张图像到图像的转换 """ 55 | # 脸部检测及矫正 56 | facepoints = detect_face(image, detect_scale=0.25, flag_384=self.args.img_size>256) 57 | if facepoints is None: 58 | return image 59 | align_result = align_face(image, facepoints, flag_384=self.args.img_size>256) 60 | if align_result is None: 61 | return image 62 | img_align, face2mean_matrix, mean2face_matrix, _ = align_result 63 | # 生成 64 | img_gen = self.generate(img_align) 65 | # 变换为原来位置 66 | img_gen = np.clip(img_gen, 0, 255).astype(np.uint8) 67 | output = np.ones([self.args.img_size, self.args.img_size, 4], dtype=np.uint8) * 255 68 | output[:, :, :3] = img_gen 69 | output = cv2.warpAffine(output, mean2face_matrix, 70 | (image.shape[1], image.shape[0]), 71 | flags=cv2.INTER_LINEAR, 72 | borderMode=cv2.BORDER_CONSTANT, borderValue=0) 73 | alpha_img_gen = (output[:, :, 3:4] / 255.0) 74 | image_compose = image.astype(np.float32) 75 | image_compose[:, :, 0:3] = (image_compose[:, :, 0:3] * (1.0 - alpha_img_gen) + output[:, :, 0:3]) 76 | image_compose = np.clip(image_compose, 0, 255).astype(np.uint8) 77 | return image_compose 78 | 79 | def record(self, cap: cv2.VideoCapture, writer: cv2.VideoWriter, 80 | frame_num: int): 81 | """ 从cap视频源读取图像,经过图像转换后,写入到writer中 """ 82 | ret, frame = cap.read() 83 | if frame is None: 84 | return -1 85 | frame = frame[:, ::-1, :] # 左右flip 86 | org_frame = frame.copy() 87 | new_frame = self.image2image(frame) 88 | new_frame = np.concatenate([org_frame, new_frame], axis=0) 89 | self.text(new_frame, str(frame_num)) 90 | writer.write(new_frame) 91 | frame_num += 1 92 | return frame_num 93 | 94 | def video(self, path): 95 | """ 视频测试 """ 96 | video_name = os.path.basename(self.args.generator_model) 97 | video_name = os.path.splitext(os.path.basename(path))[0] + '_' + video_name 98 | new_record = cv2.VideoWriter(f'{os.path.splitext(video_name)[0]}_cartoon.avi', 99 | self.fourcc, self.fps, self.frame_size) 100 | cap = cv2.VideoCapture(path) # 打开视频 101 | frame_num = 0 102 | while cap.isOpened(): 103 | frame_num = self.record(cap, new_record, frame_num) 104 | if frame_num < 0: 105 | break 106 | cap.release(), new_record.release() 107 | 108 | def camera(self): 109 | """ 摄像头测试 """ 110 | video_name = os.path.basename(self.args.generator_model) 111 | cap = cv2.VideoCapture(0) # 打开摄像头 112 | ret = True 113 | while ret: 114 | ret, frame = cap.read() 115 | frame = self.image2image(frame) 116 | cv2.imshow(video_name, frame) 117 | cv2.waitKey(1) 118 | if 0xFF == ord('q'): 119 | break 120 | cap.release() 121 | cv2.destroyAllWindows() 122 | 123 | def image_dir(self, img_dir): 124 | """ 图像文件夹测试 """ 125 | img_paths = [f for f in os.listdir(img_dir) if os.path.splitext(f)[-1].lower() in self.IMAGE_EXT] 126 | save_dir = os.path.join(img_dir, '..', 'gan_result') 127 | os.makedirs(save_dir, exist_ok=False) 128 | for img_path in img_paths: 129 | image = cv2.imread(os.path.join(img_dir, img_path)) 130 | image_gan = self.image2image(image) 131 | cv2.imwrite(os.path.join(save_dir, img_path), image_gan) 132 | -------------------------------------------------------------------------------- /faceseg/FaceSegmentation.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | import os 3 | import functools 4 | 5 | import cv2 6 | import torch 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | 11 | from faceseg.networks_faceseg import BiSeNet 12 | 13 | 14 | class FaceSegmentation: 15 | part_map = {0: 'background', 1: 'skin', 2: 'l_brow', 3: 'r_brow', 4: 'l_eye', 5: 'r_eye', 6: 'eye_g', 7: 'l_ear', 16 | 8: 'r_ear', 9: 'ear_r', 10: 'nose', 11: 'mouth', 12: 'u_lip', 13: 'l_lip', 14: 'neck', 15: 'neck_l', 17 | 16: 'cloth', 17: 'hair', 18: 'hat'} 18 | # 不同部分的颜色 19 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 0, 85], [255, 0, 170], 20 | [0, 255, 0], [85, 255, 0], [170, 255, 0], [0, 255, 85], [0, 255, 170], 21 | [0, 0, 255], [85, 0, 255], [170, 0, 255], [0, 85, 255], [0, 170, 255], 22 | [255, 255, 0], [255, 255, 85], [255, 255, 170], [255, 0, 255], [255, 85, 255], 23 | [255, 170, 255], [0, 255, 255], [85, 255, 255], [170, 255, 255]] 24 | 25 | def __init__(self, device): 26 | self.device = device 27 | self.n_classes = len(FaceSegmentation.part_map.keys()) 28 | # 膨胀、腐蚀、闭运算、开运算 29 | self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) 30 | # 模型加载 31 | net = BiSeNet(n_classes=self.n_classes).to(self.device) 32 | net.load_state_dict(torch.load(os.path.join(os.path.dirname(__file__), '79999_iter.pth'), 33 | map_location=torch.device('cpu'))) 34 | net.eval() 35 | self.net = net 36 | self.mean = torch.as_tensor((0.485, 0.456, 0.406), dtype=torch.float32, device=self.device).view(-1, 1, 1) 37 | self.std = torch.as_tensor((0.229, 0.224, 0.225), dtype=torch.float32, device=self.device).view(-1, 1, 1) 38 | # 预处理 39 | self.preprocessor = transforms.Compose([ 40 | lambda x: x * 0.5 + 0.5, 41 | functools.partial(F.interpolate, size=(512, 512), mode='bilinear', align_corners=True), 42 | lambda x: x.sub_(self.mean).div_(self.std), 43 | ]) 44 | 45 | def face_segmentation(self, input_x): 46 | """ 人脸分割 47 | :param input_x: NCHW 标准化的tensor, (image / 255 - 0.5) * 2 48 | :return mask N1HW tensor,每一个像素点位置取值 0~18 int数,表示属于哪一类 49 | """ 50 | with torch.no_grad(): 51 | img = self.preprocessor(input_x) 52 | img = img.to(self.device) 53 | out = self.net(img)[0] 54 | out = F.interpolate(out, input_x.shape[2:], mode='bicubic', align_corners=True) 55 | mask = out.detach().softmax(axis=1) 56 | return mask 57 | 58 | def gen_mask(self, mask_tensor, is_soft_edge=True, 59 | normal_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17), dilate_parts=(), erode_parts=()): 60 | """ 根据指定的parts生成mask numpy数组 61 | :param mask_tensor: face parsing 分割出来的不同类别的mask 62 | :param is_soft_edge: mask是否有软边界 63 | :param normal_parts: 不做操作的 前景类别列表 64 | :param dilate_parts: 需要做膨胀操作的 前景类别列表 65 | :param erode_parts: 需要做腐蚀操作的 前景类别列表 66 | :return mask: 前景mask 67 | """ 68 | mask_tensor = mask_tensor.cpu().numpy() 69 | N, C, H, W = mask_tensor.shape 70 | normal_mask = np.zeros((N, 1, H, W), dtype=np.float32) 71 | dilate_mask = np.zeros((N, 1, H, W), dtype=np.float32) 72 | erode_mask = np.zeros((N, 1, H, W), dtype=np.float32) 73 | for i in normal_parts: 74 | normal_mask += mask_tensor[:, i, :, :] 75 | for i in dilate_parts: 76 | dilate_mask += mask_tensor[:, i, :, :] 77 | for i in erode_parts: 78 | erode_mask += mask_tensor[:, i, :, :] 79 | # 闭运算 + 膨胀 80 | for i in range(len(dilate_mask)): 81 | dilate_mask[i, 0] = cv2.morphologyEx(dilate_mask[i, 0], cv2.MORPH_CLOSE, self.kernel) 82 | dilate_mask[i, 0] = cv2.dilate(dilate_mask[i, 0], self.kernel, iterations=1) 83 | # 闭运算 + 腐蚀 84 | for i in range(len(erode_mask)): 85 | erode_mask[i, 0] = cv2.morphologyEx(erode_mask[i, 0], cv2.MORPH_CLOSE, self.kernel) 86 | erode_mask[i, 0] = cv2.erode(erode_mask[i, 0], self.kernel, iterations=1) 87 | mask = np.maximum(normal_mask, dilate_mask) # 三个区域的交集 88 | mask = np.maximum(mask, erode_mask) 89 | if is_soft_edge: 90 | mask = ((0.7 > mask) & (mask > 0.3)) * mask + (mask > 0.7) # 概率很高/低的区域更加hard,留下过渡区域 91 | else: 92 | mask = (mask >= 0.5).astype(np.float32) 93 | mask = torch.from_numpy(mask).to(self.device) 94 | return mask 95 | 96 | def vis(self, image, mask, is_show=False): 97 | """ 可视化人脸分割结果,如果mask包含不同类别,会将不同类别用不同的颜色mask可视化,如果只有一种类别,会将前景用蓝色mask可视化 98 | :param image: 待可视化图像 99 | :param mask: 待可视化图像分割后的mask 100 | :param is_show:是否用cv2.imshow可视化 101 | :return 可视化的图像 102 | """ 103 | mask = mask.cpu().numpy() 104 | vis_im = (image.numpy().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8) 105 | vis_mask_color = np.zeros((mask.shape[0], mask.shape[1], 3)) + 255 106 | 107 | if mask.dtype == np.int64: 108 | for pi in range(1, self.n_classes): 109 | index = np.where(mask == pi) 110 | vis_mask_color[index[0], index[1], :] = self.part_colors[pi] 111 | else: 112 | mask = np.repeat(np.expand_dims(mask, axis=-1), repeats=3, axis=-1) 113 | vis_mask_color = np.array([[self.part_colors[1]]], dtype=np.float32) * mask + (1 - mask) * vis_mask_color 114 | 115 | vis_mask_color = vis_mask_color.astype(np.uint8) 116 | vis_im_hm = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.5, vis_mask_color, 0.5, 0) 117 | if is_show: 118 | cv2.imshow('seg', np.concatenate([vis_im_hm, cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR)], axis=1)) 119 | cv2.waitKey(0) 120 | return np.concatenate([vis_im_hm, cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR)], axis=1) 121 | 122 | 123 | if __name__ == '__main__': 124 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 125 | img_dir = '../dataset/cartoon/testA' 126 | face_seg = FaceSegmentation('cpu') 127 | train_transform = transforms.Compose([ 128 | cv2.imread, 129 | functools.partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB), 130 | transforms.ToTensor(), 131 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 132 | functools.partial(torch.unsqueeze, dim=0) 133 | ]) 134 | for f_name in sorted(os.listdir(img_dir)): 135 | test_img = train_transform(os.path.join(img_dir, f_name)) 136 | test_mask = face_seg.face_segmentation(test_img) 137 | # 注释这三句,使用下面一句,可以看到所有类别 138 | test_mask0 = face_seg.gen_mask(test_mask, normal_parts=(), dilate_parts=(), 139 | erode_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17)) 140 | test_mask1 = face_seg.gen_mask(test_mask, normal_parts=(1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17), 141 | dilate_parts=(), erode_parts=(6, )) 142 | test_mask = test_mask0 * test_mask1 143 | # test_mask = test_mask.argmax(1, keepdims=True) 144 | vis_img = face_seg.vis(test_img[0], test_mask[0, 0], is_show=True) 145 | cv2.imwrite(f_name, vis_img) 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # enhanced-U-GAT-IT 2 | 3 | 知乎上关于论文的解读blog: 4 | - [论文阅读 | 图像转换(五) AttentionGAN](https://zhuanlan.zhihu.com/p/168382844) 5 | - [论文阅读 | 图像转换(六) U-GAT-IT](https://zhuanlan.zhihu.com/p/270958248) 6 | 7 | 专栏内及其他专栏有更多内容,欢迎进行技术讨论~ 8 | 9 | ## Usage 10 | 11 | ```shell 12 | CUDA_VISIBLE_DEVICES=2 python main.py --dataset YOUR_DATASET_NAME --result_dir YOUR_RESULT_DIR \ 13 | # 适当减小light可降低显存的使用 14 | --light 32 \ 15 | --img_size 384 --aug_prob 0.2 --device cuda:0 --num_worker 1 --print_freq 1000 --calc_fid_freq 10000 --save_freq 100000 \ 16 | # 适当提高forward_adv_weight系数可适当增强 A2B的风格化程度 17 | --adv_weight 1.0 --forward_adv_weight 1 --cycle_weight 10 --identity_weight 10 \ 18 | # ema大幅提高模型稳定性和fid 19 | --ema_start 0.5 --ema_beta 0.9999 \ 20 | # 固定背景参数 21 | # 去除判别器中的CAM和logit loss 22 | --cam_D_weight -1 --cam_D_attention \ 23 | # 使用faceseg分割出背景做L1损失 24 | --seg_fix_weight 100 \ 25 | # 背景的判别器损失丢弃mask区域的损失 26 | --seg_D_mask \ 27 | # 背景detach 28 | --seg_G_detach \ 29 | # D判别器可以设置较为local 30 | --n_global_dis 7 --n_local_dis 5 31 | --has_blur --use_se 32 | # 如果attention_input想配合背景固定一起用的话,建议加大对抗loss的权重 adv_weight -> 10 甚至更大 33 | --attention_gan 2 --attention_input 34 | ``` 35 | 36 | 本地调试的话:`--device cuda:0` -> `--device cpu`, 37 | 38 | ## 基于官方U-GAT-IT[1]做改进: 39 | 40 | - 代码结构优化: 41 | - 将`UGATIT.py`中的`train`函数的部分代码模块化为函数(`get_batch, forward, backward_G, backward_D`); 42 | - 增加训练时损失函数的打印,tensorboard记录; 43 | - 优化了官方代码中一次多余的前向推理,约节省20%的训练时间。 44 | 45 | - 功能增强(可选是否开启): 46 | - `--ema_start`: 开始做模型ema的迭代次数的比例数,也就是`--iteration 100 --ema_start 0.7`表示 `100 * 0.7 = 70` 个迭代后开始做模型ema; 47 | - `--ema_beta`: 模型ema的滑动平均系数; 48 | - `--calc_fid_freq`: 计算fid score的频率,应该设置为`--print_freq`的整数倍; 49 | - `--fid_batch`: 每次计算fid score时,推理时使用的`batch size`; 50 | - `--adv_weight, forward_adv_weight, backward_adv_weight`: 将对抗损失项权重分成3个(原始的对抗损失项权重,A2B对抗损失项的权重,B2A的对抗损失项权重); 51 | - `--aug_prob`: 数据增强`RandomResizedCrop`的概率; 52 | - `--match_histograms`:是否将两个域进行直方图匹配,使两个域的分布一致;[2] 53 | - `--match_mode`:直方图匹配时的匹配模式,(hsv, hsl, rgb)中的一个; 54 | - `--match_prob`:将域B的图像往域A进行直方图匹配的概率,否则将域A的图像往域B直方图匹配; 55 | - `--match_ratio`:直方图匹配的比例,匹配图 * ratio + 原图 * (1 - ratio); 56 | - `--sn`: 判别器中,卷积操作后是否进行谱归一化,默认使用,tf官方代码有,pt官方代码没有; 57 | - `--has_blur`: 判别网络训练/模型更新时,损失项中增加模糊对抗损失项,增强D对模糊图片的判别;[3] 58 | - `--use_se`: 在生成网络中,是否给每个`ResnetBlock, ResnetAdaILNBlock`增加`scse block`;[4] 59 | - `--attention_gan`: 生成网络中,是否使用`attention mask`机制(在上采样前多一个分支,用于生成mask);[5] 60 | - `--use_deconv`: 生成网络中,上采样是否使用反卷积`nn.ConvTranspose2d`,默认使用插值+卷积方式; 61 | - `--seg_fix_weight`: 人脸分割[6]损失权重(本库代码默认的分割区域为下图的二分类图,借鉴的库可实现多个类别的解析),将原图与生成图在背景区域上做L1损失; 62 | - `--seg_D_mask`: 背景区域的对抗损失置0、背景区域的判别损失置0; 63 | - `--seg_G_detach`: 生成图的背景区域detach掉再放入D网络做对抗; 64 | - `--seg_D_cam_inp_mask`: 训练D网络时,整个背景替换为0作为CAM分支的输入; 65 | - `--seg_D_cam_fea_mask`: 训练D网络时,将CAM分支的feature map的背景区域替换为0; 66 | - `--cam_D_weight -1 --cam_D_attention`: 将CAM的attention和logit loss砍掉。 67 | 68 | ![背景分割](./assets/bg_parsing_boy.png) ![人脸解析](./assets/face_parsing_boy.png) 69 | 70 | - 其他功能: 71 | - `--phase`: 模型的训练/验证/测试/视频/视频文件夹/摄像头/图像文件夹/以对齐的人脸图像文件夹 模式(`'train', 'val', 'test', 'video', 'video_dir', 'camera', 'img_dir', 'generate'`); 72 | 73 | ## 文件目录说明 74 | 75 | - `assets`: `README.md`中的图像; 76 | - `dataset`: 图像数据 77 | - `faceseg`: 人脸分割部分代码与明星权重文件夹; 78 | - `metrics`: FID score 计算脚本; 79 | - `scripts`: 运行命令shell脚本文件夹;训练、测试的运行指令可以参考这里~ 80 | - `videos`: 视频测试脚本; 81 | - `dataset.py`: 数据集; 82 | - `histogram.py`: 直方图匹配; 83 | - `main.py`: 程序主入口,配置信息; 84 | - `networks.py`: 生成器、判别器网络结构 85 | - `UGATIT.py`: 整个训练、测试过程的执行脚本; 86 | - `utils.py`: 工具脚本; 87 | 88 | ## 补充实验说明 89 | 90 | - 部分代码修改:官方其实主库是tensorflow的库,pytorch是根据tensorflow写的,其中有比较多的差异,本库修正了部分差异; 91 | - 损失项权重:在`CycleGAN`实验中发现,适当增加A2B阶段的损失项权重,有利于提升A2B阶段的图像生成质量; 92 | - 反卷积和谱归一化:在`CycleGAN`实验中发现,使用反卷积和去除谱归一化有利于生成网络对原图做更强的变化; 93 | - 模糊对抗对纹理区分度略有提升; 94 | - attention mask fusion:在本库中实验,发现提升不大,不同任务需进一步实验;同时提供了一个把输入图与生成图一起做区域融合的选项; 95 | - scse模块:在本库中实验,发现提升不大,不同任务需进一步实验; 96 | - 在控制人脸除外的背景不变时,使用`faceseg`具有一定效果,可配合背景随机换色一起使用; 97 | - 直方图匹配在两个域比较接近,但是分布存在差异的时候有效,例如 大人 -> 小孩,其他情况可能只对某个通道进行直方图匹配就可,例如在value上使得光照一致; 98 | - 在两个域的色调等存在分布差异时,还可考虑将输入D的图像先转为灰度图,再进行判别;类似AnimeGAN的做法。 99 | 100 | ## reference 101 | 102 | [1] UGATIT官方pytorch实现,也是本库baseline:https://github.com/znxlwm/UGATIT-pytorch; 103 | [2] 直方图匹配参考 [Bringing-Old-Photos-Back-to-Life](https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life) ; 104 | [3] CartoonGAN: Generative Adversarial Networks for Photo Cartoonization; 105 | [4] Concurrent Spatial and Channel ‘Squeeze & Excitation’ in Fully Convolutional Networks; 106 | [5] AttentionGAN: Unpaired Image-to-Image Translation using Attention-Guided Generative Adversarial Networks; 107 | [6] 人脸分割模型库:[zllrunning/face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch) ; 108 | 109 | 110 | ------ 111 | (以下为官方原始 readme 部分) 112 | 113 | ## U-GAT-IT — Official PyTorch Implementation 114 | ### : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation 115 | 116 |
117 | 118 |
119 | 120 | ### [Paper](https://arxiv.org/abs/1907.10830) | [Official Tensorflow code](https://github.com/taki0112/UGATIT) 121 | The results of the paper came from the **Tensorflow code** 122 | 123 | 124 | > **U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation**
125 | > 126 | > **Abstract** *We propose a novel method for unsupervised image-to-image translation, which incorporates a new attention module and a new learnable normalization function in an end-to-end manner. The attention module guides our model to focus on more important regions distinguishing between source and target domains based on the attention map obtained by the auxiliary classifier. Unlike previous attention-based methods which cannot handle the geometric changes between domains, our model can translate both images requiring holistic changes and images requiring large shape changes. Moreover, our new AdaLIN (Adaptive Layer-Instance Normalization) function helps our attention-guided model to flexibly control the amount of change in shape and texture by learned parameters depending on datasets. Experimental results show the superiority of the proposed method compared to the existing state-of-the-art models with a fixed network architecture and hyper-parameters.* 127 | 128 | ## Usage 129 | ``` 130 | ├── dataset 131 |    └── YOUR_DATASET_NAME 132 |    ├── trainA 133 |           ├── xxx.jpg (name, format doesn't matter) 134 | ├── yyy.png 135 | └── ... 136 |    ├── trainB 137 | ├── zzz.jpg 138 | ├── www.png 139 | └── ... 140 |    ├── testA 141 |    ├── aaa.jpg 142 | ├── bbb.png 143 | └── ... 144 |    └── testB 145 | ├── ccc.jpg 146 | ├── ddd.png 147 | └── ... 148 | ``` 149 | 150 | ### Train 151 | ``` 152 | > python main.py --dataset selfie2anime 153 | ``` 154 | * If the memory of gpu is **not sufficient**, set `--light` to True 155 | 156 | ### Test 157 | ``` 158 | > python main.py --dataset selfie2anime --phase test 159 | ``` 160 | 161 | ## Architecture 162 | ![generator](./assets/generator.png) 163 | 164 | --- 165 | 166 | ![discriminator](./assets/discriminator.png) 167 | 168 | ## Results 169 | ### Ablation study 170 | ![Ablation study](./assets/ablation.png) 171 | 172 | ### User study 173 | ![User study](./assets/user_study.png) 174 | 175 | ### Comparison 176 | ![KID score](./assets/kid.png) 177 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import os.path 4 | from queue import Queue 5 | from threading import Thread 6 | 7 | import cv2 8 | import torch 9 | import torch.utils.data 10 | import numpy as np 11 | 12 | from histogram import match_histograms 13 | 14 | 15 | def get_loader(my_dataset, device, batch_size, num_workers, shuffle): 16 | """ 根据dataset及设置,获取对应的 DataLoader """ 17 | my_loader = torch.utils.data.DataLoader(my_dataset, batch_size=batch_size, num_workers=num_workers, 18 | shuffle=shuffle, pin_memory=True, persistent_workers=(num_workers > 0)) 19 | # if torch.cuda.is_available(): 20 | # my_loader = CudaDataLoader(my_loader, device=device) 21 | return my_loader 22 | 23 | 24 | class MatchHistogramsDataset(torch.utils.data.Dataset): 25 | 26 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 27 | 28 | def __init__(self, root, transform=None, target_transform=None, is_match_histograms=False, match_mode=True, 29 | b2a_prob=0.5, match_ratio=1.0): 30 | """ 获取指定的两个文件夹下,两张图像numpy数组的Dataset """ 31 | assert len(root) == 2, f'root of MatchHistogramsDataset must has two dir!' 32 | self.dataset_0 = DatasetFolder(root[0]) 33 | self.dataset_1 = DatasetFolder(root[1]) 34 | 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | self.len_0 = len(self.dataset_0) 38 | self.len_1 = len(self.dataset_1) 39 | self.len = max(self.len_0, self.len_1) 40 | self.is_match_histograms = is_match_histograms 41 | self.match_mode = match_mode 42 | assert self.match_mode in ('hsv', 'hsl', 'rgb'), f'match mode must in {self.match_mode}' 43 | self.b2a_prob = b2a_prob 44 | self.match_ratio = match_ratio 45 | 46 | def __getitem__(self, index): 47 | sample_0 = self.dataset_0[index] if index < self.len_0 else self.dataset_0[np.random.randint(self.len_0)] 48 | sample_1 = self.dataset_1[index] if index < self.len_1 else self.dataset_1[np.random.randint(self.len_1)] 49 | 50 | if self.is_match_histograms: 51 | if self.match_mode == 'hsv': 52 | sample_0 = cv2.cvtColor(sample_0, cv2.COLOR_RGB2HSV_FULL) 53 | sample_1 = cv2.cvtColor(sample_1, cv2.COLOR_RGB2HSV_FULL) 54 | elif self.match_mode == 'hsl': 55 | sample_0 = cv2.cvtColor(sample_0, cv2.COLOR_RGB2HLS_FULL) 56 | sample_1 = cv2.cvtColor(sample_1, cv2.COLOR_RGB2HLS_FULL) 57 | 58 | if np.random.rand() < self.b2a_prob: 59 | sample_1 = match_histograms(sample_1, sample_0, rate=self.match_ratio) 60 | else: 61 | sample_0 = match_histograms(sample_0, sample_1, rate=self.match_ratio) 62 | 63 | if self.match_mode == 'hsv': 64 | sample_0 = cv2.cvtColor(sample_0, cv2.COLOR_HSV2RGB_FULL) 65 | sample_1 = cv2.cvtColor(sample_1, cv2.COLOR_HSV2RGB_FULL) 66 | elif self.match_mode == 'hsl': 67 | sample_0 = cv2.cvtColor(sample_0, cv2.COLOR_HLS2RGB_FULL) 68 | sample_1 = cv2.cvtColor(sample_1, cv2.COLOR_HLS2RGB_FULL) 69 | 70 | if self.transform is not None: 71 | sample_0 = self.transform(sample_0) 72 | sample_1 = self.transform(sample_1) 73 | 74 | return sample_0, sample_1 75 | 76 | def __len__(self): 77 | return self.len 78 | 79 | def __repr__(self): 80 | fmt_str = f'MatchHistogramsDataset for: \n' \ 81 | f'{self.dataset_0.__repr__()} \n ' \ 82 | f'{self.dataset_1.__repr__()}' 83 | return fmt_str 84 | 85 | 86 | class DatasetFolder(torch.utils.data.Dataset): 87 | 88 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 89 | 90 | def __init__(self, root, transform=None): 91 | """ 获取指定文件夹下,单张图像numpy数组的Dataset """ 92 | samples = [] 93 | for sub_root, _, filenames in sorted(os.walk(root)): 94 | for filename in sorted(filenames): 95 | if os.path.splitext(filename)[-1].lower() in self.IMG_EXTENSIONS: 96 | path = os.path.join(sub_root, filename) 97 | samples.append(path) 98 | 99 | if len(samples) == 0: 100 | raise RuntimeError(f"Found 0 files in sub-folders of: {root}\n" 101 | f"Supported extensions are: {','.join(self.IMG_EXTENSIONS)}") 102 | 103 | self.root = root 104 | self.samples = samples 105 | self.transform = transform 106 | 107 | def __getitem__(self, index): 108 | path = self.samples[index] 109 | sample = cv2.imread(path)[..., ::-1] 110 | if self.transform is not None: 111 | sample = self.transform(sample) 112 | return sample 113 | 114 | def __len__(self): 115 | return len(self.samples) 116 | 117 | def __repr__(self): 118 | fmt_str = f'Dataset {self.__class__.__name__}\n'\ 119 | f' Number of data points: {self.__len__()}\n'\ 120 | f' Root Location: {self.root}\n' 121 | tmp = ' Transforms (if any): ' 122 | trans_tmp = self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)) 123 | fmt_str += f'{tmp}{trans_tmp}' 124 | return fmt_str 125 | 126 | 127 | class CudaDataLoader: 128 | """ 异步预先将数据从CPU加载到GPU中 """ 129 | 130 | def __init__(self, loader, device, queue_size=2): 131 | self.device = device 132 | self.queue_size = queue_size 133 | self.loader = loader 134 | 135 | self.load_stream = torch.cuda.Stream(device=device) 136 | self.queue = Queue(maxsize=self.queue_size) 137 | 138 | self.idx = 0 139 | self.worker = Thread(target=self.load_loop) 140 | self.worker.setDaemon(True) 141 | self.worker.start() 142 | 143 | def load_loop(self): 144 | """ 不断的将cuda数据加载到队列里 """ 145 | # The loop that will load into the queue in the background 146 | torch.cuda.set_device(self.device) 147 | while True: 148 | for i, sample in enumerate(self.loader): 149 | self.queue.put(self.load_instance(sample)) 150 | 151 | def load_instance(self, sample): 152 | """ 将batch数据从CPU加载到GPU中 """ 153 | if torch.is_tensor(sample): 154 | with torch.cuda.stream(self.load_stream): 155 | return sample.to(self.device, non_blocking=True) 156 | elif sample is None or type(sample) in (list, str): 157 | return sample 158 | elif isinstance(sample, dict): 159 | return {k: self.load_instance(v) for k, v in sample.items()} 160 | else: 161 | return [self.load_instance(s) for s in sample] 162 | 163 | def __iter__(self): 164 | self.idx = 0 165 | return self 166 | 167 | def __next__(self): 168 | # 加载线程挂了 169 | if not self.worker.is_alive() and self.queue.empty(): 170 | self.idx = 0 171 | self.queue.join() 172 | self.worker.join() 173 | raise StopIteration 174 | # 一个epoch加载完了 175 | elif self.idx >= len(self.loader): 176 | self.idx = 0 177 | raise StopIteration 178 | # 下一个batch 179 | else: 180 | out = self.queue.get() 181 | self.queue.task_done() 182 | self.idx += 1 183 | return out 184 | 185 | def next(self): 186 | return self.__next__() 187 | 188 | def __len__(self): 189 | return len(self.loader) 190 | 191 | @property 192 | def sampler(self): 193 | return self.loader.sampler 194 | 195 | @property 196 | def dataset(self): 197 | return self.loader.dataset 198 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | import random 5 | from typing import Any, Union 6 | 7 | import cv2 8 | import torch 9 | import numpy as np 10 | from scipy import misc 11 | from tqdm import tqdm 12 | 13 | 14 | def calc_tv_loss(inp, mask=None, eps=1e-8): 15 | """ 提供inp平滑性约束 """ 16 | x_diff = inp[:, :, 1:, 1:] - inp[:, :, :-1, 1:] 17 | y_diff = inp[:, :, 1:, 1:] - inp[:, :, 1:, :-1] 18 | if mask is not None: 19 | x_diff *= mask[:, :, 1:, 1:] 20 | y_diff *= mask[:, :, 1:, 1:] 21 | 22 | # loss = torch.mean(torch.abs(x_diff) + torch.abs(y_diff)) 23 | loss = torch.mean(torch.sqrt(torch.square(inp) + torch.square(inp) + eps)) 24 | return loss 25 | 26 | 27 | def load_test_data(image_path, size=256): 28 | img = misc.imread(image_path, mode='RGB') 29 | img = misc.imresize(img, [size, size]) 30 | img = np.expand_dims(img, axis=0) 31 | img = preprocessing(img) 32 | 33 | return img 34 | 35 | 36 | def preprocessing(x): 37 | x = x / 127.5 - 1 # -1 ~ 1 38 | return x 39 | 40 | 41 | def save_images(images, size, image_path): 42 | return imsave(inverse_transform(images), size, image_path) 43 | 44 | 45 | def inverse_transform(images): 46 | return (images + 1.) / 2 47 | 48 | 49 | def imsave(images, size, path): 50 | return misc.imsave(path, merge(images, size)) 51 | 52 | 53 | def merge(images, size): 54 | h, w = images.shape[1], images.shape[2] 55 | img = np.zeros((h * size[0], w * size[1], 3)) 56 | for idx, image in enumerate(images): 57 | i = idx % size[1] 58 | j = idx // size[1] 59 | img[h * j:h * (j + 1), w * i:w * (i + 1), :] = image 60 | 61 | return img 62 | 63 | 64 | def check_folder(log_dir): 65 | if not os.path.exists(log_dir): 66 | os.makedirs(log_dir) 67 | return log_dir 68 | 69 | 70 | def cam(x, size=256): 71 | x = x - np.min(x) 72 | cam_img = x / np.max(x) 73 | cam_img = np.uint8(255 * cam_img) 74 | cam_img = cv2.resize(cam_img, (size, size)) 75 | cam_img = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET) 76 | return cam_img / 255.0 77 | 78 | 79 | def attention_mask(x, size=256): 80 | attention_img = cv2.resize(np.uint8(255 * x), (size, size)) 81 | attention_img = cv2.applyColorMap(attention_img, cv2.COLORMAP_JET) 82 | return attention_img / 255.0 83 | 84 | 85 | def imagenet_norm(x): 86 | mean = [0.485, 0.456, 0.406] 87 | std = [0.299, 0.224, 0.225] 88 | mean = torch.FloatTensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device) 89 | std = torch.FloatTensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device) 90 | return (x - mean) / std 91 | 92 | 93 | def denorm(x): 94 | return x * 0.5 + 0.5 95 | 96 | 97 | def tensor2numpy(x): 98 | return x.detach().cpu().numpy().transpose(1, 2, 0) 99 | 100 | 101 | def RGB2BGR(x): 102 | return cv2.cvtColor(x, cv2.COLOR_RGB2BGR) 103 | 104 | 105 | def setup_seed(seed): 106 | torch.manual_seed(seed) 107 | torch.cuda.manual_seed_all(seed) 108 | np.random.seed(seed) 109 | random.seed(seed) 110 | torch.backends.cudnn.deterministic = True 111 | 112 | 113 | """ 114 | 评估量:记录,打印 115 | """ 116 | 117 | 118 | class AverageMeter: 119 | """ 计算并存储 评估量的均值和当前值 """ 120 | 121 | def __init__(self, name, fmt=':f'): 122 | self.name = name # 评估量名称 123 | self.fmt = fmt # 评估量打印格式 124 | self.val = 0 # 评估量当前值 125 | self.avg = 0 # 评估量均值 126 | self.sum = 0 # 历史评估量的和 127 | self.count = 0 # 历史评估量的数量 128 | 129 | def reset(self): 130 | self.val = 0 131 | self.avg = 0 132 | self.sum = 0 133 | self.count = 0 134 | 135 | def update(self, val, n=1): 136 | self.val = val 137 | self.sum += val * n 138 | self.count += n 139 | self.avg = self.sum / self.count 140 | 141 | def __str__(self): 142 | fmtstr = f'{{name}} {{val{self.fmt}}} ({{avg{self.fmt}}})' 143 | return fmtstr.format(**self.__dict__) 144 | 145 | 146 | class ProgressMeter: 147 | """ 评估量的进度条打印 """ 148 | 149 | def __init__(self, num_batches, *meters, prefix=""): 150 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 151 | self.meters = meters 152 | self.prefix = prefix 153 | 154 | def print(self, batch): 155 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 156 | entries += [str(meter) for meter in self.meters] 157 | print('\t'.join(entries)) 158 | 159 | @staticmethod 160 | def _get_batch_fmtstr(num_batches): 161 | num_digits = len(str(num_batches // 1)) 162 | fmt = f'{{:{str(num_digits)}d}}' 163 | return f'[{fmt}/{fmt.format(num_batches)}]' 164 | 165 | 166 | class Logger(object): 167 | """ 168 | Redirect stderr to stdout, optionally print stdout to a file, 169 | and optionally force flushing on both stdout and the file. 170 | """ 171 | 172 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): 173 | self.file = None 174 | 175 | if file_name is not None: 176 | self.file = open(file_name, file_mode) 177 | 178 | self.should_flush = should_flush 179 | self.stdout = sys.stdout 180 | self.stderr = sys.stderr 181 | 182 | sys.stdout = self 183 | sys.stderr = self 184 | 185 | def __enter__(self) -> "Logger": 186 | return self 187 | 188 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 189 | self.close() 190 | 191 | def write(self, text: Union[str, bytes]) -> None: 192 | """Write text to stdout (and a file) and optionally flush.""" 193 | if isinstance(text, bytes): 194 | text = text.decode() 195 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 196 | return 197 | 198 | if self.file is not None: 199 | self.file.write(text) 200 | 201 | self.stdout.write(text) 202 | 203 | if self.should_flush: 204 | self.flush() 205 | 206 | def flush(self) -> None: 207 | """Flush written text to both stdout and a file, if open.""" 208 | if self.file is not None: 209 | self.file.flush() 210 | 211 | self.stdout.flush() 212 | 213 | def close(self) -> None: 214 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 215 | self.flush() 216 | 217 | # if using multiple loggers, prevent closing in wrong order 218 | if sys.stdout is self: 219 | sys.stdout = self.stdout 220 | if sys.stderr is self: 221 | sys.stderr = self.stderr 222 | 223 | if self.file is not None: 224 | self.file.close() 225 | self.file = None 226 | 227 | 228 | """ 229 | 制作模糊图像 230 | """ 231 | 232 | 233 | def generate_blur_images(root, save): 234 | """ 根据清晰图像制作模糊的图像 235 | :param root: 清晰图像所在的根目录 236 | :param save: 模糊图像存放的根目录 237 | """ 238 | print(f'generating blur images: {root} to {save}...') 239 | file_list = os.listdir(root) 240 | if not os.path.isdir(save): 241 | os.makedirs(save) 242 | kernel_size = 5 243 | kernel = np.ones((kernel_size, kernel_size), np.uint8) 244 | gauss = cv2.getGaussianKernel(kernel_size, 0) 245 | gauss = gauss * gauss.transpose(1, 0) 246 | for f in tqdm(file_list): 247 | try: 248 | rgb_img = cv2.imread(os.path.join(root, f)) 249 | gray_img = cv2.imread(os.path.join(root, f), 0) 250 | pad_img = np.pad(rgb_img, ((2, 2), (2, 2), (0, 0)), mode='reflect') 251 | edges = cv2.Canny(gray_img, 100, 200) 252 | dilation = cv2.dilate(edges, kernel) 253 | 254 | gauss_img = np.copy(rgb_img) 255 | idx = np.where(dilation != 0) 256 | for i in range(np.sum(dilation != 0)): 257 | gauss_img[idx[0][i], idx[1][i], 0] = np.sum( 258 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 0], 259 | gauss)) 260 | gauss_img[idx[0][i], idx[1][i], 1] = np.sum( 261 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 1], 262 | gauss)) 263 | gauss_img[idx[0][i], idx[1][i], 2] = np.sum( 264 | np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 2], 265 | gauss)) 266 | 267 | cv2.imwrite(os.path.join(save, f), gauss_img) 268 | except Exception as e: 269 | print(f'{f} failed!\n{e}') 270 | 271 | print(f'finish: blur images over! ') 272 | -------------------------------------------------------------------------------- /histogram.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 直方图匹配,将一张图片的直方图匹配到目标图上,使两张图的视觉感觉接近 4 | ref https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/blob/95ba2834a358fa243665c86407b220e4e78854fe/Face_Detection/align_warp_back_multiple_dlib.py 5 | """ 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | def match_histograms(src_image, ref_image, rate=1.0, image_type='HWC'): 11 | """ 12 | This method matches the source image histogram to the 13 | reference signal 14 | :param image src_image: The original source image 15 | :param image ref_image: The reference image 16 | :param rate: histograms shift ratio 17 | :param image_type: HWC or CHW 18 | :return: image_after_matching 19 | :rtype: image (array) 20 | """ 21 | # Split the images into the different color channels 22 | # b means blue, g means green and r means red 23 | if image_type == 'HWC': 24 | src_b, src_g, src_r = cv2.split(src_image) 25 | ref_b, ref_g, ref_r = cv2.split(ref_image) 26 | elif image_type == 'CHW': 27 | src_b, src_g, src_r = src_image[0], src_image[1], src_image[2] 28 | ref_b, ref_g, ref_r = ref_image[0], ref_image[1], ref_image[2] 29 | else: 30 | raise ValueError(f'image_type only HWC or CHW, no: {image_type}') 31 | 32 | # Compute the b, g, and r histograms separately 33 | # The flatten() Numpy method returns a copy of the array c 34 | # collapsed into one dimension. 35 | src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) 36 | src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) 37 | src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) 38 | ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) 39 | ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) 40 | ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) 41 | 42 | # Compute the normalized cdf for the source and reference image 43 | src_cdf_blue = calculate_cdf(src_hist_blue) 44 | src_cdf_green = calculate_cdf(src_hist_green) 45 | src_cdf_red = calculate_cdf(src_hist_red) 46 | ref_cdf_blue = calculate_cdf(ref_hist_blue) 47 | ref_cdf_green = calculate_cdf(ref_hist_green) 48 | ref_cdf_red = calculate_cdf(ref_hist_red) 49 | 50 | if rate < 1.0: 51 | ref_cdf_blue = src_cdf_blue * (1.0 - rate) + ref_cdf_blue * rate 52 | ref_cdf_green = src_cdf_green * (1.0 - rate) + ref_cdf_green * rate 53 | ref_cdf_red = src_cdf_red * (1.0 - rate) + ref_cdf_red * rate 54 | 55 | # Make a separate lookup table for each color 56 | blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) 57 | green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) 58 | red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) 59 | 60 | # Use the lookup function to transform the colors of the original 61 | # source image 62 | blue_after_transform = cv2.LUT(src_b, blue_lookup_table) 63 | green_after_transform = cv2.LUT(src_g, green_lookup_table) 64 | red_after_transform = cv2.LUT(src_r, red_lookup_table) 65 | 66 | # Put the image back together 67 | if image_type == 'HWC': 68 | image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform]) 69 | elif image_type == 'CHW': 70 | image_after_matching = np.array([blue_after_transform, green_after_transform, red_after_transform]) 71 | else: 72 | raise ValueError(f'image_type only HWC or CHW, no: {image_type}') 73 | 74 | image_after_matching = cv2.convertScaleAbs(image_after_matching) 75 | 76 | return image_after_matching 77 | 78 | 79 | def calculate_cdf(histogram: np.ndarray) -> np.ndarray: 80 | """ 81 | This method calculates the cumulative distribution function 82 | :param array histogram: The values of the histogram 83 | :return: normalized_cdf: The normalized cumulative distribution function 84 | :rtype: array 85 | """ 86 | # Get the cumulative sum of the elements 87 | cdf = histogram.cumsum() 88 | 89 | # Normalize the cdf 90 | normalized_cdf = cdf / float(cdf.max()) 91 | 92 | return normalized_cdf 93 | 94 | 95 | def calculate_lookup(src_cdf: np.ndarray, ref_cdf: np.ndarray) -> np.ndarray: 96 | """ 97 | This method creates the lookup table 98 | :param array src_cdf: The cdf for the source image 99 | :param array ref_cdf: The cdf for the reference image 100 | :return: lookup_table: The lookup table 101 | :rtype: array 102 | """ 103 | lookup_table = np.zeros(256) 104 | lookup_val = 0 105 | for src_pixel_val in range(len(src_cdf)): 106 | for ref_pixel_val in range(len(ref_cdf)): 107 | if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: 108 | lookup_val = ref_pixel_val 109 | break 110 | lookup_table[src_pixel_val] = lookup_val 111 | return lookup_table 112 | 113 | 114 | if __name__ == '__main__': 115 | import numpy as np 116 | 117 | kid_src = cv2.imread('dataset/kid_src.png') 118 | man_src = cv2.imread('dataset/man_src.png') 119 | 120 | # 均衡 1 121 | kid_match_bgr_10 = match_histograms(kid_src, man_src, rate=1) 122 | kid_match_hsv_10 = match_histograms(cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL), 123 | cv2.cvtColor(man_src, cv2.COLOR_BGR2HSV_FULL), rate=1) 124 | kid_match_hls_10 = match_histograms(cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL), 125 | cv2.cvtColor(man_src, cv2.COLOR_BGR2HLS_FULL), rate=1) 126 | # 均衡0.5 127 | kid_match_bgr_05 = match_histograms(kid_src, man_src, rate=0.5) 128 | kid_match_hsv_05 = match_histograms(cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL), 129 | cv2.cvtColor(man_src, cv2.COLOR_BGR2HSV_FULL), rate=0.5) 130 | kid_match_hls_05 = match_histograms(cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL), 131 | cv2.cvtColor(man_src, cv2.COLOR_BGR2HLS_FULL), rate=0.5) 132 | 133 | result = np.concatenate([np.concatenate([kid_src, man_src, man_src], axis=1), 134 | np.concatenate([kid_match_bgr_10, cv2.cvtColor(kid_match_hsv_10, cv2.COLOR_HSV2BGR_FULL), 135 | cv2.cvtColor(kid_match_hls_10, cv2.COLOR_HLS2BGR_FULL)], 136 | axis=1), 137 | np.concatenate([kid_match_bgr_05, cv2.cvtColor(kid_match_hsv_05, cv2.COLOR_HSV2BGR_FULL), 138 | cv2.cvtColor(kid_match_hls_05, cv2.COLOR_HLS2BGR_FULL)], 139 | axis=1)], 140 | axis=0) 141 | cv2.imwrite('all.png', result) 142 | cv2.imshow('match_hsv', result) 143 | cv2.waitKey(0) 144 | 145 | # 只均衡hsv某一通道 146 | kid_hsv_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL) 147 | kid_hsv_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL) 148 | kid_hsv_src_10[..., 0] = kid_match_hsv_10[..., 0] 149 | kid_hsv_src_05[..., 0] = kid_match_hsv_05[..., 0] 150 | kid_hls_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL) 151 | kid_hls_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL) 152 | kid_hls_src_10[..., 0] = kid_match_hls_10[..., 0] 153 | kid_hls_src_05[..., 0] = kid_match_hls_05[..., 0] 154 | kid_match_h = np.concatenate([cv2.cvtColor(kid_hsv_src_10, cv2.COLOR_HSV2BGR_FULL), 155 | cv2.cvtColor(kid_hsv_src_05, cv2.COLOR_HSV2BGR_FULL), 156 | cv2.cvtColor(kid_hls_src_10, cv2.COLOR_HLS2BGR_FULL), 157 | cv2.cvtColor(kid_hls_src_05, cv2.COLOR_HLS2BGR_FULL)], axis=1) 158 | 159 | kid_hsv_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL) 160 | kid_hsv_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL) 161 | kid_hsv_src_10[..., 1] = kid_match_hsv_10[..., 1] 162 | kid_hsv_src_05[..., 1] = kid_match_hsv_05[..., 1] 163 | kid_hls_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL) 164 | kid_hls_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL) 165 | kid_hls_src_10[..., 2] = kid_match_hls_10[..., 2] 166 | kid_hls_src_05[..., 2] = kid_match_hls_05[..., 2] 167 | kid_match_s = np.concatenate([cv2.cvtColor(kid_hsv_src_10, cv2.COLOR_HSV2BGR_FULL), 168 | cv2.cvtColor(kid_hsv_src_05, cv2.COLOR_HSV2BGR_FULL), 169 | cv2.cvtColor(kid_hls_src_10, cv2.COLOR_HLS2BGR_FULL), 170 | cv2.cvtColor(kid_hls_src_05, cv2.COLOR_HLS2BGR_FULL)], axis=1) 171 | 172 | kid_hsv_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL) 173 | kid_hsv_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HSV_FULL) 174 | kid_hsv_src_10[..., 2] = kid_match_hsv_10[..., 2] 175 | kid_hsv_src_05[..., 2] = kid_match_hsv_05[..., 2] 176 | kid_hls_src_10 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL) 177 | kid_hls_src_05 = cv2.cvtColor(kid_src, cv2.COLOR_BGR2HLS_FULL) 178 | kid_hls_src_10[..., 1] = kid_match_hls_10[..., 1] 179 | kid_hls_src_05[..., 1] = kid_match_hls_05[..., 1] 180 | kid_match_v = np.concatenate([cv2.cvtColor(kid_hsv_src_10, cv2.COLOR_HSV2BGR_FULL), 181 | cv2.cvtColor(kid_hsv_src_05, cv2.COLOR_HSV2BGR_FULL), 182 | cv2.cvtColor(kid_hls_src_10, cv2.COLOR_HLS2BGR_FULL), 183 | cv2.cvtColor(kid_hls_src_05, cv2.COLOR_HLS2BGR_FULL)], axis=1) 184 | 185 | result = np.concatenate([np.concatenate([kid_src, man_src, kid_src, man_src], axis=1), 186 | kid_match_h, kid_match_s, kid_match_v], axis=0) 187 | 188 | cv2.imwrite('one.png', result) 189 | cv2.imshow('match_hsv', result) 190 | cv2.waitKey(0) 191 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | CUDA_VISIBLE_DEVICES=0 python main.py \ 4 | --dataset YOUR_DATA_SET --result_dir results --img_size 384 --device cuda:0 --num_workers 1 \ 5 | # G网络计算AdaLIN的大小 6 | --light -1 \ 7 | # 指标打印 8 | --print_freq 1000 --calc_fid_freq 10000 --save_freq 100000 \ 9 | # 模型ema 10 | --ema_start 0.5 --ema_beta 0.9999 \ 11 | # 分割损失 12 | --seg_fix_weight 50 --seg_fix_glass_mouth --seg_D_mask --seg_G_detach --seg_D_cam_fea_mask --seg_D_cam_inp_mask \ 13 | --cam_D_weight -1 \ 14 | # attention gan 15 | --attention_gan 2 --attention_input \ 16 | # 不同损失项权重 17 | --adv_weight 1.0 --forward_adv_weight 1 --cycle_weight 10 --identity_weight 10 \ 18 | # 直方图匹配 19 | --match_histograms --match_mode hsl --match_prob 0.5 --match_ratio 1.0 \ 20 | # 模糊、se 21 | --has_blur --use_se 22 | 23 | e.g. 24 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset transfer/yuwei/styleGAN/dy_cartoon/dy_cartoon.tar \ 25 | --light 32 --result_dir transfer/yuwei/styleGAN/oilpaint_result/pai2/dy_only \ 26 | --img_size 384 --device cuda:0 --num_workers 4 27 | """ 28 | import os 29 | import time 30 | import datetime 31 | import argparse 32 | 33 | import torch 34 | import torch.backends.cudnn 35 | 36 | import utils 37 | from UGATIT import UGATIT 38 | 39 | VIDEO_EXT = ('.ts', '.mp4') 40 | 41 | 42 | def parse_args(): 43 | """parsing and configuration""" 44 | desc = "Pytorch implementation of U-GAT-IT" 45 | parser = argparse.ArgumentParser(description=desc) 46 | parser.add_argument('--phase', type=str, default='train', 47 | choices=['train', 'val', 'test', 'video', 'video_dir', 'camera', 'img_dir', 'generate'], 48 | help='训练/验证/测试/视频/视频文件夹/摄像头/图像文件夹/以对齐的人脸图像文件夹 模式') 49 | parser.add_argument('--light', type=int, default=-1, 50 | help='[U-GAT-IT full version / U-GAT-IT light version],求gamma和beta的MLP的输入尺寸') 51 | parser.add_argument('--dataset', type=str, default='YOUR_DATASET_NAME', help='dataset_name') 52 | 53 | parser.add_argument('--ema_start', type=float, default=0.5, help='start ema after ratio of --iteration') 54 | parser.add_argument('--ema_beta', type=float, default=0.9999, help='ema gamma for genA2B/B2A, 0.9999^10000=0.37') 55 | parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations') 56 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size') 57 | parser.add_argument('--num_workers', type=int, default=1, help='每个 DataLoader 的进程数') 58 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image print freq') 59 | parser.add_argument('--calc_fid_freq', type=int, default=10000, help='The number of fid print freq') 60 | parser.add_argument('--fid_batch', type=int, default=50, help='计算fid score时的batch size') 61 | parser.add_argument('--save_freq', type=int, default=100000, help='The number of model save freq') 62 | parser.add_argument('--no_decay_flag', action='store_false', help='在中间iteration时,使用学习率下降策略,默认使用') 63 | 64 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 65 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='The weight decay') 66 | parser.add_argument('--adv_weight', type=float, default=1, help='Weight for GAN,建议值:0.8') 67 | parser.add_argument('--forward_adv_weight', type=float, default=1, help='前向对抗损失的权重,建议值:2') 68 | parser.add_argument('--backward_adv_weight', type=float, default=1, help='后向对抗损失的权重,建议值:1') 69 | parser.add_argument('--cycle_weight', type=float, default=10, help='Weight for Cycle,建议值:3') 70 | parser.add_argument('--identity_weight', type=float, default=10, help='Weight for Identity,建议值:1.5') 71 | parser.add_argument('--cam_weight', type=float, default=1000, help='Weight for CAM,建议值:1000') 72 | 73 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 74 | parser.add_argument('--n_res', type=int, default=4, help='The number of resblock') 75 | parser.add_argument('--n_global_dis', type=int, default=7, help='The number of global discriminator layer') 76 | parser.add_argument('--n_local_dis', type=int, default=5, help='The number of local discriminator layer') 77 | parser.add_argument('--img_size', type=int, default=384, help='The size of image') 78 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 79 | parser.add_argument('--result_dir', type=str, default='project/results', help='Directory name to save the results') 80 | parser.add_argument('--device', type=str, default='cuda:0', choices=['cpu', 'cuda:0'], help='Set gpu mode; [cpu, cuda:0]') # noqa, E501 81 | parser.add_argument('--resume', action='store_true', help='是否继续最后的一次训练') 82 | 83 | # 增强U-GAT-IT选项 84 | parser.add_argument('--aug_prob', type=float, default=0.2, help='对数据应用 resize & crop 数据增强的概率,建议值,<=0.2') 85 | parser.add_argument('--sn', action='store_false', help='默认D网络使用sn,建议使用,tf版本使用了') 86 | parser.add_argument('--has_blur', action='store_true', help='默认不使用模糊数据增强D网络,建议使用') 87 | parser.add_argument('--tv_loss', action='store_true', help='是否对生成图像使用TVLoss,默认不适用') 88 | parser.add_argument('--tv_weight', type=float, default=1.0, help='Weight for TVLoss,建议值:1.0') 89 | parser.add_argument('--use_se', action='store_true', help='resblock是否使用se-block,可以使用') 90 | parser.add_argument('--attention_gan', type=int, default=0, help='attention_gan,可以尝试') 91 | parser.add_argument('--attention_input', action='store_true', help='attention_gan时,是否把输入加入做attention,可以尝试') 92 | parser.add_argument('--cam_D_weight', type=float, default=1, help='判别器的CAM分类损失项权重,建议值:1') 93 | parser.add_argument('--cam_D_attention', action='store_false', help='是否使用判别器的CAM注意力机制,默认使用') 94 | # 直方图匹配 95 | parser.add_argument('--match_histograms', action='store_true', help='默认不使用直方图匹配,两个域真实域存在部分分布差异可尝试') 96 | parser.add_argument('--match_mode', type=str, default='hsl', help='默认直方图匹配使用hsl') 97 | parser.add_argument('--match_prob', type=float, default=0.5, help='从 B->A 进行直方图匹配的概率,否则 A->B 进行直方图匹配') 98 | parser.add_argument('--match_ratio', type=float, default=1.0, help='直方图匹配的比例') 99 | # 固定背景选项 100 | parser.add_argument('--hard_seg_edge', action='store_true', help='分割边界是否为硬边界,默认为软边界') 101 | parser.add_argument('--seg_fix_weight', type=float, default=-1, help='对生成图像的分割区域与原图做L1损失项的权重,建议值:50') 102 | parser.add_argument('--seg_fix_glass_mouth', action='store_true', help='分割是否固定眼镜边框和嘴巴内部(作为背景),默认不固定') 103 | parser.add_argument('--seg_D_mask', action='store_true', help='只计算分割mask区域的判别损失,默认都计算') 104 | parser.add_argument('--seg_G_detach', action='store_true', help='对生成图像的非分割mask区域做detach,默认不detach') 105 | parser.add_argument('--seg_D_cam_fea_mask', action='store_true', help='将判别器cam的feature map做mask替换,默认不替换') 106 | parser.add_argument('--seg_D_cam_inp_mask', action='store_true', help='将输入给判别器cam的图像做mask替换,默认不替换') 107 | 108 | # 测试 109 | parser.add_argument('--generator_model', type=str, default='', help='测试的A2B生成器路径') 110 | parser.add_argument('--video_path', type=str, default='', help='测试的A2B生成器路径所用的视频路径') 111 | parser.add_argument('--img_dir', type=str, default='', help='测试的A2B生成器路径所用的图像文件夹路径') 112 | 113 | return check_args(parser.parse_args()) 114 | 115 | 116 | def check_args(args): 117 | """checking arguments""" 118 | utils.check_folder(args.result_dir) 119 | utils.Logger(file_name=os.path.join(args.result_dir, f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.log"), 120 | file_mode='a', should_flush=True) 121 | if args.cam_D_weight <= 0: 122 | args.cam_D_attention = False 123 | print(f'can not use D cam attention while D cam weight = {args.cam_D_weight} <= 0') 124 | 125 | if args.phase in ('video', 'video_dir', 'camera', 'img_dir', 'generate'): 126 | return args 127 | 128 | # --result_dir 129 | utils.check_folder(os.path.join(args.result_dir, args.dataset, 'model')) 130 | utils.check_folder(os.path.join(args.result_dir, args.dataset, 'img')) 131 | utils.check_folder(os.path.join(args.result_dir, args.dataset, 'test')) 132 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 133 | return args 134 | 135 | 136 | def main(): 137 | # parse arguments 138 | args = parse_args() 139 | if args is None: 140 | exit() 141 | if not torch.cuda.is_available(): 142 | args.device = 'cpu' 143 | args.device = torch.device(args.device) 144 | print(args) 145 | 146 | # 视频/摄像头/图像文件夹 测试 147 | if args.phase in ('video', 'video_dir', 'camera', 'img_dir', 'generate'): 148 | if args.generator_model == '': 149 | raise ValueError('No define A2B G model path!') 150 | from videos import test_video 151 | from networks import ResnetGenerator 152 | torch.set_flush_denormal(True) 153 | # 定义及加载生成模型 154 | generator = ResnetGenerator(input_nc=3, output_nc=3, ngf=args.ch, n_blocks=args.n_res, 155 | img_size=args.img_size, args=args) 156 | params = torch.load(args.generator_model, map_location=torch.device("cpu")) 157 | generator.load_state_dict(params['genA2B_ema']) 158 | # 模型测试 159 | tester = test_video.VideoTester(args, generator) 160 | if args.phase in ('video', 'video_dir'): 161 | assert args.video_path and os.path.exists(args.video_path), f'video path ({args.video_path}) error!' 162 | if args.phase == 'video_dir': 163 | video_paths = [os.path.join(args.video_path, video_name) for video_name in os.listdir(args.video_path) 164 | if video_name.endswith(VIDEO_EXT)] 165 | else: 166 | video_paths = [args.video_path] 167 | for video_path in video_paths: 168 | print(f'generating video: {video_path} ...') 169 | tester.video(video_path) 170 | elif args.phase == 'camera': 171 | tester.camera() 172 | elif args.phase == 'img_dir': 173 | assert args.img_dir and os.path.exists(args.img_dir), f'image directory ({args.img_dir}) error!' 174 | tester.image_dir(args.img_dir) 175 | elif args.phase == 'generate': 176 | assert args.img_dir and os.path.exists(args.img_dir), f'image directory ({args.img_dir}) error!' 177 | tester.generate_images(args.img_dir) 178 | else: 179 | raise Exception(f'unknown phase: {args.phase}') 180 | return 181 | 182 | if torch.backends.cudnn.enabled: 183 | torch.backends.cudnn.benchmark = True 184 | utils.setup_seed(0) 185 | 186 | # open session 187 | gan = UGATIT(args) 188 | 189 | # build graph 190 | gan.build_model() 191 | 192 | if args.phase == 'train': 193 | gan.train() 194 | print(" [*] Training finished!") 195 | args.phase = 'test' 196 | 197 | if args.phase == 'test': 198 | torch.set_flush_denormal(True) 199 | gan.test() 200 | print(" [*] Test finished!") 201 | 202 | return 203 | 204 | 205 | if __name__ == '__main__': 206 | main() 207 | -------------------------------------------------------------------------------- /faceseg/networks_faceseg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from faceseg.resnet import Resnet18 10 | 11 | 12 | # from modules.bn import InPlaceABNSync as BatchNorm2d 13 | 14 | 15 | class ConvBNReLU(nn.Module): 16 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 17 | super(ConvBNReLU, self).__init__() 18 | self.conv = nn.Conv2d(in_chan, 19 | out_chan, 20 | kernel_size=ks, 21 | stride=stride, 22 | padding=padding, 23 | bias=False) 24 | self.bn = nn.BatchNorm2d(out_chan) 25 | self.init_weight() 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | x = F.relu(self.bn(x)) 30 | return x 31 | 32 | def init_weight(self): 33 | for ly in self.children(): 34 | if isinstance(ly, nn.Conv2d): 35 | nn.init.kaiming_normal_(ly.weight, a=1) 36 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 37 | 38 | 39 | class BiSeNetOutput(nn.Module): 40 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 41 | super(BiSeNetOutput, self).__init__() 42 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 43 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 44 | self.init_weight() 45 | 46 | def forward(self, x): 47 | x = self.conv(x) 48 | x = self.conv_out(x) 49 | return x 50 | 51 | def init_weight(self): 52 | for ly in self.children(): 53 | if isinstance(ly, nn.Conv2d): 54 | nn.init.kaiming_normal_(ly.weight, a=1) 55 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 56 | 57 | def get_params(self): 58 | wd_params, nowd_params = [], [] 59 | for name, module in self.named_modules(): 60 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 61 | wd_params.append(module.weight) 62 | if not module.bias is None: 63 | nowd_params.append(module.bias) 64 | elif isinstance(module, nn.BatchNorm2d): 65 | nowd_params += list(module.parameters()) 66 | return wd_params, nowd_params 67 | 68 | 69 | class AttentionRefinementModule(nn.Module): 70 | def __init__(self, in_chan, out_chan, *args, **kwargs): 71 | super(AttentionRefinementModule, self).__init__() 72 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 73 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 74 | self.bn_atten = nn.BatchNorm2d(out_chan) 75 | self.sigmoid_atten = nn.Sigmoid() 76 | self.init_weight() 77 | 78 | def forward(self, x): 79 | feat = self.conv(x) 80 | atten = F.avg_pool2d(feat, feat.size()[2:]) 81 | atten = self.conv_atten(atten) 82 | atten = self.bn_atten(atten) 83 | atten = self.sigmoid_atten(atten) 84 | out = torch.mul(feat, atten) 85 | return out 86 | 87 | def init_weight(self): 88 | for ly in self.children(): 89 | if isinstance(ly, nn.Conv2d): 90 | nn.init.kaiming_normal_(ly.weight, a=1) 91 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 92 | 93 | 94 | class ContextPath(nn.Module): 95 | def __init__(self, *args, **kwargs): 96 | super(ContextPath, self).__init__() 97 | self.resnet = Resnet18() 98 | self.arm16 = AttentionRefinementModule(256, 128) 99 | self.arm32 = AttentionRefinementModule(512, 128) 100 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 101 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 102 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 103 | 104 | self.init_weight() 105 | 106 | def forward(self, x): 107 | H0, W0 = x.size()[2:] 108 | feat8, feat16, feat32 = self.resnet(x) 109 | H8, W8 = feat8.size()[2:] 110 | H16, W16 = feat16.size()[2:] 111 | H32, W32 = feat32.size()[2:] 112 | 113 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 114 | avg = self.conv_avg(avg) 115 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 116 | 117 | feat32_arm = self.arm32(feat32) 118 | feat32_sum = feat32_arm + avg_up 119 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 120 | feat32_up = self.conv_head32(feat32_up) 121 | 122 | feat16_arm = self.arm16(feat16) 123 | feat16_sum = feat16_arm + feat32_up 124 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 125 | feat16_up = self.conv_head16(feat16_up) 126 | 127 | return feat8, feat16_up, feat32_up # x8, x8, x16 128 | 129 | def init_weight(self): 130 | for ly in self.children(): 131 | if isinstance(ly, nn.Conv2d): 132 | nn.init.kaiming_normal_(ly.weight, a=1) 133 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 134 | 135 | def get_params(self): 136 | wd_params, nowd_params = [], [] 137 | for name, module in self.named_modules(): 138 | if isinstance(module, (nn.Linear, nn.Conv2d)): 139 | wd_params.append(module.weight) 140 | if not module.bias is None: 141 | nowd_params.append(module.bias) 142 | elif isinstance(module, nn.BatchNorm2d): 143 | nowd_params += list(module.parameters()) 144 | return wd_params, nowd_params 145 | 146 | 147 | ### This is not used, since I replace this with the resnet feature with the same size 148 | class SpatialPath(nn.Module): 149 | def __init__(self, *args, **kwargs): 150 | super(SpatialPath, self).__init__() 151 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 152 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 153 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 154 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 155 | self.init_weight() 156 | 157 | def forward(self, x): 158 | feat = self.conv1(x) 159 | feat = self.conv2(feat) 160 | feat = self.conv3(feat) 161 | feat = self.conv_out(feat) 162 | return feat 163 | 164 | def init_weight(self): 165 | for ly in self.children(): 166 | if isinstance(ly, nn.Conv2d): 167 | nn.init.kaiming_normal_(ly.weight, a=1) 168 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 169 | 170 | def get_params(self): 171 | wd_params, nowd_params = [], [] 172 | for name, module in self.named_modules(): 173 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 174 | wd_params.append(module.weight) 175 | if not module.bias is None: 176 | nowd_params.append(module.bias) 177 | elif isinstance(module, nn.BatchNorm2d): 178 | nowd_params += list(module.parameters()) 179 | return wd_params, nowd_params 180 | 181 | 182 | class FeatureFusionModule(nn.Module): 183 | def __init__(self, in_chan, out_chan, *args, **kwargs): 184 | super(FeatureFusionModule, self).__init__() 185 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 186 | self.conv1 = nn.Conv2d(out_chan, 187 | out_chan // 4, 188 | kernel_size=1, 189 | stride=1, 190 | padding=0, 191 | bias=False) 192 | self.conv2 = nn.Conv2d(out_chan // 4, 193 | out_chan, 194 | kernel_size=1, 195 | stride=1, 196 | padding=0, 197 | bias=False) 198 | self.relu = nn.ReLU(inplace=True) 199 | self.sigmoid = nn.Sigmoid() 200 | self.init_weight() 201 | 202 | def forward(self, fsp, fcp): 203 | fcat = torch.cat([fsp, fcp], dim=1) 204 | feat = self.convblk(fcat) 205 | atten = F.avg_pool2d(feat, feat.size()[2:]) 206 | atten = self.conv1(atten) 207 | atten = self.relu(atten) 208 | atten = self.conv2(atten) 209 | atten = self.sigmoid(atten) 210 | feat_atten = torch.mul(feat, atten) 211 | feat_out = feat_atten + feat 212 | return feat_out 213 | 214 | def init_weight(self): 215 | for ly in self.children(): 216 | if isinstance(ly, nn.Conv2d): 217 | nn.init.kaiming_normal_(ly.weight, a=1) 218 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 219 | 220 | def get_params(self): 221 | wd_params, nowd_params = [], [] 222 | for name, module in self.named_modules(): 223 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 224 | wd_params.append(module.weight) 225 | if not module.bias is None: 226 | nowd_params.append(module.bias) 227 | elif isinstance(module, nn.BatchNorm2d): 228 | nowd_params += list(module.parameters()) 229 | return wd_params, nowd_params 230 | 231 | 232 | class BiSeNet(nn.Module): 233 | def __init__(self, n_classes, *args, **kwargs): 234 | super(BiSeNet, self).__init__() 235 | self.cp = ContextPath() 236 | ## here self.sp is deleted 237 | self.ffm = FeatureFusionModule(256, 256) 238 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 239 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 240 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 241 | self.init_weight() 242 | 243 | def forward(self, x): 244 | H, W = x.size()[2:] 245 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 246 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 247 | feat_fuse = self.ffm(feat_sp, feat_cp8) 248 | 249 | feat_out = self.conv_out(feat_fuse) 250 | feat_out16 = self.conv_out16(feat_cp8) 251 | feat_out32 = self.conv_out32(feat_cp16) 252 | 253 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 254 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 255 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 256 | return feat_out, feat_out16, feat_out32 257 | 258 | def init_weight(self): 259 | for ly in self.children(): 260 | if isinstance(ly, nn.Conv2d): 261 | nn.init.kaiming_normal_(ly.weight, a=1) 262 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 263 | 264 | def get_params(self): 265 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 266 | for name, child in self.named_children(): 267 | child_wd_params, child_nowd_params = child.get_params() 268 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 269 | lr_mul_wd_params += child_wd_params 270 | lr_mul_nowd_params += child_nowd_params 271 | else: 272 | wd_params += child_wd_params 273 | nowd_params += child_nowd_params 274 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 275 | 276 | 277 | if __name__ == "__main__": 278 | net = BiSeNet(19) 279 | net.cuda() 280 | net.eval() 281 | in_ten = torch.randn(16, 3, 640, 480).cuda() 282 | out, out16, out32 = net(in_ten) 283 | print(out.shape) 284 | 285 | net.get_params() 286 | -------------------------------------------------------------------------------- /metrics/fid/inception.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | 7 | try: 8 | from torchvision.models.utils import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | # Inception weights ported to Pytorch from 13 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 14 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 15 | FID_WEIGHTS_LOCAL = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pt_inception-2015-12-05-6726825d.pth') 16 | 17 | 18 | class InceptionV3(nn.Module): 19 | """Pretrained InceptionV3 network returning feature maps""" 20 | 21 | # Index of default block of inception to return, 22 | # corresponds to output of final average pooling 23 | DEFAULT_BLOCK_INDEX = 3 24 | 25 | # Maps feature dimensionality to their output blocks indices 26 | BLOCK_INDEX_BY_DIM = { 27 | 64: 0, # First max pooling features 28 | 192: 1, # Second max pooling featurs 29 | 768: 2, # Pre-aux classifier features 30 | 2048: 3 # Final average pooling features 31 | } 32 | 33 | def __init__(self, 34 | output_blocks=(DEFAULT_BLOCK_INDEX,), 35 | resize_input=True, 36 | normalize_input=True, 37 | requires_grad=False, 38 | use_fid_inception=True): 39 | """Build pretrained InceptionV3 40 | 41 | Parameters 42 | ---------- 43 | output_blocks : list of int 44 | Indices of blocks to return features of. Possible values are: 45 | - 0: corresponds to output of first max pooling 46 | - 1: corresponds to output of second max pooling 47 | - 2: corresponds to output which is fed to aux classifier 48 | - 3: corresponds to output of final average pooling 49 | resize_input : bool 50 | If true, bilinearly resizes input to width and height 299 before 51 | feeding input to model. As the network without fully connected 52 | layers is fully convolutional, it should be able to handle inputs 53 | of arbitrary size, so resizing might not be strictly needed 54 | normalize_input : bool 55 | If true, scales the input from range (0, 1) to the range the 56 | pretrained Inception network expects, namely (-1, 1) 57 | requires_grad : bool 58 | If true, parameters of the model require gradients. Possibly useful 59 | for finetuning the network 60 | use_fid_inception : bool 61 | If true, uses the pretrained Inception model used in Tensorflow's 62 | FID implementation. If false, uses the pretrained Inception model 63 | available in torchvision. The FID Inception model has different 64 | weights and a slightly different structure from torchvision's 65 | Inception model. If you want to compute FID scores, you are 66 | strongly advised to set this parameter to true to get comparable 67 | results. 68 | """ 69 | super(InceptionV3, self).__init__() 70 | 71 | self.resize_input = resize_input 72 | self.normalize_input = normalize_input 73 | self.output_blocks = sorted(output_blocks) 74 | self.last_needed_block = max(output_blocks) 75 | 76 | assert self.last_needed_block <= 3, \ 77 | 'Last possible output block index is 3' 78 | 79 | self.blocks = nn.ModuleList() 80 | 81 | if use_fid_inception: 82 | inception = fid_inception_v3() 83 | else: 84 | inception = _inception_v3(pretrained=True) 85 | 86 | # Block 0: input to maxpool1 87 | block0 = [ 88 | inception.Conv2d_1a_3x3, 89 | inception.Conv2d_2a_3x3, 90 | inception.Conv2d_2b_3x3, 91 | nn.MaxPool2d(kernel_size=3, stride=2) 92 | ] 93 | self.blocks.append(nn.Sequential(*block0)) 94 | 95 | # Block 1: maxpool1 to maxpool2 96 | if self.last_needed_block >= 1: 97 | block1 = [ 98 | inception.Conv2d_3b_1x1, 99 | inception.Conv2d_4a_3x3, 100 | nn.MaxPool2d(kernel_size=3, stride=2) 101 | ] 102 | self.blocks.append(nn.Sequential(*block1)) 103 | 104 | # Block 2: maxpool2 to aux classifier 105 | if self.last_needed_block >= 2: 106 | block2 = [ 107 | inception.Mixed_5b, 108 | inception.Mixed_5c, 109 | inception.Mixed_5d, 110 | inception.Mixed_6a, 111 | inception.Mixed_6b, 112 | inception.Mixed_6c, 113 | inception.Mixed_6d, 114 | inception.Mixed_6e, 115 | ] 116 | self.blocks.append(nn.Sequential(*block2)) 117 | 118 | # Block 3: aux classifier to final avgpool 119 | if self.last_needed_block >= 3: 120 | block3 = [ 121 | inception.Mixed_7a, 122 | inception.Mixed_7b, 123 | inception.Mixed_7c, 124 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 125 | ] 126 | self.blocks.append(nn.Sequential(*block3)) 127 | 128 | for param in self.parameters(): 129 | param.requires_grad = requires_grad 130 | 131 | def forward(self, inp): 132 | """Get Inception feature maps 133 | 134 | Parameters 135 | ---------- 136 | inp : torch.autograd.Variable 137 | Input tensor of shape Bx3xHxW. Values are expected to be in 138 | range (0, 1) 139 | 140 | Returns 141 | ------- 142 | List of torch.autograd.Variable, corresponding to the selected output 143 | block, sorted ascending by index 144 | """ 145 | outp = [] 146 | x = inp 147 | 148 | if self.resize_input: 149 | x = F.interpolate(x, 150 | size=(299, 299), 151 | mode='bilinear', 152 | align_corners=False) 153 | 154 | if self.normalize_input: 155 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 156 | 157 | for idx, block in enumerate(self.blocks): 158 | x = block(x) 159 | if idx in self.output_blocks: 160 | outp.append(x) 161 | 162 | if idx == self.last_needed_block: 163 | break 164 | 165 | return outp 166 | 167 | 168 | def _inception_v3(*args, **kwargs): 169 | """Wraps `torchvision.models.inception_v3` 170 | 171 | Skips default weight inititialization if supported by torchvision version. 172 | See https://github.com/mseitzer/pytorch-fid/issues/28. 173 | """ 174 | try: 175 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 176 | except ValueError: 177 | # Just a caution against weird version strings 178 | version = (0,) 179 | 180 | if version >= (0, 6): 181 | kwargs['init_weights'] = False 182 | 183 | return torchvision.models.inception_v3(*args, **kwargs) 184 | 185 | 186 | def fid_inception_v3(): 187 | """Build pretrained Inception model for FID computation 188 | 189 | The Inception model for FID computation uses a different set of weights 190 | and has a slightly different structure than torchvision's Inception. 191 | 192 | This method first constructs torchvision's Inception and then patches the 193 | necessary parts that are different in the FID Inception model. 194 | """ 195 | inception = _inception_v3(num_classes=1008, 196 | aux_logits=False, 197 | pretrained=False) 198 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 199 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 200 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 201 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 202 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 203 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 204 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 205 | inception.Mixed_7b = FIDInceptionE_1(1280) 206 | inception.Mixed_7c = FIDInceptionE_2(2048) 207 | 208 | try: 209 | state_dict = torch.load(FID_WEIGHTS_LOCAL) 210 | except: 211 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 212 | inception.load_state_dict(state_dict) 213 | return inception 214 | 215 | 216 | class FIDInceptionA(torchvision.models.inception.InceptionA): 217 | """InceptionA block patched for FID computation""" 218 | def __init__(self, in_channels, pool_features): 219 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 220 | 221 | def forward(self, x): 222 | branch1x1 = self.branch1x1(x) 223 | 224 | branch5x5 = self.branch5x5_1(x) 225 | branch5x5 = self.branch5x5_2(branch5x5) 226 | 227 | branch3x3dbl = self.branch3x3dbl_1(x) 228 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 229 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 230 | 231 | # Patch: Tensorflow's average pool does not use the padded zero's in 232 | # its average calculation 233 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 234 | count_include_pad=False) 235 | branch_pool = self.branch_pool(branch_pool) 236 | 237 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 238 | return torch.cat(outputs, 1) 239 | 240 | 241 | class FIDInceptionC(torchvision.models.inception.InceptionC): 242 | """InceptionC block patched for FID computation""" 243 | def __init__(self, in_channels, channels_7x7): 244 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 245 | 246 | def forward(self, x): 247 | branch1x1 = self.branch1x1(x) 248 | 249 | branch7x7 = self.branch7x7_1(x) 250 | branch7x7 = self.branch7x7_2(branch7x7) 251 | branch7x7 = self.branch7x7_3(branch7x7) 252 | 253 | branch7x7dbl = self.branch7x7dbl_1(x) 254 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 255 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 256 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 257 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 258 | 259 | # Patch: Tensorflow's average pool does not use the padded zero's in 260 | # its average calculation 261 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 262 | count_include_pad=False) 263 | branch_pool = self.branch_pool(branch_pool) 264 | 265 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 266 | return torch.cat(outputs, 1) 267 | 268 | 269 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 270 | """First InceptionE block patched for FID computation""" 271 | def __init__(self, in_channels): 272 | super(FIDInceptionE_1, self).__init__(in_channels) 273 | 274 | def forward(self, x): 275 | branch1x1 = self.branch1x1(x) 276 | 277 | branch3x3 = self.branch3x3_1(x) 278 | branch3x3 = [ 279 | self.branch3x3_2a(branch3x3), 280 | self.branch3x3_2b(branch3x3), 281 | ] 282 | branch3x3 = torch.cat(branch3x3, 1) 283 | 284 | branch3x3dbl = self.branch3x3dbl_1(x) 285 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 286 | branch3x3dbl = [ 287 | self.branch3x3dbl_3a(branch3x3dbl), 288 | self.branch3x3dbl_3b(branch3x3dbl), 289 | ] 290 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 291 | 292 | # Patch: Tensorflow's average pool does not use the padded zero's in 293 | # its average calculation 294 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 295 | count_include_pad=False) 296 | branch_pool = self.branch_pool(branch_pool) 297 | 298 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 299 | return torch.cat(outputs, 1) 300 | 301 | 302 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 303 | """Second InceptionE block patched for FID computation""" 304 | def __init__(self, in_channels): 305 | super(FIDInceptionE_2, self).__init__(in_channels) 306 | 307 | def forward(self, x): 308 | branch1x1 = self.branch1x1(x) 309 | 310 | branch3x3 = self.branch3x3_1(x) 311 | branch3x3 = [ 312 | self.branch3x3_2a(branch3x3), 313 | self.branch3x3_2b(branch3x3), 314 | ] 315 | branch3x3 = torch.cat(branch3x3, 1) 316 | 317 | branch3x3dbl = self.branch3x3dbl_1(x) 318 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 319 | branch3x3dbl = [ 320 | self.branch3x3dbl_3a(branch3x3dbl), 321 | self.branch3x3dbl_3b(branch3x3dbl), 322 | ] 323 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 324 | 325 | # Patch: The FID Inception model uses max pooling instead of average 326 | # pooling. This is likely an error in this specific Inception 327 | # implementation, as other Inception models use average pooling here 328 | # (which matches the description in the paper). 329 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 330 | branch_pool = self.branch_pool(branch_pool) 331 | 332 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 333 | return torch.cat(outputs, 1) 334 | -------------------------------------------------------------------------------- /metrics/fid/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | 37 | import numpy as np 38 | import torch 39 | import torchvision.transforms as TF 40 | from PIL import Image 41 | from scipy import linalg 42 | from torch.nn.functional import adaptive_avg_pool2d 43 | from tqdm import tqdm 44 | 45 | from metrics.fid.inception import InceptionV3 46 | 47 | # parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 48 | # parser.add_argument('--batch-size', type=int, default=50, 49 | # help='Batch size to use') 50 | # parser.add_argument('--num-workers', type=int, default=8, 51 | # help='Number of processes to use for data loading') 52 | # parser.add_argument('--device', type=str, default=None, 53 | # help='Device to use. Like cuda, cuda:0 or cpu') 54 | # parser.add_argument('--dims', type=int, default=2048, 55 | # choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 56 | # help=('Dimensionality of Inception features to use. ' 57 | # 'By default, uses pool3 features')) 58 | # parser.add_argument('path', type=str, nargs=2, 59 | # help=('Paths to the generated images or ' 60 | # 'to .npz statistic files')) 61 | 62 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 63 | 'tif', 'tiff', 'webp'} 64 | 65 | 66 | class ImagePathDataset(torch.utils.data.Dataset): 67 | def __init__(self, files, transforms=None): 68 | self.files = files 69 | self.transforms = transforms 70 | 71 | def __len__(self): 72 | return len(self.files) 73 | 74 | def __getitem__(self, i): 75 | path = self.files[i] 76 | img = Image.open(path).convert('RGB') 77 | if self.transforms is not None: 78 | img = self.transforms(img) 79 | return img 80 | 81 | 82 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8): 83 | """Calculates the activations of the pool_3 layer for all images. 84 | 85 | Params: 86 | -- files : List of image files paths 87 | -- model : Instance of inception model 88 | -- batch_size : Batch size of images for the model to process at once. 89 | Make sure that the number of samples is a multiple of 90 | the batch size, otherwise some samples are ignored. This 91 | behavior is retained to match the original FID score 92 | implementation. 93 | -- dims : Dimensionality of features returned by Inception 94 | -- device : Device to run calculations 95 | -- num_workers : Number of parallel dataloader workers 96 | 97 | Returns: 98 | -- A numpy array of dimension (num images, dims) that contains the 99 | activations of the given tensor when feeding inception with the 100 | query tensor. 101 | """ 102 | model.eval() 103 | 104 | if batch_size > len(files): 105 | print(('Warning: batch size is bigger than the data size. ' 106 | 'Setting batch size to data size')) 107 | batch_size = len(files) 108 | 109 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 110 | dataloader = torch.utils.data.DataLoader(dataset, 111 | batch_size=batch_size, 112 | shuffle=False, 113 | drop_last=False, 114 | num_workers=num_workers) 115 | 116 | pred_arr = np.empty((len(files), dims)) 117 | 118 | start_idx = 0 119 | 120 | for batch in tqdm(dataloader): 121 | batch = batch.to(device) 122 | 123 | with torch.no_grad(): 124 | pred = model(batch)[0] 125 | 126 | # If model output is not scalar, apply global spatial average pooling. 127 | # This happens if you choose a dimensionality not equal 2048. 128 | if pred.size(2) != 1 or pred.size(3) != 1: 129 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 130 | 131 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 132 | 133 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 134 | 135 | start_idx = start_idx + pred.shape[0] 136 | 137 | return pred_arr 138 | 139 | 140 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 141 | """Numpy implementation of the Frechet Distance. 142 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 143 | and X_2 ~ N(mu_2, C_2) is 144 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 145 | 146 | Stable version by Dougal J. Sutherland. 147 | 148 | Params: 149 | -- mu1 : Numpy array containing the activations of a layer of the 150 | inception net (like returned by the function 'get_predictions') 151 | for generated samples. 152 | -- mu2 : The sample mean over activations, precalculated on an 153 | representative data set. 154 | -- sigma1: The covariance matrix over activations for generated samples. 155 | -- sigma2: The covariance matrix over activations, precalculated on an 156 | representative data set. 157 | 158 | Returns: 159 | -- : The Frechet Distance. 160 | """ 161 | 162 | mu1 = np.atleast_1d(mu1) 163 | mu2 = np.atleast_1d(mu2) 164 | 165 | sigma1 = np.atleast_2d(sigma1) 166 | sigma2 = np.atleast_2d(sigma2) 167 | 168 | assert mu1.shape == mu2.shape, \ 169 | 'Training and test mean vectors have different lengths' 170 | assert sigma1.shape == sigma2.shape, \ 171 | 'Training and test covariances have different dimensions' 172 | 173 | diff = mu1 - mu2 174 | 175 | # Product might be almost singular 176 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 177 | if not np.isfinite(covmean).all(): 178 | msg = ('fid calculation produces singular product; ' 179 | 'adding %s to diagonal of cov estimates') % eps 180 | print(msg) 181 | offset = np.eye(sigma1.shape[0]) * eps 182 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 183 | 184 | # Numerical error might give slight imaginary component 185 | if np.iscomplexobj(covmean): 186 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 187 | m = np.max(np.abs(covmean.imag)) 188 | raise ValueError('Imaginary component {}'.format(m)) 189 | covmean = covmean.real 190 | 191 | tr_covmean = np.trace(covmean) 192 | 193 | return (diff.dot(diff) + np.trace(sigma1) 194 | + np.trace(sigma2) - 2 * tr_covmean) 195 | 196 | 197 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 198 | device='cpu', num_workers=8): 199 | """Calculation of the statistics used by the FID. 200 | Params: 201 | -- files : List of image files paths 202 | -- model : Instance of inception model 203 | -- batch_size : The images numpy array is split into batches with 204 | batch size batch_size. A reasonable batch size 205 | depends on the hardware. 206 | -- dims : Dimensionality of features returned by Inception 207 | -- device : Device to run calculations 208 | -- num_workers : Number of parallel dataloader workers 209 | 210 | Returns: 211 | -- mu : The mean over samples of the activations of the pool_3 layer of 212 | the inception model. 213 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 214 | the inception model. 215 | """ 216 | act = get_activations(files, model, batch_size, dims, device, num_workers) 217 | mu = np.mean(act, axis=0) 218 | sigma = np.cov(act, rowvar=False) 219 | return mu, sigma 220 | 221 | 222 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8): 223 | if path.endswith('.npz'): 224 | with np.load(path) as f: 225 | m, s = f['mu'][:], f['sigma'][:] 226 | else: 227 | path = pathlib.Path(path) 228 | files = sorted([file for ext in IMAGE_EXTENSIONS 229 | for file in path.glob('*.{}'.format(ext))]) 230 | m, s = calculate_activation_statistics(files, model, batch_size, 231 | dims, device, num_workers) 232 | 233 | return m, s 234 | 235 | 236 | class FIDScore: 237 | 238 | def __init__(self, device=None, dims=2048, batch_size=50, num_workers=4): 239 | self.device = device 240 | self.dims = dims 241 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 242 | self.inception_model = InceptionV3([block_idx]).to(device).eval() 243 | self.batch_size = batch_size 244 | self.num_workers = num_workers 245 | 246 | def calc_mean_std(self, p): 247 | """Calculates the mean and covariance matrix of path""" 248 | if not os.path.exists(p): 249 | raise RuntimeError(f'Invalid path: {p}') 250 | mean, std = compute_statistics_of_path(p, self.inception_model, self.batch_size, 251 | self.dims, self.device, self.num_workers) 252 | return mean, std 253 | 254 | def calc_mean_std_with_gen(self, gen_model, data_loader): 255 | self.inception_model.eval() 256 | 257 | pred_arr = np.empty((len(data_loader.dataset), self.dims)) 258 | start_idx = 0 259 | for batch in tqdm(data_loader): 260 | batch = batch.to(self.device, non_blocking=True) 261 | 262 | with torch.no_grad(): 263 | pred = self.inception_model(gen_model(batch))[0] 264 | 265 | # If model output is not scalar, apply global spatial average pooling. 266 | # This happens if you choose a dimensionality not equal 2048. 267 | if pred.size(2) != 1 or pred.size(3) != 1: 268 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 269 | 270 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 271 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 272 | start_idx = start_idx + pred.shape[0] 273 | 274 | mu = np.mean(pred_arr, axis=0) 275 | sigma = np.cov(pred_arr, rowvar=False) 276 | return mu, sigma 277 | 278 | @staticmethod 279 | def calc_fid(mean_std1, mean_std2): 280 | """Calculates the FID of two paths""" 281 | m1, s1 = mean_std1 282 | m2, s2 = mean_std2 283 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 284 | return fid_value 285 | 286 | 287 | # def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=8): 288 | # """Calculates the FID of two paths""" 289 | # for p in paths: 290 | # if not os.path.exists(p): 291 | # raise RuntimeError('Invalid path: %s' % p) 292 | # 293 | # block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 294 | # 295 | # model = InceptionV3([block_idx]).to(device) 296 | # 297 | # m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 298 | # dims, device, num_workers) 299 | # m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 300 | # dims, device, num_workers) 301 | # fid_value = calculate_frechet_distance(m1, s1, m2, s2) 302 | # 303 | # return fid_value 304 | # 305 | # 306 | # def main(): 307 | # args = parser.parse_args() 308 | # 309 | # if args.device is None: 310 | # device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 311 | # else: 312 | # device = torch.device(args.device) 313 | # 314 | # fid_value = calculate_fid_given_paths(args.path, 315 | # args.batch_size, 316 | # device, 317 | # args.dims, 318 | # args.num_workers) 319 | # print('FID: ', fid_value) 320 | # 321 | # 322 | if __name__ == '__main__': 323 | file_dir_1 = 'dataset/selfie2anime/testB' 324 | file_dir_2 = 'dataset/selfie2anime/testB' 325 | 326 | fid_score = FIDScore('cpu', batch_size=2, num_workers=1) 327 | mean_std_A_1 = fid_score.calc_mean_std(file_dir_1) 328 | mean_std_A_2 = fid_score.calc_mean_std(file_dir_2) 329 | score = fid_score.calc_fid(mean_std_A_1, mean_std_A_2) 330 | print(f'fid score between {file_dir_1} and {file_dir_2}: {score}') 331 | 332 | mean_std_A_1 = fid_score.calc_mean_std(file_dir_1) 333 | mean_std_A_2 = fid_score.calc_mean_std(file_dir_2) 334 | score = fid_score.calc_fid(mean_std_A_1, mean_std_A_2) 335 | print(f'fid score between {file_dir_1} and {file_dir_2}: {score}') 336 | 337 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional 5 | from torch.nn.parameter import Parameter 6 | 7 | 8 | class ResnetGenerator(nn.Module): 9 | 10 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, img_size=256, args=None): 11 | assert (n_blocks >= 0) 12 | super(ResnetGenerator, self).__init__() 13 | self.input_nc = input_nc 14 | self.output_nc = output_nc 15 | self.ngf = ngf 16 | self.n_blocks = n_blocks 17 | self.img_size = img_size 18 | self.args = args 19 | 20 | self.light = args.light 21 | self.attention_gan = args.attention_gan 22 | self.attention_input = args.attention_input 23 | self.use_se = args.use_se 24 | 25 | # 下采样模块:特征抽取、下采样、Bottleneck(resnet-block)特征编码 26 | DownBlock = [nn.ReflectionPad2d(3), 27 | nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=False), 28 | nn.InstanceNorm2d(ngf, affine=True), 29 | nn.ReLU(True)] 30 | # 下采样 31 | n_downsampling = 2 32 | for i in range(n_downsampling): 33 | mult = 2 ** i 34 | DownBlock += [nn.ReflectionPad2d(1), 35 | nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0, bias=False), 36 | nn.InstanceNorm2d(ngf * mult * 2, affine=True), 37 | nn.ReLU(True)] 38 | 39 | # Down-Sampling Bottleneck 40 | mult = 2 ** n_downsampling 41 | for i in range(n_blocks): 42 | DownBlock += [ResnetBlock(ngf * mult, use_bias=False, use_se=self.use_se)] 43 | 44 | # Class Activation Map 45 | self.gap_fc = nn.Linear(ngf * mult, 1, bias=True) 46 | # self.gmp_fc = self.gap_fc # tf版本 47 | self.gmp_fc = nn.Linear(ngf * mult, 1, bias=True) # pytorch版本 48 | self.conv1x1 = nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=1, stride=1, bias=True) 49 | self.relu = nn.ReLU(True) 50 | 51 | # 生成 gamma 和 beta,小模型和大模型 52 | if self.light > 0: 53 | FC = [nn.Linear(self.light * self.light * ngf * mult, ngf * mult, bias=True)] 54 | else: 55 | FC = [nn.Linear(img_size // mult * img_size // mult * ngf * mult, ngf * mult, bias=True)] 56 | FC += [nn.ReLU(True), 57 | nn.Linear(ngf * mult, ngf * mult, bias=True), 58 | nn.ReLU(True)] 59 | self.gamma = nn.Linear(ngf * mult, ngf * mult, bias=True) 60 | self.beta = nn.Linear(ngf * mult, ngf * mult, bias=True) 61 | 62 | # Up-Sampling Bottleneck 63 | for i in range(n_blocks): 64 | setattr(self, 'UpBlock1_' + str(i + 1), ResnetAdaLINBlock(ngf * mult, use_bias=True, use_se=self.use_se)) 65 | # Up-Sampling 66 | UpBlock2 = [] 67 | for i in range(n_downsampling): 68 | mult = 2 ** (n_downsampling - i) 69 | UpBlock2 += [nn.Upsample(scale_factor=2, mode='nearest'), 70 | nn.ReflectionPad2d(1), 71 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0, bias=True), 72 | LIN(int(ngf * mult / 2)), 73 | nn.ReLU(True) 74 | ] 75 | 76 | if self.attention_gan > 0: 77 | UpBlock_attention = [] 78 | mult = 2 ** n_downsampling 79 | for i in range(n_blocks): 80 | UpBlock_attention += [ResnetBlock(ngf * mult, use_bias=False, use_se=self.use_se)] 81 | for i in range(n_downsampling): 82 | mult = 2 ** (n_downsampling - i) 83 | UpBlock_attention += [nn.Upsample(scale_factor=2, mode='nearest'), 84 | nn.ReflectionPad2d(1), 85 | nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, 86 | stride=1, padding=0, bias=True), 87 | LIN(int(ngf * mult / 2)), 88 | nn.ReLU(True)] 89 | UpBlock_attention += [nn.Conv2d(ngf, self.attention_gan, kernel_size=1, stride=1, padding=0, bias=True), 90 | nn.Softmax(dim=1)] 91 | self.UpBlock_attention = nn.Sequential(*UpBlock_attention) 92 | if self.attention_input: 93 | output_nc *= (self.attention_gan - 1) 94 | else: 95 | output_nc *= self.attention_gan 96 | 97 | UpBlock2 += [nn.ReflectionPad2d(3), 98 | nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0, bias=True), 99 | nn.Tanh()] 100 | 101 | self.DownBlock = nn.Sequential(*DownBlock) 102 | self.FC = nn.Sequential(*FC) 103 | self.UpBlock2 = nn.Sequential(*UpBlock2) 104 | 105 | def forward(self, input_x): 106 | attention = None 107 | x = self.DownBlock(input_x) 108 | # cam作为attention加权x 109 | gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) 110 | gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) 111 | gap_weight = sum(list(self.gap_fc.parameters())) 112 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 113 | 114 | gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) 115 | gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) 116 | gmp_weight = sum(list(self.gmp_fc.parameters())) 117 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 118 | 119 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 120 | x = torch.cat([gap, gmp], 1) 121 | x = self.relu(self.conv1x1(x)) 122 | 123 | heatmap = torch.sum(x, dim=1, keepdim=True) 124 | # 生成上采样模块的adaILN的gamma和beta 125 | if self.light > 0: 126 | x_ = torch.nn.functional.adaptive_avg_pool2d(x, self.light) 127 | x_ = self.FC(x_.view(x_.shape[0], -1)) 128 | else: 129 | x_ = self.FC(x.view(x.shape[0], -1)) 130 | gamma, beta = self.gamma(x_), self.beta(x_) 131 | 132 | new_x = x 133 | for i in range(self.n_blocks): 134 | new_x = getattr(self, 'UpBlock1_' + str(i + 1))(new_x, gamma, beta) 135 | 136 | out = self.UpBlock2(new_x) 137 | 138 | if self.attention_gan > 0: 139 | attention = self.UpBlock_attention(x) 140 | batch_size, attention_ch, height, width = attention.shape 141 | if self.attention_input: 142 | out = torch.cat([input_x, out], dim=1) 143 | out = out.view(batch_size, 3, attention_ch, height, width) 144 | out = out * attention.view(batch_size, 1, attention_ch, height, width) 145 | out = out.sum(dim=2) 146 | 147 | return out, cam_logit, heatmap, attention 148 | 149 | 150 | class ChannelSELayer(nn.Module): 151 | def __init__(self, in_size, reduction=4, min_hidden_channel=8): 152 | super(ChannelSELayer, self).__init__() 153 | 154 | hidden_channel = max(in_size // reduction, min_hidden_channel) 155 | 156 | self.se = nn.Sequential( 157 | nn.AdaptiveAvgPool2d(1), 158 | nn.Conv2d(in_size, hidden_channel, kernel_size=1, stride=1, bias=True), 159 | nn.ReLU(inplace=True), 160 | nn.Conv2d(hidden_channel, in_size, kernel_size=1, stride=1, bias=True), 161 | nn.Sigmoid() 162 | ) 163 | 164 | def forward(self, x): 165 | return self.se(x) * x 166 | 167 | 168 | class SpatialSELayer(nn.Module): 169 | """ 170 | Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in: 171 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* 172 | """ 173 | 174 | def __init__(self, num_channels): 175 | """ 176 | :param num_channels: No of input channels 177 | """ 178 | super(SpatialSELayer, self).__init__() 179 | self.conv = nn.Conv2d(num_channels, 1, 1) 180 | self.sigmoid = nn.Sigmoid() 181 | 182 | def forward(self, input_tensor): 183 | """ 184 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 185 | :return: output_tensor 186 | """ 187 | out = self.conv(input_tensor) 188 | squeeze_tensor = self.sigmoid(out) 189 | return squeeze_tensor * input_tensor 190 | 191 | 192 | class ChannelSpatialSELayer(nn.Module): 193 | """ 194 | Re-implementation of concurrent spatial and channel squeeze & excitation: 195 | *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, 196 | MICCAI 2018, arXiv:1803.02579* 197 | """ 198 | 199 | def __init__(self, num_channels, reduction_ratio=4): 200 | """ 201 | :param num_channels: No of input channels 202 | :param reduction_ratio: By how much should the num_channels should be reduced 203 | """ 204 | super(ChannelSpatialSELayer, self).__init__() 205 | self.cSE = ChannelSELayer(num_channels, reduction_ratio) 206 | self.sSE = SpatialSELayer(num_channels) 207 | 208 | def forward(self, input_tensor): 209 | """ 210 | :param input_tensor: X, shape = (batch_size, num_channels, H, W) 211 | :return: output_tensor 212 | """ 213 | attention = self.cSE(input_tensor) + self.sSE(input_tensor) 214 | return attention 215 | 216 | 217 | class ResnetBlock(nn.Module): 218 | def __init__(self, dim, use_bias, use_se=False): 219 | super(ResnetBlock, self).__init__() 220 | conv_block = [] 221 | conv_block += [nn.ReflectionPad2d(1), 222 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), 223 | nn.InstanceNorm2d(dim, affine=True), 224 | nn.ReLU(True)] 225 | 226 | conv_block += [nn.ReflectionPad2d(1), 227 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), 228 | nn.InstanceNorm2d(dim, affine=True)] 229 | if use_se: 230 | conv_block += [ChannelSpatialSELayer(dim)] 231 | self.conv_block = nn.Sequential(*conv_block) 232 | 233 | def forward(self, x): 234 | out = x + self.conv_block(x) 235 | return out 236 | 237 | 238 | class ResnetAdaLINBlock(nn.Module): 239 | def __init__(self, dim, use_bias, use_se=False): 240 | super(ResnetAdaLINBlock, self).__init__() 241 | self.pad1 = nn.ReflectionPad2d(1) 242 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) 243 | self.norm1 = AdaLIN(dim) 244 | self.relu1 = nn.ReLU(True) 245 | 246 | self.pad2 = nn.ReflectionPad2d(1) 247 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) 248 | self.norm2 = AdaLIN(dim) 249 | self.use_se = use_se 250 | if use_se: 251 | self.se = ChannelSpatialSELayer(dim) 252 | 253 | def forward(self, x, gamma, beta): 254 | out = self.pad1(x) 255 | out = self.conv1(out) 256 | out = self.norm1(out, gamma, beta) 257 | out = self.relu1(out) 258 | out = self.pad2(out) 259 | out = self.conv2(out) 260 | out = self.norm2(out, gamma, beta) 261 | if self.use_se: 262 | out = self.se(out) 263 | return out + x 264 | 265 | 266 | class AdaLIN(nn.Module): 267 | def __init__(self, num_features, eps=1e-5): 268 | super(AdaLIN, self).__init__() 269 | self.eps = eps 270 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 271 | self.rho.data.fill_(0.9) 272 | 273 | def forward(self, in_x, gamma, beta): 274 | in_mean, in_var = torch.mean(in_x, dim=[2, 3], keepdim=True), torch.var(in_x, dim=[2, 3], keepdim=True) 275 | out_in = (in_x - in_mean) / torch.sqrt(in_var + self.eps) 276 | ln_mean, ln_var = torch.mean(in_x, dim=[1, 2, 3], keepdim=True), torch.var(in_x, dim=[1, 2, 3], keepdim=True) 277 | out_ln = (in_x - ln_mean) / torch.sqrt(ln_var + self.eps) 278 | out = self.rho.expand(in_x.shape[0], -1, -1, -1) * out_in + ( 279 | 1 - self.rho.expand(in_x.shape[0], -1, -1, -1)) * out_ln 280 | out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) 281 | 282 | return out 283 | 284 | 285 | class LIN(nn.Module): 286 | def __init__(self, num_features, eps=1e-5): 287 | super(LIN, self).__init__() 288 | self.eps = eps 289 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 290 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1)) 291 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1)) 292 | self.rho.data.fill_(0.0) 293 | self.gamma.data.fill_(1.0) 294 | self.beta.data.fill_(0.0) 295 | 296 | def forward(self, in_x): 297 | in_mean, in_var = torch.mean(in_x, dim=[2, 3], keepdim=True), torch.var(in_x, dim=[2, 3], keepdim=True) 298 | out_in = (in_x - in_mean) / torch.sqrt(in_var + self.eps) 299 | ln_mean, ln_var = torch.mean(in_x, dim=[1, 2, 3], keepdim=True), torch.var(in_x, dim=[1, 2, 3], keepdim=True) 300 | out_ln = (in_x - ln_mean) / torch.sqrt(ln_var + self.eps) 301 | out = self.rho.expand(in_x.shape[0], -1, -1, -1) * out_in + ( 302 | 1 - self.rho.expand(in_x.shape[0], -1, -1, -1)) * out_ln 303 | out = out * self.gamma.expand(in_x.shape[0], -1, -1, -1) + self.beta.expand(in_x.shape[0], -1, -1, -1) 304 | 305 | return out 306 | 307 | 308 | class Discriminator(nn.Module): 309 | def __init__(self, input_nc, ndf=64, n_layers=5, with_sn=True, use_cam_attention=True): 310 | super(Discriminator, self).__init__() 311 | self.use_cam_attention = use_cam_attention 312 | conv1 = nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True) 313 | model = [nn.ReflectionPad2d(1), 314 | nn.utils.spectral_norm(conv1) if with_sn else conv1, 315 | nn.LeakyReLU(0.2, True)] 316 | 317 | for i in range(1, n_layers - 2): 318 | mult = 2 ** (i - 1) 319 | conv = nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True) 320 | model += [nn.ReflectionPad2d(1), 321 | nn.utils.spectral_norm(conv) if with_sn else conv, 322 | nn.LeakyReLU(0.2, True)] 323 | if n_layers < 5: 324 | mult = 2 ** (n_layers - 2 - 1) 325 | for i in range(0, 5 - n_layers): 326 | conv = nn.Conv2d(ndf * mult, ndf * mult, kernel_size=3, stride=1, padding=0, bias=True) 327 | model += [nn.ReflectionPad2d(1), 328 | nn.utils.spectral_norm(conv) if with_sn else conv, 329 | nn.LeakyReLU(0.2, True)] 330 | 331 | mult = 2 ** (n_layers - 2 - 1) 332 | conv2 = nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True) 333 | model += [nn.ReflectionPad2d(1), 334 | nn.utils.spectral_norm(conv2) if with_sn else conv2, 335 | nn.LeakyReLU(0.2, True)] 336 | 337 | # Class Activation Map 338 | mult = 2 ** (n_layers - 2) 339 | linear_gmp = nn.Linear(ndf * mult, 1, bias=True) 340 | linear_gap = nn.Linear(ndf * mult, 1, bias=True) 341 | self.gmp_fc = nn.utils.spectral_norm(linear_gmp) if with_sn else linear_gmp 342 | self.gap_fc = nn.utils.spectral_norm(linear_gap) if with_sn else linear_gap 343 | self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True) 344 | self.leaky_relu = nn.LeakyReLU(0.2, True) 345 | 346 | self.pad = nn.ReflectionPad2d(1) 347 | conv3 = nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=True) 348 | self.conv = nn.utils.spectral_norm(conv3) if with_sn else conv3 349 | 350 | self.model = nn.Sequential(*model) 351 | 352 | def forward(self, in_x, cam_input=None, mask=None): 353 | x = self.model(in_x) 354 | 355 | # 如果专门给CAM用的输入(例如将背景抹除的输入,这样可以使背景)不是None,则用这个输入计算logit和对应CAM 356 | cam_x = self.model(cam_input) if cam_input is not None else x 357 | 358 | if mask is not None: 359 | cam_x = torch.nn.functional.interpolate(mask, cam_x.shape[2:], mode='area') * cam_x 360 | 361 | gap = torch.nn.functional.adaptive_avg_pool2d(cam_x, 1) 362 | gap_logit = self.gap_fc(gap.view(cam_x.shape[0], -1)) 363 | if self.use_cam_attention: 364 | gap_weight = sum(list(self.gap_fc.parameters())) 365 | gap = x * gap_weight.unsqueeze(2).unsqueeze(3) 366 | else: 367 | gap = x 368 | 369 | gmp = torch.nn.functional.adaptive_max_pool2d(cam_x, 1) 370 | gmp_logit = self.gmp_fc(gmp.view(cam_x.shape[0], -1)) 371 | if self.use_cam_attention: 372 | gmp_weight = sum(list(self.gmp_fc.parameters())) 373 | gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) 374 | else: 375 | gmp = x 376 | 377 | cam_logit = torch.cat([gap_logit, gmp_logit], 1) 378 | 379 | x = torch.cat([gap, gmp], 1) 380 | x = self.leaky_relu(self.conv1x1(x)) 381 | 382 | heatmap = torch.sum(x, dim=1, keepdim=True) 383 | 384 | x = self.pad(x) 385 | out = self.conv(x) 386 | 387 | return out, cam_logit, heatmap 388 | 389 | 390 | class RhoClipper(object): 391 | 392 | def __init__(self, min_num, max_num, module_type=None): 393 | self.clip_min = min_num 394 | self.clip_max = max_num 395 | self.module_type = module_type 396 | assert min_num < max_num 397 | 398 | def __call__(self, module): 399 | if (self.module_type is None and hasattr(module, 'rho')) or (type(module) == self.module_type): 400 | w = module.rho.data 401 | w = w.clamp(self.clip_min, self.clip_max) 402 | module.rho.data = w 403 | -------------------------------------------------------------------------------- /UGATIT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import copy 5 | import itertools 6 | from glob import glob 7 | from typing import Union 8 | 9 | import cv2 10 | import PIL 11 | import numpy as np 12 | from tqdm import tqdm 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F # noqa 16 | from torchvision import transforms 17 | from torch.utils.tensorboard import SummaryWriter 18 | 19 | from faceseg.FaceSegmentation import FaceSegmentation 20 | from utils import (calc_tv_loss, AverageMeter, ProgressMeter, generate_blur_images, 21 | RGB2BGR, tensor2numpy, attention_mask, cam, denorm) 22 | from dataset import MatchHistogramsDataset, DatasetFolder, get_loader 23 | from metrics import FIDScore 24 | 25 | 26 | class UGATIT(object): 27 | def __init__(self, args): 28 | self.args = args 29 | 30 | if self.args.light > 0: 31 | self.model_name = 'UGATIT_light' + str(self.args.light) 32 | else: 33 | self.model_name = 'UGATIT' 34 | 35 | print(f'\n##### Information #####\n' 36 | f'# light : {self.args.light}\n' 37 | f'# dataset : {self.args.dataset}\n' 38 | f'# batch_size : {self.args.batch_size}\n' 39 | f'# num_workers : {self.args.num_workers}\n' 40 | f'# ema_start : {self.args.ema_start}\n' 41 | f'# ema_beta : {self.args.ema_beta}\n' 42 | f'# iteration : {self.args.iteration}\n' 43 | f'# is decay : {self.args.no_decay_flag}\n' 44 | f'##### Data #####\n' 45 | f'# img_size : {self.args.img_size}\n' 46 | f'# aug_prob : {self.args.aug_prob}\n' 47 | f'# match_histograms : {self.args.match_histograms}\n' 48 | f'# match_mode : {self.args.match_mode}\n' 49 | f'# match_prob : {self.args.match_prob}\n' 50 | f'# match_ratio : {self.args.match_ratio}\n' 51 | f'##### Generator #####\n' 52 | f'# residual blocks : {self.args.n_res}\n' 53 | f'# use se or not : {self.args.use_se}\n' 54 | f'# use blur or not : {self.args.has_blur}\n' 55 | f'# tv_loss : {self.args.tv_loss}\n' 56 | f'# tv_weight : {self.args.tv_weight}\n' 57 | f'# use attention gan : {self.args.attention_gan}\n' 58 | f'# use attention input : {self.args.attention_input}\n' 59 | f'##### Discriminator #####\n' 60 | f'# global discriminator layer : {self.args.n_global_dis}\n' 61 | f'# local discriminator layer : {self.args.n_local_dis}\n' 62 | f'##### Weight #####\n' 63 | f'# adv_weight : {self.args.adv_weight}\n' 64 | f'# forward_adv_weight : {self.args.forward_adv_weight}\n' 65 | f'# backward_adv_weight : {self.args.backward_adv_weight}\n' 66 | f'# cycle_weight : {self.args.cycle_weight}\n' 67 | f'# identity_weight : {self.args.identity_weight}\n' 68 | f'# cam_weight : {self.args.cam_weight}\n' 69 | f'##### Enhanced #####\n' 70 | f'# cam_D_weight : {self.args.cam_D_weight}\n' 71 | f'# cam_D_attention : {self.args.cam_D_attention}\n' 72 | f'##### Segment #####\n' 73 | f'# hard_seg_edge : {self.args.hard_seg_edge}\n' 74 | f'# seg_fix_weight : {self.args.seg_fix_weight}\n' 75 | f'# seg_fix_glass_mouth : {self.args.seg_fix_glass_mouth}\n' 76 | f'# seg_D_mask : {self.args.seg_D_mask}\n' 77 | f'# seg_G_detach : {self.args.seg_G_detach}\n' 78 | f'# seg_D_cam_fea_mask : {self.args.seg_D_cam_fea_mask}\n' 79 | f'# seg_D_cam_inp_mask : {self.args.seg_D_cam_inp_mask}\n' 80 | f'# resume : {self.args.resume}\n\n' 81 | ) 82 | 83 | self.use_seg = ((self.args.seg_fix_weight > 0) or self.args.seg_D_mask or self.args.seg_G_detach or 84 | self.args.seg_D_cam_fea_mask or self.args.seg_D_cam_inp_mask) 85 | 86 | self.genA2B, self.genB2A = None, None 87 | self.genA2B_ema, self.genB2A_ema = None, None 88 | self.disGA, self.disGB, self.disLA, self.disLB = None, None, None, None 89 | self.FaceSeg = None 90 | self.maskA, self.maskA_erode, self.maskB, self.maskB_erode = None, None, None, None 91 | self.trainA_data_root = os.path.join('dataset', self.args.dataset, 'trainA') 92 | self.trainB_data_root = os.path.join('dataset', self.args.dataset, 'trainB') 93 | self.testA_data_root = os.path.join('dataset', self.args.dataset, 'testA') 94 | self.testB_data_root = os.path.join('dataset', self.args.dataset, 'testB') 95 | self.blurA_data_root = os.path.join('dataset', self.args.dataset, 'blurA') 96 | self.blurB_data_root = os.path.join('dataset', self.args.dataset, 'blurB') 97 | self.train_transform, self.test_transform = None, None 98 | self.trainAB_loader, self.blurAB_loader = None, None 99 | self.trainAB_iter, self.blurAB_iter = None, None 100 | self.testA_loader, self.testB_loader = None, None 101 | self.testA_iter, self.testB_iter = None, None 102 | self.L1_loss, self.MSE_loss, self.BCE_loss = None, None, None 103 | self.G_optim, self.D_optim = None, None 104 | self.Rho_LIN_clipper, self.Rho_AdaLIN_clipper = None, None 105 | self.G_adv_loss, self.G_cyc_loss, self.G_idt_loss, self.G_cam_loss = None, None, None, None 106 | self.Generator_loss, self.G_seg_loss, self.tv_loss = None, None, None 107 | self.discriminator_loss = None 108 | self.fid_score, self.mean_std_A, self.mean_std_B = None, None, None 109 | self.fid_loaderA, self.fid_loaderB = None, None 110 | 111 | ################################################################################## 112 | # Model 113 | ################################################################################## 114 | 115 | def build_data_loader(self): 116 | """ 构造data loader """ 117 | self.train_transform = transforms.Compose([ 118 | PIL.Image.fromarray, 119 | transforms.RandomHorizontalFlip(), 120 | transforms.RandomApply( 121 | [transforms.RandomResizedCrop(size=self.args.img_size, scale=(0.748, 1.0), ratio=(1.0, 1.0), 122 | interpolation=transforms.InterpolationMode.BICUBIC)], 123 | p=self.args.aug_prob), 124 | transforms.Resize(size=self.args.img_size, interpolation=transforms.InterpolationMode.BICUBIC), 125 | transforms.ToTensor(), 126 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 127 | ]) 128 | self.test_transform = transforms.Compose([ 129 | PIL.Image.fromarray, 130 | transforms.Resize((self.args.img_size, self.args.img_size), 131 | interpolation=transforms.InterpolationMode.BICUBIC), 132 | transforms.ToTensor(), 133 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 134 | ]) 135 | 136 | trainAB = MatchHistogramsDataset((self.trainA_data_root, self.trainB_data_root), 137 | self.train_transform, is_match_histograms=self.args.match_histograms, 138 | match_mode=self.args.match_mode, b2a_prob=self.args.match_prob, 139 | match_ratio=self.args.match_ratio) 140 | self.trainAB_loader = get_loader(trainAB, self.args.device, batch_size=self.args.batch_size, 141 | shuffle=True, num_workers=self.args.num_workers) 142 | testA = DatasetFolder(self.testA_data_root, self.test_transform) 143 | testB = DatasetFolder(self.testB_data_root, self.test_transform) 144 | self.testA_loader = get_loader(testA, self.args.device, batch_size=1, shuffle=False, 145 | num_workers=self.args.num_workers) 146 | self.testB_loader = get_loader(testB, self.args.device, batch_size=1, shuffle=False, 147 | num_workers=self.args.num_workers) 148 | 149 | # 使用模糊图像增强判别器D对模糊的判别,从而增强生成器G生成清晰图像 150 | if self.args.has_blur: 151 | if not os.path.exists(self.blurA_data_root): 152 | generate_blur_images(self.trainA_data_root, self.blurA_data_root) 153 | if not os.path.exists(self.blurB_data_root): 154 | generate_blur_images(self.trainB_data_root, self.blurB_data_root) 155 | 156 | blurAB = MatchHistogramsDataset((self.blurA_data_root, self.blurB_data_root), self.train_transform, 157 | is_match_histograms=self.args.match_histograms, 158 | match_mode=self.args.match_mode, b2a_prob=self.args.match_prob, 159 | match_ratio=self.args.match_ratio) 160 | self.blurAB_loader = get_loader(blurAB, self.args.device, batch_size=self.args.batch_size, 161 | shuffle=True, num_workers=self.args.num_workers) 162 | 163 | def build_model(self): 164 | """ 构造data loader,Generator,Discriminator 模型,损失,优化器 """ 165 | from networks import ResnetGenerator, Discriminator, RhoClipper, LIN, AdaLIN 166 | self.build_data_loader() 167 | # Define Generator, Discriminator 168 | self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.args.ch, n_blocks=self.args.n_res, 169 | img_size=self.args.img_size, args=self.args).to(self.args.device) 170 | self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.args.ch, n_blocks=self.args.n_res, 171 | img_size=self.args.img_size, args=self.args).to(self.args.device) 172 | self.genA2B_ema = copy.deepcopy(self.genA2B).eval().requires_grad_(False) 173 | self.genB2A_ema = copy.deepcopy(self.genB2A).eval().requires_grad_(False) 174 | self.disGA = Discriminator(input_nc=3, ndf=self.args.ch, n_layers=self.args.n_global_dis, with_sn=self.args.sn, 175 | use_cam_attention=self.args.cam_D_attention).to(self.args.device) 176 | self.disGB = Discriminator(input_nc=3, ndf=self.args.ch, n_layers=self.args.n_global_dis, with_sn=self.args.sn, 177 | use_cam_attention=self.args.cam_D_attention).to(self.args.device) 178 | self.disLA = Discriminator(input_nc=3, ndf=self.args.ch, n_layers=self.args.n_local_dis, with_sn=self.args.sn, 179 | use_cam_attention=self.args.cam_D_attention).to(self.args.device) 180 | self.disLB = Discriminator(input_nc=3, ndf=self.args.ch, n_layers=self.args.n_local_dis, with_sn=self.args.sn, 181 | use_cam_attention=self.args.cam_D_attention).to(self.args.device) 182 | 183 | # 使用分割区域做L2监督损失,或,分割出来的区域随机填充颜色的填充概率值 184 | if self.use_seg: 185 | self.FaceSeg = FaceSegmentation(self.args.device) 186 | 187 | # Define Loss 188 | self.L1_loss = nn.L1Loss().to(self.args.device) 189 | self.MSE_loss = nn.MSELoss().to(self.args.device) 190 | self.BCE_loss = nn.BCEWithLogitsLoss().to(self.args.device) 191 | 192 | # 优化器 193 | gen_params = itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()) 194 | self.G_optim = torch.optim.Adam(gen_params, lr=self.args.lr, betas=(0.5, 0.999), 195 | weight_decay=self.args.weight_decay) 196 | disc_params = itertools.chain(self.disGA.parameters(), self.disGB.parameters(), 197 | self.disLA.parameters(), self.disLB.parameters()) 198 | self.D_optim = torch.optim.Adam(disc_params, lr=self.args.lr, betas=(0.5, 0.999), 199 | weight_decay=self.args.weight_decay) 200 | 201 | # Define Rho clipper to constraint the value of rho in AdaLIN and LIN 202 | # self.Rho_clipper = RhoClipper(0, 1) 203 | self.Rho_LIN_clipper = RhoClipper(0, 1, LIN) 204 | self.Rho_AdaLIN_clipper = RhoClipper(0.0, 0.9, AdaLIN) 205 | 206 | ################################################################################## 207 | # 工具函数 208 | ################################################################################## 209 | 210 | def gen_train(self, on=True): 211 | """ 开启生成网络训练模式 """ 212 | if on: 213 | self.genA2B.train(), self.genB2A.train() 214 | else: 215 | self.genA2B.eval(), self.genB2A.eval() 216 | 217 | def dis_train(self, on=True): 218 | """ 开启判别网络训练模式 """ 219 | if on: 220 | self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train() 221 | else: 222 | self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() 223 | 224 | def get_batch(self, mode='train'): 225 | """ 获取训练数据 """ 226 | if mode == 'train': 227 | try: 228 | real_A, real_B = next(self.trainAB_iter) 229 | except (StopIteration, TypeError): 230 | self.trainAB_iter = iter(self.trainAB_loader) 231 | real_A, real_B = next(self.trainAB_iter) 232 | else: 233 | try: 234 | real_A = next(self.testA_iter) 235 | except (StopIteration, TypeError): 236 | self.testA_iter = iter(self.testA_loader) 237 | real_A = next(self.testA_iter) 238 | 239 | try: 240 | real_B = next(self.testB_iter) 241 | except (StopIteration, TypeError): 242 | self.testB_iter = iter(self.testB_loader) 243 | real_B = next(self.testB_iter) 244 | 245 | real_A, real_B = real_A.to(self.args.device, non_blocking=True), real_B.to(self.args.device, non_blocking=True) 246 | 247 | blur = None 248 | if self.args.has_blur and mode == 'train': 249 | try: 250 | blur_A, blur_B = next(self.blurAB_iter) 251 | except (StopIteration, TypeError): 252 | self.blurAB_iter = iter(self.blurAB_loader) 253 | blur_A, blur_B = next(self.blurAB_iter) 254 | 255 | blur = (blur_A.to(self.args.device, non_blocking=True), blur_B.to(self.args.device, non_blocking=True)) 256 | 257 | return real_A, real_B, blur 258 | 259 | ################################################################################## 260 | # 训练 261 | ################################################################################## 262 | 263 | def forward(self, real_A, real_B): 264 | """ 前向推理:A->B->A, B->A->B, A->A, B->B """ 265 | # cycle 266 | fake_A2B, fake_A2B_cam_logit, fake_A2B_heatmap, fake_A2B_attention = self.genA2B(real_A) 267 | fake_A2B2A, _, fake_A2B2A_heatmap, fake_A2B2A_attention = self.genB2A(fake_A2B) 268 | fake_B2A, fake_B2A_cam_logit, fake_B2A_heatmap, fake_B2A_attention = self.genB2A(real_B) 269 | fake_B2A2B, _, fake_B2A2B_heatmap, fake_B2A2B_attention = self.genA2B(fake_B2A) 270 | # 单位映射 271 | fake_A2A, fake_A2A_cam_logit, fake_A2A_heatmap, fake_A2A_attention = self.genB2A(real_A) 272 | fake_B2B, fake_B2B_cam_logit, fake_B2B_heatmap, fake_B2B_attention = self.genA2B(real_B) 273 | 274 | # 根据人脸分割,获取分割区域 self.maskA (==1) 275 | if self.use_seg: 276 | maskA = self.FaceSeg.face_segmentation(real_A) 277 | self.maskA = self.FaceSeg.gen_mask(maskA, is_soft_edge=not self.args.hard_seg_edge, 278 | normal_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17), 279 | dilate_parts=(), erode_parts=()) 280 | self.maskA_erode = self.FaceSeg.gen_mask(maskA, normal_parts=(), dilate_parts=(), 281 | is_soft_edge=not self.args.hard_seg_edge, 282 | erode_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17)) 283 | if self.args.seg_fix_glass_mouth: 284 | maskA_erode = self.FaceSeg.gen_mask(maskA, normal_parts=(1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17), 285 | dilate_parts=(), is_soft_edge=not self.args.hard_seg_edge, 286 | erode_parts=(6, )) 287 | self.maskA_erode *= maskA_erode 288 | maskB = self.FaceSeg.face_segmentation(real_B) 289 | self.maskB = self.FaceSeg.gen_mask(maskB, is_soft_edge=not self.args.hard_seg_edge, 290 | normal_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17), 291 | dilate_parts=(), erode_parts=()) 292 | self.maskB_erode = self.FaceSeg.gen_mask(maskB, normal_parts=(), dilate_parts=(), 293 | is_soft_edge=not self.args.hard_seg_edge, 294 | erode_parts=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17)) 295 | if self.args.seg_fix_glass_mouth: 296 | maskB_erode = self.FaceSeg.gen_mask(maskB, normal_parts=(1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 17), 297 | dilate_parts=(), is_soft_edge=not self.args.hard_seg_edge, 298 | erode_parts=(6, )) 299 | self.maskB_erode *= maskB_erode 300 | 301 | return (fake_A2B, fake_A2B_cam_logit, fake_A2B_heatmap, fake_A2B_attention, 302 | fake_A2B2A, fake_A2B2A_heatmap, fake_A2B2A_attention, 303 | fake_B2A, fake_B2A_cam_logit, fake_B2A_heatmap, fake_B2A_attention, 304 | fake_B2A2B, fake_B2A2B_heatmap, fake_B2A2B_attention, 305 | fake_A2A, fake_A2A_cam_logit, fake_A2A_heatmap, fake_A2A_attention, 306 | fake_B2B, fake_B2B_cam_logit, fake_B2B_heatmap, fake_B2B_attention) 307 | 308 | def backward_D(self, real_A, real_B, fake_A2B, fake_B2A, blur=None): # noqa 309 | """ D网络前向+反向计算 """ 310 | fake_A2B, fake_B2A = fake_A2B.detach(), fake_B2A.detach() 311 | # 将生成图像的分割区域(膨胀)随机填充颜色,用于后续D训练的cam 312 | cam_real_A, cam_fake_A2B, cam_real_B, cam_fake_B2A = None, None, None, None 313 | if self.args.seg_D_cam_inp_mask: 314 | cam_real_A = self.maskA * real_A 315 | cam_fake_A2B = self.maskA * fake_A2B 316 | cam_real_B = self.maskB * real_B 317 | cam_fake_B2A = self.maskB * fake_B2A 318 | 319 | maskA, maskB = None, None 320 | if self.args.seg_D_cam_fea_mask: 321 | maskA, maskB = self.maskA, self.maskB 322 | 323 | real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A, cam_real_A, maskA) 324 | real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A, cam_real_A, maskA) 325 | fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A, cam_fake_B2A, maskB) 326 | fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A, cam_fake_B2A, maskB) 327 | 328 | real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B, cam_real_B, maskB) 329 | real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B, cam_real_B, maskB) 330 | fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B, cam_fake_A2B, maskA) 331 | fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B, cam_fake_A2B, maskA) 332 | 333 | # 设置目标常量,GA和LA的D网络输出shape不一致,但cam的分类输出是一致的 334 | flag_GA_1 = torch.ones_like(real_GA_logit, requires_grad=False).to(self.args.device) 335 | flag_GA_0 = torch.zeros_like(fake_GA_logit, requires_grad=False).to(self.args.device) 336 | flag_LA_1 = torch.ones_like(real_LA_logit, requires_grad=False).to(self.args.device) 337 | flag_LA_0 = torch.zeros_like(fake_LA_logit, requires_grad=False).to(self.args.device) 338 | flag_cam_1 = torch.ones_like(real_GA_cam_logit, requires_grad=False).to(self.args.device) 339 | flag_cam_0 = torch.zeros_like(fake_GA_cam_logit, requires_grad=False).to(self.args.device) 340 | 341 | # D网络损失函数:cam和D网络损失 342 | D_loss_GA, D_cam_loss_GA = 0, 0 343 | D_loss_LA, D_cam_loss_LA = 0, 0 344 | D_loss_GB, D_cam_loss_GB = 0, 0 345 | D_loss_LB, D_cam_loss_LB = 0, 0 346 | if blur is not None: 347 | blur_A, blur_B = blur 348 | cam_blur_A, cam_blur_B = None, None 349 | if self.args.seg_D_cam_inp_mask: 350 | cam_blur_A = self.maskA * blur_A 351 | cam_blur_B = self.maskB * blur_B 352 | blur_GA_logit, blur_GA_cam_logit, _ = self.disGA(blur_A, cam_blur_A, maskA) 353 | blur_LA_logit, blur_LA_cam_logit, _ = self.disLA(blur_A, cam_blur_A, maskA) 354 | blur_GB_logit, blur_GB_cam_logit, _ = self.disGB(blur_B, cam_blur_B, maskB) 355 | blur_LB_logit, blur_LB_cam_logit, _ = self.disLB(blur_B, cam_blur_B, maskB) 356 | # 无论使用什么策略,模糊图像都是false 357 | D_loss_GA = self.MSE_loss(blur_GA_logit, flag_GA_0) 358 | D_loss_LA = self.MSE_loss(blur_LA_logit, flag_LA_0) 359 | D_loss_GB = self.MSE_loss(blur_GB_logit, flag_GA_0) 360 | D_loss_LB = self.MSE_loss(blur_LB_logit, flag_LA_0) 361 | if self.args.cam_D_weight > 0: 362 | D_cam_loss_GA = self.MSE_loss(blur_GA_cam_logit, flag_cam_0) 363 | D_cam_loss_LA = self.MSE_loss(blur_LA_cam_logit, flag_cam_0) 364 | D_cam_loss_GB = self.MSE_loss(blur_GB_cam_logit, flag_cam_0) 365 | D_cam_loss_LB = self.MSE_loss(blur_LB_cam_logit, flag_cam_0) 366 | 367 | # 只计算分割区域的损失 368 | if self.args.seg_D_mask: 369 | maskA_G = F.interpolate(self.maskA, fake_GA_logit.shape[2:], mode='area') 370 | maskA_L = F.interpolate(self.maskA, fake_LA_logit.shape[2:], mode='area') 371 | maskB_G = F.interpolate(self.maskB, fake_GB_logit.shape[2:], mode='area') 372 | maskB_L = F.interpolate(self.maskB, fake_LB_logit.shape[2:], mode='area') 373 | 374 | real_GA_logit = real_GA_logit * maskA_G + flag_GA_1 * (1 - maskA_G) 375 | fake_GA_logit = fake_GA_logit * maskB_G 376 | real_LA_logit = real_LA_logit * maskA_L + flag_LA_1 * (1 - maskA_L) 377 | fake_LA_logit = fake_LA_logit * maskB_L 378 | real_GB_logit = real_GB_logit * maskB_G + flag_GA_1 * (1 - maskB_G) 379 | fake_GB_logit = fake_GB_logit * maskA_G 380 | real_LB_logit = real_LB_logit * maskB_L + flag_LA_1 * (1 - maskB_L) 381 | fake_LB_logit = fake_LB_logit * maskA_L 382 | 383 | if self.args.cam_D_weight > 0: 384 | D_cam_loss_GA += (self.MSE_loss(real_GA_cam_logit, flag_cam_1) + self.MSE_loss(fake_GA_cam_logit, flag_cam_0)) # noqa, E501 385 | D_cam_loss_LA += (self.MSE_loss(real_LA_cam_logit, flag_cam_1) + self.MSE_loss(fake_LA_cam_logit, flag_cam_0)) # noqa, E501 386 | D_cam_loss_GB += (self.MSE_loss(real_GB_cam_logit, flag_cam_1) + self.MSE_loss(fake_GB_cam_logit, flag_cam_0)) # noqa, E501 387 | D_cam_loss_LB += (self.MSE_loss(real_LB_cam_logit, flag_cam_1) + self.MSE_loss(fake_LB_cam_logit, flag_cam_0)) # noqa, E501 388 | D_loss_GA += (self.MSE_loss(real_GA_logit, flag_GA_1) + self.MSE_loss(fake_GA_logit, flag_GA_0)) 389 | D_loss_LA += (self.MSE_loss(real_LA_logit, flag_LA_1) + self.MSE_loss(fake_LA_logit, flag_LA_0)) 390 | D_loss_A = D_loss_GA + D_loss_LA + (D_cam_loss_GA + D_cam_loss_LA) * self.args.cam_D_weight 391 | D_loss_A = self.args.forward_adv_weight * self.args.adv_weight * D_loss_A 392 | 393 | D_loss_GB += (self.MSE_loss(real_GB_logit, flag_GA_1) + self.MSE_loss(fake_GB_logit, flag_GA_0)) 394 | D_loss_LB += (self.MSE_loss(real_LB_logit, flag_LA_1) + self.MSE_loss(fake_LB_logit, flag_LA_0)) 395 | D_loss_B = D_loss_GB + D_loss_LB + (D_cam_loss_GB + D_cam_loss_LB) * self.args.cam_D_weight 396 | D_loss_B = self.args.backward_adv_weight * self.args.adv_weight * D_loss_B 397 | 398 | self.discriminator_loss = D_loss_A + D_loss_B 399 | # backward 400 | self.discriminator_loss.backward() 401 | return D_loss_A, D_loss_B 402 | 403 | def backward_G(self, real_A, real_B, fake_A2B, fake_B2A, fake_A2B2A, fake_B2A2B, fake_A2A, fake_B2B, # noqa 404 | fake_A2B_cam_logit, fake_B2A_cam_logit, fake_A2A_cam_logit, fake_B2B_cam_logit): 405 | self.Generator_loss: Union[int, torch.tensor] = 0 406 | # 根据人脸分割,获取背景不变性损失项 407 | if self.args.seg_fix_weight > 0: 408 | G_seg_loss_B = self.L1_loss(fake_A2B * (1 - self.maskA_erode), real_A * (1 - self.maskA_erode)) 409 | G_seg_loss_A = self.L1_loss(fake_B2A * (1 - self.maskB_erode), real_B * (1 - self.maskB_erode)) 410 | self.G_seg_loss = self.args.seg_fix_weight * (G_seg_loss_A + G_seg_loss_B) 411 | self.Generator_loss += self.G_seg_loss 412 | 413 | if self.args.tv_loss: 414 | self.tv_loss = calc_tv_loss(fake_A2B, mask=self.maskA) + calc_tv_loss(fake_B2A, mask=self.maskB) 415 | self.tv_loss *= self.args.tv_weight 416 | self.Generator_loss += self.tv_loss 417 | 418 | # 将生成图像的背景detach掉,使背景上的对抗损失梯度不影响G 419 | if self.args.seg_G_detach: 420 | fake_A2B = fake_A2B * self.maskA + fake_A2B.detach() * (1.0 - self.maskA) 421 | fake_B2A = fake_B2A * self.maskB + fake_B2A.detach() * (1.0 - self.maskB) 422 | 423 | cam_fake_A2B, cam_fake_B2A = None, None 424 | if self.args.seg_D_cam_inp_mask: 425 | cam_fake_A2B = self.maskA * fake_A2B 426 | cam_fake_B2A = self.maskB * fake_B2A 427 | 428 | maskA, maskB = None, None 429 | if self.args.seg_D_cam_fea_mask: 430 | maskA, maskB = self.maskA, self.maskB 431 | 432 | # 判别器输出 433 | fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A, cam_fake_B2A, maskB) 434 | fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A, cam_fake_B2A, maskB) 435 | fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B, cam_fake_A2B, maskA) 436 | fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B, cam_fake_A2B, maskA) 437 | # 设置目标常量,GA和LA的D网络输出shape不一致,但cam的分类输出是一致的 438 | flag_GA_1 = torch.ones_like(fake_GA_logit, requires_grad=False).to(self.args.device) 439 | flag_LA_1 = torch.ones_like(fake_LA_logit, requires_grad=False).to(self.args.device) 440 | flag_GA_cam_1 = torch.ones_like(fake_GA_cam_logit, requires_grad=False).to(self.args.device) 441 | 442 | # 对抗损失 443 | if self.args.seg_D_mask: 444 | # 背景区域不需要对抗损失,置为目标值,等价于把背景区域的对抗损失置0 445 | maskB_G = F.interpolate(self.maskB, fake_GA_logit.shape[2:], mode='area') 446 | maskB_L = F.interpolate(self.maskB, fake_LA_logit.shape[2:], mode='area') 447 | maskA_G = F.interpolate(self.maskA, fake_GB_logit.shape[2:], mode='area') 448 | maskA_L = F.interpolate(self.maskA, fake_LB_logit.shape[2:], mode='area') 449 | fake_GA_logit = fake_GA_logit * maskB_G + flag_GA_1 * (1 - maskB_G) 450 | fake_LA_logit = fake_LA_logit * maskB_L + flag_LA_1 * (1 - maskB_L) 451 | fake_GB_logit = fake_GB_logit * maskA_G + flag_GA_1 * (1 - maskA_G) 452 | fake_LB_logit = fake_LB_logit * maskA_L + flag_LA_1 * (1 - maskA_L) 453 | 454 | if self.args.cam_D_weight > 0: 455 | G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, flag_GA_cam_1) 456 | G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, flag_GA_cam_1) 457 | G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, flag_GA_cam_1) 458 | G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, flag_GA_cam_1) 459 | else: 460 | G_ad_cam_loss_GA, G_ad_cam_loss_LA, G_ad_cam_loss_GB, G_ad_cam_loss_LB = 0, 0, 0, 0 461 | G_ad_loss_GA = self.MSE_loss(fake_GA_logit, flag_GA_1) 462 | G_ad_loss_LA = self.MSE_loss(fake_LA_logit, flag_LA_1) 463 | G_ad_loss_A = (G_ad_loss_GA + G_ad_loss_LA) + (G_ad_cam_loss_GA + G_ad_cam_loss_LA) * self.args.cam_D_weight 464 | G_ad_loss_A = self.args.adv_weight * self.args.forward_adv_weight * G_ad_loss_A 465 | 466 | G_ad_loss_GB = self.MSE_loss(fake_GB_logit, flag_GA_1) 467 | G_ad_loss_LB = self.MSE_loss(fake_LB_logit, flag_LA_1) 468 | G_ad_loss_B = (G_ad_loss_GB + G_ad_loss_LB) + (G_ad_cam_loss_GB + G_ad_cam_loss_LB) * self.args.cam_D_weight 469 | G_ad_loss_B = self.args.adv_weight * self.args.backward_adv_weight * G_ad_loss_B 470 | # 循环一致性损失 471 | G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) * self.args.cycle_weight 472 | G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) * self.args.cycle_weight 473 | # 单位映射损失 474 | G_identity_loss_A = self.L1_loss(fake_A2A, real_A) * self.args.identity_weight 475 | G_identity_loss_B = self.L1_loss(fake_B2B, real_B) * self.args.identity_weight 476 | # G的cam损失 477 | flag_cam_1 = torch.ones_like(fake_B2A_cam_logit, requires_grad=False).to(self.args.device) 478 | flag_cam_0 = torch.zeros_like(fake_A2A_cam_logit, requires_grad=False).to(self.args.device) 479 | G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, flag_cam_1) + self.BCE_loss(fake_A2A_cam_logit, flag_cam_0) 480 | G_cam_loss_A *= self.args.cam_weight 481 | flag_cam_1 = torch.ones_like(fake_A2B_cam_logit, requires_grad=False).to(self.args.device) 482 | flag_cam_0 = torch.zeros_like(fake_B2B_cam_logit, requires_grad=False).to(self.args.device) 483 | G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, flag_cam_1) + self.BCE_loss(fake_B2B_cam_logit, flag_cam_0) 484 | G_cam_loss_B *= self.args.cam_weight 485 | 486 | G_loss_A = G_ad_loss_A + G_recon_loss_A + G_identity_loss_A + G_cam_loss_A 487 | G_loss_B = G_ad_loss_B + G_recon_loss_B + G_identity_loss_B + G_cam_loss_B 488 | 489 | self.G_adv_loss = G_ad_loss_A + G_ad_loss_B 490 | self.G_cyc_loss = G_recon_loss_A + G_recon_loss_B 491 | self.G_idt_loss = G_identity_loss_A + G_identity_loss_B 492 | self.G_cam_loss = G_cam_loss_A + G_cam_loss_B 493 | 494 | self.Generator_loss += (G_loss_A + G_loss_B) 495 | 496 | # backward 497 | self.Generator_loss.backward() 498 | return (G_ad_loss_A, G_recon_loss_A, G_identity_loss_A, G_cam_loss_A, 499 | G_ad_loss_B, G_recon_loss_B, G_identity_loss_B, G_cam_loss_B) 500 | 501 | def train(self): 502 | train_writer = SummaryWriter(os.path.join(self.args.result_dir, 'logs')) 503 | D_losses_A = AverageMeter('D_losses_A', ':.4e') 504 | D_losses_B = AverageMeter('D_losses_B', ':.4e') 505 | Discriminator_losses = AverageMeter('Discriminator_losses', ':.4e') 506 | G_ad_losses_A = AverageMeter('G_ad_losses_A', ':.4e') 507 | G_recon_losses_A = AverageMeter('G_recon_losses_A', ':.4e') 508 | G_identity_losses_A = AverageMeter('G_identity_losses_A', ':.4e') 509 | G_cam_losses_A = AverageMeter('G_cam_losses_A', ':.4e') 510 | G_ad_losses_B = AverageMeter('G_ad_losses_B', ':.4e') 511 | G_recon_losses_B = AverageMeter('G_recon_losses_B', ':.4e') 512 | G_identity_losses_B = AverageMeter('G_identity_losses_B', ':.4e') 513 | G_cam_losses_B = AverageMeter('G_cam_losses_B', ':.4e') 514 | Generator_losses = AverageMeter('Generator_losses', ':.4e') 515 | train_progress = ProgressMeter(self.args.iteration, D_losses_A, D_losses_B, Discriminator_losses, 516 | G_ad_losses_A, G_recon_losses_A, G_identity_losses_A, G_cam_losses_A, 517 | G_ad_losses_B, G_recon_losses_B, G_identity_losses_B, G_cam_losses_B, 518 | Generator_losses, prefix=f"Iteration: ") 519 | 520 | # 用于学习率 decay策略 521 | start_iter = 1 522 | mid_iter = self.args.iteration // 2 523 | lr_rate = self.args.lr / mid_iter 524 | if self.args.resume: 525 | model_list = glob(os.path.join(self.args.result_dir, self.args.dataset, 'model', '*.pt')) 526 | if not len(model_list) == 0: 527 | model_list.sort() 528 | start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) 529 | self.load(os.path.join(self.args.result_dir, self.args.dataset, 'model'), start_iter) 530 | print(" [*] Load SUCCESS") 531 | if not self.args.no_decay_flag and start_iter > mid_iter: 532 | self.G_optim.param_groups[0]['lr'] -= lr_rate * (start_iter - mid_iter) 533 | self.D_optim.param_groups[0]['lr'] -= self.G_optim.param_groups[0]['lr'] 534 | 535 | # training loop 536 | print('training start !') 537 | start_time = time.time() 538 | for step in range(start_iter, self.args.iteration + 1): 539 | if not self.args.no_decay_flag and step > mid_iter: 540 | self.G_optim.param_groups[0]['lr'] -= lr_rate 541 | self.D_optim.param_groups[0]['lr'] -= lr_rate 542 | 543 | real_A, real_B, blur = self.get_batch(mode='train') 544 | 545 | self.gen_train(True) 546 | 547 | (fake_A2B, fake_A2B_cam_logit, _, _, 548 | fake_A2B2A, _, _, 549 | fake_B2A, fake_B2A_cam_logit, _, _, 550 | fake_B2A2B, _, _, 551 | fake_A2A, fake_A2A_cam_logit, _, _, 552 | fake_B2B, fake_B2B_cam_logit, _, _) = self.forward(real_A, real_B) 553 | 554 | # Update D 555 | self.dis_train(True) 556 | self.D_optim.zero_grad() 557 | D_loss_A, D_loss_B = self.backward_D(real_A, real_B, fake_A2B, fake_B2A, blur) 558 | self.D_optim.step() 559 | # 更新统计量 560 | D_losses_A.update(D_loss_A.detach().cpu().item()) 561 | D_losses_B.update(D_loss_B.detach().cpu().item()) 562 | Discriminator_losses.update(self.discriminator_loss.detach().cpu().item()) 563 | 564 | # Update G 565 | self.dis_train(False) 566 | self.G_optim.zero_grad() 567 | (G_ad_loss_A, G_recon_loss_A, G_identity_loss_A, G_cam_loss_A, 568 | G_ad_loss_B, G_recon_loss_B, G_identity_loss_B, G_cam_loss_B) = \ 569 | self.backward_G(real_A, real_B, fake_A2B, fake_B2A, fake_A2B2A, fake_B2A2B, fake_A2A, fake_B2B, 570 | fake_A2B_cam_logit, fake_B2A_cam_logit, fake_A2A_cam_logit, fake_B2B_cam_logit) 571 | self.G_optim.step() 572 | self.gen_train(False) 573 | 574 | # clip parameter of AdaLIN and LIN, applied after optimizer step 575 | self.genA2B.apply(self.Rho_LIN_clipper) 576 | self.genB2A.apply(self.Rho_LIN_clipper) 577 | self.genA2B.apply(self.Rho_AdaLIN_clipper) 578 | self.genB2A.apply(self.Rho_AdaLIN_clipper) 579 | self.model_ema(step, self.genA2B_ema, self.genA2B) 580 | self.model_ema(step, self.genB2A_ema, self.genB2A) 581 | 582 | # 打印每一个step的损失 583 | info = f'[{step:5d}/{self.args.iteration:5d}] time: {(time.time() - start_time):4.4f} ' \ 584 | f'd_loss: {self.discriminator_loss:.8f}, g_loss: {self.Generator_loss:.8f}, ' \ 585 | f'g_adv: {self.G_adv_loss:.8f}, g_cyc: {self.G_cyc_loss:.8f}, ' \ 586 | f'g_idt: {self.G_idt_loss:.8f}, g_cam: {self.G_cam_loss:.8f}' 587 | if self.args.seg_fix_weight > 0: 588 | info += f', g_seg: {self.G_seg_loss:.8f}' 589 | if self.args.tv_loss: 590 | info += f', g_tv: {self.tv_loss:.8f}' 591 | print(info) 592 | 593 | # 更新统计量 594 | G_ad_losses_A.update(G_ad_loss_A.detach().cpu().item(), real_A.size(0)) 595 | G_recon_losses_A.update(G_recon_loss_A.detach().cpu().item(), real_A.size(0)) 596 | G_identity_losses_A.update(G_identity_loss_A.detach().cpu().item(), real_A.size(0)) 597 | G_cam_losses_A.update(G_cam_loss_A.detach().cpu().item(), real_A.size(0)) 598 | G_ad_losses_B.update(G_ad_loss_B.detach().cpu().item(), real_B.size(0)) 599 | G_recon_losses_B.update(G_recon_loss_B.detach().cpu().item(), real_B.size(0)) 600 | G_identity_losses_B.update(G_identity_loss_B.detach().cpu().item(), real_B.size(0)) 601 | G_cam_losses_B.update(G_cam_loss_B.detach().cpu().item(), real_B.size(0)) 602 | Generator_losses.update(self.Generator_loss.detach().cpu().item(), real_A.size(0)) 603 | 604 | # 可视化中间结果,计算fid,tensorboard统计 605 | if step % self.args.print_freq == 0: 606 | # 可视化中间结果 607 | self.vis_inference_result(step, train_sample_num=5, test_sample_num=5) 608 | if step > self.args.ema_start * self.args.iteration: 609 | temp = self.genA2B, self.genB2A 610 | self.genA2B, self.genB2A = self.genA2B_ema, self.genB2A_ema 611 | self.vis_inference_result(step, train_sample_num=5, test_sample_num=5, name='_ema') 612 | self.genA2B, self.genB2A = temp 613 | # 计算fid 614 | if step % self.args.calc_fid_freq == 0: 615 | temp_prefix = train_progress.prefix 616 | fid_score_A2B, fid_score_B2A = self.calc_fid_score() 617 | train_writer.add_scalar('13_fid_score_A2B', fid_score_A2B, step) 618 | train_writer.add_scalar('13_fid_score_B2A', fid_score_B2A, step) 619 | train_progress.prefix = f"Iteration: fid: A2B {fid_score_A2B:.4e}, B2A {fid_score_B2A:.4e}" 620 | if step > self.args.ema_start * self.args.iteration: 621 | temp = self.genA2B, self.genB2A 622 | self.genA2B, self.genB2A = self.genA2B_ema, self.genB2A_ema 623 | fid_score_A2B, fid_score_B2A = self.calc_fid_score() 624 | self.genA2B, self.genB2A = temp 625 | train_writer.add_scalar('14_fid_score_A2B_ema', fid_score_A2B, step) 626 | train_writer.add_scalar('14_fid_score_B2A_ema', fid_score_B2A, step) 627 | train_progress.prefix += f" A2B_ema {fid_score_A2B:.4e}, B2A_ema {fid_score_B2A:.4e}" 628 | train_progress.print(step) 629 | train_progress.prefix = temp_prefix 630 | else: 631 | train_progress.print(step) 632 | 633 | # 打印统计量 634 | train_writer.add_scalar('01_D_losses_A', D_losses_A.avg, step) 635 | train_writer.add_scalar('02_D_losses_B', D_losses_B.avg, step) 636 | train_writer.add_scalar('03_Discriminator_losses', Discriminator_losses.avg, step) 637 | train_writer.add_scalar('04_G_ad_losses_A', G_ad_losses_A.avg, step) 638 | train_writer.add_scalar('05_G_recon_losses_A', G_recon_losses_A.avg, step) 639 | train_writer.add_scalar('06_G_identity_losses_A', G_identity_losses_A.avg, step) 640 | train_writer.add_scalar('07_G_cam_losses_A', G_cam_losses_A.avg, step) 641 | train_writer.add_scalar('08_G_ad_losses_B', G_ad_losses_B.avg, step) 642 | train_writer.add_scalar('09_G_recon_losses_B', G_recon_losses_B.avg, step) 643 | train_writer.add_scalar('10_G_identity_losses_B', G_identity_losses_B.avg, step) 644 | train_writer.add_scalar('11_G_cam_losses_B', G_cam_losses_B.avg, step) 645 | train_writer.add_scalar('12_Generator_losses', Generator_losses.avg, step) 646 | train_writer.add_scalar('Learning rate', self.G_optim.param_groups[0]['lr'], step) 647 | train_writer.flush() 648 | D_losses_A.reset(), D_losses_B.reset(), Discriminator_losses.reset() 649 | G_ad_losses_A.reset(), G_recon_losses_A.reset(), G_identity_losses_A.reset() 650 | G_cam_losses_A.reset(), G_ad_losses_B.reset(), G_recon_losses_B.reset() 651 | G_identity_losses_B.reset(), G_cam_losses_B.reset(), Generator_losses.reset() 652 | 653 | if step % self.args.save_freq == 0 or step == self.args.iteration: 654 | self.save(os.path.join(self.args.result_dir, self.args.dataset, 'model'), step) 655 | 656 | # if step % 1000 == 0: 657 | # self.save(self.args.result_dir, step=None, name='_params_latest.pt') 658 | train_writer.close() 659 | 660 | def calc_fid_score(self): 661 | self.gen_train(False) 662 | if self.fid_score is None: 663 | self.fid_score = FIDScore(self.args.device, batch_size=self.args.fid_batch, num_workers=1) 664 | self.mean_std_A = self.fid_score.calc_mean_std(self.trainA_data_root) 665 | self.mean_std_B = self.fid_score.calc_mean_std(self.trainB_data_root) 666 | self.fid_loaderA = get_loader(DatasetFolder(self.trainA_data_root, self.test_transform), self.args.device, 667 | batch_size=self.args.fid_batch, shuffle=False, 668 | num_workers=self.args.num_workers) 669 | self.fid_loaderB = get_loader(DatasetFolder(self.trainB_data_root, self.test_transform), self.args.device, 670 | batch_size=self.args.fid_batch, shuffle=False, 671 | num_workers=self.args.num_workers) 672 | self.fid_score.inception_model.normalize_input = False 673 | mean_std_A2B = self.fid_score.calc_mean_std_with_gen(lambda batch: self.genA2B(batch)[0].detach(), 674 | self.fid_loaderA) 675 | fid_score_A2B = self.fid_score.calc_fid(self.mean_std_B, mean_std_A2B) 676 | mean_std_B2A = self.fid_score.calc_mean_std_with_gen(lambda batch: self.genB2A(batch)[0].detach(), 677 | self.fid_loaderB) 678 | fid_score_B2A = self.fid_score.calc_fid(self.mean_std_A, mean_std_B2A) 679 | return fid_score_A2B, fid_score_B2A 680 | 681 | def vis_inference_result(self, step, train_sample_num=5, test_sample_num=5, name=''): 682 | A2B = np.zeros((self.args.img_size * (9 + self.args.attention_gan), 0, 3)) 683 | B2A = np.zeros((self.args.img_size * (9 + self.args.attention_gan), 0, 3)) 684 | self.gen_train(False), self.dis_train(False) 685 | for _ in range(train_sample_num): 686 | real_A, real_B, _ = self.get_batch(mode='train') 687 | with torch.no_grad(): 688 | (fake_A2B, _, fake_A2B_heatmap, fake_A2B_attention, 689 | fake_A2B2A, fake_A2B2A_heatmap, fake_A2B2A_attention, 690 | fake_B2A, _, fake_B2A_heatmap, fake_B2A_attention, 691 | fake_B2A2B, fake_B2A2B_heatmap, fake_B2A2B_attention, 692 | fake_A2A, _, fake_A2A_heatmap, fake_A2A_attention, 693 | fake_B2B, _, fake_B2B_heatmap, fake_B2B_attention) = \ 694 | self.forward(real_A, real_B) 695 | 696 | cam_fake_A2B, cam_fake_B2A = None, None 697 | if self.args.seg_D_cam_inp_mask: 698 | cam_fake_A2B = self.maskA * fake_A2B 699 | cam_fake_B2A = self.maskB * fake_B2A 700 | 701 | maskA, maskB = None, None 702 | if self.args.seg_D_cam_fea_mask: 703 | maskA = self.maskA 704 | maskB = self.maskB 705 | 706 | _, _, fake_disGB_cam_hm = self.disGB(fake_A2B, cam_fake_A2B, maskA) 707 | _, _, fake_disLB_cam_hm = self.disLB(fake_A2B, cam_fake_A2B, maskA) 708 | _, _, fake_disGA_cam_hm = self.disGA(fake_B2A, cam_fake_B2A, maskB) 709 | _, _, fake_disLA_cam_hm = self.disLA(fake_B2A, cam_fake_B2A, maskB) 710 | A2B_list = [RGB2BGR(tensor2numpy(denorm(real_A[0]))), 711 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.args.img_size), 712 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), 713 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.args.img_size), 714 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), 715 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.args.img_size), 716 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0]))), 717 | cam(tensor2numpy(fake_disGB_cam_hm[0]), self.args.img_size), 718 | cam(tensor2numpy(fake_disLB_cam_hm[0]), self.args.img_size), 719 | ] 720 | if self.args.attention_gan > 0: 721 | for i in range(self.args.attention_gan): 722 | A2B_list.append(attention_mask(tensor2numpy(fake_A2B_attention[0][i:(i + 1)]), 723 | self.args.img_size)) 724 | A2B = np.concatenate((A2B, np.concatenate(A2B_list, 0)), 1) 725 | 726 | B2A_list = [RGB2BGR(tensor2numpy(denorm(real_B[0]))), 727 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.args.img_size), 728 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), 729 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.args.img_size), 730 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), 731 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.args.img_size), 732 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0]))), 733 | cam(tensor2numpy(fake_disGA_cam_hm[0]), self.args.img_size), 734 | cam(tensor2numpy(fake_disLA_cam_hm[0]), self.args.img_size), 735 | ] 736 | if self.args.attention_gan > 0: 737 | for i in range(self.args.attention_gan): 738 | B2A_list.append(attention_mask(tensor2numpy(fake_B2A_attention[0][i:(i + 1)]), 739 | self.args.img_size)) 740 | B2A = np.concatenate((B2A, np.concatenate(B2A_list, 0)), 1) 741 | 742 | for _ in range(test_sample_num): 743 | real_A, real_B, _ = self.get_batch(mode='test') 744 | with torch.no_grad(): 745 | (fake_A2B, _, fake_A2B_heatmap, fake_A2B_attention, 746 | fake_A2B2A, fake_A2B2A_heatmap, fake_A2B2A_attention, 747 | fake_B2A, _, fake_B2A_heatmap, fake_B2A_attention, 748 | fake_B2A2B, fake_B2A2B_heatmap, fake_B2A2B_attention, 749 | fake_A2A, _, fake_A2A_heatmap, fake_A2A_attention, 750 | fake_B2B, _, fake_B2B_heatmap, fake_B2B_attention) = \ 751 | self.forward(real_A, real_B) 752 | 753 | cam_fake_A2B, cam_fake_B2A = None, None 754 | if self.args.seg_D_cam_inp_mask: 755 | cam_fake_A2B = self.maskA * fake_A2B 756 | cam_fake_B2A = self.maskB * fake_B2A 757 | 758 | maskA, maskB = None, None 759 | if self.args.seg_D_cam_fea_mask: 760 | maskA = self.maskA 761 | maskB = self.maskB 762 | 763 | _, _, fake_disGB_cam_hm = self.disGB(fake_A2B, cam_fake_A2B, maskA) 764 | _, _, fake_disLB_cam_hm = self.disLB(fake_A2B, cam_fake_A2B, maskA) 765 | _, _, fake_disGA_cam_hm = self.disGA(fake_B2A, cam_fake_B2A, maskB) 766 | _, _, fake_disLA_cam_hm = self.disLA(fake_B2A, cam_fake_B2A, maskB) 767 | A2B_list = [RGB2BGR(tensor2numpy(denorm(real_A[0]))), 768 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.args.img_size), 769 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), 770 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.args.img_size), 771 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), 772 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.args.img_size), 773 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0]))), 774 | cam(tensor2numpy(fake_disGB_cam_hm[0]), self.args.img_size), 775 | cam(tensor2numpy(fake_disLB_cam_hm[0]), self.args.img_size), 776 | ] 777 | if self.args.attention_gan > 0: 778 | for i in range(self.args.attention_gan): 779 | A2B_list.append(attention_mask(tensor2numpy(fake_A2B_attention[0][i:(i + 1)]), 780 | self.args.img_size)) 781 | A2B = np.concatenate((A2B, np.concatenate(A2B_list, 0)), 1) 782 | 783 | B2A_list = [RGB2BGR(tensor2numpy(denorm(real_B[0]))), 784 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.args.img_size), 785 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), 786 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.args.img_size), 787 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), 788 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.args.img_size), 789 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0]))), 790 | cam(tensor2numpy(fake_disGA_cam_hm[0]), self.args.img_size), 791 | cam(tensor2numpy(fake_disLA_cam_hm[0]), self.args.img_size), 792 | ] 793 | if self.args.attention_gan > 0: 794 | for i in range(self.args.attention_gan): 795 | B2A_list.append(attention_mask(tensor2numpy(fake_B2A_attention[0][i:(i + 1)]), 796 | self.args.img_size)) 797 | B2A = np.concatenate((B2A, np.concatenate(B2A_list, 0)), 1) 798 | 799 | cv2.imwrite(os.path.join(self.args.result_dir, self.args.dataset, 'img', f'A2B{name}_{step:07d}.png'), 800 | A2B * 255.0) 801 | cv2.imwrite(os.path.join(self.args.result_dir, self.args.dataset, 'img', f'B2A{name}_{step:07d}.png'), 802 | B2A * 255.0) 803 | return 804 | 805 | def model_ema(self, step, G_ema, G): 806 | if step > self.args.ema_start * self.args.iteration: 807 | for p_ema, p in zip(G_ema.parameters(), G.parameters()): 808 | p_ema.copy_(p.lerp(p_ema, self.args.ema_beta)) 809 | else: 810 | for p_ema, p in zip(G_ema.parameters(), G.parameters()): 811 | p_ema.copy_(p) 812 | for b_ema, b in zip(G_ema.buffers(), G.buffers()): 813 | b_ema.copy_(b) 814 | return 815 | 816 | def save(self, root, step, name=None): 817 | if name is None: 818 | name = '_params_%07d.pt' % step 819 | params = {'genA2B': self.genA2B.state_dict(), 'genB2A': self.genB2A.state_dict(), 820 | 'genA2B_ema': self.genA2B_ema.state_dict(), 'genB2A_ema': self.genB2A_ema.state_dict(), 821 | 'disGA': self.disGA.state_dict(), 'disGB': self.disGB.state_dict(), 'disLA': self.disLA.state_dict(), 822 | 'disLB': self.disLB.state_dict()} 823 | torch.save(params, os.path.join(root, self.args.dataset + name)) 824 | g_params = {'genA2B': self.genA2B.state_dict(), 'genA2B_ema': self.genA2B_ema.state_dict()} 825 | torch.save(g_params, os.path.join(root, self.args.dataset + f'_g{name}')) 826 | 827 | def load(self, root, step): 828 | params = torch.load(os.path.join(root, self.args.dataset + '_params_%07d.pt' % step), 829 | map_location=torch.device("cpu")) 830 | self.genA2B.load_state_dict(params['genA2B']) 831 | self.genB2A.load_state_dict(params['genB2A']) 832 | self.genA2B_ema.load_state_dict(params['genA2B_ema']) 833 | self.genB2A_ema.load_state_dict(params['genB2A_ema']) 834 | self.disGA.load_state_dict(params['disGA']) 835 | self.disGB.load_state_dict(params['disGB']) 836 | self.disLA.load_state_dict(params['disLA']) 837 | self.disLB.load_state_dict(params['disLB']) 838 | 839 | def test(self): 840 | model_list = glob(os.path.join(self.args.result_dir, '*_params_latest.pt')) 841 | if len(model_list) == 0: 842 | model_list = glob(os.path.join(self.args.result_dir, self.args.dataset, 'model', '*.pt')) 843 | if len(model_list) != 0: 844 | model_list.sort() 845 | if not (self.args.generator_model and os.path.isfile(self.args.generator_model)): 846 | self.args.generator_model = model_list[-1] 847 | 848 | if self.args.generator_model and os.path.isfile(self.args.generator_model): 849 | params = torch.load(self.args.generator_model, map_location=torch.device("cpu")) 850 | self.genA2B.load_state_dict(params['genA2B_ema']) 851 | self.genB2A.load_state_dict(params['genB2A_ema']) 852 | print(" [*] Load SUCCESS") 853 | else: 854 | print(" [*] Load FAILURE") 855 | return 856 | 857 | self.genA2B.eval(), self.genB2A.eval() 858 | for n, real_A in tqdm(enumerate(self.testA_loader)): 859 | real_A = real_A.to(self.args.device) 860 | with torch.no_grad(): 861 | fake_A2B, _, fake_A2B_heatmap, fake_A2B_attention = self.genA2B(real_A) 862 | fake_A2B2A, _, fake_A2B2A_heatmap, fake_A2B2A_attention = self.genB2A(fake_A2B) 863 | fake_A2A, _, fake_A2A_heatmap, fake_A2A_attention = self.genB2A(real_A) 864 | 865 | A2B_list = [RGB2BGR(tensor2numpy(denorm(real_A[0]))), 866 | cam(tensor2numpy(fake_A2A_heatmap[0]), self.args.img_size), 867 | RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), 868 | cam(tensor2numpy(fake_A2B_heatmap[0]), self.args.img_size), 869 | RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), 870 | cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.args.img_size), 871 | RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0]))) 872 | ] 873 | if self.args.attention_gan > 0: 874 | for i in range(self.args.attention_gan): 875 | A2B_list.append(attention_mask(tensor2numpy(fake_A2B_attention[0][i:(i + 1)]), self.args.img_size)) 876 | A2B = np.concatenate(A2B_list, 0) 877 | cv2.imwrite(os.path.join(self.args.result_dir, self.args.dataset, 'test', 'A2B_%d.png' % (n + 1)), 878 | A2B * 255.0) 879 | 880 | for n, real_B in tqdm(enumerate(self.testB_loader)): 881 | real_B = real_B.to(self.args.device) 882 | with torch.no_grad(): 883 | fake_B2A, _, fake_B2A_heatmap, fake_B2A_attention = self.genB2A(real_B) 884 | fake_B2A2B, _, fake_B2A2B_heatmap, fake_B2A2B_attention = self.genA2B(fake_B2A) 885 | fake_B2B, _, fake_B2B_heatmap, fake_B2B_attention = self.genA2B(real_B) 886 | 887 | B2A_list = [RGB2BGR(tensor2numpy(denorm(real_B[0]))), 888 | cam(tensor2numpy(fake_B2B_heatmap[0]), self.args.img_size), 889 | RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), 890 | cam(tensor2numpy(fake_B2A_heatmap[0]), self.args.img_size), 891 | RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), 892 | cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.args.img_size), 893 | RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))] 894 | if self.args.attention_gan > 0: 895 | for i in range(self.args.attention_gan): 896 | B2A_list.append(attention_mask(tensor2numpy(fake_B2A_attention[0][i:(i + 1)]), self.args.img_size)) 897 | B2A = np.concatenate(B2A_list, 0) 898 | cv2.imwrite(os.path.join(self.args.result_dir, self.args.dataset, 'test', 'B2A_%d.png' % (n + 1)), 899 | B2A * 255.0) 900 | --------------------------------------------------------------------------------