├── .gitignore
├── README.md
├── convert.py
├── demo.py
├── dex
├── .gitattributes
├── __init__.py
├── api.py
├── models.py
└── pth
│ ├── age_sd.pth
│ └── gender_sd.pth
├── imgs
├── 1.png
├── 10.jpeg
├── 11.jpeg
├── 2.png
├── 3.png
├── 4.png
├── 5.png
├── 6.png
├── 7.png
└── 9.jpeg
├── network
├── age.png
└── gender.png
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .vscode
3 | *egg-info
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DEX: Deep EXpectation of apparent age from a single image
2 |
3 | This is a pytorch version of DEX. Refer to its [Home Page](https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/) for more details
4 |
5 | You can refer to [insight](https://github.com/siriusdemon/hackaway/tree/master/projects/insight) if you want a much smaller model but it uses `mxnet` instead of `pytorch`. I haven't convert it to `pytorch` yet.
6 |
7 | ## Getting Started
8 |
9 | A separate Python environment is recommended.
10 | + Python3.5+ (Python3.5, Python3.6 are tested)
11 | + Pytorch == 1.0
12 | + opencv4 (opencv3.4.5 is tested also)
13 | + numpy
14 |
15 | install dependences using `pip`
16 | ```bash
17 | pip3 install numpy opencv-python
18 | pip3 install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
19 | pip3 install torchvision (optional)
20 | ```
21 | or install using `conda`
22 | ```bash
23 | conda install opencv numpy
24 | conda install pytorch-cpu torchvision-cpu -c pytorch
25 | ```
26 |
27 | ## Usage
28 | ```bash
29 | git clone https://github.com/siriusdemon/pytorch-DEX.git
30 | cd pytorch-DEX
31 | python demo.py path/to/image
32 | ```
33 |
34 | ## Results
35 |
36 |
37 | ```
38 | predict image: imgs/2.png
39 | woman: 0.994, man: 0.006
40 | age: 21.433
41 | ```
42 |
43 |
44 | ```bash
45 | predict image: imgs/5.png
46 | woman: 0.010, man: 0.990
47 | age: 42.896
48 | ```
49 |
50 | ## Installation
51 | You can use dex as a separate Python package right now!
52 | ```
53 | cd pytorch-DEX
54 | pip install .
55 | ```
56 | See [demo.py](demo.py) for example.
57 |
58 | ## Citation
59 | @InProceedings{Rothe-ICCVW-2015,
60 | author = {Rasmus Rothe and Radu Timofte and Luc Van Gool},
61 | title = {DEX: Deep EXpectation of apparent age from a single image},
62 | booktitle = {IEEE International Conference on Computer Vision Workshops (ICCVW)},
63 | year = {2015},
64 | month = {December},
65 | }
66 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | # quick and dirty convertion script
2 | import caffe
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from models import Age, Gender
8 |
9 |
10 |
11 | def convert_age():
12 | torch_net = Age()
13 |
14 | caffe_net = caffe.Net('age.prototxt', "dex_chalearn_iccv2015.caffemodel", caffe.TEST)
15 | caffe_params = caffe_net.params
16 |
17 | mappings = {
18 | 'conv1_1': torch_net.conv[0].conv1,
19 | 'conv1_2': torch_net.conv[0].conv2,
20 | 'conv2_1': torch_net.conv[1].conv1,
21 | 'conv2_2': torch_net.conv[1].conv2,
22 | 'conv3_1': torch_net.conv[2].conv1,
23 | 'conv3_2': torch_net.conv[2].conv2,
24 | 'conv3_3': torch_net.conv[2].conv3,
25 | 'conv4_1': torch_net.conv[3].conv1,
26 | 'conv4_2': torch_net.conv[3].conv2,
27 | 'conv4_3': torch_net.conv[3].conv3,
28 | 'conv5_1': torch_net.conv[4].conv1,
29 | 'conv5_2': torch_net.conv[4].conv2,
30 | 'conv5_3': torch_net.conv[4].conv3,
31 | 'fc6': torch_net.fc1[0],
32 | 'fc7': torch_net.fc2[0],
33 | 'fc8-101': torch_net.cls,
34 | }
35 |
36 | for k, layer in mappings.items():
37 | layer.weight.data.copy_(torch.from_numpy(caffe_params[k][0].data))
38 | layer.bias.data.copy_(torch.from_numpy(caffe_params[k][1].data))
39 | torch.save(torch_net, 'pth/age.pth')
40 |
41 | def convert_gender():
42 | torch_net = Gender()
43 |
44 | caffe_net = caffe.Net('gender.prototxt', "gender.caffemodel", caffe.TEST)
45 | caffe_params = caffe_net.params
46 |
47 | mappings = {
48 | 'conv1_1': torch_net.conv[0].conv1,
49 | 'conv1_2': torch_net.conv[0].conv2,
50 | 'conv2_1': torch_net.conv[1].conv1,
51 | 'conv2_2': torch_net.conv[1].conv2,
52 | 'conv3_1': torch_net.conv[2].conv1,
53 | 'conv3_2': torch_net.conv[2].conv2,
54 | 'conv3_3': torch_net.conv[2].conv3,
55 | 'conv4_1': torch_net.conv[3].conv1,
56 | 'conv4_2': torch_net.conv[3].conv2,
57 | 'conv4_3': torch_net.conv[3].conv3,
58 | 'conv5_1': torch_net.conv[4].conv1,
59 | 'conv5_2': torch_net.conv[4].conv2,
60 | 'conv5_3': torch_net.conv[4].conv3,
61 | 'fc6': torch_net.fc1[0],
62 | 'fc7': torch_net.fc2[0],
63 | 'fc8-2': torch_net.cls,
64 | }
65 |
66 | for k, layer in mappings.items():
67 | layer.weight.data.copy_(torch.from_numpy(caffe_params[k][0].data))
68 | layer.bias.data.copy_(torch.from_numpy(caffe_params[k][1].data))
69 | torch.save(torch_net, 'pth/gender.pth')
70 |
71 |
72 | if __name__ == '__main__':
73 | convert_age()
74 | convert_gender()
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import dex
3 |
4 | # setup model
5 | dex.eval()
6 |
7 |
8 |
9 | if __name__ == '__main__':
10 | if len(sys.argv) < 2:
11 | print("Usage: python demo.py path/to/img")
12 | sys.exit()
13 |
14 | path = sys.argv[1]
15 | age, female, male = dex.estimate(path)
16 | print("predict image: {}".format(path))
17 | print("woman: {:.3f}, man: {:.3f}".format(female, male))
18 | print("age: {:.3f}".format(age))
19 |
--------------------------------------------------------------------------------
/dex/.gitattributes:
--------------------------------------------------------------------------------
1 | *.pth filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/dex/__init__.py:
--------------------------------------------------------------------------------
1 | from .api import estimate
2 | from .api import estimate_age
3 | from .api import estimate_gender
4 | from .api import _eval as eval
5 |
6 | eval()
7 |
--------------------------------------------------------------------------------
/dex/api.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 |
6 | from .models import Age, Gender
7 |
8 | age_model = Age()
9 | gender_model = Gender()
10 |
11 | cwd = os.path.dirname(__file__)
12 | age_model_path = os.path.join(cwd, 'pth/age_sd.pth')
13 | gender_model_path = os.path.join(cwd, 'pth/gender_sd.pth')
14 |
15 | def _eval():
16 | global age_model
17 | global gender_model
18 | age_model.load_state_dict(torch.load(age_model_path))
19 | age_model.eval()
20 | gender_model.load_state_dict(torch.load(gender_model_path))
21 | gender_model.eval()
22 |
23 |
24 | def preprocess(img):
25 | img = cv2.resize(img, (224, 224))
26 | img = np.transpose(img, (2, 0, 1))
27 | img = img[None, :, :, :]
28 | tensor = torch.from_numpy(img)
29 | tensor = tensor.type('torch.FloatTensor')
30 | return tensor
31 |
32 |
33 | def expected_age(vector):
34 | res = [(i+1)*v for i, v in enumerate(vector)]
35 | return sum(res)
36 |
37 |
38 | def estimate_age(img):
39 | if type(img) == str:
40 | img = cv2.imread(img)
41 | tensor = preprocess(img)
42 | with torch.no_grad():
43 | output = age_model(tensor)
44 | output = output.numpy().squeeze()
45 | age = expected_age(output)
46 | return age
47 |
48 | def estimate_gender(img):
49 | if type(img) == str:
50 | img = cv2.imread(img)
51 | tensor = preprocess(img)
52 | with torch.no_grad():
53 | output = gender_model(tensor)
54 | output = output.numpy().squeeze()
55 | return output[0], output[1]
56 |
57 | def estimate(img):
58 | """return values as (age, female, male)"""
59 | img = cv2.imread(img)
60 | result = [estimate_age(img)]
61 | result.extend(estimate_gender(img))
62 | return result
63 |
--------------------------------------------------------------------------------
/dex/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from collections import OrderedDict
5 |
6 |
7 | def vgg_block(in_channels, out_channels, more=False):
8 | blocklist = [
9 | ('conv1', nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)),
10 | ('relu1', nn.ReLU(inplace=True)),
11 | ('conv2', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)),
12 | ('relu2', nn.ReLU(inplace=True)),
13 | ]
14 | if more:
15 | blocklist.extend([
16 | ('conv3', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)),
17 | ('relu3', nn.ReLU(inplace=True)),
18 | ])
19 | blocklist.append(('maxpool', nn.MaxPool2d(kernel_size=2, stride=2)))
20 | block = nn.Sequential(OrderedDict(blocklist))
21 | return block
22 |
23 | # VGG16 architecture
24 | class VGG(nn.Module):
25 | def __init__(self, classes=1000, channels=3):
26 | super().__init__()
27 | self.conv = nn.Sequential(
28 | vgg_block(channels, 64),
29 | vgg_block(64, 128),
30 | vgg_block(128, 256, True),
31 | vgg_block(256, 512, True),
32 | vgg_block(512, 512, True),
33 | )
34 | self.fc1 = nn.Sequential(
35 | nn.Linear(512*7*7, 4096),
36 | nn.ReLU(inplace=True),
37 | nn.Dropout(0.5, inplace=True),
38 | )
39 | self.fc2 = nn.Sequential(
40 | nn.Linear(4096, 4096),
41 | nn.ReLU(inplace=True),
42 | nn.Dropout(0.5, inplace=True),
43 | )
44 | self.cls = nn.Linear(4096, classes)
45 |
46 | def forward(self, x):
47 | in_size = x.shape[0]
48 | x = self.conv(x)
49 | x = x.view(in_size, -1)
50 | x = self.fc1(x)
51 | x = self.fc2(x)
52 | x = self.cls(x)
53 | x = F.softmax(x, dim=1)
54 | return x
55 |
56 | class Gender(VGG):
57 | def __init__(self, classes=2, channels=3):
58 | super().__init__()
59 | self.cls = nn.Linear(4096, classes)
60 |
61 | class Age(VGG):
62 | def __init__(self, classes=101, channels=3):
63 | super().__init__()
64 | self.cls = nn.Linear(4096, classes)
65 |
66 |
67 |
68 | if __name__ == '__main__':
69 | net = Gender()
70 | print(net)
71 |
72 |
--------------------------------------------------------------------------------
/dex/pth/age_sd.pth:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:d754f06341792aecf1c5893269ae849892072532f161b63cd595d713897f5302
3 | size 538703264
4 |
--------------------------------------------------------------------------------
/dex/pth/gender_sd.pth:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:4e2bae5c6959145725c0152c6e5f0e65f0f2d176e3c2f7ff32b87813102767d1
3 | size 537080850
4 |
--------------------------------------------------------------------------------
/imgs/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/1.png
--------------------------------------------------------------------------------
/imgs/10.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/10.jpeg
--------------------------------------------------------------------------------
/imgs/11.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/11.jpeg
--------------------------------------------------------------------------------
/imgs/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/2.png
--------------------------------------------------------------------------------
/imgs/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/3.png
--------------------------------------------------------------------------------
/imgs/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/4.png
--------------------------------------------------------------------------------
/imgs/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/5.png
--------------------------------------------------------------------------------
/imgs/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/6.png
--------------------------------------------------------------------------------
/imgs/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/7.png
--------------------------------------------------------------------------------
/imgs/9.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/imgs/9.jpeg
--------------------------------------------------------------------------------
/network/age.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/network/age.png
--------------------------------------------------------------------------------
/network/gender.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siriusdemon/pytorch-DEX/e11c695f032e61a48c6db8bec46146faf1972e63/network/gender.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", 'r') as f:
4 | long_description = f.read()
5 |
6 |
7 | setuptools.setup(
8 | name = "dex",
9 | version = "0.1.0",
10 | author = "sirius demon",
11 | author_email = "mory2016@126.com",
12 | description="Deep EXpectation model in Pytorch",
13 | long_description=long_description,
14 | long_description_content_type='text/markdown',
15 | url = "https://github.com/siriusdemon/pytorch-dex",
16 | packages=setuptools.find_packages(),
17 | package_data = {
18 | 'dex': ['pth/*.pth'],
19 | },
20 | classifiers = [
21 | "Programming Language :: Python :: 3.5",
22 | "License :: OSI Approved :: BSD License",
23 | "Operating System :: OS Independent",
24 | ],
25 |
26 | )
--------------------------------------------------------------------------------