├── .gitignore ├── ArcFace_files ├── ArcFace_functions.py └── backbones │ ├── __init__.py │ ├── iresnet.py │ ├── iresnet2060.py │ ├── mobilefacenet.py │ └── vit.py ├── Evaluation ├── CR-FIQA │ ├── getQualityScore_FR_ID-Booth_12-2024.py │ ├── iresnet.py │ └── run_CRFIQA_ID-Booth.ipynb ├── PoseEstimation │ └── estimate_head_pose_ID-Booth.ipynb ├── PyEER_analysis │ ├── analyse_pyeer_ID-Booth.py │ ├── analysis_scripts │ │ ├── SYN_vs_REAL_create_genuine_and_impostor_files.py │ │ ├── __init__.py │ │ ├── analyse_dataset.py │ │ ├── dataset_EER.sh │ │ ├── impact_of_max_off.py │ │ ├── plot_distributions.py │ │ └── plot_logs.py │ ├── backbones │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── iresnet.py │ │ └── utils.py │ ├── create_boundary_data.py │ ├── create_genuine_and_impostor_files.py │ ├── genuine_and_imposter_SynthVsReal.py │ ├── genuine_and_impostor_AmongSynth.py │ ├── plot_genuine_imposter_distributions.ipynb │ ├── pyeer_scripts │ │ ├── __init__.py │ │ ├── cmc_info.py │ │ ├── cmc_stats.py │ │ ├── eer_info.py │ │ ├── eer_stats.py │ │ ├── plot.py │ │ └── report.py │ └── utils │ │ ├── __init__.py │ │ ├── align_trans.py │ │ ├── augmentation.py │ │ ├── countFLOPS.py │ │ ├── dataloader.py │ │ ├── losses.py │ │ ├── rand_augment.py │ │ ├── utils.py │ │ ├── utils_callbacks.py │ │ ├── utils_logging.py │ │ └── verification.py ├── convert_to_conditional_dataset_for_evaluation.ipynb └── dgm-eval │ ├── LICENSE │ ├── dgm_eval │ ├── __init__.py │ ├── __main__.py │ ├── dataloaders.py │ ├── heatmaps │ │ ├── __init__.py │ │ ├── gradcam.py │ │ ├── heatmaps.py │ │ └── heatmaps_utils.py │ ├── metrics │ │ ├── __init__.py │ │ ├── authpct.py │ │ ├── ct.py │ │ ├── fd.py │ │ ├── fls.py │ │ ├── inception_score.py │ │ ├── mmd.py │ │ ├── prdc.py │ │ ├── sw.py │ │ └── vendi.py │ ├── models │ │ ├── __init__.py │ │ ├── clip.py │ │ ├── convnext.py │ │ ├── data2vec.py │ │ ├── dinov2.py │ │ ├── encoder.py │ │ ├── inception.py │ │ ├── load_encoder.py │ │ ├── mae.py │ │ ├── pixel.py │ │ ├── simclr.py │ │ ├── swav.py │ │ └── util │ │ │ └── pos_embed.py │ ├── representations.py │ └── resizer.py │ ├── main_DGM_EVAL.ipynb │ ├── scripts │ └── run_experiments.sh │ ├── setup.cfg │ └── setup.py ├── FR_training ├── backbones │ ├── __init__.py │ ├── activation.py │ ├── iresnet.py │ └── utils.py ├── config │ ├── FR_config.py │ ├── FR_config_Augmented.py │ ├── config_new.py │ ├── config_orig.py │ ├── test_FR_config.py │ └── test_FR_config_Augmented.py ├── moco │ ├── __init__.py │ ├── builder.py │ ├── data_utils.py │ ├── dataloader.py │ └── loader.py ├── test_FR.py ├── test_FR_Augmented.py ├── train_FR.py ├── train_FR_Augmented.py └── utils │ ├── FAA_policy.py │ ├── __init__.py │ ├── augmentation.py │ ├── dataset.py │ ├── losses.py │ ├── rand_augment.py │ ├── utils_callbacks.py │ ├── utils_callbacks_4channel.py │ ├── utils_logging.py │ ├── verification.py │ └── verification_4channel.py ├── README.md ├── assets ├── ARIS_logo_eng_resized.jpg ├── preview_framework.jpg └── preview_samples.jpg ├── configs ├── config_train_SD21.py └── config_train_SD21_FRIDA.py ├── extract_ArcFace_embeds.py ├── inference_ID-Booth.py ├── requirements.txt ├── train_ID-Booth.py └── utils ├── augmentation_with_synthetic_data.py ├── detect_align_crop_data.py └── sorting_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pth 3 | *.bin 4 | *.sbatch 5 | *.sqfs 6 | *.out 7 | *.json* 8 | *.pt* 9 | *.pdf* 10 | *.zip* 11 | 12 | *backup* 13 | *BACKUP* 14 | 15 | *OUTPUT_MODELS* 16 | *GENERATED_SAMPLES* 17 | *FACE_DATASETS* 18 | *RESULTS* 19 | *Generated_Split* 20 | *FR-Ready_Datasets* 21 | *EXPERIMENTS* 22 | *VALIDATION_DATASETS* 23 | *CLASS_IMAGES* 24 | *NICE_SAMPLES* 25 | 26 | *facenet_pytorch* 27 | *images_during_training* 28 | *Face_recognition_training* 29 | *FR_DATASETS* 30 | *Results* 31 | *Features* 32 | *SwinFace* 33 | 34 | *Example* 35 | *REC_EXP* 36 | *eDifFIQA* 37 | 38 | *cache* 39 | *images* 40 | b_* -------------------------------------------------------------------------------- /ArcFace_files/ArcFace_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ArcFace_files.backbones import get_model 4 | 5 | ######################################### 6 | """ Replaced with torch cosine similarity """ 7 | # def cosine_similarity(x, y): 8 | # dot = np.sum(np.multiply(x, y), axis=1) 9 | # norm = np.linalg.norm(x, axis=1) * np.linalg.norm(y, axis=1) 10 | # similarity = np.clip(dot/norm, -1., 1.) 11 | # return similarity 12 | 13 | ######################################### 14 | def preprocess_image_for_ArcFace(img): 15 | #img = cv2.resize(img, (112, 112)) 16 | img = img.resize((112, 112)) 17 | img = np.array(img) 18 | # print(np.array(img).shape) 19 | img = np.transpose(img, (2, 0, 1)) 20 | 21 | img = torch.from_numpy(img).unsqueeze(0).float() 22 | img.div_(255).sub_(0.5).div_(0.5) 23 | 24 | return img 25 | 26 | ######################################### 27 | def prepare_locked_ArcFace_model(): 28 | arcface_model = get_model("r100", fp16=True) # TODO turned to true 29 | arcface_weights = "ArcFace_files/ArcFace_r100_ms1mv3_backbone.pth" 30 | arcface_model.load_state_dict(torch.load(arcface_weights)) 31 | arcface_model.eval() 32 | for param in arcface_model.parameters(): 33 | param.requires_grad = False 34 | 35 | #print(arcface_model) 36 | return arcface_model -------------------------------------------------------------------------------- /ArcFace_files/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | from .mobilefacenet import get_mbf 3 | 4 | 5 | def get_model(name, **kwargs): 6 | # resnet 7 | if name == "r18": 8 | return iresnet18(False, **kwargs) 9 | elif name == "r34": 10 | return iresnet34(False, **kwargs) 11 | elif name == "r50": 12 | return iresnet50(False, **kwargs) 13 | elif name == "r100": 14 | return iresnet100(False, **kwargs) 15 | elif name == "r200": 16 | return iresnet200(False, **kwargs) 17 | elif name == "r2060": 18 | from .iresnet2060 import iresnet2060 19 | return iresnet2060(False, **kwargs) 20 | 21 | elif name == "mbf": 22 | fp16 = kwargs.get("fp16", False) 23 | num_features = kwargs.get("num_features", 512) 24 | return get_mbf(fp16=fp16, num_features=num_features) 25 | 26 | elif name == "mbf_large": 27 | from .mobilefacenet import get_mbf_large 28 | fp16 = kwargs.get("fp16", False) 29 | num_features = kwargs.get("num_features", 512) 30 | return get_mbf_large(fp16=fp16, num_features=num_features) 31 | 32 | elif name == "vit_t": 33 | num_features = kwargs.get("num_features", 512) 34 | from .vit import VisionTransformer 35 | return VisionTransformer( 36 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, 37 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) 38 | 39 | elif name == "vit_t_dp005_mask0": # For WebFace42M 40 | num_features = kwargs.get("num_features", 512) 41 | from .vit import VisionTransformer 42 | return VisionTransformer( 43 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, 44 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) 45 | 46 | elif name == "vit_s": 47 | num_features = kwargs.get("num_features", 512) 48 | from .vit import VisionTransformer 49 | return VisionTransformer( 50 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, 51 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) 52 | 53 | elif name == "vit_s_dp005_mask_0": # For WebFace42M 54 | num_features = kwargs.get("num_features", 512) 55 | from .vit import VisionTransformer 56 | return VisionTransformer( 57 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, 58 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) 59 | 60 | elif name == "vit_b": 61 | # this is a feature 62 | num_features = kwargs.get("num_features", 512) 63 | from .vit import VisionTransformer 64 | return VisionTransformer( 65 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 66 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True) 67 | 68 | elif name == "vit_b_dp005_mask_005": # For WebFace42M 69 | # this is a feature 70 | num_features = kwargs.get("num_features", 512) 71 | from .vit import VisionTransformer 72 | return VisionTransformer( 73 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 74 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) 75 | 76 | elif name == "vit_l_dp005_mask_005": # For WebFace42M 77 | # this is a feature 78 | num_features = kwargs.get("num_features", 512) 79 | from .vit import VisionTransformer 80 | return VisionTransformer( 81 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24, 82 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) 83 | 84 | else: 85 | raise ValueError() 86 | -------------------------------------------------------------------------------- /ArcFace_files/backbones/iresnet2060.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | assert torch.__version__ >= "1.8.1" 5 | from torch.utils.checkpoint import checkpoint_sequential 6 | 7 | __all__ = ['iresnet2060'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, 13 | out_planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=dilation, 17 | groups=groups, 18 | bias=False, 19 | dilation=dilation) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, 25 | out_planes, 26 | kernel_size=1, 27 | stride=stride, 28 | bias=False) 29 | 30 | 31 | class IBasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, 35 | groups=1, base_width=64, dilation=1): 36 | super(IBasicBlock, self).__init__() 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) 42 | self.conv1 = conv3x3(inplanes, planes) 43 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) 44 | self.prelu = nn.PReLU(planes) 45 | self.conv2 = conv3x3(planes, planes, stride) 46 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | out = self.bn1(x) 53 | out = self.conv1(out) 54 | out = self.bn2(out) 55 | out = self.prelu(out) 56 | out = self.conv2(out) 57 | out = self.bn3(out) 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | out += identity 61 | return out 62 | 63 | 64 | class IResNet(nn.Module): 65 | fc_scale = 7 * 7 66 | 67 | def __init__(self, 68 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 69 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 70 | super(IResNet, self).__init__() 71 | self.fp16 = fp16 72 | self.inplanes = 64 73 | self.dilation = 1 74 | if replace_stride_with_dilation is None: 75 | replace_stride_with_dilation = [False, False, False] 76 | if len(replace_stride_with_dilation) != 3: 77 | raise ValueError("replace_stride_with_dilation should be None " 78 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 79 | self.groups = groups 80 | self.base_width = width_per_group 81 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 83 | self.prelu = nn.PReLU(self.inplanes) 84 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 85 | self.layer2 = self._make_layer(block, 86 | 128, 87 | layers[1], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[0]) 90 | self.layer3 = self._make_layer(block, 91 | 256, 92 | layers[2], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[1]) 95 | self.layer4 = self._make_layer(block, 96 | 512, 97 | layers[3], 98 | stride=2, 99 | dilate=replace_stride_with_dilation[2]) 100 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) 101 | self.dropout = nn.Dropout(p=dropout, inplace=True) 102 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 103 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 104 | nn.init.constant_(self.features.weight, 1.0) 105 | self.features.weight.requires_grad = False 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.normal_(m.weight, 0, 0.1) 110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | if zero_init_residual: 115 | for m in self.modules(): 116 | if isinstance(m, IBasicBlock): 117 | nn.init.constant_(m.bn2.weight, 0) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 120 | downsample = None 121 | previous_dilation = self.dilation 122 | if dilate: 123 | self.dilation *= stride 124 | stride = 1 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 129 | ) 130 | layers = [] 131 | layers.append( 132 | block(self.inplanes, planes, stride, downsample, self.groups, 133 | self.base_width, previous_dilation)) 134 | self.inplanes = planes * block.expansion 135 | for _ in range(1, blocks): 136 | layers.append( 137 | block(self.inplanes, 138 | planes, 139 | groups=self.groups, 140 | base_width=self.base_width, 141 | dilation=self.dilation)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def checkpoint(self, func, num_seg, x): 146 | if self.training: 147 | return checkpoint_sequential(func, num_seg, x) 148 | else: 149 | return func(x) 150 | 151 | def forward(self, x): 152 | with torch.cuda.amp.autocast(self.fp16): 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.prelu(x) 156 | x = self.layer1(x) 157 | x = self.checkpoint(self.layer2, 20, x) 158 | x = self.checkpoint(self.layer3, 100, x) 159 | x = self.layer4(x) 160 | x = self.bn2(x) 161 | x = torch.flatten(x, 1) 162 | x = self.dropout(x) 163 | x = self.fc(x.float() if self.fp16 else x) 164 | x = self.features(x) 165 | return x 166 | 167 | 168 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 169 | model = IResNet(block, layers, **kwargs) 170 | if pretrained: 171 | raise ValueError() 172 | return model 173 | 174 | 175 | def iresnet2060(pretrained=False, progress=True, **kwargs): 176 | return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) 177 | -------------------------------------------------------------------------------- /ArcFace_files/backbones/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py 3 | Original author cavalleria 4 | ''' 5 | 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module 8 | import torch 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class ConvBlock(Module): 17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 18 | super(ConvBlock, self).__init__() 19 | self.layers = nn.Sequential( 20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), 21 | BatchNorm2d(num_features=out_c), 22 | PReLU(num_parameters=out_c) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.layers(x) 27 | 28 | 29 | class LinearBlock(Module): 30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 31 | super(LinearBlock, self).__init__() 32 | self.layers = nn.Sequential( 33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), 34 | BatchNorm2d(num_features=out_c) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.layers(x) 39 | 40 | 41 | class DepthWise(Module): 42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 43 | super(DepthWise, self).__init__() 44 | self.residual = residual 45 | self.layers = nn.Sequential( 46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), 47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), 48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 49 | ) 50 | 51 | def forward(self, x): 52 | short_cut = None 53 | if self.residual: 54 | short_cut = x 55 | x = self.layers(x) 56 | if self.residual: 57 | output = short_cut + x 58 | else: 59 | output = x 60 | return output 61 | 62 | 63 | class Residual(Module): 64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 65 | super(Residual, self).__init__() 66 | modules = [] 67 | for _ in range(num_block): 68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) 69 | self.layers = Sequential(*modules) 70 | 71 | def forward(self, x): 72 | return self.layers(x) 73 | 74 | 75 | class GDC(Module): 76 | def __init__(self, embedding_size): 77 | super(GDC, self).__init__() 78 | self.layers = nn.Sequential( 79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), 80 | Flatten(), 81 | Linear(512, embedding_size, bias=False), 82 | BatchNorm1d(embedding_size)) 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | 87 | 88 | class MobileFaceNet(Module): 89 | def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2): 90 | super(MobileFaceNet, self).__init__() 91 | self.scale = scale 92 | self.fp16 = fp16 93 | self.layers = nn.ModuleList() 94 | self.layers.append( 95 | ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 96 | ) 97 | if blocks[0] == 1: 98 | self.layers.append( 99 | ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 100 | ) 101 | else: 102 | self.layers.append( 103 | Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 104 | ) 105 | 106 | self.layers.extend( 107 | [ 108 | DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), 109 | Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 110 | DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), 111 | Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 112 | DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), 113 | Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 114 | ]) 115 | 116 | self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 117 | self.features = GDC(num_features) 118 | self._initialize_weights() 119 | 120 | def _initialize_weights(self): 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.BatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.Linear): 130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 131 | if m.bias is not None: 132 | m.bias.data.zero_() 133 | 134 | def forward(self, x): 135 | with torch.cuda.amp.autocast(self.fp16): 136 | for func in self.layers: 137 | x = func(x) 138 | x = self.conv_sep(x.float() if self.fp16 else x) 139 | x = self.features(x) 140 | return x 141 | 142 | 143 | def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2): 144 | return MobileFaceNet(fp16, num_features, blocks, scale=scale) 145 | 146 | def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4): 147 | return MobileFaceNet(fp16, num_features, blocks, scale=scale) 148 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/analyse_pyeer_ID-Booth.py: -------------------------------------------------------------------------------- 1 | import os 2 | # num_ids = 0 3 | # num_imgs = 32 4 | from pyeer_scripts.eer_info import get_eer_stats 5 | import os 6 | import pandas as pd 7 | import json 8 | 9 | import seaborn as sns 10 | import os 11 | import argparse 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | import matplotlib.patches as mpatches 15 | from matplotlib import rcParams 16 | import sys 17 | import inspect 18 | #from multiprocessing import Process, Queue 19 | import pandas as pd 20 | from genuine_and_impostor_AmongSynth import run_create_gen_imp_files_AmongSynth 21 | from genuine_and_imposter_SynthVsReal import run_create_gen_imp_files_SynthVsReal 22 | 23 | # Cluster setup 24 | # main_folder = "/shared/home/darian.tomasevic/ID-Booth/FR_DATASETS/" 25 | main_folder = "../../FR_DATASETS/" # "../../FR_DATASETS_ABLATION/" 26 | dataset_folders = ["12-2024_SD21_LoRA4_alphaWNone_FINAL_FacePortraitPhoto_Gender_Pose_BackgroundB" 27 | #"12-2024_SD21_LoRA4_alphaWNone_FacePortrait_Photo_Gender_Pose_BackgroundB_100samples" 28 | # "01-2025_ID-Booth_ABLATION_LOSS" 29 | ] 30 | subfolders = ["no_new_Loss", "identity_loss_TimestepWeight", "triplet_prior_loss_TimestepWeight"] 31 | # subfolders = ["no_new_Loss_NoPrior", "no_new_Loss", "triplet_prior_loss_TimestepWeight"] 32 | 33 | # dataset_folders = ["tufts_512_poses_1-7_all_imgs_jpg_per_ID"] 34 | # subfolders = ["images"] 35 | 36 | real_folder = f"{main_folder}tufts_512_poses_1-7_all_imgs_jpg_per_ID/images" 37 | 38 | 39 | # create pyeer report 40 | report_which_metrics = [ 41 | "auc", 42 | "eer", 43 | "eer_th", 44 | "fnmr0", 45 | "fnmr100", 46 | "fnmr1000", 47 | "fmr0", 48 | "fmr100", 49 | "fmr1000", 50 | "gmean", 51 | "gstd", 52 | "imean", 53 | "istd", 54 | "fdr", 55 | "decidability", 56 | "mccoef" 57 | ] 58 | 59 | ######################### 60 | def compute_fdr(stats): 61 | return (stats["gmean"] - stats["imean"]) ** 2 / (stats["gstd"] ** 2 + stats["istd"] ** 2) 62 | 63 | ######################### 64 | 65 | ############################ 66 | # create figures with only nice distribution plots first 67 | def plot_score_histogram(ax, df, stats, which_stat): 68 | TU_DESIGN_COLORS = { 69 | 'Genuine': "#64a0d9",#"#009D81", 70 | 'Imposter': "#d99d64", #"#0083CC", 71 | 'random': "#721085", #"#FDCA00", 72 | 'eer': "#E0221F" 73 | } 74 | sns.histplot(ax=ax, 75 | data=df, 76 | x="scores", 77 | hue="label", 78 | #palette=TU_DESIGN_COLORS, 79 | alpha=0.5, 80 | stat=which_stat, kde=True, bins=100) # binrange=(-1, 1)) 81 | 82 | ax.axvline(x=stats["eer_th"], c=TU_DESIGN_COLORS['eer'], linestyle="--") 83 | 84 | labels = [f'Genuine', f'Imposter'] 85 | handles = [mpatches.Patch(color=TU_DESIGN_COLORS[label], label=label) for label in labels] 86 | 87 | genuine_info = f"${round(stats['gmean'], 3)} \pm {round(stats['gstd'], 3)}$" 88 | imposter_info = f"${round(stats['imean'], 3)} \pm {round(stats['istd'], 3)}$" 89 | 90 | labels = [f'Genuine ({genuine_info})', f'Imposter ({imposter_info})'] 91 | ax.legend(handles=handles, labels=labels, loc="upper left", title="") 92 | 93 | ax.set_title(subfolder.replace("/", "__"), size=10) 94 | ax.set_xlabel("Cosine Similarity", size=14) 95 | ax.set_ylabel("Probability", size=14) 96 | 97 | #plt.legend(loc="upper left") 98 | #ax.set_ylim(0, 0.075) 99 | ############################ 100 | 101 | 102 | for which_config in ["vsSynth", "vsReal"]: # "" 103 | for dataset_folder in dataset_folders: 104 | dataset_folder = os.path.join(main_folder, dataset_folder) 105 | 106 | output_folder = os.path.join(f"RESULTS/{which_config}", os.path.basename(dataset_folder)) 107 | 108 | for subfolder in subfolders: 109 | data_folder = os.path.join(dataset_folder, subfolder) 110 | which_FR_model = "backbones/ArcFace_r100_ms1mv3_backbone.pth" 111 | 112 | if which_config == "vsSynth": 113 | run_create_gen_imp_files_AmongSynth(datadir=data_folder, fr_path=which_FR_model, outdir=output_folder) 114 | elif which_config == "vsReal": 115 | run_create_gen_imp_files_SynthVsReal(datadir=data_folder, realdir=real_folder, fr_path=which_FR_model, outdir=output_folder) 116 | 117 | print(subfolder) 118 | folder = os.path.join(output_folder, subfolder.replace("/","__")) 119 | print("Folder", folder) 120 | gscore_file = os.path.join(folder, "genuines.txt") 121 | iscore_file = os.path.join(folder, "impostors.txt") 122 | 123 | genuine_scores = pd.read_csv(gscore_file, header=None)[0].to_list() 124 | impostor_scores = pd.read_csv(iscore_file, header=None)[0].to_list() 125 | 126 | # Calculating stats 127 | stats = get_eer_stats(genuine_scores, impostor_scores) 128 | stats = stats._asdict() 129 | print(stats.keys()) 130 | 131 | stats["fdr"] = compute_fdr(stats) 132 | 133 | saving_dict = dict() 134 | for metric in report_which_metrics: 135 | print(f"{metric}: {stats[metric]}") 136 | saving_dict[metric] = stats[metric] 137 | 138 | 139 | #with open(os.path.join(folder, "PyEER_report_" + subfolder.replace("/", "__"))+".json", "w") as outfile: 140 | with open(os.path.join(folder, "PyEER_report.json"), "w") as outfile: 141 | json.dump(saving_dict, outfile, indent=4) 142 | print("==" * 30) 143 | 144 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 145 | parentdir = os.path.dirname(currentdir) 146 | sys.path.insert(0, parentdir) 147 | 148 | datadir = os.path.join(output_folder, subfolder.replace("/", "__"))#"Synth_100_subset_preprocess_both_classes" 149 | print(datadir) 150 | save_dir = os.path.join(datadir) 151 | 152 | gen_sims = list(np.loadtxt(os.path.join(datadir, "genuines.txt"))) 153 | impo_sims = list(np.loadtxt(os.path.join(datadir, "impostors.txt"))) 154 | 155 | # Plot 156 | df = pd.DataFrame() 157 | df['scores'] = gen_sims + impo_sims 158 | df['label'] = ['Genuine'] * len(gen_sims) + ['Imposter'] * len(impo_sims) 159 | # save df before plotting 160 | df.to_csv(os.path.join(datadir, "final_df.csv")) 161 | 162 | print("///" * 30 ) 163 | 164 | 165 | fig = plt.figure()#figsize=(8, 8)) 166 | plot_score_histogram(plt.gca(), df, stats, which_stat="probability") 167 | plt.tight_layout() 168 | savename = os.path.join(save_dir, "distribution_" + subfolder.replace("/", "__") + ".png") 169 | print("Saving to:", savename) 170 | plt.savefig(savename, dpi=256) 171 | plt.close(fig) 172 | 173 | print("====" * 30) -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/analysis_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariant/ID-Booth/807d478b74833d69cf39399799fe50cf5284b314/Evaluation/PyEER_analysis/analysis_scripts/__init__.py -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/analysis_scripts/dataset_EER.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | #DATASET="ExFaceGAN_SG3" 3 | DATASET="Synth_100_subset_preprocess_both_classes" 4 | 5 | geteerinf -p "data_plots/"$DATASET -i "impostors.txt" -g "genuines.txt" -sp "data_plots/"$DATASET -e $DATASET 6 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/analysis_scripts/plot_distributions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as mpatches 6 | from matplotlib import rcParams 7 | import sys 8 | import inspect 9 | from tqdm import tqdm 10 | from multiprocessing import Process, Queue 11 | import pandas as pd 12 | 13 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 14 | parentdir = os.path.dirname(currentdir) 15 | sys.path.insert(0, parentdir) 16 | 17 | import seaborn as sns 18 | 19 | rcParams.update({"figure.autolayout": True}) 20 | 21 | 22 | 23 | ############################ 24 | # create figures with only nice distribution plots first 25 | def plot_score_histogram(ax, df, eer): 26 | TU_DESIGN_COLORS = { 27 | 'Genuine': "#64a0d9",#"#009D81", 28 | 'Imposter': "#d99d64", #"#0083CC", 29 | 'random': "#721085", #"#FDCA00", 30 | 'eer': "#EC6500" 31 | } 32 | sns.histplot(ax=ax, 33 | data=df, x="scores", 34 | hue="label", palette=TU_DESIGN_COLORS, 35 | stat="probability", kde=True, bins=100, binrange=(-1, 1)) 36 | 37 | ax.axvline(x=eer, c=TU_DESIGN_COLORS['eer']) 38 | 39 | 40 | labels = ['Genuine', 'Imposter'] 41 | handles = [mpatches.Patch(color=TU_DESIGN_COLORS[label], label=label) for label in labels] 42 | ax.legend(handles=handles, labels=labels, loc="upper left", title="") 43 | 44 | ax.set_title(f"TODO") 45 | ax.set_xlabel("Cosine Similarity") 46 | ax.set_ylabel("Probability") 47 | 48 | ax.set_ylim(0, 0.075) 49 | 50 | ############################ 51 | 52 | def plot_gen_imp_distribution(datadir, num_ids=0, num_imgs=0): 53 | """plots genuine impostor distribution and saves genuine and impostor scores 54 | args: 55 | datadir: directory containing embeddings folder 56 | num_ids: number of identities 57 | num_imgs: number of images per identity 58 | """ 59 | dataname = args.datadir.split(os.path.sep)[-1] 60 | save_dir = os.path.join("data_plots", dataname + "_both_classes") 61 | genuine_file = os.path.join(datadir, "genuines.txt") 62 | impostor_file = os.path.join(datadir, "impostors.txt") 63 | gen_sims = np.loadtxt(genuine_file) 64 | impo_sims = np.loadtxt(impostor_file) 65 | print("Plot histogram...") 66 | savename = os.path.join(save_dir, "distribution_" + dataname + ".png") 67 | plt.hist(gen_sims, bins=200, alpha=0.5, label="Genuine Similarities", density=True) 68 | plt.hist( 69 | impo_sims, bins=200, alpha=0.5, label="Impostor Similarities", density=True 70 | ) 71 | plt.xlabel("Cosine Similarity", size=14) 72 | plt.ylabel("Probability Density", size=14) 73 | plt.legend(loc="upper left") 74 | plt.savefig(savename) 75 | plt.close() 76 | print("Histogram saved") 77 | 78 | print("Plot histogram new...") 79 | 80 | #df = pd.DataFrame({'Genuine': gen_sims, 'Imposter': impo_sims}) 81 | #x = ['Genuine', 'Imposter'] 82 | #df_long = df.melt('x') 83 | print(np.array(gen_sims)) 84 | 85 | 86 | gen_sims = np.array(gen_sims) 87 | impo_sims = np.array(impo_sims) 88 | 89 | df_gen = pd.DataFrame({'Genuine': gen_sims}) 90 | df_impo = pd.DataFrame({'Imposter': impo_sims}) 91 | 92 | 93 | 94 | fig = plt.figure(figsize=(8, 5)) 95 | plot_score_histogram(plt.gca(), df_gen, eer) 96 | plt.tight_layout() 97 | #plt.savefig(save_path, dpi=256) 98 | #plt.close(fig) 99 | 100 | #sns.histplot(data=df_gen, stat="probability", bins=100, binrange=(-1.1, 1.1), kde=True, palette=TU_DESIGN_COLORS) 101 | #sns.histplot(data=df_impo, stat="probability", bins=100, binrange=(-1.1, 1.1), kde=True, palette=TU_DESIGN_COLORS) 102 | 103 | """ 104 | sns.histplot(ax=ax, 105 | data=model_dfs[model_name][frm_name][plot_type], x="scores", 106 | hue="label", palette=TU_DESIGN_COLORS, 107 | stat="probability", kde=True, bins=100, binrange=(-1, 1)) 108 | """ 109 | savename = os.path.join(save_dir, "NEW_distribution_" + dataname + ".png") 110 | #plt.hist(gen_sims / sum(gen_sims), bins=200, alpha=0.5, label="Genuine Similarities")#, density=True) 111 | #plt.hist( 112 | # impo_sims / sum(impo_sims), bins=200, alpha=0.5, label="Impostor Similarities"#, density=True 113 | #) 114 | #plt.xlabel("Cosine Similarity", size=14) 115 | #plt.ylabel("Probability", size=14) 116 | #plt.legend(loc="upper left") 117 | #plt.savefig(savename) 118 | plt.close() 119 | print("Histogram saved") 120 | 121 | 122 | def main(args): 123 | plot_gen_imp_distribution( 124 | args.datadir, 125 | args.num_ids, 126 | args.num_imgs, 127 | ) 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser(description="Study of datasets") 132 | parser.add_argument( 133 | "--datadir", 134 | type=str, 135 | default="/data/synthetic_imgs/ExFaceGAN_SG3", 136 | help="path to directory containing the image folder", 137 | ) 138 | parser.add_argument("--batchsize", type=int, default=32) 139 | parser.add_argument("--num_ids", type=int, default=5000) 140 | parser.add_argument("--num_imgs", type=int, default=100) 141 | parser.add_argument("--eer", type=float, default=0) 142 | parser.add_argument( 143 | "--fr_path", 144 | default="path/to/pre-trained/FR/model.pth", 145 | ) 146 | args = parser.parse_args() 147 | main(args) 148 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/analysis_scripts/plot_logs.py: -------------------------------------------------------------------------------- 1 | from os.path import join as ojoin 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | import numpy as np 5 | import re 6 | 7 | 8 | def find_reg(pattern, rows): 9 | strings = re.findall(pattern, rows) 10 | arr = np.array(strings, dtype=float) 11 | return arr 12 | 13 | 14 | def main(args): 15 | log_path = ojoin(args.log_path, "Training.log") 16 | model = args.log_path.split("/")[-1] 17 | 18 | # load the contents of the log file 19 | rows = open(log_path).read().strip() 20 | train_loss2 = find_reg(r"Loss2 (.*) Acc1", rows) 21 | if len(train_loss2) > 0: 22 | train_loss = find_reg(r"Loss (.*) Loss2", rows) 23 | else: 24 | train_loss = find_reg(r"Loss (.*) Acc1", rows) 25 | train_acc1 = find_reg(r"Acc1 (.*) Acc5", rows) 26 | steps = find_reg(r"Step: (.*)/", rows) 27 | steps = np.array(steps, dtype=int) 28 | total_step = np.max(steps) 29 | freq = np.min(steps) // 2 30 | 31 | epochs = find_reg(r"Epoch: (.*) Speed", rows) 32 | max_epoch = int(max(epochs)) 33 | 34 | acc_lfw = find_reg(r"\[lfw\]\[[0-9]*\]Accuracy-Flip: (.*)\+-", rows) * 100 35 | max_acc_lfw = round(np.max(acc_lfw), 2) if len(acc_lfw) > 0 else 0 36 | acc_agedb = find_reg(r"\[agedb_30\]\[[0-9]*\]Accuracy-Flip: (.*)\+-", rows) * 100 37 | max_acc_agedb = round(np.max(acc_agedb), 2) if len(acc_agedb) > 0 else 0 38 | acc_cfpfp = find_reg(r"\[cfp_fp\]\[[0-9]*\]Accuracy-Flip: (.*)\+-", rows) * 100 39 | max_acc_cfpfp = round(np.max(acc_cfpfp), 2) if len(acc_cfpfp) > 0 else 0 40 | acc_calfw = find_reg(r"\[calfw\]\[[0-9]*\]Accuracy-Flip: (.*)\+-", rows) * 100 41 | max_acc_caflw = round(np.max(acc_calfw), 2) if len(acc_calfw) > 0 else 0 42 | acc_cplfw = find_reg(r"\[cplfw\]\[[0-9]*\]Accuracy-Flip: (.*)\+-", rows) * 100 43 | max_acc_cplfw = round(np.max(acc_cplfw), 2) if len(acc_cplfw) > 0 else 0 44 | 45 | loss_x = np.array(range(2 * freq, total_step + 1, freq)) 46 | 47 | # plot the loss 48 | plt.style.use("ggplot") 49 | fig, ax1 = plt.subplots() 50 | ax1.set_title(model) 51 | ax1.set_xlabel("Iteration #") 52 | 53 | ax1.plot(loss_x, train_loss, "r-", label="Loss") 54 | ax1.set_ylabel("Loss", color="r") 55 | if len(train_loss2) > 0: 56 | ax1.plot(loss_x, train_loss2, "g-", label="Loss2") 57 | 58 | ax2 = ax1.twinx() 59 | ax2.plot(loss_x, train_acc1, "b-", label="Acc 1") 60 | ax2.set_ylabel("Acc 1", color="b") 61 | 62 | fig.tight_layout() 63 | plt.savefig(ojoin(args.log_path, "Training.png"), format="png", dpi=600) 64 | plt.close() 65 | 66 | tick_freq = max_epoch // 20 67 | 68 | plt.style.use("ggplot") 69 | plt.figure() 70 | plt.plot(acc_lfw, label=f"LFW: {max_acc_lfw}") 71 | plt.plot(acc_agedb, label=f"AgeDB-30: {max_acc_agedb}") 72 | plt.plot(acc_cfpfp, label=f"CFP-FP: {max_acc_cfpfp}") 73 | plt.plot(acc_calfw, label=f"CALFW: {max_acc_caflw}") 74 | plt.plot(acc_cplfw, label=f"CPLFW: {max_acc_cplfw}") 75 | plt.title(model) 76 | plt.xlabel("Epoch") 77 | plt.ylabel("Acc") 78 | plt.ylim(40, 95) 79 | plt.xticks( 80 | np.arange(1, max_epoch + 1, tick_freq), np.arange(2, max_epoch + 2, tick_freq) 81 | ) 82 | plt.legend(loc="lower right") 83 | plt.savefig(ojoin(args.log_path, "Validation.png"), format="png", dpi=600) 84 | plt.close() 85 | 86 | print("Successfully plotted") 87 | 88 | 89 | if __name__ == "__main__": 90 | parser = argparse.ArgumentParser(description="PyTorch Plot Training Logs") 91 | parser.add_argument( 92 | "--log_path", 93 | type=str, 94 | default="../output/ExFace_SG3_CosFace_RA", 95 | help="folder path to log file", 96 | ) 97 | args = parser.parse_args() 98 | main(args) 99 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/backbones/activation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | import torch 5 | 6 | from inspect import isfunction 7 | 8 | class Identity(nn.Module): 9 | """ 10 | Identity block. 11 | """ 12 | def __init__(self): 13 | super(Identity, self).__init__() 14 | 15 | def forward(self, x): 16 | return x 17 | 18 | def __repr__(self): 19 | return '{name}()'.format(name=self.__class__.__name__) 20 | class HSigmoid(nn.Module): 21 | """ 22 | Approximated sigmoid function, so-called hard-version of sigmoid from 'Searching for MobileNetV3,' 23 | https://arxiv.org/abs/1905.02244. 24 | """ 25 | def forward(self, x): 26 | return F.relu6(x + 3.0, inplace=True) / 6.0 27 | 28 | 29 | class Swish(nn.Module): 30 | """ 31 | Swish activation function from 'Searching for Activation Functions,' https://arxiv.org/abs/1710.05941. 32 | """ 33 | def forward(self, x): 34 | return x * torch.sigmoid(x) 35 | class HSwish(nn.Module): 36 | """ 37 | H-Swish activation function from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244. 38 | Parameters: 39 | ---------- 40 | inplace : bool 41 | Whether to use inplace version of the module. 42 | """ 43 | def __init__(self, inplace=False): 44 | super(HSwish, self).__init__() 45 | self.inplace = inplace 46 | 47 | def forward(self, x): 48 | return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 49 | 50 | 51 | def get_activation_layer(activation,param): 52 | """ 53 | Create activation layer from string/function. 54 | Parameters: 55 | ---------- 56 | activation : function, or str, or nn.Module 57 | Activation function or name of activation function. 58 | Returns: 59 | ------- 60 | nn.Module 61 | Activation layer. 62 | """ 63 | assert (activation is not None) 64 | if isfunction(activation): 65 | return activation() 66 | elif isinstance(activation, str): 67 | if activation == "relu": 68 | return nn.ReLU(inplace=True) 69 | elif activation =="prelu": 70 | return nn.PReLU(param) 71 | elif activation == "relu6": 72 | return nn.ReLU6(inplace=True) 73 | elif activation == "swish": 74 | return Swish() 75 | elif activation == "hswish": 76 | return HSwish(inplace=True) 77 | elif activation == "sigmoid": 78 | return nn.Sigmoid() 79 | elif activation == "hsigmoid": 80 | return HSigmoid() 81 | elif activation == "identity": 82 | return Identity() 83 | else: 84 | raise NotImplementedError() 85 | else: 86 | assert (isinstance(activation, nn.Module)) 87 | return activation -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/pyeer_scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | __copyright__ = 'Copyright 2017' 4 | __author__ = u'Bsc. Manuel Aguado Martínez' 5 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/pyeer_scripts/cmc_info.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import argparse 4 | 5 | from os.path import join 6 | 7 | from .cmc_stats import load_scores_from_file, get_cmc_curve, CMCstats 8 | from .report import generate_cmc_report 9 | from .plot import plot_cmc_stats 10 | 11 | __copyright__ = 'Copyright 2017' 12 | __author__ = u'Manuel Aguado Martínez' 13 | 14 | 15 | def get_cmc_info(): 16 | # Setting script arguments 17 | ap = argparse.ArgumentParser() 18 | ap.add_argument("-p", "--path", required=False, default='.', 19 | help="The path to the scores files") 20 | ap.add_argument("-ms", "--scores_filenames", required=True, 21 | help="The scores file. Multiple files must be" 22 | " separated by a comma") 23 | ap.add_argument("-t", "--true_pairs_file_names", required=True, 24 | help="Genuine pairs file. Multiple files must be" 25 | " separated by a comma.") 26 | ap.add_argument("-e", "--experiment_names", required=True, 27 | help="Experiment ID. Multiple IDS must be separated by " 28 | " comma") 29 | ap.add_argument("-r", "--maximum_rank", required=False, default=20, 30 | help="The maximum rank to calculate the penetration" 31 | " coefficient. (default=20)") 32 | ap.add_argument("-lw", "--line_width", required=False, default=2, 33 | help="Line width for plots (default=2)") 34 | ap.add_argument("-lf", "--legend_font", required=False, default=12, 35 | help="The size of the plots legend font (default=12)") 36 | ap.add_argument("-np", "--no_plots", required=False, action='store_true', 37 | help="Indicates whether to not plot the results") 38 | ap.add_argument("-sp", "--save_path", required=False, default='', 39 | help="Path to save the plots (if -s was specified)" 40 | " and stats report") 41 | ap.add_argument("-pf", "--plots_format", required=False, default='png', 42 | help="Format to save plots. Valid formats are:" 43 | "(png, pdf, ps, eps and svg)") 44 | ap.add_argument("-rf", "--report_format", required=False, default='csv', 45 | help="Format to save the report. Valid formats are:" 46 | " (csv, html, tex, json). Default csv.") 47 | ap.add_argument("-sr", "--save_dpi", required=False, default=None, 48 | help="Plots resolution (dots per inch). If not given" 49 | " it will default to the value savefig.dpi in the" 50 | " matplotlibrc file") 51 | ap.add_argument("-ds", "--ds_scores", required=False, action='store_true', 52 | help='Indicates whether the input scores are dissimilarity' 53 | 'scores') 54 | args = ap.parse_args() 55 | 56 | # Parsing script arguments 57 | score_filenames = args.scores_filenames.split(',') 58 | true_pairs_filenames = args.true_pairs_file_names.split(',') 59 | if len(true_pairs_filenames) == 1: 60 | true_pairs_filenames *= len(score_filenames) 61 | experiment_names = args.experiment_names.split(',') 62 | experiments = zip(score_filenames, true_pairs_filenames, experiment_names) 63 | rank = int(args.maximum_rank) 64 | line_width = int(args.line_width) 65 | lgf_size = int(args.legend_font) 66 | ext = '.' + args.plots_format 67 | dpi = None if args.save_dpi is None else int(args.save_dpi) 68 | 69 | # Calculating CMC values for each experiment and plotting them 70 | stats = [] 71 | for i, exp in enumerate(experiments): 72 | s_filename = join(args.path, exp[0]) 73 | tp_filename = join(args.path, exp[1]) 74 | experiment_name = exp[2] 75 | 76 | print('%s: Loading scores file...' % experiment_name) 77 | scores = load_scores_from_file(s_filename, tp_filename, args.ds_scores) 78 | 79 | print('%s: Calculating CMC cruve...' % experiment_name) 80 | rank_values = get_cmc_curve(scores, rank) 81 | 82 | stats.append(CMCstats(exp_id=experiment_name, ranks=rank_values)) 83 | 84 | # Generating reports 85 | print('Generating report...') 86 | 87 | filename = join(args.save_path, 'pyeer_report.' + args.report_format) 88 | generate_cmc_report(stats, rank, filename) 89 | 90 | if not args.no_plots: 91 | print('Plotting...') 92 | plot_cmc_stats(stats, rank, line_width, lgf_size, True, 93 | dpi, args.save_path, ext) 94 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/pyeer_scripts/cmc_stats.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import operator 4 | 5 | from collections import namedtuple 6 | from warnings import warn 7 | 8 | __copyright__ = 'Copyright 2017' 9 | __author__ = u'Manuel Aguado Martínez' 10 | 11 | TEMPLATE_POS = 0 12 | SCORE_POS = 1 13 | 14 | 15 | CMCstats = namedtuple('CMCstats', ['exp_id', # Exp id 16 | 'ranks', # Rank values 17 | ]) 18 | 19 | 20 | def load_scores_from_file(scores_filename, true_pairs_filename, 21 | ds_scores=False, delimiter=' '): 22 | """Loads the match information from the files. 23 | 24 | @param scores_filename: The scores file address. One score per 25 | line with the following format: (query template score) 26 | @type scores_filename: str 27 | @param true_pairs_filename: The true pairs file address. Each line 28 | indicates the corresponding template of each query. Must have 29 | the following format: (query true_template) 30 | @type true_pairs_filename: str 31 | @param ds_scores: Indicates whether te input scores are dissimilarity 32 | scores. 33 | @type ds_scores: bool 34 | @param delimiter: The boundary string of input files. 35 | @type delimiter: str, default ' ' 36 | 37 | @returns: A dictionary {key=query, value=QueryMatchInfo} 38 | @rtype: dict 39 | """ 40 | matching_scores = {} 41 | 42 | with open(true_pairs_filename) as tpf: 43 | for line in tpf: 44 | query, template = line.split(delimiter, 1) 45 | 46 | if query in matching_scores: 47 | matching_scores[query][TEMPLATE_POS].append(template.strip()) 48 | else: 49 | matching_scores[query] = ([template.strip()], []) 50 | 51 | with open(scores_filename) as sf: 52 | for line in sf: 53 | query, template, score = line.split(delimiter)[:3] 54 | matching_scores[query][SCORE_POS].append((template, float(score))) 55 | 56 | for query_match_info in matching_scores.values(): 57 | query_match_info[SCORE_POS].sort(key=operator.itemgetter(SCORE_POS), 58 | reverse=not ds_scores) 59 | 60 | return matching_scores 61 | 62 | 63 | def get_cmc_curve(scores, max_rank): 64 | """Calculates the values of a CMC curve 65 | 66 | @param scores: The dictionary returned by the function 67 | load_scores_from_file or a similar one. 68 | @type scores: dict 69 | @param max_rank: The maximum rank to calculate the penetration coefficient. 70 | @type max_rank : int 71 | 72 | @return: A list with the rank values. 73 | @rtype: list 74 | """ 75 | ranks_values = [0.0] * (max_rank + 1) 76 | queries_total = len(scores) 77 | 78 | # Calculating identification rates 79 | for r in range(max_rank): 80 | 81 | # Calculating identification rate at Rank-r 82 | in_rank = 0.0 83 | for query_match_info in scores.values(): 84 | 85 | if r < len(query_match_info[SCORE_POS]): 86 | # Candidate at position r 87 | candidate = query_match_info[SCORE_POS][r][TEMPLATE_POS] 88 | 89 | # Checking if candidate is the corresponding positive id 90 | true_template = query_match_info[TEMPLATE_POS] 91 | if candidate in true_template: 92 | in_rank += 1 93 | 94 | # Updating rank values 95 | ranks_values[r + 1] = in_rank / queries_total + ranks_values[r] 96 | if ranks_values[r + 1] >= 1.0: 97 | ranks_values[r + 1:] = [1.0] * (len(ranks_values) - r - 1) 98 | 99 | break 100 | 101 | if ranks_values[-1] < 0.2: 102 | warn("It is possible that you had set the wrong score" 103 | " type. Please consider reviewing if you are using" 104 | " dissimilarity or similarity scores") 105 | 106 | return ranks_values[1:] 107 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariant/ID-Booth/807d478b74833d69cf39399799fe50cf5284b314/Evaluation/PyEER_analysis/utils/__init__.py -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/align_trans.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import cv2 4 | from skimage import transform as trans 5 | 6 | 7 | arcface_ref_points = np.array( 8 | [ 9 | [30.2946, 51.6963], 10 | [65.5318, 51.5014], 11 | [48.0252, 71.7366], 12 | [33.5493, 92.3655], 13 | [62.7299, 92.2041], 14 | ], 15 | dtype=np.float32, 16 | ) 17 | 18 | # https://github.com/deepinsight/insightface/blob/master/python-package/insightface/utils/face_align.py 19 | # [:,0] += 8.0 20 | arcface_eval_ref_points = np.array( 21 | [ 22 | [38.2946, 51.6963], 23 | [73.5318, 51.5014], 24 | [56.0252, 71.7366], 25 | [41.5493, 92.3655], 26 | [70.7299, 92.2041], 27 | ], 28 | dtype=np.float32, 29 | ) 30 | 31 | # lmk is prediction; src is template 32 | def estimate_norm(lmk, image_size=112, createEvalDB=False): 33 | """estimate the transformation matrix 34 | :param lmk: detected landmarks 35 | :param image_size: resulting image size (default=112) 36 | :param createEvalDB: (boolean) crop an evaluation or training dataset 37 | :return: transformation matrix M and index 38 | """ 39 | assert lmk.shape == (5, 2) 40 | assert image_size == 112 41 | tform = trans.SimilarityTransform() 42 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1) 43 | min_M = [] 44 | min_index = [] 45 | min_error = float("inf") 46 | if createEvalDB: 47 | src = arcface_eval_ref_points 48 | else: 49 | src = arcface_ref_points 50 | src = np.expand_dims(src, axis=0) 51 | 52 | for i in np.arange(src.shape[0]): 53 | tform.estimate(lmk, src[i]) 54 | M = tform.params[0:2, :] 55 | results = np.dot(M, lmk_tran.T) 56 | results = results.T 57 | error = np.sum(np.sqrt(np.sum((results - src[i]) ** 2, axis=1))) 58 | # print(error) 59 | if error < min_error: 60 | min_error = error 61 | min_M = M 62 | min_index = i 63 | return min_M, min_index 64 | 65 | 66 | # norm_crop from Arcface repository (insightface/recognition/common/face_align.py) 67 | def norm_crop(img, landmark, image_size=112, createEvalDB=False): 68 | """transforms image to match the landmarks with reference landmarks 69 | :param landmark: detected landmarks 70 | :param image_size: resulting image size (default=112) 71 | :param createEvalDB: (boolean) crop an evaluation or training dataset 72 | :return: transformed image 73 | """ 74 | M, pose_index = estimate_norm( 75 | landmark, image_size=image_size, createEvalDB=createEvalDB 76 | ) 77 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 78 | return warped 79 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import torch 3 | import logging 4 | from utils.rand_augment import RandAugment 5 | 6 | 7 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 8 | 9 | to_tensor = [transforms.ToTensor(), normalize] 10 | 11 | aug_h_flip = [transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize] 12 | 13 | aug_rand_4_16 = [ 14 | transforms.RandomHorizontalFlip(), 15 | RandAugment(num_ops=4, magnitude=16), 16 | transforms.ToTensor(), 17 | normalize, 18 | ] 19 | 20 | 21 | def get_conventional_aug_policy(aug_type): 22 | """get geometric and color augmentations 23 | args: 24 | aug_type: string defining augmentation type 25 | operation: RA augmentation operation under testing 26 | num_ops: number of sequential operations under testing 27 | mag: magnitude under testing 28 | return: 29 | augmentation policy 30 | """ 31 | aug = aug_type.lower() 32 | if aug == "gan_hf" or aug == "nogan_hf" or aug == "hf": 33 | augmentation = aug_h_flip 34 | elif aug == "gan_ra_4_16" or aug == "nogan_ra_4_16" or aug == "ra_4_16": 35 | augmentation = aug_rand_4_16 36 | elif aug == "totensor": 37 | augmentation = to_tensor 38 | elif aug == "mag_totensor": 39 | augmentation = [transforms.ToTensor()] 40 | else: 41 | logging.error("Unknown augmentation method: {}".format(aug_type)) 42 | exit() 43 | return transforms.Compose(augmentation) 44 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/countFLOPS.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import numpy as np 3 | 4 | import torch 5 | 6 | def count_model_flops(model, input_res=[112, 112], multiply_adds=True): 7 | list_conv = [] 8 | 9 | def conv_hook(self, input, output): 10 | batch_size, input_channels, input_height, input_width = input[0].size() 11 | output_channels, output_height, output_width = output[0].size() 12 | 13 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 14 | bias_ops = 1 if self.bias is not None else 0 15 | 16 | params = output_channels * (kernel_ops + bias_ops) 17 | flops = (kernel_ops * ( 18 | 2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 19 | 20 | list_conv.append(flops) 21 | 22 | list_linear = [] 23 | 24 | def linear_hook(self, input, output): 25 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 26 | 27 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 28 | if self.bias is not None: 29 | bias_ops = self.bias.nelement() if self.bias.nelement() else 0 30 | flops = batch_size * (weight_ops + bias_ops) 31 | else: 32 | flops = batch_size * weight_ops 33 | list_linear.append(flops) 34 | 35 | list_bn = [] 36 | 37 | def bn_hook(self, input, output): 38 | list_bn.append(input[0].nelement() * 2) 39 | 40 | list_relu = [] 41 | 42 | def relu_hook(self, input, output): 43 | list_relu.append(input[0].nelement()) 44 | 45 | list_pooling = [] 46 | 47 | def pooling_hook(self, input, output): 48 | batch_size, input_channels, input_height, input_width = input[0].size() 49 | output_channels, output_height, output_width = output[0].size() 50 | 51 | kernel_ops = self.kernel_size * self.kernel_size 52 | bias_ops = 0 53 | params = 0 54 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 55 | 56 | list_pooling.append(flops) 57 | def pooling_hook_ad(self, input, output): 58 | batch_size, input_channels, input_height, input_width = input[0].size() 59 | input = input[0] 60 | flops = int(np.prod(input.shape)) 61 | list_pooling.append(flops) 62 | 63 | handles = [] 64 | 65 | def foo(net): 66 | childrens = list(net.children()) 67 | if not childrens: 68 | if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d): 69 | handles.append(net.register_forward_hook(conv_hook)) 70 | elif isinstance(net, torch.nn.Linear): 71 | handles.append(net.register_forward_hook(linear_hook)) 72 | elif isinstance(net, torch.nn.BatchNorm2d) or isinstance(net, torch.nn.BatchNorm1d): 73 | handles.append(net.register_forward_hook(bn_hook)) 74 | elif isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU): 75 | handles.append(net.register_forward_hook(relu_hook)) 76 | elif isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 77 | handles.append(net.register_forward_hook(pooling_hook)) 78 | else: 79 | print("warning" + str(net)) 80 | return 81 | for c in childrens: 82 | foo(c) 83 | 84 | model.eval() 85 | foo(model) 86 | input = Variable(torch.rand(3, input_res[1], input_res[0]).unsqueeze(0), requires_grad=True) 87 | out = model(input) 88 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)) 89 | for h in handles: 90 | h.remove() 91 | model.train() 92 | return flops_to_string(total_flops) 93 | 94 | def flops_to_string(flops, units='MFLOPS', precision=4): 95 | if units == 'GFLOPS': 96 | return str(round(flops / 10.**9, precision)) + ' ' + units 97 | elif units == 'MFLOPS': 98 | return str(round(flops / 10.**6, precision)) + ' ' + units 99 | elif units == 'KFLOPS': 100 | return str(round(flops / 10.**3, precision)) + ' ' + units 101 | else: 102 | return str(flops) + ' FLOPS' 103 | 104 | def _calc_width(net): 105 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 106 | weight_count = 0 107 | for param in net_params: 108 | weight_count += np.prod(param.size()) 109 | return weight_count 110 | 111 | 112 | def count_parameters(model, m_name): 113 | """ counts model parameters 114 | args: 115 | model: model 116 | m_name: model name for return string 117 | return: 118 | string with total and trainable parameters 119 | """ 120 | total_params = sum(p.numel() for p in model.parameters()) 121 | train_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 122 | return f"{m_name}: Total parameters: {total_params} Trainable parameters: {train_params}" 123 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from os.path import join as ojoin 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | import numpy as np 7 | 8 | 9 | def check_for_folder_structure(datadir): 10 | """checks if datadir contains folders (like CASIA) or images (synthetic datasets)""" 11 | img_path = sorted(os.listdir(datadir))[0] 12 | img_path = ojoin(datadir, img_path) 13 | return os.path.isdir(img_path) 14 | 15 | 16 | def load_real_paths(datadir, num_imgs=0, num_classes=0): 17 | """loads complete real image paths 18 | args: 19 | datadir: path to image folders 20 | num_imgs: number of total images 21 | num_classes: number of classes that should be loaded 22 | return: 23 | list of image paths 24 | """ 25 | img_paths = [] 26 | id_folders = sorted(os.listdir(datadir)) 27 | if num_classes != 0: 28 | id_folders = id_folders[:num_classes] 29 | for id in id_folders: 30 | img_files = sorted(os.listdir(ojoin(datadir, id))) 31 | img_paths += [os.path.join(datadir, id, f_name) for f_name in img_files] 32 | if num_imgs != 0: 33 | img_paths = img_paths[:num_imgs] 34 | return img_paths 35 | 36 | 37 | def load_syn_paths(datadir, num_imgs=0, start_img=0): 38 | """loads first level paths, i.e. image folders for DFG that contain augmentation images 39 | args: 40 | datadir: path to image folder 41 | num_imgs: number of images / folders 42 | start_img: start image index 43 | return: 44 | list of image paths 45 | """ 46 | img_files = sorted(os.listdir(datadir)) 47 | if num_imgs != 0: 48 | img_files = img_files[start_img : start_img + num_imgs] 49 | return [os.path.join(datadir, f_name) for f_name in img_files] 50 | 51 | 52 | def load_supervised_paths(datadir, num_ids, num_imgs): 53 | """load e.g. DFG images with folder structure as supervised dataset 54 | args: 55 | datadir: path to directory containing the images 56 | num_ids: number of identities (folders) that should be loaded 57 | num_imgs: number of images per identity that should be loaded 58 | return: 59 | list of image paths, corresponding list of labels 60 | """ 61 | img_paths, labels = [], [] 62 | id_folders = sorted(os.listdir(datadir))[:num_ids] 63 | for i, id in enumerate(id_folders): 64 | id_path = ojoin(datadir, id) 65 | img_files = sorted(os.listdir(id_path))[:num_imgs] 66 | img_paths += [ojoin(id_path, f_name) for f_name in img_files] 67 | 68 | labels += [int(i)] * len(img_files) 69 | 70 | return img_paths, labels 71 | 72 | 73 | def load_latents(datadir, num_lats=0): 74 | """load numpy latents from directory 75 | args: 76 | datadir: path to latent folder 77 | num_lats: number of latents 78 | return: 79 | numpy array of latents 80 | """ 81 | lat_files = sorted(os.listdir(datadir)) 82 | if num_lats != 0: 83 | lat_files = lat_files[:num_lats] 84 | lats = [] 85 | for lat_file in lat_files: 86 | lats.append(np.load(ojoin(datadir, lat_file))) 87 | return np.array(lats) 88 | 89 | 90 | class LimitedDataset(Dataset): 91 | def __init__(self, datadir, transform, num_persons, num_imgs): 92 | """Similar to ImageDataset, but limit the number of persons and images per person""" 93 | self.img_paths, self.labels = load_supervised_paths( 94 | datadir, num_persons, num_imgs 95 | ) 96 | self.transform = transform 97 | dirname = os.path.basename(os.path.normpath(datadir)) 98 | logging.info(f"{dirname}: {len(self.img_paths)} images") 99 | 100 | def __getitem__(self, index): 101 | """Reads an image from a file and preprocesses it and returns with corresponding label.""" 102 | image = Image.open(self.img_paths[index]) 103 | image = image.convert("RGB") 104 | img = self.transform(image) 105 | return img, self.labels[index] 106 | 107 | def __len__(self): 108 | """Returns the total number of font files.""" 109 | return len(self.img_paths) 110 | 111 | 112 | class LatsDataset(Dataset): 113 | def __init__(self, num_imgs, latent_dim=512, lat_path=None, seed=42): 114 | self.lat_dim = latent_dim 115 | if lat_path == "None": 116 | np.random.seed(seed) 117 | self.latents = np.random.randn(num_imgs, latent_dim) 118 | self.norm = False 119 | print("random latent generation") 120 | else: 121 | self.latents = load_latents(lat_path, num_imgs) 122 | self.norm = False 123 | logging.info(f"Create {len(self.latents)} latent representations") 124 | 125 | def __getitem__(self, index): 126 | latent_codes = self.latents[index] # .reshape(-1, self.lat_dim) 127 | if self.norm: 128 | norm = np.linalg.norm(latent_codes, axis=0, keepdims=True) 129 | latent_codes = latent_codes / norm * np.sqrt(self.lat_dim) 130 | return latent_codes 131 | 132 | def __len__(self): 133 | return len(self.latents) 134 | 135 | 136 | class InferenceDataset(Dataset): 137 | def __init__(self, datadir, transform, num_imgs=0, num_ids=0): 138 | """Initializes image paths and preprocessing module.""" 139 | self.is_folder_struct = check_for_folder_structure(datadir) 140 | if self.is_folder_struct: 141 | self.img_paths = load_real_paths( 142 | datadir, num_imgs, num_classes=num_ids 143 | ) # load_first_dfg_path() 144 | else: 145 | self.img_paths = load_syn_paths(datadir, num_imgs) 146 | 147 | self.transform = transform 148 | dirname = os.path.basename(os.path.normpath(datadir)) 149 | logging.info(f"{dirname}: {len(self.img_paths)} images") 150 | 151 | def __getitem__(self, index): 152 | """Reads an image from a file and preprocesses it and returns.""" 153 | image = Image.open(self.img_paths[index]) 154 | image = image.convert("RGB") 155 | img = self.transform(image) 156 | return img, self.img_paths[index] 157 | 158 | def __len__(self): 159 | """Returns the total number of font files.""" 160 | return len(self.img_paths) 161 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | def l2_norm(input, axis=1): 8 | norm = torch.norm(input, 2, axis, True) 9 | output = torch.div(input, norm) 10 | return output 11 | 12 | 13 | class ArcFace(nn.Module): 14 | def __init__(self, in_features, out_features, s=64.0, m=0.50): 15 | super(ArcFace, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.s = s 19 | self.m = m 20 | self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) 21 | nn.init.normal_(self.kernel, std=0.01) 22 | 23 | def forward(self, embbedings, label): 24 | embbedings = l2_norm(embbedings, axis=1) 25 | kernel_norm = l2_norm(self.kernel, axis=0) 26 | cos_theta = torch.mm(embbedings, kernel_norm) 27 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 28 | index = torch.where(label != -1)[0] 29 | m_hot = torch.zeros( 30 | index.size()[0], cos_theta.size()[1], device=cos_theta.device 31 | ) 32 | m_hot.scatter_(1, label[index, None], self.m) 33 | cos_theta.acos_() 34 | cos_theta[index] += m_hot 35 | cos_theta.cos_().mul_(self.s) 36 | return cos_theta 37 | 38 | 39 | class CosFace(nn.Module): 40 | def __init__(self, in_features, out_features, s=64.0, m=0.35): 41 | super(CosFace, self).__init__() 42 | self.in_features = in_features 43 | self.out_features = out_features 44 | self.s = s 45 | self.m = m 46 | self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) 47 | nn.init.normal_(self.kernel, std=0.01) 48 | 49 | def forward(self, embbedings, label): 50 | embbedings = l2_norm(embbedings, axis=1) 51 | kernel_norm = l2_norm(self.kernel, axis=0) 52 | cos_theta = torch.mm(embbedings, kernel_norm) 53 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 54 | index = torch.where(label != -1)[0] 55 | m_hot = torch.zeros( 56 | index.size()[0], cos_theta.size()[1], device=cos_theta.device 57 | ) 58 | m_hot.scatter_(1, label[index, None], self.m) 59 | cos_theta[index] -= m_hot 60 | ret = cos_theta * self.s 61 | return ret 62 | 63 | 64 | class ElasticCosFace(nn.Module): 65 | def __init__( 66 | self, in_features, out_features, s=64.0, m=0.35, std=0.0125, plus=False 67 | ): 68 | super(ElasticCosFace, self).__init__() 69 | self.in_features = in_features 70 | self.out_features = out_features 71 | self.s = s 72 | self.m = m 73 | self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) 74 | nn.init.normal_(self.kernel, std=0.01) 75 | self.std = std 76 | self.plus = plus 77 | 78 | def forward(self, embbedings, label): 79 | embbedings = l2_norm(embbedings, axis=1) 80 | kernel_norm = l2_norm(self.kernel, axis=0) 81 | cos_theta = torch.mm(embbedings, kernel_norm) 82 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 83 | index = torch.where(label != -1)[0] 84 | m_hot = torch.zeros( 85 | index.size()[0], cos_theta.size()[1], device=cos_theta.device 86 | ) 87 | margin = torch.normal( 88 | mean=self.m, 89 | std=self.std, 90 | size=label[index, None].size(), 91 | device=cos_theta.device, 92 | ) # Fast converge .clamp(self.m-self.std, self.m+self.std) 93 | if self.plus: 94 | with torch.no_grad(): 95 | distmat = cos_theta[index, label.view(-1)].detach().clone() 96 | _, idicate_cosie = torch.sort(distmat, dim=0, descending=True) 97 | margin, _ = torch.sort(margin, dim=0) 98 | m_hot.scatter_(1, label[index, None], margin[idicate_cosie]) 99 | else: 100 | m_hot.scatter_(1, label[index, None], margin) 101 | cos_theta[index] -= m_hot 102 | ret = cos_theta * self.s 103 | return ret 104 | 105 | 106 | class AdaFace(nn.Module): 107 | def __init__( 108 | self, 109 | embedding_size=512, 110 | classnum=70722, 111 | m=0.4, 112 | h=0.333, 113 | s=64.0, 114 | t_alpha=1.0, 115 | ): 116 | super(AdaFace, self).__init__() 117 | self.classnum = classnum 118 | self.kernel = nn.Parameter(torch.Tensor(embedding_size, classnum)) 119 | 120 | # initial kernel 121 | self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 122 | self.m = m 123 | self.eps = 1e-3 124 | self.h = h 125 | self.s = s 126 | 127 | # ema prep 128 | self.t_alpha = t_alpha 129 | self.register_buffer("t", torch.zeros(1)) 130 | self.register_buffer("batch_mean", torch.ones(1) * (20)) 131 | self.register_buffer("batch_std", torch.ones(1) * 100) 132 | 133 | print("\n\AdaFace with the following property") 134 | print("self.m", self.m) 135 | print("self.h", self.h) 136 | print("self.s", self.s) 137 | print("self.t_alpha", self.t_alpha) 138 | 139 | def forward(self, embbedings, norms, label): 140 | 141 | kernel_norm = l2_norm(self.kernel, axis=0) 142 | cosine = torch.mm(embbedings, kernel_norm) 143 | cosine = cosine.clamp(-1 + self.eps, 1 - self.eps) # for stability 144 | 145 | safe_norms = torch.clip(norms, min=0.001, max=100) # for stability 146 | safe_norms = safe_norms.clone().detach() 147 | 148 | # update batchmean batchstd 149 | with torch.no_grad(): 150 | mean = safe_norms.mean().detach() 151 | std = safe_norms.std().detach() 152 | self.batch_mean = mean * self.t_alpha + (1 - self.t_alpha) * self.batch_mean 153 | self.batch_std = std * self.t_alpha + (1 - self.t_alpha) * self.batch_std 154 | 155 | margin_scaler = (safe_norms - self.batch_mean) / ( 156 | self.batch_std + self.eps 157 | ) # 66% between -1, 1 158 | margin_scaler = margin_scaler * self.h # 68% between -0.333 ,0.333 when h:0.333 159 | margin_scaler = torch.clip(margin_scaler, -1, 1) 160 | # ex: m=0.5, h:0.333 161 | # range 162 | # (66% range) 163 | # -1 -0.333 0.333 1 (margin_scaler) 164 | # -0.5 -0.166 0.166 0.5 (m * margin_scaler) 165 | 166 | # g_angular 167 | m_arc = torch.zeros(label.size()[0], cosine.size()[1], device=cosine.device) 168 | m_arc.scatter_(1, label.reshape(-1, 1), 1.0) 169 | g_angular = self.m * margin_scaler * -1 170 | m_arc = m_arc * g_angular 171 | theta = cosine.acos() 172 | theta_m = torch.clip(theta + m_arc, min=self.eps, max=math.pi - self.eps) 173 | cosine = theta_m.cos() 174 | 175 | # g_additive 176 | m_cos = torch.zeros(label.size()[0], cosine.size()[1], device=cosine.device) 177 | m_cos.scatter_(1, label.reshape(-1, 1), 1.0) 178 | g_add = self.m + (self.m * margin_scaler) 179 | m_cos = m_cos * g_add 180 | cosine = cosine - m_cos 181 | 182 | # scale 183 | scaled_cosine_m = cosine * self.s 184 | return scaled_cosine_m 185 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from os.path import join as ojoin 4 | from tqdm import tqdm 5 | from scipy.spatial.distance import cosine 6 | 7 | 8 | def save_emb_2_id(embs, img_paths, save_path): 9 | """assigns embeddings to corresponding identity and saves embeddings per identity 10 | args: 11 | embs: numpy array of embeddings 12 | img_paths: list of image paths in same order as embs 13 | save_path: path to save the embeddings 14 | """ 15 | id_embs = {} 16 | for emb, img in zip(embs, img_paths): 17 | #print(img.split("")) 18 | identity = img.split(os.path.sep)[-1].split("_")[0] 19 | #return 20 | #print(identity) #img.split(os.path.sep)[-2] 21 | if identity in id_embs: 22 | id_embs[identity].append(emb) 23 | else: 24 | id_embs[identity] = [emb] 25 | 26 | #print(id_embs) 27 | print("Num ids:", len(id_embs.keys())) 28 | for i, emb in id_embs.items(): 29 | emb = np.array(emb) 30 | save_file = ojoin(save_path, i + ".npy") 31 | np.save(save_file, emb) 32 | print(f"{len(embs)} embeddings saved in {save_path}") 33 | 34 | 35 | 36 | # def save_combined_emb_2_id(embs, img_paths, save_path): 37 | # """assigns embeddings to corresponding identity and saves embeddings per identity 38 | # args: 39 | # embs: numpy array of embeddings 40 | # img_paths: list of image paths in same order as embs 41 | # save_path: path to save the embeddings 42 | # """ 43 | # id_embs = {} 44 | # for emb, img in zip(embs, img_paths): 45 | # #print(img.split("")) 46 | # identity = img.split(os.path.sep)[-1].split("_")[0] 47 | # #return 48 | # #print(identity) #img.split(os.path.sep)[-2] 49 | # if identity in id_embs: 50 | # id_embs[identity].append(emb) 51 | # else: 52 | # id_embs[identity] = [emb] 53 | 54 | # #print(id_embs) 55 | # print("Num ids:", len(id_embs.keys())) 56 | # for i, emb in id_embs.items(): 57 | # emb = np.array(emb) 58 | # save_file = ojoin(save_path, i + ".npy") 59 | # np.save(save_file, emb) 60 | # print(f"{len(embs)} embeddings saved in {save_path}") 61 | 62 | 63 | def save_embeddings(save_dir, img_paths, embeddings): 64 | """saves embedding under corresponding image filename to save_dir 65 | args: 66 | save_dir: path to directory to save the embeddings 67 | img_paths: list of image paths in same order as embs 68 | embeddings: numpy array of embeddings 69 | """ 70 | for img_path, emb in zip(img_paths, embeddings): 71 | img_name = img_path.split(os.path.sep)[-1].split(".")[0] 72 | save_file = ojoin(save_dir, img_name + ".npy") 73 | np.save(save_file, emb) 74 | print(f"{len(embeddings)} inferred embeddings saved in:", save_dir) 75 | 76 | 77 | def load_embeddings(dir, num_embs=0): 78 | """loads embeddings and slightly incorrect image paths from directory 79 | image paths have the correct image name, which is important for further processing 80 | args: 81 | dir: path to embedding directory 82 | num_embs: number of maximal embeddings that should be loaded 83 | return: 84 | numpy array of embeddings, list of corresponding image paths 85 | """ 86 | emb_files = sorted(os.listdir(dir)) 87 | if num_embs > 0: 88 | emb_files = emb_files[:num_embs] 89 | embs, img_paths = [], [] 90 | print("Loading embeddings from:", dir) 91 | for emb_file in tqdm(emb_files): 92 | emb = np.load(ojoin(dir, emb_file)) 93 | embs.append(emb) 94 | img_file = emb_file.replace(".npy", ".jpg") 95 | img_path = ojoin(dir, img_file) 96 | img_paths.append(img_path) 97 | print(f"{len(emb_files)} embeddings loaded from", dir) 98 | return np.vstack(embs), img_paths 99 | 100 | 101 | def pairwise_cos_sim(embs1, embs2, show_pbar=False): 102 | """calculates the cosine similarity between each pair 103 | args: 104 | embs1: array num_samples x feature_dim 105 | embs2: array num_samples x feature_dim 106 | show_pbar: bool show progress bar 107 | return: 108 | cosine similarity [emb1_1 * emb2_1, emb1_2 * emb2_2] 109 | """ 110 | if show_pbar: 111 | print("Calculate cosine similarity...") 112 | cos_sims = [] 113 | pbar = zip(embs1, embs2) 114 | if show_pbar: 115 | pbar = tqdm(zip(embs1, embs2), total=len(embs1)) 116 | for e1, e2 in pbar: 117 | cos_dist = cosine(e1, e2) 118 | cos_sims.append(1 - cos_dist) 119 | return np.array(cos_sims) 120 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/utils_callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import List 5 | import torch 6 | 7 | from utils import verification 8 | from utils.utils_logging import AverageMeter 9 | 10 | 11 | class CallBackVerification(object): 12 | def __init__( 13 | self, 14 | frequent, 15 | rank, 16 | val_targets, 17 | rec_prefix, 18 | img_size, 19 | gen_im_path=None, 20 | is_vae=False, 21 | ): 22 | self.frequent: int = frequent 23 | self.rank: int = rank 24 | self.highest_acc: float = 0.0 25 | self.highest_acc_list: List[float] = [0.0] * len(val_targets) 26 | self.ver_list: List[object] = [] 27 | self.ver_name_list: List[str] = [] 28 | self.gen_im_path = gen_im_path 29 | self.is_vae = is_vae 30 | if self.rank == 0: 31 | self.ver_list, self.ver_name_list = self.init_dataset( 32 | val_targets=val_targets, data_dir=rec_prefix, image_size=img_size 33 | ) 34 | 35 | def ver_test(self, backbone: torch.nn.Module, global_step: int): 36 | results = [] 37 | for i in range(len(self.ver_list)): 38 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( 39 | self.ver_list[i], 40 | backbone, 41 | batch_size=10, 42 | nfolds=10, 43 | gen_im_path=self.gen_im_path, 44 | is_vae=self.is_vae, 45 | ) 46 | logging.info( 47 | "[%s][%d]XNorm: %f" % (self.ver_name_list[i], global_step, xnorm) 48 | ) 49 | logging.info( 50 | "[%s][%d]Accuracy-Flip: %1.5f+-%1.5f" 51 | % (self.ver_name_list[i], global_step, acc2, std2) 52 | ) 53 | if acc2 > self.highest_acc_list[i]: 54 | self.highest_acc_list[i] = acc2 55 | logging.info( 56 | "[%s][%d]Accuracy-Highest: %1.5f" 57 | % (self.ver_name_list[i], global_step, self.highest_acc_list[i]) 58 | ) 59 | results.append(acc2) 60 | return results 61 | 62 | def init_dataset(self, val_targets, data_dir, image_size): 63 | ver_list = [] 64 | ver_name_list = [] 65 | for name in val_targets: 66 | path = os.path.join(data_dir, name + ".bin") 67 | if os.path.exists(path): 68 | data_set = verification.load_bin(path, image_size) 69 | ver_list.append(data_set) 70 | ver_name_list.append(name) 71 | return ver_list, ver_name_list 72 | 73 | def __call__(self, num_update, backbone: torch.nn.Module, do_da=False, ranking=[], curr_beta=0, modified_neurons=[], means=[]): 74 | results = [] 75 | if self.rank == 0 and num_update % self.frequent == 0: 76 | backbone.eval() 77 | if do_da: 78 | results = self.ver_test_da(backbone, num_update, ranking, curr_beta, modified_neurons, means) 79 | else: 80 | results = self.ver_test(backbone, num_update) 81 | backbone.train() 82 | return results 83 | 84 | 85 | class CallBackLogging(object): 86 | def __init__(self, frequent, rank, total_step, batch_size, world_size): 87 | self.frequent: int = frequent 88 | self.rank: int = rank 89 | self.time_start = time.time() 90 | self.total_step: int = total_step 91 | self.batch_size: int = batch_size 92 | self.world_size: int = world_size 93 | 94 | self.init = False 95 | self.tic = 0 96 | 97 | def __call__( 98 | self, 99 | global_step, 100 | loss: AverageMeter, 101 | acc1: AverageMeter, 102 | acc5: AverageMeter, 103 | epoch: int, 104 | ): 105 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: 106 | if self.init: 107 | try: 108 | speed: float = ( 109 | self.frequent * self.batch_size / (time.time() - self.tic) 110 | ) 111 | speed_total = speed * self.world_size 112 | except ZeroDivisionError: 113 | speed_total = float("inf") 114 | 115 | time_now = (time.time() - self.time_start) / 3600 116 | time_total = time_now / ((global_step + 1) / self.total_step) 117 | time_for_end = time_total - time_now 118 | 119 | msg = "Epoch: {:>2} Speed {:.2f} samples/sec Loss {:.4f} Acc1 {:.2f} Acc5 {:.2f} Step: {:>4}/{} Required: {:.1f} hours".format( 120 | epoch, 121 | speed_total, 122 | loss.avg, 123 | acc1.avg, 124 | acc5.avg, 125 | global_step, 126 | self.total_step, 127 | time_for_end, 128 | ) 129 | logging.info(msg) 130 | loss.reset() 131 | self.tic = time.time() 132 | else: 133 | self.init = True 134 | self.tic = time.time() 135 | 136 | 137 | class CallBackModelCheckpoint(object): 138 | def __init__(self, rank, output="./"): 139 | self.rank: int = rank 140 | self.output: str = output 141 | 142 | def __call__( 143 | self, 144 | global_step, 145 | backbone: torch.nn.Module, 146 | header: torch.nn.Module = None, 147 | ): 148 | if global_step > 100 and self.rank == 0: 149 | torch.save( 150 | backbone.module.state_dict(), 151 | os.path.join(self.output, str(global_step) + "backbone.pth"), 152 | ) 153 | if global_step > 100 and self.rank == 0 and header is not None: 154 | torch.save( 155 | header.module.state_dict(), 156 | os.path.join(self.output, str(global_step) + "header.pth"), 157 | ) 158 | -------------------------------------------------------------------------------- /Evaluation/PyEER_analysis/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(log_root, rank, models_root, logfile): 31 | 32 | if rank == 0: 33 | if (not logfile): 34 | logfile= "training.log" 35 | log_root.handlers = [] # This is the key thing for the question! 36 | log_root.setLevel(logging.INFO) 37 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s", datefmt='%Y-%m-%d %H:%M') 38 | handler_file = logging.FileHandler(os.path.join(models_root, logfile)) 39 | handler_stream = logging.StreamHandler(sys.stdout) 40 | handler_file.setFormatter(formatter) 41 | handler_stream.setFormatter(formatter) 42 | log_root.addHandler(handler_file) 43 | log_root.addHandler(handler_stream) 44 | log_root.info('rank_id: %d' % rank) 45 | -------------------------------------------------------------------------------- /Evaluation/convert_to_conditional_dataset_for_evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os \n", 10 | "import shutil \n", 11 | "from tqdm import tqdm \n", 12 | "\n", 13 | "origin_folder = \"FR_DATASETS_ABLATION\"\n", 14 | "output_folder = \"FR_DATASETS_ABLATION_Conditional\"\n", 15 | "\n", 16 | "folders = [\"01-2025_ID-Booth_ABLATION\"]\n", 17 | "#folders = [\"NonFineTuned_21_samples\"]\n", 18 | "\n", 19 | "# losses = [\"no_new_Loss\", \"identity_loss_TimestepWeight\", \"triplet_prior_loss_TimestepWeight\"]\n", 20 | "#losses = [\"+Background_NoNegPrompt\", \"Base_NoNegPrompt\", \"no_new_Loss_NoPrior\"]\n", 21 | "losses = [\"DreamBooth\", \"PortraitBooth\", \"ID-Booth\"]\n", 22 | "\n", 23 | "for folder in folders:\n", 24 | " for loss in losses:\n", 25 | " img_files = os.listdir(os.path.join(origin_folder, folder, loss))\n", 26 | " for img_file in tqdm(img_files): \n", 27 | " img_path = os.path.join(origin_folder, folder, loss, img_file)\n", 28 | " id_of_img = img_file.split(\"_\")[0]\n", 29 | "\n", 30 | " os.makedirs(os.path.join(output_folder, folder, loss, id_of_img), exist_ok=True)\n", 31 | " target_path = os.path.join(output_folder, folder, loss, id_of_img, img_file)\n", 32 | " shutil.copyfile(img_path, target_path)\n", 33 | "\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import os \n", 50 | "import shutil \n", 51 | "from tqdm import tqdm \n", 52 | "\n", 53 | "origin_folder = \"FR_DATASETS\"\n", 54 | "output_folder = \"FR_DATASETS_Conditional\"\n", 55 | "\n", 56 | "# folders = [\"tufts_512_poses_1-7_all_imgs_jpg_per_ID\"]\n", 57 | "folders = [\"FFHQ_512\"]\n", 58 | "\n", 59 | "losses = [\"images\"]\n", 60 | "\n", 61 | "\n", 62 | "for folder in folders:\n", 63 | " for loss in losses:\n", 64 | " img_files = os.listdir(os.path.join(origin_folder, folder, loss))\n", 65 | " for i, img_file in tqdm(enumerate(img_files)): \n", 66 | " if i <= 9999: \n", 67 | " continue \n", 68 | " if i == 19999: \n", 69 | " break \n", 70 | " img_path = os.path.join(origin_folder, folder, loss, img_file)\n", 71 | " id_of_img = img_file.split(\"_\")[0]\n", 72 | "\n", 73 | " os.makedirs(os.path.join(output_folder, folder, loss, id_of_img), exist_ok=True)\n", 74 | " target_path = os.path.join(output_folder, folder, loss, id_of_img, img_file)\n", 75 | " shutil.copyfile(img_path, target_path)\n", 76 | " \n" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "dgm_eval", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.10.16" 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 2 115 | } 116 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Layer 6 AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/heatmaps/__init__.py: -------------------------------------------------------------------------------- 1 | from .heatmaps import visualize_heatmaps 2 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/heatmaps/gradcam.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from dgm_eval.heatmaps.heatmaps_utils import get_features, show_heatmap_on_image 7 | 8 | 9 | class GradCAM: 10 | 11 | def __init__(self, model, reps_real, reps_gen, device, **kwargs): 12 | # Register forward and backward hooks to get activations and gradients, respectively. 13 | self.acts_and_gradients = ActivationsAndGradients(network=model, model_name=model.name) 14 | 15 | # Compute feature statistics on real images. 16 | self.mean_reals = torch.from_numpy(np.mean(reps_real, axis=0)).to(device) 17 | self.cov_reals = torch.from_numpy(np.cov(reps_real, rowvar=False)).to(device) 18 | 19 | self.reps_gen = reps_gen 20 | self.model = model 21 | self.device = device 22 | 23 | def get_map(self, image, idx): 24 | """ 25 | Compute heatmap from the gradients and activation of the Frechet distance 26 | Return the heatmap and image label if possible 27 | """ 28 | self.acts_and_gradients.eval() # Model needs to be in eval mode 29 | 30 | # Computing selected image features 31 | features = get_features(self.acts_and_gradients, image) 32 | 33 | # Compute feature statistics without the selected image 34 | mean_gen = torch.from_numpy(np.mean(np.delete(self.reps_gen, idx, axis=0), axis=0)).to(self.device) 35 | cov_gen = torch.from_numpy(np.cov(np.delete(self.reps_gen, idx, axis=0), rowvar=False)).to(self.device) 36 | 37 | # Compute fid without the selected image 38 | original_fid = wasserstein2_loss(mean_reals=self.mean_reals, mean_gen=mean_gen, 39 | cov_reals=self.cov_reals, cov_gen=cov_gen) 40 | 41 | # Updating feature statistics with the selected image to get gradients 42 | num_images = len(self.reps_gen) 43 | mean = ((num_images - 1) / num_images) * mean_gen + (1 / num_images) * features 44 | cov = ((num_images - 2) / (num_images - 1)) * cov_gen + \ 45 | (1 / num_images) * torch.mm((features - mean_gen).T, (features - mean_gen)) 46 | 47 | # Compute frechet distance and back-propagate loss 48 | loss = wasserstein2_loss(mean_reals=self.mean_reals, mean_gen=mean, 49 | cov_reals=self.cov_reals, cov_gen=cov) 50 | loss.backward() 51 | delta_fid = loss.detach().cpu().numpy() - original_fid.detach().cpu().numpy() 52 | 53 | # Get heatmap from gradients and activation 54 | heatmap = self._get_heatmap_from_grads() 55 | 56 | # Get overlay of heatmap on image 57 | overlay = show_heatmap_on_image(heatmap, image) 58 | 59 | # Get classification label if possible 60 | label = None 61 | if hasattr(self.model, 'get_label'): 62 | label = self.model.get_label(features) 63 | # label = f"{delta_fid:0.5f}" 64 | 65 | return overlay, label 66 | 67 | def _get_heatmap_from_grads(self): 68 | # Get activations and gradients from the target layer by accessing hooks. 69 | activations = self.acts_and_gradients.activations[-1] 70 | gradients = self.acts_and_gradients.gradients[-1] 71 | 72 | if len(activations.shape) == 3: 73 | dim = int(activations.shape[-1] ** 0.5) 74 | activations = activations[:, :, 1:].reshape(*activations.shape[:-1], dim, dim) 75 | gradients = gradients[:, :, 1:].reshape(*gradients.shape[:-1], dim, dim) 76 | 77 | # Turn gradients and activation into heatmap. 78 | weights = np.mean(gradients ** 2, axis=(2, 3), keepdims=True) 79 | heatmap = (weights * activations).sum(axis=1) 80 | 81 | return heatmap[0] 82 | 83 | 84 | MODEL_TO_LAYER_NAME_MAP = { 85 | 'inception': 'blocks.3.2', 86 | 'clip': 'visual.transformer.resblocks.11.ln_1', 87 | 'mae': 'blocks.23.norm1', 88 | 'swav': 'layer4.2', 89 | 'dinov2': 'blocks.23.norm1', 90 | 'convnext': 'stages.3.blocks.2', 91 | 'data2vec': 'model.encoder.layer.23.layernorm_before', 92 | 'simclr': 'net.4.blocks.2.net.3' 93 | } 94 | 95 | MODEL_TO_TRANSFORM_MAP = { 96 | 'inception': lambda x : x, 97 | 'clip': lambda x : -x.transpose(1, 2, 0), 98 | 'mae': lambda x : x.transpose(0, 2, 1), 99 | 'swav': lambda x : x, 100 | 'dinov2': lambda x : -x.transpose(0, 2, 1), 101 | 'convnext': lambda x: -x, 102 | 'data2vec': lambda x: x.transpose(0, 2, 1), 103 | 'simclr': lambda x: -x, 104 | } 105 | 106 | 107 | class ActivationsAndGradients: 108 | """Class to obtain intermediate activations and gradients. 109 | Adapted from: https://github.com/jacobgil/pytorch-grad-cam""" 110 | 111 | def __init__(self, 112 | network: Any, 113 | model_name: str, 114 | network_kwargs: dict = None) -> None: 115 | self.network = network 116 | self.network_kwargs = network_kwargs if network_kwargs is not None else {} 117 | self.gradients: List[np.ndarray] = [] 118 | self.activations: List[np.ndarray] = [] 119 | self.transform = MODEL_TO_TRANSFORM_MAP.get(model_name) 120 | 121 | target_layer_name = MODEL_TO_LAYER_NAME_MAP.get(model_name) 122 | target_layer = dict(network.model.named_modules()).get(target_layer_name) 123 | target_layer.register_forward_hook(self.save_activation) 124 | target_layer.register_full_backward_hook(self.save_gradient) 125 | 126 | def save_activation(self, 127 | module: Any, 128 | input: Any, 129 | output: Any) -> None: 130 | """Saves forward pass activations.""" 131 | activation = output 132 | self.activations.append(self.transform(activation.detach().cpu().numpy())) 133 | 134 | def save_gradient(self, 135 | module: Any, 136 | grad_input: Any, 137 | grad_output: Any) -> None: 138 | """Saves backward pass gradients.""" 139 | # Gradients are computed in reverse order. 140 | grad = grad_output[0] 141 | self.gradients = [self.transform(grad.detach().cpu().numpy())] + self.gradients # Prepend current gradients. 142 | 143 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 144 | """Resets hooked activations and gradients and calls model forward pass.""" 145 | self.gradients = [] 146 | self.activations = [] 147 | return self.network(x, **self.network_kwargs) 148 | 149 | def eval(self): 150 | self.network.eval() 151 | 152 | 153 | def wasserstein2_loss(mean_reals: torch.Tensor, 154 | mean_gen: torch.Tensor, 155 | cov_reals: torch.Tensor, 156 | cov_gen: torch.Tensor, 157 | eps: float = 1e-12) -> torch.Tensor: 158 | """Computes 2-Wasserstein distance.""" 159 | mean_term = torch.sum(torch.square(mean_reals - mean_gen.squeeze(0))) 160 | eigenvalues = torch.real(torch.linalg.eig(torch.matmul(cov_gen, cov_reals))[0]) 161 | cov_term = torch.trace(cov_reals) + torch.trace(cov_gen) - 2 * torch.sum(torch.sqrt(abs(eigenvalues) + eps)) 162 | return mean_term + cov_term 163 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/heatmaps/heatmaps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from dgm_eval.heatmaps.gradcam import GradCAM 9 | from dgm_eval.heatmaps.heatmaps_utils import get_image, create_grid, perturb_image, zero_one_scaling 10 | from dgm_eval.models.encoder import Encoder 11 | 12 | 13 | def visualize_heatmaps(reps_real: np.array, 14 | reps_gen: np.array, 15 | model: Encoder, 16 | dataset: torch.utils.data.Dataset, 17 | results_dir: str, 18 | results_suffix: str = 'default', 19 | dataset_name: str = None, 20 | num_rows: int = 4, 21 | num_cols: int = 4, 22 | device: torch.device = torch.device('cpu'), 23 | perturbation: bool = False, 24 | human_exp_indices: str = None, 25 | random_seed: int = 0) -> None: 26 | """Visualizes to which regions in the images FID is the most sensitive to.""" 27 | 28 | visualizer = GradCAM(model, reps_real, reps_gen, device) 29 | 30 | # ---------------------------------------------------------------------------- 31 | # Visualize FID sensitivity heatmaps. 32 | heatmaps, labels, images = [], [], [] 33 | 34 | # Sampling image indices 35 | rnd = np.random.RandomState(random_seed) 36 | if human_exp_indices is not None: 37 | with open(human_exp_indices, 'r') as f_in: 38 | index_to_score = json.load(f_in) 39 | indices = [int(idx) for idx in list(index_to_score.keys()) if int(idx) < len(dataset)] 40 | if len(indices) < len(index_to_score): 41 | raise RuntimeWarning("The datasets were subsampled so the human experiment indices will not be accurate. " 42 | "Please use '--nmax_images' with a higher value") 43 | vis_images_indices = [idx for idx in rnd.choice(indices, size=num_rows * num_cols, replace=False)] 44 | vis_images_scores = [index_to_score[str(idx)] for idx in vis_images_indices] 45 | vis_images_indices = [idx for _, idx in sorted(zip(vis_images_scores, vis_images_indices))] # sorting indices in ascending human score 46 | else: 47 | vis_images_indices = rnd.choice(np.arange(len(dataset)), size=num_rows * num_cols, replace=False) 48 | 49 | print('Visualizing heatmaps...') 50 | for idx in tqdm(vis_images_indices): 51 | 52 | # ---------------------------------------------------------------------------- 53 | # Get selected image and do required transforms 54 | image = get_image(dataset, idx, device, perturbation=perturbation) 55 | 56 | # ---------------------------------------------------------------------------- 57 | # Compute and visualize a sensitivity map. 58 | heatmap, label = visualizer.get_map(image, idx) 59 | 60 | heatmaps.append(heatmap) 61 | labels.append(label) 62 | images.append(np.clip(zero_one_scaling(image=image.detach().cpu().numpy().squeeze(0)) * 255, 0.0, 255.0).astype(np.uint8)) 63 | 64 | human_scores = labels 65 | if human_exp_indices is not None: 66 | human_scores = [f"{index_to_score[str(idx)]:0.2f}" for idx in vis_images_indices] 67 | 68 | # ---------------------------------------------------------------------------- 69 | # Create a grid of overlay heatmaps. 70 | heatmap_grid = create_grid(images=heatmaps, labels=labels, num_rows=num_rows, num_cols=num_cols) 71 | image_grid = create_grid(images=images, labels=human_scores, num_rows=num_rows, num_cols=num_cols) 72 | heatmap_grid.save(os.path.join(results_dir, f'sensitivity_grid_{results_suffix}.png')) 73 | image_grid.save(os.path.join(results_dir, f'images_grid_{results_suffix}.png')) 74 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/heatmaps/heatmaps_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import PIL 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def get_image(dataset, idx, device, perturbation=False): 10 | image = dataset[idx] 11 | if isinstance(image, tuple): 12 | # image is likely tuple[images, label] 13 | image = image[0] 14 | if isinstance(image, torch.Tensor): 15 | # add batch dimension 16 | image.unsqueeze_(0) 17 | else: # Special case of data2vec 18 | image = image.data["pixel_values"] 19 | # Convert grayscale to RGB 20 | if image.ndim == 3: 21 | image.unsqueeze_(1) 22 | if image.shape[1] == 1: 23 | image = image.repeat(1, 3, 1, 1) 24 | if perturbation: 25 | image = perturb_image(image) 26 | image = image.to(device) 27 | image.requires_grad = True 28 | return image 29 | 30 | 31 | def get_features(model, image): 32 | features = model(image)[0] 33 | 34 | if not torch.is_tensor(features): # Some encoders output tuples or lists 35 | features = features[0] 36 | 37 | # If model output is not scalar, apply global spatial average pooling. 38 | # This happens if you choose a dimensionality not equal 2048. 39 | if features.dim() > 2: 40 | if features.size(2) != 1 or features.size(3) != 1: 41 | features = torch.nn.functional.adaptive_avg_pool2d(features, output_size=(1, 1)) 42 | 43 | features = features.squeeze(3).squeeze(2) 44 | 45 | if features.dim() == 1: 46 | features = features.unsqueeze(0) 47 | 48 | return features 49 | 50 | 51 | def zero_one_scaling(image: np.ndarray) -> np.ndarray: 52 | """Scales an image to range [0, 1].""" 53 | if np.all(image == 0): 54 | return image 55 | image = image.astype(np.float32) 56 | if (image.max() - image.min()) == 0: 57 | return image 58 | return (image - image.min()) / (image.max() - image.min()) 59 | 60 | 61 | def show_heatmap_on_image(heatmap, image, colormap: int = cv2.COLORMAP_PARULA, heatmap_weight: float = 1.): 62 | image_np = image.detach().cpu().numpy()[0] 63 | _, h, w = image_np.shape 64 | 65 | # Scale heatmap values between 0 and 255. 66 | heatmap = zero_one_scaling(image=heatmap) 67 | heatmap = np.clip((heatmap * 255.0).astype(np.uint8), 0.0, 255.0) 68 | 69 | # Scale to original image size. 70 | heatmap = np.array( 71 | PIL.Image.fromarray(heatmap).resize((w, h), resample=PIL.Image.LANCZOS).convert( 72 | 'L')) 73 | 74 | # Apply color map 75 | heatmap = cv2.applyColorMap(heatmap, colormap) 76 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 77 | heatmap = heatmap.astype(np.float32) / 255 78 | 79 | # Overlay original RGB image and heatmap with specified weights. 80 | scaled_image = zero_one_scaling(image=image_np) 81 | overlay = heatmap_weight * heatmap.transpose(2, 0, 1) + scaled_image 82 | overlay = zero_one_scaling(image=overlay) 83 | overlay = np.clip(overlay * 255, 0.0, 255.0).astype(np.uint8) 84 | 85 | return overlay 86 | 87 | 88 | def create_grid(images: List[np.ndarray], 89 | num_rows: int, 90 | num_cols: int, 91 | labels: Optional[List[str]] = None, 92 | label_loc: Tuple[int, int] = (0, 0), 93 | fontsize: int = 32, 94 | font_path: str = './data/times-new-roman.ttf') -> PIL.Image: 95 | """Creates an image grid.""" 96 | h, w = 256, 256 97 | if labels is None or len(labels)==0: 98 | labels = [None]*len(images) 99 | assert len(images) == len(labels) 100 | font = PIL.ImageFont.truetype(font_path, fontsize) 101 | 102 | grid = PIL.Image.new('RGB', size=(num_cols * h, num_rows * w)) 103 | 104 | for i in range(num_rows): 105 | for j in range(num_cols): 106 | im = cv2.resize(images.pop(0).transpose((1, 2, 0)), dsize=(h, w), interpolation=cv2.INTER_CUBIC) 107 | im = PIL.Image.fromarray(im) 108 | 109 | label = labels.pop(0) 110 | if label is not None: 111 | draw = PIL.ImageDraw.Draw(im) 112 | draw.text(label_loc, f'{label}'.capitalize(), font=font) 113 | 114 | grid.paste(im, box=(j * w, i * h)) 115 | return grid 116 | 117 | 118 | def perturb_image(image): 119 | # image is (B, N, H, W) 120 | _, _, h, w = image.shape 121 | image[:, :, int(2*h/10):int(3*h/10), int(2*w/10):int(3*w/10)] = 0 122 | return image 123 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .fls import compute_fls, compute_fls_overfit 2 | from .ct import compute_CTscore, compute_CTscore_mode, compute_CTscore_mem 3 | from .authpct import compute_authpct 4 | from .sw import sw_approx 5 | from .fd import compute_FD_with_reps, compute_FD_infinity, compute_FD_with_stats, compute_efficient_FD_with_reps 6 | from .mmd import compute_mmd 7 | from .inception_score import compute_inception_score 8 | from .vendi import compute_vendi_score, compute_per_class_vendi_scores 9 | from .prdc import compute_prdc 10 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/authpct.py: -------------------------------------------------------------------------------- 1 | #https://github.com/marcojira/fls/blob/main/metrics/AuthPct.py 2 | import torch 3 | 4 | def compute_authpct(train_feat, gen_feat): 5 | with torch.no_grad(): 6 | train_feat = torch.tensor(train_feat, dtype=torch.float32) 7 | gen_feat = torch.tensor(gen_feat, dtype=torch.float32) 8 | real_dists = torch.cdist(train_feat, train_feat) 9 | 10 | # Hacky way to get it to ignore distance to self in nearest neighbor calculation 11 | real_dists.fill_diagonal_(float("inf")) 12 | gen_dists = torch.cdist(train_feat, gen_feat) 13 | 14 | real_min_dists = real_dists.min(axis=0) 15 | gen_min_dists = gen_dists.min(dim=0) 16 | 17 | # For every synthetic point, find its closest real point, d1 18 | # Then, for that real point, find its closest real point(not itself), d2 19 | # if d2 64: 11 | print(f'running pca for CT score to take first 64 components out of {train_feat.shape[1]}', file=sys.stderr) 12 | pca_xf = PCA(n_components=64).fit(train_feat) 13 | 14 | train_feat = pca_xf.transform(train_feat) 15 | test_feat = pca_xf.transform(test_feat) 16 | gen_feat = pca_xf.transform(gen_feat) 17 | 18 | del pca_xf 19 | return train_feat, test_feat, gen_feat 20 | 21 | def Zu(Pn, Qm, T): 22 | """Extracts distances to training nearest neighbor 23 | L(P_n), L(Q_m), and runs Z-scored Mann Whitney U-test. 24 | For the global test, this is used on the samples within each cell. 25 | Inputs: 26 | Pn: (n X d) np array representing test sample of 27 | length n (with dimension d) 28 | Qm: (m X d) np array representing generated sample of 29 | length n (with dimension d) 30 | T: (l X d) np array representing training sample of 31 | length l (with dimension d) 32 | Ouptuts: 33 | Zu: Z-scored U value. A large value >>0 indicates 34 | underfitting by Qm. A small value <<0 indicates. 35 | """ 36 | m = Qm.shape[0] 37 | n = Pn.shape[0] 38 | 39 | # fit NN model to training sample to get distances to test and generated samples 40 | T_NN = NN(n_neighbors=1).fit(T) 41 | LQm, _ = T_NN.kneighbors(X=Qm, n_neighbors=1) 42 | LPn, _ = T_NN.kneighbors(X=Pn, n_neighbors=1) 43 | 44 | # Get Mann-Whitney U score and manually Z-score it using the conditions of null hypothesis H_0 45 | u, _ = mannwhitneyu(LQm, LPn, alternative="less") 46 | mean = (n * m / 2) - 0.5 # 0.5 is continuity correction 47 | std = np.sqrt(n * m * (n + m + 1) / 12) 48 | Z_u = (u - mean) / std 49 | return Z_u 50 | 51 | 52 | def Zu_cells(Pn, Pn_cells, Qm, Qm_cells, T, T_cells): 53 | """Collects the Zu statistic in each of k cells. 54 | There should be >0 test (Pn) and train (T) samples in each of the cells. 55 | Inputs: 56 | Pn: (n X d) np array representing test sample of length 57 | n (with dimension d) 58 | Pn_cells: (1 X n) np array of integers indicating which 59 | of the k cells each sample belongs to 60 | Qm: (m X d) np array representing generated sample of 61 | length n (with dimension d) 62 | Qm_cells: (1 X m) np array of integers indicating which of the 63 | k cells each sample belongs to 64 | T: (l X d) np array representing training sample of 65 | length l (with dimension d) 66 | T_cells: (1 X l) np array of integers indicating which of the 67 | k cells each sample belongs to 68 | Outputs: 69 | Zus: length k np array, where entry i indicates the Zu score for cell i 70 | """ 71 | # assume cells are labeled 0 to k-1 72 | k = len(np.unique(Pn_cells)) 73 | Zu_cells = np.zeros(k) 74 | 75 | # get samples in each cell and collect Zu 76 | for i in range(k): 77 | Pn_cell_i = Pn[Pn_cells == i] 78 | Qm_cell_i = Qm[Qm_cells == i] 79 | T_cell_i = T[T_cells == i] 80 | # check that the cell has test and training samples present 81 | if len(Pn_cell_i) * len(T_cell_i) == 0: 82 | raise ValueError( 83 | "Cell {:n} lacks test samples and/or training samples. Consider reducing the number of cells in partition.".format( 84 | i 85 | ) 86 | ) 87 | 88 | # if there are no generated samples present, add a 0 for Zu. This cell will be excluded in \Pi_\tau 89 | if len(Qm_cell_i) > 0: 90 | Zu_cells[i] = Zu(Pn_cell_i, Qm_cell_i, T_cell_i) 91 | else: 92 | Zu_cells[i] = 0 93 | print("cell {:n} unrepresented by Qm".format(i), file=sys.stderr) 94 | 95 | return Zu_cells 96 | 97 | 98 | def C_T(Pn, Pn_cells, Qm, Qm_cells, T, T_cells, tau): 99 | """Runs C_T test given samples and their respective cell labels. 100 | The C_T statistic is a weighted average of the in-cell Zu statistics, weighted 101 | by the share of test samples (Pn) in each cell. Cells with an insufficient number 102 | of generated samples (Qm) are not included in the statistic. 103 | Inputs: 104 | Pn: (n X d) np array representing test sample of length 105 | n (with dimension d) 106 | Pn_cells: (1 X n) np array of integers indicating which 107 | of the k cells each sample belongs to 108 | Qm: (m X d) np array representing generated sample of 109 | length n (with dimension d) 110 | Qm_cells: (1 X m) np array of integers indicating which of the 111 | k cells each sample belongs to 112 | T: (l X d) np array representing training sample of 113 | length l (with dimension d) 114 | T_cells: (1 X l) np array of integers indicating which of the 115 | k cells each sample belongs to 116 | tau: (scalar between 0 and 1) fraction of Qm samples that a 117 | cell needs to be included in C_T statistic. 118 | Outputs: 119 | C_T: The C_T statistic for the three samples Pn, Qm, T 120 | """ 121 | 122 | m = Qm.shape[0] 123 | n = Pn.shape[0] 124 | k = np.max(np.unique(T_cells)) + 1 # number of cells 125 | 126 | # First, determine which of the cells have sufficient generated samples (Qm(pi) > tau) 127 | labels, cts = np.unique(Qm_cells, return_counts=True) 128 | Qm_cts = np.zeros(k) 129 | Qm_cts[labels.astype(int)] = cts # put in order of cell label 130 | Qm_of_pi = Qm_cts / m 131 | Pi_tau = ( 132 | Qm_of_pi > tau 133 | ) # binary array selecting which cells have sufficient samples 134 | 135 | # Get the fraction of test samples in each cell Pn(pi) 136 | labels, cts = np.unique(Pn_cells, return_counts=True) 137 | Pn_cts = np.zeros(k) 138 | Pn_cts[labels.astype(int)] = cts # put in order of cell label 139 | Pn_of_pi = Pn_cts / n 140 | 141 | # Now get the in-cell Zu scores 142 | Zu_scores = Zu_cells(Pn, Pn_cells, Qm, Qm_cells, T, T_cells) 143 | 144 | # compute C_T: 145 | C_T = Pn_of_pi[Pi_tau].dot(Zu_scores[Pi_tau]) / np.sum(Pn_of_pi[Pi_tau]) 146 | 147 | return C_T 148 | 149 | def compute_CTscore(train_feat, test_feat, gen_feat): 150 | train_feat, test_feat, gen_feat = preprcoess_ct(train_feat, test_feat, gen_feat) 151 | print('running kmeans for CT score', file=sys.stderr) 152 | km_clf = KMeans(n_clusters=3, n_init=10).fit(train_feat) 153 | 154 | T_labels = km_clf.predict(train_feat) 155 | Pn_labels = km_clf.predict(test_feat) 156 | Qm_labels = km_clf.predict(gen_feat) 157 | 158 | print('calculating CT score', file=sys.stderr) 159 | C_T_score = C_T( 160 | test_feat, 161 | Pn_labels, 162 | gen_feat, 163 | Qm_labels, 164 | train_feat, 165 | T_labels, 166 | tau=20 / len(gen_feat), 167 | ) 168 | 169 | return C_T_score 170 | 171 | 172 | def compute_CTscore_mem(train_feat, test_feat, gen_feat): 173 | # Swap the training and generated sets 174 | return compute_CTscore(gen_feat, test_feat, train_feat) 175 | 176 | 177 | def compute_CTscore_mode(train_feat, test_feat, gen_feat): 178 | # Split the test set and use half as a training set 179 | test_feat1, test_feat2 = np.array_split(test_feat, 2) 180 | return compute_CTscore(test_feat1, test_feat2, gen_feat) 181 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/fd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | from tqdm import tqdm 4 | from sklearn.linear_model import LinearRegression 5 | 6 | def compute_statistics(reps): 7 | """Compute necessary statistics from representtions""" 8 | mu = np.mean(reps, axis=0) 9 | sigma = np.cov(reps, rowvar=False) 10 | mu = np.atleast_1d(mu) 11 | sigma = np.atleast_2d(sigma) 12 | return mu, sigma 13 | 14 | 15 | def compute_FD_with_stats(mu1, mu2, sigma1, sigma2, eps=1e-6): 16 | """ 17 | Numpy implementation of the Frechet Distance. 18 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 19 | and X_2 ~ N(mu_2, C_2) is 20 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 21 | Stable version by Dougal J. Sutherland. 22 | Params: 23 | -- mu1 : Numpy array containing the activations of a layer of the 24 | inception net (like returned by the function 'get_predictions') 25 | for generated samples. 26 | -- mu2 : The sample mean over activations, precalculated on an 27 | representative data set. 28 | -- sigma1: The covariance matrix over activations for generated samples. 29 | -- sigma2: The covariance matrix over activations, precalculated on an 30 | representative data set. 31 | Returns: 32 | -- : The Frechet Distance. 33 | """ 34 | assert mu1.shape == mu2.shape, \ 35 | 'Training and test mean vectors have different lengths' 36 | assert sigma1.shape == sigma2.shape, \ 37 | 'Training and test covariances have different dimensions' 38 | 39 | diff = mu1 - mu2 40 | # Product might be almost singular 41 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 42 | if not np.isfinite(covmean).all(): 43 | msg = ('fd calculation produces singular product; ' 44 | 'adding %s to diagonal of cov estimates') % eps 45 | print(msg) 46 | offset = np.eye(sigma1.shape[0]) * eps 47 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 48 | 49 | # Numerical error might give slight imaginary component 50 | if np.iscomplexobj(covmean): 51 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 52 | m = np.max(np.abs(covmean.imag)) 53 | raise ValueError('Imaginary component {}'.format(m)) 54 | covmean = covmean.real 55 | 56 | tr_covmean = np.trace(covmean) 57 | 58 | # Return mean and covariance terms and intermediate steps, as well as FD 59 | mean_term = diff.dot(diff) 60 | tr1, tr2 = np.trace(sigma1), np.trace(sigma2) 61 | cov_term = tr1 + tr2 - 2 * tr_covmean 62 | 63 | return mean_term+cov_term 64 | 65 | 66 | def compute_FD_with_reps(reps1, reps2, eps=1e-6): 67 | """ 68 | Params: 69 | -- reps1 : activations of a representative data set (usually train) 70 | -- reps2 : activations of generated data set 71 | Returns: 72 | -- : The Frechet Distance. 73 | """ 74 | mu1, sigma1 = compute_statistics(reps1) 75 | mu2, sigma2 = compute_statistics(reps2) 76 | return compute_FD_with_stats(mu1, mu2, sigma1, sigma2, eps=eps) 77 | 78 | 79 | def compute_efficient_FD_with_reps(reps1, reps2): 80 | """ 81 | A more efficient computation of FD as proposed at the following link: 82 | https://www.reddit.com/r/MachineLearning/comments/12hv2u6/d_a_better_way_to_compute_the_fr%C3%A9chet_inception/ 83 | 84 | Confirmed to return identical values as the standard calculation above on all datasets we in our work. 85 | """ 86 | mu1, sigma1 = compute_statistics(reps1) 87 | mu2, sigma2 = compute_statistics(reps2) 88 | sqrt_trace = np.real(linalg.eigvals(sigma1 @ sigma2)**0.5).sum() 89 | return ((mu1 - mu2)**2).sum() + sigma1.trace() + sigma2.trace() - 2 * sqrt_trace 90 | 91 | 92 | 93 | def compute_FD_infinity(reps1, reps2, num_points=15): 94 | ''' 95 | reps1: 96 | representation of training images 97 | reps2: 98 | representatio of generated images 99 | num_points: 100 | Number of FD_N we evaluate to fit a line. 101 | Default: 15 102 | 103 | ''' 104 | fds = [] 105 | 106 | # Choose the number of images to evaluate FID_N at regular intervals over N 107 | fd_batches = np.linspace(min(5000, max(len(reps2)//10, 2)), len(reps2), num_points).astype('int32') 108 | mu1, sigma1 = compute_statistics(reps1) 109 | 110 | pbar = tqdm(total=num_points, desc='FID-infinity batches') 111 | # Evaluate FD_N 112 | rng = np.random.default_rng() 113 | for fd_batch_size in fd_batches: 114 | # sample, replacement allowed for different sample sizes 115 | fd_activations = rng.choice(reps2, fd_batch_size, replace=False) 116 | mu2, sigma2 = compute_statistics(fd_activations) 117 | fds.append(compute_FD_with_stats(mu1, mu2, sigma1, sigma2, eps=1e-6)) 118 | pbar.update(1) 119 | del pbar 120 | fds = np.array(fds).reshape(-1, 1) 121 | 122 | # Fit linear regression 123 | reg = LinearRegression().fit(1/fd_batches.reshape(-1, 1), fds) 124 | fd_infinity = reg.predict(np.array([[0]]))[0,0] 125 | 126 | return fd_infinity 127 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/fls.py: -------------------------------------------------------------------------------- 1 | """From the FLS repo (https://github.com/marcojira/fls)""" 2 | 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | def preprocess_fls(train_feat, baseline_feat, test_feat, gen_feat): 8 | # Assert correct device 9 | train_feat = torch.tensor(train_feat, dtype=torch.float32).cuda() 10 | baseline_feat = torch.tensor(baseline_feat, dtype=torch.float32).cuda() 11 | test_feat = torch.tensor(test_feat, dtype=torch.float32).cuda() 12 | gen_feat = torch.tensor(gen_feat, dtype=torch.float32).cuda() 13 | 14 | # Normalize features to 0 mean, unit variance 15 | all_features = torch.cat((train_feat, baseline_feat, test_feat, gen_feat), dim=0) 16 | mean = all_features.mean(dim=0) 17 | std = all_features.std(dim=0) 18 | 19 | def normalize(feat): 20 | return (feat - mean) / std 21 | 22 | train_feat = normalize(train_feat) 23 | baseline_feat = normalize(baseline_feat) 24 | test_feat = normalize(test_feat) 25 | gen_feat = normalize(gen_feat) 26 | return train_feat, baseline_feat, test_feat, gen_feat 27 | 28 | 29 | def tensor_to_numpy(tensor): 30 | """Shortcut to get a np.array corresponding to tensor""" 31 | return tensor.detach().cpu().numpy() 32 | 33 | 34 | def compute_dists(x_data, x_kernel): 35 | """Returns the dists tensor of all L2^2 distances between samples from x_data and x_kernel""" 36 | dists = (torch.cdist(x_data, x_kernel)) ** 2 37 | return dists.detach() 38 | 39 | 40 | def get_pairwise_likelihood(x_data, x_kernel, log_sigmas): 41 | dists = compute_dists(x_data, x_kernel) 42 | exponent_term = (-0.5 * dists) / torch.exp(log_sigmas) 43 | exponent_term -= (x_kernel.shape[1] / 2) * log_sigmas 44 | exponent_term += torch.log(torch.tensor(1 / dists.shape[1])) 45 | return exponent_term 46 | 47 | 48 | def nll(dists, log_sigmas, dim, lambd=0, detailed=False): 49 | """Computes the negative KDE log-likelihood using the distances between x_data and x_kernel 50 | Args: 51 | - dists: N x M tensor where the i,j entry is the squared L2 distance between the i-th row of x_data and the j-th row of x_kernel 52 | - x_data is N x dim and x_kernel is M x dim (where x_kernel are the points of the KDE) 53 | - dists is passed as an argument so that it can be computed once and cached (as it is O(N x M x dim)) 54 | - log_sigmas: Tensor of size M where the i-th entry is the log of the bandwidth for the i-th kernel point 55 | - dim: Dimension of the data (passed as argument since it cannot be inferred from the dists) 56 | - lambd: Optional regularization parameter for the sigmas (default 0) 57 | - detailed: If True, returns the NLL of each datapoint as well as the mean 58 | Returns: The NLL of the above 59 | """ 60 | exponent_term = (-0.5 * dists) / torch.exp(log_sigmas) 61 | 62 | # Here we use that dividing by x is equivalent to multiplying by e^{-ln(x)} 63 | # allows for use of logsumexp 64 | exponent_term -= (dim / 2) * log_sigmas 65 | exponent_term += torch.log(torch.tensor(1 / dists.shape[1])) 66 | inner_term = torch.logsumexp(exponent_term, dim=1) 67 | 68 | total_nll = torch.mean(-inner_term) 69 | 70 | reg_term = lambd / 2 * torch.norm(torch.exp(-log_sigmas)) ** 2 71 | final_nll = total_nll + reg_term 72 | 73 | if detailed: 74 | return final_nll, -inner_term 75 | 76 | return final_nll 77 | 78 | 79 | def optimize_sigmas(x_data, x_kernel, init_val=1, verbose=False): 80 | """Find the sigmas that minimize the NLL of x_data under a kernel given by x_kernel 81 | Args: 82 | - x_data: N x dim tensor we are evaluating the NLL of 83 | - x_kernel: M x dim tensor of points to use as kernels for KDE 84 | - init_val: Initial value of tensor of log_sigmas 85 | - verbose: Whether to print optimization progress 86 | Returns: (log_sigmas, losses) 87 | - log_sigmas: Tensor of log of sigmas 88 | - losses: List of losses at each step of optimization 89 | """ 90 | # Tracking 91 | losses = [] 92 | 93 | log_sigmas = torch.ones(x_kernel.shape[0], requires_grad=True, device="cuda") 94 | log_sigmas.data = init_val * log_sigmas 95 | 96 | optim = torch.optim.Adam([log_sigmas], lr=0.5) 97 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=[50], gamma=0.1) 98 | 99 | dim = x_data.shape[1] 100 | 101 | # Precompute dists 102 | dists = compute_dists(x_data, x_kernel) 103 | 104 | for i in range(100): 105 | loss = nll(dists, log_sigmas, dim) 106 | 107 | optim.zero_grad() 108 | loss.backward() 109 | optim.step() 110 | scheduler.step() 111 | 112 | # Here we clamp log_sigmas to stop values exploding for identical samples 113 | with torch.no_grad(): 114 | log_sigmas.data = log_sigmas.clamp(-100, 20).data 115 | 116 | if verbose and i % 25 == 0: 117 | print( 118 | f"Loss: {loss:.2f} | Sigmas: min({torch.min(log_sigmas):.4f}), mean({torch.mean(log_sigmas):.4f}), max({torch.max(log_sigmas):.2f})" 119 | ) 120 | 121 | losses.append(tensor_to_numpy(loss)) 122 | 123 | return log_sigmas, losses 124 | 125 | 126 | def evaluate_set(evaluated_set, x_kernel, log_sigmas): 127 | """Gets the NLL of the test set using given kernel/bandwidths""" 128 | dists = compute_dists(evaluated_set, x_kernel) 129 | nlls = nll(dists, log_sigmas, x_kernel.shape[1], detailed=True) 130 | return tensor_to_numpy(nlls[0]), nlls[1] 131 | 132 | 133 | def compute_fls(train_feat, baseline_feat, test_feat, gen_feat): 134 | """From the FLS repo https://github.com/marcojira/fls/""" 135 | train_feat, baseline_feat, test_feat, gen_feat = preprocess_fls( 136 | train_feat, baseline_feat, test_feat, gen_feat 137 | ) 138 | 139 | gen_log_sigmas, gen_losses = optimize_sigmas(train_feat, gen_feat, init_val=0) 140 | 141 | # Get gen_nll 142 | gen_nlls = evaluate_set(test_feat, gen_feat, gen_log_sigmas) 143 | gen_nll = gen_nlls[0].item() 144 | 145 | # Get baseline_nll 146 | base_log_sigmas, base_losses = optimize_sigmas( 147 | train_feat, baseline_feat, init_val=0 148 | ) 149 | baseline_nlls = evaluate_set(test_feat, baseline_feat, base_log_sigmas) 150 | baseline_nll = baseline_nlls[0].item() 151 | 152 | diff = 2 * (gen_nll - baseline_nll) / train_feat.shape[1] 153 | score = math.e ** (-diff) * 100 154 | return score 155 | 156 | 157 | def compute_fls_overfit(train_feat, baseline_feat, test_feat, gen_feat): 158 | """From the FLS repo https://github.com/marcojira/fls/""" 159 | train_feat, baseline_feat, test_feat, gen_feat = preprocess_fls( 160 | train_feat, baseline_feat, test_feat, gen_feat 161 | ) 162 | 163 | gen_log_sigmas, gen_losses = optimize_sigmas(train_feat, gen_feat, init_val=0) 164 | # Ensure both sets have the same amount of data points 165 | size = min(test_feat.shape[0], train_feat.shape[0]) 166 | 167 | train_lls = get_pairwise_likelihood( 168 | train_feat[:size], gen_feat, gen_log_sigmas 169 | ) 170 | test_lls = get_pairwise_likelihood( 171 | test_feat[:size], gen_feat, gen_log_sigmas 172 | ) 173 | 174 | ll_diff = train_lls.logsumexp(axis=0) - test_lls.logsumexp(axis=0) 175 | score = ((ll_diff > 0).sum().item() / ll_diff.shape[0]) * 100 176 | return score - 50 177 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # a refactored version from https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | import numpy as np 7 | from scipy.stats import entropy 8 | 9 | def compute_inception_score(model, DataLoader=None, splits=10, device=None): 10 | """Computes the inception score of the generated images imgs""" 11 | score = {} 12 | preds = get_preds(model, DataLoader, device) 13 | inecption_score, std = calculate_score(preds, splits=splits, N=DataLoader.nimages) 14 | score['inception score'] = inecption_score 15 | score['inception std'] = std 16 | return score 17 | 18 | def get_preds(model, DataLoader, device): 19 | model.eval() 20 | start_idx = 0 21 | 22 | for ibatch, batch in enumerate(tqdm(DataLoader.data_loader)): 23 | if isinstance(batch, list): 24 | # batch is likely list[array(images), array(labels)] 25 | batch = batch[0] 26 | 27 | # Convert grayscale to RGB 28 | if batch.ndim == 3: 29 | batch.unsqueeze_(1) 30 | if batch.shape[1] == 1: 31 | batch = batch.repeat(1, 3, 1, 1) 32 | 33 | batch = batch.to(device) 34 | 35 | with torch.no_grad(): 36 | pred = model(batch) 37 | if not torch.is_tensor(pred): # Some encoders output tuples or lists 38 | pred = pred[0] 39 | pred = pred.cpu().numpy() 40 | 41 | 42 | if ibatch==0: 43 | # initialize output array with full dataset size 44 | dims = pred.shape[-1] 45 | pred_arr = np.empty((DataLoader.nimages, dims)) 46 | 47 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 48 | 49 | start_idx = start_idx + pred.shape[0] 50 | return pred_arr 51 | 52 | def calculate_score(preds, splits=10, N=50000, shuffle=True, rng_seed=2020): 53 | if shuffle: 54 | rng = np.random.RandomState(rng_seed) 55 | preds = preds[rng.permutation(N), :] 56 | # Compute the mean kl-div 57 | split_scores = [] 58 | for k in range(splits): 59 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 60 | py = np.mean(part, axis=0) 61 | scores = [] 62 | for i in range(part.shape[0]): 63 | pyx = part[i, :] 64 | scores.append(entropy(pyx, py)) 65 | split_scores.append(np.exp(np.mean(scores))) 66 | 67 | return np.mean(split_scores), np.std(split_scores) 68 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/mmd.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from sklearn.metrics.pairwise import polynomial_kernel 3 | import numpy as np 4 | 5 | def compute_mmd(feat_real, feat_gen, n_subsets=100, subset_size=1000, **kernel_args): 6 | m = min(feat_real.shape[0], feat_gen.shape[0]) 7 | subset_size = min(subset_size, m) 8 | mmds = np.zeros(n_subsets) 9 | choice = np.random.choice 10 | 11 | with tqdm(range(n_subsets), desc='MMD') as bar: 12 | for i in bar: 13 | g = feat_real[choice(len(feat_real), subset_size, replace=False)] 14 | r = feat_gen[choice(len(feat_gen), subset_size, replace=False)] 15 | o = compute_polynomial_mmd(g, r, **kernel_args) 16 | mmds[i] = o 17 | bar.set_postfix({'mean': mmds[:i+1].mean()}) 18 | return mmds 19 | 20 | 21 | def compute_polynomial_mmd(feat_r, feat_gen, degree=3, gamma=None, coef0=1): 22 | # use k(x, y) = (gamma + coef0)^degree 23 | # default gamma is 1 / dim 24 | X = feat_r 25 | Y = feat_gen 26 | 27 | K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) 28 | K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) 29 | K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) 30 | 31 | return _mmd2_and_variance(K_XX, K_XY, K_YY) 32 | 33 | 34 | def _mmd2_and_variance(K_XX, K_XY, K_YY): 35 | # based on https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py 36 | # but changed to not compute the full kernel matrix at once 37 | m = K_XX.shape[0] 38 | assert K_XX.shape == (m, m) 39 | assert K_XY.shape == (m, m) 40 | assert K_YY.shape == (m, m) 41 | 42 | diag_X = np.diagonal(K_XX) 43 | diag_Y = np.diagonal(K_YY) 44 | 45 | Kt_XX_sums = K_XX.sum(axis=1) - diag_X 46 | Kt_YY_sums = K_YY.sum(axis=1) - diag_Y 47 | K_XY_sums_0 = K_XY.sum(axis=0) 48 | 49 | Kt_XX_sum = Kt_XX_sums.sum() 50 | Kt_YY_sum = Kt_YY_sums.sum() 51 | K_XY_sum = K_XY_sums_0.sum() 52 | 53 | mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1)) 54 | mmd2 -= 2 * K_XY_sum / (m * m) 55 | return mmd2 56 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/prdc.py: -------------------------------------------------------------------------------- 1 | """ 2 | prdc from https://github.com/clovaai/generative-evaluation-prdc 3 | Copyright (c) 2020-present NAVER Corp. 4 | MIT license 5 | Modified to also report realism score from https://arxiv.org/abs/1904.06991 6 | """ 7 | import numpy as np 8 | import sklearn.metrics 9 | import sys 10 | 11 | __all__ = ['compute_prdc'] 12 | 13 | 14 | def compute_pairwise_distance(data_x, data_y=None): 15 | """ 16 | Args: 17 | data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) 18 | data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) 19 | Returns: 20 | numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. 21 | """ 22 | if data_y is None: 23 | data_y = data_x 24 | dists = sklearn.metrics.pairwise_distances( 25 | data_x, data_y, metric='euclidean', n_jobs=8) 26 | return dists 27 | 28 | 29 | def get_kth_value(unsorted, k, axis=-1): 30 | """ 31 | Args: 32 | unsorted: numpy.ndarray of any dimensionality. 33 | k: int 34 | Returns: 35 | kth values along the designated axis. 36 | """ 37 | indices = np.argpartition(unsorted, k, axis=axis)[..., :k] 38 | k_smallests = np.take_along_axis(unsorted, indices, axis=axis) 39 | kth_values = k_smallests.max(axis=axis) 40 | return kth_values 41 | 42 | 43 | def compute_nearest_neighbour_distances(input_features, nearest_k): 44 | """ 45 | Args: 46 | input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 47 | nearest_k: int 48 | Returns: 49 | Distances to kth nearest neighbours. 50 | """ 51 | distances = compute_pairwise_distance(input_features) 52 | radii = get_kth_value(distances, k=nearest_k + 1, axis=-1) 53 | return radii 54 | 55 | 56 | def compute_prdc(real_features, fake_features, nearest_k, realism=False): 57 | """ 58 | Computes precision, recall, density, and coverage given two manifolds. 59 | 60 | Args: 61 | real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 62 | fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) 63 | nearest_k: int. 64 | Returns: 65 | dict of precision, recall, density, and coverage. 66 | """ 67 | 68 | print('Num real: {} Num fake: {}' 69 | .format(real_features.shape[0], fake_features.shape[0]), file=sys.stderr) 70 | 71 | real_nearest_neighbour_distances = compute_nearest_neighbour_distances( 72 | real_features, nearest_k) 73 | fake_nearest_neighbour_distances = compute_nearest_neighbour_distances( 74 | fake_features, nearest_k) 75 | distance_real_fake = compute_pairwise_distance( 76 | real_features, fake_features) 77 | 78 | precision = ( 79 | distance_real_fake < 80 | np.expand_dims(real_nearest_neighbour_distances, axis=1) 81 | ).any(axis=0).mean() 82 | 83 | recall = ( 84 | distance_real_fake < 85 | np.expand_dims(fake_nearest_neighbour_distances, axis=0) 86 | ).any(axis=1).mean() 87 | 88 | density = (1. / float(nearest_k)) * ( 89 | distance_real_fake < 90 | np.expand_dims(real_nearest_neighbour_distances, axis=1) 91 | ).sum(axis=0).mean() 92 | 93 | coverage = ( 94 | distance_real_fake.min(axis=1) < 95 | real_nearest_neighbour_distances 96 | ).mean() 97 | 98 | d = dict(precision=precision, recall=recall, 99 | density=density, coverage=coverage) 100 | 101 | if realism: 102 | """ 103 | Large errors, even if they are rare, would undermine the usefulness of the metric. 104 | We tackle this problem by discarding half of the hyperspheres with the largest radii. 105 | In other words, the maximum in Equation 3 is not taken over all φr ∈ Φr but only over 106 | those φr whose associated hypersphere is smaller than the median. 107 | """ 108 | mask = real_nearest_neighbour_distances < np.median(real_nearest_neighbour_distances) 109 | 110 | d['realism'] = ( 111 | np.expand_dims(real_nearest_neighbour_distances[mask], axis=1)/distance_real_fake[mask] 112 | ).max(axis=0) 113 | 114 | return d 115 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/sw.py: -------------------------------------------------------------------------------- 1 | from numpy import linalg 2 | 3 | def sw_approx(X, Y): 4 | '''Approximate Sliced W2 without 5 | Monte Carlo From https://arxiv.org/pdf/2106.15427.pdf''' 6 | d = X.shape[1] 7 | mean_X = X.mean(axis=0) 8 | mean_Y = Y.mean(axis=0) 9 | mean_term = linalg.norm(mean_X - mean_Y) ** 2 / d 10 | m2_Xc = (linalg.norm(X - mean_X, axis=1) ** 2).mean() / d 11 | m2_Yc = (linalg.norm(Y - mean_Y, axis=1) ** 2).mean() / d 12 | approx_sw = (mean_term + (m2_Xc ** (1 / 2) - m2_Yc ** (1 / 2)) ** 2) ** (1/2) 13 | return approx_sw 14 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/metrics/vendi.py: -------------------------------------------------------------------------------- 1 | from sklearn import preprocessing 2 | from sklearn.metrics.pairwise import polynomial_kernel 3 | import scipy 4 | import scipy.linalg 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | def compute_vendi_score(X, q=1, normalize=True, kernel='linear'): 9 | if normalize: 10 | X = preprocessing.normalize(X, axis=1) 11 | n = X.shape[0] 12 | if kernel == 'linear': 13 | S = X @ X.T 14 | elif kernel == 'polynomial': 15 | S = polynomial_kernel(X, degree=3, gamma=None, coef0=1) # currently hardcoding kernel params to match KID 16 | else: 17 | raise NotImplementedError("kernel not implemented") 18 | # print('similarity matrix of shape {}'.format(S.shape)) 19 | w = scipy.linalg.eigvalsh(S / n) 20 | return np.exp(entropy_q(w, q=q)) 21 | 22 | def entropy_q(p, q=1): 23 | p_ = p[p > 0] 24 | if q == 1: 25 | return -(p_ * np.log(p_)).sum() 26 | if q == "inf": 27 | return -np.log(np.max(p)) 28 | return np.log((p_ ** q).sum()) / (1 - q) 29 | 30 | def compute_per_class_vendi_scores(reps, labels): 31 | num_classes = len(np.unique(labels)) 32 | vendi_per_class = np.zeros(shape=num_classes) 33 | with tqdm(total=num_classes) as pbar: 34 | for i in range(num_classes): 35 | reps_class = reps[labels==i] 36 | vendi_per_class[i] = compute_vendi_score(reps_class) 37 | pbar.update(1) 38 | return vendi_per_class 39 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .inception import InceptionEncoder 2 | from .swav import ResNet50Encoder #, ResNet18Encoder 3 | from .mae import VisionTransformerEncoder 4 | from .data2vec import HuggingFaceTransformerEncoder 5 | from .clip import CLIPEncoder 6 | from .convnext import ConvNeXTEncoder 7 | from .dinov2 import DINOv2Encoder 8 | from .load_encoder import load_encoder, MODELS 9 | from .simclr import SimCLRResNetEncoder -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import open_clip 3 | from torchvision.transforms import Normalize, InterpolationMode 4 | import torchvision.transforms.functional as TF 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | from .encoder import Encoder 8 | from ..resizer import pil_resize 9 | 10 | ARCH_WEIGHT_DEFAULTS = { 11 | 'ViT-B-32': 'laion2b_s34b_b79k', 12 | 'ViT-B-16': 'laion2b_s34b_b88k', 13 | 'ViT-L-14': 'datacomp_xl_s13b_b90k', 14 | 'ViT-bigG-14': 'laion2b_s39b_b160k', 15 | } 16 | class CLIPEncoder(Encoder): 17 | def setup(self, arch:bool=None, pretrained_weights:bool=None, clean_resize:bool=False, depth:int=0): 18 | 19 | if arch is None: 20 | arch = 'ViT-L-14' 21 | if pretrained_weights is None: 22 | pretrained_weights=ARCH_WEIGHT_DEFAULTS[arch] 23 | 24 | self.model = open_clip.create_model(arch, pretrained_weights) 25 | self.clean_resize = clean_resize 26 | self.depth = depth 27 | 28 | def transform(self, image): 29 | mean = (0.48145466, 0.4578275, 0.40821073) 30 | std = (0.26862954, 0.26130258, 0.27577711) 31 | size = self.model.visual.image_size 32 | if self.clean_resize: 33 | image = pil_resize(image, size) 34 | else: 35 | image = TF.resize(image, size, interpolation=InterpolationMode.BICUBIC).convert('RGB') 36 | image = TF.to_tensor(image) 37 | image = TF.center_crop(image, size) 38 | return Normalize(mean, std)(image) 39 | 40 | def forward(self, x: torch.Tensor): 41 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] 42 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 43 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 44 | x = torch.cat( 45 | [self.model.visual.class_embedding.to(x.dtype) + 46 | torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 47 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 48 | x = x + self.model.visual.positional_embedding.to(x.dtype) 49 | 50 | x = self.model.visual.patch_dropout(x) 51 | x = self.model.visual.ln_pre(x) 52 | 53 | x = x.permute(1, 0, 2) # NLD -> LND 54 | blocks = self.model.visual.transformer.resblocks 55 | if self.depth<0: 56 | blocks = self.model.visual.transformer.resblocks[:self.depth] 57 | for r in blocks: 58 | x = r(x, attn_mask=None) 59 | x = x.permute(1, 0, 2) # LND -> NLD 60 | 61 | if self.model.visual.global_average_pool: 62 | x = x.mean(dim=1) 63 | else: 64 | x = x[:, 0] 65 | 66 | x = self.model.visual.ln_post(x) 67 | 68 | if self.model.visual.proj is not None and self.depth==1: 69 | x = x @ self.model.visual.proj 70 | 71 | return x 72 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/convnext.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | from torchvision import transforms 3 | from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize, CenterCrop 4 | import torchvision.transforms.functional as TF 5 | 6 | from .encoder import Encoder 7 | from ..resizer import pil_resize 8 | from timm.models import create_model 9 | from timm.data.constants import \ 10 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ 11 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD 12 | import sys 13 | 14 | class ConvNeXTEncoder(Encoder): 15 | """ 16 | requires timm version: 0.8.19.dev0 17 | model_arch options: 18 | convnext_xlarge_in22k (imagenet 21k); default 19 | convnext_xxlarge.clip_laion2b_rewind (clip objective trained on laion2b) 20 | 21 | see more options https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py 22 | 23 | """ 24 | def setup(self, arch='convnext_large_in22k', clean_resize=False): 25 | if arch==None: arch = 'convnext_large_in22k' 26 | self.arch = arch 27 | self.model = create_model( 28 | arch, 29 | pretrained=True, 30 | ) 31 | self.model.eval() 32 | 33 | if arch == "convnext_large_in22k": 34 | self.input_size = 224 35 | elif arch in ["convnext_base.clip_laion2b_augreg", "convnext_xxlarge.clip_laion2b_rewind"]: 36 | self.input_size = 256 37 | 38 | self.clean_resize = clean_resize 39 | self.build_transform() 40 | 41 | 42 | def build_transform(self): 43 | # get mean & std based on the model arch 44 | if self.arch == "convnext_large_in22k": 45 | print("IMAGENET MEAN STD", file=sys.stderr) 46 | mean = IMAGENET_DEFAULT_MEAN 47 | std = IMAGENET_DEFAULT_STD 48 | elif "clip" in self.arch: 49 | print("OPENAI CLIP MEAN STD", file=sys.stderr) 50 | mean = OPENAI_CLIP_MEAN 51 | std = OPENAI_CLIP_STD 52 | 53 | t = [] 54 | 55 | # warping (no cropping) when evaluated at 384 or larger 56 | if self.input_size >= 384: 57 | t.append( 58 | transforms.Resize((self.input_size, self.input_size), 59 | interpolation=transforms.InterpolationMode.BICUBIC), 60 | ) 61 | print(f"Warping {self.input_size} size input images...", file=sys.stderr) 62 | else: 63 | size = 256 64 | t.append( 65 | # to maintain same ratio 66 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 67 | ) 68 | t.append(transforms.CenterCrop(self.input_size)) 69 | 70 | t.append(transforms.ToTensor()) 71 | t.append(transforms.Normalize(mean, std)) 72 | self.transform_ops = transforms.Compose(t) 73 | return 74 | 75 | def transform(self, image): 76 | return self.transform_ops(image) 77 | 78 | def forward(self, x): 79 | # forward features + global_pool + norm + flatten => output dims () 80 | outputs = self.model.forward_features(x) 81 | outputs = self.model.head.global_pool(outputs) 82 | outputs = self.model.head.norm(outputs) 83 | outputs = self.model.head.flatten(outputs) 84 | return outputs 85 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/data2vec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torchvision.transforms as TF 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | from transformers import Data2VecVisionConfig, Data2VecVisionModel 21 | from transformers import AutoFeatureExtractor, AutoModel, AutoConfig, AutoImageProcessor 22 | 23 | from .encoder import Encoder 24 | 25 | try: 26 | from torchvision.models.utils import load_state_dict_from_url 27 | except ImportError: 28 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 29 | 30 | # DATA2VEC_WEIGHTS_URL = 'https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet.pt' 31 | 32 | from transformers import Data2VecVisionConfig, Data2VecVisionModel 33 | 34 | 35 | class HuggingFaceTransformer(nn.Module): 36 | """ Vision Transformer with support for global average pooling 37 | """ 38 | def __init__(self, **kwargs): 39 | super(HuggingFaceTransformer, self).__init__(**kwargs) 40 | 41 | # checkpoint = load_state_dict_from_url(DATA2VEC_WEIGHTS_URL, progress=True) 42 | # print(checkpoint) 43 | 44 | # self.model = AutoModel.from_pretrained("facebook/data2vec-vision-base", add_pooling_layer=True) 45 | self.model = AutoModel.from_pretrained("facebook/data2vec-vision-large", add_pooling_layer=True) 46 | 47 | def forward(self, inputs, mask=None, **kwargs): 48 | 49 | outputs = self.model(inputs, return_dict=True) 50 | 51 | # print('Encoder out shape = ', encoder_out.shape) 52 | return outputs.pooler_output 53 | 54 | 55 | class HuggingFaceTransformerEncoder(Encoder): 56 | def setup(self, ckpt=None): 57 | 58 | # self.image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base") 59 | self.image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-large") 60 | 61 | self.model = HuggingFaceTransformer() 62 | 63 | def transform(self, image): 64 | return self.image_processor(image, return_tensors="pt") 65 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/dinov2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | import torchvision.transforms as TF 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | import sys 19 | 20 | from .encoder import Encoder 21 | 22 | from ..resizer import pil_resize 23 | 24 | VALID_ARCHITECTURES = [ 25 | 'vits14', 26 | 'vitb14', 27 | 'vitl14', 28 | 'vitg14', 29 | ] 30 | 31 | class DINOv2Encoder(Encoder): 32 | def setup(self, arch=None, clean_resize:bool=False): 33 | if arch is None: 34 | arch = 'vitl14' 35 | 36 | self.arch = arch 37 | 38 | arch_str = f'dinov2_{self.arch}' 39 | 40 | if self.arch not in VALID_ARCHITECTURES: 41 | sys.exit(f"arch={self.arch} is not a valid architecture. Choose from {VALID_ARCHITECTURES}") 42 | 43 | self.model = torch.hub.load('facebookresearch/dinov2', arch_str) 44 | self.clean_resize = clean_resize 45 | 46 | def transform(self, image): 47 | 48 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 49 | imagenet_std = np.array([0.229, 0.224, 0.225]) 50 | 51 | if self.clean_resize: 52 | image = pil_resize(image, (224, 224)) 53 | else: 54 | image = TF.Compose([ 55 | TF.Resize((224, 224), TF.InterpolationMode.BICUBIC), 56 | TF.ToTensor(), 57 | ])(image) 58 | 59 | return TF.Normalize(imagenet_mean, imagenet_std)(image) 60 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class Encoder(ABC, nn.Module): 7 | def __init__(self, *args, **kwargs): 8 | nn.Module.__init__(self) 9 | self.setup(*args, **kwargs) 10 | self.name = 'encoder' 11 | 12 | @abstractmethod 13 | def setup(self, *args, **kwargs): 14 | pass 15 | 16 | @abstractmethod 17 | def transform(self, x): 18 | """Converts a PIL Image to an input for the model""" 19 | pass 20 | 21 | def forward(self, *args, **kwargs): 22 | return self.model(*args, **kwargs) 23 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/load_encoder.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from .encoder import Encoder 4 | from .inception import InceptionEncoder 5 | from .swav import ResNet50Encoder #ResNet18Encoder, ResNet18MocoEncoder, 6 | from .mae import VisionTransformerEncoder 7 | from .data2vec import HuggingFaceTransformerEncoder 8 | from .clip import CLIPEncoder 9 | from .pixel import PixelEncoder 10 | from .convnext import ConvNeXTEncoder 11 | from .clip import CLIPEncoder 12 | from .dinov2 import DINOv2Encoder 13 | from .simclr import SimCLRResNetEncoder 14 | MODELS = { 15 | "inception" : InceptionEncoder, 16 | "sinception" : InceptionEncoder, 17 | "mae": VisionTransformerEncoder, 18 | "data2vec": HuggingFaceTransformerEncoder, 19 | "swav": ResNet50Encoder, 20 | "clip": CLIPEncoder, 21 | 'pixel': PixelEncoder, 22 | "convnext": ConvNeXTEncoder, 23 | "dinov2": DINOv2Encoder, 24 | "simclr": SimCLRResNetEncoder, 25 | } 26 | 27 | 28 | def load_encoder(model_name, device, **kwargs): 29 | """Load feature extractor""" 30 | 31 | model_cls = MODELS[model_name] 32 | 33 | # Get names of model_cls.setup arguments 34 | signature = inspect.signature(model_cls.setup) 35 | arguments = list(signature.parameters.keys()) 36 | arguments = arguments[1:] # Omit `self` arg 37 | 38 | # Initialize model using the `arguments` that have been passed in the `kwargs` dict 39 | encoder = model_cls(**{arg: kwargs[arg] for arg in arguments if arg in kwargs}) 40 | encoder.name = model_name 41 | 42 | assert isinstance(encoder, Encoder), "Can only get representations with Encoder subclasses!" 43 | 44 | return encoder.to(device) 45 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torchvision.transforms as TF 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | import timm.models.vision_transformer 21 | 22 | from .encoder import Encoder 23 | from ..resizer import pil_resize 24 | from .util.pos_embed import interpolate_pos_embed 25 | import sys 26 | 27 | try: 28 | from torchvision.models.utils import load_state_dict_from_url 29 | except ImportError: 30 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 31 | 32 | MAE_WEIGHTS_URL = 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth' 33 | 34 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 35 | """ Vision Transformer with support for global average pooling 36 | """ 37 | def __init__(self, global_pool=False, **kwargs): 38 | super(VisionTransformer, self).__init__(**kwargs) 39 | 40 | self.global_pool = global_pool 41 | if self.global_pool: 42 | norm_layer = kwargs['norm_layer'] 43 | embed_dim = kwargs['embed_dim'] 44 | self.fc_norm = norm_layer(embed_dim) 45 | 46 | del self.norm # remove the original norm 47 | 48 | def forward_features(self, x): 49 | 50 | B = x.shape[0] 51 | x = self.patch_embed(x) 52 | 53 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 54 | x = torch.cat((cls_tokens, x), dim=1) 55 | x = x + self.pos_embed 56 | x = self.pos_drop(x) 57 | 58 | for blk in self.blocks: 59 | x = blk(x) 60 | 61 | if self.global_pool: 62 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 63 | outcome = self.fc_norm(x) 64 | else: 65 | x = self.norm(x) 66 | outcome = x[:, 0] 67 | 68 | return outcome 69 | 70 | def vit_large_patch16(**kwargs): 71 | model = VisionTransformer( 72 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | return model 75 | 76 | class VisionTransformerEncoder(Encoder): 77 | def setup(self, model=None, ckpt=None, clean_resize=False): 78 | 79 | 80 | # Model at https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth 81 | checkpoint = load_state_dict_from_url(MAE_WEIGHTS_URL, progress=True) 82 | self.model = vit_large_patch16() 83 | self.clean_resize = clean_resize 84 | 85 | checkpoint_model = checkpoint['model'] 86 | state_dict = self.model.state_dict() 87 | for k in ['head.weight', 'head.bias']: 88 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 89 | print(f"Removing key {k} from pretrained checkpoint", file=sys.stderr) 90 | del checkpoint_model[k] 91 | 92 | # interpolate position embedding 93 | interpolate_pos_embed(self.model, checkpoint_model) 94 | 95 | # load pre-trained model 96 | msg = self.model.load_state_dict(checkpoint_model, strict=False) 97 | 98 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 99 | 100 | self.model.forward = self.model.forward_features 101 | 102 | def transform(self, image): 103 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 104 | imagenet_std = np.array([0.229, 0.224, 0.225]) 105 | resize_function = TF.Resize(224, interpolation=TF.InterpolationMode.BICUBIC)#.convert('RGB') 106 | to_tensor = TF.ToTensor() 107 | if self.clean_resize: 108 | image = pil_resize(image, (224, 224)) 109 | else: 110 | image = resize_function(image).convert('RGB') 111 | image = to_tensor(image) 112 | # image = TF.center_crop(image, size) TODO: Add crop center if it makes sense 113 | return TF.Normalize(imagenet_mean, imagenet_std)(image) 114 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/pixel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import timm 7 | 8 | try: 9 | from torchvision.models.utils import load_state_dict_from_url 10 | except ImportError: 11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | from torchvision import transforms 13 | import torchvision.transforms.functional as TF 14 | 15 | 16 | from .encoder import Encoder 17 | 18 | from ..resizer import pil_resize 19 | 20 | class PixelEncoder(Encoder): 21 | def setup(self): 22 | self.model = torch.nn.Identity() 23 | pass 24 | 25 | def transform(self, image): 26 | image = pil_resize(image, (32, 32)) 27 | return image 28 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/models/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import sys 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: https://github.com/facebookresearch/moco-v3 20 | # -------------------------------------------------------- 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_h = np.arange(grid_size, dtype=np.float32) 28 | grid_w = np.arange(grid_size, dtype=np.float32) 29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([2, 1, grid_size, grid_size]) 33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 34 | if cls_token: 35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 40 | assert embed_dim % 2 == 0 41 | 42 | # use half of dimensions to encode grid_h 43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 45 | 46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 47 | return emb 48 | 49 | 50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 51 | """ 52 | embed_dim: output dimension for each position 53 | pos: a list of positions to be encoded: size (M,) 54 | out: (M, D) 55 | """ 56 | assert embed_dim % 2 == 0 57 | omega = np.arange(embed_dim // 2, dtype=np.float) 58 | omega /= embed_dim / 2. 59 | omega = 1. / 10000**omega # (D/2,) 60 | 61 | pos = pos.reshape(-1) # (M,) 62 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 63 | 64 | emb_sin = np.sin(out) # (M, D/2) 65 | emb_cos = np.cos(out) # (M, D/2) 66 | 67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 68 | return emb 69 | 70 | 71 | # -------------------------------------------------------- 72 | # Interpolate position embeddings for high-resolution 73 | # References: 74 | # DeiT: https://github.com/facebookresearch/deit 75 | # -------------------------------------------------------- 76 | def interpolate_pos_embed(model, checkpoint_model): 77 | if 'pos_embed' in checkpoint_model: 78 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 79 | embedding_size = pos_embed_checkpoint.shape[-1] 80 | num_patches = model.patch_embed.num_patches 81 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 82 | # height (== width) for the checkpoint position embedding 83 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 84 | # height (== width) for the new position embedding 85 | new_size = int(num_patches ** 0.5) 86 | # class_token and dist_token are kept unchanged 87 | if orig_size != new_size: 88 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size), file=sys.stderr) 89 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 90 | # only the position tokens are interpolated 91 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 92 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 93 | pos_tokens = torch.nn.functional.interpolate( 94 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 95 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 96 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 97 | checkpoint_model['pos_embed'] = new_pos_embed 98 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/representations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.nn.functional import adaptive_avg_pool2d 4 | import torch 5 | import pathlib 6 | 7 | try: 8 | from tqdm import tqdm 9 | except ImportError: 10 | # If tqdm is not available, provide a mock version of it 11 | def tqdm(x): 12 | return x 13 | 14 | def get_representations(model, DataLoader, device, normalized=False): 15 | """Extracts features from all images in DataLoader given model. 16 | 17 | Params: 18 | -- model : Instance of Encoder such as inception or CLIP or dinov2 19 | -- DataLoader : DataLoader containing image files, or torchvision.dataset 20 | 21 | Returns: 22 | -- A numpy array of dimension (num images, dims) that contains the 23 | activations of the given tensor when feeding inception with the 24 | query tensor. 25 | """ 26 | model.eval() 27 | 28 | start_idx = 0 29 | 30 | for ibatch, batch in enumerate(tqdm(DataLoader.data_loader)): 31 | if isinstance(batch, list): 32 | # batch is likely list[array(images), array(labels)] 33 | batch = batch[0] 34 | 35 | if not torch.is_tensor(batch): 36 | # assume batch is then e.g. AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base") 37 | batch = batch['pixel_values'] 38 | batch = batch[:,0] 39 | 40 | # Convert grayscale to RGB 41 | if batch.ndim == 3: 42 | batch.unsqueeze_(1) 43 | if batch.shape[1] == 1: 44 | batch = batch.repeat(1, 3, 1, 1) 45 | 46 | batch = batch.to(device) 47 | 48 | with torch.no_grad(): 49 | pred = model(batch) 50 | 51 | if not torch.is_tensor(pred): # Some encoders output tuples or lists 52 | pred = pred[0] 53 | 54 | # If model output is not scalar, apply global spatial average pooling. 55 | # This happens if you choose a dimensionality not equal 2048. 56 | if pred.dim() > 2: 57 | if pred.size(2) != 1 or pred.size(3) != 1: 58 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 59 | 60 | pred = pred.squeeze(3).squeeze(2) 61 | 62 | if normalized: 63 | pred = torch.nn.functional.normalize(pred, dim=-1) 64 | pred = pred.cpu().numpy() 65 | 66 | if ibatch==0: 67 | # initialize output array with full dataset size 68 | dims = pred.shape[-1] 69 | pred_arr = np.empty((DataLoader.nimages, dims)) 70 | 71 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 72 | 73 | start_idx = start_idx + pred.shape[0] 74 | 75 | return pred_arr 76 | 77 | 78 | def save_outputs(output_dir, reps, model, checkpoint, DataLoader): 79 | """Save representations and other info to disk at file_path""" 80 | out_path = get_path(output_dir, model, checkpoint, DataLoader) 81 | 82 | pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) 83 | 84 | hyperparams = vars(DataLoader).copy() # Remove keys that can't be pickled 85 | hyperparams.pop("transform") 86 | hyperparams.pop("data_loader") 87 | hyperparams.pop("data_set") 88 | 89 | np.savez(out_path, model=model, reps=reps, hparams=hyperparams) 90 | 91 | def load_reps_from_path(saved_dir, model, checkpoint, DataLoader): 92 | """Save representations and other info to disk at file_path""" 93 | save_path = get_path(saved_dir, model, checkpoint, DataLoader) 94 | reps = None 95 | print('Loading from:', save_path) 96 | if os.path.exists(f'{save_path}.npz'): 97 | saved_file = np.load(f'{save_path}.npz') 98 | reps = saved_file['reps'] 99 | return reps 100 | 101 | def get_path(output_dir, model, checkpoint, DataLoader): 102 | train_str = 'train' if DataLoader.train_set else 'test' 103 | 104 | ckpt_str = '' if checkpoint is None else f'_ckpt-{os.path.splitext(os.path.basename(checkpoint))[0]}' 105 | 106 | hparams_str = f'reps_{DataLoader.dataset_name}_{model}{ckpt_str}_nimage-{len(DataLoader.data_set)}_{train_str}' 107 | return os.path.join(output_dir, hparams_str) 108 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/dgm_eval/resizer.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from torchvision.transforms.functional import to_tensor 4 | 5 | def pil_resize(x, output_size): 6 | s1, s2 = output_size 7 | def resize_single_channel(x): 8 | img = Image.fromarray(x, mode='F') 9 | img = img.resize(output_size, resample=Image.BICUBIC) 10 | return np.asarray(img).clip(0, 255).reshape(s2, s1, 1) 11 | x = np.array(x.convert('RGB')).astype(np.float32) 12 | x = [resize_single_channel(x[:, :, idx]) for idx in range(3)] 13 | x = np.concatenate(x, axis=2).astype(np.float32) 14 | return to_tensor(x)/255 -------------------------------------------------------------------------------- /Evaluation/dgm-eval/scripts/run_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | nsample=512 4 | reduced_n=256 5 | batch_size=256 6 | 7 | models='inception dinov2' 8 | 9 | metrics='fd fd-infinity mmd is prdc authpct authpct_test asw ct ct_mem ct_mode fls fls_overfit' 10 | 11 | output_dir=experiments/CIFAR10/ 12 | 13 | reference_dataset='CIFAR10:train' 14 | test_path='CIFAR10:test' 15 | 16 | gen_dir='data/GAN-images/CIFAR10/' 17 | 18 | # List of generated datasets 19 | # Metrics will be computed for each dataset of path specified here 20 | test_datasets="CIFAR10:test \ 21 | CIFAR10:test \ 22 | " 23 | 24 | for model in $models 25 | do 26 | echo 'Running on model:' $model 27 | 28 | python -m dgm_eval \ 29 | $reference_dataset \ 30 | $test_datasets \ 31 | --model $model \ 32 | --nsample $nsample -bs $batch_size \ 33 | --output_dir $output_dir --metrics $metrics \ 34 | --reduced_n $reduced_n --save \ 35 | --test_path $test_path \ 36 | # --heatmaps \ 37 | 38 | done 39 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select=F,W,E,I,B,B9 3 | ignore=W503,B950 4 | max-line-length=79 5 | 6 | [isort] 7 | multi_line_output=1 8 | line_length=79 9 | -------------------------------------------------------------------------------- /Evaluation/dgm-eval/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import setuptools 4 | 5 | 6 | def read(rel_path): 7 | base_path = os.path.abspath(os.path.dirname(__file__)) 8 | with open(os.path.join(base_path, rel_path), 'r') as f: 9 | return f.read() 10 | 11 | 12 | if __name__ == '__main__': 13 | setuptools.setup( 14 | name='dgm-eval', 15 | author='Layer 6', 16 | description=('Package for evaluating deep generative models'), 17 | long_description=read('README.md'), 18 | long_description_content_type='text/markdown', 19 | packages=['dgm_eval'], 20 | classifiers=[ 21 | 'Programming Language :: Python :: 3', 22 | 'License :: OSI Approved :: Apache Software License', 23 | ], 24 | python_requires='>=3.7', 25 | entry_points={ 26 | 'console_scripts': [ 27 | 'dgm-eval = dgm_eval:main', 28 | ], 29 | }, 30 | install_requires=[ 31 | 'numpy==1.23.3', 32 | 'opencv-python==4.6.0.66', 33 | 'open_clip_torch==2.19.0', 34 | 'pandas==1.5.3', 35 | 'pillow==9.2.0', 36 | 'scikit-image==0.19.3', 37 | 'scikit-learn==1.1.3', 38 | 'scipy==1.9.3', 39 | 'timm==0.8.19.dev0', 40 | 'torch>=2.0.0', 41 | 'torchvision>=0.2.2', 42 | 'transformers==4.26.0', 43 | 'xformers==0.0.18', 44 | ], 45 | extras_require={'dev': ['flake8', 46 | 'flake8-bugbear', 47 | 'flake8-isort', 48 | 'nox']}, 49 | ) 50 | -------------------------------------------------------------------------------- /FR_training/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FR_training/backbones/activation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | import torch 5 | 6 | from inspect import isfunction 7 | 8 | class Identity(nn.Module): 9 | """ 10 | Identity block. 11 | """ 12 | def __init__(self): 13 | super(Identity, self).__init__() 14 | 15 | def forward(self, x): 16 | return x 17 | 18 | def __repr__(self): 19 | return '{name}()'.format(name=self.__class__.__name__) 20 | class HSigmoid(nn.Module): 21 | """ 22 | Approximated sigmoid function, so-called hard-version of sigmoid from 'Searching for MobileNetV3,' 23 | https://arxiv.org/abs/1905.02244. 24 | """ 25 | def forward(self, x): 26 | return F.relu6(x + 3.0, inplace=True) / 6.0 27 | 28 | 29 | class Swish(nn.Module): 30 | """ 31 | Swish activation function from 'Searching for Activation Functions,' https://arxiv.org/abs/1710.05941. 32 | """ 33 | def forward(self, x): 34 | return x * torch.sigmoid(x) 35 | class HSwish(nn.Module): 36 | """ 37 | H-Swish activation function from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244. 38 | Parameters: 39 | ---------- 40 | inplace : bool 41 | Whether to use inplace version of the module. 42 | """ 43 | def __init__(self, inplace=False): 44 | super(HSwish, self).__init__() 45 | self.inplace = inplace 46 | 47 | def forward(self, x): 48 | return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 49 | 50 | 51 | def get_activation_layer(activation,param): 52 | """ 53 | Create activation layer from string/function. 54 | Parameters: 55 | ---------- 56 | activation : function, or str, or nn.Module 57 | Activation function or name of activation function. 58 | Returns: 59 | ------- 60 | nn.Module 61 | Activation layer. 62 | """ 63 | assert (activation is not None) 64 | if isfunction(activation): 65 | return activation() 66 | elif isinstance(activation, str): 67 | if activation == "relu": 68 | return nn.ReLU(inplace=True) 69 | elif activation =="prelu": 70 | return nn.PReLU(param) 71 | elif activation == "relu6": 72 | return nn.ReLU6(inplace=True) 73 | elif activation == "swish": 74 | return Swish() 75 | elif activation == "hswish": 76 | return HSwish(inplace=True) 77 | elif activation == "sigmoid": 78 | return nn.Sigmoid() 79 | elif activation == "hsigmoid": 80 | return HSigmoid() 81 | elif activation == "identity": 82 | return Identity() 83 | else: 84 | raise NotImplementedError() 85 | else: 86 | assert (isinstance(activation, nn.Module)) 87 | return activation -------------------------------------------------------------------------------- /FR_training/config/FR_config.py: -------------------------------------------------------------------------------- 1 | architecture = "resnet50"#"resnet50"#"resnet50" 2 | 3 | # root_folder = "/shared/home/darian.tomasevic/ID-Booth/" 4 | root_folder = ".." 5 | folder_to_test = "FacePortrait_Photo_21_Gender_Pose_Background" # TODO 6 | models = ["DreamBooth", "PortraitBooth", "ID-Booth"] 7 | 8 | # folder_to_test = "tufts_512_poses_1-7_all_imgs_jpg_per_ID" 9 | # models = ["images"] 10 | 11 | dataset_folder = f"{root_folder}/FR_DATASETS/{folder_to_test}" 12 | 13 | 14 | model = "TODO" 15 | benchmark_folder = f"{root_folder}/FR_training/VALIDATION_DATASETS_from_webface" 16 | augment = False 17 | stopping_condition_epochs = 6 18 | stop_only_after_epoch_schedule = False 19 | 20 | verification_frequency = 1 21 | output_folder_name_start = f"REC_EXP_01_2025_LFW_Verification{verification_frequency}"#_AllBench" 22 | 23 | EMBEDDING_TYPE = [ 24 | "." 25 | ] 26 | 27 | embedding_type = EMBEDDING_TYPE[0] 28 | 29 | width = 0 30 | depth = 0 31 | 32 | batch_size = 128 # 128 # 256 33 | workers = 8 # 32 34 | embedding_size = 512 35 | learning_rate = 0.1 36 | momentum = 0.9 37 | weight_decay = 5e-4 38 | 39 | global_step = 0 # to resume 40 | start_epoch = 0 41 | 42 | s = 64.0 43 | m = 0.35 44 | loss = "AdaFace"#"" 45 | dropout_ratio = 0.4 46 | 47 | augmentation = "ra_4_16" # hf, ra_4_16 48 | 49 | 50 | print_freq = 1 #50 51 | val_path = "TODO"#"/data/Biometrics/database/faces_emore" # "/data/fboutros/faces_emore" 52 | val_targets = ["lfw"] 53 | # val_targets = ["lfw", "agedb_30", "cfp_fp", "calfw", "cplfw"] 54 | 55 | 56 | auto_schedule = True 57 | num_epoch = 200 58 | schedule = [22, 30, 35] 59 | 60 | 61 | def lr_step_func(epoch): 62 | return ( 63 | ((epoch + 1) / (4 + 1)) ** 2 64 | if epoch < -1 65 | else 0.1 ** len([m for m in schedule if m - 1 <= epoch]) 66 | ) 67 | 68 | 69 | lr_func = lr_step_func 70 | 71 | -------------------------------------------------------------------------------- /FR_training/config/FR_config_Augmented.py: -------------------------------------------------------------------------------- 1 | architecture = "resnet50"#"resnet50"#"resnet50" 2 | 3 | root_folder = "/shared/home/darian.tomasevic/ID-Booth/" 4 | # root_folder = "/home/darian/Desktop/Diffusion/ID-Booth/" 5 | 6 | # folder_to_test = "12-2024_SD21_LoRA4_alphaW0.1_FINAL_FacePortraitPhoto_Gender_Pose_BackgroundB" 7 | # folder_to_test = "12-2024_SD21_LoRA4_alphaW0.1_FINAL_FacePortraitPhoto_Gender_Pose_AgePhases_Expression_BackgroundB" 8 | # folder_to_test = "12-2024_SD21_LoRA4_alphaWNone_FINAL_FacePortraitPhoto_Gender_Pose_BackgroundB" 9 | folder_to_test = "12-2024_SD21_LoRA4_alphaWNone_FacePortrait_Photo_Gender_Pose_BackgroundB_100samples" 10 | 11 | dataset_folder = f"{root_folder}/FR_DATASETS_AUGMENTED_samples/{folder_to_test}" 12 | 13 | # dataset_folder = f"{root_folder}/FR_DATASETS_AUGMENTED_+21_samples/12-2024_SD21_LoRA4_alphaW0.1_Face_Poses_Environments" 14 | 15 | # # TODO 16 | #dataset_folder = f"{root_folder}/FR_DATASETS/FR_DATASETS_SDXL/tufts_512_poses_1-7_all_imgs_jpg_per_ID" 17 | #models = ["images"] 18 | 19 | models = ["no_new_Loss", "identity_loss_TimestepWeight", "triplet_prior_loss_TimestepWeight"] 20 | 21 | 22 | model = "TODO" 23 | benchmark_folder = f"{root_folder}/FR_training/VALIDATION_DATASETS_from_webface" 24 | augment = False 25 | stopping_condition_epochs = 6 #6 #6 # 10 # TODO was 6 26 | stop_only_after_epoch_schedule = False 27 | 28 | verification_frequency = 1 29 | output_folder_name_start = f"REC_EXP_01_2025_TFD+Synth_LFW_Verification{verification_frequency}"#_AllBench" 30 | 31 | EMBEDDING_TYPE = [ 32 | "." 33 | ] 34 | 35 | embedding_type = EMBEDDING_TYPE[0] 36 | 37 | width = 0 38 | depth = 0 39 | 40 | batch_size = 128 # 128 # 256 41 | workers = 8 # 32 42 | embedding_size = 512 43 | learning_rate = 0.1 44 | momentum = 0.9 45 | weight_decay = 5e-4 46 | 47 | global_step = 0 # to resume 48 | start_epoch = 0 49 | 50 | s = 64.0 51 | m = 0.35 52 | loss = "AdaFace"#"" 53 | dropout_ratio = 0.4 54 | 55 | augmentation = "ra_4_16" # hf, ra_4_16 56 | 57 | 58 | print_freq = 1 #50 59 | val_path = "TODO"#"/data/Biometrics/database/faces_emore" # "/data/fboutros/faces_emore" 60 | val_targets = ["lfw"] 61 | # val_targets = ["lfw", "agedb_30", "cfp_fp", "calfw", "cplfw"] 62 | 63 | 64 | auto_schedule = True 65 | num_epoch = 200 66 | schedule = [22, 30, 35] 67 | 68 | 69 | def lr_step_func(epoch): 70 | return ( 71 | ((epoch + 1) / (4 + 1)) ** 2 72 | if epoch < -1 73 | else 0.1 ** len([m for m in schedule if m - 1 <= epoch]) 74 | ) 75 | 76 | 77 | lr_func = lr_step_func 78 | 79 | -------------------------------------------------------------------------------- /FR_training/config/config_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | architecture = "resnet18"#"resnet50" 4 | 5 | dataset_folder = ""#"specify the path to the aligned images folder" 6 | 7 | #MODELS = [ 8 | # "TUFTS_dataset_cropped_and_aligned/VIS" 9 | #"unet-cond-ca-bs512-150K", 10 | #"unet-cond-ca-bs512-150K-cpd25", 11 | #"unet-cond-ca-bs512-150K-cpd50" 12 | #] 13 | 14 | EMBEDDING_TYPE = [ 15 | "." 16 | #"random_synthetic_uniform_5000", 17 | #"random_synthetic_learned_5000", 18 | #"random_synthetic_extracted_5000" 19 | ] 20 | 21 | #model = MODELS[0] 22 | embedding_type = EMBEDDING_TYPE[0] 23 | 24 | width = 0 25 | depth = 0 26 | 27 | batch_size = 16 # TODO Was 16 #128 # 256 28 | workers = 8 # 32 29 | embedding_size = 112 #512 30 | learning_rate = 0.1 31 | momentum = 0.9 32 | weight_decay = 5e-4 33 | 34 | global_step = 0 # to resume 35 | start_epoch = 0 36 | 37 | s = 64.0 38 | m = 0.35 39 | loss = "CosFace" 40 | dropout_ratio = 0.4 41 | 42 | augmentation = "ra_4_16" # hf, ra_4_16 43 | 44 | print_freq = 1 #50 45 | val_path = "TODO"#"/data/Biometrics/database/faces_emore" # "/data/fboutros/faces_emore" 46 | val_targets = ["TODO"]#"lfw"]#, "agedb_30", "cfp_fp", "calfw", "cplfw"] 47 | model = "TODO" 48 | 49 | auto_schedule = True 50 | num_epoch = 200 51 | schedule = [22, 30, 35] 52 | 53 | 54 | def lr_step_func(epoch): 55 | return ( 56 | ((epoch + 1) / (4 + 1)) ** 2 57 | if epoch < -1 58 | else 0.1 ** len([m for m in schedule if m - 1 <= epoch]) 59 | ) 60 | 61 | 62 | lr_func = lr_step_func 63 | -------------------------------------------------------------------------------- /FR_training/config/config_orig.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | architecture = "resnet50" 4 | 5 | dataset_folder = "specify the path to the aligned images folder" 6 | 7 | MODELS = [ 8 | "unet-cond-ca-bs512-150K", 9 | "unet-cond-ca-bs512-150K-cpd25", 10 | "unet-cond-ca-bs512-150K-cpd50" 11 | ] 12 | 13 | EMBEDDING_TYPE = [ 14 | "random_synthetic_uniform_5000", 15 | "random_synthetic_learned_5000", 16 | "random_synthetic_extracted_5000" 17 | ] 18 | 19 | model = MODELS[0] 20 | embedding_type = EMBEDDING_TYPE[0] 21 | 22 | width = 0 23 | depth = 0 24 | 25 | batch_size = 128 # 256 26 | workers = 8 # 32 27 | embedding_size = 512 28 | learning_rate = 0.1 29 | momentum = 0.9 30 | weight_decay = 5e-4 31 | 32 | global_step = 0 # to resume 33 | start_epoch = 0 34 | 35 | s = 64.0 36 | m = 0.35 37 | loss = "CosFace" 38 | dropout_ratio = 0.4 39 | 40 | augmentation = "ra_4_16" # hf, ra_4_16 41 | 42 | print_freq = 50 43 | val_path = "/data/Biometrics/database/faces_emore" # "/data/fboutros/faces_emore" 44 | val_targets = ["lfw", "agedb_30", "cfp_fp", "calfw", "cplfw"] 45 | 46 | auto_schedule = True 47 | num_epoch = 200 48 | schedule = [22, 30, 35] 49 | 50 | 51 | def lr_step_func(epoch): 52 | return ( 53 | ((epoch + 1) / (4 + 1)) ** 2 54 | if epoch < -1 55 | else 0.1 ** len([m for m in schedule if m - 1 <= epoch]) 56 | ) 57 | 58 | 59 | lr_func = lr_step_func 60 | -------------------------------------------------------------------------------- /FR_training/config/test_FR_config.py: -------------------------------------------------------------------------------- 1 | architecture = "resnet50"#"resnet50" 2 | 3 | # root_folder = "/shared/home/darian.tomasevic/ID-Booth/" 4 | root_folder = ".." 5 | folder_to_test = "FacePortrait_Photo_21_Gender_Pose_Background" 6 | models = ["DreamBooth", "PortraitBooth", "ID-Booth"] 7 | 8 | # folder_to_test = "tufts_512_poses_1-7_all_imgs_jpg_per_ID" 9 | # models = ["images"] 10 | 11 | dataset_folder = f"{root_folder}/FR_DATASETS/{folder_to_test}" 12 | 13 | 14 | model = "TODO" 15 | benchmark_folder = f"{root_folder}/FR_training/VALIDATION_DATASETS_from_webface" 16 | augment = False 17 | stopping_condition_epochs = 0 18 | verification_frequency = 1 19 | output_folder_name_start = f"REC_EXP_01_2025_LFW_Verification{verification_frequency}" 20 | 21 | EMBEDDING_TYPE = [ 22 | "." 23 | ] 24 | 25 | embedding_type = EMBEDDING_TYPE[0] 26 | 27 | width = 0 28 | depth = 0 29 | 30 | batch_size = 128 # 128 # 256 31 | workers = 8 # 32 32 | embedding_size = 512 33 | learning_rate = 0.1 34 | momentum = 0.9 35 | weight_decay = 5e-4 36 | 37 | global_step = 0 # to resume 38 | start_epoch = 0 39 | 40 | s = 64.0 41 | m = 0.35 42 | loss = "AdaFace" 43 | dropout_ratio = 0.4 44 | 45 | augmentation = "ra_4_16" # hf, ra_4_16 46 | 47 | print_freq = 1 #50 48 | val_path = "TODO"#"/data/Biometrics/database/faces_emore" # "/data/fboutros/faces_emore" 49 | val_targets = ["lfw", "agedb_30", "cfp_fp", "calfw", "cplfw"]#"lfw"]#, "agedb_30", "cfp_fp", "calfw", "cplfw"] 50 | 51 | auto_schedule = True 52 | num_epoch = 200 53 | schedule = [22, 30, 35] 54 | 55 | 56 | def lr_step_func(epoch): 57 | return ( 58 | ((epoch + 1) / (4 + 1)) ** 2 59 | if epoch < -1 60 | else 0.1 ** len([m for m in schedule if m - 1 <= epoch]) 61 | ) 62 | 63 | 64 | lr_func = lr_step_func 65 | 66 | -------------------------------------------------------------------------------- /FR_training/config/test_FR_config_Augmented.py: -------------------------------------------------------------------------------- 1 | architecture = "resnet50"#"resnet50" 2 | 3 | root_folder = "/shared/home/darian.tomasevic/ID-Booth/" 4 | #dataset_folder = "../Generated_Split_Images_112x112/" 5 | 6 | # dataset_folder = f"{root_folder}/FR_DATASETS/FR_DATASETS_SDXL/SDXL_DB_LoRA_Tufts_base_prompt_16_07_png" 7 | # dataset_folder = f"{root_folder}/FR_DATASETS/FR_DATASETS_SDXL/SDXL_DB_LoRA_Tufts_combined_16_07_png" 8 | # dataset_folder = f"{root_folder}/FR_DATASETS/FR_DATASETS_AUGMENTED_+21_samples/SDXL_DB_LoRA_Tufts_base_prompt_16_07_png" 9 | # dataset_folder = f"{root_folder}/FR_DATASETS/FR_DATASETS_AUGMENTED_+21_samples/SDXL_DB_LoRA_Tufts_combined_16_07_png" 10 | 11 | # dataset_folder = f"{root_folder}/FR_DATASETS/FR_DATASETS_AUGMENTED_+10_samples/SDXL_DB_LoRA_Tufts_combined_16_07_png" 12 | # dataset_folder = f"{root_folder}/FR_DATASETS/FR_DATASETS_AUGMENTED_+10_samples/SDXL_DB_LoRA_Tufts_base_prompt_16_07_png" 13 | 14 | 15 | # # TODO 16 | # dataset_folder = f"{root_folder}/FR_DATASETS/FR_DATASETS_SDXL/tufts_512_poses_1-7_all_imgs_jpg_per_ID" 17 | # models = ["images"] 18 | 19 | # dataset_folder = f"{root_folder}/FR_DATASETS_AUGMENTED_+21_samples/12-2024_SD21_LoRA4_alphaW0.1_Face_Poses_Environments" 20 | # folder_to_test = "12-2024_SD21_LoRA4_alphaWNone_FINAL_FacePortraitPhoto_Gender_Pose_BackgroundB" 21 | folder_to_test = "12-2024_SD21_LoRA4_alphaWNone_FacePortrait_Photo_Gender_Pose_BackgroundB_100samples" 22 | 23 | dataset_folder = f"{root_folder}/FR_DATASETS_AUGMENTED_samples/{folder_to_test}" 24 | 25 | models = ["no_new_Loss", "identity_loss_TimestepWeight", "triplet_prior_loss_TimestepWeight"] 26 | 27 | 28 | model = "TODO" 29 | benchmark_folder = f"{root_folder}/FR_training/VALIDATION_DATASETS_from_webface" 30 | augment = False 31 | stopping_condition_epochs = 0 32 | 33 | verification_frequency = 1 34 | output_folder_name_start = f"REC_EXP_01_2025_TFD+Synth_LFW_Verification{verification_frequency}"#_AllBench" 35 | 36 | EMBEDDING_TYPE = [ 37 | "." 38 | ] 39 | 40 | embedding_type = EMBEDDING_TYPE[0] 41 | 42 | width = 0 43 | depth = 0 44 | 45 | batch_size = 128 # 128 # 256 46 | workers = 8 # 32 47 | embedding_size = 512 48 | learning_rate = 0.1 49 | momentum = 0.9 50 | weight_decay = 5e-4 51 | 52 | global_step = 0 # to resume 53 | start_epoch = 0 54 | 55 | s = 64.0 56 | m = 0.35 57 | loss = "AdaFace"#"" 58 | dropout_ratio = 0.4 59 | 60 | augmentation = "ra_4_16" # hf, ra_4_16 61 | 62 | print_freq = 1 #50 63 | val_path = "TODO"#"/data/Biometrics/database/faces_emore" # "/data/fboutros/faces_emore" 64 | val_targets = ["lfw", "agedb_30", "cfp_fp", "calfw", "cplfw"]#"lfw"]#, "agedb_30", "cfp_fp", "calfw", "cplfw"] 65 | 66 | auto_schedule = True 67 | num_epoch = 200 68 | schedule = [22, 30, 35] 69 | 70 | 71 | def lr_step_func(epoch): 72 | return ( 73 | ((epoch + 1) / (4 + 1)) ** 2 74 | if epoch < -1 75 | else 0.1 ** len([m for m in schedule if m - 1 <= epoch]) 76 | ) 77 | 78 | 79 | lr_func = lr_step_func 80 | 81 | -------------------------------------------------------------------------------- /FR_training/moco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariant/ID-Booth/807d478b74833d69cf39399799fe50cf5284b314/FR_training/moco/__init__.py -------------------------------------------------------------------------------- /FR_training/moco/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from PIL import ImageFilter 3 | import random 4 | 5 | 6 | class TwoCropsTransform: 7 | """Augment query with base transform and key with key augmentation. 8 | If key augmentation == None, key is augmented with base transform""" 9 | 10 | def __init__(self, base_transform, key_augmentation=None): 11 | self.base_transform = base_transform 12 | self.key_augment = ( 13 | base_transform if key_augmentation is None else key_augmentation 14 | ) 15 | 16 | def __call__(self, x): 17 | q = self.base_transform(x) 18 | k = self.key_augment(x) 19 | return [q, k] 20 | 21 | 22 | class GaussianBlur(object): 23 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 24 | 25 | def __init__(self, sigma=[0.1, 2.0]): 26 | self.sigma = sigma 27 | 28 | def __call__(self, x): 29 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 30 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 31 | return x 32 | -------------------------------------------------------------------------------- /FR_training/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariant/ID-Booth/807d478b74833d69cf39399799fe50cf5284b314/FR_training/utils/__init__.py -------------------------------------------------------------------------------- /FR_training/utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import torch 3 | import logging 4 | import moco.loader 5 | from utils.rand_augment import RandAugment 6 | from utils.FAA_policy import IResNet50CasiaPolicy, ReducedImageNetPolicy 7 | 8 | normalize_moco = transforms.Normalize( 9 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 10 | ) 11 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 12 | 13 | 14 | # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 15 | aug_plus = [ 16 | transforms.RandomResizedCrop(112, scale=(0.2, 1.0)), 17 | transforms.RandomApply( 18 | [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8 # not strengthened 19 | ), 20 | transforms.RandomGrayscale(p=0.2), 21 | transforms.RandomApply([moco.loader.GaussianBlur([0.1, 2.0])], p=0.5), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | normalize_moco, 25 | ] 26 | 27 | # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978 28 | aug_default = [ 29 | transforms.RandomResizedCrop(112, scale=(0.2, 1.0)), 30 | transforms.RandomGrayscale(p=0.2), 31 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | normalize_moco, 35 | ] 36 | 37 | to_tensor = [transforms.ToTensor(), normalize] 38 | 39 | only_normalize = [normalize] 40 | 41 | online_RA_2_16 = [ 42 | transforms.RandomHorizontalFlip(), 43 | RandAugment(num_ops=4, magnitude=16), 44 | normalize, 45 | ] 46 | 47 | aug_h_flip = [transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize] 48 | 49 | aug_rand_4_16 = [ 50 | transforms.RandomHorizontalFlip(), 51 | RandAugment(num_ops=4, magnitude=16), 52 | transforms.ToTensor(), 53 | normalize, 54 | ] 55 | 56 | aug_online_rand_4_16 = [ 57 | transforms.RandomHorizontalFlip(), 58 | RandAugment(num_ops=4, magnitude=16), 59 | ] 60 | 61 | aug_rand_2_9 = [ 62 | transforms.RandomHorizontalFlip(), 63 | RandAugment(num_ops=2, magnitude=9), 64 | transforms.ToTensor(), 65 | normalize, 66 | ] 67 | 68 | aug_rand_4_24 = [ 69 | transforms.RandomHorizontalFlip(), 70 | RandAugment(num_ops=4, magnitude=24), 71 | transforms.ToTensor(), 72 | normalize, 73 | ] 74 | 75 | aug_CASIA_FAA = [ 76 | transforms.RandomHorizontalFlip(), 77 | IResNet50CasiaPolicy(), 78 | transforms.ToTensor(), 79 | normalize, 80 | ] 81 | 82 | aug_ImgNet_FAA = [ 83 | transforms.RandomHorizontalFlip(), 84 | ReducedImageNetPolicy(), 85 | transforms.ToTensor(), 86 | normalize, 87 | ] 88 | 89 | 90 | def get_randaug(n, m): 91 | """return RandAugment transforms with 92 | n: number of operations 93 | m: magnitude 94 | """ 95 | return [ 96 | transforms.RandomHorizontalFlip(), 97 | RandAugment(num_ops=n, magnitude=m), 98 | transforms.ToTensor(), 99 | normalize, 100 | ] 101 | 102 | 103 | def select_x_operation(x): 104 | """enable only the x operation to RandAug 105 | x: string of the available augmentation 106 | """ 107 | return [ 108 | transforms.RandomHorizontalFlip(), 109 | RandAugment(num_ops=1, magnitude=9, available_aug=x), 110 | transforms.ToTensor(), 111 | normalize, 112 | ] 113 | 114 | 115 | def get_conventional_aug_policy(aug_type, operation="none", num_ops=0, mag=0): 116 | """get geometric and color augmentations 117 | args: 118 | aug_type: string defining augmentation type 119 | operation: RA augmentation operation under testing 120 | num_ops: number of sequential operations under testing 121 | mag: magnitude under testing 122 | return: 123 | augmentation policy 124 | """ 125 | aug = aug_type.lower() 126 | if aug == "gan_hf" or aug == "nogan_hf" or aug == "hf": 127 | augmentation = aug_h_flip 128 | elif aug == "gan_ra_4_16" or aug == "nogan_ra_4_16" or aug == "ra_4_16": 129 | augmentation = aug_rand_4_16 130 | elif aug == "moco": 131 | augmentation = aug_default 132 | elif aug == "faa_casia": 133 | augmentation = aug_CASIA_FAA 134 | elif aug == "faa_imgnet": 135 | augmentation = aug_ImgNet_FAA 136 | elif aug == "num_mag_exp": 137 | augmentation = get_randaug(num_ops, mag) 138 | elif aug == "aug_operation_exp": 139 | logging.info("Augmentation under testing: " + operation) 140 | augmentation = select_x_operation(operation) 141 | elif aug == "totensor": 142 | augmentation = to_tensor 143 | elif aug == "mag_totensor": 144 | augmentation = [transforms.ToTensor()] 145 | else: 146 | logging.error("Unknown augmentation method: {}".format(aug_type)) 147 | exit() 148 | return transforms.Compose(augmentation) 149 | 150 | 151 | """ default simCLR augmentation """ 152 | s = 1 153 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 154 | aug_simCLR = [ 155 | transforms.RandomResizedCrop( 156 | size=[112, 112], scale=(0.75, 1) 157 | ), # scale not given in default SimCLR augmentation 158 | transforms.RandomHorizontalFlip(), # with 0.5 probability 159 | transforms.RandomApply([color_jitter], p=0.8), 160 | transforms.RandomGrayscale(p=0.2), 161 | transforms.ToTensor(), 162 | normalize, # not given in default SimCLR augmentation 163 | ] 164 | 165 | 166 | class GaussianNoise(object): 167 | def __init__(self, mean=0.0, std=1.0): 168 | self.std = std 169 | self.mean = mean 170 | 171 | def __call__(self, tensor): 172 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 173 | 174 | def __repr__(self): 175 | return self.__class__.__name__ + "(mean={0}, std={1})".format( 176 | self.mean, self.std 177 | ) 178 | -------------------------------------------------------------------------------- /FR_training/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | # import logging 2 | from accelerate import logging 3 | import os 4 | import sys 5 | 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value 10 | """ 11 | 12 | def __init__(self): 13 | self.val = None 14 | self.avg = None 15 | self.sum = None 16 | self.count = None 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | def init_logging(log_root, rank, models_root, logfile): 33 | 34 | if rank == 0: 35 | if (not logfile): 36 | logfile= "training.log" 37 | log_root.handlers = [] # This is the key thing for the question! 38 | log_root.setLevel(logging.INFO) 39 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s", datefmt='%Y-%m-%d %H:%M') 40 | handler_file = logging.FileHandler(os.path.join(models_root, logfile)) 41 | handler_stream = logging.StreamHandler(sys.stdout) 42 | handler_file.setFormatter(formatter) 43 | handler_stream.setFormatter(formatter) 44 | log_root.addHandler(handler_file) 45 | log_root.addHandler(handler_stream) 46 | log_root.info('rank_id: %d' % rank) 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ID-Booth: Identity-consistent Face Generation with Diffusion Models 2 | 3 |
4 | Darian Tomašević1, Fadi Boutros2, Chenhao Lin3, Naser Damer2,4, Vitomir Štruc5, Peter Peer1 5 |
6 |
7 | 1 University of Ljubljana, Faculty of Computer and Information Science, Ljubljana, Slovenia
8 | 2 Fraunhofer Institute for Computer Graphics Research IGD, Darmstadt, Germany
9 | 3 Xi’an Jiaotong University, School of Cyber Science and Engineering, Xi’an, China
10 | 4 Department of Computer Science, TU Darmstadt, Germany
11 | 5 University of Ljubljana, Faculty of Electrical Engineering, Ljubljana, Slovenia 12 | 13 |
14 |
15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |
23 | 24 |
25 | 26 | This is the official implementation of the [ID-Booth framework](https://arxiv.org/abs/2504.07392), which: 27 | 28 |  🔥 generates in-the-wild images of consenting identities captured in a constrained environment
29 |  🔥 uses a triplet identity loss to fine-tune Stable Diffusion for identity-consistent yet diverse image generation
30 |  🔥 can augment small-scale datasets to improve their suitability for training face recognition models
31 | 32 |
33 | 34 |
35 |

36 | 37 |

38 |
39 |
40 |

41 | 42 |

43 |
44 | 45 | 46 | ##
Installation
47 | 48 | ```bash 49 | conda create -n id-booth python=3.10 50 | conda activate id-booth 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | ##
Dowload links for pretrained models
55 | 56 | To generate images of identities found in the paper, download their fine-tuned [ID-Booth LoRA weights](https://unilj-my.sharepoint.com/:u:/g/personal/darian_tomasevic_fri1_uni-lj_si/EXnMuZvIG49HuNkOomD_2K8BNvba8f8kJcWb6M7hGPCA0w). 57 | To create your own fine-tuned model with ID-Booth, download the pretrained [ArcFace recognition model](https://unilj-my.sharepoint.com/:u:/g/personal/darian_tomasevic_fri1_uni-lj_si/EfSmDfvsVlZEuOBqieDl4zEBJkTJ65aBnUtrC4q5nT2a-g?e=PBYj7o), and place the weights into the [ArcFace_files](https://github.com/dariant/ID-Booth/tree/main/ArcFace_files) directory. 58 | 59 | 60 | 61 | ##
Generating identity-specific images
62 | 63 | To generate images of a desired identity with [Stable Diffusion 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1), use the [diffusers](https://huggingface.co/docs/diffusers/index) library to load the corresponding LoRA weights, which were trained with the ID-Booth framework. The following example generates in-the-wild images of ID_1: 64 | 65 | ```python 66 | import torch 67 | from diffusers import StableDiffusionPipeline, DDPMScheduler 68 | 69 | base_model = "stabilityai/stable-diffusion-2-1-base" 70 | lora_checkpoint = "Trained_LoRA_Models/ID-Booth/ID_1/checkpoint-31-6400" # Download or train your own 71 | 72 | prompt = "face portrait photo of male sks person, city street background" 73 | negative_prompt = "cartoon, render, illustration, painting, drawing, black and white, bad body proportions, landscape" 74 | 75 | pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16).to("cuda:0") 76 | pipe.scheduler = DDPMScheduler.from_pretrained(base_model, subfolder="scheduler") 77 | pipe.load_lora_weights(lora_checkpoint) 78 | 79 | image = pipe(prompt=prompt, 80 | negative_prompt=negative_prompt, 81 | num_inference_steps=30, 82 | guidance_scale=5.0).images[0] 83 | 84 | image.save(f"ID_1_{prompt}.png") 85 | ``` 86 | Results in the paper can be reproduced with data generated by the [inference_ID-Booth.py](https://github.com/dariant/ID-Booth/blob/main/inference_ID-Booth.py) script. 87 | 88 | 89 | 90 | ##
ID-Booth fine-tuning on new identities
91 | 92 | To perform ID-Booth fine-tuning of [Stable Diffusion 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) on a new identity, you can follow the [train_ID-Booth.py](https://github.com/dariant/ID-Booth/blob/main/train_ID-Booth.py) script. The training dataset for a desired identity should include a handful of images along with ID embeddings extracted from these images with a pretrained ArcFace recognition model: 93 | ``` 94 | FACE_DATASET 95 | └─── ID_1 96 | │ └─── images 97 | │ | sample_0.png 98 | │ | sample_1.png 99 | │ | ... 100 | │ └─── ArcFace_embeds 101 | │ sample_0.pt 102 | │ sample_1.pt 103 | │ ... 104 | └─── ID_2 105 | └─── ... 106 | ``` 107 | The required ID embeddings can be extracted with the [extract_ArcFace_embeds.py](https://github.com/dariant/ID-Booth/blob/main/extract_ArcFace_embeds.py) script. 108 | Before running [train_ID-Booth.py](https://github.com/dariant/ID-Booth/blob/main/train_ID-Booth.py), specify the path to the source folder with identity images in [config_train_SD21.py](https://github.com/dariant/ID-Booth/blob/main/configs/config_train_SD21.py). 109 | 110 | 111 | ##
Evaluating the synthetic data
112 | 113 | For the evaluation of generated synthetic images, we rely on the following repositories: 114 | * [dgm-eval](https://github.com/layer6ai-labs/dgm-eval) to measure quality, fidelity and diversity, 115 | * [CR-FIQA](https://github.com/fdbtrs/CR-FIQA) to determine the face image quality, 116 | * [6DRepNet](https://github.com/thohemp/6DRepNet) to estimate the pitch, yaw and roll of head poses, 117 | * [PyEER](https://github.com/manuelaguadomtz/pyeer) to analyse identity consistency and separability. 118 | 119 | Notebooks and scripts for reproducing the results in the paper can be found in the [Evaluation](https://github.com/dariant/ID-Booth/tree/main/Evaluation) directory, while fine-tuned LoRA weights of different approaches can be downloaded [here](https://unilj-my.sharepoint.com/:f:/g/personal/darian_tomasevic_fri1_uni-lj_si/Esv2DimWDExAtkBjx-6SDoMBvP4TiD_N-gBaPhf10ekKrA?e=BCKzQb). 120 | 121 | To also evaluate the utility of the produced data, we also use it to train a deep face recognition model, following the [train_FR.py](https://github.com/dariant/ID-Booth/blob/main/FR_training/train_FR.py) script. The performance of these models is then evaluated on state-of-the-art verification benchmarks with [test_FR.py](https://github.com/dariant/ID-Booth/blob/main/FR_training/test_FR.py). 122 | 123 | 124 | ## Citation 125 | 126 | If you use the code or results from this repository, please cite the ID-Booth paper: 127 | 128 | ``` 129 | @article{tomasevic2025IDBooth, 130 | title={{ID-Booth}: Identity-consistent Face Generation with Diffusion Models}, 131 | author={Toma{\v{s}}evi{\'c}, Darian and Boutros, Fadi and Lin, Chenhao and Damer, Naser and {\v{S}}truc, Vitomir and Peer, Peter}, 132 | journal={arXiv preprint arXiv:2504.07392}, 133 | year={2025} 134 | } 135 | ``` 136 | 137 | ## Acknowledgements 138 | 139 | Supported in parts by the Slovenian Research and Innovation Agency (ARIS) through the Research Programmes P2-0250 (B) "Metrology and Biometric Systems" and P2--0214 (A) “Computer Vision”, the ARIS Project J2-50065 "DeepFake DAD" and the ARIS Young Researcher Programme. 140 | 141 | ARIS_logo_eng_resized 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /assets/ARIS_logo_eng_resized.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariant/ID-Booth/807d478b74833d69cf39399799fe50cf5284b314/assets/ARIS_logo_eng_resized.jpg -------------------------------------------------------------------------------- /assets/preview_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariant/ID-Booth/807d478b74833d69cf39399799fe50cf5284b314/assets/preview_framework.jpg -------------------------------------------------------------------------------- /assets/preview_samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariant/ID-Booth/807d478b74833d69cf39399799fe50cf5284b314/assets/preview_samples.jpg -------------------------------------------------------------------------------- /configs/config_train_SD21.py: -------------------------------------------------------------------------------- 1 | pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1-base" 2 | 3 | mixed_precision = "fp16" 4 | logging_dir = "Logs" 5 | report_to = "tensorboard" 6 | 7 | revision = None 8 | variant = None 9 | 10 | resume_from_checkpoint = None 11 | 12 | # Source images 13 | # source_folder = "/shared/home/darian.tomasevic/ID-Booth/FACE_DATASETS/tufts_512_poses_1-7_all_imgs_jpg_per_ID/images" 14 | source_folder = "../tufts_512_poses_1-7_all_imgs_jpg_per_ID/images" 15 | resolution = 512 16 | instance_prompt = "photo of sks person" # "face portrait photo of fid person" 17 | 18 | # Prior preservation images 19 | with_prior_preservation = True 20 | class_prompt = "photo of a person" #"face portrait photo of a person" # 21 | # class_data_dir = "/shared/home/darian.tomasevic/ID-Booth/CLASS_IMAGES/SD21_Class_imgs_200/images" 22 | class_data_dir = "../CLASS_IMAGES/SD21_Class_imgs_200/images" 23 | num_class_images = 200 24 | prior_loss_weight = 1.0 25 | 26 | validation_prompt = "photo of sks person with blue hair" 27 | validation_negative_prompt = "" #"cartoon, cgi, render, illustration, painting, drawing, black and white, bad body proportions, landscape" 28 | validation_prompt_path = "FACE_DATASETS/Samples_for_validation/validation_prompt.pt" 29 | 30 | num_validation_images = 4 31 | 32 | dataloader_num_workers = 0 33 | use_8bit_adam = False 34 | enable_xformers_memory_efficient_attention = False 35 | allow_tf32 = True # "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training." 36 | prior_generation_precision = "fp16" # Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Default to fp16 if a GPU is available, otherwise fp32. 37 | 38 | local_rank = 1 39 | tokenizer_max_length = None 40 | tokenizer_name = None 41 | text_encoder_use_attention_mask = False 42 | class_labels_conditioning = None 43 | max_train_steps = None 44 | 45 | seed = 0 46 | 47 | # Training parameters to experiment with 48 | lora_rank = 4 49 | train_batch_size = 1 50 | gradient_accumulation_steps = 1 51 | gradient_checkpointing = False 52 | 53 | num_train_epochs = 32 54 | validation_epochs = 8 55 | checkpointing_epochs = 8 56 | checkpoints_total_limit = None 57 | 58 | learning_rate = 1e-4 # TODO 59 | lr_scheduler = "cosine" 60 | lr_warmup_steps = 0 # TODO 500? 61 | 62 | adam_beta1 = 0.9 63 | adam_beta2 = 0.999 64 | adam_weight_decay = 1e-2 65 | adam_epsilon = 1e-08 66 | max_grad_norm = 1.0 67 | 68 | scale_lr = False 69 | lr_num_cycles = 1 # "Number of hard resets of the lr in cosine_with_restarts scheduler." 70 | lr_power = 1.0 # Power factor of the polynomial scheduler 71 | 72 | train_text_encoder = False 73 | pre_compute_text_embeddings = False 74 | 75 | 76 | losses_to_test = ["", "identity", "triplet_prior"] 77 | timestep_loss_weighting = True 78 | alpha_id_loss_weighting = 0.1 79 | 80 | sample_batch_size = 1 # 4 81 | 82 | output_folder = f"Trained_LoRA_Models/" #_NoPriorAblation" 83 | show_tqdm = True -------------------------------------------------------------------------------- /configs/config_train_SD21_FRIDA.py: -------------------------------------------------------------------------------- 1 | pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1-base" 2 | 3 | mixed_precision = "fp16" 4 | logging_dir = "Logs" 5 | report_to = "tensorboard" 6 | 7 | revision = None 8 | variant = None 9 | 10 | resume_from_checkpoint = None 11 | 12 | # Source images 13 | source_folder = "/shared/home/darian.tomasevic/ID-Booth/FACE_DATASETS/tufts_512_poses_1-7_all_imgs_jpg_per_ID/images" 14 | resolution = 512 15 | instance_prompt = "photo of sks person" # "face portrait photo of fid person" 16 | 17 | # Prior preservation images 18 | with_prior_preservation = True 19 | class_prompt = "photo of a person" #"face portrait photo of a person" # 20 | class_data_dir ="/shared/home/darian.tomasevic/ID-Booth/CLASS_IMAGES/SD21_Class_imgs_200/images" # "./CLASS_IMAGES/epiCRealism_SD15/images" 21 | num_class_images = 200 22 | prior_loss_weight = 1.0 23 | 24 | validation_prompt = "photo of sks person with blue hair"#"photo of [ID] person, portrait"# photo of [ID] person with blue hair" #"face portrait photo of fid person with blue hair" # 25 | validation_negative_prompt = "" #"cartoon, cgi, render, illustration, painting, drawing, black and white, bad body proportions, landscape" 26 | validation_prompt_path = "FACE_DATASETS/Samples_for_validation/validation_prompt.pt" 27 | 28 | num_validation_images = 4 29 | 30 | dataloader_num_workers = 0 31 | use_8bit_adam = False 32 | enable_xformers_memory_efficient_attention = False 33 | allow_tf32 = True # "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training." 34 | prior_generation_precision = "fp16" # Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Default to fp16 if a GPU is available, otherwise fp32. 35 | 36 | local_rank = 1 37 | tokenizer_max_length = None 38 | tokenizer_name = None 39 | text_encoder_use_attention_mask = False 40 | class_labels_conditioning = None 41 | max_train_steps = None 42 | 43 | seed = 0 44 | 45 | # Training parameters to experiment with 46 | lora_rank = 4 47 | train_batch_size = 1 48 | gradient_accumulation_steps = 1 49 | gradient_checkpointing = False 50 | 51 | num_train_epochs = 32 52 | validation_epochs = 8 53 | checkpointing_epochs = 8 54 | checkpoints_total_limit = None 55 | 56 | learning_rate = 1e-4 # TODO 57 | lr_scheduler = "cosine" 58 | lr_warmup_steps = 0 # TODO 500? 59 | 60 | adam_beta1 = 0.9 61 | adam_beta2 = 0.999 62 | adam_weight_decay = 1e-2 63 | adam_epsilon = 1e-08 64 | max_grad_norm = 1.0 65 | 66 | scale_lr = False 67 | lr_num_cycles = 1 # "Number of hard resets of the lr in cosine_with_restarts scheduler." 68 | lr_power = 1.0 # Power factor of the polynomial scheduler 69 | 70 | train_text_encoder = False 71 | pre_compute_text_embeddings = False 72 | 73 | 74 | losses_to_test = [""]#, "identity", "triplet_prior"] # TODO #[ "", "identity", "triplet_prior"] 75 | timestep_loss_weighting = True 76 | alpha_id_loss_weighting = 0.1 77 | 78 | sample_batch_size = 1 # 4 79 | 80 | output_folder = f"OUTPUT_MODELS/12-2024_SD21_LoRA{lora_rank}" #_NoPriorAblation" 81 | show_tqdm = True -------------------------------------------------------------------------------- /extract_ArcFace_embeds.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import json 5 | import os 6 | 7 | from Arcface_files.ArcFace_functions import prepare_locked_ArcFace_model 8 | from facenet_pytorch import MTCNN 9 | from PIL import Image 10 | from torchvision import transforms 11 | from tqdm import tqdm 12 | 13 | def prepare_for_arcface_model_torch(img): 14 | img = img.permute(2,0,1) 15 | img = transforms.functional.resize(img, (112, 112)) 16 | img = img.float() 17 | img = ((img / 255) - 0.5) / 0.5 18 | img = img[None, :, :, :] 19 | return img 20 | 21 | 22 | origin_path = "FACE_DATASET" 23 | device = "cuda:0" 24 | 25 | arcface_model = prepare_locked_ArcFace_model() 26 | arcface_model.to(device=device) 27 | 28 | mtcnn_model = MTCNN(image_size=112,device=device, margin=0) 29 | 30 | files_without_faces = dict() 31 | files_without_faces["files_without_faces"] = [] 32 | 33 | folders = os.listdir(os.path.join(origin_path, "images")) 34 | 35 | for folder in tqdm(folders): 36 | folder_path = os.path.join(origin_path, "images", folder) 37 | 38 | output_path = folder_path.replace("images", "ArcFace_embeds") 39 | os.makedirs(output_path, exist_ok=True) 40 | img_files = os.listdir(folder_path) 41 | 42 | images = [] 43 | for img_name in (img_files): 44 | img_path = os.path.join(folder_path,img_name) 45 | image = torch.from_numpy(np.array(Image.open(img_path))).to(device) 46 | images.append(image) 47 | 48 | images = torch.stack(images, dim=0) 49 | 50 | # image must be (x, x, 3) .. RGB 51 | # detect face bounding box 52 | bboxs, probs = mtcnn_model.detect(images, landmarks=False) 53 | 54 | images_cropped = [] 55 | for image, bbox in zip(images, bboxs): 56 | if bbox is None: 57 | files_without_faces["files_without_faces"].append(img_path) 58 | continue 59 | 60 | # crop to face 61 | bbox = bbox[0].astype(int) 62 | initial_size = image.shape[0] 63 | img_cropped = image[max(0,bbox[1]): min(bbox[3], initial_size ), 64 | max(0, bbox[0]): min(bbox[2], initial_size)] 65 | 66 | # transform to (1, 3, x, x) 67 | img_cropped = prepare_for_arcface_model_torch(img_cropped) 68 | images_cropped.append(img_cropped) 69 | 70 | images_cropped = torch.stack(images_cropped, dim=0) 71 | 72 | face_embeds = arcface_model(img_cropped) 73 | 74 | embed_file_name = folder + ".pt" 75 | torch.save(face_embeds, os.path.join(output_path, embed_file_name)) 76 | 77 | json_pth = f"{origin_path}/files_without_faces.json" 78 | 79 | print(json_pth) 80 | 81 | with open(json_pth, 'w') as fp: 82 | json.dump(files_without_faces, fp) -------------------------------------------------------------------------------- /inference_ID-Booth.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import torch 3 | import os 4 | from torchvision.utils import save_image 5 | from diffusers import DPMSolverMultistepScheduler 6 | from diffusers import DDPMScheduler 7 | import random 8 | from accelerate.utils import set_seed 9 | from tqdm import tqdm 10 | import re 11 | import json 12 | #from compel import Compel 13 | from diffusers import AutoPipelineForText2Image 14 | from itertools import product 15 | from utils.sorting_utils import natural_keys 16 | 17 | backgrounds_list = ["","forest", "city street", "beach", "office", "bus", "laboratory", "factory", "construction site", "hospital", "night club"] 18 | backgrounds_list = [f"{b} background" if b != "" else "" for b in backgrounds_list]# 19 | 20 | age_phases = ["", "young", "middle-aged", "old"] 21 | 22 | num_samples_per_prompt = 1 23 | num_prompts = 21 # 21 #21 #50 #21 #len(additions_list) 24 | 25 | add_gender = True 26 | add_pose = True 27 | add_age = False # should be first in combination 28 | add_background = True 29 | 30 | do_not_use_negative_prompt = False 31 | use_non_finetuned = False 32 | 33 | if add_age and add_background: 34 | all_prompt_combinations = list(product(age_phases, backgrounds_list)) 35 | elif add_background: 36 | if num_prompts == 100: 37 | all_prompt_combinations = list(backgrounds_list[1:] * 10) 38 | else: 39 | all_prompt_combinations = list([""] + backgrounds_list[1:] * 2) 40 | 41 | elif add_age: 42 | all_prompt_combinations = list(age_phases * 6) 43 | else: 44 | all_prompt_combinations = list([""] * num_prompts) 45 | print(all_prompt_combinations) 46 | 47 | device = "cuda:0" 48 | seed = 0 49 | guidance_scale = 5.0 50 | num_inference_steps = 30 51 | 52 | folder_of_models = f"Trained_LoRA_Models" 53 | models_to_test = ["DreamBooth", "PortraitBooth", "ID-Booth"] 54 | checkpoint = "checkpoint-31-6400" 55 | 56 | folder_output = f"Generated_Samples/FacePortrait_Photo_21" # _NonFinetuned 57 | if add_gender: folder_output += "_Gender" 58 | if add_pose: folder_output+= "_Pose" 59 | if add_age: folder_output+= "_Age" 60 | if add_background: folder_output += "_Background" 61 | if do_not_use_negative_prompt: folder_output += "_NoNegPrompt" 62 | 63 | architectures = ["stabilityai/stable-diffusion-2-1-base"] 64 | model_architecture = architectures[0] 65 | arch = model_architecture.split("/")[1] 66 | 67 | set_seed(seed) 68 | 69 | width, height = 512, 512 70 | 71 | ids = os.listdir(os.path.join(folder_of_models, models_to_test[0])) 72 | ids = [i for i in ids if ".json" not in i] 73 | ids.sort(key=natural_keys) 74 | 75 | print(ids) 76 | if add_gender: 77 | with open("tufts_gender_dict.json", "r") as fp: 78 | gender_dict = json.load(fp) 79 | 80 | 81 | negative_prompt = "cartoon, cgi, render, illustration, painting, drawing, black and white, bad body proportions, landscape" 82 | original_prompt = f"face portrait photo of sks person" 83 | 84 | prompt = "" 85 | 86 | for id_number, which_id in enumerate(ids): 87 | print("\n", which_id) 88 | 89 | if add_gender: 90 | gender = gender_dict[which_id] 91 | if gender == "M": gender = "male" 92 | elif gender == "F": gender = "female" 93 | 94 | all_prompts_for_id = random.sample(all_prompt_combinations, num_prompts) 95 | 96 | comparison_image_list = [] 97 | for model_name in models_to_test: 98 | full_model_path = os.path.join(folder_of_models, model_name, which_id, checkpoint) 99 | 100 | output_dir = os.path.join(folder_output, model_name)#"GENERATED_SAMPLES/FINAL_No_ID_loss_TEST" 101 | print("Load:", full_model_path) 102 | 103 | pipe = StableDiffusionPipeline.from_pretrained(model_architecture, torch_dtype=torch.float16).to(device) 104 | pipe.scheduler = DDPMScheduler.from_pretrained(model_architecture, subfolder="scheduler") 105 | 106 | if not use_non_finetuned: 107 | pipe.load_lora_weights(full_model_path) 108 | pipe.set_progress_bar_config(disable=True) 109 | 110 | os.makedirs(output_dir, exist_ok=True) 111 | generator = torch.Generator(device=device).manual_seed(id_number) 112 | 113 | for i, num_prompt in enumerate(tqdm(range(num_prompts))): 114 | prompt_additions = all_prompts_for_id[i] 115 | prompt = original_prompt 116 | if add_age: 117 | age_insert = "" 118 | if isinstance(prompt_additions, str): age_insert = prompt_additions 119 | else: 120 | age_insert = prompt_additions[0] 121 | prompt_additions = prompt_additions[1:] 122 | if age_insert != "": prompt = prompt.replace(" sks person", f" {age_insert} sks person") 123 | 124 | 125 | if add_gender: prompt = prompt.replace(" sks person", f" {gender} sks person") 126 | if add_pose and random.choice([True, False]): prompt = prompt.replace("portrait", "side-portrait") 127 | 128 | if add_background: 129 | if isinstance(prompt_additions, str): 130 | prompt += f", {prompt_additions}" 131 | else: 132 | for addition in prompt_additions: 133 | if addition != "": 134 | prompt += f", {addition}" 135 | 136 | # generate samples 137 | for j in range(num_samples_per_prompt): 138 | output = pipe(prompt=prompt, negative_prompt=negative_prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height) 139 | output = torch.Tensor(output.images) 140 | comparison_image_list.append(output) 141 | output = torch.permute(output, (0, 3, 1, 2)) 142 | path_to_output = f"{output_dir}/{which_id}_{checkpoint}_{arch}" 143 | os.makedirs(path_to_output, exist_ok=True) 144 | save_image(output, fp=f"{path_to_output}/{i}_{j}_{prompt}.png") 145 | 146 | images = torch.cat(comparison_image_list) 147 | images = torch.permute(images, (0, 3, 1, 2)) # permute dimensions to be (x, 3, 512, 512) 148 | 149 | print("Saving comparison image") 150 | comparison_folder = "Comparison" 151 | 152 | comparison_folder = os.path.join(folder_output, comparison_folder) 153 | os.makedirs(comparison_folder, exist_ok=True) 154 | save_path = f"{comparison_folder}/{which_id}_{checkpoint}_{arch}_{guidance_scale}.jpg" 155 | print(save_path) 156 | save_image(images, fp=save_path, nrow=num_prompts*num_samples_per_prompt, padding=0) 157 | 158 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<1.26.0 2 | torch==2.1.2 3 | torchvision==0.16.2 4 | diffusers==0.32.2 5 | transformers==4.34.1 6 | peft 7 | accelerate 8 | facenet-pytorch -------------------------------------------------------------------------------- /utils/augmentation_with_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from tqdm import tqdm 4 | 5 | root_folder = "../FR_DATASETS" 6 | # With additional 10 or 21 synthetic images 7 | how_many_samples = 100 8 | target_root = f"../FR_DATASETS_AUGMENTED_+{how_many_samples}_samples" 9 | 10 | starting_dataset = f"{root_folder}/tufts_512_poses_1-7_all_imgs_jpg_per_ID/images" 11 | # which_folders = ["12-2024_SD21_LoRA4_alphaWNone_FINAL_FacePortraitPhoto_Gender_Pose_BackgroundB"] 12 | which_folders = ["12-2024_SD21_LoRA4_alphaWNone_FacePortrait_Photo_Gender_Pose_BackgroundB_100samples"] 13 | 14 | for which_folder in which_folders: 15 | in_folder = os.path.join(root_folder, which_folder) 16 | model_folders = os.listdir(in_folder) 17 | 18 | for model_fold in tqdm(model_folders): 19 | 20 | print(model_fold) 21 | if "embeddings" in model_fold: continue 22 | if ".json" in model_fold: continue 23 | 24 | in_model = os.path.join(root_folder, which_folder, model_fold) 25 | 26 | # go across generated samples, copy them 27 | for img_name in os.listdir(in_model): 28 | # only copy how many samples 29 | sample_number = int(img_name.split("_")[1]) 30 | if sample_number >= how_many_samples: 31 | continue 32 | 33 | src_img_path = os.path.join(in_model, img_name) 34 | 35 | tar_fold = os.path.join(target_root, which_folder, model_fold) 36 | 37 | 38 | os.makedirs(tar_fold, exist_ok=True) 39 | 40 | tar_img_path = os.path.join(tar_fold, img_name) 41 | 42 | shutil.copyfile(src_img_path, tar_img_path) 43 | 44 | # go across real samples, copy them to the same folder as well 45 | for img_name in os.listdir(starting_dataset): 46 | 47 | src_img_path = os.path.join(starting_dataset, img_name) 48 | tar_fold = os.path.join(target_root, which_folder, model_fold) 49 | 50 | tar_img_path = os.path.join(tar_fold, img_name.replace(".jpg", ".png")) 51 | 52 | shutil.copyfile(src_img_path, tar_img_path) 53 | 54 | -------------------------------------------------------------------------------- /utils/sorting_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | ############################################## 4 | def atoi(text): 5 | return int(text) if text.isdigit() else text 6 | 7 | def natural_keys(text): 8 | ''' 9 | alist.sort(key=natural_keys) sorts in human order 10 | http://nedbatchelder.com/blog/200712/human_sorting.html 11 | (See Toothy's implementation in the comments) 12 | ''' 13 | return [ atoi(c) for c in re.split(r'(\d+)', text) ] 14 | ############################################## --------------------------------------------------------------------------------