├── backbones
├── __init__.py
└── iresnet.py
├── images
└── img.jpg
├── main.py
├── README.md
└── .gitignore
/backbones/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jankolf/ser-fiq-pytorch/HEAD/images/img.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Imports: Project
2 | from backbones.iresnet import iresnet50, iresnet18
3 |
4 | # Imports: Python
5 | from pathlib import Path
6 |
7 | # Imports: Installed Packages
8 | import torch
9 |
10 | from torchvision import io
11 | from torchvision import transforms
12 |
13 |
14 | def load_image(img_path : str) -> torch.Tensor:
15 | image : torch.Tensor = io.read_image(img_path)
16 | image = image.type(torch.FloatTensor) / 255.0
17 | image = normalize(image)
18 | return image
19 |
20 |
21 | if __name__ == "__main__":
22 |
23 | torch_device = torch.device("cpu")
24 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
25 |
26 | resnet = iresnet18(dropout=0.4,num_features=512, use_se=False).to(torch_device)
27 | resnet.load_state_dict(
28 | torch.load("checkpoints/resnet18.pth", map_location=torch_device)
29 | )
30 | resnet.eval()
31 |
32 | image = load_image("images/img.jpg").unsqueeze(dim=0)
33 |
34 | scores = resnet.calculate_serfiq(image, T=10, scaling=5.0)
35 |
36 | print(f"SER-FIQ Score: {scores[0].item():.8f}")
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ###
Unofficial pytorch reimplementation of
2 | # SER-FIQ: Unsupervised Estimation of Face Image Quality Based on Stochastic Embedding Robustness
3 | ### CVPR 2020
4 |
5 |
6 | Authors:
7 |
8 | Philipp Terhörst, Jan Niklas Kolf, Naser Damer, Florian Kirchbuchner, Arjan Kuijper
9 |
10 |
11 |
12 |
13 |
15 |
16 |
17 |
19 |
20 |
21 |
22 |
23 |
24 | ### Reimplementation
25 | SER-FIQ is implemented for iResNet architecture, pre-trained iResNet18 and iResNet50 are provided.
26 | The scaling of the SER-FIQ scores is different from the original implementation and needs to be adapted to the dataset/model.
27 |
28 | ### Run
29 | To run the example code,
30 | 1. download the [model checkpoints](https://share.jankolf.de/s/F64PNQjsQLpmGLW)
31 | 2. create a folder checkpoints and copy the file into it
32 | 3. Execute the example code with
33 | ```
34 | python main.py
35 | ```
36 | ### Dependencies
37 | Tested with pytorch 2.0, torchvision.
38 |
39 | ### License
40 | This project is licensed under the terms of the Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0) license.
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/backbones/iresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100']
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes,
10 | out_planes,
11 | kernel_size=3,
12 | stride=stride,
13 | padding=dilation,
14 | groups=groups,
15 | bias=False,
16 | dilation=dilation)
17 |
18 |
19 | def conv1x1(in_planes, out_planes, stride=1):
20 | """1x1 convolution"""
21 | return nn.Conv2d(in_planes,
22 | out_planes,
23 | kernel_size=1,
24 | stride=stride,
25 | bias=False)
26 | class SEModule(nn.Module):
27 | def __init__(self, channels, reduction):
28 | super(SEModule, self).__init__()
29 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
30 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
33 | self.sigmoid = nn.Sigmoid()
34 |
35 | def forward(self, x):
36 | input = x
37 | x = self.avg_pool(x)
38 | x = self.fc1(x)
39 | x = self.relu(x)
40 | x = self.fc2(x)
41 | x = self.sigmoid(x)
42 |
43 | return input * x
44 |
45 | class IBasicBlock(nn.Module):
46 | expansion = 1
47 | def __init__(self, inplanes, planes, stride=1, downsample=None,
48 | groups=1, base_width=64, dilation=1,use_se=False):
49 | super(IBasicBlock, self).__init__()
50 | if groups != 1 or base_width != 64:
51 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
52 | if dilation > 1:
53 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
54 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
55 | self.conv1 = conv3x3(inplanes, planes)
56 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
57 | self.prelu = nn.PReLU(planes)
58 | self.conv2 = conv3x3(planes, planes, stride)
59 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
60 | self.downsample = downsample
61 | self.stride = stride
62 | self.use_se=use_se
63 | if (use_se):
64 | self.se_block=SEModule(planes,16)
65 |
66 | def forward(self, x):
67 | identity = x
68 | out = self.bn1(x)
69 | out = self.conv1(out)
70 | out = self.bn2(out)
71 | out = self.prelu(out)
72 | out = self.conv2(out)
73 | out = self.bn3(out)
74 | if(self.use_se):
75 | out=self.se_block(out)
76 | if self.downsample is not None:
77 | identity = self.downsample(x)
78 | out += identity
79 | return out
80 |
81 |
82 | class IResNet(nn.Module):
83 | fc_scale = 7 * 7
84 | def __init__(self,
85 | block, layers, dropout=0, num_features=512, zero_init_residual=False,
86 | groups=1, width_per_group=64, replace_stride_with_dilation=None, use_se=False):
87 | super(IResNet, self).__init__()
88 | self.inplanes = 64
89 | self.dilation = 1
90 | self.use_se=use_se
91 | if replace_stride_with_dilation is None:
92 | replace_stride_with_dilation = [False, False, False]
93 | if len(replace_stride_with_dilation) != 3:
94 | raise ValueError("replace_stride_with_dilation should be None "
95 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
96 | self.groups = groups
97 | self.base_width = width_per_group
98 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
99 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
100 | self.prelu = nn.PReLU(self.inplanes)
101 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2 ,use_se=self.use_se)
102 | self.layer2 = self._make_layer(block,
103 | 128,
104 | layers[1],
105 | stride=2,
106 | dilate=replace_stride_with_dilation[0],use_se=self.use_se)
107 | self.layer3 = self._make_layer(block,
108 | 256,
109 | layers[2],
110 | stride=2,
111 | dilate=replace_stride_with_dilation[1] ,use_se=self.use_se)
112 | self.layer4 = self._make_layer(block,
113 | 512,
114 | layers[3],
115 | stride=2,
116 | dilate=replace_stride_with_dilation[2] ,use_se=self.use_se)
117 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
118 | self.dropout =nn.Dropout(p=dropout, inplace=True) # 7x7x 512
119 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
120 | self.num_features = num_features
121 | self.features = nn.BatchNorm1d(num_features, eps=1e-05)
122 | nn.init.constant_(self.features.weight, 1.0)
123 | self.features.weight.requires_grad = False
124 |
125 | for m in self.modules():
126 | if isinstance(m, nn.Conv2d):
127 | nn.init.normal_(m.weight, 0, 0.1)
128 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
129 | nn.init.constant_(m.weight, 1)
130 | nn.init.constant_(m.bias, 0)
131 |
132 | if zero_init_residual:
133 | for m in self.modules():
134 | if isinstance(m, IBasicBlock):
135 | nn.init.constant_(m.bn2.weight, 0)
136 |
137 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False,use_se=False):
138 | downsample = None
139 | previous_dilation = self.dilation
140 | if dilate:
141 | self.dilation *= stride
142 | stride = 1
143 | if stride != 1 or self.inplanes != planes * block.expansion:
144 | downsample = nn.Sequential(
145 | conv1x1(self.inplanes, planes * block.expansion, stride),
146 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
147 | )
148 | layers = []
149 | layers.append(
150 | block(self.inplanes, planes, stride, downsample, self.groups,
151 | self.base_width, previous_dilation,use_se=use_se))
152 | self.inplanes = planes * block.expansion
153 | for _ in range(1, blocks):
154 | layers.append(
155 | block(self.inplanes,
156 | planes,
157 | groups=self.groups,
158 | base_width=self.base_width,
159 | dilation=self.dilation,use_se=use_se))
160 |
161 | return nn.Sequential(*layers)
162 |
163 | def forward(self, x):
164 | x = self.conv1(x)
165 | x = self.bn1(x)
166 | x = self.prelu(x)
167 | x = self.layer1(x)
168 | x = self.layer2(x)
169 | x = self.layer3(x)
170 | x = self.layer4(x)
171 | x = self.bn2(x)
172 | x = torch.flatten(x, 1)
173 | x = self.dropout(x)
174 | x = self.fc(x)
175 | x = self.features(x)
176 | return x
177 |
178 |
179 | @torch.no_grad()
180 | def calculate_serfiq(self,
181 | x : torch.Tensor,
182 | T : int = 100,
183 | scaling : float = 8.0
184 | ):
185 |
186 | if T < 2:
187 | raise ValueError(f"SER-FIQ parameter T is {T}, but needs to be T>1")
188 |
189 | train_mode = self.dropout.training
190 | self.dropout.train()
191 |
192 | batch_size = x.shape[0]
193 |
194 | x = self.conv1(x)
195 | x = self.bn1(x)
196 | x = self.prelu(x)
197 | x = self.layer1(x)
198 | x = self.layer2(x)
199 | x = self.layer3(x)
200 | x = self.layer4(x)
201 | x = self.bn2(x)
202 | x = torch.flatten(x, 1)
203 | x = x.repeat_interleave(T,dim=0)
204 | x = self.dropout(x)
205 | x = self.fc(x)
206 | x = self.features(x)
207 |
208 | norm = torch.linalg.vector_norm(x,ord=2,dim=1).unsqueeze(dim=1)
209 | x = x / norm
210 |
211 | scores = torch.empty((batch_size,), device=x.device)
212 |
213 | for i in range(batch_size):
214 | dist = torch.cdist(
215 | x[i*T:(i+1)*T,:],
216 | x[i*T:(i+1)*T,:],
217 | p=2
218 | )
219 | mean = torch.triu(dist).mean()
220 | scores[i] = 2*(1/(1+torch.exp(scaling*mean)))
221 |
222 | if not train_mode:
223 | self.dropout.eval()
224 |
225 | return scores
226 |
227 |
228 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
229 | model = IResNet(block, layers, **kwargs)
230 | if pretrained:
231 | raise ValueError()
232 | return model
233 |
234 |
235 | def iresnet18(pretrained=False, progress=True, **kwargs):
236 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
237 | progress, **kwargs)
238 |
239 |
240 | def iresnet34(pretrained=False, progress=True, **kwargs):
241 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
242 | progress, **kwargs)
243 |
244 |
245 | def iresnet50(pretrained=False, progress=True, **kwargs):
246 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
247 | progress, **kwargs)
248 |
249 |
250 | def iresnet100(pretrained=False, progress=True, **kwargs):
251 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
252 | progress, **kwargs)
253 |
--------------------------------------------------------------------------------