├── LICENSE ├── README.md └── ssrnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hans Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSR-Net 2 | Pytorch's version implementation of SSR-Net for age and gender Estimation 3 | 4 | # Paper 5 | https://github.com/shamangary/SSR-Net/blob/master/ijcai18_ssrnet_pdfa_2b.pdf 6 | 7 | # Reference 8 | https://github.com/shamangary/SSR-Net 9 | -------------------------------------------------------------------------------- /ssrnet.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | # ''' 3 | # @Author: Hans Hu 4 | # ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import math 10 | 11 | class ssrnet(nn.Module): 12 | def __init__(self,stage_num,lambda_local,lambda_d,age): 13 | super(ssrnet, self).__init__() 14 | 15 | self.stage_num = stage_num 16 | self.lambda_local = lambda_local 17 | self.lambda_d = lambda_d 18 | self.age = age 19 | 20 | # ------------------------------------------------- 21 | self.x1=nn.Sequential( 22 | nn.Conv2d(3,32,kernel_size=3,padding=1), 23 | nn.BatchNorm2d(32), 24 | nn.ReLU(inplace=True), 25 | ) 26 | self.x_layer1=nn.AvgPool2d(kernel_size=2,stride=2) 27 | 28 | self.x2=nn.Sequential( 29 | nn.Conv2d(32,32,kernel_size=3,padding=1), 30 | nn.BatchNorm2d(32), 31 | nn.ReLU(inplace=True), 32 | ) 33 | self.x_layer2=nn.AvgPool2d(kernel_size=2,stride=2) 34 | 35 | self.x3=nn.Sequential( 36 | nn.Conv2d(32,32,kernel_size=3,padding=1), 37 | nn.BatchNorm2d(32), 38 | nn.ReLU(inplace=True), 39 | ) 40 | self.x_layer3=nn.AvgPool2d(kernel_size=2,stride=2) 41 | 42 | self.x4=nn.Sequential( 43 | nn.Conv2d(32,32,kernel_size=3,padding=1), 44 | nn.BatchNorm2d(32), 45 | nn.ReLU(inplace=True), 46 | ) 47 | # ------------------------------------------------- 48 | self.s1=nn.Sequential( 49 | nn.Conv2d(3,16,kernel_size=3,padding=1), 50 | nn.BatchNorm2d(16), 51 | nn.ReLU(inplace=True), 52 | ) 53 | self.s_layer1=nn.MaxPool2d(kernel_size=2,stride=2) 54 | 55 | self.s2=nn.Sequential( 56 | nn.Conv2d(16,16,kernel_size=3,padding=1), 57 | nn.BatchNorm2d(16), 58 | nn.Tanh(), 59 | ) 60 | self.s_layer2=nn.MaxPool2d(kernel_size=2,stride=2) 61 | 62 | self.s3=nn.Sequential( 63 | nn.Conv2d(16,16,kernel_size=3,padding=1), 64 | nn.BatchNorm2d(16), 65 | nn.Tanh(), 66 | ) 67 | self.s_layer3=nn.MaxPool2d(kernel_size=2,stride=2) 68 | 69 | self.s4=nn.Sequential( 70 | nn.Conv2d(16,16,kernel_size=3,padding=1), 71 | nn.BatchNorm2d(16), 72 | nn.Tanh(), 73 | ) 74 | # ------------------------------------------------- 75 | 76 | # ------------------------------------------------- 77 | self.s_layer4=nn.Sequential( 78 | nn.Conv2d(16,10,kernel_size=1,stride=1), 79 | nn.ReLU(inplace=True), 80 | ) 81 | self.s_layer4_mix=nn.Sequential( 82 | nn.Dropout(p=0.2), 83 | nn.Linear(640,3), 84 | nn.ReLU(inplace=True), 85 | ) 86 | # ------------------------------------------------- 87 | 88 | # ------------------------------------------------- 89 | self.x_layer4=nn.Sequential( 90 | nn.Conv2d(32,10,kernel_size=1,stride=1), 91 | nn.ReLU(inplace=True), 92 | ) 93 | self.x_layer4_mix=nn.Sequential( 94 | nn.Dropout(p=0.2), 95 | nn.Linear(640,3), 96 | nn.ReLU(inplace=True), 97 | ) 98 | # ------------------------------------------------- 99 | 100 | # ------------------------------------------------- 101 | self.delta_s1=nn.Sequential( 102 | nn.Linear(640,1), 103 | nn.Tanh(), 104 | ) 105 | # ------------------------------------------------- 106 | 107 | # ------------------------------------------------- 108 | self.feat_a_s1=nn.Sequential( 109 | nn.Linear(3,6), 110 | nn.ReLU(inplace=True), 111 | ) 112 | self.pred_a_s1=nn.Sequential( 113 | nn.Linear(6,3), 114 | nn.ReLU(inplace=True), 115 | ) 116 | self.local_s1=nn.Sequential( 117 | nn.Linear(6,3), 118 | nn.Tanh(), 119 | ) 120 | # ------------------------------------------------- 121 | 122 | # ------------------------------------------------- 123 | self.s_layer2_2=nn.Sequential( 124 | nn.Conv2d(16,10,kernel_size=1,padding=1), 125 | nn.ReLU(inplace=True), 126 | nn.MaxPool2d(kernel_size=4,stride=4), 127 | ) 128 | self.s_layer2_mix=nn.Sequential( 129 | nn.Dropout(p=0.2), 130 | nn.Linear(160,3), 131 | nn.ReLU(inplace=True), 132 | ) 133 | # ------------------------------------------------- 134 | 135 | # ------------------------------------------------- 136 | self.x_layer2_2=nn.Sequential( 137 | nn.Conv2d(32,10,kernel_size=1,padding=1), 138 | nn.ReLU(inplace=True), 139 | nn.AvgPool2d(kernel_size=4,stride=4), 140 | ) 141 | self.x_layer2_mix=nn.Sequential( 142 | nn.Dropout(p=0.2), 143 | nn.Linear(160,3), 144 | nn.ReLU(inplace=True), 145 | ) 146 | # ------------------------------------------------- 147 | 148 | # ------------------------------------------------- 149 | self.delta_s2=nn.Sequential( 150 | nn.Linear(160,1), 151 | nn.Tanh(), 152 | ) 153 | # ------------------------------------------------- 154 | 155 | # ------------------------------------------------- 156 | self.feat_a_s2=nn.Sequential( 157 | nn.Linear(3,6), 158 | nn.ReLU(inplace=True), 159 | ) 160 | self.pred_a_s2=nn.Sequential( 161 | nn.Linear(6,3), 162 | nn.ReLU(inplace=True), 163 | ) 164 | self.local_s2=nn.Sequential( 165 | nn.Linear(6,3), 166 | nn.Tanh(), 167 | ) 168 | # ------------------------------------------------- 169 | 170 | # ------------------------------------------------- 171 | self.s_layer1_2=nn.Sequential( 172 | nn.Conv2d(16,10,kernel_size=1,padding=1), 173 | nn.ReLU(inplace=True), 174 | nn.MaxPool2d(kernel_size=8,stride=8), 175 | ) 176 | self.s_layer1_mix=nn.Sequential( 177 | nn.Dropout(p=0.2), 178 | nn.Linear(160,3), 179 | nn.ReLU(inplace=True), 180 | ) 181 | # ------------------------------------------------- 182 | 183 | # ------------------------------------------------- 184 | self.x_layer1_2=nn.Sequential( 185 | nn.Conv2d(32,10,kernel_size=1,padding=1), 186 | nn.ReLU(inplace=True), 187 | nn.AvgPool2d(kernel_size=8,stride=8), 188 | ) 189 | self.x_layer1_mix=nn.Sequential( 190 | nn.Dropout(p=0.2), 191 | nn.Linear(160,3), 192 | nn.ReLU(inplace=True), 193 | ) 194 | # ------------------------------------------------- 195 | 196 | # ------------------------------------------------- 197 | self.delta_s3=nn.Sequential( 198 | nn.Linear(160,1), 199 | nn.Tanh(), 200 | ) 201 | # ------------------------------------------------- 202 | 203 | # ------------------------------------------------- 204 | self.feat_a_s3=nn.Sequential( 205 | nn.Linear(3,6), 206 | nn.ReLU(inplace=True), 207 | ) 208 | self.pred_a_s3=nn.Sequential( 209 | nn.Linear(6,3), 210 | nn.ReLU(inplace=True), 211 | ) 212 | self.local_s3=nn.Sequential( 213 | nn.Linear(6,3), 214 | nn.Tanh(), 215 | ) 216 | # ------------------------------------------------- 217 | self.init_params() 218 | 219 | def init_params(self): 220 | for m in self.modules(): 221 | if isinstance(m, nn.Conv2d): 222 | init.kaiming_normal(m.weight, mode='fan_out') 223 | if m.bias is not None: 224 | init.constant(m.bias, 0.5) 225 | elif isinstance(m, nn.BatchNorm2d): 226 | init.constant(m.weight, 0.5) 227 | init.constant(m.bias, 0.6) 228 | elif isinstance(m, nn.Linear): 229 | init.normal(m.weight, std=0.1) 230 | if m.bias is not None: 231 | init.constant(m.bias, 0.3) 232 | 233 | def forward(self,x): 234 | # x->x1->x_layer1 : 3*64*64->32*32*32 235 | x1=self.x1(x) 236 | x_layer1=self.x_layer1(x1) 237 | 238 | # x_layer1->x2->x_layer2 : 32*32*32->32*16*16 239 | x2=self.x2(x_layer1) 240 | x_layer2=self.x_layer2(x2) 241 | 242 | # x_layer2->x3->x_layer3->x4 : 32*16*16->32*8*8 243 | x3=self.x3(x_layer2) 244 | x_layer3=self.x_layer3(x3) 245 | x4=self.x4(x_layer3) 246 | 247 | # x->s1->s_layer1 : 3*64*64->16*32*32 248 | s1=self.s1(x) 249 | s_layer1=self.s_layer1(s1) 250 | 251 | # s_layer1->s2->s_layer2 : 16*32*32->16*16*16 252 | s2=self.s2(s_layer1) 253 | s_layer2=self.s_layer2(s2) 254 | 255 | # s_layer2->s3->s_layer3->s4 : 16*16*16->16*8*8 256 | s3=self.s3(s_layer2) 257 | s_layer3=self.s_layer3(s3) 258 | s4=self.s4(s_layer3) 259 | 260 | # s4->s_layer4->s_layer4_mix : 16*8*8->10*8*8->640->3 261 | s_layer4=self.s_layer4(s4) 262 | s_layer4=s_layer4.view(s_layer4.size(0), -1) 263 | s_layer4_mix=self.s_layer4_mix(s_layer4) 264 | 265 | # x4->x_layer4->x_layer4_mix : 32*8*8->10*8*8->640->3 266 | x_layer4=self.x_layer4(x4) 267 | x_layer4=x_layer4.view(x_layer4.size(0), -1) 268 | x_layer4_mix=self.x_layer4_mix(x_layer4) 269 | 270 | # feat_a_s1_pre->delta_s1 : 640->1 271 | feat_a_s1_pre=s_layer4.mul(x_layer4) 272 | delta_s1=self.delta_s1(feat_a_s1_pre) 273 | 274 | # feat_a_s1->pred_a_s1 : 3->6->3 275 | # feat_a_s1->local_s1 : 3->6->3 276 | feat_a_s1=s_layer4_mix.mul(x_layer4_mix) 277 | feat_a_s1=self.feat_a_s1(feat_a_s1) 278 | pred_a_s1=self.pred_a_s1(feat_a_s1) 279 | local_s1=self.local_s1(feat_a_s1) 280 | 281 | # s_layer2->s_layer2_mix : 16*16*16->10*16*16->10*4*4->160->3 282 | s_layer2=self.s_layer2_2(s_layer2) 283 | s_layer2=s_layer2.view(s_layer2.size(0), -1) 284 | s_layer2_mix=self.s_layer2_mix(s_layer2) 285 | 286 | # x_layer2->x_layer2_mix : 32*16*16->10*16*16->10*4*4->160->3 287 | x_layer2=self.x_layer2_2(x_layer2) 288 | x_layer2=x_layer2.view(x_layer2.size(0), -1) 289 | x_layer2_mix=self.x_layer2_mix(x_layer2) 290 | 291 | # feat_a_s2_pre->delta_s2 : 160->1 292 | feat_a_s2_pre=s_layer2.mul(x_layer2) 293 | delta_s2=self.delta_s2(feat_a_s2_pre) 294 | 295 | # feat_a_s2->pred_a_s2 : 3->6->3 296 | # feat_a_s2->local_s2 : 3->6->3 297 | feat_a_s2=s_layer2_mix.mul(x_layer2_mix) 298 | feat_a_s2=self.feat_a_s2(feat_a_s2) 299 | pred_a_s2=self.pred_a_s2(feat_a_s2) 300 | local_s2=self.local_s2(feat_a_s2) 301 | 302 | # s_layer1->s_layer1_mix : 16*32*32->10*32*32->10*4*4->160->3 303 | s_layer1=self.s_layer1_2(s_layer1) 304 | s_layer1=s_layer1.view(s_layer1.size(0), -1) 305 | s_layer1_mix=self.s_layer1_mix(s_layer1) 306 | 307 | # x_layer1->x_layer1_mix : 32*32*32->10*32*32->10*4*4->160->3 308 | x_layer1=self.x_layer1_2(x_layer1) 309 | x_layer1=x_layer1.view(x_layer1.size(0), -1) 310 | x_layer1_mix=self.x_layer1_mix(x_layer1) 311 | 312 | # feat_a_s3_pre->delta_s3 : 160->1 313 | feat_a_s3_pre=s_layer1.mul(x_layer1) 314 | delta_s3=self.delta_s3(feat_a_s3_pre) 315 | 316 | # feat_a_s3->pred_a_s3 : 3->6->3 317 | # feat_a_s3->local_s3 : 3->6->3 318 | feat_a_s3=s_layer1_mix.mul(x_layer1_mix) 319 | feat_a_s3=self.feat_a_s3(feat_a_s3) 320 | pred_a_s3=self.pred_a_s3(feat_a_s3) 321 | local_s3=self.local_s3(feat_a_s3) 322 | 323 | a = pred_a_s1[:,0]*0 324 | b = pred_a_s1[:,0]*0 325 | c = pred_a_s1[:,0]*0 326 | 327 | for i in range(0,self.stage_num[0]): 328 | a = a+(i+self.lambda_local*local_s1[:,i])*pred_a_s1[:,i] 329 | a = torch.unsqueeze(a, 1) 330 | a = a / (self.stage_num[0] * (1 + self.lambda_d * delta_s1)) 331 | 332 | for j in range(0,self.stage_num[1]): 333 | b = b+(j+self.lambda_local*local_s2[:,j])*pred_a_s2[:,j] 334 | b = torch.unsqueeze(b, 1) 335 | b = b / (self.stage_num[0] * (1 + self.lambda_d * delta_s1)) / ( 336 | self.stage_num[1] * (1 + self.lambda_d * delta_s2)) 337 | 338 | for k in range(0,self.stage_num[2]): 339 | c = c+(k+self.lambda_local*local_s3[:,k])*pred_a_s3[:,k] 340 | 341 | c = torch.unsqueeze(c,1) 342 | c = c / (self.stage_num[0] * (1 + self.lambda_d * delta_s1)) / ( 343 | self.stage_num[1] * (1 + self.lambda_d * delta_s2)) / ( 344 | self.stage_num[2] * (1 + self.lambda_d * delta_s3)) 345 | 346 | if self.age: 347 | V=101 348 | else: 349 | V=1 350 | age = (a+b+c)*V 351 | age = torch.squeeze(age,1) 352 | return age 353 | --------------------------------------------------------------------------------