├── .gitignore ├── LICENSE.txt ├── README.md ├── data_loader.py ├── deeplab.py ├── download_ffhq_aging.py ├── ffhq_aging_labels.csv ├── get_ffhq_aging.bat ├── get_ffhq_aging.sh ├── images ├── age_distribution.png └── dataset_samples_github.png ├── pydrive_utils.py ├── requirements.txt ├── run_deeplab.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ffhq_aging*/ 2 | in-the-wild-images/ 3 | deeplab_model/ 4 | __pycache__/ 5 | *.json 6 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright Roy Or-El, 2020 2 | 3 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 4 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 5 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 6 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 7 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 8 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE 9 | USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | 11 | The dataset is made available under Creative Commons BY-NC-SA 4.0 license 12 | by University of Washington. You can use, redistribute, and adapt it 13 | for non-commercial purposes, as long as you (a) give appropriate credit 14 | by citing our paper, (b) indicate any changes that you've made, 15 | and (c) distribute any derivative works under the same license. 16 | 17 | Lifespan Age Transformation Synthesis 18 | Roy Or-El, Soumyadip Sengupta, Ohad Fried, Eli Shechtman, Ira Kemelmacher-Shlizerman 19 | https://arxiv.org/pdf/2003.09764.pdf 20 | 21 | The individual images were published in Flickr by their respective authors 22 | under either Creative Commons BY 2.0, Creative Commons BY-NC 2.0, 23 | Public Domain Mark 1.0, Public Domain CC0 1.0, or U.S. Government Works 24 | license. All of these licenses allow free use, redistribution, and adaptation 25 | for non-commercial purposes. However, some of them require giving appropriate 26 | credit to the original author, as well as indicating any changes that were 27 | made to the images. The license and original author of each image are 28 | indicated in the metadata. 29 | 30 | https://creativecommons.org/licenses/by/2.0/ 31 | https://creativecommons.org/licenses/by-nc/2.0/ 32 | https://creativecommons.org/publicdomain/mark/1.0/ 33 | https://creativecommons.org/publicdomain/zero/1.0/ 34 | http://www.usa.gov/copyright.shtml 35 | 36 | The JSON metadata is made available under Creative Commons BY-NC-SA 4.0 license by NVIDIA Corporation. 37 | 38 | The individual images and JSON metadata are hosted on NVIDIA's Google Drive, 39 | please see the original FFHQ dataset for more details. 40 | 41 | https://github.com/NVlabs/ffhq-dataset 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FFHQ-Aging Dataset 2 | ### [Project Page](https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/) | [Paper](https://arxiv.org/pdf/2003.09764.pdf) | [Lifespan Age Transformation Synthesis Code](https://github.com/royorel/Lifespan_Age_Transformation_Synthesis) 3 |
4 | 5 | ## Overview 6 | FFHQ-Aging is a Dataset of human faces designed for benchmarking age transformation algorithms as well as many other possible vision tasks. 7 | 8 | This dataset is an extention of the NVIDIA [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset), on top of the 70,000 original FFHQ images, it also contains the following information for each image: 9 | 1. Gender information (male/female with confidence score) 10 | 2. Age group information (10 classes with confidence score) 11 | 3. Head pose (pitch, roll & yaw) 12 | 4. Glasses type (none, normal or dark) 13 | 5. Eye occlusion score (0-100, different score for each eye) 14 | 6. Full semantic map (19 classes, based on CelebAMask-HQ labels) 15 | 16 | If you use this dataset for your work, please cite our paper: 17 | > **Lifespan Age Transformation Synthesis**
18 | > Roy Or-El, Soumyadip Sengupta, Ohad Fried, Eli Shechtman, Ira Kemelmacher-Shlizerman
19 | > ECCV 2020
20 | > https://arxiv.org/pdf/2003.09764.pdf 21 | 22 | ## Dataset Statistics 23 | The following histogram shows the age class distribution per gender. 24 | 25 |
26 | 27 | Gender labels & confidence, age class labels & confidence score, head pose, glasses type and left & right eye occlusion scores for each individual image are stored in **ffhq_aging_labels.csv**. 28 | 29 | ## Pre-Requisits 30 | You must have a **GPU with CUDA support** in order to run the segmentation code. 31 | 32 | This code requires **PyTorch** to be installed, please go to [Pytorch.org](https://pytorch.org/) for installation info.
33 | In addition, the following python packages should be installed: 34 | 1. requests 35 | 2. pillow 36 | 3. numpy 37 | 4. scipy 38 | 5. PyDrive 39 | 40 | If any of these packages are not installed on your computer, you can install them using the supplied `requirements.txt` file:
41 | ```pip install -r requirements.txt``` 42 | 43 | **Note for windows users:** make sure that you have a 64bit python version installed. Otherwise you might get a memory error when reading the FFHQ JSON file. 44 | 45 | ## Usage 46 | 47 | ### Default download method 48 | To download the dataset in the default resolution (256x256) run:
49 | Linux & Mac: ```./get_ffhq_aging.sh```
50 | Windows: ```get_ffhq_aging.bat```
51 | 52 | If you encounter a "quota exceeded" error, see [Downloading with PyDrive](#downloading-with-pydrive) 53 | 54 | ### Downloading with PyDrive 55 | Google drive enforces a quota on file download by anonymous users. 56 | If you encounter a "quota exceeded" error, either wait 24 hours for the quota limit to reset and try again, or follow the procedure below. 57 | 58 | #### Step 1: Add the original FFHQ dataset to the "Shared With Me" section of your Google Drive 59 | Note: this step does *not* count against your Google Drive storage limit. 60 | 61 | * Login to your Google Drive 62 | * Visit [ffhq-dataset](https://drive.google.com/drive/folders/1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP) 63 | 64 | #### Step 2: Enable the Google Drive API 65 | Note: this only applies to *your* download script, and does not give access to other users. 66 | Nevertheless, we recommend revoking the script's access after the download is complete. 67 | 68 | * Go to : https://developers.google.com/drive/api/v3/quickstart/python 69 | * Click on enable drive API 70 | * Select Desktop app 71 | * Download client configuration 72 | * Rename this file to `client_secrets.json` and place it in the same folder as the download script (`download_ffhq_aging.py`). 73 | 74 | **Update (4/29/2021): Google have updated this page, please follow the prerequisists section of the updated page to get the credential files** 75 | 76 | #### Step 3: Run the script 77 | * In order to run the code with authntication, edit the `get_ffhq_aging.sh/bat` script, and add the `--pydrive` flag when invoking `download_ffhq_aging.py`. This will open a browser authentication window. Log in to your account and allow access. 78 | * If you have no display (like when running from a remote compute server), edit the `get_ffhq_aging.sh/bat` script, and also add the `--cmd_auth` flag when invoking `download_ffhq_aging.py`. This will print a Google authentication link to the screen. Open the link in any browser, allow access, and paste the Google authentication token back to the command line. 79 | 80 | **Important Note**: using this will let the code access your Google Drive, which might pose a security risk. 81 | We recommend using it only in cases when the default interface consistently returns a quota exceeded error. 82 | In addition, we recommend to disable the drive API and delete `client_secrets.json` after the dataset download is complete. 83 | 84 | ### Optional Arguments 85 | **download_ffhq_aging.py**
86 | ``` 87 | --debug run in debug mode, download 50 random images (default: False) 88 | --pydrive use pydrive interface to download files. It can override google drive quota limitation 89 | this requires google credentials (default: False) 90 | --cmd_auth use command line google authentication when using pydrive interface 91 | this is good when running on a server with no display (default: False) 92 | --check_invalid_images checks for any invalid images and downloads them again 93 | --resolution final resolution of saved images (default: 256) 94 | --num_threads NUM number of concurrent download threads (default: 32) 95 | --num_attempts NUM number of download attempts per file (default: 10) 96 | ``` 97 | 98 | **run_deeplab.py**
99 | ``` 100 | --resolution segmentation output size (default: 256) 101 | --workers number of data loading workers (default: 4) 102 | ``` 103 | 104 | Please make sure that the `--resolution` option for both scripts is the same 105 | 106 | 107 | ## License & Privacy 108 | The dataset is made available under [Creative Commons BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license by University of Washington. You can **use, redistribute, and adapt it for non-commercial purposes**, as long as you (a) give appropriate credit by **citing our paper**, (b) **indicate any changes** that you've made, and (c) distribute any derivative works **under the same license**. 109 | 110 | The individual images were published in Flickr by their respective authors under either [Creative Commons BY 2.0](https://creativecommons.org/licenses/by/2.0/), [Creative Commons BY-NC 2.0](https://creativecommons.org/licenses/by-nc/2.0/), [Public Domain Mark 1.0](https://creativecommons.org/publicdomain/mark/1.0/), [Public Domain CC0 1.0](https://creativecommons.org/publicdomain/zero/1.0/), or [U.S. Government Works](http://www.usa.gov/copyright.shtml) license. All of these licenses allow **free use, redistribution, and adaptation for non-commercial purposes**. However, some of them require giving **appropriate credit** to the original author, as well as **indicating any changes** that were made to the images. The license and original author of each image are indicated in the metadata. 111 | 112 | * [https://creativecommons.org/licenses/by/2.0/](https://creativecommons.org/licenses/by/2.0/) 113 | * [https://creativecommons.org/licenses/by-nc/2.0/](https://creativecommons.org/licenses/by-nc/2.0/) 114 | * [https://creativecommons.org/publicdomain/mark/1.0/](https://creativecommons.org/publicdomain/mark/1.0/) 115 | * [https://creativecommons.org/publicdomain/zero/1.0/](https://creativecommons.org/publicdomain/zero/1.0/) 116 | * [http://www.usa.gov/copyright.shtml](http://www.usa.gov/copyright.shtml) 117 | 118 | The JSON metadata is made available under [Creative Commons BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license by NVIDIA Corporation. 119 | 120 | The individual images and JSON metadata are hosted on NVIDIA's Google Drive, please see the original [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset) for more details. 121 | 122 | To find out whether your photo is included in the original Flickr-Faces-HQ dataset and/or get it removed from both this dataset and the original FFHQ dataset please go to the privacy section in the original [FFHQ Dataset website](https://github.com/NVlabs/ffhq-dataset) and follow the instructions. 123 | 124 | ## Acknowledgements 125 | We wish to thank Thevina Dokka for helping us collecting the dataset. 126 | 127 | Original face images were collected in the [NVIDIA FFHQ dataset](https://github.com/NVlabs/ffhq-dataset). 128 | > **A Style-Based Generator Architecture for Generative Adversarial Networks**
129 | > Tero Karras, Samuli Laine, Timo Aila, CVPR 2019
130 | > http://openaccess.thecvf.com/content_CVPR_2019/papers/Karras_A_Style-Based_Generator_Architecture_for_Generative_Adversarial_Networks_CVPR_2019_paper.pdf 131 | 132 | Age & gender labels and confidence scores were collected using the [Appen](https://www.appen.com/) platform. 133 | 134 | Head pose, glasses type and eye occlusion score were extraceted using the [Face++](https://www.faceplusplus.com/) platform. 135 | 136 | Face Semantic maps were acquired by training a pytorch implementation of [DeepLabV3](https://github.com/chenxi116/DeepLabv3.pytorch) network on the [CelebAMASK-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset. 137 | > **Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation**
138 | > Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam, ECCV 2018
139 | > http://openaccess.thecvf.com/content_ECCV_2018/papers/Liang-Chieh_Chen_Encoder-Decoder_with_Atrous_ECCV_2018_paper.pdf 140 | 141 | > **MaskGAN: Towards Diverse and Interactive Facial Image Manipulation**
142 | > Cheng-Han Lee, Ziwei Liu, Lingyun Wu, Ping Luo, CVPR 2020
143 | > https://arxiv.org/pdf/1907.11922.pdf 144 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Roy Or-El. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # To view a copy of this license, visit 6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 8 | 9 | 10 | import torch.utils.data as data 11 | import os 12 | from PIL import Image 13 | from utils import preprocess_image 14 | 15 | 16 | class CelebASegmentation(data.Dataset): 17 | CLASSES = ['background' ,'skin','nose','eye_g','l_eye','r_eye','l_brow','r_brow','l_ear','r_ear','mouth','u_lip','l_lip','hair','hat','ear_r','neck_l','neck','cloth'] 18 | 19 | def __init__(self, root, transform=None, crop_size=None): 20 | self.root = root 21 | self.transform = transform 22 | self.crop_size = crop_size 23 | 24 | self.images = [] 25 | subdirs = next(os.walk(self.root))[1] #quick trick to get all subdirectories 26 | for subdir in subdirs: 27 | curr_images = [os.path.join(self.root,subdir,file) for file in os.listdir(os.path.join(self.root,subdir)) if file.endswith('.png')] 28 | self.images += curr_images 29 | 30 | 31 | def __getitem__(self, index): 32 | _img = Image.open(self.images[index]).convert('RGB') 33 | _img=_img.resize((513,513),Image.BILINEAR) 34 | _img = preprocess_image(_img,flip=False,scale=None,crop=(self.crop_size, self.crop_size)) 35 | 36 | if self.transform is not None: 37 | _img = self.transform(_img) 38 | 39 | return _img 40 | 41 | def __len__(self): 42 | return len(self.images) 43 | -------------------------------------------------------------------------------- /deeplab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Roy Or-El. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # To view a copy of this license, visit 6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 8 | 9 | # This file was taken as is from the https://github.com/chenxi116/DeepLabv3.pytorch repository. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import math 14 | import torch.utils.model_zoo as model_zoo 15 | from torch.nn import functional as F 16 | 17 | 18 | __all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152'] 19 | 20 | 21 | model_urls = { 22 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 23 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 24 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 25 | } 26 | 27 | 28 | class Conv2d(nn.Conv2d): 29 | 30 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 31 | padding=0, dilation=1, groups=1, bias=True): 32 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 33 | padding, dilation, groups, bias) 34 | 35 | def forward(self, x): 36 | # return super(Conv2d, self).forward(x) 37 | weight = self.weight 38 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 39 | keepdim=True).mean(dim=3, keepdim=True) 40 | weight = weight - weight_mean 41 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 42 | weight = weight / std.expand_as(weight) 43 | return F.conv2d(x, weight, self.bias, self.stride, 44 | self.padding, self.dilation, self.groups) 45 | 46 | 47 | class ASPP(nn.Module): 48 | 49 | def __init__(self, C, depth, num_classes, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1): 50 | super(ASPP, self).__init__() 51 | self._C = C 52 | self._depth = depth 53 | self._num_classes = num_classes 54 | 55 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False) 58 | self.aspp2 = conv(C, depth, kernel_size=3, stride=1, 59 | dilation=int(6*mult), padding=int(6*mult), 60 | bias=False) 61 | self.aspp3 = conv(C, depth, kernel_size=3, stride=1, 62 | dilation=int(12*mult), padding=int(12*mult), 63 | bias=False) 64 | self.aspp4 = conv(C, depth, kernel_size=3, stride=1, 65 | dilation=int(18*mult), padding=int(18*mult), 66 | bias=False) 67 | self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False) 68 | self.aspp1_bn = norm(depth, momentum) 69 | self.aspp2_bn = norm(depth, momentum) 70 | self.aspp3_bn = norm(depth, momentum) 71 | self.aspp4_bn = norm(depth, momentum) 72 | self.aspp5_bn = norm(depth, momentum) 73 | self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1, 74 | bias=False) 75 | self.bn2 = norm(depth, momentum) 76 | self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1) 77 | 78 | def forward(self, x): 79 | x1 = self.aspp1(x) 80 | x1 = self.aspp1_bn(x1) 81 | x1 = self.relu(x1) 82 | x2 = self.aspp2(x) 83 | x2 = self.aspp2_bn(x2) 84 | x2 = self.relu(x2) 85 | x3 = self.aspp3(x) 86 | x3 = self.aspp3_bn(x3) 87 | x3 = self.relu(x3) 88 | x4 = self.aspp4(x) 89 | x4 = self.aspp4_bn(x4) 90 | x4 = self.relu(x4) 91 | x5 = self.global_pooling(x) 92 | x5 = self.aspp5(x5) 93 | x5 = self.aspp5_bn(x5) 94 | x5 = self.relu(x5) 95 | x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', 96 | align_corners=True)(x5) 97 | x = torch.cat((x1, x2, x3, x4, x5), 1) 98 | x = self.conv2(x) 99 | x = self.bn2(x) 100 | x = self.relu(x) 101 | x = self.conv3(x) 102 | 103 | return x 104 | 105 | 106 | class Bottleneck(nn.Module): 107 | expansion = 4 108 | 109 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, conv=None, norm=None): 110 | super(Bottleneck, self).__init__() 111 | self.conv1 = conv(inplanes, planes, kernel_size=1, bias=False) 112 | self.bn1 = norm(planes) 113 | self.conv2 = conv(planes, planes, kernel_size=3, stride=stride, 114 | dilation=dilation, padding=dilation, bias=False) 115 | self.bn2 = norm(planes) 116 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, bias=False) 117 | self.bn3 = norm(planes * self.expansion) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | self.stride = stride 121 | 122 | def forward(self, x): 123 | residual = x 124 | 125 | out = self.conv1(x) 126 | out = self.bn1(out) 127 | out = self.relu(out) 128 | 129 | out = self.conv2(out) 130 | out = self.bn2(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv3(out) 134 | out = self.bn3(out) 135 | 136 | if self.downsample is not None: 137 | residual = self.downsample(x) 138 | 139 | out += residual 140 | out = self.relu(out) 141 | 142 | return out 143 | 144 | 145 | class ResNet(nn.Module): 146 | 147 | def __init__(self, block, layers, num_classes, num_groups=None, weight_std=False, beta=False): 148 | self.inplanes = 64 149 | self.norm = lambda planes, momentum=0.05: nn.BatchNorm2d(planes, momentum=momentum) if num_groups is None else nn.GroupNorm(num_groups, planes) 150 | self.conv = Conv2d if weight_std else nn.Conv2d 151 | 152 | super(ResNet, self).__init__() 153 | if not beta: 154 | self.conv1 = self.conv(3, 64, kernel_size=7, stride=2, padding=3, 155 | bias=False) 156 | else: 157 | self.conv1 = nn.Sequential( 158 | self.conv(3, 64, 3, stride=2, padding=1, bias=False), 159 | self.conv(64, 64, 3, stride=1, padding=1, bias=False), 160 | self.conv(64, 64, 3, stride=1, padding=1, bias=False)) 161 | self.bn1 = self.norm(64) 162 | self.relu = nn.ReLU(inplace=True) 163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 164 | self.layer1 = self._make_layer(block, 64, layers[0]) 165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 167 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 168 | dilation=2) 169 | self.aspp = ASPP(512 * block.expansion, 256, num_classes, conv=self.conv, norm=self.norm) 170 | 171 | for m in self.modules(): 172 | if isinstance(m, self.conv): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() 178 | 179 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 180 | downsample = None 181 | if stride != 1 or dilation != 1 or self.inplanes != planes * block.expansion: 182 | downsample = nn.Sequential( 183 | self.conv(self.inplanes, planes * block.expansion, 184 | kernel_size=1, stride=stride, dilation=max(1, dilation/2), bias=False), 185 | self.norm(planes * block.expansion), 186 | ) 187 | 188 | layers = [] 189 | layers.append(block(self.inplanes, planes, stride, downsample, dilation=max(1, dilation/2), conv=self.conv, norm=self.norm)) 190 | self.inplanes = planes * block.expansion 191 | for i in range(1, blocks): 192 | layers.append(block(self.inplanes, planes, dilation=dilation, conv=self.conv, norm=self.norm)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | size = (x.shape[2], x.shape[3]) 198 | x = self.conv1(x) 199 | x = self.bn1(x) 200 | x = self.relu(x) 201 | x = self.maxpool(x) 202 | 203 | x = self.layer1(x) 204 | x = self.layer2(x) 205 | x = self.layer3(x) 206 | x = self.layer4(x) 207 | 208 | x = self.aspp(x) 209 | x = nn.Upsample(size, mode='bilinear', align_corners=True)(x) 210 | return x 211 | 212 | 213 | def resnet50(pretrained=False, **kwargs): 214 | """Constructs a ResNet-50 model. 215 | 216 | Args: 217 | pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | """ 219 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 220 | if pretrained: 221 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 222 | return model 223 | 224 | 225 | def resnet101(pretrained=False, num_groups=None, weight_std=False, **kwargs): 226 | """Constructs a ResNet-101 model. 227 | 228 | Args: 229 | pretrained (bool): If True, returns a model pre-trained on ImageNet 230 | """ 231 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_groups=num_groups, weight_std=weight_std, **kwargs) 232 | if pretrained: 233 | model_dict = model.state_dict() 234 | if num_groups and weight_std: 235 | pretrained_dict = torch.load('deeplab_model/R-101-GN-WS.pth.tar') 236 | overlap_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} 237 | assert len(overlap_dict) == 312 238 | elif not num_groups and not weight_std: 239 | pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 240 | overlap_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 241 | else: 242 | raise ValueError('Currently only support BN or GN+WS') 243 | model_dict.update(overlap_dict) 244 | model.load_state_dict(model_dict) 245 | return model 246 | 247 | 248 | def resnet152(pretrained=False, **kwargs): 249 | """Constructs a ResNet-152 model. 250 | 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | """ 254 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 255 | if pretrained: 256 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 257 | return model 258 | -------------------------------------------------------------------------------- /download_ffhq_aging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Roy Or-El. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # To view a copy of this license, visit 6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 8 | 9 | # This code is a modification of the download_ffhq.py file from the original FFHQ dataset. 10 | # Here we download an in-the-wild-image, do the alignment and delete the original in-the-wild image. 11 | 12 | """Download Flickr-Face-HQ-Aging (FFHQ-Aging) dataset to current working directory.""" 13 | 14 | import os 15 | import sys 16 | import requests 17 | import html 18 | import hashlib 19 | import PIL.Image 20 | import PIL.ImageFile 21 | import numpy as np 22 | import scipy.ndimage 23 | import threading 24 | import queue 25 | import time 26 | import json 27 | import uuid 28 | import glob 29 | import argparse 30 | import itertools 31 | import shutil 32 | import pydrive_utils 33 | from collections import OrderedDict, defaultdict 34 | from pdb import set_trace as st 35 | 36 | PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | json_spec = dict(file_url='https://drive.google.com/uc?id=16N0RV4fHI6joBuKbQAoG34V_cQk7vxSA', file_path='ffhq-dataset-v2.json', file_size=267793842, file_md5='425ae20f06a4da1d4dc0f46d40ba5fd6') 41 | 42 | license_specs = { 43 | 'json': dict(file_url='https://drive.google.com/uc?id=1SHafCugkpMZzYhbgOz0zCuYiy-hb9lYX', file_path='LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'), 44 | 'images': dict(file_url='https://drive.google.com/uc?id=1sP2qz8TzLkzG2gjwAa4chtdB31THska4', file_path='images1024x1024/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'), 45 | 'thumbs': dict(file_url='https://drive.google.com/uc?id=1iaL1S381LS10VVtqu-b2WfF9TiY75Kmj', file_path='thumbnails128x128/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'), 46 | 'wilds': dict(file_url='https://drive.google.com/uc?id=1rsfFOEQvkd6_Z547qhpq5LhDl2McJEzw', file_path='in-the-wild-images/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'), 47 | 'tfrecords': dict(file_url='https://drive.google.com/uc?id=1SYUmqKdLoTYq-kqsnPsniLScMhspvl5v', file_path='tfrecords/ffhq/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'), 48 | } 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10): 53 | file_path = file_spec['file_path'] 54 | file_url = file_spec['file_url'] 55 | file_dir = os.path.dirname(file_path) 56 | tmp_path = file_path + '.tmp.' + uuid.uuid4().hex 57 | if file_dir: 58 | os.makedirs(file_dir, exist_ok=True) 59 | 60 | for attempts_left in reversed(range(num_attempts)): 61 | data_size = 0 62 | try: 63 | # Download. 64 | data_md5 = hashlib.md5() 65 | with session.get(file_url, stream=True) as res: 66 | res.raise_for_status() 67 | with open(tmp_path, 'wb') as f: 68 | for chunk in res.iter_content(chunk_size=chunk_size<<10): 69 | f.write(chunk) 70 | data_size += len(chunk) 71 | data_md5.update(chunk) 72 | with stats['lock']: 73 | stats['bytes_done'] += len(chunk) 74 | 75 | # Validate. 76 | if 'file_size' in file_spec and data_size != file_spec['file_size']: 77 | raise IOError('Incorrect file size', file_path) 78 | if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']: 79 | raise IOError('Incorrect file MD5', file_path) 80 | if 'pixel_size' in file_spec or 'pixel_md5' in file_spec: 81 | with PIL.Image.open(tmp_path) as image: 82 | if 'pixel_size' in file_spec and list(image.size) != file_spec['pixel_size']: 83 | raise IOError('Incorrect pixel size', file_path) 84 | if 'pixel_md5' in file_spec and hashlib.md5(np.array(image)).hexdigest() != file_spec['pixel_md5']: 85 | raise IOError('Incorrect pixel MD5', file_path) 86 | break 87 | 88 | except: 89 | with stats['lock']: 90 | stats['bytes_done'] -= data_size 91 | 92 | # Handle known failure cases. 93 | if data_size > 0 and data_size < 8192: 94 | with open(tmp_path, 'rb') as f: 95 | data = f.read() 96 | data_str = data.decode('utf-8') 97 | 98 | # Google Drive virus checker nag. 99 | links = [html.unescape(link) for link in data_str.split('"') if 'export=download' in link] 100 | if len(links) == 1: 101 | if attempts_left: 102 | file_url = requests.compat.urljoin(file_url, links[0]) 103 | continue 104 | 105 | # Google Drive quota exceeded. 106 | if 'Google Drive - Quota exceeded' in data_str: 107 | if not attempts_left: 108 | raise IOError("Google Drive download quota exceeded -- please try again later") 109 | 110 | # Last attempt => raise error. 111 | if not attempts_left: 112 | raise 113 | 114 | # Rename temp file to the correct name. 115 | os.replace(tmp_path, file_path) # atomic 116 | # with stats['lock']: 117 | # stats['files_done'] += 1 118 | 119 | # Attempt to clean up any leftover temps. 120 | for filename in glob.glob(file_path + '.tmp.*'): 121 | try: 122 | os.remove(filename) 123 | except: 124 | pass 125 | 126 | #---------------------------------------------------------------------------- 127 | 128 | def choose_bytes_unit(num_bytes): 129 | b = int(np.rint(num_bytes)) 130 | if b < (100 << 0): return 'B', (1 << 0) 131 | if b < (100 << 10): return 'kB', (1 << 10) 132 | if b < (100 << 20): return 'MB', (1 << 20) 133 | if b < (100 << 30): return 'GB', (1 << 30) 134 | return 'TB', (1 << 40) 135 | 136 | #---------------------------------------------------------------------------- 137 | 138 | def format_time(seconds): 139 | s = int(np.rint(seconds)) 140 | if s < 60: return '%ds' % s 141 | if s < 60 * 60: return '%dm %02ds' % (s // 60, s % 60) 142 | if s < 24 * 60 * 60: return '%dh %02dm' % (s // (60 * 60), (s // 60) % 60) 143 | if s < 100 * 24 * 60 * 60: return '%dd %02dh' % (s // (24 * 60 * 60), (s // (60 * 60)) % 24) 144 | return '>100d' 145 | 146 | #---------------------------------------------------------------------------- 147 | 148 | def download_files(file_specs, dst_dir='.', output_size=256, check_invalid_images=False, drive=None, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs): 149 | 150 | # Determine which files to download. 151 | done_specs = {} 152 | for spec in file_specs: 153 | if os.path.isfile(spec['file_path'].replace('in-the-wild-images',dst_dir)): 154 | if check_invalid_images: 155 | try: 156 | test_im = PIL.Image.open(spec['file_path'].replace('in-the-wild-images',dst_dir)) 157 | done_specs.update({spec['file_path']: spec}) 158 | except: 159 | continue 160 | else: 161 | done_specs.update({spec['file_path']: spec}) 162 | 163 | missing_specs = [spec for spec in file_specs if spec['file_path'] not in done_specs] 164 | files_total = len(file_specs) 165 | bytes_total = sum(spec['file_size'] for spec in file_specs) 166 | stats = dict(files_done=len(done_specs), bytes_done=sum(spec['file_size'] for spec in done_specs.values()), lock=threading.Lock()) 167 | if len(done_specs) == files_total: 168 | print('All files already downloaded -- skipping.') 169 | return 170 | 171 | # Launch worker threads. 172 | spec_queue = queue.Queue() 173 | exception_queue = queue.Queue() 174 | for spec in missing_specs: 175 | spec_queue.put(spec) 176 | thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, 177 | stats=stats, dst_dir=dst_dir, output_size=output_size, 178 | drive=drive, download_kwargs=download_kwargs) 179 | for _thread_idx in range(min(num_threads, len(missing_specs))): 180 | threading.Thread(target=_download_thread, kwargs=thread_kwargs, daemon=True).start() 181 | 182 | # Monitor status until done. 183 | bytes_unit, bytes_div = choose_bytes_unit(bytes_total) 184 | spinner = '/-\\|' 185 | timing = [] 186 | while True: 187 | spinner = spinner[1:] + spinner[:1] 188 | if drive != None: 189 | with stats['lock']: 190 | files_done = stats['files_done'] 191 | 192 | print('\r{} done processing {}/{} files'.format(spinner[0], files_done, files_total), 193 | end='', flush=True) 194 | else: 195 | with stats['lock']: 196 | files_done = stats['files_done'] 197 | bytes_done = stats['bytes_done'] 198 | timing = timing[max(len(timing) - timing_window + 1, 0):] + [(time.time(), bytes_done)] 199 | bandwidth = max((timing[-1][1] - timing[0][1]) / max(timing[-1][0] - timing[0][0], 1e-8), 0) 200 | bandwidth_unit, bandwidth_div = choose_bytes_unit(bandwidth) 201 | eta = format_time((bytes_total - bytes_done) / max(bandwidth, 1)) 202 | 203 | print('\r%s %6.2f%% done processed %d/%d files %-13s %-10s ETA: %-7s ' % ( 204 | spinner[0], 205 | bytes_done / bytes_total * 100, 206 | files_done, files_total, 207 | 'downloaded %.2f/%.2f %s' % (bytes_done / bytes_div, bytes_total / bytes_div, bytes_unit), 208 | '%.2f %s/s' % (bandwidth / bandwidth_div, bandwidth_unit), 209 | 'done' if bytes_total == bytes_done else '...' if len(timing) < timing_window or bandwidth == 0 else eta, 210 | ), end='', flush=True) 211 | 212 | if files_done == files_total: 213 | print() 214 | break 215 | 216 | 217 | try: 218 | exc_info = exception_queue.get(timeout=status_delay) 219 | raise exc_info[1].with_traceback(exc_info[2]) 220 | except queue.Empty: 221 | pass 222 | 223 | def _download_thread(spec_queue, exception_queue, stats, dst_dir, output_size, drive, download_kwargs): 224 | with requests.Session() as session: 225 | while not spec_queue.empty(): 226 | spec = spec_queue.get() 227 | try: 228 | if drive != None: 229 | pydrive_utils.pydrive_download(drive, spec['file_url'], spec['file_path']) 230 | else: 231 | download_file(session, spec, stats, **download_kwargs) 232 | 233 | if spec['file_path'].endswith('.png'): 234 | align_in_the_wild_image(spec, dst_dir, output_size) 235 | os.remove(spec['file_path']) 236 | 237 | except: 238 | exception_queue.put(sys.exc_info()) 239 | 240 | with stats['lock']: 241 | stats['files_done'] += 1 242 | 243 | #---------------------------------------------------------------------------- 244 | 245 | def align_in_the_wild_image(spec, dst_dir, output_size, transform_size=4096, enable_padding=True): 246 | if not os.path.isdir(dst_dir): 247 | os.makedirs(dst_dir, exist_ok=True) 248 | shutil.copyfile('LICENSE.txt', os.path.join(dst_dir, 'LICENSE.txt')) 249 | 250 | item_idx = int(os.path.basename(spec['file_path'])[:-4]) 251 | 252 | # Parse landmarks. 253 | # pylint: disable=unused-variable 254 | lm = np.array(spec['face_landmarks']) 255 | lm_chin = lm[0 : 17] # left-right 256 | lm_eyebrow_left = lm[17 : 22] # left-right 257 | lm_eyebrow_right = lm[22 : 27] # left-right 258 | lm_nose = lm[27 : 31] # top-down 259 | lm_nostrils = lm[31 : 36] # top-down 260 | lm_eye_left = lm[36 : 42] # left-clockwise 261 | lm_eye_right = lm[42 : 48] # left-clockwise 262 | lm_mouth_outer = lm[48 : 60] # left-clockwise 263 | lm_mouth_inner = lm[60 : 68] # left-clockwise 264 | 265 | # Calculate auxiliary vectors. 266 | eye_left = np.mean(lm_eye_left, axis=0) 267 | eye_right = np.mean(lm_eye_right, axis=0) 268 | eye_avg = (eye_left + eye_right) * 0.5 269 | eye_to_eye = eye_right - eye_left 270 | mouth_left = lm_mouth_outer[0] 271 | mouth_right = lm_mouth_outer[6] 272 | mouth_avg = (mouth_left + mouth_right) * 0.5 273 | eye_to_mouth = mouth_avg - eye_avg 274 | 275 | # Choose oriented crop rectangle. 276 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 277 | x /= np.hypot(*x) 278 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 2.2) # This results in larger crops then the original FFHQ. For the original crops, replace 2.2 with 1.8 279 | y = np.flipud(x) * [-1, 1] 280 | c = eye_avg + eye_to_mouth * 0.1 281 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 282 | qsize = np.hypot(*x) * 2 283 | 284 | # Load in-the-wild image. 285 | src_file = spec['file_path'] 286 | if not os.path.isfile(src_file): 287 | print('\nCannot find source image. Please run "--wilds" before "--align".') 288 | return 289 | img = PIL.Image.open(src_file) 290 | 291 | # Shrink. 292 | shrink = int(np.floor(qsize / output_size * 0.5)) 293 | if shrink > 1: 294 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 295 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 296 | quad /= shrink 297 | qsize /= shrink 298 | 299 | # Crop. 300 | border = max(int(np.rint(qsize * 0.1)), 3) 301 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 302 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 303 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 304 | img = img.crop(crop) 305 | quad -= crop[0:2] 306 | 307 | # Pad. 308 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 309 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 310 | if enable_padding and max(pad) > border - 4: 311 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 312 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 313 | h, w, _ = img.shape 314 | y, x, _ = np.ogrid[:h, :w, :1] 315 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 316 | blur = qsize * 0.02 317 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 318 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 319 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 320 | quad += pad[:2] 321 | 322 | # Transform. 323 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 324 | if output_size < transform_size: 325 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 326 | 327 | # Save aligned image. 328 | dst_subdir = os.path.join(dst_dir, '%05d' % (item_idx - item_idx % 1000)) 329 | os.makedirs(dst_subdir, exist_ok=True) 330 | img.save(os.path.join(dst_subdir, '%05d.png' % item_idx)) 331 | 332 | 333 | #---------------------------------------------------------------------------- 334 | 335 | def run(resolution, debug, pydrive, cmd_auth, check_invalid_images, **download_kwargs): 336 | if pydrive: 337 | drive = pydrive_utils.create_drive_manager(cmd_auth) 338 | else: 339 | drive = None 340 | 341 | if not os.path.isfile(json_spec['file_path']) or not os.path.isfile('LICENSE.txt'): 342 | print('Downloading JSON metadata...') 343 | download_files([json_spec, license_specs['json']], drive=drive, **download_kwargs) 344 | 345 | print('Parsing JSON metadata...') 346 | with open(json_spec['file_path'], 'rb') as f: 347 | json_data = json.load(f, object_pairs_hook=OrderedDict) 348 | 349 | specs = [item['in_the_wild'] for item in json_data.values()] + [license_specs['wilds']] 350 | 351 | if len(specs): 352 | output_size = resolution 353 | dst_dir = 'ffhq_aging{}x{}'.format(output_size,output_size) 354 | np.random.shuffle(specs) # to make the workload more homogeneous 355 | if debug: 356 | specs = specs[:50] # to create images in multiple directories 357 | print('Downloading %d files...' % len(specs)) 358 | download_files(specs, dst_dir, output_size, check_invalid_images, drive=drive, **download_kwargs) 359 | 360 | if os.path.isdir('in-the-wild-images'): 361 | shutil.rmtree('in-the-wild-images') 362 | 363 | #---------------------------------------------------------------------------- 364 | 365 | def run_cmdline(argv): 366 | parser = argparse.ArgumentParser(prog=argv[0], description='Download Flickr-Face-HQ-Aging (FFHQ-Aging) dataset to current working directory.') 367 | parser.add_argument('--debug', help='activate debug mode, download 50 random images (default: False)', action='store_true') 368 | parser.add_argument('--pydrive', help='use pydrive interface to download files. it overrides google drive quota limitation \ 369 | this requires google credentials (default: False)', action='store_true') 370 | parser.add_argument('--cmd_auth', help='use command line google authentication when using pydrive interface (default: False)', action='store_true') 371 | parser.add_argument('--check_invalid_images', help='checks for any invalid images and downloads them again', action='store_true') 372 | parser.add_argument('--resolution', help='final resolution of saved images (default: 256)', type=int, default=256, metavar='PIXELS') 373 | parser.add_argument('--num_threads', help='number of concurrent download threads (default: 32)', type=int, default=32, metavar='NUM') 374 | parser.add_argument('--status_delay', help='time between download status prints (default: 0.2)', type=float, default=0.2, metavar='SEC') 375 | parser.add_argument('--timing_window', help='samples for estimating download eta (default: 50)', type=int, default=50, metavar='LEN') 376 | parser.add_argument('--chunk_size', help='chunk size for each download thread (default: 128)', type=int, default=128, metavar='KB') 377 | parser.add_argument('--num_attempts', help='number of download attempts per file (default: 10)', type=int, default=10, metavar='NUM') 378 | 379 | args = parser.parse_args() 380 | run(**vars(args)) 381 | 382 | #---------------------------------------------------------------------------- 383 | 384 | if __name__ == "__main__": 385 | run_cmdline(sys.argv) 386 | 387 | #---------------------------------------------------------------------------- 388 | -------------------------------------------------------------------------------- /get_ffhq_aging.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set CUDA_VISIBLE_DEVICES=0 4 | 5 | python download_ffhq_aging.py --resolution 256 6 | python run_deeplab.py --resolution 256 7 | -------------------------------------------------------------------------------- /get_ffhq_aging.sh: -------------------------------------------------------------------------------- 1 | python download_ffhq_aging.py --resolution 256 2 | CUDA_VISIBLE_DEVICES=0 python run_deeplab.py --resolution 256 3 | -------------------------------------------------------------------------------- /images/age_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/royorel/FFHQ-Aging-Dataset/2ecdcd2511c7e0da7f2a7cf0d839a9f6faafa645/images/age_distribution.png -------------------------------------------------------------------------------- /images/dataset_samples_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/royorel/FFHQ-Aging-Dataset/2ecdcd2511c7e0da7f2a7cf0d839a9f6faafa645/images/dataset_samples_github.png -------------------------------------------------------------------------------- /pydrive_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from pydrive.auth import GoogleAuth 4 | from pydrive.drive import GoogleDrive 5 | 6 | 7 | # Authentication + token creation 8 | def create_drive_manager(cmd_auth): 9 | gAuth = GoogleAuth() 10 | 11 | if cmd_auth: 12 | gAuth.CommandLineAuth() 13 | else: 14 | gAuth.LocalWebserverAuth() 15 | 16 | gAuth.Authorize() 17 | print("authorized access to google drive API!") 18 | 19 | drive: GoogleDrive = GoogleDrive(gAuth) 20 | return drive 21 | 22 | 23 | def extract_files_id(drive, link): 24 | try: 25 | fileID = re.search(r"(?<=/d/|id=|rs/).+?(?=/|$)", link)[0] # extract the fileID 26 | return fileID 27 | except Exception as error: 28 | print("error : " + str(error)) 29 | print("Link is probably invalid") 30 | print(link) 31 | 32 | 33 | def pydrive_download(drive, link, save_path): 34 | id = extract_files_id(drive, link) 35 | file_dir = os.path.dirname(save_path) 36 | if file_dir: 37 | os.makedirs(file_dir, exist_ok=True) 38 | 39 | pydrive_file = drive.CreateFile({'id': id}) 40 | pydrive_file.GetContentFile(save_path) 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | numpy 3 | scipy 4 | pillow 5 | PyDrive 6 | -------------------------------------------------------------------------------- /run_deeplab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Roy Or-El. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # To view a copy of this license, visit 6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 8 | 9 | # This code is a modification of the main.py file 10 | # from the https://github.com/chenxi116/DeepLabv3.pytorch repository 11 | 12 | import argparse 13 | import os 14 | import requests 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | from pdb import set_trace as st 19 | from PIL import Image 20 | from torchvision import transforms 21 | 22 | import deeplab 23 | from data_loader import CelebASegmentation 24 | from utils import download_file 25 | 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--resolution', type=int, default=256, 29 | help='segmentation output size') 30 | parser.add_argument('--workers', type=int, default=4, 31 | help='number of data loading workers') 32 | args = parser.parse_args() 33 | 34 | 35 | resnet_file_spec = dict(file_url='https://drive.google.com/uc?id=1oRGgrI4KNdefbWVpw0rRkEP1gbJIRokM', file_path='deeplab_model/R-101-GN-WS.pth.tar', file_size=178260167, file_md5='aa48cc3d3ba3b7ac357c1489b169eb32') 36 | deeplab_file_spec = dict(file_url='https://drive.google.com/uc?id=1w2XjDywFr2NjuUWaLQDRktH7VwIfuNlY', file_path='deeplab_model/deeplab_model.pth', file_size=464446305, file_md5='8e8345b1b9d95e02780f9bed76cc0293') 37 | 38 | def main(): 39 | resolution = args.resolution 40 | assert torch.cuda.is_available() 41 | torch.backends.cudnn.benchmark = True 42 | model_fname = 'deeplab_model/deeplab_model.pth' 43 | dataset_root = 'ffhq_aging{}x{}'.format(resolution,resolution) 44 | assert os.path.isdir(dataset_root) 45 | dataset = CelebASegmentation(dataset_root, crop_size=513) 46 | 47 | if not os.path.isfile(resnet_file_spec['file_path']): 48 | print('Downloading backbone Resnet Model parameters') 49 | with requests.Session() as session: 50 | download_file(session, resnet_file_spec) 51 | 52 | print('Done!') 53 | 54 | model = getattr(deeplab, 'resnet101')( 55 | pretrained=True, 56 | num_classes=len(dataset.CLASSES), 57 | num_groups=32, 58 | weight_std=True, 59 | beta=False) 60 | 61 | model = model.cuda() 62 | model.eval() 63 | if not os.path.isfile(deeplab_file_spec['file_path']): 64 | print('Downloading DeeplabV3 Model parameters') 65 | with requests.Session() as session: 66 | download_file(session, deeplab_file_spec) 67 | 68 | print('Done!') 69 | 70 | checkpoint = torch.load(model_fname) 71 | state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'tracked' not in k} 72 | model.load_state_dict(state_dict) 73 | for i in range(len(dataset)): 74 | inputs=dataset[i] 75 | inputs = inputs.cuda() 76 | outputs = model(inputs.unsqueeze(0)) 77 | _, pred = torch.max(outputs, 1) 78 | pred = pred.data.cpu().numpy().squeeze().astype(np.uint8) 79 | imname = os.path.basename(dataset.images[i]) 80 | mask_pred = Image.fromarray(pred) 81 | mask_pred=mask_pred.resize((resolution,resolution), Image.NEAREST) 82 | try: 83 | mask_pred.save(dataset.images[i].replace(imname,'parsings/'+imname[:-4]+'.png')) 84 | except FileNotFoundError: 85 | os.makedirs(os.path.join(os.path.dirname(dataset.images[i]),'parsings')) 86 | mask_pred.save(dataset.images[i].replace(imname,'parsings/'+imname[:-4]+'.png')) 87 | 88 | print('processed {0}/{1} images'.format(i + 1, len(dataset))) 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Roy Or-El. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # To view a copy of this license, visit 6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 8 | 9 | # This code is a modification of the utils.py file 10 | # from the https://github.com/chenxi116/DeepLabv3.pytorch repository 11 | 12 | 13 | import os 14 | import math 15 | import html 16 | import glob 17 | import uuid 18 | import random 19 | import hashlib 20 | import requests 21 | import numpy as np 22 | import torch 23 | import torchvision.transforms as transforms 24 | from PIL import Image 25 | 26 | 27 | def preprocess_image(image, flip=False, scale=None, crop=None): 28 | if flip: 29 | if random.random() < 0.5: 30 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 31 | if scale: 32 | w, h = image.size 33 | rand_log_scale = math.log(scale[0], 2) + random.random() * (math.log(scale[1], 2) - math.log(scale[0], 2)) 34 | random_scale = math.pow(2, rand_log_scale) 35 | new_size = (int(round(w * random_scale)), int(round(h * random_scale))) 36 | image = image.resize(new_size, Image.ANTIALIAS) 37 | 38 | data_transforms = transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 41 | ]) 42 | image = data_transforms(image) 43 | 44 | return image 45 | 46 | 47 | def download_file(session, file_spec, chunk_size=128, num_attempts=10): 48 | file_path = file_spec['file_path'] 49 | file_url = file_spec['file_url'] 50 | file_dir = os.path.dirname(file_path) 51 | tmp_path = file_path + '.tmp.' + uuid.uuid4().hex 52 | if file_dir: 53 | os.makedirs(file_dir, exist_ok=True) 54 | 55 | for attempts_left in reversed(range(num_attempts)): 56 | data_size = 0 57 | try: 58 | # Download. 59 | data_md5 = hashlib.md5() 60 | with session.get(file_url, stream=True) as res: 61 | res.raise_for_status() 62 | with open(tmp_path, 'wb') as f: 63 | for chunk in res.iter_content(chunk_size=chunk_size<<10): 64 | f.write(chunk) 65 | data_size += len(chunk) 66 | data_md5.update(chunk) 67 | 68 | # Validate. 69 | if 'file_size' in file_spec and data_size != file_spec['file_size']: 70 | raise IOError('Incorrect file size', file_path) 71 | if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']: 72 | raise IOError('Incorrect file MD5', file_path) 73 | break 74 | 75 | except: 76 | # Last attempt => raise error. 77 | if not attempts_left: 78 | raise 79 | 80 | # Handle Google Drive virus checker nag. 81 | if data_size > 0 and data_size < 8192: 82 | with open(tmp_path, 'rb') as f: 83 | data = f.read() 84 | links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'export=download' in link] 85 | if len(links) == 1: 86 | file_url = requests.compat.urljoin(file_url, links[0]) 87 | continue 88 | 89 | # Rename temp file to the correct name. 90 | os.replace(tmp_path, file_path) # atomic 91 | 92 | # Attempt to clean up any leftover temps. 93 | for filename in glob.glob(file_path + '.tmp.*'): 94 | try: 95 | os.remove(filename) 96 | except: 97 | pass 98 | --------------------------------------------------------------------------------