├── LICENSE.md
├── README.md
├── dataset.py
├── defaults.py
├── demo.py
├── ignore_list.csv
├── misc
├── example.png
└── tfboard.png
├── model.py
├── requirements.txt
├── test.py
└── train.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yusuke Uchida
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Age Estimation PyTorch
2 | PyTorch-based CNN implementation for estimating age from face images.
3 | Currently only the APPA-REAL dataset is supported.
4 | Similar Keras-based project can be found [here](https://github.com/yu4u/age-gender-estimation).
5 |
6 |
7 |
8 | ## Requirements
9 |
10 | ```bash
11 | pip install -r requirements.txt
12 | ```
13 |
14 | ## Demo
15 | Webcam is required.
16 | See `python demo.py -h` for detailed options.
17 |
18 | ```bash
19 | python demo.py
20 | ```
21 |
22 | Using `--img_dir` argument, images in that directory will be used as input:
23 |
24 | ```bash
25 | python demo.py --img_dir [PATH/TO/IMAGE_DIRECTORY]
26 | ```
27 |
28 | Further using `--output_dir` argument,
29 | resulting images will be saved in that directory (no resulting image window is displayed in this case):
30 |
31 | ```bash
32 | python demo.py --img_dir [PATH/TO/IMAGE_DIRECTORY] --output_dir [PATH/TO/OUTPUT_DIRECTORY]
33 | ```
34 |
35 | ## Train
36 |
37 | #### Download Dataset
38 |
39 | Download and extract the [APPA-REAL dataset](http://chalearnlap.cvc.uab.es/dataset/26/description/).
40 |
41 | > The APPA-REAL database contains 7,591 images with associated real and apparent age labels. The total number of apparent votes is around 250,000. On average we have around 38 votes per each image and this makes the average apparent age very stable (0.3 standard error of the mean).
42 |
43 | ```bash
44 | wget http://158.109.8.102/AppaRealAge/appa-real-release.zip
45 | unzip appa-real-release.zip
46 | ```
47 |
48 | #### Train Model
49 | Train a model using the APPA-REAL dataset.
50 | See `python train.py -h` for detailed options.
51 |
52 | ```bash
53 | python train.py --data_dir [PATH/TO/appa-real-release] --tensorboard tf_log
54 | ```
55 |
56 | Check training progress:
57 |
58 | ```bash
59 | tensorboard --logdir=tf_log
60 | ```
61 |
62 |
63 |
64 | #### Training Options
65 | You can change training parameters including model architecture using additional arguments like this:
66 |
67 | ```bash
68 | python train.py --data_dir [PATH/TO/appa-real-release] --tensorboard tf_log MODEL.ARCH se_resnet50 TRAIN.OPT sgd TRAIN.LR 0.1
69 | ```
70 |
71 | All default parameters defined in [defaults.py](defaults.py) can be changed using this style.
72 |
73 |
74 | #### Test Trained Model
75 | Evaluate the trained model using the APPA-REAL test dataset.
76 |
77 | ```bash
78 | python test.py --data_dir [PATH/TO/appa-real-release] --resume [PATH/TO/BEST_MODEL.pth]
79 | ```
80 |
81 | After evaluation, you can see something like this:
82 |
83 | ```bash
84 | 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:08<00:00, 1.28it/s]
85 | test mae: 4.800
86 | ```
87 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import better_exceptions
3 | from pathlib import Path
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | import cv2
8 | from torch.utils.data import Dataset
9 | from imgaug import augmenters as iaa
10 |
11 |
12 | class ImgAugTransform:
13 | def __init__(self):
14 | self.aug = iaa.Sequential([
15 | iaa.OneOf([
16 | iaa.Sometimes(0.25, iaa.AdditiveGaussianNoise(scale=0.1 * 255)),
17 | iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0)))
18 | ]),
19 | iaa.Affine(
20 | rotate=(-20, 20), mode="edge",
21 | scale={"x": (0.95, 1.05), "y": (0.95, 1.05)},
22 | translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)}
23 | ),
24 | iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True),
25 | iaa.GammaContrast((0.3, 2)),
26 | iaa.Fliplr(0.5),
27 | ])
28 |
29 | def __call__(self, img):
30 | img = np.array(img)
31 | img = self.aug.augment_image(img)
32 | return img
33 |
34 |
35 | class FaceDataset(Dataset):
36 | def __init__(self, data_dir, data_type, img_size=224, augment=False, age_stddev=1.0):
37 | assert(data_type in ("train", "valid", "test"))
38 | csv_path = Path(data_dir).joinpath(f"gt_avg_{data_type}.csv")
39 | img_dir = Path(data_dir).joinpath(data_type)
40 | self.img_size = img_size
41 | self.augment = augment
42 | self.age_stddev = age_stddev
43 |
44 | if augment:
45 | self.transform = ImgAugTransform()
46 | else:
47 | self.transform = lambda i: i
48 |
49 | self.x = []
50 | self.y = []
51 | self.std = []
52 | df = pd.read_csv(str(csv_path))
53 | ignore_path = Path(__file__).resolve().parent.joinpath("ignore_list.csv")
54 | ignore_img_names = list(pd.read_csv(str(ignore_path))["img_name"].values)
55 |
56 | for _, row in df.iterrows():
57 | img_name = row["file_name"]
58 |
59 | if img_name in ignore_img_names:
60 | continue
61 |
62 | img_path = img_dir.joinpath(img_name + "_face.jpg")
63 | assert(img_path.is_file())
64 | self.x.append(str(img_path))
65 | self.y.append(row["apparent_age_avg"])
66 | self.std.append(row["apparent_age_std"])
67 |
68 | def __len__(self):
69 | return len(self.y)
70 |
71 | def __getitem__(self, idx):
72 | img_path = self.x[idx]
73 | age = self.y[idx]
74 |
75 | if self.augment:
76 | age += np.random.randn() * self.std[idx] * self.age_stddev
77 |
78 | img = cv2.imread(str(img_path), 1)
79 | img = cv2.resize(img, (self.img_size, self.img_size))
80 | img = self.transform(img).astype(np.float32)
81 | return torch.from_numpy(np.transpose(img, (2, 0, 1))), np.clip(round(age), 0, 100)
82 |
83 |
84 | def main():
85 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
86 | parser.add_argument("--data_dir", type=str, required=True)
87 | args = parser.parse_args()
88 | dataset = FaceDataset(args.data_dir, "train")
89 | print("train dataset len: {}".format(len(dataset)))
90 | dataset = FaceDataset(args.data_dir, "valid")
91 | print("valid dataset len: {}".format(len(dataset)))
92 | dataset = FaceDataset(args.data_dir, "test")
93 | print("test dataset len: {}".format(len(dataset)))
94 |
95 |
96 | if __name__ == '__main__':
97 | main()
98 |
--------------------------------------------------------------------------------
/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _C = CN()
4 |
5 | # Model
6 | _C.MODEL = CN()
7 | _C.MODEL.ARCH = "se_resnext50_32x4d" # check python train.py -h for available models
8 | _C.MODEL.IMG_SIZE = 224
9 |
10 | # Train
11 | _C.TRAIN = CN()
12 | _C.TRAIN.OPT = "adam" # adam or sgd
13 | _C.TRAIN.WORKERS = 8
14 | _C.TRAIN.LR = 0.001
15 | _C.TRAIN.LR_DECAY_STEP = 20
16 | _C.TRAIN.LR_DECAY_RATE = 0.2
17 | _C.TRAIN.MOMENTUM = 0.9
18 | _C.TRAIN.WEIGHT_DECAY = 0.0
19 | _C.TRAIN.BATCH_SIZE = 128
20 | _C.TRAIN.EPOCHS = 80
21 | _C.TRAIN.AGE_STDDEV = 1.0
22 |
23 | # Test
24 | _C.TEST = CN()
25 | _C.TEST.WORKERS = 8
26 | _C.TEST.BATCH_SIZE = 128
27 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import better_exceptions
3 | from pathlib import Path
4 | from contextlib import contextmanager
5 | import urllib.request
6 | import numpy as np
7 | import cv2
8 | import dlib
9 | import torch
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim
13 | import torch.utils.data
14 | import torch.nn.functional as F
15 | from model import get_model
16 | from defaults import _C as cfg
17 |
18 |
19 | def get_args():
20 | parser = argparse.ArgumentParser(description="Age estimation demo",
21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument("--resume", type=str, default=None,
23 | help="Model weight to be tested")
24 | parser.add_argument("--margin", type=float, default=0.4,
25 | help="Margin around detected face for age-gender estimation")
26 | parser.add_argument("--img_dir", type=str, default=None,
27 | help="Target image directory; if set, images in image_dir are used instead of webcam")
28 | parser.add_argument("--output_dir", type=str, default=None,
29 | help="Output directory to which resulting images will be stored if set")
30 | parser.add_argument("opts", default=[], nargs=argparse.REMAINDER,
31 | help="Modify config options using the command-line")
32 | args = parser.parse_args()
33 | return args
34 |
35 |
36 | def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX,
37 | font_scale=0.8, thickness=1):
38 | size = cv2.getTextSize(label, font, font_scale, thickness)[0]
39 | x, y = point
40 | cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED)
41 | cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness, lineType=cv2.LINE_AA)
42 |
43 |
44 | @contextmanager
45 | def video_capture(*args, **kwargs):
46 | cap = cv2.VideoCapture(*args, **kwargs)
47 | try:
48 | yield cap
49 | finally:
50 | cap.release()
51 |
52 |
53 | def yield_images():
54 | with video_capture(0) as cap:
55 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
56 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
57 |
58 | while True:
59 | ret, img = cap.read()
60 |
61 | if not ret:
62 | raise RuntimeError("Failed to capture image")
63 |
64 | yield img, None
65 |
66 |
67 | def yield_images_from_dir(img_dir):
68 | img_dir = Path(img_dir)
69 |
70 | for img_path in img_dir.glob("*.*"):
71 | img = cv2.imread(str(img_path), 1)
72 |
73 | if img is not None:
74 | h, w, _ = img.shape
75 | r = 640 / max(w, h)
76 | yield cv2.resize(img, (int(w * r), int(h * r))), img_path.name
77 |
78 |
79 | def main():
80 | args = get_args()
81 |
82 | if args.opts:
83 | cfg.merge_from_list(args.opts)
84 |
85 | cfg.freeze()
86 |
87 | if args.output_dir is not None:
88 | if args.img_dir is None:
89 | raise ValueError("=> --img_dir argument is required if --output_dir is used")
90 |
91 | output_dir = Path(args.output_dir)
92 | output_dir.mkdir(parents=True, exist_ok=True)
93 |
94 | # create model
95 | print("=> creating model '{}'".format(cfg.MODEL.ARCH))
96 | model = get_model(model_name=cfg.MODEL.ARCH, pretrained=None)
97 | device = "cuda" if torch.cuda.is_available() else "cpu"
98 | model = model.to(device)
99 |
100 | # load checkpoint
101 | resume_path = args.resume
102 |
103 | if resume_path is None:
104 | resume_path = Path(__file__).resolve().parent.joinpath("misc", "epoch044_0.02343_3.9984.pth")
105 |
106 | if not resume_path.is_file():
107 | print(f"=> model path is not set; start downloading trained model to {resume_path}")
108 | url = "https://github.com/yu4u/age-estimation-pytorch/releases/download/v1.0/epoch044_0.02343_3.9984.pth"
109 | urllib.request.urlretrieve(url, str(resume_path))
110 | print("=> download finished")
111 |
112 | if Path(resume_path).is_file():
113 | print("=> loading checkpoint '{}'".format(resume_path))
114 | checkpoint = torch.load(resume_path, map_location="cpu")
115 | model.load_state_dict(checkpoint['state_dict'])
116 | print("=> loaded checkpoint '{}'".format(resume_path))
117 | else:
118 | raise ValueError("=> no checkpoint found at '{}'".format(resume_path))
119 |
120 | if device == "cuda":
121 | cudnn.benchmark = True
122 |
123 | model.eval()
124 | margin = args.margin
125 | img_dir = args.img_dir
126 | detector = dlib.get_frontal_face_detector()
127 | img_size = cfg.MODEL.IMG_SIZE
128 | image_generator = yield_images_from_dir(img_dir) if img_dir else yield_images()
129 |
130 | with torch.no_grad():
131 | for img, name in image_generator:
132 | input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
133 | img_h, img_w, _ = np.shape(input_img)
134 |
135 | # detect faces using dlib detector
136 | detected = detector(input_img, 1)
137 | faces = np.empty((len(detected), img_size, img_size, 3))
138 |
139 | if len(detected) > 0:
140 | for i, d in enumerate(detected):
141 | x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height()
142 | xw1 = max(int(x1 - margin * w), 0)
143 | yw1 = max(int(y1 - margin * h), 0)
144 | xw2 = min(int(x2 + margin * w), img_w - 1)
145 | yw2 = min(int(y2 + margin * h), img_h - 1)
146 | cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2)
147 | cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)
148 | faces[i] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1], (img_size, img_size))
149 |
150 | # predict ages
151 | inputs = torch.from_numpy(np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device)
152 | outputs = F.softmax(model(inputs), dim=-1).cpu().numpy()
153 | ages = np.arange(0, 101)
154 | predicted_ages = (outputs * ages).sum(axis=-1)
155 |
156 | # draw results
157 | for i, d in enumerate(detected):
158 | label = "{}".format(int(predicted_ages[i]))
159 | draw_label(img, (d.left(), d.top()), label)
160 |
161 | if args.output_dir is not None:
162 | output_path = output_dir.joinpath(name)
163 | cv2.imwrite(str(output_path), img)
164 | else:
165 | cv2.imshow("result", img)
166 | key = cv2.waitKey(-1) if img_dir else cv2.waitKey(30)
167 |
168 | if key == 27: # ESC
169 | break
170 |
171 |
172 | if __name__ == '__main__':
173 | main()
174 |
--------------------------------------------------------------------------------
/ignore_list.csv:
--------------------------------------------------------------------------------
1 | img_name
2 | 000025.jpg
3 | 000049.jpg
4 | 000067.jpg
5 | 000085.jpg
6 | 000095.jpg
7 | 000100.jpg
8 | 000127.jpg
9 | 000145.jpg
10 | 000191.jpg
11 | 000215.jpg
12 | 000320.jpg
13 | 000373.jpg
14 | 000392.jpg
15 | 000407.jpg
16 | 000488.jpg
17 | 000503.jpg
18 | 000506.jpg
19 | 000510.jpg
20 | 000536.jpg
21 | 000605.jpg
22 | 000625.jpg
23 | 000639.jpg
24 | 000707.jpg
25 | 000708.jpg
26 | 000712.jpg
27 | 000813.jpg
28 | 000837.jpg
29 | 000848.jpg
30 | 000856.jpg
31 | 000891.jpg
32 | 000892.jpg
33 | 001022.jpg
34 | 001044.jpg
35 | 001095.jpg
36 | 001098.jpg
37 | 001122.jpg
38 | 001125.jpg
39 | 001137.jpg
40 | 001156.jpg
41 | 001227.jpg
42 | 001251.jpg
43 | 001267.jpg
44 | 001282.jpg
45 | 001328.jpg
46 | 001349.jpg
47 | 001380.jpg
48 | 001427.jpg
49 | 001460.jpg
50 | 001475.jpg
51 | 001697.jpg
52 | 001744.jpg
53 | 001864.jpg
54 | 001957.jpg
55 | 001968.jpg
56 | 001973.jpg
57 | 002029.jpg
58 | 002063.jpg
59 | 002109.jpg
60 | 002112.jpg
61 | 002115.jpg
62 | 002123.jpg
63 | 002162.jpg
64 | 002175.jpg
65 | 002179.jpg
66 | 002221.jpg
67 | 002250.jpg
68 | 002303.jpg
69 | 002359.jpg
70 | 002360.jpg
71 | 002412.jpg
72 | 002417.jpg
73 | 002435.jpg
74 | 002460.jpg
75 | 002466.jpg
76 | 002472.jpg
77 | 002488.jpg
78 | 002535.jpg
79 | 002543.jpg
80 | 002565.jpg
81 | 002615.jpg
82 | 002630.jpg
83 | 002633.jpg
84 | 002661.jpg
85 | 002733.jpg
86 | 002756.jpg
87 | 002860.jpg
88 | 002883.jpg
89 | 002887.jpg
90 | 002890.jpg
91 | 002948.jpg
92 | 002995.jpg
93 | 003018.jpg
94 | 003130.jpg
95 | 003164.jpg
96 | 003233.jpg
97 | 003258.jpg
98 | 003271.jpg
99 | 003329.jpg
100 | 003351.jpg
101 | 003357.jpg
102 | 003371.jpg
103 | 003415.jpg
104 | 003427.jpg
105 | 003441.jpg
106 | 003447.jpg
107 | 003458.jpg
108 | 003570.jpg
109 | 003625.jpg
110 | 003669.jpg
111 | 003711.jpg
112 | 003747.jpg
113 | 003749.jpg
114 | 003758.jpg
115 | 003763.jpg
116 | 003772.jpg
117 | 003805.jpg
118 | 003814.jpg
119 | 003903.jpg
120 |
--------------------------------------------------------------------------------
/misc/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yu4u/age-estimation-pytorch/4724da9cc87104f64959b9317787860c8e7fc0aa/misc/example.png
--------------------------------------------------------------------------------
/misc/tfboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yu4u/age-estimation-pytorch/4724da9cc87104f64959b9317787860c8e7fc0aa/misc/tfboard.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import pretrainedmodels
3 | import pretrainedmodels.utils
4 |
5 |
6 | def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"):
7 | model = pretrainedmodels.__dict__[model_name](pretrained=pretrained)
8 | dim_feats = model.last_linear.in_features
9 | model.last_linear = nn.Linear(dim_feats, num_classes)
10 | model.avg_pool = nn.AdaptiveAvgPool2d(1)
11 | return model
12 |
13 |
14 | def main():
15 | model = get_model()
16 | print(model)
17 |
18 |
19 | if __name__ == '__main__':
20 | main()
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | better-exceptions==0.2.2
2 | dlib==19.17.0
3 | future==0.17.1
4 | imgaug==0.2.9
5 | numpy==1.22.0
6 | opencv-python==4.2.0.32
7 | pandas==0.24.2
8 | pretrainedmodels==0.7.4
9 | tensorboard==1.14.0
10 | torch==1.1.0
11 | torchvision==0.3.0
12 | tqdm==4.32.2
13 | yacs==0.1.6
14 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import better_exceptions
3 | from pathlib import Path
4 | import torch
5 | import torch.nn.parallel
6 | import torch.backends.cudnn as cudnn
7 | import torch.optim
8 | import torch.utils.data
9 | from torch.utils.data import DataLoader
10 | import pretrainedmodels
11 | import pretrainedmodels.utils
12 | from model import get_model
13 | from dataset import FaceDataset
14 | from defaults import _C as cfg
15 | from train import validate
16 |
17 |
18 | def get_args():
19 | model_names = sorted(name for name in pretrainedmodels.__dict__
20 | if not name.startswith("__")
21 | and name.islower()
22 | and callable(pretrainedmodels.__dict__[name]))
23 | parser = argparse.ArgumentParser(description=f"available models: {model_names}",
24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
25 | parser.add_argument("--data_dir", type=str, required=True, help="Data root directory")
26 | parser.add_argument("--resume", type=str, required=True, help="Model weight to be tested")
27 | parser.add_argument("opts", default=[], nargs=argparse.REMAINDER,
28 | help="Modify config options using the command-line")
29 | args = parser.parse_args()
30 | return args
31 |
32 |
33 | def main():
34 | args = get_args()
35 |
36 | if args.opts:
37 | cfg.merge_from_list(args.opts)
38 |
39 | cfg.freeze()
40 |
41 | # create model
42 | print("=> creating model '{}'".format(cfg.MODEL.ARCH))
43 | model = get_model(model_name=cfg.MODEL.ARCH, pretrained=None)
44 | device = "cuda" if torch.cuda.is_available() else "cpu"
45 | model = model.to(device)
46 |
47 | # load checkpoint
48 | resume_path = args.resume
49 |
50 | if Path(resume_path).is_file():
51 | print("=> loading checkpoint '{}'".format(resume_path))
52 | checkpoint = torch.load(resume_path, map_location="cpu")
53 | model.load_state_dict(checkpoint['state_dict'])
54 | print("=> loaded checkpoint '{}'".format(resume_path))
55 | else:
56 | raise ValueError("=> no checkpoint found at '{}'".format(resume_path))
57 |
58 | if device == "cuda":
59 | cudnn.benchmark = True
60 |
61 | test_dataset = FaceDataset(args.data_dir, "test", img_size=cfg.MODEL.IMG_SIZE, augment=False)
62 | test_loader = DataLoader(test_dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False,
63 | num_workers=cfg.TRAIN.WORKERS, drop_last=False)
64 |
65 | print("=> start testing")
66 | _, _, test_mae = validate(test_loader, model, None, 0, device)
67 | print(f"test mae: {test_mae:.3f}")
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import better_exceptions
3 | from pathlib import Path
4 | from collections import OrderedDict
5 | from tqdm import tqdm
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 | from torch.optim.lr_scheduler import StepLR
13 | import torch.utils.data
14 | from torch.utils.data import DataLoader
15 | import torch.nn.functional as F
16 | from torch.utils.tensorboard import SummaryWriter
17 | import pretrainedmodels
18 | import pretrainedmodels.utils
19 | from model import get_model
20 | from dataset import FaceDataset
21 | from defaults import _C as cfg
22 |
23 |
24 | def get_args():
25 | model_names = sorted(name for name in pretrainedmodels.__dict__
26 | if not name.startswith("__")
27 | and name.islower()
28 | and callable(pretrainedmodels.__dict__[name]))
29 | parser = argparse.ArgumentParser(description=f"available models: {model_names}",
30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
31 | parser.add_argument("--data_dir", type=str, required=True, help="Data root directory")
32 | parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint if any")
33 | parser.add_argument("--checkpoint", type=str, default="checkpoint", help="Checkpoint directory")
34 | parser.add_argument("--tensorboard", type=str, default=None, help="Tensorboard log directory")
35 | parser.add_argument('--multi_gpu', action="store_true", help="Use multi GPUs (data parallel)")
36 | parser.add_argument("opts", default=[], nargs=argparse.REMAINDER,
37 | help="Modify config options using the command-line")
38 | args = parser.parse_args()
39 | return args
40 |
41 |
42 | class AverageMeter(object):
43 | def __init__(self):
44 | self.val = 0
45 | self.avg = 0
46 | self.sum = 0
47 | self.count = 0
48 |
49 | def update(self, val, n=1):
50 | self.val = val
51 | self.sum += val
52 | self.count += n
53 | self.avg = self.sum / self.count
54 |
55 |
56 | def train(train_loader, model, criterion, optimizer, epoch, device):
57 | model.train()
58 | loss_monitor = AverageMeter()
59 | accuracy_monitor = AverageMeter()
60 |
61 | with tqdm(train_loader) as _tqdm:
62 | for x, y in _tqdm:
63 | x = x.to(device)
64 | y = y.to(device)
65 |
66 | # compute output
67 | outputs = model(x)
68 |
69 | # calc loss
70 | loss = criterion(outputs, y)
71 | cur_loss = loss.item()
72 |
73 | # calc accuracy
74 | _, predicted = outputs.max(1)
75 | correct_num = predicted.eq(y).sum().item()
76 |
77 | # measure accuracy and record loss
78 | sample_num = x.size(0)
79 | loss_monitor.update(cur_loss, sample_num)
80 | accuracy_monitor.update(correct_num, sample_num)
81 |
82 | # compute gradient and do SGD step
83 | optimizer.zero_grad()
84 | loss.backward()
85 | optimizer.step()
86 |
87 | _tqdm.set_postfix(OrderedDict(stage="train", epoch=epoch, loss=loss_monitor.avg),
88 | acc=accuracy_monitor.avg, correct=correct_num, sample_num=sample_num)
89 |
90 | return loss_monitor.avg, accuracy_monitor.avg
91 |
92 |
93 | def validate(validate_loader, model, criterion, epoch, device):
94 | model.eval()
95 | loss_monitor = AverageMeter()
96 | accuracy_monitor = AverageMeter()
97 | preds = []
98 | gt = []
99 |
100 | with torch.no_grad():
101 | with tqdm(validate_loader) as _tqdm:
102 | for i, (x, y) in enumerate(_tqdm):
103 | x = x.to(device)
104 | y = y.to(device)
105 |
106 | # compute output
107 | outputs = model(x)
108 | preds.append(F.softmax(outputs, dim=-1).cpu().numpy())
109 | gt.append(y.cpu().numpy())
110 |
111 | # valid for validation, not used for test
112 | if criterion is not None:
113 | # calc loss
114 | loss = criterion(outputs, y)
115 | cur_loss = loss.item()
116 |
117 | # calc accuracy
118 | _, predicted = outputs.max(1)
119 | correct_num = predicted.eq(y).sum().item()
120 |
121 | # measure accuracy and record loss
122 | sample_num = x.size(0)
123 | loss_monitor.update(cur_loss, sample_num)
124 | accuracy_monitor.update(correct_num, sample_num)
125 | _tqdm.set_postfix(OrderedDict(stage="val", epoch=epoch, loss=loss_monitor.avg),
126 | acc=accuracy_monitor.avg, correct=correct_num, sample_num=sample_num)
127 |
128 | preds = np.concatenate(preds, axis=0)
129 | gt = np.concatenate(gt, axis=0)
130 | ages = np.arange(0, 101)
131 | ave_preds = (preds * ages).sum(axis=-1)
132 | diff = ave_preds - gt
133 | mae = np.abs(diff).mean()
134 |
135 | return loss_monitor.avg, accuracy_monitor.avg, mae
136 |
137 |
138 | def main():
139 | args = get_args()
140 |
141 | if args.opts:
142 | cfg.merge_from_list(args.opts)
143 |
144 | cfg.freeze()
145 | start_epoch = 0
146 | checkpoint_dir = Path(args.checkpoint)
147 | checkpoint_dir.mkdir(parents=True, exist_ok=True)
148 |
149 | # create model
150 | print("=> creating model '{}'".format(cfg.MODEL.ARCH))
151 | model = get_model(model_name=cfg.MODEL.ARCH)
152 |
153 | if cfg.TRAIN.OPT == "sgd":
154 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg.TRAIN.LR,
155 | momentum=cfg.TRAIN.MOMENTUM,
156 | weight_decay=cfg.TRAIN.WEIGHT_DECAY)
157 | else:
158 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.TRAIN.LR)
159 |
160 | device = "cuda" if torch.cuda.is_available() else "cpu"
161 | model = model.to(device)
162 |
163 | # optionally resume from a checkpoint
164 | resume_path = args.resume
165 |
166 | if resume_path:
167 | if Path(resume_path).is_file():
168 | print("=> loading checkpoint '{}'".format(resume_path))
169 | checkpoint = torch.load(resume_path, map_location="cpu")
170 | start_epoch = checkpoint['epoch']
171 | model.load_state_dict(checkpoint['state_dict'])
172 | print("=> loaded checkpoint '{}' (epoch {})"
173 | .format(resume_path, checkpoint['epoch']))
174 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
175 | else:
176 | print("=> no checkpoint found at '{}'".format(resume_path))
177 |
178 | if args.multi_gpu:
179 | model = nn.DataParallel(model)
180 |
181 | if device == "cuda":
182 | cudnn.benchmark = True
183 |
184 | criterion = nn.CrossEntropyLoss().to(device)
185 | train_dataset = FaceDataset(args.data_dir, "train", img_size=cfg.MODEL.IMG_SIZE, augment=True,
186 | age_stddev=cfg.TRAIN.AGE_STDDEV)
187 | train_loader = DataLoader(train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True,
188 | num_workers=cfg.TRAIN.WORKERS, drop_last=True)
189 |
190 | val_dataset = FaceDataset(args.data_dir, "valid", img_size=cfg.MODEL.IMG_SIZE, augment=False)
191 | val_loader = DataLoader(val_dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False,
192 | num_workers=cfg.TRAIN.WORKERS, drop_last=False)
193 |
194 | scheduler = StepLR(optimizer, step_size=cfg.TRAIN.LR_DECAY_STEP, gamma=cfg.TRAIN.LR_DECAY_RATE,
195 | last_epoch=start_epoch - 1)
196 | best_val_mae = 10000.0
197 | train_writer = None
198 |
199 | if args.tensorboard is not None:
200 | opts_prefix = "_".join(args.opts)
201 | train_writer = SummaryWriter(log_dir=args.tensorboard + "/" + opts_prefix + "_train")
202 | val_writer = SummaryWriter(log_dir=args.tensorboard + "/" + opts_prefix + "_val")
203 |
204 | for epoch in range(start_epoch, cfg.TRAIN.EPOCHS):
205 | # train
206 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, device)
207 |
208 | # validate
209 | val_loss, val_acc, val_mae = validate(val_loader, model, criterion, epoch, device)
210 |
211 | if args.tensorboard is not None:
212 | train_writer.add_scalar("loss", train_loss, epoch)
213 | train_writer.add_scalar("acc", train_acc, epoch)
214 | val_writer.add_scalar("loss", val_loss, epoch)
215 | val_writer.add_scalar("acc", val_acc, epoch)
216 | val_writer.add_scalar("mae", val_mae, epoch)
217 |
218 | # checkpoint
219 | if val_mae < best_val_mae:
220 | print(f"=> [epoch {epoch:03d}] best val mae was improved from {best_val_mae:.3f} to {val_mae:.3f}")
221 | model_state_dict = model.module.state_dict() if args.multi_gpu else model.state_dict()
222 | torch.save(
223 | {
224 | 'epoch': epoch + 1,
225 | 'arch': cfg.MODEL.ARCH,
226 | 'state_dict': model_state_dict,
227 | 'optimizer_state_dict': optimizer.state_dict()
228 | },
229 | str(checkpoint_dir.joinpath("epoch{:03d}_{:.5f}_{:.4f}.pth".format(epoch, val_loss, val_mae)))
230 | )
231 | best_val_mae = val_mae
232 | else:
233 | print(f"=> [epoch {epoch:03d}] best val mae was not improved from {best_val_mae:.3f} ({val_mae:.3f})")
234 |
235 | # adjust learning rate
236 | scheduler.step()
237 |
238 | print("=> training finished")
239 | print(f"additional opts: {args.opts}")
240 | print(f"best val mae: {best_val_mae:.3f}")
241 |
242 |
243 | if __name__ == '__main__':
244 | main()
245 |
--------------------------------------------------------------------------------