├── etc ├── readme └── groupface.pptx ├── README.md ├── models ├── GroupFace.py └── resnet.py ├── dataloader.py └── loss └── loss.py /etc/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GroupFace 2 | https://arxiv.org/abs/2005.10497 3 | -------------------------------------------------------------------------------- /etc/groupface.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungyounShin/GroupFace/HEAD/etc/groupface.pptx -------------------------------------------------------------------------------- /models/GroupFace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from resnet import * 4 | 5 | backbone = {18 : resnet_face18(), 6 | 50 : resnet_face50(), 7 | 101 : resnet_face101()} 8 | 9 | class FC(nn.Module): 10 | def __init__(self, inplanes, outplanes): 11 | super(FC, self).__init__() 12 | self.fc = nn.Linear(inplanes, outplanes) 13 | self.bn = nn.BatchNorm1d(outplanes) 14 | self.act = nn.PReLU() 15 | def forward(self, x): 16 | x = self.fc(x) 17 | x = self.bn(x) 18 | return self.act(x) 19 | 20 | class GDN(nn.Module): 21 | def __init__(self, inplanes, outplanes, intermediate_dim = 256): 22 | super(GDN, self).__init__() 23 | self.fc1 = FC(inplanes, intermediate_dim) 24 | self.fc2 = FC(intermediate_dim, outplanes) 25 | self.softmax = nn.Softmax() 26 | def forward(self, x): 27 | intermediate = self.fc1(x) 28 | out = self.fc2(intermediate) 29 | return intermediate, self.softmax(out) 30 | 31 | class GroupFace(nn.Module): 32 | def __init__(self, resnet=18, feature_dim = 512 ,groups = 4, mode='S'): 33 | super(GroupFace, self).__init__() 34 | self.mode = mode 35 | self.groups = groups 36 | self.Backbone = backbone[resnet] 37 | self.instance_fc = FC(4096, feature_dim) 38 | self.GDN = GDN(feature_dim, groups) 39 | self.group_fc = nn.ModuleList([FC(4096,feature_dim) for i in range(groups)]) 40 | 41 | def forward(self, x): 42 | B = x.shape[0] 43 | x = self.Backbone(x) #(B,4096) 44 | instacne_representation = self.instance_fc(x) 45 | 46 | #GDN 47 | group_inter, group_prob = self.GDN(instacne_representation) 48 | print(group_prob) 49 | #group aware repr 50 | v_G = [Gk(x) for Gk in self.group_fc] #(B,512) 51 | 52 | #self distributed labeling 53 | group_label_p = group_prob.data 54 | group_label_E = group_label_p.mean(dim=0) 55 | group_label_u = (group_label_p - group_label_E.unsqueeze(dim=-1).expand(self.groups,B).T)/self.groups + (1/self.groups) 56 | group_label = torch.argmax(group_label_u, dim=1).data 57 | 58 | #group ensemble 59 | group_mul_p_vk = list() 60 | if self.mode == 'S': 61 | for k in range(self.groups): 62 | Pk = group_prob[:,k].unsqueeze(dim=-1).expand(B,512) 63 | group_mul_p_vk.append(torch.mul(v_G[k], Pk)) 64 | group_ensembled = torch.stack(group_mul_p_vk).sum(dim=0) 65 | #instance , group aggregation 66 | final = instacne_representation + group_ensembled 67 | return group_inter, final , group_prob, group_label 68 | 69 | if __name__=="__main__": 70 | x = torch.randn(5,3,112,112) 71 | model = GroupFace(resnet=101) 72 | out = model(x) 73 | print("==output==") 74 | print(out[0].shape, out[1].shape, out[2].shape, out[3].shape) 75 | print(torch.argmax(out[2], dim=1), '\n',out[3]) 76 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torch.utils.data import DataLoader 3 | import os,cv2,random 4 | from itertools import combinations 5 | import mxnet as mx 6 | from tqdm import tqdm 7 | 8 | class CustomFace(Dataset): 9 | def __init__(self,root,transform=None, align_path="/home/yo0n/workspace2/celebrity_lmk", sample=None): 10 | self.root = root 11 | self.identities = [i for i in os.listdir(self.root) if "DS" not in i] 12 | if sample is not None: 13 | self.identities = random.sample(self.identities, sample) 14 | self.img_paths = list() 15 | for iden in self.identities: 16 | self.img_paths += [root + "/"+ iden+ "/"+i for i in os.listdir(root + "/" + iden) if "DS" not in i] 17 | self.transform = transform 18 | self.align_path = align_path 19 | if self.align_path is not None: 20 | f = open(align_path, 'r') 21 | label_txt = f.read() 22 | f.close() 23 | self.label_txt = label_txt.split('\n') 24 | self.label_txt = [i.split(' ') for i in self.label_txt] 25 | self.label_dict = dict() 26 | self.pad = 112 27 | for i in tqdm(range(len(self.label_txt))): 28 | align_info = self.label_txt[i][2:] 29 | align_info = [float(j) for j in align_info] 30 | xalign = [j for idx,j in enumerate(align_info) if idx%2==0] 31 | yalign = [j for idx,j in enumerate(align_info) if (idx+1)%2==0] 32 | xs,ys = sum(xalign)/5.-self.pad, sum(yalign)/5.-self.pad 33 | self.label_dict[self.label_txt[i][0]] = (int(xs),int(ys)) 34 | #print(self.label_dict) 35 | 36 | 37 | def find_alignment(self,idx): 38 | img_path = self.img_paths[idx] 39 | id = img_path.split('/')[-2] 40 | imgP = img_path.split('/')[-1] 41 | query = id + "/" + imgP 42 | for i in range(len(self.label_txt)): 43 | if query in self.label_txt[i]: 44 | align_info = self.label_txt[idx].split(' ')[2:] 45 | align_info = [float(i) for i in align_info] 46 | xalign = [i for idx,i in enumerate(align_info) if idx%2==0] 47 | yalign = [i for idx,i in enumerate(align_info) if (idx+1)%2==0] 48 | xs,ys = sum(xalign)/5.-self.pad, sum(yalign)/5.-self.pad 49 | return int(xs),int(ys) 50 | return None 51 | 52 | def __getitem__(self, idx): 53 | img_path = self.img_paths[idx] 54 | #print(img_path) 55 | iden = self.identities.index(img_path.split("/")[-2]) 56 | img = cv2.imread(img_path) 57 | if self.align_path is not None: 58 | ys,xs = self.label_dict[img_path[img_path.index('celebrity'):]] 59 | img = img[xs:xs+self.pad*2, ys:ys+self.pad*2] 60 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 61 | if self.transform is not None: 62 | img = self.transform(image=img) 63 | return img['image'],iden 64 | return img, iden 65 | 66 | 67 | def __len__(self): 68 | return len(self.img_paths) 69 | 70 | 71 | class CustomFaceValid(Dataset): 72 | def __init__(self,root, transform=None): 73 | self.root = root 74 | self.identities = [i for i in os.listdir(self.root) if "DS" not in i] 75 | self.img_paths = list() 76 | for iden in self.identities: 77 | self.img_paths += [root + "/"+ iden+ "/"+i for i in os.listdir(root + "/" + iden) if "DS" not in i] 78 | self.transform = transform 79 | self.pairs = list(combinations(self.img_paths, 2)) 80 | pos_pairs = [i for i in self.pairs if i[0].split('/')[-2] == i[1].split('/')[-2] ] 81 | neg_pairs = [i for i in self.pairs if i[0].split('/')[-2] != i[1].split('/')[-2] ] 82 | 83 | neg_pairs = random.sample(neg_pairs, k=len(pos_pairs)) 84 | self.pairs = pos_pairs + neg_pairs 85 | 86 | 87 | def __getitem__(self, idx): 88 | pair_paths = self.pairs[idx] 89 | iden_path1, iden_path2 = pair_paths[0], pair_paths[1] 90 | iden1, iden2 = iden_path1.split('/')[-2], iden_path2.split('/')[-2] 91 | label = 0 92 | if iden1 == iden2: 93 | label = 1 94 | 95 | img1 = cv2.imread(iden_path1) 96 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) 97 | img1 = cv2.resize(img1, dsize=(112, 112), interpolation=cv2.INTER_AREA) 98 | 99 | img2 = cv2.imread(iden_path2) 100 | img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) 101 | img2 = cv2.resize(img2, dsize=(112, 112), interpolation=cv2.INTER_AREA) 102 | return img1, img2, label 103 | 104 | 105 | def __len__(self): 106 | return len(self.pairs) 107 | 108 | class MS1M(Dataset): 109 | def __init__(self,root="/home/yo0n/workspace2/ms1m-retinaface-t1", transform=None): 110 | self.root = root 111 | self.record = mx.recordio.MXIndexedRecordIO("/home/yo0n/workspace2/ms1m-retinaface-t1/train.idx", "/home/yo0n/workspace2/ms1m-retinaface-t1/train.rec", 'r') 112 | #print(self.record.read_idx(20)) 113 | self.transform = transform 114 | self.identities = [i for i in range(93431)] 115 | 116 | def __getitem__(self, idx): 117 | img_mxnet = self.record.read_idx(idx+1) 118 | header, img = mx.recordio.unpack_img(img_mxnet) 119 | 120 | iden = int(header.label[0]) 121 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 122 | if self.transform is not None: 123 | img = self.transform(image=img) 124 | 125 | return img['image'], iden 126 | 127 | 128 | def __len__(self): 129 | return 5179509 130 | 131 | if __name__=="__main__": 132 | import matplotlib.pyplot as plt 133 | import random 134 | """ 135 | ms1m = MS1M() 136 | idx = random.randint(0, len(ms1m)) 137 | print(idx) 138 | img, iden = ms1m[0] 139 | 140 | print(iden) 141 | #plt.imshow(img) 142 | #plt.show() 143 | 144 | 145 | max = -9 146 | for i in range(5179509): 147 | img, iden = ms1m[i] 148 | if(max < iden): 149 | max = iden 150 | print(max) 151 | """ 152 | dataset = CustomFace("/sdb/celebrity") 153 | while True: 154 | idx = random.randint(0, len(dataset)) 155 | print(idx) 156 | img, iden = dataset[idx] 157 | 158 | print(iden) 159 | plt.imshow(img) 160 | plt.show() 161 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from math import pi 6 | from torch.nn import Parameter 7 | import math 8 | 9 | class LiArcFace(nn.Module): 10 | def __init__(self, num_classes, emb_size=512, m=0.45, s=64.0): 11 | super().__init__() 12 | self.weight = nn.Parameter(torch.empty(num_classes, emb_size)).cuda() 13 | nn.init.xavier_normal_(self.weight) 14 | self.m = m 15 | self.s = s 16 | 17 | def forward(self, input, label): 18 | W = F.normalize(self.weight) 19 | input = F.normalize(input) 20 | cosine = input @ W.t() 21 | theta = torch.acos(cosine) 22 | m = torch.zeros_like(theta) 23 | m.scatter_(1, label.view(-1, 1), self.m) 24 | logits = self.s * (pi - 2 * (theta + m)) / pi 25 | return logits 26 | 27 | class ArcMarginProduct(nn.Module): 28 | r"""Implement of large margin arc distance: : 29 | Args: 30 | in_features: size of each input sample 31 | out_features: size of each output sample 32 | s: norm of input feature 33 | m: margin 34 | cos(theta + m) 35 | """ 36 | def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False): 37 | super(ArcMarginProduct, self).__init__() 38 | self.in_features = in_features 39 | self.out_features = out_features 40 | self.s = s 41 | self.m = m 42 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)).cuda() 43 | nn.init.xavier_uniform_(self.weight) 44 | 45 | self.easy_margin = easy_margin 46 | self.cos_m = math.cos(m) 47 | self.sin_m = math.sin(m) 48 | self.th = math.cos(math.pi - m) 49 | self.mm = math.sin(math.pi - m) * m 50 | 51 | def forward(self, input, label): 52 | # --------------------------- cos(theta) & phi(theta) --------------------------- 53 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 54 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 55 | phi = cosine * self.cos_m - sine * self.sin_m 56 | if self.easy_margin: 57 | phi = torch.where(cosine > 0, phi, cosine) 58 | else: 59 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 60 | # --------------------------- convert label to one-hot --------------------------- 61 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 62 | one_hot = torch.zeros(cosine.size(), device='cuda') 63 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 64 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 65 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 66 | output *= self.s 67 | # print(output) 68 | 69 | return output 70 | 71 | 72 | class AddMarginProduct(nn.Module): 73 | r"""Implement of large margin cosine distance: : 74 | Args: 75 | in_features: size of each input sample 76 | out_features: size of each output sample 77 | s: norm of input feature 78 | m: margin 79 | cos(theta) - m 80 | """ 81 | 82 | def __init__(self, in_features, out_features, s=30.0, m=0.40): 83 | super(AddMarginProduct, self).__init__() 84 | self.in_features = in_features 85 | self.out_features = out_features 86 | self.s = s 87 | self.m = m 88 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)).cuda() 89 | nn.init.xavier_uniform_(self.weight) 90 | 91 | def forward(self, input, label): 92 | # --------------------------- cos(theta) & phi(theta) --------------------------- 93 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 94 | phi = cosine - self.m 95 | # --------------------------- convert label to one-hot --------------------------- 96 | one_hot = torch.zeros(cosine.size(), device='cuda') 97 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot 98 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 99 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 100 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 101 | output *= self.s 102 | # print(output) 103 | 104 | return output 105 | 106 | def __repr__(self): 107 | return self.__class__.__name__ + '(' \ 108 | + 'in_features=' + str(self.in_features) \ 109 | + ', out_features=' + str(self.out_features) \ 110 | + ', s=' + str(self.s) \ 111 | + ', m=' + str(self.m) + ')' 112 | 113 | 114 | class SphereProduct(nn.Module): 115 | r"""Implement of large margin cosine distance: : 116 | Args: 117 | in_features: size of each input sample 118 | out_features: size of each output sample 119 | m: margin 120 | cos(m*theta) 121 | """ 122 | def __init__(self, in_features, out_features, m=4): 123 | super(SphereProduct, self).__init__() 124 | self.in_features = in_features 125 | self.out_features = out_features 126 | self.m = m 127 | self.base = 1000.0 128 | self.gamma = 0.12 129 | self.power = 1 130 | self.LambdaMin = 5.0 131 | self.iter = 0 132 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 133 | nn.init.xavier_uniform(self.weight) 134 | 135 | # duplication formula 136 | self.mlambda = [ 137 | lambda x: x ** 0, 138 | lambda x: x ** 1, 139 | lambda x: 2 * x ** 2 - 1, 140 | lambda x: 4 * x ** 3 - 3 * x, 141 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1, 142 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x 143 | ] 144 | 145 | def forward(self, input, label): 146 | # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power)) 147 | self.iter += 1 148 | self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power)) 149 | 150 | # --------------------------- cos(theta) & phi(theta) --------------------------- 151 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) 152 | cos_theta = cos_theta.clamp(-1, 1) 153 | cos_m_theta = self.mlambda[self.m](cos_theta) 154 | theta = cos_theta.data.acos() 155 | k = (self.m * theta / 3.14159265).floor() 156 | phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k 157 | NormOfFeature = torch.norm(input, 2, 1) 158 | 159 | # --------------------------- convert label to one-hot --------------------------- 160 | one_hot = torch.zeros(cos_theta.size()) 161 | one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot 162 | one_hot.scatter_(1, label.view(-1, 1), 1) 163 | 164 | # --------------------------- Calculate output --------------------------- 165 | output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta 166 | output *= NormOfFeature.view(-1, 1) 167 | 168 | return output 169 | 170 | def __repr__(self): 171 | return self.__class__.__name__ + '(' \ 172 | + 'in_features=' + str(self.in_features) \ 173 | + ', out_features=' + str(self.out_features) \ 174 | + ', m=' + str(self.m) + ')' 175 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 18-5-21 下午5:26 4 | @author: ronghuaiyang 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | import torch.utils.model_zoo as model_zoo 10 | import torch.nn.utils.weight_norm as weight_norm 11 | import torch.nn.functional as F 12 | 13 | 14 | # __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 15 | # 'resnet152'] 16 | 17 | 18 | model_urls = { 19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class IRBlock(nn.Module): 66 | expansion = 1 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): 69 | super(IRBlock, self).__init__() 70 | self.bn0 = nn.BatchNorm2d(inplanes) 71 | self.conv1 = conv3x3(inplanes, inplanes) 72 | self.bn1 = nn.BatchNorm2d(inplanes) 73 | self.prelu = nn.PReLU() 74 | self.conv2 = conv3x3(inplanes, planes, stride) 75 | self.bn2 = nn.BatchNorm2d(planes) 76 | self.downsample = downsample 77 | self.stride = stride 78 | self.use_se = use_se 79 | if self.use_se: 80 | self.se = SEBlock(planes) 81 | 82 | def forward(self, x): 83 | residual = x 84 | out = self.bn0(x) 85 | out = self.conv1(out) 86 | out = self.bn1(out) 87 | out = self.prelu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | if self.use_se: 92 | out = self.se(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.prelu(out) 99 | 100 | return out 101 | 102 | 103 | class Bottleneck(nn.Module): 104 | expansion = 4 105 | 106 | def __init__(self, inplanes, planes, stride=1, downsample=None): 107 | super(Bottleneck, self).__init__() 108 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(planes) 110 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 111 | padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(planes) 113 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 114 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.downsample = downsample 117 | self.stride = stride 118 | 119 | def forward(self, x): 120 | residual = x 121 | 122 | out = self.conv1(x) 123 | out = self.bn1(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv2(out) 127 | out = self.bn2(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv3(out) 131 | out = self.bn3(out) 132 | 133 | if self.downsample is not None: 134 | residual = self.downsample(x) 135 | 136 | out += residual 137 | out = self.relu(out) 138 | 139 | return out 140 | 141 | 142 | class SEBlock(nn.Module): 143 | def __init__(self, channel, reduction=16): 144 | super(SEBlock, self).__init__() 145 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 146 | self.fc = nn.Sequential( 147 | nn.Linear(channel, channel // reduction), 148 | nn.PReLU(), 149 | nn.Linear(channel // reduction, channel), 150 | nn.Sigmoid() 151 | ) 152 | 153 | def forward(self, x): 154 | b, c, _, _ = x.size() 155 | y = self.avg_pool(x).view(b, c) 156 | y = self.fc(y).view(b, c, 1, 1) 157 | return x * y 158 | 159 | 160 | class ResNetFace(nn.Module): 161 | def __init__(self, block, layers, use_se=True): 162 | self.inplanes = 64 163 | self.use_se = use_se 164 | super(ResNetFace, self).__init__() 165 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False) 166 | self.bn1 = nn.BatchNorm2d(64) 167 | self.prelu = nn.PReLU() 168 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 169 | self.layer1 = self._make_layer(block, 64, layers[0]) 170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 171 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 172 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 173 | self.bn4 = nn.BatchNorm2d(512) 174 | self.dropout = nn.Dropout() 175 | self.fc5 = nn.Linear(112*112*2, 4096) 176 | self.bn5 = nn.BatchNorm1d(4096) 177 | 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | nn.init.xavier_normal_(m.weight) 181 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 182 | nn.init.constant_(m.weight, 1) 183 | nn.init.constant_(m.bias, 0) 184 | elif isinstance(m, nn.Linear): 185 | nn.init.xavier_normal_(m.weight) 186 | nn.init.constant_(m.bias, 0) 187 | 188 | def _make_layer(self, block, planes, blocks, stride=1): 189 | downsample = None 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = nn.Sequential( 192 | nn.Conv2d(self.inplanes, planes * block.expansion, 193 | kernel_size=1, stride=stride, bias=False), 194 | nn.BatchNorm2d(planes * block.expansion), 195 | ) 196 | layers = [] 197 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) 198 | self.inplanes = planes 199 | for i in range(1, blocks): 200 | layers.append(block(self.inplanes, planes, use_se=self.use_se)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | x = self.conv1(x) 206 | x = self.bn1(x) 207 | x = self.prelu(x) 208 | x = self.maxpool(x) 209 | 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | x = self.layer4(x) 214 | x = self.bn4(x) 215 | x = self.dropout(x) 216 | x = x.view(x.size(0), -1) 217 | #print(x.shape) 218 | x = self.fc5(x) 219 | x = self.bn5(x) 220 | 221 | return x 222 | 223 | 224 | class ResNet(nn.Module): 225 | 226 | def __init__(self, block, layers): 227 | self.inplanes = 64 228 | super(ResNet, self).__init__() 229 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 230 | # bias=False) 231 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 232 | bias=False) 233 | self.bn1 = nn.BatchNorm2d(64) 234 | self.relu = nn.ReLU(inplace=True) 235 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 236 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 237 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 238 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 239 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 240 | # self.avgpool = nn.AvgPool2d(8, stride=1) 241 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 242 | self.fc5 = nn.Linear(512 * 8 * 8, 512) 243 | 244 | for m in self.modules(): 245 | if isinstance(m, nn.Conv2d): 246 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 247 | elif isinstance(m, nn.BatchNorm2d): 248 | nn.init.constant_(m.weight, 1) 249 | nn.init.constant_(m.bias, 0) 250 | 251 | def _make_layer(self, block, planes, blocks, stride=1): 252 | downsample = None 253 | if stride != 1 or self.inplanes != planes * block.expansion: 254 | downsample = nn.Sequential( 255 | nn.Conv2d(self.inplanes, planes * block.expansion, 256 | kernel_size=1, stride=stride, bias=False), 257 | nn.BatchNorm2d(planes * block.expansion), 258 | ) 259 | 260 | layers = [] 261 | layers.append(block(self.inplanes, planes, stride, downsample)) 262 | self.inplanes = planes * block.expansion 263 | for i in range(1, blocks): 264 | layers.append(block(self.inplanes, planes)) 265 | 266 | return nn.Sequential(*layers) 267 | 268 | def forward(self, x): 269 | x = self.conv1(x) 270 | x = self.bn1(x) 271 | x = self.relu(x) 272 | # x = self.maxpool(x) 273 | 274 | x = self.layer1(x) 275 | x = self.layer2(x) 276 | x = self.layer3(x) 277 | x = self.layer4(x) 278 | # x = nn.AvgPool2d(kernel_size=x.size()[2:])(x) 279 | # x = self.avgpool(x) 280 | x = x.view(x.size(0), -1) 281 | x = self.fc5(x) 282 | 283 | return x 284 | 285 | 286 | def resnet18(pretrained=False, **kwargs): 287 | """Constructs a ResNet-18 model. 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | """ 291 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 292 | if pretrained: 293 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 294 | return model 295 | 296 | 297 | def resnet34(pretrained=False, **kwargs): 298 | """Constructs a ResNet-34 model. 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | """ 302 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 303 | if pretrained: 304 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 305 | return model 306 | 307 | 308 | def resnet50(pretrained=False, **kwargs): 309 | """Constructs a ResNet-50 model. 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | """ 313 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 314 | if pretrained: 315 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 316 | return model 317 | 318 | 319 | def resnet101(pretrained=False, **kwargs): 320 | """Constructs a ResNet-101 model. 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | """ 324 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 325 | if pretrained: 326 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 327 | return model 328 | 329 | 330 | def resnet152(pretrained=False, **kwargs): 331 | """Constructs a ResNet-152 model. 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | """ 335 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 336 | if pretrained: 337 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 338 | return model 339 | 340 | 341 | def resnet_face18(use_se=True, **kwargs): 342 | model = ResNetFace(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs) 343 | return model 344 | 345 | def resnet_face50(use_se=True, **kwargs): 346 | model = ResNetFace(IRBlock, [3, 4, 6, 3], use_se=use_se, **kwargs) 347 | return model 348 | 349 | def resnet_face101(use_se=True, **kwargs): 350 | model = ResNetFace(IRBlock, [3, 4, 23, 3], use_se=use_se, **kwargs) 351 | return model 352 | 353 | if __name__=="__main__": 354 | x = torch.randn(4,3,112,112) 355 | model = resnet_face18() 356 | out = model(x) 357 | 358 | print(out.shape) 359 | --------------------------------------------------------------------------------