├── .gitattributes ├── .gitignore ├── MainModel.py ├── README.md ├── evaluation_protocols └── classifier_metrics │ └── metrics.py ├── images ├── paper_training_algorithm.png └── vggface_tsne_base_ft_models_8.png ├── main.py ├── requirements.txt ├── trainer.py ├── utils.py ├── vggface2_custom_dataset.py └── vggface2_data_manager.py /.gitattributes: -------------------------------------------------------------------------------- 1 | senet50_ft_pytorch.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | *.pyc 3 | *pycache* 4 | *.npy 5 | *.pth 6 | *.ipynb 7 | *.sh 8 | *.idea 9 | *.hdf5 10 | *.sh 11 | *.jpg 12 | *.pdf 13 | *.save 14 | *.png 15 | *.csv 16 | *.prototxt 17 | *.tar 18 | *.pt 19 | *__init__* 20 | *.pyc 21 | *_old* 22 | experiments_results/* -------------------------------------------------------------------------------- /MainModel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __weights_dict = dict() 7 | 8 | def load_weights(weight_file): 9 | if weight_file == None: 10 | return 11 | 12 | try: 13 | weights_dict = np.load(weight_file).item() 14 | except: 15 | weights_dict = np.load(weight_file, encoding='bytes').item() 16 | 17 | return weights_dict 18 | 19 | class KitModel(nn.Module): 20 | 21 | 22 | def __init__(self, weight_file): 23 | super(KitModel, self).__init__() 24 | global __weights_dict 25 | __weights_dict = load_weights(weight_file) 26 | 27 | self.conv1_7x7_s2 = self.__conv(2, name='conv1/7x7_s2', in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=False) 28 | self.conv1_7x7_s2_bn = self.__batch_normalization(2, 'conv1/7x7_s2/bn', num_features=64, eps=9.99999974738e-06, momentum=0.0) 29 | self.conv2_1_1x1_reduce = self.__conv(2, name='conv2_1_1x1_reduce', in_channels=64, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 30 | self.conv2_1_1x1_proj = self.__conv(2, name='conv2_1_1x1_proj', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 31 | self.conv2_1_1x1_reduce_bn = self.__batch_normalization(2, 'conv2_1_1x1_reduce/bn', num_features=64, eps=9.99999974738e-06, momentum=0.0) 32 | self.conv2_1_1x1_proj_bn = self.__batch_normalization(2, 'conv2_1_1x1_proj/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 33 | self.conv2_1_3x3 = self.__conv(2, name='conv2_1_3x3', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 34 | self.conv2_1_3x3_bn = self.__batch_normalization(2, 'conv2_1_3x3/bn', num_features=64, eps=9.99999974738e-06, momentum=0.0) 35 | self.conv2_1_1x1_increase = self.__conv(2, name='conv2_1_1x1_increase', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 36 | self.conv2_1_1x1_increase_bn = self.__batch_normalization(2, 'conv2_1_1x1_increase/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 37 | self.conv2_1_1x1_down = self.__conv(2, name='conv2_1_1x1_down', in_channels=256, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 38 | self.conv2_1_1x1_up = self.__conv(2, name='conv2_1_1x1_up', in_channels=16, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 39 | self.conv2_2_1x1_reduce = self.__conv(2, name='conv2_2_1x1_reduce', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 40 | self.conv2_2_1x1_reduce_bn = self.__batch_normalization(2, 'conv2_2_1x1_reduce/bn', num_features=64, eps=9.99999974738e-06, momentum=0.0) 41 | self.conv2_2_3x3 = self.__conv(2, name='conv2_2_3x3', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 42 | self.conv2_2_3x3_bn = self.__batch_normalization(2, 'conv2_2_3x3/bn', num_features=64, eps=9.99999974738e-06, momentum=0.0) 43 | self.conv2_2_1x1_increase = self.__conv(2, name='conv2_2_1x1_increase', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 44 | self.conv2_2_1x1_increase_bn = self.__batch_normalization(2, 'conv2_2_1x1_increase/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 45 | self.conv2_2_1x1_down = self.__conv(2, name='conv2_2_1x1_down', in_channels=256, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 46 | self.conv2_2_1x1_up = self.__conv(2, name='conv2_2_1x1_up', in_channels=16, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 47 | self.conv2_3_1x1_reduce = self.__conv(2, name='conv2_3_1x1_reduce', in_channels=256, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 48 | self.conv2_3_1x1_reduce_bn = self.__batch_normalization(2, 'conv2_3_1x1_reduce/bn', num_features=64, eps=9.99999974738e-06, momentum=0.0) 49 | self.conv2_3_3x3 = self.__conv(2, name='conv2_3_3x3', in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 50 | self.conv2_3_3x3_bn = self.__batch_normalization(2, 'conv2_3_3x3/bn', num_features=64, eps=9.99999974738e-06, momentum=0.0) 51 | self.conv2_3_1x1_increase = self.__conv(2, name='conv2_3_1x1_increase', in_channels=64, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 52 | self.conv2_3_1x1_increase_bn = self.__batch_normalization(2, 'conv2_3_1x1_increase/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 53 | self.conv2_3_1x1_down = self.__conv(2, name='conv2_3_1x1_down', in_channels=256, out_channels=16, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 54 | self.conv2_3_1x1_up = self.__conv(2, name='conv2_3_1x1_up', in_channels=16, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 55 | self.conv3_1_1x1_proj = self.__conv(2, name='conv3_1_1x1_proj', in_channels=256, out_channels=512, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False) 56 | self.conv3_1_1x1_reduce = self.__conv(2, name='conv3_1_1x1_reduce', in_channels=256, out_channels=128, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False) 57 | self.conv3_1_1x1_proj_bn = self.__batch_normalization(2, 'conv3_1_1x1_proj/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 58 | self.conv3_1_1x1_reduce_bn = self.__batch_normalization(2, 'conv3_1_1x1_reduce/bn', num_features=128, eps=9.99999974738e-06, momentum=0.0) 59 | self.conv3_1_3x3 = self.__conv(2, name='conv3_1_3x3', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 60 | self.conv3_1_3x3_bn = self.__batch_normalization(2, 'conv3_1_3x3/bn', num_features=128, eps=9.99999974738e-06, momentum=0.0) 61 | self.conv3_1_1x1_increase = self.__conv(2, name='conv3_1_1x1_increase', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 62 | self.conv3_1_1x1_increase_bn = self.__batch_normalization(2, 'conv3_1_1x1_increase/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 63 | self.conv3_1_1x1_down = self.__conv(2, name='conv3_1_1x1_down', in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 64 | self.conv3_1_1x1_up = self.__conv(2, name='conv3_1_1x1_up', in_channels=32, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 65 | self.conv3_2_1x1_reduce = self.__conv(2, name='conv3_2_1x1_reduce', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 66 | self.conv3_2_1x1_reduce_bn = self.__batch_normalization(2, 'conv3_2_1x1_reduce/bn', num_features=128, eps=9.99999974738e-06, momentum=0.0) 67 | self.conv3_2_3x3 = self.__conv(2, name='conv3_2_3x3', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 68 | self.conv3_2_3x3_bn = self.__batch_normalization(2, 'conv3_2_3x3/bn', num_features=128, eps=9.99999974738e-06, momentum=0.0) 69 | self.conv3_2_1x1_increase = self.__conv(2, name='conv3_2_1x1_increase', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 70 | self.conv3_2_1x1_increase_bn = self.__batch_normalization(2, 'conv3_2_1x1_increase/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 71 | self.conv3_2_1x1_down = self.__conv(2, name='conv3_2_1x1_down', in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 72 | self.conv3_2_1x1_up = self.__conv(2, name='conv3_2_1x1_up', in_channels=32, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 73 | self.conv3_3_1x1_reduce = self.__conv(2, name='conv3_3_1x1_reduce', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 74 | self.conv3_3_1x1_reduce_bn = self.__batch_normalization(2, 'conv3_3_1x1_reduce/bn', num_features=128, eps=9.99999974738e-06, momentum=0.0) 75 | self.conv3_3_3x3 = self.__conv(2, name='conv3_3_3x3', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 76 | self.conv3_3_3x3_bn = self.__batch_normalization(2, 'conv3_3_3x3/bn', num_features=128, eps=9.99999974738e-06, momentum=0.0) 77 | self.conv3_3_1x1_increase = self.__conv(2, name='conv3_3_1x1_increase', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 78 | self.conv3_3_1x1_increase_bn = self.__batch_normalization(2, 'conv3_3_1x1_increase/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 79 | self.conv3_3_1x1_down = self.__conv(2, name='conv3_3_1x1_down', in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 80 | self.conv3_3_1x1_up = self.__conv(2, name='conv3_3_1x1_up', in_channels=32, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 81 | self.conv3_4_1x1_reduce = self.__conv(2, name='conv3_4_1x1_reduce', in_channels=512, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 82 | self.conv3_4_1x1_reduce_bn = self.__batch_normalization(2, 'conv3_4_1x1_reduce/bn', num_features=128, eps=9.99999974738e-06, momentum=0.0) 83 | self.conv3_4_3x3 = self.__conv(2, name='conv3_4_3x3', in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 84 | self.conv3_4_3x3_bn = self.__batch_normalization(2, 'conv3_4_3x3/bn', num_features=128, eps=9.99999974738e-06, momentum=0.0) 85 | self.conv3_4_1x1_increase = self.__conv(2, name='conv3_4_1x1_increase', in_channels=128, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 86 | self.conv3_4_1x1_increase_bn = self.__batch_normalization(2, 'conv3_4_1x1_increase/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 87 | self.conv3_4_1x1_down = self.__conv(2, name='conv3_4_1x1_down', in_channels=512, out_channels=32, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 88 | self.conv3_4_1x1_up = self.__conv(2, name='conv3_4_1x1_up', in_channels=32, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 89 | self.conv4_1_1x1_proj = self.__conv(2, name='conv4_1_1x1_proj', in_channels=512, out_channels=1024, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False) 90 | self.conv4_1_1x1_reduce = self.__conv(2, name='conv4_1_1x1_reduce', in_channels=512, out_channels=256, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False) 91 | self.conv4_1_1x1_proj_bn = self.__batch_normalization(2, 'conv4_1_1x1_proj/bn', num_features=1024, eps=9.99999974738e-06, momentum=0.0) 92 | self.conv4_1_1x1_reduce_bn = self.__batch_normalization(2, 'conv4_1_1x1_reduce/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 93 | self.conv4_1_3x3 = self.__conv(2, name='conv4_1_3x3', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 94 | self.conv4_1_3x3_bn = self.__batch_normalization(2, 'conv4_1_3x3/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 95 | self.conv4_1_1x1_increase = self.__conv(2, name='conv4_1_1x1_increase', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 96 | self.conv4_1_1x1_increase_bn = self.__batch_normalization(2, 'conv4_1_1x1_increase/bn', num_features=1024, eps=9.99999974738e-06, momentum=0.0) 97 | self.conv4_1_1x1_down = self.__conv(2, name='conv4_1_1x1_down', in_channels=1024, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 98 | self.conv4_1_1x1_up = self.__conv(2, name='conv4_1_1x1_up', in_channels=64, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 99 | self.conv4_2_1x1_reduce = self.__conv(2, name='conv4_2_1x1_reduce', in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 100 | self.conv4_2_1x1_reduce_bn = self.__batch_normalization(2, 'conv4_2_1x1_reduce/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 101 | self.conv4_2_3x3 = self.__conv(2, name='conv4_2_3x3', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 102 | self.conv4_2_3x3_bn = self.__batch_normalization(2, 'conv4_2_3x3/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 103 | self.conv4_2_1x1_increase = self.__conv(2, name='conv4_2_1x1_increase', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 104 | self.conv4_2_1x1_increase_bn = self.__batch_normalization(2, 'conv4_2_1x1_increase/bn', num_features=1024, eps=9.99999974738e-06, momentum=0.0) 105 | self.conv4_2_1x1_down = self.__conv(2, name='conv4_2_1x1_down', in_channels=1024, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 106 | self.conv4_2_1x1_up = self.__conv(2, name='conv4_2_1x1_up', in_channels=64, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 107 | self.conv4_3_1x1_reduce = self.__conv(2, name='conv4_3_1x1_reduce', in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 108 | self.conv4_3_1x1_reduce_bn = self.__batch_normalization(2, 'conv4_3_1x1_reduce/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 109 | self.conv4_3_3x3 = self.__conv(2, name='conv4_3_3x3', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 110 | self.conv4_3_3x3_bn = self.__batch_normalization(2, 'conv4_3_3x3/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 111 | self.conv4_3_1x1_increase = self.__conv(2, name='conv4_3_1x1_increase', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 112 | self.conv4_3_1x1_increase_bn = self.__batch_normalization(2, 'conv4_3_1x1_increase/bn', num_features=1024, eps=9.99999974738e-06, momentum=0.0) 113 | self.conv4_3_1x1_down = self.__conv(2, name='conv4_3_1x1_down', in_channels=1024, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 114 | self.conv4_3_1x1_up = self.__conv(2, name='conv4_3_1x1_up', in_channels=64, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 115 | self.conv4_4_1x1_reduce = self.__conv(2, name='conv4_4_1x1_reduce', in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 116 | self.conv4_4_1x1_reduce_bn = self.__batch_normalization(2, 'conv4_4_1x1_reduce/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 117 | self.conv4_4_3x3 = self.__conv(2, name='conv4_4_3x3', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 118 | self.conv4_4_3x3_bn = self.__batch_normalization(2, 'conv4_4_3x3/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 119 | self.conv4_4_1x1_increase = self.__conv(2, name='conv4_4_1x1_increase', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 120 | self.conv4_4_1x1_increase_bn = self.__batch_normalization(2, 'conv4_4_1x1_increase/bn', num_features=1024, eps=9.99999974738e-06, momentum=0.0) 121 | self.conv4_4_1x1_down = self.__conv(2, name='conv4_4_1x1_down', in_channels=1024, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 122 | self.conv4_4_1x1_up = self.__conv(2, name='conv4_4_1x1_up', in_channels=64, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 123 | self.conv4_5_1x1_reduce = self.__conv(2, name='conv4_5_1x1_reduce', in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 124 | self.conv4_5_1x1_reduce_bn = self.__batch_normalization(2, 'conv4_5_1x1_reduce/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 125 | self.conv4_5_3x3 = self.__conv(2, name='conv4_5_3x3', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 126 | self.conv4_5_3x3_bn = self.__batch_normalization(2, 'conv4_5_3x3/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 127 | self.conv4_5_1x1_increase = self.__conv(2, name='conv4_5_1x1_increase', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 128 | self.conv4_5_1x1_increase_bn = self.__batch_normalization(2, 'conv4_5_1x1_increase/bn', num_features=1024, eps=9.99999974738e-06, momentum=0.0) 129 | self.conv4_5_1x1_down = self.__conv(2, name='conv4_5_1x1_down', in_channels=1024, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 130 | self.conv4_5_1x1_up = self.__conv(2, name='conv4_5_1x1_up', in_channels=64, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 131 | self.conv4_6_1x1_reduce = self.__conv(2, name='conv4_6_1x1_reduce', in_channels=1024, out_channels=256, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 132 | self.conv4_6_1x1_reduce_bn = self.__batch_normalization(2, 'conv4_6_1x1_reduce/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 133 | self.conv4_6_3x3 = self.__conv(2, name='conv4_6_3x3', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 134 | self.conv4_6_3x3_bn = self.__batch_normalization(2, 'conv4_6_3x3/bn', num_features=256, eps=9.99999974738e-06, momentum=0.0) 135 | self.conv4_6_1x1_increase = self.__conv(2, name='conv4_6_1x1_increase', in_channels=256, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 136 | self.conv4_6_1x1_increase_bn = self.__batch_normalization(2, 'conv4_6_1x1_increase/bn', num_features=1024, eps=9.99999974738e-06, momentum=0.0) 137 | self.conv4_6_1x1_down = self.__conv(2, name='conv4_6_1x1_down', in_channels=1024, out_channels=64, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 138 | self.conv4_6_1x1_up = self.__conv(2, name='conv4_6_1x1_up', in_channels=64, out_channels=1024, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 139 | self.conv5_1_1x1_proj = self.__conv(2, name='conv5_1_1x1_proj', in_channels=1024, out_channels=2048, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False) 140 | self.conv5_1_1x1_reduce = self.__conv(2, name='conv5_1_1x1_reduce', in_channels=1024, out_channels=512, kernel_size=(1, 1), stride=(2, 2), groups=1, bias=False) 141 | self.conv5_1_1x1_proj_bn = self.__batch_normalization(2, 'conv5_1_1x1_proj/bn', num_features=2048, eps=9.99999974738e-06, momentum=0.0) 142 | self.conv5_1_1x1_reduce_bn = self.__batch_normalization(2, 'conv5_1_1x1_reduce/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 143 | self.conv5_1_3x3 = self.__conv(2, name='conv5_1_3x3', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 144 | self.conv5_1_3x3_bn = self.__batch_normalization(2, 'conv5_1_3x3/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 145 | self.conv5_1_1x1_increase = self.__conv(2, name='conv5_1_1x1_increase', in_channels=512, out_channels=2048, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 146 | self.conv5_1_1x1_increase_bn = self.__batch_normalization(2, 'conv5_1_1x1_increase/bn', num_features=2048, eps=9.99999974738e-06, momentum=0.0) 147 | self.conv5_1_1x1_down = self.__conv(2, name='conv5_1_1x1_down', in_channels=2048, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 148 | self.conv5_1_1x1_up = self.__conv(2, name='conv5_1_1x1_up', in_channels=128, out_channels=2048, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 149 | self.conv5_2_1x1_reduce = self.__conv(2, name='conv5_2_1x1_reduce', in_channels=2048, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 150 | self.conv5_2_1x1_reduce_bn = self.__batch_normalization(2, 'conv5_2_1x1_reduce/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 151 | self.conv5_2_3x3 = self.__conv(2, name='conv5_2_3x3', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 152 | self.conv5_2_3x3_bn = self.__batch_normalization(2, 'conv5_2_3x3/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 153 | self.conv5_2_1x1_increase = self.__conv(2, name='conv5_2_1x1_increase', in_channels=512, out_channels=2048, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 154 | self.conv5_2_1x1_increase_bn = self.__batch_normalization(2, 'conv5_2_1x1_increase/bn', num_features=2048, eps=9.99999974738e-06, momentum=0.0) 155 | self.conv5_2_1x1_down = self.__conv(2, name='conv5_2_1x1_down', in_channels=2048, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 156 | self.conv5_2_1x1_up = self.__conv(2, name='conv5_2_1x1_up', in_channels=128, out_channels=2048, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 157 | self.conv5_3_1x1_reduce = self.__conv(2, name='conv5_3_1x1_reduce', in_channels=2048, out_channels=512, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 158 | self.conv5_3_1x1_reduce_bn = self.__batch_normalization(2, 'conv5_3_1x1_reduce/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 159 | self.conv5_3_3x3 = self.__conv(2, name='conv5_3_3x3', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=False) 160 | self.conv5_3_3x3_bn = self.__batch_normalization(2, 'conv5_3_3x3/bn', num_features=512, eps=9.99999974738e-06, momentum=0.0) 161 | self.conv5_3_1x1_increase = self.__conv(2, name='conv5_3_1x1_increase', in_channels=512, out_channels=2048, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=False) 162 | self.conv5_3_1x1_increase_bn = self.__batch_normalization(2, 'conv5_3_1x1_increase/bn', num_features=2048, eps=9.99999974738e-06, momentum=0.0) 163 | self.conv5_3_1x1_down = self.__conv(2, name='conv5_3_1x1_down', in_channels=2048, out_channels=128, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 164 | self.conv5_3_1x1_up = self.__conv(2, name='conv5_3_1x1_up', in_channels=128, out_channels=2048, kernel_size=(1, 1), stride=(1, 1), groups=1, bias=True) 165 | self.classifier_1 = self.__dense(name = 'classifier_1', in_features = 2048, out_features = 8631, bias = True) 166 | 167 | def forward(self, x): 168 | conv1_7x7_s2_pad = F.pad(x, (3, 3, 3, 3)) 169 | conv1_7x7_s2 = self.conv1_7x7_s2(conv1_7x7_s2_pad) 170 | conv1_7x7_s2_bn = self.conv1_7x7_s2_bn(conv1_7x7_s2) 171 | conv1_relu_7x7_s2 = F.relu(conv1_7x7_s2_bn) 172 | pool1_3x3_s2_pad = F.pad(conv1_relu_7x7_s2, (0, 1, 0, 1), value=float('-inf')) 173 | pool1_3x3_s2 = F.max_pool2d(pool1_3x3_s2_pad, kernel_size=(3, 3), stride=(2, 2), padding=0, ceil_mode=False) 174 | conv2_1_1x1_reduce = self.conv2_1_1x1_reduce(pool1_3x3_s2) 175 | conv2_1_1x1_proj = self.conv2_1_1x1_proj(pool1_3x3_s2) 176 | conv2_1_1x1_reduce_bn = self.conv2_1_1x1_reduce_bn(conv2_1_1x1_reduce) 177 | conv2_1_1x1_proj_bn = self.conv2_1_1x1_proj_bn(conv2_1_1x1_proj) 178 | conv2_1_1x1_reduce_relu = F.relu(conv2_1_1x1_reduce_bn) 179 | conv2_1_3x3_pad = F.pad(conv2_1_1x1_reduce_relu, (1, 1, 1, 1)) 180 | conv2_1_3x3 = self.conv2_1_3x3(conv2_1_3x3_pad) 181 | conv2_1_3x3_bn = self.conv2_1_3x3_bn(conv2_1_3x3) 182 | conv2_1_3x3_relu = F.relu(conv2_1_3x3_bn) 183 | conv2_1_1x1_increase = self.conv2_1_1x1_increase(conv2_1_3x3_relu) 184 | conv2_1_1x1_increase_bn = self.conv2_1_1x1_increase_bn(conv2_1_1x1_increase) 185 | conv2_1_global_pool = F.avg_pool2d(conv2_1_1x1_increase_bn, kernel_size=(56, 56), stride=(1, 1), padding=(0,), ceil_mode=False) 186 | conv2_1_1x1_down = self.conv2_1_1x1_down(conv2_1_global_pool) 187 | conv2_1_1x1_down_relu = F.relu(conv2_1_1x1_down) 188 | conv2_1_1x1_up = self.conv2_1_1x1_up(conv2_1_1x1_down_relu) 189 | conv2_1_prob = F.sigmoid(conv2_1_1x1_up) 190 | conv2_1_1x1_increase_bn_scale = conv2_1_prob * conv2_1_1x1_increase_bn 191 | conv2_1 = conv2_1_1x1_increase_bn_scale + conv2_1_1x1_proj_bn 192 | conv2_1_relu = F.relu(conv2_1) 193 | conv2_2_1x1_reduce = self.conv2_2_1x1_reduce(conv2_1_relu) 194 | conv2_2_1x1_reduce_bn = self.conv2_2_1x1_reduce_bn(conv2_2_1x1_reduce) 195 | conv2_2_1x1_reduce_relu = F.relu(conv2_2_1x1_reduce_bn) 196 | conv2_2_3x3_pad = F.pad(conv2_2_1x1_reduce_relu, (1, 1, 1, 1)) 197 | conv2_2_3x3 = self.conv2_2_3x3(conv2_2_3x3_pad) 198 | conv2_2_3x3_bn = self.conv2_2_3x3_bn(conv2_2_3x3) 199 | conv2_2_3x3_relu = F.relu(conv2_2_3x3_bn) 200 | conv2_2_1x1_increase = self.conv2_2_1x1_increase(conv2_2_3x3_relu) 201 | conv2_2_1x1_increase_bn = self.conv2_2_1x1_increase_bn(conv2_2_1x1_increase) 202 | conv2_2_global_pool = F.avg_pool2d(conv2_2_1x1_increase_bn, kernel_size=(56, 56), stride=(1, 1), padding=(0,), ceil_mode=False) 203 | conv2_2_1x1_down = self.conv2_2_1x1_down(conv2_2_global_pool) 204 | conv2_2_1x1_down_relu = F.relu(conv2_2_1x1_down) 205 | conv2_2_1x1_up = self.conv2_2_1x1_up(conv2_2_1x1_down_relu) 206 | conv2_2_prob = F.sigmoid(conv2_2_1x1_up) 207 | conv2_2_1x1_increase_bn_scale = conv2_2_prob * conv2_2_1x1_increase_bn 208 | conv2_2 = conv2_2_1x1_increase_bn_scale + conv2_1_relu 209 | conv2_2_relu = F.relu(conv2_2) 210 | conv2_3_1x1_reduce = self.conv2_3_1x1_reduce(conv2_2_relu) 211 | conv2_3_1x1_reduce_bn = self.conv2_3_1x1_reduce_bn(conv2_3_1x1_reduce) 212 | conv2_3_1x1_reduce_relu = F.relu(conv2_3_1x1_reduce_bn) 213 | conv2_3_3x3_pad = F.pad(conv2_3_1x1_reduce_relu, (1, 1, 1, 1)) 214 | conv2_3_3x3 = self.conv2_3_3x3(conv2_3_3x3_pad) 215 | conv2_3_3x3_bn = self.conv2_3_3x3_bn(conv2_3_3x3) 216 | conv2_3_3x3_relu = F.relu(conv2_3_3x3_bn) 217 | conv2_3_1x1_increase = self.conv2_3_1x1_increase(conv2_3_3x3_relu) 218 | conv2_3_1x1_increase_bn = self.conv2_3_1x1_increase_bn(conv2_3_1x1_increase) 219 | conv2_3_global_pool = F.avg_pool2d(conv2_3_1x1_increase_bn, kernel_size=(56, 56), stride=(1, 1), padding=(0,), ceil_mode=False) 220 | conv2_3_1x1_down = self.conv2_3_1x1_down(conv2_3_global_pool) 221 | conv2_3_1x1_down_relu = F.relu(conv2_3_1x1_down) 222 | conv2_3_1x1_up = self.conv2_3_1x1_up(conv2_3_1x1_down_relu) 223 | conv2_3_prob = F.sigmoid(conv2_3_1x1_up) 224 | conv2_3_1x1_increase_bn_scale = conv2_3_prob * conv2_3_1x1_increase_bn 225 | conv2_3 = conv2_3_1x1_increase_bn_scale + conv2_2_relu 226 | conv2_3_relu = F.relu(conv2_3) 227 | conv3_1_1x1_proj = self.conv3_1_1x1_proj(conv2_3_relu) 228 | conv3_1_1x1_reduce = self.conv3_1_1x1_reduce(conv2_3_relu) 229 | conv3_1_1x1_proj_bn = self.conv3_1_1x1_proj_bn(conv3_1_1x1_proj) 230 | conv3_1_1x1_reduce_bn = self.conv3_1_1x1_reduce_bn(conv3_1_1x1_reduce) 231 | conv3_1_1x1_reduce_relu = F.relu(conv3_1_1x1_reduce_bn) 232 | conv3_1_3x3_pad = F.pad(conv3_1_1x1_reduce_relu, (1, 1, 1, 1)) 233 | conv3_1_3x3 = self.conv3_1_3x3(conv3_1_3x3_pad) 234 | conv3_1_3x3_bn = self.conv3_1_3x3_bn(conv3_1_3x3) 235 | conv3_1_3x3_relu = F.relu(conv3_1_3x3_bn) 236 | conv3_1_1x1_increase = self.conv3_1_1x1_increase(conv3_1_3x3_relu) 237 | conv3_1_1x1_increase_bn = self.conv3_1_1x1_increase_bn(conv3_1_1x1_increase) 238 | conv3_1_global_pool = F.avg_pool2d(conv3_1_1x1_increase_bn, kernel_size=(28, 28), stride=(1, 1), padding=(0,), ceil_mode=False) 239 | conv3_1_1x1_down = self.conv3_1_1x1_down(conv3_1_global_pool) 240 | conv3_1_1x1_down_relu = F.relu(conv3_1_1x1_down) 241 | conv3_1_1x1_up = self.conv3_1_1x1_up(conv3_1_1x1_down_relu) 242 | conv3_1_prob = F.sigmoid(conv3_1_1x1_up) 243 | conv3_1_1x1_increase_bn_scale = conv3_1_prob * conv3_1_1x1_increase_bn 244 | conv3_1 = conv3_1_1x1_increase_bn_scale + conv3_1_1x1_proj_bn 245 | conv3_1_relu = F.relu(conv3_1) 246 | conv3_2_1x1_reduce = self.conv3_2_1x1_reduce(conv3_1_relu) 247 | conv3_2_1x1_reduce_bn = self.conv3_2_1x1_reduce_bn(conv3_2_1x1_reduce) 248 | conv3_2_1x1_reduce_relu = F.relu(conv3_2_1x1_reduce_bn) 249 | conv3_2_3x3_pad = F.pad(conv3_2_1x1_reduce_relu, (1, 1, 1, 1)) 250 | conv3_2_3x3 = self.conv3_2_3x3(conv3_2_3x3_pad) 251 | conv3_2_3x3_bn = self.conv3_2_3x3_bn(conv3_2_3x3) 252 | conv3_2_3x3_relu = F.relu(conv3_2_3x3_bn) 253 | conv3_2_1x1_increase = self.conv3_2_1x1_increase(conv3_2_3x3_relu) 254 | conv3_2_1x1_increase_bn = self.conv3_2_1x1_increase_bn(conv3_2_1x1_increase) 255 | conv3_2_global_pool = F.avg_pool2d(conv3_2_1x1_increase_bn, kernel_size=(28, 28), stride=(1, 1), padding=(0,), ceil_mode=False) 256 | conv3_2_1x1_down = self.conv3_2_1x1_down(conv3_2_global_pool) 257 | conv3_2_1x1_down_relu = F.relu(conv3_2_1x1_down) 258 | conv3_2_1x1_up = self.conv3_2_1x1_up(conv3_2_1x1_down_relu) 259 | conv3_2_prob = F.sigmoid(conv3_2_1x1_up) 260 | conv3_2_1x1_increase_bn_scale = conv3_2_prob * conv3_2_1x1_increase_bn 261 | conv3_2 = conv3_2_1x1_increase_bn_scale + conv3_1_relu 262 | conv3_2_relu = F.relu(conv3_2) 263 | conv3_3_1x1_reduce = self.conv3_3_1x1_reduce(conv3_2_relu) 264 | conv3_3_1x1_reduce_bn = self.conv3_3_1x1_reduce_bn(conv3_3_1x1_reduce) 265 | conv3_3_1x1_reduce_relu = F.relu(conv3_3_1x1_reduce_bn) 266 | conv3_3_3x3_pad = F.pad(conv3_3_1x1_reduce_relu, (1, 1, 1, 1)) 267 | conv3_3_3x3 = self.conv3_3_3x3(conv3_3_3x3_pad) 268 | conv3_3_3x3_bn = self.conv3_3_3x3_bn(conv3_3_3x3) 269 | conv3_3_3x3_relu = F.relu(conv3_3_3x3_bn) 270 | conv3_3_1x1_increase = self.conv3_3_1x1_increase(conv3_3_3x3_relu) 271 | conv3_3_1x1_increase_bn = self.conv3_3_1x1_increase_bn(conv3_3_1x1_increase) 272 | conv3_3_global_pool = F.avg_pool2d(conv3_3_1x1_increase_bn, kernel_size=(28, 28), stride=(1, 1), padding=(0,), ceil_mode=False) 273 | conv3_3_1x1_down = self.conv3_3_1x1_down(conv3_3_global_pool) 274 | conv3_3_1x1_down_relu = F.relu(conv3_3_1x1_down) 275 | conv3_3_1x1_up = self.conv3_3_1x1_up(conv3_3_1x1_down_relu) 276 | conv3_3_prob = F.sigmoid(conv3_3_1x1_up) 277 | conv3_3_1x1_increase_bn_scale = conv3_3_prob * conv3_3_1x1_increase_bn 278 | conv3_3 = conv3_3_1x1_increase_bn_scale + conv3_2_relu 279 | conv3_3_relu = F.relu(conv3_3) 280 | conv3_4_1x1_reduce = self.conv3_4_1x1_reduce(conv3_3_relu) 281 | conv3_4_1x1_reduce_bn = self.conv3_4_1x1_reduce_bn(conv3_4_1x1_reduce) 282 | conv3_4_1x1_reduce_relu = F.relu(conv3_4_1x1_reduce_bn) 283 | conv3_4_3x3_pad = F.pad(conv3_4_1x1_reduce_relu, (1, 1, 1, 1)) 284 | conv3_4_3x3 = self.conv3_4_3x3(conv3_4_3x3_pad) 285 | conv3_4_3x3_bn = self.conv3_4_3x3_bn(conv3_4_3x3) 286 | conv3_4_3x3_relu = F.relu(conv3_4_3x3_bn) 287 | conv3_4_1x1_increase = self.conv3_4_1x1_increase(conv3_4_3x3_relu) 288 | conv3_4_1x1_increase_bn = self.conv3_4_1x1_increase_bn(conv3_4_1x1_increase) 289 | conv3_4_global_pool = F.avg_pool2d(conv3_4_1x1_increase_bn, kernel_size=(28, 28), stride=(1, 1), padding=(0,), ceil_mode=False) 290 | conv3_4_1x1_down = self.conv3_4_1x1_down(conv3_4_global_pool) 291 | conv3_4_1x1_down_relu = F.relu(conv3_4_1x1_down) 292 | conv3_4_1x1_up = self.conv3_4_1x1_up(conv3_4_1x1_down_relu) 293 | conv3_4_prob = F.sigmoid(conv3_4_1x1_up) 294 | conv3_4_1x1_increase_bn_scale = conv3_4_prob * conv3_4_1x1_increase_bn 295 | conv3_4 = conv3_4_1x1_increase_bn_scale + conv3_3_relu 296 | conv3_4_relu = F.relu(conv3_4) 297 | conv4_1_1x1_proj = self.conv4_1_1x1_proj(conv3_4_relu) 298 | conv4_1_1x1_reduce = self.conv4_1_1x1_reduce(conv3_4_relu) 299 | conv4_1_1x1_proj_bn = self.conv4_1_1x1_proj_bn(conv4_1_1x1_proj) 300 | conv4_1_1x1_reduce_bn = self.conv4_1_1x1_reduce_bn(conv4_1_1x1_reduce) 301 | conv4_1_1x1_reduce_relu = F.relu(conv4_1_1x1_reduce_bn) 302 | conv4_1_3x3_pad = F.pad(conv4_1_1x1_reduce_relu, (1, 1, 1, 1)) 303 | conv4_1_3x3 = self.conv4_1_3x3(conv4_1_3x3_pad) 304 | conv4_1_3x3_bn = self.conv4_1_3x3_bn(conv4_1_3x3) 305 | conv4_1_3x3_relu = F.relu(conv4_1_3x3_bn) 306 | conv4_1_1x1_increase = self.conv4_1_1x1_increase(conv4_1_3x3_relu) 307 | conv4_1_1x1_increase_bn = self.conv4_1_1x1_increase_bn(conv4_1_1x1_increase) 308 | conv4_1_global_pool = F.avg_pool2d(conv4_1_1x1_increase_bn, kernel_size=(14, 14), stride=(1, 1), padding=(0,), ceil_mode=False) 309 | conv4_1_1x1_down = self.conv4_1_1x1_down(conv4_1_global_pool) 310 | conv4_1_1x1_down_relu = F.relu(conv4_1_1x1_down) 311 | conv4_1_1x1_up = self.conv4_1_1x1_up(conv4_1_1x1_down_relu) 312 | conv4_1_prob = F.sigmoid(conv4_1_1x1_up) 313 | conv4_1_1x1_increase_bn_scale = conv4_1_prob * conv4_1_1x1_increase_bn 314 | conv4_1 = conv4_1_1x1_increase_bn_scale + conv4_1_1x1_proj_bn 315 | conv4_1_relu = F.relu(conv4_1) 316 | conv4_2_1x1_reduce = self.conv4_2_1x1_reduce(conv4_1_relu) 317 | conv4_2_1x1_reduce_bn = self.conv4_2_1x1_reduce_bn(conv4_2_1x1_reduce) 318 | conv4_2_1x1_reduce_relu = F.relu(conv4_2_1x1_reduce_bn) 319 | conv4_2_3x3_pad = F.pad(conv4_2_1x1_reduce_relu, (1, 1, 1, 1)) 320 | conv4_2_3x3 = self.conv4_2_3x3(conv4_2_3x3_pad) 321 | conv4_2_3x3_bn = self.conv4_2_3x3_bn(conv4_2_3x3) 322 | conv4_2_3x3_relu = F.relu(conv4_2_3x3_bn) 323 | conv4_2_1x1_increase = self.conv4_2_1x1_increase(conv4_2_3x3_relu) 324 | conv4_2_1x1_increase_bn = self.conv4_2_1x1_increase_bn(conv4_2_1x1_increase) 325 | conv4_2_global_pool = F.avg_pool2d(conv4_2_1x1_increase_bn, kernel_size=(14, 14), stride=(1, 1), padding=(0,), ceil_mode=False) 326 | conv4_2_1x1_down = self.conv4_2_1x1_down(conv4_2_global_pool) 327 | conv4_2_1x1_down_relu = F.relu(conv4_2_1x1_down) 328 | conv4_2_1x1_up = self.conv4_2_1x1_up(conv4_2_1x1_down_relu) 329 | conv4_2_prob = F.sigmoid(conv4_2_1x1_up) 330 | conv4_2_1x1_increase_bn_scale = conv4_2_prob * conv4_2_1x1_increase_bn 331 | conv4_2 = conv4_2_1x1_increase_bn_scale + conv4_1_relu 332 | conv4_2_relu = F.relu(conv4_2) 333 | conv4_3_1x1_reduce = self.conv4_3_1x1_reduce(conv4_2_relu) 334 | conv4_3_1x1_reduce_bn = self.conv4_3_1x1_reduce_bn(conv4_3_1x1_reduce) 335 | conv4_3_1x1_reduce_relu = F.relu(conv4_3_1x1_reduce_bn) 336 | conv4_3_3x3_pad = F.pad(conv4_3_1x1_reduce_relu, (1, 1, 1, 1)) 337 | conv4_3_3x3 = self.conv4_3_3x3(conv4_3_3x3_pad) 338 | conv4_3_3x3_bn = self.conv4_3_3x3_bn(conv4_3_3x3) 339 | conv4_3_3x3_relu = F.relu(conv4_3_3x3_bn) 340 | conv4_3_1x1_increase = self.conv4_3_1x1_increase(conv4_3_3x3_relu) 341 | conv4_3_1x1_increase_bn = self.conv4_3_1x1_increase_bn(conv4_3_1x1_increase) 342 | conv4_3_global_pool = F.avg_pool2d(conv4_3_1x1_increase_bn, kernel_size=(14, 14), stride=(1, 1), padding=(0,), ceil_mode=False) 343 | conv4_3_1x1_down = self.conv4_3_1x1_down(conv4_3_global_pool) 344 | conv4_3_1x1_down_relu = F.relu(conv4_3_1x1_down) 345 | conv4_3_1x1_up = self.conv4_3_1x1_up(conv4_3_1x1_down_relu) 346 | conv4_3_prob = F.sigmoid(conv4_3_1x1_up) 347 | conv4_3_1x1_increase_bn_scale = conv4_3_prob * conv4_3_1x1_increase_bn 348 | conv4_3 = conv4_3_1x1_increase_bn_scale + conv4_2_relu 349 | conv4_3_relu = F.relu(conv4_3) 350 | conv4_4_1x1_reduce = self.conv4_4_1x1_reduce(conv4_3_relu) 351 | conv4_4_1x1_reduce_bn = self.conv4_4_1x1_reduce_bn(conv4_4_1x1_reduce) 352 | conv4_4_1x1_reduce_relu = F.relu(conv4_4_1x1_reduce_bn) 353 | conv4_4_3x3_pad = F.pad(conv4_4_1x1_reduce_relu, (1, 1, 1, 1)) 354 | conv4_4_3x3 = self.conv4_4_3x3(conv4_4_3x3_pad) 355 | conv4_4_3x3_bn = self.conv4_4_3x3_bn(conv4_4_3x3) 356 | conv4_4_3x3_relu = F.relu(conv4_4_3x3_bn) 357 | conv4_4_1x1_increase = self.conv4_4_1x1_increase(conv4_4_3x3_relu) 358 | conv4_4_1x1_increase_bn = self.conv4_4_1x1_increase_bn(conv4_4_1x1_increase) 359 | conv4_4_global_pool = F.avg_pool2d(conv4_4_1x1_increase_bn, kernel_size=(14, 14), stride=(1, 1), padding=(0,), ceil_mode=False) 360 | conv4_4_1x1_down = self.conv4_4_1x1_down(conv4_4_global_pool) 361 | conv4_4_1x1_down_relu = F.relu(conv4_4_1x1_down) 362 | conv4_4_1x1_up = self.conv4_4_1x1_up(conv4_4_1x1_down_relu) 363 | conv4_4_prob = F.sigmoid(conv4_4_1x1_up) 364 | conv4_4_1x1_increase_bn_scale = conv4_4_prob * conv4_4_1x1_increase_bn 365 | conv4_4 = conv4_4_1x1_increase_bn_scale + conv4_3_relu 366 | conv4_4_relu = F.relu(conv4_4) 367 | conv4_5_1x1_reduce = self.conv4_5_1x1_reduce(conv4_4_relu) 368 | conv4_5_1x1_reduce_bn = self.conv4_5_1x1_reduce_bn(conv4_5_1x1_reduce) 369 | conv4_5_1x1_reduce_relu = F.relu(conv4_5_1x1_reduce_bn) 370 | conv4_5_3x3_pad = F.pad(conv4_5_1x1_reduce_relu, (1, 1, 1, 1)) 371 | conv4_5_3x3 = self.conv4_5_3x3(conv4_5_3x3_pad) 372 | conv4_5_3x3_bn = self.conv4_5_3x3_bn(conv4_5_3x3) 373 | conv4_5_3x3_relu = F.relu(conv4_5_3x3_bn) 374 | conv4_5_1x1_increase = self.conv4_5_1x1_increase(conv4_5_3x3_relu) 375 | conv4_5_1x1_increase_bn = self.conv4_5_1x1_increase_bn(conv4_5_1x1_increase) 376 | conv4_5_global_pool = F.avg_pool2d(conv4_5_1x1_increase_bn, kernel_size=(14, 14), stride=(1, 1), padding=(0,), ceil_mode=False) 377 | conv4_5_1x1_down = self.conv4_5_1x1_down(conv4_5_global_pool) 378 | conv4_5_1x1_down_relu = F.relu(conv4_5_1x1_down) 379 | conv4_5_1x1_up = self.conv4_5_1x1_up(conv4_5_1x1_down_relu) 380 | conv4_5_prob = F.sigmoid(conv4_5_1x1_up) 381 | conv4_5_1x1_increase_bn_scale = conv4_5_prob * conv4_5_1x1_increase_bn 382 | conv4_5 = conv4_5_1x1_increase_bn_scale + conv4_4_relu 383 | conv4_5_relu = F.relu(conv4_5) 384 | conv4_6_1x1_reduce = self.conv4_6_1x1_reduce(conv4_5_relu) 385 | conv4_6_1x1_reduce_bn = self.conv4_6_1x1_reduce_bn(conv4_6_1x1_reduce) 386 | conv4_6_1x1_reduce_relu = F.relu(conv4_6_1x1_reduce_bn) 387 | conv4_6_3x3_pad = F.pad(conv4_6_1x1_reduce_relu, (1, 1, 1, 1)) 388 | conv4_6_3x3 = self.conv4_6_3x3(conv4_6_3x3_pad) 389 | conv4_6_3x3_bn = self.conv4_6_3x3_bn(conv4_6_3x3) 390 | conv4_6_3x3_relu = F.relu(conv4_6_3x3_bn) 391 | conv4_6_1x1_increase = self.conv4_6_1x1_increase(conv4_6_3x3_relu) 392 | conv4_6_1x1_increase_bn = self.conv4_6_1x1_increase_bn(conv4_6_1x1_increase) 393 | conv4_6_global_pool = F.avg_pool2d(conv4_6_1x1_increase_bn, kernel_size=(14, 14), stride=(1, 1), padding=(0,), ceil_mode=False) 394 | conv4_6_1x1_down = self.conv4_6_1x1_down(conv4_6_global_pool) 395 | conv4_6_1x1_down_relu = F.relu(conv4_6_1x1_down) 396 | conv4_6_1x1_up = self.conv4_6_1x1_up(conv4_6_1x1_down_relu) 397 | conv4_6_prob = F.sigmoid(conv4_6_1x1_up) 398 | conv4_6_1x1_increase_bn_scale = conv4_6_prob * conv4_6_1x1_increase_bn 399 | conv4_6 = conv4_6_1x1_increase_bn_scale + conv4_5_relu 400 | conv4_6_relu = F.relu(conv4_6) 401 | conv5_1_1x1_proj = self.conv5_1_1x1_proj(conv4_6_relu) 402 | conv5_1_1x1_reduce = self.conv5_1_1x1_reduce(conv4_6_relu) 403 | conv5_1_1x1_proj_bn = self.conv5_1_1x1_proj_bn(conv5_1_1x1_proj) 404 | conv5_1_1x1_reduce_bn = self.conv5_1_1x1_reduce_bn(conv5_1_1x1_reduce) 405 | conv5_1_1x1_reduce_relu = F.relu(conv5_1_1x1_reduce_bn) 406 | conv5_1_3x3_pad = F.pad(conv5_1_1x1_reduce_relu, (1, 1, 1, 1)) 407 | conv5_1_3x3 = self.conv5_1_3x3(conv5_1_3x3_pad) 408 | conv5_1_3x3_bn = self.conv5_1_3x3_bn(conv5_1_3x3) 409 | conv5_1_3x3_relu = F.relu(conv5_1_3x3_bn) 410 | conv5_1_1x1_increase = self.conv5_1_1x1_increase(conv5_1_3x3_relu) 411 | conv5_1_1x1_increase_bn = self.conv5_1_1x1_increase_bn(conv5_1_1x1_increase) 412 | conv5_1_global_pool = F.avg_pool2d(conv5_1_1x1_increase_bn, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False) 413 | conv5_1_1x1_down = self.conv5_1_1x1_down(conv5_1_global_pool) 414 | conv5_1_1x1_down_relu = F.relu(conv5_1_1x1_down) 415 | conv5_1_1x1_up = self.conv5_1_1x1_up(conv5_1_1x1_down_relu) 416 | conv5_1_prob = F.sigmoid(conv5_1_1x1_up) 417 | conv5_1_1x1_increase_bn_scale = conv5_1_prob * conv5_1_1x1_increase_bn 418 | conv5_1 = conv5_1_1x1_increase_bn_scale + conv5_1_1x1_proj_bn 419 | conv5_1_relu = F.relu(conv5_1) 420 | conv5_2_1x1_reduce = self.conv5_2_1x1_reduce(conv5_1_relu) 421 | conv5_2_1x1_reduce_bn = self.conv5_2_1x1_reduce_bn(conv5_2_1x1_reduce) 422 | conv5_2_1x1_reduce_relu = F.relu(conv5_2_1x1_reduce_bn) 423 | conv5_2_3x3_pad = F.pad(conv5_2_1x1_reduce_relu, (1, 1, 1, 1)) 424 | conv5_2_3x3 = self.conv5_2_3x3(conv5_2_3x3_pad) 425 | conv5_2_3x3_bn = self.conv5_2_3x3_bn(conv5_2_3x3) 426 | conv5_2_3x3_relu = F.relu(conv5_2_3x3_bn) 427 | conv5_2_1x1_increase = self.conv5_2_1x1_increase(conv5_2_3x3_relu) 428 | conv5_2_1x1_increase_bn = self.conv5_2_1x1_increase_bn(conv5_2_1x1_increase) 429 | conv5_2_global_pool = F.avg_pool2d(conv5_2_1x1_increase_bn, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False) 430 | conv5_2_1x1_down = self.conv5_2_1x1_down(conv5_2_global_pool) 431 | conv5_2_1x1_down_relu = F.relu(conv5_2_1x1_down) 432 | conv5_2_1x1_up = self.conv5_2_1x1_up(conv5_2_1x1_down_relu) 433 | conv5_2_prob = F.sigmoid(conv5_2_1x1_up) 434 | conv5_2_1x1_increase_bn_scale = conv5_2_prob * conv5_2_1x1_increase_bn 435 | conv5_2 = conv5_2_1x1_increase_bn_scale + conv5_1_relu 436 | conv5_2_relu = F.relu(conv5_2) 437 | conv5_3_1x1_reduce = self.conv5_3_1x1_reduce(conv5_2_relu) 438 | conv5_3_1x1_reduce_bn = self.conv5_3_1x1_reduce_bn(conv5_3_1x1_reduce) 439 | conv5_3_1x1_reduce_relu = F.relu(conv5_3_1x1_reduce_bn) 440 | conv5_3_3x3_pad = F.pad(conv5_3_1x1_reduce_relu, (1, 1, 1, 1)) 441 | conv5_3_3x3 = self.conv5_3_3x3(conv5_3_3x3_pad) 442 | conv5_3_3x3_bn = self.conv5_3_3x3_bn(conv5_3_3x3) 443 | conv5_3_3x3_relu = F.relu(conv5_3_3x3_bn) 444 | conv5_3_1x1_increase = self.conv5_3_1x1_increase(conv5_3_3x3_relu) 445 | conv5_3_1x1_increase_bn = self.conv5_3_1x1_increase_bn(conv5_3_1x1_increase) 446 | conv5_3_global_pool = F.avg_pool2d(conv5_3_1x1_increase_bn, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False) 447 | conv5_3_1x1_down = self.conv5_3_1x1_down(conv5_3_global_pool) 448 | conv5_3_1x1_down_relu = F.relu(conv5_3_1x1_down) 449 | conv5_3_1x1_up = self.conv5_3_1x1_up(conv5_3_1x1_down_relu) 450 | conv5_3_prob = F.sigmoid(conv5_3_1x1_up) 451 | conv5_3_1x1_increase_bn_scale = conv5_3_prob * conv5_3_1x1_increase_bn 452 | conv5_3 = conv5_3_1x1_increase_bn_scale + conv5_2_relu 453 | conv5_3_relu = F.relu(conv5_3) 454 | pool5_7x7_s1 = F.avg_pool2d(conv5_3_relu, kernel_size=(7, 7), stride=(1, 1), padding=(0,), ceil_mode=False) 455 | classifier_0 = pool5_7x7_s1.view(pool5_7x7_s1.size(0), -1) 456 | classifier_1 = self.classifier_1(classifier_0) 457 | return classifier_0, classifier_1 458 | 459 | 460 | @staticmethod 461 | def __batch_normalization(dim, name, **kwargs): 462 | if dim == 1: layer = nn.BatchNorm1d(**kwargs) 463 | elif dim == 2: layer = nn.BatchNorm2d(**kwargs) 464 | elif dim == 3: layer = nn.BatchNorm3d(**kwargs) 465 | else: raise NotImplementedError() 466 | 467 | if 'scale' in __weights_dict[name]: 468 | layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['scale'])) 469 | else: 470 | layer.weight.data.fill_(1) 471 | 472 | if 'bias' in __weights_dict[name]: 473 | layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias'])) 474 | else: 475 | layer.bias.data.fill_(0) 476 | 477 | layer.state_dict()['running_mean'].copy_(torch.from_numpy(__weights_dict[name]['mean'])) 478 | layer.state_dict()['running_var'].copy_(torch.from_numpy(__weights_dict[name]['var'])) 479 | return layer 480 | 481 | @staticmethod 482 | def __conv(dim, name, **kwargs): 483 | if dim == 1: layer = nn.Conv1d(**kwargs) 484 | elif dim == 2: layer = nn.Conv2d(**kwargs) 485 | elif dim == 3: layer = nn.Conv3d(**kwargs) 486 | else: raise NotImplementedError() 487 | 488 | layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['weights'])) 489 | if 'bias' in __weights_dict[name]: 490 | layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias'])) 491 | return layer 492 | 493 | @staticmethod 494 | def __dense(name, **kwargs): 495 | layer = nn.Linear(**kwargs) 496 | layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['weights'])) 497 | if 'bias' in __weights_dict[name]: 498 | layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias'])) 499 | return layer 500 | 501 | @staticmethod 502 | def __dropout(name, **kwargs): 503 | return nn.Dropout(name=name, p=0.5) 504 | 505 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross-resolution learning for Face Recognition 2 | 3 | This repository contains the code relative to the paper "[Cross-resolution learning for Face Recognition](https://www.sciencedirect.com/science/article/pii/S0262885620300597)" by Fabio Valerio Massoli (ISTI - CNR), Giuseppe Amato (ISTI - CNR), and Fabrizio Falchi (ISTI - CNR). 4 | 5 | It reports a new training procedure for cross-resolution robust deep neural network. 6 | 7 | **Please note:** 8 | We are researchers, not a software company, and have no personnel devoted to documenting and maintaing this research code. Therefore this code is offered "AS IS". Exact reproduction of the numbers in the paper depends on exact reproduction of many factors, including the version of all software dependencies and the choice of underlying hardware (GPU model, etc). Therefore you should expect to need to re-tune your hyperparameters slightly for your new setup. 9 | 10 | ## Cross-resolution training 11 | 12 | Proposed training approach 13 | 14 |

15 | t-SNE 16 |

17 | 18 | 19 | 2D t-SNE embeddings for 20 different identities randomly extracted from the VGGFace2 dataset. All the images were down-sampled to a resolution of 8 pixels. Left: “Base Model”. Right: model trained with our approach. 20 | 21 |

22 | t-SNE 23 |

24 | 25 | ## How to run the code 26 | The current version of the code requires python 3.6 and pytorch 1.4.0. 27 | 28 | Inside the dataset folder, the code expects to find two subdirs: "train" and "validation". 29 | 30 | Minimal usage: 31 | 32 | ``` 33 | python -W ignore main.py --model-base-path path_to_base_model_weight_file --dset-base-path path_to_data_folder 34 | ``` 35 | 36 | The base model is the SE-ResNet-50 (pretrained on the VGGFace2 dataset) that is available [here](https://github.com/fvmassoli/cross-resolution-face-recognition/releases/tag/v1.0). 37 | 38 | The model is the SE-ResNet-50 with features dim = 2048. 39 | 40 | **BE VERY CAREFUL** 41 | 42 | When you download the VGGFace2 dataset, you should NOT use the test set while training. To create a validation set, just take a subset of the training set 43 | 44 | ## Reference 45 | For all the details about the training procedure and the experimental results, please have a look at the [paper](https://www.sciencedirect.com/science/article/pii/S0262885620300597). 46 | 47 | To cite our work, please use the following form 48 | 49 | ``` 50 | @article{massoli2020cross, 51 | title={Cross-resolution learning for Face Recognition}, 52 | author={Massoli, Fabio Valerio and Amato, Giuseppe and Falchi, Fabrizio}, 53 | journal={Image and Vision Computing}, 54 | pages={103927}, 55 | year={2020}, 56 | publisher={Elsevier} 57 | } 58 | ``` 59 | 60 | ## Contacts & Model Request 61 | If you have any question about our work, please contact [Dr. Fabio Valerio Massoli](mailto:fabio.massoli@isti.cnr.it). 62 | 63 | **Currently, we cannot supply the trained model checkpoint.** 64 | 65 | Have fun! :-D 66 | -------------------------------------------------------------------------------- /evaluation_protocols/classifier_metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from sklearn import metrics 8 | from scipy.interpolate import interp1d 9 | 10 | 11 | def eval_roc(scores_dict, resolutions, model=None, template_matching=None): 12 | """ 13 | 14 | Evaluates the ROC for each model at each resolution 15 | 16 | :param scores_dict: dict of scores, k: resolution, v: scores 17 | :param resolutions: list of resolution values 18 | :param template_matching: list of templates to match and the corresponding labels 19 | :param model: model name (used only for the print statement) 20 | 21 | :return: numpy array of (fpr, tpr, thr) 22 | """ 23 | 24 | return np.asarray([metrics.roc_curve(template_matching, scores_dict[res], pos_label=1) for res in resolutions]) 25 | 26 | 27 | def get_cross_resolution_scores(models_base_path, model, dataset_path, date, template_matching, far=1.e-3): 28 | """ 29 | 30 | Evaluates TAR @ FAR 31 | 32 | :param models_base_path: base path to models 33 | :param model: model name 34 | :param dataset_path: data set nema 35 | :param date: date of the data generation 36 | :param template_matching: list of template matching and relative labels 37 | :param far: far value at which evaluate the tar 38 | 39 | :return: TAR @ FAR 40 | """ 41 | 42 | print('Loading scores...') 43 | path = os.path.join(models_base_path, model, dataset_path, date, 'scores/scores_cross_correlation') 44 | cross_correlation_scores = np.load(os.path.join(path, 'similarity_scores.npy')) 45 | print('Scores loaded!!!') 46 | 47 | results = [] 48 | for i in tqdm(range(len(cross_correlation_scores)), total=len(cross_correlation_scores), desc='Looping!!!!'): 49 | 50 | tmp = [] 51 | for j in range(len(cross_correlation_scores[i])): 52 | 53 | (fpr, tpr, thr) = metrics.roc_curve(template_matching, cross_correlation_scores[i][j][:], pos_label=1) 54 | r_f = interp1d(fpr, tpr) 55 | tmp.append(round(r_f(far) * 100., 2)) 56 | 57 | results.append(np.asarray(tmp)) 58 | 59 | return np.asarray(results) 60 | 61 | 62 | def eval_cmc(probe_features_norm, probe_subj_ids, gallery_features_norm, gallery_subj_ids, mAP=False): 63 | """ 64 | 65 | Eval the Cumulative Match Characteristics for the 1:N Face Identification protocol (close set scenario) 66 | 67 | :param probe_features_norm: normalized probe features 68 | :param probe_subj_ids: subject ids of probe features 69 | :param gallery_features_norm: normalized gallery features 70 | :param gallery_subj_ids: subject ids of gallery features 71 | :param mAP: boolean. If True skip CMC final evaluation 72 | 73 | :return: points: ranks at which the CMC has been evaluated, retrieval_rates: CMC for each rank value 74 | """ 75 | 76 | if len(gallery_subj_ids.shape) == 1: 77 | gallery_subj_ids = gallery_subj_ids[:, np.newaxis] 78 | if len(probe_subj_ids.shape) == 1: 79 | probe_subj_ids = probe_subj_ids[:, np.newaxis] 80 | 81 | print('\t\tEvaluating distance matrix...') 82 | 83 | distance_matrix = np.dot(probe_features_norm, gallery_features_norm.T) 84 | ranking = distance_matrix.argsort(axis=1)[:, ::-1] 85 | ranked_scores = distance_matrix[np.arange(probe_features_norm.shape[0])[:, np.newaxis], ranking] 86 | print('\t\tDistance matrix evaluated!!!') 87 | 88 | gallery_ids_expanded = np.tile(gallery_subj_ids, probe_features_norm.shape[0]).T 89 | gallery_ids_ranked = gallery_ids_expanded[np.arange(probe_features_norm.shape[0])[:, np.newaxis], ranking] 90 | 91 | ranked_gt = (gallery_ids_ranked == probe_subj_ids).astype(np.int8) 92 | 93 | nb_points = 50 94 | points = np.arange(1, nb_points+1) 95 | retrieval_rates = np.empty(shape=(nb_points, 1)) 96 | 97 | if not mAP: 98 | for k in points: 99 | retrieval_rates_ = ranked_gt[:, :k].sum(axis=1) 100 | retrieval_rates_[retrieval_rates_ > 1] = 1 101 | retrieval_rates[k - 1] = np.average(retrieval_rates_) 102 | 103 | return points, retrieval_rates, ranked_scores, ranked_gt 104 | 105 | 106 | def eval_fpir(unmated_query_indeces, distance_matrix): 107 | """ 108 | 109 | Eval FPIR: False Positive Identification Rate. 110 | How many unmated probe has a score higher than a specific threshold 111 | 112 | :param unmated_query_indeces: list of indeces (indices of the distance_matrix's rows) corresponding to unmated probes subject ids 113 | :param distance_matrix: matrix of distances among features 114 | 115 | :return: thresholds, FPIR 116 | """ 117 | 118 | highest_scores = [] 119 | 120 | for i, n in enumerate(unmated_query_indeces): 121 | # n is the index of the query that has not a mate into the gallery 122 | # get the distances of a single query from all templates 123 | score = distance_matrix[n] 124 | 125 | # sort in descending order 126 | score = np.sort(score)[::-1] 127 | 128 | # only consider the highest score (the most similar retrieved face) 129 | highest_scores.append(score[0]) 130 | 131 | # the searching step 132 | highest_scores = np.asarray(highest_scores) 133 | min_ = np.min(highest_scores) 134 | max_ = np.max(highest_scores) 135 | step = (max_ - min_) / 1000. 136 | thresholds = [] 137 | FPIRs = [] 138 | 139 | for thr in np.logspace(-4, 0, 4000): #np.arange(min_, max_, step): 140 | # for each value of the thresholds counts how many queries returned the most similar 141 | # face above the threshold. In this case we expect the curve to go down as fast as possible 142 | # since all the queries do not belong to any of the templates 143 | current_fpir = np.sum((highest_scores > thr).astype(np.uint8)) / highest_scores.size 144 | 145 | # the thresholds will be used when calculate the corresponding FNIR 146 | thresholds.append(thr) 147 | FPIRs.append(current_fpir) 148 | 149 | return thresholds, np.asarray(FPIRs) 150 | 151 | 152 | def eval_fnir(mated_indeces, distance_matrix, thresholds, probe_subject_id, gallery_subject_id, r=20): 153 | """ 154 | 155 | Eval FNIR: False Negative Identification Rate. 156 | How many mated matches, within a certain rank, are below a certain threshold. 157 | 158 | :param mated_indeces: indeces in the distance matrix of the query that have mated into the gallery. Indeces of the 159 | query subject id that have a mate into the gallery 160 | :param distance_matrix: 161 | :param thresholds: 162 | :param probe_subject_id: 163 | :param gallery_subject_id: 164 | :param r: highest rank among which look for results 165 | 166 | :return: 167 | """ 168 | 169 | gt_scores = [] 170 | for p, mi in enumerate(mated_indeces): 171 | 172 | # subject id of the query that has a mate into the gallery 173 | mated_id = probe_subject_id[mi] 174 | 175 | # lista di indici 176 | good_index = np.where(gallery_subject_id == mated_id)[0] 177 | 178 | score = distance_matrix[mi] 179 | 180 | rank = np.argsort(score)[::-1] 181 | 182 | flag = 0 183 | 184 | for i in range(r): 185 | find_ = np.where(good_index == rank[i])[0] 186 | if len(find_) != 0: 187 | gt_scores.append(score[rank[find_[0]]]) 188 | flag = 1 189 | break 190 | 191 | if flag == 0: 192 | gt_scores.append(0.0) # matching fails 193 | 194 | gt_scores_ar = np.asarray(gt_scores) 195 | FNIRs = [] 196 | 197 | for thr in thresholds: 198 | curr_fnir = np.sum((gt_scores_ar < thr).astype(np.uint8)) / gt_scores_ar.size 199 | FNIRs.append(curr_fnir) 200 | 201 | return thresholds, np.asarray(FNIRs) 202 | -------------------------------------------------------------------------------- /images/paper_training_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fvmassoli/cross-resolution-face-recognition/81f4a88058b3712945d23c39b3d4f9e8b774e266/images/paper_training_algorithm.png -------------------------------------------------------------------------------- /images/vggface_tsne_base_ft_models_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fvmassoli/cross-resolution-face-recognition/81f4a88058b3712945d23c39b3d4f9e8b774e266/images/vggface_tsne_base_ft_models_8.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import logging 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | from torch.optim import SGD 9 | import torch.backends.cudnn as cudnn 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | 12 | from tensorboardX import SummaryWriter 13 | 14 | from utils import * 15 | from trainer import Trainer 16 | from vggface2_data_manager import VGGFace2DataManager 17 | 18 | 19 | parser = argparse.ArgumentParser("CR-FR") 20 | # Generic usage 21 | parser.add_argument('-s', '--seed', type=int, default=41, 22 | help='Set random seed (default: 41)') 23 | # Model related options 24 | parser.add_argument('-bp', '--model-base-path', default='./senet50_ft_pytorch.pth', 25 | help='Path to base model checkpoint') 26 | parser.add_argument('-ckp', '--model-ckp', 27 | help='Path to fine tuned model checkpoint') 28 | parser.add_argument('-ep', '--experimental-path', default='experiments_results', 29 | help='Output main path') 30 | parser.add_argument('-tp', '--tensorboard-path', default='experiments_results', 31 | help='Tensorboard main log dir path') 32 | # Training Options 33 | parser.add_argument('-dp', '--dset-base-path', 34 | help='Base path to datasets') 35 | parser.add_argument('-l', '--lambda_', default=0.1, type=float, 36 | help='Lambda for features regression loss (default: 0.1)') 37 | parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, 38 | help='Learning rate (default: 1.e-3)') 39 | parser.add_argument('-m', '--momentum', default=0.9, type=float, 40 | help='Optimizer momentum (default: 0.9)') 41 | parser.add_argument('-nt', '--nesterov', action='store_true', 42 | help='Use Nesterov (default: False)') 43 | parser.add_argument('-lp', '--downsampling-prob', default=0.1, type=float, 44 | help='Downsampling probability (default: 0.1)') 45 | parser.add_argument('-e', '--epochs', type=int, default=1, help='Training epochs (default: 1)') 46 | parser.add_argument('-rs', '--train-steps', type=int, default=1, 47 | help='Set number of training iterations before each validation run (default: 1)') 48 | parser.add_argument('-c', '--curriculum', action='store_true', 49 | help='Use curriculum learning (default: False)') 50 | parser.add_argument('-cs', '--curr-step-iterations', type=int, default=35000, 51 | help='Number of images for each curriculum step (default: 35000)') 52 | parser.add_argument('-sp', '--scheduler-patience', type=int, default=10, 53 | help='Scheduler patience (default: 10)') 54 | parser.add_argument('-b', '--batch-size', type=int, default=32, 55 | help='Batch size (default: 32)') 56 | parser.add_argument('-ba', '--batch-accumulation', type=int, default=8, 57 | help='Batch accumulation iterations (default: 8)') 58 | parser.add_argument('-fr', '--valid-fix-resolution', type=int, default=8, 59 | help='Resolution on validation images (default: 8)') 60 | parser.add_argument('-nw', '--num-workers', type=int, default=8, 61 | help='Number of workers (default: 8)') 62 | args = parser.parse_args() 63 | 64 | 65 | # ----------------------------- GENERAL ---------------------------------------- 66 | tmp = ( 67 | f"{args.lambda_}-{args.learning_rate}-{args.downsampling_prob}-" 68 | f"{args.train_steps}-{args.curriculum}-{args.curr_step_iterations}" 69 | ) 70 | 71 | out_dir = os.path.join(args.experimental_path, tmp) 72 | if not os.path.exists(out_dir): 73 | os.makedirs(out_dir) 74 | 75 | logging.basicConfig( 76 | level=logging.INFO, 77 | format="%(asctime)s | %(message)s", 78 | handlers=[ 79 | logging.FileHandler(os.path.join(out_dir, 'training.log')), 80 | logging.StreamHandler() 81 | ]) 82 | logger = logging.getLogger() 83 | 84 | tb_writer = SummaryWriter(os.path.join(args.tensorboard_path, 'tb_runs', tmp)) 85 | 86 | logging.info(f"Training outputs will be saved at: {out_dir}") 87 | # ------------------------------------------------------------------------------ 88 | 89 | 90 | # --------------------------- CUDA SET UP -------------------------------------- 91 | cudnn.benchmark = True 92 | 93 | np.random.seed(args.seed) 94 | torch.manual_seed(args.seed) 95 | torch.cuda.manual_seed(args.seed) 96 | 97 | cuda_available = torch.cuda.is_available() 98 | device = torch.device('cuda' if cuda_available else 'cpu') 99 | # ------------------------------------------------------------------------------ 100 | 101 | 102 | # ---------------- LOAD MODEL & OPTIMIZER & SCHEDULER -------------------------- 103 | sm, tm = load_models(args.model_base_path, device, args.model_ckp) 104 | optimizer = SGD( 105 | params=sm.parameters(), 106 | lr=args.learning_rate, 107 | momentum=args.momentum, 108 | weight_decay=5e-04, 109 | nesterov=args.nesterov 110 | ) 111 | scheduler = ReduceLROnPlateau( 112 | optimizer=optimizer, 113 | mode='min', 114 | factor=0.5, 115 | patience=args.scheduler_patience, 116 | verbose=True, 117 | min_lr=1.e-7, 118 | threshold=0.1 119 | ) 120 | # ------------------------------------------------------------------------------ 121 | 122 | 123 | # ---------------------------- LOAD DATA --------------------------------------- 124 | kwargs = { 125 | 'batch_size': args.batch_size, 126 | 'downsampling_prob': args.downsampling_prob, 127 | 'curriculum': args.curriculum, 128 | 'curr_step_iterations': args.curr_step_iterations, 129 | 'algo_name': 'bilinear', 130 | 'algo_val': PIL.Image.BILINEAR, 131 | 'valid_fix_resolution': args.valid_fix_resolution, 132 | 'num_of_workers': args.num_workers 133 | } 134 | data_manager = VGGFace2DataManager( 135 | dataset_path=args.dset_base_path, 136 | img_folders=['train', 'validation'], 137 | transforms=[get_transforms(mode='train'), get_transforms(mode='eval')], 138 | device=device, 139 | logging=logging, 140 | **kwargs 141 | ) 142 | # ------------------------------------------------------------------------------ 143 | 144 | 145 | if __name__ == '__main__': 146 | Trainer( 147 | student=sm, 148 | teacher=tm, 149 | optimizer=optimizer, 150 | scheduler=scheduler, 151 | loaders=data_manager.get_loaders(), 152 | device=device, 153 | batch_accumulation=args.batch_accumulation, 154 | lambda_=args.lambda_, 155 | train_steps=args.train_steps, 156 | out_dir=out_dir, 157 | tb_writer=tb_writer, 158 | logging=logging 159 | ).train(args.epochs) 160 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.6.9 2 | h5py==2.10.0 3 | matplotlib==3.2.0 4 | notebook==6.0.3 5 | numpy==1.18.4 6 | opencv-python==4.0.0.21 7 | pandas==1.0.4 8 | pickle==4.0 9 | Pillow==7.0.0 10 | psutil==5.6.7 11 | pymc3==3.7 12 | scikit-image==0.14.2 13 | scikit-learn==0.22.2.post1 14 | scipy==1.4.1 15 | seaborn==0.9.0 16 | sklearn==0.0 17 | statsmodels==0.11.1 18 | tensorboard==1.8.0 19 | tensorboardX==2.0 20 | torch==1.4.0 21 | torchfile==0.1.0 22 | torchvision==0.4.2 23 | tqdm==4.31.1 24 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | import torch.nn.functional as F 5 | 6 | from utils import * 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, student, teacher, optimizer, scheduler, loaders, device, batch_accumulation, 11 | lambda_, train_steps, out_dir, tb_writer, logging): 12 | self._student = student 13 | self._teacher = teacher 14 | self._optimizer = optimizer 15 | self._scheduler = scheduler 16 | self._train_loader, self._valid_loader_lr, self._valid_loader = loaders 17 | self._device = device 18 | self._batch_accumulation = batch_accumulation 19 | self._lambda = lambda_ 20 | self._train_steps = train_steps 21 | self._out_dir = out_dir 22 | self._tb_writer = tb_writer 23 | self._logging = logging 24 | self._it_t = 0 25 | self._it_v = 0 26 | 27 | def _eval_batch(self, loader_idx, data): 28 | curr_index = -1 29 | downsampling_prob = -1 30 | if loader_idx == 0: # ImageFolder for original sized images 31 | batch_original = batch = data[0] 32 | labels = data[1] 33 | else: # custom data set for down sampled images 34 | batch = data[0] 35 | batch_original = data[1] 36 | labels = data[2] 37 | curr_index = data[3] 38 | downsampling_prob = data[4] 39 | teacher_features, teacher_logits = self._teacher(batch_original.to(self._device)) 40 | student_features, student_logits = self._student(batch.to(self._device)) 41 | correct = (student_logits.argmax(dim=1).cpu() == labels).sum().item() 42 | loss = F.cross_entropy(student_logits, labels.to(self._device)) + self._lambda*F.mse_loss(student_features, teacher_features) 43 | return loss, student_logits, labels, correct, curr_index, downsampling_prob 44 | 45 | def _train(self, epoch): 46 | self._logging.info("#"*30) 47 | self._logging.info(f'Training at epoch: {epoch}') 48 | self._logging.info("#"*30) 49 | 50 | self._student.train() 51 | self._optimizer.zero_grad() 52 | 53 | j = 1 54 | loss_ = 0 55 | best_acc = 0 56 | correct_ = 0 57 | n_samples_ = 0 58 | nb_backward_steps = 0 59 | 60 | for batch_idx, data in enumerate(self._train_loader, 1): 61 | 62 | if nb_backward_steps == self._train_steps: 63 | nb_backward_steps = 0 64 | v_l_, tmp_best_acc = self._val(epoch) 65 | self._student.train() 66 | self._scheduler.step(v_l_, epoch+1) 67 | ## Save best model 68 | if tmp_best_acc > best_acc: 69 | best_acc = tmp_best_acc 70 | save_model_checkpoint( 71 | best_acc, 72 | batch_idx, 73 | epoch, 74 | self._student.state_dict(), 75 | self._out_dir, 76 | self._logging 77 | ) 78 | 79 | loss, logits, labels, correct, curr_index, downsampling_prob = self._eval_batch(loader_idx=-1, data=data) 80 | 81 | loss_ += loss.item() 82 | correct_ += correct 83 | n_samples_ += labels.shape[0] 84 | 85 | loss.backward() 86 | if j % self._batch_accumulation == 0: 87 | self._logging.info( 88 | f'Train [{epoch}] - [{batch_idx}]/[{len(self._train_loader)}]:' 89 | f'\n\t\t\tLoss LR: {loss_/batch_idx:.3f} --- Acc LR: {(correct_/n_samples_)*100:.2f}%' 90 | f'\n\t\t\tcurr_index: {curr_index[0]} --- downsampling_prob: {downsampling_prob[0]}' 91 | ) 92 | if nb_backward_steps%5 == 1: 93 | self._it_t += 1 94 | self._tb_writer.add_scalar('train/loss', loss_/batch_idx, self._it_t) 95 | self._tb_writer.add_scalar('train/accuracy', correct_/n_samples_, self._it_t) 96 | 97 | j = 1 98 | nb_backward_steps += 1 99 | self._optimizer.step() 100 | self._optimizer.zero_grad() 101 | 102 | else: 103 | j += 1 104 | 105 | def _val(self, epoch): 106 | self._student.eval() 107 | 108 | with torch.no_grad(): 109 | for loader_idx, local_loader in enumerate([self._valid_loader, self._valid_loader_lr]): 110 | loss_ = 0.0 111 | correct_ = 0.0 112 | n_samples = 0 113 | desc = 'Validaiont HR' if loader_idx == 0 else 'Validation LR' 114 | 115 | for batch_id, data in enumerate(tqdm(local_loader, total=len(local_loader), desc=desc, leave=False)): 116 | loss, logits, labels, correct, _, _ = self._eval_batch(loader_idx, data) 117 | 118 | loss_ += loss.item() 119 | correct_ += correct 120 | n_samples += labels.shape[0] 121 | 122 | loss_ = loss_ / len(local_loader) 123 | acc_ = (correct_ / n_samples) * 100 124 | 125 | if loader_idx == 0: 126 | self._logging.info(f'Valid loss HR: {loss_:.3f} --- Valid acc HR: {acc_:.2f}%') 127 | self._tb_writer.add_scalar('validation/loss_hr', loss_, self._it_v) 128 | self._tb_writer.add_scalar('validation/accuracy_hr', acc_, self._it_v) 129 | else: 130 | self._logging.info(f'Valid loss LR: {loss_:.3f} --- Valid acc LR: {acc_:.2f}%') 131 | self._tb_writer.add_scalar('validation/loss_lr', loss_, self._it_v) 132 | self._tb_writer.add_scalar('validation/accuracy_lr', acc_, self._it_v) 133 | 134 | self._it_v += 1 135 | 136 | return loss_, acc_ 137 | 138 | def train(self, epochs): 139 | self._val(0) 140 | [self._train(epoch) for epoch in range(1, epochs+1)] 141 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as t 6 | 7 | 8 | def load_models(model_base_path, device="cpu", model_ckp=None): 9 | assert os.path.exists(model_base_path), "Base model checkpoint not found at: {}".format(model_base_path) 10 | sm = torch.load(model_base_path) 11 | tm = torch.load(model_base_path) 12 | if model_ckp is not None: 13 | assert os.path.exists(model_ckp), f"Model checkpoint not found at: {model_ckp}" 14 | ckp = torch.load(model_ckp, map_location='cpu') 15 | [p.data.copy_(torch.from_numpy(ckp['model_state_dict'][n].numpy())) for n, p in sm.named_parameters()] 16 | for n, m in sm.named_modules(): 17 | if isinstance(m, nn.BatchNorm2d): 18 | m.momentum = 0.1 19 | m.running_var = ckp['model_state_dict'][n + '.running_var'] 20 | m.running_mean = ckp['model_state_dict'][n + '.running_mean'] 21 | m.num_batches_tracked = ckp['model_state_dict'][n + '.num_batches_tracked'] 22 | ## Freeze all params for the teacher model 23 | for param in tm.parameters(): 24 | param.requires_grad = False 25 | return sm.to(device), tm.to(device) 26 | 27 | 28 | def save_model_checkpoint(best_acc, batch_idx, epoch, model_state_dict, out_dir, logging): 29 | state_dict = { 30 | 'best_acc': best_acc, 31 | 'epoch': epoch, 32 | 'model_state_dict': model_state_dict 33 | } 34 | file_name = os.path.join(out_dir, f'models_ckp_{epoch}_{batch_idx}.pth') 35 | torch.save(state_dict, file_name) 36 | logging.info( 37 | f"Saved model with best acc: {best_acc}" 38 | f"\nAt epoch: {epoch}" 39 | f"\nAt iter: {batch_idx}" 40 | f"\nModel saved at: {file_name}" 41 | ) 42 | 43 | 44 | def get_transforms(mode, resize=256, grayed_prob=0.2, crop_size=224): 45 | def subtract_mean(x): 46 | mean_vector = [91.4953, 103.8827, 131.0912] 47 | x *= 255. 48 | x[0] -= mean_vector[0] 49 | x[1] -= mean_vector[1] 50 | x[2] -= mean_vector[2] 51 | return x 52 | if mode=='train': 53 | return t.Compose([ 54 | t.Resize(resize), 55 | t.RandomGrayscale(p=grayed_prob), 56 | t.RandomCrop(crop_size), 57 | t.ToTensor(), 58 | t.Lambda(lambda x: subtract_mean(x)) 59 | ]) 60 | else: 61 | return t.Compose([ 62 | t.Resize(resize), 63 | t.CenterCrop(crop_size), 64 | t.ToTensor(), 65 | t.Lambda(lambda x: subtract_mean(x)) 66 | ]) -------------------------------------------------------------------------------- /vggface2_custom_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | import torch 5 | from tqdm import tqdm 6 | from PIL import Image 7 | 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class VGGFace2Dataset(Dataset): 12 | def __init__(self, root, transforms, train, logging, **kwargs): 13 | self._root = root 14 | self.transforms = transforms 15 | self._train = train 16 | self._curr_step_iterations = kwargs['curr_step_iterations'] 17 | self._algo_name = kwargs['algo_name'] 18 | self._algo = kwargs['algo_val'] 19 | self._curriculum = kwargs['curriculum'] 20 | self._curriculum_index = 0 21 | if self._train: 22 | self._downsampling_prob = 0.1 if self._curriculum else kwargs['downsampling_prob'] 23 | else: 24 | self._downsampling_prob = 1.0 # validation 25 | self._valid_resolution = kwargs['valid_fix_resolution'] 26 | self._classes, self._class_to_idx = self._find_classes() 27 | self._samples = self._make_dataset() 28 | self._loader = self._get_loader 29 | tr = 'training' if self._train else 'validation' 30 | logging.info( 31 | f'VGGFace2 custom {tr} dataset info:' 32 | f'\n\t\t\t\tRoot folder: {self._root}' 33 | f'\n\t\t\t\tDownsampling prob: {self._downsampling_prob}' 34 | f'\n\t\t\t\tUse Curriculum: {self._curriculum and self._train}' 35 | f'\n\t\t\t\tValid resolution: {self._valid_resolution}' 36 | ) 37 | 38 | def _find_classes(self): 39 | if sys.version_info >= (3, 5): 40 | classes = [d.name for d in os.scandir(self._root) if d.is_dir()] 41 | else: 42 | classes = [d for d in os.listdir(self._root) if os.path.isdir(os.path.join(self._root, d))] 43 | classes.sort() 44 | class_to_idx = {classes[i]: i for i in range(len(classes))} 45 | return classes, class_to_idx 46 | 47 | def _make_dataset(self): 48 | images = [] 49 | dir = os.path.expanduser(self._root) 50 | progress_bar = tqdm( 51 | sorted(self._class_to_idx.keys()), 52 | desc='Making data training set' if self._train else 'Making data validation set', 53 | total=len(self._class_to_idx.keys()), 54 | leave=False 55 | ) 56 | for target in progress_bar: 57 | d = os.path.join(dir, target) 58 | if not os.path.isdir(d): 59 | continue 60 | for root, _, fnames in sorted(os.walk(d)): 61 | for fname in sorted(fnames): 62 | path = os.path.join(root, fname) 63 | item = (path, self._class_to_idx[target]) 64 | images.append(item) 65 | progress_bar.update(n=1) 66 | progress_bar.close() 67 | return images 68 | 69 | @staticmethod 70 | def _get_loader(path): 71 | return Image.fromarray(cv2.imread(path)) 72 | 73 | def _lower_resolution(self, img): 74 | w_i, h_i = img.size 75 | r = h_i/float(w_i) 76 | if self._train: 77 | res = torch.rand(1).item() 78 | res = 3 + 5*res 79 | res = 2**int(res) 80 | else: 81 | res = self._valid_resolution 82 | if res >= w_i or res >= h_i: 83 | return img 84 | if h_i < w_i: 85 | h_n = res 86 | w_n = h_n/float(r) 87 | else: 88 | w_n = res 89 | h_n = w_n*float(r) 90 | img2 = img.resize((int(w_n), int(h_n)), self._algo) 91 | img2 = img2.resize((w_i, h_i), self._algo) 92 | return img2 93 | 94 | def __len__(self): 95 | return len(self._samples) 96 | 97 | def __getitem__(self, idx): 98 | if self._train and self._curriculum: 99 | self._curriculum_index += 1 100 | if (self._curriculum_index % self._curr_step_iterations) == 0 and self._downsampling_prob < 1.0: 101 | self._downsampling_prob += 0.1 102 | path, label = self._samples[idx] 103 | img = self._loader(path) 104 | orig_img = self._loader(path) 105 | if torch.rand(1).item() < self._downsampling_prob: 106 | img = self._lower_resolution(img) 107 | if self.transforms: 108 | img = self.transforms(img) 109 | orig_img = self.transforms(orig_img) 110 | return img, orig_img, label, torch.tensor(self._curriculum_index), torch.tensor(self._downsampling_prob) 111 | -------------------------------------------------------------------------------- /vggface2_data_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from torchvision.datasets import ImageFolder 5 | from torch.utils.data import DataLoader, Subset 6 | 7 | from vggface2_custom_dataset import VGGFace2Dataset 8 | 9 | 10 | class VGGFace2DataManager(): 11 | def __init__(self, dataset_path, img_folders, transforms, device, logging, **kwargs): 12 | self._dataset_path = dataset_path 13 | self._train_img_folders = img_folders[0] 14 | self._valid_img_folders = img_folders[1] 15 | self._train_transforms = transforms[0] 16 | self._valid_transforms = transforms[1] 17 | self._use_cuda = device == 'cuda' 18 | self._logging = logging 19 | self._kwargs = kwargs 20 | self._batch_size = kwargs['batch_size'] 21 | self._num_of_workers = kwargs['num_of_workers'] 22 | self._datasets = self._init_datasets() 23 | self._data_loaders = self._init_data_loaders() 24 | self._print_summary() 25 | 26 | def _init_datasets(self): 27 | self._logging.info(f'Initializing VGGFace2 data sets...') 28 | train_dataset = VGGFace2Dataset( 29 | root=os.path.join(self._dataset_path, self._train_img_folders), 30 | transforms=self._train_transforms, 31 | train=True, 32 | logging=self._logging, 33 | **self._kwargs 34 | ) 35 | valid_dataset_lr = VGGFace2Dataset( 36 | root=os.path.join(self._dataset_path, self._valid_img_folders), 37 | transforms=self._valid_transforms, 38 | train=False, 39 | logging=self._logging, 40 | **self._kwargs 41 | ) 42 | valid_dataset = ImageFolder( 43 | root=os.path.join(self._dataset_path, self._valid_img_folders), 44 | transform=self._valid_transforms 45 | ) 46 | self._logging.info('Train datasets initialized!!!') 47 | return train_dataset, valid_dataset_lr, valid_dataset 48 | 49 | def _init_data_loaders(self): 50 | self._logging.info('Initializing VGGFace2 data loaders...') 51 | train_data_loader = DataLoader( 52 | dataset=self._datasets[0], 53 | batch_size=self._batch_size, 54 | shuffle=True, 55 | num_workers=self._num_of_workers, 56 | pin_memory=self._use_cuda 57 | ) 58 | dataset_len = len(self._datasets[1]) 59 | indices = list(np.arange(0, dataset_len, 30)) 60 | split = int(np.floor(len(indices) * 0.5)) 61 | valid_indices = indices[split:] 62 | tmp_valid_dataset_lr = Subset(self._datasets[1], valid_indices) 63 | tmp_valid_dataset = Subset(self._datasets[2], valid_indices) 64 | valid_data_loader_lr = DataLoader( 65 | dataset=tmp_valid_dataset_lr, 66 | batch_size=self._batch_size, 67 | num_workers=self._num_of_workers, 68 | pin_memory=self._use_cuda 69 | ) 70 | valid_data_loader = DataLoader( 71 | dataset=tmp_valid_dataset, 72 | batch_size=self._batch_size, 73 | num_workers=self._num_of_workers, 74 | pin_memory=self._use_cuda 75 | ) 76 | return train_data_loader, valid_data_loader_lr, valid_data_loader 77 | 78 | def _print_summary(self): 79 | self._logging.info("VGGFace2 data summary:") 80 | self._logging.info( 81 | f'\tBatch size: {self._batch_size}' 82 | f'\n\t\t\t\tNumber of workers: {self._num_of_workers}' 83 | f'\n\t\t\t\tTraining images: {len(self._data_loaders[0].dataset)}' 84 | f'\n\t\t\t\tTraining batches: {len(self._data_loaders[0])}' 85 | f'\n\t\t\t\tValidation images: {len(self._data_loaders[1].dataset)}' 86 | f'\n\t\t\t\tValidation batches: {len(self._data_loaders[1])}' 87 | f'\n\t\t\t\tValidation original images: {len(self._data_loaders[2].dataset)}' 88 | f'\n\t\t\t\tValidation original batches: {len(self._data_loaders[2])}' 89 | f'\n\t\t\t\tPin Memory: {self._use_cuda}\n' 90 | ) 91 | 92 | def get_loaders(self): 93 | return self._data_loaders 94 | --------------------------------------------------------------------------------