├── backbones ├── __init__.py └── iresnet.py ├── images └── img.jpg ├── main.py ├── README.md └── .gitignore /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jankolf/ser-fiq-pytorch/HEAD/images/img.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Imports: Project 2 | from backbones.iresnet import iresnet50, iresnet18 3 | 4 | # Imports: Python 5 | from pathlib import Path 6 | 7 | # Imports: Installed Packages 8 | import torch 9 | 10 | from torchvision import io 11 | from torchvision import transforms 12 | 13 | 14 | def load_image(img_path : str) -> torch.Tensor: 15 | image : torch.Tensor = io.read_image(img_path) 16 | image = image.type(torch.FloatTensor) / 255.0 17 | image = normalize(image) 18 | return image 19 | 20 | 21 | if __name__ == "__main__": 22 | 23 | torch_device = torch.device("cpu") 24 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 25 | 26 | resnet = iresnet18(dropout=0.4,num_features=512, use_se=False).to(torch_device) 27 | resnet.load_state_dict( 28 | torch.load("checkpoints/resnet18.pth", map_location=torch_device) 29 | ) 30 | resnet.eval() 31 | 32 | image = load_image("images/img.jpg").unsqueeze(dim=0) 33 | 34 | scores = resnet.calculate_serfiq(image, T=10, scaling=5.0) 35 | 36 | print(f"SER-FIQ Score: {scores[0].item():.8f}") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ###
Unofficial pytorch reimplementation of
2 | #
SER-FIQ: Unsupervised Estimation of Face Image Quality Based on Stochastic Embedding Robustness
3 | ###
CVPR 2020
4 | 5 |
6 | Authors: 7 |
8 | Philipp Terhörst, Jan Niklas Kolf, Naser Damer, Florian Kirchbuchner, Arjan Kuijper 9 |
10 |
11 |
12 |

13 | 15 | Paper available at TheCVF 16 | 17 | 19 | Data available to download 20 | 21 |

22 |
23 | 24 | ### Reimplementation 25 | SER-FIQ is implemented for iResNet architecture, pre-trained iResNet18 and iResNet50 are provided. 26 | The scaling of the SER-FIQ scores is different from the original implementation and needs to be adapted to the dataset/model. 27 | 28 | ### Run 29 | To run the example code, 30 | 1. download the [model checkpoints](https://share.jankolf.de/s/F64PNQjsQLpmGLW) 31 | 2. create a folder checkpoints and copy the file into it 32 | 3. Execute the example code with 33 | ``` 34 | python main.py 35 | ``` 36 | ### Dependencies 37 | Tested with pytorch 2.0, torchvision. 38 | 39 | ### License 40 | This project is licensed under the terms of the Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0) license. 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /backbones/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, 10 | out_planes, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=dilation, 14 | groups=groups, 15 | bias=False, 16 | dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, 22 | out_planes, 23 | kernel_size=1, 24 | stride=stride, 25 | bias=False) 26 | class SEModule(nn.Module): 27 | def __init__(self, channels, reduction): 28 | super(SEModule, self).__init__() 29 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 30 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 33 | self.sigmoid = nn.Sigmoid() 34 | 35 | def forward(self, x): 36 | input = x 37 | x = self.avg_pool(x) 38 | x = self.fc1(x) 39 | x = self.relu(x) 40 | x = self.fc2(x) 41 | x = self.sigmoid(x) 42 | 43 | return input * x 44 | 45 | class IBasicBlock(nn.Module): 46 | expansion = 1 47 | def __init__(self, inplanes, planes, stride=1, downsample=None, 48 | groups=1, base_width=64, dilation=1,use_se=False): 49 | super(IBasicBlock, self).__init__() 50 | if groups != 1 or base_width != 64: 51 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 52 | if dilation > 1: 53 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 54 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 55 | self.conv1 = conv3x3(inplanes, planes) 56 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 57 | self.prelu = nn.PReLU(planes) 58 | self.conv2 = conv3x3(planes, planes, stride) 59 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 60 | self.downsample = downsample 61 | self.stride = stride 62 | self.use_se=use_se 63 | if (use_se): 64 | self.se_block=SEModule(planes,16) 65 | 66 | def forward(self, x): 67 | identity = x 68 | out = self.bn1(x) 69 | out = self.conv1(out) 70 | out = self.bn2(out) 71 | out = self.prelu(out) 72 | out = self.conv2(out) 73 | out = self.bn3(out) 74 | if(self.use_se): 75 | out=self.se_block(out) 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | out += identity 79 | return out 80 | 81 | 82 | class IResNet(nn.Module): 83 | fc_scale = 7 * 7 84 | def __init__(self, 85 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 86 | groups=1, width_per_group=64, replace_stride_with_dilation=None, use_se=False): 87 | super(IResNet, self).__init__() 88 | self.inplanes = 64 89 | self.dilation = 1 90 | self.use_se=use_se 91 | if replace_stride_with_dilation is None: 92 | replace_stride_with_dilation = [False, False, False] 93 | if len(replace_stride_with_dilation) != 3: 94 | raise ValueError("replace_stride_with_dilation should be None " 95 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 96 | self.groups = groups 97 | self.base_width = width_per_group 98 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 100 | self.prelu = nn.PReLU(self.inplanes) 101 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2 ,use_se=self.use_se) 102 | self.layer2 = self._make_layer(block, 103 | 128, 104 | layers[1], 105 | stride=2, 106 | dilate=replace_stride_with_dilation[0],use_se=self.use_se) 107 | self.layer3 = self._make_layer(block, 108 | 256, 109 | layers[2], 110 | stride=2, 111 | dilate=replace_stride_with_dilation[1] ,use_se=self.use_se) 112 | self.layer4 = self._make_layer(block, 113 | 512, 114 | layers[3], 115 | stride=2, 116 | dilate=replace_stride_with_dilation[2] ,use_se=self.use_se) 117 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 118 | self.dropout =nn.Dropout(p=dropout, inplace=True) # 7x7x 512 119 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 120 | self.num_features = num_features 121 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 122 | nn.init.constant_(self.features.weight, 1.0) 123 | self.features.weight.requires_grad = False 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | nn.init.normal_(m.weight, 0, 0.1) 128 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 129 | nn.init.constant_(m.weight, 1) 130 | nn.init.constant_(m.bias, 0) 131 | 132 | if zero_init_residual: 133 | for m in self.modules(): 134 | if isinstance(m, IBasicBlock): 135 | nn.init.constant_(m.bn2.weight, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False,use_se=False): 138 | downsample = None 139 | previous_dilation = self.dilation 140 | if dilate: 141 | self.dilation *= stride 142 | stride = 1 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | conv1x1(self.inplanes, planes * block.expansion, stride), 146 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 147 | ) 148 | layers = [] 149 | layers.append( 150 | block(self.inplanes, planes, stride, downsample, self.groups, 151 | self.base_width, previous_dilation,use_se=use_se)) 152 | self.inplanes = planes * block.expansion 153 | for _ in range(1, blocks): 154 | layers.append( 155 | block(self.inplanes, 156 | planes, 157 | groups=self.groups, 158 | base_width=self.base_width, 159 | dilation=self.dilation,use_se=use_se)) 160 | 161 | return nn.Sequential(*layers) 162 | 163 | def forward(self, x): 164 | x = self.conv1(x) 165 | x = self.bn1(x) 166 | x = self.prelu(x) 167 | x = self.layer1(x) 168 | x = self.layer2(x) 169 | x = self.layer3(x) 170 | x = self.layer4(x) 171 | x = self.bn2(x) 172 | x = torch.flatten(x, 1) 173 | x = self.dropout(x) 174 | x = self.fc(x) 175 | x = self.features(x) 176 | return x 177 | 178 | 179 | @torch.no_grad() 180 | def calculate_serfiq(self, 181 | x : torch.Tensor, 182 | T : int = 100, 183 | scaling : float = 8.0 184 | ): 185 | 186 | if T < 2: 187 | raise ValueError(f"SER-FIQ parameter T is {T}, but needs to be T>1") 188 | 189 | train_mode = self.dropout.training 190 | self.dropout.train() 191 | 192 | batch_size = x.shape[0] 193 | 194 | x = self.conv1(x) 195 | x = self.bn1(x) 196 | x = self.prelu(x) 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | x = self.bn2(x) 202 | x = torch.flatten(x, 1) 203 | x = x.repeat_interleave(T,dim=0) 204 | x = self.dropout(x) 205 | x = self.fc(x) 206 | x = self.features(x) 207 | 208 | norm = torch.linalg.vector_norm(x,ord=2,dim=1).unsqueeze(dim=1) 209 | x = x / norm 210 | 211 | scores = torch.empty((batch_size,), device=x.device) 212 | 213 | for i in range(batch_size): 214 | dist = torch.cdist( 215 | x[i*T:(i+1)*T,:], 216 | x[i*T:(i+1)*T,:], 217 | p=2 218 | ) 219 | mean = torch.triu(dist).mean() 220 | scores[i] = 2*(1/(1+torch.exp(scaling*mean))) 221 | 222 | if not train_mode: 223 | self.dropout.eval() 224 | 225 | return scores 226 | 227 | 228 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 229 | model = IResNet(block, layers, **kwargs) 230 | if pretrained: 231 | raise ValueError() 232 | return model 233 | 234 | 235 | def iresnet18(pretrained=False, progress=True, **kwargs): 236 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 237 | progress, **kwargs) 238 | 239 | 240 | def iresnet34(pretrained=False, progress=True, **kwargs): 241 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 242 | progress, **kwargs) 243 | 244 | 245 | def iresnet50(pretrained=False, progress=True, **kwargs): 246 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 247 | progress, **kwargs) 248 | 249 | 250 | def iresnet100(pretrained=False, progress=True, **kwargs): 251 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 252 | progress, **kwargs) 253 | --------------------------------------------------------------------------------