├── .gitignore ├── README.md ├── convert.py ├── demo.py ├── dex ├── .gitattributes ├── __init__.py ├── api.py ├── models.py └── pth │ ├── age_sd.pth │ └── gender_sd.pth ├── imgs ├── 1.png ├── 10.jpeg ├── 11.jpeg ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png └── 9.jpeg ├── network ├── age.png └── gender.png └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | *egg-info -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DEX: Deep EXpectation of apparent age from a single image 2 | 3 | This is a pytorch version of DEX. Refer to its [Home Page](https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/) for more details 4 | 5 | You can refer to [insight](https://github.com/siriusdemon/hackaway/tree/master/projects/insight) if you want a much smaller model but it uses `mxnet` instead of `pytorch`. I haven't convert it to `pytorch` yet. 6 | 7 | ## Getting Started 8 | 9 | A separate Python environment is recommended. 10 | + Python3.5+ (Python3.5, Python3.6 are tested) 11 | + Pytorch == 1.0 12 | + opencv4 (opencv3.4.5 is tested also) 13 | + numpy 14 | 15 | install dependences using `pip` 16 | ```bash 17 | pip3 install numpy opencv-python 18 | pip3 install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl 19 | pip3 install torchvision (optional) 20 | ``` 21 | or install using `conda` 22 | ```bash 23 | conda install opencv numpy 24 | conda install pytorch-cpu torchvision-cpu -c pytorch 25 | ``` 26 | 27 | ## Usage 28 | ```bash 29 | git clone https://github.com/siriusdemon/pytorch-DEX.git 30 | cd pytorch-DEX 31 | python demo.py path/to/image 32 | ``` 33 | 34 | ## Results 35 | 36 | 37 | ``` 38 | predict image: imgs/2.png 39 | woman: 0.994, man: 0.006 40 | age: 21.433 41 | ``` 42 | 43 | 44 | ```bash 45 | predict image: imgs/5.png 46 | woman: 0.010, man: 0.990 47 | age: 42.896 48 | ``` 49 | 50 | ## Installation 51 | You can use dex as a separate Python package right now! 52 | ``` 53 | cd pytorch-DEX 54 | pip install . 55 | ``` 56 | See [demo.py](demo.py) for example. 57 | 58 | ## Citation 59 | @InProceedings{Rothe-ICCVW-2015, 60 | author = {Rasmus Rothe and Radu Timofte and Luc Van Gool}, 61 | title = {DEX: Deep EXpectation of apparent age from a single image}, 62 | booktitle = {IEEE International Conference on Computer Vision Workshops (ICCVW)}, 63 | year = {2015}, 64 | month = {December}, 65 | } 66 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # quick and dirty convertion script 2 | import caffe 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from models import Age, Gender 8 | 9 | 10 | 11 | def convert_age(): 12 | torch_net = Age() 13 | 14 | caffe_net = caffe.Net('age.prototxt', "dex_chalearn_iccv2015.caffemodel", caffe.TEST) 15 | caffe_params = caffe_net.params 16 | 17 | mappings = { 18 | 'conv1_1': torch_net.conv[0].conv1, 19 | 'conv1_2': torch_net.conv[0].conv2, 20 | 'conv2_1': torch_net.conv[1].conv1, 21 | 'conv2_2': torch_net.conv[1].conv2, 22 | 'conv3_1': torch_net.conv[2].conv1, 23 | 'conv3_2': torch_net.conv[2].conv2, 24 | 'conv3_3': torch_net.conv[2].conv3, 25 | 'conv4_1': torch_net.conv[3].conv1, 26 | 'conv4_2': torch_net.conv[3].conv2, 27 | 'conv4_3': torch_net.conv[3].conv3, 28 | 'conv5_1': torch_net.conv[4].conv1, 29 | 'conv5_2': torch_net.conv[4].conv2, 30 | 'conv5_3': torch_net.conv[4].conv3, 31 | 'fc6': torch_net.fc1[0], 32 | 'fc7': torch_net.fc2[0], 33 | 'fc8-101': torch_net.cls, 34 | } 35 | 36 | for k, layer in mappings.items(): 37 | layer.weight.data.copy_(torch.from_numpy(caffe_params[k][0].data)) 38 | layer.bias.data.copy_(torch.from_numpy(caffe_params[k][1].data)) 39 | torch.save(torch_net, 'pth/age.pth') 40 | 41 | def convert_gender(): 42 | torch_net = Gender() 43 | 44 | caffe_net = caffe.Net('gender.prototxt', "gender.caffemodel", caffe.TEST) 45 | caffe_params = caffe_net.params 46 | 47 | mappings = { 48 | 'conv1_1': torch_net.conv[0].conv1, 49 | 'conv1_2': torch_net.conv[0].conv2, 50 | 'conv2_1': torch_net.conv[1].conv1, 51 | 'conv2_2': torch_net.conv[1].conv2, 52 | 'conv3_1': torch_net.conv[2].conv1, 53 | 'conv3_2': torch_net.conv[2].conv2, 54 | 'conv3_3': torch_net.conv[2].conv3, 55 | 'conv4_1': torch_net.conv[3].conv1, 56 | 'conv4_2': torch_net.conv[3].conv2, 57 | 'conv4_3': torch_net.conv[3].conv3, 58 | 'conv5_1': torch_net.conv[4].conv1, 59 | 'conv5_2': torch_net.conv[4].conv2, 60 | 'conv5_3': torch_net.conv[4].conv3, 61 | 'fc6': torch_net.fc1[0], 62 | 'fc7': torch_net.fc2[0], 63 | 'fc8-2': torch_net.cls, 64 | } 65 | 66 | for k, layer in mappings.items(): 67 | layer.weight.data.copy_(torch.from_numpy(caffe_params[k][0].data)) 68 | layer.bias.data.copy_(torch.from_numpy(caffe_params[k][1].data)) 69 | torch.save(torch_net, 'pth/gender.pth') 70 | 71 | 72 | if __name__ == '__main__': 73 | convert_age() 74 | convert_gender() -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import dex 3 | 4 | # setup model 5 | dex.eval() 6 | 7 | 8 | 9 | if __name__ == '__main__': 10 | if len(sys.argv) < 2: 11 | print("Usage: python demo.py path/to/img") 12 | sys.exit() 13 | 14 | path = sys.argv[1] 15 | age, female, male = dex.estimate(path) 16 | print("predict image: {}".format(path)) 17 | print("woman: {:.3f}, man: {:.3f}".format(female, male)) 18 | print("age: {:.3f}".format(age)) 19 | -------------------------------------------------------------------------------- /dex/.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /dex/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import estimate 2 | from .api import estimate_age 3 | from .api import estimate_gender 4 | from .api import _eval as eval 5 | 6 | eval() 7 | -------------------------------------------------------------------------------- /dex/api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | 6 | from .models import Age, Gender 7 | 8 | age_model = Age() 9 | gender_model = Gender() 10 | 11 | cwd = os.path.dirname(__file__) 12 | age_model_path = os.path.join(cwd, 'pth/age_sd.pth') 13 | gender_model_path = os.path.join(cwd, 'pth/gender_sd.pth') 14 | 15 | def _eval(): 16 | global age_model 17 | global gender_model 18 | age_model.load_state_dict(torch.load(age_model_path)) 19 | age_model.eval() 20 | gender_model.load_state_dict(torch.load(gender_model_path)) 21 | gender_model.eval() 22 | 23 | 24 | def preprocess(img): 25 | img = cv2.resize(img, (224, 224)) 26 | img = np.transpose(img, (2, 0, 1)) 27 | img = img[None, :, :, :] 28 | tensor = torch.from_numpy(img) 29 | tensor = tensor.type('torch.FloatTensor') 30 | return tensor 31 | 32 | 33 | def expected_age(vector): 34 | res = [(i+1)*v for i, v in enumerate(vector)] 35 | return sum(res) 36 | 37 | 38 | def estimate_age(img): 39 | if type(img) == str: 40 | img = cv2.imread(img) 41 | tensor = preprocess(img) 42 | with torch.no_grad(): 43 | output = age_model(tensor) 44 | output = output.numpy().squeeze() 45 | age = expected_age(output) 46 | return age 47 | 48 | def estimate_gender(img): 49 | if type(img) == str: 50 | img = cv2.imread(img) 51 | tensor = preprocess(img) 52 | with torch.no_grad(): 53 | output = gender_model(tensor) 54 | output = output.numpy().squeeze() 55 | return output[0], output[1] 56 | 57 | def estimate(img): 58 | """return values as (age, female, male)""" 59 | img = cv2.imread(img) 60 | result = [estimate_age(img)] 61 | result.extend(estimate_gender(img)) 62 | return result 63 | -------------------------------------------------------------------------------- /dex/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | 7 | def vgg_block(in_channels, out_channels, more=False): 8 | blocklist = [ 9 | ('conv1', nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)), 10 | ('relu1', nn.ReLU(inplace=True)), 11 | ('conv2', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)), 12 | ('relu2', nn.ReLU(inplace=True)), 13 | ] 14 | if more: 15 | blocklist.extend([ 16 | ('conv3', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)), 17 | ('relu3', nn.ReLU(inplace=True)), 18 | ]) 19 | blocklist.append(('maxpool', nn.MaxPool2d(kernel_size=2, stride=2))) 20 | block = nn.Sequential(OrderedDict(blocklist)) 21 | return block 22 | 23 | # VGG16 architecture 24 | class VGG(nn.Module): 25 | def __init__(self, classes=1000, channels=3): 26 | super().__init__() 27 | self.conv = nn.Sequential( 28 | vgg_block(channels, 64), 29 | vgg_block(64, 128), 30 | vgg_block(128, 256, True), 31 | vgg_block(256, 512, True), 32 | vgg_block(512, 512, True), 33 | ) 34 | self.fc1 = nn.Sequential( 35 | nn.Linear(512*7*7, 4096), 36 | nn.ReLU(inplace=True), 37 | nn.Dropout(0.5, inplace=True), 38 | ) 39 | self.fc2 = nn.Sequential( 40 | nn.Linear(4096, 4096), 41 | nn.ReLU(inplace=True), 42 | nn.Dropout(0.5, inplace=True), 43 | ) 44 | self.cls = nn.Linear(4096, classes) 45 | 46 | def forward(self, x): 47 | in_size = x.shape[0] 48 | x = self.conv(x) 49 | x = x.view(in_size, -1) 50 | x = self.fc1(x) 51 | x = self.fc2(x) 52 | x = self.cls(x) 53 | x = F.softmax(x, dim=1) 54 | return x 55 | 56 | class Gender(VGG): 57 | def __init__(self, classes=2, channels=3): 58 | super().__init__() 59 | self.cls = nn.Linear(4096, classes) 60 | 61 | class Age(VGG): 62 | def __init__(self, classes=101, channels=3): 63 | super().__init__() 64 | self.cls = nn.Linear(4096, classes) 65 | 66 | 67 | 68 | if __name__ == '__main__': 69 | net = Gender() 70 | print(net) 71 | 72 | -------------------------------------------------------------------------------- /dex/pth/age_sd.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d754f06341792aecf1c5893269ae849892072532f161b63cd595d713897f5302 3 | size 538703264 4 | -------------------------------------------------------------------------------- /dex/pth/gender_sd.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4e2bae5c6959145725c0152c6e5f0e65f0f2d176e3c2f7ff32b87813102767d1 3 | size 537080850 4 | -------------------------------------------------------------------------------- /imgs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/1.png -------------------------------------------------------------------------------- /imgs/10.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/10.jpeg -------------------------------------------------------------------------------- /imgs/11.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/11.jpeg -------------------------------------------------------------------------------- /imgs/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/2.png -------------------------------------------------------------------------------- /imgs/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/3.png -------------------------------------------------------------------------------- /imgs/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/4.png -------------------------------------------------------------------------------- /imgs/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/5.png -------------------------------------------------------------------------------- /imgs/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/6.png -------------------------------------------------------------------------------- /imgs/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/7.png -------------------------------------------------------------------------------- /imgs/9.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/9.jpeg -------------------------------------------------------------------------------- /network/age.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/network/age.png -------------------------------------------------------------------------------- /network/gender.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/network/gender.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", 'r') as f: 4 | long_description = f.read() 5 | 6 | 7 | setuptools.setup( 8 | name = "dex", 9 | version = "0.1.0", 10 | author = "sirius demon", 11 | author_email = "mory2016@126.com", 12 | description="Deep EXpectation model in Pytorch", 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', 15 | url = "https://github.com/siriusdemon/pytorch-dex", 16 | packages=setuptools.find_packages(), 17 | package_data = { 18 | 'dex': ['pth/*.pth'], 19 | }, 20 | classifiers = [ 21 | "Programming Language :: Python :: 3.5", 22 | "License :: OSI Approved :: BSD License", 23 | "Operating System :: OS Independent", 24 | ], 25 | 26 | ) --------------------------------------------------------------------------------