├── .gitignore ├── LICENSE.md ├── README.md ├── arcface ├── __init__.py └── iresnet.py ├── boundaries_ours ├── Eyeglasses_boundary.npy ├── Heavy_Makeup_boundary.npy └── Smiling_boundary.npy ├── configs └── 001.yaml ├── data ├── celeba_hq │ ├── 0.jpg │ ├── 1.jpg │ ├── 10.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── 8.jpg │ └── 9.jpg ├── stylegan2-generate-images │ └── seeds_pytorch_1.8.1.npy └── video │ └── FP010363HD03.mp4 ├── download_models.sh ├── environment.yml ├── face_parsing ├── model.py ├── resnet.py └── test.py ├── generate_imgs.py ├── images └── teaser.png ├── inference.ipynb ├── lpips ├── __init__.py ├── lpips.py ├── pretrained_networks.py ├── trainer.py └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── nets └── feature_style_encoder.py ├── pixel2style2pixel ├── LICENSE ├── download-weights.sh └── models │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── psp.cpython-36.pyc │ └── stylegan2 │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── model.cpython-36.pyc │ ├── model.py │ └── op │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── fused_act.cpython-36.pyc │ └── upfirdn2d.cpython-36.pyc │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── ranger.py ├── run_video_inversion_editing.sh ├── test.py ├── test ├── 00020.jpg ├── 00109.jpg ├── 00128.jpg ├── 00299.jpg ├── 00610.jpg └── 00962.jpg ├── train.py ├── trainer.py ├── utils ├── .DS_Store ├── datasets.py ├── functions.py └── video_utils.py └── video_processing.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | __pycache__/ 3 | matshow/ 4 | .ipynb_checkpoints/ 5 | tmp/ 6 | RAFT/ 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## LIMITED SOFTWARE EVALUATION LICENSE AGREEMENT 2 | 3 | The following Limited Software Evaluation License (the “License”) constitutes an agreement between you (the “Licensee”) and InterDigital Communications, Inc, a company organized and existing under the laws of the State of Delaware, USA, with its registered offices located at 200 Bellevue Parkway, Suite 300, Wilmington, DE 19809, USA (hereinafter “InterDigital”). 4 | This License governs the download and use of the Software (as defined below). Your use of the Software is subject to the terms and conditions set forth in this License. By installing, using, accessing or copying the Software, you hereby irrevocably accept the terms and conditions of this License. If you do not accept all parts of the terms and conditions of this License, you cannot install, use, access nor copy the Software. 5 | 6 | # Article 1. Definitions 7 | “Affiliate” as used herein shall mean any entity that, directly or indirectly, through one or more intermediates, is controlled by, controls, or is under common control with InterDigital or The Licensee, as the case may be. For purposes of this definition only, the term “control” means the possession of the power to direct or cause the direction of the management and policies of an entity, whether by ownership of voting stock or partnership interest, by contract, or otherwise, including direct or indirect ownership of more than fifty percent (50%) of the voting interest in the entity in question. 8 | “Authorized Purpose” means any use of the Software for fundamental research work with the exclusion of any commercial use. A commercial use includes, without limitation, any sublicense granted on the Software against a fee whatever its nature, any use of the Software in a product that is offered (either free or for a price) to any third party, any use of the Software to provide a service to a third party and/or any use of the Software to create a competing product of the Software ("Purpose") 9 | “Documentation” means textual materials delivered by InterDigital to the Licensee pursuant to this License relating to the Software, in written or electronic format, including but not limited to, technical reference manuals, technical notes, user manuals, and application guides. 10 | “Effective Date” means the date Licensee first installs a copy of the Software on any computer. 11 | 12 | “Limited Period” means the life of the copyright owned by InterDigital on the Software in each and every country where such copyright would exist. 13 | “Intellectual Property Rights” means all copyrights, trademarks, trade secrets, patents and any other intellectual property rights recognized in any jurisdiction worldwide, including all applications and registrations with respect thereto. 14 | "Open Source Software" shall mean any software, including where appropriate, any and all modifications, derivative works, enhancements, upgrades, improvements, fixed bugs, and/or statically linked to the source code of such software, released under a free or open source software license that requires, as a condition of usage, copy, modification and/or redistribution of such software, that the party: 15 | • Redistribute the Open Source Software royalty-free; and/or 16 | • Redistribute the Open Source Software under the same license/distribution terms as those contained in the open source or free software license under which it was originally released; and/or 17 | • Release to the public, disclose or otherwise make available the source code of the Open Source Software. 18 | For purposes of this License, by means of example and without limitation, any software that is released or distributed under any of the following licenses shall be qualified as Open Source Software: (i) GNU General Public License (GPL); (ii) GNU Lesser/Library GPL (LGPL); (iii) the Artistic License; (iv) the Mozilla Public License; (v) the Common Public License; (vi) the Sun Community Source License (SCSL); (vii) the Sun Industry Standards Source License (SISSL); (viii) BSD License; (ix) MIT License; (x) Apache Software License; (xi) Open SSL License; (xii) IBM Public License; and (xiii) Open Software License. 19 | “Software” means the Software with which this license was downloaded, namely FeatureStyleEncoder in object code. 20 | # Article 2. License 21 | InterDigital grants Licensee a free, worldwide, non-exclusive, license to InterDigital’s copyright on the Software to download, use and reproduce solely for the Authorized Purpose for the Limited Period. 22 | Licensee shall not pay any royalty, license fee or maintenance fee, or other fee of any nature under this License. 23 | # Article 3. Restrictions on use of the Software 24 | Licensee shall not have the right to correct, adapt, modify, reverse engineer, disassemble, decompile or/and otherwise perform or conduct any action leading to the transformation of the Software. 25 | Licensee shall not remove, obscure or modify any copyright, trademark or other proprietary rights notices, marks or labels contained on or within the Software, falsify or delete any author attributions, legal notices or other labels of the origin or source of the material. 26 | Licensee may reproduce and distribute copies of the Software in any medium provided that Licensee gives any other recipients of the Software a copy of this License. 27 | 28 | # Article 4. Ownership 29 | Title to and ownership of the Software, the Documentation, and/or any Intellectual Property Right protecting the Software and/or the Documentation shall at all times remain with InterDigital. Licensee agrees that except for the limited rights granted to the Software as set forth in Section 2 above, in no event shall anything in this License grant, provide, or convey any other rights, privileges, immunities, or interest in or to any Intellectual Property Rights (including but not limited to patent rights) of InterDigital or any of its Affiliates, whether by implication, estoppel, or otherwise. 30 | 31 | # Article 5. Publication/Communication 32 | Any publication or oral communication resulting from the use of the Software shall be elaborated in good faith and shall not be driven by a deliberate will to denigrate InterDigital or any of its products. In any publication and on any support joined to an oral communication (e.g., a PowerPoint presentation) relating to the Software, the following statement shall be inserted: 33 | “FeatureStyleEncoder” is an InterDigital product” 34 | In any publication, the latest publication about the software shall be properly cited. The latest publication currently is: 35 | A Style-Based GAN Encoder for High Fidelity Reconstruction of Images and Videos, Xu Yao, Alasdair Newson, Yann Gousseau, Pierre Hellier, ECCV European Conference on Computer Vision, 2022 (https://arxiv.org/pdf/2202.02183.pdf). 36 | In any oral communication relating to the Software and/or its use, the Licensee shall orally indicate that the Software is InterDigital’s property. 37 | 38 | # Article 6. No Warranty - Disclaimer 39 | THE SOFTWARE AND DOCUMENTATION ARE PROVIDED TO LICENSEE ON AN “AS IS” BASIS. INTERDIGITAL MAKES NO WARRANTY THAT THE SOFTWARE WILL OPERATE ON ANY PARTICULAR HARDWARE, PLATFORM, OR ENVIRONMENT. THERE IS NO WARRANTY THAT THE OPERATION OF THE SOFTWARE SHALL BE UNINTERRUPTED, WITHOUT BUGS OR ERROR FREE. THE SOFTWARE AND DOCUMENTATION ARE PROVIDED HEREUNDER WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED LIABILITIES AND WARRANTIES OF NONINFRINGEMENT OF INTELLECTUAL PROPERTY, FREEDOM FROM INHERENT DEFECTS, CONFORMITY TO A SAMPLE OR MODEL, MERCHANTABILITY, FITNESS AND/OR SUITABILITY FOR A SPECIFIC OR GENERAL PURPOSE AND THOSE ARISING BY STATUTE OR BY LAW, OR FROM A CAUSE OF DEALING OR USAGE OF TRADE. ANY AND ALL SUCH IMPLIED WARRANTIES ARE FULLY DISCLAIMED BY INTERDIGITAL TO THE MAXIMUM EXTENT ALLOWED BY LAW, AND LICENSEE ACKNOWLEDGES THAT THIS DISCLAIMER OF ALL EXPRESS AND IMPLIED WARRANTIES BY INTERDIGITAL, AS WELL AS LICENSEE’S ACCEPTANCE AND ACKNOWLEDGEMENT OF THE SAME, IS A MATERIAL PART OF THE CONSIDERATION FOR THIS LICENSE. 40 | InterDigital shall not be obligated to perform or provide any modifications, derivative works, enhancements, upgrades, updates or improvements of the Software or Documentation, or to fix any bug that could arise. 41 | Licensee at all times uses the Software at its own cost, risk and responsibility. InterDigital shall not be liable for any damages that could accrue by or to Licensee as a result of its use of the Software, either in accordance with this License or not. 42 | InterDigital shall not be liable for any consequential or indirect losses, including any indirect loss of profits, revenues, business, and/or anticipated savings, whether or not in the contemplation of the Parties at the time of entering into this License unless expressly set out in this License, or arising from gross negligence, willful misconduct or fraud. 43 | Licensee agrees that it will defend, indemnify and hold harmless InterDigital and its Affiliates against any and all losses, damages, costs and expenses arising from a breach by the Licensee of any of its obligations or representations hereunder, including, without limitation, any third party claims, and/or any claims in connection with any such breach and/or any use of the Software, including any claim from third party arising from access, use, or any other activity in relation to this Software. 44 | Licensee shall not make any warranty, representation, or commitment on behalf of InterDigital to any other third party. 45 | 46 | # Article 7. Open Source Software 47 | Licensee hereby represents, warrants, and covenants to InterDigital that Licensee’s use of the Software shall not result in the Contamination of all or any part of the Software, directly or indirectly, or of any Intellectual Property of InterDigital or its Affiliates. 48 | As used herein, “Contamination” shall mean that the licensing terms under which any Open Source Software, distinct from the Software, is released would also apply to the Software herein, by virtue of such Open Source Software being linked to, combined with, or otherwise connected to the Software. 49 | Licensee agree that some Open Source Software are included in the distribution. A list of such is provided in exhibit A with the relevant licenses applicable. For the avoidance of doubt, regarding such open source parts, the relevant license will apply exclusively. 50 | 51 | # Article 8. No Future Contract Obligation 52 | Neither this License nor the furnishing of the Software, nor any other InterDigital information provided to Licensee, shall be construed to obligate either party to: (a) enter into any further agreement or negotiation concerning the deployment of the Software; (b) refrain from entering into any agreement or negotiation with any other third party regarding the same or any other subject matter; or (c) refrain from pursuing its business in whatever manner it elects even if this involves competing with the other party. 53 | 54 | # Article 9. General Provisions 55 | 9.1 Severability. If any provision of this License shall be held to be in contravention of applicable law, this License shall be construed as if such provision were not a part thereof, and in all other respects the terms hereof shall remain in full force and effect. 56 | 9.2 Governing Law. Regardless of the place of execution, delivery, performance or any other aspect of this License, this License and all of the rights of the parties under this License shall be governed by, construed under and enforced in accordance with the substantive law of the State of Delaware, USA, without regard to conflicts of law principles. In case of a dispute that cannot be settled amicably, the state and federal courts located in New Castle County, Delaware, USA, shall have exclusive jurisdiction over such dispute, and each party hereby irrevocably waives any objection to the jurisdiction of such courts, including but not limited to objections of lack of in personam jurisdiction or based on principles of forum non conveniens. 57 | 9.3 Survival. The provisions of articles 1, 3, 4, 6, 7, 8, 9.1, 9.2 and 9.5 shall survive termination of this License. 58 | 9.4 Assignment. InterDigital may assign this license to any third Party. Licensee may not assign this agreement to any third party without InterDigital’s prior written approval. 59 | 9.5 Entire Agreement. This License constitutes the entire agreement between the parties hereto with respect to the subject matter hereof and supersedes any prior agreements or understanding. 60 |   61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A Style-Based GAN Encoder for High Fidelity Reconstruction of Images and Videos 2 | 3 | Official implementation for paper: A Style-Based GAN Encoder for High Fidelity Reconstruction of Images and Videos. 4 | 5 | [[Video Editing Results]](https://drive.google.com/file/d/1ebih6TZxb2eLKxJdbO8GnsInDKSegfYL/view?usp=sharing) 6 | 7 | ![teaser](images/teaser.jpg) 8 | 9 | > **Abstract** We propose a novel architecture for GAN inversion, which we call Feature-Style encoder. The style encoder is key for the manipulation of the obtained latent codes, while the feature encoder is crucial for optimal image reconstruction. Our model achieves accurate inversion of real images from the latent space of a pre-trained style-based GAN model, obtaining better perceptual quality and lower reconstruction error than existing methods. Thanks to its encoder structure, the model allows fast and accurate image editing. Additionally, we demonstrate that the proposed encoder is especially well-suited for inversion and editing on videos. We conduct extensive experiments for several style-based generators pre-trained on different data domains. Our proposed method yields state-of-the-art results for style-based GAN inversion, significantly outperforming competing approaches. 10 | 11 | 12 | ## Requirements 13 | 14 | ### Dependencies 15 | 16 | - Python 3.6 17 | - PyTorch 1.8 18 | - Opencv 19 | 20 | You can install a new environment for this repo by running 21 | ``` 22 | conda env create -f environment.yml 23 | conda activate feature_style 24 | ``` 25 | 26 | ### Prepare StyleGAN2 model and other necessary models 27 | 28 | * We adapt the StyleGAN2 model implemented by paper [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](https://arxiv.org/pdf/2008.00951.pdf). Here is their [official implementation](https://github.com/eladrich/pixel2style2pixel.git). 29 | 30 | * Download and save the pretrained models running 31 | ``` 32 | sh download_models.sh 33 | ``` 34 | 35 | 36 | ## Training 37 | 38 | * Prepare the training data 39 | 40 | To train the encoder for StyleGAN, we use the synthetic images generated by StyleGAN and also the real images [ffhq dataset](https://github.com/NVlabs/ffhq-dataset). 41 | You can generate the synthetic images by running 42 | ``` 43 | python generate_imgs.py 44 | ``` 45 | and download the ffhq dataset (aligned faces) to `data/ffhq-dataset/images/`. 46 | 47 | * Training 48 | 49 | You can modify the training options of the config file in the directory `configs/`. 50 | ``` 51 | python train.py --config 001 52 | ``` 53 | 54 | ## Testing 55 | 56 | * Inversion 57 | 58 | You can test the encoder on the images in `test/`. The output images are saved in `output/image/`. 59 | ``` 60 | python test.py --pretrained_model_path './pretrained_models/143_enc.pth' --input_path './test/' 61 | ``` 62 | * Inversion and editing in notebook 63 | 64 | You can explore the encoder and the attribute editing code in notebook `inference.ipynb`. You can also open it in Google Colab [here](https://colab.research.google.com/github/InterDigitalInc/FeatureStyleEncoder/blob/master/inference.ipynb). 65 | 66 | 67 | ## Video Manipulation 68 | 69 | We provide a script to achieve inversion and attribute manipulation for the videos in the test directory `data/video/`. You can upload your own video and modify the options in `run_video_inversion_editing.sh`. 70 | 71 | ``` 72 | sh run_video_inversion_editing.sh 73 | ``` 74 | 75 | ## Citation 76 | ``` 77 | @article{xuyao2022, 78 | title={A Style-Based GAN Encoder for High Fidelity Reconstruction of Images and Videos}, 79 | author={Yao, Xu and Newson, Alasdair and Gousseau, Yann and Hellier, Pierre}, 80 | journal={European conference on computer vision}, 81 | year={2022} 82 | } 83 | ``` 84 | ## License 85 | 86 | Copyright © 2022, InterDigital R&D France. All rights reserved. 87 | 88 | This source code is made available under the license found in the LICENSE.txt in the root directory of this source tree. 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /arcface/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | -------------------------------------------------------------------------------- /arcface/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, 10 | out_planes, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=dilation, 14 | groups=groups, 15 | bias=False, 16 | dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, 22 | out_planes, 23 | kernel_size=1, 24 | stride=stride, 25 | bias=False) 26 | 27 | 28 | class IBasicBlock(nn.Module): 29 | expansion = 1 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, 31 | groups=1, base_width=64, dilation=1): 32 | super(IBasicBlock, self).__init__() 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 38 | self.conv1 = conv3x3(inplanes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 40 | self.prelu = nn.PReLU(planes) 41 | self.conv2 = conv3x3(planes, planes, stride) 42 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | identity = x 48 | out = self.bn1(x) 49 | out = self.conv1(out) 50 | out = self.bn2(out) 51 | out = self.prelu(out) 52 | out = self.conv2(out) 53 | out = self.bn3(out) 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | out += identity 57 | return out 58 | 59 | 60 | class IResNet(nn.Module): 61 | fc_scale = 7 * 7 62 | def __init__(self, 63 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 64 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 65 | super(IResNet, self).__init__() 66 | self.fp16 = fp16 67 | self.inplanes = 64 68 | self.dilation = 1 69 | if replace_stride_with_dilation is None: 70 | replace_stride_with_dilation = [False, False, False] 71 | if len(replace_stride_with_dilation) != 3: 72 | raise ValueError("replace_stride_with_dilation should be None " 73 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 74 | self.groups = groups 75 | self.base_width = width_per_group 76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 78 | self.prelu = nn.PReLU(self.inplanes) 79 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 80 | self.layer2 = self._make_layer(block, 81 | 128, 82 | layers[1], 83 | stride=2, 84 | dilate=replace_stride_with_dilation[0]) 85 | self.layer3 = self._make_layer(block, 86 | 256, 87 | layers[2], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[1]) 90 | self.layer4 = self._make_layer(block, 91 | 512, 92 | layers[3], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[2]) 95 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 96 | self.dropout = nn.Dropout(p=dropout, inplace=True) 97 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 98 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 99 | nn.init.constant_(self.features.weight, 1.0) 100 | self.features.weight.requires_grad = False 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | nn.init.normal_(m.weight, 0, 0.1) 105 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 106 | nn.init.constant_(m.weight, 1) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | if zero_init_residual: 110 | for m in self.modules(): 111 | if isinstance(m, IBasicBlock): 112 | nn.init.constant_(m.bn2.weight, 0) 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 115 | downsample = None 116 | previous_dilation = self.dilation 117 | if dilate: 118 | self.dilation *= stride 119 | stride = 1 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | conv1x1(self.inplanes, planes * block.expansion, stride), 123 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 124 | ) 125 | layers = [] 126 | layers.append( 127 | block(self.inplanes, planes, stride, downsample, self.groups, 128 | self.base_width, previous_dilation)) 129 | self.inplanes = planes * block.expansion 130 | for _ in range(1, blocks): 131 | layers.append( 132 | block(self.inplanes, 133 | planes, 134 | groups=self.groups, 135 | base_width=self.base_width, 136 | dilation=self.dilation)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x, return_features=False): 141 | out = [] 142 | with torch.cuda.amp.autocast(self.fp16): 143 | x = self.conv1(x) 144 | x = self.bn1(x) 145 | x = self.prelu(x) 146 | x = self.layer1(x) 147 | out.append(x) 148 | x = self.layer2(x) 149 | out.append(x) 150 | x = self.layer3(x) 151 | out.append(x) 152 | x = self.layer4(x) 153 | out.append(x) 154 | x = self.bn2(x) 155 | x = torch.flatten(x, 1) 156 | x = self.dropout(x) 157 | x = self.fc(x.float() if self.fp16 else x) 158 | x = self.features(x) 159 | 160 | if return_features: 161 | out.append(x) 162 | return out 163 | return x 164 | 165 | 166 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 167 | model = IResNet(block, layers, **kwargs) 168 | if pretrained: 169 | raise ValueError() 170 | return model 171 | 172 | 173 | def iresnet18(pretrained=False, progress=True, **kwargs): 174 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 175 | progress, **kwargs) 176 | 177 | 178 | def iresnet34(pretrained=False, progress=True, **kwargs): 179 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 180 | progress, **kwargs) 181 | 182 | 183 | def iresnet50(pretrained=False, progress=True, **kwargs): 184 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 185 | progress, **kwargs) 186 | 187 | 188 | def iresnet100(pretrained=False, progress=True, **kwargs): 189 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 190 | progress, **kwargs) 191 | 192 | 193 | def iresnet200(pretrained=False, progress=True, **kwargs): 194 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, 195 | progress, **kwargs) 196 | -------------------------------------------------------------------------------- /boundaries_ours/Eyeglasses_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/boundaries_ours/Eyeglasses_boundary.npy -------------------------------------------------------------------------------- /boundaries_ours/Heavy_Makeup_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/boundaries_ours/Heavy_Makeup_boundary.npy -------------------------------------------------------------------------------- /boundaries_ours/Smiling_boundary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/boundaries_ours/Smiling_boundary.npy -------------------------------------------------------------------------------- /configs/001.yaml: -------------------------------------------------------------------------------- 1 | # Input data 2 | resolution: 1024 3 | age_min: 20 4 | age_max: 70 5 | use_realimg: True 6 | # Training hyperparameters 7 | batch_size: 1 8 | epochs: 12 9 | iter_per_epoch: 10000 10 | device: 'cuda' 11 | # Optimizer parameters 12 | optimizer: 'ranger' 13 | lr: 0.0001 14 | beta_1: 0.95 15 | beta_2: 0.999 16 | weight_decay: 0 17 | # Learning rate scheduler 18 | step_size: 10 19 | gamma: 0.1 20 | # Tensorboard log options 21 | image_save_iter: 100 22 | log_iter: 10 23 | # Network setting 24 | use_fs_encoder: True 25 | use_fs_encoder_v2: True 26 | fs_stride: 2 27 | pretrained_weight_for_fs: False 28 | enc_resolution: 256 29 | enc_residual: False 30 | truncation_psi: 1 31 | use_noise: True 32 | randomize_noise: False # If generator use a different random noise at each time of generating a image from z 33 | # Loss setting 34 | use_parsing_net: True 35 | multi_layer_idloss: True 36 | real_image_as_image_loss: False 37 | feature_match_loss: False 38 | feature_match_loss_G: False 39 | use_random_noise: True 40 | optimize_on_z: False 41 | multiscale_lpips: True 42 | # Loss weight 43 | w: 44 | l1: 0 45 | l2: 1 46 | lpips: 0.2 47 | id: 0.1 48 | landmark: 0.1 49 | f_recon: 0.01 -------------------------------------------------------------------------------- /data/celeba_hq/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/0.jpg -------------------------------------------------------------------------------- /data/celeba_hq/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/1.jpg -------------------------------------------------------------------------------- /data/celeba_hq/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/10.jpg -------------------------------------------------------------------------------- /data/celeba_hq/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/2.jpg -------------------------------------------------------------------------------- /data/celeba_hq/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/3.jpg -------------------------------------------------------------------------------- /data/celeba_hq/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/4.jpg -------------------------------------------------------------------------------- /data/celeba_hq/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/5.jpg -------------------------------------------------------------------------------- /data/celeba_hq/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/6.jpg -------------------------------------------------------------------------------- /data/celeba_hq/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/7.jpg -------------------------------------------------------------------------------- /data/celeba_hq/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/8.jpg -------------------------------------------------------------------------------- /data/celeba_hq/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/celeba_hq/9.jpg -------------------------------------------------------------------------------- /data/stylegan2-generate-images/seeds_pytorch_1.8.1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/stylegan2-generate-images/seeds_pytorch_1.8.1.npy -------------------------------------------------------------------------------- /data/video/FP010363HD03.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/data/video/FP010363HD03.mp4 -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | pip install gdown 3 | mkdir pretrained_models 4 | cd pretrained_models 5 | 6 | # download pretrained encoder 7 | gdown --fuzzy https://drive.google.com/file/d/1RnnBL77j_Can0dY1KOiXHvG224MxjvzC/view?usp=sharing 8 | 9 | # download arcface pretrained model 10 | gdown --fuzzy https://drive.google.com/file/d/1coFTz-Kkgvoc_gRT8JFzqCgeC3lAFWQp/view?usp=sharing 11 | 12 | # download face parsing model from https://github.com/zllrunning/face-parsing.PyTorch 13 | gdown --fuzzy https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812 14 | 15 | # download pSp pretrained model from https://github.com/eladrich/pixel2style2pixel.git 16 | cd ../pixel2style2pixel 17 | mkdir pretrained_models 18 | cd pretrained_models 19 | gdown --fuzzy https://drive.google.com/file/d/1bMTNWkh5LArlaWSc_wa8VKyq2V42T2z0/view?usp=sharing 20 | cd .. 21 | cd .. 22 | 23 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: feature_style 2 | channels: 3 | - pytorch 4 | - 1adrianb 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - argon2-cffi=20.1.0 10 | - async_generator=1.10 11 | - attrs=20.3.0 12 | - backcall=0.2.0 13 | - blas=1.0 14 | - bleach=3.3.0 15 | - bzip2=1.0.8 16 | - ca-certificates=2021.10.8 17 | - certifi=2021.5.30 18 | - cffi=1.14.5 19 | - cloudpickle=1.6.0 20 | - cudatoolkit=10.2.89 21 | - cycler=0.10.0 22 | - cytoolz=0.11.0 23 | - dask-core=1.1.4 24 | - dataclasses=0.8 25 | - decorator=5.0.6 26 | - defusedxml=0.7.1 27 | - entrypoints=0.3 28 | - face_alignment=1.3.4 29 | - ffmpeg=4.3 30 | - freetype=2.10.4 31 | - gmp=6.2.1 32 | - gnutls=3.6.15 33 | - hdf5=1.10.2 34 | - imageio=2.9.0 35 | - importlib-metadata=3.10.0 36 | - importlib_metadata=3.10.0 37 | - intel-openmp=2020.2 38 | - ipykernel=5.3.4 39 | - ipython=7.16.1 40 | - ipython_genutils=0.2.0 41 | - ipywidgets=7.6.3 42 | - jedi=0.17.0 43 | - jinja2=2.11.3 44 | - jpeg=9b 45 | - jsonschema=3.2.0 46 | - jupyter_client=6.1.12 47 | - jupyter_core=4.7.1 48 | - jupyterlab_pygments=0.1.2 49 | - jupyterlab_widgets=1.0.0 50 | - kiwisolver=1.3.1 51 | - lame=3.100 52 | - lcms2=2.12 53 | - ld_impl_linux-64=2.33.1 54 | - libffi=3.3 55 | - libgcc-ng=9.1.0 56 | - libgfortran=3.0.0 57 | - libgfortran-ng=7.3.0 58 | - libiconv=1.15 59 | - libidn2=2.3.0 60 | - libllvm10=10.0.1 61 | - libpng=1.6.37 62 | - libsodium=1.0.18 63 | - libstdcxx-ng=9.1.0 64 | - libtasn1=4.16.0 65 | - libtiff=4.1.0 66 | - libunistring=0.9.10 67 | - libuv=1.40.0 68 | - llvmlite=0.36.0 69 | - lz4-c=1.9.3 70 | - markupsafe=1.1.1 71 | - matplotlib-base=3.3.4 72 | - mistune=0.8.4 73 | - mkl=2020.2 74 | - mkl-service=2.3.0 75 | - mkl_fft=1.3.0 76 | - mkl_random=1.1.1 77 | - nbclient=0.5.3 78 | - nbconvert=6.0.7 79 | - nbformat=5.1.3 80 | - ncurses=6.2 81 | - nest-asyncio=1.5.1 82 | - nettle=3.7.2 83 | - networkx=2.2 84 | - notebook=6.3.0 85 | - numba=0.53.1 86 | - numpy=1.19.2 87 | - numpy-base=1.19.2 88 | - olefile=0.46 89 | - opencv=3.4.1 90 | - openh264=2.1.0 91 | - openssl=1.1.1l 92 | - packaging=20.9 93 | - pandoc=2.12 94 | - pandocfilters=1.4.3 95 | - parso=0.8.2 96 | - pexpect=4.8.0 97 | - pickleshare=0.7.5 98 | - pillow=8.2.0 99 | - pip=21.0.1 100 | - prometheus_client=0.10.1 101 | - prompt-toolkit=3.0.17 102 | - ptyprocess=0.7.0 103 | - pycparser=2.20 104 | - pygments=2.8.1 105 | - pyparsing=2.4.7 106 | - pyrsistent=0.17.3 107 | - python=3.6.13 108 | - python-dateutil=2.8.1 109 | - python_abi=3.6 110 | - pytorch=1.8.1 111 | - pywavelets=1.1.1 112 | - pyzmq=20.0.0 113 | - readline=8.1 114 | - scikit-image=0.17.2 115 | - scipy=1.5.2 116 | - send2trash=1.5.0 117 | - setuptools=52.0.0 118 | - six=1.15.0 119 | - sqlite=3.35.4 120 | - tbb=2020.3 121 | - terminado=0.9.4 122 | - testpath=0.4.4 123 | - tifffile=2020.10.1 124 | - tk=8.6.10 125 | - toolz=0.11.1 126 | - torchvision=0.9.1 127 | - tornado=6.1 128 | - tqdm=4.59.0 129 | - traitlets=4.3.3 130 | - typing_extensions=3.7.4.3 131 | - wcwidth=0.2.5 132 | - webencodings=0.5.1 133 | - wheel=0.36.2 134 | - widgetsnbextension=3.5.1 135 | - xz=5.2.5 136 | - yaml=0.2.5 137 | - zeromq=4.3.4 138 | - zipp=3.4.1 139 | - zlib=1.2.11 140 | - zstd=1.4.9 141 | - pip: 142 | - charset-normalizer==2.0.7 143 | - click==8.0.3 144 | - cython==0.29.23 145 | - dlib==19.22.1 146 | - idna==3.2 147 | - joblib==1.1.0 148 | - ninja==1.10.2.2 149 | - onnxruntime==1.7.0 150 | - opencv-python==4.5.1.48 151 | - protobuf==3.15.8 152 | - pytorch-msssim==0.2.1 153 | - pyyaml==5.4.1 154 | - requests==2.26.0 155 | - scikit-learn==0.22 156 | - sklearn==0.0 157 | - tensorboard-logger==0.1.0 158 | - threadpoolctl==3.0.0 159 | - urllib3==1.26.7 160 | -------------------------------------------------------------------------------- /face_parsing/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | import sys 11 | sys.path.append('..') 12 | from face_parsing.resnet import Resnet18 13 | # from modules.bn import InPlaceABNSync as BatchNorm2d 14 | 15 | 16 | class ConvBNReLU(nn.Module): 17 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 18 | super(ConvBNReLU, self).__init__() 19 | self.conv = nn.Conv2d(in_chan, 20 | out_chan, 21 | kernel_size = ks, 22 | stride = stride, 23 | padding = padding, 24 | bias = False) 25 | self.bn = nn.BatchNorm2d(out_chan) 26 | self.init_weight() 27 | 28 | def forward(self, x): 29 | x = self.conv(x) 30 | x = F.relu(self.bn(x)) 31 | return x 32 | 33 | def init_weight(self): 34 | for ly in self.children(): 35 | if isinstance(ly, nn.Conv2d): 36 | nn.init.kaiming_normal_(ly.weight, a=1) 37 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 38 | 39 | class BiSeNetOutput(nn.Module): 40 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 41 | super(BiSeNetOutput, self).__init__() 42 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 43 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 44 | self.init_weight() 45 | 46 | def forward(self, x): 47 | x = self.conv(x) 48 | x = self.conv_out(x) 49 | return x 50 | 51 | def init_weight(self): 52 | for ly in self.children(): 53 | if isinstance(ly, nn.Conv2d): 54 | nn.init.kaiming_normal_(ly.weight, a=1) 55 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 56 | 57 | def get_params(self): 58 | wd_params, nowd_params = [], [] 59 | for name, module in self.named_modules(): 60 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 61 | wd_params.append(module.weight) 62 | if not module.bias is None: 63 | nowd_params.append(module.bias) 64 | elif isinstance(module, nn.BatchNorm2d): 65 | nowd_params += list(module.parameters()) 66 | return wd_params, nowd_params 67 | 68 | 69 | class AttentionRefinementModule(nn.Module): 70 | def __init__(self, in_chan, out_chan, *args, **kwargs): 71 | super(AttentionRefinementModule, self).__init__() 72 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 73 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 74 | self.bn_atten = nn.BatchNorm2d(out_chan) 75 | self.sigmoid_atten = nn.Sigmoid() 76 | self.init_weight() 77 | 78 | def forward(self, x): 79 | feat = self.conv(x) 80 | atten = F.avg_pool2d(feat, feat.size()[2:]) 81 | atten = self.conv_atten(atten) 82 | atten = self.bn_atten(atten) 83 | atten = self.sigmoid_atten(atten) 84 | out = torch.mul(feat, atten) 85 | return out 86 | 87 | def init_weight(self): 88 | for ly in self.children(): 89 | if isinstance(ly, nn.Conv2d): 90 | nn.init.kaiming_normal_(ly.weight, a=1) 91 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 92 | 93 | 94 | class ContextPath(nn.Module): 95 | def __init__(self, *args, **kwargs): 96 | super(ContextPath, self).__init__() 97 | self.resnet = Resnet18() 98 | self.arm16 = AttentionRefinementModule(256, 128) 99 | self.arm32 = AttentionRefinementModule(512, 128) 100 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 101 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 102 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 103 | 104 | self.init_weight() 105 | 106 | def forward(self, x): 107 | H0, W0 = x.size()[2:] 108 | feat8, feat16, feat32 = self.resnet(x) 109 | H8, W8 = feat8.size()[2:] 110 | H16, W16 = feat16.size()[2:] 111 | H32, W32 = feat32.size()[2:] 112 | 113 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 114 | avg = self.conv_avg(avg) 115 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 116 | 117 | feat32_arm = self.arm32(feat32) 118 | feat32_sum = feat32_arm + avg_up 119 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 120 | feat32_up = self.conv_head32(feat32_up) 121 | 122 | feat16_arm = self.arm16(feat16) 123 | feat16_sum = feat16_arm + feat32_up 124 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 125 | feat16_up = self.conv_head16(feat16_up) 126 | 127 | return feat8, feat16_up, feat32_up # x8, x8, x16 128 | 129 | def init_weight(self): 130 | for ly in self.children(): 131 | if isinstance(ly, nn.Conv2d): 132 | nn.init.kaiming_normal_(ly.weight, a=1) 133 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 134 | 135 | def get_params(self): 136 | wd_params, nowd_params = [], [] 137 | for name, module in self.named_modules(): 138 | if isinstance(module, (nn.Linear, nn.Conv2d)): 139 | wd_params.append(module.weight) 140 | if not module.bias is None: 141 | nowd_params.append(module.bias) 142 | elif isinstance(module, nn.BatchNorm2d): 143 | nowd_params += list(module.parameters()) 144 | return wd_params, nowd_params 145 | 146 | 147 | ### This is not used, since I replace this with the resnet feature with the same size 148 | class SpatialPath(nn.Module): 149 | def __init__(self, *args, **kwargs): 150 | super(SpatialPath, self).__init__() 151 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 152 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 153 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 154 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 155 | self.init_weight() 156 | 157 | def forward(self, x): 158 | feat = self.conv1(x) 159 | feat = self.conv2(feat) 160 | feat = self.conv3(feat) 161 | feat = self.conv_out(feat) 162 | return feat 163 | 164 | def init_weight(self): 165 | for ly in self.children(): 166 | if isinstance(ly, nn.Conv2d): 167 | nn.init.kaiming_normal_(ly.weight, a=1) 168 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 169 | 170 | def get_params(self): 171 | wd_params, nowd_params = [], [] 172 | for name, module in self.named_modules(): 173 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 174 | wd_params.append(module.weight) 175 | if not module.bias is None: 176 | nowd_params.append(module.bias) 177 | elif isinstance(module, nn.BatchNorm2d): 178 | nowd_params += list(module.parameters()) 179 | return wd_params, nowd_params 180 | 181 | 182 | class FeatureFusionModule(nn.Module): 183 | def __init__(self, in_chan, out_chan, *args, **kwargs): 184 | super(FeatureFusionModule, self).__init__() 185 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 186 | self.conv1 = nn.Conv2d(out_chan, 187 | out_chan//4, 188 | kernel_size = 1, 189 | stride = 1, 190 | padding = 0, 191 | bias = False) 192 | self.conv2 = nn.Conv2d(out_chan//4, 193 | out_chan, 194 | kernel_size = 1, 195 | stride = 1, 196 | padding = 0, 197 | bias = False) 198 | self.relu = nn.ReLU(inplace=True) 199 | self.sigmoid = nn.Sigmoid() 200 | self.init_weight() 201 | 202 | def forward(self, fsp, fcp): 203 | fcat = torch.cat([fsp, fcp], dim=1) 204 | feat = self.convblk(fcat) 205 | atten = F.avg_pool2d(feat, feat.size()[2:]) 206 | atten = self.conv1(atten) 207 | atten = self.relu(atten) 208 | atten = self.conv2(atten) 209 | atten = self.sigmoid(atten) 210 | feat_atten = torch.mul(feat, atten) 211 | feat_out = feat_atten + feat 212 | return feat_out 213 | 214 | def init_weight(self): 215 | for ly in self.children(): 216 | if isinstance(ly, nn.Conv2d): 217 | nn.init.kaiming_normal_(ly.weight, a=1) 218 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 219 | 220 | def get_params(self): 221 | wd_params, nowd_params = [], [] 222 | for name, module in self.named_modules(): 223 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 224 | wd_params.append(module.weight) 225 | if not module.bias is None: 226 | nowd_params.append(module.bias) 227 | elif isinstance(module, nn.BatchNorm2d): 228 | nowd_params += list(module.parameters()) 229 | return wd_params, nowd_params 230 | 231 | 232 | class BiSeNet(nn.Module): 233 | def __init__(self, n_classes, *args, **kwargs): 234 | super(BiSeNet, self).__init__() 235 | self.cp = ContextPath() 236 | ## here self.sp is deleted 237 | self.ffm = FeatureFusionModule(256, 256) 238 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 239 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 240 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 241 | self.init_weight() 242 | 243 | def forward(self, x): 244 | H, W = x.size()[2:] 245 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 246 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 247 | feat_fuse = self.ffm(feat_sp, feat_cp8) 248 | 249 | feat_out = self.conv_out(feat_fuse) 250 | feat_out16 = self.conv_out16(feat_cp8) 251 | feat_out32 = self.conv_out32(feat_cp16) 252 | 253 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 254 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 255 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 256 | return feat_out, feat_out16, feat_out32 257 | 258 | def extract_fuse_layer(self, x): 259 | H, W = x.size()[2:] 260 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 261 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 262 | feat_fuse = self.ffm(feat_sp, feat_cp8) 263 | return [feat_fuse] 264 | 265 | def init_weight(self): 266 | for ly in self.children(): 267 | if isinstance(ly, nn.Conv2d): 268 | nn.init.kaiming_normal_(ly.weight, a=1) 269 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 270 | 271 | def get_params(self): 272 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 273 | for name, child in self.named_children(): 274 | child_wd_params, child_nowd_params = child.get_params() 275 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 276 | lr_mul_wd_params += child_wd_params 277 | lr_mul_nowd_params += child_nowd_params 278 | else: 279 | wd_params += child_wd_params 280 | nowd_params += child_nowd_params 281 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 282 | 283 | 284 | if __name__ == "__main__": 285 | net = BiSeNet(19) 286 | net.cuda() 287 | net.eval() 288 | in_ten = torch.randn(16, 3, 640, 480).cuda() 289 | out, out16, out32 = net(in_ten) 290 | print(out.shape) 291 | 292 | net.get_params() 293 | -------------------------------------------------------------------------------- /face_parsing/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | import os 10 | 11 | 12 | # from modules.bn import InPlaceABNSync as BatchNorm2d 13 | 14 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | def __init__(self, in_chan, out_chan, stride=1): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = conv3x3(in_chan, out_chan, stride) 27 | self.bn1 = nn.BatchNorm2d(out_chan) 28 | self.conv2 = conv3x3(out_chan, out_chan) 29 | self.bn2 = nn.BatchNorm2d(out_chan) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = None 32 | if in_chan != out_chan or stride != 1: 33 | self.downsample = nn.Sequential( 34 | nn.Conv2d(in_chan, out_chan, 35 | kernel_size=1, stride=stride, bias=False), 36 | nn.BatchNorm2d(out_chan), 37 | ) 38 | 39 | def forward(self, x): 40 | residual = self.conv1(x) 41 | residual = F.relu(self.bn1(residual)) 42 | residual = self.conv2(residual) 43 | residual = self.bn2(residual) 44 | 45 | shortcut = x 46 | if self.downsample is not None: 47 | shortcut = self.downsample(x) 48 | 49 | out = shortcut + residual 50 | out = self.relu(out) 51 | return out 52 | 53 | 54 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 55 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 56 | for i in range(bnum-1): 57 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 58 | return nn.Sequential(*layers) 59 | 60 | 61 | class Resnet18(nn.Module): 62 | def __init__(self): 63 | super(Resnet18, self).__init__() 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 65 | bias=False) 66 | self.bn1 = nn.BatchNorm2d(64) 67 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 68 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 69 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 70 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 71 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 72 | self.init_weight() 73 | 74 | def forward(self, x): 75 | x = self.conv1(x) 76 | x = F.relu(self.bn1(x)) 77 | x = self.maxpool(x) 78 | 79 | x = self.layer1(x) 80 | feat8 = self.layer2(x) # 1/8 81 | feat16 = self.layer3(feat8) # 1/16 82 | feat32 = self.layer4(feat16) # 1/32 83 | return feat8, feat16, feat32 84 | 85 | def init_weight(self): 86 | state_dict = modelzoo.load_url(resnet18_url) 87 | self_state_dict = self.state_dict() 88 | for k, v in state_dict.items(): 89 | if 'fc' in k: continue 90 | self_state_dict.update({k: v}) 91 | self.load_state_dict(self_state_dict) 92 | 93 | def get_params(self): 94 | wd_params, nowd_params = [], [] 95 | for name, module in self.named_modules(): 96 | if isinstance(module, (nn.Linear, nn.Conv2d)): 97 | wd_params.append(module.weight) 98 | if not module.bias is None: 99 | nowd_params.append(module.bias) 100 | elif isinstance(module, nn.BatchNorm2d): 101 | nowd_params += list(module.parameters()) 102 | return wd_params, nowd_params 103 | 104 | 105 | if __name__ == "__main__": 106 | net = Resnet18() 107 | x = torch.randn(16, 3, 224, 224) 108 | out = net(x) 109 | print(out[0].size()) 110 | print(out[1].size()) 111 | print(out[2].size()) 112 | net.get_params() 113 | -------------------------------------------------------------------------------- /face_parsing/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from model import BiSeNet 5 | 6 | import torch 7 | 8 | import os 9 | import os.path as osp 10 | import numpy as np 11 | from PIL import Image 12 | import torchvision.transforms as transforms 13 | import cv2 14 | 15 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): 16 | # Colors for all 20 parts 17 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], 18 | [255, 0, 85], [255, 0, 170], 19 | [0, 255, 0], [85, 255, 0], [170, 255, 0], 20 | [0, 255, 85], [0, 255, 170], 21 | [0, 0, 255], [85, 0, 255], [170, 0, 255], 22 | [0, 85, 255], [0, 170, 255], 23 | [255, 255, 0], [255, 255, 85], [255, 255, 170], 24 | [255, 0, 255], [255, 85, 255], [255, 170, 255], 25 | [0, 255, 255], [85, 255, 255], [170, 255, 255]] 26 | 27 | im = np.array(im) 28 | vis_im = im.copy().astype(np.uint8) 29 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 30 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 31 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 32 | 33 | num_of_class = np.max(vis_parsing_anno) 34 | 35 | for pi in range(1, num_of_class + 1): 36 | index = np.where(vis_parsing_anno == pi) 37 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 38 | 39 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 40 | # print(vis_parsing_anno_color.shape, vis_im.shape) 41 | vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) 42 | 43 | # Save result or not 44 | if save_im: 45 | cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno) 46 | cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) 47 | 48 | # return vis_im 49 | 50 | def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): 51 | 52 | if not os.path.exists(respth): 53 | os.makedirs(respth) 54 | 55 | n_classes = 19 56 | net = BiSeNet(n_classes=n_classes) 57 | net.cuda() 58 | save_pth = osp.join('', cp) 59 | net.load_state_dict(torch.load(save_pth)) 60 | net.eval() 61 | 62 | to_tensor = transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 65 | ]) 66 | with torch.no_grad(): 67 | for image_path in os.listdir(dspth): 68 | img = Image.open(osp.join(dspth, image_path)) 69 | image = img.resize((512, 512), Image.BILINEAR) 70 | img = to_tensor(image) 71 | img = torch.unsqueeze(img, 0) 72 | img = img.cuda() 73 | out = net(img) 74 | out = out[0] 75 | parsing = out.squeeze(0).cpu().numpy().argmax(0) 76 | # print(parsing) 77 | print(np.unique(parsing)) 78 | 79 | vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path)) 80 | 81 | 82 | 83 | 84 | if __name__ == "__main__": 85 | evaluate(dspth='../face-parsing/imgs/', cp='./pretrained_models/79999_iter.pth') 86 | 87 | 88 | -------------------------------------------------------------------------------- /generate_imgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data as data 8 | import yaml 9 | 10 | from PIL import Image 11 | from torchvision import transforms, utils 12 | from tensorboard_logger import Logger 13 | from tqdm import tqdm 14 | from utils.functions import * 15 | 16 | import sys 17 | sys.path.append('pixel2style2pixel/') 18 | from pixel2style2pixel.models.stylegan2.model import Generator, get_keys 19 | 20 | torch.backends.cudnn.enabled = True 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = True 23 | torch.autograd.set_detect_anomaly(True) 24 | Image.MAX_IMAGE_PIXELS = None 25 | device = torch.device('cuda') 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--config', type=str, default='002', help='Path to the config file.') 29 | parser.add_argument('--dataset_path', type=str, default='./data/stylegan2-generate-images/', help='dataset path') 30 | parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='pretrained stylegan model') 31 | opts = parser.parse_args() 32 | 33 | 34 | StyleGAN = Generator(1024, 512, 8) 35 | state_dict = torch.load(opts.stylegan_model_path, map_location='cpu') 36 | StyleGAN.load_state_dict(get_keys(state_dict, 'decoder'), strict=True) 37 | StyleGAN.to(device) 38 | 39 | #seeds = np.array([torch.random.seed() for i in range(100000)]) 40 | seeds = np.load(opts.dataset_path + 'seeds_pytorch_1.8.1.npy') 41 | 42 | with torch.no_grad(): 43 | os.makedirs(opts.dataset_path + 'ims/', exist_ok=True) 44 | 45 | for i, seed in enumerate(tqdm(seeds)): 46 | 47 | torch.manual_seed(seed) 48 | z = torch.randn(1, 512).to(device) 49 | n = StyleGAN.make_noise() 50 | w = StyleGAN.get_latent(z) 51 | x, _ = StyleGAN([w], input_is_latent=True, noise=n) 52 | utils.save_image(clip_img(x), opts.dataset_path + 'ims/%06d.jpg'%i) 53 | -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/images/teaser.png -------------------------------------------------------------------------------- /lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | # from torch.autograd import Variable 9 | 10 | from lpips.trainer import * 11 | from lpips.lpips import * 12 | 13 | # class PerceptualLoss(torch.nn.Module): 14 | # def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | # super(PerceptualLoss, self).__init__() 17 | # print('Setting up Perceptual loss...') 18 | # self.use_gpu = use_gpu 19 | # self.spatial = spatial 20 | # self.gpu_ids = gpu_ids 21 | # self.model = dist_model.DistModel() 22 | # self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 23 | # print('...[%s] initialized'%self.model.name()) 24 | # print('...Done') 25 | 26 | # def forward(self, pred, target, normalize=False): 27 | # """ 28 | # Pred and target are Variables. 29 | # If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | # If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | # Inputs pred and target are Nx3xHxW 33 | # Output pytorch Variable N long 34 | # """ 35 | 36 | # if normalize: 37 | # target = 2 * target - 1 38 | # pred = 2 * pred - 1 39 | 40 | # return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2+1e-8,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | from skimage.measure import compare_ssim 54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 55 | 56 | def rgb2lab(in_img,mean_cent=False): 57 | from skimage import color 58 | img_lab = color.rgb2lab(in_img) 59 | if(mean_cent): 60 | img_lab[:,:,0] = img_lab[:,:,0]-50 61 | return img_lab 62 | 63 | def tensor2np(tensor_obj): 64 | # change dimension of a tensor object into a numpy array 65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 66 | 67 | def np2tensor(np_obj): 68 | # change dimenion of np array into tensor array 69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 70 | 71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 72 | # image tensor to lab tensor 73 | from skimage import color 74 | 75 | img = tensor2im(image_tensor) 76 | img_lab = color.rgb2lab(img) 77 | if(mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | if(to_norm and not mc_only): 80 | img_lab[:,:,0] = img_lab[:,:,0]-50 81 | img_lab = img_lab/100. 82 | 83 | return np2tensor(img_lab) 84 | 85 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 86 | from skimage import color 87 | import warnings 88 | warnings.filterwarnings("ignore") 89 | 90 | lab = tensor2np(lab_tensor)*100. 91 | lab[:,:,0] = lab[:,:,0]+50 92 | 93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 94 | if(return_inbnd): 95 | # convert back to lab, see if we match 96 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 97 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 99 | return (im2tensor(rgb_back),mask) 100 | else: 101 | return im2tensor(rgb_back) 102 | 103 | def load_image(path): 104 | if(path[-3:] == 'dng'): 105 | import rawpy 106 | with rawpy.imread(path) as raw: 107 | img = raw.postprocess() 108 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg'): 109 | import cv2 110 | return cv2.imread(path)[:,:,::-1] 111 | else: 112 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 113 | 114 | return img 115 | 116 | def rgb2lab(input): 117 | from skimage import color 118 | return color.rgb2lab(input / 255.) 119 | 120 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 121 | image_numpy = image_tensor[0].cpu().float().numpy() 122 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 123 | return image_numpy.astype(imtype) 124 | 125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 126 | return torch.Tensor((image / factor - cent) 127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 128 | 129 | def tensor2vec(vector_tensor): 130 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 131 | 132 | 133 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 134 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 135 | image_numpy = image_tensor[0].cpu().float().numpy() 136 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 137 | return image_numpy.astype(imtype) 138 | 139 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 140 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 141 | return torch.Tensor((image / factor - cent) 142 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 143 | 144 | 145 | 146 | def voc_ap(rec, prec, use_07_metric=False): 147 | """ ap = voc_ap(rec, prec, [use_07_metric]) 148 | Compute VOC AP given precision and recall. 149 | If use_07_metric is true, uses the 150 | VOC 07 11 point method (default:False). 151 | """ 152 | if use_07_metric: 153 | # 11 point metric 154 | ap = 0. 155 | for t in np.arange(0., 1.1, 0.1): 156 | if np.sum(rec >= t) == 0: 157 | p = 0 158 | else: 159 | p = np.max(prec[rec >= t]) 160 | ap = ap + p / 11. 161 | else: 162 | # correct AP calculation 163 | # first append sentinel values at the end 164 | mrec = np.concatenate(([0.], rec, [1.])) 165 | mpre = np.concatenate(([0.], prec, [0.])) 166 | 167 | # compute the precision envelope 168 | for i in range(mpre.size - 1, 0, -1): 169 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 170 | 171 | # to calculate area under PR curve, look for points 172 | # where X axis (recall) changes value 173 | i = np.where(mrec[1:] != mrec[:-1])[0] 174 | 175 | # and sum (\Delta recall) * prec 176 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 177 | return ap 178 | 179 | -------------------------------------------------------------------------------- /lpips/lpips.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from . import pretrained_networks as pn 10 | import torch.nn 11 | 12 | import lpips 13 | 14 | def spatial_average(in_tens, keepdim=True): 15 | return in_tens.mean([2,3],keepdim=keepdim) 16 | 17 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W 18 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 19 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) 20 | 21 | # Learned perceptual metric 22 | class LPIPS(nn.Module): 23 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, 24 | pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): 25 | # lpips - [True] means with linear calibration on top of base network 26 | # pretrained - [True] means load linear weights 27 | 28 | super(LPIPS, self).__init__() 29 | if(verbose): 30 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% 31 | ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) 32 | 33 | self.pnet_type = net 34 | self.pnet_tune = pnet_tune 35 | self.pnet_rand = pnet_rand 36 | self.spatial = spatial 37 | self.lpips = lpips # false means baseline of just averaging all layers 38 | self.version = version 39 | self.scaling_layer = ScalingLayer() 40 | 41 | if(self.pnet_type in ['vgg','vgg16']): 42 | net_type = pn.vgg16 43 | self.chns = [64,128,256,512,512] 44 | elif(self.pnet_type=='alex'): 45 | net_type = pn.alexnet 46 | self.chns = [64,192,384,256,256] 47 | elif(self.pnet_type=='squeeze'): 48 | net_type = pn.squeezenet 49 | self.chns = [64,128,256,384,384,512,512] 50 | self.L = len(self.chns) 51 | 52 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 53 | 54 | if(lpips): 55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 60 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 61 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 62 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 63 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 64 | self.lins+=[self.lin5,self.lin6] 65 | self.lins = nn.ModuleList(self.lins) 66 | 67 | if(pretrained): 68 | if(model_path is None): 69 | import inspect 70 | import os 71 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) 72 | 73 | if(verbose): 74 | print('Loading model from: %s'%model_path) 75 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) 76 | 77 | if(eval_mode): 78 | self.eval() 79 | 80 | def forward(self, in0, in1, retPerLayer=False, normalize=False): 81 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 82 | in0 = 2 * in0 - 1 83 | in1 = 2 * in1 - 1 84 | 85 | # v0.0 - original release had a bug, where input was not scaled 86 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 87 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 88 | feats0, feats1, diffs = {}, {}, {} 89 | 90 | for kk in range(self.L): 91 | feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) 92 | diffs[kk] = (feats0[kk]-feats1[kk])**2 93 | 94 | if(self.lpips): 95 | if(self.spatial): 96 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 97 | else: 98 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 99 | else: 100 | if(self.spatial): 101 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 102 | else: 103 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 104 | 105 | val = res[0] 106 | for l in range(1,self.L): 107 | val += res[l] 108 | 109 | # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 110 | # b = torch.max(self.lins[kk](feats0[kk]**2)) 111 | # for kk in range(self.L): 112 | # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 113 | # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) 114 | # a = a/self.L 115 | # from IPython import embed 116 | # embed() 117 | # return 10*torch.log10(b/a) 118 | 119 | if(retPerLayer): 120 | return (val, res) 121 | else: 122 | return val 123 | 124 | 125 | class ScalingLayer(nn.Module): 126 | def __init__(self): 127 | super(ScalingLayer, self).__init__() 128 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 129 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 130 | 131 | def forward(self, inp): 132 | return (inp - self.shift) / self.scale 133 | 134 | 135 | class NetLinLayer(nn.Module): 136 | ''' A single linear layer which does a 1x1 conv ''' 137 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 138 | super(NetLinLayer, self).__init__() 139 | 140 | layers = [nn.Dropout(),] if(use_dropout) else [] 141 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 142 | self.model = nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | return self.model(x) 146 | 147 | class Dist2LogitLayer(nn.Module): 148 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 149 | def __init__(self, chn_mid=32, use_sigmoid=True): 150 | super(Dist2LogitLayer, self).__init__() 151 | 152 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 153 | layers += [nn.LeakyReLU(0.2,True),] 154 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 155 | layers += [nn.LeakyReLU(0.2,True),] 156 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 157 | if(use_sigmoid): 158 | layers += [nn.Sigmoid(),] 159 | self.model = nn.Sequential(*layers) 160 | 161 | def forward(self,d0,d1,eps=0.1): 162 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 163 | 164 | class BCERankingLoss(nn.Module): 165 | def __init__(self, chn_mid=32): 166 | super(BCERankingLoss, self).__init__() 167 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 168 | # self.parameters = list(self.net.parameters()) 169 | self.loss = torch.nn.BCELoss() 170 | 171 | def forward(self, d0, d1, judge): 172 | per = (judge+1.)/2. 173 | self.logit = self.net.forward(d0,d1) 174 | return self.loss(self.logit, per) 175 | 176 | # L2, DSSIM metrics 177 | class FakeNet(nn.Module): 178 | def __init__(self, use_gpu=True, colorspace='Lab'): 179 | super(FakeNet, self).__init__() 180 | self.use_gpu = use_gpu 181 | self.colorspace = colorspace 182 | 183 | class L2(FakeNet): 184 | def forward(self, in0, in1, retPerLayer=None): 185 | assert(in0.size()[0]==1) # currently only supports batchSize 1 186 | 187 | if(self.colorspace=='RGB'): 188 | (N,C,X,Y) = in0.size() 189 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 190 | return value 191 | elif(self.colorspace=='Lab'): 192 | value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), 193 | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 194 | ret_var = Variable( torch.Tensor((value,) ) ) 195 | if(self.use_gpu): 196 | ret_var = ret_var.cuda() 197 | return ret_var 198 | 199 | class DSSIM(FakeNet): 200 | 201 | def forward(self, in0, in1, retPerLayer=None): 202 | assert(in0.size()[0]==1) # currently only supports batchSize 1 203 | 204 | if(self.colorspace=='RGB'): 205 | value = lpips.dssim(1.*lpips.tensor2im(in0.data), 1.*lpips.tensor2im(in1.data), range=255.).astype('float') 206 | elif(self.colorspace=='Lab'): 207 | value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), 208 | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 209 | ret_var = Variable( torch.Tensor((value,) ) ) 210 | if(self.use_gpu): 211 | ret_var = ret_var.cuda() 212 | return ret_var 213 | 214 | def print_network(net): 215 | num_params = 0 216 | for param in net.parameters(): 217 | num_params += param.numel() 218 | print('Network',net) 219 | print('Total number of parameters: %d' % num_params) 220 | -------------------------------------------------------------------------------- /lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | import os 6 | 7 | class squeezenet(torch.nn.Module): 8 | def __init__(self, requires_grad=False, pretrained=True): 9 | super(squeezenet, self).__init__() 10 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 11 | self.slice1 = torch.nn.Sequential() 12 | self.slice2 = torch.nn.Sequential() 13 | self.slice3 = torch.nn.Sequential() 14 | self.slice4 = torch.nn.Sequential() 15 | self.slice5 = torch.nn.Sequential() 16 | self.slice6 = torch.nn.Sequential() 17 | self.slice7 = torch.nn.Sequential() 18 | self.N_slices = 7 19 | for x in range(2): 20 | self.slice1.add_module(str(x), pretrained_features[x]) 21 | for x in range(2,5): 22 | self.slice2.add_module(str(x), pretrained_features[x]) 23 | for x in range(5, 8): 24 | self.slice3.add_module(str(x), pretrained_features[x]) 25 | for x in range(8, 10): 26 | self.slice4.add_module(str(x), pretrained_features[x]) 27 | for x in range(10, 11): 28 | self.slice5.add_module(str(x), pretrained_features[x]) 29 | for x in range(11, 12): 30 | self.slice6.add_module(str(x), pretrained_features[x]) 31 | for x in range(12, 13): 32 | self.slice7.add_module(str(x), pretrained_features[x]) 33 | if not requires_grad: 34 | for param in self.parameters(): 35 | param.requires_grad = False 36 | 37 | def forward(self, X): 38 | h = self.slice1(X) 39 | h_relu1 = h 40 | h = self.slice2(h) 41 | h_relu2 = h 42 | h = self.slice3(h) 43 | h_relu3 = h 44 | h = self.slice4(h) 45 | h_relu4 = h 46 | h = self.slice5(h) 47 | h_relu5 = h 48 | h = self.slice6(h) 49 | h_relu6 = h 50 | h = self.slice7(h) 51 | h_relu7 = h 52 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 53 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 54 | 55 | return out 56 | 57 | 58 | class alexnet(torch.nn.Module): 59 | def __init__(self, requires_grad=False, pretrained=True): 60 | super(alexnet, self).__init__() 61 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 62 | self.slice1 = torch.nn.Sequential() 63 | self.slice2 = torch.nn.Sequential() 64 | self.slice3 = torch.nn.Sequential() 65 | self.slice4 = torch.nn.Sequential() 66 | self.slice5 = torch.nn.Sequential() 67 | self.N_slices = 5 68 | for x in range(2): 69 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 70 | for x in range(2, 5): 71 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(5, 8): 73 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(8, 10): 75 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 76 | for x in range(10, 12): 77 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 78 | if not requires_grad: 79 | for param in self.parameters(): 80 | param.requires_grad = False 81 | 82 | def forward(self, X): 83 | h = self.slice1(X) 84 | h_relu1 = h 85 | h = self.slice2(h) 86 | h_relu2 = h 87 | h = self.slice3(h) 88 | h_relu3 = h 89 | h = self.slice4(h) 90 | h_relu4 = h 91 | h = self.slice5(h) 92 | h_relu5 = h 93 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 94 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 95 | 96 | return out 97 | 98 | class vgg16(torch.nn.Module): 99 | def __init__(self, requires_grad=False, pretrained=True): 100 | super(vgg16, self).__init__() 101 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 102 | self.slice1 = torch.nn.Sequential() 103 | self.slice2 = torch.nn.Sequential() 104 | self.slice3 = torch.nn.Sequential() 105 | self.slice4 = torch.nn.Sequential() 106 | self.slice5 = torch.nn.Sequential() 107 | self.N_slices = 5 108 | for x in range(4): 109 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(4, 9): 111 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(9, 16): 113 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(16, 23): 115 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 116 | for x in range(23, 30): 117 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 118 | if not requires_grad: 119 | for param in self.parameters(): 120 | param.requires_grad = False 121 | 122 | def forward(self, X): 123 | h = self.slice1(X) 124 | h_relu1_2 = h 125 | h = self.slice2(h) 126 | h_relu2_2 = h 127 | h = self.slice3(h) 128 | h_relu3_3 = h 129 | h = self.slice4(h) 130 | h_relu4_3 = h 131 | h = self.slice5(h) 132 | h_relu5_3 = h 133 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 134 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 135 | 136 | return out 137 | 138 | 139 | 140 | class resnet(torch.nn.Module): 141 | def __init__(self, requires_grad=False, pretrained=True, num=18): 142 | super(resnet, self).__init__() 143 | if(num==18): 144 | self.net = tv.resnet18(pretrained=pretrained) 145 | elif(num==34): 146 | self.net = tv.resnet34(pretrained=pretrained) 147 | elif(num==50): 148 | self.net = tv.resnet50(pretrained=pretrained) 149 | elif(num==101): 150 | self.net = tv.resnet101(pretrained=pretrained) 151 | elif(num==152): 152 | self.net = tv.resnet152(pretrained=pretrained) 153 | self.N_slices = 5 154 | 155 | self.conv1 = self.net.conv1 156 | self.bn1 = self.net.bn1 157 | self.relu = self.net.relu 158 | self.maxpool = self.net.maxpool 159 | self.layer1 = self.net.layer1 160 | self.layer2 = self.net.layer2 161 | self.layer3 = self.net.layer3 162 | self.layer4 = self.net.layer4 163 | 164 | def forward(self, X): 165 | h = self.conv1(X) 166 | h = self.bn1(h) 167 | h = self.relu(h) 168 | h_relu1 = h 169 | h = self.maxpool(h) 170 | h = self.layer1(h) 171 | h_conv2 = h 172 | h = self.layer2(h) 173 | h_conv3 = h 174 | h = self.layer3(h) 175 | h_conv4 = h 176 | h = self.layer4(h) 177 | h_conv5 = h 178 | 179 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 180 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 181 | 182 | return out 183 | -------------------------------------------------------------------------------- /lpips/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from collections import OrderedDict 8 | from torch.autograd import Variable 9 | from scipy.ndimage import zoom 10 | from tqdm import tqdm 11 | import lpips 12 | import os 13 | 14 | 15 | class Trainer(): 16 | def name(self): 17 | return self.model_name 18 | 19 | def initialize(self, model='lpips', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 20 | use_gpu=True, printNet=False, spatial=False, 21 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 22 | ''' 23 | INPUTS 24 | model - ['lpips'] for linearly calibrated network 25 | ['baseline'] for off-the-shelf network 26 | ['L2'] for L2 distance in Lab colorspace 27 | ['SSIM'] for ssim in RGB colorspace 28 | net - ['squeeze','alex','vgg'] 29 | model_path - if None, will look in weights/[NET_NAME].pth 30 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 31 | use_gpu - bool - whether or not to use a GPU 32 | printNet - bool - whether or not to print network architecture out 33 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 34 | is_train - bool - [True] for training mode 35 | lr - float - initial learning rate 36 | beta1 - float - initial momentum term for adam 37 | version - 0.1 for latest, 0.0 was original (with a bug) 38 | gpu_ids - int array - [0] by default, gpus to use 39 | ''' 40 | self.use_gpu = use_gpu 41 | self.gpu_ids = gpu_ids 42 | self.model = model 43 | self.net = net 44 | self.is_train = is_train 45 | self.spatial = spatial 46 | self.model_name = '%s [%s]'%(model,net) 47 | 48 | if(self.model == 'lpips'): # pretrained net + linear layer 49 | self.net = lpips.LPIPS(pretrained=not is_train, net=net, version=version, lpips=True, spatial=spatial, 50 | pnet_rand=pnet_rand, pnet_tune=pnet_tune, 51 | use_dropout=True, model_path=model_path, eval_mode=False) 52 | elif(self.model=='baseline'): # pretrained network 53 | self.net = lpips.LPIPS(pnet_rand=pnet_rand, net=net, lpips=False) 54 | elif(self.model in ['L2','l2']): 55 | self.net = lpips.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 56 | self.model_name = 'L2' 57 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 58 | self.net = lpips.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 59 | self.model_name = 'SSIM' 60 | else: 61 | raise ValueError("Model [%s] not recognized." % self.model) 62 | 63 | self.parameters = list(self.net.parameters()) 64 | 65 | if self.is_train: # training mode 66 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 67 | self.rankLoss = lpips.BCERankingLoss() 68 | self.parameters += list(self.rankLoss.net.parameters()) 69 | self.lr = lr 70 | self.old_lr = lr 71 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 72 | else: # test mode 73 | self.net.eval() 74 | 75 | if(use_gpu): 76 | self.net.to(gpu_ids[0]) 77 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 78 | if(self.is_train): 79 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 80 | 81 | if(printNet): 82 | print('---------- Networks initialized -------------') 83 | networks.print_network(self.net) 84 | print('-----------------------------------------------') 85 | 86 | def forward(self, in0, in1, retPerLayer=False): 87 | ''' Function computes the distance between image patches in0 and in1 88 | INPUTS 89 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 90 | OUTPUT 91 | computed distances between in0 and in1 92 | ''' 93 | 94 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 95 | 96 | # ***** TRAINING FUNCTIONS ***** 97 | def optimize_parameters(self): 98 | self.forward_train() 99 | self.optimizer_net.zero_grad() 100 | self.backward_train() 101 | self.optimizer_net.step() 102 | self.clamp_weights() 103 | 104 | def clamp_weights(self): 105 | for module in self.net.modules(): 106 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 107 | module.weight.data = torch.clamp(module.weight.data,min=0) 108 | 109 | def set_input(self, data): 110 | self.input_ref = data['ref'] 111 | self.input_p0 = data['p0'] 112 | self.input_p1 = data['p1'] 113 | self.input_judge = data['judge'] 114 | 115 | if(self.use_gpu): 116 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 117 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 118 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 119 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 120 | 121 | self.var_ref = Variable(self.input_ref,requires_grad=True) 122 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 123 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 124 | 125 | def forward_train(self): # run forward pass 126 | self.d0 = self.forward(self.var_ref, self.var_p0) 127 | self.d1 = self.forward(self.var_ref, self.var_p1) 128 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 129 | 130 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 131 | 132 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 133 | 134 | return self.loss_total 135 | 136 | def backward_train(self): 137 | torch.mean(self.loss_total).backward() 138 | 139 | def compute_accuracy(self,d0,d1,judge): 140 | ''' d0, d1 are Variables, judge is a Tensor ''' 141 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 197 | self.old_lr = lr 198 | 199 | 200 | def get_image_paths(self): 201 | return self.image_paths 202 | 203 | def save_done(self, flag=False): 204 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 205 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 206 | 207 | 208 | def score_2afc_dataset(data_loader, func, name=''): 209 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 210 | distance function 'func' in dataset 'data_loader' 211 | INPUTS 212 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 213 | func - callable distance function - calling d=func(in0,in1) should take 2 214 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 215 | OUTPUTS 216 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 217 | [1] - dictionary with following elements 218 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 219 | gts - N array in [0,1], preferred patch selected by human evaluators 220 | (closer to "0" for left patch p0, "1" for right patch p1, 221 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 222 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 223 | CONSTS 224 | N - number of test triplets in data_loader 225 | ''' 226 | 227 | d0s = [] 228 | d1s = [] 229 | gts = [] 230 | 231 | for data in tqdm(data_loader.load_data(), desc=name): 232 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 233 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 234 | gts+=data['judge'].cpu().numpy().flatten().tolist() 235 | 236 | d0s = np.array(d0s) 237 | d1s = np.array(d1s) 238 | gts = np.array(gts) 239 | scores = (d0s 1: 83 | kernel = kernel * (upsample_factor ** 2) 84 | 85 | self.register_buffer('kernel', kernel) 86 | 87 | self.pad = pad 88 | 89 | def forward(self, input): 90 | out = upfirdn2d(input, self.kernel, pad=self.pad) 91 | 92 | return out 93 | 94 | 95 | class EqualConv2d(nn.Module): 96 | def __init__( 97 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 98 | ): 99 | super().__init__() 100 | 101 | self.weight = nn.Parameter( 102 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 103 | ) 104 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 105 | 106 | self.stride = stride 107 | self.padding = padding 108 | 109 | if bias: 110 | self.bias = nn.Parameter(torch.zeros(out_channel)) 111 | 112 | else: 113 | self.bias = None 114 | 115 | def forward(self, input): 116 | out = F.conv2d( 117 | input, 118 | self.weight * self.scale, 119 | bias=self.bias, 120 | stride=self.stride, 121 | padding=self.padding, 122 | ) 123 | 124 | return out 125 | 126 | def __repr__(self): 127 | return ( 128 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 129 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 130 | ) 131 | 132 | 133 | class EqualLinear(nn.Module): 134 | def __init__( 135 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 136 | ): 137 | super().__init__() 138 | 139 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 140 | 141 | if bias: 142 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 143 | 144 | else: 145 | self.bias = None 146 | 147 | self.activation = activation 148 | 149 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 150 | self.lr_mul = lr_mul 151 | 152 | def forward(self, input): 153 | if self.activation: 154 | out = F.linear(input, self.weight * self.scale) 155 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 156 | 157 | else: 158 | out = F.linear( 159 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 160 | ) 161 | 162 | return out 163 | 164 | def __repr__(self): 165 | return ( 166 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 167 | ) 168 | 169 | 170 | class ScaledLeakyReLU(nn.Module): 171 | def __init__(self, negative_slope=0.2): 172 | super().__init__() 173 | 174 | self.negative_slope = negative_slope 175 | 176 | def forward(self, input): 177 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 178 | 179 | return out * math.sqrt(2) 180 | 181 | 182 | class ModulatedConv2d(nn.Module): 183 | def __init__( 184 | self, 185 | in_channel, 186 | out_channel, 187 | kernel_size, 188 | style_dim, 189 | demodulate=True, 190 | upsample=False, 191 | downsample=False, 192 | blur_kernel=[1, 3, 3, 1], 193 | ): 194 | super().__init__() 195 | 196 | self.eps = 1e-8 197 | self.kernel_size = kernel_size 198 | self.in_channel = in_channel 199 | self.out_channel = out_channel 200 | self.upsample = upsample 201 | self.downsample = downsample 202 | 203 | if upsample: 204 | factor = 2 205 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 206 | pad0 = (p + 1) // 2 + factor - 1 207 | pad1 = p // 2 + 1 208 | 209 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 210 | 211 | if downsample: 212 | factor = 2 213 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 214 | pad0 = (p + 1) // 2 215 | pad1 = p // 2 216 | 217 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 218 | 219 | fan_in = in_channel * kernel_size ** 2 220 | self.scale = 1 / math.sqrt(fan_in) 221 | self.padding = kernel_size // 2 222 | 223 | self.weight = nn.Parameter( 224 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 225 | ) 226 | 227 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 228 | 229 | self.demodulate = demodulate 230 | 231 | def __repr__(self): 232 | return ( 233 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 234 | f'upsample={self.upsample}, downsample={self.downsample})' 235 | ) 236 | 237 | def forward(self, input, style): 238 | batch, in_channel, height, width = input.shape 239 | 240 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 241 | weight = self.scale * self.weight * style 242 | 243 | if self.demodulate: 244 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 245 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 246 | 247 | weight = weight.view( 248 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 249 | ) 250 | 251 | if self.upsample: 252 | input = input.view(1, batch * in_channel, height, width) 253 | weight = weight.view( 254 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 255 | ) 256 | weight = weight.transpose(1, 2).reshape( 257 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 258 | ) 259 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 260 | _, _, height, width = out.shape 261 | out = out.view(batch, self.out_channel, height, width) 262 | out = self.blur(out) 263 | 264 | elif self.downsample: 265 | input = self.blur(input) 266 | _, _, height, width = input.shape 267 | input = input.view(1, batch * in_channel, height, width) 268 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 269 | _, _, height, width = out.shape 270 | out = out.view(batch, self.out_channel, height, width) 271 | 272 | else: 273 | input = input.view(1, batch * in_channel, height, width) 274 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 275 | _, _, height, width = out.shape 276 | out = out.view(batch, self.out_channel, height, width) 277 | 278 | return out 279 | 280 | 281 | class NoiseInjection(nn.Module): 282 | def __init__(self): 283 | super().__init__() 284 | 285 | self.weight = nn.Parameter(torch.zeros(1)) 286 | 287 | def forward(self, image, noise=None): 288 | if noise is None: 289 | batch, _, height, width = image.shape 290 | noise = image.new_empty(batch, 1, height, width).normal_() 291 | 292 | return image + self.weight * noise 293 | 294 | 295 | class ConstantInput(nn.Module): 296 | def __init__(self, channel, size=4): 297 | super().__init__() 298 | 299 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 300 | 301 | def forward(self, input): 302 | batch = input.shape[0] 303 | out = self.input.repeat(batch, 1, 1, 1) 304 | 305 | return out 306 | 307 | 308 | class StyledConv(nn.Module): 309 | def __init__( 310 | self, 311 | in_channel, 312 | out_channel, 313 | kernel_size, 314 | style_dim, 315 | upsample=False, 316 | blur_kernel=[1, 3, 3, 1], 317 | demodulate=True, 318 | ): 319 | super().__init__() 320 | 321 | self.conv = ModulatedConv2d( 322 | in_channel, 323 | out_channel, 324 | kernel_size, 325 | style_dim, 326 | upsample=upsample, 327 | blur_kernel=blur_kernel, 328 | demodulate=demodulate, 329 | ) 330 | 331 | self.noise = NoiseInjection() 332 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 333 | # self.activate = ScaledLeakyReLU(0.2) 334 | self.activate = FusedLeakyReLU(out_channel) 335 | 336 | def forward(self, input, style, noise=None): 337 | out = self.conv(input, style) 338 | out = self.noise(out, noise=noise) 339 | # out = out + self.bias 340 | out = self.activate(out) 341 | 342 | return out 343 | 344 | 345 | class ToRGB(nn.Module): 346 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 347 | super().__init__() 348 | 349 | if upsample: 350 | self.upsample = Upsample(blur_kernel) 351 | 352 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 353 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 354 | 355 | def forward(self, input, style, skip=None): 356 | out = self.conv(input, style) 357 | out = out + self.bias 358 | 359 | if skip is not None: 360 | skip = self.upsample(skip) 361 | 362 | out = out + skip 363 | 364 | return out 365 | 366 | 367 | class Generator(nn.Module): 368 | def __init__( 369 | self, 370 | size, 371 | style_dim, 372 | n_mlp, 373 | channel_multiplier=2, 374 | blur_kernel=[1, 3, 3, 1], 375 | lr_mlp=0.01, 376 | ): 377 | super().__init__() 378 | 379 | self.size = size 380 | 381 | self.style_dim = style_dim 382 | 383 | layers = [PixelNorm()] 384 | 385 | for i in range(n_mlp): 386 | layers.append( 387 | EqualLinear( 388 | style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' 389 | ) 390 | ) 391 | 392 | self.style = nn.Sequential(*layers) 393 | 394 | self.channels = { 395 | 4: 512, 396 | 8: 512, 397 | 16: 512, 398 | 32: 512, 399 | 64: 256 * channel_multiplier, 400 | 128: 128 * channel_multiplier, 401 | 256: 64 * channel_multiplier, 402 | 512: 32 * channel_multiplier, 403 | 1024: 16 * channel_multiplier, 404 | } 405 | 406 | self.input = ConstantInput(self.channels[4]) 407 | self.conv1 = StyledConv( 408 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 409 | ) 410 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 411 | 412 | self.log_size = int(math.log(size, 2)) 413 | self.num_layers = (self.log_size - 2) * 2 + 1 414 | 415 | self.convs = nn.ModuleList() 416 | self.upsamples = nn.ModuleList() 417 | self.to_rgbs = nn.ModuleList() 418 | self.noises = nn.Module() 419 | 420 | in_channel = self.channels[4] 421 | 422 | for layer_idx in range(self.num_layers): 423 | res = (layer_idx + 5) // 2 424 | shape = [1, 1, 2 ** res, 2 ** res] 425 | self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) 426 | 427 | for i in range(3, self.log_size + 1): 428 | out_channel = self.channels[2 ** i] 429 | 430 | self.convs.append( 431 | StyledConv( 432 | in_channel, 433 | out_channel, 434 | 3, 435 | style_dim, 436 | upsample=True, 437 | blur_kernel=blur_kernel, 438 | ) 439 | ) 440 | 441 | self.convs.append( 442 | StyledConv( 443 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 444 | ) 445 | ) 446 | 447 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 448 | 449 | in_channel = out_channel 450 | 451 | self.n_latent = self.log_size * 2 - 2 452 | 453 | def make_noise(self): 454 | device = self.input.input.device 455 | 456 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 457 | 458 | for i in range(3, self.log_size + 1): 459 | for _ in range(2): 460 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 461 | 462 | return noises 463 | 464 | def mean_latent(self, n_latent): 465 | latent_in = torch.randn( 466 | n_latent, self.style_dim, device=self.input.input.device 467 | ) 468 | latent = self.style(latent_in).mean(0, keepdim=True) 469 | 470 | return latent 471 | 472 | def get_latent(self, input): 473 | return self.style(input) 474 | 475 | def forward( 476 | self, 477 | styles, 478 | return_latents=False, 479 | return_features=False, 480 | inject_index=None, 481 | truncation=1, 482 | truncation_latent=None, 483 | input_is_latent=False, 484 | noise=None, 485 | randomize_noise=True, 486 | features_in=None, 487 | feature_scale=1.0 488 | ): 489 | if not input_is_latent: 490 | styles = [self.style(s) for s in styles] 491 | 492 | if noise is None: 493 | if randomize_noise: 494 | noise = [None] * self.num_layers 495 | else: 496 | noise = [ 497 | getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) 498 | ] 499 | 500 | if truncation < 1: 501 | style_t = [] 502 | 503 | for style in styles: 504 | style_t.append( 505 | truncation_latent + truncation * (style - truncation_latent) 506 | ) 507 | 508 | styles = style_t 509 | 510 | if len(styles) < 2: 511 | inject_index = self.n_latent 512 | 513 | if styles[0].ndim < 3: 514 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 515 | else: 516 | latent = styles[0] 517 | 518 | else: 519 | if inject_index is None: 520 | inject_index = random.randint(1, self.n_latent - 1) 521 | 522 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 523 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 524 | 525 | latent = torch.cat([latent, latent2], 1) 526 | 527 | def insert_feature(x, layer_idx): 528 | if features_in is not None and features_in[layer_idx] is not None: 529 | x = (1 - feature_scale) * x + feature_scale * features_in[layer_idx].type_as(x) 530 | return x 531 | 532 | outs = [] 533 | out = self.input(latent) 534 | outs.append(out) 535 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 536 | outs.append(out) 537 | 538 | skip = self.to_rgb1(out, latent[:, 1]) 539 | 540 | i = 1 541 | for conv1, conv2, noise1, noise2, to_rgb in zip( 542 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 543 | ): 544 | out = insert_feature(out, i) 545 | out = conv1(out, latent[:, i], noise=noise1) 546 | outs.append(out) 547 | out = insert_feature(out, i + 1) 548 | out = conv2(out, latent[:, i + 1], noise=noise2) 549 | outs.append(out) 550 | skip = to_rgb(out, latent[:, i + 2], skip) 551 | 552 | i += 2 553 | 554 | image = skip 555 | 556 | if return_latents: 557 | return image, latent 558 | elif return_features: 559 | return image, outs 560 | else: 561 | return image, None 562 | 563 | 564 | class ConvLayer(nn.Sequential): 565 | def __init__( 566 | self, 567 | in_channel, 568 | out_channel, 569 | kernel_size, 570 | downsample=False, 571 | blur_kernel=[1, 3, 3, 1], 572 | bias=True, 573 | activate=True, 574 | ): 575 | layers = [] 576 | 577 | if downsample: 578 | factor = 2 579 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 580 | pad0 = (p + 1) // 2 581 | pad1 = p // 2 582 | 583 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 584 | 585 | stride = 2 586 | self.padding = 0 587 | 588 | else: 589 | stride = 1 590 | self.padding = kernel_size // 2 591 | 592 | layers.append( 593 | EqualConv2d( 594 | in_channel, 595 | out_channel, 596 | kernel_size, 597 | padding=self.padding, 598 | stride=stride, 599 | bias=bias and not activate, 600 | ) 601 | ) 602 | 603 | if activate: 604 | if bias: 605 | layers.append(FusedLeakyReLU(out_channel)) 606 | 607 | else: 608 | layers.append(ScaledLeakyReLU(0.2)) 609 | 610 | super().__init__(*layers) 611 | 612 | 613 | class ResBlock(nn.Module): 614 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 615 | super().__init__() 616 | 617 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 618 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 619 | 620 | self.skip = ConvLayer( 621 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 622 | ) 623 | 624 | def forward(self, input): 625 | out = self.conv1(input) 626 | out = self.conv2(out) 627 | 628 | skip = self.skip(input) 629 | out = (out + skip) / math.sqrt(2) 630 | 631 | return out 632 | 633 | 634 | class Discriminator(nn.Module): 635 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 636 | super().__init__() 637 | 638 | channels = { 639 | 4: 512, 640 | 8: 512, 641 | 16: 512, 642 | 32: 512, 643 | 64: 256 * channel_multiplier, 644 | 128: 128 * channel_multiplier, 645 | 256: 64 * channel_multiplier, 646 | 512: 32 * channel_multiplier, 647 | 1024: 16 * channel_multiplier, 648 | } 649 | 650 | convs = [ConvLayer(3, channels[size], 1)] 651 | 652 | log_size = int(math.log(size, 2)) 653 | 654 | in_channel = channels[size] 655 | 656 | for i in range(log_size, 2, -1): 657 | out_channel = channels[2 ** (i - 1)] 658 | 659 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 660 | 661 | in_channel = out_channel 662 | 663 | self.convs = nn.Sequential(*convs) 664 | 665 | self.stddev_group = 4 666 | self.stddev_feat = 1 667 | 668 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 669 | self.final_linear = nn.Sequential( 670 | EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), 671 | EqualLinear(channels[4], 1), 672 | ) 673 | 674 | def forward(self, input): 675 | out = self.convs(input) 676 | 677 | batch, channel, height, width = out.shape 678 | group = min(batch, self.stddev_group) 679 | stddev = out.view( 680 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 681 | ) 682 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 683 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 684 | stddev = stddev.repeat(group, 1, height, width) 685 | out = torch.cat([out, stddev], 1) 686 | 687 | out = self.final_conv(out) 688 | 689 | out = out.view(batch, -1) 690 | out = self.final_linear(out) 691 | 692 | return out 693 | -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/pixel2style2pixel/models/stylegan2/op/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/__pycache__/fused_act.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/pixel2style2pixel/models/stylegan2/op/__pycache__/fused_act.cpython-36.pyc -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/__pycache__/upfirdn2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/pixel2style2pixel/models/stylegan2/op/__pycache__/upfirdn2d.cpython-36.pyc -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | up_x, up_y = up 23 | down_x, down_y = down 24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 25 | 26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 27 | 28 | grad_input = upfirdn2d_op.upfirdn2d( 29 | grad_output, 30 | grad_kernel, 31 | down_x, 32 | down_y, 33 | up_x, 34 | up_y, 35 | g_pad_x0, 36 | g_pad_x1, 37 | g_pad_y0, 38 | g_pad_y1, 39 | ) 40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 41 | 42 | ctx.save_for_backward(kernel) 43 | 44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 45 | 46 | ctx.up_x = up_x 47 | ctx.up_y = up_y 48 | ctx.down_x = down_x 49 | ctx.down_y = down_y 50 | ctx.pad_x0 = pad_x0 51 | ctx.pad_x1 = pad_x1 52 | ctx.pad_y0 = pad_y0 53 | ctx.pad_y1 = pad_y1 54 | ctx.in_size = in_size 55 | ctx.out_size = out_size 56 | 57 | return grad_input 58 | 59 | @staticmethod 60 | def backward(ctx, gradgrad_input): 61 | kernel, = ctx.saved_tensors 62 | 63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 64 | 65 | gradgrad_out = upfirdn2d_op.upfirdn2d( 66 | gradgrad_input, 67 | kernel, 68 | ctx.up_x, 69 | ctx.up_y, 70 | ctx.down_x, 71 | ctx.down_y, 72 | ctx.pad_x0, 73 | ctx.pad_x1, 74 | ctx.pad_y0, 75 | ctx.pad_y1, 76 | ) 77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 78 | gradgrad_out = gradgrad_out.view( 79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 80 | ) 81 | 82 | return gradgrad_out, None, None, None, None, None, None, None, None 83 | 84 | 85 | class UpFirDn2d(Function): 86 | @staticmethod 87 | def forward(ctx, input, kernel, up, down, pad): 88 | up_x, up_y = up 89 | down_x, down_y = down 90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 91 | 92 | kernel_h, kernel_w = kernel.shape 93 | batch, channel, in_h, in_w = input.shape 94 | ctx.in_size = input.shape 95 | 96 | input = input.reshape(-1, in_h, in_w, 1) 97 | 98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 99 | 100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 102 | ctx.out_size = (out_h, out_w) 103 | 104 | ctx.up = (up_x, up_y) 105 | ctx.down = (down_x, down_y) 106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 107 | 108 | g_pad_x0 = kernel_w - pad_x0 - 1 109 | g_pad_y0 = kernel_h - pad_y0 - 1 110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 112 | 113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 114 | 115 | out = upfirdn2d_op.upfirdn2d( 116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 117 | ) 118 | # out = out.view(major, out_h, out_w, minor) 119 | out = out.view(-1, channel, out_h, out_w) 120 | 121 | return out 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | kernel, grad_kernel = ctx.saved_tensors 126 | 127 | grad_input = UpFirDn2dBackward.apply( 128 | grad_output, 129 | kernel, 130 | grad_kernel, 131 | ctx.up, 132 | ctx.down, 133 | ctx.pad, 134 | ctx.g_pad, 135 | ctx.in_size, 136 | ctx.out_size, 137 | ) 138 | 139 | return grad_input, None, None, None, None 140 | 141 | 142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 143 | out = UpFirDn2d.apply( 144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 145 | ) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native( 151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 152 | ): 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad( 161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 162 | ) 163 | out = out[ 164 | :, 165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), 166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), 167 | :, 168 | ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape( 172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 173 | ) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | 184 | return out[:, ::down_y, ::down_x, :] 185 | -------------------------------------------------------------------------------- /pixel2style2pixel/models/stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | use_gc=True, gc_conv_only=False 35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, 55 | eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | def __setstate__(self, state): 76 | super(Ranger, self).__setstate__(state) 77 | 78 | def step(self, closure=None): 79 | loss = None 80 | 81 | # Evaluate averages and grad, update param tensors 82 | for group in self.param_groups: 83 | 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | grad = p.grad.data.float() 88 | 89 | if grad.is_sparse: 90 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 91 | 92 | p_data_fp32 = p.data.float() 93 | 94 | state = self.state[p] # get state dict for this param 95 | 96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 97 | # if self.first_run_check==0: 98 | # self.first_run_check=1 99 | # print("Initializing slow buffer...should not see this at load from saved model!") 100 | state['step'] = 0 101 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 103 | 104 | # look ahead weight storage now in state dict 105 | state['slow_buffer'] = torch.empty_like(p.data) 106 | state['slow_buffer'].copy_(p.data) 107 | 108 | else: 109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 111 | 112 | # begin computations 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | # GC operation for Conv layers and FC layers 117 | if grad.dim() > self.gc_gradient_threshold: 118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 119 | 120 | state['step'] += 1 121 | 122 | # compute variance mov avg 123 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 124 | # compute mean moving avg 125 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 126 | 127 | buffered = self.radam_buffer[int(state['step'] % 10)] 128 | 129 | if state['step'] == buffered[0]: 130 | N_sma, step_size = buffered[1], buffered[2] 131 | else: 132 | buffered[0] = state['step'] 133 | beta2_t = beta2 ** state['step'] 134 | N_sma_max = 2 / (1 - beta2) - 1 135 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 136 | buffered[1] = N_sma 137 | if N_sma > self.N_sma_threshhold: 138 | step_size = math.sqrt( 139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 140 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 141 | else: 142 | step_size = 1.0 / (1 - beta1 ** state['step']) 143 | buffered[2] = step_size 144 | 145 | if group['weight_decay'] != 0: 146 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 147 | 148 | # apply lr 149 | if N_sma > self.N_sma_threshhold: 150 | denom = exp_avg_sq.sqrt().add_(group['eps']) 151 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 152 | else: 153 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 154 | 155 | p.data.copy_(p_data_fp32) 156 | 157 | # integrated look ahead... 158 | # we do it at the param level instead of group level 159 | if state['step'] % group['k'] == 0: 160 | slow_p = state['slow_buffer'] # get access to slow param tensor 161 | slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha 162 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor 163 | 164 | return loss -------------------------------------------------------------------------------- /run_video_inversion_editing.sh: -------------------------------------------------------------------------------- 1 | VideoName='FP010363HD03' 2 | Attribute='Heavy_Makeup' 3 | Scale='1' 4 | Sigma='3' # Choose appropriate gaussian filter size 5 | VideoDir='./data/video/' 6 | OutputDir='./output/video/' 7 | 8 | 9 | # Cut video to frames 10 | python video_processing.py --function 'video_to_frames' --video_path ${VideoDir}/${VideoName}.mp4 --output_path ${OutputDir} #--resize 11 | 12 | # Crop and align the faces in each frame 13 | python video_processing.py --function 'align_frames' --video_path ${VideoDir}/${VideoName}.mp4 --output_path ${OutputDir} --filter_size=${Sigma} --optical_flow 14 | 15 | # Inversion 16 | python test.py --config 143 --input_path ${OutputDir}/${VideoName}/${VideoName}_crop_align/ --save_path ${OutputDir}/${VideoName}/${VideoName}_inversion/ 17 | 18 | # Achieve latent manipulation 19 | python video_processing.py --function 'latent_manipulation' --video_path ${VideoDir}/${VideoName}.mp4 --attr ${Attribute} --alpha=${Scale} 20 | 21 | # Reproject the manipulated frames to the original video 22 | python video_processing.py --function 'reproject_origin' --video_path ${VideoDir}/${VideoName}.mp4 --seamless 23 | python video_processing.py --function 'reproject_manipulate' --video_path ${VideoDir}/${VideoName}.mp4 --attr ${Attribute} --seamless 24 | python video_processing.py --function 'compare_frames' --video_path ${VideoDir}/${VideoName}.mp4 --attr ${Attribute} --strs 'Original,Projected,Manipulated' 25 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.data as data 9 | import yaml 10 | 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from torchvision import transforms, utils 14 | 15 | from utils.datasets import * 16 | from utils.functions import * 17 | from trainer import * 18 | 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | torch.autograd.set_detect_anomaly(True) 23 | Image.MAX_IMAGE_PIXELS = None 24 | device = torch.device('cuda') 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--config', type=str, default='001', help='Path to the config file.') 28 | parser.add_argument('--pretrained_model_path', type=str, default='./pretrained_models/143_enc.pth', help='pretrained stylegan2 model') 29 | parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='pretrained stylegan2 model') 30 | parser.add_argument('--arcface_model_path', type=str, default='./pretrained_models/backbone.pth', help='pretrained arcface model') 31 | parser.add_argument('--parsing_model_path', type=str, default='./pretrained_models/79999_iter.pth', help='pretrained parsing model') 32 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 33 | parser.add_argument('--resume', action='store_true', help='resume from checkpoint') 34 | parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path') 35 | parser.add_argument('--checkpoint_noiser', type=str, default='', help='checkpoint file path') 36 | parser.add_argument('--multigpu', type=bool, default=False, help='use multiple gpus') 37 | parser.add_argument('--input_path', type=str, default='./test/', help='evaluation data file path') 38 | parser.add_argument('--save_path', type=str, default='./output/image/', help='output data save path') 39 | 40 | opts = parser.parse_args() 41 | 42 | log_dir = os.path.join(opts.log_path, opts.config) + '/' 43 | config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'), Loader=yaml.FullLoader) 44 | 45 | # Initialize trainer 46 | trainer = Trainer(config, opts) 47 | trainer.initialize(opts.stylegan_model_path, opts.arcface_model_path, opts.parsing_model_path) 48 | trainer.to(device) 49 | 50 | state_dict = torch.load(opts.pretrained_model_path)#os.path.join(opts.log_path, opts.config + '/checkpoint.pth')) 51 | trainer.enc.load_state_dict(torch.load(opts.pretrained_model_path)) 52 | trainer.enc.eval() 53 | 54 | img_to_tensor = transforms.Compose([ 55 | transforms.Resize((1024, 1024)), 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 58 | ]) 59 | 60 | # simple inference 61 | image_dir = opts.input_path 62 | save_dir = opts.save_path 63 | os.makedirs(save_dir, exist_ok=True) 64 | 65 | with torch.no_grad(): 66 | img_list = [glob.glob1(image_dir, ext) for ext in ['*jpg','*png']] 67 | img_list = [item for sublist in img_list for item in sublist] 68 | img_list.sort() 69 | for i, img_name in enumerate(img_list): 70 | #print(i, img_name) 71 | image_A = img_to_tensor(Image.open(image_dir + img_name)).unsqueeze(0).to(device) 72 | output = trainer.test(img=image_A, return_latent=True) 73 | feature = output.pop() 74 | latent = output.pop() 75 | #np.save(save_dir + 'latent_code_%d.npy'%i, latent.cpu().numpy()) 76 | utils.save_image(clip_img(output[1]), save_dir + img_name) 77 | if i > 1000: 78 | break 79 | -------------------------------------------------------------------------------- /test/00020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/test/00020.jpg -------------------------------------------------------------------------------- /test/00109.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/test/00109.jpg -------------------------------------------------------------------------------- /test/00128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/test/00128.jpg -------------------------------------------------------------------------------- /test/00299.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/test/00299.jpg -------------------------------------------------------------------------------- /test/00610.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/test/00610.jpg -------------------------------------------------------------------------------- /test/00962.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/test/00962.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data as data 8 | import yaml 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from torchvision import transforms, utils 13 | from tensorboard_logger import Logger 14 | 15 | from utils.datasets import * 16 | from utils.functions import * 17 | from trainer import * 18 | 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | torch.autograd.set_detect_anomaly(True) 23 | Image.MAX_IMAGE_PIXELS = None 24 | device = torch.device('cuda') 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--config', type=str, default='001', help='Path to the config file.') 28 | parser.add_argument('--real_dataset_path', type=str, default='./data/ffhq-dataset/images/', help='dataset path') 29 | parser.add_argument('--dataset_path', type=str, default='./data/stylegan2-generate-images/ims/', help='dataset path') 30 | parser.add_argument('--label_path', type=str, default='./data/stylegan2-generate-images/seeds_pytorch_1.8.1.npy', help='laebl path') 31 | parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='pretrained stylegan2 model') 32 | parser.add_argument('--arcface_model_path', type=str, default='./pretrained_models/backbone.pth', help='pretrained arcface model') 33 | parser.add_argument('--parsing_model_path', type=str, default='./pretrained_models/79999_iter.pth', help='pretrained parsing model') 34 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 35 | parser.add_argument('--resume', action='store_true', help='resume from checkpoint') 36 | parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path') 37 | opts = parser.parse_args() 38 | 39 | log_dir = os.path.join(opts.log_path, opts.config) + '/' 40 | os.makedirs(log_dir, exist_ok=True) 41 | logger = Logger(log_dir) 42 | 43 | config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'), Loader=yaml.FullLoader) 44 | 45 | batch_size = config['batch_size'] 46 | epochs = config['epochs'] 47 | iter_per_epoch = config['iter_per_epoch'] 48 | img_size = (config['resolution'], config['resolution']) 49 | video_data_input = False 50 | 51 | 52 | img_to_tensor = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 55 | ]) 56 | img_to_tensor_car = transforms.Compose([ 57 | transforms.Resize((384, 512)), 58 | transforms.Pad(padding=(0, 64, 0, 64)), 59 | transforms.ToTensor(), 60 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 61 | ]) 62 | 63 | # Initialize trainer 64 | trainer = Trainer(config, opts) 65 | trainer.initialize(opts.stylegan_model_path, opts.arcface_model_path, opts.parsing_model_path) 66 | trainer.to(device) 67 | 68 | noise_exemple = trainer.noise_inputs 69 | train_data_split = 0.9 if 'train_split' not in config else config['train_split'] 70 | 71 | # Load synthetic dataset 72 | dataset_A = MyDataSet(image_dir=opts.dataset_path, label_dir=opts.label_path, output_size=img_size, noise_in=noise_exemple, training_set=True, train_split=train_data_split) 73 | loader_A = data.DataLoader(dataset_A, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) 74 | # Load real dataset 75 | dataset_B = MyDataSet(image_dir=opts.real_dataset_path, label_dir=None, output_size=img_size, noise_in=noise_exemple, training_set=True, train_split=train_data_split) 76 | loader_B = data.DataLoader(dataset_B, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) 77 | 78 | # Start Training 79 | epoch_0 = 0 80 | 81 | # check if checkpoint exist 82 | if 'checkpoint.pth' in os.listdir(log_dir): 83 | epoch_0 = trainer.load_checkpoint(os.path.join(log_dir, 'checkpoint.pth')) 84 | 85 | if opts.resume: 86 | epoch_0 = trainer.load_checkpoint(os.path.join(opts.log_path, opts.checkpoint)) 87 | 88 | torch.manual_seed(0) 89 | os.makedirs(log_dir + 'validation/', exist_ok=True) 90 | 91 | print("Start!") 92 | 93 | for n_epoch in tqdm(range(epoch_0, epochs)): 94 | 95 | iter_A = iter(loader_A) 96 | iter_B = iter(loader_B) 97 | iter_0 = n_epoch*iter_per_epoch 98 | 99 | trainer.enc_opt.zero_grad() 100 | 101 | for n_iter in range(iter_0, iter_0 + iter_per_epoch): 102 | 103 | if opts.dataset_path is None: 104 | z, noise = next(iter_A) 105 | img_A = None 106 | else: 107 | z, img_A, noise = next(iter_A) 108 | img_A = img_A.to(device) 109 | 110 | z = z.to(device) 111 | noise = [ee.to(device) for ee in noise] 112 | w = trainer.mapping(z) 113 | if 'fixed_noise' in config and config['fixed_noise']: 114 | img_A, noise = None, None 115 | 116 | img_B = None 117 | if 'use_realimg' in config and config['use_realimg']: 118 | try: 119 | img_B = next(iter_B) 120 | if img_B.size(0) != batch_size: 121 | iter_B = iter(loader_B) 122 | img_B = next(iter_B) 123 | except StopIteration: 124 | iter_B = iter(loader_B) 125 | img_B = next(iter_B) 126 | img_B = img_B.to(device) 127 | 128 | trainer.update(w=w, img=img_A, noise=noise, real_img=img_B, n_iter=n_iter) 129 | if (n_iter+1) % config['log_iter'] == 0: 130 | trainer.log_loss(logger, n_iter, prefix='train') 131 | if (n_iter+1) % config['image_save_iter'] == 0: 132 | trainer.save_image(log_dir, n_epoch, n_iter, prefix='/train/', w=w, img=img_A, noise=noise) 133 | trainer.save_image(log_dir, n_epoch, n_iter+1, prefix='/train/', w=w, img=img_B, noise=noise, training_mode=False) 134 | 135 | trainer.enc_scheduler.step() 136 | trainer.save_checkpoint(n_epoch, log_dir) 137 | 138 | # Test the model on celeba hq dataset 139 | with torch.no_grad(): 140 | trainer.enc.eval() 141 | for i in range(10): 142 | image_A = img_to_tensor(Image.open('./data/celeba_hq/%d.jpg' % i)).unsqueeze(0).to(device) 143 | output = trainer.test(img=image_A) 144 | out_img = torch.cat(output, 3) 145 | utils.save_image(clip_img(out_img[:1]), log_dir + 'validation/' + 'epoch_' +str(n_epoch+1) + '_' + str(i) + '.jpg') 146 | trainer.compute_loss(w=w, img=img_A, noise=noise, real_img=img_B) 147 | trainer.log_loss(logger, n_iter, prefix='validation') 148 | 149 | trainer.save_model(log_dir) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | from torch.autograd import grad 11 | from torchvision import transforms, utils 12 | 13 | import face_alignment 14 | import lpips 15 | 16 | sys.path.append('pixel2style2pixel/') 17 | from pixel2style2pixel.models.stylegan2.model import Generator, get_keys 18 | 19 | from nets.feature_style_encoder import * 20 | from utils.functions import * 21 | from arcface.iresnet import * 22 | from face_parsing.model import BiSeNet 23 | from ranger import Ranger 24 | 25 | class Trainer(nn.Module): 26 | def __init__(self, config, opts): 27 | super(Trainer, self).__init__() 28 | # Load Hyperparameters 29 | self.config = config 30 | self.device = torch.device(self.config['device']) 31 | self.scale = int(np.log2(config['resolution']/config['enc_resolution'])) 32 | self.scale_mode = 'bilinear' 33 | self.opts = opts 34 | self.n_styles = 2 * int(np.log2(config['resolution'])) - 2 35 | self.idx_k = 5 36 | if 'idx_k' in self.config: 37 | self.idx_k = self.config['idx_k'] 38 | if 'stylegan_version' in self.config and self.config['stylegan_version'] == 3: 39 | self.n_styles = 16 40 | # Networks 41 | in_channels = 256 42 | if 'in_c' in self.config: 43 | in_channels = config['in_c'] 44 | enc_residual = False 45 | if 'enc_residual' in self.config: 46 | enc_residual = self.config['enc_residual'] 47 | enc_residual_coeff = False 48 | if 'enc_residual_coeff' in self.config: 49 | enc_residual_coeff = self.config['enc_residual_coeff'] 50 | resnet_layers = [4,5,6] 51 | if 'enc_start_layer' in self.config: 52 | st_l = self.config['enc_start_layer'] 53 | resnet_layers = [st_l, st_l+1, st_l+2] 54 | if 'scale_mode' in self.config: 55 | self.scale_mode = self.config['scale_mode'] 56 | # Load encoder 57 | self.stride = (self.config['fs_stride'], self.config['fs_stride']) 58 | self.enc = fs_encoder_v2(n_styles=self.n_styles, opts=opts, residual=enc_residual, use_coeff=enc_residual_coeff, resnet_layer=resnet_layers, stride=self.stride) 59 | 60 | ########################## 61 | # Other nets 62 | self.StyleGAN = self.init_stylegan(config) 63 | self.Arcface = iresnet50() 64 | self.parsing_net = BiSeNet(n_classes=19) 65 | # Optimizers 66 | # Latent encoder 67 | self.enc_params = list(self.enc.parameters()) 68 | if 'freeze_iresnet' in self.config and self.config['freeze_iresnet']: 69 | self.enc_params = list(self.enc.styles.parameters()) 70 | if 'optimizer' in self.config and self.config['optimizer'] == 'ranger': 71 | self.enc_opt = Ranger(self.enc_params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay']) 72 | else: 73 | self.enc_opt = torch.optim.Adam(self.enc_params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay']) 74 | self.enc_scheduler = torch.optim.lr_scheduler.StepLR(self.enc_opt, step_size=config['step_size'], gamma=config['gamma']) 75 | 76 | self.fea_avg = None 77 | 78 | def initialize(self, stylegan_model_path, arcface_model_path, parsing_model_path): 79 | # load StyleGAN model 80 | stylegan_state_dict = torch.load(stylegan_model_path, map_location='cpu') 81 | self.StyleGAN.load_state_dict(get_keys(stylegan_state_dict, 'decoder'), strict=True) 82 | self.StyleGAN.to(self.device) 83 | # get StyleGAN average latent in w space and the noise inputs 84 | self.dlatent_avg = stylegan_state_dict['latent_avg'].to(self.device) 85 | self.noise_inputs = [getattr(self.StyleGAN.noises, f'noise_{i}').to(self.device) for i in range(self.StyleGAN.num_layers)] 86 | # load Arcface weight 87 | self.Arcface.load_state_dict(torch.load(self.opts.arcface_model_path)) 88 | self.Arcface.eval() 89 | # load face parsing net weight 90 | self.parsing_net.load_state_dict(torch.load(self.opts.parsing_model_path)) 91 | self.parsing_net.eval() 92 | # load lpips net weight 93 | self.loss_fn = lpips.LPIPS(net='alex', spatial=False) 94 | self.loss_fn.to(self.device) 95 | 96 | def init_stylegan(self, config): 97 | """StyleGAN = G_main( 98 | truncation_psi=config['truncation_psi'], 99 | resolution=config['resolution'], 100 | use_noise=config['use_noise'], 101 | randomize_noise=config['randomize_noise'] 102 | )""" 103 | StyleGAN = Generator(1024, 512, 8) 104 | return StyleGAN 105 | 106 | def mapping(self, z): 107 | return self.StyleGAN.get_latent(z).detach() 108 | 109 | def L1loss(self, input, target): 110 | return nn.L1Loss()(input,target) 111 | 112 | def L2loss(self, input, target): 113 | return nn.MSELoss()(input,target) 114 | 115 | def CEloss(self, x, target_age): 116 | return nn.CrossEntropyLoss()(x, target_age) 117 | 118 | def LPIPS(self, input, target, multi_scale=False): 119 | if multi_scale: 120 | out = 0 121 | for k in range(3): 122 | out += self.loss_fn.forward(downscale(input, k, self.scale_mode), downscale(target, k, self.scale_mode)).mean() 123 | else: 124 | out = self.loss_fn.forward(downscale(input, self.scale, self.scale_mode), downscale(target, self.scale, self.scale_mode)).mean() 125 | return out 126 | 127 | def IDloss(self, input, target): 128 | x_1 = F.interpolate(input, (112,112)) 129 | x_2 = F.interpolate(target, (112,112)) 130 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 131 | if 'multi_layer_idloss' in self.config and self.config['multi_layer_idloss']: 132 | id_1 = self.Arcface(x_1, return_features=True) 133 | id_2 = self.Arcface(x_2, return_features=True) 134 | return sum([1 - cos(id_1[i].flatten(start_dim=1), id_2[i].flatten(start_dim=1)) for i in range(len(id_1))]) 135 | else: 136 | id_1 = self.Arcface(x_1) 137 | id_2 = self.Arcface(x_2) 138 | return 1 - cos(id_1, id_2) 139 | 140 | def landmarkloss(self, input, target): 141 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 142 | x_1 = stylegan_to_classifier(input, out_size=(512, 512)) 143 | x_2 = stylegan_to_classifier(target, out_size=(512,512)) 144 | out_1 = self.parsing_net(x_1) 145 | out_2 = self.parsing_net(x_2) 146 | parsing_loss = sum([1 - cos(out_1[i].flatten(start_dim=1), out_2[i].flatten(start_dim=1)) for i in range(len(out_1))]) 147 | return parsing_loss.mean() 148 | 149 | 150 | def feature_match(self, enc_feat, dec_feat, layer_idx=None): 151 | loss = [] 152 | if layer_idx is None: 153 | layer_idx = [i for i in range(len(enc_feat))] 154 | for i in layer_idx: 155 | loss.append(self.L1loss(enc_feat[i], dec_feat[i])) 156 | return loss 157 | 158 | def encode(self, img): 159 | w_recon, fea = self.enc(downscale(img, self.scale, self.scale_mode)) 160 | w_recon = w_recon + self.dlatent_avg 161 | return w_recon, fea 162 | 163 | def get_image(self, w=None, img=None, noise=None, zero_noise_input=True, training_mode=True): 164 | 165 | x_1, n_1 = img, noise 166 | if x_1 is None: 167 | x_1, _ = self.StyleGAN([w], input_is_latent=True, noise = n_1) 168 | 169 | w_delta = None 170 | fea = None 171 | features = None 172 | return_features = False 173 | # Reconstruction 174 | k = 0 175 | if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: 176 | return_features = True 177 | k = self.idx_k 178 | w_recon, fea = self.enc(downscale(x_1, self.scale, self.scale_mode)) 179 | w_recon = w_recon + self.dlatent_avg 180 | features = [None]*k + [fea] + [None]*(17-k) 181 | else: 182 | w_recon = self.enc(downscale(x_1, self.scale, self.scale_mode)) + self.dlatent_avg 183 | 184 | # generate image 185 | x_1_recon, fea_recon = self.StyleGAN([w_recon], input_is_latent=True, return_features=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) 186 | fea_recon = fea_recon[k].detach() 187 | return [x_1_recon, x_1[:,:3,:,:], w_recon, w_delta, n_1, fea, fea_recon] 188 | 189 | def compute_loss(self, w=None, img=None, noise=None, real_img=None): 190 | return self.compute_loss_stylegan2(w=w, img=img, noise=noise, real_img=real_img) 191 | 192 | def compute_loss_stylegan2(self, w=None, img=None, noise=None, real_img=None): 193 | 194 | if img is None: 195 | # generate synthetic images 196 | if noise is None: 197 | noise = [torch.randn(w.size()[:1] + ee.size()[1:]).to(self.device) for ee in self.noise_inputs] 198 | img, _ = self.StyleGAN([w], input_is_latent=True, noise = noise) 199 | img = img.detach() 200 | 201 | if img is not None and real_img is not None: 202 | # concat synthetic and real data 203 | img = torch.cat([img, real_img], dim=0) 204 | noise = [torch.cat([ee, ee], dim=0) for ee in noise] 205 | 206 | out = self.get_image(w=w, img=img, noise=noise) 207 | x_1_recon, x_1, w_recon, w_delta, n_1, fea_1, fea_recon = out 208 | 209 | # Loss setting 210 | w_l2, w_lpips, w_id = self.config['w']['l2'], self.config['w']['lpips'], self.config['w']['id'] 211 | b = x_1.size(0)//2 212 | if 'l2loss_on_real_image' in self.config and self.config['l2loss_on_real_image']: 213 | b = x_1.size(0) 214 | self.l2_loss = self.L2loss(x_1_recon[:b], x_1[:b]) if w_l2 > 0 else torch.tensor(0) # l2 loss only on synthetic data 215 | # LPIPS 216 | multiscale_lpips=False if 'multiscale_lpips' not in self.config else self.config['multiscale_lpips'] 217 | self.lpips_loss = self.LPIPS(x_1_recon, x_1, multi_scale=multiscale_lpips).mean() if w_lpips > 0 else torch.tensor(0) 218 | self.id_loss = self.IDloss(x_1_recon, x_1).mean() if w_id > 0 else torch.tensor(0) 219 | self.landmark_loss = self.landmarkloss(x_1_recon, x_1) if self.config['w']['landmark'] > 0 else torch.tensor(0) 220 | 221 | if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: 222 | k = self.idx_k 223 | features = [None]*k + [fea_1] + [None]*(17-k) 224 | x_1_recon_2, _ = self.StyleGAN([w_recon], noise=n_1, input_is_latent=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) 225 | self.lpips_loss += self.LPIPS(x_1_recon_2, x_1, multi_scale=multiscale_lpips).mean() if w_lpips > 0 else torch.tensor(0) 226 | self.id_loss += self.IDloss(x_1_recon_2, x_1).mean() if w_id > 0 else torch.tensor(0) 227 | self.landmark_loss += self.landmarkloss(x_1_recon_2, x_1) if self.config['w']['landmark'] > 0 else torch.tensor(0) 228 | 229 | # downscale image 230 | x_1 = downscale(x_1, self.scale, self.scale_mode) 231 | x_1_recon = downscale(x_1_recon, self.scale, self.scale_mode) 232 | 233 | # Total loss 234 | w_l2, w_lpips, w_id = self.config['w']['l2'], self.config['w']['lpips'], self.config['w']['id'] 235 | self.loss = w_l2*self.l2_loss + w_lpips*self.lpips_loss + w_id*self.id_loss 236 | 237 | if 'f_recon' in self.config['w']: 238 | self.feature_recon_loss = self.L2loss(fea_1, fea_recon) 239 | self.loss += self.config['w']['f_recon']*self.feature_recon_loss 240 | if 'l1' in self.config['w'] and self.config['w']['l1']>0: 241 | self.l1_loss = self.L1loss(x_1_recon, x_1) 242 | self.loss += self.config['w']['l1']*self.l1_loss 243 | if 'landmark' in self.config['w']: 244 | self.loss += self.config['w']['landmark']*self.landmark_loss 245 | return self.loss 246 | 247 | def test(self, w=None, img=None, noise=None, zero_noise_input=True, return_latent=False, training_mode=False): 248 | if 'n_iter' not in self.__dict__.keys(): 249 | self.n_iter = 1e5 250 | out = self.get_image(w=w, img=img, noise=noise, training_mode=training_mode) 251 | x_1_recon, x_1, w_recon, w_delta, n_1, fea_1 = out[:6] 252 | output = [x_1, x_1_recon] 253 | if return_latent: 254 | output += [w_recon, fea_1] 255 | return output 256 | 257 | def log_loss(self, logger, n_iter, prefix='train'): 258 | logger.log_value(prefix + '/l2_loss', self.l2_loss.item(), n_iter + 1) 259 | logger.log_value(prefix + '/lpips_loss', self.lpips_loss.item(), n_iter + 1) 260 | logger.log_value(prefix + '/id_loss', self.id_loss.item(), n_iter + 1) 261 | logger.log_value(prefix + '/total_loss', self.loss.item(), n_iter + 1) 262 | if 'f_recon' in self.config['w']: 263 | logger.log_value(prefix + '/feature_recon_loss', self.feature_recon_loss.item(), n_iter + 1) 264 | if 'l1' in self.config['w'] and self.config['w']['l1']>0: 265 | logger.log_value(prefix + '/l1_loss', self.l1_loss.item(), n_iter + 1) 266 | if 'landmark' in self.config['w']: 267 | logger.log_value(prefix + '/landmark_loss', self.landmark_loss.item(), n_iter + 1) 268 | 269 | def save_image(self, log_dir, n_epoch, n_iter, prefix='/train/', w=None, img=None, noise=None, training_mode=True): 270 | return self.save_image_stylegan2(log_dir=log_dir, n_epoch=n_epoch, n_iter=n_iter, prefix=prefix, w=w, img=img, noise=noise, training_mode=training_mode) 271 | 272 | def save_image_stylegan2(self, log_dir, n_epoch, n_iter, prefix='/train/', w=None, img=None, noise=None, training_mode=True): 273 | os.makedirs(log_dir + prefix, exist_ok=True) 274 | with torch.no_grad(): 275 | out = self.get_image(w=w, img=img, noise=noise, training_mode=training_mode) 276 | x_1_recon, x_1, w_recon, w_delta, n_1, fea_1 = out[:6] 277 | x_1 = downscale(x_1, self.scale, self.scale_mode) 278 | x_1_recon = downscale(x_1_recon, self.scale, self.scale_mode) 279 | out_img = torch.cat((x_1, x_1_recon), dim=3) 280 | #fs 281 | if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: 282 | k = self.idx_k 283 | features = [None]*k + [fea_1] + [None]*(17-k) 284 | x_1_recon_2, _ = self.StyleGAN([w_recon], noise=n_1, input_is_latent=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) 285 | x_1_recon_2 = downscale(x_1_recon_2, self.scale, self.scale_mode) 286 | out_img = torch.cat((x_1, x_1_recon, x_1_recon_2), dim=3) 287 | utils.save_image(clip_img(out_img[:1]), log_dir + prefix + 'epoch_' +str(n_epoch+1) + '_iter_' + str(n_iter+1) + '_0.jpg') 288 | if out_img.size(0)>1: 289 | utils.save_image(clip_img(out_img[1:]), log_dir + prefix + 'epoch_' +str(n_epoch+1) + '_iter_' + str(n_iter+1) + '_1.jpg') 290 | 291 | def save_model(self, log_dir): 292 | torch.save(self.enc.state_dict(),'{:s}/enc.pth.tar'.format(log_dir)) 293 | 294 | def save_checkpoint(self, n_epoch, log_dir): 295 | checkpoint_state = { 296 | 'n_epoch': n_epoch, 297 | 'enc_state_dict': self.enc.state_dict(), 298 | 'enc_opt_state_dict': self.enc_opt.state_dict(), 299 | 'enc_scheduler_state_dict': self.enc_scheduler.state_dict() 300 | } 301 | torch.save(checkpoint_state, '{:s}/checkpoint.pth'.format(log_dir)) 302 | if (n_epoch+1)%10 == 0 : 303 | torch.save(checkpoint_state, '{:s}/checkpoint'.format(log_dir)+'_'+str(n_epoch+1)+'.pth') 304 | 305 | def load_model(self, log_dir): 306 | self.enc.load_state_dict(torch.load('{:s}/enc.pth.tar'.format(log_dir))) 307 | 308 | def load_checkpoint(self, checkpoint_path): 309 | state_dict = torch.load(checkpoint_path) 310 | self.enc.load_state_dict(state_dict['enc_state_dict']) 311 | self.enc_opt.load_state_dict(state_dict['enc_opt_state_dict']) 312 | self.enc_scheduler.load_state_dict(state_dict['enc_scheduler_state_dict']) 313 | return state_dict['n_epoch'] + 1 314 | 315 | def update(self, w=None, img=None, noise=None, real_img=None, n_iter=0): 316 | self.n_iter = n_iter 317 | self.enc_opt.zero_grad() 318 | self.compute_loss(w=w, img=img, noise=noise, real_img=real_img).backward() 319 | self.enc_opt.step() 320 | 321 | 322 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/FeatureStyleEncoder/4aa0069aedec3993dab6f3471f481616617d7bc9/utils/.DS_Store -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | from torchvision import transforms, utils 11 | 12 | class MyDataSet(data.Dataset): 13 | def __init__(self, image_dir=None, label_dir=None, output_size=(256, 256), noise_in=None, training_set=True, video_data=False, train_split=0.9): 14 | self.image_dir = image_dir 15 | self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 16 | self.resize = transforms.Compose([ 17 | transforms.Resize(output_size), 18 | transforms.ToTensor() 19 | ]) 20 | self.noise_in = noise_in 21 | self.video_data = video_data 22 | self.random_rotation = transforms.Compose([ 23 | transforms.Resize(output_size), 24 | transforms.RandomPerspective(distortion_scale=0.05, p=1.0), 25 | transforms.ToTensor() 26 | ]) 27 | 28 | # load image file 29 | train_len = None 30 | self.length = 0 31 | self.image_dir = image_dir 32 | if image_dir is not None: 33 | img_list = [glob.glob1(self.image_dir, ext) for ext in ['*jpg','*png']] 34 | image_list = [item for sublist in img_list for item in sublist] 35 | image_list.sort() 36 | train_len = int(train_split*len(image_list)) 37 | if training_set: 38 | self.image_list = image_list[:train_len] 39 | else: 40 | self.image_list = image_list[train_len:] 41 | self.length = len(self.image_list) 42 | 43 | # load label file 44 | self.label_dir = label_dir 45 | if label_dir is not None: 46 | self.seeds = np.load(label_dir) 47 | if train_len is None: 48 | train_len = int(train_split*len(self.seeds)) 49 | if training_set: 50 | self.seeds = self.seeds[:train_len] 51 | else: 52 | self.seeds = self.seeds[train_len:] 53 | if self.length == 0: 54 | self.length = len(self.seeds) 55 | 56 | def __len__(self): 57 | return self.length 58 | 59 | def __getitem__(self, idx): 60 | img = None 61 | if self.image_dir is not None: 62 | img_name = os.path.join(self.image_dir, self.image_list[idx]) 63 | image = Image.open(img_name) 64 | img = self.resize(image) 65 | if img.size(0) == 1: 66 | img = torch.cat((img, img, img), dim=0) 67 | img = self.normalize(img) 68 | 69 | # generate image 70 | if self.label_dir is not None: 71 | torch.manual_seed(self.seeds[idx]) 72 | z = torch.randn(1, 512)[0] 73 | if self.noise_in is None: 74 | n = [torch.randn(1, 1)] 75 | else: 76 | n = [torch.randn(noise.size())[0] for noise in self.noise_in] 77 | if img is None: 78 | return z, n 79 | else: 80 | return z, img, n 81 | else: 82 | return img 83 | 84 | class Car_DataSet(data.Dataset): 85 | def __init__(self, image_dir=None, label_dir=None, output_size=(512, 512), noise_in=None, training_set=True, video_data=False, train_split=0.9): 86 | self.image_dir = image_dir 87 | self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 88 | self.resize = transforms.Compose([ 89 | transforms.Resize((384, 512)), 90 | transforms.Pad(padding=(0, 64, 0, 64)), 91 | transforms.ToTensor() 92 | ]) 93 | self.noise_in = noise_in 94 | self.video_data = video_data 95 | self.random_rotation = transforms.Compose([ 96 | transforms.Resize(output_size), 97 | transforms.RandomPerspective(distortion_scale=0.05, p=1.0), 98 | transforms.ToTensor() 99 | ]) 100 | 101 | # load image file 102 | train_len = None 103 | self.length = 0 104 | self.image_dir = image_dir 105 | if image_dir is not None: 106 | img_list = [glob.glob1(self.image_dir, ext) for ext in ['*jpg','*png']] 107 | image_list = [item for sublist in img_list for item in sublist] 108 | image_list.sort() 109 | train_len = int(train_split*len(image_list)) 110 | if training_set: 111 | self.image_list = image_list[:train_len] 112 | else: 113 | self.image_list = image_list[train_len:] 114 | self.length = len(self.image_list) 115 | 116 | # load label file 117 | self.label_dir = label_dir 118 | if label_dir is not None: 119 | self.seeds = np.load(label_dir) 120 | if train_len is None: 121 | train_len = int(train_split*len(self.seeds)) 122 | if training_set: 123 | self.seeds = self.seeds[:train_len] 124 | else: 125 | self.seeds = self.seeds[train_len:] 126 | if self.length == 0: 127 | self.length = len(self.seeds) 128 | 129 | def __len__(self): 130 | return self.length 131 | 132 | def __getitem__(self, idx): 133 | img = None 134 | if self.image_dir is not None: 135 | img_name = os.path.join(self.image_dir, self.image_list[idx]) 136 | image = Image.open(img_name) 137 | img = self.resize(image) 138 | if img.size(0) == 1: 139 | img = torch.cat((img, img, img), dim=0) 140 | img = self.normalize(img) 141 | if self.video_data: 142 | img_2 = self.random_rotation(image) 143 | img_2 = self.normalize(img_2) 144 | img_2 = torch.where(img_2 > -1, img_2, img) 145 | img = torch.cat([img, img_2], dim=0) 146 | 147 | # generate image 148 | if self.label_dir is not None: 149 | torch.manual_seed(self.seeds[idx]) 150 | z = torch.randn(1, 512)[0] 151 | n = [torch.randn_like(noise[0]) for noise in self.noise_in] 152 | if img is None: 153 | return z, n 154 | else: 155 | return z, img, n 156 | else: 157 | return img 158 | 159 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.data as data 7 | 8 | from PIL import Image 9 | from torch.autograd import grad 10 | 11 | 12 | def clip_img(x): 13 | """Clip stylegan generated image to range(0,1)""" 14 | img_tmp = x.clone()[0] 15 | img_tmp = (img_tmp + 1) / 2 16 | img_tmp = torch.clamp(img_tmp, 0, 1) 17 | return [img_tmp.detach().cpu()] 18 | 19 | def tensor_byte(x): 20 | return x.element_size()*x.nelement() 21 | 22 | def count_parameters(net): 23 | s = sum([np.prod(list(mm.size())) for mm in net.parameters()]) 24 | print(s) 25 | 26 | def stylegan_to_classifier(x, out_size=(224, 224)): 27 | """Clip image to range(0,1)""" 28 | img_tmp = x.clone() 29 | img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1) 30 | img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear') 31 | img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229 32 | img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224 33 | img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225 34 | #img_tmp = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_tmp) 35 | return img_tmp 36 | 37 | def downscale(x, scale_times=1, mode='bilinear'): 38 | for i in range(scale_times): 39 | x = F.interpolate(x, scale_factor=0.5, mode=mode) 40 | return x 41 | 42 | def upscale(x, scale_times=1, mode='bilinear'): 43 | for i in range(scale_times): 44 | x = F.interpolate(x, scale_factor=2, mode=mode) 45 | return x 46 | 47 | def hist_transform(source_tensor, target_tensor): 48 | """Histogram transformation""" 49 | c, h, w = source_tensor.size() 50 | s_t = source_tensor.view(c, -1) 51 | t_t = target_tensor.view(c, -1) 52 | s_t_sorted, s_t_indices = torch.sort(s_t) 53 | t_t_sorted, t_t_indices = torch.sort(t_t) 54 | for i in range(c): 55 | s_t[i, s_t_indices[i]] = t_t_sorted[i] 56 | return s_t.view(c, h, w) 57 | 58 | def init_weights(m): 59 | """Initialize layers with Xavier uniform distribution""" 60 | if type(m) == nn.Conv2d: 61 | nn.init.xavier_uniform_(m.weight) 62 | elif type(m) == nn.Linear: 63 | nn.init.uniform_(m.weight, 0.0, 1.0) 64 | if m.bias is not None: 65 | nn.init.constant_(m.bias, 0.01) 66 | 67 | def total_variation(x, delta=1): 68 | """Total variation, x: tensor of size (B, C, H, W)""" 69 | out = torch.mean(torch.abs(x[:, :, :, :-delta] - x[:, :, :, delta:]))\ 70 | + torch.mean(torch.abs(x[:, :, :-delta, :] - x[:, :, delta:, :])) 71 | return out 72 | 73 | def vgg_transform(x): 74 | """Adapt image for vgg network, x: image of range(0,1) subtracting ImageNet mean""" 75 | r, g, b = torch.split(x, 1, 1) 76 | out = torch.cat((b, g, r), dim = 1) 77 | out = F.interpolate(out, size=(224, 224), mode='bilinear') 78 | out = out*255. 79 | return out 80 | 81 | # warp image with flow 82 | def normalize_axis(x,L): 83 | return (x-1-(L-1)/2)*2/(L-1) 84 | 85 | def unnormalize_axis(x,L): 86 | return x*(L-1)/2+1+(L-1)/2 87 | 88 | def torch_flow_to_th_sampling_grid(flow,h_src,w_src,use_cuda=False): 89 | b,c,h_tgt,w_tgt=flow.size() 90 | grid_y, grid_x = torch.meshgrid(torch.tensor(range(1,w_tgt+1)),torch.tensor(range(1,h_tgt+1))) 91 | disp_x=flow[:,0,:,:] 92 | disp_y=flow[:,1,:,:] 93 | source_x=grid_x.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_x 94 | source_y=grid_y.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_y 95 | source_x_norm=normalize_axis(source_x,w_src) 96 | source_y_norm=normalize_axis(source_y,h_src) 97 | sampling_grid=torch.cat((source_x_norm.unsqueeze(3), source_y_norm.unsqueeze(3)), dim=3) 98 | if use_cuda: 99 | sampling_grid = sampling_grid.cuda() 100 | return sampling_grid 101 | 102 | def warp_image_torch(image, flow): 103 | """ 104 | Warp image (tensor, shape=[b, 3, h_src, w_src]) with flow (tensor, shape=[b, h_tgt, w_tgt, 2]) 105 | """ 106 | b,c,h_src,w_src=image.size() 107 | sampling_grid_torch = torch_flow_to_th_sampling_grid(flow, h_src, w_src) 108 | warped_image_torch = F.grid_sample(image, sampling_grid_torch) 109 | return warped_image_torch -------------------------------------------------------------------------------- /utils/video_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import os 10 | import face_alignment 11 | import torch 12 | 13 | from PIL import Image, ImageFilter 14 | from scipy import ndimage 15 | from scipy.ndimage import gaussian_filter1d 16 | from skimage import io 17 | from torchvision import transforms, utils 18 | 19 | 20 | def pil_to_cv2(pil_image): 21 | open_cv_image = np.array(pil_image) 22 | return open_cv_image[:, :, ::-1].copy() 23 | 24 | 25 | def cv2_to_pil(open_cv_image): 26 | return Image.fromarray(open_cv_image[:, :, ::-1].copy()) 27 | 28 | 29 | def put_text(img, text): 30 | font = cv2.FONT_HERSHEY_SIMPLEX 31 | bottomLeftCornerOfText = (10,50) 32 | fontScale = 1.5 33 | fontColor = (255,255,0) 34 | lineType = 2 35 | return cv2.putText(img, text, 36 | bottomLeftCornerOfText, 37 | font, 38 | fontScale, 39 | fontColor, 40 | lineType) 41 | 42 | 43 | # Compare frames in two directory 44 | def compare_frames(save_dir, origin_dir, target_dir, strs='Original,Projected,Manipulated', dim=None): 45 | 46 | os.makedirs(save_dir, exist_ok=True) 47 | try: 48 | if not isinstance(target_dir, list): 49 | target_dir = [target_dir] 50 | image_list = glob.glob1(origin_dir,'frame*') 51 | image_list.sort() 52 | for name in image_list: 53 | img_l = [] 54 | for idx, dir_path in enumerate([origin_dir] + list(target_dir)): 55 | img_1 = cv2.imread(dir_path + name) 56 | img_1 = put_text(img_1, strs.split(',')[idx]) 57 | img_l.append(img_1) 58 | img = np.concatenate(img_l, axis=1) 59 | cv2.imwrite(save_dir + name, img) 60 | except FileNotFoundError: 61 | pass 62 | 63 | 64 | # Save frames into video 65 | def create_video(image_folder, fps=24, video_format='.mp4', resize_ratio=1): 66 | 67 | video_name = os.path.dirname(image_folder) + video_format 68 | img_list = glob.glob1(image_folder,'frame*') 69 | img_list.sort() 70 | frame = cv2.imread(os.path.join(image_folder, img_list[0])) 71 | frame = cv2.resize(frame, (0,0), fx=resize_ratio, fy=resize_ratio) 72 | height, width, layers = frame.shape 73 | if video_format == '.mp4': 74 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 75 | elif video_format == '.avi': 76 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 77 | video = cv2.VideoWriter(video_name, fourcc, fps, (width,height)) 78 | for image_name in img_list: 79 | frame = cv2.imread(os.path.join(image_folder, image_name)) 80 | frame = cv2.resize(frame, (0,0), fx=resize_ratio, fy=resize_ratio) 81 | video.write(frame) 82 | 83 | 84 | # Split video into frames 85 | def video_to_frames(video_path, frame_path, img_format='.jpg', count_num=1000, resize=False): 86 | 87 | os.makedirs(frame_path, exist_ok=True) 88 | vidcap = cv2.VideoCapture(video_path) 89 | success,image = vidcap.read() 90 | count = 0 91 | while success: 92 | if resize: 93 | image = cv2.resize(image, (0,0), fx=0.5, fy=0.5) 94 | cv2.imwrite(frame_path + '/frame%04d' % count + img_format, image) 95 | success,image = vidcap.read() 96 | count += 1 97 | if count >= count_num: 98 | break 99 | 100 | # Align faces 101 | def align_frames(img_dir, save_dir, output_size=1024, transform_size=1024, optical_flow=True, gaussian=True, filter_size=3): 102 | 103 | os.makedirs(save_dir, exist_ok=True) 104 | 105 | # load face landmark detector 106 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device='cuda') 107 | 108 | # list images in the directory 109 | img_list = glob.glob1(img_dir, 'frame*') 110 | img_list.sort() 111 | 112 | # save align statistics 113 | stat_dict = {'quad':[], 'qsize':[], 'coord':[], 'crop':[]} 114 | lms = [] 115 | for idx, img_name in enumerate(img_list): 116 | 117 | img_path = os.path.join(img_dir, img_name) 118 | img = io.imread(img_path) 119 | lm = [] 120 | 121 | preds = fa.get_landmarks(img) 122 | for kk in range(68): 123 | lm.append((preds[0][kk][0], preds[0][kk][1])) 124 | 125 | # Eye distance 126 | lm_eye_left = lm[36 : 42] # left-clockwise 127 | lm_eye_right = lm[42 : 48] # left-clockwise 128 | eye_left = np.mean(lm_eye_left, axis=0) 129 | eye_right = np.mean(lm_eye_right, axis=0) 130 | eye_to_eye = eye_right - eye_left 131 | 132 | if optical_flow: 133 | if idx > 0: 134 | s = int(np.hypot(*eye_to_eye)/4) 135 | lk_params = dict(winSize=(s, s), maxLevel=5, criteria = (cv2.TERM_CRITERIA_COUNT | cv2.TERM_CRITERIA_EPS, 10, 0.03)) 136 | points_arr = np.array(lm, np.float32) 137 | points_prevarr = np.array(prev_lm, np.float32) 138 | points_arr,status, err = cv2.calcOpticalFlowPyrLK(prev_img, img, points_prevarr, points_arr, **lk_params) 139 | sigma =100 140 | points_arr_float = np.array(points_arr,np.float32) 141 | points = points_arr_float.tolist() 142 | for k in range(0, len(lm)): 143 | d = cv2.norm(np.array(prev_lm[k]) - np.array(lm[k])) 144 | alpha = np.exp(-d*d/sigma) 145 | lm[k] = (1 - alpha) * np.array(lm[k]) + alpha * np.array(points[k]) 146 | prev_img = img 147 | prev_lm = lm 148 | 149 | lms.append(lm) 150 | 151 | # Apply gaussian filter on landmarks 152 | if gaussian: 153 | lm_filtered = np.array(lms) 154 | for kk in range(68): 155 | lm_filtered[:, kk, 0] = gaussian_filter1d(lm_filtered[:, kk, 0], filter_size) 156 | lm_filtered[:, kk, 1] = gaussian_filter1d(lm_filtered[:, kk, 1], filter_size) 157 | lms = lm_filtered.tolist() 158 | 159 | # save landmarks 160 | landmark_out_dir = os.path.dirname(img_dir) + '_landmark/' 161 | os.makedirs(landmark_out_dir, exist_ok=True) 162 | 163 | for idx, img_name in enumerate(img_list): 164 | 165 | img_path = os.path.join(img_dir, img_name) 166 | img = io.imread(img_path) 167 | 168 | lm = lms[idx] 169 | img_lm = img.copy() 170 | for kk in range(68): 171 | img_lm = cv2.circle(img_lm, (int(lm[kk][0]),int(lm[kk][1])), radius=3, color=(255, 0, 255), thickness=-1) 172 | # Save landmark images 173 | cv2.imwrite(landmark_out_dir + img_name, img_lm[:,:,::-1]) 174 | 175 | # Save mask images 176 | """ 177 | seg_mask = np.zeros(img.shape, img.dtype) 178 | poly = np.array(lm[0:17] + lm[17:27][::-1], np.int32) 179 | cv2.fillPoly(seg_mask, [poly], (255, 255, 255)) 180 | cv2.imwrite(img_dir + "mask%04d.jpg"%idx, seg_mask); 181 | """ 182 | 183 | # Parse landmarks. 184 | lm_eye_left = lm[36 : 42] # left-clockwise 185 | lm_eye_right = lm[42 : 48] # left-clockwise 186 | lm_mouth_outer = lm[48 : 60] # left-clockwise 187 | 188 | # Calculate auxiliary vectors. 189 | eye_left = np.mean([lm_eye_left[0], lm_eye_left[3]], axis=0) 190 | eye_right = np.mean([lm_eye_right[0], lm_eye_right[3]], axis=0) 191 | eye_avg = (eye_left + eye_right) * 0.5 192 | eye_to_eye = eye_right - eye_left 193 | mouth_left = np.array(lm_mouth_outer[0]) 194 | mouth_right = np.array(lm_mouth_outer[6]) 195 | mouth_avg = (mouth_left + mouth_right) * 0.5 196 | eye_to_mouth = mouth_avg - eye_avg 197 | 198 | # Choose oriented crop rectangle. 199 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 200 | x /= np.hypot(*x) 201 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 202 | y = np.flipud(x) * [-1, 1] 203 | c = eye_avg + eye_to_mouth * 0.1 204 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 205 | qsize = np.hypot(*x) * 2 206 | 207 | stat_dict['coord'].append(quad) 208 | stat_dict['qsize'].append(qsize) 209 | 210 | # Apply gaussian filter on crops 211 | if gaussian: 212 | quads = np.array(stat_dict['coord']) 213 | quads = gaussian_filter1d(quads, 2*filter_size, axis=0) 214 | stat_dict['coord'] = quads.tolist() 215 | qsize = np.array(stat_dict['qsize']) 216 | qsize = gaussian_filter1d(qsize, 2*filter_size, axis=0) 217 | stat_dict['qsize'] = qsize.tolist() 218 | 219 | for idx, img_name in enumerate(img_list): 220 | 221 | img_path = os.path.join(img_dir, img_name) 222 | img = Image.open(img_path) 223 | 224 | qsize = stat_dict['qsize'][idx] 225 | quad = np.array(stat_dict['coord'][idx]) 226 | 227 | # Crop. 228 | border = max(int(np.rint(qsize * 0.1)), 3) 229 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 230 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 231 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 232 | img = img.crop(crop) 233 | quad -= crop[0:2] 234 | 235 | stat_dict['crop'].append(crop) 236 | stat_dict['quad'].append((quad + 0.5).flatten()) 237 | 238 | # Pad. 239 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 240 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 241 | if max(pad) > border - 4: 242 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 243 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 244 | h, w, _ = img.shape 245 | y, x, _ = np.ogrid[:h, :w, :1] 246 | img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 247 | quad += pad[:2] 248 | # Transform. 249 | img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR) 250 | 251 | # resizing 252 | img_pil = img.resize((output_size, output_size), Image.LANCZOS) 253 | img_pil.save(save_dir+img_name) 254 | 255 | create_video(landmark_out_dir) 256 | np.save(save_dir+'stat_dict.npy', stat_dict) 257 | 258 | 259 | def find_coeffs(pa, pb): 260 | 261 | matrix = [] 262 | for p1, p2 in zip(pa, pb): 263 | matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]]) 264 | matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]]) 265 | A = np.matrix(matrix, dtype=np.float) 266 | B = np.array(pb).reshape(8) 267 | res = np.dot(np.linalg.inv(A.T * A) * A.T, B) 268 | return np.array(res).reshape(8) 269 | 270 | # reproject aligned frames to the original video 271 | def video_reproject(orig_dir_path, recon_dir_path, save_dir_path, state_dir_path, mask_dir_path, seamless=False): 272 | 273 | if not os.path.exists(save_dir_path): 274 | os.makedirs(save_dir_path) 275 | 276 | img_list_0 = glob.glob1(orig_dir_path,'frame*') 277 | img_list_2 = glob.glob1(recon_dir_path,'frame*') 278 | img_list_0.sort() 279 | img_list_2.sort() 280 | stat_dict = np.load(state_dir_path + 'stat_dict.npy', allow_pickle=True).item() 281 | counter = len(img_list_2) 282 | 283 | for idx in range(counter): 284 | 285 | img_0 = Image.open(orig_dir_path + img_list_0[idx]) 286 | img_2 = Image.open(recon_dir_path + img_list_2[idx]) 287 | 288 | quad_f = stat_dict['quad'][idx] 289 | quad_0 = stat_dict['crop'][idx] 290 | 291 | coeffs = find_coeffs( 292 | [(quad_f[0], quad_f[1]), (quad_f[2] , quad_f[3]), (quad_f[4], quad_f[5]), (quad_f[6], quad_f[7])], 293 | [(0, 0), (0, 1024), (1024, 1024), (1024, 0)]) 294 | crop_size = (quad_0[2] - quad_0[0], quad_0[3] - quad_0[1]) 295 | img_2 = img_2.transform(crop_size, Image.PERSPECTIVE, coeffs, Image.BICUBIC) 296 | output = img_0.copy() 297 | output.paste(img_2, (int(quad_0[0]), int(quad_0[1]))) 298 | 299 | """ 300 | mask = cv2.imread(orig_dir_path + 'mask%04d.jpg'%idx) 301 | kernel = np.ones((10,10), np.uint8) 302 | mask = cv2.dilate(mask, kernel, iterations=5) 303 | """ 304 | crop_mask = Image.open(mask_dir_path + img_list_0[idx]) 305 | crop_mask = crop_mask.transform(crop_size, Image.PERSPECTIVE, coeffs, Image.BICUBIC) 306 | mask = Image.fromarray(np.zeros(np.array(img_0).shape, np.array(img_0).dtype)) 307 | mask.paste(crop_mask, (int(quad_0[0]), int(quad_0[1]))) 308 | mask = pil_to_cv2(mask) 309 | # Apply mask 310 | if not seamless: 311 | mask = cv2_to_pil(mask).filter(ImageFilter.GaussianBlur(radius=10)).convert('L') 312 | mask = np.array(mask)[:, :, np.newaxis]/255. 313 | output = np.array(img_0)*(1-mask) + np.array(output)*mask 314 | output = Image.fromarray(output.astype(np.uint8)) 315 | output.save(save_dir_path + img_list_2[idx]) 316 | else: 317 | src = pil_to_cv2(output) 318 | dst = pil_to_cv2(img_0) 319 | # clone 320 | br = cv2.boundingRect(cv2.split(mask)[0]) # bounding rect (x,y,width,height) 321 | center = (br[0] + br[2] // 2, br[1] + br[3] // 2) 322 | output = cv2.seamlessClone(src, dst, mask, center, cv2.NORMAL_CLONE) 323 | cv2.imwrite(save_dir_path + img_list_2[idx], output) 324 | 325 | 326 | 327 | 328 | # Align faces 329 | def align_image(img_dir, save_dir, output_size=1024, transform_size=1024, format='*.png'): 330 | os.makedirs(save_dir, exist_ok=True) 331 | 332 | # load face landmark detector 333 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device='cuda') 334 | # list images in the directory 335 | img_list = glob.glob1(img_dir, format) 336 | #img_list = os.listdir(img_dir) 337 | img_list.sort() 338 | 339 | # save align statistics 340 | stat_dict = {'quad':[], 'qsize':[], 'coord':[], 'crop':[]} 341 | 342 | for idx, img_name in enumerate(img_list): 343 | 344 | img_path = os.path.join(img_dir, img_name) 345 | img = Image.open(img_path).convert('RGB') 346 | img_np = np.array(img) 347 | lm = [] 348 | 349 | preds = fa.get_landmarks(img_np) 350 | for kk in range(68): 351 | lm.append((preds[0][kk][0], preds[0][kk][1])) 352 | if len(lm)==0: 353 | continue 354 | 355 | # Parse landmarks. Code extracted from ffhq-dataset 356 | # pylint: disable=unused-variable 357 | lm_chin = lm[0 : 17] # left-right 358 | lm_eyebrow_left = lm[17 : 22] # left-right 359 | lm_eyebrow_right = lm[22 : 27] # left-right 360 | lm_nose = lm[27 : 31] # top-down 361 | lm_nostrils = lm[31 : 36] # top-down 362 | lm_eye_left = lm[36 : 42] # left-clockwise 363 | lm_eye_right = lm[42 : 48] # left-clockwise 364 | lm_mouth_outer = lm[48 : 60] # left-clockwise 365 | lm_mouth_inner = lm[60 : 68] # left-clockwise 366 | 367 | # Calculate auxiliary vectors. 368 | eye_left = np.mean([lm_eye_left[0], lm_eye_left[3]], axis=0) 369 | eye_right = np.mean([lm_eye_right[0], lm_eye_right[3]], axis=0) 370 | eye_avg = (eye_left + eye_right) * 0.5 371 | eye_to_eye = eye_right - eye_left 372 | mouth_left = np.array(lm_mouth_outer[0]) 373 | mouth_right = np.array(lm_mouth_outer[6]) 374 | mouth_avg = (mouth_left + mouth_right) * 0.5 375 | eye_to_mouth = mouth_avg - eye_avg 376 | 377 | # Choose oriented crop rectangle. 378 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 379 | x /= np.hypot(*x) 380 | x *= np.hypot(*eye_to_eye) * 2.0#max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 381 | 382 | y = np.flipud(x) * [-1, 1] 383 | c = eye_avg + eye_to_mouth * 0.1 384 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 385 | qsize = np.hypot(*x) * 2 386 | 387 | stat_dict['coord'].append(quad) 388 | stat_dict['qsize'].append(qsize) 389 | 390 | qsize = stat_dict['qsize'][idx] 391 | quad = np.array(stat_dict['coord'][idx]) 392 | """ 393 | # Shrink. 394 | shrink = int(np.floor(qsize / output_size * 0.5)) 395 | if shrink > 1: 396 | print('shrink!') 397 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 398 | img = img.resize(rsize, Image.ANTIALIAS) 399 | quad /= shrink 400 | qsize /= shrink 401 | """ 402 | # Crop. 403 | border = max(int(np.rint(qsize * 0.1)), 3) 404 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 405 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 406 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 407 | img = img.crop(crop) 408 | quad -= crop[0:2] 409 | 410 | stat_dict['crop'].append(crop) 411 | stat_dict['quad'].append((quad + 0.5).flatten()) 412 | #img = img.crop(crop) 413 | # Pad. 414 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 415 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 416 | if max(pad) > border - 4: 417 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 418 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'edge') 419 | h, w, _ = img.shape 420 | y, x, _ = np.ogrid[:h, :w, :1] 421 | img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 422 | quad += pad[:2] 423 | # Transform. 424 | img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR) 425 | img_pil = img.resize((output_size, output_size), Image.LANCZOS) 426 | 427 | # resizing 428 | img_pil.save(save_dir+img_name) 429 | 430 | np.save(save_dir+'stat_dict.npy', stat_dict) 431 | 432 | img_to_tensor = transforms.Compose([ 433 | transforms.ToTensor(), 434 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 435 | ]) 436 | 437 | def generate_mask(img_dir, save_dir, parsing_net, labels=[1,2,3,4,5,6,9,10,11,12,13], output_size=(1024, 1024), device=torch.device('cuda')): 438 | os.makedirs(save_dir, exist_ok=True) 439 | img_list = glob.glob1(img_dir, 'frame*') 440 | img_list.sort() 441 | 442 | for img_name in img_list: 443 | img_path = os.path.join(img_dir, img_name) 444 | img = Image.open(img_path).resize((512, 512), Image.LANCZOS) 445 | x_1 = img_to_tensor(img).unsqueeze(0).to(device) 446 | out_1 = parsing_net(x_1) 447 | parsing = out_1[0].squeeze(0).detach().cpu().numpy().argmax(0) 448 | mask = np.uint8(parsing) 449 | for j in labels: 450 | mask = np.where(mask==j, 255, mask) 451 | mask = np.where(mask==255, 255, 0) 452 | mask_pil = Image.fromarray(np.uint8(mask)).resize(output_size, Image.LANCZOS) 453 | save_path = os.path.join(save_dir, img_name) 454 | mask_pil.save(save_path) -------------------------------------------------------------------------------- /video_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import argparse 7 | import copy 8 | import glob 9 | import numpy as np 10 | import os 11 | import torch 12 | import yaml 13 | import time 14 | 15 | from PIL import Image 16 | from torchvision import transforms, utils, models 17 | 18 | from utils.video_utils import * 19 | from face_parsing.model import BiSeNet 20 | from trainer import * 21 | 22 | 23 | torch.backends.cudnn.enabled = True 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = True 26 | torch.autograd.set_detect_anomaly(True) 27 | Image.MAX_IMAGE_PIXELS = None 28 | device = torch.device('cuda') 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--config', type=str, default='001', help='Path to the config file.') 32 | parser.add_argument('--attr', type=str, default='Eyeglasses', help='attribute for manipulation.') 33 | parser.add_argument('--alpha', type=str, default='1.', help='scale for manipulation.') 34 | parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path') 35 | parser.add_argument('--pretrained_model_path', type=str, default='./pretrained_models/143_enc.pth', help='pretrained stylegan2 model') 36 | parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='pretrained stylegan2 model') 37 | parser.add_argument('--arcface_model_path', type=str, default='./pretrained_models/backbone.pth', help='pretrained arcface model') 38 | parser.add_argument('--parsing_model_path', type=str, default='./pretrained_models/79999_iter.pth', help='pretrained parsing model') 39 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 40 | parser.add_argument('--function', type=str, default='', help='Calling function by name.') 41 | parser.add_argument('--video_path', type=str, default='./data/video/FP006911MD02.mp4', help='video file path') 42 | parser.add_argument('--output_path', type=str, default='./output/video/', help='output video file path') 43 | parser.add_argument('--boundary_path', type=str, default='./boundaries_ours/', help='output video file path') 44 | parser.add_argument('--optical_flow', action='store_true', help='use optical flow') 45 | parser.add_argument('--resize', action='store_true', help='downscale image size') 46 | parser.add_argument('--seamless', action='store_true', help='seamless cloning') 47 | parser.add_argument('--filter_size', type=float, default=3, help='filter size') 48 | parser.add_argument('--strs', type=str, default='Original,Projected,Manipulated', help='strs to be added on video') 49 | opts = parser.parse_args() 50 | 51 | # Celeba attribute list 52 | attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \ 53 | 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \ 54 | 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \ 55 | 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \ 56 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \ 57 | 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \ 58 | 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \ 59 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 60 | 61 | img_to_tensor = transforms.Compose([ 62 | transforms.Resize((1024, 1024)), 63 | transforms.ToTensor(), 64 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 65 | ]) 66 | 67 | # linear interpolation 68 | def linear_interpolate(latent_code, 69 | boundary, 70 | start_distance=-3.0, 71 | end_distance=3.0, 72 | steps=10): 73 | assert (latent_code.shape[0] == 1 and boundary.shape[0] == 1 and 74 | len(boundary.shape) == 2 and 75 | boundary.shape[1] == latent_code.shape[-1]) 76 | 77 | linspace = np.linspace(start_distance, end_distance, steps) 78 | if len(latent_code.shape) == 2: 79 | linspace = linspace.reshape(-1, 1).astype(np.float32) 80 | return latent_code + linspace * boundary 81 | if len(latent_code.shape) == 3: 82 | linspace = linspace.reshape(-1, 1, 1).astype(np.float32) 83 | return latent_code + linspace * boundary.reshape(1, 1, -1) 84 | 85 | # Latent code manipulation 86 | def latent_manipulation(opts, align_dir_path, process_dir_path): 87 | 88 | os.makedirs(process_dir_path, exist_ok=True) 89 | #attrs = opts.attr.split(',') 90 | #alphas = opts.alpha.split(',') 91 | step_scale = 15 * int(opts.alpha) 92 | n_steps = 5 93 | 94 | boundary = np.load(opts.boundary_path +'%s_boundary.npy'%opts.attr) 95 | 96 | # Initialize trainer 97 | config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'), Loader=yaml.FullLoader) 98 | trainer = Trainer(config, opts) 99 | trainer.initialize(opts.stylegan_model_path, opts.arcface_model_path, opts.parsing_model_path) 100 | trainer.to(device) 101 | 102 | state_dict = torch.load(opts.pretrained_model_path)#os.path.join(opts.log_path, opts.config + '/checkpoint.pth')) 103 | trainer.enc.load_state_dict(torch.load(opts.pretrained_model_path)) 104 | trainer.enc.eval() 105 | 106 | with torch.no_grad(): 107 | img_list = [glob.glob1(align_dir_path, ext) for ext in ['*jpg','*png']] 108 | img_list = [item for sublist in img_list for item in sublist] 109 | img_list.sort() 110 | n_1 = trainer.StyleGAN.make_noise() 111 | 112 | for i, img_name in enumerate(img_list): 113 | #print(i, img_name) 114 | image_A = img_to_tensor(Image.open(align_dir_path + img_name)).unsqueeze(0).to(device) 115 | w_0, f_0 = trainer.encode(image_A) 116 | 117 | w_0_np = w_0.cpu().numpy().reshape(1, -1) 118 | out = linear_interpolate(w_0_np, boundary, start_distance=-step_scale, end_distance=step_scale, steps=n_steps) 119 | w_1 = torch.tensor(out[-1]).view(1, -1, 512).to(device) 120 | 121 | _, fea_0 = trainer.StyleGAN([w_0], noise=n_1, input_is_latent=True, return_features=True) 122 | _, fea_1 = trainer.StyleGAN([w_1], noise=n_1, input_is_latent=True, return_features=True) 123 | 124 | features = [None]*5 + [f_0 + fea_1[5] - fea_0[5]] + [None]*(17-5) 125 | x_1, _ = trainer.StyleGAN([w_1], noise=n_1, input_is_latent=True, features_in=features, feature_scale=1.0) 126 | utils.save_image(clip_img(x_1), process_dir_path + 'frame%04d'%i+'.jpg') 127 | 128 | 129 | video_path = opts.video_path 130 | video_name = video_path.split('/')[-1] 131 | orig_dir_path = opts.output_path + video_name.split('.')[0] + '/' + video_name.split('.')[0] + '/' 132 | align_dir_path = os.path.dirname(orig_dir_path) + '_crop_align/' 133 | mask_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_mask/' 134 | latent_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_latent/' 135 | process_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_' + opts.attr.replace(',','_') + '/' 136 | reproject_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_' + opts.attr.replace(',','_') + '_reproject/' 137 | 138 | 139 | print(opts.function) 140 | start_time = time.perf_counter() 141 | 142 | if opts.function == 'video_to_frames': 143 | video_to_frames(video_path, orig_dir_path, count_num=120, resize=opts.resize) 144 | create_video(orig_dir_path) 145 | elif opts.function == 'align_frames': 146 | align_frames(orig_dir_path, align_dir_path, output_size=1024, optical_flow=opts.optical_flow, filter_size=opts.filter_size) 147 | # parsing mask 148 | parsing_net = BiSeNet(n_classes=19) 149 | parsing_net.load_state_dict(torch.load(opts.parsing_model_path)) 150 | parsing_net.eval() 151 | parsing_net.to(device) 152 | generate_mask(align_dir_path, mask_dir_path, parsing_net) 153 | elif opts.function == 'latent_manipulation': 154 | latent_manipulation(opts, align_dir_path, process_dir_path) 155 | elif opts.function == 'reproject_origin': 156 | process_dir_path = os.path.dirname(orig_dir_path) + '_inversion/' 157 | reproject_dir_path = os.path.dirname(orig_dir_path) + '_inversion_reproject/' 158 | video_reproject(orig_dir_path, process_dir_path, reproject_dir_path, align_dir_path, mask_dir_path, seamless=opts.seamless) 159 | create_video(reproject_dir_path) 160 | elif opts.function == 'reproject_manipulate': 161 | video_reproject(orig_dir_path, process_dir_path, reproject_dir_path, align_dir_path, mask_dir_path, seamless=opts.seamless) 162 | create_video(reproject_dir_path) 163 | elif opts.function == 'compare_frames': 164 | process_dir_paths = [] 165 | process_dir_paths.append(os.path.dirname(orig_dir_path) + '_inversion_reproject/') 166 | if len(opts.attr.split(','))>0: 167 | process_dir_paths.append(reproject_dir_path) 168 | save_dir = os.path.dirname(orig_dir_path) + '_crop_align_' + opts.attr.replace(',','_') + '_compare/' 169 | compare_frames(save_dir, orig_dir_path, process_dir_paths, strs=opts.strs, dim=1) 170 | create_video(save_dir, video_format='.avi', resize_ratio=1) 171 | 172 | 173 | count_time = time.perf_counter() - start_time 174 | print("Elapsed time: %0.4f seconds"%count_time) --------------------------------------------------------------------------------