├── pyproject.toml
├── loss_vis.jpg
├── testing.py
├── resizing.py
├── LICENSE
├── README.md
├── tf_octuplet_loss.py
└── pt_octuplet_loss.py
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 | target-version = ['py38']
--------------------------------------------------------------------------------
/loss_vis.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Martlgap/octuplet-loss/HEAD/loss_vis.jpg
--------------------------------------------------------------------------------
/testing.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 | from tf_octuplet_loss import OctupletLoss as TFOctupletLoss
4 | from pt_octuplet_loss import OctupletLoss as PTOctupletLoss
5 | import pickle
6 | import torch
7 |
8 |
9 | with open("/mnt/ssd2/test_embs.pkl", "rb") as f:
10 | data, labels = pickle.load(f)
11 |
12 | tf_loss_fn = TFOctupletLoss(margin=500, metric="euclidean_squared", configuration=[True, True, True, True])
13 | tf_output = tf_loss_fn(data, labels)
14 | print(2660.77685546875 == float(tf_output.numpy()))
15 | #print(1.0572563409805298 == float(tf_output.numpy()))
16 |
17 | pt_loss_fn = PTOctupletLoss(margin=500, metric="euclidean_squared", configuration=[True, True, True, True])
18 | pt_output = pt_loss_fn(torch.tensor(data.numpy()[:8]), torch.tensor(labels.numpy()[:4]))
19 | print(float(pt_output.numpy()))
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/resizing.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def reduce_res(batch, res: int):
5 | """Down- & Up-sampling of images in a batch
6 |
7 | :param batch: batch tensor shape=[batch_size, height, width, channels]
8 | :param res: the desired resolution as integer
9 | :return: batch
10 | """
11 |
12 | def reduce_res_fn(inputs):
13 | img_hr = inputs[0]
14 | dim_lr = inputs[1]
15 | dim_hr = tf.shape(img_hr)[0]
16 | img_rlr = tf.image.resize(img_hr, (dim_lr, dim_lr), antialias=True, method=tf.image.ResizeMethod.BICUBIC)
17 | img_lr = tf.image.resize(img_rlr, (dim_hr, dim_hr), antialias=False, method=tf.image.ResizeMethod.BICUBIC)
18 | return tf.saturate_cast(img_lr, dtype=img_hr.dtype)
19 |
20 | batch_size = tf.shape(batch)[0]
21 | batch_res = tf.repeat(
22 | tf.gather([res], tf.random.uniform([], minval=0, maxval=1, dtype=tf.int32)),
23 | repeats=batch_size,
24 | )
25 | return tf.map_fn(
26 | reduce_res_fn, elems=(batch, batch_res), fn_output_signature=batch.dtype
27 | ) # Apply function to each element of batch
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Martin Knoche
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Octuplet Loss - Make Face Recognition Robust Against Image Resolution
2 | [](https://github.com/psf/black)
3 | [](https://img.shields.io/badge/license-MIT-blue)
4 | [](https://img.shields.io/github/last-commit/martlgap/octuplet-loss)
5 |
6 |
7 | Here, we release our code and model utilized in the following paper:
8 | - [Octuplet Loss - Make Face Recognition Robust Against Image Resolution
9 | ](https://arxiv.org/abs/2207.06726)
10 |
11 | 
12 |
13 | ## 🏆 Performance (Accuracy [%])
14 | All models are finetuned with Octuplet-Loss:
15 | | Model | [LFW](http://vis-www.cs.umass.edu/lfw/) | [XQLFW](https://martlgap.github.io/xqlfw/) |
16 | |---|---|---|
17 | | [ArcFace](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deng_ArcFace_Additive_Angular_Margin_Loss_for_Deep_Face_Recognition_CVPR_2019_paper.pdf) | 99.55 | 93.27 |
18 | | [MagFace](https://openaccess.thecvf.com/content/CVPR2021/papers/Meng_MagFace_A_Universal_Representation_for_Face_Recognition_and_Quality_Assessment_CVPR_2021_paper.pdf) | 99.63 | 92.92 |
19 | | [FaceTransformer](https://arxiv.org/abs/2103.14803) | 99.73 | 95.12 |
20 |
21 | ## 💻 Code
22 | We provide the code of our proposed octuplet loss:
23 | - [tf_octuplet_loss.py (Tensorflow 2)](https://github.com/martlgap/octuplet-loss/blob/main/tf_octuplet_loss.py)
24 | - [pt_octuplet_loss.py (PyTorch)](https://github.com/martlgap/octuplet-loss/blob/main/pt_octuplet_loss.py).
25 |
26 | ## 🧠 Model
27 | We provide the model via huggingface
28 | [https://huggingface.co/Martlgap/FaceTransformerOctupletLoss](https://huggingface.co/Martlgap/FaceTransformerOctupletLoss)
29 |
30 | ## 🥣 Requirements
31 | [](https://img.shields.io/badge/Python-3.8-blue)
32 |
33 |
34 | ## 📖 Cite
35 | If you use our code please consider citing:
36 | ~~~tex
37 | @inproceedings{knoche2023octuplet,
38 | title={Octuplet loss: Make face recognition robust to image resolution},
39 | author={Knoche, Martin and Elkadeem, Mohamed and H{\"o}rmann, Stefan and Rigoll, Gerhard},
40 | booktitle={2023 IEEE 17th International Conference on Automatic Face and Gesture Recognition (FG)},
41 | pages={1--8},
42 | year={2023},
43 | organization={IEEE}
44 | }
45 | ~~~
46 |
47 |
48 | ## ✉️ Contact
49 | For any inquiries, please open an [issue](https://github.com/Martlgap/octuplet-loss/issues) on GitHub or send an E-Mail to: [Martin.Knoche@tum.de](mailto:Martin.Knoche@tum.de)
50 |
--------------------------------------------------------------------------------
/tf_octuplet_loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | @tf.function
5 | def pairwise_distances(embeddings: tf.float32, metric: str = "euclidean") -> tf.float32:
6 | """Calculates pairwise distances of embeddings
7 |
8 | :param embeddings: embeddings
9 | :param metric: 'euclidean', 'euclidean_squared' or 'cosine'
10 | :return: pairwise distance matrix
11 | """
12 |
13 | if metric == "cosine":
14 | distances_normalized = tf.nn.l2_normalize(embeddings, axis=1)
15 | distances = tf.matmul(distances_normalized, distances_normalized, adjoint_b=True)
16 | return 1.0 - distances
17 |
18 | # With help of: ||a - b||^2 = ||a||^2 - 2 + ||b||^2
19 | dot_product = tf.matmul(embeddings, tf.transpose(embeddings))
20 | square_norm = tf.linalg.diag_part(dot_product)
21 | distances = tf.expand_dims(square_norm, 1) - 2.0 * dot_product + tf.expand_dims(square_norm, 0)
22 | distances = tf.maximum(distances, 0.0)
23 |
24 | if metric == "euclidean_squared":
25 | return distances
26 |
27 | # Prevent square root from error with 0.0 -> sqrt(0.0)
28 | mask = tf.cast(tf.equal(distances, 0.0), tf.float32)
29 | distances = distances + mask * 1e-16
30 | distances = tf.sqrt(distances)
31 | distances = distances * (1.0 - mask)
32 |
33 | return distances
34 |
35 |
36 | @tf.function
37 | def triplet_loss(distances: tf.float32, mask_pos: tf.bool, mask_neg: tf.bool, margin: float) -> tf.float32:
38 | """Triplet Loss Function
39 |
40 | :param distances: pairwise distances of all embeddings within batch
41 | :param mask_pos: mask of distance between A and P (positive distances)
42 | :param mask_neg: mask of distances between A and N (negative distances
43 | :param margin: the margin for the triplet loss
44 | Formula: Loss = max(0, dist(A,P) - dist(A,N) + margin)
45 | :return: triplet loss values
46 | """
47 |
48 | pos_dists = tf.multiply(distances, tf.cast(mask_pos, tf.float32))
49 | hardest_pos_dists = tf.reduce_max(pos_dists, axis=1)
50 | neg_dists = tf.multiply(distances, tf.cast(mask_neg, tf.float32))
51 | neg_dists_max = tf.reduce_max(neg_dists, axis=1, keepdims=True)
52 | dists_manipulated = distances + neg_dists_max * (1.0 - tf.cast(mask_neg, tf.float32))
53 | hardest_neg_dist = tf.reduce_min(dists_manipulated, axis=1)
54 |
55 | return tf.maximum(hardest_pos_dists - hardest_neg_dist + margin, 0.0)
56 |
57 |
58 | def OctupletLoss(margin: float = 0.5, metric: str = "euclidean", configuration: list = None):
59 | """Octuplet Loss Function Generator
60 | See our paper -> TODO
61 | https://arxiv.TBD/
62 | See also ->
63 | https://omoindrot.github.io/triplet-loss (A nice Blog)
64 |
65 | :param margin: margin for triplet loss
66 | :param metric: 'euclidean', 'euclidean_squared', or 'cosine'
67 | :param configuration: configuration of triplet loss functions 'True' takes that specific loss term into account:
68 | Explanation: [Thhh, Thll, Tlhh, Tlll] (see our paper)
69 | :return: the octuplet loss function
70 | """
71 |
72 | if configuration is None:
73 | configuration = [True, True, True, True]
74 |
75 | #@tf.function
76 | def _loss_function(embeddings: tf.float32, labels: tf.int64) -> tf.float32:
77 | """Octuplet Loss Function
78 |
79 | :param embeddings: concatenated high-resolution and low-resolution embeddings (size: 2*batch_size)
80 | :param labels: classes (size: batch_size)
81 | :return: loss value
82 | """
83 |
84 | batch_size = labels.shape[0]
85 | distances = pairwise_distances(embeddings, metric=metric)
86 |
87 | lbls_same = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
88 | not_eye_bool = tf.logical_not(tf.cast(tf.eye(batch_size, batch_size), tf.bool))
89 | mask_pos = tf.equal(lbls_same, not_eye_bool)
90 | mask_neg = tf.logical_not(lbls_same)
91 |
92 | # TRIPLETS HR:HR ---------------------------------------------------------------
93 | dist_hrhr = tf.slice(distances, [0, 0], [batch_size, batch_size])
94 | loss_hrhr = triplet_loss(dist_hrhr, mask_pos, mask_neg, margin)
95 |
96 | # TRIPLETS HR:LR ---------------------------------------------------------------
97 | dist_hrlr = tf.slice(distances, [0, batch_size], [batch_size, batch_size])
98 | loss_hrlr = triplet_loss(dist_hrlr, mask_pos, mask_neg, margin)
99 |
100 | # TRIPLETS LR:HR ---------------------------------------------------------------
101 | dist_lrhr = tf.slice(distances, [batch_size, 0], [batch_size, batch_size])
102 | loss_lrhr = triplet_loss(dist_lrhr, mask_pos, mask_neg, margin)
103 |
104 | # TRIPLETS LR:LR ---------------------------------------------------------------
105 | dist_lrlr = tf.slice(distances, [batch_size, batch_size], [batch_size, batch_size])
106 | loss_lrlr = triplet_loss(dist_lrlr, mask_pos, mask_neg, margin)
107 |
108 | # Combination of triplet loss terms
109 | losses = tf.transpose(tf.cast([configuration] * batch_size, tf.float32)) * [
110 | loss_hrhr,
111 | loss_hrlr,
112 | loss_lrhr,
113 | loss_lrlr,
114 | ]
115 | summe = tf.reduce_sum(losses, axis=0)
116 | total_loss = tf.reduce_mean(summe, axis=0)
117 |
118 | return total_loss
119 |
120 | return _loss_function
121 |
--------------------------------------------------------------------------------
/pt_octuplet_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def pairwise_distances(embeddings: torch.float32, metric: str = "euclidean") -> torch.float32:
5 | """Calculates pairwise distances of embeddings
6 |
7 | :param embeddings: embeddings
8 | :param metric: 'euclidean', 'euclidean_squared' or 'cosine'
9 | :return: pairwise distance matrix
10 | """
11 |
12 | if metric=="cosine":
13 | norms = torch.norm(embeddings, p=2, dim=1, keepdim=True)
14 | embeddings_normalized = embeddings.div(norms.expand_as(embeddings))
15 | dists = torch.matmul(embeddings_normalized, embeddings_normalized.T)
16 | return 1.-dists
17 |
18 | # With help of: ||a - b||^2 = ||a||^2 - 2 + ||b||^2
19 | dot_product = torch.matmul(embeddings, embeddings.T)
20 | square_norm = torch.diagonal(dot_product)
21 | distances = torch.unsqueeze(square_norm, 1) - 2.0 * dot_product + torch.unsqueeze(square_norm, 0)
22 | distances = torch.maximum(torch.tensor(distances), torch.tensor(0.0))
23 |
24 | if metric=="euclidean_squared":
25 | return distances
26 |
27 | # Prevent square root from error with 0.0 -> sqrt(0.0)
28 | mask = torch.eq(distances, torch.tensor(0.0)).float()
29 | distances = distances + mask * 1e-16
30 | distances = torch.sqrt(distances)
31 | distances = distances * (1.0 - mask)
32 | return distances
33 |
34 |
35 | def triplet_loss(distances: torch.float, mask_pos: torch.bool, mask_neg: torch.bool, margin: float) -> torch.float:
36 | """Triplet Loss Function
37 |
38 | :param distances: pairwise distances of all embeddings within batch
39 | :param mask_pos: mask of distance between A and P (positive distances)
40 | :param mask_neg: mask of distances between A and N (negative distances
41 | :param margin: the margin for the triplet loss
42 | Formula: Loss = max(0, dist(A,P) - dist(A,N) + margin)
43 | :return: triplet loss values
44 | """
45 |
46 | pos_dists = torch.multiply(distances, mask_pos)
47 | hardest_pos_dists = torch.amax(pos_dists, dim=1)
48 | neg_dists = torch.multiply(distances, mask_neg)
49 | neg_dists_max = torch.amax(neg_dists, dim=1, keepdim=True)
50 | dists_manipulated = distances + neg_dists_max * (1.0 - mask_neg)
51 | hardest_neg_dist = torch.amin(dists_manipulated, dim=1)
52 |
53 | return torch.maximum(hardest_pos_dists - hardest_neg_dist + margin, torch.tensor(0.0))
54 |
55 |
56 | class OctupletLoss(torch.nn.modules.loss._Loss):
57 | def __init__(self, margin: float = 0.5, metric: str = "euclidean", configuration: list = None):
58 | """Octuplet Loss Function Generator
59 | See our paper -> TODO
60 | https://arxiv.TBD/
61 | See also ->
62 | https://omoindrot.github.io/triplet-loss (A nice Blog)
63 |
64 | :param margin: margin for triplet loss
65 | :param metric: 'euclidean', 'euclidean_squared', or 'cosine'
66 | :param configuration: configuration of triplet loss functions 'True' takes that specific loss term into account:
67 | Explanation: [Thhh, Thll, Tlhh, Tlll] (see our paper)
68 | :return: the octuplet loss function
69 | """
70 |
71 | super(OctupletLoss, self).__init__()
72 | if configuration is None:
73 | configuration = [True, True, True, True]
74 | self.margin = margin
75 | self.metric = metric
76 | self.configuration = configuration
77 |
78 |
79 | def forward(self, embeddings: torch.float, labels: torch.float) -> torch.float:
80 | """Octuplet Loss Function
81 |
82 | :param embeddings: concatenated high-resolution and low-resolution embeddings (size: 2*batch_size)
83 | :param labels: classes (size: batch_size)
84 | :return: loss value
85 | """
86 |
87 | # Concat embeddings with HR and LR images
88 | batch_size = labels.shape[0]
89 | pairwise_dist = pairwise_distances(embeddings, metric=self.metric)
90 |
91 | lbls_same = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1))
92 | not_eye_bool = torch.logical_not(torch.eye(batch_size, batch_size, device=lbls_same.device).bool())
93 | mask_pos = torch.eq(lbls_same, not_eye_bool).float()
94 | mask_neg = torch.logical_not(lbls_same).float()
95 |
96 | # TRIPLETS HR:HR ---------------------------------------------------------------
97 | dist_hrhr = pairwise_dist[0:batch_size, 0:batch_size]
98 | loss_hrhr = triplet_loss(dist_hrhr, mask_pos, mask_neg, self.margin)
99 |
100 | # TRIPLETS HR:LR ---------------------------------------------------------------
101 | dist_hrlr = pairwise_dist[0:batch_size, batch_size:2*batch_size]
102 | loss_hrlr = triplet_loss(dist_hrlr, mask_pos, mask_neg, self.margin)
103 |
104 | # TRIPLETS LR:HR ---------------------------------------------------------------
105 | dist_lrhr = pairwise_dist[batch_size:2*batch_size, 0:batch_size]
106 | loss_lrhr = triplet_loss(dist_lrhr, mask_pos, mask_neg, self.margin)
107 |
108 | # TRIPLETS LR:LR ---------------------------------------------------------------
109 | dist_lrlr = pairwise_dist[batch_size:2*batch_size, batch_size:2*batch_size]
110 | loss_lrlr = triplet_loss(dist_lrlr, mask_pos, mask_neg, self.margin)
111 |
112 | # Combination of triplet loss terms
113 | losses = torch.transpose(torch.tensor([self.configuration] * batch_size), 0, 1) * torch.stack([
114 | loss_hrhr,
115 | loss_hrlr,
116 | loss_lrhr,
117 | loss_lrlr,
118 | ])
119 | summe = torch.sum(losses, dim=0)
120 | total_loss = torch.mean(summe, dim=0)
121 |
122 | return total_loss
123 |
--------------------------------------------------------------------------------