├── codes └── adience_poe │ ├── poe │ ├── __init__.py │ ├── dataset_manger.py │ ├── metrics.py │ ├── utils.py │ ├── options.py │ ├── probordiloss.py │ └── vgg.py │ ├── log │ └── placeholder.txt │ ├── misc │ └── metric_summary.py │ ├── README.md │ ├── scripts │ ├── test_poe.sh │ ├── train_poe.sh │ ├── test_baseline.sh │ └── train_baseline.sh │ ├── environment.yaml │ ├── test.py │ └── train.py ├── imgs └── framework.png └── README.md /codes/adience_poe/poe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /codes/adience_poe/log/placeholder.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Li-Wanhua/POEs/HEAD/imgs/framework.png -------------------------------------------------------------------------------- /codes/adience_poe/misc/metric_summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | 6 | def process_one_exp(path): 7 | out_ls = [] 8 | for fold_dir in os.listdir(path): 9 | npy_path = os.path.join(path, fold_dir, 'test_result.npy') 10 | if os.path.exists(npy_path) and os.path.isfile(npy_path): 11 | test_result_mat = np.load(npy_path) 12 | out_ls.append(test_result_mat) 13 | 14 | if len(out_ls) > 0: 15 | results_mat = np.vstack(out_ls) 16 | print(results_mat.shape) 17 | means = results_mat.mean(axis=0) 18 | stds = results_mat.std(axis=0) 19 | results = [(mean, std) for mean, std in zip(means, stds)] 20 | print("{}".format(path)) 21 | for k, val in zip(["mae", "acc"], results[1:-1:1]): 22 | print("{}: {}+-{}".format(k, *val)) 23 | 24 | 25 | #%% 26 | if __name__ == "__main__": 27 | AP = argparse.ArgumentParser() 28 | AP.add_argument("--logdir", type=str, default="./log") 29 | args = AP.parse_args() 30 | 31 | os.chdir(args.logdir) 32 | for root_dir in os.listdir('.'): 33 | if os.path.isdir(root_dir): 34 | process_one_exp(root_dir) 35 | -------------------------------------------------------------------------------- /codes/adience_poe/README.md: -------------------------------------------------------------------------------- 1 | # POEs 2 | 3 | PyTorch re-implementation of Learning Probabilistic Ordinal Embeddings for Uncertainty-Aware Regression (CVPR 2021)[[project page](https://li-wanhua.github.io/POEs/)] 4 | 5 | # Codes for Adience Dataset 6 | [Adience Dataset](https://talhassner.github.io/home/projects/Adience/Adience-data.html) 7 | 8 | ## Prepare Environment 9 | Simply create a conda environment by: 10 | ```bash 11 | conda create -f environment.yaml 12 | ``` 13 | The codes is on test on pytorch==1.0.0, but higher version of pytorch should be ok. 14 | 15 | ## Train 16 | Configure the data-related paths in `scripts/*.sh`, specifically the `--train-images-root`, `--test-images-root`, `--train-data-file`, and `--test-data-file` flags. 17 | 18 | ```bash 19 | # Train POEs / baselines 20 | # model_type should be in ['reg', 'cls', 'rank'] 21 | bash ./scripts/train_poe.sh [id_of_gpu='0'] [model_type='cls'] 22 | bash ./scripts/train_baseline.sh [id_of_gpu='0'] [model_type='cls'] 23 | ``` 24 | ## Test 25 | ```bash 26 | # Test POEs / baselines 27 | # model_type should be in ['reg', 'cls', 'rank'] 28 | bash ./scripts/test_poe.sh [id_of_gpu='0'] [model_type='cls'] 29 | bash ./scripts/test_baseline.sh [id_of_gpu='0'] [model_type='cls'] 30 | ``` 31 | ## Performance Summary 32 | ```bash 33 | python ./misc/metric_summary.py 34 | ``` 35 | -------------------------------------------------------------------------------- /codes/adience_poe/scripts/test_poe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gpu='0' 3 | 4 | main_loss_type='cls' 5 | num_output_neurons=8 6 | 7 | # main_loss_type='rank' 8 | # num_output_neurons=16 9 | 10 | # main_loss_type='reg' 11 | # num_output_neurons=1 12 | 13 | 14 | if [[ $# = 1 ]]; then 15 | gpu=${1} 16 | fi 17 | 18 | if [[ $# = 2 ]]; then 19 | gpu=${1} 20 | main_loss_type=${2} 21 | if [[ $main_loss_type = 'cls' ]]; then 22 | num_output_neurons=8 23 | fi 24 | if [[ $main_loss_type = 'rank' ]]; then 25 | num_output_neurons=16 26 | fi 27 | if [[ $main_loss_type = 'reg' ]]; then 28 | num_output_neurons=1 29 | fi 30 | fi 31 | 32 | 33 | for fold in $(seq 0 1 4); do 34 | CUDA_VISIBLE_DEVICES=${gpu} python -u test.py \ 35 | --batch-size=32 \ 36 | --test-batch-size=32 \ 37 | --max-epochs=10 --lr-decay-epoch=7 --lr=0.0001 --fc-lr=0.0001 \ 38 | --num-output-neurons=${num_output_neurons} --main-loss-type=${main_loss_type} \ 39 | --save-model='./Save_Model' \ 40 | --train-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \ 41 | --test-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \ 42 | --train-data-file=./data_list/test_fold_is_${fold}/age_train.txt \ 43 | --test-data-file=./data_list/test_fold_is_${fold}/age_test.txt \ 44 | --num-workers=4 --distance='JDistance' \ 45 | --alpha-coeff=1e-5 --beta-coeff=1e-4 --margin=5 \ 46 | --exp-name=train_val_fold_${fold} \ 47 | --logdir=./log/adience_${main_loss_type} \ 48 | > ./log/adience_${main_loss_type}_test_fold_${fold}.txt 2>&1 49 | done -------------------------------------------------------------------------------- /codes/adience_poe/scripts/train_poe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gpu='0' 3 | 4 | main_loss_type='cls' 5 | num_output_neurons=8 6 | 7 | # main_loss_type='rank' 8 | # num_output_neurons=16 9 | 10 | # main_loss_type='reg' 11 | # num_output_neurons=1 12 | 13 | 14 | if [[ $# = 1 ]]; then 15 | gpu=${1} 16 | fi 17 | 18 | if [[ $# = 2 ]]; then 19 | gpu=${1} 20 | main_loss_type=${2} 21 | if [[ $main_loss_type = 'cls' ]]; then 22 | num_output_neurons=8 23 | fi 24 | if [[ $main_loss_type = 'rank' ]]; then 25 | num_output_neurons=16 26 | fi 27 | if [[ $main_loss_type = 'reg' ]]; then 28 | num_output_neurons=1 29 | fi 30 | fi 31 | 32 | 33 | for fold in $(seq 0 1 4); do 34 | CUDA_VISIBLE_DEVICES=${gpu} python -u train.py \ 35 | --batch-size=32 \ 36 | --test-batch-size=32 \ 37 | --max-epochs=50 --lr-decay-epoch=30 --lr=0.0001 --fc-lr=0.0001 \ 38 | --num-output-neurons=${num_output_neurons} --main-loss-type=${main_loss_type} \ 39 | --save-model='./Save_Model' \ 40 | --train-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \ 41 | --test-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \ 42 | --train-data-file=./data_list/test_fold_is_${fold}/age_train.txt \ 43 | --test-data-file=./data_list/test_fold_is_${fold}/age_val.txt \ 44 | --num-workers=4 --distance='JDistance' \ 45 | --alpha-coeff=1e-4 --beta-coeff=1e-4 --margin=5 \ 46 | --exp-name=train_val_fold_${fold} \ 47 | --logdir=./log/adience_${main_loss_type} \ 48 | > ./log/adience_${main_loss_type}_train_val_fold_${fold}.txt 2>&1 49 | done 50 | -------------------------------------------------------------------------------- /codes/adience_poe/scripts/test_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gpu='0' 3 | 4 | main_loss_type='cls' 5 | num_output_neurons=8 6 | 7 | # main_loss_type='rank' 8 | # num_output_neurons=16 9 | 10 | # main_loss_type='reg' 11 | # num_output_neurons=1 12 | 13 | 14 | if [[ $# = 1 ]]; then 15 | gpu=${1} 16 | fi 17 | 18 | if [[ $# = 2 ]]; then 19 | gpu=${1} 20 | main_loss_type=${2} 21 | if [[ $main_loss_type = 'cls' ]]; then 22 | num_output_neurons=8 23 | fi 24 | if [[ $main_loss_type = 'rank' ]]; then 25 | num_output_neurons=16 26 | fi 27 | if [[ $main_loss_type = 'reg' ]]; then 28 | num_output_neurons=1 29 | fi 30 | fi 31 | 32 | for fold in $(seq 0 1 4); do 33 | CUDA_VISIBLE_DEVICES=${gpu} python -u test.py \ 34 | --batch-size=32 \ 35 | --test-batch-size=32 \ 36 | --max-epochs=10 --lr-decay-epoch=7 --lr=0.0001 --fc-lr=0.0001 \ 37 | --num-output-neurons=${num_output_neurons} --main-loss-type=${main_loss_type} \ 38 | --save-model='./Save_Model' \ 39 | --train-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \ 40 | --test-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \ 41 | --train-data-file=./data_list/test_fold_is_${fold}/age_train.txt \ 42 | --test-data-file=./data_list/test_fold_is_${fold}/age_test.txt \ 43 | --num-workers=4 --distance='JDistance' \ 44 | --alpha-coeff=1e-5 --beta-coeff=1e-4 --margin=5 \ 45 | --exp-name=train_val_fold_${fold} \ 46 | --logdir=./log/adience_baseline_${main_loss_type} \ 47 | --no-sto \ 48 | > ./log/adience_baseline_${main_loss_type}_test_fold_${fold}.txt 2>&1 49 | done -------------------------------------------------------------------------------- /codes/adience_poe/scripts/train_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gpu='0' 3 | 4 | main_loss_type='cls' 5 | num_output_neurons=8 6 | 7 | # main_loss_type='rank' 8 | # num_output_neurons=16 9 | 10 | # main_loss_type='reg' 11 | # num_output_neurons=1 12 | 13 | 14 | if [[ $# = 1 ]]; then 15 | gpu=${1} 16 | fi 17 | 18 | if [[ $# = 2 ]]; then 19 | gpu=${1} 20 | main_loss_type=${2} 21 | if [[ $main_loss_type = 'cls' ]]; then 22 | num_output_neurons=8 23 | fi 24 | if [[ $main_loss_type = 'rank' ]]; then 25 | num_output_neurons=16 26 | fi 27 | if [[ $main_loss_type = 'reg' ]]; then 28 | num_output_neurons=1 29 | fi 30 | fi 31 | 32 | for fold in $(seq 0 1 4); do 33 | CUDA_VISIBLE_DEVICES=${gpu} python -u train.py \ 34 | --batch-size=32 \ 35 | --test-batch-size=32 \ 36 | --max-epochs=50 --lr-decay-epoch=30 --lr=0.0001 --fc-lr=0.0001 \ 37 | --num-output-neurons=${num_output_neurons} --main-loss-type=${main_loss_type} \ 38 | --save-model='./Save_Model' \ 39 | --train-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \ 40 | --test-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \ 41 | --train-data-file=./data_list/test_fold_is_${fold}/age_train.txt \ 42 | --test-data-file=./data_list/test_fold_is_${fold}/age_val.txt \ 43 | --num-workers=4 --distance='JDistance' \ 44 | --alpha-coeff=1e-5 --beta-coeff=1e-4 --margin=5 \ 45 | --exp-name=train_val_fold_${fold} \ 46 | --logdir=./log/adience_baseline_${main_loss_type} \ 47 | --no-sto \ 48 | > ./log/adience_baseline_${main_loss_type}_train_val_fold_${fold}.txt 2>&1 49 | done -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # POEs 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-probabilistic-ordinal-embeddings-for/age-estimation-on-adience-1)](https://paperswithcode.com/sota/age-estimation-on-adience-1?p=learning-probabilistic-ordinal-embeddings-for) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-probabilistic-ordinal-embeddings-for/historical-color-image-dating-on-hci)](https://paperswithcode.com/sota/historical-color-image-dating-on-hci?p=learning-probabilistic-ordinal-embeddings-for) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-probabilistic-ordinal-embeddings-for/aesthetics-quality-assessment-on-image)](https://paperswithcode.com/sota/aesthetics-quality-assessment-on-image?p=learning-probabilistic-ordinal-embeddings-for) 6 | 7 | 8 | PyTorch implementation of Learning Probabilistic Ordinal Embeddings for Uncertainty-Aware Regression (CVPR 2021) \[[arXiv](https://arxiv.org/abs/2103.13629)\]\[[Homepage](https://li-wanhua.github.io/POEs/)\] 9 | 10 |

11 | 12 |

13 | 14 | If you find our work useful in your research, please consider citing: 15 | ``` 16 | @inproceedings{li2021probabilistic, 17 | title={Learning Probabilistic Ordinal Embeddings for Uncertainty-Aware Regression}, 18 | author={Li, Wanhua and Huang, Xiaoke and Lu, Jiwen and Feng, Jianjiang and Zhou, Jie}, 19 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 20 | year={2021} 21 | } 22 | ``` 23 | 24 | # Codes 25 | [Adience](./codes/adience_poe) 26 | -------------------------------------------------------------------------------- /codes/adience_poe/poe/dataset_manger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torch.utils.data as data 4 | from PIL import Image 5 | import os 6 | 7 | 8 | class dataset_manger(data.Dataset): 9 | def __init__(self, images_root, data_file, transforms=None, num_output_bins=8): 10 | self.images_root = images_root 11 | self.labels = [] 12 | self.images_file = [] 13 | self.transforms = transforms 14 | self.num_output_bins = num_output_bins 15 | with open(data_file) as fin: 16 | for line in fin: 17 | image_file, image_label = line.split() 18 | self.labels.append(int(image_label)) 19 | self.images_file.append(image_file) 20 | 21 | def __getitem__(self, index): 22 | img_file, target = self.images_file[index], self.labels[index] 23 | full_file = os.path.join(self.images_root, img_file) 24 | img = Image.open(full_file) 25 | 26 | if img.mode == 'L': 27 | img = img.convert('RGB') 28 | 29 | if self.transforms: 30 | img = self.transforms(img) 31 | 32 | multi_hot_target = torch.zeros(self.num_output_bins).long() 33 | multi_hot_target[list(range(target))] = 1 34 | 35 | return img, target, multi_hot_target 36 | 37 | def __len__(self): 38 | return len(self.labels) 39 | 40 | 41 | if __name__ == '__main__': 42 | normalize = transforms.Normalize( 43 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 44 | scaler = transforms.Resize((224, 224)) 45 | preprocess = transforms.Compose([scaler, transforms.ToTensor(), normalize]) 46 | t = dataset_manger('/home/share_data/age/CVPR19/datasets/MORPH', './data_list/ET_proto_val.txt', preprocess) 47 | train_loader = data.DataLoader(t, batch_size=2, shuffle=False) 48 | for data_, index, tar in train_loader: 49 | print(data_) 50 | print(index) 51 | print(tar) 52 | break 53 | -------------------------------------------------------------------------------- /codes/adience_poe/poe/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | def cal_mae_acc_rank(logits, targets, is_sto=True): 7 | if is_sto: 8 | r_dim, s_dim, out_dim = logits.shape 9 | assert out_dim % 2 == 0, "outdim {} wrong".format(out_dim) 10 | logits = logits.view(r_dim, s_dim, out_dim / 2, 2) 11 | logits = torch.argmax(logits, dim=-1) 12 | logits = torch.sum(logits, dim=-1) 13 | logits = torch.mean(logits.float(), dim=0) 14 | logits = logits.cpu().data.numpy() 15 | targets = targets.cpu().data.numpy() 16 | mae = sum(abs(logits - targets)) * 1.0 / len(targets) 17 | acc = sum(np.rint(logits) == targets) * 1.0 / len(targets) 18 | else: 19 | s_dim, out_dim = logits.shape 20 | assert out_dim % 2 == 0, "outdim {} wrong".format(out_dim) 21 | logits = logits.view(s_dim, out_dim / 2, 2) 22 | logits = torch.argmax(logits, dim=-1) 23 | logits = torch.sum(logits, dim=-1) 24 | logits = logits.cpu().data.numpy() 25 | targets = targets.cpu().data.numpy() 26 | mae = sum(abs(logits - targets)) * 1.0 / len(targets) 27 | acc = sum(np.rint(logits) == targets) * 1.0 / len(targets) 28 | return mae, acc 29 | 30 | 31 | def cal_mae_acc_reg(logits, targets, is_sto=True): 32 | if is_sto: 33 | logits = logits.mean(dim=0) 34 | 35 | assert logits.view(-1).shape == targets.shape, "logits {}, targets {}".format( 36 | logits.shape, targets.shape) 37 | 38 | logits = logits.cpu().data.numpy().reshape(-1) 39 | targets = targets.cpu().data.numpy() 40 | mae = sum(abs(logits - targets)) * 1.0 / len(targets) 41 | acc = sum(np.rint(logits) == targets) * 1.0 / len(targets) 42 | 43 | return mae, acc 44 | 45 | 46 | def cal_mae_acc_cls(logits, targets, is_sto=True): 47 | if is_sto: 48 | r_dim, s_dim, out_dim = logits.shape 49 | label_arr = torch.arange(0, out_dim).float().cuda() 50 | probs = F.softmax(logits, -1) 51 | exp = torch.sum(probs * label_arr, dim=-1) 52 | exp = torch.mean(exp, dim=0) 53 | max_a = torch.mean(probs, dim=0) 54 | max_data = max_a.cpu().data.numpy() 55 | max_data = np.argmax(max_data, axis=1) 56 | target_data = targets.cpu().data.numpy() 57 | exp_data = exp.cpu().data.numpy() 58 | mae = sum(abs(exp_data - target_data)) * 1.0 / len(target_data) 59 | acc = sum(np.rint(exp_data) == target_data) * 1.0 / len(target_data) 60 | 61 | else: 62 | s_dim, out_dim = logits.shape 63 | probs = F.softmax(logits, -1) 64 | probs_data = probs.cpu().data.numpy() 65 | target_data = targets.cpu().data.numpy() 66 | max_data = np.argmax(probs_data, axis=1) 67 | label_arr = np.array(range(out_dim)) 68 | exp_data = np.sum(probs_data * label_arr, axis=1) 69 | mae = sum(abs(exp_data - target_data)) * 1.0 / len(target_data) 70 | acc = sum(np.rint(exp_data) == target_data) * 1.0 / len(target_data) 71 | 72 | return mae, acc 73 | 74 | 75 | def get_metric(main_loss_type): 76 | assert main_loss_type in ['cls', 'reg', 'rank'], \ 77 | "main_loss_type not in ['cls', 'reg', 'rank'], loss type {%s}" % ( 78 | main_loss_type) 79 | if main_loss_type == 'cls': 80 | return cal_mae_acc_cls 81 | elif main_loss_type == 'reg': 82 | return cal_mae_acc_reg 83 | elif main_loss_type == 'rank': 84 | return cal_mae_acc_rank 85 | else: 86 | raise AttributeError('main loss type: {}'.format(main_loss_type)) 87 | -------------------------------------------------------------------------------- /codes/adience_poe/poe/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from datetime import datetime 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self, max_count=100): 10 | self.reset(max_count) 11 | 12 | def reset(self, max_count): 13 | self.val = 0 14 | self.avg = 0 15 | self.data_container = [] 16 | self.max_count = max_count 17 | 18 | def update(self, val): 19 | self.val = val 20 | if(len(self.data_container) < self.max_count): 21 | self.data_container.append(val) 22 | self.avg = sum(self.data_container) * 1.0 / \ 23 | len(self.data_container) 24 | else: 25 | self.data_container.pop(0) 26 | self.data_container.append(val) 27 | self.avg = sum(self.data_container) * 1.0 / self.max_count 28 | 29 | 30 | def is_fc(para_name): 31 | split_name = para_name.split('.') 32 | if split_name[-2] == 'final': 33 | return True 34 | else: 35 | return False 36 | 37 | 38 | def display_lr(optimizer): 39 | for param_group in optimizer.param_groups: 40 | print(param_group['lr'], param_group['initial_lr']) 41 | 42 | 43 | def load_model(unload_model, args): 44 | if not os.path.exists(args.save_model): 45 | os.makedirs(args.save_model) 46 | print(args.save_model, 'is created!') 47 | if not os.path.exists(os.path.join(args.save_model, 'checkpoint.txt')): 48 | f = open(os.path.join(args.save_model, 'checkpoint.txt'), 'w') 49 | print('checkpoint', 'is created!') 50 | 51 | start_index = 0 52 | with open(os.path.join(args.save_model, 'checkpoint.txt'), 'r') as fin: 53 | lines = fin.readlines() 54 | if len(lines) > 0: 55 | model_path, model_index = lines[0].split() 56 | print('Resuming from', model_path) 57 | if int(model_index) == 0: 58 | unload_model_dict = unload_model.state_dict() 59 | 60 | pretrained_dict = torch.load( 61 | os.path.join(args.save_model, model_path)) 62 | pretrained_dict['emd.0.weight'] = pretrained_dict['classifier.3.weight'] 63 | pretrained_dict['emd.0.bias'] = pretrained_dict['classifier.3.bias'] 64 | pretrained_dict['final.weight'] = pretrained_dict['classifier.6.weight'] 65 | pretrained_dict['final.bias'] = pretrained_dict['classifier.6.bias'] 66 | 67 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if ( 68 | k in unload_model_dict and pretrained_dict[k].shape == unload_model_dict[k].shape)} 69 | print(len(pretrained_dict)) 70 | for dict_inx, (k, v) in enumerate(pretrained_dict.items()): 71 | print(dict_inx, k, v.shape) 72 | unload_model_dict.update(pretrained_dict) 73 | unload_model.load_state_dict(unload_model_dict) 74 | else: 75 | unload_model.load_state_dict(torch.load( 76 | os.path.join(args.save_model, model_path))) 77 | 78 | start_index = int(model_index) + 1 79 | return start_index 80 | 81 | 82 | def save_model(tosave_model, epoch, args): 83 | model_epoch = '%04d' % (epoch) 84 | model_path = 'model-' + model_epoch + '.pth' 85 | save_path = os.path.join(args.save_model, model_path) 86 | torch.save(tosave_model.state_dict(), save_path) 87 | with open(os.path.join(args.save_model, 'checkpoint.txt'), 'w') as fin: 88 | fin.write(model_path + ' ' + str(epoch) + '\n') 89 | 90 | 91 | def get_current_time(): 92 | _now = datetime.now() 93 | _now = str(_now)[:-7] 94 | return _now 95 | -------------------------------------------------------------------------------- /codes/adience_poe/environment.yaml: -------------------------------------------------------------------------------- 1 | name: p2t1 2 | channels: 3 | - menpo 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - attrs=20.3.0=pyhd3eb1b0_0 9 | - backports=1.0=py_2 10 | - backports.functools_lru_cache=1.5=py_2 11 | - backports.shutil_get_terminal_size=1.0.0=pyhd3eb1b0_3 12 | - backports_abc=0.5=py_0 13 | - blas=1.0=mkl 14 | - bleach=3.2.1=py_0 15 | - ca-certificates=2021.1.19=h06a4308_0 16 | - certifi=2020.6.20=pyhd3eb1b0_3 17 | - cffi=1.12.3=py27h2e261b9_0 18 | - cloudpickle=1.2.1=py_0 19 | - configparser=4.0.2=py27_0 20 | - cuda100=1.0=0 21 | - cycler=0.10.0=py27_0 22 | - cytoolz=0.10.0=py27h7b6447c_0 23 | - dask-core=1.2.2=py_0 24 | - dbus=1.13.6=h746ee38_0 25 | - decorator=4.4.0=py27_1 26 | - defusedxml=0.6.0=py_0 27 | - entrypoints=0.3=py27_0 28 | - enum34=1.1.6=py27_1 29 | - expat=2.2.6=he6710b0_0 30 | - fontconfig=2.13.0=h9420a91_0 31 | - freetype=2.9.1=h8a8886c_1 32 | - functools32=3.2.3.2=py27_1 33 | - futures=3.3.0=py27_0 34 | - get_terminal_size=1.0.0=haa9412d_0 35 | - glib=2.56.2=hd408876_0 36 | - gst-plugins-base=1.14.0=hbbd80ab_1 37 | - gstreamer=1.14.0=hb453b48_1 38 | - icu=58.2=h9c2bf20_1 39 | - imageio=2.5.0=py27_0 40 | - intel-openmp=2019.4=243 41 | - ipykernel=4.9.0=py27_1 42 | - ipython=5.3.0=py27_0 43 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 44 | - ipywidgets=7.4.2=py27_0 45 | - jinja2=2.11.2=pyhd3eb1b0_0 46 | - jpeg=9b=h024ee3a_2 47 | - jsonschema=3.0.2=py27_0 48 | - jupyter=1.0.0=py27_7 49 | - jupyter_client=5.1.0=py27_0 50 | - jupyter_console=5.2.0=py27_0 51 | - jupyter_core=4.5.0=py_0 52 | - kiwisolver=1.1.0=py27he6710b0_0 53 | - libedit=3.1.20181209=hc058e9b_0 54 | - libffi=3.2.1=hd88cf55_4 55 | - libgcc-ng=9.1.0=hdf63c60_0 56 | - libgfortran-ng=7.3.0=hdf63c60_0 57 | - libpng=1.6.37=hbc83047_0 58 | - libsodium=1.0.18=h7b6447c_0 59 | - libstdcxx-ng=9.1.0=hdf63c60_0 60 | - libtiff=4.0.10=h2733197_2 61 | - libuuid=1.0.3=h1bed415_2 62 | - libxcb=1.13=h1bed415_1 63 | - libxml2=2.9.9=hea5a465_1 64 | - markupsafe=1.1.1=py27h7b6447c_0 65 | - matplotlib=2.2.3=py27hb69df0a_0 66 | - mistune=0.8.4=py27h7b6447c_0 67 | - mkl=2019.4=243 68 | - mkl-service=2.0.2=py27h7b6447c_0 69 | - mkl_fft=1.0.14=py27ha843d7b_0 70 | - mkl_random=1.0.2=py27hd81dba3_0 71 | - nbconvert=5.5.0=py_0 72 | - nbformat=4.4.0=py27_0 73 | - ncurses=6.1=he6710b0_1 74 | - networkx=2.2=py27_1 75 | - ninja=1.9.0=py27hfd86e86_0 76 | - notebook=5.0.0=py27_0 77 | - olefile=0.46=py27_0 78 | - opencv=2.4.11=nppy27_0 79 | - openssl=1.1.1i=h27cfd23_0 80 | - packaging=20.8=pyhd3eb1b0_0 81 | - pandas=0.24.2=py27he6710b0_0 82 | - pandoc=2.11=hb0f4dca_0 83 | - pandocfilters=1.4.2=py27_0 84 | - pathlib2=2.3.5=py27_0 85 | - pcre=8.43=he6710b0_0 86 | - pexpect=4.8.0=pyhd3eb1b0_3 87 | - pickleshare=0.7.5=py27_0 88 | - pillow=6.1.0=py27h34e0f95_0 89 | - pip=19.2.2=py27_0 90 | - prompt_toolkit=1.0.15=py27_0 91 | - ptyprocess=0.7.0=pyhd3eb1b0_2 92 | - pycparser=2.19=py27_0 93 | - pygments=2.3.1=py27_0 94 | - pyparsing=2.4.2=py_0 95 | - pyqt=5.9.2=py27h05f1152_2 96 | - pyrsistent=0.14.11=py27h7b6447c_0 97 | - python=2.7.16=h8b3fad2_5 98 | - python-dateutil=2.8.0=py27_0 99 | - pytorch=1.0.0=py2.7_cuda10.0.130_cudnn7.4.1_1 100 | - pytz=2019.2=py_0 101 | - pywavelets=1.0.3=py27hdd07704_1 102 | - pyzmq=16.0.2=py27_0 103 | - qt=5.9.7=h5867ecd_1 104 | - qtconsole=4.7.7=py_0 105 | - qtpy=1.9.0=py_0 106 | - readline=7.0=h7b6447c_5 107 | - scandir=1.10.0=py27h7b6447c_0 108 | - scikit-image=0.14.2=py27he6710b0_0 109 | - scikit-learn=0.20.3=py27hd81dba3_0 110 | - scipy=1.2.1=py27h7c811a0_0 111 | - setuptools=41.0.1=py27_0 112 | - simplegeneric=0.8.1=py27_2 113 | - singledispatch=3.4.0.3=py27_0 114 | - sip=4.19.8=py27hf484d3e_0 115 | - six=1.12.0=py27_0 116 | - sqlite=3.29.0=h7b6447c_0 117 | - subprocess32=3.5.4=py27h7b6447c_0 118 | - terminado=0.6=py27_0 119 | - testpath=0.4.4=py_0 120 | - tk=8.6.8=hbc83047_0 121 | - toolz=0.10.0=py_0 122 | - torchvision=0.2.2=py_3 123 | - tornado=5.1.1=py27h7b6447c_0 124 | - traitlets=4.3.3=py27_0 125 | - wcwidth=0.2.5=py_0 126 | - webencodings=0.5.1=py27_1 127 | - wheel=0.33.4=py27_0 128 | - widgetsnbextension=3.4.2=py27_0 129 | - xz=5.2.4=h14c3975_4 130 | - zeromq=4.1.5=0 131 | - zlib=1.2.11=h7b6447c_3 132 | - zstd=1.3.7=h0b5b093_0 133 | - pip: 134 | - autopep8==1.5.4 135 | - contextlib2==0.6.0.post1 136 | - flake8==3.8.4 137 | - importlib-metadata==2.1.1 138 | - mccabe==0.6.1 139 | - numpy==1.16.0 140 | - pycodestyle==2.6.0 141 | - pyflakes==2.2.0 142 | - thop==0.0.30-2101242100 143 | - toml==0.10.2 144 | - tqdm==4.56.0 145 | - typing==3.7.4.3 146 | - zipp==1.2.0 147 | prefix: /home/xiaoke_wh/anaconda3/envs/p2t1 148 | -------------------------------------------------------------------------------- /codes/adience_poe/poe/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | 5 | # Training settings 6 | def get_args(): 7 | parser = argparse.ArgumentParser( 8 | description='Probabilistic Ordinal Embedding (POE) for age estimation for Adience dataset') 9 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 10 | help='input batch size for training (default: 32)') 11 | parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', 12 | help='input batch size for testing (default: 32)') 13 | parser.add_argument('--max-epochs', type=int, default=50, metavar='N', 14 | help='number of epochs to train (default: 50)') 15 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', 16 | help='learning rate (default: 0.0001)') 17 | parser.add_argument('--fc-lr', type=float, default=0.0001, 18 | help='fc layer learning rate (default: 0.0001)') 19 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 20 | help='SGD momentum (default: 0.9)') 21 | parser.add_argument('--weight-decay', type=float, default=5e-4, metavar='M', 22 | help='SGD weight decay (default: 5e-4)') 23 | parser.add_argument('--print-freq', default=10, type=int, 24 | metavar='N', help='print frequency (default: 10)') 25 | parser.add_argument('--no-cuda', action='store_true', default=False, 26 | help='disables CUDA training') 27 | parser.add_argument('--AverageMeter-MaxCount', default=100, type=int, 28 | help='maximum capacity for AverageMeter(default: 100)') 29 | parser.add_argument('--num-workers', default=2, type=int, 30 | help='number of load data workers (default: 2)') 31 | parser.add_argument('--train-images-root', type=str, default='/home/share_data/age/CVPR19/datasets/MORPH', 32 | help='images root for train dataset') 33 | parser.add_argument('--test-images-root', type=str, default='/home/share_data/age/CVPR19/datasets/MORPH', 34 | help='images root for test dataset') 35 | parser.add_argument('--train-data-file', type=str, default='./data_list/ET_proto_train.txt', 36 | help='data file for train dataset') 37 | parser.add_argument('--test-data-file', type=str, default='./data_list/ET_proto_val.txt', 38 | help='data file for test dataset') 39 | parser.add_argument('--distance', type=str, default='JDistance', 40 | help='distance metric between two gaussian distribution') 41 | parser.add_argument('--alpha-coeff', type=float, default=1e-5, metavar='M', 42 | help='alpha_coeff (default: 0)') 43 | parser.add_argument('--beta-coeff', type=float, default=1e-4, metavar='M', 44 | help='beta_coeff (default: 1.0)') 45 | parser.add_argument('--margin', type=float, default=5, metavar='M', 46 | help='margin (default: 1.0)') 47 | parser.add_argument('--logdir', type=str, default='./log/', 48 | help='where you save log.') 49 | parser.add_argument('--exp-name', type=str, default='exp', 50 | help='name of your experiment.') 51 | parser.add_argument('--save-freq', default=10, type=int, 52 | metavar='N', help='save checkpoint frequency (default: 10)') 53 | parser.add_argument('--lr-decay-epoch', type=str, default='30', 54 | help='epochs at which learning rate decays. default is 30.') 55 | parser.add_argument('--lr-decay', type=float, default=0.1, 56 | help='decay rate of learning rate. default is 0.1.') 57 | parser.add_argument('--no-sto', action='store_true', default=False, 58 | help='not using stochastic sampling when training or testing.') 59 | parser.add_argument('--test-only', action='store_true', 60 | default=False, help='test your model, no training loop.') 61 | parser.add_argument('--num-output-neurons', type=int, default=1, 62 | help='number of ouput neurons of your model, note that for `reg` model we use 1; `cls` model we use `num_output_classes`; and for `rank` model we use `num_output_class` * 2.') 63 | parser.add_argument('--main-loss-type', type=str, 64 | default='reg', help='loss type in [cls, reg, rank].') 65 | parser.add_argument('--max-t', type=int, default=50, 66 | help='number of samples during sto.') 67 | parser.add_argument('--save-model', type=str, default='./Saved_Model/', 68 | help='where you save model') 69 | parser.add_argument('--checkpoint-path', type=str, default=None, 70 | help="checkpoint to be loaded when testing.") 71 | 72 | args = parser.parse_args() 73 | 74 | args.cuda = not args.no_cuda and torch.cuda.is_available() 75 | args.use_sto = True 76 | 77 | if args.no_sto: 78 | args.use_sto = False 79 | args.alpha_coeff = .0 80 | args.beta_coeff = .0 81 | args.margin = .0 82 | print("no stochastic sampling when training or testing, baseline set up") 83 | 84 | print('------------ Options -------------') 85 | for k, v in sorted(vars(args).items()): 86 | print('{}: {}'.format(str(k), str(v))) 87 | print('-------------- End ----------------') 88 | 89 | return args 90 | 91 | 92 | if __name__ == "__main__": 93 | args = get_args() 94 | -------------------------------------------------------------------------------- /codes/adience_poe/poe/probordiloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def BhattacharyyaDistance(u1, sigma1, u2, sigma2): 7 | sigma_mean = (sigma1 + sigma2) / 2.0 8 | sigma_inv = 1.0 / (sigma_mean) 9 | dis1 = torch.sum(torch.pow(u1 - u2, 2) * sigma_inv, dim=1) / 8.0 10 | dis2 = 0.5 * (torch.sum(torch.log(sigma_mean), dim=1) - 11 | 0.5 * (torch.sum(torch.log(sigma1), dim=1) + torch.sum(torch.log(sigma2), dim=1))) 12 | return dis1 + dis2 13 | 14 | 15 | def HellingerDistance(u1, sigma1, u2, sigma2): 16 | return torch.pow(1.0 - torch.exp(-BhattacharyyaDistance(u1, sigma1, u2, sigma2)), 0.5) 17 | 18 | 19 | def WassersteinDistance(u1, sigma1, u2, sigma2): 20 | dis1 = torch.sum(torch.pow(u1 - u2, 2), dim=1) 21 | dis2 = torch.sum(torch.pow(torch.pow(sigma1, 0.5) - 22 | torch.pow(sigma2, 0.5), 2), dim=1) 23 | return torch.pow(dis1 + dis2, 0.5) 24 | 25 | 26 | def GeodesicDistance(u1, sigma1, u2, sigma2): 27 | u_dis = torch.pow(u1 - u2, 2) 28 | std1 = sigma1.sqrt() 29 | std2 = sigma2.sqrt() 30 | 31 | sig_dis = torch.pow(std1 - std2, 2) 32 | sig_sum = torch.pow(std1 + std2, 2) 33 | delta = torch.div(u_dis + 2 * sig_dis, u_dis + 2 * sig_sum).sqrt() 34 | return torch.sum(torch.pow(torch.log((1.0 + delta) / (1.0 - delta)), 2) * 2, dim=1).sqrt() 35 | 36 | 37 | def ForwardKLDistance(u1, sigma1, u2, sigma2): 38 | return -0.5 * torch.sum(torch.log(sigma1) - torch.log(sigma2) - torch.div(sigma1, sigma2) 39 | - torch.div(torch.pow(u1 - u2, 2), sigma2) + 1, dim=1) 40 | 41 | 42 | def ReverseKLDistance(u2, sigma2, u1, sigma1): 43 | return -0.5 * torch.sum(torch.log(sigma1) - torch.log(sigma2) - torch.div(sigma1, sigma2) 44 | - torch.div(torch.pow(u1 - u2, 2), sigma2) + 1, dim=1) 45 | 46 | 47 | def JDistance(u1, sigma1, u2, sigma2): 48 | return ForwardKLDistance(u1, sigma1, u2, sigma2) + ForwardKLDistance(u2, sigma2, u1, sigma1) 49 | 50 | 51 | class ProbOrdiLoss(nn.Module): 52 | def __init__(self, distance='Bhattacharyya', alpha_coeff=0, beta_coeff=0, margin=0, main_loss_type='cls'): 53 | super(ProbOrdiLoss, self).__init__() 54 | self.alpha_coeff = alpha_coeff 55 | self.beta_coeff = beta_coeff 56 | self.margin = margin 57 | self.zeros = torch.zeros(1).cuda() 58 | 59 | assert main_loss_type in ['cls', 'reg', 'rank'], \ 60 | "main_loss_type not in ['cls', 'reg', 'rank'], loss type {%s}" % ( 61 | main_loss_type) 62 | self.main_loss_type = main_loss_type 63 | 64 | if distance == 'Bhattacharyya': 65 | self.distrance_f = BhattacharyyaDistance 66 | elif distance == 'Wasserstein': 67 | self.distrance_f = WassersteinDistance 68 | elif distance == 'JDistance': 69 | self.distrance_f = JDistance 70 | elif distance == 'ForwardKLDistance': 71 | self.distrance_f = ForwardKLDistance 72 | elif distance == 'HellingerDistance': 73 | self.distrance_f = HellingerDistance 74 | elif distance == 'GeodesicDistance': 75 | self.distrance_f = GeodesicDistance 76 | elif distance == 'ReverseKLDistance': 77 | self.distrance_f = ReverseKLDistance 78 | else: 79 | print('ERROR: this distance is not supported!') 80 | self.distrance_f = None 81 | 82 | def forward(self, logit, emb, log_var, target, mh_target=None, use_sto=True): 83 | class_dim = logit.shape[-1] 84 | sample_size = logit.shape[0] # reparameterized with max_t samples 85 | 86 | if self.main_loss_type == 'cls': 87 | if use_sto: 88 | CEloss = F.cross_entropy( 89 | logit.view(-1, class_dim), target.repeat(sample_size)) 90 | else: 91 | CEloss = F.cross_entropy(logit.view(-1, class_dim), target) 92 | elif self.main_loss_type == 'reg': 93 | if use_sto: 94 | CEloss = F.mse_loss( 95 | logit.view(-1), target.repeat(sample_size).to(dtype=logit.dtype)) 96 | else: 97 | CEloss = F.mse_loss( 98 | logit.view(-1), target.to(dtype=logit.dtype)) 99 | elif self.main_loss_type == 'rank': 100 | assert len(mh_target.shape) == 2, "target shape {} wrong".format( 101 | mh_target.shape) 102 | assert class_dim % 2 == 0, "class dim {} wrong".format(class_dim) 103 | if use_sto: 104 | CEloss = F.cross_entropy( 105 | logit.view(-1, 2), mh_target.repeat([sample_size, 1]).view(-1)) 106 | else: 107 | CEloss = F.cross_entropy(logit.view(-1, 2), mh_target.view(-1)) 108 | else: 109 | raise AttributeError( 110 | 'main loss type: {}'.format(self.main_loss_type)) 111 | 112 | KLLoss = torch.mean(torch.sum(torch.pow(emb, 2) + 113 | torch.exp(log_var) - log_var - 1.0, dim=1) * 0.5) 114 | 115 | var = torch.exp(log_var) 116 | 117 | batch_size = emb.shape[0] 118 | dims = emb.shape[1] 119 | target_dis = torch.abs( 120 | target.view(-1, 1).repeat(1, batch_size) - target.view(1, -1).repeat(batch_size, 1)) 121 | anchor_pos = [i for i in range(batch_size)] 122 | second_pos = [(i + 1) % batch_size for i in anchor_pos] 123 | target_dis = torch.abs(target_dis - torch.abs( 124 | target[anchor_pos] - target[second_pos]).view(-1, 1).repeat(1, batch_size)) 125 | offset_m = torch.eye(batch_size).cuda().to(dtype=target_dis.dtype) 126 | target_dis = target_dis + offset_m * 1000 127 | target_dis[target_dis == 0] = 700 128 | thrid_pos = torch.argmin(target_dis, dim=1) 129 | 130 | anchor_sign = torch.sign(torch.abs( 131 | target[anchor_pos] - target[second_pos]) - torch.abs(target[anchor_pos] - target[thrid_pos])) 132 | 133 | emb_dis_12 = self.distrance_f( 134 | emb[anchor_pos, :], var[anchor_pos, :], emb[second_pos, :], var[second_pos, :]) 135 | emb_dis_13 = self.distrance_f( 136 | emb[anchor_pos, :], var[anchor_pos, :], emb[thrid_pos, :], var[thrid_pos, :]) 137 | 138 | anchor_cons = (emb_dis_13 - emb_dis_12) * \ 139 | anchor_sign.float() + self.margin 140 | 141 | loss_anchor = torch.max(self.zeros, anchor_cons) * \ 142 | torch.abs(anchor_sign).float() 143 | loss_mask = (anchor_cons > 0).to(dtype=anchor_sign.dtype) 144 | if sum(torch.abs(anchor_sign) * loss_mask) > 0: 145 | triple_loss = torch.sum(loss_anchor) / \ 146 | sum(torch.abs(anchor_sign) * loss_mask) 147 | else: 148 | triple_loss = torch.tensor(0.0).cuda() 149 | 150 | return CEloss, KLLoss * self.alpha_coeff, triple_loss * self.beta_coeff, CEloss + self.alpha_coeff * KLLoss + self.beta_coeff * triple_loss 151 | -------------------------------------------------------------------------------- /codes/adience_poe/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from torch.autograd import Variable 4 | from poe import dataset_manger 5 | import os 6 | import time 7 | from poe import vgg 8 | from poe.probordiloss import ProbOrdiLoss 9 | from poe.metrics import get_metric 10 | from poe.utils import get_current_time, AverageMeter 11 | from poe.options import get_args 12 | import numpy as np 13 | 14 | 15 | def test_epoch(epoch): 16 | 17 | batch_time = AverageMeter(args.AverageMeter_MaxCount) 18 | loss1es = AverageMeter(args.AverageMeter_MaxCount) 19 | loss2es = AverageMeter(args.AverageMeter_MaxCount) 20 | loss3es = AverageMeter(args.AverageMeter_MaxCount) 21 | losses = AverageMeter(args.AverageMeter_MaxCount) 22 | mae_sto = AverageMeter(args.AverageMeter_MaxCount) 23 | mae_no_sto = AverageMeter(args.AverageMeter_MaxCount) 24 | acc_sto = AverageMeter(args.AverageMeter_MaxCount) 25 | acc_no_sto = AverageMeter(args.AverageMeter_MaxCount) 26 | 27 | rpr_model.eval() 28 | total = 0 29 | all_mae_sto = .0 30 | all_mae_no_sto = .0 31 | all_acc_sto = .0 32 | all_acc_no_sto = .0 33 | 34 | end_time = time.time() 35 | for batch_idx, (inputs, targets, mh_targets) in enumerate(testloader): 36 | inputs, targets, mh_targets = inputs.cuda(), targets.cuda(), mh_targets.cuda() 37 | inputs, targets = Variable( 38 | inputs, requires_grad=True), Variable(targets) 39 | 40 | logit, emb, log_var = rpr_model(inputs, max_t=args.max_t, use_sto=True) 41 | logit_no_sto, emb, log_var = rpr_model( 42 | inputs, max_t=args.max_t, use_sto=False) 43 | if args.use_sto: 44 | loss1, loss2, loss3, loss = criterion( 45 | logit, emb, log_var, targets, mh_targets, use_sto=args.use_sto) 46 | else: 47 | loss1, loss2, loss3, loss = criterion( 48 | logit_no_sto, emb, log_var, targets, mh_targets, use_sto=args.use_sto) 49 | 50 | total += targets.size(0) 51 | 52 | batch_mae_sto, batch_acc_sto = cal_mae_acc(logit, targets, True) 53 | batch_mae_no_sto, batch_acc_no_sto = cal_mae_acc( 54 | logit_no_sto, targets, False) 55 | 56 | loss1es.update(loss1.cpu().data.numpy()) 57 | loss2es.update(loss2.cpu().data.numpy()) 58 | loss3es.update(loss3.cpu().data.numpy()) 59 | losses.update(loss.cpu().data.numpy()) 60 | 61 | mae_sto.update(batch_mae_sto) 62 | mae_no_sto.update(batch_mae_no_sto) 63 | acc_sto.update(batch_acc_sto) 64 | acc_no_sto.update(batch_acc_no_sto) 65 | 66 | all_mae_sto = all_mae_sto + batch_mae_sto * targets.size(0) 67 | all_mae_no_sto = all_mae_no_sto + batch_mae_no_sto * targets.size(0) 68 | all_acc_sto = all_acc_sto + batch_acc_sto * targets.size(0) 69 | all_acc_no_sto = all_acc_no_sto + batch_acc_no_sto * targets.size(0) 70 | 71 | batch_time.update(time.time() - end_time) 72 | end_time = time.time() 73 | 74 | if batch_idx % args.print_freq == 0: 75 | print('Test: [%d/%d]\t' 76 | 'Time %.3f (%.3f)\t' 77 | 'Loss1 %.3f (%.3f)\t' 78 | 'Loss2 %.3f (%.3f)\t' 79 | 'Loss3 %.3f (%.3f)\t' 80 | 'Loss %.3f (%.3f)\t' 81 | 'MAE_sto %.3f (%.3f)\t' 82 | 'MAE_no_sto %.3f (%.3f)\t' 83 | 'ACC_sto %.3f (%.3f)\t' 84 | 'ACC_no_sto %.3f (%.3f)\t' % (batch_idx, len(testloader), 85 | batch_time.val, batch_time.avg, loss1es.val, loss1es.avg, loss2es.val, loss2es.avg, loss3es.val, loss3es.avg, 86 | losses.val, losses.avg, mae_sto.val, mae_sto.avg, mae_no_sto.val, mae_no_sto.avg, acc_sto.val, acc_sto.avg, acc_no_sto.val, acc_no_sto.avg)) 87 | 88 | print('Test: MAE_sto: %.3f MAE_no_sto: %.3f ACC_sto: %.3f ACC_no_sto: %.3f' % (all_mae_sto * 89 | 1.0 / total, all_mae_no_sto * 1.0 / total, all_acc_sto * 1.0 / total, all_acc_no_sto * 1.0 / total)) 90 | 91 | return all_mae_sto * 1.0 / total, all_mae_no_sto * 1.0 / total, all_acc_sto * 1.0 / total, all_acc_no_sto * 1.0 / total 92 | 93 | 94 | if __name__ == "__main__": 95 | # Training settings 96 | # --------------------------------- 97 | args = get_args() 98 | 99 | torch.manual_seed(0) 100 | np.random.seed(0) 101 | torch.backends.cudnn.deterministic = True 102 | torch.backends.cudnn.benchmark = False 103 | # --------------------------------- 104 | 105 | # dataset prepare 106 | # --------------------------------- 107 | transform_test = transforms.Compose([ 108 | transforms.Resize((256, 256)), 109 | transforms.CenterCrop(224), 110 | transforms.ToTensor(), 111 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 112 | ]) 113 | 114 | testset = dataset_manger.dataset_manger( 115 | images_root=args.test_images_root, data_file=args.test_data_file, transforms=transform_test) 116 | 117 | testloader = torch.utils.data.DataLoader( 118 | testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_workers) 119 | # -------------------------------------- 120 | 121 | # define Model 122 | # ---------------------------------------- 123 | rpr_model = vgg.vgg16(num_output_neurons=args.num_output_neurons) 124 | rpr_model.cuda() 125 | # rpr_model = torch.nn.DataParallel(rpr_model) 126 | print("Model finshed") 127 | # --------------------------------- 128 | 129 | # define loss 130 | # ---------------------------------------- 131 | criterion = ProbOrdiLoss(distance=args.distance, alpha_coeff=args.alpha_coeff, 132 | beta_coeff=args.beta_coeff, margin=args.margin, main_loss_type=args.main_loss_type) 133 | criterion.cuda() 134 | 135 | # define Metric 136 | # ---------------------------------------- 137 | cal_mae_acc = get_metric(args.main_loss_type) 138 | # --------------------------------- 139 | 140 | args.logdir = os.path.join(args.logdir, args.exp_name) 141 | if not os.path.exists(args.logdir): 142 | os.makedirs(args.logdir) 143 | print('log dir [{}] {}'.format(args.logdir, 'is created!')) 144 | 145 | if args.checkpoint_path: 146 | rpr_model.load_state_dict(torch.load( 147 | args.checkpoint_path)['model_state_dict']) 148 | print("Load model from checkpoint: {}".foramt(args.checkpoint_path)) 149 | else: 150 | rpr_model.load_state_dict(torch.load(os.path.join( 151 | args.logdir, 'checkpoint.pth'))['model_state_dict']) 152 | print("Load model from checkpoint: {}".format( 153 | os.path.join(args.logdir, 'best.pth'))) 154 | 155 | print("Start Testing...") 156 | with torch.no_grad(): 157 | cur_mae_sto, cur_mae_no_sto, cur_acc_sto, cur_acc_no_sto = test_epoch( 158 | 0) 159 | 160 | print('[{}] end!'.format(get_current_time())) 161 | np.save( 162 | os.path.join(args.logdir, 'test_result.npy'), 163 | np.array([cur_mae_sto, cur_mae_no_sto, cur_acc_sto, cur_acc_no_sto]) 164 | ) 165 | -------------------------------------------------------------------------------- /codes/adience_poe/poe/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import math 4 | import torch 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class ProbVar(nn.Module): 25 | def __init__(self, dim): 26 | super(ProbVar, self).__init__() 27 | self.trans = nn.Sequential( 28 | nn.Linear(dim, dim), 29 | nn.BatchNorm1d(dim), 30 | nn.ReLU(True), 31 | nn.Linear(dim, dim), 32 | ) 33 | self.act = nn.Linear(1, 1) 34 | self.dim = dim 35 | 36 | def forward(self, x): 37 | x = self.trans(x) 38 | x = x.view(-1, self.dim, 1) 39 | x = self.act(x) 40 | x = x.view(-1, self.dim) 41 | return x 42 | 43 | 44 | class VGG(nn.Module): 45 | 46 | def __init__(self, features, num_output_neurons): 47 | super(VGG, self).__init__() 48 | self.features = features 49 | self.classifier = nn.Sequential( 50 | nn.Linear(512 * 7 * 7, 4096), 51 | nn.ReLU(True), 52 | nn.Dropout(), 53 | ) 54 | self.emd = nn.Sequential( 55 | nn.Linear(4096, 4096), 56 | nn.ReLU(True), 57 | ) 58 | self.var = nn.Sequential(nn.Linear(4096, 4096), 59 | nn.BatchNorm1d(4096, eps=0.001, affine=False), 60 | ) 61 | self.drop = nn.Dropout() 62 | self.final = nn.Linear(4096, num_output_neurons) 63 | self._initialize_weights() 64 | 65 | def forward(self, x, max_t=50, use_sto=True): 66 | x = self.features(x) 67 | x = x.view(x.size(0), -1) 68 | x = self.classifier(x) 69 | emb = self.emd(x) 70 | log_var = self.var(x) 71 | sqrt_var = torch.exp(log_var * 0.5) 72 | if use_sto: 73 | rep_emb = emb[None].expand(max_t, *emb.shape) 74 | rep_sqrt_var = sqrt_var[None].expand(max_t, *sqrt_var.shape) 75 | norm_v = torch.randn_like(rep_emb).cuda() 76 | sto_emb = rep_emb + rep_sqrt_var * norm_v 77 | sto_emb = self.drop(sto_emb) 78 | logit = self.final(sto_emb) 79 | else: 80 | drop_emb = self.drop(emb) 81 | logit = self.final(drop_emb) 82 | 83 | return logit, emb, log_var 84 | 85 | def _initialize_weights(self): 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 89 | m.weight.data.normal_(0, math.sqrt(2. / n)) 90 | if m.bias is not None: 91 | m.bias.data.zero_() 92 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 93 | if m.weight is not None: 94 | m.weight.data.fill_(1) 95 | m.bias.data.zero_() 96 | elif isinstance(m, nn.Linear): 97 | m.weight.data.normal_(0, 0.01) 98 | m.bias.data.zero_() 99 | 100 | 101 | def make_layers(cfg, batch_norm=False): 102 | layers = [] 103 | in_channels = 3 104 | for v in cfg: 105 | if v == 'M': 106 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 107 | else: 108 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 109 | if batch_norm: 110 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 111 | else: 112 | layers += [conv2d, nn.ReLU(inplace=True)] 113 | in_channels = v 114 | return nn.Sequential(*layers) 115 | 116 | 117 | cfg = { 118 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 119 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 120 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 121 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 122 | } 123 | 124 | 125 | def vgg11(pretrained=False, **kwargs): 126 | """VGG 11-layer model (configuration "A") 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | model = VGG(make_layers(cfg['A']), **kwargs) 132 | if pretrained: 133 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 134 | return model 135 | 136 | 137 | def vgg11_bn(pretrained=False, **kwargs): 138 | """VGG 11-layer model (configuration "A") with batch normalization 139 | 140 | Args: 141 | pretrained (bool): If True, returns a model pre-trained on ImageNet 142 | """ 143 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 144 | if pretrained: 145 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 146 | return model 147 | 148 | 149 | def vgg13(pretrained=False, **kwargs): 150 | """VGG 13-layer model (configuration "B") 151 | 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | """ 155 | model = VGG(make_layers(cfg['B']), **kwargs) 156 | if pretrained: 157 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 158 | return model 159 | 160 | 161 | def vgg13_bn(pretrained=False, **kwargs): 162 | """VGG 13-layer model (configuration "B") with batch normalization 163 | 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 170 | return model 171 | 172 | 173 | def vgg16(pretrained=False, **kwargs): 174 | """VGG 16-layer model (configuration "D") 175 | 176 | Args: 177 | pretrained (bool): If True, returns a model pre-trained on ImageNet 178 | """ 179 | model = VGG(make_layers(cfg['D']), **kwargs) 180 | if pretrained: 181 | unload_model_dict = model.state_dict() 182 | pretrained_dict = model_zoo.load_url(model_urls['vgg16']) 183 | 184 | pretrained_dict['emd.0.weight'] = pretrained_dict['classifier.3.weight'] 185 | pretrained_dict['emd.0.bias'] = pretrained_dict['classifier.3.bias'] 186 | pretrained_dict['final.weight'] = pretrained_dict['classifier.6.weight'] 187 | pretrained_dict['final.bias'] = pretrained_dict['classifier.6.bias'] 188 | 189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if ( 190 | k in unload_model_dict and pretrained_dict[k].shape == unload_model_dict[k].shape)} 191 | 192 | unload_model_dict.update(pretrained_dict) 193 | model.load_state_dict(unload_model_dict) 194 | 195 | return model 196 | 197 | 198 | def vgg16_bn(pretrained=False, **kwargs): 199 | """VGG 16-layer model (configuration "D") with batch normalization 200 | 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | """ 204 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 205 | if pretrained: 206 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 207 | return model 208 | 209 | 210 | def vgg19(pretrained=False, **kwargs): 211 | """VGG 19-layer model (configuration "E") 212 | 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = VGG(make_layers(cfg['E']), **kwargs) 217 | if pretrained: 218 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 219 | return model 220 | 221 | 222 | def vgg19_bn(pretrained=False, **kwargs): 223 | """VGG 19-layer model (configuration 'E') with batch normalization 224 | 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | """ 228 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 229 | if pretrained: 230 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 231 | return model 232 | -------------------------------------------------------------------------------- /codes/adience_poe/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torchvision import transforms 4 | from torch.autograd import Variable 5 | from poe import dataset_manger 6 | import os 7 | import time 8 | import numpy as np 9 | from poe import vgg 10 | from poe.probordiloss import ProbOrdiLoss 11 | from poe.metrics import get_metric 12 | from poe.utils import is_fc, get_current_time, AverageMeter 13 | from poe.options import get_args 14 | 15 | 16 | def train_epoch(epoch): 17 | 18 | batch_time = AverageMeter(args.AverageMeter_MaxCount) 19 | loss1es = AverageMeter(args.AverageMeter_MaxCount) 20 | loss2es = AverageMeter(args.AverageMeter_MaxCount) 21 | loss3es = AverageMeter(args.AverageMeter_MaxCount) 22 | losses = AverageMeter(args.AverageMeter_MaxCount) 23 | mae = AverageMeter(args.AverageMeter_MaxCount) 24 | acc = AverageMeter(args.AverageMeter_MaxCount) 25 | 26 | rpr_model.train() 27 | 28 | end_time = time.time() 29 | for batch_idx, (inputs, targets, mh_targets) in enumerate(trainloader): 30 | 31 | inputs, targets, mh_targets = inputs.cuda(), targets.cuda(), mh_targets.cuda() 32 | inputs, targets = Variable( 33 | inputs, requires_grad=True), Variable(targets) 34 | 35 | optimizer.zero_grad() 36 | 37 | logit, emb, log_var = rpr_model( 38 | inputs, max_t=args.max_t, use_sto=args.use_sto) 39 | 40 | loss1, loss2, loss3, loss = criterion( 41 | logit, emb, log_var, targets, mh_targets, use_sto=args.use_sto) 42 | 43 | loss.backward() 44 | optimizer.step() 45 | 46 | batch_mae, batch_acc = cal_mae_acc(logit, targets, args.use_sto) 47 | loss1es.update(loss1.cpu().data.numpy()) 48 | loss2es.update(loss2.cpu().data.numpy()) 49 | loss3es.update(loss3.cpu().data.numpy()) 50 | losses.update(loss.cpu().data.numpy()) 51 | 52 | mae.update(batch_mae) 53 | acc.update(batch_acc) 54 | 55 | batch_time.update(time.time() - end_time) 56 | end_time = time.time() 57 | 58 | if batch_idx % args.print_freq == 0: 59 | print('Epoch: [%d][%d/%d] ' 60 | 'Time %.3f (%.3f)\t' 61 | 'Loss1 %.3f (%.3f)\t' 62 | 'Loss2 %.3f (%.3f)\t' 63 | 'Loss3 %.3f (%.3f)\t' 64 | 'Loss %.3f (%.3f)\t' 65 | 'MAE %.3f (%.3f)\t' 66 | 'ACC %.3f (%.3f)' % (epoch, batch_idx, len(trainloader), 67 | batch_time.val, batch_time.avg, loss1es.val, loss1es.avg, loss2es.val, loss2es.avg, loss3es.val, loss3es.avg, 68 | losses.val, losses.avg, mae.val, mae.avg, acc.val, acc.avg)) 69 | 70 | 71 | def test_epoch(epoch): 72 | 73 | batch_time = AverageMeter(args.AverageMeter_MaxCount) 74 | loss1es = AverageMeter(args.AverageMeter_MaxCount) 75 | loss2es = AverageMeter(args.AverageMeter_MaxCount) 76 | loss3es = AverageMeter(args.AverageMeter_MaxCount) 77 | losses = AverageMeter(args.AverageMeter_MaxCount) 78 | mae_sto = AverageMeter(args.AverageMeter_MaxCount) 79 | mae_no_sto = AverageMeter(args.AverageMeter_MaxCount) 80 | acc_sto = AverageMeter(args.AverageMeter_MaxCount) 81 | acc_no_sto = AverageMeter(args.AverageMeter_MaxCount) 82 | 83 | rpr_model.eval() 84 | total = 0 85 | all_mae_sto = .0 86 | all_mae_no_sto = .0 87 | all_acc_sto = .0 88 | all_acc_no_sto = .0 89 | 90 | end_time = time.time() 91 | for batch_idx, (inputs, targets, mh_targets) in enumerate(testloader): 92 | inputs, targets, mh_targets = inputs.cuda(), targets.cuda(), mh_targets.cuda() 93 | inputs, targets = Variable( 94 | inputs, requires_grad=True), Variable(targets) 95 | 96 | logit, emb, log_var = rpr_model(inputs, max_t=args.max_t, use_sto=True) 97 | logit_no_sto, emb, log_var = rpr_model( 98 | inputs, max_t=args.max_t, use_sto=False) 99 | if args.use_sto: 100 | loss1, loss2, loss3, loss = criterion( 101 | logit, emb, log_var, targets, mh_targets, use_sto=args.use_sto) 102 | else: 103 | loss1, loss2, loss3, loss = criterion( 104 | logit_no_sto, emb, log_var, targets, mh_targets, use_sto=args.use_sto) 105 | 106 | total += targets.size(0) 107 | 108 | batch_mae_sto, batch_acc_sto = cal_mae_acc(logit, targets, True) 109 | batch_mae_no_sto, batch_acc_no_sto = cal_mae_acc( 110 | logit_no_sto, targets, False) 111 | 112 | loss1es.update(loss1.cpu().data.numpy()) 113 | loss2es.update(loss2.cpu().data.numpy()) 114 | loss3es.update(loss3.cpu().data.numpy()) 115 | losses.update(loss.cpu().data.numpy()) 116 | 117 | mae_sto.update(batch_mae_sto) 118 | mae_no_sto.update(batch_mae_no_sto) 119 | acc_sto.update(batch_acc_sto) 120 | acc_no_sto.update(batch_acc_no_sto) 121 | 122 | all_mae_sto = all_mae_sto + batch_mae_sto * targets.size(0) 123 | all_mae_no_sto = all_mae_no_sto + batch_mae_no_sto * targets.size(0) 124 | all_acc_sto = all_acc_sto + batch_acc_sto * targets.size(0) 125 | all_acc_no_sto = all_acc_no_sto + batch_acc_no_sto * targets.size(0) 126 | 127 | batch_time.update(time.time() - end_time) 128 | end_time = time.time() 129 | 130 | if batch_idx % args.print_freq == 0: 131 | print('Test: [%d/%d]\t' 132 | 'Time %.3f (%.3f)\t' 133 | 'Loss1 %.3f (%.3f)\t' 134 | 'Loss2 %.3f (%.3f)\t' 135 | 'Loss3 %.3f (%.3f)\t' 136 | 'Loss %.3f (%.3f)\t' 137 | 'MAE_sto %.3f (%.3f)\t' 138 | 'MAE_no_sto %.3f (%.3f)\t' 139 | 'ACC_sto %.3f (%.3f)\t' 140 | 'ACC_no_sto %.3f (%.3f)\t' % (batch_idx, len(testloader), 141 | batch_time.val, batch_time.avg, loss1es.val, loss1es.avg, loss2es.val, loss2es.avg, loss3es.val, loss3es.avg, 142 | losses.val, losses.avg, mae_sto.val, mae_sto.avg, mae_no_sto.val, mae_no_sto.avg, acc_sto.val, acc_sto.avg, acc_no_sto.val, acc_no_sto.avg)) 143 | 144 | print('Test: MAE_sto: %.3f MAE_no_sto: %.3f ACC_sto: %.3f ACC_no_sto: %.3f' % (all_mae_sto * 145 | 1.0 / total, all_mae_no_sto * 1.0 / total, all_acc_sto * 1.0 / total, all_acc_no_sto * 1.0 / total)) 146 | 147 | return all_mae_sto * 1.0 / total, all_mae_no_sto * 1.0 / total, all_acc_sto * 1.0 / total, all_acc_no_sto * 1.0 / total 148 | 149 | 150 | if __name__ == "__main__": 151 | # Training settings 152 | # --------------------------------- 153 | args = get_args() 154 | 155 | torch.manual_seed(0) 156 | np.random.seed(0) 157 | torch.backends.cudnn.deterministic = True 158 | torch.backends.cudnn.benchmark = False 159 | # --------------------------------- 160 | 161 | # dataset prepare 162 | # --------------------------------- 163 | transform_train = transforms.Compose([ 164 | transforms.Resize((256, 256)), 165 | # transforms.RandomRotation(10), 166 | transforms.RandomCrop(224), 167 | transforms.RandomHorizontalFlip(), 168 | transforms.ToTensor(), 169 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 170 | ]) 171 | 172 | transform_test = transforms.Compose([ 173 | transforms.Resize((256, 256)), 174 | transforms.CenterCrop(224), 175 | transforms.ToTensor(), 176 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 177 | ]) 178 | 179 | trainset = dataset_manger.dataset_manger( 180 | images_root=args.train_images_root, data_file=args.train_data_file, transforms=transform_train) 181 | 182 | trainloader = torch.utils.data.DataLoader( 183 | trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 184 | 185 | testset = dataset_manger.dataset_manger( 186 | images_root=args.test_images_root, data_file=args.test_data_file, transforms=transform_test) 187 | 188 | testloader = torch.utils.data.DataLoader( 189 | testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_workers) 190 | # -------------------------------------- 191 | 192 | # define Model 193 | # ---------------------------------------- 194 | start_epoch = 0 195 | rpr_model = vgg.vgg16(pretrained=True, num_output_neurons=args.num_output_neurons) 196 | rpr_model.cuda() 197 | # rpr_model = torch.nn.DataParallel(rpr_model) 198 | print("Model finshed") 199 | # --------------------------------- 200 | 201 | # define Optimizer 202 | # ---------------------------------------- 203 | params = [] 204 | for keys, param_value in rpr_model.named_parameters(): 205 | if (is_fc(keys)): 206 | params += [{'params': [param_value], 'lr':args.fc_lr}] 207 | else: 208 | params += [{'params': [param_value], 'lr':args.lr}] 209 | 210 | optimizer = optim.Adam(params, lr=args.lr, betas=(0.9, 0.999), eps=1e-08) 211 | 212 | lr_decay = args.lr_decay 213 | lr_decay_epoch = [int(i) 214 | for i in args.lr_decay_epoch.split(',')] + [np.inf] 215 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 216 | optimizer, milestones=lr_decay_epoch, gamma=lr_decay, last_epoch=-1) 217 | # --------------------------------- 218 | 219 | # define loss 220 | # ---------------------------------------- 221 | criterion = ProbOrdiLoss(distance=args.distance, alpha_coeff=args.alpha_coeff, 222 | beta_coeff=args.beta_coeff, margin=args.margin, main_loss_type=args.main_loss_type) 223 | criterion.cuda() 224 | 225 | # define Metric 226 | # ---------------------------------------- 227 | cal_mae_acc = get_metric(args.main_loss_type) 228 | # --------------------------------- 229 | 230 | args.logdir = os.path.join(args.logdir, args.exp_name) 231 | if not os.path.exists(args.logdir): 232 | os.makedirs(args.logdir) 233 | print('log dir [{}] {}'.format(args.logdir, 'is created!')) 234 | 235 | print("start training...") 236 | best_mae = np.inf 237 | 238 | for epoch in range(start_epoch, args.max_epochs): 239 | print('[{}] Epoch: {} start!'.format(get_current_time(), epoch)) 240 | 241 | if not args.test_only: 242 | train_epoch(epoch) 243 | lr_scheduler.step() 244 | with torch.no_grad(): 245 | cur_mae_sto, cur_mae_no_sto, cur_acc_sto, cur_acc_no_sto = test_epoch( 246 | epoch) 247 | else: 248 | with torch.no_grad(): 249 | cur_mae_sto, cur_mae_no_sto, cur_acc_sto, cur_acc_no_sto = test_epoch( 250 | epoch) 251 | break 252 | 253 | print('saving model...') 254 | is_best = cur_mae_no_sto < best_mae 255 | best_mae = min(best_mae, cur_mae_no_sto) 256 | if epoch % args.save_freq == 0: 257 | torch.save( 258 | { 259 | 'model_state_dict': rpr_model.state_dict(), 260 | 'optim_state_dict': optimizer.state_dict(), 261 | 'mae': cur_mae_no_sto, 262 | 'acc': cur_acc_no_sto, 263 | 'epoch': epoch + 1 264 | }, 265 | os.path.join(args.logdir, 'checkpoint.pth') 266 | ) 267 | print('save checkpoint at {}'.format( 268 | os.path.join(args.logdir, 'checkpoint.pth'))) 269 | if is_best: 270 | torch.save( 271 | { 272 | 'model_state_dict': rpr_model.state_dict(), 273 | 'optim_state_dict': optimizer.state_dict(), 274 | 'mae': cur_mae_no_sto, 275 | 'acc': cur_acc_no_sto, 276 | 'epoch': epoch + 1 277 | }, 278 | os.path.join(args.logdir, 'best.pth') 279 | ) 280 | print('save best model at {}'.format( 281 | os.path.join(args.logdir, 'best.pth'))) 282 | 283 | print('[{}] Epoch: {} end!'.format(get_current_time(), epoch)) 284 | --------------------------------------------------------------------------------