├── pyproject.toml ├── loss_vis.jpg ├── testing.py ├── resizing.py ├── LICENSE ├── README.md ├── tf_octuplet_loss.py └── pt_octuplet_loss.py /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py38'] -------------------------------------------------------------------------------- /loss_vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Martlgap/octuplet-loss/HEAD/loss_vis.jpg -------------------------------------------------------------------------------- /testing.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | from tf_octuplet_loss import OctupletLoss as TFOctupletLoss 4 | from pt_octuplet_loss import OctupletLoss as PTOctupletLoss 5 | import pickle 6 | import torch 7 | 8 | 9 | with open("/mnt/ssd2/test_embs.pkl", "rb") as f: 10 | data, labels = pickle.load(f) 11 | 12 | tf_loss_fn = TFOctupletLoss(margin=500, metric="euclidean_squared", configuration=[True, True, True, True]) 13 | tf_output = tf_loss_fn(data, labels) 14 | print(2660.77685546875 == float(tf_output.numpy())) 15 | #print(1.0572563409805298 == float(tf_output.numpy())) 16 | 17 | pt_loss_fn = PTOctupletLoss(margin=500, metric="euclidean_squared", configuration=[True, True, True, True]) 18 | pt_output = pt_loss_fn(torch.tensor(data.numpy()[:8]), torch.tensor(labels.numpy()[:4])) 19 | print(float(pt_output.numpy())) 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /resizing.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def reduce_res(batch, res: int): 5 | """Down- & Up-sampling of images in a batch 6 | 7 | :param batch: batch tensor shape=[batch_size, height, width, channels] 8 | :param res: the desired resolution as integer 9 | :return: batch 10 | """ 11 | 12 | def reduce_res_fn(inputs): 13 | img_hr = inputs[0] 14 | dim_lr = inputs[1] 15 | dim_hr = tf.shape(img_hr)[0] 16 | img_rlr = tf.image.resize(img_hr, (dim_lr, dim_lr), antialias=True, method=tf.image.ResizeMethod.BICUBIC) 17 | img_lr = tf.image.resize(img_rlr, (dim_hr, dim_hr), antialias=False, method=tf.image.ResizeMethod.BICUBIC) 18 | return tf.saturate_cast(img_lr, dtype=img_hr.dtype) 19 | 20 | batch_size = tf.shape(batch)[0] 21 | batch_res = tf.repeat( 22 | tf.gather([res], tf.random.uniform([], minval=0, maxval=1, dtype=tf.int32)), 23 | repeats=batch_size, 24 | ) 25 | return tf.map_fn( 26 | reduce_res_fn, elems=(batch, batch_res), fn_output_signature=batch.dtype 27 | ) # Apply function to each element of batch 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Martin Knoche 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Octuplet Loss - Make Face Recognition Robust Against Image Resolution 2 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 3 | [![License](https://img.shields.io/badge/license-MIT-blue)](https://img.shields.io/badge/license-MIT-blue) 4 | [![Last Commit](https://img.shields.io/github/last-commit/martlgap/octuplet-loss)](https://img.shields.io/github/last-commit/martlgap/octuplet-loss) 5 | 6 | 7 | Here, we release our code and model utilized in the following paper: 8 | - [Octuplet Loss - Make Face Recognition Robust Against Image Resolution 9 | ](https://arxiv.org/abs/2207.06726) 10 | 11 | ![Loss Visualization](https://github.com/martlgap/octuplet-loss/blob/main/loss_vis.jpg?raw=true) 12 | 13 | ## 🏆 Performance (Accuracy [%]) 14 | All models are finetuned with Octuplet-Loss: 15 | | Model | [LFW](http://vis-www.cs.umass.edu/lfw/) | [XQLFW](https://martlgap.github.io/xqlfw/) | 16 | |---|---|---| 17 | | [ArcFace](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deng_ArcFace_Additive_Angular_Margin_Loss_for_Deep_Face_Recognition_CVPR_2019_paper.pdf) | 99.55 | 93.27 | 18 | | [MagFace](https://openaccess.thecvf.com/content/CVPR2021/papers/Meng_MagFace_A_Universal_Representation_for_Face_Recognition_and_Quality_Assessment_CVPR_2021_paper.pdf) | 99.63 | 92.92 | 19 | | [FaceTransformer](https://arxiv.org/abs/2103.14803) | 99.73 | 95.12 | 20 | 21 | ## 💻 Code 22 | We provide the code of our proposed octuplet loss: 23 | - [tf_octuplet_loss.py (Tensorflow 2)](https://github.com/martlgap/octuplet-loss/blob/main/tf_octuplet_loss.py) 24 | - [pt_octuplet_loss.py (PyTorch)](https://github.com/martlgap/octuplet-loss/blob/main/pt_octuplet_loss.py). 25 | 26 | ## 🧠 Model 27 | We provide the model via huggingface 28 | [https://huggingface.co/Martlgap/FaceTransformerOctupletLoss](https://huggingface.co/Martlgap/FaceTransformerOctupletLoss) 29 | 30 | ## 🥣 Requirements 31 | [![Python 3.8](https://img.shields.io/badge/Python-3.8-blue)](https://img.shields.io/badge/Python-3.8-blue) 32 | 33 | 34 | ## 📖 Cite 35 | If you use our code please consider citing: 36 | ~~~tex 37 | @inproceedings{knoche2023octuplet, 38 | title={Octuplet loss: Make face recognition robust to image resolution}, 39 | author={Knoche, Martin and Elkadeem, Mohamed and H{\"o}rmann, Stefan and Rigoll, Gerhard}, 40 | booktitle={2023 IEEE 17th International Conference on Automatic Face and Gesture Recognition (FG)}, 41 | pages={1--8}, 42 | year={2023}, 43 | organization={IEEE} 44 | } 45 | ~~~ 46 | 47 | 48 | ## ✉️ Contact 49 | For any inquiries, please open an [issue](https://github.com/Martlgap/octuplet-loss/issues) on GitHub or send an E-Mail to: [Martin.Knoche@tum.de](mailto:Martin.Knoche@tum.de) 50 | -------------------------------------------------------------------------------- /tf_octuplet_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | @tf.function 5 | def pairwise_distances(embeddings: tf.float32, metric: str = "euclidean") -> tf.float32: 6 | """Calculates pairwise distances of embeddings 7 | 8 | :param embeddings: embeddings 9 | :param metric: 'euclidean', 'euclidean_squared' or 'cosine' 10 | :return: pairwise distance matrix 11 | """ 12 | 13 | if metric == "cosine": 14 | distances_normalized = tf.nn.l2_normalize(embeddings, axis=1) 15 | distances = tf.matmul(distances_normalized, distances_normalized, adjoint_b=True) 16 | return 1.0 - distances 17 | 18 | # With help of: ||a - b||^2 = ||a||^2 - 2 + ||b||^2 19 | dot_product = tf.matmul(embeddings, tf.transpose(embeddings)) 20 | square_norm = tf.linalg.diag_part(dot_product) 21 | distances = tf.expand_dims(square_norm, 1) - 2.0 * dot_product + tf.expand_dims(square_norm, 0) 22 | distances = tf.maximum(distances, 0.0) 23 | 24 | if metric == "euclidean_squared": 25 | return distances 26 | 27 | # Prevent square root from error with 0.0 -> sqrt(0.0) 28 | mask = tf.cast(tf.equal(distances, 0.0), tf.float32) 29 | distances = distances + mask * 1e-16 30 | distances = tf.sqrt(distances) 31 | distances = distances * (1.0 - mask) 32 | 33 | return distances 34 | 35 | 36 | @tf.function 37 | def triplet_loss(distances: tf.float32, mask_pos: tf.bool, mask_neg: tf.bool, margin: float) -> tf.float32: 38 | """Triplet Loss Function 39 | 40 | :param distances: pairwise distances of all embeddings within batch 41 | :param mask_pos: mask of distance between A and P (positive distances) 42 | :param mask_neg: mask of distances between A and N (negative distances 43 | :param margin: the margin for the triplet loss 44 | Formula: Loss = max(0, dist(A,P) - dist(A,N) + margin) 45 | :return: triplet loss values 46 | """ 47 | 48 | pos_dists = tf.multiply(distances, tf.cast(mask_pos, tf.float32)) 49 | hardest_pos_dists = tf.reduce_max(pos_dists, axis=1) 50 | neg_dists = tf.multiply(distances, tf.cast(mask_neg, tf.float32)) 51 | neg_dists_max = tf.reduce_max(neg_dists, axis=1, keepdims=True) 52 | dists_manipulated = distances + neg_dists_max * (1.0 - tf.cast(mask_neg, tf.float32)) 53 | hardest_neg_dist = tf.reduce_min(dists_manipulated, axis=1) 54 | 55 | return tf.maximum(hardest_pos_dists - hardest_neg_dist + margin, 0.0) 56 | 57 | 58 | def OctupletLoss(margin: float = 0.5, metric: str = "euclidean", configuration: list = None): 59 | """Octuplet Loss Function Generator 60 | See our paper -> TODO 61 | https://arxiv.TBD/ 62 | See also -> 63 | https://omoindrot.github.io/triplet-loss (A nice Blog) 64 | 65 | :param margin: margin for triplet loss 66 | :param metric: 'euclidean', 'euclidean_squared', or 'cosine' 67 | :param configuration: configuration of triplet loss functions 'True' takes that specific loss term into account: 68 | Explanation: [Thhh, Thll, Tlhh, Tlll] (see our paper) 69 | :return: the octuplet loss function 70 | """ 71 | 72 | if configuration is None: 73 | configuration = [True, True, True, True] 74 | 75 | #@tf.function 76 | def _loss_function(embeddings: tf.float32, labels: tf.int64) -> tf.float32: 77 | """Octuplet Loss Function 78 | 79 | :param embeddings: concatenated high-resolution and low-resolution embeddings (size: 2*batch_size) 80 | :param labels: classes (size: batch_size) 81 | :return: loss value 82 | """ 83 | 84 | batch_size = labels.shape[0] 85 | distances = pairwise_distances(embeddings, metric=metric) 86 | 87 | lbls_same = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1)) 88 | not_eye_bool = tf.logical_not(tf.cast(tf.eye(batch_size, batch_size), tf.bool)) 89 | mask_pos = tf.equal(lbls_same, not_eye_bool) 90 | mask_neg = tf.logical_not(lbls_same) 91 | 92 | # TRIPLETS HR:HR --------------------------------------------------------------- 93 | dist_hrhr = tf.slice(distances, [0, 0], [batch_size, batch_size]) 94 | loss_hrhr = triplet_loss(dist_hrhr, mask_pos, mask_neg, margin) 95 | 96 | # TRIPLETS HR:LR --------------------------------------------------------------- 97 | dist_hrlr = tf.slice(distances, [0, batch_size], [batch_size, batch_size]) 98 | loss_hrlr = triplet_loss(dist_hrlr, mask_pos, mask_neg, margin) 99 | 100 | # TRIPLETS LR:HR --------------------------------------------------------------- 101 | dist_lrhr = tf.slice(distances, [batch_size, 0], [batch_size, batch_size]) 102 | loss_lrhr = triplet_loss(dist_lrhr, mask_pos, mask_neg, margin) 103 | 104 | # TRIPLETS LR:LR --------------------------------------------------------------- 105 | dist_lrlr = tf.slice(distances, [batch_size, batch_size], [batch_size, batch_size]) 106 | loss_lrlr = triplet_loss(dist_lrlr, mask_pos, mask_neg, margin) 107 | 108 | # Combination of triplet loss terms 109 | losses = tf.transpose(tf.cast([configuration] * batch_size, tf.float32)) * [ 110 | loss_hrhr, 111 | loss_hrlr, 112 | loss_lrhr, 113 | loss_lrlr, 114 | ] 115 | summe = tf.reduce_sum(losses, axis=0) 116 | total_loss = tf.reduce_mean(summe, axis=0) 117 | 118 | return total_loss 119 | 120 | return _loss_function 121 | -------------------------------------------------------------------------------- /pt_octuplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pairwise_distances(embeddings: torch.float32, metric: str = "euclidean") -> torch.float32: 5 | """Calculates pairwise distances of embeddings 6 | 7 | :param embeddings: embeddings 8 | :param metric: 'euclidean', 'euclidean_squared' or 'cosine' 9 | :return: pairwise distance matrix 10 | """ 11 | 12 | if metric=="cosine": 13 | norms = torch.norm(embeddings, p=2, dim=1, keepdim=True) 14 | embeddings_normalized = embeddings.div(norms.expand_as(embeddings)) 15 | dists = torch.matmul(embeddings_normalized, embeddings_normalized.T) 16 | return 1.-dists 17 | 18 | # With help of: ||a - b||^2 = ||a||^2 - 2 + ||b||^2 19 | dot_product = torch.matmul(embeddings, embeddings.T) 20 | square_norm = torch.diagonal(dot_product) 21 | distances = torch.unsqueeze(square_norm, 1) - 2.0 * dot_product + torch.unsqueeze(square_norm, 0) 22 | distances = torch.maximum(torch.tensor(distances), torch.tensor(0.0)) 23 | 24 | if metric=="euclidean_squared": 25 | return distances 26 | 27 | # Prevent square root from error with 0.0 -> sqrt(0.0) 28 | mask = torch.eq(distances, torch.tensor(0.0)).float() 29 | distances = distances + mask * 1e-16 30 | distances = torch.sqrt(distances) 31 | distances = distances * (1.0 - mask) 32 | return distances 33 | 34 | 35 | def triplet_loss(distances: torch.float, mask_pos: torch.bool, mask_neg: torch.bool, margin: float) -> torch.float: 36 | """Triplet Loss Function 37 | 38 | :param distances: pairwise distances of all embeddings within batch 39 | :param mask_pos: mask of distance between A and P (positive distances) 40 | :param mask_neg: mask of distances between A and N (negative distances 41 | :param margin: the margin for the triplet loss 42 | Formula: Loss = max(0, dist(A,P) - dist(A,N) + margin) 43 | :return: triplet loss values 44 | """ 45 | 46 | pos_dists = torch.multiply(distances, mask_pos) 47 | hardest_pos_dists = torch.amax(pos_dists, dim=1) 48 | neg_dists = torch.multiply(distances, mask_neg) 49 | neg_dists_max = torch.amax(neg_dists, dim=1, keepdim=True) 50 | dists_manipulated = distances + neg_dists_max * (1.0 - mask_neg) 51 | hardest_neg_dist = torch.amin(dists_manipulated, dim=1) 52 | 53 | return torch.maximum(hardest_pos_dists - hardest_neg_dist + margin, torch.tensor(0.0)) 54 | 55 | 56 | class OctupletLoss(torch.nn.modules.loss._Loss): 57 | def __init__(self, margin: float = 0.5, metric: str = "euclidean", configuration: list = None): 58 | """Octuplet Loss Function Generator 59 | See our paper -> TODO 60 | https://arxiv.TBD/ 61 | See also -> 62 | https://omoindrot.github.io/triplet-loss (A nice Blog) 63 | 64 | :param margin: margin for triplet loss 65 | :param metric: 'euclidean', 'euclidean_squared', or 'cosine' 66 | :param configuration: configuration of triplet loss functions 'True' takes that specific loss term into account: 67 | Explanation: [Thhh, Thll, Tlhh, Tlll] (see our paper) 68 | :return: the octuplet loss function 69 | """ 70 | 71 | super(OctupletLoss, self).__init__() 72 | if configuration is None: 73 | configuration = [True, True, True, True] 74 | self.margin = margin 75 | self.metric = metric 76 | self.configuration = configuration 77 | 78 | 79 | def forward(self, embeddings: torch.float, labels: torch.float) -> torch.float: 80 | """Octuplet Loss Function 81 | 82 | :param embeddings: concatenated high-resolution and low-resolution embeddings (size: 2*batch_size) 83 | :param labels: classes (size: batch_size) 84 | :return: loss value 85 | """ 86 | 87 | # Concat embeddings with HR and LR images 88 | batch_size = labels.shape[0] 89 | pairwise_dist = pairwise_distances(embeddings, metric=self.metric) 90 | 91 | lbls_same = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1)) 92 | not_eye_bool = torch.logical_not(torch.eye(batch_size, batch_size, device=lbls_same.device).bool()) 93 | mask_pos = torch.eq(lbls_same, not_eye_bool).float() 94 | mask_neg = torch.logical_not(lbls_same).float() 95 | 96 | # TRIPLETS HR:HR --------------------------------------------------------------- 97 | dist_hrhr = pairwise_dist[0:batch_size, 0:batch_size] 98 | loss_hrhr = triplet_loss(dist_hrhr, mask_pos, mask_neg, self.margin) 99 | 100 | # TRIPLETS HR:LR --------------------------------------------------------------- 101 | dist_hrlr = pairwise_dist[0:batch_size, batch_size:2*batch_size] 102 | loss_hrlr = triplet_loss(dist_hrlr, mask_pos, mask_neg, self.margin) 103 | 104 | # TRIPLETS LR:HR --------------------------------------------------------------- 105 | dist_lrhr = pairwise_dist[batch_size:2*batch_size, 0:batch_size] 106 | loss_lrhr = triplet_loss(dist_lrhr, mask_pos, mask_neg, self.margin) 107 | 108 | # TRIPLETS LR:LR --------------------------------------------------------------- 109 | dist_lrlr = pairwise_dist[batch_size:2*batch_size, batch_size:2*batch_size] 110 | loss_lrlr = triplet_loss(dist_lrlr, mask_pos, mask_neg, self.margin) 111 | 112 | # Combination of triplet loss terms 113 | losses = torch.transpose(torch.tensor([self.configuration] * batch_size), 0, 1) * torch.stack([ 114 | loss_hrhr, 115 | loss_hrlr, 116 | loss_lrhr, 117 | loss_lrlr, 118 | ]) 119 | summe = torch.sum(losses, dim=0) 120 | total_loss = torch.mean(summe, dim=0) 121 | 122 | return total_loss 123 | --------------------------------------------------------------------------------