├── .gitignore ├── LICENSE ├── README.md ├── dual_space_encoder.py ├── dual_space_encoder_test.py ├── environment.yaml ├── metrics ├── calc_inception.py ├── calc_prdc.py ├── celeba_hq_stats_256_29000.pkl ├── evaluate_query.py ├── fid_query.py ├── inception.py ├── inception_ffhq.pkl ├── lpips.py ├── lpips_weights.ckpt └── prdc.py ├── model_spatial_query.py ├── our_interfaceGAN ├── calculate_score.py ├── calculate_score_id.py ├── celebahq_utils │ └── dex │ │ ├── __init__.py │ │ ├── api.py │ │ └── networks │ │ ├── __init__.py │ │ ├── classifiers │ │ ├── attribute_classifier.py │ │ ├── attribute_utils.py │ │ ├── cifar10_resnet.py │ │ └── cifar10_utils.py │ │ ├── domain_classifier.py │ │ ├── domain_generator.py │ │ ├── perturb_settings.py │ │ └── stylegan2 │ │ └── stylegan2_networks.py ├── config_inversion │ ├── 0.json │ ├── 1.json │ ├── 12.json │ ├── 13.json │ ├── 3.json │ ├── 4.json │ ├── 8.json │ ├── 9.json │ ├── age.json │ ├── gender.json │ ├── pose.json │ └── seed.json ├── config_noinversion │ ├── 0.json │ ├── 1.json │ ├── 12.json │ ├── 13.json │ ├── 3.json │ ├── 8.json │ ├── 9.json │ ├── age.json │ ├── gender.json │ ├── pose.json │ ├── seed.json │ └── utils │ │ ├── edit_utils.py │ │ └── sample.py ├── edit_all_inversion_celebahq.py ├── edit_all_inversion_ffhq.py ├── edit_all_noinversion_celebahq.py ├── edit_all_noinversion_ffhq.py ├── editing_evaluate.py ├── editing_evaluate_id.py ├── ffhq_utils │ └── dex │ │ ├── __init__.py │ │ ├── api.py │ │ └── models.py ├── linear_change.py ├── linear_interpolation.py └── train_boundary.py ├── pSp ├── LICENSE ├── configs │ ├── data_configs.py │ ├── paths_config.py │ └── transforms_config.py ├── criteria │ ├── id_loss.py │ ├── lpips │ │ ├── lpips.py │ │ ├── networks.py │ │ └── utils.py │ └── w_norm.py ├── datasets │ ├── augmentations.py │ ├── gt_res_dataset.py │ ├── images_dataset.py │ └── inference_dataset.py ├── licenses │ ├── LICENSE_HuangYG123 │ ├── LICENSE_S-aiueo32 │ ├── LICENSE_TreB1eN │ ├── LICENSE_lessw2020 │ └── LICENSE_rosinality ├── models │ ├── encoders │ │ ├── helpers.py │ │ ├── model_irse.py │ │ ├── psp_encoders.py │ │ └── psp_encoders_new.py │ ├── mtcnn │ │ ├── mtcnn.py │ │ └── mtcnn_pytorch │ │ │ └── src │ │ │ ├── __init__.py │ │ │ ├── align_trans.py │ │ │ ├── box_utils.py │ │ │ ├── detector.py │ │ │ ├── first_stage.py │ │ │ ├── get_nets.py │ │ │ ├── matlab_cp2tform.py │ │ │ ├── visualization_utils.py │ │ │ └── weights │ │ │ ├── onet.npy │ │ │ ├── pnet.npy │ │ │ └── rnet.npy │ └── psp_new.py ├── options │ ├── test_options.py │ └── train_options.py ├── scripts │ ├── align_all_parallel.py │ ├── calc_id_loss_parallel.py │ ├── calc_losses_on_images.py │ ├── generate_sketch_data.py │ ├── inference.py │ ├── style_mixing.py │ └── train.py ├── training │ ├── coach_new.py │ └── ranger.py └── utils │ ├── common.py │ ├── data_utils.py │ └── train_utils.py ├── projection └── encoder_inversion │ ├── celebahq_encode │ ├── encoded_p.npy │ └── encoded_z.npy │ └── ffhq_encode │ ├── encoded_p.npy │ └── encoded_z.npy ├── projector_optimization.py ├── psp_spatial_train.py ├── psp_testing_options.py ├── psp_training_options.py ├── resources ├── Teaser_v2.1-1.png ├── edit_blackhair_celeba.png ├── edit_ffhq_pose.png ├── edit_gender_ffhq.png ├── edit_pose_ffhq.png ├── edit_smile_celeba.png ├── interp_content_celeba.png ├── interp_style_celeba.png ├── teaser.png └── teaser_change_order.png ├── test_spatial_query.py ├── train_spatial_query.py └── utils ├── dataset.py ├── dataset_projector.py ├── distributed.py ├── editing_utils.py ├── lpips ├── __init__.py ├── base_model.py ├── dist_model.py ├── networks_basic.py ├── pretrained_networks.py └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu └── sample.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yanbo Xu, Yueqin Yin, Liming Jiang, Qianyi Wu, Chengyao Zheng, Chen Change Loy, Bo Dai, Wayne Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dual_space_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | from pSp.models.psp_new import pSp 3 | 4 | 5 | def set_test_options(opts): 6 | opts.start_from_latent_avg = True 7 | 8 | 9 | return opts 10 | 11 | 12 | class DualSpaceEncoder(): 13 | def __init__(self, opts): 14 | self.device = 'cuda' 15 | self.opts = opts 16 | self.opts.device = self.device 17 | self.net = pSp(self.opts).to(self.device) 18 | self.net.eval() 19 | 20 | def encode(self, real_img): 21 | z_code, p_code = self.net(real_img, only_encode=True) 22 | return z_code, p_code 23 | 24 | def decode(self,z_code,p_code, plus_sapce=True): 25 | if plus_sapce: 26 | images, _, _ = self.net.decoder(z_code, p_code, 27 | use_spatial_mapping=False,use_style_mapping=False, 28 | return_latents=False) 29 | else: 30 | images = self.net.decoder(z_code, p_code, return_latents=False) 31 | 32 | return images 33 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: transeditor 2 | channels: 3 | - anaconda 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_gnu 10 | - _pytorch_select=0.2=gpu_0 11 | - absl-py=0.13.0=pyhd8ed1ab_0 12 | - aiohttp=3.7.4.post0=py38h497a2fe_0 13 | - async-timeout=3.0.1=py_1000 14 | - attrs=21.2.0=pyhd8ed1ab_0 15 | - backports=1.0=py_2 16 | - backports.functools_lru_cache=1.6.1=py_0 17 | - blas=1.0=mkl 18 | - blinker=1.4=py_1 19 | - brotlipy=0.7.0=py38h497a2fe_1001 20 | - bzip2=1.0.8=h7f98852_4 21 | - c-ares=1.17.1=h36c2ea0_0 22 | - ca-certificates=2021.5.30=ha878542_0 23 | - cachetools=4.2.2=pyhd8ed1ab_0 24 | - cairo=1.16.0=h7979940_1007 25 | - certifi=2021.5.30=py38h578d9bd_0 26 | - cffi=1.14.5=py38h261ae71_0 27 | - chardet=4.0.0=py38h578d9bd_1 28 | - charset-normalizer=2.0.0=pyhd8ed1ab_0 29 | - click=8.0.1=py38h578d9bd_0 30 | - cryptography=3.4.7=py38ha5dfef3_0 31 | - cudatoolkit=10.1.243=h036e899_8 32 | - cudnn=7.6.5.32=hc0a50b0_1 33 | - dbus=1.13.6=hfdff14a_1 34 | - easydict=1.9=py_0 35 | - expat=2.2.10=h9c3ff4c_0 36 | - ffmpeg=4.3.1=hca11adc_2 37 | - fontconfig=2.13.1=hba837de_1004 38 | - freetype=2.10.4=h0708190_1 39 | - gettext=0.19.8.1=h0b5b191_1005 40 | - glib=2.66.7=h9c3ff4c_1 41 | - glib-tools=2.66.7=h9c3ff4c_1 42 | - gmp=6.2.1=h58526e2_0 43 | - gnutls=3.6.13=h85f3911_1 44 | - google-auth=1.33.1=pyh6c4a22f_0 45 | - google-auth-oauthlib=0.4.1=py_2 46 | - graphite2=1.3.13=h58526e2_1001 47 | - grpcio=1.38.1=py38hdd6454d_0 48 | - gst-plugins-base=1.18.3=h04508c2_0 49 | - gstreamer=1.18.3=h3560a44_0 50 | - harfbuzz=2.7.4=h5cf4720_0 51 | - hdf5=1.10.6=nompi_h3c11f04_101 52 | - icu=68.1=h58526e2_0 53 | - idna=3.1=pyhd3deb0d_0 54 | - importlib-metadata=2.0.0=py_1 55 | - intel-openmp=2019.4=243 56 | - jasper=1.900.1=h07fcdf6_1006 57 | - jpeg=9d=h36c2ea0_0 58 | - krb5=1.17.2=h926e7f8_0 59 | - lame=3.100=h7f98852_1001 60 | - lcms2=2.11=h396b838_0 61 | - ld_impl_linux-64=2.33.1=h53a641e_7 62 | - libblas=3.8.0=21_mkl 63 | - libcblas=3.8.0=21_mkl 64 | - libclang=11.1.0=default_ha53f305_0 65 | - libcurl=7.71.1=hcdd3856_8 66 | - libedit=3.1.20191231=h14c3975_1 67 | - libev=4.33=h516909a_1 68 | - libevent=2.1.10=hcdb4288_3 69 | - libffi=3.3=he6710b0_2 70 | - libgcc-ng=9.3.0=h2828fa1_18 71 | - libgfortran-ng=7.3.0=hdf63c60_0 72 | - libglib=2.66.7=h3e27bee_1 73 | - libgomp=9.3.0=h2828fa1_18 74 | - libiconv=1.16=h516909a_0 75 | - liblapack=3.8.0=21_mkl 76 | - liblapacke=3.8.0=21_mkl 77 | - libllvm11=11.1.0=hf817b99_0 78 | - libmklml=2019.0.5=0 79 | - libnghttp2=1.43.0=h812cca2_0 80 | - libopencv=4.5.1=py38h703c3c0_0 81 | - libpng=1.6.37=h21135ba_2 82 | - libpq=13.1=hfd2b0eb_2 83 | - libprotobuf=3.14.0=h8c45485_0 84 | - libssh2=1.9.0=hab1572f_5 85 | - libstdcxx-ng=9.3.0=h6de172a_18 86 | - libtiff=4.2.0=hdc55705_0 87 | - libuuid=2.32.1=h7f98852_1000 88 | - libuv=1.40.0=h7b6447c_0 89 | - libwebp-base=1.2.0=h7f98852_0 90 | - libxcb=1.13=h7f98852_1003 91 | - libxkbcommon=1.0.3=he3ba5ed_0 92 | - libxml2=2.9.10=h72842e0_3 93 | - lz4-c=1.9.3=h9c3ff4c_0 94 | - markdown=3.3.4=pyhd8ed1ab_0 95 | - mkl=2020.2=256 96 | - mkl-service=2.3.0=py38he904b0f_0 97 | - mkl_fft=1.3.0=py38h54f3939_0 98 | - mkl_random=1.1.1=py38h0573a6f_0 99 | - multidict=5.1.0=py38h497a2fe_1 100 | - mysql-common=8.0.23=ha770c72_1 101 | - mysql-libs=8.0.23=h935591d_1 102 | - ncurses=6.2=he6710b0_1 103 | - nettle=3.6=he412f7d_0 104 | - ninja=1.10.2=py38hff7bd54_0 105 | - nspr=4.29=h9c3ff4c_1 106 | - nss=3.62=hb5efdd6_0 107 | - numpy=1.19.2=py38h54aff64_0 108 | - numpy-base=1.19.2=py38hfa32c7d_0 109 | - oauthlib=3.1.1=pyhd8ed1ab_0 110 | - olefile=0.46=py_0 111 | - opencv=4.5.1=py38h578d9bd_0 112 | - openh264=2.1.1=h780b84a_0 113 | - openssl=1.1.1k=h7f98852_0 114 | - pcre=8.44=he1b5a44_0 115 | - pillow=8.1.1=py38he98fc37_0 116 | - pip=21.0.1=py38h06a4308_0 117 | - pixman=0.40.0=h36c2ea0_0 118 | - prettytable=2.1.0=pyhd8ed1ab_0 119 | - protobuf=3.14.0=py38h2531618_1 120 | - pthread-stubs=0.4=h36c2ea0_1001 121 | - py-opencv=4.5.1=py38h81c977d_0 122 | - pyasn1=0.4.8=py_0 123 | - pyasn1-modules=0.2.7=py_0 124 | - pycparser=2.20=py_2 125 | - pyjwt=2.1.0=pyhd8ed1ab_0 126 | - pyopenssl=20.0.1=pyhd8ed1ab_0 127 | - pysocks=1.7.1=py38h578d9bd_3 128 | - python=3.8.8=hdb3f193_4 129 | - python_abi=3.8=1_cp38 130 | - pytorch=1.7.0=py3.8_cuda10.1.243_cudnn7.6.3_0 131 | - pyu2f=0.1.5=pyhd8ed1ab_0 132 | - qt=5.12.9=hda022c4_4 133 | - readline=8.1=h27cfd23_0 134 | - requests=2.26.0=pyhd8ed1ab_0 135 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 136 | - rsa=4.7.2=pyh44b312d_0 137 | - scipy=1.6.1=py38h91f5cce_0 138 | - setuptools=52.0.0=py38h06a4308_0 139 | - six=1.15.0=pyhd3eb1b0_0 140 | - sqlite=3.34.0=h74cdb3f_0 141 | - tensorboard=2.5.0=pyhd8ed1ab_0 142 | - tensorboard-data-server=0.6.0=py38h2b97feb_0 143 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 144 | - tensorboardx=2.1=py_0 145 | - tk=8.6.10=hbc83047_0 146 | - torchvision=0.8.1=py38_cu101 147 | - tqdm=4.56.0=pyhd3eb1b0_0 148 | - typing-extensions=3.7.4.3=hd3eb1b0_0 149 | - typing_extensions=3.7.4.3=pyh06a4308_0 150 | - urllib3=1.26.6=pyhd8ed1ab_0 151 | - wcwidth=0.2.5=pyh9f0ad1d_2 152 | - werkzeug=2.0.1=pyhd8ed1ab_0 153 | - wheel=0.36.2=pyhd3eb1b0_0 154 | - x264=1!161.3030=h7f98852_0 155 | - xorg-kbproto=1.0.7=h7f98852_1002 156 | - xorg-libice=1.0.10=h516909a_0 157 | - xorg-libsm=1.2.3=h84519dc_1000 158 | - xorg-libx11=1.6.12=h516909a_0 159 | - xorg-libxau=1.0.9=h7f98852_0 160 | - xorg-libxdmcp=1.1.3=h7f98852_0 161 | - xorg-libxext=1.3.4=h516909a_0 162 | - xorg-libxrender=0.9.10=h516909a_1002 163 | - xorg-renderproto=0.11.1=h14c3975_1002 164 | - xorg-xextproto=7.3.0=h7f98852_1002 165 | - xorg-xproto=7.0.31=h7f98852_1007 166 | - xz=5.2.5=h7b6447c_0 167 | - yaml=0.2.5=h7b6447c_0 168 | - yarl=1.6.3=py38h497a2fe_2 169 | - zipp=3.4.1=pyhd8ed1ab_0 170 | - zlib=1.2.11=h7b6447c_3 171 | - zstd=1.4.8=ha95c52a_1 172 | - pip: 173 | - backcall==0.2.0 174 | - cycler==0.10.0 175 | - dataclasses==0.6 176 | - decorator==4.4.2 177 | - future==0.18.2 178 | - imageio==2.9.0 179 | - ipdb==0.13.9 180 | - ipython==7.26.0 181 | - ipython-genutils==0.2.0 182 | - jedi==0.18.0 183 | - joblib==1.0.1 184 | - kiwisolver==1.3.1 185 | - lmdb==1.2.1 186 | - matplotlib==3.4.2 187 | - matplotlib-inline==0.1.2 188 | - munch==2.5.0 189 | - networkx==2.5.1 190 | - parso==0.8.2 191 | - pexpect==4.8.0 192 | - pickleshare==0.7.5 193 | - prompt-toolkit==3.0.19 194 | - ptyprocess==0.7.0 195 | - pygments==2.9.0 196 | - pyparsing==2.4.7 197 | - python-dateutil==2.8.1 198 | - pywavelets==1.1.1 199 | - pyyaml==5.4.1 200 | - scikit-image==0.18.1 201 | - scikit-learn==1.0 202 | - threadpoolctl==2.2.0 203 | - tifffile==2021.6.14 204 | - toml==0.10.2 205 | - traitlets==5.0.5 206 | -------------------------------------------------------------------------------- /metrics/calc_inception.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from torchvision.models import Inception3 12 | from tqdm import tqdm 13 | 14 | from metrics.inception import InceptionV3 15 | from utils.dataset import MultiResolutionDataset 16 | 17 | 18 | class Inception3Feature(Inception3): 19 | def forward(self, x): 20 | if x.shape[2] != 299 or x.shape[3] != 299: 21 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 22 | 23 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 24 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 25 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 26 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 27 | 28 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 29 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 30 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 31 | 32 | x = self.Mixed_5b(x) # 35 x 35 x 192 33 | x = self.Mixed_5c(x) # 35 x 35 x 256 34 | x = self.Mixed_5d(x) # 35 x 35 x 288 35 | 36 | x = self.Mixed_6a(x) # 35 x 35 x 288 37 | x = self.Mixed_6b(x) # 17 x 17 x 768 38 | x = self.Mixed_6c(x) # 17 x 17 x 768 39 | x = self.Mixed_6d(x) # 17 x 17 x 768 40 | x = self.Mixed_6e(x) # 17 x 17 x 768 41 | 42 | x = self.Mixed_7a(x) # 17 x 17 x 768 43 | x = self.Mixed_7b(x) # 8 x 8 x 1280 44 | x = self.Mixed_7c(x) # 8 x 8 x 2048 45 | 46 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 47 | 48 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 49 | 50 | 51 | def load_patched_inception_v3(): 52 | # inception = inception_v3(pretrained=True) 53 | # inception_feat = Inception3Feature() 54 | # inception_feat.load_state_dict(inception.state_dict()) 55 | inception_feat = InceptionV3([3], normalize_input=False) 56 | 57 | return inception_feat 58 | 59 | 60 | @torch.no_grad() 61 | def extract_features(loader, inception, device): 62 | pbar = tqdm(loader) 63 | 64 | feature_list = [] 65 | 66 | for img in pbar: 67 | img = img.to(device) 68 | feature = inception(img)[0].view(img.shape[0], -1) 69 | feature_list.append(feature.to('cpu')) 70 | 71 | features = torch.cat(feature_list, 0) 72 | 73 | return features 74 | 75 | 76 | if __name__ == '__main__': 77 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 78 | 79 | parser = argparse.ArgumentParser( 80 | description='Calculate Inception v3 features for datasets' 81 | ) 82 | parser.add_argument('--size', type=int, default=256) 83 | parser.add_argument('--batch', default=64, type=int, help='batch size') 84 | parser.add_argument('--n_sample', type=int, default=50000) 85 | parser.add_argument('--flip', action='store_true') 86 | parser.add_argument('path', metavar='PATH', help='path to datset lmdb file') 87 | 88 | args = parser.parse_args() 89 | 90 | inception = load_patched_inception_v3() 91 | inception = nn.DataParallel(inception).eval().to(device) 92 | 93 | transform = transforms.Compose( 94 | [ 95 | transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), 96 | transforms.ToTensor(), 97 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 98 | ] 99 | ) 100 | 101 | dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) 102 | loader = DataLoader(dset, batch_size=args.batch, num_workers=4) 103 | 104 | features = extract_features(loader, inception, device).numpy() 105 | 106 | features = features[: args.n_sample] 107 | 108 | print(f'extracted {features.shape[0]} features') 109 | 110 | mean = np.mean(features, 0) 111 | cov = np.cov(features, rowvar=False) 112 | 113 | name = os.path.splitext(os.path.basename(args.path))[0] 114 | 115 | with open(f'inception_{name}.pkl', 'wb') as f: 116 | pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f) 117 | -------------------------------------------------------------------------------- /metrics/calc_prdc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torch import nn 6 | from torch.utils import data 7 | from torchvision import transforms, models 8 | from tqdm import tqdm 9 | 10 | from metrics.prdc import compute_prdc 11 | from model import Generator 12 | from train import data_sampler, sample_data 13 | from utils.dataset import MultiResolutionDataset 14 | 15 | 16 | @torch.no_grad() 17 | def extract_feature_from_samples(generator, vgg, batch_size, n_sample, device): 18 | n_batch = n_sample // batch_size 19 | resid = n_sample - (n_batch * batch_size) 20 | if resid == 0: 21 | batch_sizes = [batch_size] * n_batch 22 | else: 23 | batch_sizes = [batch_size] * n_batch + [resid] 24 | features = [] 25 | 26 | for batch in tqdm(batch_sizes): 27 | latent = torch.randn(batch, 14, 512, device=device) 28 | img, _ = generator(latent) 29 | feat = vgg(img) 30 | features.append(feat.to('cpu')) 31 | 32 | features = torch.cat(features, 0) 33 | 34 | return features 35 | 36 | 37 | @torch.no_grad() 38 | def extract_feature_from_data(dataset, vgg, batch_size, n_sample, device): 39 | n_batch = n_sample // batch_size 40 | resid = n_sample - (n_batch * batch_size) 41 | if resid == 0: 42 | batch_sizes = [batch_size] * n_batch 43 | else: 44 | batch_sizes = [batch_size] * n_batch + [resid] 45 | features = [] 46 | 47 | for batch in tqdm(batch_sizes): 48 | loader = data.DataLoader( 49 | dataset, 50 | batch_size=batch, 51 | sampler=data_sampler(dataset, shuffle=True, distributed=False), 52 | drop_last=True, 53 | ) 54 | loader = sample_data(loader) 55 | img = next(loader).to(device) 56 | feat = vgg(img) 57 | features.append(feat.to('cpu')) 58 | 59 | features = torch.cat(features, 0) 60 | 61 | return features 62 | 63 | 64 | if __name__ == '__main__': 65 | device = 'cuda' 66 | 67 | parser = argparse.ArgumentParser() 68 | 69 | parser.add_argument('--n_sample', type=int, default=50000) 70 | parser.add_argument('--start_num', type=int, default=0) 71 | parser.add_argument('--size', type=int, default=256) 72 | parser.add_argument('--batch', type=int, default=64) 73 | parser.add_argument('--ckpt', default='./checkpoint') 74 | parser.add_argument('--dataset', type=str, required=True) 75 | 76 | args = parser.parse_args() 77 | 78 | nearest_k = 3 79 | resize = min(args.size, 256) 80 | 81 | if os.path.isdir(args.ckpt): 82 | files = os.listdir(args.ckpt) 83 | ckpt = sorted([os.path.join(args.ckpt, x) for x in files]) 84 | ckpt = list(filter(lambda x: int(x.split('/')[-1].split('.')[0]) >= args.start_num, ckpt)) 85 | print(args.ckpt) 86 | else: 87 | ckpt = [args.ckpt] 88 | 89 | print(ckpt) 90 | 91 | transform = transforms.Compose( 92 | [ 93 | transforms.Resize(resize), 94 | transforms.CenterCrop(resize), 95 | transforms.ToTensor(), 96 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 97 | ] 98 | ) 99 | dataset = MultiResolutionDataset(args.dataset, transform, args.size) 100 | 101 | model_vgg16 = models.vgg16(pretrained=True) 102 | model_vgg16.classifier = model_vgg16.classifier[:-1] 103 | model_vgg16 = nn.DataParallel(model_vgg16).to(device) 104 | model_vgg16.eval() 105 | 106 | for model_path in ckpt: 107 | iteration = int(os.path.splitext(os.path.basename(model_path))[0]) 108 | print(f'Iteration = {iteration}') 109 | 110 | g = Generator(args.size, 512, 8).to(device) 111 | model = torch.load(model_path, map_location='cpu') 112 | g.load_state_dict(model['g_ema']) 113 | g = nn.DataParallel(g) 114 | g.eval() 115 | 116 | fake_features = extract_feature_from_samples(g, model_vgg16, args.batch, args.n_sample, device).numpy() 117 | print(f'extracted {fake_features.shape[0]} fake features') 118 | real_features = extract_feature_from_data(dataset, model_vgg16, args.batch, args.n_sample, device).numpy() 119 | print(f'extracted {real_features.shape[0]} real features') 120 | 121 | metrics = compute_prdc(real_features=real_features, fake_features=fake_features, nearest_k=nearest_k) 122 | print(metrics) 123 | -------------------------------------------------------------------------------- /metrics/celeba_hq_stats_256_29000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/metrics/celeba_hq_stats_256_29000.pkl -------------------------------------------------------------------------------- /metrics/inception_ffhq.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/metrics/inception_ffhq.pkl -------------------------------------------------------------------------------- /metrics/lpips.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision import models 14 | 15 | 16 | def normalize(x, eps=1e-10): 17 | return x * torch.rsqrt(torch.sum(x**2, dim=1, keepdim=True) + eps) 18 | 19 | 20 | class AlexNet(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.layers = models.alexnet(pretrained=True).features 24 | self.channels = [] 25 | for layer in self.layers: 26 | if isinstance(layer, nn.Conv2d): 27 | self.channels.append(layer.out_channels) 28 | 29 | def forward(self, x): 30 | fmaps = [] 31 | for layer in self.layers: 32 | x = layer(x) 33 | if isinstance(layer, nn.ReLU): 34 | fmaps.append(x) 35 | return fmaps 36 | 37 | 38 | class Conv1x1(nn.Module): 39 | def __init__(self, in_channels, out_channels=1): 40 | super().__init__() 41 | self.main = nn.Sequential( 42 | nn.Dropout(0.5), 43 | nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)) 44 | 45 | def forward(self, x): 46 | return self.main(x) 47 | 48 | 49 | class LPIPS(nn.Module): 50 | def __init__(self): 51 | super().__init__() 52 | self.alexnet = AlexNet() 53 | self.lpips_weights = nn.ModuleList() 54 | for channels in self.alexnet.channels: 55 | self.lpips_weights.append(Conv1x1(channels, 1)) 56 | self._load_lpips_weights() 57 | # imagenet normalization for range [-1, 1] 58 | self.mu = torch.tensor([-0.03, -0.088, -0.188]).view(1, 3, 1, 1).cuda() 59 | self.sigma = torch.tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1).cuda() 60 | 61 | def _load_lpips_weights(self): 62 | own_state_dict = self.state_dict() 63 | if torch.cuda.is_available(): 64 | state_dict = torch.load('metrics/lpips_weights.ckpt') 65 | else: 66 | state_dict = torch.load('metrics/lpips_weights.ckpt', 67 | map_location=torch.device('cpu')) 68 | for name, param in state_dict.items(): 69 | if name in own_state_dict: 70 | own_state_dict[name].copy_(param) 71 | 72 | def forward(self, x, y): 73 | x = (x - self.mu) / self.sigma 74 | y = (y - self.mu) / self.sigma 75 | x_fmaps = self.alexnet(x) 76 | y_fmaps = self.alexnet(y) 77 | lpips_value = 0 78 | for x_fmap, y_fmap, conv1x1 in zip(x_fmaps, y_fmaps, self.lpips_weights): 79 | x_fmap = normalize(x_fmap) 80 | y_fmap = normalize(y_fmap) 81 | lpips_value += torch.mean(conv1x1((x_fmap - y_fmap)**2)) 82 | return lpips_value 83 | 84 | 85 | @torch.no_grad() 86 | def calculate_lpips_given_images(group_of_images): 87 | # group_of_images = [torch.randn(N, C, H, W) for _ in range(10)] 88 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 89 | lpips = LPIPS().eval().to(device) 90 | lpips_values = [] 91 | num_rand_outputs = len(group_of_images) 92 | 93 | # calculate the average of pairwise distances among all random outputs 94 | for i in range(num_rand_outputs-1): 95 | for j in range(i+1, num_rand_outputs): 96 | lpips_values.append(lpips(group_of_images[i], group_of_images[j])) 97 | lpips_value = torch.mean(torch.stack(lpips_values, dim=0)) 98 | return lpips_value.item() -------------------------------------------------------------------------------- /metrics/lpips_weights.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/metrics/lpips_weights.ckpt -------------------------------------------------------------------------------- /metrics/prdc.py: -------------------------------------------------------------------------------- 1 | """ 2 | prdc 3 | Copyright (c) 2020-present NAVER Corp. 4 | MIT license 5 | """ 6 | import numpy as np 7 | import sklearn.metrics 8 | 9 | __all__ = ['compute_prdc'] 10 | 11 | 12 | def compute_pairwise_distance(data_x, data_y=None): 13 | """ 14 | Args: 15 | data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) 16 | data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) 17 | Returns: 18 | numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. 19 | """ 20 | if data_y is None: 21 | data_y = data_x 22 | dists = sklearn.metrics.pairwise_distances( 23 | data_x, data_y, metric='euclidean', n_jobs=8) 24 | return dists 25 | 26 | 27 | def get_kth_value(unsorted, k, axis=-1): 28 | """ 29 | Args: 30 | unsorted: numpy.ndarray of any dimensionality. 31 | k: int 32 | Returns: 33 | kth values along the designated axis. 34 | """ 35 | indices = np.argpartition(unsorted, k, axis=axis)[..., :k] 36 | k_smallests = np.take_along_axis(unsorted, indices, axis=axis) 37 | kth_values = k_smallests.max(axis=axis) 38 | return kth_values 39 | 40 | 41 | def compute_nearest_neighbour_distances(input_features, nearest_k): 42 | """ 43 | Args: 44 | input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 45 | nearest_k: int 46 | Returns: 47 | Distances to kth nearest neighbours. 48 | """ 49 | distances = compute_pairwise_distance(input_features) 50 | radii = get_kth_value(distances, k=nearest_k + 1, axis=-1) 51 | return radii 52 | 53 | 54 | def compute_prdc(real_features, fake_features, nearest_k): 55 | """ 56 | Computes precision, recall, density, and coverage given two manifolds. 57 | Args: 58 | real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 59 | fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 60 | nearest_k: int. 61 | Returns: 62 | dict of precision, recall, density, and coverage. 63 | """ 64 | 65 | print('Num real: {} Num fake: {}' 66 | .format(real_features.shape[0], fake_features.shape[0])) 67 | 68 | real_nearest_neighbour_distances = compute_nearest_neighbour_distances( 69 | real_features, nearest_k) 70 | fake_nearest_neighbour_distances = compute_nearest_neighbour_distances( 71 | fake_features, nearest_k) 72 | distance_real_fake = compute_pairwise_distance( 73 | real_features, fake_features) 74 | 75 | precision = ( 76 | distance_real_fake < 77 | np.expand_dims(real_nearest_neighbour_distances, axis=1) 78 | ).any(axis=0).mean() 79 | 80 | recall = ( 81 | distance_real_fake < 82 | np.expand_dims(fake_nearest_neighbour_distances, axis=0) 83 | ).any(axis=1).mean() 84 | 85 | density = (1. / float(nearest_k)) * ( 86 | distance_real_fake < 87 | np.expand_dims(real_nearest_neighbour_distances, axis=1) 88 | ).sum(axis=0).mean() 89 | 90 | coverage = ( 91 | distance_real_fake.min(axis=1) < 92 | real_nearest_neighbour_distances 93 | ).mean() 94 | 95 | return dict(precision=precision, recall=recall, 96 | density=density, coverage=coverage) 97 | -------------------------------------------------------------------------------- /our_interfaceGAN/calculate_score.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | import numpy as np 5 | from scipy import spatial 6 | 7 | def calculate_cos_score(boundrary_1, boundrary_2): 8 | return 1 - spatial.distance.cosine(boundrary_1, boundrary_2) 9 | 10 | if __name__ == '__main__': 11 | 12 | # attr_change = "pose" 13 | # attr_change = "Male" 14 | # attr_change = "Smiling" 15 | attr_change = "gender" 16 | # attr_change = "pose" 17 | # attr_change = "Wavy_Hair" 18 | # attr_change = "Blond_Hair" 19 | # attr_change = "pose" 20 | # attr_change = "Bangs" 21 | 22 | space_list = ["p","pz","z"] 23 | method_list = ["stylegan2-pytorch","StyleMapGAN", "DiagonalGAN", "Controllable-Face-Generation",] 24 | test_attribute_list = ["Male", "Smiling", "pose"] 25 | list_1 = [] 26 | list_2 = [] 27 | 28 | load_dict = {} 29 | for method in method_list: 30 | eval_base_dir = os.path.join("/mnt/lustre/xuyanbo",method, "editing_evaluation") 31 | eval_dir = os.path.join(eval_base_dir, attr_change) 32 | load_dict[method]=np.load(os.path.join(eval_dir, "test_dict_softmax.npy"), allow_pickle=True) 33 | 34 | 35 | for attr_interest in test_attribute_list: 36 | for method in method_list: 37 | if method in ["stylegan2-pytorch", "StyleMapGAN"]: 38 | space_list = ["z"] 39 | else: 40 | space_list = ["p","pz","z"] 41 | 42 | result_list = load_dict[method] 43 | 44 | for space in space_list: 45 | attr_change_score = [] 46 | attr_interest_score = [] 47 | delta_sum_change_pos = 0 48 | delta_sum_change_neg = 0 49 | delta_sum_interest_pos = 0 50 | delta_sum_interest_neg = 0 51 | 52 | for i in range(len(result_list)): 53 | delta_sum_change_pos += np.sum(np.array(result_list[i][attr_change][space][4:7])-np.array(result_list[i][attr_change][space][3:6])) 54 | delta_sum_interest_pos += np.sum(np.array(result_list[i][attr_interest][space][4:7])-np.array(result_list[i][attr_interest][space][3:6])) 55 | delta_sum_change_neg += np.sum(np.array(result_list[i][attr_change][space][0:3])-np.array(result_list[i][attr_change][space][1:4])) 56 | delta_sum_interest_neg += np.sum(np.array(result_list[i][attr_interest][space][0:3])-np.array(result_list[i][attr_interest][space][1:4])) 57 | attr_change_score.append(result_list[i][attr_change][space]) 58 | attr_interest_score.append(result_list[i][attr_interest][space]) 59 | 60 | delta_sum_change_pos /= len(result_list) 61 | delta_sum_interest_pos /= len(result_list) 62 | 63 | delta_sum_change_neg /= len(result_list) 64 | delta_sum_interest_neg /= len(result_list) 65 | 66 | attr_change_score = np.concatenate(attr_change_score) 67 | attr_interest_score = np.concatenate(attr_interest_score) 68 | corralation = np.corrcoef(attr_change_score, attr_interest_score) 69 | 70 | 71 | result = (abs(delta_sum_interest_pos/delta_sum_change_pos) + abs(delta_sum_interest_neg/delta_sum_change_neg))/2 72 | print(method, attr_change, attr_interest, space, result) 73 | -------------------------------------------------------------------------------- /our_interfaceGAN/calculate_score_id.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | import numpy as np 5 | from scipy import spatial 6 | 7 | def calculate_cos_score(boundrary_1, boundrary_2): 8 | return spatial.distance.cosine(boundrary_1, boundrary_2) 9 | 10 | if __name__ == '__main__': 11 | 12 | 13 | 14 | # import pdb; pdb.set_trace() 15 | attr_change = "pose" 16 | # attr_change = "Male" 17 | # attr_change = "Smiling" 18 | # attr_change = "gender" 19 | # attr_change = "Smiling" 20 | # attr_change = "pose" 21 | # attr_change = "Wavy_Hair" 22 | # attr_change = "Blond_Hair" 23 | # attr_change = "age" 24 | # attr_change = "Bangs" 25 | # ttr_change = "Black_Hair" 26 | space_list = ["p","pz","z"] 27 | 28 | method_list = ["stylegan2-pytorch","StyleMapGAN", "DiagonalGAN", "Controllable-Face-Generation",] 29 | test_attribute_list = ["id"] 30 | 31 | list_1 = [] 32 | list_2 = [] 33 | 34 | load_dict = {} 35 | for method in method_list: 36 | eval_base_dir = os.path.join("/mnt/lustre/xuyanbo",method, "editing_evaluation") 37 | eval_dir = os.path.join(eval_base_dir, attr_change) 38 | load_dict[method+"_id"]= np.load(os.path.join(eval_dir, "test_dict_id.npy"), allow_pickle=True) 39 | 40 | load_dict[method]=np.load(os.path.join(eval_dir, "test_dict_softmax.npy"), allow_pickle=True) 41 | 42 | 43 | 44 | for attr_interest in test_attribute_list: 45 | for method in method_list: 46 | if method in ["stylegan2-pytorch", "StyleMapGAN"]: 47 | space_list = ["z"] 48 | else: 49 | space_list = ["p","pz","z"] 50 | 51 | result_list = load_dict[method] 52 | result_list_id = load_dict[method+"_id"] 53 | 54 | for space in space_list: 55 | # attr_change_score = [] 56 | # attr_interest_score = [] 57 | 58 | delta_sum_change_pos = 0 59 | delta_sum_change_neg = 0 60 | delta_sum_id_pos = 0 61 | delta_sum_id_neg = 0 62 | 63 | for i in range(len(result_list)): 64 | delta_sum_change_pos += np.sum(np.array(result_list[i][attr_change][space][6])-np.array(result_list[i][attr_change][space][3])) 65 | delta_sum_change_neg += np.sum(np.array(result_list[i][attr_change][space][0])-np.array(result_list[i][attr_change][space][3])) 66 | delta_sum_id_pos += calculate_cos_score(boundrary_1=np.array(result_list_id[i][attr_interest][space][6]), 67 | boundrary_2=np.array(result_list_id[i][attr_interest][space][3])) 68 | 69 | delta_sum_id_neg += calculate_cos_score(boundrary_1=np.array(result_list_id[i][attr_interest][space][0]), 70 | boundrary_2=np.array(result_list_id[i][attr_interest][space][3])) 71 | 72 | delta_sum_change_pos += np.sum(np.array(result_list[i][attr_change][space][4:7])-np.array(result_list[i][attr_change][space][3:6])) 73 | delta_sum_change_neg += np.sum(np.array(result_list[i][attr_change][space][0:3])-np.array(result_list[i][attr_change][space][1:4])) 74 | 75 | for j in range(3): 76 | delta_sum_id_pos += calculate_cos_score(boundrary_1=np.array(result_list_id[i][attr_interest][space][4+j]), 77 | boundrary_2=np.array(result_list_id[i][attr_interest][space][3+j])) 78 | delta_sum_id_neg += calculate_cos_score(boundrary_1=np.array(result_list_id[i][attr_interest][space][0+j]), 79 | boundrary_2=np.array(result_list_id[i][attr_interest][space][1+j])) 80 | 81 | 82 | 83 | delta_sum_change_pos /= len(result_list) 84 | delta_sum_id_pos /= len(result_list) 85 | 86 | delta_sum_change_neg /= len(result_list) 87 | delta_sum_id_neg /= len(result_list) 88 | 89 | 90 | result = (abs(delta_sum_id_pos/delta_sum_change_pos) + abs(delta_sum_id_neg/delta_sum_change_neg))/2 91 | print(method, attr_change, attr_interest, space, result) 92 | print("detail:", (abs(delta_sum_id_pos)+abs(delta_sum_id_neg))/2, (abs(delta_sum_change_pos)+abs(delta_sum_change_neg))/2) 93 | -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import _eval as eval 2 | from .api import estimate_score 3 | -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/api.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append("./") 5 | sys.path.append("../") 6 | 7 | import inspect 8 | 9 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 10 | parentdir = os.path.dirname(currentdir) 11 | sys.path.insert(0, parentdir) 12 | sys.path.append(os.path.join(parentdir, "dex")) 13 | 14 | from networks import domain_classifier 15 | 16 | 17 | 18 | dataset_name = 'celebahq' 19 | 20 | def _eval(classifier_name): 21 | classifier = domain_classifier.define_classifier(dataset_name, classifier_name) 22 | return classifier 23 | 24 | def estimate_score(classifier, imgs, no_soft=False): 25 | res = classifier(imgs,no_soft=no_soft) 26 | return res -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/our_interfaceGAN/celebahq_utils/dex/networks/__init__.py -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/networks/classifiers/attribute_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from . import attribute_classifier 4 | import glob 5 | 6 | softmax = torch.nn.Softmax(dim=1) 7 | 8 | def downsample(images, size=256): 9 | # Downsample to 256x256. The attribute classifiers were built for 256x256. 10 | # follows https://github.com/NVlabs/stylegan/blob/master/metrics/linear_separability.py#L127 11 | if images.shape[2] > size: 12 | factor = images.shape[2] // size 13 | assert(factor * size == images.shape[2]) 14 | images = images.view([-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor]) 15 | images = images.mean(dim=[3, 5]) 16 | return images 17 | else: 18 | assert(images.shape[-1] == 256) 19 | return images 20 | 21 | 22 | def get_logit(net, im): 23 | im_256 = downsample(im) 24 | logit = net(im_256) 25 | return logit 26 | 27 | 28 | def get_softmaxed(net, im): 29 | logit = get_logit(net, im) 30 | softmaxed = softmax(torch.cat([logit, -logit], dim=1))[:, 1] 31 | # logit is (N,) softmaxed is (N,) 32 | return logit[:, 0], softmaxed 33 | 34 | 35 | def load_attribute_classifier(attribute, ckpt_path=None): 36 | if ckpt_path is None: 37 | base_path = os.path.abspath(__file__+"/../../../pth_celeba") 38 | attribute_pkl = os.path.join(base_path, attribute, 'net_best.pth') 39 | ckpt = torch.load(attribute_pkl) 40 | else: 41 | ckpt = torch.load(ckpt_path) 42 | # print("Using classifier at epoch: %d" % ckpt['epoch']) 43 | if 'valacc' in ckpt.keys(): 44 | print("Validation acc on raw images: %0.5f" % ckpt['valacc']) 45 | detector = attribute_classifier.from_state_dict( 46 | ckpt['state_dict'], fixed_size=True, use_mbstd=False).cuda().eval() 47 | return detector 48 | 49 | 50 | class ClassifierWrapper(torch.nn.Module): 51 | def __init__(self, classifier_name, ckpt_path=None, device='cuda'): 52 | super(ClassifierWrapper, self).__init__() 53 | self.net = load_attribute_classifier(classifier_name, ckpt_path).eval().to(device) 54 | 55 | def forward(self, ims, no_soft=False): 56 | # returns (N,) softmax values for binary classification 57 | if not no_soft: 58 | return get_softmaxed(self.net, ims)[1] 59 | else: 60 | return get_softmaxed(self.net, ims)[0] 61 | -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/networks/classifiers/cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 23 | stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion * planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion * planes, 30 | kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion * planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 50 | stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * 53 | planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion * planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion * planes, 60 | kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion * planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = F.relu(self.bn2(self.conv2(out))) 67 | out = self.bn3(self.conv3(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | def __init__(self, block, num_blocks, num_classes=10): 75 | super(ResNet, self).__init__() 76 | self.in_planes = 64 77 | 78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 79 | stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.linear = nn.Linear(512 * block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1] * (num_blocks - 1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.avg_pool2d(out, 4) 102 | out = out.view(out.size(0), -1) 103 | out = self.linear(out) 104 | return out 105 | 106 | 107 | def ResNet18(): 108 | return ResNet(BasicBlock, [2, 2, 2, 2]) 109 | 110 | 111 | def ResNet34(): 112 | return ResNet(BasicBlock, [3, 4, 6, 3]) 113 | 114 | 115 | def ResNet50(): 116 | return ResNet(Bottleneck, [3, 4, 6, 3]) 117 | 118 | 119 | def ResNet101(): 120 | return ResNet(Bottleneck, [3, 4, 23, 3]) 121 | 122 | 123 | def ResNet152(): 124 | return ResNet(Bottleneck, [3, 8, 36, 3]) 125 | 126 | 127 | def test(): 128 | net = ResNet18() 129 | y = net(torch.randn(1, 3, 32, 32)) 130 | print(y.size()) 131 | 132 | # test() 133 | -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/networks/classifiers/cifar10_utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including:ww 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | import torch 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | 16 | def get_mean_and_std(dataset): 17 | '''Compute the mean and std value of dataset.''' 18 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 19 | mean = torch.zeros(3) 20 | std = torch.zeros(3) 21 | print('==> Computing mean and std..') 22 | for inputs, targets in dataloader: 23 | for i in range(3): 24 | mean[i] += inputs[:, i, :, :].mean() 25 | std[i] += inputs[:, i, :, :].std() 26 | mean.div_(len(dataset)) 27 | std.div_(len(dataset)) 28 | return mean, std 29 | 30 | 31 | def init_params(net): 32 | '''Init layer parameters.''' 33 | for m in net.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | init.kaiming_normal(m.weight, mode='fan_out') 36 | if m.bias: 37 | init.constant(m.bias, 0) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | init.constant(m.weight, 1) 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.Linear): 42 | init.normal(m.weight, std=1e-3) 43 | if m.bias: 44 | init.constant(m.bias, 0) 45 | 46 | 47 | _, term_width = os.popen('stty size', 'r').read().split() 48 | term_width = int(term_width) 49 | 50 | TOTAL_BAR_LENGTH = 65. 51 | last_time = time.time() 52 | begin_time = last_time 53 | 54 | 55 | def progress_bar(current, total, msg=None): 56 | global last_time, begin_time 57 | if current == 0: 58 | begin_time = time.time() # Reset for new bar. 59 | 60 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 61 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 62 | 63 | sys.stdout.write(' [') 64 | for i in range(cur_len): 65 | sys.stdout.write('=') 66 | sys.stdout.write('>') 67 | for i in range(rest_len): 68 | sys.stdout.write('.') 69 | sys.stdout.write(']') 70 | 71 | cur_time = time.time() 72 | step_time = cur_time - last_time 73 | last_time = cur_time 74 | tot_time = cur_time - begin_time 75 | 76 | L = [] 77 | L.append(' Step: %s' % format_time(step_time)) 78 | L.append(' | Tot: %s' % format_time(tot_time)) 79 | if msg: 80 | L.append(' | ' + msg) 81 | 82 | msg = ''.join(L) 83 | sys.stdout.write(msg) 84 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 85 | sys.stdout.write(' ') 86 | 87 | # Go back to the center of the bar. 88 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): 89 | sys.stdout.write('\b') 90 | sys.stdout.write(' %d/%d ' % (current + 1, total)) 91 | 92 | if current < total - 1: 93 | sys.stdout.write('\r') 94 | else: 95 | sys.stdout.write('\n') 96 | sys.stdout.flush() 97 | 98 | 99 | def format_time(seconds): 100 | days = int(seconds / 3600 / 24) 101 | seconds = seconds - days * 3600 * 24 102 | hours = int(seconds / 3600) 103 | seconds = seconds - hours * 3600 104 | minutes = int(seconds / 60) 105 | seconds = seconds - minutes * 60 106 | secondsf = int(seconds) 107 | seconds = seconds - secondsf 108 | millis = int(seconds * 1000) 109 | 110 | f = '' 111 | i = 1 112 | if days > 0: 113 | f += str(days) + 'D' 114 | i += 1 115 | if hours > 0 and i <= 2: 116 | f += str(hours) + 'h' 117 | i += 1 118 | if minutes > 0 and i <= 2: 119 | f += str(minutes) + 'm' 120 | i += 1 121 | if secondsf > 0 and i <= 2: 122 | f += str(secondsf) + 's' 123 | i += 1 124 | if millis > 0 and i <= 2: 125 | f += str(millis) + 'ms' 126 | i += 1 127 | if f == '': 128 | f = '0ms' 129 | return f 130 | -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/networks/domain_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def define_classifier(domain, classifier_name=None, 5 | ckpt_path=None, device='cuda'): 6 | # check that name of pretrained model, or direct ckpt path is provided 7 | assert(classifier_name or ckpt_path) 8 | # load the trained classifiers 9 | if 'celebahq' in domain: 10 | from .classifiers import attribute_utils 11 | return attribute_utils.ClassifierWrapper(classifier_name, 12 | ckpt_path=ckpt_path, 13 | device=device) 14 | elif domain == 'cat' or domain == 'car': 15 | import torchvision.models 16 | if ckpt_path is None: 17 | ckpt = torch.load('results/pretrained_classifiers/%s/%s/net_best.pth' % 18 | (domain, classifier_name)) 19 | else: 20 | ckpt = torch.load(ckpt_path) 21 | state_dict = ckpt['state_dict'] 22 | # determine num_classes from the checkpoint 23 | num_classes = state_dict['fc.bias'].shape[0] 24 | net = torchvision.models.resnet18(num_classes=num_classes) 25 | net.load_state_dict(state_dict) 26 | return net.eval().to(device) 27 | elif domain == 'cifar10': 28 | from .classifiers import cifar10_resnet 29 | net = cifar10_resnet.ResNet18() 30 | if ckpt_path is None: 31 | ckpt = torch.load('results/pretrained_classifiers/cifar10/%s/ckpt.pth' % 32 | classifier_name)['net'] 33 | else: 34 | ckpt = torch.load(ckpt_path)['net'] 35 | net.load_state_dict(ckpt) 36 | return net.eval().to(device) 37 | 38 | 39 | softmax = torch.nn.Softmax(dim=-1) 40 | 41 | 42 | def postprocess(classifier_output): 43 | # multiclass classification N x labels 44 | if len(classifier_output.shape) == 2: 45 | postprocessed_outputs = softmax(classifier_output) 46 | # binary classification output, (N,) 47 | # the softmax should already be applied in ClassifierWrapper 48 | elif len(classifier_output.shape) == 1: 49 | postprocessed_outputs = classifier_output 50 | # sanity check 51 | assert(torch.min(postprocessed_outputs) >= 0.) 52 | assert(torch.max(postprocessed_outputs) <= 1.) 53 | return postprocessed_outputs 54 | -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/networks/perturb_settings.py: -------------------------------------------------------------------------------- 1 | # perturbation ranges and layers for stylegan2, stylegan_idvert, 2 | # and stylegan2_ada models, for each dataset domain 3 | 4 | stylegan2_settings = { 5 | 'ffhq': { 6 | 'isotropic_eps_fine': [0.1, 0.2, 0.3], 7 | 'isotropic_eps_coarse': [0.1, 0.2, 0.3], 8 | 'pca_eps': [1.0, 2.0, 3.0], 9 | 'pca_stats': 'networks/stats/stylegan2_ffhq_stats.npz', 10 | 'fine_layer': 10, 11 | 'coarse_layer': 4, 12 | }, 13 | 'car': { 14 | 'isotropic_eps_fine': [0.3, 0.5, 0.7], 15 | 'isotropic_eps_coarse': [1.0, 1.5, 2.0], 16 | 'pca_eps': [1.0, 2.0, 3.0], 17 | 'pca_stats': 'networks/stats/stylegan2_car_stats.npz', 18 | 'fine_layer': 10, 19 | 'coarse_layer': 4, 20 | }, 21 | 'cat': { 22 | 'isotropic_eps_fine': [0.1, 0.2, 0.3], 23 | 'isotropic_eps_coarse': [0.5, 0.7, 1.0], 24 | 'pca_eps': [0.5, 0.7, 1.0], 25 | 'pca_stats': 'networks/stats/stylegan2_cat_stats.npz', 26 | 'fine_layer': 10, 27 | 'coarse_layer': 4, 28 | }, 29 | } 30 | 31 | stylegan_idinvert_settings = { 32 | 'ffhq': { 33 | 'coarse_layer': 4, 34 | 'fine_layer': 10, 35 | }, 36 | } 37 | 38 | stylegan2_cc_settings = { 39 | 'cifar10': { 40 | 'fine_layer': 7, 41 | 'cc_mean_w': 'networks/stats/stylegan2_cifar10c_wmean.npz' 42 | }, 43 | } 44 | -------------------------------------------------------------------------------- /our_interfaceGAN/celebahq_utils/dex/networks/stylegan2/stylegan2_networks.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | from utils import customnet, util 3 | from argparse import Namespace 4 | from utils.pt_stylegan2 import get_generator 5 | from collections import OrderedDict 6 | import torch.nn as nn 7 | from torch.nn.functional import interpolate 8 | 9 | def stylegan_setting(domain): 10 | outdim = 256 11 | nz = 512 12 | mult = 14 13 | resnet_depth = 34 14 | if domain == 'ffhq': 15 | outdim = 1024 16 | mult = 18 17 | if domain == 'car': 18 | outdim = 512 19 | mult = 16 20 | return dict(outdim=outdim, nz=nz, nlatent=nz*mult, 21 | resnet_depth=resnet_depth) 22 | 23 | def load_stylegan(domain, size=256): 24 | ckpt_path = f'pretrained_models/sgans_stylegan2-{domain}-config-f.pt' 25 | url = 'http://latent-composition.csail.mit.edu/' + ckpt_path 26 | cfg=Namespace(optimize_to_w=True) 27 | generator = get_generator(url, cfg=cfg, size=size).eval() 28 | return generator 29 | 30 | def load_stylegan_encoder(domain, nz=512*14, outdim=256, use_RGBM=True, use_VAE=False, 31 | resnet_depth=34, ckpt_path=None): 32 | halfsize = False # hardcoding 33 | if use_VAE: 34 | nz = nz*2 35 | channels_in = 4 if use_RGBM or use_VAE else 3 36 | print(f"Using halfsize?: {halfsize}") 37 | print(f"Input channels: {channels_in}") 38 | encoder = get_stylegan_encoder(ndim_z=nz, resnet_depth=resnet_depth, 39 | halfsize=halfsize, channels_in=channels_in) 40 | if ckpt_path is None: 41 | # use the pretrained checkpoint path (RGBM model) 42 | assert(use_RGBM) 43 | assert(not use_VAE) 44 | suffix = 'RGBM' 45 | ckpt_path = f'pretrained_models/sgan_encoders_{domain}_{suffix}_model_initial.pth.tar' 46 | # note: a further finetuned version of the encoder is at the 47 | # following path, it may better initialize for optimization 48 | # but we did not use the finetuned version in the paper 49 | # ckpt_path = f'pretrained_models/sgan_encoders_{domain}_{suffix}_model_final.pth' 50 | print(f"Using default checkpoint path: {ckpt_path}") 51 | url = 'http://latent-composition.csail.mit.edu/' + ckpt_path 52 | ckpt = torch.hub.load_state_dict_from_url(url) 53 | else: 54 | if util.is_url(ckpt_path): 55 | ckpt = torch.hub.load_state_dict_from_url(ckpt_path) 56 | else: 57 | ckpt = torch.load(ckpt_path) 58 | encoder.load_state_dict(ckpt['state_dict']) 59 | encoder = encoder.eval() 60 | return encoder 61 | 62 | def get_stylegan_encoder(ndim_z=512, add_relu=False, resnet_depth=34, halfsize=True, channels_in=3): 63 | """ 64 | Return encoder. Change to get a different encoder. 65 | """ 66 | def make_resnet(halfsize=True, resize=True, ndim_z=512, add_relu=False, resnet_depth=34, channels_in=3): 67 | # A resnet with the final FC layer removed. 68 | # Instead, we have a final conv5, leaky relu, and global average pooling. 69 | native_size = 128 if halfsize else 256 70 | # Make an encoder model. 71 | def change_out(layers): 72 | numch = 512 if resnet_depth < 50 else 2048 73 | ind = [i for i, (n, l) in enumerate(layers) if n == 'layer4'][0] + 1 74 | newlayer = ('layer5', 75 | torch.nn.Sequential(OrderedDict([ 76 | ('conv5', torch.nn.Conv2d(numch, ndim_z, kernel_size=1)), 77 | ]))) 78 | 79 | layers.insert(ind, newlayer) 80 | 81 | if resize: 82 | layers[:0] = [('downsample', 83 | InterpolationLayer(size=(native_size, native_size)))] 84 | 85 | # Remove FC layer 86 | layers = layers[:-1] 87 | 88 | if add_relu: 89 | layers.append( ('postrelu', torch.nn.LeakyReLU(0.2) )) 90 | 91 | # add reshape layer 92 | layers.append(('to_wplus', customnet.EncoderToWplus())) 93 | 94 | return layers 95 | 96 | encoder = customnet.CustomResNet( 97 | resnet_depth, modify_sequence=change_out, halfsize=halfsize, 98 | channels_in=channels_in) 99 | 100 | # Init using He initialization 101 | def init_weights(m): 102 | if type(m) == torch.nn.Linear or type(m) == torch.nn.Conv2d: 103 | torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 104 | if m.bias is not None: 105 | m.bias.data.fill_(0.01) 106 | 107 | encoder.apply(init_weights) 108 | return encoder 109 | 110 | encoder = make_resnet(ndim_z=ndim_z, add_relu=add_relu ,resnet_depth=resnet_depth, 111 | channels_in=channels_in, halfsize=halfsize) 112 | return encoder 113 | 114 | 115 | class InterpolationLayer(nn.Module): 116 | def __init__(self, size): 117 | super(InterpolationLayer, self).__init__() 118 | 119 | self.size=size 120 | 121 | def forward(self, x): 122 | return interpolate(x, size=self.size, mode='area') 123 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/0.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [300], 3 | "content_end_distance": [7] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/1.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [110], 3 | "content_end_distance": [5] 4 | } 5 | 6 | 7 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/12.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [70], 3 | "content_end_distance": [5] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/13.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [70], 3 | "content_end_distance": [5] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/3.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [10], 3 | "content_end_distance": [15] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/4.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [20], 3 | "content_end_distance": [8] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/8.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [20], 3 | "content_end_distance": [8] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/9.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [30], 3 | "content_end_distance": [8] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/age.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [2], 3 | "content_end_distance": [30] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/gender.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [40], 3 | "content_end_distance": [7] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/pose.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [1], 3 | "content_end_distance": [20] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_inversion/seed.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": [0, 100] 3 | } 4 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/0.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [5], 3 | "content_end_distance": [3] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/1.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [30], 3 | "content_end_distance": [5] 4 | } 5 | 6 | 7 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/12.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [110], 3 | "content_end_distance": [5] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/13.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [110], 3 | "content_end_distance": [5] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/3.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [0.5], 3 | "content_end_distance": [6.5] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/8.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [1], 3 | "content_end_distance": [8] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/9.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [0.5], 3 | "content_end_distance": [4] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/age.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [2.5], 3 | "content_end_distance": [1] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/gender.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [3.5], 3 | "content_end_distance": [1] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/pose.json: -------------------------------------------------------------------------------- 1 | { 2 | "style_end_distance": [1], 3 | "content_end_distance": [17] 4 | } 5 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/seed.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": [0, 100] 3 | } 4 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/utils/edit_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms, utils 3 | from PIL import Image 4 | import os 5 | import numpy as np 6 | 7 | 8 | def make_image(tensor): 9 | return ( 10 | tensor.detach() 11 | .clamp_(min=-1, max=1) 12 | .add(1) 13 | .div_(2) 14 | .mul(255) 15 | .type(torch.uint8) 16 | .permute(0, 2, 3, 1) 17 | .to('cpu') 18 | .numpy() 19 | ) 20 | 21 | def visualize(img_path): 22 | img_list=os.listdir(img_path) 23 | img_list.sort() 24 | img_list.sort(key = lambda x: (x[:-4])) ##文件名按数字排序 25 | b = ['0', '12', '24', '36', '48', '60'] 26 | dir = [] 27 | img_nums=len(img_list) 28 | res = [] 29 | 30 | transform = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) 33 | ] 34 | ) 35 | 36 | for i in range(img_nums): 37 | img_name=os.path.join(img_path, img_list[i]) 38 | ll = img_name.split('/')[-1].split('_')[3] 39 | if ll in b: 40 | dir.append(img_name) 41 | img = Image.open(img_name).convert('RGB') 42 | img2 = transform(img) 43 | array = np.asarray(img2) 44 | data = torch.from_numpy(array).unsqueeze(0) 45 | res.append(data) 46 | sample = torch.cat(res, dim=0) 47 | utils.save_image( 48 | sample, 49 | os.path.join(img_path,'edit.png'), 50 | nrow=int(6), 51 | normalize=True, 52 | range=(-1, 1), 53 | ) 54 | -------------------------------------------------------------------------------- /our_interfaceGAN/config_noinversion/utils/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def prepare_param(n_sample, args, device, method="batch_same", truncation=1.0): 4 | if method == "batch_same": 5 | return torch.randn(args.para_num, args.latent, device=device).repeat(n_sample, 1, 1) * truncation 6 | elif method == "batch_diff": 7 | return torch.randn(n_sample, args.para_num, args.latent, device=device) * truncation 8 | elif method == "spatial": 9 | # batch, 512,16 10 | return torch.randn(n_sample, args.latent, args.para_num, device=device) * truncation 11 | elif method == "spatial_same": 12 | return torch.randn(args.latent, args.para_num, device=device).repeat(n_sample,1,1) * truncation 13 | 14 | 15 | 16 | def prepare_noise_new(n_sample, args, device, method="multi", truncation=1.0, mode = 'train'): 17 | # used for train_spatial_query, returns (bs, 512, 16) 18 | if method == 'query': 19 | return torch.randn(n_sample, args.latent, args.para_num, device=device) * truncation 20 | elif method == 'query_same': 21 | return torch.randn(args.latent, args.para_num, device=device).repeat(n_sample,1,1) * truncation -------------------------------------------------------------------------------- /our_interfaceGAN/ffhq_utils/dex/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import _eval as eval 2 | from .api import estimate_age 3 | from .api import estimate_gender 4 | -------------------------------------------------------------------------------- /our_interfaceGAN/ffhq_utils/dex/api.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | from .models import Age, Gender, ClassifyModel 7 | 8 | device = 'cuda' 9 | 10 | age_model = Age() 11 | gender_model = Gender() 12 | other_model = ClassifyModel() 13 | 14 | 15 | cwd = os.path.dirname(__file__) 16 | age_model_path = os.path.join(cwd, 'pth/age_sd.pth') 17 | gender_model_path = os.path.join(cwd, 'pth/gender_sd.pth') 18 | pose_model_path = os.path.join(cwd, 'pth/classifier/pose/weight.pkl') 19 | 20 | 21 | def _eval(attribute_name): 22 | global age_model 23 | global gender_model 24 | age_model.load_state_dict(torch.load(age_model_path)) 25 | age_model.eval() 26 | age_model = age_model.to(device) 27 | binarymodel_path = gender_model_path 28 | if attribute_name == 'gender': 29 | binarymodel_path = gender_model_path 30 | gender_model.load_state_dict(torch.load(binarymodel_path)) 31 | gender_model.eval() 32 | gender_model = gender_model.to(device) 33 | else: 34 | if attribute_name != 'age': 35 | if attribute_name == 'pose': 36 | binarymodel_path = pose_model_path 37 | other_model.load_state_dict(torch.load(binarymodel_path)) 38 | other_model.eval() 39 | gender_model = other_model.to(device) 40 | 41 | 42 | def expected_age(tensor): 43 | weight = torch.arange(1, 102, device=device) 44 | return weight * tensor 45 | 46 | 47 | def estimate_age(img): 48 | # tensor = transforms.CenterCrop(224)(img) 49 | # tensor = transforms.CenterCrop(224)(img) 50 | h = img.size(2) 51 | offset = (h - 224) // 2 52 | tensor = img[:, :, offset:-offset, offset:-offset] 53 | # print(tensor.shape) 54 | 55 | with torch.no_grad(): 56 | output = age_model(tensor) 57 | age = expected_age(output) 58 | return torch.sum(age, dim=1) 59 | 60 | 61 | def estimate_gender(img): 62 | tensor = transforms.CenterCrop(224)(img) 63 | with torch.no_grad(): 64 | output = gender_model(tensor)[:,0] # 把第一类作为正样本 65 | return output 66 | 67 | 68 | -------------------------------------------------------------------------------- /our_interfaceGAN/ffhq_utils/dex/models.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | from torchvision.models import resnet18 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def vgg_block(in_channels, out_channels, more=False): 10 | blocklist = [ 11 | ('conv1', nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)), 12 | ('relu1', nn.ReLU(inplace=True)), 13 | ('conv2', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)), 14 | ('relu2', nn.ReLU(inplace=True)), 15 | ] 16 | if more: 17 | blocklist.extend([ 18 | ('conv3', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)), 19 | ('relu3', nn.ReLU(inplace=True)), 20 | ]) 21 | blocklist.append(('maxpool', nn.MaxPool2d(kernel_size=2, stride=2))) 22 | block = nn.Sequential(OrderedDict(blocklist)) 23 | return block 24 | 25 | 26 | # VGG16 architecture 27 | class VGG(nn.Module): 28 | def __init__(self, classes=1000, channels=3): 29 | super().__init__() 30 | self.conv = nn.Sequential( 31 | vgg_block(channels, 64), 32 | vgg_block(64, 128), 33 | vgg_block(128, 256, True), 34 | vgg_block(256, 512, True), 35 | vgg_block(512, 512, True), 36 | ) 37 | self.fc1 = nn.Sequential( 38 | nn.Linear(512 * 7 * 7, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Dropout(0.5, inplace=True), 41 | ) 42 | self.fc2 = nn.Sequential( 43 | nn.Linear(4096, 4096), 44 | nn.ReLU(inplace=True), 45 | nn.Dropout(0.5, inplace=True), 46 | ) 47 | self.cls = nn.Linear(4096, classes) 48 | 49 | def forward(self, x): 50 | in_size = x.shape[0] 51 | x = self.conv(x) 52 | x = x.view(in_size, -1) 53 | x = self.fc1(x) 54 | x = self.fc2(x) 55 | x = self.cls(x) 56 | x = F.softmax(x, dim=1) 57 | return x 58 | 59 | 60 | class Gender(VGG): 61 | def __init__(self, classes=2, channels=3): 62 | super().__init__() 63 | self.cls = nn.Linear(4096, classes) 64 | 65 | 66 | class Age(VGG): 67 | def __init__(self, classes=101, channels=3): 68 | super().__init__() 69 | self.cls = nn.Linear(4096, classes) 70 | 71 | 72 | 73 | def get_resnet(): 74 | net = resnet18() 75 | modified_net = nn.Sequential(*list(net.children())[:-1]) # fetch all of the layers before the last fc. 76 | return modified_net 77 | 78 | class ClassifyModel(nn.Module): 79 | def __init__(self, n_class=2): 80 | super(ClassifyModel, self).__init__() 81 | self.backbone = get_resnet() 82 | self.extra_layer = nn.Linear(512, n_class) 83 | 84 | def forward(self, x): 85 | out = self.backbone(x) 86 | out = torch.flatten(out, 1) 87 | out = self.extra_layer(out) 88 | out = F.softmax(out, dim=1) 89 | return out -------------------------------------------------------------------------------- /our_interfaceGAN/linear_interpolation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def linear_interpolate(latent_code, 5 | boundary, 6 | start_distance=-100, 7 | end_distance=100, 8 | steps=10): 9 | """Manipulates the given latent code with respect to a particular boundary. 10 | Basically, this function takes a latent code and a boundary as inputs, and 11 | outputs a collection of manipulated latent codes. For example, let `steps` to 12 | be 10, then the input `latent_code` is with shape [1, latent_space_dim], input 13 | `boundary` is with shape [1, latent_space_dim] and unit norm, the output is 14 | with shape [10, latent_space_dim]. The first output latent code is 15 | `start_distance` away from the given `boundary`, while the last output latent 16 | code is `end_distance` away from the given `boundary`. Remaining latent codes 17 | are linearly interpolated. 18 | Input `latent_code` can also be with shape [1, num_layers, latent_space_dim] 19 | to support W+ space in Style GAN. In this case, all features in W+ space will 20 | be manipulated same as each other. Accordingly, the output will be with shape 21 | [10, num_layers, latent_space_dim]. 22 | NOTE: Distance is sign sensitive. 23 | Args: 24 | latent_code: The input latent code for manipulation. 25 | boundary: The semantic boundary as reference. 26 | start_distance: The distance to the boundary where the manipulation starts. 27 | (default: -3.0) 28 | end_distance: The distance to the boundary where the manipulation ends. 29 | (default: 3.0) 30 | steps: Number of steps to move the latent code from start position to end 31 | position. (default: 10) 32 | """ 33 | assert (latent_code.shape[0] == 1 and boundary.shape[0] == 1 and 34 | len(boundary.shape) == 2 and 35 | boundary.shape[1] == latent_code.shape[-1]) 36 | 37 | linspace = np.linspace(start_distance, end_distance, steps) 38 | if len(latent_code.shape) == 2: 39 | linspace = linspace - latent_code.dot(boundary.T) 40 | linspace = linspace.reshape(-1, 1).astype(np.float32) 41 | return latent_code + linspace * boundary 42 | if len(latent_code.shape) == 3: 43 | linspace = linspace.reshape(-1, 1, 1).astype(np.float32) 44 | return latent_code + linspace * boundary.reshape(1, 1, -1) 45 | raise ValueError(f'Input `latent_code` should be with shape ' 46 | f'[1, latent_space_dim] or [1, N, latent_space_dim] for ' 47 | f'W+ space in Style GAN!\n' 48 | f'But {latent_code.shape} is received.') 49 | -------------------------------------------------------------------------------- /our_interfaceGAN/train_boundary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import svm 3 | 4 | 5 | def train_boundary(latent_codes, 6 | scores, 7 | chosen_num_or_ratio=0.02, 8 | split_ratio=0.7, 9 | invalid_value=None): 10 | """Trains boundary in latent space with offline predicted attribute scores. 11 | Given a collection of latent codes and the attribute scores predicted from the 12 | corresponding images, this function will train a linear SVM by treating it as 13 | a bi-classification problem. Basically, the samples with highest attribute 14 | scores are treated as positive samples, while those with lowest scores as 15 | negative. For now, the latent code can ONLY be with 1 dimension. 16 | NOTE: The returned boundary is with shape (1, latent_space_dim), and also 17 | normalized with unit norm. 18 | Args: 19 | latent_codes: Input latent codes as training data. 20 | scores: Input attribute scores used to generate training labels. 21 | chosen_num_or_ratio: How many samples will be chosen as positive (negative) 22 | samples. If this field lies in range (0, 0.5], `chosen_num_or_ratio * 23 | latent_codes_num` will be used. Otherwise, `min(chosen_num_or_ratio, 24 | 0.5 * latent_codes_num)` will be used. (default: 0.02) 25 | split_ratio: Ratio to split training and validation sets. (default: 0.7) 26 | invalid_value: This field is used to filter out data. (default: None) 27 | Returns: 28 | A decision boundary with type `numpy.ndarray`. 29 | Raises: 30 | ValueError: If the input `latent_codes` or `scores` are with invalid format. 31 | """ 32 | 33 | if (not isinstance(latent_codes, np.ndarray) or 34 | not len(latent_codes.shape) == 2): 35 | raise ValueError(f'Input `latent_codes` should be with type' 36 | f'`numpy.ndarray`, and shape [num_samples, ' 37 | f'latent_space_dim]!') 38 | num_samples = latent_codes.shape[0] 39 | latent_space_dim = latent_codes.shape[1] 40 | if (not isinstance(scores, np.ndarray) or not len(scores.shape) == 2 or 41 | not scores.shape[0] == num_samples or not scores.shape[1] == 1): 42 | raise ValueError(f'Input `scores` should be with type `numpy.ndarray`, and ' 43 | f'shape [num_samples, 1], where `num_samples` should be ' 44 | f'exactly same as that of input `latent_codes`!') 45 | if chosen_num_or_ratio <= 0: 46 | raise ValueError(f'Input `chosen_num_or_ratio` should be positive, ' 47 | f'but {chosen_num_or_ratio} received!') 48 | 49 | print('Filtering training data.') 50 | if invalid_value is not None: 51 | latent_codes = latent_codes[scores[:, 0] != invalid_value] 52 | scores = scores[scores[:, 0] != invalid_value] 53 | 54 | print('Sorting scores to get positive and negative samples.') 55 | sorted_idx = np.argsort(scores, axis=0)[::-1, 0] 56 | latent_codes = latent_codes[sorted_idx] 57 | scores = scores[sorted_idx] 58 | num_samples = latent_codes.shape[0] 59 | if 0 < chosen_num_or_ratio <= 1: 60 | chosen_num = int(num_samples * chosen_num_or_ratio) # chosen_num个样本 61 | else: 62 | chosen_num = int(chosen_num_or_ratio) 63 | chosen_num = min(chosen_num, num_samples // 2) 64 | print(f"sample range: >{scores[chosen_num]}, <{scores[-chosen_num]}") 65 | 66 | print('Spliting training and validation sets:') 67 | train_num = int(chosen_num * split_ratio) 68 | val_num = chosen_num - train_num 69 | # Positive samples. 70 | positive_idx = np.arange(chosen_num) 71 | np.random.shuffle(positive_idx) 72 | positive_train = latent_codes[:chosen_num][positive_idx[:train_num]] 73 | positive_val = latent_codes[:chosen_num][positive_idx[train_num:]] 74 | # Negative samples. 75 | negative_idx = np.arange(chosen_num) 76 | np.random.shuffle(negative_idx) 77 | negative_train = latent_codes[-chosen_num:][negative_idx[:train_num]] 78 | negative_val = latent_codes[-chosen_num:][negative_idx[train_num:]] 79 | # Training set. 80 | train_data = np.concatenate([positive_train, negative_train], axis=0) 81 | train_label = np.concatenate([np.ones(train_num, dtype=np.int), 82 | np.zeros(train_num, dtype=np.int)], axis=0) 83 | print(f' Training: {train_num} positive, {train_num} negative.') # Training: 1400 positive, 14 negative. 84 | # Validation set. 85 | val_data = np.concatenate([positive_val, negative_val], axis=0) 86 | val_label = np.concatenate([np.ones(val_num, dtype=np.int), 87 | np.zeros(val_num, dtype=np.int)], axis=0) 88 | print(f' Validation: {val_num} positive, {val_num} negative.') 89 | # Remaining set. 90 | remaining_num = num_samples - chosen_num * 2 91 | remaining_data = latent_codes[chosen_num:-chosen_num] 92 | remaining_scores = scores[chosen_num:-chosen_num] 93 | decision_value = (scores[0] + scores[-1]) / 2 94 | remaining_label = np.ones(remaining_num, dtype=np.int) 95 | remaining_label[remaining_scores.ravel() < decision_value] = 0 96 | remaining_positive_num = np.sum(remaining_label == 1) 97 | remaining_negative_num = np.sum(remaining_label == 0) 98 | print(f' Remaining: {remaining_positive_num} positive, ' 99 | f'{remaining_negative_num} negative.') 100 | 101 | print(f'Training boundary.') 102 | # scaler = StandardScaler() 103 | # scaler.fit(train_data) 104 | # 105 | # train_transformed = scaler.transform(train_data) 106 | # val_transformed = scaler.transform(val_data) 107 | # # remaining_transformed = scaler.transform(remaining_data) 108 | # 109 | # clf = SGDClassifier() 110 | # classifier = clf.fit(train_transformed, train_label) 111 | # print(f"Actual number of iterations: {classifier.n_iter_}") 112 | 113 | clf = svm.SVC(kernel='linear') 114 | classifier = clf.fit(train_data, train_label) 115 | print(f'Finish training.') 116 | 117 | if train_num: 118 | train_prediction = classifier.predict(train_data) 119 | correct_num = np.sum(train_label == train_prediction) 120 | print(f'Accuracy for training set: ' 121 | f'{correct_num} / {train_num * 2} = ' 122 | f'{correct_num / (train_num * 2):.6f}') 123 | 124 | if val_num: 125 | val_prediction = classifier.predict(val_data) 126 | correct_num = np.sum(val_label == val_prediction) 127 | print(f'Accuracy for validation set: ' 128 | f'{correct_num} / {val_num * 2} = ' 129 | f'{correct_num / (val_num * 2):.6f}') 130 | 131 | # if remaining_num: 132 | # remaining_prediction = classifier.predict(remaining_data) 133 | # correct_num = np.sum(remaining_label == remaining_prediction) 134 | # print(f'Accuracy for remaining set: ' 135 | # f'{correct_num} / {remaining_num} = ' 136 | # f'{correct_num / remaining_num:.6f}') 137 | 138 | a = classifier.coef_.reshape(1, latent_space_dim).astype(np.float32) 139 | return a / np.linalg.norm(a) 140 | 141 | -------------------------------------------------------------------------------- /pSp/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elad Richardson, Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pSp/configs/data_configs.py: -------------------------------------------------------------------------------- 1 | from pSp.configs import transforms_config 2 | from pSp.configs.paths_config import dataset_paths 3 | 4 | DATASETS = { 5 | 'ffhq_encode': { 6 | 'transforms': transforms_config.EncodeTransforms, 7 | 'train_source_root': dataset_paths['ffhq'], 8 | 'train_target_root': dataset_paths['ffhq'], 9 | 'test_source_root': dataset_paths['celeba_test'], 10 | 'test_target_root': dataset_paths['celeba_test'], 11 | }, 12 | 'ffhq_frontalize': { 13 | 'transforms': transforms_config.FrontalizationTransforms, 14 | 'train_source_root': dataset_paths['ffhq'], 15 | 'train_target_root': dataset_paths['ffhq'], 16 | 'test_source_root': dataset_paths['celeba_test'], 17 | 'test_target_root': dataset_paths['celeba_test'], 18 | }, 19 | 'celebs_sketch_to_face': { 20 | 'transforms': transforms_config.SketchToImageTransforms, 21 | 'train_source_root': dataset_paths['celeba_train_sketch'], 22 | 'train_target_root': dataset_paths['celeba_train'], 23 | 'test_source_root': dataset_paths['celeba_test_sketch'], 24 | 'test_target_root': dataset_paths['celeba_test'], 25 | }, 26 | 'celebs_seg_to_face': { 27 | 'transforms': transforms_config.SegToImageTransforms, 28 | 'train_source_root': dataset_paths['celeba_train_segmentation'], 29 | 'train_target_root': dataset_paths['celeba_train'], 30 | 'test_source_root': dataset_paths['celeba_test_segmentation'], 31 | 'test_target_root': dataset_paths['celeba_test'], 32 | }, 33 | 'celebs_super_resolution': { 34 | 'transforms': transforms_config.SuperResTransforms, 35 | 'train_source_root': dataset_paths['celeba_train'], 36 | 'train_target_root': dataset_paths['celeba_train'], 37 | 'test_source_root': dataset_paths['celeba_test'], 38 | 'test_target_root': dataset_paths['celeba_test'], 39 | }, 40 | } 41 | -------------------------------------------------------------------------------- /pSp/configs/paths_config.py: -------------------------------------------------------------------------------- 1 | dataset_paths = { 2 | 'celeba_train': '', 3 | 'celeba_test': '', 4 | 'celeba_train_sketch': '', 5 | 'celeba_test_sketch': '', 6 | 'celeba_train_segmentation': '', 7 | 'celeba_test_segmentation': '', 8 | 'ffhq': '', 9 | } 10 | 11 | 12 | # yanbo, 86 /mnt/lustre/share_data/xuyanbo 13 | model_paths = { 14 | 'stylegan_ffhq': '/mnt/lustre/share_data/xuyanbo/pretrained_models/stylegan2-ffhq-config-f.pt', 15 | 'ir_se50': '/mnt/lustre/share_data/xuyanbo/pretrained_models/model_ir_se50.pth', 16 | 'circular_face': '/mnt/lustre/share_data/xuyanbo/pretrained_models/CurricularFace_Backbone.pth', 17 | 'mtcnn_pnet': '/mnt/lustre/share_data/xuyanbo/pretrained_models/mtcnn/pnet.npy', 18 | 'mtcnn_rnet': '/mnt/lustre/share_data/xuyanbo/pretrained_models/mtcnn/rnet.npy', 19 | 'mtcnn_onet': '/mnt/lustre/share_data/xuyanbo/pretrained_models/mtcnn/onet.npy', 20 | 'shape_predictor': 'shape_predictor_68_face_landmarks.dat' 21 | } 22 | -------------------------------------------------------------------------------- /pSp/configs/transforms_config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torchvision.transforms as transforms 4 | 5 | from pSp.datasets import augmentations 6 | 7 | 8 | class TransformsConfig(object): 9 | 10 | def __init__(self, opts): 11 | self.opts = opts 12 | 13 | @abstractmethod 14 | def get_transforms(self): 15 | pass 16 | 17 | 18 | class EncodeTransforms(TransformsConfig): 19 | 20 | def __init__(self, opts): 21 | super(EncodeTransforms, self).__init__(opts) 22 | 23 | def get_transforms(self): 24 | transforms_dict = { 25 | 'transform_gt_train': transforms.Compose([ 26 | transforms.Resize((256, 256)), 27 | transforms.RandomHorizontalFlip(0.5), 28 | transforms.ToTensor(), 29 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 30 | 'transform_source': None, 31 | 'transform_test': transforms.Compose([ 32 | transforms.Resize((256, 256)), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 35 | 'transform_inference': transforms.Compose([ 36 | transforms.Resize((256, 256)), 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 39 | } 40 | return transforms_dict 41 | 42 | 43 | class FrontalizationTransforms(TransformsConfig): 44 | 45 | def __init__(self, opts): 46 | super(FrontalizationTransforms, self).__init__(opts) 47 | 48 | def get_transforms(self): 49 | transforms_dict = { 50 | 'transform_gt_train': transforms.Compose([ 51 | transforms.Resize((256, 256)), 52 | transforms.RandomHorizontalFlip(0.5), 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 55 | 'transform_source': transforms.Compose([ 56 | transforms.Resize((256, 256)), 57 | transforms.RandomHorizontalFlip(0.5), 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 60 | 'transform_test': transforms.Compose([ 61 | transforms.Resize((256, 256)), 62 | transforms.ToTensor(), 63 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 64 | 'transform_inference': transforms.Compose([ 65 | transforms.Resize((256, 256)), 66 | transforms.ToTensor(), 67 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 68 | } 69 | return transforms_dict 70 | 71 | 72 | class SketchToImageTransforms(TransformsConfig): 73 | 74 | def __init__(self, opts): 75 | super(SketchToImageTransforms, self).__init__(opts) 76 | 77 | def get_transforms(self): 78 | transforms_dict = { 79 | 'transform_gt_train': transforms.Compose([ 80 | transforms.Resize((256, 256)), 81 | transforms.ToTensor(), 82 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 83 | 'transform_source': transforms.Compose([ 84 | transforms.Resize((256, 256)), 85 | transforms.ToTensor()]), 86 | 'transform_test': transforms.Compose([ 87 | transforms.Resize((256, 256)), 88 | transforms.ToTensor(), 89 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 90 | 'transform_inference': transforms.Compose([ 91 | transforms.Resize((256, 256)), 92 | transforms.ToTensor()]), 93 | } 94 | return transforms_dict 95 | 96 | 97 | class SegToImageTransforms(TransformsConfig): 98 | 99 | def __init__(self, opts): 100 | super(SegToImageTransforms, self).__init__(opts) 101 | 102 | def get_transforms(self): 103 | transforms_dict = { 104 | 'transform_gt_train': transforms.Compose([ 105 | transforms.Resize((256, 256)), 106 | transforms.ToTensor(), 107 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 108 | 'transform_source': transforms.Compose([ 109 | transforms.Resize((256, 256)), 110 | augmentations.ToOneHot(self.opts.label_nc), 111 | transforms.ToTensor()]), 112 | 'transform_test': transforms.Compose([ 113 | transforms.Resize((256, 256)), 114 | transforms.ToTensor(), 115 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 116 | 'transform_inference': transforms.Compose([ 117 | transforms.Resize((256, 256)), 118 | augmentations.ToOneHot(self.opts.label_nc), 119 | transforms.ToTensor()]) 120 | } 121 | return transforms_dict 122 | 123 | 124 | class SuperResTransforms(TransformsConfig): 125 | 126 | def __init__(self, opts): 127 | super(SuperResTransforms, self).__init__(opts) 128 | 129 | def get_transforms(self): 130 | if self.opts.resize_factors is None: 131 | self.opts.resize_factors = '1,2,4,8,16,32' 132 | factors = [int(f) for f in self.opts.resize_factors.split(",")] 133 | print("Performing down-sampling with factors: {}".format(factors)) 134 | transforms_dict = { 135 | 'transform_gt_train': transforms.Compose([ 136 | transforms.Resize((256, 256)), 137 | transforms.ToTensor(), 138 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 139 | 'transform_source': transforms.Compose([ 140 | transforms.Resize((256, 256)), 141 | augmentations.BilinearResize(factors=factors), 142 | transforms.Resize((256, 256)), 143 | transforms.ToTensor(), 144 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 145 | 'transform_test': transforms.Compose([ 146 | transforms.Resize((256, 256)), 147 | transforms.ToTensor(), 148 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), 149 | 'transform_inference': transforms.Compose([ 150 | transforms.Resize((256, 256)), 151 | augmentations.BilinearResize(factors=factors), 152 | transforms.Resize((256, 256)), 153 | transforms.ToTensor(), 154 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 155 | } 156 | return transforms_dict 157 | -------------------------------------------------------------------------------- /pSp/criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from pSp.configs.paths_config import model_paths 5 | from pSp.models.encoders.model_irse import Backbone 6 | 7 | 8 | class IDLoss(nn.Module): 9 | def __init__(self): 10 | super(IDLoss, self).__init__() 11 | print('Loading ResNet ArcFace') 12 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 13 | self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) 14 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 15 | self.facenet.eval() 16 | 17 | def extract_feats(self, x): 18 | x = x[:, :, 35:223, 32:220] # Crop interesting region 19 | x = self.face_pool(x) 20 | x_feats = self.facenet(x) 21 | return x_feats 22 | 23 | def forward(self, y_hat, y, x): 24 | n_samples = x.shape[0] 25 | x_feats = self.extract_feats(x) 26 | y_feats = self.extract_feats(y) # Otherwise use the feature from there 27 | y_hat_feats = self.extract_feats(y_hat) 28 | y_feats = y_feats.detach() 29 | loss = 0 30 | sim_improvement = 0 31 | id_logs = [] 32 | count = 0 33 | for i in range(n_samples): 34 | diff_target = y_hat_feats[i].dot(y_feats[i]) 35 | diff_input = y_hat_feats[i].dot(x_feats[i]) 36 | diff_views = y_feats[i].dot(x_feats[i]) 37 | id_logs.append({'diff_target': float(diff_target), 38 | 'diff_input': float(diff_input), 39 | 'diff_views': float(diff_views)}) 40 | loss += 1 - diff_target 41 | id_diff = float(diff_target) - float(diff_views) 42 | sim_improvement += id_diff 43 | count += 1 44 | 45 | return loss / count, sim_improvement / count, id_logs 46 | -------------------------------------------------------------------------------- /pSp/criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from pSp.criteria.lpips.networks import get_network, LinLayers 5 | from pSp.criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /pSp/criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchvision import models 7 | 8 | from pSp.criteria.lpips.utils import normalize_activation 9 | 10 | 11 | def get_network(net_type: str): 12 | if net_type == 'alex': 13 | return AlexNet() 14 | elif net_type == 'squeeze': 15 | return SqueezeNet() 16 | elif net_type == 'vgg': 17 | return VGG16() 18 | else: 19 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 20 | 21 | 22 | class LinLayers(nn.ModuleList): 23 | def __init__(self, n_channels_list: Sequence[int]): 24 | super(LinLayers, self).__init__([ 25 | nn.Sequential( 26 | nn.Identity(), 27 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 28 | ) for nc in n_channels_list 29 | ]) 30 | 31 | for param in self.parameters(): 32 | param.requires_grad = False 33 | 34 | 35 | class BaseNet(nn.Module): 36 | def __init__(self): 37 | super(BaseNet, self).__init__() 38 | 39 | # register buffer 40 | self.register_buffer( 41 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 42 | self.register_buffer( 43 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 44 | 45 | def set_requires_grad(self, state: bool): 46 | for param in chain(self.parameters(), self.buffers()): 47 | param.requires_grad = state 48 | 49 | def z_score(self, x: torch.Tensor): 50 | return (x - self.mean) / self.std 51 | 52 | def forward(self, x: torch.Tensor): 53 | x = self.z_score(x) 54 | 55 | output = [] 56 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 57 | x = layer(x) 58 | if i in self.target_layers: 59 | output.append(normalize_activation(x)) 60 | if len(output) == len(self.target_layers): 61 | break 62 | return output 63 | 64 | 65 | class SqueezeNet(BaseNet): 66 | def __init__(self): 67 | super(SqueezeNet, self).__init__() 68 | 69 | self.layers = models.squeezenet1_1(True).features 70 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 71 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 72 | 73 | self.set_requires_grad(False) 74 | 75 | 76 | class AlexNet(BaseNet): 77 | def __init__(self): 78 | super(AlexNet, self).__init__() 79 | 80 | self.layers = models.alexnet(True).features 81 | self.target_layers = [2, 5, 8, 10, 12] 82 | self.n_channels_list = [64, 192, 384, 256, 256] 83 | 84 | self.set_requires_grad(False) 85 | 86 | 87 | class VGG16(BaseNet): 88 | def __init__(self): 89 | super(VGG16, self).__init__() 90 | 91 | self.layers = models.vgg16(True).features 92 | self.target_layers = [4, 9, 16, 23, 30] 93 | self.n_channels_list = [64, 128, 256, 512, 512] 94 | 95 | self.set_requires_grad(False) 96 | -------------------------------------------------------------------------------- /pSp/criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /pSp/criteria/w_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class WNormLoss(nn.Module): 6 | 7 | def __init__(self, start_from_latent_avg=True): 8 | super(WNormLoss, self).__init__() 9 | self.start_from_latent_avg = start_from_latent_avg 10 | 11 | def forward(self, latent, latent_avg=None): 12 | if self.start_from_latent_avg: 13 | latent = latent - latent_avg 14 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] 15 | -------------------------------------------------------------------------------- /pSp/datasets/augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torchvision import transforms 6 | 7 | 8 | class ToOneHot(object): 9 | """ Convert the input PIL image to a one-hot torch tensor """ 10 | 11 | def __init__(self, n_classes=None): 12 | self.n_classes = n_classes 13 | 14 | def onehot_initialization(self, a): 15 | if self.n_classes is None: 16 | self.n_classes = len(np.unique(a)) 17 | out = np.zeros(a.shape + (self.n_classes,), dtype=int) 18 | out[self.__all_idx(a, axis=2)] = 1 19 | return out 20 | 21 | def __all_idx(self, idx, axis): 22 | grid = np.ogrid[tuple(map(slice, idx.shape))] 23 | grid.insert(axis, idx) 24 | return tuple(grid) 25 | 26 | def __call__(self, img): 27 | img = np.array(img) 28 | one_hot = self.onehot_initialization(img) 29 | return one_hot 30 | 31 | 32 | class BilinearResize(object): 33 | def __init__(self, factors=[1, 2, 4, 8, 16, 32]): 34 | self.factors = factors 35 | 36 | def __call__(self, image): 37 | factor = np.random.choice(self.factors, size=1)[0] 38 | D = BicubicDownSample(factor=factor, cuda=False) 39 | img_tensor = transforms.ToTensor()(image).unsqueeze(0) 40 | img_tensor_lr = D(img_tensor)[0].clamp(0, 1) 41 | img_low_res = transforms.ToPILImage()(img_tensor_lr) 42 | return img_low_res 43 | 44 | 45 | class BicubicDownSample(nn.Module): 46 | def bicubic_kernel(self, x, a=-0.50): 47 | """ 48 | This equation is exactly copied from the website below: 49 | https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic 50 | """ 51 | abs_x = torch.abs(x) 52 | if abs_x <= 1.: 53 | return (a + 2.) * torch.pow(abs_x, 3.) - (a + 3.) * torch.pow(abs_x, 2.) + 1 54 | elif 1. < abs_x < 2.: 55 | return a * torch.pow(abs_x, 3) - 5. * a * torch.pow(abs_x, 2.) + 8. * a * abs_x - 4. * a 56 | else: 57 | return 0.0 58 | 59 | def __init__(self, factor=4, cuda=True, padding='reflect'): 60 | super().__init__() 61 | self.factor = factor 62 | size = factor * 4 63 | k = torch.tensor([self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) 64 | for i in range(size)], dtype=torch.float32) 65 | k = k / torch.sum(k) 66 | k1 = torch.reshape(k, shape=(1, 1, size, 1)) 67 | self.k1 = torch.cat([k1, k1, k1], dim=0) 68 | k2 = torch.reshape(k, shape=(1, 1, 1, size)) 69 | self.k2 = torch.cat([k2, k2, k2], dim=0) 70 | self.cuda = '.cuda' if cuda else '' 71 | self.padding = padding 72 | for param in self.parameters(): 73 | param.requires_grad = False 74 | 75 | def forward(self, x, nhwc=False, clip_round=False, byte_output=False): 76 | filter_height = self.factor * 4 77 | filter_width = self.factor * 4 78 | stride = self.factor 79 | 80 | pad_along_height = max(filter_height - stride, 0) 81 | pad_along_width = max(filter_width - stride, 0) 82 | filters1 = self.k1.type('torch{}.FloatTensor'.format(self.cuda)) 83 | filters2 = self.k2.type('torch{}.FloatTensor'.format(self.cuda)) 84 | 85 | # compute actual padding values for each side 86 | pad_top = pad_along_height // 2 87 | pad_bottom = pad_along_height - pad_top 88 | pad_left = pad_along_width // 2 89 | pad_right = pad_along_width - pad_left 90 | 91 | # apply mirror padding 92 | if nhwc: 93 | x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW 94 | 95 | # downscaling performed by 1-d convolution 96 | x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) 97 | x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) 98 | if clip_round: 99 | x = torch.clamp(torch.round(x), 0.0, 255.) 100 | 101 | x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) 102 | x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) 103 | if clip_round: 104 | x = torch.clamp(torch.round(x), 0.0, 255.) 105 | 106 | if nhwc: 107 | x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) 108 | if byte_output: 109 | return x.type('torch.ByteTensor'.format(self.cuda)) 110 | else: 111 | return x 112 | -------------------------------------------------------------------------------- /pSp/datasets/gt_res_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | import os 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class GTResDataset(Dataset): 10 | 11 | def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): 12 | self.pairs = [] 13 | for f in os.listdir(root_path): 14 | image_path = os.path.join(root_path, f) 15 | gt_path = os.path.join(gt_dir, f) 16 | if f.endswith(".jpg") or f.endswith(".png"): 17 | self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) 18 | self.transform = transform 19 | self.transform_train = transform_train 20 | 21 | def __len__(self): 22 | return len(self.pairs) 23 | 24 | def __getitem__(self, index): 25 | from_path, to_path, _ = self.pairs[index] 26 | from_im = Image.open(from_path).convert('RGB') 27 | to_im = Image.open(to_path).convert('RGB') 28 | 29 | if self.transform: 30 | to_im = self.transform(to_im) 31 | from_im = self.transform(from_im) 32 | 33 | return from_im, to_im 34 | -------------------------------------------------------------------------------- /pSp/datasets/images_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | 4 | from pSp.utils import data_utils 5 | 6 | 7 | class ImagesDataset(Dataset): 8 | 9 | def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): 10 | self.source_paths = sorted(data_utils.make_dataset(source_root)) 11 | self.target_paths = sorted(data_utils.make_dataset(target_root)) 12 | self.source_transform = source_transform 13 | self.target_transform = target_transform 14 | self.opts = opts 15 | 16 | def __len__(self): 17 | return len(self.source_paths) 18 | 19 | def __getitem__(self, index): 20 | from_path = self.source_paths[index] 21 | from_im = Image.open(from_path) 22 | from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') 23 | 24 | to_path = self.target_paths[index] 25 | to_im = Image.open(to_path).convert('RGB') 26 | if self.target_transform: 27 | to_im = self.target_transform(to_im) 28 | 29 | if self.source_transform: 30 | from_im = self.source_transform(from_im) 31 | else: 32 | from_im = to_im 33 | 34 | return from_im, to_im 35 | -------------------------------------------------------------------------------- /pSp/datasets/inference_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | 4 | from pSp.utils import data_utils 5 | 6 | 7 | class InferenceDataset(Dataset): 8 | 9 | def __init__(self, root, opts, transform=None): 10 | self.paths = sorted(data_utils.make_dataset(root)) 11 | self.transform = transform 12 | self.opts = opts 13 | 14 | def __len__(self): 15 | return len(self.paths) 16 | 17 | def __getitem__(self, index): 18 | from_path = self.paths[index] 19 | from_im = Image.open(from_path) 20 | from_im = from_im.convert('RGB') if self.opts.label_nc == 0 else from_im.convert('L') 21 | if self.transform: 22 | from_im = self.transform(from_im) 23 | return from_im 24 | -------------------------------------------------------------------------------- /pSp/licenses/LICENSE_HuangYG123: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 HuangYG123 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /pSp/licenses/LICENSE_S-aiueo32: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Sou Uchida 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /pSp/licenses/LICENSE_TreB1eN: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 TreB1eN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /pSp/licenses/LICENSE_rosinality: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /pSp/models/encoders/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 5 | 6 | """ 7 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, input): 13 | return input.view(input.size(0), -1) 14 | 15 | 16 | def l2_norm(input, axis=1): 17 | norm = torch.norm(input, 2, axis, True) 18 | output = torch.div(input, norm) 19 | return output 20 | 21 | 22 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 23 | """ A named tuple describing a ResNet block. """ 24 | 25 | 26 | def get_block(in_channel, depth, num_units, stride=2): 27 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 28 | 29 | 30 | def get_blocks(num_layers): 31 | if num_layers == 50: 32 | blocks = [ 33 | get_block(in_channel=64, depth=64, num_units=3), 34 | get_block(in_channel=64, depth=128, num_units=4), 35 | get_block(in_channel=128, depth=256, num_units=14), 36 | get_block(in_channel=256, depth=512, num_units=3) 37 | ] 38 | elif num_layers == 100: 39 | blocks = [ 40 | get_block(in_channel=64, depth=64, num_units=3), 41 | get_block(in_channel=64, depth=128, num_units=13), 42 | get_block(in_channel=128, depth=256, num_units=30), 43 | get_block(in_channel=256, depth=512, num_units=3) 44 | ] 45 | elif num_layers == 152: 46 | blocks = [ 47 | get_block(in_channel=64, depth=64, num_units=3), 48 | get_block(in_channel=64, depth=128, num_units=8), 49 | get_block(in_channel=128, depth=256, num_units=36), 50 | get_block(in_channel=256, depth=512, num_units=3) 51 | ] 52 | else: 53 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 54 | return blocks 55 | 56 | 57 | class SEModule(Module): 58 | def __init__(self, channels, reduction): 59 | super(SEModule, self).__init__() 60 | self.avg_pool = AdaptiveAvgPool2d(1) 61 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 62 | self.relu = ReLU(inplace=True) 63 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 64 | self.sigmoid = Sigmoid() 65 | 66 | def forward(self, x): 67 | module_input = x 68 | x = self.avg_pool(x) 69 | x = self.fc1(x) 70 | x = self.relu(x) 71 | x = self.fc2(x) 72 | x = self.sigmoid(x) 73 | return module_input * x 74 | 75 | 76 | class bottleneck_IR(Module): 77 | def __init__(self, in_channel, depth, stride): 78 | super(bottleneck_IR, self).__init__() 79 | if in_channel == depth: 80 | self.shortcut_layer = MaxPool2d(1, stride) 81 | else: 82 | self.shortcut_layer = Sequential( 83 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 84 | BatchNorm2d(depth) 85 | ) 86 | self.res_layer = Sequential( 87 | BatchNorm2d(in_channel), 88 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 89 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 90 | ) 91 | 92 | def forward(self, x): 93 | shortcut = self.shortcut_layer(x) 94 | res = self.res_layer(x) 95 | return res + shortcut 96 | 97 | 98 | class bottleneck_IR_SE(Module): 99 | def __init__(self, in_channel, depth, stride): 100 | super(bottleneck_IR_SE, self).__init__() 101 | if in_channel == depth: 102 | self.shortcut_layer = MaxPool2d(1, stride) 103 | else: 104 | self.shortcut_layer = Sequential( 105 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 106 | BatchNorm2d(depth) 107 | ) 108 | self.res_layer = Sequential( 109 | BatchNorm2d(in_channel), 110 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 111 | PReLU(depth), 112 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 113 | BatchNorm2d(depth), 114 | SEModule(depth, 16) 115 | ) 116 | 117 | def forward(self, x): 118 | shortcut = self.shortcut_layer(x) 119 | res = self.res_layer(x) 120 | return res + shortcut 121 | -------------------------------------------------------------------------------- /pSp/models/encoders/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | 3 | from pSp.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 4 | 5 | """ 6 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Backbone(Module): 11 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 12 | super(Backbone, self).__init__() 13 | assert input_size in [112, 224], "input_size should be 112 or 224" 14 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 15 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 16 | blocks = get_blocks(num_layers) 17 | if mode == 'ir': 18 | unit_module = bottleneck_IR 19 | elif mode == 'ir_se': 20 | unit_module = bottleneck_IR_SE 21 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 22 | BatchNorm2d(64), 23 | PReLU(64)) 24 | if input_size == 112: 25 | self.output_layer = Sequential(BatchNorm2d(512), 26 | Dropout(drop_ratio), 27 | Flatten(), 28 | Linear(512 * 7 * 7, 512), 29 | BatchNorm1d(512, affine=affine)) 30 | else: 31 | self.output_layer = Sequential(BatchNorm2d(512), 32 | Dropout(drop_ratio), 33 | Flatten(), 34 | Linear(512 * 14 * 14, 512), 35 | BatchNorm1d(512, affine=affine)) 36 | 37 | modules = [] 38 | for block in blocks: 39 | for bottleneck in block: 40 | modules.append(unit_module(bottleneck.in_channel, 41 | bottleneck.depth, 42 | bottleneck.stride)) 43 | self.body = Sequential(*modules) 44 | 45 | def forward(self, x): 46 | x = self.input_layer(x) 47 | x = self.body(x) 48 | x = self.output_layer(x) 49 | return l2_norm(x) 50 | 51 | 52 | def IR_50(input_size): 53 | """Constructs a ir-50 model.""" 54 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 55 | return model 56 | 57 | 58 | def IR_101(input_size): 59 | """Constructs a ir-101 model.""" 60 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 61 | return model 62 | 63 | 64 | def IR_152(input_size): 65 | """Constructs a ir-152 model.""" 66 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 67 | return model 68 | 69 | 70 | def IR_SE_50(input_size): 71 | """Constructs a ir_se-50 model.""" 72 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 73 | return model 74 | 75 | 76 | def IR_SE_101(input_size): 77 | """Constructs a ir_se-101 model.""" 78 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 79 | return model 80 | 81 | 82 | def IR_SE_152(input_size): 83 | """Constructs a ir_se-152 model.""" 84 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 85 | return model 86 | -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | from pSp.models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face 6 | from pSp.models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 7 | from pSp.models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage 8 | from pSp.models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet 9 | 10 | device = 'cuda:0' 11 | 12 | 13 | class MTCNN(): 14 | def __init__(self): 15 | print(device) 16 | self.pnet = PNet().to(device) 17 | self.rnet = RNet().to(device) 18 | self.onet = ONet().to(device) 19 | self.pnet.eval() 20 | self.rnet.eval() 21 | self.onet.eval() 22 | self.refrence = get_reference_facial_points(default_square=True) 23 | 24 | def align(self, img): 25 | _, landmarks = self.detect_faces(img) 26 | if len(landmarks) == 0: 27 | return None, None 28 | facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] 29 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 30 | return Image.fromarray(warped_face), tfm 31 | 32 | def align_multi(self, img, limit=None, min_face_size=30.0): 33 | boxes, landmarks = self.detect_faces(img, min_face_size) 34 | if limit: 35 | boxes = boxes[:limit] 36 | landmarks = landmarks[:limit] 37 | faces = [] 38 | tfms = [] 39 | for landmark in landmarks: 40 | facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)] 41 | warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112)) 42 | faces.append(Image.fromarray(warped_face)) 43 | tfms.append(tfm) 44 | return boxes, faces, tfms 45 | 46 | def detect_faces(self, image, min_face_size=20.0, 47 | thresholds=[0.15, 0.25, 0.35], 48 | nms_thresholds=[0.7, 0.7, 0.7]): 49 | """ 50 | Arguments: 51 | image: an instance of PIL.Image. 52 | min_face_size: a float number. 53 | thresholds: a list of length 3. 54 | nms_thresholds: a list of length 3. 55 | 56 | Returns: 57 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 58 | bounding boxes and facial landmarks. 59 | """ 60 | 61 | # BUILD AN IMAGE PYRAMID 62 | width, height = image.size 63 | min_length = min(height, width) 64 | 65 | min_detection_size = 12 66 | factor = 0.707 # sqrt(0.5) 67 | 68 | # scales for scaling the image 69 | scales = [] 70 | 71 | # scales the image so that 72 | # minimum size that we can detect equals to 73 | # minimum face size that we want to detect 74 | m = min_detection_size / min_face_size 75 | min_length *= m 76 | 77 | factor_count = 0 78 | while min_length > min_detection_size: 79 | scales.append(m * factor ** factor_count) 80 | min_length *= factor 81 | factor_count += 1 82 | 83 | # STAGE 1 84 | 85 | # it will be returned 86 | bounding_boxes = [] 87 | 88 | with torch.no_grad(): 89 | # run P-Net on different scales 90 | for s in scales: 91 | boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0]) 92 | bounding_boxes.append(boxes) 93 | 94 | # collect boxes (and offsets, and scores) from different scales 95 | bounding_boxes = [i for i in bounding_boxes if i is not None] 96 | bounding_boxes = np.vstack(bounding_boxes) 97 | 98 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 99 | bounding_boxes = bounding_boxes[keep] 100 | 101 | # use offsets predicted by pnet to transform bounding boxes 102 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 103 | # shape [n_boxes, 5] 104 | 105 | bounding_boxes = convert_to_square(bounding_boxes) 106 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 107 | 108 | # STAGE 2 109 | 110 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 111 | img_boxes = torch.FloatTensor(img_boxes).to(device) 112 | 113 | output = self.rnet(img_boxes) 114 | offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4] 115 | probs = output[1].cpu().data.numpy() # shape [n_boxes, 2] 116 | 117 | keep = np.where(probs[:, 1] > thresholds[1])[0] 118 | bounding_boxes = bounding_boxes[keep] 119 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 120 | offsets = offsets[keep] 121 | 122 | keep = nms(bounding_boxes, nms_thresholds[1]) 123 | bounding_boxes = bounding_boxes[keep] 124 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 125 | bounding_boxes = convert_to_square(bounding_boxes) 126 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 127 | 128 | # STAGE 3 129 | 130 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 131 | if len(img_boxes) == 0: 132 | return [], [] 133 | img_boxes = torch.FloatTensor(img_boxes).to(device) 134 | output = self.onet(img_boxes) 135 | landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10] 136 | offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4] 137 | probs = output[2].cpu().data.numpy() # shape [n_boxes, 2] 138 | 139 | keep = np.where(probs[:, 1] > thresholds[2])[0] 140 | bounding_boxes = bounding_boxes[keep] 141 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 142 | offsets = offsets[keep] 143 | landmarks = landmarks[keep] 144 | 145 | # compute landmark points 146 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 147 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 148 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 149 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 150 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 151 | 152 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 153 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 154 | bounding_boxes = bounding_boxes[keep] 155 | landmarks = landmarks[keep] 156 | 157 | return bounding_boxes, landmarks 158 | -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .detector import detect_faces 2 | from .visualization_utils import show_bboxes 3 | -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn_pytorch/src/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 5 | from .first_stage import run_first_stage 6 | from .get_nets import PNet, RNet, ONet 7 | 8 | 9 | def detect_faces(image, min_face_size=20.0, 10 | thresholds=[0.6, 0.7, 0.8], 11 | nms_thresholds=[0.7, 0.7, 0.7]): 12 | """ 13 | Arguments: 14 | image: an instance of PIL.Image. 15 | min_face_size: a float number. 16 | thresholds: a list of length 3. 17 | nms_thresholds: a list of length 3. 18 | 19 | Returns: 20 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 21 | bounding boxes and facial landmarks. 22 | """ 23 | 24 | # LOAD MODELS 25 | pnet = PNet() 26 | rnet = RNet() 27 | onet = ONet() 28 | onet.eval() 29 | 30 | # BUILD AN IMAGE PYRAMID 31 | width, height = image.size 32 | min_length = min(height, width) 33 | 34 | min_detection_size = 12 35 | factor = 0.707 # sqrt(0.5) 36 | 37 | # scales for scaling the image 38 | scales = [] 39 | 40 | # scales the image so that 41 | # minimum size that we can detect equals to 42 | # minimum face size that we want to detect 43 | m = min_detection_size / min_face_size 44 | min_length *= m 45 | 46 | factor_count = 0 47 | while min_length > min_detection_size: 48 | scales.append(m * factor ** factor_count) 49 | min_length *= factor 50 | factor_count += 1 51 | 52 | # STAGE 1 53 | 54 | # it will be returned 55 | bounding_boxes = [] 56 | 57 | with torch.no_grad(): 58 | # run P-Net on different scales 59 | for s in scales: 60 | boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0]) 61 | bounding_boxes.append(boxes) 62 | 63 | # collect boxes (and offsets, and scores) from different scales 64 | bounding_boxes = [i for i in bounding_boxes if i is not None] 65 | bounding_boxes = np.vstack(bounding_boxes) 66 | 67 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 68 | bounding_boxes = bounding_boxes[keep] 69 | 70 | # use offsets predicted by pnet to transform bounding boxes 71 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 72 | # shape [n_boxes, 5] 73 | 74 | bounding_boxes = convert_to_square(bounding_boxes) 75 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 76 | 77 | # STAGE 2 78 | 79 | img_boxes = get_image_boxes(bounding_boxes, image, size=24) 80 | img_boxes = torch.FloatTensor(img_boxes) 81 | 82 | output = rnet(img_boxes) 83 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 84 | probs = output[1].data.numpy() # shape [n_boxes, 2] 85 | 86 | keep = np.where(probs[:, 1] > thresholds[1])[0] 87 | bounding_boxes = bounding_boxes[keep] 88 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 89 | offsets = offsets[keep] 90 | 91 | keep = nms(bounding_boxes, nms_thresholds[1]) 92 | bounding_boxes = bounding_boxes[keep] 93 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 94 | bounding_boxes = convert_to_square(bounding_boxes) 95 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 96 | 97 | # STAGE 3 98 | 99 | img_boxes = get_image_boxes(bounding_boxes, image, size=48) 100 | if len(img_boxes) == 0: 101 | return [], [] 102 | img_boxes = torch.FloatTensor(img_boxes) 103 | output = onet(img_boxes) 104 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 105 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 106 | probs = output[2].data.numpy() # shape [n_boxes, 2] 107 | 108 | keep = np.where(probs[:, 1] > thresholds[2])[0] 109 | bounding_boxes = bounding_boxes[keep] 110 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,)) 111 | offsets = offsets[keep] 112 | landmarks = landmarks[keep] 113 | 114 | # compute landmark points 115 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 116 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 117 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 118 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5] 119 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10] 120 | 121 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 122 | keep = nms(bounding_boxes, nms_thresholds[2], mode='min') 123 | bounding_boxes = bounding_boxes[keep] 124 | landmarks = landmarks[keep] 125 | 126 | return bounding_boxes, landmarks 127 | -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn_pytorch/src/first_stage.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | 7 | from .box_utils import nms, _preprocess 8 | 9 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | device = 'cuda:0' 11 | 12 | 13 | def run_first_stage(image, net, scale, threshold): 14 | """Run P-Net, generate bounding boxes, and do NMS. 15 | 16 | Arguments: 17 | image: an instance of PIL.Image. 18 | net: an instance of pytorch's nn.Module, P-Net. 19 | scale: a float number, 20 | scale width and height of the image by this number. 21 | threshold: a float number, 22 | threshold on the probability of a face when generating 23 | bounding boxes from predictions of the net. 24 | 25 | Returns: 26 | a float numpy array of shape [n_boxes, 9], 27 | bounding boxes with scores and offsets (4 + 1 + 4). 28 | """ 29 | 30 | # scale the image and convert it to a float array 31 | width, height = image.size 32 | sw, sh = math.ceil(width * scale), math.ceil(height * scale) 33 | img = image.resize((sw, sh), Image.BILINEAR) 34 | img = np.asarray(img, 'float32') 35 | 36 | img = torch.FloatTensor(_preprocess(img)).to(device) 37 | with torch.no_grad(): 38 | output = net(img) 39 | probs = output[1].cpu().data.numpy()[0, 1, :, :] 40 | offsets = output[0].cpu().data.numpy() 41 | # probs: probability of a face at each sliding window 42 | # offsets: transformations to true bounding boxes 43 | 44 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 45 | if len(boxes) == 0: 46 | return None 47 | 48 | keep = nms(boxes[:, 0:5], overlap_threshold=0.5) 49 | return boxes[keep] 50 | 51 | 52 | def _generate_bboxes(probs, offsets, scale, threshold): 53 | """Generate bounding boxes at places 54 | where there is probably a face. 55 | 56 | Arguments: 57 | probs: a float numpy array of shape [n, m]. 58 | offsets: a float numpy array of shape [1, 4, n, m]. 59 | scale: a float number, 60 | width and height of the image were scaled by this number. 61 | threshold: a float number. 62 | 63 | Returns: 64 | a float numpy array of shape [n_boxes, 9] 65 | """ 66 | 67 | # applying P-Net is equivalent, in some sense, to 68 | # moving 12x12 window with stride 2 69 | stride = 2 70 | cell_size = 12 71 | 72 | # indices of boxes where there is probably a face 73 | inds = np.where(probs > threshold) 74 | 75 | if inds[0].size == 0: 76 | return np.array([]) 77 | 78 | # transformations of bounding boxes 79 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 80 | # they are defined as: 81 | # w = x2 - x1 + 1 82 | # h = y2 - y1 + 1 83 | # x1_true = x1 + tx1*w 84 | # x2_true = x2 + tx2*w 85 | # y1_true = y1 + ty1*h 86 | # y2_true = y2 + ty2*h 87 | 88 | offsets = np.array([tx1, ty1, tx2, ty2]) 89 | score = probs[inds[0], inds[1]] 90 | 91 | # P-Net is applied to scaled images 92 | # so we need to rescale bounding boxes back 93 | bounding_boxes = np.vstack([ 94 | np.round((stride * inds[1] + 1.0) / scale), 95 | np.round((stride * inds[0] + 1.0) / scale), 96 | np.round((stride * inds[1] + 1.0 + cell_size) / scale), 97 | np.round((stride * inds[0] + 1.0 + cell_size) / scale), 98 | score, offsets 99 | ]) 100 | # why one is added? 101 | 102 | return bounding_boxes.T 103 | -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn_pytorch/src/get_nets.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from pSp.configs.paths_config import model_paths 9 | 10 | PNET_PATH = model_paths["mtcnn_pnet"] 11 | ONET_PATH = model_paths["mtcnn_onet"] 12 | RNET_PATH = model_paths["mtcnn_rnet"] 13 | 14 | 15 | class Flatten(nn.Module): 16 | 17 | def __init__(self): 18 | super(Flatten, self).__init__() 19 | 20 | def forward(self, x): 21 | """ 22 | Arguments: 23 | x: a float tensor with shape [batch_size, c, h, w]. 24 | Returns: 25 | a float tensor with shape [batch_size, c*h*w]. 26 | """ 27 | 28 | # without this pretrained model isn't working 29 | x = x.transpose(3, 2).contiguous() 30 | 31 | return x.view(x.size(0), -1) 32 | 33 | 34 | class PNet(nn.Module): 35 | 36 | def __init__(self): 37 | super().__init__() 38 | 39 | # suppose we have input with size HxW, then 40 | # after first layer: H - 2, 41 | # after pool: ceil((H - 2)/2), 42 | # after second conv: ceil((H - 2)/2) - 2, 43 | # after last conv: ceil((H - 2)/2) - 4, 44 | # and the same for W 45 | 46 | self.features = nn.Sequential(OrderedDict([ 47 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 48 | ('prelu1', nn.PReLU(10)), 49 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)), 50 | 51 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 52 | ('prelu2', nn.PReLU(16)), 53 | 54 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 55 | ('prelu3', nn.PReLU(32)) 56 | ])) 57 | 58 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 59 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 60 | 61 | weights = np.load(PNET_PATH, allow_pickle=True)[()] 62 | for n, p in self.named_parameters(): 63 | p.data = torch.FloatTensor(weights[n]) 64 | 65 | def forward(self, x): 66 | """ 67 | Arguments: 68 | x: a float tensor with shape [batch_size, 3, h, w]. 69 | Returns: 70 | b: a float tensor with shape [batch_size, 4, h', w']. 71 | a: a float tensor with shape [batch_size, 2, h', w']. 72 | """ 73 | x = self.features(x) 74 | a = self.conv4_1(x) 75 | b = self.conv4_2(x) 76 | a = F.softmax(a, dim=-1) 77 | return b, a 78 | 79 | 80 | class RNet(nn.Module): 81 | 82 | def __init__(self): 83 | super().__init__() 84 | 85 | self.features = nn.Sequential(OrderedDict([ 86 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 87 | ('prelu1', nn.PReLU(28)), 88 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 89 | 90 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 91 | ('prelu2', nn.PReLU(48)), 92 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 93 | 94 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 95 | ('prelu3', nn.PReLU(64)), 96 | 97 | ('flatten', Flatten()), 98 | ('conv4', nn.Linear(576, 128)), 99 | ('prelu4', nn.PReLU(128)) 100 | ])) 101 | 102 | self.conv5_1 = nn.Linear(128, 2) 103 | self.conv5_2 = nn.Linear(128, 4) 104 | 105 | weights = np.load(RNET_PATH, allow_pickle=True)[()] 106 | for n, p in self.named_parameters(): 107 | p.data = torch.FloatTensor(weights[n]) 108 | 109 | def forward(self, x): 110 | """ 111 | Arguments: 112 | x: a float tensor with shape [batch_size, 3, h, w]. 113 | Returns: 114 | b: a float tensor with shape [batch_size, 4]. 115 | a: a float tensor with shape [batch_size, 2]. 116 | """ 117 | x = self.features(x) 118 | a = self.conv5_1(x) 119 | b = self.conv5_2(x) 120 | a = F.softmax(a, dim=-1) 121 | return b, a 122 | 123 | 124 | class ONet(nn.Module): 125 | 126 | def __init__(self): 127 | super().__init__() 128 | 129 | self.features = nn.Sequential(OrderedDict([ 130 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 131 | ('prelu1', nn.PReLU(32)), 132 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)), 133 | 134 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 135 | ('prelu2', nn.PReLU(64)), 136 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)), 137 | 138 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 139 | ('prelu3', nn.PReLU(64)), 140 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)), 141 | 142 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 143 | ('prelu4', nn.PReLU(128)), 144 | 145 | ('flatten', Flatten()), 146 | ('conv5', nn.Linear(1152, 256)), 147 | ('drop5', nn.Dropout(0.25)), 148 | ('prelu5', nn.PReLU(256)), 149 | ])) 150 | 151 | self.conv6_1 = nn.Linear(256, 2) 152 | self.conv6_2 = nn.Linear(256, 4) 153 | self.conv6_3 = nn.Linear(256, 10) 154 | 155 | weights = np.load(ONET_PATH, allow_pickle=True)[()] 156 | for n, p in self.named_parameters(): 157 | p.data = torch.FloatTensor(weights[n]) 158 | 159 | def forward(self, x): 160 | """ 161 | Arguments: 162 | x: a float tensor with shape [batch_size, 3, h, w]. 163 | Returns: 164 | c: a float tensor with shape [batch_size, 10]. 165 | b: a float tensor with shape [batch_size, 4]. 166 | a: a float tensor with shape [batch_size, 2]. 167 | """ 168 | x = self.features(x) 169 | a = self.conv6_1(x) 170 | b = self.conv6_2(x) 171 | c = self.conv6_3(x) 172 | a = F.softmax(a, dim=-1) 173 | return c, b, a 174 | -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn_pytorch/src/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_bboxes(img, bounding_boxes, facial_landmarks=[]): 5 | """Draw bounding boxes and facial landmarks. 6 | 7 | Arguments: 8 | img: an instance of PIL.Image. 9 | bounding_boxes: a float numpy array of shape [n, 5]. 10 | facial_landmarks: a float numpy array of shape [n, 10]. 11 | 12 | Returns: 13 | an instance of PIL.Image. 14 | """ 15 | 16 | img_copy = img.copy() 17 | draw = ImageDraw.Draw(img_copy) 18 | 19 | for b in bounding_boxes: 20 | draw.rectangle([ 21 | (b[0], b[1]), (b[2], b[3]) 22 | ], outline='white') 23 | 24 | for p in facial_landmarks: 25 | for i in range(5): 26 | draw.ellipse([ 27 | (p[i] - 1.0, p[i + 5] - 1.0), 28 | (p[i] + 1.0, p[i + 5] + 1.0) 29 | ], outline='blue') 30 | 31 | return img_copy 32 | -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/pSp/models/mtcnn/mtcnn_pytorch/src/weights/onet.npy -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/pSp/models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy -------------------------------------------------------------------------------- /pSp/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/pSp/models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy -------------------------------------------------------------------------------- /pSp/options/test_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class TestOptions: 5 | 6 | def __init__(self): 7 | self.parser = ArgumentParser() 8 | self.initialize() 9 | 10 | def initialize(self): 11 | # arguments for inference script 12 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 13 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') 14 | self.parser.add_argument('--data_path', type=str, default='gt_images', 15 | help='Path to directory of images to evaluate') 16 | self.parser.add_argument('--couple_outputs', action='store_true', 17 | help='Whether to also save inputs + outputs side-by-side') 18 | self.parser.add_argument('--resize_outputs', action='store_true', 19 | help='Whether to resize outputs to 256x256 or keep at 1024x1024') 20 | 21 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') 22 | self.parser.add_argument('--test_workers', default=2, type=int, 23 | help='Number of test/inference dataloader workers') 24 | 25 | # arguments for style-mixing script 26 | self.parser.add_argument('--n_images', type=int, default=None, 27 | help='Number of images to output. If None, run on all data') 28 | self.parser.add_argument('--n_outputs_to_generate', type=int, default=5, 29 | help='Number of outputs to generate per input image.') 30 | self.parser.add_argument('--mix_alpha', type=float, default=None, help='Alpha value for style-mixing') 31 | self.parser.add_argument('--latent_mask', type=str, default=None, 32 | help='Comma-separated list of latents to perform style-mixing with') 33 | 34 | # arguments for super-resolution 35 | self.parser.add_argument('--resize_factors', type=str, default=None, 36 | help='Downsampling factor for super-res (should be a single value for inference).') 37 | 38 | def parse(self): 39 | opts = self.parser.parse_args() 40 | return opts 41 | -------------------------------------------------------------------------------- /pSp/options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from pSp.configs.paths_config import model_paths 4 | 5 | 6 | class TrainOptions: 7 | 8 | def __init__(self): 9 | self.parser = ArgumentParser() 10 | self.initialize() 11 | 12 | def initialize(self): 13 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 14 | 15 | self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, 16 | help='Type of dataset/experiment to run') 17 | self.parser.add_argument('--encoder_type', default='GradualStyleEncoder', type=str, help='Which encoder to use') 18 | self.parser.add_argument('--input_nc', default=3, type=int, 19 | help='Number of input image channels to the psp encoder') 20 | self.parser.add_argument('--label_nc', default=0, type=int, 21 | help='Number of input label channels to the psp encoder') 22 | self.parser.add_argument('--output_size', default=1024, type=int, help='Output size of generator') 23 | 24 | self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training') 25 | self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') 26 | self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') 27 | self.parser.add_argument('--test_workers', default=2, type=int, 28 | help='Number of test/inference dataloader workers') 29 | 30 | self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate') 31 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') 32 | self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') 33 | self.parser.add_argument('--start_from_latent_avg', action='store_true', 34 | help='Whether to add average latent vector to generate codes from encoder.') 35 | self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space insteaf of w+') 36 | 37 | self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor') 38 | self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') 39 | self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor') 40 | self.parser.add_argument('--w_norm_lambda', default=0, type=float, help='W-norm loss multiplier factor') 41 | self.parser.add_argument('--lpips_lambda_crop', default=0, type=float, 42 | help='LPIPS loss multiplier factor for inner image region') 43 | self.parser.add_argument('--l2_lambda_crop', default=0, type=float, 44 | help='L2 loss multiplier factor for inner image region') 45 | 46 | self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str, 47 | help='Path to StyleGAN model weights') 48 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') 49 | 50 | self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps') 51 | self.parser.add_argument('--image_interval', default=1000, type=int, 52 | help='Interval for logging train images during training') 53 | self.parser.add_argument('--board_interval', default=50, type=int, 54 | help='Interval for logging metrics to tensorboard') 55 | self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval') 56 | self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval') 57 | 58 | # arguments for super-resolution 59 | self.parser.add_argument('--resize_factors', type=str, default=None, 60 | help='For super-res, comma-separated resize factors to use for inference.') 61 | 62 | def parse(self): 63 | opts = self.parser.parse_args() 64 | return opts 65 | -------------------------------------------------------------------------------- /pSp/scripts/calc_id_loss_parallel.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import multiprocessing as mp 4 | import os 5 | import sys 6 | import time 7 | from argparse import ArgumentParser 8 | 9 | import numpy as np 10 | import torch 11 | import torchvision.transforms as trans 12 | from PIL import Image 13 | 14 | sys.path.append(".") 15 | sys.path.append("..") 16 | 17 | from models.mtcnn.mtcnn import MTCNN 18 | from models.encoders.model_irse import IR_101 19 | from configs.paths_config import model_paths 20 | 21 | CIRCULAR_FACE_PATH = model_paths['circular_face'] 22 | 23 | 24 | def chunks(lst, n): 25 | """Yield successive n-sized chunks from lst.""" 26 | for i in range(0, len(lst), n): 27 | yield lst[i:i + n] 28 | 29 | 30 | def extract_on_paths(file_paths): 31 | facenet = IR_101(input_size=112) 32 | facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH)) 33 | facenet.cuda() 34 | facenet.eval() 35 | mtcnn = MTCNN() 36 | id_transform = trans.Compose([ 37 | trans.ToTensor(), 38 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 39 | ]) 40 | 41 | pid = mp.current_process().name 42 | print('\t{} is starting to extract on {} images'.format(pid, len(file_paths))) 43 | tot_count = len(file_paths) 44 | count = 0 45 | 46 | scores_dict = {} 47 | for res_path, gt_path in file_paths: 48 | count += 1 49 | if count % 100 == 0: 50 | print('{} done with {}/{}'.format(pid, count, tot_count)) 51 | if True: 52 | input_im = Image.open(res_path) 53 | input_im, _ = mtcnn.align(input_im) 54 | if input_im is None: 55 | print('{} skipping {}'.format(pid, res_path)) 56 | continue 57 | 58 | input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0] 59 | 60 | result_im = Image.open(gt_path) 61 | result_im, _ = mtcnn.align(result_im) 62 | if result_im is None: 63 | print('{} skipping {}'.format(pid, gt_path)) 64 | continue 65 | 66 | result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0] 67 | score = float(input_id.dot(result_id)) 68 | scores_dict[os.path.basename(gt_path)] = score 69 | 70 | return scores_dict 71 | 72 | 73 | def parse_args(): 74 | parser = ArgumentParser(add_help=False) 75 | parser.add_argument('--num_threads', type=int, default=4) 76 | parser.add_argument('--data_path', type=str, default='results') 77 | parser.add_argument('--gt_path', type=str, default='gt_images') 78 | args = parser.parse_args() 79 | return args 80 | 81 | 82 | def run(args): 83 | file_paths = [] 84 | for f in os.listdir(args.data_path): 85 | image_path = os.path.join(args.data_path, f) 86 | gt_path = os.path.join(args.gt_path, f) 87 | if f.endswith(".jpg") or f.endswith('.png'): 88 | file_paths.append([image_path, gt_path.replace('.png', '.jpg')]) 89 | 90 | file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads)))) 91 | pool = mp.Pool(args.num_threads) 92 | print('Running on {} paths\nHere we goooo'.format(len(file_paths))) 93 | 94 | tic = time.time() 95 | results = pool.map(extract_on_paths, file_chunks) 96 | scores_dict = {} 97 | for d in results: 98 | scores_dict.update(d) 99 | 100 | all_scores = list(scores_dict.values()) 101 | mean = np.mean(all_scores) 102 | std = np.std(all_scores) 103 | result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std) 104 | print(result_str) 105 | 106 | out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 107 | if not os.path.exists(out_path): 108 | os.makedirs(out_path) 109 | 110 | with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f: 111 | f.write(result_str) 112 | with open(os.path.join(out_path, 'scores_id.json'), 'w') as f: 113 | json.dump(scores_dict, f) 114 | 115 | toc = time.time() 116 | print('Mischief managed in {}s'.format(toc - tic)) 117 | 118 | 119 | if __name__ == '__main__': 120 | args = parse_args() 121 | run(args) 122 | -------------------------------------------------------------------------------- /pSp/scripts/calc_losses_on_images.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from argparse import ArgumentParser 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | sys.path.append(".") 13 | sys.path.append("..") 14 | 15 | from pSp.criteria.lpips.lpips import LPIPS 16 | from pSp.datasets.gt_res_dataset import GTResDataset 17 | 18 | 19 | def parse_args(): 20 | parser = ArgumentParser(add_help=False) 21 | parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) 22 | parser.add_argument('--data_path', type=str, default='results') 23 | parser.add_argument('--gt_path', type=str, default='gt_images') 24 | parser.add_argument('--workers', type=int, default=4) 25 | parser.add_argument('--batch_size', type=int, default=4) 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def run(args): 31 | transform = transforms.Compose([transforms.Resize((256, 256)), 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 34 | 35 | print('Loading dataset') 36 | dataset = GTResDataset(root_path=args.data_path, 37 | gt_dir=args.gt_path, 38 | transform=transform) 39 | 40 | dataloader = DataLoader(dataset, 41 | batch_size=args.batch_size, 42 | shuffle=False, 43 | num_workers=int(args.workers), 44 | drop_last=True) 45 | 46 | if args.mode == 'lpips': 47 | loss_func = LPIPS(net_type='alex') 48 | elif args.mode == 'l2': 49 | loss_func = torch.nn.MSELoss() 50 | else: 51 | raise Exception('Not a valid mode!') 52 | loss_func.cuda() 53 | 54 | global_i = 0 55 | scores_dict = {} 56 | all_scores = [] 57 | for result_batch, gt_batch in tqdm(dataloader): 58 | for i in range(args.batch_size): 59 | loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda())) 60 | all_scores.append(loss) 61 | im_path = dataset.pairs[global_i][0] 62 | scores_dict[os.path.basename(im_path)] = loss 63 | global_i += 1 64 | 65 | all_scores = list(scores_dict.values()) 66 | mean = np.mean(all_scores) 67 | std = np.std(all_scores) 68 | result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) 69 | print('Finished with ', args.data_path) 70 | print(result_str) 71 | 72 | out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') 73 | if not os.path.exists(out_path): 74 | os.makedirs(out_path) 75 | 76 | with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: 77 | f.write(result_str) 78 | with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: 79 | json.dump(scores_dict, f) 80 | 81 | 82 | if __name__ == '__main__': 83 | args = parse_args() 84 | run(args) 85 | -------------------------------------------------------------------------------- /pSp/scripts/generate_sketch_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | from torch.utils.serialization import load_lua 6 | from torchvision import transforms 7 | from torchvision.utils import save_image 8 | 9 | """ 10 | NOTE!: Must have torch==0.4.1 and torchvision==0.2.1 11 | The sketch simplification model (sketch_gan.t7) from Simo Serra et al. can be downloaded from their official implementation: 12 | https://github.com/bobbens/sketch_simplification 13 | """ 14 | 15 | 16 | def sobel(img): 17 | opImgx = cv2.Sobel(img, cv2.CV_8U, 0, 1, ksize=3) 18 | opImgy = cv2.Sobel(img, cv2.CV_8U, 1, 0, ksize=3) 19 | return cv2.bitwise_or(opImgx, opImgy) 20 | 21 | 22 | def sketch(frame): 23 | frame = cv2.GaussianBlur(frame, (3, 3), 0) 24 | invImg = 255 - frame 25 | edgImg0 = sobel(frame) 26 | edgImg1 = sobel(invImg) 27 | edgImg = cv2.addWeighted(edgImg0, 0.75, edgImg1, 0.75, 0) 28 | opImg = 255 - edgImg 29 | return opImg 30 | 31 | 32 | def get_sketch_image(image_path): 33 | original = cv2.imread(image_path) 34 | original = cv2.cvtColor(original, cv2.COLOR_BGR2GRAY) 35 | sketch_image = sketch(original) 36 | return sketch_image[:, :, np.newaxis] 37 | 38 | 39 | use_cuda = True 40 | 41 | cache = load_lua("/path/to/sketch_gan.t7") 42 | model = cache.model 43 | immean = cache.mean 44 | imstd = cache.std 45 | model.evaluate() 46 | 47 | data_path = "/path/to/data/imgs" 48 | images = [os.path.join(data_path, f) for f in os.listdir(data_path)] 49 | 50 | output_dir = "/path/to/data/edges" 51 | if not os.path.exists(output_dir): 52 | os.makedirs(output_dir) 53 | 54 | for idx, image_path in enumerate(images): 55 | if idx % 50 == 0: 56 | print("{} out of {}".format(idx, len(images))) 57 | data = get_sketch_image(image_path) 58 | data = ((transforms.ToTensor()(data) - immean) / imstd).unsqueeze(0) 59 | if use_cuda: 60 | pred = model.cuda().forward(data.cuda()).float() 61 | else: 62 | pred = model.forward(data) 63 | save_image(pred[0], os.path.join(output_dir, "{}_edges.jpg".format(image_path.split("/")[-1].split('.')[0]))) 64 | -------------------------------------------------------------------------------- /pSp/scripts/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from argparse import Namespace 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | sys.path.append(".") 14 | sys.path.append("..") 15 | 16 | from pSp.configs import data_configs 17 | from pSp.datasets.inference_dataset import InferenceDataset 18 | from pSp.utils.common import tensor2im, log_input_image 19 | from pSp.options.test_options import TestOptions 20 | from pSp.models.psp import pSp 21 | import pSp.lpips as lpips 22 | 23 | 24 | def run(): 25 | test_opts = TestOptions().parse() 26 | 27 | if test_opts.resize_factors is not None: 28 | assert len( 29 | test_opts.resize_factors.split(',')) == 1, "When running inference, provide a single downsampling factor!" 30 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results', 31 | 'downsampling_{}'.format(test_opts.resize_factors)) 32 | out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled', 33 | 'downsampling_{}'.format(test_opts.resize_factors)) 34 | else: 35 | out_path_results = os.path.join(test_opts.exp_dir, 'inference_results') 36 | out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled') 37 | 38 | os.makedirs(out_path_results, exist_ok=True) 39 | os.makedirs(out_path_coupled, exist_ok=True) 40 | 41 | # update test options with options used during training 42 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 43 | opts = ckpt['opts'] 44 | opts.update(vars(test_opts)) 45 | if 'learn_in_w' not in opts: 46 | opts['learn_in_w'] = False 47 | opts = Namespace(**opts) 48 | 49 | net = pSp(opts) 50 | net.eval() 51 | net.cuda() 52 | 53 | print('Loading dataset for {}'.format(opts.dataset_type)) 54 | dataset_args = data_configs.DATASETS[opts.dataset_type] 55 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 56 | dataset = InferenceDataset(root=opts.data_path, 57 | transform=transforms_dict['transform_inference'], 58 | opts=opts) 59 | dataloader = DataLoader(dataset, 60 | batch_size=opts.test_batch_size, 61 | shuffle=False, 62 | num_workers=int(opts.test_workers), 63 | drop_last=True) 64 | 65 | if opts.n_images is None: 66 | opts.n_images = len(dataset) 67 | 68 | device = 'cuda' 69 | percept = lpips.PerceptualLoss( 70 | model='net-lin', net='vgg', use_gpu=device.startswith('cuda') 71 | ) 72 | 73 | global_i = 0 74 | global_time = [] 75 | perceptual_values = [] 76 | mse_values = [] 77 | for input_batch in tqdm(dataloader): 78 | if global_i >= opts.n_images: 79 | break 80 | with torch.no_grad(): 81 | input_cuda = input_batch.cuda().float() 82 | tic = time.time() 83 | result_batch = run_on_batch(input_cuda, net, opts) 84 | toc = time.time() 85 | global_time.append(toc - tic) 86 | 87 | p_loss = percept(result_batch, input_cuda) 88 | mse_loss = F.mse_loss(result_batch, input_cuda) 89 | 90 | perceptual_values.append(p_loss.item()) 91 | mse_values.append(mse_loss.item()) 92 | 93 | for i in range(opts.test_batch_size): 94 | result = tensor2im(result_batch[i]) 95 | im_path = dataset.paths[global_i] 96 | 97 | if opts.couple_outputs or global_i % 100 == 0: 98 | input_im = log_input_image(input_batch[i], opts) 99 | resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024) 100 | if opts.resize_factors is not None: 101 | # for super resolution, save the original, down-sampled, and output 102 | source = Image.open(im_path) 103 | res = np.concatenate([np.array(source.resize(resize_amount)), 104 | np.array(input_im.resize(resize_amount, resample=Image.NEAREST)), 105 | np.array(result.resize(resize_amount))], axis=1) 106 | else: 107 | # otherwise, save the original and output 108 | res = np.concatenate([np.array(input_im.resize(resize_amount)), 109 | np.array(result.resize(resize_amount))], axis=1) 110 | Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path))) 111 | 112 | im_save_path = os.path.join(out_path_results, os.path.basename(im_path)) 113 | Image.fromarray(np.array(result)).save(im_save_path) 114 | 115 | global_i += 1 116 | 117 | stats_path = os.path.join(opts.exp_dir, 'stats.txt') 118 | result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time)) 119 | print(result_str) 120 | 121 | np.save(os.path.join(opts.exp_dir, 'perceptual.npy'), perceptual_values) 122 | np.save(os.path.join(opts.exp_dir, 'mse.npy'), mse_values) 123 | 124 | with open(stats_path, 'w') as f: 125 | f.write(result_str) 126 | 127 | 128 | def run_on_batch(inputs, net, opts): 129 | if opts.latent_mask is None: 130 | result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs) 131 | else: 132 | latent_mask = [int(l) for l in opts.latent_mask.split(",")] 133 | result_batch = [] 134 | for image_idx, input_image in enumerate(inputs): 135 | # get latent vector to inject into our input image 136 | vec_to_inject = np.random.randn(1, 512).astype('float32') 137 | _, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"), 138 | input_code=True, 139 | return_latents=True) 140 | # get output image with injected style vector 141 | res = net(input_image.unsqueeze(0).to("cuda").float(), 142 | latent_mask=latent_mask, 143 | inject_latent=latent_to_inject, 144 | alpha=opts.mix_alpha, 145 | resize=opts.resize_outputs) 146 | result_batch.append(res) 147 | result_batch = torch.cat(result_batch, dim=0) 148 | return result_batch 149 | 150 | 151 | if __name__ == '__main__': 152 | run() 153 | -------------------------------------------------------------------------------- /pSp/scripts/style_mixing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from argparse import Namespace 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | sys.path.append(".") 12 | sys.path.append("..") 13 | 14 | from pSp.configs import data_configs 15 | from pSp.datasets.inference_dataset import InferenceDataset 16 | from pSp.utils.common import tensor2im, log_input_image 17 | from pSp.options.test_options import TestOptions 18 | from pSp.models.psp import pSp 19 | 20 | 21 | def run(): 22 | test_opts = TestOptions().parse() 23 | 24 | if test_opts.resize_factors is not None: 25 | factors = test_opts.resize_factors.split(',') 26 | assert len(factors) == 1, "When running inference, please provide a single downsampling factor!" 27 | mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing', 28 | 'downsampling_{}'.format(test_opts.resize_factors)) 29 | else: 30 | mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing') 31 | os.makedirs(mixed_path_results, exist_ok=True) 32 | 33 | # update test options with options used during training 34 | ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') 35 | opts = ckpt['opts'] 36 | opts.update(vars(test_opts)) 37 | if 'learn_in_w' not in opts: 38 | opts['learn_in_w'] = False 39 | opts = Namespace(**opts) 40 | 41 | net = pSp(opts) 42 | net.eval() 43 | net.cuda() 44 | 45 | print('Loading dataset for {}'.format(opts.dataset_type)) 46 | dataset_args = data_configs.DATASETS[opts.dataset_type] 47 | transforms_dict = dataset_args['transforms'](opts).get_transforms() 48 | dataset = InferenceDataset(root=opts.data_path, 49 | transform=transforms_dict['transform_inference'], 50 | opts=opts) 51 | dataloader = DataLoader(dataset, 52 | batch_size=opts.test_batch_size, 53 | shuffle=False, 54 | num_workers=int(opts.test_workers), 55 | drop_last=True) 56 | 57 | latent_mask = [int(l) for l in opts.latent_mask.split(",")] 58 | if opts.n_images is None: 59 | opts.n_images = len(dataset) 60 | 61 | global_i = 0 62 | for input_batch in tqdm(dataloader): 63 | if global_i >= opts.n_images: 64 | break 65 | with torch.no_grad(): 66 | input_batch = input_batch.cuda() 67 | for image_idx, input_image in enumerate(input_batch): 68 | # generate random vectors to inject into input image 69 | vecs_to_inject = np.random.randn(opts.n_outputs_to_generate, 512).astype('float32') 70 | multi_modal_outputs = [] 71 | for vec_to_inject in vecs_to_inject: 72 | cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to("cuda") 73 | # get latent vector to inject into our input image 74 | _, latent_to_inject = net(cur_vec, 75 | input_code=True, 76 | return_latents=True) 77 | # get output image with injected style vector 78 | res = net(input_image.unsqueeze(0).to("cuda").float(), 79 | latent_mask=latent_mask, 80 | inject_latent=latent_to_inject, 81 | alpha=opts.mix_alpha, 82 | resize=opts.resize_outputs) 83 | multi_modal_outputs.append(res[0]) 84 | 85 | # visualize multi modal outputs 86 | input_im_path = dataset.paths[global_i] 87 | image = input_batch[image_idx] 88 | input_image = log_input_image(image, opts) 89 | resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024) 90 | res = np.array(input_image.resize(resize_amount)) 91 | for output in multi_modal_outputs: 92 | output = tensor2im(output) 93 | res = np.concatenate([res, np.array(output.resize(resize_amount))], axis=1) 94 | Image.fromarray(res).save(os.path.join(mixed_path_results, os.path.basename(input_im_path))) 95 | global_i += 1 96 | 97 | 98 | if __name__ == '__main__': 99 | run() 100 | -------------------------------------------------------------------------------- /pSp/scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import json 5 | import os 6 | import pprint 7 | import sys 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from pSp.options.train_options import TrainOptions 13 | from pSp.training.coach import Coach 14 | 15 | 16 | def main(): 17 | opts = TrainOptions().parse() 18 | os.makedirs(opts.exp_dir, exist_ok=True) 19 | 20 | opts_dict = vars(opts) 21 | pprint.pprint(opts_dict) 22 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 23 | json.dump(opts_dict, f, indent=4, sort_keys=True) 24 | 25 | coach = Coach(opts) 26 | coach.train() 27 | 28 | 29 | if __name__ == '__main__': 30 | main() 31 | -------------------------------------------------------------------------------- /pSp/utils/common.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | # Log images 8 | def log_input_image(x, opts): 9 | if opts.label_nc == 0: 10 | return tensor2im(x) 11 | elif opts.label_nc == 1: 12 | return tensor2sketch(x) 13 | else: 14 | return tensor2map(x) 15 | 16 | 17 | def tensor2im(var): 18 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 19 | var = ((var + 1) / 2) 20 | var[var < 0] = 0 21 | var[var > 1] = 1 22 | var = var * 255 23 | return Image.fromarray(var.astype('uint8')) 24 | 25 | 26 | def tensor2map(var): 27 | mask = np.argmax(var.data.cpu().numpy(), axis=0) 28 | colors = get_colors() 29 | mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3)) 30 | for class_idx in np.unique(mask): 31 | mask_image[mask == class_idx] = colors[class_idx] 32 | mask_image = mask_image.astype('uint8') 33 | return Image.fromarray(mask_image) 34 | 35 | 36 | def tensor2sketch(var): 37 | im = var[0].cpu().detach().numpy() 38 | im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) 39 | im = (im * 255).astype(np.uint8) 40 | return Image.fromarray(im) 41 | 42 | 43 | # Visualization utils 44 | def get_colors(): 45 | # currently support up to 19 classes (for the celebs-hq-mask dataset) 46 | colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], 47 | [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], 48 | [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] 49 | return colors 50 | 51 | 52 | def vis_faces(log_hooks): 53 | display_count = len(log_hooks) 54 | fig = plt.figure(figsize=(8, 4 * display_count)) 55 | gs = fig.add_gridspec(display_count, 3) 56 | for i in range(display_count): 57 | hooks_dict = log_hooks[i] 58 | fig.add_subplot(gs[i, 0]) 59 | if 'diff_input' in hooks_dict: 60 | vis_faces_with_id(hooks_dict, fig, gs, i) 61 | else: 62 | vis_faces_no_id(hooks_dict, fig, gs, i) 63 | plt.tight_layout() 64 | return fig 65 | 66 | 67 | def vis_faces_with_id(hooks_dict, fig, gs, i): 68 | plt.imshow(hooks_dict['input_face']) 69 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 70 | fig.add_subplot(gs[i, 1]) 71 | plt.imshow(hooks_dict['target_face']) 72 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), 73 | float(hooks_dict['diff_target']))) 74 | fig.add_subplot(gs[i, 2]) 75 | plt.imshow(hooks_dict['output_face']) 76 | plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) 77 | 78 | 79 | def vis_faces_no_id(hooks_dict, fig, gs, i): 80 | plt.imshow(hooks_dict['input_face'], cmap="gray") 81 | plt.title('Input') 82 | fig.add_subplot(gs[i, 1]) 83 | plt.imshow(hooks_dict['target_face']) 84 | plt.title('Target') 85 | fig.add_subplot(gs[i, 2]) 86 | plt.imshow(hooks_dict['output_face']) 87 | plt.title('Output') 88 | -------------------------------------------------------------------------------- /pSp/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adopted from pix2pixHD: 3 | https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py 4 | """ 5 | import os 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 20 | for root, _, fnames in sorted(os.walk(dir)): 21 | for fname in fnames: 22 | if is_image_file(fname): 23 | path = os.path.join(root, fname) 24 | images.append(path) 25 | return images 26 | -------------------------------------------------------------------------------- /pSp/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | def aggregate_loss_dict(agg_loss_dict): 2 | mean_vals = {} 3 | for output in agg_loss_dict: 4 | for key in output: 5 | mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] 6 | for key in mean_vals: 7 | if len(mean_vals[key]) > 0: 8 | mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) 9 | else: 10 | print('{} has no value'.format(key)) 11 | mean_vals[key] = 0 12 | return mean_vals 13 | -------------------------------------------------------------------------------- /projection/encoder_inversion/celebahq_encode/encoded_p.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/projection/encoder_inversion/celebahq_encode/encoded_p.npy -------------------------------------------------------------------------------- /projection/encoder_inversion/celebahq_encode/encoded_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/projection/encoder_inversion/celebahq_encode/encoded_z.npy -------------------------------------------------------------------------------- /projection/encoder_inversion/ffhq_encode/encoded_p.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/projection/encoder_inversion/ffhq_encode/encoded_p.npy -------------------------------------------------------------------------------- /projection/encoder_inversion/ffhq_encode/encoded_z.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/projection/encoder_inversion/ffhq_encode/encoded_z.npy -------------------------------------------------------------------------------- /psp_spatial_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pprint 4 | import sys 5 | import argparse 6 | import math 7 | 8 | 9 | sys.path.append("./pSp") 10 | 11 | from psp_training_options import TrainOptions 12 | from pSp.training.coach_new import Coach 13 | 14 | 15 | if __name__ == '__main__': 16 | device = 'cuda' 17 | 18 | args = TrainOptions().parse() 19 | 20 | n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 21 | args.distributed = n_gpu > 1 22 | 23 | args.latent = 512 24 | args.token = 2 * (int(math.log(args.size, 2)) - 1) 25 | 26 | args.use_spatial_mapping = not args.no_spatial_map 27 | 28 | coach = Coach(args) 29 | coach.train() 30 | 31 | 32 | # python psp_spatial_train.py ffhq/LMDB_train/ --test_path ffhq/LMDB_test/ --ckpt ./out/trans_spatial_squery_multimap_fixed/checkpoint/790000.pt --num_region 1 --num_trans 8 --pixel_norm_op_dim 1" 33 | # python psp_spatial_train.py ffhq/LMDB_train/ --test_path ffhq/LMDB_test/ --ckpt ./out/trans_spatial_squery_multimap_fixed/checkpoint/790000.pt --num_region 1 --num_trans 8 --pixel_norm_op_dim 1 -------------------------------------------------------------------------------- /psp_testing_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from pSp.configs.paths_config import model_paths 4 | 5 | 6 | class TestOptions: 7 | 8 | def __init__(self): 9 | self.parser = ArgumentParser() 10 | self.initialize() 11 | 12 | def initialize(self): 13 | # self.parser.add_argument('--ckpt', type=str, required=True) 14 | # self.parser.add_argument('path', type=str) 15 | # self.parser.add_argument('--test_path', type=str, required=True) 16 | self.parser.add_argument('--ckpt', type=str, required=False) 17 | self.parser.add_argument('--size', type=int, default=256) 18 | self.parser.add_argument('--n_sample', type=int, default=8) 19 | # self.parser.add_argument('--loop_num', type=int, default=10) 20 | self.parser.add_argument('--output_dir', type=str, default='./psp_out') 21 | self.parser.add_argument('--para_num', type=int, default=16) 22 | 23 | 24 | 25 | self.parser.add_argument('--channel_multiplier', type=int, default=2) 26 | 27 | self.parser.add_argument('--inject_noise', action='store_true', default=False) 28 | 29 | self.parser.add_argument('--num_region', type=int, default=1) 30 | self.parser.add_argument('--no_spatial_map', action='store_true', default=False) 31 | self.parser.add_argument('--n_mlp', type=int, default=8) 32 | 33 | self.parser.add_argument('--num_trans', type=int, default=8) 34 | self.parser.add_argument('--no_trans', action='store_true', default=False) 35 | 36 | self.parser.add_argument('--pixel_norm_op_dim', type=int, default=1) 37 | 38 | # for psp 39 | self.parser.add_argument('--exp_dir', type=str, default= "psp_training_dir",help='Path to experiment output directory') 40 | 41 | self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, 42 | help='Type of dataset/experiment to run') 43 | self.parser.add_argument('--encoder_type', default='GradualStyleEncoder', type=str, help='Which encoder to use') 44 | self.parser.add_argument('--input_nc', default=3, type=int, 45 | help='Number of input image channels to the psp encoder') 46 | self.parser.add_argument('--label_nc', default=0, type=int, 47 | help='Number of input label channels to the psp encoder') 48 | self.parser.add_argument('--output_size', default=256, type=int, help='Output size of generator') 49 | 50 | self.parser.add_argument('--batch_size', default=8, type=int, help='Batch size for training') 51 | self.parser.add_argument('--test_batch_size', default=8, type=int, help='Batch size for testing and inference') 52 | self.parser.add_argument('--workers', default=8, type=int, help='Number of train dataloader workers') 53 | self.parser.add_argument('--test_workers', default=8, type=int, 54 | help='Number of test/inference dataloader workers') 55 | 56 | self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate') 57 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') 58 | self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') 59 | self.parser.add_argument('--start_from_latent_avg', action='store_true', 60 | help='Whether to add average latent vector to generate codes from encoder.') 61 | self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space insteaf of w+') 62 | 63 | self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor') 64 | self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') 65 | self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor') 66 | self.parser.add_argument('--w_norm_lambda', default=0, type=float, help='W-norm loss multiplier factor') 67 | self.parser.add_argument('--lpips_lambda_crop', default=0, type=float, 68 | help='LPIPS loss multiplier factor for inner image region') 69 | self.parser.add_argument('--l2_lambda_crop', default=0, type=float, 70 | help='L2 loss multiplier factor for inner image region') 71 | 72 | self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str, 73 | help='Path to StyleGAN model weights') 74 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') 75 | 76 | self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps') 77 | self.parser.add_argument('--image_interval', default=1000, type=int, 78 | help='Interval for logging train images during training') 79 | self.parser.add_argument('--board_interval', default=50, type=int, 80 | help='Interval for logging metrics to tensorboard') 81 | self.parser.add_argument('--val_interval', default=2500, type=int, help='Validation interval') 82 | self.parser.add_argument('--save_interval', default=5000, type=int, help='Model checkpoint interval') 83 | 84 | # arguments for super-resolution 85 | self.parser.add_argument('--resize_factors', type=str, default=None, 86 | help='For super-res, comma-separated resize factors to use for inference.') 87 | 88 | self.parser.add_argument('--from_plus_space', action='store_true') # invert to p+ and z+ 89 | 90 | 91 | 92 | def parse(self): 93 | opts = self.parser.parse_args() 94 | return opts 95 | -------------------------------------------------------------------------------- /resources/Teaser_v2.1-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/Teaser_v2.1-1.png -------------------------------------------------------------------------------- /resources/edit_blackhair_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/edit_blackhair_celeba.png -------------------------------------------------------------------------------- /resources/edit_ffhq_pose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/edit_ffhq_pose.png -------------------------------------------------------------------------------- /resources/edit_gender_ffhq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/edit_gender_ffhq.png -------------------------------------------------------------------------------- /resources/edit_pose_ffhq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/edit_pose_ffhq.png -------------------------------------------------------------------------------- /resources/edit_smile_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/edit_smile_celeba.png -------------------------------------------------------------------------------- /resources/interp_content_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/interp_content_celeba.png -------------------------------------------------------------------------------- /resources/interp_style_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/interp_style_celeba.png -------------------------------------------------------------------------------- /resources/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/teaser.png -------------------------------------------------------------------------------- /resources/teaser_change_order.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/resources/teaser_change_order.png -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from io import BytesIO 3 | 4 | import lmdb 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class MultiResolutionDataset(Dataset): 10 | def __init__(self, path, transform, resolution=256): 11 | 12 | self.env = lmdb.open( 13 | path, 14 | max_readers=32, 15 | readonly=True, 16 | lock=False, 17 | readahead=False, 18 | meminit=False, 19 | ) 20 | 21 | if not self.env: 22 | raise IOError('Cannot open lmdb dataset', path) 23 | 24 | with self.env.begin(write=False) as txn: 25 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 26 | 27 | self.resolution = resolution 28 | self.transform = transform 29 | 30 | def __len__(self): 31 | return self.length 32 | 33 | def __getitem__(self, index): 34 | with self.env.begin(write=False) as txn: 35 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') # b'256-00140' 36 | img_bytes = txn.get(key) 37 | 38 | try: 39 | buffer = BytesIO(img_bytes) 40 | img = Image.open(buffer) 41 | img = self.transform(img) 42 | return img 43 | except Exception as e: 44 | print(e) 45 | return self.__getitem__(random.randint(0, self.length - 1)) 46 | -------------------------------------------------------------------------------- /utils/dataset_projector.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | # import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | import os 7 | import torchvision.transforms as transforms 8 | import random 9 | import torch 10 | 11 | 12 | class MultiResolutionDataset(Dataset): 13 | def __init__(self, path, resolution=8): 14 | 15 | files = sorted(list(os.listdir(path))) 16 | self.imglist =[] 17 | for fir in files: 18 | self.imglist.append(os.path.join(path,fir)) 19 | 20 | self.transform = transforms.Compose( 21 | [ 22 | transforms.Resize(resolution), 23 | # transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 26 | ] 27 | ) 28 | 29 | self.resolution = resolution 30 | 31 | 32 | def __len__(self): 33 | return len(self.imglist) 34 | 35 | def __getitem__(self, index): 36 | 37 | img = Image.open(self.imglist[index]) 38 | img = self.transform(img) 39 | 40 | return img -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | from torch import distributed as dist 5 | 6 | 7 | def get_rank(): 8 | if not dist.is_available(): 9 | return 0 10 | 11 | if not dist.is_initialized(): 12 | return 0 13 | 14 | return dist.get_rank() 15 | 16 | 17 | def synchronize(): 18 | if not dist.is_available(): 19 | return 20 | 21 | if not dist.is_initialized(): 22 | return 23 | 24 | world_size = dist.get_world_size() 25 | 26 | if world_size == 1: 27 | return 28 | 29 | dist.barrier() 30 | 31 | 32 | def get_world_size(): 33 | if not dist.is_available(): 34 | return 1 35 | 36 | if not dist.is_initialized(): 37 | return 1 38 | 39 | return dist.get_world_size() 40 | 41 | 42 | def reduce_sum(tensor): 43 | if not dist.is_available(): 44 | return tensor 45 | 46 | if not dist.is_initialized(): 47 | return tensor 48 | 49 | tensor = tensor.clone() 50 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 51 | 52 | return tensor 53 | 54 | 55 | def gather_grad(params): 56 | world_size = get_world_size() 57 | 58 | if world_size == 1: 59 | return 60 | 61 | for param in params: 62 | if param.grad is not None: 63 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 64 | param.grad.data.div_(world_size) 65 | 66 | 67 | def all_gather(data): 68 | world_size = get_world_size() 69 | 70 | if world_size == 1: 71 | return [data] 72 | 73 | buffer = pickle.dumps(data) 74 | storage = torch.ByteStorage.from_buffer(buffer) 75 | tensor = torch.ByteTensor(storage).to('cuda') 76 | 77 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 78 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 79 | dist.all_gather(size_list, local_size) 80 | size_list = [int(size.item()) for size in size_list] 81 | max_size = max(size_list) 82 | 83 | tensor_list = [] 84 | for _ in size_list: 85 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 86 | 87 | if local_size != max_size: 88 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 89 | tensor = torch.cat((tensor, padding), 0) 90 | 91 | dist.all_gather(tensor_list, tensor) 92 | 93 | data_list = [] 94 | 95 | for size, tensor in zip(size_list, tensor_list): 96 | buffer = tensor.cpu().numpy().tobytes()[:size] 97 | data_list.append(pickle.loads(buffer)) 98 | 99 | return data_list 100 | 101 | 102 | def reduce_loss_dict(loss_dict): 103 | world_size = get_world_size() 104 | 105 | if world_size < 2: 106 | return loss_dict 107 | 108 | with torch.no_grad(): 109 | keys = [] 110 | losses = [] 111 | 112 | for k in sorted(loss_dict.keys()): 113 | keys.append(k) 114 | losses.append(loss_dict[k]) 115 | 116 | losses = torch.stack(losses, 0) 117 | dist.reduce(losses, dst=0) 118 | 119 | if dist.get_rank() == 0: 120 | losses /= world_size 121 | 122 | reduced_losses = {k: v for k, v in zip(keys, losses)} 123 | 124 | return reduced_losses 125 | -------------------------------------------------------------------------------- /utils/editing_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms, utils 3 | from PIL import Image 4 | import os 5 | import numpy as np 6 | 7 | 8 | def make_image(tensor): 9 | return ( 10 | tensor.detach() 11 | .clamp_(min=-1, max=1) 12 | .add(1) 13 | .div_(2) 14 | .mul(255) 15 | .type(torch.uint8) 16 | .permute(0, 2, 3, 1) 17 | .to('cpu') 18 | .numpy() 19 | ) 20 | 21 | def visualize(img_path): 22 | img_list=os.listdir(img_path) 23 | img_list.sort() 24 | img_list.sort(key = lambda x: (x[:-4])) ##文件名按数字排序 25 | b = ['0', '12', '24', '36', '48', '60'] 26 | dir = [] 27 | img_nums=len(img_list) 28 | res = [] 29 | 30 | transform = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) 33 | ] 34 | ) 35 | 36 | for i in range(img_nums): 37 | img_name=os.path.join(img_path, img_list[i]) 38 | ll = img_name.split('/')[-1].split('_')[3] 39 | if ll in b: 40 | dir.append(img_name) 41 | img = Image.open(img_name).convert('RGB') 42 | img2 = transform(img) 43 | array = np.asarray(img2) 44 | data = torch.from_numpy(array).unsqueeze(0) 45 | res.append(data) 46 | sample = torch.cat(res, dim=0) 47 | utils.save_image( 48 | sample, 49 | os.path.join(img_path,'edit.png'), 50 | nrow=int(6), 51 | normalize=True, 52 | range=(-1, 1), 53 | ) -------------------------------------------------------------------------------- /utils/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | # from skimage.measure import compare_ssim 9 | from skimage.metrics import structural_similarity 10 | 11 | from utils.lpips import dist_model 12 | 13 | 14 | class PerceptualLoss(torch.nn.Module): 15 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 16 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 17 | super(PerceptualLoss, self).__init__() 18 | print('Setting up Perceptual loss...') 19 | self.use_gpu = use_gpu 20 | self.spatial = spatial 21 | self.gpu_ids = gpu_ids 22 | self.model = dist_model.DistModel() 23 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 24 | print('...[%s] initialized'%self.model.name()) 25 | print('...Done') 26 | 27 | def forward(self, pred, target, normalize=False): 28 | """ 29 | Pred and target are Variables. 30 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 31 | If normalize is False, assumes the images are already between [-1,+1] 32 | 33 | Inputs pred and target are Nx3xHxW 34 | Output pytorch Variable N long 35 | """ 36 | 37 | if normalize: 38 | target = 2 * target - 1 39 | pred = 2 * pred - 1 40 | 41 | return self.model.forward(target, pred) 42 | 43 | def normalize_tensor(in_feat,eps=1e-10): 44 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 45 | return in_feat/(norm_factor+eps) 46 | 47 | def l2(p0, p1, range=255.): 48 | return .5*np.mean((p0 / range - p1 / range)**2) 49 | 50 | def psnr(p0, p1, peak=255.): 51 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 52 | 53 | def dssim(p0, p1, range=255.): 54 | return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2. 55 | 56 | def rgb2lab(in_img,mean_cent=False): 57 | from skimage import color 58 | img_lab = color.rgb2lab(in_img) 59 | if(mean_cent): 60 | img_lab[:,:,0] = img_lab[:,:,0]-50 61 | return img_lab 62 | 63 | def tensor2np(tensor_obj): 64 | # change dimension of a tensor object into a numpy array 65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 66 | 67 | def np2tensor(np_obj): 68 | # change dimenion of np array into tensor array 69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 70 | 71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 72 | # image tensor to lab tensor 73 | from skimage import color 74 | 75 | img = tensor2im(image_tensor) 76 | img_lab = color.rgb2lab(img) 77 | if(mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | if(to_norm and not mc_only): 80 | img_lab[:,:,0] = img_lab[:,:,0]-50 81 | img_lab = img_lab/100. 82 | 83 | return np2tensor(img_lab) 84 | 85 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 86 | from skimage import color 87 | import warnings 88 | warnings.filterwarnings("ignore") 89 | 90 | lab = tensor2np(lab_tensor)*100. 91 | lab[:,:,0] = lab[:,:,0]+50 92 | 93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 94 | if(return_inbnd): 95 | # convert back to lab, see if we match 96 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 97 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 99 | return (im2tensor(rgb_back),mask) 100 | else: 101 | return im2tensor(rgb_back) 102 | 103 | def rgb2lab(input): 104 | from skimage import color 105 | return color.rgb2lab(input / 255.) 106 | 107 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 108 | image_numpy = image_tensor[0].cpu().float().numpy() 109 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 110 | return image_numpy.astype(imtype) 111 | 112 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 113 | return torch.Tensor((image / factor - cent) 114 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 115 | 116 | def tensor2vec(vector_tensor): 117 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 118 | 119 | def voc_ap(rec, prec, use_07_metric=False): 120 | """ ap = voc_ap(rec, prec, [use_07_metric]) 121 | Compute VOC AP given precision and recall. 122 | If use_07_metric is true, uses the 123 | VOC 07 11 point method (default:False). 124 | """ 125 | if use_07_metric: 126 | # 11 point metric 127 | ap = 0. 128 | for t in np.arange(0., 1.1, 0.1): 129 | if np.sum(rec >= t) == 0: 130 | p = 0 131 | else: 132 | p = np.max(prec[rec >= t]) 133 | ap = ap + p / 11. 134 | else: 135 | # correct AP calculation 136 | # first append sentinel values at the end 137 | mrec = np.concatenate(([0.], rec, [1.])) 138 | mpre = np.concatenate(([0.], prec, [0.])) 139 | 140 | # compute the precision envelope 141 | for i in range(mpre.size - 1, 0, -1): 142 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 143 | 144 | # to calculate area under PR curve, look for points 145 | # where X axis (recall) changes value 146 | i = np.where(mrec[1:] != mrec[:-1])[0] 147 | 148 | # and sum (\Delta recall) * prec 149 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 150 | return ap 151 | 152 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 153 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 154 | image_numpy = image_tensor[0].cpu().float().numpy() 155 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 156 | return image_numpy.astype(imtype) 157 | 158 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 159 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 160 | return torch.Tensor((image / factor - cent) 161 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 162 | -------------------------------------------------------------------------------- /utils/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | class BaseModel(): 7 | def __init__(self): 8 | pass; 9 | 10 | def name(self): 11 | return 'BaseModel' 12 | 13 | def initialize(self, use_gpu=True, gpu_ids=[0]): 14 | self.use_gpu = use_gpu 15 | self.gpu_ids = gpu_ids 16 | 17 | def forward(self): 18 | pass 19 | 20 | def get_image_paths(self): 21 | pass 22 | 23 | def optimize_parameters(self): 24 | pass 25 | 26 | def get_current_visuals(self): 27 | return self.input 28 | 29 | def get_current_errors(self): 30 | return {} 31 | 32 | def save(self, label): 33 | pass 34 | 35 | # helper saving function that can be used by subclasses 36 | def save_network(self, network, path, network_label, epoch_label): 37 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 38 | save_path = os.path.join(path, save_filename) 39 | torch.save(network.state_dict(), save_path) 40 | 41 | # helper loading function that can be used by subclasses 42 | def load_network(self, network, network_label, epoch_label): 43 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 44 | save_path = os.path.join(self.save_dir, save_filename) 45 | print('Loading network from %s'%save_path) 46 | network.load_state_dict(torch.load(save_path)) 47 | 48 | def update_learning_rate(): 49 | pass 50 | 51 | def get_image_paths(self): 52 | return self.image_paths 53 | 54 | def save_done(self, flag=False): 55 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 56 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 57 | 58 | -------------------------------------------------------------------------------- /utils/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/utils/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /utils/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/utils/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /utils/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/utils/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /utils/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/utils/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /utils/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/utils/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /utils/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BillyXYB/TransEditor/14efcc3a9079e8853946707ad2273f863e6ff72b/utils/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /utils/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /utils/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | # self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | if bias: 81 | self.bias = nn.Parameter(torch.zeros(channel)) 82 | else: 83 | self.bias = None 84 | 85 | def forward(self, input): 86 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 87 | 88 | 89 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 90 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 91 | -------------------------------------------------------------------------------- /utils/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /utils/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /utils/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /utils/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | 23 | up_x, up_y = up 24 | down_x, down_y = down 25 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 26 | 27 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 28 | 29 | grad_input = upfirdn2d_op.upfirdn2d( 30 | grad_output, 31 | grad_kernel, 32 | down_x, 33 | down_y, 34 | up_x, 35 | up_y, 36 | g_pad_x0, 37 | g_pad_x1, 38 | g_pad_y0, 39 | g_pad_y1, 40 | ) 41 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 42 | 43 | ctx.save_for_backward(kernel) 44 | 45 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 46 | 47 | ctx.up_x = up_x 48 | ctx.up_y = up_y 49 | ctx.down_x = down_x 50 | ctx.down_y = down_y 51 | ctx.pad_x0 = pad_x0 52 | ctx.pad_x1 = pad_x1 53 | ctx.pad_y0 = pad_y0 54 | ctx.pad_y1 = pad_y1 55 | ctx.in_size = in_size 56 | ctx.out_size = out_size 57 | 58 | return grad_input 59 | 60 | @staticmethod 61 | def backward(ctx, gradgrad_input): 62 | kernel, = ctx.saved_tensors 63 | 64 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 65 | 66 | gradgrad_out = upfirdn2d_op.upfirdn2d( 67 | gradgrad_input, 68 | kernel, 69 | ctx.up_x, 70 | ctx.up_y, 71 | ctx.down_x, 72 | ctx.down_y, 73 | ctx.pad_x0, 74 | ctx.pad_x1, 75 | ctx.pad_y0, 76 | ctx.pad_y1, 77 | ) 78 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 79 | gradgrad_out = gradgrad_out.view( 80 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 81 | ) 82 | 83 | return gradgrad_out, None, None, None, None, None, None, None, None 84 | 85 | 86 | class UpFirDn2d(Function): 87 | @staticmethod 88 | def forward(ctx, input, kernel, up, down, pad): 89 | up_x, up_y = up 90 | down_x, down_y = down 91 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 92 | 93 | kernel_h, kernel_w = kernel.shape 94 | batch, channel, in_h, in_w = input.shape 95 | ctx.in_size = input.shape 96 | 97 | input = input.reshape(-1, in_h, in_w, 1) 98 | 99 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 100 | 101 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 102 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 103 | ctx.out_size = (out_h, out_w) 104 | 105 | ctx.up = (up_x, up_y) 106 | ctx.down = (down_x, down_y) 107 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 108 | 109 | g_pad_x0 = kernel_w - pad_x0 - 1 110 | g_pad_y0 = kernel_h - pad_y0 - 1 111 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 112 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 113 | 114 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 115 | 116 | out = upfirdn2d_op.upfirdn2d( 117 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 118 | ) 119 | # out = out.view(major, out_h, out_w, minor) 120 | out = out.view(-1, channel, out_h, out_w) 121 | 122 | return out 123 | 124 | @staticmethod 125 | def backward(ctx, grad_output): 126 | kernel, grad_kernel = ctx.saved_tensors 127 | 128 | grad_input = UpFirDn2dBackward.apply( 129 | grad_output, 130 | kernel, 131 | grad_kernel, 132 | ctx.up, 133 | ctx.down, 134 | ctx.pad, 135 | ctx.g_pad, 136 | ctx.in_size, 137 | ctx.out_size, 138 | ) 139 | 140 | return grad_input, None, None, None, None 141 | 142 | 143 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 144 | out = UpFirDn2d.apply( 145 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 146 | ) 147 | 148 | return out 149 | 150 | 151 | def upfirdn2d_native( 152 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 153 | ): 154 | _, in_h, in_w, minor = input.shape 155 | kernel_h, kernel_w = kernel.shape 156 | 157 | out = input.view(-1, in_h, 1, in_w, 1, minor) 158 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 159 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 160 | 161 | out = F.pad( 162 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 163 | ) 164 | out = out[ 165 | :, 166 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 167 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 168 | :, 169 | ] 170 | 171 | out = out.permute(0, 3, 1, 2) 172 | out = out.reshape( 173 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 174 | ) 175 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 176 | out = F.conv2d(out, w) 177 | out = out.reshape( 178 | -1, 179 | minor, 180 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 181 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 182 | ) 183 | out = out.permute(0, 2, 3, 1) 184 | 185 | return out[:, ::down_y, ::down_x, :] 186 | 187 | -------------------------------------------------------------------------------- /utils/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def prepare_param(n_sample, args, device, method="batch_same", truncation=1.0): 4 | if method == "batch_same": 5 | return torch.randn(args.para_num, args.latent, device=device).repeat(n_sample, 1, 1) * truncation 6 | elif method == "batch_diff": 7 | return torch.randn(n_sample, args.para_num, args.latent, device=device) * truncation 8 | elif method == "spatial": 9 | # batch, 512,16 10 | return torch.randn(n_sample, args.latent, args.para_num, device=device) * truncation 11 | elif method == "spatial_same": 12 | return torch.randn(args.latent, args.para_num, device=device).repeat(n_sample,1,1) * truncation 13 | 14 | 15 | 16 | def prepare_noise_new(n_sample, args, device, method="multi", truncation=1.0, mode = 'train'): 17 | # used for train_spatial_query, returns (bs, 512, 16) 18 | if method == 'query': 19 | return torch.randn(n_sample, args.latent, args.para_num, device=device) * truncation 20 | elif method == 'query_same': 21 | return torch.randn(args.latent, args.para_num, device=device).repeat(n_sample,1,1) * truncation --------------------------------------------------------------------------------