├── codes
└── adience_poe
│ ├── poe
│ ├── __init__.py
│ ├── dataset_manger.py
│ ├── metrics.py
│ ├── utils.py
│ ├── options.py
│ ├── probordiloss.py
│ └── vgg.py
│ ├── log
│ └── placeholder.txt
│ ├── misc
│ └── metric_summary.py
│ ├── README.md
│ ├── scripts
│ ├── test_poe.sh
│ ├── train_poe.sh
│ ├── test_baseline.sh
│ └── train_baseline.sh
│ ├── environment.yaml
│ ├── test.py
│ └── train.py
├── imgs
└── framework.png
└── README.md
/codes/adience_poe/poe/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/codes/adience_poe/log/placeholder.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imgs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Li-Wanhua/POEs/HEAD/imgs/framework.png
--------------------------------------------------------------------------------
/codes/adience_poe/misc/metric_summary.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 |
5 |
6 | def process_one_exp(path):
7 | out_ls = []
8 | for fold_dir in os.listdir(path):
9 | npy_path = os.path.join(path, fold_dir, 'test_result.npy')
10 | if os.path.exists(npy_path) and os.path.isfile(npy_path):
11 | test_result_mat = np.load(npy_path)
12 | out_ls.append(test_result_mat)
13 |
14 | if len(out_ls) > 0:
15 | results_mat = np.vstack(out_ls)
16 | print(results_mat.shape)
17 | means = results_mat.mean(axis=0)
18 | stds = results_mat.std(axis=0)
19 | results = [(mean, std) for mean, std in zip(means, stds)]
20 | print("{}".format(path))
21 | for k, val in zip(["mae", "acc"], results[1:-1:1]):
22 | print("{}: {}+-{}".format(k, *val))
23 |
24 |
25 | #%%
26 | if __name__ == "__main__":
27 | AP = argparse.ArgumentParser()
28 | AP.add_argument("--logdir", type=str, default="./log")
29 | args = AP.parse_args()
30 |
31 | os.chdir(args.logdir)
32 | for root_dir in os.listdir('.'):
33 | if os.path.isdir(root_dir):
34 | process_one_exp(root_dir)
35 |
--------------------------------------------------------------------------------
/codes/adience_poe/README.md:
--------------------------------------------------------------------------------
1 | # POEs
2 |
3 | PyTorch re-implementation of Learning Probabilistic Ordinal Embeddings for Uncertainty-Aware Regression (CVPR 2021)[[project page](https://li-wanhua.github.io/POEs/)]
4 |
5 | # Codes for Adience Dataset
6 | [Adience Dataset](https://talhassner.github.io/home/projects/Adience/Adience-data.html)
7 |
8 | ## Prepare Environment
9 | Simply create a conda environment by:
10 | ```bash
11 | conda create -f environment.yaml
12 | ```
13 | The codes is on test on pytorch==1.0.0, but higher version of pytorch should be ok.
14 |
15 | ## Train
16 | Configure the data-related paths in `scripts/*.sh`, specifically the `--train-images-root`, `--test-images-root`, `--train-data-file`, and `--test-data-file` flags.
17 |
18 | ```bash
19 | # Train POEs / baselines
20 | # model_type should be in ['reg', 'cls', 'rank']
21 | bash ./scripts/train_poe.sh [id_of_gpu='0'] [model_type='cls']
22 | bash ./scripts/train_baseline.sh [id_of_gpu='0'] [model_type='cls']
23 | ```
24 | ## Test
25 | ```bash
26 | # Test POEs / baselines
27 | # model_type should be in ['reg', 'cls', 'rank']
28 | bash ./scripts/test_poe.sh [id_of_gpu='0'] [model_type='cls']
29 | bash ./scripts/test_baseline.sh [id_of_gpu='0'] [model_type='cls']
30 | ```
31 | ## Performance Summary
32 | ```bash
33 | python ./misc/metric_summary.py
34 | ```
35 |
--------------------------------------------------------------------------------
/codes/adience_poe/scripts/test_poe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | gpu='0'
3 |
4 | main_loss_type='cls'
5 | num_output_neurons=8
6 |
7 | # main_loss_type='rank'
8 | # num_output_neurons=16
9 |
10 | # main_loss_type='reg'
11 | # num_output_neurons=1
12 |
13 |
14 | if [[ $# = 1 ]]; then
15 | gpu=${1}
16 | fi
17 |
18 | if [[ $# = 2 ]]; then
19 | gpu=${1}
20 | main_loss_type=${2}
21 | if [[ $main_loss_type = 'cls' ]]; then
22 | num_output_neurons=8
23 | fi
24 | if [[ $main_loss_type = 'rank' ]]; then
25 | num_output_neurons=16
26 | fi
27 | if [[ $main_loss_type = 'reg' ]]; then
28 | num_output_neurons=1
29 | fi
30 | fi
31 |
32 |
33 | for fold in $(seq 0 1 4); do
34 | CUDA_VISIBLE_DEVICES=${gpu} python -u test.py \
35 | --batch-size=32 \
36 | --test-batch-size=32 \
37 | --max-epochs=10 --lr-decay-epoch=7 --lr=0.0001 --fc-lr=0.0001 \
38 | --num-output-neurons=${num_output_neurons} --main-loss-type=${main_loss_type} \
39 | --save-model='./Save_Model' \
40 | --train-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \
41 | --test-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \
42 | --train-data-file=./data_list/test_fold_is_${fold}/age_train.txt \
43 | --test-data-file=./data_list/test_fold_is_${fold}/age_test.txt \
44 | --num-workers=4 --distance='JDistance' \
45 | --alpha-coeff=1e-5 --beta-coeff=1e-4 --margin=5 \
46 | --exp-name=train_val_fold_${fold} \
47 | --logdir=./log/adience_${main_loss_type} \
48 | > ./log/adience_${main_loss_type}_test_fold_${fold}.txt 2>&1
49 | done
--------------------------------------------------------------------------------
/codes/adience_poe/scripts/train_poe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | gpu='0'
3 |
4 | main_loss_type='cls'
5 | num_output_neurons=8
6 |
7 | # main_loss_type='rank'
8 | # num_output_neurons=16
9 |
10 | # main_loss_type='reg'
11 | # num_output_neurons=1
12 |
13 |
14 | if [[ $# = 1 ]]; then
15 | gpu=${1}
16 | fi
17 |
18 | if [[ $# = 2 ]]; then
19 | gpu=${1}
20 | main_loss_type=${2}
21 | if [[ $main_loss_type = 'cls' ]]; then
22 | num_output_neurons=8
23 | fi
24 | if [[ $main_loss_type = 'rank' ]]; then
25 | num_output_neurons=16
26 | fi
27 | if [[ $main_loss_type = 'reg' ]]; then
28 | num_output_neurons=1
29 | fi
30 | fi
31 |
32 |
33 | for fold in $(seq 0 1 4); do
34 | CUDA_VISIBLE_DEVICES=${gpu} python -u train.py \
35 | --batch-size=32 \
36 | --test-batch-size=32 \
37 | --max-epochs=50 --lr-decay-epoch=30 --lr=0.0001 --fc-lr=0.0001 \
38 | --num-output-neurons=${num_output_neurons} --main-loss-type=${main_loss_type} \
39 | --save-model='./Save_Model' \
40 | --train-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \
41 | --test-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \
42 | --train-data-file=./data_list/test_fold_is_${fold}/age_train.txt \
43 | --test-data-file=./data_list/test_fold_is_${fold}/age_val.txt \
44 | --num-workers=4 --distance='JDistance' \
45 | --alpha-coeff=1e-4 --beta-coeff=1e-4 --margin=5 \
46 | --exp-name=train_val_fold_${fold} \
47 | --logdir=./log/adience_${main_loss_type} \
48 | > ./log/adience_${main_loss_type}_train_val_fold_${fold}.txt 2>&1
49 | done
50 |
--------------------------------------------------------------------------------
/codes/adience_poe/scripts/test_baseline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | gpu='0'
3 |
4 | main_loss_type='cls'
5 | num_output_neurons=8
6 |
7 | # main_loss_type='rank'
8 | # num_output_neurons=16
9 |
10 | # main_loss_type='reg'
11 | # num_output_neurons=1
12 |
13 |
14 | if [[ $# = 1 ]]; then
15 | gpu=${1}
16 | fi
17 |
18 | if [[ $# = 2 ]]; then
19 | gpu=${1}
20 | main_loss_type=${2}
21 | if [[ $main_loss_type = 'cls' ]]; then
22 | num_output_neurons=8
23 | fi
24 | if [[ $main_loss_type = 'rank' ]]; then
25 | num_output_neurons=16
26 | fi
27 | if [[ $main_loss_type = 'reg' ]]; then
28 | num_output_neurons=1
29 | fi
30 | fi
31 |
32 | for fold in $(seq 0 1 4); do
33 | CUDA_VISIBLE_DEVICES=${gpu} python -u test.py \
34 | --batch-size=32 \
35 | --test-batch-size=32 \
36 | --max-epochs=10 --lr-decay-epoch=7 --lr=0.0001 --fc-lr=0.0001 \
37 | --num-output-neurons=${num_output_neurons} --main-loss-type=${main_loss_type} \
38 | --save-model='./Save_Model' \
39 | --train-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \
40 | --test-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \
41 | --train-data-file=./data_list/test_fold_is_${fold}/age_train.txt \
42 | --test-data-file=./data_list/test_fold_is_${fold}/age_test.txt \
43 | --num-workers=4 --distance='JDistance' \
44 | --alpha-coeff=1e-5 --beta-coeff=1e-4 --margin=5 \
45 | --exp-name=train_val_fold_${fold} \
46 | --logdir=./log/adience_baseline_${main_loss_type} \
47 | --no-sto \
48 | > ./log/adience_baseline_${main_loss_type}_test_fold_${fold}.txt 2>&1
49 | done
--------------------------------------------------------------------------------
/codes/adience_poe/scripts/train_baseline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | gpu='0'
3 |
4 | main_loss_type='cls'
5 | num_output_neurons=8
6 |
7 | # main_loss_type='rank'
8 | # num_output_neurons=16
9 |
10 | # main_loss_type='reg'
11 | # num_output_neurons=1
12 |
13 |
14 | if [[ $# = 1 ]]; then
15 | gpu=${1}
16 | fi
17 |
18 | if [[ $# = 2 ]]; then
19 | gpu=${1}
20 | main_loss_type=${2}
21 | if [[ $main_loss_type = 'cls' ]]; then
22 | num_output_neurons=8
23 | fi
24 | if [[ $main_loss_type = 'rank' ]]; then
25 | num_output_neurons=16
26 | fi
27 | if [[ $main_loss_type = 'reg' ]]; then
28 | num_output_neurons=1
29 | fi
30 | fi
31 |
32 | for fold in $(seq 0 1 4); do
33 | CUDA_VISIBLE_DEVICES=${gpu} python -u train.py \
34 | --batch-size=32 \
35 | --test-batch-size=32 \
36 | --max-epochs=50 --lr-decay-epoch=30 --lr=0.0001 --fc-lr=0.0001 \
37 | --num-output-neurons=${num_output_neurons} --main-loss-type=${main_loss_type} \
38 | --save-model='./Save_Model' \
39 | --train-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \
40 | --test-images-root='/home/share_data/huangxiaoke/datasets/adience_dataset/aligned' \
41 | --train-data-file=./data_list/test_fold_is_${fold}/age_train.txt \
42 | --test-data-file=./data_list/test_fold_is_${fold}/age_val.txt \
43 | --num-workers=4 --distance='JDistance' \
44 | --alpha-coeff=1e-5 --beta-coeff=1e-4 --margin=5 \
45 | --exp-name=train_val_fold_${fold} \
46 | --logdir=./log/adience_baseline_${main_loss_type} \
47 | --no-sto \
48 | > ./log/adience_baseline_${main_loss_type}_train_val_fold_${fold}.txt 2>&1
49 | done
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # POEs
2 |
3 | [](https://paperswithcode.com/sota/age-estimation-on-adience-1?p=learning-probabilistic-ordinal-embeddings-for)
4 | [](https://paperswithcode.com/sota/historical-color-image-dating-on-hci?p=learning-probabilistic-ordinal-embeddings-for)
5 | [](https://paperswithcode.com/sota/aesthetics-quality-assessment-on-image?p=learning-probabilistic-ordinal-embeddings-for)
6 |
7 |
8 | PyTorch implementation of Learning Probabilistic Ordinal Embeddings for Uncertainty-Aware Regression (CVPR 2021) \[[arXiv](https://arxiv.org/abs/2103.13629)\]\[[Homepage](https://li-wanhua.github.io/POEs/)\]
9 |
10 |
11 |
12 |
13 |
14 | If you find our work useful in your research, please consider citing:
15 | ```
16 | @inproceedings{li2021probabilistic,
17 | title={Learning Probabilistic Ordinal Embeddings for Uncertainty-Aware Regression},
18 | author={Li, Wanhua and Huang, Xiaoke and Lu, Jiwen and Feng, Jianjiang and Zhou, Jie},
19 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
20 | year={2021}
21 | }
22 | ```
23 |
24 | # Codes
25 | [Adience](./codes/adience_poe)
26 |
--------------------------------------------------------------------------------
/codes/adience_poe/poe/dataset_manger.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import torch.utils.data as data
4 | from PIL import Image
5 | import os
6 |
7 |
8 | class dataset_manger(data.Dataset):
9 | def __init__(self, images_root, data_file, transforms=None, num_output_bins=8):
10 | self.images_root = images_root
11 | self.labels = []
12 | self.images_file = []
13 | self.transforms = transforms
14 | self.num_output_bins = num_output_bins
15 | with open(data_file) as fin:
16 | for line in fin:
17 | image_file, image_label = line.split()
18 | self.labels.append(int(image_label))
19 | self.images_file.append(image_file)
20 |
21 | def __getitem__(self, index):
22 | img_file, target = self.images_file[index], self.labels[index]
23 | full_file = os.path.join(self.images_root, img_file)
24 | img = Image.open(full_file)
25 |
26 | if img.mode == 'L':
27 | img = img.convert('RGB')
28 |
29 | if self.transforms:
30 | img = self.transforms(img)
31 |
32 | multi_hot_target = torch.zeros(self.num_output_bins).long()
33 | multi_hot_target[list(range(target))] = 1
34 |
35 | return img, target, multi_hot_target
36 |
37 | def __len__(self):
38 | return len(self.labels)
39 |
40 |
41 | if __name__ == '__main__':
42 | normalize = transforms.Normalize(
43 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
44 | scaler = transforms.Resize((224, 224))
45 | preprocess = transforms.Compose([scaler, transforms.ToTensor(), normalize])
46 | t = dataset_manger('/home/share_data/age/CVPR19/datasets/MORPH', './data_list/ET_proto_val.txt', preprocess)
47 | train_loader = data.DataLoader(t, batch_size=2, shuffle=False)
48 | for data_, index, tar in train_loader:
49 | print(data_)
50 | print(index)
51 | print(tar)
52 | break
53 |
--------------------------------------------------------------------------------
/codes/adience_poe/poe/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 |
6 | def cal_mae_acc_rank(logits, targets, is_sto=True):
7 | if is_sto:
8 | r_dim, s_dim, out_dim = logits.shape
9 | assert out_dim % 2 == 0, "outdim {} wrong".format(out_dim)
10 | logits = logits.view(r_dim, s_dim, out_dim / 2, 2)
11 | logits = torch.argmax(logits, dim=-1)
12 | logits = torch.sum(logits, dim=-1)
13 | logits = torch.mean(logits.float(), dim=0)
14 | logits = logits.cpu().data.numpy()
15 | targets = targets.cpu().data.numpy()
16 | mae = sum(abs(logits - targets)) * 1.0 / len(targets)
17 | acc = sum(np.rint(logits) == targets) * 1.0 / len(targets)
18 | else:
19 | s_dim, out_dim = logits.shape
20 | assert out_dim % 2 == 0, "outdim {} wrong".format(out_dim)
21 | logits = logits.view(s_dim, out_dim / 2, 2)
22 | logits = torch.argmax(logits, dim=-1)
23 | logits = torch.sum(logits, dim=-1)
24 | logits = logits.cpu().data.numpy()
25 | targets = targets.cpu().data.numpy()
26 | mae = sum(abs(logits - targets)) * 1.0 / len(targets)
27 | acc = sum(np.rint(logits) == targets) * 1.0 / len(targets)
28 | return mae, acc
29 |
30 |
31 | def cal_mae_acc_reg(logits, targets, is_sto=True):
32 | if is_sto:
33 | logits = logits.mean(dim=0)
34 |
35 | assert logits.view(-1).shape == targets.shape, "logits {}, targets {}".format(
36 | logits.shape, targets.shape)
37 |
38 | logits = logits.cpu().data.numpy().reshape(-1)
39 | targets = targets.cpu().data.numpy()
40 | mae = sum(abs(logits - targets)) * 1.0 / len(targets)
41 | acc = sum(np.rint(logits) == targets) * 1.0 / len(targets)
42 |
43 | return mae, acc
44 |
45 |
46 | def cal_mae_acc_cls(logits, targets, is_sto=True):
47 | if is_sto:
48 | r_dim, s_dim, out_dim = logits.shape
49 | label_arr = torch.arange(0, out_dim).float().cuda()
50 | probs = F.softmax(logits, -1)
51 | exp = torch.sum(probs * label_arr, dim=-1)
52 | exp = torch.mean(exp, dim=0)
53 | max_a = torch.mean(probs, dim=0)
54 | max_data = max_a.cpu().data.numpy()
55 | max_data = np.argmax(max_data, axis=1)
56 | target_data = targets.cpu().data.numpy()
57 | exp_data = exp.cpu().data.numpy()
58 | mae = sum(abs(exp_data - target_data)) * 1.0 / len(target_data)
59 | acc = sum(np.rint(exp_data) == target_data) * 1.0 / len(target_data)
60 |
61 | else:
62 | s_dim, out_dim = logits.shape
63 | probs = F.softmax(logits, -1)
64 | probs_data = probs.cpu().data.numpy()
65 | target_data = targets.cpu().data.numpy()
66 | max_data = np.argmax(probs_data, axis=1)
67 | label_arr = np.array(range(out_dim))
68 | exp_data = np.sum(probs_data * label_arr, axis=1)
69 | mae = sum(abs(exp_data - target_data)) * 1.0 / len(target_data)
70 | acc = sum(np.rint(exp_data) == target_data) * 1.0 / len(target_data)
71 |
72 | return mae, acc
73 |
74 |
75 | def get_metric(main_loss_type):
76 | assert main_loss_type in ['cls', 'reg', 'rank'], \
77 | "main_loss_type not in ['cls', 'reg', 'rank'], loss type {%s}" % (
78 | main_loss_type)
79 | if main_loss_type == 'cls':
80 | return cal_mae_acc_cls
81 | elif main_loss_type == 'reg':
82 | return cal_mae_acc_reg
83 | elif main_loss_type == 'rank':
84 | return cal_mae_acc_rank
85 | else:
86 | raise AttributeError('main loss type: {}'.format(main_loss_type))
87 |
--------------------------------------------------------------------------------
/codes/adience_poe/poe/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from datetime import datetime
4 |
5 |
6 | class AverageMeter(object):
7 | """Computes and stores the average and current value"""
8 |
9 | def __init__(self, max_count=100):
10 | self.reset(max_count)
11 |
12 | def reset(self, max_count):
13 | self.val = 0
14 | self.avg = 0
15 | self.data_container = []
16 | self.max_count = max_count
17 |
18 | def update(self, val):
19 | self.val = val
20 | if(len(self.data_container) < self.max_count):
21 | self.data_container.append(val)
22 | self.avg = sum(self.data_container) * 1.0 / \
23 | len(self.data_container)
24 | else:
25 | self.data_container.pop(0)
26 | self.data_container.append(val)
27 | self.avg = sum(self.data_container) * 1.0 / self.max_count
28 |
29 |
30 | def is_fc(para_name):
31 | split_name = para_name.split('.')
32 | if split_name[-2] == 'final':
33 | return True
34 | else:
35 | return False
36 |
37 |
38 | def display_lr(optimizer):
39 | for param_group in optimizer.param_groups:
40 | print(param_group['lr'], param_group['initial_lr'])
41 |
42 |
43 | def load_model(unload_model, args):
44 | if not os.path.exists(args.save_model):
45 | os.makedirs(args.save_model)
46 | print(args.save_model, 'is created!')
47 | if not os.path.exists(os.path.join(args.save_model, 'checkpoint.txt')):
48 | f = open(os.path.join(args.save_model, 'checkpoint.txt'), 'w')
49 | print('checkpoint', 'is created!')
50 |
51 | start_index = 0
52 | with open(os.path.join(args.save_model, 'checkpoint.txt'), 'r') as fin:
53 | lines = fin.readlines()
54 | if len(lines) > 0:
55 | model_path, model_index = lines[0].split()
56 | print('Resuming from', model_path)
57 | if int(model_index) == 0:
58 | unload_model_dict = unload_model.state_dict()
59 |
60 | pretrained_dict = torch.load(
61 | os.path.join(args.save_model, model_path))
62 | pretrained_dict['emd.0.weight'] = pretrained_dict['classifier.3.weight']
63 | pretrained_dict['emd.0.bias'] = pretrained_dict['classifier.3.bias']
64 | pretrained_dict['final.weight'] = pretrained_dict['classifier.6.weight']
65 | pretrained_dict['final.bias'] = pretrained_dict['classifier.6.bias']
66 |
67 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (
68 | k in unload_model_dict and pretrained_dict[k].shape == unload_model_dict[k].shape)}
69 | print(len(pretrained_dict))
70 | for dict_inx, (k, v) in enumerate(pretrained_dict.items()):
71 | print(dict_inx, k, v.shape)
72 | unload_model_dict.update(pretrained_dict)
73 | unload_model.load_state_dict(unload_model_dict)
74 | else:
75 | unload_model.load_state_dict(torch.load(
76 | os.path.join(args.save_model, model_path)))
77 |
78 | start_index = int(model_index) + 1
79 | return start_index
80 |
81 |
82 | def save_model(tosave_model, epoch, args):
83 | model_epoch = '%04d' % (epoch)
84 | model_path = 'model-' + model_epoch + '.pth'
85 | save_path = os.path.join(args.save_model, model_path)
86 | torch.save(tosave_model.state_dict(), save_path)
87 | with open(os.path.join(args.save_model, 'checkpoint.txt'), 'w') as fin:
88 | fin.write(model_path + ' ' + str(epoch) + '\n')
89 |
90 |
91 | def get_current_time():
92 | _now = datetime.now()
93 | _now = str(_now)[:-7]
94 | return _now
95 |
--------------------------------------------------------------------------------
/codes/adience_poe/environment.yaml:
--------------------------------------------------------------------------------
1 | name: p2t1
2 | channels:
3 | - menpo
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - attrs=20.3.0=pyhd3eb1b0_0
9 | - backports=1.0=py_2
10 | - backports.functools_lru_cache=1.5=py_2
11 | - backports.shutil_get_terminal_size=1.0.0=pyhd3eb1b0_3
12 | - backports_abc=0.5=py_0
13 | - blas=1.0=mkl
14 | - bleach=3.2.1=py_0
15 | - ca-certificates=2021.1.19=h06a4308_0
16 | - certifi=2020.6.20=pyhd3eb1b0_3
17 | - cffi=1.12.3=py27h2e261b9_0
18 | - cloudpickle=1.2.1=py_0
19 | - configparser=4.0.2=py27_0
20 | - cuda100=1.0=0
21 | - cycler=0.10.0=py27_0
22 | - cytoolz=0.10.0=py27h7b6447c_0
23 | - dask-core=1.2.2=py_0
24 | - dbus=1.13.6=h746ee38_0
25 | - decorator=4.4.0=py27_1
26 | - defusedxml=0.6.0=py_0
27 | - entrypoints=0.3=py27_0
28 | - enum34=1.1.6=py27_1
29 | - expat=2.2.6=he6710b0_0
30 | - fontconfig=2.13.0=h9420a91_0
31 | - freetype=2.9.1=h8a8886c_1
32 | - functools32=3.2.3.2=py27_1
33 | - futures=3.3.0=py27_0
34 | - get_terminal_size=1.0.0=haa9412d_0
35 | - glib=2.56.2=hd408876_0
36 | - gst-plugins-base=1.14.0=hbbd80ab_1
37 | - gstreamer=1.14.0=hb453b48_1
38 | - icu=58.2=h9c2bf20_1
39 | - imageio=2.5.0=py27_0
40 | - intel-openmp=2019.4=243
41 | - ipykernel=4.9.0=py27_1
42 | - ipython=5.3.0=py27_0
43 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
44 | - ipywidgets=7.4.2=py27_0
45 | - jinja2=2.11.2=pyhd3eb1b0_0
46 | - jpeg=9b=h024ee3a_2
47 | - jsonschema=3.0.2=py27_0
48 | - jupyter=1.0.0=py27_7
49 | - jupyter_client=5.1.0=py27_0
50 | - jupyter_console=5.2.0=py27_0
51 | - jupyter_core=4.5.0=py_0
52 | - kiwisolver=1.1.0=py27he6710b0_0
53 | - libedit=3.1.20181209=hc058e9b_0
54 | - libffi=3.2.1=hd88cf55_4
55 | - libgcc-ng=9.1.0=hdf63c60_0
56 | - libgfortran-ng=7.3.0=hdf63c60_0
57 | - libpng=1.6.37=hbc83047_0
58 | - libsodium=1.0.18=h7b6447c_0
59 | - libstdcxx-ng=9.1.0=hdf63c60_0
60 | - libtiff=4.0.10=h2733197_2
61 | - libuuid=1.0.3=h1bed415_2
62 | - libxcb=1.13=h1bed415_1
63 | - libxml2=2.9.9=hea5a465_1
64 | - markupsafe=1.1.1=py27h7b6447c_0
65 | - matplotlib=2.2.3=py27hb69df0a_0
66 | - mistune=0.8.4=py27h7b6447c_0
67 | - mkl=2019.4=243
68 | - mkl-service=2.0.2=py27h7b6447c_0
69 | - mkl_fft=1.0.14=py27ha843d7b_0
70 | - mkl_random=1.0.2=py27hd81dba3_0
71 | - nbconvert=5.5.0=py_0
72 | - nbformat=4.4.0=py27_0
73 | - ncurses=6.1=he6710b0_1
74 | - networkx=2.2=py27_1
75 | - ninja=1.9.0=py27hfd86e86_0
76 | - notebook=5.0.0=py27_0
77 | - olefile=0.46=py27_0
78 | - opencv=2.4.11=nppy27_0
79 | - openssl=1.1.1i=h27cfd23_0
80 | - packaging=20.8=pyhd3eb1b0_0
81 | - pandas=0.24.2=py27he6710b0_0
82 | - pandoc=2.11=hb0f4dca_0
83 | - pandocfilters=1.4.2=py27_0
84 | - pathlib2=2.3.5=py27_0
85 | - pcre=8.43=he6710b0_0
86 | - pexpect=4.8.0=pyhd3eb1b0_3
87 | - pickleshare=0.7.5=py27_0
88 | - pillow=6.1.0=py27h34e0f95_0
89 | - pip=19.2.2=py27_0
90 | - prompt_toolkit=1.0.15=py27_0
91 | - ptyprocess=0.7.0=pyhd3eb1b0_2
92 | - pycparser=2.19=py27_0
93 | - pygments=2.3.1=py27_0
94 | - pyparsing=2.4.2=py_0
95 | - pyqt=5.9.2=py27h05f1152_2
96 | - pyrsistent=0.14.11=py27h7b6447c_0
97 | - python=2.7.16=h8b3fad2_5
98 | - python-dateutil=2.8.0=py27_0
99 | - pytorch=1.0.0=py2.7_cuda10.0.130_cudnn7.4.1_1
100 | - pytz=2019.2=py_0
101 | - pywavelets=1.0.3=py27hdd07704_1
102 | - pyzmq=16.0.2=py27_0
103 | - qt=5.9.7=h5867ecd_1
104 | - qtconsole=4.7.7=py_0
105 | - qtpy=1.9.0=py_0
106 | - readline=7.0=h7b6447c_5
107 | - scandir=1.10.0=py27h7b6447c_0
108 | - scikit-image=0.14.2=py27he6710b0_0
109 | - scikit-learn=0.20.3=py27hd81dba3_0
110 | - scipy=1.2.1=py27h7c811a0_0
111 | - setuptools=41.0.1=py27_0
112 | - simplegeneric=0.8.1=py27_2
113 | - singledispatch=3.4.0.3=py27_0
114 | - sip=4.19.8=py27hf484d3e_0
115 | - six=1.12.0=py27_0
116 | - sqlite=3.29.0=h7b6447c_0
117 | - subprocess32=3.5.4=py27h7b6447c_0
118 | - terminado=0.6=py27_0
119 | - testpath=0.4.4=py_0
120 | - tk=8.6.8=hbc83047_0
121 | - toolz=0.10.0=py_0
122 | - torchvision=0.2.2=py_3
123 | - tornado=5.1.1=py27h7b6447c_0
124 | - traitlets=4.3.3=py27_0
125 | - wcwidth=0.2.5=py_0
126 | - webencodings=0.5.1=py27_1
127 | - wheel=0.33.4=py27_0
128 | - widgetsnbextension=3.4.2=py27_0
129 | - xz=5.2.4=h14c3975_4
130 | - zeromq=4.1.5=0
131 | - zlib=1.2.11=h7b6447c_3
132 | - zstd=1.3.7=h0b5b093_0
133 | - pip:
134 | - autopep8==1.5.4
135 | - contextlib2==0.6.0.post1
136 | - flake8==3.8.4
137 | - importlib-metadata==2.1.1
138 | - mccabe==0.6.1
139 | - numpy==1.16.0
140 | - pycodestyle==2.6.0
141 | - pyflakes==2.2.0
142 | - thop==0.0.30-2101242100
143 | - toml==0.10.2
144 | - tqdm==4.56.0
145 | - typing==3.7.4.3
146 | - zipp==1.2.0
147 | prefix: /home/xiaoke_wh/anaconda3/envs/p2t1
148 |
--------------------------------------------------------------------------------
/codes/adience_poe/poe/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 |
5 | # Training settings
6 | def get_args():
7 | parser = argparse.ArgumentParser(
8 | description='Probabilistic Ordinal Embedding (POE) for age estimation for Adience dataset')
9 | parser.add_argument('--batch-size', type=int, default=32, metavar='N',
10 | help='input batch size for training (default: 32)')
11 | parser.add_argument('--test-batch-size', type=int, default=32, metavar='N',
12 | help='input batch size for testing (default: 32)')
13 | parser.add_argument('--max-epochs', type=int, default=50, metavar='N',
14 | help='number of epochs to train (default: 50)')
15 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
16 | help='learning rate (default: 0.0001)')
17 | parser.add_argument('--fc-lr', type=float, default=0.0001,
18 | help='fc layer learning rate (default: 0.0001)')
19 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
20 | help='SGD momentum (default: 0.9)')
21 | parser.add_argument('--weight-decay', type=float, default=5e-4, metavar='M',
22 | help='SGD weight decay (default: 5e-4)')
23 | parser.add_argument('--print-freq', default=10, type=int,
24 | metavar='N', help='print frequency (default: 10)')
25 | parser.add_argument('--no-cuda', action='store_true', default=False,
26 | help='disables CUDA training')
27 | parser.add_argument('--AverageMeter-MaxCount', default=100, type=int,
28 | help='maximum capacity for AverageMeter(default: 100)')
29 | parser.add_argument('--num-workers', default=2, type=int,
30 | help='number of load data workers (default: 2)')
31 | parser.add_argument('--train-images-root', type=str, default='/home/share_data/age/CVPR19/datasets/MORPH',
32 | help='images root for train dataset')
33 | parser.add_argument('--test-images-root', type=str, default='/home/share_data/age/CVPR19/datasets/MORPH',
34 | help='images root for test dataset')
35 | parser.add_argument('--train-data-file', type=str, default='./data_list/ET_proto_train.txt',
36 | help='data file for train dataset')
37 | parser.add_argument('--test-data-file', type=str, default='./data_list/ET_proto_val.txt',
38 | help='data file for test dataset')
39 | parser.add_argument('--distance', type=str, default='JDistance',
40 | help='distance metric between two gaussian distribution')
41 | parser.add_argument('--alpha-coeff', type=float, default=1e-5, metavar='M',
42 | help='alpha_coeff (default: 0)')
43 | parser.add_argument('--beta-coeff', type=float, default=1e-4, metavar='M',
44 | help='beta_coeff (default: 1.0)')
45 | parser.add_argument('--margin', type=float, default=5, metavar='M',
46 | help='margin (default: 1.0)')
47 | parser.add_argument('--logdir', type=str, default='./log/',
48 | help='where you save log.')
49 | parser.add_argument('--exp-name', type=str, default='exp',
50 | help='name of your experiment.')
51 | parser.add_argument('--save-freq', default=10, type=int,
52 | metavar='N', help='save checkpoint frequency (default: 10)')
53 | parser.add_argument('--lr-decay-epoch', type=str, default='30',
54 | help='epochs at which learning rate decays. default is 30.')
55 | parser.add_argument('--lr-decay', type=float, default=0.1,
56 | help='decay rate of learning rate. default is 0.1.')
57 | parser.add_argument('--no-sto', action='store_true', default=False,
58 | help='not using stochastic sampling when training or testing.')
59 | parser.add_argument('--test-only', action='store_true',
60 | default=False, help='test your model, no training loop.')
61 | parser.add_argument('--num-output-neurons', type=int, default=1,
62 | help='number of ouput neurons of your model, note that for `reg` model we use 1; `cls` model we use `num_output_classes`; and for `rank` model we use `num_output_class` * 2.')
63 | parser.add_argument('--main-loss-type', type=str,
64 | default='reg', help='loss type in [cls, reg, rank].')
65 | parser.add_argument('--max-t', type=int, default=50,
66 | help='number of samples during sto.')
67 | parser.add_argument('--save-model', type=str, default='./Saved_Model/',
68 | help='where you save model')
69 | parser.add_argument('--checkpoint-path', type=str, default=None,
70 | help="checkpoint to be loaded when testing.")
71 |
72 | args = parser.parse_args()
73 |
74 | args.cuda = not args.no_cuda and torch.cuda.is_available()
75 | args.use_sto = True
76 |
77 | if args.no_sto:
78 | args.use_sto = False
79 | args.alpha_coeff = .0
80 | args.beta_coeff = .0
81 | args.margin = .0
82 | print("no stochastic sampling when training or testing, baseline set up")
83 |
84 | print('------------ Options -------------')
85 | for k, v in sorted(vars(args).items()):
86 | print('{}: {}'.format(str(k), str(v)))
87 | print('-------------- End ----------------')
88 |
89 | return args
90 |
91 |
92 | if __name__ == "__main__":
93 | args = get_args()
94 |
--------------------------------------------------------------------------------
/codes/adience_poe/poe/probordiloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def BhattacharyyaDistance(u1, sigma1, u2, sigma2):
7 | sigma_mean = (sigma1 + sigma2) / 2.0
8 | sigma_inv = 1.0 / (sigma_mean)
9 | dis1 = torch.sum(torch.pow(u1 - u2, 2) * sigma_inv, dim=1) / 8.0
10 | dis2 = 0.5 * (torch.sum(torch.log(sigma_mean), dim=1) -
11 | 0.5 * (torch.sum(torch.log(sigma1), dim=1) + torch.sum(torch.log(sigma2), dim=1)))
12 | return dis1 + dis2
13 |
14 |
15 | def HellingerDistance(u1, sigma1, u2, sigma2):
16 | return torch.pow(1.0 - torch.exp(-BhattacharyyaDistance(u1, sigma1, u2, sigma2)), 0.5)
17 |
18 |
19 | def WassersteinDistance(u1, sigma1, u2, sigma2):
20 | dis1 = torch.sum(torch.pow(u1 - u2, 2), dim=1)
21 | dis2 = torch.sum(torch.pow(torch.pow(sigma1, 0.5) -
22 | torch.pow(sigma2, 0.5), 2), dim=1)
23 | return torch.pow(dis1 + dis2, 0.5)
24 |
25 |
26 | def GeodesicDistance(u1, sigma1, u2, sigma2):
27 | u_dis = torch.pow(u1 - u2, 2)
28 | std1 = sigma1.sqrt()
29 | std2 = sigma2.sqrt()
30 |
31 | sig_dis = torch.pow(std1 - std2, 2)
32 | sig_sum = torch.pow(std1 + std2, 2)
33 | delta = torch.div(u_dis + 2 * sig_dis, u_dis + 2 * sig_sum).sqrt()
34 | return torch.sum(torch.pow(torch.log((1.0 + delta) / (1.0 - delta)), 2) * 2, dim=1).sqrt()
35 |
36 |
37 | def ForwardKLDistance(u1, sigma1, u2, sigma2):
38 | return -0.5 * torch.sum(torch.log(sigma1) - torch.log(sigma2) - torch.div(sigma1, sigma2)
39 | - torch.div(torch.pow(u1 - u2, 2), sigma2) + 1, dim=1)
40 |
41 |
42 | def ReverseKLDistance(u2, sigma2, u1, sigma1):
43 | return -0.5 * torch.sum(torch.log(sigma1) - torch.log(sigma2) - torch.div(sigma1, sigma2)
44 | - torch.div(torch.pow(u1 - u2, 2), sigma2) + 1, dim=1)
45 |
46 |
47 | def JDistance(u1, sigma1, u2, sigma2):
48 | return ForwardKLDistance(u1, sigma1, u2, sigma2) + ForwardKLDistance(u2, sigma2, u1, sigma1)
49 |
50 |
51 | class ProbOrdiLoss(nn.Module):
52 | def __init__(self, distance='Bhattacharyya', alpha_coeff=0, beta_coeff=0, margin=0, main_loss_type='cls'):
53 | super(ProbOrdiLoss, self).__init__()
54 | self.alpha_coeff = alpha_coeff
55 | self.beta_coeff = beta_coeff
56 | self.margin = margin
57 | self.zeros = torch.zeros(1).cuda()
58 |
59 | assert main_loss_type in ['cls', 'reg', 'rank'], \
60 | "main_loss_type not in ['cls', 'reg', 'rank'], loss type {%s}" % (
61 | main_loss_type)
62 | self.main_loss_type = main_loss_type
63 |
64 | if distance == 'Bhattacharyya':
65 | self.distrance_f = BhattacharyyaDistance
66 | elif distance == 'Wasserstein':
67 | self.distrance_f = WassersteinDistance
68 | elif distance == 'JDistance':
69 | self.distrance_f = JDistance
70 | elif distance == 'ForwardKLDistance':
71 | self.distrance_f = ForwardKLDistance
72 | elif distance == 'HellingerDistance':
73 | self.distrance_f = HellingerDistance
74 | elif distance == 'GeodesicDistance':
75 | self.distrance_f = GeodesicDistance
76 | elif distance == 'ReverseKLDistance':
77 | self.distrance_f = ReverseKLDistance
78 | else:
79 | print('ERROR: this distance is not supported!')
80 | self.distrance_f = None
81 |
82 | def forward(self, logit, emb, log_var, target, mh_target=None, use_sto=True):
83 | class_dim = logit.shape[-1]
84 | sample_size = logit.shape[0] # reparameterized with max_t samples
85 |
86 | if self.main_loss_type == 'cls':
87 | if use_sto:
88 | CEloss = F.cross_entropy(
89 | logit.view(-1, class_dim), target.repeat(sample_size))
90 | else:
91 | CEloss = F.cross_entropy(logit.view(-1, class_dim), target)
92 | elif self.main_loss_type == 'reg':
93 | if use_sto:
94 | CEloss = F.mse_loss(
95 | logit.view(-1), target.repeat(sample_size).to(dtype=logit.dtype))
96 | else:
97 | CEloss = F.mse_loss(
98 | logit.view(-1), target.to(dtype=logit.dtype))
99 | elif self.main_loss_type == 'rank':
100 | assert len(mh_target.shape) == 2, "target shape {} wrong".format(
101 | mh_target.shape)
102 | assert class_dim % 2 == 0, "class dim {} wrong".format(class_dim)
103 | if use_sto:
104 | CEloss = F.cross_entropy(
105 | logit.view(-1, 2), mh_target.repeat([sample_size, 1]).view(-1))
106 | else:
107 | CEloss = F.cross_entropy(logit.view(-1, 2), mh_target.view(-1))
108 | else:
109 | raise AttributeError(
110 | 'main loss type: {}'.format(self.main_loss_type))
111 |
112 | KLLoss = torch.mean(torch.sum(torch.pow(emb, 2) +
113 | torch.exp(log_var) - log_var - 1.0, dim=1) * 0.5)
114 |
115 | var = torch.exp(log_var)
116 |
117 | batch_size = emb.shape[0]
118 | dims = emb.shape[1]
119 | target_dis = torch.abs(
120 | target.view(-1, 1).repeat(1, batch_size) - target.view(1, -1).repeat(batch_size, 1))
121 | anchor_pos = [i for i in range(batch_size)]
122 | second_pos = [(i + 1) % batch_size for i in anchor_pos]
123 | target_dis = torch.abs(target_dis - torch.abs(
124 | target[anchor_pos] - target[second_pos]).view(-1, 1).repeat(1, batch_size))
125 | offset_m = torch.eye(batch_size).cuda().to(dtype=target_dis.dtype)
126 | target_dis = target_dis + offset_m * 1000
127 | target_dis[target_dis == 0] = 700
128 | thrid_pos = torch.argmin(target_dis, dim=1)
129 |
130 | anchor_sign = torch.sign(torch.abs(
131 | target[anchor_pos] - target[second_pos]) - torch.abs(target[anchor_pos] - target[thrid_pos]))
132 |
133 | emb_dis_12 = self.distrance_f(
134 | emb[anchor_pos, :], var[anchor_pos, :], emb[second_pos, :], var[second_pos, :])
135 | emb_dis_13 = self.distrance_f(
136 | emb[anchor_pos, :], var[anchor_pos, :], emb[thrid_pos, :], var[thrid_pos, :])
137 |
138 | anchor_cons = (emb_dis_13 - emb_dis_12) * \
139 | anchor_sign.float() + self.margin
140 |
141 | loss_anchor = torch.max(self.zeros, anchor_cons) * \
142 | torch.abs(anchor_sign).float()
143 | loss_mask = (anchor_cons > 0).to(dtype=anchor_sign.dtype)
144 | if sum(torch.abs(anchor_sign) * loss_mask) > 0:
145 | triple_loss = torch.sum(loss_anchor) / \
146 | sum(torch.abs(anchor_sign) * loss_mask)
147 | else:
148 | triple_loss = torch.tensor(0.0).cuda()
149 |
150 | return CEloss, KLLoss * self.alpha_coeff, triple_loss * self.beta_coeff, CEloss + self.alpha_coeff * KLLoss + self.beta_coeff * triple_loss
151 |
--------------------------------------------------------------------------------
/codes/adience_poe/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | from torch.autograd import Variable
4 | from poe import dataset_manger
5 | import os
6 | import time
7 | from poe import vgg
8 | from poe.probordiloss import ProbOrdiLoss
9 | from poe.metrics import get_metric
10 | from poe.utils import get_current_time, AverageMeter
11 | from poe.options import get_args
12 | import numpy as np
13 |
14 |
15 | def test_epoch(epoch):
16 |
17 | batch_time = AverageMeter(args.AverageMeter_MaxCount)
18 | loss1es = AverageMeter(args.AverageMeter_MaxCount)
19 | loss2es = AverageMeter(args.AverageMeter_MaxCount)
20 | loss3es = AverageMeter(args.AverageMeter_MaxCount)
21 | losses = AverageMeter(args.AverageMeter_MaxCount)
22 | mae_sto = AverageMeter(args.AverageMeter_MaxCount)
23 | mae_no_sto = AverageMeter(args.AverageMeter_MaxCount)
24 | acc_sto = AverageMeter(args.AverageMeter_MaxCount)
25 | acc_no_sto = AverageMeter(args.AverageMeter_MaxCount)
26 |
27 | rpr_model.eval()
28 | total = 0
29 | all_mae_sto = .0
30 | all_mae_no_sto = .0
31 | all_acc_sto = .0
32 | all_acc_no_sto = .0
33 |
34 | end_time = time.time()
35 | for batch_idx, (inputs, targets, mh_targets) in enumerate(testloader):
36 | inputs, targets, mh_targets = inputs.cuda(), targets.cuda(), mh_targets.cuda()
37 | inputs, targets = Variable(
38 | inputs, requires_grad=True), Variable(targets)
39 |
40 | logit, emb, log_var = rpr_model(inputs, max_t=args.max_t, use_sto=True)
41 | logit_no_sto, emb, log_var = rpr_model(
42 | inputs, max_t=args.max_t, use_sto=False)
43 | if args.use_sto:
44 | loss1, loss2, loss3, loss = criterion(
45 | logit, emb, log_var, targets, mh_targets, use_sto=args.use_sto)
46 | else:
47 | loss1, loss2, loss3, loss = criterion(
48 | logit_no_sto, emb, log_var, targets, mh_targets, use_sto=args.use_sto)
49 |
50 | total += targets.size(0)
51 |
52 | batch_mae_sto, batch_acc_sto = cal_mae_acc(logit, targets, True)
53 | batch_mae_no_sto, batch_acc_no_sto = cal_mae_acc(
54 | logit_no_sto, targets, False)
55 |
56 | loss1es.update(loss1.cpu().data.numpy())
57 | loss2es.update(loss2.cpu().data.numpy())
58 | loss3es.update(loss3.cpu().data.numpy())
59 | losses.update(loss.cpu().data.numpy())
60 |
61 | mae_sto.update(batch_mae_sto)
62 | mae_no_sto.update(batch_mae_no_sto)
63 | acc_sto.update(batch_acc_sto)
64 | acc_no_sto.update(batch_acc_no_sto)
65 |
66 | all_mae_sto = all_mae_sto + batch_mae_sto * targets.size(0)
67 | all_mae_no_sto = all_mae_no_sto + batch_mae_no_sto * targets.size(0)
68 | all_acc_sto = all_acc_sto + batch_acc_sto * targets.size(0)
69 | all_acc_no_sto = all_acc_no_sto + batch_acc_no_sto * targets.size(0)
70 |
71 | batch_time.update(time.time() - end_time)
72 | end_time = time.time()
73 |
74 | if batch_idx % args.print_freq == 0:
75 | print('Test: [%d/%d]\t'
76 | 'Time %.3f (%.3f)\t'
77 | 'Loss1 %.3f (%.3f)\t'
78 | 'Loss2 %.3f (%.3f)\t'
79 | 'Loss3 %.3f (%.3f)\t'
80 | 'Loss %.3f (%.3f)\t'
81 | 'MAE_sto %.3f (%.3f)\t'
82 | 'MAE_no_sto %.3f (%.3f)\t'
83 | 'ACC_sto %.3f (%.3f)\t'
84 | 'ACC_no_sto %.3f (%.3f)\t' % (batch_idx, len(testloader),
85 | batch_time.val, batch_time.avg, loss1es.val, loss1es.avg, loss2es.val, loss2es.avg, loss3es.val, loss3es.avg,
86 | losses.val, losses.avg, mae_sto.val, mae_sto.avg, mae_no_sto.val, mae_no_sto.avg, acc_sto.val, acc_sto.avg, acc_no_sto.val, acc_no_sto.avg))
87 |
88 | print('Test: MAE_sto: %.3f MAE_no_sto: %.3f ACC_sto: %.3f ACC_no_sto: %.3f' % (all_mae_sto *
89 | 1.0 / total, all_mae_no_sto * 1.0 / total, all_acc_sto * 1.0 / total, all_acc_no_sto * 1.0 / total))
90 |
91 | return all_mae_sto * 1.0 / total, all_mae_no_sto * 1.0 / total, all_acc_sto * 1.0 / total, all_acc_no_sto * 1.0 / total
92 |
93 |
94 | if __name__ == "__main__":
95 | # Training settings
96 | # ---------------------------------
97 | args = get_args()
98 |
99 | torch.manual_seed(0)
100 | np.random.seed(0)
101 | torch.backends.cudnn.deterministic = True
102 | torch.backends.cudnn.benchmark = False
103 | # ---------------------------------
104 |
105 | # dataset prepare
106 | # ---------------------------------
107 | transform_test = transforms.Compose([
108 | transforms.Resize((256, 256)),
109 | transforms.CenterCrop(224),
110 | transforms.ToTensor(),
111 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
112 | ])
113 |
114 | testset = dataset_manger.dataset_manger(
115 | images_root=args.test_images_root, data_file=args.test_data_file, transforms=transform_test)
116 |
117 | testloader = torch.utils.data.DataLoader(
118 | testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_workers)
119 | # --------------------------------------
120 |
121 | # define Model
122 | # ----------------------------------------
123 | rpr_model = vgg.vgg16(num_output_neurons=args.num_output_neurons)
124 | rpr_model.cuda()
125 | # rpr_model = torch.nn.DataParallel(rpr_model)
126 | print("Model finshed")
127 | # ---------------------------------
128 |
129 | # define loss
130 | # ----------------------------------------
131 | criterion = ProbOrdiLoss(distance=args.distance, alpha_coeff=args.alpha_coeff,
132 | beta_coeff=args.beta_coeff, margin=args.margin, main_loss_type=args.main_loss_type)
133 | criterion.cuda()
134 |
135 | # define Metric
136 | # ----------------------------------------
137 | cal_mae_acc = get_metric(args.main_loss_type)
138 | # ---------------------------------
139 |
140 | args.logdir = os.path.join(args.logdir, args.exp_name)
141 | if not os.path.exists(args.logdir):
142 | os.makedirs(args.logdir)
143 | print('log dir [{}] {}'.format(args.logdir, 'is created!'))
144 |
145 | if args.checkpoint_path:
146 | rpr_model.load_state_dict(torch.load(
147 | args.checkpoint_path)['model_state_dict'])
148 | print("Load model from checkpoint: {}".foramt(args.checkpoint_path))
149 | else:
150 | rpr_model.load_state_dict(torch.load(os.path.join(
151 | args.logdir, 'checkpoint.pth'))['model_state_dict'])
152 | print("Load model from checkpoint: {}".format(
153 | os.path.join(args.logdir, 'best.pth')))
154 |
155 | print("Start Testing...")
156 | with torch.no_grad():
157 | cur_mae_sto, cur_mae_no_sto, cur_acc_sto, cur_acc_no_sto = test_epoch(
158 | 0)
159 |
160 | print('[{}] end!'.format(get_current_time()))
161 | np.save(
162 | os.path.join(args.logdir, 'test_result.npy'),
163 | np.array([cur_mae_sto, cur_mae_no_sto, cur_acc_sto, cur_acc_no_sto])
164 | )
165 |
--------------------------------------------------------------------------------
/codes/adience_poe/poe/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 | import math
4 | import torch
5 |
6 | __all__ = [
7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
8 | 'vgg19_bn', 'vgg19',
9 | ]
10 |
11 |
12 | model_urls = {
13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
21 | }
22 |
23 |
24 | class ProbVar(nn.Module):
25 | def __init__(self, dim):
26 | super(ProbVar, self).__init__()
27 | self.trans = nn.Sequential(
28 | nn.Linear(dim, dim),
29 | nn.BatchNorm1d(dim),
30 | nn.ReLU(True),
31 | nn.Linear(dim, dim),
32 | )
33 | self.act = nn.Linear(1, 1)
34 | self.dim = dim
35 |
36 | def forward(self, x):
37 | x = self.trans(x)
38 | x = x.view(-1, self.dim, 1)
39 | x = self.act(x)
40 | x = x.view(-1, self.dim)
41 | return x
42 |
43 |
44 | class VGG(nn.Module):
45 |
46 | def __init__(self, features, num_output_neurons):
47 | super(VGG, self).__init__()
48 | self.features = features
49 | self.classifier = nn.Sequential(
50 | nn.Linear(512 * 7 * 7, 4096),
51 | nn.ReLU(True),
52 | nn.Dropout(),
53 | )
54 | self.emd = nn.Sequential(
55 | nn.Linear(4096, 4096),
56 | nn.ReLU(True),
57 | )
58 | self.var = nn.Sequential(nn.Linear(4096, 4096),
59 | nn.BatchNorm1d(4096, eps=0.001, affine=False),
60 | )
61 | self.drop = nn.Dropout()
62 | self.final = nn.Linear(4096, num_output_neurons)
63 | self._initialize_weights()
64 |
65 | def forward(self, x, max_t=50, use_sto=True):
66 | x = self.features(x)
67 | x = x.view(x.size(0), -1)
68 | x = self.classifier(x)
69 | emb = self.emd(x)
70 | log_var = self.var(x)
71 | sqrt_var = torch.exp(log_var * 0.5)
72 | if use_sto:
73 | rep_emb = emb[None].expand(max_t, *emb.shape)
74 | rep_sqrt_var = sqrt_var[None].expand(max_t, *sqrt_var.shape)
75 | norm_v = torch.randn_like(rep_emb).cuda()
76 | sto_emb = rep_emb + rep_sqrt_var * norm_v
77 | sto_emb = self.drop(sto_emb)
78 | logit = self.final(sto_emb)
79 | else:
80 | drop_emb = self.drop(emb)
81 | logit = self.final(drop_emb)
82 |
83 | return logit, emb, log_var
84 |
85 | def _initialize_weights(self):
86 | for m in self.modules():
87 | if isinstance(m, nn.Conv2d):
88 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
89 | m.weight.data.normal_(0, math.sqrt(2. / n))
90 | if m.bias is not None:
91 | m.bias.data.zero_()
92 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
93 | if m.weight is not None:
94 | m.weight.data.fill_(1)
95 | m.bias.data.zero_()
96 | elif isinstance(m, nn.Linear):
97 | m.weight.data.normal_(0, 0.01)
98 | m.bias.data.zero_()
99 |
100 |
101 | def make_layers(cfg, batch_norm=False):
102 | layers = []
103 | in_channels = 3
104 | for v in cfg:
105 | if v == 'M':
106 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
107 | else:
108 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
109 | if batch_norm:
110 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
111 | else:
112 | layers += [conv2d, nn.ReLU(inplace=True)]
113 | in_channels = v
114 | return nn.Sequential(*layers)
115 |
116 |
117 | cfg = {
118 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
119 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
120 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
121 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
122 | }
123 |
124 |
125 | def vgg11(pretrained=False, **kwargs):
126 | """VGG 11-layer model (configuration "A")
127 |
128 | Args:
129 | pretrained (bool): If True, returns a model pre-trained on ImageNet
130 | """
131 | model = VGG(make_layers(cfg['A']), **kwargs)
132 | if pretrained:
133 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
134 | return model
135 |
136 |
137 | def vgg11_bn(pretrained=False, **kwargs):
138 | """VGG 11-layer model (configuration "A") with batch normalization
139 |
140 | Args:
141 | pretrained (bool): If True, returns a model pre-trained on ImageNet
142 | """
143 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
144 | if pretrained:
145 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
146 | return model
147 |
148 |
149 | def vgg13(pretrained=False, **kwargs):
150 | """VGG 13-layer model (configuration "B")
151 |
152 | Args:
153 | pretrained (bool): If True, returns a model pre-trained on ImageNet
154 | """
155 | model = VGG(make_layers(cfg['B']), **kwargs)
156 | if pretrained:
157 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
158 | return model
159 |
160 |
161 | def vgg13_bn(pretrained=False, **kwargs):
162 | """VGG 13-layer model (configuration "B") with batch normalization
163 |
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
168 | if pretrained:
169 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
170 | return model
171 |
172 |
173 | def vgg16(pretrained=False, **kwargs):
174 | """VGG 16-layer model (configuration "D")
175 |
176 | Args:
177 | pretrained (bool): If True, returns a model pre-trained on ImageNet
178 | """
179 | model = VGG(make_layers(cfg['D']), **kwargs)
180 | if pretrained:
181 | unload_model_dict = model.state_dict()
182 | pretrained_dict = model_zoo.load_url(model_urls['vgg16'])
183 |
184 | pretrained_dict['emd.0.weight'] = pretrained_dict['classifier.3.weight']
185 | pretrained_dict['emd.0.bias'] = pretrained_dict['classifier.3.bias']
186 | pretrained_dict['final.weight'] = pretrained_dict['classifier.6.weight']
187 | pretrained_dict['final.bias'] = pretrained_dict['classifier.6.bias']
188 |
189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (
190 | k in unload_model_dict and pretrained_dict[k].shape == unload_model_dict[k].shape)}
191 |
192 | unload_model_dict.update(pretrained_dict)
193 | model.load_state_dict(unload_model_dict)
194 |
195 | return model
196 |
197 |
198 | def vgg16_bn(pretrained=False, **kwargs):
199 | """VGG 16-layer model (configuration "D") with batch normalization
200 |
201 | Args:
202 | pretrained (bool): If True, returns a model pre-trained on ImageNet
203 | """
204 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
205 | if pretrained:
206 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
207 | return model
208 |
209 |
210 | def vgg19(pretrained=False, **kwargs):
211 | """VGG 19-layer model (configuration "E")
212 |
213 | Args:
214 | pretrained (bool): If True, returns a model pre-trained on ImageNet
215 | """
216 | model = VGG(make_layers(cfg['E']), **kwargs)
217 | if pretrained:
218 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
219 | return model
220 |
221 |
222 | def vgg19_bn(pretrained=False, **kwargs):
223 | """VGG 19-layer model (configuration 'E') with batch normalization
224 |
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | """
228 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
229 | if pretrained:
230 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
231 | return model
232 |
--------------------------------------------------------------------------------
/codes/adience_poe/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torchvision import transforms
4 | from torch.autograd import Variable
5 | from poe import dataset_manger
6 | import os
7 | import time
8 | import numpy as np
9 | from poe import vgg
10 | from poe.probordiloss import ProbOrdiLoss
11 | from poe.metrics import get_metric
12 | from poe.utils import is_fc, get_current_time, AverageMeter
13 | from poe.options import get_args
14 |
15 |
16 | def train_epoch(epoch):
17 |
18 | batch_time = AverageMeter(args.AverageMeter_MaxCount)
19 | loss1es = AverageMeter(args.AverageMeter_MaxCount)
20 | loss2es = AverageMeter(args.AverageMeter_MaxCount)
21 | loss3es = AverageMeter(args.AverageMeter_MaxCount)
22 | losses = AverageMeter(args.AverageMeter_MaxCount)
23 | mae = AverageMeter(args.AverageMeter_MaxCount)
24 | acc = AverageMeter(args.AverageMeter_MaxCount)
25 |
26 | rpr_model.train()
27 |
28 | end_time = time.time()
29 | for batch_idx, (inputs, targets, mh_targets) in enumerate(trainloader):
30 |
31 | inputs, targets, mh_targets = inputs.cuda(), targets.cuda(), mh_targets.cuda()
32 | inputs, targets = Variable(
33 | inputs, requires_grad=True), Variable(targets)
34 |
35 | optimizer.zero_grad()
36 |
37 | logit, emb, log_var = rpr_model(
38 | inputs, max_t=args.max_t, use_sto=args.use_sto)
39 |
40 | loss1, loss2, loss3, loss = criterion(
41 | logit, emb, log_var, targets, mh_targets, use_sto=args.use_sto)
42 |
43 | loss.backward()
44 | optimizer.step()
45 |
46 | batch_mae, batch_acc = cal_mae_acc(logit, targets, args.use_sto)
47 | loss1es.update(loss1.cpu().data.numpy())
48 | loss2es.update(loss2.cpu().data.numpy())
49 | loss3es.update(loss3.cpu().data.numpy())
50 | losses.update(loss.cpu().data.numpy())
51 |
52 | mae.update(batch_mae)
53 | acc.update(batch_acc)
54 |
55 | batch_time.update(time.time() - end_time)
56 | end_time = time.time()
57 |
58 | if batch_idx % args.print_freq == 0:
59 | print('Epoch: [%d][%d/%d] '
60 | 'Time %.3f (%.3f)\t'
61 | 'Loss1 %.3f (%.3f)\t'
62 | 'Loss2 %.3f (%.3f)\t'
63 | 'Loss3 %.3f (%.3f)\t'
64 | 'Loss %.3f (%.3f)\t'
65 | 'MAE %.3f (%.3f)\t'
66 | 'ACC %.3f (%.3f)' % (epoch, batch_idx, len(trainloader),
67 | batch_time.val, batch_time.avg, loss1es.val, loss1es.avg, loss2es.val, loss2es.avg, loss3es.val, loss3es.avg,
68 | losses.val, losses.avg, mae.val, mae.avg, acc.val, acc.avg))
69 |
70 |
71 | def test_epoch(epoch):
72 |
73 | batch_time = AverageMeter(args.AverageMeter_MaxCount)
74 | loss1es = AverageMeter(args.AverageMeter_MaxCount)
75 | loss2es = AverageMeter(args.AverageMeter_MaxCount)
76 | loss3es = AverageMeter(args.AverageMeter_MaxCount)
77 | losses = AverageMeter(args.AverageMeter_MaxCount)
78 | mae_sto = AverageMeter(args.AverageMeter_MaxCount)
79 | mae_no_sto = AverageMeter(args.AverageMeter_MaxCount)
80 | acc_sto = AverageMeter(args.AverageMeter_MaxCount)
81 | acc_no_sto = AverageMeter(args.AverageMeter_MaxCount)
82 |
83 | rpr_model.eval()
84 | total = 0
85 | all_mae_sto = .0
86 | all_mae_no_sto = .0
87 | all_acc_sto = .0
88 | all_acc_no_sto = .0
89 |
90 | end_time = time.time()
91 | for batch_idx, (inputs, targets, mh_targets) in enumerate(testloader):
92 | inputs, targets, mh_targets = inputs.cuda(), targets.cuda(), mh_targets.cuda()
93 | inputs, targets = Variable(
94 | inputs, requires_grad=True), Variable(targets)
95 |
96 | logit, emb, log_var = rpr_model(inputs, max_t=args.max_t, use_sto=True)
97 | logit_no_sto, emb, log_var = rpr_model(
98 | inputs, max_t=args.max_t, use_sto=False)
99 | if args.use_sto:
100 | loss1, loss2, loss3, loss = criterion(
101 | logit, emb, log_var, targets, mh_targets, use_sto=args.use_sto)
102 | else:
103 | loss1, loss2, loss3, loss = criterion(
104 | logit_no_sto, emb, log_var, targets, mh_targets, use_sto=args.use_sto)
105 |
106 | total += targets.size(0)
107 |
108 | batch_mae_sto, batch_acc_sto = cal_mae_acc(logit, targets, True)
109 | batch_mae_no_sto, batch_acc_no_sto = cal_mae_acc(
110 | logit_no_sto, targets, False)
111 |
112 | loss1es.update(loss1.cpu().data.numpy())
113 | loss2es.update(loss2.cpu().data.numpy())
114 | loss3es.update(loss3.cpu().data.numpy())
115 | losses.update(loss.cpu().data.numpy())
116 |
117 | mae_sto.update(batch_mae_sto)
118 | mae_no_sto.update(batch_mae_no_sto)
119 | acc_sto.update(batch_acc_sto)
120 | acc_no_sto.update(batch_acc_no_sto)
121 |
122 | all_mae_sto = all_mae_sto + batch_mae_sto * targets.size(0)
123 | all_mae_no_sto = all_mae_no_sto + batch_mae_no_sto * targets.size(0)
124 | all_acc_sto = all_acc_sto + batch_acc_sto * targets.size(0)
125 | all_acc_no_sto = all_acc_no_sto + batch_acc_no_sto * targets.size(0)
126 |
127 | batch_time.update(time.time() - end_time)
128 | end_time = time.time()
129 |
130 | if batch_idx % args.print_freq == 0:
131 | print('Test: [%d/%d]\t'
132 | 'Time %.3f (%.3f)\t'
133 | 'Loss1 %.3f (%.3f)\t'
134 | 'Loss2 %.3f (%.3f)\t'
135 | 'Loss3 %.3f (%.3f)\t'
136 | 'Loss %.3f (%.3f)\t'
137 | 'MAE_sto %.3f (%.3f)\t'
138 | 'MAE_no_sto %.3f (%.3f)\t'
139 | 'ACC_sto %.3f (%.3f)\t'
140 | 'ACC_no_sto %.3f (%.3f)\t' % (batch_idx, len(testloader),
141 | batch_time.val, batch_time.avg, loss1es.val, loss1es.avg, loss2es.val, loss2es.avg, loss3es.val, loss3es.avg,
142 | losses.val, losses.avg, mae_sto.val, mae_sto.avg, mae_no_sto.val, mae_no_sto.avg, acc_sto.val, acc_sto.avg, acc_no_sto.val, acc_no_sto.avg))
143 |
144 | print('Test: MAE_sto: %.3f MAE_no_sto: %.3f ACC_sto: %.3f ACC_no_sto: %.3f' % (all_mae_sto *
145 | 1.0 / total, all_mae_no_sto * 1.0 / total, all_acc_sto * 1.0 / total, all_acc_no_sto * 1.0 / total))
146 |
147 | return all_mae_sto * 1.0 / total, all_mae_no_sto * 1.0 / total, all_acc_sto * 1.0 / total, all_acc_no_sto * 1.0 / total
148 |
149 |
150 | if __name__ == "__main__":
151 | # Training settings
152 | # ---------------------------------
153 | args = get_args()
154 |
155 | torch.manual_seed(0)
156 | np.random.seed(0)
157 | torch.backends.cudnn.deterministic = True
158 | torch.backends.cudnn.benchmark = False
159 | # ---------------------------------
160 |
161 | # dataset prepare
162 | # ---------------------------------
163 | transform_train = transforms.Compose([
164 | transforms.Resize((256, 256)),
165 | # transforms.RandomRotation(10),
166 | transforms.RandomCrop(224),
167 | transforms.RandomHorizontalFlip(),
168 | transforms.ToTensor(),
169 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
170 | ])
171 |
172 | transform_test = transforms.Compose([
173 | transforms.Resize((256, 256)),
174 | transforms.CenterCrop(224),
175 | transforms.ToTensor(),
176 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
177 | ])
178 |
179 | trainset = dataset_manger.dataset_manger(
180 | images_root=args.train_images_root, data_file=args.train_data_file, transforms=transform_train)
181 |
182 | trainloader = torch.utils.data.DataLoader(
183 | trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
184 |
185 | testset = dataset_manger.dataset_manger(
186 | images_root=args.test_images_root, data_file=args.test_data_file, transforms=transform_test)
187 |
188 | testloader = torch.utils.data.DataLoader(
189 | testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_workers)
190 | # --------------------------------------
191 |
192 | # define Model
193 | # ----------------------------------------
194 | start_epoch = 0
195 | rpr_model = vgg.vgg16(pretrained=True, num_output_neurons=args.num_output_neurons)
196 | rpr_model.cuda()
197 | # rpr_model = torch.nn.DataParallel(rpr_model)
198 | print("Model finshed")
199 | # ---------------------------------
200 |
201 | # define Optimizer
202 | # ----------------------------------------
203 | params = []
204 | for keys, param_value in rpr_model.named_parameters():
205 | if (is_fc(keys)):
206 | params += [{'params': [param_value], 'lr':args.fc_lr}]
207 | else:
208 | params += [{'params': [param_value], 'lr':args.lr}]
209 |
210 | optimizer = optim.Adam(params, lr=args.lr, betas=(0.9, 0.999), eps=1e-08)
211 |
212 | lr_decay = args.lr_decay
213 | lr_decay_epoch = [int(i)
214 | for i in args.lr_decay_epoch.split(',')] + [np.inf]
215 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
216 | optimizer, milestones=lr_decay_epoch, gamma=lr_decay, last_epoch=-1)
217 | # ---------------------------------
218 |
219 | # define loss
220 | # ----------------------------------------
221 | criterion = ProbOrdiLoss(distance=args.distance, alpha_coeff=args.alpha_coeff,
222 | beta_coeff=args.beta_coeff, margin=args.margin, main_loss_type=args.main_loss_type)
223 | criterion.cuda()
224 |
225 | # define Metric
226 | # ----------------------------------------
227 | cal_mae_acc = get_metric(args.main_loss_type)
228 | # ---------------------------------
229 |
230 | args.logdir = os.path.join(args.logdir, args.exp_name)
231 | if not os.path.exists(args.logdir):
232 | os.makedirs(args.logdir)
233 | print('log dir [{}] {}'.format(args.logdir, 'is created!'))
234 |
235 | print("start training...")
236 | best_mae = np.inf
237 |
238 | for epoch in range(start_epoch, args.max_epochs):
239 | print('[{}] Epoch: {} start!'.format(get_current_time(), epoch))
240 |
241 | if not args.test_only:
242 | train_epoch(epoch)
243 | lr_scheduler.step()
244 | with torch.no_grad():
245 | cur_mae_sto, cur_mae_no_sto, cur_acc_sto, cur_acc_no_sto = test_epoch(
246 | epoch)
247 | else:
248 | with torch.no_grad():
249 | cur_mae_sto, cur_mae_no_sto, cur_acc_sto, cur_acc_no_sto = test_epoch(
250 | epoch)
251 | break
252 |
253 | print('saving model...')
254 | is_best = cur_mae_no_sto < best_mae
255 | best_mae = min(best_mae, cur_mae_no_sto)
256 | if epoch % args.save_freq == 0:
257 | torch.save(
258 | {
259 | 'model_state_dict': rpr_model.state_dict(),
260 | 'optim_state_dict': optimizer.state_dict(),
261 | 'mae': cur_mae_no_sto,
262 | 'acc': cur_acc_no_sto,
263 | 'epoch': epoch + 1
264 | },
265 | os.path.join(args.logdir, 'checkpoint.pth')
266 | )
267 | print('save checkpoint at {}'.format(
268 | os.path.join(args.logdir, 'checkpoint.pth')))
269 | if is_best:
270 | torch.save(
271 | {
272 | 'model_state_dict': rpr_model.state_dict(),
273 | 'optim_state_dict': optimizer.state_dict(),
274 | 'mae': cur_mae_no_sto,
275 | 'acc': cur_acc_no_sto,
276 | 'epoch': epoch + 1
277 | },
278 | os.path.join(args.logdir, 'best.pth')
279 | )
280 | print('save best model at {}'.format(
281 | os.path.join(args.logdir, 'best.pth')))
282 |
283 | print('[{}] Epoch: {} end!'.format(get_current_time(), epoch))
284 |
--------------------------------------------------------------------------------