├── .gitignore
├── LICENSE.txt
├── README.md
├── data_loader.py
├── deeplab.py
├── download_ffhq_aging.py
├── ffhq_aging_labels.csv
├── get_ffhq_aging.bat
├── get_ffhq_aging.sh
├── images
├── age_distribution.png
└── dataset_samples_github.png
├── pydrive_utils.py
├── requirements.txt
├── run_deeplab.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ffhq_aging*/
2 | in-the-wild-images/
3 | deeplab_model/
4 | __pycache__/
5 | *.json
6 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright Roy Or-El, 2020
2 |
3 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
4 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
5 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
6 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
7 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
8 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
9 | USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
11 | The dataset is made available under Creative Commons BY-NC-SA 4.0 license
12 | by University of Washington. You can use, redistribute, and adapt it
13 | for non-commercial purposes, as long as you (a) give appropriate credit
14 | by citing our paper, (b) indicate any changes that you've made,
15 | and (c) distribute any derivative works under the same license.
16 |
17 | Lifespan Age Transformation Synthesis
18 | Roy Or-El, Soumyadip Sengupta, Ohad Fried, Eli Shechtman, Ira Kemelmacher-Shlizerman
19 | https://arxiv.org/pdf/2003.09764.pdf
20 |
21 | The individual images were published in Flickr by their respective authors
22 | under either Creative Commons BY 2.0, Creative Commons BY-NC 2.0,
23 | Public Domain Mark 1.0, Public Domain CC0 1.0, or U.S. Government Works
24 | license. All of these licenses allow free use, redistribution, and adaptation
25 | for non-commercial purposes. However, some of them require giving appropriate
26 | credit to the original author, as well as indicating any changes that were
27 | made to the images. The license and original author of each image are
28 | indicated in the metadata.
29 |
30 | https://creativecommons.org/licenses/by/2.0/
31 | https://creativecommons.org/licenses/by-nc/2.0/
32 | https://creativecommons.org/publicdomain/mark/1.0/
33 | https://creativecommons.org/publicdomain/zero/1.0/
34 | http://www.usa.gov/copyright.shtml
35 |
36 | The JSON metadata is made available under Creative Commons BY-NC-SA 4.0 license by NVIDIA Corporation.
37 |
38 | The individual images and JSON metadata are hosted on NVIDIA's Google Drive,
39 | please see the original FFHQ dataset for more details.
40 |
41 | https://github.com/NVlabs/ffhq-dataset
42 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FFHQ-Aging Dataset
2 | ### [Project Page](https://grail.cs.washington.edu/projects/lifespan_age_transformation_synthesis/) | [Paper](https://arxiv.org/pdf/2003.09764.pdf) | [Lifespan Age Transformation Synthesis Code](https://github.com/royorel/Lifespan_Age_Transformation_Synthesis)
3 |

4 |
5 | ## Overview
6 | FFHQ-Aging is a Dataset of human faces designed for benchmarking age transformation algorithms as well as many other possible vision tasks.
7 |
8 | This dataset is an extention of the NVIDIA [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset), on top of the 70,000 original FFHQ images, it also contains the following information for each image:
9 | 1. Gender information (male/female with confidence score)
10 | 2. Age group information (10 classes with confidence score)
11 | 3. Head pose (pitch, roll & yaw)
12 | 4. Glasses type (none, normal or dark)
13 | 5. Eye occlusion score (0-100, different score for each eye)
14 | 6. Full semantic map (19 classes, based on CelebAMask-HQ labels)
15 |
16 | If you use this dataset for your work, please cite our paper:
17 | > **Lifespan Age Transformation Synthesis**
18 | > Roy Or-El, Soumyadip Sengupta, Ohad Fried, Eli Shechtman, Ira Kemelmacher-Shlizerman
19 | > ECCV 2020
20 | > https://arxiv.org/pdf/2003.09764.pdf
21 |
22 | ## Dataset Statistics
23 | The following histogram shows the age class distribution per gender.
24 |
25 | 
26 |
27 | Gender labels & confidence, age class labels & confidence score, head pose, glasses type and left & right eye occlusion scores for each individual image are stored in **ffhq_aging_labels.csv**.
28 |
29 | ## Pre-Requisits
30 | You must have a **GPU with CUDA support** in order to run the segmentation code.
31 |
32 | This code requires **PyTorch** to be installed, please go to [Pytorch.org](https://pytorch.org/) for installation info.
33 | In addition, the following python packages should be installed:
34 | 1. requests
35 | 2. pillow
36 | 3. numpy
37 | 4. scipy
38 | 5. PyDrive
39 |
40 | If any of these packages are not installed on your computer, you can install them using the supplied `requirements.txt` file:
41 | ```pip install -r requirements.txt```
42 |
43 | **Note for windows users:** make sure that you have a 64bit python version installed. Otherwise you might get a memory error when reading the FFHQ JSON file.
44 |
45 | ## Usage
46 |
47 | ### Default download method
48 | To download the dataset in the default resolution (256x256) run:
49 | Linux & Mac: ```./get_ffhq_aging.sh```
50 | Windows: ```get_ffhq_aging.bat```
51 |
52 | If you encounter a "quota exceeded" error, see [Downloading with PyDrive](#downloading-with-pydrive)
53 |
54 | ### Downloading with PyDrive
55 | Google drive enforces a quota on file download by anonymous users.
56 | If you encounter a "quota exceeded" error, either wait 24 hours for the quota limit to reset and try again, or follow the procedure below.
57 |
58 | #### Step 1: Add the original FFHQ dataset to the "Shared With Me" section of your Google Drive
59 | Note: this step does *not* count against your Google Drive storage limit.
60 |
61 | * Login to your Google Drive
62 | * Visit [ffhq-dataset](https://drive.google.com/drive/folders/1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP)
63 |
64 | #### Step 2: Enable the Google Drive API
65 | Note: this only applies to *your* download script, and does not give access to other users.
66 | Nevertheless, we recommend revoking the script's access after the download is complete.
67 |
68 | * Go to : https://developers.google.com/drive/api/v3/quickstart/python
69 | * Click on enable drive API
70 | * Select Desktop app
71 | * Download client configuration
72 | * Rename this file to `client_secrets.json` and place it in the same folder as the download script (`download_ffhq_aging.py`).
73 |
74 | **Update (4/29/2021): Google have updated this page, please follow the prerequisists section of the updated page to get the credential files**
75 |
76 | #### Step 3: Run the script
77 | * In order to run the code with authntication, edit the `get_ffhq_aging.sh/bat` script, and add the `--pydrive` flag when invoking `download_ffhq_aging.py`. This will open a browser authentication window. Log in to your account and allow access.
78 | * If you have no display (like when running from a remote compute server), edit the `get_ffhq_aging.sh/bat` script, and also add the `--cmd_auth` flag when invoking `download_ffhq_aging.py`. This will print a Google authentication link to the screen. Open the link in any browser, allow access, and paste the Google authentication token back to the command line.
79 |
80 | **Important Note**: using this will let the code access your Google Drive, which might pose a security risk.
81 | We recommend using it only in cases when the default interface consistently returns a quota exceeded error.
82 | In addition, we recommend to disable the drive API and delete `client_secrets.json` after the dataset download is complete.
83 |
84 | ### Optional Arguments
85 | **download_ffhq_aging.py**
86 | ```
87 | --debug run in debug mode, download 50 random images (default: False)
88 | --pydrive use pydrive interface to download files. It can override google drive quota limitation
89 | this requires google credentials (default: False)
90 | --cmd_auth use command line google authentication when using pydrive interface
91 | this is good when running on a server with no display (default: False)
92 | --check_invalid_images checks for any invalid images and downloads them again
93 | --resolution final resolution of saved images (default: 256)
94 | --num_threads NUM number of concurrent download threads (default: 32)
95 | --num_attempts NUM number of download attempts per file (default: 10)
96 | ```
97 |
98 | **run_deeplab.py**
99 | ```
100 | --resolution segmentation output size (default: 256)
101 | --workers number of data loading workers (default: 4)
102 | ```
103 |
104 | Please make sure that the `--resolution` option for both scripts is the same
105 |
106 |
107 | ## License & Privacy
108 | The dataset is made available under [Creative Commons BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license by University of Washington. You can **use, redistribute, and adapt it for non-commercial purposes**, as long as you (a) give appropriate credit by **citing our paper**, (b) **indicate any changes** that you've made, and (c) distribute any derivative works **under the same license**.
109 |
110 | The individual images were published in Flickr by their respective authors under either [Creative Commons BY 2.0](https://creativecommons.org/licenses/by/2.0/), [Creative Commons BY-NC 2.0](https://creativecommons.org/licenses/by-nc/2.0/), [Public Domain Mark 1.0](https://creativecommons.org/publicdomain/mark/1.0/), [Public Domain CC0 1.0](https://creativecommons.org/publicdomain/zero/1.0/), or [U.S. Government Works](http://www.usa.gov/copyright.shtml) license. All of these licenses allow **free use, redistribution, and adaptation for non-commercial purposes**. However, some of them require giving **appropriate credit** to the original author, as well as **indicating any changes** that were made to the images. The license and original author of each image are indicated in the metadata.
111 |
112 | * [https://creativecommons.org/licenses/by/2.0/](https://creativecommons.org/licenses/by/2.0/)
113 | * [https://creativecommons.org/licenses/by-nc/2.0/](https://creativecommons.org/licenses/by-nc/2.0/)
114 | * [https://creativecommons.org/publicdomain/mark/1.0/](https://creativecommons.org/publicdomain/mark/1.0/)
115 | * [https://creativecommons.org/publicdomain/zero/1.0/](https://creativecommons.org/publicdomain/zero/1.0/)
116 | * [http://www.usa.gov/copyright.shtml](http://www.usa.gov/copyright.shtml)
117 |
118 | The JSON metadata is made available under [Creative Commons BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license by NVIDIA Corporation.
119 |
120 | The individual images and JSON metadata are hosted on NVIDIA's Google Drive, please see the original [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset) for more details.
121 |
122 | To find out whether your photo is included in the original Flickr-Faces-HQ dataset and/or get it removed from both this dataset and the original FFHQ dataset please go to the privacy section in the original [FFHQ Dataset website](https://github.com/NVlabs/ffhq-dataset) and follow the instructions.
123 |
124 | ## Acknowledgements
125 | We wish to thank Thevina Dokka for helping us collecting the dataset.
126 |
127 | Original face images were collected in the [NVIDIA FFHQ dataset](https://github.com/NVlabs/ffhq-dataset).
128 | > **A Style-Based Generator Architecture for Generative Adversarial Networks**
129 | > Tero Karras, Samuli Laine, Timo Aila, CVPR 2019
130 | > http://openaccess.thecvf.com/content_CVPR_2019/papers/Karras_A_Style-Based_Generator_Architecture_for_Generative_Adversarial_Networks_CVPR_2019_paper.pdf
131 |
132 | Age & gender labels and confidence scores were collected using the [Appen](https://www.appen.com/) platform.
133 |
134 | Head pose, glasses type and eye occlusion score were extraceted using the [Face++](https://www.faceplusplus.com/) platform.
135 |
136 | Face Semantic maps were acquired by training a pytorch implementation of [DeepLabV3](https://github.com/chenxi116/DeepLabv3.pytorch) network on the [CelebAMASK-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset.
137 | > **Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation**
138 | > Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam, ECCV 2018
139 | > http://openaccess.thecvf.com/content_ECCV_2018/papers/Liang-Chieh_Chen_Encoder-Decoder_with_Atrous_ECCV_2018_paper.pdf
140 |
141 | > **MaskGAN: Towards Diverse and Interactive Facial Image Manipulation**
142 | > Cheng-Han Lee, Ziwei Liu, Lingyun Wu, Ping Luo, CVPR 2020
143 | > https://arxiv.org/pdf/1907.11922.pdf
144 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Roy Or-El. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # To view a copy of this license, visit
6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to
7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
8 |
9 |
10 | import torch.utils.data as data
11 | import os
12 | from PIL import Image
13 | from utils import preprocess_image
14 |
15 |
16 | class CelebASegmentation(data.Dataset):
17 | CLASSES = ['background' ,'skin','nose','eye_g','l_eye','r_eye','l_brow','r_brow','l_ear','r_ear','mouth','u_lip','l_lip','hair','hat','ear_r','neck_l','neck','cloth']
18 |
19 | def __init__(self, root, transform=None, crop_size=None):
20 | self.root = root
21 | self.transform = transform
22 | self.crop_size = crop_size
23 |
24 | self.images = []
25 | subdirs = next(os.walk(self.root))[1] #quick trick to get all subdirectories
26 | for subdir in subdirs:
27 | curr_images = [os.path.join(self.root,subdir,file) for file in os.listdir(os.path.join(self.root,subdir)) if file.endswith('.png')]
28 | self.images += curr_images
29 |
30 |
31 | def __getitem__(self, index):
32 | _img = Image.open(self.images[index]).convert('RGB')
33 | _img=_img.resize((513,513),Image.BILINEAR)
34 | _img = preprocess_image(_img,flip=False,scale=None,crop=(self.crop_size, self.crop_size))
35 |
36 | if self.transform is not None:
37 | _img = self.transform(_img)
38 |
39 | return _img
40 |
41 | def __len__(self):
42 | return len(self.images)
43 |
--------------------------------------------------------------------------------
/deeplab.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Roy Or-El. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # To view a copy of this license, visit
6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to
7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
8 |
9 | # This file was taken as is from the https://github.com/chenxi116/DeepLabv3.pytorch repository.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import math
14 | import torch.utils.model_zoo as model_zoo
15 | from torch.nn import functional as F
16 |
17 |
18 | __all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152']
19 |
20 |
21 | model_urls = {
22 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
23 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
24 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
25 | }
26 |
27 |
28 | class Conv2d(nn.Conv2d):
29 |
30 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
31 | padding=0, dilation=1, groups=1, bias=True):
32 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
33 | padding, dilation, groups, bias)
34 |
35 | def forward(self, x):
36 | # return super(Conv2d, self).forward(x)
37 | weight = self.weight
38 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
39 | keepdim=True).mean(dim=3, keepdim=True)
40 | weight = weight - weight_mean
41 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
42 | weight = weight / std.expand_as(weight)
43 | return F.conv2d(x, weight, self.bias, self.stride,
44 | self.padding, self.dilation, self.groups)
45 |
46 |
47 | class ASPP(nn.Module):
48 |
49 | def __init__(self, C, depth, num_classes, conv=nn.Conv2d, norm=nn.BatchNorm2d, momentum=0.0003, mult=1):
50 | super(ASPP, self).__init__()
51 | self._C = C
52 | self._depth = depth
53 | self._num_classes = num_classes
54 |
55 | self.global_pooling = nn.AdaptiveAvgPool2d(1)
56 | self.relu = nn.ReLU(inplace=True)
57 | self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False)
58 | self.aspp2 = conv(C, depth, kernel_size=3, stride=1,
59 | dilation=int(6*mult), padding=int(6*mult),
60 | bias=False)
61 | self.aspp3 = conv(C, depth, kernel_size=3, stride=1,
62 | dilation=int(12*mult), padding=int(12*mult),
63 | bias=False)
64 | self.aspp4 = conv(C, depth, kernel_size=3, stride=1,
65 | dilation=int(18*mult), padding=int(18*mult),
66 | bias=False)
67 | self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False)
68 | self.aspp1_bn = norm(depth, momentum)
69 | self.aspp2_bn = norm(depth, momentum)
70 | self.aspp3_bn = norm(depth, momentum)
71 | self.aspp4_bn = norm(depth, momentum)
72 | self.aspp5_bn = norm(depth, momentum)
73 | self.conv2 = conv(depth * 5, depth, kernel_size=1, stride=1,
74 | bias=False)
75 | self.bn2 = norm(depth, momentum)
76 | self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1)
77 |
78 | def forward(self, x):
79 | x1 = self.aspp1(x)
80 | x1 = self.aspp1_bn(x1)
81 | x1 = self.relu(x1)
82 | x2 = self.aspp2(x)
83 | x2 = self.aspp2_bn(x2)
84 | x2 = self.relu(x2)
85 | x3 = self.aspp3(x)
86 | x3 = self.aspp3_bn(x3)
87 | x3 = self.relu(x3)
88 | x4 = self.aspp4(x)
89 | x4 = self.aspp4_bn(x4)
90 | x4 = self.relu(x4)
91 | x5 = self.global_pooling(x)
92 | x5 = self.aspp5(x5)
93 | x5 = self.aspp5_bn(x5)
94 | x5 = self.relu(x5)
95 | x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear',
96 | align_corners=True)(x5)
97 | x = torch.cat((x1, x2, x3, x4, x5), 1)
98 | x = self.conv2(x)
99 | x = self.bn2(x)
100 | x = self.relu(x)
101 | x = self.conv3(x)
102 |
103 | return x
104 |
105 |
106 | class Bottleneck(nn.Module):
107 | expansion = 4
108 |
109 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, conv=None, norm=None):
110 | super(Bottleneck, self).__init__()
111 | self.conv1 = conv(inplanes, planes, kernel_size=1, bias=False)
112 | self.bn1 = norm(planes)
113 | self.conv2 = conv(planes, planes, kernel_size=3, stride=stride,
114 | dilation=dilation, padding=dilation, bias=False)
115 | self.bn2 = norm(planes)
116 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, bias=False)
117 | self.bn3 = norm(planes * self.expansion)
118 | self.relu = nn.ReLU(inplace=True)
119 | self.downsample = downsample
120 | self.stride = stride
121 |
122 | def forward(self, x):
123 | residual = x
124 |
125 | out = self.conv1(x)
126 | out = self.bn1(out)
127 | out = self.relu(out)
128 |
129 | out = self.conv2(out)
130 | out = self.bn2(out)
131 | out = self.relu(out)
132 |
133 | out = self.conv3(out)
134 | out = self.bn3(out)
135 |
136 | if self.downsample is not None:
137 | residual = self.downsample(x)
138 |
139 | out += residual
140 | out = self.relu(out)
141 |
142 | return out
143 |
144 |
145 | class ResNet(nn.Module):
146 |
147 | def __init__(self, block, layers, num_classes, num_groups=None, weight_std=False, beta=False):
148 | self.inplanes = 64
149 | self.norm = lambda planes, momentum=0.05: nn.BatchNorm2d(planes, momentum=momentum) if num_groups is None else nn.GroupNorm(num_groups, planes)
150 | self.conv = Conv2d if weight_std else nn.Conv2d
151 |
152 | super(ResNet, self).__init__()
153 | if not beta:
154 | self.conv1 = self.conv(3, 64, kernel_size=7, stride=2, padding=3,
155 | bias=False)
156 | else:
157 | self.conv1 = nn.Sequential(
158 | self.conv(3, 64, 3, stride=2, padding=1, bias=False),
159 | self.conv(64, 64, 3, stride=1, padding=1, bias=False),
160 | self.conv(64, 64, 3, stride=1, padding=1, bias=False))
161 | self.bn1 = self.norm(64)
162 | self.relu = nn.ReLU(inplace=True)
163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
164 | self.layer1 = self._make_layer(block, 64, layers[0])
165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
167 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
168 | dilation=2)
169 | self.aspp = ASPP(512 * block.expansion, 256, num_classes, conv=self.conv, norm=self.norm)
170 |
171 | for m in self.modules():
172 | if isinstance(m, self.conv):
173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
174 | m.weight.data.normal_(0, math.sqrt(2. / n))
175 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
176 | m.weight.data.fill_(1)
177 | m.bias.data.zero_()
178 |
179 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
180 | downsample = None
181 | if stride != 1 or dilation != 1 or self.inplanes != planes * block.expansion:
182 | downsample = nn.Sequential(
183 | self.conv(self.inplanes, planes * block.expansion,
184 | kernel_size=1, stride=stride, dilation=max(1, dilation/2), bias=False),
185 | self.norm(planes * block.expansion),
186 | )
187 |
188 | layers = []
189 | layers.append(block(self.inplanes, planes, stride, downsample, dilation=max(1, dilation/2), conv=self.conv, norm=self.norm))
190 | self.inplanes = planes * block.expansion
191 | for i in range(1, blocks):
192 | layers.append(block(self.inplanes, planes, dilation=dilation, conv=self.conv, norm=self.norm))
193 |
194 | return nn.Sequential(*layers)
195 |
196 | def forward(self, x):
197 | size = (x.shape[2], x.shape[3])
198 | x = self.conv1(x)
199 | x = self.bn1(x)
200 | x = self.relu(x)
201 | x = self.maxpool(x)
202 |
203 | x = self.layer1(x)
204 | x = self.layer2(x)
205 | x = self.layer3(x)
206 | x = self.layer4(x)
207 |
208 | x = self.aspp(x)
209 | x = nn.Upsample(size, mode='bilinear', align_corners=True)(x)
210 | return x
211 |
212 |
213 | def resnet50(pretrained=False, **kwargs):
214 | """Constructs a ResNet-50 model.
215 |
216 | Args:
217 | pretrained (bool): If True, returns a model pre-trained on ImageNet
218 | """
219 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
220 | if pretrained:
221 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
222 | return model
223 |
224 |
225 | def resnet101(pretrained=False, num_groups=None, weight_std=False, **kwargs):
226 | """Constructs a ResNet-101 model.
227 |
228 | Args:
229 | pretrained (bool): If True, returns a model pre-trained on ImageNet
230 | """
231 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_groups=num_groups, weight_std=weight_std, **kwargs)
232 | if pretrained:
233 | model_dict = model.state_dict()
234 | if num_groups and weight_std:
235 | pretrained_dict = torch.load('deeplab_model/R-101-GN-WS.pth.tar')
236 | overlap_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
237 | assert len(overlap_dict) == 312
238 | elif not num_groups and not weight_std:
239 | pretrained_dict = model_zoo.load_url(model_urls['resnet101'])
240 | overlap_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
241 | else:
242 | raise ValueError('Currently only support BN or GN+WS')
243 | model_dict.update(overlap_dict)
244 | model.load_state_dict(model_dict)
245 | return model
246 |
247 |
248 | def resnet152(pretrained=False, **kwargs):
249 | """Constructs a ResNet-152 model.
250 |
251 | Args:
252 | pretrained (bool): If True, returns a model pre-trained on ImageNet
253 | """
254 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
255 | if pretrained:
256 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
257 | return model
258 |
--------------------------------------------------------------------------------
/download_ffhq_aging.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Roy Or-El. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # To view a copy of this license, visit
6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to
7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
8 |
9 | # This code is a modification of the download_ffhq.py file from the original FFHQ dataset.
10 | # Here we download an in-the-wild-image, do the alignment and delete the original in-the-wild image.
11 |
12 | """Download Flickr-Face-HQ-Aging (FFHQ-Aging) dataset to current working directory."""
13 |
14 | import os
15 | import sys
16 | import requests
17 | import html
18 | import hashlib
19 | import PIL.Image
20 | import PIL.ImageFile
21 | import numpy as np
22 | import scipy.ndimage
23 | import threading
24 | import queue
25 | import time
26 | import json
27 | import uuid
28 | import glob
29 | import argparse
30 | import itertools
31 | import shutil
32 | import pydrive_utils
33 | from collections import OrderedDict, defaultdict
34 | from pdb import set_trace as st
35 |
36 | PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | json_spec = dict(file_url='https://drive.google.com/uc?id=16N0RV4fHI6joBuKbQAoG34V_cQk7vxSA', file_path='ffhq-dataset-v2.json', file_size=267793842, file_md5='425ae20f06a4da1d4dc0f46d40ba5fd6')
41 |
42 | license_specs = {
43 | 'json': dict(file_url='https://drive.google.com/uc?id=1SHafCugkpMZzYhbgOz0zCuYiy-hb9lYX', file_path='LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'),
44 | 'images': dict(file_url='https://drive.google.com/uc?id=1sP2qz8TzLkzG2gjwAa4chtdB31THska4', file_path='images1024x1024/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'),
45 | 'thumbs': dict(file_url='https://drive.google.com/uc?id=1iaL1S381LS10VVtqu-b2WfF9TiY75Kmj', file_path='thumbnails128x128/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'),
46 | 'wilds': dict(file_url='https://drive.google.com/uc?id=1rsfFOEQvkd6_Z547qhpq5LhDl2McJEzw', file_path='in-the-wild-images/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'),
47 | 'tfrecords': dict(file_url='https://drive.google.com/uc?id=1SYUmqKdLoTYq-kqsnPsniLScMhspvl5v', file_path='tfrecords/ffhq/LICENSE.txt', file_size=1610, file_md5='724f3831aaecd61a84fe98500079abc2'),
48 | }
49 |
50 | #----------------------------------------------------------------------------
51 |
52 | def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10):
53 | file_path = file_spec['file_path']
54 | file_url = file_spec['file_url']
55 | file_dir = os.path.dirname(file_path)
56 | tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
57 | if file_dir:
58 | os.makedirs(file_dir, exist_ok=True)
59 |
60 | for attempts_left in reversed(range(num_attempts)):
61 | data_size = 0
62 | try:
63 | # Download.
64 | data_md5 = hashlib.md5()
65 | with session.get(file_url, stream=True) as res:
66 | res.raise_for_status()
67 | with open(tmp_path, 'wb') as f:
68 | for chunk in res.iter_content(chunk_size=chunk_size<<10):
69 | f.write(chunk)
70 | data_size += len(chunk)
71 | data_md5.update(chunk)
72 | with stats['lock']:
73 | stats['bytes_done'] += len(chunk)
74 |
75 | # Validate.
76 | if 'file_size' in file_spec and data_size != file_spec['file_size']:
77 | raise IOError('Incorrect file size', file_path)
78 | if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
79 | raise IOError('Incorrect file MD5', file_path)
80 | if 'pixel_size' in file_spec or 'pixel_md5' in file_spec:
81 | with PIL.Image.open(tmp_path) as image:
82 | if 'pixel_size' in file_spec and list(image.size) != file_spec['pixel_size']:
83 | raise IOError('Incorrect pixel size', file_path)
84 | if 'pixel_md5' in file_spec and hashlib.md5(np.array(image)).hexdigest() != file_spec['pixel_md5']:
85 | raise IOError('Incorrect pixel MD5', file_path)
86 | break
87 |
88 | except:
89 | with stats['lock']:
90 | stats['bytes_done'] -= data_size
91 |
92 | # Handle known failure cases.
93 | if data_size > 0 and data_size < 8192:
94 | with open(tmp_path, 'rb') as f:
95 | data = f.read()
96 | data_str = data.decode('utf-8')
97 |
98 | # Google Drive virus checker nag.
99 | links = [html.unescape(link) for link in data_str.split('"') if 'export=download' in link]
100 | if len(links) == 1:
101 | if attempts_left:
102 | file_url = requests.compat.urljoin(file_url, links[0])
103 | continue
104 |
105 | # Google Drive quota exceeded.
106 | if 'Google Drive - Quota exceeded' in data_str:
107 | if not attempts_left:
108 | raise IOError("Google Drive download quota exceeded -- please try again later")
109 |
110 | # Last attempt => raise error.
111 | if not attempts_left:
112 | raise
113 |
114 | # Rename temp file to the correct name.
115 | os.replace(tmp_path, file_path) # atomic
116 | # with stats['lock']:
117 | # stats['files_done'] += 1
118 |
119 | # Attempt to clean up any leftover temps.
120 | for filename in glob.glob(file_path + '.tmp.*'):
121 | try:
122 | os.remove(filename)
123 | except:
124 | pass
125 |
126 | #----------------------------------------------------------------------------
127 |
128 | def choose_bytes_unit(num_bytes):
129 | b = int(np.rint(num_bytes))
130 | if b < (100 << 0): return 'B', (1 << 0)
131 | if b < (100 << 10): return 'kB', (1 << 10)
132 | if b < (100 << 20): return 'MB', (1 << 20)
133 | if b < (100 << 30): return 'GB', (1 << 30)
134 | return 'TB', (1 << 40)
135 |
136 | #----------------------------------------------------------------------------
137 |
138 | def format_time(seconds):
139 | s = int(np.rint(seconds))
140 | if s < 60: return '%ds' % s
141 | if s < 60 * 60: return '%dm %02ds' % (s // 60, s % 60)
142 | if s < 24 * 60 * 60: return '%dh %02dm' % (s // (60 * 60), (s // 60) % 60)
143 | if s < 100 * 24 * 60 * 60: return '%dd %02dh' % (s // (24 * 60 * 60), (s // (60 * 60)) % 24)
144 | return '>100d'
145 |
146 | #----------------------------------------------------------------------------
147 |
148 | def download_files(file_specs, dst_dir='.', output_size=256, check_invalid_images=False, drive=None, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):
149 |
150 | # Determine which files to download.
151 | done_specs = {}
152 | for spec in file_specs:
153 | if os.path.isfile(spec['file_path'].replace('in-the-wild-images',dst_dir)):
154 | if check_invalid_images:
155 | try:
156 | test_im = PIL.Image.open(spec['file_path'].replace('in-the-wild-images',dst_dir))
157 | done_specs.update({spec['file_path']: spec})
158 | except:
159 | continue
160 | else:
161 | done_specs.update({spec['file_path']: spec})
162 |
163 | missing_specs = [spec for spec in file_specs if spec['file_path'] not in done_specs]
164 | files_total = len(file_specs)
165 | bytes_total = sum(spec['file_size'] for spec in file_specs)
166 | stats = dict(files_done=len(done_specs), bytes_done=sum(spec['file_size'] for spec in done_specs.values()), lock=threading.Lock())
167 | if len(done_specs) == files_total:
168 | print('All files already downloaded -- skipping.')
169 | return
170 |
171 | # Launch worker threads.
172 | spec_queue = queue.Queue()
173 | exception_queue = queue.Queue()
174 | for spec in missing_specs:
175 | spec_queue.put(spec)
176 | thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue,
177 | stats=stats, dst_dir=dst_dir, output_size=output_size,
178 | drive=drive, download_kwargs=download_kwargs)
179 | for _thread_idx in range(min(num_threads, len(missing_specs))):
180 | threading.Thread(target=_download_thread, kwargs=thread_kwargs, daemon=True).start()
181 |
182 | # Monitor status until done.
183 | bytes_unit, bytes_div = choose_bytes_unit(bytes_total)
184 | spinner = '/-\\|'
185 | timing = []
186 | while True:
187 | spinner = spinner[1:] + spinner[:1]
188 | if drive != None:
189 | with stats['lock']:
190 | files_done = stats['files_done']
191 |
192 | print('\r{} done processing {}/{} files'.format(spinner[0], files_done, files_total),
193 | end='', flush=True)
194 | else:
195 | with stats['lock']:
196 | files_done = stats['files_done']
197 | bytes_done = stats['bytes_done']
198 | timing = timing[max(len(timing) - timing_window + 1, 0):] + [(time.time(), bytes_done)]
199 | bandwidth = max((timing[-1][1] - timing[0][1]) / max(timing[-1][0] - timing[0][0], 1e-8), 0)
200 | bandwidth_unit, bandwidth_div = choose_bytes_unit(bandwidth)
201 | eta = format_time((bytes_total - bytes_done) / max(bandwidth, 1))
202 |
203 | print('\r%s %6.2f%% done processed %d/%d files %-13s %-10s ETA: %-7s ' % (
204 | spinner[0],
205 | bytes_done / bytes_total * 100,
206 | files_done, files_total,
207 | 'downloaded %.2f/%.2f %s' % (bytes_done / bytes_div, bytes_total / bytes_div, bytes_unit),
208 | '%.2f %s/s' % (bandwidth / bandwidth_div, bandwidth_unit),
209 | 'done' if bytes_total == bytes_done else '...' if len(timing) < timing_window or bandwidth == 0 else eta,
210 | ), end='', flush=True)
211 |
212 | if files_done == files_total:
213 | print()
214 | break
215 |
216 |
217 | try:
218 | exc_info = exception_queue.get(timeout=status_delay)
219 | raise exc_info[1].with_traceback(exc_info[2])
220 | except queue.Empty:
221 | pass
222 |
223 | def _download_thread(spec_queue, exception_queue, stats, dst_dir, output_size, drive, download_kwargs):
224 | with requests.Session() as session:
225 | while not spec_queue.empty():
226 | spec = spec_queue.get()
227 | try:
228 | if drive != None:
229 | pydrive_utils.pydrive_download(drive, spec['file_url'], spec['file_path'])
230 | else:
231 | download_file(session, spec, stats, **download_kwargs)
232 |
233 | if spec['file_path'].endswith('.png'):
234 | align_in_the_wild_image(spec, dst_dir, output_size)
235 | os.remove(spec['file_path'])
236 |
237 | except:
238 | exception_queue.put(sys.exc_info())
239 |
240 | with stats['lock']:
241 | stats['files_done'] += 1
242 |
243 | #----------------------------------------------------------------------------
244 |
245 | def align_in_the_wild_image(spec, dst_dir, output_size, transform_size=4096, enable_padding=True):
246 | if not os.path.isdir(dst_dir):
247 | os.makedirs(dst_dir, exist_ok=True)
248 | shutil.copyfile('LICENSE.txt', os.path.join(dst_dir, 'LICENSE.txt'))
249 |
250 | item_idx = int(os.path.basename(spec['file_path'])[:-4])
251 |
252 | # Parse landmarks.
253 | # pylint: disable=unused-variable
254 | lm = np.array(spec['face_landmarks'])
255 | lm_chin = lm[0 : 17] # left-right
256 | lm_eyebrow_left = lm[17 : 22] # left-right
257 | lm_eyebrow_right = lm[22 : 27] # left-right
258 | lm_nose = lm[27 : 31] # top-down
259 | lm_nostrils = lm[31 : 36] # top-down
260 | lm_eye_left = lm[36 : 42] # left-clockwise
261 | lm_eye_right = lm[42 : 48] # left-clockwise
262 | lm_mouth_outer = lm[48 : 60] # left-clockwise
263 | lm_mouth_inner = lm[60 : 68] # left-clockwise
264 |
265 | # Calculate auxiliary vectors.
266 | eye_left = np.mean(lm_eye_left, axis=0)
267 | eye_right = np.mean(lm_eye_right, axis=0)
268 | eye_avg = (eye_left + eye_right) * 0.5
269 | eye_to_eye = eye_right - eye_left
270 | mouth_left = lm_mouth_outer[0]
271 | mouth_right = lm_mouth_outer[6]
272 | mouth_avg = (mouth_left + mouth_right) * 0.5
273 | eye_to_mouth = mouth_avg - eye_avg
274 |
275 | # Choose oriented crop rectangle.
276 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
277 | x /= np.hypot(*x)
278 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 2.2) # This results in larger crops then the original FFHQ. For the original crops, replace 2.2 with 1.8
279 | y = np.flipud(x) * [-1, 1]
280 | c = eye_avg + eye_to_mouth * 0.1
281 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
282 | qsize = np.hypot(*x) * 2
283 |
284 | # Load in-the-wild image.
285 | src_file = spec['file_path']
286 | if not os.path.isfile(src_file):
287 | print('\nCannot find source image. Please run "--wilds" before "--align".')
288 | return
289 | img = PIL.Image.open(src_file)
290 |
291 | # Shrink.
292 | shrink = int(np.floor(qsize / output_size * 0.5))
293 | if shrink > 1:
294 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
295 | img = img.resize(rsize, PIL.Image.ANTIALIAS)
296 | quad /= shrink
297 | qsize /= shrink
298 |
299 | # Crop.
300 | border = max(int(np.rint(qsize * 0.1)), 3)
301 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
302 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
303 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
304 | img = img.crop(crop)
305 | quad -= crop[0:2]
306 |
307 | # Pad.
308 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
309 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
310 | if enable_padding and max(pad) > border - 4:
311 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
312 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
313 | h, w, _ = img.shape
314 | y, x, _ = np.ogrid[:h, :w, :1]
315 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
316 | blur = qsize * 0.02
317 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
318 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
319 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
320 | quad += pad[:2]
321 |
322 | # Transform.
323 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
324 | if output_size < transform_size:
325 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
326 |
327 | # Save aligned image.
328 | dst_subdir = os.path.join(dst_dir, '%05d' % (item_idx - item_idx % 1000))
329 | os.makedirs(dst_subdir, exist_ok=True)
330 | img.save(os.path.join(dst_subdir, '%05d.png' % item_idx))
331 |
332 |
333 | #----------------------------------------------------------------------------
334 |
335 | def run(resolution, debug, pydrive, cmd_auth, check_invalid_images, **download_kwargs):
336 | if pydrive:
337 | drive = pydrive_utils.create_drive_manager(cmd_auth)
338 | else:
339 | drive = None
340 |
341 | if not os.path.isfile(json_spec['file_path']) or not os.path.isfile('LICENSE.txt'):
342 | print('Downloading JSON metadata...')
343 | download_files([json_spec, license_specs['json']], drive=drive, **download_kwargs)
344 |
345 | print('Parsing JSON metadata...')
346 | with open(json_spec['file_path'], 'rb') as f:
347 | json_data = json.load(f, object_pairs_hook=OrderedDict)
348 |
349 | specs = [item['in_the_wild'] for item in json_data.values()] + [license_specs['wilds']]
350 |
351 | if len(specs):
352 | output_size = resolution
353 | dst_dir = 'ffhq_aging{}x{}'.format(output_size,output_size)
354 | np.random.shuffle(specs) # to make the workload more homogeneous
355 | if debug:
356 | specs = specs[:50] # to create images in multiple directories
357 | print('Downloading %d files...' % len(specs))
358 | download_files(specs, dst_dir, output_size, check_invalid_images, drive=drive, **download_kwargs)
359 |
360 | if os.path.isdir('in-the-wild-images'):
361 | shutil.rmtree('in-the-wild-images')
362 |
363 | #----------------------------------------------------------------------------
364 |
365 | def run_cmdline(argv):
366 | parser = argparse.ArgumentParser(prog=argv[0], description='Download Flickr-Face-HQ-Aging (FFHQ-Aging) dataset to current working directory.')
367 | parser.add_argument('--debug', help='activate debug mode, download 50 random images (default: False)', action='store_true')
368 | parser.add_argument('--pydrive', help='use pydrive interface to download files. it overrides google drive quota limitation \
369 | this requires google credentials (default: False)', action='store_true')
370 | parser.add_argument('--cmd_auth', help='use command line google authentication when using pydrive interface (default: False)', action='store_true')
371 | parser.add_argument('--check_invalid_images', help='checks for any invalid images and downloads them again', action='store_true')
372 | parser.add_argument('--resolution', help='final resolution of saved images (default: 256)', type=int, default=256, metavar='PIXELS')
373 | parser.add_argument('--num_threads', help='number of concurrent download threads (default: 32)', type=int, default=32, metavar='NUM')
374 | parser.add_argument('--status_delay', help='time between download status prints (default: 0.2)', type=float, default=0.2, metavar='SEC')
375 | parser.add_argument('--timing_window', help='samples for estimating download eta (default: 50)', type=int, default=50, metavar='LEN')
376 | parser.add_argument('--chunk_size', help='chunk size for each download thread (default: 128)', type=int, default=128, metavar='KB')
377 | parser.add_argument('--num_attempts', help='number of download attempts per file (default: 10)', type=int, default=10, metavar='NUM')
378 |
379 | args = parser.parse_args()
380 | run(**vars(args))
381 |
382 | #----------------------------------------------------------------------------
383 |
384 | if __name__ == "__main__":
385 | run_cmdline(sys.argv)
386 |
387 | #----------------------------------------------------------------------------
388 |
--------------------------------------------------------------------------------
/get_ffhq_aging.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set CUDA_VISIBLE_DEVICES=0
4 |
5 | python download_ffhq_aging.py --resolution 256
6 | python run_deeplab.py --resolution 256
7 |
--------------------------------------------------------------------------------
/get_ffhq_aging.sh:
--------------------------------------------------------------------------------
1 | python download_ffhq_aging.py --resolution 256
2 | CUDA_VISIBLE_DEVICES=0 python run_deeplab.py --resolution 256
3 |
--------------------------------------------------------------------------------
/images/age_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royorel/FFHQ-Aging-Dataset/2ecdcd2511c7e0da7f2a7cf0d839a9f6faafa645/images/age_distribution.png
--------------------------------------------------------------------------------
/images/dataset_samples_github.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/royorel/FFHQ-Aging-Dataset/2ecdcd2511c7e0da7f2a7cf0d839a9f6faafa645/images/dataset_samples_github.png
--------------------------------------------------------------------------------
/pydrive_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | from pydrive.auth import GoogleAuth
4 | from pydrive.drive import GoogleDrive
5 |
6 |
7 | # Authentication + token creation
8 | def create_drive_manager(cmd_auth):
9 | gAuth = GoogleAuth()
10 |
11 | if cmd_auth:
12 | gAuth.CommandLineAuth()
13 | else:
14 | gAuth.LocalWebserverAuth()
15 |
16 | gAuth.Authorize()
17 | print("authorized access to google drive API!")
18 |
19 | drive: GoogleDrive = GoogleDrive(gAuth)
20 | return drive
21 |
22 |
23 | def extract_files_id(drive, link):
24 | try:
25 | fileID = re.search(r"(?<=/d/|id=|rs/).+?(?=/|$)", link)[0] # extract the fileID
26 | return fileID
27 | except Exception as error:
28 | print("error : " + str(error))
29 | print("Link is probably invalid")
30 | print(link)
31 |
32 |
33 | def pydrive_download(drive, link, save_path):
34 | id = extract_files_id(drive, link)
35 | file_dir = os.path.dirname(save_path)
36 | if file_dir:
37 | os.makedirs(file_dir, exist_ok=True)
38 |
39 | pydrive_file = drive.CreateFile({'id': id})
40 | pydrive_file.GetContentFile(save_path)
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | requests
2 | numpy
3 | scipy
4 | pillow
5 | PyDrive
6 |
--------------------------------------------------------------------------------
/run_deeplab.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Roy Or-El. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # To view a copy of this license, visit
6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to
7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
8 |
9 | # This code is a modification of the main.py file
10 | # from the https://github.com/chenxi116/DeepLabv3.pytorch repository
11 |
12 | import argparse
13 | import os
14 | import requests
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | from pdb import set_trace as st
19 | from PIL import Image
20 | from torchvision import transforms
21 |
22 | import deeplab
23 | from data_loader import CelebASegmentation
24 | from utils import download_file
25 |
26 |
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--resolution', type=int, default=256,
29 | help='segmentation output size')
30 | parser.add_argument('--workers', type=int, default=4,
31 | help='number of data loading workers')
32 | args = parser.parse_args()
33 |
34 |
35 | resnet_file_spec = dict(file_url='https://drive.google.com/uc?id=1oRGgrI4KNdefbWVpw0rRkEP1gbJIRokM', file_path='deeplab_model/R-101-GN-WS.pth.tar', file_size=178260167, file_md5='aa48cc3d3ba3b7ac357c1489b169eb32')
36 | deeplab_file_spec = dict(file_url='https://drive.google.com/uc?id=1w2XjDywFr2NjuUWaLQDRktH7VwIfuNlY', file_path='deeplab_model/deeplab_model.pth', file_size=464446305, file_md5='8e8345b1b9d95e02780f9bed76cc0293')
37 |
38 | def main():
39 | resolution = args.resolution
40 | assert torch.cuda.is_available()
41 | torch.backends.cudnn.benchmark = True
42 | model_fname = 'deeplab_model/deeplab_model.pth'
43 | dataset_root = 'ffhq_aging{}x{}'.format(resolution,resolution)
44 | assert os.path.isdir(dataset_root)
45 | dataset = CelebASegmentation(dataset_root, crop_size=513)
46 |
47 | if not os.path.isfile(resnet_file_spec['file_path']):
48 | print('Downloading backbone Resnet Model parameters')
49 | with requests.Session() as session:
50 | download_file(session, resnet_file_spec)
51 |
52 | print('Done!')
53 |
54 | model = getattr(deeplab, 'resnet101')(
55 | pretrained=True,
56 | num_classes=len(dataset.CLASSES),
57 | num_groups=32,
58 | weight_std=True,
59 | beta=False)
60 |
61 | model = model.cuda()
62 | model.eval()
63 | if not os.path.isfile(deeplab_file_spec['file_path']):
64 | print('Downloading DeeplabV3 Model parameters')
65 | with requests.Session() as session:
66 | download_file(session, deeplab_file_spec)
67 |
68 | print('Done!')
69 |
70 | checkpoint = torch.load(model_fname)
71 | state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'tracked' not in k}
72 | model.load_state_dict(state_dict)
73 | for i in range(len(dataset)):
74 | inputs=dataset[i]
75 | inputs = inputs.cuda()
76 | outputs = model(inputs.unsqueeze(0))
77 | _, pred = torch.max(outputs, 1)
78 | pred = pred.data.cpu().numpy().squeeze().astype(np.uint8)
79 | imname = os.path.basename(dataset.images[i])
80 | mask_pred = Image.fromarray(pred)
81 | mask_pred=mask_pred.resize((resolution,resolution), Image.NEAREST)
82 | try:
83 | mask_pred.save(dataset.images[i].replace(imname,'parsings/'+imname[:-4]+'.png'))
84 | except FileNotFoundError:
85 | os.makedirs(os.path.join(os.path.dirname(dataset.images[i]),'parsings'))
86 | mask_pred.save(dataset.images[i].replace(imname,'parsings/'+imname[:-4]+'.png'))
87 |
88 | print('processed {0}/{1} images'.format(i + 1, len(dataset)))
89 |
90 | if __name__ == "__main__":
91 | main()
92 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Roy Or-El. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # To view a copy of this license, visit
6 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to
7 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
8 |
9 | # This code is a modification of the utils.py file
10 | # from the https://github.com/chenxi116/DeepLabv3.pytorch repository
11 |
12 |
13 | import os
14 | import math
15 | import html
16 | import glob
17 | import uuid
18 | import random
19 | import hashlib
20 | import requests
21 | import numpy as np
22 | import torch
23 | import torchvision.transforms as transforms
24 | from PIL import Image
25 |
26 |
27 | def preprocess_image(image, flip=False, scale=None, crop=None):
28 | if flip:
29 | if random.random() < 0.5:
30 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
31 | if scale:
32 | w, h = image.size
33 | rand_log_scale = math.log(scale[0], 2) + random.random() * (math.log(scale[1], 2) - math.log(scale[0], 2))
34 | random_scale = math.pow(2, rand_log_scale)
35 | new_size = (int(round(w * random_scale)), int(round(h * random_scale)))
36 | image = image.resize(new_size, Image.ANTIALIAS)
37 |
38 | data_transforms = transforms.Compose([
39 | transforms.ToTensor(),
40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
41 | ])
42 | image = data_transforms(image)
43 |
44 | return image
45 |
46 |
47 | def download_file(session, file_spec, chunk_size=128, num_attempts=10):
48 | file_path = file_spec['file_path']
49 | file_url = file_spec['file_url']
50 | file_dir = os.path.dirname(file_path)
51 | tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
52 | if file_dir:
53 | os.makedirs(file_dir, exist_ok=True)
54 |
55 | for attempts_left in reversed(range(num_attempts)):
56 | data_size = 0
57 | try:
58 | # Download.
59 | data_md5 = hashlib.md5()
60 | with session.get(file_url, stream=True) as res:
61 | res.raise_for_status()
62 | with open(tmp_path, 'wb') as f:
63 | for chunk in res.iter_content(chunk_size=chunk_size<<10):
64 | f.write(chunk)
65 | data_size += len(chunk)
66 | data_md5.update(chunk)
67 |
68 | # Validate.
69 | if 'file_size' in file_spec and data_size != file_spec['file_size']:
70 | raise IOError('Incorrect file size', file_path)
71 | if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
72 | raise IOError('Incorrect file MD5', file_path)
73 | break
74 |
75 | except:
76 | # Last attempt => raise error.
77 | if not attempts_left:
78 | raise
79 |
80 | # Handle Google Drive virus checker nag.
81 | if data_size > 0 and data_size < 8192:
82 | with open(tmp_path, 'rb') as f:
83 | data = f.read()
84 | links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'export=download' in link]
85 | if len(links) == 1:
86 | file_url = requests.compat.urljoin(file_url, links[0])
87 | continue
88 |
89 | # Rename temp file to the correct name.
90 | os.replace(tmp_path, file_path) # atomic
91 |
92 | # Attempt to clean up any leftover temps.
93 | for filename in glob.glob(file_path + '.tmp.*'):
94 | try:
95 | os.remove(filename)
96 | except:
97 | pass
98 |
--------------------------------------------------------------------------------