├── .gitignore ├── LICENSE ├── README.md ├── cog.yaml ├── configs ├── __init__.py ├── data_configs.py ├── paths_config.py └── transforms_config.py ├── criteria ├── __init__.py ├── arcface │ ├── __init__.py │ └── iresnet.py ├── contrastive_id_loss.py ├── contrastive_loss.py ├── feature_matching_loss.py ├── id_loss.py ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── moco_loss.py └── w_norm.py ├── datasets ├── __init__.py ├── augmentations.py ├── contrastive_dataset.py ├── gt_res_dataset.py ├── image_folder.py ├── images_dataset.py ├── inference_dataset.py ├── inference_dataset_me.py ├── inversion_dataset.py └── test_ddp_sample.py ├── doc ├── contrastive.png ├── pipeline.png └── teaser.png ├── docs ├── encoding_inputs.jpg ├── encoding_outputs.jpg ├── frontalization_inputs.jpg ├── frontalization_outputs.jpg ├── seg2image.png ├── sketch2image.png ├── super_res_32.jpg ├── super_res_style_mixing.jpg ├── teaser.png ├── toonify_input.jpg └── toonify_output.jpg ├── editings ├── ganspace.py ├── ganspace_pca │ └── ffhq_pca.pt ├── interfacegan_directions │ ├── age.pt │ ├── pose.pt │ └── smile.pt ├── latent_editor.py └── sefa.py ├── environment └── clcae_env.yaml ├── mertric ├── __init__.py ├── fid.py └── measure.py ├── models ├── __init__.py ├── attention_feature_psp.py ├── base_network.py ├── encoders │ ├── __init__.py │ ├── fapsp_encoder.py │ ├── helpers.py │ ├── model_irse.py │ ├── projection_head.py │ ├── psp_encoders.py │ └── transformer.py ├── image_encoder.py ├── latent_encoder.py ├── mtcnn │ ├── __init__.py │ ├── mtcnn.py │ └── mtcnn_pytorch │ │ ├── __init__.py │ │ └── src │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── box_utils.py │ │ ├── detector.py │ │ ├── first_stage.py │ │ ├── get_nets.py │ │ ├── matlab_cp2tform.py │ │ ├── visualization_utils.py │ │ └── weights │ │ ├── onet.npy │ │ ├── pnet.npy │ │ └── rnet.npy ├── psp.py └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── options ├── __init__.py ├── test_options.py └── train_options.py ├── scripts ├── align_all_parallel.py ├── calc_id_loss_parallel.py ├── calc_losses_on_images.py ├── inference_car.py ├── inference_edit.py ├── inference_edit_not_interface.py ├── inference_inversion.py └── train.py ├── training ├── __init__.py ├── contrastive_coach.py ├── distributed.py ├── inversion_coach.py └── ranger.py └── utils ├── __init__.py ├── alignment.py ├── common.py ├── data_utils.py ├── dataset_txt_generate.py ├── train_utils.py └── wandb_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Generation results 2 | results/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | .idea 134 | checkpoints 135 | jizhi_task_scripts -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elad Richardson, Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | python_version: "3.8" 4 | system_packages: 5 | - "libgl1-mesa-glx" 6 | - "libglib2.0-0" 7 | - "ninja-build" 8 | python_packages: 9 | - "cmake==3.21.2" 10 | - "torch==1.8.0" 11 | - "torchvision==0.9.0" 12 | - "numpy==1.21.1" 13 | - "ipython==7.21.0" 14 | - "tensorboard==2.6.0" 15 | - "tqdm==4.43.0" 16 | - "torch-optimizer==0.1.0" 17 | - "opencv-python==4.5.3.56" 18 | - "Pillow==8.3.2" 19 | - "matplotlib==3.2.1" 20 | - "scipy==1.7.1" 21 | run: 22 | - pip install dlib 23 | 24 | predict: "predict.py:Predictor" 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/configs/__init__.py -------------------------------------------------------------------------------- /configs/data_configs.py: -------------------------------------------------------------------------------- 1 | from configs import transforms_config 2 | from configs.paths_config import dataset_paths 3 | import os 4 | 5 | DATASETS = { 6 | 'ffhq_encode_contrastive': { 7 | 'transforms': transforms_config.ContrastiveTransforms, 8 | 'train_image_root': os.path.join(dataset_paths['ffhq_generate_train'], 'image'), 9 | 'train_latent_root': os.path.join(dataset_paths['ffhq_generate_train'], 'latent'), 10 | 'test_image_root': os.path.join(dataset_paths['ffhq_generate_test'], 'image'), 11 | 'test_latent_root': os.path.join(dataset_paths['ffhq_generate_test'], 'latent'), 12 | 'avg_latent_root': dataset_paths['avg_latent_root'], 13 | 'avg_image_root': dataset_paths['avg_image_root'], 14 | }, 15 | 16 | 'ffhq_encode_inversion': { 17 | 'transforms': transforms_config.EncodeTransforms, 18 | 'train_image_root': os.path.join(dataset_paths['ffhq_inversion'], 'train_img.txt'), 19 | 'test_image_root': os.path.join(dataset_paths['ffhq_inversion'], 'val_img.txt'), 20 | 'avg_image_root': dataset_paths['avg_image_root'] 21 | }, 22 | 23 | 24 | 'car_encode_contrastive': { 25 | 'transforms': transforms_config.ContrastiveTransforms, 26 | 'train_image_root': os.path.join(dataset_paths['car_generate_train'], 'image'), 27 | 'train_latent_root': os.path.join(dataset_paths['car_generate_train'], 'latent'), 28 | 'test_image_root': os.path.join(dataset_paths['car_generate_test'], 'image'), 29 | 'test_latent_root': os.path.join(dataset_paths['car_generate_test'], 'latent'), 30 | 'avg_latent_root': dataset_paths['car_avg_latent_root'], 31 | 'avg_image_root': dataset_paths['car_avg_image_root'], 32 | }, 33 | 34 | 'car_encode_inversion': { 35 | 'transforms': transforms_config.CarsEncodeTransforms, 36 | 'train_image_root': os.path.join(dataset_paths['car_inversion'], 'train_img.txt'), 37 | 'test_image_root': os.path.join(dataset_paths['car_inversion'], 'val_img.txt'), 38 | 'avg_image_root': dataset_paths['car_avg_image_root'] 39 | }, 40 | } 41 | -------------------------------------------------------------------------------- /configs/paths_config.py: -------------------------------------------------------------------------------- 1 | dataset_paths = { 2 | 'ffhq_generate_train': './data/gan_inversion/train', 3 | 'ffhq_generate_test': './data/data/gan_inversion/test', 4 | 'avg_latent_root': '/apdcephfs/share_1290939/kumamzqliu/data/gan_inversion/avg/latent/latent_avg.npy', 5 | 'avg_image_root': '/apdcephfs/share_1290939/kumamzqliu/data/gan_inversion/avg/image/image_avg.png', 6 | 'ffhq_inversion': '/apdcephfs/share_1290939/kumamzqliu/data/face_inversion', 7 | 8 | 9 | 'car_generate_train': '/apdcephfs/share_1290939/kumamzqliu/data/car_inversion/train', 10 | 'car_generate_test': '/apdcephfs/share_1290939/kumamzqliu/data/car_inversion/test', 11 | 'car_avg_latent_root': '/apdcephfs/share_1290939/kumamzqliu/data/car_inversion/avg/latent/000000.npy', 12 | 'car_avg_image_root': '/apdcephfs/share_1290939/kumamzqliu/data/car_inversion/avg/image/000000.png', 13 | 'car_inversion': '/apdcephfs/share_1290939/kumamzqliu/data/car_inversion', 14 | } 15 | 16 | model_paths = { 17 | 'stylegan_ffhq': 'pretrained/stylegan2-ffhq-config-f.pt', 18 | 'stylegan_church': 'pretrained/stylegan2-church-config-f.pt', 19 | 'stylegan_horse': 'pretrained/stylegan2-horse-config-f.pt', 20 | 'ir_se50': 'pretrained/model_ir_se50.pth', 21 | 'circular_face': 'pretrained/CurricularFace_Backbone.pth', 22 | 'mtcnn_pnet': 'pretrained/pnet.npy', 23 | 'mtcnn_rnet': 'pretrained/rnet.npy', 24 | 'mtcnn_onet': 'pretrained/onet.npy', 25 | 'shape_predictor': 'shape_predictor_68_face_landmarks.dat', 26 | 'moco': 'pretrained/moco_v2_800ep_pretrain.pt', 27 | 'contrastive_ffhq_image': 'pretrained/ffhq_cont/best_model_image.pt', 28 | 'contrastive_ffhq_latent': 'pretrained/ffhq_cont/best_model_latent.pt', 29 | 'contrastive_car_image': 'pretrained/car_cont/best_model_image.pt', 30 | 'contrastive_car_latent': 'pretrained/car_cont/best_model_latent.pt', 31 | 32 | } 33 | 34 | edit_paths = { 35 | 'age': 'pretrained/age.pt', 36 | 'pose': '/apdcephfs/share_1290939/kumamzqliu/code/pixel2style2pixel/editings/interfacegan_directions/pose.pt', 37 | 'smile': '/apdcephfs/share_1290939/kumamzqliu/code/pixel2style2pixel/editings/interfacegan_directions/smile.pt', 38 | 'ffhq_pca': '/apdcephfs/share_1290939/kumamzqliu/code/pixel2style2pixel/editings/ganspace_pca/ffhq_pca.pt', 39 | 'car_pca': '/apdcephfs/share_1290939/kumamzqliu/code/pixel2style2pixel/editings/ganspace_pca/cars_pca.pt' 40 | } -------------------------------------------------------------------------------- /configs/transforms_config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torchvision.transforms as transforms 3 | from datasets import augmentations 4 | 5 | 6 | class TransformsConfig(object): 7 | 8 | def __init__(self, opts): 9 | self.opts = opts 10 | 11 | @abstractmethod 12 | def get_transforms(self): 13 | pass 14 | 15 | 16 | class ContrastiveTransforms(TransformsConfig): 17 | 18 | def __init__(self, opts): 19 | super(ContrastiveTransforms, self).__init__(opts) 20 | 21 | def get_transforms(self): 22 | transforms_dict = { 23 | 'transform_gt_train': transforms.Compose([ 24 | transforms.Resize((256, 256)), 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 27 | 'transform_source': None, 28 | 'transform_test': transforms.Compose([ 29 | transforms.Resize((256, 256)), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 32 | 'transform_inference': transforms.Compose([ 33 | transforms.Resize((256, 256)), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 36 | } 37 | return transforms_dict 38 | 39 | 40 | class EncodeTransforms(TransformsConfig): 41 | 42 | def __init__(self, opts): 43 | super(EncodeTransforms, self).__init__(opts) 44 | 45 | def get_transforms(self): 46 | transforms_dict = { 47 | 'transform_gt_train': transforms.Compose([ 48 | transforms.Resize((256, 256)), 49 | transforms.RandomHorizontalFlip(0.5), 50 | transforms.ToTensor(), 51 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 52 | 'transform_source': None, 53 | 'transform_test': transforms.Compose([ 54 | transforms.Resize((256, 256)), 55 | transforms.ToTensor(), 56 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 57 | 'transform_inference': transforms.Compose([ 58 | transforms.Resize((256, 256)), 59 | transforms.ToTensor(), 60 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 61 | } 62 | return transforms_dict 63 | 64 | 65 | class CarsEncodeTransforms(TransformsConfig): 66 | 67 | def __init__(self, opts): 68 | super(CarsEncodeTransforms, self).__init__(opts) 69 | 70 | def get_transforms(self): 71 | transforms_dict = { 72 | 'transform_gt_train': transforms.Compose([ 73 | transforms.Resize((192, 256)), 74 | transforms.RandomHorizontalFlip(0.5), 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 77 | 'transform_source': None, 78 | 'transform_test': transforms.Compose([ 79 | transforms.Resize((192, 256)), 80 | transforms.ToTensor(), 81 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 82 | 'transform_inference': transforms.Compose([ 83 | transforms.Resize((192, 256)), 84 | transforms.ToTensor(), 85 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 86 | } 87 | return transforms_dict 88 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/criteria/__init__.py -------------------------------------------------------------------------------- /criteria/arcface/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | -------------------------------------------------------------------------------- /criteria/arcface/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, 10 | out_planes, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=dilation, 14 | groups=groups, 15 | bias=False, 16 | dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, 22 | out_planes, 23 | kernel_size=1, 24 | stride=stride, 25 | bias=False) 26 | 27 | 28 | class IBasicBlock(nn.Module): 29 | expansion = 1 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, 31 | groups=1, base_width=64, dilation=1): 32 | super(IBasicBlock, self).__init__() 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 38 | self.conv1 = conv3x3(inplanes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 40 | self.prelu = nn.PReLU(planes) 41 | self.conv2 = conv3x3(planes, planes, stride) 42 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | identity = x 48 | out = self.bn1(x) 49 | out = self.conv1(out) 50 | out = self.bn2(out) 51 | out = self.prelu(out) 52 | out = self.conv2(out) 53 | out = self.bn3(out) 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | out += identity 57 | return out 58 | 59 | 60 | class IResNet(nn.Module): 61 | fc_scale = 7 * 7 62 | def __init__(self, 63 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 64 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 65 | super(IResNet, self).__init__() 66 | self.fp16 = fp16 67 | self.inplanes = 64 68 | self.dilation = 1 69 | if replace_stride_with_dilation is None: 70 | replace_stride_with_dilation = [False, False, False] 71 | if len(replace_stride_with_dilation) != 3: 72 | raise ValueError("replace_stride_with_dilation should be None " 73 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 74 | self.groups = groups 75 | self.base_width = width_per_group 76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 78 | self.prelu = nn.PReLU(self.inplanes) 79 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 80 | self.layer2 = self._make_layer(block, 81 | 128, 82 | layers[1], 83 | stride=2, 84 | dilate=replace_stride_with_dilation[0]) 85 | self.layer3 = self._make_layer(block, 86 | 256, 87 | layers[2], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[1]) 90 | self.layer4 = self._make_layer(block, 91 | 512, 92 | layers[3], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[2]) 95 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 96 | self.dropout = nn.Dropout(p=dropout, inplace=True) 97 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 98 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 99 | nn.init.constant_(self.features.weight, 1.0) 100 | self.features.weight.requires_grad = False 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | nn.init.normal_(m.weight, 0, 0.1) 105 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 106 | nn.init.constant_(m.weight, 1) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | if zero_init_residual: 110 | for m in self.modules(): 111 | if isinstance(m, IBasicBlock): 112 | nn.init.constant_(m.bn2.weight, 0) 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 115 | downsample = None 116 | previous_dilation = self.dilation 117 | if dilate: 118 | self.dilation *= stride 119 | stride = 1 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | conv1x1(self.inplanes, planes * block.expansion, stride), 123 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 124 | ) 125 | layers = [] 126 | layers.append( 127 | block(self.inplanes, planes, stride, downsample, self.groups, 128 | self.base_width, previous_dilation)) 129 | self.inplanes = planes * block.expansion 130 | for _ in range(1, blocks): 131 | layers.append( 132 | block(self.inplanes, 133 | planes, 134 | groups=self.groups, 135 | base_width=self.base_width, 136 | dilation=self.dilation)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x, return_features=False): 141 | out = [] 142 | with torch.cuda.amp.autocast(self.fp16): 143 | x = self.conv1(x) 144 | x = self.bn1(x) 145 | x = self.prelu(x) 146 | x = self.layer1(x) 147 | out.append(x) 148 | x = self.layer2(x) 149 | out.append(x) 150 | x = self.layer3(x) 151 | out.append(x) 152 | x = self.layer4(x) 153 | out.append(x) 154 | x = self.bn2(x) 155 | x = torch.flatten(x, 1) 156 | x = self.dropout(x) 157 | x = self.fc(x.float() if self.fp16 else x) 158 | x = self.features(x) 159 | 160 | if return_features: 161 | out.append(x) 162 | return out 163 | return x 164 | 165 | 166 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 167 | model = IResNet(block, layers, **kwargs) 168 | if pretrained: 169 | raise ValueError() 170 | return model 171 | 172 | 173 | def iresnet18(pretrained=False, progress=True, **kwargs): 174 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 175 | progress, **kwargs) 176 | 177 | 178 | def iresnet34(pretrained=False, progress=True, **kwargs): 179 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 180 | progress, **kwargs) 181 | 182 | 183 | def iresnet50(pretrained=False, progress=True, **kwargs): 184 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 185 | progress, **kwargs) 186 | 187 | 188 | def iresnet100(pretrained=False, progress=True, **kwargs): 189 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 190 | progress, **kwargs) 191 | 192 | 193 | def iresnet200(pretrained=False, progress=True, **kwargs): 194 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, 195 | progress, **kwargs) 196 | -------------------------------------------------------------------------------- /criteria/contrastive_id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from configs.paths_config import model_paths 5 | from models.image_encoder import ImageEncoder 6 | from models.latent_encoder import LatentEncoder 7 | from .contrastive_loss import ClipLoss 8 | 9 | 10 | class ContrastiveID(nn.Module): 11 | def __init__(self, opts): 12 | super(ContrastiveID, self).__init__() 13 | self.net_image = ImageEncoder(opts) 14 | self.net_latent = LatentEncoder(opts) 15 | self.net_image.load_weights(model_paths[opts.contrastive_model_image]) 16 | 17 | print('Loading decoder weights from pretrained contrastive_image!') 18 | self.net_latent.load_weights(model_paths[opts.contrastive_model_latent]) 19 | print('Loading decoder weights from pretrained contrastive_latent!') 20 | self.net_image.eval() 21 | self.net_latent.eval() 22 | for param in self.net_image.parameters(): 23 | param.requires_grad = False 24 | for latent_encoder in self.net_latent.parameters(): 25 | latent_encoder.requires_grad = False 26 | self.contrastive_loss = ClipLoss() 27 | 28 | def forward(self, image, image_avg, latent, latent_avg): 29 | ''' 30 | 31 | Args: 32 | image: True 33 | latent: Fake 34 | 35 | Returns: 36 | 37 | ''' 38 | B = image.shape[0] 39 | latent_avg = latent_avg.repeat(B, 1) 40 | latent_avg = latent_avg.unsqueeze(1) 41 | image_embedding, t = self.net_image.forward(image, image_avg) 42 | latent_embedding = self.net_latent.forward(latent, latent_avg) 43 | loss = self.contrastive_loss(image_embedding, latent_embedding, t) 44 | return loss 45 | -------------------------------------------------------------------------------- /criteria/feature_matching_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FeatureMatchingLoss(nn.Module): 5 | def __int__(self): 6 | super(FeatureMatchingLoss, self).__int__() 7 | self.l1 = nn.L1Loss() 8 | 9 | def forward(self, enc_feat, dec_feat, layer_idx=None): 10 | loss = [] 11 | if layer_idx is None: 12 | layer_idx = [i for i in range(len(enc_feat))] 13 | for i in layer_idx: 14 | loss.append(self.l1(enc_feat[i], dec_feat[i])) 15 | return loss 16 | -------------------------------------------------------------------------------- /criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from configs.paths_config import model_paths 4 | from models.encoders.model_irse import Backbone 5 | 6 | 7 | class IDLoss(nn.Module): 8 | def __init__(self): 9 | super(IDLoss, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | for module in [self.facenet, self.face_pool]: 16 | for param in module.parameters(): 17 | param.requires_grad = False 18 | 19 | def extract_feats(self, x): 20 | x = x[:, :, 35:223, 32:220] # Crop interesting region 21 | x = self.face_pool(x) 22 | x_feats = self.facenet(x) 23 | return x_feats 24 | 25 | def forward(self, y_hat, y, x): 26 | n_samples = x.shape[0] 27 | x_feats = self.extract_feats(x) 28 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 29 | y_hat_feats = self.extract_feats(y_hat) 30 | y_feats = y_feats.detach() 31 | loss = 0 32 | sim_improvement = 0 33 | id_logs = [] 34 | count = 0 35 | for i in range(n_samples): 36 | diff_target = y_hat_feats[i].dot(y_feats[i]) 37 | diff_input = y_hat_feats[i].dot(x_feats[i]) 38 | diff_views = y_feats[i].dot(x_feats[i]) 39 | id_logs.append({'diff_target': float(diff_target), 40 | 'diff_input': float(diff_input), 41 | 'diff_views': float(diff_views)}) 42 | loss += 1 - diff_target 43 | id_diff = float(diff_target) - float(diff_views) 44 | sim_improvement += id_diff 45 | count += 1 46 | 47 | return loss / count, sim_improvement / count, id_logs 48 | -------------------------------------------------------------------------------- /criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | from configs.paths_config import model_paths 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | model = models.alexnet(True) 81 | # model.load_state_dict(torch.load(model_paths['alex'])) 82 | self.layers = model.features 83 | self.target_layers = [2, 5, 8, 10, 12] 84 | self.n_channels_list = [64, 192, 384, 256, 256] 85 | 86 | self.set_requires_grad(False) 87 | 88 | 89 | class VGG16(BaseNet): 90 | def __init__(self): 91 | super(VGG16, self).__init__() 92 | 93 | self.layers = models.vgg16(True).features 94 | self.target_layers = [4, 9, 16, 23, 30] 95 | self.n_channels_list = [64, 128, 256, 512, 512] 96 | 97 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from configs.paths_config import model_paths 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | # old_state_dict = torch.load(model_paths['alex0.1'], map_location=None if torch.cuda.is_available() else torch.device('cpu')) 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from configs.paths_config import model_paths 5 | 6 | 7 | class MocoLoss(nn.Module): 8 | 9 | def __init__(self): 10 | super(MocoLoss, self).__init__() 11 | print("Loading MOCO model from path: {}".format(model_paths["moco"])) 12 | self.model = self.__load_model() 13 | self.model.cuda() 14 | self.model.eval() 15 | 16 | @staticmethod 17 | def __load_model(): 18 | import torchvision.models as models 19 | model = models.__dict__["resnet50"]() 20 | # freeze all layers but the last fc 21 | for name, param in model.named_parameters(): 22 | if name not in ['fc.weight', 'fc.bias']: 23 | param.requires_grad = False 24 | checkpoint = torch.load(model_paths['moco'], map_location="cpu") 25 | state_dict = checkpoint['state_dict'] 26 | # rename moco pre-trained keys 27 | for k in list(state_dict.keys()): 28 | # retain only encoder_q up to before the embedding layer 29 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 30 | # remove prefix 31 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 32 | # delete renamed or unused k 33 | del state_dict[k] 34 | msg = model.load_state_dict(state_dict, strict=False) 35 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 36 | # remove output layer 37 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 38 | return model 39 | 40 | def extract_feats(self, x): 41 | x = F.interpolate(x, size=224) 42 | x_feats = self.model(x) 43 | x_feats = nn.functional.normalize(x_feats, dim=1) 44 | x_feats = x_feats.squeeze() 45 | return x_feats 46 | 47 | def forward(self, y_hat, y, x): 48 | n_samples = x.shape[0] 49 | x_feats = self.extract_feats(x) 50 | y_feats = self.extract_feats(y) 51 | y_hat_feats = self.extract_feats(y_hat) 52 | y_feats = y_feats.detach() 53 | loss = 0 54 | sim_improvement = 0 55 | sim_logs = [] 56 | count = 0 57 | for i in range(n_samples): 58 | diff_target = y_hat_feats[i].dot(y_feats[i]) 59 | diff_input = y_hat_feats[i].dot(x_feats[i]) 60 | diff_views = y_feats[i].dot(x_feats[i]) 61 | sim_logs.append({'diff_target': float(diff_target), 62 | 'diff_input': float(diff_input), 63 | 'diff_views': float(diff_views)}) 64 | loss += 1 - diff_target 65 | sim_diff = float(diff_target) - float(diff_views) 66 | sim_improvement += sim_diff 67 | count += 1 68 | 69 | return loss / count, sim_improvement / count, sim_logs 70 | -------------------------------------------------------------------------------- /criteria/w_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class WNormLoss(nn.Module): 6 | 7 | def __init__(self, start_from_latent_avg=True): 8 | super(WNormLoss, self).__init__() 9 | self.start_from_latent_avg = start_from_latent_avg 10 | 11 | def forward(self, latent, latent_avg=None): 12 | if self.start_from_latent_avg: 13 | latent = latent - latent_avg 14 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] 15 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torchvision import transforms 6 | 7 | 8 | class ToOneHot(object): 9 | """ Convert the input PIL image to a one-hot torch tensor """ 10 | def __init__(self, n_classes=None): 11 | self.n_classes = n_classes 12 | 13 | def onehot_initialization(self, a): 14 | if self.n_classes is None: 15 | self.n_classes = len(np.unique(a)) 16 | out = np.zeros(a.shape + (self.n_classes, ), dtype=int) 17 | out[self.__all_idx(a, axis=2)] = 1 18 | return out 19 | 20 | def __all_idx(self, idx, axis): 21 | grid = np.ogrid[tuple(map(slice, idx.shape))] 22 | grid.insert(axis, idx) 23 | return tuple(grid) 24 | 25 | def __call__(self, img): 26 | img = np.array(img) 27 | one_hot = self.onehot_initialization(img) 28 | return one_hot 29 | 30 | 31 | class BilinearResize(object): 32 | def __init__(self, factors=[1, 2, 4, 8, 16, 32]): 33 | self.factors = factors 34 | 35 | def __call__(self, image): 36 | factor = np.random.choice(self.factors, size=1)[0] 37 | D = BicubicDownSample(factor=factor, cuda=False) 38 | img_tensor = transforms.ToTensor()(image).unsqueeze(0) 39 | img_tensor_lr = D(img_tensor)[0].clamp(0, 1) 40 | img_low_res = transforms.ToPILImage()(img_tensor_lr) 41 | return img_low_res 42 | 43 | 44 | class BicubicDownSample(nn.Module): 45 | def bicubic_kernel(self, x, a=-0.50): 46 | """ 47 | This equation is exactly copied from the website below: 48 | https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic 49 | """ 50 | abs_x = torch.abs(x) 51 | if abs_x <= 1.: 52 | return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1 53 | elif 1. < abs_x < 2.: 54 | return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a 55 | else: 56 | return 0.0 57 | 58 | def __init__(self, factor=4, cuda=True, padding='reflect'): 59 | super().__init__() 60 | self.factor = factor 61 | size = factor * 4 62 | k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) 63 | for i in range(size)], dtype=torch.float32) 64 | k = k / torch.sum(k) 65 | k1 = torch.reshape(k, shape=(1, 1, size, 1)) 66 | self.k1 = torch.cat([k1, k1, k1], dim=0) 67 | k2 = torch.reshape(k, shape=(1, 1, 1, size)) 68 | self.k2 = torch.cat([k2, k2, k2], dim=0) 69 | self.cuda = '.cuda' if cuda else '' 70 | self.padding = padding 71 | for param in self.parameters(): 72 | param.requires_grad = False 73 | 74 | def forward(self, x, nhwc=False, clip_round=False, byte_output=False): 75 | filter_height = self.factor * 4 76 | filter_width = self.factor * 4 77 | stride = self.factor 78 | 79 | pad_along_height = max(filter_height - stride, 0) 80 | pad_along_width = max(filter_width - stride, 0) 81 | filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda)) 82 | filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda)) 83 | 84 | # compute actual padding values for each side 85 | pad_top = pad_along_height // 2 86 | pad_bottom = pad_along_height - pad_top 87 | pad_left = pad_along_width // 2 88 | pad_right = pad_along_width - pad_left 89 | 90 | # apply mirror padding 91 | if nhwc: 92 | x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW 93 | 94 | # downscaling performed by 1-d convolution 95 | x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) 96 | x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) 97 | if clip_round: 98 | x = torch.clamp(torch.round(x), 0.0, 255.) 99 | 100 | x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) 101 | x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) 102 | if clip_round: 103 | x = torch.clamp(torch.round(x), 0.0, 255.) 104 | 105 | if nhwc: 106 | x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) 107 | if byte_output: 108 | return x.type('torch.ByteTensor'.format(self.cuda)) 109 | else: 110 | return x 111 | -------------------------------------------------------------------------------- /datasets/contrastive_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | from PIL import ImageFile 6 | from utils import data_utils 7 | import os 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | class ContrastiveDataset(Dataset): 10 | 11 | def __init__(self, image_root, latent_root, avg_latent_root, avg_image_root, opts, image_transform=None): 12 | self.image_paths = sorted(data_utils.make_dataset(image_root)) 13 | self.target_paths = latent_root 14 | self.image_transform = image_transform 15 | self.opts = opts 16 | self.avg_latent_root = avg_latent_root 17 | self.avg_image_root = avg_image_root 18 | 19 | def __len__(self): 20 | return len(self.image_paths) 21 | 22 | def __getitem__(self, index): 23 | image_path = self.image_paths[index] 24 | image = Image.open(image_path).convert('RGB') 25 | name = os.path.split(image_path)[1].split('.')[0] 26 | latent_path = os.path.join(self.target_paths, name + '.npy') 27 | latent = torch.from_numpy(np.load(latent_path)) 28 | image = self.image_transform(image) 29 | 30 | image_avg = Image.open(self.avg_image_root).convert('RGB') 31 | image_avg = self.image_transform(image_avg) 32 | latent_avg = torch.from_numpy(np.load(self.avg_latent_root)) 33 | return image, latent.unsqueeze(0), image_avg, latent_avg.unsqueeze(0) 34 | -------------------------------------------------------------------------------- /datasets/gt_res_dataset.py: -------------------------------------------------------------------------------- 1 | !/usr/bin/python 2 | # encoding: utf-8 3 | import os 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | 7 | 8 | class GTResDataset(Dataset): 9 | 10 | def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): 11 | self.pairs = [] 12 | for f in os.listdir(root_path): 13 | image_path = os.path.join(root_path, f) 14 | gt_path = os.path.join(gt_dir, f) 15 | if f.endswith(".jpg") or f.endswith(".png"): 16 | self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) 17 | self.transform = transform 18 | self.transform_train = transform_train 19 | 20 | def __len__(self): 21 | return len(self.pairs) 22 | 23 | def __getitem__(self, index): 24 | from_path, to_path, _ = self.pairs[index] 25 | from_im = Image.open(from_path).convert('RGB') 26 | to_im = Image.open(to_path).convert('RGB') 27 | 28 | if self.transform: 29 | to_im = self.transform(to_im) 30 | from_im = self.transform(from_im) 31 | 32 | return from_im, to_im 33 | -------------------------------------------------------------------------------- /datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | ############################################################################### 7 | # Code from 8 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 9 | # Modified the original code so that it also loads images from the current 10 | # directory as well as the subdirectories 11 | ############################################################################### 12 | import torch.utils.data as data 13 | from PIL import Image 14 | import os 15 | 16 | IMG_EXTENSIONS = [ 17 | '.jpg', '.JPG', '.jpeg', '.JPEG', 18 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp' 19 | ] 20 | 21 | 22 | def is_image_file(filename): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | 25 | 26 | def make_dataset_rec(dir, images): 27 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 28 | 29 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): 30 | for fname in fnames: 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | images.append(path) 34 | 35 | 36 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): 37 | images = [] 38 | if read_cache: 39 | possible_filelist = dir 40 | if os.path.isfile(possible_filelist): 41 | with open(possible_filelist, 'r') as f: 42 | images = f.read().splitlines() 43 | return images 44 | 45 | if recursive: 46 | make_dataset_rec(dir, images) 47 | else: 48 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 49 | 50 | for root, dnames, fnames in sorted(os.walk(dir)): 51 | for fname in fnames: 52 | if is_image_file(fname): 53 | path = os.path.join(root, fname) 54 | images.append(path) 55 | 56 | if write_cache: 57 | filelist_cache = os.path.join(dir, 'files.list') 58 | with open(filelist_cache, 'w') as f: 59 | for path in images: 60 | f.write("%s\n" % path) 61 | print('wrote filelist cache at %s' % filelist_cache) 62 | 63 | return images 64 | 65 | 66 | def default_loader(path): 67 | return Image.open(path).convert('RGB') 68 | 69 | 70 | 71 | if __name__ == '__main__': 72 | test_path = '/apdcephfs/share_1290939/kumamzqliu/data/face_inversion/val_img.txt' 73 | test_file = make_dataset(dir=test_path, recursive=False, read_cache=True) 74 | print(len(test_file)) 75 | -------------------------------------------------------------------------------- /datasets/images_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class ImagesDataset(Dataset): 7 | 8 | def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): 9 | self.source_paths = sorted(data_utils.make_dataset(source_root)) 10 | self.target_paths = sorted(data_utils.make_dataset(target_root)) 11 | self.source_transform = source_transform 12 | self.target_transform = target_transform 13 | self.opts = opts 14 | 15 | def __len__(self): 16 | return len(self.source_paths) 17 | 18 | def __getitem__(self, index): 19 | from_path = self.source_paths[index] 20 | from_im = Image.open(from_path) 21 | from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') 22 | 23 | to_path = self.target_paths[index] 24 | to_im = Image.open(to_path).convert('RGB') 25 | if self.target_transform: 26 | to_im = self.target_transform(to_im) 27 | 28 | if self.source_transform: 29 | from_im = self.source_transform(from_im) 30 | else: 31 | from_im = to_im 32 | 33 | return from_im, to_im 34 | -------------------------------------------------------------------------------- /datasets/inference_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | 5 | 6 | class InferenceDataset(Dataset): 7 | 8 | def __init__(self, root, opts, transform=None): 9 | self.paths = sorted(data_utils.make_dataset(root)) 10 | self.transform = transform 11 | self.opts = opts 12 | 13 | def __len__(self): 14 | return len(self.paths) 15 | 16 | def __getitem__(self, index): 17 | from_path = self.paths[index] 18 | from_im = Image.open(from_path) 19 | from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') 20 | if self.transform: 21 | from_im = self.transform(from_im) 22 | return from_im 23 | -------------------------------------------------------------------------------- /datasets/inference_dataset_me.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | from .image_folder import make_dataset 5 | 6 | class InferenceDataset(Dataset): 7 | 8 | def __init__(self, root, opts, image_avg_root, transform=None): 9 | self.paths = sorted(data_utils.make_dataset(root)) 10 | self.transform = transform 11 | self.opts = opts 12 | self.image_avg_root = image_avg_root 13 | 14 | def __len__(self): 15 | return len(self.paths) 16 | 17 | def __getitem__(self, index): 18 | from_path = self.paths[index] 19 | from_im = Image.open(from_path) 20 | from_im = from_im.convert('RGB') 21 | 22 | image_avg = Image.open(self.image_avg_root).convert('RGB') 23 | if self.transform: 24 | from_im = self.transform(from_im) 25 | image_avg = self.transform(image_avg) 26 | return from_im, image_avg 27 | -------------------------------------------------------------------------------- /datasets/inversion_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | from utils import data_utils 4 | import os 5 | from .image_folder import make_dataset 6 | from PIL import ImageFile 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | class ImagesDataset(Dataset): 10 | 11 | def __init__(self, image_root, image_avg_root, opts, image_transform=None): 12 | self.image_paths = sorted(make_dataset(dir=image_root, recursive=False, read_cache=True)) 13 | self.image_transform = image_transform 14 | self.image_avg_root = image_avg_root 15 | self.opts = opts 16 | 17 | def __len__(self): 18 | return len(self.image_paths) 19 | 20 | def __getitem__(self, index): 21 | image_path = self.image_paths[index] 22 | image = Image.open(image_path).convert('RGB') 23 | if image_path.find('car_inversion') is not -1: 24 | image = image.crop((0, 64, 512, 448)) 25 | image = self.image_transform(image) 26 | 27 | image_avg = Image.open(self.image_avg_root).convert('RGB') 28 | image_avg = self.image_transform(image_avg) 29 | return image, image_avg, image 30 | -------------------------------------------------------------------------------- /datasets/test_ddp_sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.distributed 4 | 5 | class SequentialDistributedSampler(torch.utils.data.sampler.Sampler): 6 | """ 7 | Distributed Sampler that subsamples indicies sequentially, 8 | making it easier to collate all results at the end. 9 | Even though we only use this sampler for eval and predict (no training), 10 | which means that the model params won't have to be synced (i.e. will not hang 11 | for synchronization even if varied number of forward passes), we still add extra 12 | samples to the sampler to make it evenly divisible (like in `DistributedSampler`) 13 | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. 14 | """ 15 | 16 | def __init__(self, dataset, batch_size, rank=None, num_replicas=None): 17 | if num_replicas is None: 18 | if not torch.distributed.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = torch.distributed.get_world_size() 21 | if rank is None: 22 | if not torch.distributed.is_available(): 23 | raise RuntimeError("Requires distrib" 24 | "uted package to be available") 25 | rank = torch.distributed.get_rank() 26 | self.dataset = dataset 27 | self.num_replicas = num_replicas 28 | self.rank = rank 29 | self.batch_size = batch_size 30 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size 31 | self.total_size = self.num_samples * self.num_replicas 32 | 33 | def __iter__(self): 34 | indices = list(range(len(self.dataset))) 35 | # add extra samples to make it evenly divisible 36 | indices += [indices[-1]] * (self.total_size - len(indices)) 37 | # subsample 38 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 39 | return iter(indices) 40 | 41 | def __len__(self): 42 | return self.num_samples 43 | -------------------------------------------------------------------------------- /doc/contrastive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/doc/contrastive.png -------------------------------------------------------------------------------- /doc/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/doc/pipeline.png -------------------------------------------------------------------------------- /doc/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/doc/teaser.png -------------------------------------------------------------------------------- /docs/encoding_inputs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/encoding_inputs.jpg -------------------------------------------------------------------------------- /docs/encoding_outputs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/encoding_outputs.jpg -------------------------------------------------------------------------------- /docs/frontalization_inputs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/frontalization_inputs.jpg -------------------------------------------------------------------------------- /docs/frontalization_outputs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/frontalization_outputs.jpg -------------------------------------------------------------------------------- /docs/seg2image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/seg2image.png -------------------------------------------------------------------------------- /docs/sketch2image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/sketch2image.png -------------------------------------------------------------------------------- /docs/super_res_32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/super_res_32.jpg -------------------------------------------------------------------------------- /docs/super_res_style_mixing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/super_res_style_mixing.jpg -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/teaser.png -------------------------------------------------------------------------------- /docs/toonify_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/toonify_input.jpg -------------------------------------------------------------------------------- /docs/toonify_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/docs/toonify_output.jpg -------------------------------------------------------------------------------- /editings/ganspace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def edit(latents, pca, edit_directions): 5 | edit_latents = [] 6 | for latent in latents: 7 | for pca_idx, start, end, strength in edit_directions: 8 | delta = get_delta(pca, latent, pca_idx, strength) 9 | delta_padded = torch.zeros(latent.shape).to('cuda') 10 | delta_padded[start:end] += delta.repeat(end - start, 1) 11 | edit_latents.append(latent + delta_padded) 12 | return torch.stack(edit_latents) 13 | 14 | 15 | def get_delta(pca, latent, idx, strength): 16 | # pca: ganspace checkpoint. latent: (16, 512) w+ 17 | w_centered = latent - pca['mean'].to('cuda') 18 | lat_comp = pca['comp'].to('cuda') 19 | lat_std = pca['std'].to('cuda') 20 | w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx] 21 | delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx] 22 | return delta 23 | -------------------------------------------------------------------------------- /editings/ganspace_pca/ffhq_pca.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/editings/ganspace_pca/ffhq_pca.pt -------------------------------------------------------------------------------- /editings/interfacegan_directions/age.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/editings/interfacegan_directions/age.pt -------------------------------------------------------------------------------- /editings/interfacegan_directions/pose.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/editings/interfacegan_directions/pose.pt -------------------------------------------------------------------------------- /editings/interfacegan_directions/smile.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/editings/interfacegan_directions/smile.pt -------------------------------------------------------------------------------- /editings/latent_editor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.append(".") 4 | sys.path.append("..") 5 | from editings import ganspace, sefa 6 | from utils.common import tensor2im 7 | 8 | 9 | class LatentEditor(object): 10 | def __init__(self, is_cars=False): 11 | self.generator = None 12 | self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output. 13 | 14 | def apply_ganspace(self, latent, ganspace_pca, edit_directions): 15 | edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions) 16 | return edit_latents 17 | 18 | def apply_interfacegan(self, latent, direction, factor=1, factor_range=None): 19 | edit_latents = [] 20 | if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5) 21 | for f in range(*factor_range): 22 | edit_latent = latent + f * direction 23 | edit_latents.append(edit_latent) 24 | edit_latents = torch.cat(edit_latents) 25 | else: 26 | edit_latents = latent + factor * direction 27 | return self._latents_to_image(edit_latents) 28 | 29 | def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs): 30 | edit_latents = sefa.edit(self.generator, latent, indices, **kwargs) 31 | return self._latents_to_image(edit_latents) 32 | 33 | # Currently, in order to apply StyleFlow editings, one should run inference, 34 | # save the latent codes and load them form the official StyleFlow repository. 35 | # def apply_styleflow(self): 36 | # pass 37 | 38 | def _latents_to_image(self, latents): 39 | with torch.no_grad(): 40 | images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True) 41 | if self.is_cars: 42 | images = images[:, :, 64:448, :] # 512x512 -> 384x512 43 | horizontal_concat_image = torch.cat(list(images), 2) 44 | final_image = tensor2im(horizontal_concat_image) 45 | return final_image 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /editings/sefa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | 6 | def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11): 7 | 8 | layers, boundaries, values = factorize_weight(generator, indices) 9 | codes = latents.detach().cpu().numpy() # (1,18,512) 10 | 11 | # Generate visualization pages. 12 | distances = np.linspace(start_distance, end_distance, step) 13 | num_sam = num_samples 14 | num_sem = semantics 15 | 16 | edited_latents = [] 17 | for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False): 18 | boundary = boundaries[sem_id:sem_id + 1] 19 | for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False): 20 | code = codes[sam_id:sam_id + 1] 21 | for col_id, d in enumerate(distances, start=1): 22 | temp_code = code.copy() 23 | temp_code[:, layers, :] += boundary * d 24 | edited_latents.append(torch.from_numpy(temp_code).float().cuda()) 25 | return torch.cat(edited_latents) 26 | 27 | 28 | def factorize_weight(g_ema, layers='all'): 29 | 30 | weights = [] 31 | if layers == 'all' or 0 in layers: 32 | weight = g_ema.conv1.conv.modulation.weight.T 33 | weights.append(weight.cpu().detach().numpy()) 34 | 35 | if layers == 'all': 36 | layers = list(range(g_ema.num_layers - 1)) 37 | else: 38 | layers = [l - 1 for l in layers if l != 0] 39 | 40 | for idx in layers: 41 | weight = g_ema.convs[idx].conv.modulation.weight.T 42 | weights.append(weight.cpu().detach().numpy()) 43 | weight = np.concatenate(weights, axis=1).astype(np.float32) 44 | weight = weight / np.linalg.norm(weight, axis=0, keepdims=True) 45 | eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T)) 46 | return layers, eigen_vectors.T, eigen_values 47 | -------------------------------------------------------------------------------- /environment/clcae_env.yaml: -------------------------------------------------------------------------------- 1 | name: psp_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - ca-certificates=2020.4.5.1=hecc5488_0 8 | - certifi=2020.4.5.1=py36h9f0ad1d_0 9 | - libedit=3.1.20181209=hc058e9b_0 10 | - libffi=3.2.1=hd88cf55_4 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - ncurses=6.2=he6710b0_1 14 | - ninja=1.10.0=hc9558a2_0 15 | - openssl=1.1.1g=h516909a_0 16 | - pip=20.0.2=py36_3 17 | - python=3.6.7=h0371630_0 18 | - python_abi=3.6=1_cp36m 19 | - readline=7.0=h7b6447c_5 20 | - setuptools=46.4.0=py36_0 21 | - sqlite=3.31.1=h62c20be_1 22 | - tk=8.6.8=hbc83047_0 23 | - wheel=0.34.2=py36_0 24 | - xz=5.2.5=h7b6447c_0 25 | - zlib=1.2.11=h7b6447c_3 26 | - pip: 27 | - scipy==1.4.1 28 | - matplotlib==3.2.1 29 | - tqdm==4.46.0 30 | - numpy==1.18.4 31 | - opencv-python==4.2.0.34 32 | - pillow==7.1.2 33 | - tensorboard==2.2.1 34 | - torch==1.6.0 35 | - torchvision==0.4.2 36 | prefix: ~/anaconda3/envs/psp_env 37 | 38 | -------------------------------------------------------------------------------- /mertric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/mertric/__init__.py -------------------------------------------------------------------------------- /mertric/measure.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from fid import FID 5 | from PIL import Image 6 | from natsort import natsorted 7 | from skimage.metrics import structural_similarity 8 | from skimage.metrics import peak_signal_noise_ratio 9 | import torch 10 | # import lpips 11 | import shutil 12 | 13 | 14 | def img_to_tensor(img): 15 | image_size = img.width 16 | image = (np.asarray(img) / 255.0).reshape(image_size * image_size, 3).transpose().reshape(3, image_size, image_size) 17 | torch_image = torch.from_numpy(image).float() 18 | torch_image = torch_image * 2.0 - 1.0 19 | torch_image = torch_image.unsqueeze(0) 20 | return torch_image 21 | 22 | 23 | class Reconstruction_Metrics: 24 | def __init__(self, metric_list=['ssim', 'psnr', 'fid'], data_range=1, win_size=21, multichannel=True): 25 | self.data_range = data_range 26 | self.win_size = win_size 27 | self.multichannel = multichannel 28 | self.fid_calculate = FID() 29 | # self.loss_fn_vgg = lpips.LPIPS(net='alex') 30 | for metric in metric_list: 31 | setattr(self, metric, True) 32 | 33 | def calculate_metric(self, real_image_path, fake_image_path): 34 | """ 35 | inputs: .txt files, floders, image files (string), image files (list) 36 | gts: .txt files, floders, image files (string), image files (list) 37 | """ 38 | # with torch.no_grad(): 39 | # fid_value = self.fid_calculate.calculate_from_disk(fake_image_path, real_image_path) 40 | psnr = [] 41 | ssim = [] 42 | # lipis = [] 43 | image_name_list = [name for name in os.listdir(real_image_path) if 44 | name.endswith((('.png', '.jpg', '.jpeg', '.JPG', '.bmp')))] 45 | # fake_image_name_list = os.listdir(fake_image_path) 46 | for i, image_name in enumerate(image_name_list): 47 | image_fake_name = image_name.split('.')[0] + '.jpg' 48 | path_real = os.path.join(real_image_path, image_name) 49 | path_fake = os.path.join(fake_image_path, image_fake_name) 50 | PIL_real = Image.open(path_real).convert('RGB') 51 | PIL_fake = Image.open(path_fake).convert('RGB') 52 | # PIL_real = PIL_real.resize((256, 192)) 53 | # PIL_fake = PIL_fake.resize((256, 256)) 54 | # fake_torch_image = img_to_tensor(PIL_fake) 55 | # real_torch_image = img_to_tensor(PIL_real) 56 | img_content_real = np.array(PIL_real).astype(np.float32) / 255.0 57 | img_content_fake = np.array(PIL_fake).astype(np.float32) / 255.0 58 | # img_content_fake = img_content_fake[32:224, :, :] 59 | # print(img_content_fake.shape) 60 | psnr_each_img = peak_signal_noise_ratio(img_content_real, img_content_fake) 61 | ssim_each_image = structural_similarity(img_content_real, img_content_fake, data_range=self.data_range, 62 | win_size=self.win_size, multichannel=self.multichannel) 63 | # lipis_each_image = self.loss_fn_vgg(fake_torch_image, real_torch_image) 64 | # lipis_each_image = lipis_each_image.detach().numpy() 65 | psnr.append(psnr_each_img) 66 | ssim.append(ssim_each_image) 67 | # lipis.append(lipis_each_image) 68 | print( 69 | "PSNR: %.4f" % np.round(np.mean(psnr), 4), 70 | "PSNR Variance: %.4f" % np.round(np.var(psnr), 4)) 71 | print( 72 | "SSIM: %.4f" % np.round(np.mean(ssim), 4), 73 | "SSIM Variance: %.4f" % np.round(np.var(ssim), 4)) 74 | # print( 75 | # "LPIPS: %.4f" % np.round(np.mean(lipis), 4), 76 | # "LPIPS Variance: %.4f" % np.round(np.var(lipis), 4)) 77 | return np.round(np.mean(psnr), 4), np.round(np.mean(ssim), 4) 78 | 79 | 80 | 81 | 82 | def select_fake_from_fewshot(video_name, start, nums): 83 | base_path = r'D:\My Documents\Desktop\transformer_dance\results\LWG\result\reconstruct' 84 | video_path = os.path.join(base_path, video_name, 'imitators') 85 | file_list = natsorted(os.listdir(video_path)) 86 | select_file_list = file_list[start:start + nums] 87 | save_base = r'D:\My Documents\Desktop\transformer_dance\results\real' 88 | save_file = os.path.join(save_base, video_name, 'fake') 89 | if os.path.exists(save_file) is False: 90 | os.makedirs(save_file) 91 | for i, file_name in enumerate(select_file_list): 92 | original_path = os.path.join(video_path, file_name) 93 | save_path = os.path.join(save_file, f'{str(i + 1).zfill(6)}.png') 94 | shutil.copyfile(original_path, save_path) 95 | 96 | 97 | 98 | def get_metric(fake_dir, real_dir): 99 | Get_metric = Reconstruction_Metrics() 100 | psnr_out, ssim_out = Get_metric.calculate_metric(fake_dir, real_dir) 101 | save_txt = os.path.join(fake_dir, 'metric.txt') 102 | with open(save_txt, 'a') as txt2: 103 | txt2.write("psnr:") 104 | txt2.write(str(psnr_out)) 105 | txt2.write(" ") 106 | txt2.write("ssim:") 107 | txt2.write(str(ssim_out)) 108 | txt2.write('\n') 109 | 110 | 111 | # def get_metric(fake_dir, real_dir): 112 | # Get_metric = Reconstruction_Metrics() 113 | # psnr_out, ssim_out, lipis_out, fid_value_out = Get_metric.calculate_metric(fake_dir, real_dir) 114 | # save_txt = os.path.join(fake_dir, 'metric.txt') 115 | # with open(save_txt, 'a') as txt2: 116 | # txt2.write("psnr:") 117 | # txt2.write(str(psnr_out)) 118 | # txt2.write(" ") 119 | # txt2.write("ssim:") 120 | # txt2.write(str(ssim_out)) 121 | # txt2.write(" ") 122 | # txt2.write("lipis:") 123 | # txt2.write(str(lipis_out)) 124 | # txt2.write(" ") 125 | # txt2.write("fid:") 126 | # txt2.write(str(fid_value_out)) 127 | # txt2.write('\n') 128 | 129 | 130 | if __name__ == "__main__": 131 | real_path = r'/apdcephfs/share_1290939/kumamzqliu/data/face_inversion/test' 132 | fake_path = r'/apdcephfs/share_1290939/kumamzqliu/resutls/ours/inference_w' 133 | get_metric(real_path, fake_path) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/models/__init__.py -------------------------------------------------------------------------------- /models/base_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 9 | 10 | class BaseNetwork(nn.Module): 11 | def __init__(self): 12 | super(BaseNetwork, self).__init__() 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | return parser 17 | 18 | def print_network(self): 19 | if isinstance(self, list): 20 | self = self[0] 21 | num_params = 0 22 | for param in self.parameters(): 23 | num_params += param.numel() 24 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 25 | 'To see the architecture, do print(network).' 26 | % (type(self).__name__, num_params / 1000000)) 27 | 28 | def init_weights(self, init_type='normal', gain=0.02): 29 | def init_func(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm2d') != -1 or classname.find('LayerNorm') != -1: 32 | if hasattr(m, 'weight') and m.weight is not None: 33 | init.normal_(m.weight.data, 1.0, gain) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif hasattr(m, 'weight') and (classname.find('Linear') != -1): 37 | trunc_normal_(m.weight.data, std=.02) 38 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 39 | if init_type == 'normal': 40 | init.normal_(m.weight.data, 0.0, gain) 41 | elif init_type == 'xavier': 42 | init.xavier_normal_(m.weight.data, gain=gain) 43 | elif init_type == 'xavier_uniform': 44 | init.xavier_uniform_(m.weight.data, gain=1.0) 45 | elif init_type == 'kaiming': 46 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 47 | elif init_type == 'orthogonal': 48 | init.orthogonal_(m.weight.data, gain=gain) 49 | elif init_type == 'none': # uses pytorch's default init method 50 | m.reset_parameters() 51 | else: 52 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | init.constant_(m.bias.data, 0.0) 55 | 56 | self.apply(init_func) 57 | 58 | # propagate to children 59 | for m in self.children(): 60 | if hasattr(m, 'init_weights'): 61 | m.init_weights(init_type, gain) 62 | -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/models/encoders/__init__.py -------------------------------------------------------------------------------- /models/encoders/fapsp_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module 6 | 7 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE 8 | from models.stylegan2.model import EqualLinear 9 | 10 | 11 | class GradualStyleBlock(Module): 12 | def __init__(self, in_c, out_c, spatial): 13 | super(GradualStyleBlock, self).__init__() 14 | self.out_c = out_c 15 | self.spatial = spatial 16 | num_pools = int(np.log2(spatial)) 17 | modules = [] 18 | modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), 19 | nn.LeakyReLU()] 20 | for i in range(num_pools - 1): 21 | modules += [ 22 | Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), 23 | nn.LeakyReLU() 24 | ] 25 | self.convs = nn.Sequential(*modules) 26 | self.linear = EqualLinear(out_c, out_c, lr_mul=1) 27 | 28 | def forward(self, x): 29 | x = self.convs(x) 30 | x = x.view(-1, self.out_c) 31 | x = self.linear(x) 32 | return x 33 | 34 | 35 | class GradualStyleEncoder(Module): 36 | def __init__(self, num_layers, mode='ir', opts=None): 37 | super(GradualStyleEncoder, self).__init__() 38 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 39 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 40 | blocks = get_blocks(num_layers) 41 | if mode == 'ir': 42 | unit_module = bottleneck_IR 43 | elif mode == 'ir_se': 44 | unit_module = bottleneck_IR_SE 45 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 46 | BatchNorm2d(64), 47 | PReLU(64)) 48 | modules = [] 49 | for block in blocks: 50 | for bottleneck in block: 51 | modules.append(unit_module(bottleneck.in_channel, 52 | bottleneck.depth, 53 | bottleneck.stride)) 54 | self.body = Sequential(*modules) 55 | 56 | self.styles = nn.ModuleList() 57 | self.style_count = opts.n_styles 58 | self.coarse_ind = 3 59 | self.middle_ind = 7 60 | for i in range(self.style_count): 61 | if i < self.coarse_ind: 62 | style = GradualStyleBlock(512, 512, 16) 63 | elif i < self.middle_ind: 64 | style = GradualStyleBlock(512, 512, 32) 65 | else: 66 | style = GradualStyleBlock(512, 512, 64) 67 | self.styles.append(style) 68 | self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) 69 | self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) 70 | 71 | def _upsample_add(self, x, y): 72 | '''Upsample and add two feature maps. 73 | Args: 74 | x: (Variable) top feature map to be upsampled. 75 | y: (Variable) lateral feature map. 76 | Returns: 77 | (Variable) added feature map. 78 | Note in PyTorch, when input size is odd, the upsampled feature map 79 | with `F.upsample(..., scale_factor=2, mode='nearest')` 80 | maybe not equal to the lateral feature map size. 81 | e.g. 82 | original input size: [N,_,15,15] -> 83 | conv2d feature map size: [N,_,8,8] -> 84 | upsampled feature map size: [N,_,16,16] 85 | So we choose bilinear upsample which supports arbitrary output sizes. 86 | ''' 87 | _, _, H, W = y.size() 88 | return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y 89 | 90 | def forward(self, x): 91 | x = self.input_layer(x) 92 | 93 | latents = [] 94 | modulelist = list(self.body._modules.values()) 95 | for i, l in enumerate(modulelist): 96 | x = l(x) 97 | if i == 6: 98 | c1 = x 99 | elif i == 20: 100 | c2 = x 101 | elif i == 23: 102 | c3 = x 103 | 104 | for j in range(self.coarse_ind): 105 | latents.append(self.styles[j](c3)) 106 | 107 | p2 = self._upsample_add(c3, self.latlayer1(c2)) 108 | for j in range(self.coarse_ind, self.middle_ind): 109 | latents.append(self.styles[j](p2)) 110 | 111 | p1 = self._upsample_add(p2, self.latlayer2(c1)) 112 | for j in range(self.middle_ind, self.style_count): 113 | latents.append(self.styles[j](p1)) 114 | 115 | out = torch.stack(latents, dim=1) 116 | return out 117 | 118 | 119 | class BackboneEncoderUsingLastLayerIntoW(Module): 120 | def __init__(self, num_layers, mode='ir', opts=None): 121 | super(BackboneEncoderUsingLastLayerIntoW, self).__init__() 122 | print('Using BackboneEncoderUsingLastLayerIntoW') 123 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 124 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 125 | blocks = get_blocks(num_layers) 126 | if mode == 'ir': 127 | unit_module = bottleneck_IR 128 | elif mode == 'ir_se': 129 | unit_module = bottleneck_IR_SE 130 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 131 | BatchNorm2d(64), 132 | PReLU(64)) 133 | self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) 134 | self.linear = EqualLinear(512, 512, lr_mul=1) 135 | modules = [] 136 | for block in blocks: 137 | for bottleneck in block: 138 | modules.append(unit_module(bottleneck.in_channel, 139 | bottleneck.depth, 140 | bottleneck.stride)) 141 | self.body = Sequential(*modules) 142 | 143 | def forward(self, x): 144 | x = self.input_layer(x) 145 | x = self.body(x) 146 | x = self.output_pool(x) 147 | x = x.view(-1, 512) 148 | x = self.linear(x) 149 | return x 150 | 151 | 152 | class BackboneEncoderUsingLastLayerIntoWPlus(Module): 153 | def __init__(self, num_layers, mode='ir', opts=None): 154 | super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__() 155 | print('Using BackboneEncoderUsingLastLayerIntoWPlus') 156 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 157 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 158 | blocks = get_blocks(num_layers) 159 | if mode == 'ir': 160 | unit_module = bottleneck_IR 161 | elif mode == 'ir_se': 162 | unit_module = bottleneck_IR_SE 163 | self.n_styles = opts.n_styles 164 | self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False), 165 | BatchNorm2d(64), 166 | PReLU(64)) 167 | self.output_layer_2 = Sequential(BatchNorm2d(512), 168 | torch.nn.AdaptiveAvgPool2d((7, 7)), 169 | Flatten(), 170 | Linear(512 * 7 * 7, 512)) 171 | self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1) 172 | modules = [] 173 | for block in blocks: 174 | for bottleneck in block: 175 | modules.append(unit_module(bottleneck.in_channel, 176 | bottleneck.depth, 177 | bottleneck.stride)) 178 | self.body = Sequential(*modules) 179 | 180 | def forward(self, x): 181 | x = self.input_layer(x) 182 | x = self.body(x) 183 | x = self.output_layer_2(x) 184 | x = self.linear(x) 185 | x = x.view(-1, self.n_styles, 512) 186 | return -------------------------------------------------------------------------------- /models/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | 121 | -------------------------------------------------------------------------------- /models/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /models/encoders/projection_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ProjectionHead(nn.Module): 4 | def __init__( 5 | self, 6 | opts, 7 | embedding_dim, 8 | 9 | ): 10 | super().__init__() 11 | projection_dim = opts.projection_dim if opts.projection_dim is not None else 512 12 | dropout = opts.projection_dropout if opts.projection_dropout is not None else 0.1 13 | self.projection = nn.Linear(embedding_dim, projection_dim) 14 | self.gelu = nn.GELU() 15 | self.fc = nn.Linear(projection_dim, projection_dim) 16 | self.dropout = nn.Dropout(dropout) 17 | self.layer_norm = nn.LayerNorm(projection_dim) 18 | 19 | def forward(self, x): 20 | projected = self.projection(x) 21 | x = self.gelu(projected) 22 | x = self.fc(x) 23 | x = self.dropout(x) 24 | x = x + projected 25 | x = self.layer_norm(x) 26 | return x 27 | 28 | class Projection(nn.Module): 29 | """ 30 | Creates projection head 31 | Args: 32 | n_in (int): Number of input features 33 | n_hidden (int): Number of hidden features 34 | n_out (int): Number of output features 35 | use_bn (bool): Whether to use batch norm 36 | """ 37 | 38 | def __init__(self, n_in: int, n_hidden: int, n_out: int, 39 | use_bn: bool = True): 40 | super().__init__() 41 | 42 | # No point in using bias if we've batch norm 43 | self.lin1 = nn.Linear(n_in, n_hidden, bias=not use_bn) 44 | self.bn = nn.BatchNorm1d(n_hidden) if use_bn else nn.Identity() 45 | self.relu = nn.ReLU() 46 | # No bias for the final linear layer 47 | self.lin2 = nn.Linear(n_hidden, n_out, bias=False) 48 | 49 | def forward(self, x): 50 | x = self.lin1(x) 51 | x = self.bn(x) 52 | x = self.relu(x) 53 | x = self.lin2(x) 54 | return x -------------------------------------------------------------------------------- /models/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | DETR Transformer class. 4 | 5 | Copy-paste from torch.nn.Transformer with modifications: 6 | * positional encodings are passed in MHattention 7 | * extra LN at the end of encoder is removed 8 | * decoder returns a stack of activations from all decoding layers 9 | """ 10 | import copy 11 | from typing import Optional, List 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | 18 | class TransformerEncoderLayer(nn.Module): 19 | 20 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 21 | activation="relu", normalize_before=False): 22 | super().__init__() 23 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 24 | # Implementation of Feedforward model 25 | self.linear1 = nn.Linear(d_model, dim_feedforward) 26 | self.dropout = nn.Dropout(dropout) 27 | self.linear2 = nn.Linear(dim_feedforward, d_model) 28 | 29 | self.norm1 = nn.LayerNorm(d_model) 30 | self.norm2 = nn.LayerNorm(d_model) 31 | self.dropout1 = nn.Dropout(dropout) 32 | self.dropout2 = nn.Dropout(dropout) 33 | 34 | self.activation = _get_activation_fn(activation) 35 | self.normalize_before = normalize_before 36 | 37 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 38 | return tensor if pos is None else tensor + pos 39 | 40 | def forward_post(self, 41 | src, 42 | src_mask: Optional[Tensor] = None, 43 | src_key_padding_mask: Optional[Tensor] = None, 44 | pos: Optional[Tensor] = None): 45 | q = k = self.with_pos_embed(src, pos) 46 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 47 | key_padding_mask=src_key_padding_mask)[0] 48 | src = src + self.dropout1(src2) 49 | src = self.norm1(src) 50 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 51 | src = src + self.dropout2(src2) 52 | src = self.norm2(src) 53 | return src 54 | 55 | def forward_pre(self, src, 56 | src_mask: Optional[Tensor] = None, 57 | src_key_padding_mask: Optional[Tensor] = None, 58 | pos: Optional[Tensor] = None): 59 | src2 = self.norm1(src) 60 | q = k = self.with_pos_embed(src2, pos) 61 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 62 | key_padding_mask=src_key_padding_mask)[0] 63 | src = src + self.dropout1(src2) 64 | src2 = self.norm2(src) 65 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 66 | src = src + self.dropout2(src2) 67 | return src 68 | 69 | def forward(self, src, 70 | src_mask: Optional[Tensor] = None, 71 | src_key_padding_mask: Optional[Tensor] = None, 72 | pos: Optional[Tensor] = None): 73 | if self.normalize_before: 74 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 75 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 76 | 77 | 78 | class CrossAttention(nn.Module): 79 | 80 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 81 | activation="relu", normalize_before=False): 82 | super().__init__() 83 | # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 84 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 85 | # Implementation of Feedforward model 86 | self.linear1 = nn.Linear(d_model, dim_feedforward) 87 | self.dropout = nn.Dropout(dropout) 88 | self.linear2 = nn.Linear(dim_feedforward, d_model) 89 | 90 | self.norm1 = nn.LayerNorm(d_model) 91 | self.norm2 = nn.LayerNorm(d_model) 92 | self.norm3 = nn.LayerNorm(d_model) 93 | self.dropout1 = nn.Dropout(dropout) 94 | self.dropout2 = nn.Dropout(dropout) 95 | self.dropout3 = nn.Dropout(dropout) 96 | 97 | self.activation = _get_activation_fn(activation) 98 | self.normalize_before = normalize_before 99 | 100 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 101 | return tensor if pos is None else tensor + pos 102 | 103 | def forward(self, tgt, memory, 104 | tgt_mask: Optional[Tensor] = None, 105 | memory_mask: Optional[Tensor] = None, 106 | tgt_key_padding_mask: Optional[Tensor] = None, 107 | memory_key_padding_mask: Optional[Tensor] = None, 108 | pos: Optional[Tensor] = None, 109 | query_pos: Optional[Tensor] = None, 110 | no_res=False): 111 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 112 | key=self.with_pos_embed(memory, pos), 113 | value=memory, attn_mask=memory_mask, 114 | key_padding_mask=memory_key_padding_mask)[0] 115 | # if no_res: 116 | # tgt = tgt2 117 | # else: 118 | tgt = tgt + self.dropout2(tgt2) 119 | tgt = self.norm2(tgt) 120 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 121 | tgt = tgt + self.dropout3(tgt2) 122 | tgt = self.norm3(tgt) 123 | return tgt 124 | 125 | 126 | def _get_activation_fn(activation): 127 | """Return an activation function given a string""" 128 | if activation == "relu": 129 | return F.relu 130 | if activation == "gelu": 131 | return F.gelu 132 | if activation == "glu": 133 | return F.glu 134 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 135 | -------------------------------------------------------------------------------- /models/image_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import matplotlib 5 | 6 | matplotlib.use('Agg') 7 | import math 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | from .base_network import BaseNetwork 13 | from models.encoders import psp_encoders 14 | from configs.paths_config import model_paths 15 | from .encoders.projection_head import ProjectionHead, Projection 16 | 17 | 18 | def get_keys(d, name): 19 | if 'state_dict' in d: 20 | d = d['state_dict'] 21 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 22 | return d_filt 23 | 24 | 25 | class ImageEncoder(BaseNetwork): 26 | 27 | def __init__(self, opts): 28 | super(ImageEncoder, self).__init__() 29 | self.set_opts(opts) 30 | # Define architecture 31 | self.encoder = self.set_encoder() 32 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 33 | # Load weights if needed n_in: int, n_hidden: int, n_out: int 34 | self.image_projection = Projection(n_in=opts.image_embedding_dim, n_hidden=opts.image_embedding_dim, 35 | n_out=opts.image_embedding_dim) 36 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 37 | 38 | def set_encoder(self): 39 | if self.opts.encoder_type == 'GradualStyleEncoder': 40 | encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) 41 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': 42 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) 43 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': 44 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts) 45 | else: 46 | raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) 47 | return encoder 48 | 49 | def load_weights(self, model_path=None): 50 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) 51 | if self.opts.checkpoint_path_image is not None or model_path is not None: 52 | if model_path is not None: 53 | pretrained_model = model_path 54 | else: 55 | pretrained_model = self.opts.checkpoint_path_image 56 | print('Loading latentencoder from checkpoint: {}'.format(pretrained_model)) 57 | ckpt = torch.load(pretrained_model, map_location='cpu') 58 | self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) 59 | self.image_projection.load_state_dict(get_keys(ckpt, 'image_projection'), strict=True) 60 | self.logit_scale = nn.Parameter(get_keys(ckpt, 'logit_scale')['']) 61 | elif self.opts.load_pretrain_image_encoder: 62 | print('Loading encoders weights from irse50!') 63 | encoder_ckpt = torch.load(model_paths['ir_se50']) 64 | # if input to encoder is not an RGB image, do not load the input layer weights 65 | if self.opts.label_nc != 0: 66 | encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k} 67 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 68 | print('Loading W_avg from pretrained!') 69 | ckpt = torch.load(self.opts.stylegan_weights) 70 | if self.opts.learn_in_w: 71 | self.__load_latent_avg(ckpt, repeat=1) 72 | else: 73 | self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) 74 | else: 75 | print('No former trained model') 76 | 77 | def forward(self, x, x_avg): 78 | out = self.encoder(x - x_avg) 79 | image_feature = self.image_projection(out) 80 | if self.opts.use_norm: 81 | image_feature = F.normalize(image_feature, dim=-1) 82 | return image_feature, self.logit_scale.exp() 83 | 84 | def set_opts(self, opts): 85 | self.opts = opts 86 | 87 | def __load_latent_avg(self, ckpt, repeat=None): 88 | if 'latent_avg' in ckpt: 89 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 90 | if repeat is not None: 91 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 92 | else: 93 | self.latent_avg = None 94 | -------------------------------------------------------------------------------- /models/latent_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | from models.stylegan2.op import fused_leaky_relu 5 | from models.encoders.transformer import TransformerEncoderLayer 6 | import torch.nn as nn 7 | from .encoders.projection_head import Projection 8 | from .base_network import BaseNetwork 9 | 10 | 11 | def get_keys(d, name): 12 | if 'state_dict' in d: 13 | d = d['state_dict'] 14 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 15 | return d_filt 16 | 17 | 18 | class EqualLinear(nn.Module): 19 | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): 20 | super().__init__() 21 | 22 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 23 | 24 | if bias: 25 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 26 | 27 | else: 28 | self.bias = None 29 | 30 | self.activation = activation 31 | 32 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 33 | self.lr_mul = lr_mul 34 | 35 | def forward(self, input): 36 | if self.activation: 37 | out = F.linear(input, self.weight * self.scale) 38 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 39 | 40 | else: 41 | out = F.linear( 42 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 43 | ) 44 | 45 | return out 46 | 47 | def __repr__(self): 48 | return ( 49 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 50 | ) 51 | 52 | 53 | class EncoderMlp(nn.Module): 54 | def __init__(self, dim, n_mlp, lr_mlp=0.01): 55 | super().__init__() 56 | layers = [] 57 | 58 | for i in range(n_mlp): 59 | layers.append( 60 | EqualLinear( 61 | dim, dim, lr_mul=lr_mlp, activation='fused_lrelu' 62 | ) 63 | ) 64 | 65 | self.latent_encoder = nn.Sequential(*layers) 66 | 67 | def forward(self, x, x_avg): 68 | out = self.latent_encoder(x - x_avg) 69 | return out 70 | 71 | 72 | class LatentEncoder(BaseNetwork): 73 | def __init__(self, opts): 74 | super(LatentEncoder, self).__init__() 75 | self.set_opts(opts) 76 | num_latent = self.opts.num_latent 77 | dim = self.opts.latent_embedding_dim 78 | self.pos_embedding = nn.Parameter(torch.randn(1, num_latent, dim)) 79 | self.coarse = TransformerEncoderLayer(d_model=512, nhead=4, dim_feedforward=1024) 80 | self.medium = TransformerEncoderLayer(d_model=512, nhead=4, dim_feedforward=1024) 81 | self.fine = TransformerEncoderLayer(d_model=512, nhead=4, dim_feedforward=1024) 82 | self.latent_projection = Projection(n_in=opts.latent_embedding_dim, n_hidden=opts.latent_embedding_dim, 83 | n_out=opts.latent_embedding_dim) 84 | 85 | def load_weights(self, model_path=None): 86 | if self.opts.checkpoint_path_latent is not None or model_path is not None: 87 | if model_path is not None: 88 | pretrained_model = model_path 89 | else: 90 | pretrained_model = self.opts.checkpoint_path_latent 91 | print('Loading latentencoder from checkpoint: {}'.format(pretrained_model)) 92 | ckpt = torch.load(pretrained_model, map_location='cpu') 93 | self.coarse.load_state_dict(get_keys(ckpt, 'coarse'), strict=True) 94 | self.medium.load_state_dict(get_keys(ckpt, 'medium'), strict=True) 95 | self.fine.load_state_dict(get_keys(ckpt, 'fine'), strict=True) 96 | self.latent_projection.load_state_dict(get_keys(ckpt, 'latent_projection'), strict=True) 97 | 98 | self.pos_embedding = nn.Parameter(get_keys(ckpt, 'pos_embedding')['']) 99 | else: 100 | print('No former trained model') 101 | pass 102 | 103 | def forward(self, x, x_avg, return_all=False): 104 | latent_input = x - x_avg 105 | latent_input = latent_input.permute(1, 0, 2) # N B C 106 | latent_coarse = self.coarse(latent_input, pos=self.pos_embedding) 107 | latent_medium = self.medium(latent_coarse, pos=self.pos_embedding) 108 | latent_fine = self.fine(latent_medium, pos=self.pos_embedding) 109 | latent_out = latent_fine.permute(1, 0, 2) 110 | latent_out = latent_out.squeeze(1) 111 | latent_feature = self.latent_projection(latent_out) 112 | if self.opts.use_norm: 113 | latent_feature = F.normalize(latent_feature, dim=-1) 114 | if return_all: 115 | return latent_out, latent_coarse.permute(1, 0, 2), latent_medium.permute(1, 0, 2) 116 | else: 117 | return latent_feature 118 | 119 | def set_opts(self, opts): 120 | self.opts = opts 121 | 122 | 123 | # unit test 124 | if __name__ == '__main__': 125 | model = LatentEncoder().cuda() 126 | input_test = torch.rand(1, 1, 512).cuda() 127 | latent_avg = torch.rand(1, 1, 512).cuda() 128 | out = model(input_test, latent_avg) 129 | print(out.shape) 130 | -------------------------------------------------------------------------------- /models/mtcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/models/mtcnn/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet 5 | from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage 7 | from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face 8 | 9 | device = 'cuda:0' 10 | 11 | 12 | class MTCNN(): 13 | def __init__(self): 14 | print(device) 15 | self.pnet = PNet().to(device) 16 | self.rnet = RNet().to(device) 17 | self.onet = ONet().to(device) 18 | self.pnet.eval() 19 | self.rnet.eval() 20 | self.onet.eval() 21 | self.refrence = get_reference_facial_points(default_square=True) 22 | 23 | def align(self, img): 24 | _, landmarks = self.detect_faces(img) 25 | if len(landmarks) == 0: 26 | return None, None 27 | facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] 28 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 29 | return Image.fromarray(warped_face), tfm 30 | 31 | def align_multi(self, img, limit=None, min_face_size=30.0): 32 | boxes, landmarks = self.detect_faces(img, min_face_size) 33 | if limit: 34 | boxes = boxes[:limit] 35 | landmarks = landmarks[:limit] 36 | faces = [] 37 | tfms = [] 38 | for landmark in landmarks: 39 | facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)] 40 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 41 | faces.append(Image.fromarray(warped_face)) 42 | tfms.append(tfm) 43 | return boxes, faces, tfms 44 | 45 | def detect_faces(self, image, min_face_size=20.0, 46 | thresholds=[0.15, 0.25, 0.35], 47 | nms_thresholds=[0.7, 0.7, 0.7]): 48 | """ 49 | Arguments: 50 | image: an instance of PIL.Image. 51 | min_face_size: a float number. 52 | thresholds: a list of length 3. 53 | nms_thresholds: a list of length 3. 54 | 55 | Returns: 56 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 57 | bounding boxes and facial landmarks. 58 | """ 59 | 60 | # BUILD AN IMAGE PYRAMID 61 | width, height = image.size 62 | min_length = min(height, width) 63 | 64 | min_detection_size = 12 65 | factor = 0.707 # sqrt(0.5) 66 | 67 | # scales for scaling the image 68 | scales = [] 69 | 70 | # scales the image so that 71 | # minimum size that we can detect equals to 72 | # minimum face size that we want to detect 73 | m = min_detection_size / min_face_size 74 | min_length *= m 75 | 76 | factor_count = 0 77 | while min_length > min_detection_size: 78 | scales.append(m * factor ** factor_count) 79 | min_length *= factor 80 | factor_count += 1 81 | 82 | # STAGE 1 83 | 84 | # it will be returned 85 | bounding_boxes = [] 86 | 87 | with torch.no_grad(): 88 | # run P-Net on different scales 89 | for s in scales: 90 | boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0]) 91 | bounding_boxes.append(boxes) 92 | 93 | # collect boxes (and offsets, and scores) from different scales 94 | bounding_boxes = [i for i in bounding_boxes if i is not None] 95 | bounding_boxes = np.vstack(bounding_boxes) 96 | 97 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 98 | bounding_boxes = bounding_boxes[keep] 99 | 100 | # use offsets predicted by pnet to transform bounding boxes 101 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 102 | # shape [n_boxes, 5] 103 | 104 | bounding_boxes = convert_to_square(bounding_boxes) 105 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 106 | 107 | # STAGE 2 108 | 109 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 110 | img_boxes = torch.FloatTensor(img_boxes).to(device) 111 | 112 | output = self.rnet(img_boxes) 113 | offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] 114 | probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] 115 | 116 | keep = np.where(probs[:, 1] > thresholds[1])[0] 117 | bounding_boxes = bounding_boxes[keep] 118 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 119 | offsets = offsets[keep] 120 | 121 | keep = nms(bounding_boxes, nms_thresholds[1]) 122 | bounding_boxes = bounding_boxes[keep] 123 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 124 | bounding_boxes = convert_to_square(bounding_boxes) 125 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 126 | 127 | # STAGE 3 128 | 129 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 130 | if len(img_boxes) == 0: 131 | return [], [] 132 | img_boxes = torch.FloatTensor(img_boxes).to(device) 133 | output = self.onet(img_boxes) 134 | landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] 135 | offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] 136 | probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] 137 | 138 | keep = np.where(probs[:, 1] > thresholds[2])[0] 139 | bounding_boxes = bounding_boxes[keep] 140 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 141 | offsets = offsets[keep] 142 | landmarks = landmarks[keep] 143 | 144 | # compute landmark points 145 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 146 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 147 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 148 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 149 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 150 | 151 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 152 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 153 | bounding_boxes = bounding_boxes[keep] 154 | landmarks = landmarks[keep] 155 | 156 | return bounding_boxes, landmarks 157 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/models/mtcnn/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization_utils import show_bboxes 2 | from .detector import detect_faces 3 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/box_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def nms(boxes, overlap_threshold=0.5, mode='union'): 6 | """Non-maximum suppression. 7 | 8 | Arguments: 9 | boxes: a float numpy array of shape [n, 5], 10 | where each row is (xmin, ymin, xmax, ymax, score). 11 | overlap_threshold: a float number. 12 | mode: 'union' or 'min'. 13 | 14 | Returns: 15 | list with indices of the selected boxes 16 | """ 17 | 18 | # if there are no boxes, return the empty list 19 | if len(boxes) == 0: 20 | return [] 21 | 22 | # list of picked indices 23 | pick = [] 24 | 25 | # grab the coordinates of the bounding boxes 26 | x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] 27 | 28 | area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0) 29 | ids = np.argsort(score) # in increasing order 30 | 31 | while len(ids) > 0: 32 | 33 | # grab index of the largest value 34 | last = len(ids) - 1 35 | i = ids[last] 36 | pick.append(i) 37 | 38 | # compute intersections 39 | # of the box with the largest score 40 | # with the rest of boxes 41 | 42 | # left top corner of intersection boxes 43 | ix1 = np.maximum(x1[i], x1[ids[:last]]) 44 | iy1 = np.maximum(y1[i], y1[ids[:last]]) 45 | 46 | # right bottom corner of intersection boxes 47 | ix2 = np.minimum(x2[i], x2[ids[:last]]) 48 | iy2 = np.minimum(y2[i], y2[ids[:last]]) 49 | 50 | # width and height of intersection boxes 51 | w = np.maximum(0.0, ix2 - ix1 + 1.0) 52 | h = np.maximum(0.0, iy2 - iy1 + 1.0) 53 | 54 | # intersections' areas 55 | inter = w * h 56 | if mode == 'min': 57 | overlap = inter / np.minimum(area[i], area[ids[:last]]) 58 | elif mode == 'union': 59 | # intersection over union (IoU) 60 | overlap = inter / (area[i] + area[ids[:last]] - inter) 61 | 62 | # delete all boxes where overlap is too big 63 | ids = np.delete( 64 | ids, 65 | np.concatenate([[last], np.where(overlap > overlap_threshold)[0]]) 66 | ) 67 | 68 | return pick 69 | 70 | 71 | def convert_to_square(bboxes): 72 | """Convert bounding boxes to a square form. 73 | 74 | Arguments: 75 | bboxes: a float numpy array of shape [n, 5]. 76 | 77 | Returns: 78 | a float numpy array of shape [n, 5], 79 | squared bounding boxes. 80 | """ 81 | 82 | square_bboxes = np.zeros_like(bboxes) 83 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 84 | h = y2 - y1 + 1.0 85 | w = x2 - x1 + 1.0 86 | max_side = np.maximum(h, w) 87 | square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5 88 | square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5 89 | square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 90 | square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 91 | return square_bboxes 92 | 93 | 94 | def calibrate_box(bboxes, offsets): 95 | """Transform bounding boxes to be more like true bounding boxes. 96 | 'offsets' is one of the outputs of the nets. 97 | 98 | Arguments: 99 | bboxes: a float numpy array of shape [n, 5]. 100 | offsets: a float numpy array of shape [n, 4]. 101 | 102 | Returns: 103 | a float numpy array of shape [n, 5]. 104 | """ 105 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 106 | w = x2 - x1 + 1.0 107 | h = y2 - y1 + 1.0 108 | w = np.expand_dims(w, 1) 109 | h = np.expand_dims(h, 1) 110 | 111 | # this is what happening here: 112 | # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] 113 | # x1_true = x1 + tx1*w 114 | # y1_true = y1 + ty1*h 115 | # x2_true = x2 + tx2*w 116 | # y2_true = y2 + ty2*h 117 | # below is just more compact form of this 118 | 119 | # are offsets always such that 120 | # x1 < x2 and y1 < y2 ? 121 | 122 | translation = np.hstack([w, h, w, h]) * offsets 123 | bboxes[:, 0:4] = bboxes[:, 0:4] + translation 124 | return bboxes 125 | 126 | 127 | def get_image_boxes(bounding_boxes, img, size=24): 128 | """Cut out boxes from the image. 129 | 130 | Arguments: 131 | bounding_boxes: a float numpy array of shape [n, 5]. 132 | img: an instance of PIL.Image. 133 | size: an integer, size of cutouts. 134 | 135 | Returns: 136 | a float numpy array of shape [n, 3, size, size]. 137 | """ 138 | 139 | num_boxes = len(bounding_boxes) 140 | width, height = img.size 141 | 142 | [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height) 143 | img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') 144 | 145 | for i in range(num_boxes): 146 | img_box = np.zeros((h[i], w[i], 3), 'uint8') 147 | 148 | img_array = np.asarray(img, 'uint8') 149 | img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \ 150 | img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] 151 | 152 | # resize 153 | img_box = Image.fromarray(img_box) 154 | img_box = img_box.resize((size, size), Image.BILINEAR) 155 | img_box = np.asarray(img_box, 'float32') 156 | 157 | img_boxes[i, :, :, :] = _preprocess(img_box) 158 | 159 | return img_boxes 160 | 161 | 162 | def correct_bboxes(bboxes, width, height): 163 | """Crop boxes that are too big and get coordinates 164 | with respect to cutouts. 165 | 166 | Arguments: 167 | bboxes: a float numpy array of shape [n, 5], 168 | where each row is (xmin, ymin, xmax, ymax, score). 169 | width: a float number. 170 | height: a float number. 171 | 172 | Returns: 173 | dy, dx, edy, edx: a int numpy arrays of shape [n], 174 | coordinates of the boxes with respect to the cutouts. 175 | y, x, ey, ex: a int numpy arrays of shape [n], 176 | corrected ymin, xmin, ymax, xmax. 177 | h, w: a int numpy arrays of shape [n], 178 | just heights and widths of boxes. 179 | 180 | in the following order: 181 | [dy, edy, dx, edx, y, ey, x, ex, w, h]. 182 | """ 183 | 184 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 185 | w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 186 | num_boxes = bboxes.shape[0] 187 | 188 | # 'e' stands for end 189 | # (x, y) -> (ex, ey) 190 | x, y, ex, ey = x1, y1, x2, y2 191 | 192 | # we need to cut out a box from the image. 193 | # (x, y, ex, ey) are corrected coordinates of the box 194 | # in the image. 195 | # (dx, dy, edx, edy) are coordinates of the box in the cutout 196 | # from the image. 197 | dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,)) 198 | edx, edy = w.copy() - 1.0, h.copy() - 1.0 199 | 200 | # if box's bottom right corner is too far right 201 | ind = np.where(ex > width - 1.0)[0] 202 | edx[ind] = w[ind] + width - 2.0 - ex[ind] 203 | ex[ind] = width - 1.0 204 | 205 | # if box's bottom right corner is too low 206 | ind = np.where(ey > height - 1.0)[0] 207 | edy[ind] = h[ind] + height - 2.0 - ey[ind] 208 | ey[ind] = height - 1.0 209 | 210 | # if box's top left corner is too far left 211 | ind = np.where(x < 0.0)[0] 212 | dx[ind] = 0.0 - x[ind] 213 | x[ind] = 0.0 214 | 215 | # if box's top left corner is too high 216 | ind = np.where(y < 0.0)[0] 217 | dy[ind] = 0.0 - y[ind] 218 | y[ind] = 0.0 219 | 220 | return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] 221 | return_list = [i.astype('int32') for i in return_list] 222 | 223 | return return_list 224 | 225 | 226 | def _preprocess(img): 227 | """Preprocessing step before feeding the network. 228 | 229 | Arguments: 230 | img: a float numpy array of shape [h, w, c]. 231 | 232 | Returns: 233 | a float numpy array of shape [1, c, h, w]. 234 | """ 235 | img = img.transpose((2, 0, 1)) 236 | img = np.expand_dims(img, 0) 237 | img = (img - 127.5) * 0.0078125 238 | return img 239 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from .get_nets import PNet, RNet, ONet 5 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from .first_stage import run_first_stage 7 | 8 | 9 | def detect_faces(image, min_face_size=20.0, 10 | thresholds=[0.6, 0.7, 0.8], 11 | nms_thresholds=[0.7, 0.7, 0.7]): 12 | """ 13 | Arguments: 14 | image: an instance of PIL.Image. 15 | min_face_size: a float number. 16 | thresholds: a list of length 3. 17 | nms_thresholds: a list of length 3. 18 | 19 | Returns: 20 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 21 | bounding boxes and facial landmarks. 22 | """ 23 | 24 | # LOAD MODELS 25 | pnet = PNet() 26 | rnet = RNet() 27 | onet = ONet() 28 | onet.eval() 29 | 30 | # BUILD AN IMAGE PYRAMID 31 | width, height = image.size 32 | min_length = min(height, width) 33 | 34 | min_detection_size = 12 35 | factor = 0.707 # sqrt(0.5) 36 | 37 | # scales for scaling the image 38 | scales = [] 39 | 40 | # scales the image so that 41 | # minimum size that we can detect equals to 42 | # minimum face size that we want to detect 43 | m = min_detection_size / min_face_size 44 | min_length *= m 45 | 46 | factor_count = 0 47 | while min_length > min_detection_size: 48 | scales.append(m * factor ** factor_count) 49 | min_length *= factor 50 | factor_count += 1 51 | 52 | # STAGE 1 53 | 54 | # it will be returned 55 | bounding_boxes = [] 56 | 57 | with torch.no_grad(): 58 | # run P-Net on different scales 59 | for s in scales: 60 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 61 | bounding_boxes.append(boxes) 62 | 63 | # collect boxes (and offsets, and scores) from different scales 64 | bounding_boxes = [i for i in bounding_boxes if i is not None] 65 | bounding_boxes = np.vstack(bounding_boxes) 66 | 67 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 68 | bounding_boxes = bounding_boxes[keep] 69 | 70 | # use offsets predicted by pnet to transform bounding boxes 71 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 72 | # shape [n_boxes, 5] 73 | 74 | bounding_boxes = convert_to_square(bounding_boxes) 75 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 76 | 77 | # STAGE 2 78 | 79 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 80 | img_boxes = torch.FloatTensor(img_boxes) 81 | 82 | output = rnet(img_boxes) 83 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 84 | probs = output[1].data.numpy() # shape [n_boxes, 2] 85 | 86 | keep = np.where(probs[:, 1] > thresholds[1])[0] 87 | bounding_boxes = bounding_boxes[keep] 88 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 89 | offsets = offsets[keep] 90 | 91 | keep = nms(bounding_boxes, nms_thresholds[1]) 92 | bounding_boxes = bounding_boxes[keep] 93 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 94 | bounding_boxes = convert_to_square(bounding_boxes) 95 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 96 | 97 | # STAGE 3 98 | 99 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 100 | if len(img_boxes) == 0: 101 | return [], [] 102 | img_boxes = torch.FloatTensor(img_boxes) 103 | output = onet(img_boxes) 104 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 105 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 106 | probs = output[2].data.numpy() # shape [n_boxes, 2] 107 | 108 | keep = np.where(probs[:, 1] > thresholds[2])[0] 109 | bounding_boxes = bounding_boxes[keep] 110 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 111 | offsets = offsets[keep] 112 | landmarks = landmarks[keep] 113 | 114 | # compute landmark points 115 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 116 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 117 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 118 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 119 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 120 | 121 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 122 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 123 | bounding_boxes = bounding_boxes[keep] 124 | landmarks = landmarks[keep] 125 | 126 | return bounding_boxes, landmarks 127 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import math 4 | from PIL import Image 5 | import numpy as np 6 | from .box_utils import nms, _preprocess 7 | 8 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | device = 'cuda:0' 10 | 11 | 12 | def run_first_stage(image, net, scale, threshold): 13 | """Run P-Net, generate bounding boxes, and do NMS. 14 | 15 | Arguments: 16 | image: an instance of PIL.Image. 17 | net: an instance of pytorch's nn.Module, P-Net. 18 | scale: a float number, 19 | scale width and height of the image by this number. 20 | threshold: a float number, 21 | threshold on the probability of a face when generating 22 | bounding boxes from predictions of the net. 23 | 24 | Returns: 25 | a float numpy array of shape [n_boxes, 9], 26 | bounding boxes with scores and offsets (4 + 1 + 4). 27 | """ 28 | 29 | # scale the image and convert it to a float array 30 | width, height = image.size 31 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 32 | img = image.resize((sw, sh), Image.BILINEAR) 33 | img = np.asarray(img, 'float32') 34 | 35 | img = torch.FloatTensor(_preprocess(img)).to(device) 36 | with torch.no_grad(): 37 | output = net(img) 38 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 39 | offsets = output[0].cpu().data.numpy() 40 | # probs: probability of a face at each sliding window 41 | # offsets: transformations to true bounding boxes 42 | 43 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 44 | if len(boxes) == 0: 45 | return None 46 | 47 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 48 | return boxes[keep] 49 | 50 | 51 | def _generate_bboxes(probs, offsets, scale, threshold): 52 | """Generate bounding boxes at places 53 | where there is probably a face. 54 | 55 | Arguments: 56 | probs: a float numpy array of shape [n, m]. 57 | offsets: a float numpy array of shape [1, 4, n, m]. 58 | scale: a float number, 59 | width and height of the image were scaled by this number. 60 | threshold: a float number. 61 | 62 | Returns: 63 | a float numpy array of shape [n_boxes, 9] 64 | """ 65 | 66 | # applying P-Net is equivalent, in some sense, to 67 | # moving 12x12 window with stride 2 68 | stride = 2 69 | cell_size = 12 70 | 71 | # indices of boxes where there is probably a face 72 | inds = np.where(probs > threshold) 73 | 74 | if inds[0].size == 0: 75 | return np.array([]) 76 | 77 | # transformations of bounding boxes 78 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 79 | # they are defined as: 80 | # w = x2 - x1 + 1 81 | # h = y2 - y1 + 1 82 | # x1_true = x1 + tx1*w 83 | # x2_true = x2 + tx2*w 84 | # y1_true = y1 + ty1*h 85 | # y2_true = y2 + ty2*h 86 | 87 | offsets = np.array([tx1, ty1, tx2, ty2]) 88 | score = probs[inds[0], inds[1]] 89 | 90 | # P-Net is applied to scaled images 91 | # so we need to rescale bounding boxes back 92 | bounding_boxes = np.vstack([ 93 | np.round((stride * inds[1] + 1.0) / scale), 94 | np.round((stride * inds[0] + 1.0) / scale), 95 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 96 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 97 | score, offsets 98 | ]) 99 | # why one is added? 100 | 101 | return bounding_boxes.T 102 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | from configs.paths_config import model_paths 8 | PNET_PATH = model_paths["mtcnn_pnet"] 9 | ONET_PATH = model_paths["mtcnn_onet"] 10 | RNET_PATH = model_paths["mtcnn_rnet"] 11 | 12 | 13 | class Flatten(nn.Module): 14 | 15 | def __init__(self): 16 | super(Flatten, self).__init__() 17 | 18 | def forward(self, x): 19 | """ 20 | Arguments: 21 | x: a float tensor with shape [batch_size, c, h, w]. 22 | Returns: 23 | a float tensor with shape [batch_size, c*h*w]. 24 | """ 25 | 26 | # without this pretrained model isn't working 27 | x = x.transpose(3, 2).contiguous() 28 | 29 | return x.view(x.size(0), -1) 30 | 31 | 32 | class PNet(nn.Module): 33 | 34 | def __init__(self): 35 | super().__init__() 36 | 37 | # suppose we have input with size HxW, then 38 | # after first layer: H - 2, 39 | # after pool: ceil((H - 2)/2), 40 | # after second conv: ceil((H - 2)/2) - 2, 41 | # after last conv: ceil((H - 2)/2) - 4, 42 | # and the same for W 43 | 44 | self.features = nn.Sequential(OrderedDict([ 45 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 46 | ('prelu1', nn.PReLU(10)), 47 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), 48 | 49 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 50 | ('prelu2', nn.PReLU(16)), 51 | 52 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 53 | ('prelu3', nn.PReLU(32)) 54 | ])) 55 | 56 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 57 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 58 | 59 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 60 | for n, p in self.named_parameters(): 61 | p.data = torch.FloatTensor(weights[n]) 62 | 63 | def forward(self, x): 64 | """ 65 | Arguments: 66 | x: a float tensor with shape [batch_size, 3, h, w]. 67 | Returns: 68 | b: a float tensor with shape [batch_size, 4, h', w']. 69 | a: a float tensor with shape [batch_size, 2, h', w']. 70 | """ 71 | x = self.features(x) 72 | a = self.conv4_1(x) 73 | b = self.conv4_2(x) 74 | a = F.softmax(a, dim=-1) 75 | return b, a 76 | 77 | 78 | class RNet(nn.Module): 79 | 80 | def __init__(self): 81 | super().__init__() 82 | 83 | self.features = nn.Sequential(OrderedDict([ 84 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 85 | ('prelu1', nn.PReLU(28)), 86 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 87 | 88 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 89 | ('prelu2', nn.PReLU(48)), 90 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 91 | 92 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 93 | ('prelu3', nn.PReLU(64)), 94 | 95 | ('flatten', Flatten()), 96 | ('conv4', nn.Linear(576, 128)), 97 | ('prelu4', nn.PReLU(128)) 98 | ])) 99 | 100 | self.conv5_1 = nn.Linear(128, 2) 101 | self.conv5_2 = nn.Linear(128, 4) 102 | 103 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 104 | for n, p in self.named_parameters(): 105 | p.data = torch.FloatTensor(weights[n]) 106 | 107 | def forward(self, x): 108 | """ 109 | Arguments: 110 | x: a float tensor with shape [batch_size, 3, h, w]. 111 | Returns: 112 | b: a float tensor with shape [batch_size, 4]. 113 | a: a float tensor with shape [batch_size, 2]. 114 | """ 115 | x = self.features(x) 116 | a = self.conv5_1(x) 117 | b = self.conv5_2(x) 118 | a = F.softmax(a, dim=-1) 119 | return b, a 120 | 121 | 122 | class ONet(nn.Module): 123 | 124 | def __init__(self): 125 | super().__init__() 126 | 127 | self.features = nn.Sequential(OrderedDict([ 128 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 129 | ('prelu1', nn.PReLU(32)), 130 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 131 | 132 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 133 | ('prelu2', nn.PReLU(64)), 134 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 135 | 136 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 137 | ('prelu3', nn.PReLU(64)), 138 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), 139 | 140 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 141 | ('prelu4', nn.PReLU(128)), 142 | 143 | ('flatten', Flatten()), 144 | ('conv5', nn.Linear(1152, 256)), 145 | ('drop5', nn.Dropout(0.25)), 146 | ('prelu5', nn.PReLU(256)), 147 | ])) 148 | 149 | self.conv6_1 = nn.Linear(256, 2) 150 | self.conv6_2 = nn.Linear(256, 4) 151 | self.conv6_3 = nn.Linear(256, 10) 152 | 153 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 154 | for n, p in self.named_parameters(): 155 | p.data = torch.FloatTensor(weights[n]) 156 | 157 | def forward(self, x): 158 | """ 159 | Arguments: 160 | x: a float tensor with shape [batch_size, 3, h, w]. 161 | Returns: 162 | c: a float tensor with shape [batch_size, 10]. 163 | b: a float tensor with shape [batch_size, 4]. 164 | a: a float tensor with shape [batch_size, 2]. 165 | """ 166 | x = self.features(x) 167 | a = self.conv6_1(x) 168 | b = self.conv6_2(x) 169 | c = self.conv6_3(x) 170 | a = F.softmax(a, dim=-1) 171 | return c, b, a 172 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /models/psp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the core research contribution 3 | """ 4 | import matplotlib 5 | 6 | matplotlib.use('Agg') 7 | import math 8 | 9 | import torch 10 | from torch import nn 11 | from models.encoders import psp_encoders 12 | from models.stylegan2.model import Generator 13 | from configs.paths_config import model_paths 14 | 15 | 16 | def get_keys(d, name): 17 | if 'state_dict' in d: 18 | d = d['state_dict'] 19 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 20 | return d_filt 21 | 22 | 23 | class pSp(nn.Module): 24 | 25 | def __init__(self, opts): 26 | super(pSp, self).__init__() 27 | self.set_opts(opts) 28 | # compute number of style inputs based on the output resolution 29 | self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 30 | # Define architecture 31 | self.encoder = self.set_encoder() 32 | self.decoder = Generator(self.opts.output_size, 512, 8) 33 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 34 | # Load weights if needed 35 | self.load_weights() 36 | 37 | def set_encoder(self): 38 | if self.opts.encoder_type == 'GradualStyleEncoder': 39 | encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) 40 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': 41 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) 42 | elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': 43 | encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts) 44 | else: 45 | raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) 46 | return encoder 47 | 48 | def load_weights(self): 49 | if self.opts.checkpoint_path is not None: 50 | print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path)) 51 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 52 | self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) 53 | self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) 54 | self.__load_latent_avg(ckpt) 55 | else: 56 | print('Loading encoders weights from irse50!') 57 | encoder_ckpt = torch.load(model_paths['ir_se50']) 58 | # if input to encoder is not an RGB image, do not load the input layer weights 59 | if self.opts.label_nc != 0: 60 | encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k} 61 | self.encoder.load_state_dict(encoder_ckpt, strict=False) 62 | print('Loading decoder weights from pretrained!') 63 | ckpt = torch.load(self.opts.stylegan_weights) 64 | self.decoder.load_state_dict(ckpt['g_ema'], strict=False) 65 | if self.opts.learn_in_w: 66 | self.__load_latent_avg(ckpt, repeat=1) 67 | else: 68 | self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) 69 | 70 | def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, 71 | inject_latent=None, return_latents=False, alpha=None): 72 | if input_code: 73 | codes = x 74 | else: 75 | codes = self.encoder(x) 76 | # normalize with respect to the center of an average face 77 | if self.opts.start_from_latent_avg: 78 | if self.opts.learn_in_w: 79 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1) 80 | else: 81 | codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) 82 | 83 | if latent_mask is not None: 84 | for i in latent_mask: 85 | if inject_latent is not None: 86 | if alpha is not None: 87 | codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] 88 | else: 89 | codes[:, i] = inject_latent[:, i] 90 | else: 91 | codes[:, i] = 0 92 | 93 | input_is_latent = not input_code 94 | images, result_latent = self.decoder([codes], 95 | input_is_latent=input_is_latent, 96 | randomize_noise=randomize_noise, 97 | return_latents=return_latents) 98 | 99 | if resize: 100 | images = self.face_pool(images) 101 | 102 | if return_latents: 103 | return images, result_latent 104 | else: 105 | return images 106 | 107 | def set_opts(self, opts): 108 | self.opts = opts 109 | 110 | def __load_latent_avg(self, ckpt, repeat=None): 111 | if 'latent_avg' in ckpt: 112 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 113 | if repeat is not None: 114 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 115 | else: 116 | self.latent_avg = None 117 | -------------------------------------------------------------------------------- /models/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/models/stylegan2/__init__.py -------------------------------------------------------------------------------- /models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | up_x, up_y = up 23 | down_x, down_y = down 24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 25 | 26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 27 | 28 | grad_input = upfirdn2d_op.upfirdn2d( 29 | grad_output, 30 | grad_kernel, 31 | down_x, 32 | down_y, 33 | up_x, 34 | up_y, 35 | g_pad_x0, 36 | g_pad_x1, 37 | g_pad_y0, 38 | g_pad_y1, 39 | ) 40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 41 | 42 | ctx.save_for_backward(kernel) 43 | 44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 45 | 46 | ctx.up_x = up_x 47 | ctx.up_y = up_y 48 | ctx.down_x = down_x 49 | ctx.down_y = down_y 50 | ctx.pad_x0 = pad_x0 51 | ctx.pad_x1 = pad_x1 52 | ctx.pad_y0 = pad_y0 53 | ctx.pad_y1 = pad_y1 54 | ctx.in_size = in_size 55 | ctx.out_size = out_size 56 | 57 | return grad_input 58 | 59 | @staticmethod 60 | def backward(ctx, gradgrad_input): 61 | kernel, = ctx.saved_tensors 62 | 63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 64 | 65 | gradgrad_out = upfirdn2d_op.upfirdn2d( 66 | gradgrad_input, 67 | kernel, 68 | ctx.up_x, 69 | ctx.up_y, 70 | ctx.down_x, 71 | ctx.down_y, 72 | ctx.pad_x0, 73 | ctx.pad_x1, 74 | ctx.pad_y0, 75 | ctx.pad_y1, 76 | ) 77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 78 | gradgrad_out = gradgrad_out.view( 79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 80 | ) 81 | 82 | return gradgrad_out, None, None, None, None, None, None, None, None 83 | 84 | 85 | class UpFirDn2d(Function): 86 | @staticmethod 87 | def forward(ctx, input, kernel, up, down, pad): 88 | up_x, up_y = up 89 | down_x, down_y = down 90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 91 | 92 | kernel_h, kernel_w = kernel.shape 93 | batch, channel, in_h, in_w = input.shape 94 | ctx.in_size = input.shape 95 | 96 | input = input.reshape(-1, in_h, in_w, 1) 97 | 98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 99 | 100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 102 | ctx.out_size = (out_h, out_w) 103 | 104 | ctx.up = (up_x, up_y) 105 | ctx.down = (down_x, down_y) 106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 107 | 108 | g_pad_x0 = kernel_w - pad_x0 - 1 109 | g_pad_y0 = kernel_h - pad_y0 - 1 110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 112 | 113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 114 | 115 | out = upfirdn2d_op.upfirdn2d( 116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 117 | ) 118 | # out = out.view(major, out_h, out_w, minor) 119 | out = out.view(-1, channel, out_h, out_w) 120 | 121 | return out 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | kernel, grad_kernel = ctx.saved_tensors 126 | 127 | grad_input = UpFirDn2dBackward.apply( 128 | grad_output, 129 | kernel, 130 | grad_kernel, 131 | ctx.up, 132 | ctx.down, 133 | ctx.pad, 134 | ctx.g_pad, 135 | ctx.in_size, 136 | ctx.out_size, 137 | ) 138 | 139 | return grad_input, None, None, None, None 140 | 141 | 142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 143 | out = UpFirDn2d.apply( 144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 145 | ) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native( 151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 152 | ): 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad( 161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 162 | ) 163 | out = out[ 164 | :, 165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), 166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), 167 | :, 168 | ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape( 172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 173 | ) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | 184 | return out[:, ::down_y, ::down_x, :] 185 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/options/__init__.py -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class TestOptions: 5 | 6 | def __init__(self): 7 | self.parser = ArgumentParser() 8 | self.initialize() 9 | 10 | def initialize(self): 11 | # arguments for inference script 12 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 13 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') 14 | self.parser.add_argument('--checkpoint_path_af', default=None, type=str, help='Path to pSp model checkpoint') 15 | self.parser.add_argument('--data_path', type=str, default='gt_images', help='Path to directory of images to evaluate') 16 | self.parser.add_argument('--couple_outputs', action='store_true', help='Whether to also save inputs + outputs side-by-side') 17 | self.parser.add_argument('--resize_outputs', action='store_true', help='Whether to resize outputs to 256x256 or keep at 1024x1024') 18 | 19 | self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference') 20 | self.parser.add_argument('--test_workers', default=1, type=int, help='Number of test/inference dataloader workers') 21 | # self.parser.add_argument('--edit_type', edit_paths) 22 | 23 | # arguments for style-mixing script 24 | self.parser.add_argument('--n_images', type=int, default=None, help='Number of images to output. If None, run on all data') 25 | self.parser.add_argument('--n_outputs_to_generate', type=int, default=5, help='Number of outputs to generate per input image.') 26 | self.parser.add_argument('--mix_alpha', type=float, default=None, help='Alpha value for style-mixing') 27 | self.parser.add_argument('--latent_mask', type=str, default=None, help='Comma-separated list of latents to perform style-mixing with') 28 | # self.parser.add_argument('--no_feature_attention', action='store_true') 29 | self.parser.add_argument('--edit_attribute', default=None, type=str, help='smile, age, eyes, bread, lip, pose') 30 | self.parser.add_argument('--edit_degree', default=3.0, type=float) 31 | self.parser.add_argument('--no_w_attention', action='store_true') 32 | self.parser.add_argument('--idx_k', type=int, default=None) 33 | self.parser.add_argument('--no_res', action='store_true') 34 | self.parser.add_argument('--is_car', action='store_true') 35 | # arguments for super-resolution 36 | self.parser.add_argument('--resize_factors', type=str, default=None, 37 | help='Downsampling factor for super-res (should be a single value for inference).') 38 | 39 | def parse(self): 40 | opts = self.parser.parse_args() 41 | return opts -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from configs.paths_config import model_paths 3 | 4 | 5 | class TrainOptions: 6 | 7 | def __init__(self): 8 | self.parser = ArgumentParser() 9 | self.initialize() 10 | 11 | def initialize(self): 12 | # common option 13 | self.parser.add_argument('--train_inversion', action="store_true", help='Train your inversion model') 14 | self.parser.add_argument('--train_contrastive', action="store_true", help='Train your contrastive model') 15 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 16 | self.parser.add_argument('--dataset_type', default='ffhq_encode_inversion', type=str, 17 | help='Type of dataset/experiment to run') 18 | self.parser.add_argument('--encoder_type', default='GradualStyleEncoder', type=str, help='Which encoder to use') 19 | self.parser.add_argument('--input_nc', default=3, type=int, 20 | help='Number of input image channels to the psp encoder') 21 | self.parser.add_argument('--label_nc', default=0, type=int, 22 | help='Number of input label channels to the psp encoder') 23 | self.parser.add_argument('--output_size', default=1024, type=int, help='Output size of generator') 24 | 25 | self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training') 26 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') 27 | self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') 28 | self.parser.add_argument('--test_workers', default=2, type=int, 29 | help='Number of test/inference dataloader workers') 30 | 31 | self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate') 32 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') 33 | self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') 34 | self.parser.add_argument('--start_from_latent_avg', action='store_true', 35 | help='Whether to add average latent vector to generate codes from encoder.') 36 | self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space instead of w+') 37 | self.parser.add_argument('--n_styles', default=18) 38 | self.parser.add_argument('--num_latent', default=1) 39 | # for contrastive learning 40 | self.parser.add_argument('--contrastive_lambda', default=0.1, type=float, help='Contrastive loss factor') 41 | self.parser.add_argument('--use_norm', action='store_true', help='Use norm before calculate contrastive loss') 42 | self.parser.add_argument('--latent_embedding_dim', default=512, type=int, help='the dim of latent embedding') 43 | self.parser.add_argument('--image_embedding_dim', default=512, type=int, help='the dim of image_embedding') 44 | self.parser.add_argument('--projection_dim', default=512, type=int, help='projection dim of projection head') 45 | self.parser.add_argument('--load_pretrain_image_encoder', default=False) 46 | self.parser.add_argument('--checkpoint_path_image', default=None, type=str, help='Path to image model checkpoint') 47 | self.parser.add_argument('--checkpoint_path_latent', default=None, type=str, help='Path to latent model checkpoint') 48 | self.parser.add_argument('--checkpoint_path_af', default=None, type=str, 49 | help='Path to latset model checkpoint') 50 | 51 | self.parser.add_argument('--landmark_lambda', default=0.8, type=float, help='contrastive id loss multiplier factor' ) 52 | self.parser.add_argument('--feature_matching_lambda', default=0.01, type=float, 53 | help='feature matching loss multiplier factor') 54 | self.parser.add_argument('--multi_scale_lpips', default=True) 55 | self.parser.add_argument('--not_for_feature', action='store_true') 56 | self.parser.add_argument('--not_for_ww', action='store_true') 57 | self.parser.add_argument('--global_step', default=None, type=int) 58 | self.parser.add_argument('--no_feature_attention', action='store_true') 59 | self.parser.add_argument('--contrastive_model_image', type=str, default='contrastive_ffhq_image') 60 | self.parser.add_argument('--contrastive_model_latent', type=str, default='contrastive_ffhq_latent') 61 | self.parser.add_argument('--no_w_attention', action='store_true') 62 | self.parser.add_argument('--idx_k', type=int, default=None) 63 | self.parser.add_argument('--no_res', action='store_true') 64 | #################################### 65 | self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor') 66 | self.parser.add_argument('--id_lambda', default=0, type=float, help='ID loss multiplier factor') 67 | self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor') 68 | self.parser.add_argument('--w_norm_lambda', default=0, type=float, help='W-norm loss multiplier factor') 69 | self.parser.add_argument('--lpips_lambda_crop', default=0, type=float, 70 | help='LPIPS loss multiplier factor for inner image region') 71 | self.parser.add_argument('--l2_lambda_crop', default=0, type=float, 72 | help='L2 loss multiplier factor for inner image region') 73 | self.parser.add_argument('--moco_lambda', default=0, type=float, 74 | help='Moco-based feature similarity loss multiplier factor') 75 | 76 | self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str, 77 | help='Path to StyleGAN model weights') 78 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') 79 | 80 | self.parser.add_argument('--max_epoch', default=80, type=int, help='Maximum number of training steps') 81 | self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps') 82 | self.parser.add_argument('--image_interval', default=100, type=int, 83 | help='Interval for logging train images during training') 84 | self.parser.add_argument('--board_interval', default=50, type=int, 85 | help='Interval for logging metrics to tensorboard') 86 | self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval') 87 | self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval') 88 | 89 | # ddp 90 | self.parser.add_argument("--local_rank", default=0, type=int) 91 | self.parser.add_argument("--use_ddp", action='store_true', help = 'Whether to use the ddp of pytorch') 92 | 93 | # arguments for weights & biases support 94 | self.parser.add_argument('--use_wandb', action="store_true", 95 | help='Whether to use Weights & Biases to track experiment.') 96 | 97 | 98 | def parse(self): 99 | opts = self.parser.parse_args() 100 | return opts 101 | -------------------------------------------------------------------------------- /scripts/align_all_parallel.py: -------------------------------------------------------------------------------- 1 | """ 2 | brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) 3 | author: lzhbrian (https://lzhbrian.me) 4 | date: 2020.1.5 5 | note: code is heavily borrowed from 6 | https://github.com/NVlabs/ffhq-dataset 7 | http://dlib.net/face_landmark_detection.py.html 8 | 9 | requirements: 10 | apt install cmake 11 | conda install Pillow numpy scipy 12 | pip install dlib 13 | # download face landmark model from: 14 | # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 15 | """ 16 | from argparse import ArgumentParser 17 | import time 18 | import numpy as np 19 | import PIL 20 | import PIL.Image 21 | import os 22 | import scipy 23 | import scipy.ndimage 24 | import dlib 25 | import multiprocessing as mp 26 | import math 27 | 28 | from configs.paths_config import model_paths 29 | SHAPE_PREDICTOR_PATH = model_paths["shape_predictor"] 30 | 31 | 32 | def get_landmark(filepath, predictor): 33 | """get landmark with dlib 34 | :return: np.array shape=(68, 2) 35 | """ 36 | detector = dlib.get_frontal_face_detector() 37 | 38 | img = dlib.load_rgb_image(filepath) 39 | dets = detector(img, 1) 40 | 41 | for k, d in enumerate(dets): 42 | shape = predictor(img, d) 43 | 44 | t = list(shape.parts()) 45 | a = [] 46 | for tt in t: 47 | a.append([tt.x, tt.y]) 48 | lm = np.array(a) 49 | return lm 50 | 51 | 52 | def align_face(filepath, predictor): 53 | """ 54 | :param filepath: str 55 | :return: PIL Image 56 | """ 57 | 58 | lm = get_landmark(filepath, predictor) 59 | 60 | lm_chin = lm[0: 17] # left-right 61 | lm_eyebrow_left = lm[17: 22] # left-right 62 | lm_eyebrow_right = lm[22: 27] # left-right 63 | lm_nose = lm[27: 31] # top-down 64 | lm_nostrils = lm[31: 36] # top-down 65 | lm_eye_left = lm[36: 42] # left-clockwise 66 | lm_eye_right = lm[42: 48] # left-clockwise 67 | lm_mouth_outer = lm[48: 60] # left-clockwise 68 | lm_mouth_inner = lm[60: 68] # left-clockwise 69 | 70 | # Calculate auxiliary vectors. 71 | eye_left = np.mean(lm_eye_left, axis=0) 72 | eye_right = np.mean(lm_eye_right, axis=0) 73 | eye_avg = (eye_left + eye_right) * 0.5 74 | eye_to_eye = eye_right - eye_left 75 | mouth_left = lm_mouth_outer[0] 76 | mouth_right = lm_mouth_outer[6] 77 | mouth_avg = (mouth_left + mouth_right) * 0.5 78 | eye_to_mouth = mouth_avg - eye_avg 79 | 80 | # Choose oriented crop rectangle. 81 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 82 | x /= np.hypot(*x) 83 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 84 | y = np.flipud(x) * [-1, 1] 85 | c = eye_avg + eye_to_mouth * 0.1 86 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 87 | qsize = np.hypot(*x) * 2 88 | 89 | # read image 90 | img = PIL.Image.open(filepath) 91 | 92 | output_size = 256 93 | transform_size = 256 94 | enable_padding = True 95 | 96 | # Shrink. 97 | shrink = int(np.floor(qsize / output_size * 0.5)) 98 | if shrink > 1: 99 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 100 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 101 | quad /= shrink 102 | qsize /= shrink 103 | 104 | # Crop. 105 | border = max(int(np.rint(qsize * 0.1)), 3) 106 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 107 | int(np.ceil(max(quad[:, 1])))) 108 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), 109 | min(crop[3] + border, img.size[1])) 110 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 111 | img = img.crop(crop) 112 | quad -= crop[0:2] 113 | 114 | # Pad. 115 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 116 | int(np.ceil(max(quad[:, 1])))) 117 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), 118 | max(pad[3] - img.size[1] + border, 0)) 119 | if enable_padding and max(pad) > border - 4: 120 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 121 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 122 | h, w, _ = img.shape 123 | y, x, _ = np.ogrid[:h, :w, :1] 124 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 125 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) 126 | blur = qsize * 0.02 127 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 128 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 129 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 130 | quad += pad[:2] 131 | 132 | # Transform. 133 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 134 | if output_size < transform_size: 135 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 136 | 137 | # Save aligned image. 138 | return img 139 | 140 | 141 | def chunks(lst, n): 142 | """Yield successive n-sized chunks from lst.""" 143 | for i in range(0, len(lst), n): 144 | yield lst[i:i + n] 145 | 146 | 147 | def extract_on_paths(file_paths): 148 | predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH) 149 | pid = mp.current_process().name 150 | print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths))) 151 | tot_count = len(file_paths) 152 | count = 0 153 | for file_path, res_path in file_paths: 154 | count += 1 155 | if count % 100 == 0: 156 | print('{} done with {}/{}'.format(pid, count, tot_count)) 157 | try: 158 | res = align_face(file_path, predictor) 159 | res = res.convert('RGB') 160 | os.makedirs(os.path.dirname(res_path), exist_ok=True) 161 | res.save(res_path) 162 | except Exception: 163 | continue 164 | print('\tDone!') 165 | 166 | 167 | def parse_args(): 168 | parser = ArgumentParser(add_help=False) 169 | parser.add_argument('--num_threads', type=int, default=1) 170 | parser.add_argument('--root_path', type=str, default='') 171 | args = parser.parse_args() 172 | return args 173 | 174 | 175 | def run(args): 176 | root_path = args.root_path 177 | out_crops_path = root_path + '_crops' 178 | if not os.path.exists(out_crops_path): 179 | os.makedirs(out_crops_path, exist_ok=True) 180 | 181 | file_paths = [] 182 | for root, dirs, files in os.walk(root_path): 183 | for file in files: 184 | file_path = os.path.join(root, file) 185 | fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path)) 186 | res_path = '{}.jpg'.format(os.path.splitext(fname)[0]) 187 | if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path): 188 | continue 189 | file_paths.append((file_path, res_path)) 190 | 191 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 192 | print(len(file_chunks)) 193 | pool = mp.Pool(args.num_threads) 194 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 195 | tic = time.time() 196 | pool.map(extract_on_paths, file_chunks) 197 | toc = time.time() 198 | print('Mischief managed in {}s'.format(toc - tic)) 199 | 200 | 201 | if __name__ == '__main__': 202 | args = parse_args() 203 | run(args) 204 | -------------------------------------------------------------------------------- /scripts/calc_id_loss_parallel.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import time 3 | import numpy as np 4 | import os 5 | import json 6 | import sys 7 | from PIL import Image 8 | import multiprocessing as mp 9 | import math 10 | import torch 11 | import torchvision.transforms as trans 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from models.mtcnn.mtcnn import MTCNN 17 | from models.encoders.model_irse import IR_101 18 | from configs.paths_config import model_paths 19 | CIRCULAR_FACE_PATH = model_paths['circular_face'] 20 | 21 | 22 | def chunks(lst, n): 23 | """Yield successive n-sized chunks from lst.""" 24 | for i in range(0, len(lst), n): 25 | yield lst[i:i + n] 26 | 27 | 28 | def extract_on_paths(file_paths): 29 | facenet = IR_101(input_size=112) 30 | facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) 31 | facenet.cuda() 32 | facenet.eval() 33 | mtcnn = MTCNN() 34 | id_transform = trans.Compose([ 35 | trans.ToTensor(), 36 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 37 | ]) 38 | 39 | pid = mp.current_process().name 40 | print('\t{} is starting to extract on {} images'.format(pid, len(file_paths))) 41 | tot_count = len(file_paths) 42 | count = 0 43 | 44 | scores_dict = {} 45 | for res_path, gt_path in file_paths: 46 | count += 1 47 | if count % 100 == 0: 48 | print('{} done with {}/{}'.format(pid, count, tot_count)) 49 | if True: 50 | input_im = Image.open(res_path) 51 | input_im, _ = mtcnn.align(input_im) 52 | if input_im is None: 53 | print('{} skipping {}'.format(pid, res_path)) 54 | continue 55 | 56 | input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] 57 | 58 | result_im = Image.open(gt_path) 59 | result_im, _ = mtcnn.align(result_im) 60 | if result_im is None: 61 | print('{} skipping {}'.format(pid, gt_path)) 62 | continue 63 | 64 | result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] 65 | score = float(input_id.dot(result_id)) 66 | scores_dict[os.path.basename(gt_path)] = score 67 | 68 | return scores_dict 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser(add_help=False) 73 | parser.add_argument('--num_threads', type=int, default=4) 74 | parser.add_argument('--data_path', type=str, default='results') 75 | parser.add_argument('--gt_path', type=str, default='gt_images') 76 | args = parser.parse_args() 77 | return args 78 | 79 | 80 | def run(args): 81 | file_paths = [] 82 | for f in os.listdir(args.data_path): 83 | image_path = os.path.join(args.data_path, f) 84 | gt_path = os.path.join(args.gt_path, f) 85 | if f.endswith(".jpg") or f.endswith('.png'): 86 | file_paths.append([image_path, gt_path.replace('.png','.jpg')]) 87 | 88 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 89 | pool = mp.Pool(args.num_threads) 90 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 91 | 92 | tic = time.time() 93 | results = pool.map(extract_on_paths, file_chunks) 94 | scores_dict = {} 95 | for d in results: 96 | scores_dict.update(d) 97 | 98 | all_scores = list(scores_dict.values()) 99 | mean = np.mean(all_scores) 100 | std = np.std(all_scores) 101 | result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std) 102 | print(result_str) 103 | 104 | out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 105 | if not os.path.exists(out_path): 106 | os.makedirs(out_path) 107 | 108 | with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f: 109 | f.write(result_str) 110 | with open(os.path.join(out_path, 'scores_id.json'), 'w') as f: 111 | json.dump(scores_dict, f) 112 | 113 | toc = time.time() 114 | print('Mischief managed in {}s'.format(toc - tic)) 115 | 116 | 117 | if __name__ == '__main__': 118 | args = parse_args() 119 | run(args) 120 | -------------------------------------------------------------------------------- /scripts/calc_losses_on_images.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import json 4 | import sys 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torchvision.transforms as transforms 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from criteria.lpips.lpips import LPIPS 15 | from datasets.gt_res_dataset import GTResDataset 16 | 17 | 18 | def parse_args(): 19 | parser = ArgumentParser(add_help=False) 20 | parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) 21 | parser.add_argument('--data_path', type=str, default='results') 22 | parser.add_argument('--gt_path', type=str, default='gt_images') 23 | parser.add_argument('--workers', type=int, default=4) 24 | parser.add_argument('--batch_size', type=int, default=4) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def run(args): 30 | transform = transforms.Compose([transforms.Resize((256, 256)), 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 33 | 34 | print('Loading dataset') 35 | dataset = GTResDataset(root_path=args.data_path, 36 | gt_dir=args.gt_path, 37 | transform=transform) 38 | 39 | dataloader = DataLoader(dataset, 40 | batch_size=args.batch_size, 41 | shuffle=False, 42 | num_workers=int(args.workers), 43 | drop_last=True) 44 | 45 | if args.mode == 'lpips': 46 | loss_func = LPIPS(net_type='alex') 47 | elif args.mode == 'l2': 48 | loss_func = torch.nn.MSELoss() 49 | else: 50 | raise Exception('Not a valid mode!') 51 | loss_func.cuda() 52 | 53 | global_i = 0 54 | scores_dict = {} 55 | all_scores = [] 56 | for result_batch, gt_batch in tqdm(dataloader): 57 | for i in range(args.batch_size): 58 | loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda())) 59 | all_scores.append(loss) 60 | im_path = dataset.pairs[global_i][0] 61 | scores_dict[os.path.basename(im_path)] = loss 62 | global_i += 1 63 | 64 | all_scores = list(scores_dict.values()) 65 | mean = np.mean(all_scores) 66 | std = np.std(all_scores) 67 | result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) 68 | print('Finished with ', args.data_path) 69 | print(result_str) 70 | 71 | out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 72 | if not os.path.exists(out_path): 73 | os.makedirs(out_path) 74 | 75 | with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: 76 | f.write(result_str) 77 | with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: 78 | json.dump(scores_dict, f) 79 | 80 | 81 | if __name__ == '__main__': 82 | args = parse_args() 83 | run(args) 84 | -------------------------------------------------------------------------------- /scripts/inference_edit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | 4 | from tqdm import tqdm 5 | import time 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | import sys 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | from editings import latent_editor 15 | from configs import data_configs 16 | from datasets.inference_dataset_me import InferenceDataset 17 | from utils.common import tensor2im, log_input_image 18 | from options.test_options import TestOptions 19 | from models.attention_feature_psp import AFPSP 20 | from configs.paths_config import edit_paths 21 | 22 | 23 | # from editings 24 | 25 | def run(): 26 | test_opts = TestOptions().parse() 27 | edit_directory_path = os.path.join(test_opts.exp_dir, test_opts.edit_attribute) 28 | os.makedirs(edit_directory_path, exist_ok=True) 29 | 30 | # update test options with options used during training 31 | ckpt = torch.load(test_opts.checkpoint_path_af, map_location='cpu') 32 | opts = ckpt['opts'] 33 | print(f"iter:{ckpt['iter']}") 34 | iter = 1.0 35 | opts.update(vars(test_opts)) 36 | if 'learn_in_w' not in opts: 37 | opts['learn_in_w'] = False 38 | if 'output_size' not in opts: 39 | opts['output_size'] = 1024 40 | opts = Namespace(**opts) 41 | print(opts) 42 | print('#################### network init #####################') 43 | net = AFPSP(opts) 44 | net.load_weights() 45 | net.eval() 46 | net.cuda() 47 | 48 | print('Loading dataset for {}'.format(opts.dataset_type)) 49 | dataset_args = data_configs.DATASETS[opts.dataset_type] 50 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 51 | dataset = InferenceDataset(root=opts.data_path, 52 | image_avg_root=dataset_args['avg_image_root'], 53 | transform=transforms_dict['transform_inference'], 54 | opts=opts) 55 | dataloader = DataLoader(dataset, 56 | batch_size=opts.test_batch_size, 57 | shuffle=False, 58 | num_workers=int(opts.test_workers), 59 | drop_last=True) 60 | 61 | if opts.n_images is None: 62 | opts.n_images = len(dataset) 63 | edit_direction = None 64 | edit_degree = None 65 | if opts.edit_attribute is not None: 66 | print(f'######edit {opts.edit_attribute} ##############') 67 | edit_degree = opts.edit_degree 68 | edit_direction, ganspace_pca = edit(opts) 69 | 70 | 71 | for factor in range(int(-edit_degree), int(edit_degree)): 72 | global_i = 0 73 | global_time = [] 74 | out_path_base = os.path.join(edit_directory_path, f'{str(factor)}') 75 | os.makedirs(edit_directory_path, exist_ok=True) 76 | out_path_results = os.path.join(out_path_base, 'inference_results') 77 | out_path_coupled = os.path.join(out_path_base, 'inference_coupled') 78 | out_path_w = os.path.join(out_path_base, 'inference_w') 79 | out_path_ww = os.path.join(out_path_base, 'inference_ww') 80 | os.makedirs(out_path_results, exist_ok=True) 81 | os.makedirs(out_path_coupled, exist_ok=True) 82 | os.makedirs(out_path_w, exist_ok=True) 83 | os.makedirs(out_path_ww, exist_ok=True) 84 | for input_batch in tqdm(dataloader): 85 | if global_i >= opts.n_images: 86 | break 87 | with torch.no_grad(): 88 | input_cuda, image_avg = input_batch 89 | input_cuda = input_cuda.cuda().float() 90 | image_avg = image_avg.cuda().float() 91 | tic = time.time() 92 | result_batch, results_ww_batch, results_w_batch, latent_base = run_on_batch(input_cuda, image_avg, net, 93 | opts, iter, edit_direction, 94 | factor) 95 | toc = time.time() 96 | global_time.append(toc - tic) 97 | for i in range(opts.test_batch_size): 98 | result = tensor2im(result_batch[i]) 99 | result_ww = tensor2im(results_ww_batch[i]) 100 | result_w = tensor2im(results_w_batch[i]) 101 | 102 | im_path = dataset.paths[global_i] 103 | 104 | if opts.couple_outputs or global_i % 100 == 0: 105 | input_im = log_input_image(input_cuda[i], opts) 106 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 107 | # otherwise, save the original and output 108 | res = np.concatenate([np.array(input_im.resize(resize_amount)), 109 | np.array(result.resize(resize_amount))], axis=1) 110 | Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path))) 111 | 112 | im_save_path = os.path.join(out_path_results, os.path.basename(im_path)) 113 | im_save_path_ww = os.path.join(out_path_ww, os.path.basename(im_path)) 114 | im_save_path_w = os.path.join(out_path_w, os.path.basename(im_path)) 115 | Image.fromarray(np.array(result)).save(im_save_path) 116 | Image.fromarray(np.array(result_ww)).save(im_save_path_ww) 117 | 118 | Image.fromarray(np.array(result_w)).save(im_save_path_w) 119 | 120 | global_i += 1 121 | 122 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 123 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 124 | print(result_str) 125 | 126 | with open(stats_path, 'w') as f: 127 | f.write(result_str) 128 | 129 | 130 | def edit(opts): 131 | ganspace_pca = None 132 | if opts.edit_attribute == 'age' or opts.edit_attribute == 'smile' or opts.edit_attribute == 'pose': 133 | edit_direction = torch.load(edit_paths[opts.edit_attribute]).cuda() 134 | 135 | elif opts.edit_attribute == 'eyes' or opts.edit_attribute == 'beard' or opts.edit_attribute == 'lip': 136 | ganspace_pca = torch.load(edit_paths[opts.edit_attribute]) 137 | ganspace_directions = { 138 | 'eyes': (54, 7, 8, 20), 139 | 'beard': (58, 7, 9, -20), 140 | 'lip': (34, 10, 11, 20), 141 | 'white_hair': (57, 7, 10, -24)} 142 | edit_direction = ganspace_directions[opts.edit_attribute] 143 | else: 144 | ganspace_pca = torch.load(edit_paths[opts.edit_attribute]) 145 | ganspace_directions = { 146 | 'Viewpoint1': (0, 0, 5, 2), 147 | 'Viewpoint2': (0, 0, 5, -2), 148 | 'Cube': (16, 3, 6, 25), 149 | 'Color': (22, 9, 11, -8), 150 | 'Grass': (41, 9, 11, -18)} 151 | edit_direction = ganspace_directions[opts.edit_attribute] 152 | # For a single edit: 153 | return edit_direction, ganspace_pca 154 | 155 | 156 | def run_on_batch(inputs, image_avg, net, opts, iter, latent_offset=None, factor=None): 157 | result_batch, latent_refine, latent_base, feature_offset, feature_refine, results_ww, results_w = net(inputs, 158 | image_avg, 159 | randomize_noise=False, 160 | resize=opts.resize_outputs, 161 | interation=iter, 162 | return_features=True, 163 | edit_offset=latent_offset * factor) 164 | 165 | return result_batch, results_ww, results_w, latent_base 166 | 167 | 168 | if __name__ == '__main__': 169 | run() 170 | -------------------------------------------------------------------------------- /scripts/inference_inversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | 4 | from tqdm import tqdm 5 | import time 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | import sys 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | from editings import latent_editor 15 | from configs import data_configs 16 | from datasets.inference_dataset_me import InferenceDataset 17 | from utils.common import tensor2im, log_input_image 18 | from options.test_options import TestOptions 19 | from models.attention_feature_psp import AFPSP 20 | from configs.paths_config import edit_paths 21 | 22 | 23 | # from editings 24 | 25 | def run(): 26 | test_opts = TestOptions().parse() 27 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') 28 | out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled') 29 | out_path_input = os.path.join(test_opts.exp_dir, 'inference_input') 30 | out_path_w = os.path.join(test_opts.exp_dir, 'inference_w') 31 | out_path_ww = os.path.join(test_opts.exp_dir, 'inference_ww') 32 | out_path_ww_embedding = os.path.join(test_opts.exp_dir, 'inference_ww_embedding') 33 | out_path_w_embedding = os.path.join(test_opts.exp_dir, 'inference_w_embedding') 34 | os.makedirs(out_path_input, exist_ok=True) 35 | os.makedirs(out_path_ww_embedding, exist_ok=True) 36 | os.makedirs(out_path_w_embedding, exist_ok=True) 37 | os.makedirs(out_path_results, exist_ok=True) 38 | os.makedirs(out_path_coupled, exist_ok=True) 39 | os.makedirs(out_path_w, exist_ok=True) 40 | os.makedirs(out_path_ww, exist_ok=True) 41 | # update test options with options used during training 42 | ckpt = torch.load(test_opts.checkpoint_path_af, map_location='cpu') 43 | opts = ckpt['opts'] 44 | iter = 1.0 45 | opts.update(vars(test_opts)) 46 | if 'learn_in_w' not in opts: 47 | opts['learn_in_w'] = False 48 | if 'output_size' not in opts: 49 | opts['output_size'] = 1024 50 | opts = Namespace(**opts) 51 | print(opts) 52 | print('#################### network init #####################') 53 | net = AFPSP(opts) 54 | net.load_weights() 55 | net.eval() 56 | net.cuda() 57 | 58 | print('Loading dataset for {}'.format(opts.dataset_type)) 59 | dataset_args = data_configs.DATASETS[opts.dataset_type] 60 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 61 | dataset = InferenceDataset(root=opts.data_path, 62 | image_avg_root=dataset_args['avg_image_root'], 63 | transform=transforms_dict['transform_inference'], 64 | opts=opts) 65 | dataloader = DataLoader(dataset, 66 | batch_size=opts.test_batch_size, 67 | shuffle=False, 68 | num_workers=int(opts.test_workers), 69 | drop_last=True) 70 | 71 | if opts.n_images is None: 72 | opts.n_images = len(dataset) 73 | 74 | 75 | global_i = 0 76 | global_time = [] 77 | 78 | for input_batch in tqdm(dataloader): 79 | if global_i >= opts.n_images: 80 | break 81 | with torch.no_grad(): 82 | input_cuda, image_avg = input_batch 83 | input_cuda = input_cuda.cuda().float() 84 | image_avg = image_avg.cuda().float() 85 | tic = time.time() 86 | result_batch, results_ww_batch, results_w_batch, latent_base, latent_refine = run_on_batch(input_cuda, image_avg, net, 87 | opts, iter ) 88 | toc = time.time() 89 | global_time.append(toc - tic) 90 | for i in range(opts.test_batch_size): 91 | result = tensor2im(result_batch[i]) 92 | result_ww = tensor2im(results_ww_batch[i]) 93 | result_w = tensor2im(results_w_batch[i]) 94 | 95 | im_path = dataset.paths[global_i] 96 | 97 | if opts.couple_outputs or global_i % 100 == 0: 98 | input_im = log_input_image(input_cuda[i], opts) 99 | resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size) 100 | # otherwise, save the original and output 101 | res = np.concatenate([np.array(input_im.resize(resize_amount)), 102 | np.array(result.resize(resize_amount))], axis=1) 103 | Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path))) 104 | 105 | im_input_save_path = os.path.join(out_path_input, os.path.basename(im_path)) 106 | Image.fromarray(np.array(input_im.resize(resize_amount))).save(im_input_save_path) 107 | 108 | 109 | im_save_path = os.path.join(out_path_results, os.path.basename(im_path)) 110 | 111 | im_save_path_ww = os.path.join(out_path_ww, os.path.basename(im_path)) 112 | im_save_path_w = os.path.join(out_path_w, os.path.basename(im_path)) 113 | 114 | Image.fromarray(np.array(result)).save(im_save_path) 115 | Image.fromarray(np.array(result_ww)).save(im_save_path_ww) 116 | 117 | Image.fromarray(np.array(result_w)).save(im_save_path_w) 118 | global_i += 1 119 | 120 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 121 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 122 | print(result_str) 123 | 124 | with open(stats_path, 'w') as f: 125 | f.write(result_str) 126 | 127 | 128 | def run_on_batch(inputs, image_avg, net, opts, iter, latent_offset=None, factor=None): 129 | result_batch, latent_refine, latent_base, feature_offset, feature_refine, results_ww, results_w = net(inputs, 130 | image_avg, 131 | randomize_noise=False, 132 | resize=opts.resize_outputs, 133 | interation=iter, 134 | return_features=True) 135 | 136 | if latent_offset is not None: 137 | result_batch, latent_refine, latent_base, feature_offset, feature_refine, results_ww, results_w = net(inputs, 138 | image_avg, 139 | randomize_noise=False, 140 | resize=opts.resize_outputs, 141 | interation=iter, 142 | return_features=True, 143 | edit_offset=latent_offset * factor) 144 | 145 | 146 | 147 | return result_batch, results_ww, results_w, latent_base, latent_refine 148 | 149 | 150 | if __name__ == '__main__': 151 | run() 152 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import json 5 | import sys 6 | import pprint 7 | import random 8 | import numpy as np 9 | import torch 10 | import os 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | from options.train_options import TrainOptions 14 | from training.contrastive_coach import ContrastiveCoach 15 | from training.inversion_coach import InversionCoach 16 | 17 | def init_seeds(seed=0, cuda_deterministic=True): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | 22 | 23 | def main(): 24 | opts = TrainOptions().parse() 25 | if opts.use_ddp: 26 | rank = int(os.environ["LOCAL_RANK"]) 27 | 28 | init_seeds(seed=1 + rank) 29 | else: 30 | init_seeds(seed=0) 31 | if os.path.exists(opts.exp_dir): 32 | print('Oops... {} already exists'.format(opts.exp_dir)) 33 | else: 34 | os.makedirs(opts.exp_dir, exist_ok=True) 35 | 36 | opts_dict = vars(opts) 37 | pprint.pprint(opts_dict) 38 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 39 | json.dump(opts_dict, f, indent=4, sort_keys=True) 40 | if opts.train_inversion: 41 | coach = InversionCoach(opts) 42 | elif opts.train_contrastive: 43 | coach = ContrastiveCoach(opts) 44 | else: 45 | raise ValueError ('Please select the correct model type') 46 | coach.train() 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/training/__init__.py -------------------------------------------------------------------------------- /training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | try: 6 | import horovod.torch as hvd 7 | except ImportError: 8 | hvd = None 9 | 10 | 11 | def is_global_master(args): 12 | return args.rank == 0 13 | 14 | 15 | def is_local_master(args): 16 | return args.local_rank == 0 17 | 18 | 19 | def is_master(args, local=False): 20 | return is_local_master(args) if local else is_global_master(args) 21 | 22 | 23 | def is_using_horovod(): 24 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 25 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 26 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 27 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 28 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 29 | return True 30 | else: 31 | return False 32 | 33 | 34 | def is_using_distributed(): 35 | if 'WORLD_SIZE' in os.environ: 36 | return int(os.environ['WORLD_SIZE']) > 1 37 | if 'SLURM_NTASKS' in os.environ: 38 | return int(os.environ['SLURM_NTASKS']) > 1 39 | return False 40 | 41 | 42 | def world_info_from_env(): 43 | local_rank = 0 44 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 45 | if v in os.environ: 46 | local_rank = int(os.environ[v]) 47 | break 48 | global_rank = 0 49 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 50 | if v in os.environ: 51 | global_rank = int(os.environ[v]) 52 | break 53 | world_size = 1 54 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 55 | if v in os.environ: 56 | world_size = int(os.environ[v]) 57 | break 58 | 59 | return local_rank, global_rank, world_size 60 | 61 | 62 | def init_distributed_device(args): 63 | # Distributed training = training on more than one GPU. 64 | # Works in both single and multi-node scenarios. 65 | args.distributed = False 66 | args.world_size = 1 67 | args.rank = 0 # global rank 68 | args.local_rank = 0 69 | if args.horovod: 70 | assert hvd is not None, "Horovod is not installed" 71 | hvd.init() 72 | args.local_rank = int(hvd.local_rank()) 73 | args.rank = hvd.rank() 74 | args.world_size = hvd.size() 75 | args.distributed = True 76 | os.environ['LOCAL_RANK'] = str(args.local_rank) 77 | os.environ['RANK'] = str(args.rank) 78 | os.environ['WORLD_SIZE'] = str(args.world_size) 79 | elif is_using_distributed(): 80 | if 'SLURM_PROCID' in os.environ: 81 | # DDP via SLURM 82 | args.local_rank, args.rank, args.world_size = world_info_from_env() 83 | # SLURM var -> torch.distributed vars in case needed 84 | os.environ['LOCAL_RANK'] = str(args.local_rank) 85 | os.environ['RANK'] = str(args.rank) 86 | os.environ['WORLD_SIZE'] = str(args.world_size) 87 | torch.distributed.init_process_group( 88 | backend=args.dist_backend, 89 | init_method=args.dist_url, 90 | world_size=args.world_size, 91 | rank=args.rank, 92 | ) 93 | else: 94 | # DDP via torchrun, torch.distributed.launch 95 | args.local_rank, _, _ = world_info_from_env() 96 | torch.distributed.init_process_group( 97 | backend=args.dist_backend, 98 | init_method=args.dist_url) 99 | args.world_size = torch.distributed.get_world_size() 100 | args.rank = torch.distributed.get_rank() 101 | args.distributed = True 102 | 103 | if torch.cuda.is_available(): 104 | if args.distributed and not args.no_set_device_rank: 105 | device = 'cuda:%d' % args.local_rank 106 | else: 107 | device = 'cuda:0' 108 | torch.cuda.set_device(device) 109 | else: 110 | device = 'cpu' 111 | args.device = device 112 | device = torch.device(device) 113 | return device 114 | -------------------------------------------------------------------------------- /training/ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | use_gc=True, gc_conv_only=False 35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, 55 | eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | def __setstate__(self, state): 76 | super(Ranger, self).__setstate__(state) 77 | 78 | def step(self, closure=None): 79 | loss = None 80 | 81 | # Evaluate averages and grad, update param tensors 82 | for group in self.param_groups: 83 | 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | grad = p.grad.data.float() 88 | 89 | if grad.is_sparse: 90 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 91 | 92 | p_data_fp32 = p.data.float() 93 | 94 | state = self.state[p] # get state dict for this param 95 | 96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 97 | # if self.first_run_check==0: 98 | # self.first_run_check=1 99 | # print("Initializing slow buffer...should not see this at load from saved model!") 100 | state['step'] = 0 101 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 103 | 104 | # look ahead weight storage now in state dict 105 | state['slow_buffer'] = torch.empty_like(p.data) 106 | state['slow_buffer'].copy_(p.data) 107 | 108 | else: 109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 111 | 112 | # begin computations 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | # GC operation for Conv layers and FC layers 117 | if grad.dim() > self.gc_gradient_threshold: 118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 119 | 120 | state['step'] += 1 121 | 122 | # compute variance mov avg 123 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 124 | # compute mean moving avg 125 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 126 | 127 | buffered = self.radam_buffer[int(state['step'] % 10)] 128 | 129 | if state['step'] == buffered[0]: 130 | N_sma, step_size = buffered[1], buffered[2] 131 | else: 132 | buffered[0] = state['step'] 133 | beta2_t = beta2 ** state['step'] 134 | N_sma_max = 2 / (1 - beta2) - 1 135 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 136 | buffered[1] = N_sma 137 | if N_sma > self.N_sma_threshhold: 138 | step_size = math.sqrt( 139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 140 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 141 | else: 142 | step_size = 1.0 / (1 - beta1 ** state['step']) 143 | buffered[2] = step_size 144 | 145 | if group['weight_decay'] != 0: 146 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 147 | 148 | # apply lr 149 | if N_sma > self.N_sma_threshhold: 150 | denom = exp_avg_sq.sqrt().add_(group['eps']) 151 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 152 | else: 153 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 154 | 155 | p.data.copy_(p_data_fp32) 156 | 157 | # integrated look ahead... 158 | # we do it at the param level instead of group level 159 | if state['step'] % group['k'] == 0: 160 | slow_p = state['slow_buffer'] # get access to slow param tensor 161 | slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha 162 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor 163 | 164 | return loss -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumapowerLIU/CLCAE/55986ea75577a48a50e5f57302009dafbc0d8919/utils/__init__.py -------------------------------------------------------------------------------- /utils/alignment.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import PIL 5 | import PIL.Image 6 | import scipy 7 | import scipy.ndimage 8 | import dlib 9 | 10 | 11 | def get_landmark(filepath, predictor): 12 | """get landmark with dlib 13 | :return: np.array shape=(68, 2) 14 | """ 15 | detector = dlib.get_frontal_face_detector() 16 | 17 | img = dlib.load_rgb_image(filepath) 18 | dets = detector(img, 1) 19 | 20 | for k, d in enumerate(dets): 21 | shape = predictor(img, d) 22 | 23 | t = list(shape.parts()) 24 | a = [] 25 | for tt in t: 26 | a.append([tt.x, tt.y]) 27 | lm = np.array(a) 28 | return lm 29 | 30 | 31 | def align_face(filepath, predictor): 32 | """ 33 | :param filepath: str 34 | :return: PIL Image 35 | """ 36 | 37 | lm = get_landmark(filepath, predictor) 38 | 39 | lm_chin = lm[0: 17] # left-right 40 | lm_eyebrow_left = lm[17: 22] # left-right 41 | lm_eyebrow_right = lm[22: 27] # left-right 42 | lm_nose = lm[27: 31] # top-down 43 | lm_nostrils = lm[31: 36] # top-down 44 | lm_eye_left = lm[36: 42] # left-clockwise 45 | lm_eye_right = lm[42: 48] # left-clockwise 46 | lm_mouth_outer = lm[48: 60] # left-clockwise 47 | lm_mouth_inner = lm[60: 68] # left-clockwise 48 | 49 | # Calculate auxiliary vectors. 50 | eye_left = np.mean(lm_eye_left, axis=0) 51 | eye_right = np.mean(lm_eye_right, axis=0) 52 | eye_avg = (eye_left + eye_right) * 0.5 53 | eye_to_eye = eye_right - eye_left 54 | mouth_left = lm_mouth_outer[0] 55 | mouth_right = lm_mouth_outer[6] 56 | mouth_avg = (mouth_left + mouth_right) * 0.5 57 | eye_to_mouth = mouth_avg - eye_avg 58 | 59 | # Choose oriented crop rectangle. 60 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 61 | x /= np.hypot(*x) 62 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 63 | y = np.flipud(x) * [-1, 1] 64 | c = eye_avg + eye_to_mouth * 0.1 65 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 66 | qsize = np.hypot(*x) * 2 67 | 68 | # read image 69 | img = PIL.Image.open(filepath) 70 | 71 | output_size = 256 72 | transform_size = 256 73 | enable_padding = True 74 | 75 | # Shrink. 76 | shrink = int(np.floor(qsize / output_size * 0.5)) 77 | if shrink > 1: 78 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 79 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 80 | quad /= shrink 81 | qsize /= shrink 82 | 83 | # Crop. 84 | border = max(int(np.rint(qsize * 0.1)), 3) 85 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 86 | int(np.ceil(max(quad[:, 1])))) 87 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), 88 | min(crop[3] + border, img.size[1])) 89 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 90 | img = img.crop(crop) 91 | quad -= crop[0:2] 92 | 93 | # Pad. 94 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 95 | int(np.ceil(max(quad[:, 1])))) 96 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), 97 | max(pad[3] - img.size[1] + border, 0)) 98 | if enable_padding and max(pad) > border - 4: 99 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 100 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 101 | h, w, _ = img.shape 102 | y, x, _ = np.ogrid[:h, :w, :1] 103 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 104 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) 105 | blur = qsize * 0.02 106 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 107 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 108 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 109 | quad += pad[:2] 110 | 111 | # Transform. 112 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 113 | if output_size < transform_size: 114 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 115 | 116 | # Return aligned image. 117 | return img 118 | 119 | 120 | if __name__ == '__main__': 121 | import dlib 122 | 123 | predictor = dlib.shape_predictor( 124 | '/apdcephfs/share_1290939/kumamzqliu/checkpoint/train_pretrain_model/shape_predictor_68_face_landmarks.dat') 125 | 126 | 127 | def run_alignment(image_path): 128 | aligned_image = align_face(filepath=image_path, predictor=predictor) 129 | print("Aligned image has shape: {}".format(aligned_image.size)) 130 | return aligned_image 131 | 132 | 133 | base_path = '/apdcephfs/share_1290939/kumamzqliu/data/face_inversion/test' 134 | save_path = '/apdcephfs/share_1290939/kumamzqliu/data/face_inversion/test_align' 135 | path = [os.path.join(base_path, k) for k in os.listdir(base_path)] 136 | for i, path_each in enumerate(path): 137 | out = run_alignment(path_each ) 138 | out.save(os.path.join(save_path, f'{str(i).zfill(5)}.png')) 139 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | # Log images 8 | def log_input_image(x, opts): 9 | if opts.label_nc == 0: 10 | return tensor2im(x) 11 | else: 12 | return tensor2map(x) 13 | 14 | 15 | def tensor2im(var): 16 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 17 | var = ((var + 1) / 2) 18 | var[var < 0] = 0 19 | var[var > 1] = 1 20 | var = var * 255 21 | return Image.fromarray(var.astype('uint8')) 22 | 23 | 24 | def tensor2map(var): 25 | mask = np.argmax(var.data.cpu().numpy(), axis=0) 26 | colors = get_colors() 27 | mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3)) 28 | for class_idx in np.unique(mask): 29 | mask_image[mask == class_idx] = colors[class_idx] 30 | mask_image = mask_image.astype('uint8') 31 | return Image.fromarray(mask_image) 32 | 33 | 34 | 35 | 36 | 37 | # Visualization utils 38 | def get_colors(): 39 | # currently support up to 19 classes (for the celebs-hq-mask dataset) 40 | colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], 41 | [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], 42 | [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] 43 | return colors 44 | 45 | 46 | def vis_faces(log_hooks): 47 | display_count = len(log_hooks) 48 | fig = plt.figure(figsize=(16, 4 * display_count)) 49 | gs = fig.add_gridspec(display_count, 5) 50 | for i in range(display_count): 51 | hooks_dict = log_hooks[i] 52 | fig.add_subplot(gs[i, 0]) 53 | if 'diff_input' in hooks_dict: 54 | vis_faces_with_id(hooks_dict, fig, gs, i) 55 | else: 56 | vis_faces_no_id(hooks_dict, fig, gs, i) 57 | plt.tight_layout() 58 | return fig 59 | 60 | 61 | def vis_faces_with_id(hooks_dict, fig, gs, i): 62 | plt.imshow(hooks_dict['input_face']) 63 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 64 | fig.add_subplot(gs[i, 1]) 65 | plt.imshow(hooks_dict['target_face']) 66 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), 67 | float(hooks_dict['diff_target']))) 68 | fig.add_subplot(gs[i, 2]) 69 | plt.imshow(hooks_dict['output_face']) 70 | plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) 71 | 72 | fig.add_subplot(gs[i, 3]) 73 | plt.imshow(hooks_dict['output_face_w']) 74 | plt.title('Out_w') 75 | 76 | fig.add_subplot(gs[i, 4]) 77 | plt.imshow(hooks_dict['output_face_ww']) 78 | plt.title('Out_ww') 79 | 80 | def vis_faces_no_id(hooks_dict, fig, gs, i): 81 | plt.imshow(hooks_dict['input_face'], cmap="gray") 82 | plt.title('Input') 83 | fig.add_subplot(gs[i, 1]) 84 | plt.imshow(hooks_dict['target_face']) 85 | plt.title('Target') 86 | fig.add_subplot(gs[i, 2]) 87 | plt.imshow(hooks_dict['output_face']) 88 | plt.title('Output') 89 | fig.add_subplot(gs[i, 3]) 90 | plt.imshow(hooks_dict['output_face_w']) 91 | plt.title('Out_w') 92 | 93 | fig.add_subplot(gs[i, 4]) 94 | plt.imshow(hooks_dict['output_face_ww']) 95 | plt.title('Out_ww') -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD: 3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 4 | """ 5 | import os 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 20 | for root, _, fnames in sorted(os.walk(dir)): 21 | for fname in fnames: 22 | if is_image_file(fname): 23 | path = os.path.join(root, fname) 24 | images.append(path) 25 | return images 26 | -------------------------------------------------------------------------------- /utils/dataset_txt_generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | 5 | def sort(img_list): 6 | img_list.sort() 7 | img_list.sort(key=lambda x: int(x[:-4])) 8 | return img_list 9 | 10 | 11 | def data_generate(ffhq_train_path, sys_train_path, test_path, data_txt_save_path): 12 | train_txt = os.path.join(data_txt_save_path, "train_img.txt") 13 | val_txt = os.path.join(data_txt_save_path, "val_img.txt") 14 | pbar_ffhq = tqdm(total=len(os.listdir(ffhq_train_path))) 15 | pbar_sys = tqdm(total=len(os.listdir(sys_train_path))) 16 | pbar_test = tqdm(total=len(os.listdir(test_path))) 17 | ffhq_id_list = os.listdir(ffhq_train_path) 18 | ffhq_id_list.sort() 19 | # ffhq_id_list.remove('check_for_dataset.py') 20 | # ffhq_id_list.remove('LICENSE.txt') 21 | with open(train_txt, 'a') as train_txt_img: 22 | for i, image_name in enumerate(sort(os.listdir(sys_train_path))): 23 | image_path = os.path.join(sys_train_path, image_name) 24 | 25 | train_txt_img.write(image_path) 26 | train_txt_img.write('\n') 27 | pbar_sys.update() 28 | # for i, image_id in enumerate(ffhq_id_list): 29 | # id_file = os.path.join(ffhq_train_path, image_id) 30 | # for k, image_name in enumerate(sort(os.listdir(id_file))): 31 | # image_path = os.path.join(id_file, image_name) 32 | # train_txt_img.write(image_path) 33 | # train_txt_img.write('\n') 34 | # pbar_ffhq.update() 35 | 36 | for i, image_name in enumerate(ffhq_id_list): 37 | image_path = os.path.join(ffhq_train_path, image_name) 38 | train_txt_img.write(image_path) 39 | train_txt_img.write('\n') 40 | pbar_ffhq.update() 41 | train_txt_img.close() 42 | with open(val_txt, 'a') as val_txt_img: 43 | for i, image_name in enumerate(sort(os.listdir(test_path))): 44 | image_path = os.path.join(test_path, image_name) 45 | val_txt_img.write(image_path) 46 | val_txt_img.write('\n') 47 | pbar_test.update() 48 | val_txt_img.close() 49 | 50 | 51 | if __name__ == '__main__': 52 | ffhq_path = '/apdcephfs_cq2/share_1290939/liuhongyu/data/cars_train' 53 | sys_path = '/apdcephfs/share_1290939/kumamzqliu/data/car_inversion/train/image' 54 | val_path = '/apdcephfs/share_1290939/kumamzqliu/data/car_inversion/test/image' 55 | data_save = '/apdcephfs/share_1290939/kumamzqliu/data/car_inversion' 56 | data_generate(ffhq_path, sys_path, val_path, data_save) 57 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def aggregate_loss_dict(agg_loss_dict): 3 | mean_vals = {} 4 | for output in agg_loss_dict: 5 | for key in output: 6 | mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] 7 | for key in mean_vals: 8 | if len(mean_vals[key]) > 0: 9 | mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) 10 | else: 11 | print('{} has no value'.format(key)) 12 | mean_vals[key] = 0 13 | return mean_vals 14 | 15 | 16 | class AvgMeter: 17 | def __init__(self, name="Metric"): 18 | self.name = name 19 | self.reset() 20 | 21 | def reset(self): 22 | self.avg, self.sum, self.count = [0] * 3 23 | 24 | def update(self, val, count=1): 25 | self.count += count 26 | self.sum += val * count 27 | self.avg = self.sum / self.count 28 | 29 | def get(self): 30 | return self.avg 31 | def __repr__(self): 32 | text = f"{self.name}: {self.avg:.4f}" 33 | return text 34 | 35 | 36 | def get_lr(optimizer): 37 | for param_group in optimizer.param_groups: 38 | return param_group["lr"] 39 | 40 | def distributed_concat(tensor, num_total_examples): 41 | output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] 42 | torch.distributed.all_gather(output_tensors, tensor) 43 | concat = torch.cat(output_tensors, dim=0) 44 | # truncate the dummy elements added by SequentialDistributedSampler 45 | return concat[:num_total_examples] 46 | 47 | 48 | -------------------------------------------------------------------------------- /utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import numpy as np 4 | import wandb 5 | 6 | from utils import common 7 | 8 | 9 | class WBLogger: 10 | 11 | def __init__(self, opts): 12 | wandb_run_name = os.path.basename(opts.exp_dir) 13 | wandb.init(project="pixel2style2pixel", config=vars(opts), name=wandb_run_name) 14 | 15 | @staticmethod 16 | def log_best_model(): 17 | wandb.run.summary["best-model-save-time"] = datetime.datetime.now() 18 | 19 | @staticmethod 20 | def log(prefix, metrics_dict, global_step): 21 | log_dict = {f'{prefix}_{key}': value for key, value in metrics_dict.items()} 22 | log_dict["global_step"] = global_step 23 | wandb.log(log_dict) 24 | 25 | @staticmethod 26 | def log_dataset_wandb(dataset, dataset_name, n_images=16): 27 | idxs = np.random.choice(a=range(len(dataset)), size=n_images, replace=False) 28 | data = [wandb.Image(dataset.source_paths[idx]) for idx in idxs] 29 | wandb.log({f"{dataset_name} Data Samples": data}) 30 | 31 | @staticmethod 32 | def log_images_to_wandb(x, y, y_hat, id_logs, prefix, step, opts): 33 | im_data = [] 34 | column_names = ["Source", "Target", "Output"] 35 | if id_logs is not None: 36 | column_names.append("ID Diff Output to Target") 37 | for i in range(len(x)): 38 | cur_im_data = [ 39 | wandb.Image(common.log_input_image(x[i], opts)), 40 | wandb.Image(common.tensor2im(y[i])), 41 | wandb.Image(common.tensor2im(y_hat[i])), 42 | ] 43 | if id_logs is not None: 44 | cur_im_data.append(id_logs[i]["diff_target"]) 45 | im_data.append(cur_im_data) 46 | outputs_table = wandb.Table(data=im_data, columns=column_names) 47 | wandb.log({f"{prefix.title()} Step {step} Output Samples": outputs_table}) 48 | --------------------------------------------------------------------------------