├── LICENSE.md ├── README.md ├── dataset.py ├── defaults.py ├── demo.py ├── ignore_list.csv ├── misc ├── example.png └── tfboard.png ├── model.py ├── requirements.txt ├── test.py └── train.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yusuke Uchida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Age Estimation PyTorch 2 | PyTorch-based CNN implementation for estimating age from face images. 3 | Currently only the APPA-REAL dataset is supported. 4 | Similar Keras-based project can be found [here](https://github.com/yu4u/age-gender-estimation). 5 | 6 | 7 | 8 | ## Requirements 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Demo 15 | Webcam is required. 16 | See `python demo.py -h` for detailed options. 17 | 18 | ```bash 19 | python demo.py 20 | ``` 21 | 22 | Using `--img_dir` argument, images in that directory will be used as input: 23 | 24 | ```bash 25 | python demo.py --img_dir [PATH/TO/IMAGE_DIRECTORY] 26 | ``` 27 | 28 | Further using `--output_dir` argument, 29 | resulting images will be saved in that directory (no resulting image window is displayed in this case): 30 | 31 | ```bash 32 | python demo.py --img_dir [PATH/TO/IMAGE_DIRECTORY] --output_dir [PATH/TO/OUTPUT_DIRECTORY] 33 | ``` 34 | 35 | ## Train 36 | 37 | #### Download Dataset 38 | 39 | Download and extract the [APPA-REAL dataset](http://chalearnlap.cvc.uab.es/dataset/26/description/). 40 | 41 | > The APPA-REAL database contains 7,591 images with associated real and apparent age labels. The total number of apparent votes is around 250,000. On average we have around 38 votes per each image and this makes the average apparent age very stable (0.3 standard error of the mean). 42 | 43 | ```bash 44 | wget http://158.109.8.102/AppaRealAge/appa-real-release.zip 45 | unzip appa-real-release.zip 46 | ``` 47 | 48 | #### Train Model 49 | Train a model using the APPA-REAL dataset. 50 | See `python train.py -h` for detailed options. 51 | 52 | ```bash 53 | python train.py --data_dir [PATH/TO/appa-real-release] --tensorboard tf_log 54 | ``` 55 | 56 | Check training progress: 57 | 58 | ```bash 59 | tensorboard --logdir=tf_log 60 | ``` 61 | 62 | 63 | 64 | #### Training Options 65 | You can change training parameters including model architecture using additional arguments like this: 66 | 67 | ```bash 68 | python train.py --data_dir [PATH/TO/appa-real-release] --tensorboard tf_log MODEL.ARCH se_resnet50 TRAIN.OPT sgd TRAIN.LR 0.1 69 | ``` 70 | 71 | All default parameters defined in [defaults.py](defaults.py) can be changed using this style. 72 | 73 | 74 | #### Test Trained Model 75 | Evaluate the trained model using the APPA-REAL test dataset. 76 | 77 | ```bash 78 | python test.py --data_dir [PATH/TO/appa-real-release] --resume [PATH/TO/BEST_MODEL.pth] 79 | ``` 80 | 81 | After evaluation, you can see something like this: 82 | 83 | ```bash 84 | 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:08<00:00, 1.28it/s] 85 | test mae: 4.800 86 | ``` 87 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import better_exceptions 3 | from pathlib import Path 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import cv2 8 | from torch.utils.data import Dataset 9 | from imgaug import augmenters as iaa 10 | 11 | 12 | class ImgAugTransform: 13 | def __init__(self): 14 | self.aug = iaa.Sequential([ 15 | iaa.OneOf([ 16 | iaa.Sometimes(0.25, iaa.AdditiveGaussianNoise(scale=0.1 * 255)), 17 | iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0))) 18 | ]), 19 | iaa.Affine( 20 | rotate=(-20, 20), mode="edge", 21 | scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, 22 | translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)} 23 | ), 24 | iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True), 25 | iaa.GammaContrast((0.3, 2)), 26 | iaa.Fliplr(0.5), 27 | ]) 28 | 29 | def __call__(self, img): 30 | img = np.array(img) 31 | img = self.aug.augment_image(img) 32 | return img 33 | 34 | 35 | class FaceDataset(Dataset): 36 | def __init__(self, data_dir, data_type, img_size=224, augment=False, age_stddev=1.0): 37 | assert(data_type in ("train", "valid", "test")) 38 | csv_path = Path(data_dir).joinpath(f"gt_avg_{data_type}.csv") 39 | img_dir = Path(data_dir).joinpath(data_type) 40 | self.img_size = img_size 41 | self.augment = augment 42 | self.age_stddev = age_stddev 43 | 44 | if augment: 45 | self.transform = ImgAugTransform() 46 | else: 47 | self.transform = lambda i: i 48 | 49 | self.x = [] 50 | self.y = [] 51 | self.std = [] 52 | df = pd.read_csv(str(csv_path)) 53 | ignore_path = Path(__file__).resolve().parent.joinpath("ignore_list.csv") 54 | ignore_img_names = list(pd.read_csv(str(ignore_path))["img_name"].values) 55 | 56 | for _, row in df.iterrows(): 57 | img_name = row["file_name"] 58 | 59 | if img_name in ignore_img_names: 60 | continue 61 | 62 | img_path = img_dir.joinpath(img_name + "_face.jpg") 63 | assert(img_path.is_file()) 64 | self.x.append(str(img_path)) 65 | self.y.append(row["apparent_age_avg"]) 66 | self.std.append(row["apparent_age_std"]) 67 | 68 | def __len__(self): 69 | return len(self.y) 70 | 71 | def __getitem__(self, idx): 72 | img_path = self.x[idx] 73 | age = self.y[idx] 74 | 75 | if self.augment: 76 | age += np.random.randn() * self.std[idx] * self.age_stddev 77 | 78 | img = cv2.imread(str(img_path), 1) 79 | img = cv2.resize(img, (self.img_size, self.img_size)) 80 | img = self.transform(img).astype(np.float32) 81 | return torch.from_numpy(np.transpose(img, (2, 0, 1))), np.clip(round(age), 0, 100) 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 86 | parser.add_argument("--data_dir", type=str, required=True) 87 | args = parser.parse_args() 88 | dataset = FaceDataset(args.data_dir, "train") 89 | print("train dataset len: {}".format(len(dataset))) 90 | dataset = FaceDataset(args.data_dir, "valid") 91 | print("valid dataset len: {}".format(len(dataset))) 92 | dataset = FaceDataset(args.data_dir, "test") 93 | print("test dataset len: {}".format(len(dataset))) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | 5 | # Model 6 | _C.MODEL = CN() 7 | _C.MODEL.ARCH = "se_resnext50_32x4d" # check python train.py -h for available models 8 | _C.MODEL.IMG_SIZE = 224 9 | 10 | # Train 11 | _C.TRAIN = CN() 12 | _C.TRAIN.OPT = "adam" # adam or sgd 13 | _C.TRAIN.WORKERS = 8 14 | _C.TRAIN.LR = 0.001 15 | _C.TRAIN.LR_DECAY_STEP = 20 16 | _C.TRAIN.LR_DECAY_RATE = 0.2 17 | _C.TRAIN.MOMENTUM = 0.9 18 | _C.TRAIN.WEIGHT_DECAY = 0.0 19 | _C.TRAIN.BATCH_SIZE = 128 20 | _C.TRAIN.EPOCHS = 80 21 | _C.TRAIN.AGE_STDDEV = 1.0 22 | 23 | # Test 24 | _C.TEST = CN() 25 | _C.TEST.WORKERS = 8 26 | _C.TEST.BATCH_SIZE = 128 27 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import better_exceptions 3 | from pathlib import Path 4 | from contextlib import contextmanager 5 | import urllib.request 6 | import numpy as np 7 | import cv2 8 | import dlib 9 | import torch 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.nn.functional as F 15 | from model import get_model 16 | from defaults import _C as cfg 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser(description="Age estimation demo", 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument("--resume", type=str, default=None, 23 | help="Model weight to be tested") 24 | parser.add_argument("--margin", type=float, default=0.4, 25 | help="Margin around detected face for age-gender estimation") 26 | parser.add_argument("--img_dir", type=str, default=None, 27 | help="Target image directory; if set, images in image_dir are used instead of webcam") 28 | parser.add_argument("--output_dir", type=str, default=None, 29 | help="Output directory to which resulting images will be stored if set") 30 | parser.add_argument("opts", default=[], nargs=argparse.REMAINDER, 31 | help="Modify config options using the command-line") 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX, 37 | font_scale=0.8, thickness=1): 38 | size = cv2.getTextSize(label, font, font_scale, thickness)[0] 39 | x, y = point 40 | cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED) 41 | cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness, lineType=cv2.LINE_AA) 42 | 43 | 44 | @contextmanager 45 | def video_capture(*args, **kwargs): 46 | cap = cv2.VideoCapture(*args, **kwargs) 47 | try: 48 | yield cap 49 | finally: 50 | cap.release() 51 | 52 | 53 | def yield_images(): 54 | with video_capture(0) as cap: 55 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) 56 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) 57 | 58 | while True: 59 | ret, img = cap.read() 60 | 61 | if not ret: 62 | raise RuntimeError("Failed to capture image") 63 | 64 | yield img, None 65 | 66 | 67 | def yield_images_from_dir(img_dir): 68 | img_dir = Path(img_dir) 69 | 70 | for img_path in img_dir.glob("*.*"): 71 | img = cv2.imread(str(img_path), 1) 72 | 73 | if img is not None: 74 | h, w, _ = img.shape 75 | r = 640 / max(w, h) 76 | yield cv2.resize(img, (int(w * r), int(h * r))), img_path.name 77 | 78 | 79 | def main(): 80 | args = get_args() 81 | 82 | if args.opts: 83 | cfg.merge_from_list(args.opts) 84 | 85 | cfg.freeze() 86 | 87 | if args.output_dir is not None: 88 | if args.img_dir is None: 89 | raise ValueError("=> --img_dir argument is required if --output_dir is used") 90 | 91 | output_dir = Path(args.output_dir) 92 | output_dir.mkdir(parents=True, exist_ok=True) 93 | 94 | # create model 95 | print("=> creating model '{}'".format(cfg.MODEL.ARCH)) 96 | model = get_model(model_name=cfg.MODEL.ARCH, pretrained=None) 97 | device = "cuda" if torch.cuda.is_available() else "cpu" 98 | model = model.to(device) 99 | 100 | # load checkpoint 101 | resume_path = args.resume 102 | 103 | if resume_path is None: 104 | resume_path = Path(__file__).resolve().parent.joinpath("misc", "epoch044_0.02343_3.9984.pth") 105 | 106 | if not resume_path.is_file(): 107 | print(f"=> model path is not set; start downloading trained model to {resume_path}") 108 | url = "https://github.com/yu4u/age-estimation-pytorch/releases/download/v1.0/epoch044_0.02343_3.9984.pth" 109 | urllib.request.urlretrieve(url, str(resume_path)) 110 | print("=> download finished") 111 | 112 | if Path(resume_path).is_file(): 113 | print("=> loading checkpoint '{}'".format(resume_path)) 114 | checkpoint = torch.load(resume_path, map_location="cpu") 115 | model.load_state_dict(checkpoint['state_dict']) 116 | print("=> loaded checkpoint '{}'".format(resume_path)) 117 | else: 118 | raise ValueError("=> no checkpoint found at '{}'".format(resume_path)) 119 | 120 | if device == "cuda": 121 | cudnn.benchmark = True 122 | 123 | model.eval() 124 | margin = args.margin 125 | img_dir = args.img_dir 126 | detector = dlib.get_frontal_face_detector() 127 | img_size = cfg.MODEL.IMG_SIZE 128 | image_generator = yield_images_from_dir(img_dir) if img_dir else yield_images() 129 | 130 | with torch.no_grad(): 131 | for img, name in image_generator: 132 | input_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 133 | img_h, img_w, _ = np.shape(input_img) 134 | 135 | # detect faces using dlib detector 136 | detected = detector(input_img, 1) 137 | faces = np.empty((len(detected), img_size, img_size, 3)) 138 | 139 | if len(detected) > 0: 140 | for i, d in enumerate(detected): 141 | x1, y1, x2, y2, w, h = d.left(), d.top(), d.right() + 1, d.bottom() + 1, d.width(), d.height() 142 | xw1 = max(int(x1 - margin * w), 0) 143 | yw1 = max(int(y1 - margin * h), 0) 144 | xw2 = min(int(x2 + margin * w), img_w - 1) 145 | yw2 = min(int(y2 + margin * h), img_h - 1) 146 | cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2) 147 | cv2.rectangle(img, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2) 148 | faces[i] = cv2.resize(img[yw1:yw2 + 1, xw1:xw2 + 1], (img_size, img_size)) 149 | 150 | # predict ages 151 | inputs = torch.from_numpy(np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device) 152 | outputs = F.softmax(model(inputs), dim=-1).cpu().numpy() 153 | ages = np.arange(0, 101) 154 | predicted_ages = (outputs * ages).sum(axis=-1) 155 | 156 | # draw results 157 | for i, d in enumerate(detected): 158 | label = "{}".format(int(predicted_ages[i])) 159 | draw_label(img, (d.left(), d.top()), label) 160 | 161 | if args.output_dir is not None: 162 | output_path = output_dir.joinpath(name) 163 | cv2.imwrite(str(output_path), img) 164 | else: 165 | cv2.imshow("result", img) 166 | key = cv2.waitKey(-1) if img_dir else cv2.waitKey(30) 167 | 168 | if key == 27: # ESC 169 | break 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /ignore_list.csv: -------------------------------------------------------------------------------- 1 | img_name 2 | 000025.jpg 3 | 000049.jpg 4 | 000067.jpg 5 | 000085.jpg 6 | 000095.jpg 7 | 000100.jpg 8 | 000127.jpg 9 | 000145.jpg 10 | 000191.jpg 11 | 000215.jpg 12 | 000320.jpg 13 | 000373.jpg 14 | 000392.jpg 15 | 000407.jpg 16 | 000488.jpg 17 | 000503.jpg 18 | 000506.jpg 19 | 000510.jpg 20 | 000536.jpg 21 | 000605.jpg 22 | 000625.jpg 23 | 000639.jpg 24 | 000707.jpg 25 | 000708.jpg 26 | 000712.jpg 27 | 000813.jpg 28 | 000837.jpg 29 | 000848.jpg 30 | 000856.jpg 31 | 000891.jpg 32 | 000892.jpg 33 | 001022.jpg 34 | 001044.jpg 35 | 001095.jpg 36 | 001098.jpg 37 | 001122.jpg 38 | 001125.jpg 39 | 001137.jpg 40 | 001156.jpg 41 | 001227.jpg 42 | 001251.jpg 43 | 001267.jpg 44 | 001282.jpg 45 | 001328.jpg 46 | 001349.jpg 47 | 001380.jpg 48 | 001427.jpg 49 | 001460.jpg 50 | 001475.jpg 51 | 001697.jpg 52 | 001744.jpg 53 | 001864.jpg 54 | 001957.jpg 55 | 001968.jpg 56 | 001973.jpg 57 | 002029.jpg 58 | 002063.jpg 59 | 002109.jpg 60 | 002112.jpg 61 | 002115.jpg 62 | 002123.jpg 63 | 002162.jpg 64 | 002175.jpg 65 | 002179.jpg 66 | 002221.jpg 67 | 002250.jpg 68 | 002303.jpg 69 | 002359.jpg 70 | 002360.jpg 71 | 002412.jpg 72 | 002417.jpg 73 | 002435.jpg 74 | 002460.jpg 75 | 002466.jpg 76 | 002472.jpg 77 | 002488.jpg 78 | 002535.jpg 79 | 002543.jpg 80 | 002565.jpg 81 | 002615.jpg 82 | 002630.jpg 83 | 002633.jpg 84 | 002661.jpg 85 | 002733.jpg 86 | 002756.jpg 87 | 002860.jpg 88 | 002883.jpg 89 | 002887.jpg 90 | 002890.jpg 91 | 002948.jpg 92 | 002995.jpg 93 | 003018.jpg 94 | 003130.jpg 95 | 003164.jpg 96 | 003233.jpg 97 | 003258.jpg 98 | 003271.jpg 99 | 003329.jpg 100 | 003351.jpg 101 | 003357.jpg 102 | 003371.jpg 103 | 003415.jpg 104 | 003427.jpg 105 | 003441.jpg 106 | 003447.jpg 107 | 003458.jpg 108 | 003570.jpg 109 | 003625.jpg 110 | 003669.jpg 111 | 003711.jpg 112 | 003747.jpg 113 | 003749.jpg 114 | 003758.jpg 115 | 003763.jpg 116 | 003772.jpg 117 | 003805.jpg 118 | 003814.jpg 119 | 003903.jpg 120 | -------------------------------------------------------------------------------- /misc/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu4u/age-estimation-pytorch/4724da9cc87104f64959b9317787860c8e7fc0aa/misc/example.png -------------------------------------------------------------------------------- /misc/tfboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu4u/age-estimation-pytorch/4724da9cc87104f64959b9317787860c8e7fc0aa/misc/tfboard.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import pretrainedmodels 3 | import pretrainedmodels.utils 4 | 5 | 6 | def get_model(model_name="se_resnext50_32x4d", num_classes=101, pretrained="imagenet"): 7 | model = pretrainedmodels.__dict__[model_name](pretrained=pretrained) 8 | dim_feats = model.last_linear.in_features 9 | model.last_linear = nn.Linear(dim_feats, num_classes) 10 | model.avg_pool = nn.AdaptiveAvgPool2d(1) 11 | return model 12 | 13 | 14 | def main(): 15 | model = get_model() 16 | print(model) 17 | 18 | 19 | if __name__ == '__main__': 20 | main() 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | better-exceptions==0.2.2 2 | dlib==19.17.0 3 | future==0.17.1 4 | imgaug==0.2.9 5 | numpy==1.22.0 6 | opencv-python==4.2.0.32 7 | pandas==0.24.2 8 | pretrainedmodels==0.7.4 9 | tensorboard==1.14.0 10 | torch==1.1.0 11 | torchvision==0.3.0 12 | tqdm==4.32.2 13 | yacs==0.1.6 14 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import better_exceptions 3 | from pathlib import Path 4 | import torch 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim 8 | import torch.utils.data 9 | from torch.utils.data import DataLoader 10 | import pretrainedmodels 11 | import pretrainedmodels.utils 12 | from model import get_model 13 | from dataset import FaceDataset 14 | from defaults import _C as cfg 15 | from train import validate 16 | 17 | 18 | def get_args(): 19 | model_names = sorted(name for name in pretrainedmodels.__dict__ 20 | if not name.startswith("__") 21 | and name.islower() 22 | and callable(pretrainedmodels.__dict__[name])) 23 | parser = argparse.ArgumentParser(description=f"available models: {model_names}", 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 25 | parser.add_argument("--data_dir", type=str, required=True, help="Data root directory") 26 | parser.add_argument("--resume", type=str, required=True, help="Model weight to be tested") 27 | parser.add_argument("opts", default=[], nargs=argparse.REMAINDER, 28 | help="Modify config options using the command-line") 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def main(): 34 | args = get_args() 35 | 36 | if args.opts: 37 | cfg.merge_from_list(args.opts) 38 | 39 | cfg.freeze() 40 | 41 | # create model 42 | print("=> creating model '{}'".format(cfg.MODEL.ARCH)) 43 | model = get_model(model_name=cfg.MODEL.ARCH, pretrained=None) 44 | device = "cuda" if torch.cuda.is_available() else "cpu" 45 | model = model.to(device) 46 | 47 | # load checkpoint 48 | resume_path = args.resume 49 | 50 | if Path(resume_path).is_file(): 51 | print("=> loading checkpoint '{}'".format(resume_path)) 52 | checkpoint = torch.load(resume_path, map_location="cpu") 53 | model.load_state_dict(checkpoint['state_dict']) 54 | print("=> loaded checkpoint '{}'".format(resume_path)) 55 | else: 56 | raise ValueError("=> no checkpoint found at '{}'".format(resume_path)) 57 | 58 | if device == "cuda": 59 | cudnn.benchmark = True 60 | 61 | test_dataset = FaceDataset(args.data_dir, "test", img_size=cfg.MODEL.IMG_SIZE, augment=False) 62 | test_loader = DataLoader(test_dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, 63 | num_workers=cfg.TRAIN.WORKERS, drop_last=False) 64 | 65 | print("=> start testing") 66 | _, _, test_mae = validate(test_loader, model, None, 0, device) 67 | print(f"test mae: {test_mae:.3f}") 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import better_exceptions 3 | from pathlib import Path 4 | from collections import OrderedDict 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | from torch.optim.lr_scheduler import StepLR 13 | import torch.utils.data 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | from torch.utils.tensorboard import SummaryWriter 17 | import pretrainedmodels 18 | import pretrainedmodels.utils 19 | from model import get_model 20 | from dataset import FaceDataset 21 | from defaults import _C as cfg 22 | 23 | 24 | def get_args(): 25 | model_names = sorted(name for name in pretrainedmodels.__dict__ 26 | if not name.startswith("__") 27 | and name.islower() 28 | and callable(pretrainedmodels.__dict__[name])) 29 | parser = argparse.ArgumentParser(description=f"available models: {model_names}", 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | parser.add_argument("--data_dir", type=str, required=True, help="Data root directory") 32 | parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint if any") 33 | parser.add_argument("--checkpoint", type=str, default="checkpoint", help="Checkpoint directory") 34 | parser.add_argument("--tensorboard", type=str, default=None, help="Tensorboard log directory") 35 | parser.add_argument('--multi_gpu', action="store_true", help="Use multi GPUs (data parallel)") 36 | parser.add_argument("opts", default=[], nargs=argparse.REMAINDER, 37 | help="Modify config options using the command-line") 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | class AverageMeter(object): 43 | def __init__(self): 44 | self.val = 0 45 | self.avg = 0 46 | self.sum = 0 47 | self.count = 0 48 | 49 | def update(self, val, n=1): 50 | self.val = val 51 | self.sum += val 52 | self.count += n 53 | self.avg = self.sum / self.count 54 | 55 | 56 | def train(train_loader, model, criterion, optimizer, epoch, device): 57 | model.train() 58 | loss_monitor = AverageMeter() 59 | accuracy_monitor = AverageMeter() 60 | 61 | with tqdm(train_loader) as _tqdm: 62 | for x, y in _tqdm: 63 | x = x.to(device) 64 | y = y.to(device) 65 | 66 | # compute output 67 | outputs = model(x) 68 | 69 | # calc loss 70 | loss = criterion(outputs, y) 71 | cur_loss = loss.item() 72 | 73 | # calc accuracy 74 | _, predicted = outputs.max(1) 75 | correct_num = predicted.eq(y).sum().item() 76 | 77 | # measure accuracy and record loss 78 | sample_num = x.size(0) 79 | loss_monitor.update(cur_loss, sample_num) 80 | accuracy_monitor.update(correct_num, sample_num) 81 | 82 | # compute gradient and do SGD step 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | 87 | _tqdm.set_postfix(OrderedDict(stage="train", epoch=epoch, loss=loss_monitor.avg), 88 | acc=accuracy_monitor.avg, correct=correct_num, sample_num=sample_num) 89 | 90 | return loss_monitor.avg, accuracy_monitor.avg 91 | 92 | 93 | def validate(validate_loader, model, criterion, epoch, device): 94 | model.eval() 95 | loss_monitor = AverageMeter() 96 | accuracy_monitor = AverageMeter() 97 | preds = [] 98 | gt = [] 99 | 100 | with torch.no_grad(): 101 | with tqdm(validate_loader) as _tqdm: 102 | for i, (x, y) in enumerate(_tqdm): 103 | x = x.to(device) 104 | y = y.to(device) 105 | 106 | # compute output 107 | outputs = model(x) 108 | preds.append(F.softmax(outputs, dim=-1).cpu().numpy()) 109 | gt.append(y.cpu().numpy()) 110 | 111 | # valid for validation, not used for test 112 | if criterion is not None: 113 | # calc loss 114 | loss = criterion(outputs, y) 115 | cur_loss = loss.item() 116 | 117 | # calc accuracy 118 | _, predicted = outputs.max(1) 119 | correct_num = predicted.eq(y).sum().item() 120 | 121 | # measure accuracy and record loss 122 | sample_num = x.size(0) 123 | loss_monitor.update(cur_loss, sample_num) 124 | accuracy_monitor.update(correct_num, sample_num) 125 | _tqdm.set_postfix(OrderedDict(stage="val", epoch=epoch, loss=loss_monitor.avg), 126 | acc=accuracy_monitor.avg, correct=correct_num, sample_num=sample_num) 127 | 128 | preds = np.concatenate(preds, axis=0) 129 | gt = np.concatenate(gt, axis=0) 130 | ages = np.arange(0, 101) 131 | ave_preds = (preds * ages).sum(axis=-1) 132 | diff = ave_preds - gt 133 | mae = np.abs(diff).mean() 134 | 135 | return loss_monitor.avg, accuracy_monitor.avg, mae 136 | 137 | 138 | def main(): 139 | args = get_args() 140 | 141 | if args.opts: 142 | cfg.merge_from_list(args.opts) 143 | 144 | cfg.freeze() 145 | start_epoch = 0 146 | checkpoint_dir = Path(args.checkpoint) 147 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 148 | 149 | # create model 150 | print("=> creating model '{}'".format(cfg.MODEL.ARCH)) 151 | model = get_model(model_name=cfg.MODEL.ARCH) 152 | 153 | if cfg.TRAIN.OPT == "sgd": 154 | optimizer = torch.optim.SGD(model.parameters(), lr=cfg.TRAIN.LR, 155 | momentum=cfg.TRAIN.MOMENTUM, 156 | weight_decay=cfg.TRAIN.WEIGHT_DECAY) 157 | else: 158 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.TRAIN.LR) 159 | 160 | device = "cuda" if torch.cuda.is_available() else "cpu" 161 | model = model.to(device) 162 | 163 | # optionally resume from a checkpoint 164 | resume_path = args.resume 165 | 166 | if resume_path: 167 | if Path(resume_path).is_file(): 168 | print("=> loading checkpoint '{}'".format(resume_path)) 169 | checkpoint = torch.load(resume_path, map_location="cpu") 170 | start_epoch = checkpoint['epoch'] 171 | model.load_state_dict(checkpoint['state_dict']) 172 | print("=> loaded checkpoint '{}' (epoch {})" 173 | .format(resume_path, checkpoint['epoch'])) 174 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 175 | else: 176 | print("=> no checkpoint found at '{}'".format(resume_path)) 177 | 178 | if args.multi_gpu: 179 | model = nn.DataParallel(model) 180 | 181 | if device == "cuda": 182 | cudnn.benchmark = True 183 | 184 | criterion = nn.CrossEntropyLoss().to(device) 185 | train_dataset = FaceDataset(args.data_dir, "train", img_size=cfg.MODEL.IMG_SIZE, augment=True, 186 | age_stddev=cfg.TRAIN.AGE_STDDEV) 187 | train_loader = DataLoader(train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, 188 | num_workers=cfg.TRAIN.WORKERS, drop_last=True) 189 | 190 | val_dataset = FaceDataset(args.data_dir, "valid", img_size=cfg.MODEL.IMG_SIZE, augment=False) 191 | val_loader = DataLoader(val_dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, 192 | num_workers=cfg.TRAIN.WORKERS, drop_last=False) 193 | 194 | scheduler = StepLR(optimizer, step_size=cfg.TRAIN.LR_DECAY_STEP, gamma=cfg.TRAIN.LR_DECAY_RATE, 195 | last_epoch=start_epoch - 1) 196 | best_val_mae = 10000.0 197 | train_writer = None 198 | 199 | if args.tensorboard is not None: 200 | opts_prefix = "_".join(args.opts) 201 | train_writer = SummaryWriter(log_dir=args.tensorboard + "/" + opts_prefix + "_train") 202 | val_writer = SummaryWriter(log_dir=args.tensorboard + "/" + opts_prefix + "_val") 203 | 204 | for epoch in range(start_epoch, cfg.TRAIN.EPOCHS): 205 | # train 206 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, device) 207 | 208 | # validate 209 | val_loss, val_acc, val_mae = validate(val_loader, model, criterion, epoch, device) 210 | 211 | if args.tensorboard is not None: 212 | train_writer.add_scalar("loss", train_loss, epoch) 213 | train_writer.add_scalar("acc", train_acc, epoch) 214 | val_writer.add_scalar("loss", val_loss, epoch) 215 | val_writer.add_scalar("acc", val_acc, epoch) 216 | val_writer.add_scalar("mae", val_mae, epoch) 217 | 218 | # checkpoint 219 | if val_mae < best_val_mae: 220 | print(f"=> [epoch {epoch:03d}] best val mae was improved from {best_val_mae:.3f} to {val_mae:.3f}") 221 | model_state_dict = model.module.state_dict() if args.multi_gpu else model.state_dict() 222 | torch.save( 223 | { 224 | 'epoch': epoch + 1, 225 | 'arch': cfg.MODEL.ARCH, 226 | 'state_dict': model_state_dict, 227 | 'optimizer_state_dict': optimizer.state_dict() 228 | }, 229 | str(checkpoint_dir.joinpath("epoch{:03d}_{:.5f}_{:.4f}.pth".format(epoch, val_loss, val_mae))) 230 | ) 231 | best_val_mae = val_mae 232 | else: 233 | print(f"=> [epoch {epoch:03d}] best val mae was not improved from {best_val_mae:.3f} ({val_mae:.3f})") 234 | 235 | # adjust learning rate 236 | scheduler.step() 237 | 238 | print("=> training finished") 239 | print(f"additional opts: {args.opts}") 240 | print(f"best val mae: {best_val_mae:.3f}") 241 | 242 | 243 | if __name__ == '__main__': 244 | main() 245 | --------------------------------------------------------------------------------