├── networks ├── util.py ├── base.py └── gol.py ├── data ├── get_datasets_OL_basic.py ├── datasets │ ├── basic.py │ ├── OL_basic_val.py │ ├── OL_basic_train.py │ ├── OL_hard_train.py │ ├── OL_triple_order_train.py │ ├── OL_lossbalancing_train.py │ ├── OL_mining_train.py │ ├── OL_triplet_train.py │ ├── angl_triplet_train.py │ └── OL_angl_train.py ├── get_datasets_NN_test.py ├── get_datasets_OLMining.py ├── get_datasets_tr_Metric_val_NN.py ├── get_datasets_basic.py ├── get_datasets_OL_align.py ├── get_datasets_align_only.py ├── get_datasets_tr_OLhard_val_NN.py ├── get_datasets_tr_OL_lossbalancing.py └── get_datasets_tr_OLbasic_val_NN.py ├── README.md ├── utils ├── scheduler_util.py ├── metric_util.py ├── sampling_utils.py ├── loss_util.py ├── util.py └── comparison_utils.py ├── config └── basic.py └── train └── train.py /networks/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from networks.gol import GOL 4 | 5 | 6 | def prepare_model(opt): 7 | model = eval(opt.model)(opt) 8 | return model 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /data/get_datasets_OL_basic.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from torch.utils.data import DataLoader 3 | 4 | from data.datasets import OL_basic_train, OL_basic_val 5 | 6 | 7 | def get_datasets(cfg): 8 | with open(cfg.train_file, 'rb') as f: 9 | data = pickle.load(f) 10 | tr_imgs = data['data'] 11 | tr_ages = data['age'] 12 | 13 | with open(cfg.test_file, 'rb') as f: 14 | data = pickle.load(f) 15 | te_imgs = data['data'] 16 | te_ages = data['age'] 17 | 18 | loader_dict = dict() 19 | loader_dict['train'] = DataLoader(OL_basic_train.OLBasic_Train(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau), 20 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 21 | loader_dict['val'] = DataLoader(OL_basic_val.OLBasic_Val([te_imgs, te_ages], [tr_imgs, tr_ages], cfg.transform_te, cfg.tau), 22 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 23 | return loader_dict 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GOL 2 | [NeurIPS 2022] Geometric order learning for rank estimation [[paper]](https://openreview.net/pdf?id=agNTJU1QNw) 3 | 4 | [Seon-Ho Lee](https://scholar.google.co.kr/citations?user=_LtQ4TcAAAAJ&hl=en), Nyeong-Ho Shin, and Chang-Su Kim 5 | 6 | --- 7 | ## Dependencies 8 | * Python 3.8 9 | * Pytorch 1.7.1 10 | --- 11 | ## Datasets 12 | - [MORPH II](https://ebill.uncw.edu/C20231_ustores/web/classic/product_detail.jsp?PRODUCTID=8) 13 | * For MORPH II experiments, we follow the same fold settings in this [OL](https://github.com/changsukim-ku/order-learning/tree/master/index) repo. 14 | - [Adience](https://talhassner.github.io/home/projects/Adience/Adience-data.html) 15 | * For Adience experiments, we follow the official splits. 16 | - [CACD] 17 | - [UTK] 18 | --- 19 | ## Usage 20 | ``` 21 | $ python train.py 22 | ``` 23 | * Modify 'cfg.dataset' and 'cfg.setting' for training on other/custom dataset 24 | * You may need to change 'cfg.ref_point_num' and 'cfg.margin' to obtain decent results. 25 | --- 26 | 27 | ## Citation 28 | Please cite our paper if you use this repository. 29 | ``` 30 | @inproceedings{GOL2022lee, 31 | author = {LEE, Seon-Ho and Shin, Nyeong-Ho and Kim, Chang-Su}, 32 | title = {Geometric Order Learning for Rank Estimation}, 33 | booktitle = {Advances in Neural Information Processing Systems}, 34 | year = {2022} 35 | } 36 | ``` 37 | --- 38 | ## License 39 | MIT License 40 | 41 | -------------------------------------------------------------------------------- /data/datasets/basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | class Basic(Dataset): 8 | def __init__(self, imgs, labels, transform, norm_age=True, is_filelist=False, return_ranks=False, std=None): 9 | super(Dataset, self).__init__() 10 | self.transform = transform 11 | self.imgs = imgs 12 | self.labels = labels 13 | 14 | self.n_imgs = len(self.imgs) 15 | self.is_filelist = is_filelist 16 | if norm_age: 17 | self.labels = self.labels - min(self.labels) 18 | self.return_ranks = return_ranks 19 | self.std = std 20 | 21 | rank = 0 22 | self.mapping = dict() 23 | for cls in np.unique(self.labels): 24 | self.mapping[cls] = rank 25 | rank += 1 26 | self.ranks = np.array([self.mapping[l] for l in self.labels]) 27 | 28 | 29 | 30 | def __getitem__(self, item): 31 | if self.is_filelist: 32 | img = np.asarray(load_one_image(self.imgs[item])).astype('uint8') 33 | else: 34 | img = np.asarray(self.imgs[item]).astype('uint8') 35 | img = self.transform(img) 36 | 37 | if self.return_ranks: 38 | return img, self.labels[item], self.ranks[item], item 39 | else: 40 | return img, self.labels[item], item 41 | 42 | def __len__(self): 43 | return len(self.imgs) 44 | -------------------------------------------------------------------------------- /data/get_datasets_NN_test.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import numpy as np 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | from data.datasets import basic 8 | 9 | 10 | def get_datasets_NN_test(cfg): 11 | if cfg.is_filelist: 12 | img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 13 | tr_list = pd.read_csv(cfg.train_file, sep=cfg.delimeter) 14 | tr_list = np.array(tr_list) 15 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, cfg.img_idx]] 16 | tr_ages = tr_list[:, cfg.lb_idx] 17 | 18 | te_list = pd.read_csv(cfg.test_file, sep=cfg.delimeter) 19 | te_list = np.array(te_list) 20 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, cfg.img_idx]] 21 | te_ages = te_list[:, cfg.lb_idx] 22 | 23 | else: 24 | with open(cfg.train_file, 'rb') as f: 25 | data = pickle.load(f) 26 | tr_imgs = data['data'] 27 | tr_ages = data['age'] 28 | 29 | with open(cfg.test_file, 'rb') as f: 30 | data = pickle.load(f) 31 | te_imgs = data['data'] 32 | te_ages = data['age'] 33 | 34 | loader_dict = dict() 35 | loader_dict['train'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_te, cfg.tau, is_filelist=cfg.is_filelist), 36 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 37 | loader_dict['test'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, cfg.tau, is_filelist=cfg.is_filelist), 38 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 39 | return loader_dict 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /data/get_datasets_OLMining.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from data.datasets import OL_mining_train, basic 7 | 8 | 9 | 10 | def get_datasets(cfg): 11 | if cfg.is_filelist: 12 | img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 13 | tr_list = pd.read_csv(cfg.train_file, sep=" ") 14 | tr_list = np.array(tr_list) 15 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, 3]] 16 | tr_ages = tr_list[:, 2] 17 | 18 | te_list = pd.read_csv(cfg.test_file, sep=" ") 19 | te_list = np.array(te_list) 20 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, 3]] 21 | te_ages = te_list[:, 2] 22 | 23 | else: 24 | with open(cfg.train_file, 'rb') as f: 25 | data = pickle.load(f) 26 | tr_imgs = data['data'] 27 | tr_ages = data['age'] 28 | 29 | with open(cfg.test_file, 'rb') as f: 30 | data = pickle.load(f) 31 | te_imgs = data['data'] 32 | te_ages = data['age'] 33 | 34 | loader_dict = dict() 35 | loader_dict['train'] = DataLoader(OL_mining_train.OLMining_Train(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, max_epoch=cfg.epochs), 36 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 37 | loader_dict['train_for_val'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_te, cfg.tau), 38 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, 39 | num_workers=cfg.num_workers) 40 | 41 | loader_dict['val'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, cfg.tau), 42 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 43 | return loader_dict 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /data/get_datasets_tr_Metric_val_NN.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | from data.datasets import basic 8 | 9 | 10 | def get_datasets(cfg): 11 | if cfg.is_filelist: 12 | img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 13 | tr_list = pd.read_csv(cfg.train_file, sep=cfg.delimeter) 14 | tr_list = np.array(tr_list) 15 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, cfg.img_idx]] 16 | tr_ages = tr_list[:, cfg.lb_idx] 17 | 18 | te_list = pd.read_csv(cfg.test_file, sep=cfg.delimeter) 19 | te_list = np.array(te_list) 20 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, cfg.img_idx]] 21 | te_ages = te_list[:, cfg.lb_idx] 22 | 23 | else: 24 | with open(cfg.train_file, 'rb') as f: 25 | data = pickle.load(f) 26 | tr_imgs = data['data'] 27 | tr_ages = data['age'] 28 | 29 | with open(cfg.test_file, 'rb') as f: 30 | data = pickle.load(f) 31 | te_imgs = data['data'] 32 | te_ages = data['age'] 33 | 34 | loader_dict = dict() 35 | loader_dict['train'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, is_filelist=cfg.is_filelist), 36 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 37 | loader_dict['val'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, cfg.tau, is_filelist=cfg.is_filelist), 38 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 39 | loader_dict['train_for_val'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, is_filelist=cfg.is_filelist), 40 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 41 | 42 | return loader_dict 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /data/get_datasets_basic.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | from data.datasets import basic 8 | 9 | 10 | def get_datasets(cfg): 11 | if cfg.is_filelist: 12 | img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 13 | tr_list = pd.read_csv(cfg.train_file, sep=cfg.delimeter) 14 | tr_list = np.array(tr_list) 15 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, cfg.img_idx]] 16 | tr_ages = tr_list[:, cfg.lb_idx] 17 | 18 | te_list = pd.read_csv(cfg.test_file, sep=cfg.delimeter) 19 | te_list = np.array(te_list) 20 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, cfg.img_idx]] 21 | te_ages = te_list[:, cfg.lb_idx] 22 | 23 | else: 24 | with open(cfg.train_file, 'rb') as f: 25 | data = pickle.load(f) 26 | tr_imgs = data['data'] 27 | tr_ages = data['age'] 28 | 29 | with open(cfg.test_file, 'rb') as f: 30 | data = pickle.load(f) 31 | te_imgs = data['data'] 32 | te_ages = data['age'] 33 | 34 | loader_dict = dict() 35 | loader_dict['train'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, is_filelist=cfg.is_filelist, return_ranks=True), 36 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 37 | loader_dict['val'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, cfg.tau, is_filelist=cfg.is_filelist), 38 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 39 | loader_dict['train_for_val'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, is_filelist=cfg.is_filelist), 40 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 41 | 42 | return loader_dict 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /data/get_datasets_OL_align.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from data.datasets import OL_triplet_train, basic 7 | 8 | 9 | 10 | def get_datasets(cfg): 11 | if cfg.is_filelist: 12 | img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 13 | tr_list = pd.read_csv(cfg.train_file, sep=" ") 14 | tr_list = np.array(tr_list) 15 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, 3]] 16 | tr_ages = tr_list[:, 2] 17 | 18 | te_list = pd.read_csv(cfg.test_file, sep=" ") 19 | te_list = np.array(te_list) 20 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, 3]] 21 | te_ages = te_list[:, 2] 22 | 23 | else: 24 | with open(cfg.train_file, 'rb') as f: 25 | data = pickle.load(f) 26 | tr_imgs = data['data'] 27 | tr_ages = data['age'] 28 | 29 | with open(cfg.test_file, 'rb') as f: 30 | data = pickle.load(f) 31 | te_imgs = data['data'] 32 | te_ages = data['age'] 33 | 34 | loader_dict = dict() 35 | loader_dict['train'] = DataLoader(OL_triplet_train.OLTriplet_Train(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, logscale=cfg.logscale, is_filelist=cfg.is_filelist), 36 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 37 | loader_dict['train_for_val'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_te, is_filelist=cfg.is_filelist), 38 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, 39 | num_workers=cfg.num_workers) 40 | 41 | loader_dict['val'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, is_filelist=cfg.is_filelist), 42 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 43 | return loader_dict 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /data/get_datasets_align_only.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from data.datasets import OL_angl_train, basic 7 | 8 | 9 | 10 | def get_datasets(cfg): 11 | if cfg.is_filelist: 12 | img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 13 | tr_list = pd.read_csv(cfg.train_file, sep=cfg.delimeter) 14 | tr_list = np.array(tr_list) 15 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, cfg.img_idx]] 16 | tr_ages = tr_list[:, cfg.lb_idx] 17 | 18 | te_list = pd.read_csv(cfg.test_file, sep=cfg.delimeter) 19 | te_list = np.array(te_list) 20 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, cfg.img_idx]] 21 | te_ages = te_list[:, cfg.lb_idx] 22 | 23 | else: 24 | with open(cfg.train_file, 'rb') as f: 25 | data = pickle.load(f) 26 | tr_imgs = data['data'] 27 | tr_ages = data['age'] 28 | 29 | with open(cfg.test_file, 'rb') as f: 30 | data = pickle.load(f) 31 | te_imgs = data['data'] 32 | te_ages = data['age'] 33 | 34 | loader_dict = dict() 35 | loader_dict['train'] = DataLoader(OL_angl_train.OLAngle_Train(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, logscale=cfg.logscale, is_filelist=cfg.is_filelist), 36 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 37 | loader_dict['train_for_val'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_te, is_filelist=cfg.is_filelist), 38 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, 39 | num_workers=cfg.num_workers) 40 | 41 | loader_dict['val'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, is_filelist=cfg.is_filelist), 42 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 43 | return loader_dict 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /data/get_datasets_tr_OLhard_val_NN.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from data.datasets import OL_hard_train, basic 7 | 8 | 9 | 10 | def get_datasets(cfg): 11 | if cfg.is_filelist: 12 | img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 13 | tr_list = pd.read_csv(cfg.train_file, sep=cfg.delimeter) 14 | tr_list = np.array(tr_list) 15 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, cfg.img_idx]] 16 | tr_ages = tr_list[:, cfg.lb_idx] 17 | 18 | te_list = pd.read_csv(cfg.test_file, sep=cfg.delimeter) 19 | te_list = np.array(te_list) 20 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, cfg.img_idx]] 21 | te_ages = te_list[:, cfg.lb_idx] 22 | 23 | else: 24 | with open(cfg.train_file, 'rb') as f: 25 | data = pickle.load(f) 26 | tr_imgs = data['data'] 27 | tr_ages = data['age'] 28 | 29 | with open(cfg.test_file, 'rb') as f: 30 | data = pickle.load(f) 31 | te_imgs = data['data'] 32 | te_ages = data['age'] 33 | 34 | loader_dict = dict() 35 | loader_dict['train'] = DataLoader(OL_hard_train.OLHard_Train(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, logscale=cfg.logscale, is_filelist=cfg.is_filelist), 36 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 37 | loader_dict['train_for_val'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_te, is_filelist=cfg.is_filelist), 38 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, 39 | num_workers=cfg.num_workers) 40 | 41 | loader_dict['val'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, is_filelist=cfg.is_filelist), 42 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 43 | return loader_dict 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /data/get_datasets_tr_OL_lossbalancing.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from data.datasets import OL_lossbalancing_train, basic 7 | 8 | 9 | 10 | def get_datasets(cfg): 11 | if cfg.is_filelist: 12 | img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 13 | tr_list = pd.read_csv(cfg.train_file, sep=cfg.delimeter) 14 | tr_list = np.array(tr_list) 15 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, cfg.img_idx]] 16 | tr_ages = tr_list[:, cfg.lb_idx] 17 | 18 | te_list = pd.read_csv(cfg.test_file, sep=cfg.delimeter) 19 | te_list = np.array(te_list) 20 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, cfg.img_idx]] 21 | te_ages = te_list[:, cfg.lb_idx] 22 | 23 | else: 24 | with open(cfg.train_file, 'rb') as f: 25 | data = pickle.load(f) 26 | tr_imgs = data['data'] 27 | tr_ages = data['age'] 28 | 29 | with open(cfg.test_file, 'rb') as f: 30 | data = pickle.load(f) 31 | te_imgs = data['data'] 32 | te_ages = data['age'] 33 | 34 | loader_dict = dict() 35 | loader_dict['train'] = DataLoader(OL_lossbalancing_train.OLLossBalancing_Train(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, logscale=cfg.logscale, is_filelist=cfg.is_filelist), 36 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 37 | loader_dict['train_for_val'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_te, is_filelist=cfg.is_filelist), 38 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, 39 | num_workers=cfg.num_workers) 40 | 41 | loader_dict['val'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, is_filelist=cfg.is_filelist), 42 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 43 | return loader_dict 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /data/datasets/OL_basic_val.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | class OLBasic_Val(Dataset): 8 | def __init__(self, base_data, ref_data, transform, tau, norm_age=True, is_filelist=False): 9 | super(Dataset, self).__init__() 10 | self.transform = transform 11 | self.ref_imgs, self.ref_labels = ref_data 12 | self.base_imgs, self.base_labels = base_data 13 | 14 | self.n_base_imgs = len(self.base_imgs) 15 | self.n_ref_imgs = len(self.ref_imgs) 16 | self.tau = tau 17 | self.is_filelist = is_filelist 18 | 19 | if norm_age: 20 | self.base_labels = self.base_labels - min(self.base_labels) 21 | self.ref_labels = self.ref_labels - min(self.ref_labels) 22 | 23 | def __getitem__(self, item): 24 | ref_idx = np.random.choice(self.n_ref_imgs, 1)[0] 25 | if self.is_filelist: 26 | base_img = np.asarray(load_one_image(self.base_imgs[item])).astype('uint8') 27 | ref_img = np.asarray(load_one_image(self.ref_imgs[ref_idx])).astype('uint8') 28 | else: 29 | base_img = np.asarray(self.base_imgs[item]).astype('uint8') 30 | ref_img = np.asarray(self.ref_imgs[ref_idx]).astype('uint8') 31 | 32 | base_img = self.transform(base_img) 33 | ref_img = self.transform(ref_img) 34 | 35 | # order label generation 36 | order_labels = self.get_order_labels(item, ref_idx) 37 | 38 | # gt ages 39 | # base_age = self.base_labels[item] 40 | # ref_age = self.ref_labels[ref_idx] 41 | return base_img, ref_img, order_labels, item 42 | 43 | def __len__(self): 44 | return len(self.base_imgs) 45 | 46 | def get_order_labels(self, base_idx, ref_idx): 47 | base_ranks = self.base_labels[base_idx] 48 | ref_ranks = self.ref_labels[ref_idx] 49 | 50 | if base_ranks > ref_ranks + self.tau: 51 | order_labels = 0 52 | elif base_ranks < ref_ranks - self.tau: 53 | order_labels = 1 54 | else: 55 | order_labels = 2 56 | return order_labels -------------------------------------------------------------------------------- /utils/scheduler_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class CosineAnnealingWarmUpRestarts(_LRScheduler): 6 | def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1): 7 | if T_0 <= 0 or not isinstance(T_0, int): 8 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 9 | if T_mult < 1 or not isinstance(T_mult, int): 10 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) 11 | if T_up < 0 or not isinstance(T_up, int): 12 | raise ValueError("Expected positive integer T_up, but got {}".format(T_up)) 13 | self.T_0 = T_0 14 | self.T_mult = T_mult 15 | self.base_eta_max = eta_max 16 | self.eta_max = eta_max 17 | self.T_up = T_up 18 | self.T_i = T_0 19 | self.gamma = gamma 20 | self.cycle = 0 21 | self.T_cur = last_epoch 22 | super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch) 23 | 24 | def get_lr(self): 25 | if self.T_cur == -1: 26 | return self.base_lrs 27 | elif self.T_cur < self.T_up: 28 | return [(self.eta_max - base_lr) * self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs] 29 | else: 30 | return [base_lr + (self.eta_max - base_lr) * ( 31 | 1 + math.cos(math.pi * (self.T_cur - self.T_up) / (self.T_i - self.T_up))) / 2 32 | for base_lr in self.base_lrs] 33 | 34 | def step(self, epoch=None): 35 | if epoch is None: 36 | epoch = self.last_epoch + 1 37 | self.T_cur = self.T_cur + 1 38 | if self.T_cur >= self.T_i: 39 | self.cycle += 1 40 | self.T_cur = self.T_cur - self.T_i 41 | self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up 42 | else: 43 | if epoch >= self.T_0: 44 | if self.T_mult == 1: 45 | self.T_cur = epoch % self.T_0 46 | self.cycle = epoch // self.T_0 47 | else: 48 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 49 | self.cycle = n 50 | self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) 51 | self.T_i = self.T_0 * self.T_mult ** (n) 52 | else: 53 | self.T_i = self.T_0 54 | self.T_cur = epoch 55 | 56 | self.eta_max = self.base_eta_max * (self.gamma ** self.cycle) 57 | self.last_epoch = math.floor(epoch) 58 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 59 | param_group['lr'] = lr -------------------------------------------------------------------------------- /networks/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torchvision.models as models 5 | 6 | 7 | class BaseModel(nn.Module): 8 | def __init__(self, cfg): 9 | super().__init__() 10 | if cfg.backbone == 'resnet18': 11 | backbone = models.resnet18(pretrained=True) 12 | backbone.fc = nn.Identity() 13 | self.encoder = backbone 14 | 15 | elif cfg.backbone == 'vgg16': 16 | backbone = models.vgg16_bn(pretrained=True) 17 | backbone.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 18 | backbone.classifier = nn.Identity() 19 | self.encoder = backbone 20 | 21 | elif cfg.backbone == 'vgg16v2': # no bn, relu, maxpool after last convolution 22 | backbone = models.vgg16_bn(pretrained=True) 23 | backbone.features[41] = nn.Identity() 24 | backbone.features[42] = nn.Identity() 25 | backbone.features[43] = nn.Identity() 26 | backbone.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 27 | backbone.classifier = nn.Identity() 28 | self.encoder = backbone 29 | 30 | elif cfg.backbone == 'vgg16v2norm': # no bn, relu, maxpool after last convolution 31 | backbone = models.vgg16_bn(pretrained=True) 32 | backbone.features[41] = nn.Identity() 33 | backbone.features[42] = nn.Identity() 34 | backbone.features[43] = nn.Identity() 35 | backbone.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 36 | class Normalization(torch.nn.Module): 37 | def __init__(self, dim=-1): 38 | super().__init__() 39 | self.dim = dim 40 | def forward(self, x): 41 | return nn.functional.normalize(x, dim=self.dim) 42 | backbone.classifier = Normalization() 43 | self.encoder = backbone 44 | 45 | elif cfg.backbone == 'vgg16fc': 46 | backbone = models.vgg16_bn(pretrained=True) 47 | backbone.classifier[5] = nn.Identity() 48 | backbone.classifier[6] = nn.Identity() 49 | self.encoder = backbone 50 | else: 51 | raise ValueError(f'Not supported backbone architecture {cfg.backbone}') 52 | 53 | def forward(self, x_base, x_ref=None): 54 | # feature extraction 55 | base_embs = self.encoder(x_base) 56 | if x_ref is not None: 57 | ref_embs = self.encoder(x_ref) 58 | out = self._forward(base_embs, ref_embs) 59 | return out, base_embs, ref_embs 60 | else: 61 | out = self._forward(base_embs) 62 | return out 63 | 64 | def _forward(self, base_embs, ref_embs=None): 65 | raise NotImplementedError('Suppose to be implemented by subclass') -------------------------------------------------------------------------------- /data/get_datasets_tr_OLbasic_val_NN.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | 6 | from data.datasets import OL_basic_train, basic 7 | 8 | 9 | 10 | def get_datasets(cfg): 11 | tr_std = None 12 | te_std = None 13 | if cfg.dataset =='morph': 14 | img_root = cfg.img_root 15 | tr_list = pd.read_csv(cfg.train_file, sep=cfg.delimeter) 16 | tr_list = np.array(tr_list) 17 | tr_imgs = [f'{img_root}/{i_path}' for i_path in tr_list[:, cfg.img_idx]] 18 | tr_ages = tr_list[:, cfg.lb_idx] 19 | 20 | te_list = pd.read_csv(cfg.test_file, sep=cfg.delimeter) 21 | te_list = np.array(te_list) 22 | te_imgs = [f'{img_root}/{i_path}' for i_path in te_list[:, cfg.img_idx]] 23 | te_ages = te_list[:, cfg.lb_idx] 24 | 25 | elif cfg.dataset =='clap': 26 | img_root = cfg.img_root 27 | tr_list = pd.read_csv(cfg.train_file, sep=cfg.delimeter) 28 | tr_list = np.array(tr_list) 29 | tr_ages = tr_list[:, cfg.lb_idx] 30 | tr_imgs = [f'{img_root}/{tr_list[i, 3]}/{tr_list[i, cfg.img_idx]}' for i in range(len(tr_list))] 31 | tr_std = tr_list[:, 2] 32 | # 33 | # # debug for n_ranks and margin relation 34 | # idx = np.argwhere(tr_ages < 60).flatten() 35 | # tr_ages = tr_ages[idx] 36 | # tr_imgs = np.array(tr_imgs)[idx] 37 | # tr_std = tr_std[idx] 38 | 39 | te_list = pd.read_csv(cfg.test_file, sep=cfg.delimeter) 40 | te_list = np.array(te_list) 41 | te_imgs = [f'{img_root}/{te_list[i, 3]}/{te_list[i, cfg.img_idx]}' for i in range(len(te_list))] 42 | te_ages = te_list[:, cfg.lb_idx] 43 | te_std = te_list[:, 2] 44 | # 45 | # # debug for n_ranks and margin relation 46 | # idx = np.argwhere(te_ages < 60).flatten() 47 | # te_ages = te_ages[idx] 48 | # te_imgs = np.array(te_imgs)[idx] 49 | # te_std = te_std[idx] 50 | 51 | else: 52 | with open(cfg.train_file, 'rb') as f: 53 | data = pickle.load(f) 54 | tr_imgs = data['data'] 55 | tr_ages = data['age'] 56 | 57 | with open(cfg.test_file, 'rb') as f: 58 | data = pickle.load(f) 59 | te_imgs = data['data'] 60 | te_ages = data['age'] 61 | 62 | loader_dict = dict() 63 | loader_dict['train'] = DataLoader(OL_basic_train.OLBasic_Train(tr_imgs, tr_ages, cfg.transform_tr, cfg.tau, logscale=cfg.logscale, is_filelist=cfg.is_filelist), 64 | batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.num_workers) 65 | loader_dict['train_for_val'] = DataLoader(basic.Basic(tr_imgs, tr_ages, cfg.transform_te, is_filelist=cfg.is_filelist, norm_age=False), 66 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, 67 | num_workers=cfg.num_workers) 68 | 69 | loader_dict['val'] = DataLoader(basic.Basic(te_imgs, te_ages, cfg.transform_te, is_filelist=cfg.is_filelist, std=te_std, norm_age=False), 70 | batch_size=cfg.batch_size, shuffle=False, drop_last=False, num_workers=cfg.num_workers) 71 | return loader_dict 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /utils/metric_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.util import to_device 4 | 5 | 6 | def get_matches_and_diffs(labels, ref_labels=None): 7 | if ref_labels is None: 8 | ref_labels = labels 9 | labels1 = labels.unsqueeze(1) 10 | labels2 = ref_labels.unsqueeze(0) 11 | matches = (labels1 == labels2).byte() 12 | diffs = matches ^ 1 13 | if ref_labels is labels: 14 | matches.fill_diagonal_(0) 15 | return matches, diffs 16 | 17 | 18 | def get_all_pairs_indices(labels, ref_labels=None): 19 | """ 20 | Given a tensor of labels, this will return 4 tensors. 21 | The first 2 tensors are the indices which form all positive pairs 22 | The second 2 tensors are the indices which form all negative pairs 23 | """ 24 | matches, diffs = get_matches_and_diffs(labels, ref_labels) 25 | a1_idx, p_idx = torch.where(matches) 26 | a2_idx, n_idx = torch.where(diffs) 27 | return a1_idx, p_idx, a2_idx, n_idx 28 | 29 | 30 | def get_all_triplets_indices(labels, ref_labels=None): 31 | matches, diffs = get_matches_and_diffs(labels, ref_labels) 32 | triplets = matches.unsqueeze(2) * diffs.unsqueeze(1) 33 | return torch.where(triplets) 34 | 35 | 36 | # sample triplets, with a weighted distribution if weights is specified. 37 | def get_random_triplet_indices( 38 | labels, ref_labels=None, t_per_anchor=None, weights=None 39 | ): 40 | a_idx, p_idx, n_idx = [], [], [] 41 | labels_device = labels.device 42 | ref_labels = labels if ref_labels is None else ref_labels 43 | unique_labels = torch.unique(labels) 44 | for label in unique_labels: 45 | # Get indices of positive samples for this label. 46 | p_inds = torch.where(ref_labels == label)[0] 47 | if ref_labels is labels: 48 | a_inds = p_inds 49 | else: 50 | a_inds = torch.where(labels == label)[0] 51 | n_inds = torch.where(ref_labels != label)[0] 52 | n_a = len(a_inds) 53 | n_p = len(p_inds) 54 | min_required_p = 2 if ref_labels is labels else 1 55 | if (n_p < min_required_p) or (len(n_inds) < 1): 56 | continue 57 | 58 | k = n_p if t_per_anchor is None else t_per_anchor 59 | num_triplets = n_a * k 60 | p_inds_ = p_inds.expand((n_a, n_p)) 61 | # Remove anchors from list of possible positive samples. 62 | if ref_labels is labels: 63 | p_inds_ = p_inds_[~torch.eye(n_a).bool()].view((n_a, n_a - 1)) 64 | # Get indices of indices of k random positive samples for each anchor. 65 | p_ = torch.randint(0, p_inds_.shape[1], (num_triplets,)) 66 | # Get indices of indices of corresponding anchors. 67 | a_ = torch.arange(n_a).view(-1, 1).repeat(1, k).view(num_triplets) 68 | p = p_inds_[a_, p_] 69 | a = a_inds[a_] 70 | 71 | # Get indices of negative samples for this label. 72 | if weights is not None: 73 | w = weights[:, n_inds][a] 74 | non_zero_rows = torch.where(torch.sum(w, dim=1) > 0)[0] 75 | if len(non_zero_rows) == 0: 76 | continue 77 | w = w[non_zero_rows] 78 | a = a[non_zero_rows] 79 | p = p[non_zero_rows] 80 | # Sample the negative indices according to the weights. 81 | if w.dtype == torch.float16: 82 | # special case needed due to pytorch cuda bug 83 | # https://github.com/pytorch/pytorch/issues/19900 84 | w = w.type(torch.float32) 85 | n_ = torch.multinomial(w, 1, replacement=True).flatten() 86 | else: 87 | # Sample the negative indices uniformly. 88 | n_ = torch.randint(0, len(n_inds), (num_triplets,)) 89 | n = n_inds[n_] 90 | a_idx.append(a) 91 | p_idx.append(p) 92 | n_idx.append(n) 93 | 94 | if len(a_idx) > 0: 95 | a_idx = to_device(torch.cat(a_idx), device=labels_device, dtype=torch.long) 96 | p_idx = to_device(torch.cat(p_idx), device=labels_device, dtype=torch.long) 97 | n_idx = to_device(torch.cat(n_idx), device=labels_device, dtype=torch.long) 98 | assert len(a_idx) == len(p_idx) == len(n_idx) 99 | return a_idx, p_idx, n_idx 100 | else: 101 | empty = torch.tensor([], device=labels_device, dtype=torch.long) 102 | return empty.clone(), empty.clone(), empty.clone() 103 | -------------------------------------------------------------------------------- /data/datasets/OL_basic_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | 8 | class OLBasic_Train(Dataset): 9 | def __init__(self, imgs, labels, transform, tau, norm_age=True, logscale=False, is_filelist=False): 10 | super(Dataset, self).__init__() 11 | self.imgs = imgs 12 | self.labels = labels 13 | self.transform = transform 14 | self.n_imgs = len(self.imgs) 15 | self.min_age_bf_norm = self.labels.min() 16 | if logscale: 17 | self.labels = np.log(labels.astype(np.float32)) 18 | else: 19 | if norm_age: 20 | self.labels = self.labels - min(self.labels) 21 | 22 | self.max_age = self.labels.max() 23 | self.min_age = self.labels.min() 24 | self.tau = tau 25 | self.is_filelist = is_filelist 26 | 27 | # mapping age to rank : because there are omitted ages 28 | rank = 0 29 | self.mapping = dict() 30 | for cls in np.unique(self.labels): 31 | self.mapping[cls] = rank 32 | rank += 1 33 | self.ranks = np.array([self.mapping[l] for l in self.labels]) 34 | 35 | def __getitem__(self, item): 36 | order_label, ref_idx = self.find_reference(self.labels[item], self.labels, min_rank=self.min_age, 37 | max_rank=self.max_age) 38 | if self.is_filelist: 39 | base_img = np.asarray(load_one_image(self.imgs[item])).astype('uint8') 40 | ref_img = np.asarray(load_one_image(self.imgs[ref_idx])).astype('uint8') 41 | else: 42 | base_img = np.asarray(self.imgs[item]).astype('uint8') 43 | ref_img = np.asarray(self.imgs[ref_idx]).astype('uint8') 44 | base_img = self.transform(base_img) 45 | ref_img = self.transform(ref_img) 46 | 47 | base_age = self.labels[item] 48 | ref_age = self.labels[ref_idx] 49 | 50 | # gt ranks 51 | base_rank = self.ranks[item] 52 | ref_rank = self.ranks[ref_idx] 53 | 54 | return base_img, ref_img, order_label, [base_rank, ref_rank], item 55 | 56 | def __len__(self): 57 | return self.n_imgs 58 | 59 | def find_reference(self, base_rank, ref_ranks, min_rank=0, max_rank=32, epsilon=1e-4): 60 | 61 | def get_indices_in_range(search_range, ages): 62 | """find indices of values within range[0] <= x <= range[1]""" 63 | return np.argwhere(np.logical_and(search_range[0] <= ages, ages <= search_range[1])) 64 | 65 | rng = np.random.default_rng() 66 | order = np.random.randint(0, 3) 67 | ref_idx = -1 68 | debug_flag = 0 69 | while ref_idx == -1: 70 | if debug_flag == 3: 71 | raise ValueError(f'Failed to find reference... base_score: {base_rank}') 72 | if order == 0: # base_rank > ref_rank + tau 73 | ref_range_min = min_rank 74 | ref_range_max = base_rank - self.tau - epsilon 75 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 76 | if len(candidates) > 0: 77 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 78 | else: 79 | order = (order + 1) % 3 80 | debug_flag += 1 81 | continue 82 | elif order == 1: # base_rank < ref_rank - tau 83 | ref_range_min = base_rank + self.tau + epsilon 84 | ref_range_max = max_rank 85 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 86 | if len(candidates) > 0: 87 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 88 | else: 89 | order = (order + 1) % 3 90 | debug_flag += 1 91 | continue 92 | 93 | else: # |base_rank - ref_rank| <= tau 94 | ref_range_min = base_rank - self.tau - epsilon 95 | ref_range_max = base_rank + self.tau + epsilon 96 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 97 | if len(candidates) > 0: 98 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 99 | else: 100 | order = (order + 1) % 3 101 | debug_flag += 1 102 | return order, ref_idx -------------------------------------------------------------------------------- /data/datasets/OL_hard_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | 8 | class OLHard_Train(Dataset): 9 | def __init__(self, imgs, labels, transform, tau, norm_age=True, logscale=False, is_filelist=False): 10 | super(Dataset, self).__init__() 11 | self.imgs = imgs 12 | self.labels = labels 13 | self.transform = transform 14 | self.n_imgs = len(self.imgs) 15 | 16 | if logscale: 17 | self.labels = np.log(labels.astype(np.float32)) 18 | else: 19 | if norm_age: 20 | self.labels = self.labels - min(self.labels) 21 | 22 | self.max_age = self.labels.max() 23 | self.min_age = self.labels.min() 24 | self.tau = tau 25 | self.is_filelist = is_filelist 26 | 27 | # mapping age to rank : because there are omitted ages 28 | rank = 0 29 | self.mapping = dict() 30 | for cls in np.unique(self.labels): 31 | self.mapping[cls] = rank 32 | rank += 1 33 | self.ranks = np.array([self.mapping[l] for l in self.labels]) 34 | self.max_rank = self.ranks.max() 35 | self.min_rank = self.ranks.min() 36 | 37 | def __getitem__(self, item): 38 | order_label, ref_idx = self.find_reference(self.ranks[item], self.ranks, min_rank=self.min_rank, 39 | max_rank=self.max_rank) 40 | if self.is_filelist: 41 | base_img = np.asarray(load_one_image(self.imgs[item])).astype('uint8') 42 | ref_img = np.asarray(load_one_image(self.imgs[ref_idx])).astype('uint8') 43 | else: 44 | base_img = np.asarray(self.imgs[item]).astype('uint8') 45 | ref_img = np.asarray(self.imgs[ref_idx]).astype('uint8') 46 | base_img = self.transform(base_img) 47 | ref_img = self.transform(ref_img) 48 | 49 | base_age = self.labels[item] 50 | ref_age = self.labels[ref_idx] 51 | 52 | # gt ranks 53 | base_rank = self.ranks[item] 54 | ref_rank = self.ranks[ref_idx] 55 | 56 | return base_img, ref_img, order_label, [base_rank, ref_rank], item 57 | 58 | def __len__(self): 59 | return self.n_imgs 60 | 61 | def find_reference(self, base_rank, ref_ranks, min_rank=0, max_rank=32, epsilon=1e-4): 62 | 63 | def get_indices_in_range(search_range, ages): 64 | """find indices of values within range[0] <= x <= range[1]""" 65 | return np.argwhere(np.logical_and(search_range[0] <= ages, ages <= search_range[1])) 66 | 67 | rng = np.random.default_rng() 68 | order = np.random.randint(0, 3) 69 | ref_idx = -1 70 | debug_flag = 0 71 | while ref_idx == -1: 72 | if debug_flag == 3: 73 | raise ValueError(f'Failed to find reference... base_score: {base_rank}') 74 | if order == 0: # base_rank > ref_rank + tau 75 | ref_range_min = max(min_rank, base_rank - 2*self.tau - epsilon) 76 | ref_range_max = base_rank - self.tau - epsilon 77 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 78 | if len(candidates) > 0: 79 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 80 | else: 81 | order = (order + 1) % 3 82 | debug_flag += 1 83 | continue 84 | elif order == 1: # base_rank < ref_rank - tau 85 | ref_range_min = base_rank + self.tau + epsilon 86 | ref_range_max = min(max_rank, base_rank + 2*self.tau + epsilon) 87 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 88 | if len(candidates) > 0: 89 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 90 | else: 91 | order = (order + 1) % 3 92 | debug_flag += 1 93 | continue 94 | 95 | else: # |base_rank - ref_rank| <= tau 96 | ref_range_min = base_rank - self.tau - epsilon 97 | ref_range_max = base_rank + self.tau + epsilon 98 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 99 | if len(candidates) > 0: 100 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 101 | else: 102 | order = (order + 1) % 3 103 | debug_flag += 1 104 | return order, ref_idx -------------------------------------------------------------------------------- /networks/gol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from networks.base import BaseModel 5 | 6 | class GOL(BaseModel): 7 | def __init__(self, cfg): 8 | super().__init__(cfg) 9 | if cfg.backbone == 'resnet12': 10 | hdim = 640 11 | elif cfg.backbone == 'resnet18': 12 | hdim = 640 13 | elif cfg.backbone == 'vgg16': 14 | hdim = 512 15 | elif cfg.backbone == 'vgg16v2': 16 | hdim = 512 17 | elif cfg.backbone == 'vgg16v2norm': 18 | hdim = 512 19 | elif cfg.backbone == 'vgg16fc': 20 | hdim = 4096 21 | else: 22 | raise ValueError('no backbone was found.') 23 | 24 | if cfg.ref_mode =='fix': 25 | if cfg.ref_point_num == 2: 26 | self.ref_points = torch.randn([hdim]) 27 | self.ref_points = torch.stack([self.ref_points, -self.ref_points]) 28 | self.ref_points = nn.functional.normalize(self.ref_points, dim=-1) 29 | 30 | elif cfg.ref_point_num == 3: 31 | self.ref_points = nn.functional.normalize(torch.randn([hdim]), dim=0) 32 | noise = (1e-4)*torch.randn(hdim) 33 | max_point = nn.functional.normalize(-self.ref_points + noise, dim=0) 34 | mid_point = nn.functional.normalize(self.ref_points + max_point, dim=0) 35 | self.ref_points = torch.stack([self.ref_points, mid_point, -self.ref_points]) 36 | 37 | elif cfg.ref_point_num == 5: 38 | self.ref_points = nn.functional.normalize(torch.randn([hdim]), dim=0) 39 | noise = (1e-4)*torch.randn(hdim) 40 | max_point = nn.functional.normalize(-self.ref_points + noise, dim=0) 41 | r2_point = nn.functional.normalize(self.ref_points + max_point, dim=0) 42 | r1_point = nn.functional.normalize(self.ref_points + r2_point, dim=0) 43 | r3_point = nn.functional.normalize(r2_point-self.ref_points, dim=0) 44 | self.ref_points = torch.stack([self.ref_points, r1_point, r2_point, r3_point, -self.ref_points]) 45 | print(torch.sum(self.ref_points[0] * self.ref_points[1])) 46 | print(torch.sum(self.ref_points[1] * self.ref_points[2])) 47 | print(torch.sum(self.ref_points[2] * self.ref_points[3])) 48 | print(torch.sum(self.ref_points[3] * self.ref_points[4])) 49 | else: 50 | self.ref_points = torch.randn([cfg.ref_point_num, hdim]) 51 | self.ref_points = nn.functional.normalize(self.ref_points, dim=-1) 52 | 53 | 54 | self.ref_points = nn.parameter.Parameter(self.ref_points) 55 | self.ref_points.requires_grad = False 56 | 57 | elif cfg.ref_mode == 'flex_reference': 58 | self.ref_points = torch.randn([cfg.n_ranks, hdim]) 59 | self.ref_points = nn.parameter.Parameter(self.ref_points) 60 | 61 | else: 62 | if cfg.ref_point_num == 2: 63 | self.ref_points = torch.randn([hdim]) 64 | self.ref_points = torch.stack([self.ref_points, -self.ref_points]) 65 | self.ref_points = nn.functional.normalize(self.ref_points, dim=-1) 66 | 67 | elif cfg.ref_point_num == 3: 68 | self.ref_points = nn.functional.normalize(torch.randn([hdim]), dim=0) 69 | noise = (1e-3)*torch.randn(hdim) 70 | max_point = nn.functional.normalize(-self.ref_points + noise, dim=0) 71 | mid_point = nn.functional.normalize(self.ref_points + max_point, dim=0) 72 | self.ref_points = torch.stack([self.ref_points, mid_point, -self.ref_points]) 73 | 74 | elif cfg.ref_point_num == 5: 75 | self.ref_points = nn.functional.normalize(torch.randn([hdim]), dim=0) 76 | noise = (1e-4)*torch.randn(hdim) 77 | max_point = nn.functional.normalize(-self.ref_points + noise, dim=0) 78 | r2_point = nn.functional.normalize(self.ref_points + max_point, dim=0) 79 | r1_point = nn.functional.normalize(self.ref_points + r2_point, dim=0) 80 | r3_point = nn.functional.normalize(r2_point-self.ref_points, dim=0) 81 | self.ref_points = torch.stack([self.ref_points, r1_point, r2_point, r3_point, -self.ref_points]) 82 | print(torch.sum(self.ref_points[0] * self.ref_points[1])) 83 | print(torch.sum(self.ref_points[1] * self.ref_points[2])) 84 | print(torch.sum(self.ref_points[2] * self.ref_points[3])) 85 | print(torch.sum(self.ref_points[3] * self.ref_points[4])) 86 | 87 | else: 88 | self.ref_points = torch.randn([cfg.ref_point_num, hdim]) 89 | if cfg.start_norm: 90 | self.ref_points = nn.functional.normalize(self.ref_points, dim=-1) 91 | 92 | self.ref_points = nn.parameter.Parameter(self.ref_points) 93 | 94 | def _forward(self, base_embs, ref_embs=None): 95 | return base_embs 96 | -------------------------------------------------------------------------------- /data/datasets/OL_triple_order_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | 8 | class OLNonary_Train(Dataset): 9 | def __init__(self, imgs, labels, transform, tau, norm_age=True, logscale=False, is_filelist=False): 10 | super(Dataset, self).__init__() 11 | self.imgs = imgs 12 | self.labels = labels 13 | self.transform = transform 14 | self.n_imgs = len(self.imgs) 15 | 16 | if logscale: 17 | self.labels = np.log(labels.astype(np.float32)) 18 | else: 19 | if norm_age: 20 | self.labels = self.labels - min(self.labels) 21 | 22 | self.max_age = self.labels.max() 23 | self.min_age = self.labels.min() 24 | self.tau = tau 25 | self.is_filelist = is_filelist 26 | 27 | # mapping age to rank : because there are omitted ages 28 | rank = 0 29 | self.mapping = dict() 30 | for cls in np.unique(self.labels): 31 | self.mapping[cls] = rank 32 | rank += 1 33 | self.ranks = np.array([self.mapping[l] for l in self.labels]) 34 | 35 | self.hardness_multiplier = 5 # thus, interval is 4 (=5-1) tau 36 | 37 | def __getitem__(self, item): 38 | order1, ref_idx1 = self.find_reference(self.labels[item], self.labels, min_rank=self.min_age, 39 | max_rank=self.max_age) 40 | order2, ref_idx2 = self.find_reference(self.labels[item], self.labels, min_rank=self.min_age, 41 | max_rank=self.max_age) 42 | order_label = 3*order1 + order2 43 | 44 | if self.is_filelist: 45 | base_img = np.asarray(load_one_image(self.imgs[item])).astype('uint8') 46 | ref_img1 = np.asarray(load_one_image(self.imgs[ref_idx1])).astype('uint8') 47 | ref_img2 = np.asarray(load_one_image(self.imgs[ref_idx2])).astype('uint8') 48 | else: 49 | base_img = np.asarray(self.imgs[item]).astype('uint8') 50 | ref_img1 = np.asarray(self.imgs[ref_idx1]).astype('uint8') 51 | ref_img2 = np.asarray(self.imgs[ref_idx2]).astype('uint8') 52 | 53 | base_img = self.transform(base_img) 54 | ref_img1 = self.transform(ref_img1) 55 | ref_img2 = self.transform(ref_img2) 56 | 57 | # base_age = self.labels[item] 58 | # ref_age1 = self.labels[ref_idx1] 59 | # ref_age2 = self.labels[ref_idx2] 60 | # 61 | # # gt ranks 62 | # base_rank = self.ranks[item] 63 | # ref_rank1 = self.ranks[ref_idx1] 64 | # ref_rank2 = self.ranks[ref_idx2] 65 | 66 | return base_img, ref_img1, ref_img2, order_label, item 67 | 68 | def __len__(self): 69 | return self.n_imgs 70 | 71 | def find_reference(self, base_rank, ref_ranks, min_rank=0, max_rank=32, epsilon=1e-4): 72 | 73 | def get_indices_in_range(search_range, ages): 74 | """find indices of values within range[0] <= x <= range[1]""" 75 | return np.argwhere(np.logical_and(search_range[0] <= ages, ages <= search_range[1])) 76 | 77 | rng = np.random.default_rng() 78 | order = np.random.randint(0, 3) 79 | ref_idx = -1 80 | debug_flag = 0 81 | while ref_idx == -1: 82 | if debug_flag == 3: 83 | raise ValueError(f'Failed to find reference... base_score: {base_rank}') 84 | if order == 0: # base_rank > ref_rank + tau 85 | ref_range_min = min_rank 86 | ref_range_max = base_rank - self.tau - epsilon 87 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 88 | if len(candidates) > 0: 89 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 90 | else: 91 | order = (order + 1) % 3 92 | debug_flag += 1 93 | continue 94 | elif order == 1: # |base_rank - ref_rank| <= tau 95 | ref_range_min = base_rank - self.tau - epsilon 96 | ref_range_max = base_rank + self.tau + epsilon 97 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 98 | if len(candidates) > 0: 99 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 100 | else: 101 | order = (order + 1) % 3 102 | debug_flag += 1 103 | 104 | else: # base_rank < ref_rank - tau 105 | ref_range_min = base_rank + self.tau + epsilon 106 | ref_range_max = max_rank 107 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 108 | if len(candidates) > 0: 109 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 110 | else: 111 | order = (order + 1) % 3 112 | debug_flag += 1 113 | continue 114 | return order, ref_idx -------------------------------------------------------------------------------- /data/datasets/OL_lossbalancing_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | 8 | class OLLossBalancing_Train(Dataset): 9 | def __init__(self, imgs, labels, transform, tau, norm_age=True, logscale=False, is_filelist=False): 10 | super(Dataset, self).__init__() 11 | self.imgs = imgs 12 | self.labels = labels 13 | self.transform = transform 14 | self.n_imgs = len(self.imgs) 15 | self.min_age_bf_norm = self.labels.min() 16 | if logscale: 17 | self.labels = np.log(labels.astype(np.float32)) 18 | else: 19 | if norm_age: 20 | self.labels = self.labels - min(self.labels) 21 | 22 | self.max_age = self.labels.max() 23 | self.min_age = self.labels.min() 24 | self.tau = tau 25 | self.is_filelist = is_filelist 26 | 27 | # mapping age to rank : because there are omitted ages 28 | rank = 0 29 | self.mapping = dict() 30 | for cls in np.unique(self.labels): 31 | self.mapping[cls] = rank 32 | rank += 1 33 | self.ranks = np.array([self.mapping[l] for l in self.labels]) 34 | self.n_ranks = self.ranks.max()+1 35 | self.probs = None 36 | self.sample_idxs_per_rank = dict() 37 | for r in range(self.n_ranks): 38 | self.sample_idxs_per_rank[r] = np.argwhere(self.ranks==r).flatten() 39 | 40 | def __getitem__(self, item): 41 | if self.probs is not None: 42 | rng = np.random.default_rng() 43 | base_rank = rng.choice(np.arange(self.n_ranks), 1, p=self.probs)[0] 44 | base_idx = rng.choice(self.sample_idxs_per_rank[base_rank], 1)[0] 45 | else: 46 | base_idx = item 47 | order_label, ref_idx = self.find_reference(self.labels[base_idx], self.labels, min_rank=self.min_age, 48 | max_rank=self.max_age) 49 | if self.is_filelist: 50 | base_img = np.asarray(load_one_image(self.imgs[base_idx])).astype('uint8') 51 | ref_img = np.asarray(load_one_image(self.imgs[ref_idx])).astype('uint8') 52 | else: 53 | base_img = np.asarray(self.imgs[base_idx]).astype('uint8') 54 | ref_img = np.asarray(self.imgs[ref_idx]).astype('uint8') 55 | base_img = self.transform(base_img) 56 | ref_img = self.transform(ref_img) 57 | 58 | # gt ranks 59 | base_rank = self.ranks[base_idx] 60 | ref_rank = self.ranks[ref_idx] 61 | 62 | return base_img, ref_img, order_label, [base_rank, ref_rank], item 63 | 64 | def __len__(self): 65 | return self.n_imgs 66 | 67 | def find_reference(self, base_rank, ref_ranks, min_rank=0, max_rank=32, epsilon=1e-4): 68 | 69 | def get_indices_in_range(search_range, ages): 70 | """find indices of values within range[0] <= x <= range[1]""" 71 | return np.argwhere(np.logical_and(search_range[0] <= ages, ages <= search_range[1])) 72 | 73 | rng = np.random.default_rng() 74 | order = np.random.randint(0, 3) 75 | ref_idx = -1 76 | debug_flag = 0 77 | while ref_idx == -1: 78 | if debug_flag == 3: 79 | raise ValueError(f'Failed to find reference... base_score: {base_rank}') 80 | if order == 0: # base_rank > ref_rank + tau 81 | ref_range_min = min_rank 82 | ref_range_max = base_rank - self.tau - epsilon 83 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 84 | if len(candidates) > 0: 85 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 86 | else: 87 | order = (order + 1) % 3 88 | debug_flag += 1 89 | continue 90 | elif order == 1: # base_rank < ref_rank - tau 91 | ref_range_min = base_rank + self.tau + epsilon 92 | ref_range_max = max_rank 93 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 94 | if len(candidates) > 0: 95 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 96 | else: 97 | order = (order + 1) % 3 98 | debug_flag += 1 99 | continue 100 | 101 | else: # |base_rank - ref_rank| <= tau 102 | ref_range_min = base_rank - self.tau - epsilon 103 | ref_range_max = base_rank + self.tau + epsilon 104 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 105 | if len(candidates) > 0: 106 | ref_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 107 | else: 108 | order = (order + 1) % 3 109 | debug_flag += 1 110 | return order, ref_idx 111 | 112 | def update_probs(self, loss_record): 113 | self.probs = np.exp(loss_record) / np.sum(np.exp(loss_record)) -------------------------------------------------------------------------------- /data/datasets/OL_mining_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | class OLMining_Train(Dataset): 8 | def __init__(self, imgs, labels, transform, tau, norm_age=True, is_filelist=False, max_epoch=350): 9 | super(Dataset, self).__init__() 10 | self.imgs = imgs 11 | self.labels = labels 12 | self.transform = transform 13 | self.n_imgs = len(self.imgs) 14 | if norm_age: 15 | self.labels = self.labels - min(self.labels) 16 | 17 | self.max_age = self.labels.max() 18 | self.min_age = self.labels.min() 19 | self.tau = tau 20 | self.is_filelist = is_filelist 21 | 22 | self.max_hard_sample_prob = 0.75 23 | self.init_hardness_multiplier = 3 24 | self.hard_sample_prob = 0 25 | self.hardness_multiplier = 10 26 | 27 | self.epoch = 0 28 | self.max_epoch = max_epoch 29 | 30 | def __getitem__(self, item): 31 | order_label, ref_idx = self.find_reference(self.labels[item], self.labels, min_rank=self.min_age, 32 | max_rank=self.max_age) 33 | if self.is_filelist: 34 | base_img = np.asarray(load_one_image(self.imgs[item])).astype('uint8') 35 | ref_img = np.asarray(load_one_image(self.imgs[ref_idx])).astype('uint8') 36 | else: 37 | base_img = np.asarray(self.imgs[item]).astype('uint8') 38 | ref_img = np.asarray(self.imgs[ref_idx]).astype('uint8') 39 | base_img = self.transform(base_img) 40 | ref_img = self.transform(ref_img) 41 | 42 | one_hot_vector = self.convert_to_onehot(order_label) 43 | 44 | # gt ages 45 | base_age = self.labels[item] 46 | ref_age = self.labels[ref_idx] 47 | return base_img, ref_img, one_hot_vector, order_label, [base_age, ref_age], item 48 | 49 | def __len__(self): 50 | return self.n_imgs 51 | 52 | def convert_to_onehot(self, order): 53 | if order == 0: 54 | one_hot_vector = torch.tensor([1, 0], dtype=torch.float32) 55 | elif order == 1: 56 | one_hot_vector = torch.tensor([0, 1], dtype=torch.float32) 57 | elif order == 2: 58 | one_hot_vector = torch.tensor([0.5, 0.5], dtype=torch.float32) 59 | else: 60 | raise ValueError(f'order value {order} is out of expected range.') 61 | return one_hot_vector 62 | 63 | def update_mining_params(self, ): 64 | self.epoch += 1 65 | 66 | self.hard_sample_prob = self.max_hard_sample_prob*self.epoch/self.max_epoch 67 | self.hardness_multiplier = -((self.init_hardness_multiplier - 1)/self.max_epoch*self.epoch) + self.init_hardness_multiplier 68 | 69 | def find_reference(self, base_rank, ref_ranks, min_rank=0, max_rank=32, epsilon=1e-4): 70 | 71 | def get_indices_in_range(search_range, ages): 72 | """find indices of values within range[0] <= x <= range[1]""" 73 | return np.argwhere(np.logical_and(search_range[0] <= ages, ages <= search_range[1])) 74 | 75 | is_normal = np.random.choice([True, False], p=[1-self.hard_sample_prob, self.hard_sample_prob]) 76 | 77 | order = np.random.randint(0, 3) 78 | ref_idx = -1 79 | debug_flag = 0 80 | while ref_idx == -1: 81 | if debug_flag == 3: 82 | raise ValueError(f'Failed to find reference... base_score: {base_rank}') 83 | if order == 0: # base_rank > ref_rank 84 | ref_range_min = min_rank if is_normal else max(base_rank - (self.tau*self.hardness_multiplier), min_rank) 85 | ref_range_max = base_rank - self.tau - epsilon 86 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 87 | if len(candidates) > 0: 88 | ref_idx = candidates[np.random.choice(len(candidates), 1)[0]][0] 89 | else: 90 | order = (order + 1) % 3 91 | debug_flag += 1 92 | continue 93 | elif order == 1: # base_rank < ref_rank 94 | ref_range_min = base_rank + self.tau + epsilon 95 | ref_range_max = max_rank if is_normal else min(base_rank + (self.tau*self.hardness_multiplier), max_rank) 96 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 97 | if len(candidates) > 0: 98 | ref_idx = candidates[np.random.choice(len(candidates), 1)[0]][0] 99 | else: 100 | order = (order + 1) % 3 101 | debug_flag += 1 102 | continue 103 | 104 | else: # base_rank = ref_rank 105 | ref_range_min = base_rank - self.tau - epsilon 106 | ref_range_max = base_rank + self.tau + epsilon 107 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 108 | if len(candidates) > 0: 109 | ref_idx = candidates[np.random.choice(len(candidates), 1)[0]][0] 110 | else: 111 | order = (order + 1) % 3 112 | debug_flag += 1 113 | return order, ref_idx -------------------------------------------------------------------------------- /data/datasets/OL_triplet_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | 8 | class OLTriplet_Train(Dataset): 9 | def __init__(self, imgs, labels, transform, tau, norm_age=True, logscale=False, is_filelist=False): 10 | super(Dataset, self).__init__() 11 | self.imgs = imgs 12 | self.labels = labels 13 | self.transform = transform 14 | self.n_imgs = len(self.imgs) 15 | 16 | if logscale: 17 | self.labels = np.log(labels.astype(np.float32)) 18 | else: 19 | if norm_age: 20 | self.labels = self.labels - min(self.labels) 21 | 22 | self.max_age = self.labels.max() 23 | self.min_age = self.labels.min() 24 | self.tau = tau 25 | self.is_filelist = is_filelist 26 | 27 | # mapping age to rank : because there are omitted ages 28 | rank = 0 29 | self.mapping = dict() 30 | for cls in np.unique(self.labels): 31 | self.mapping[cls] = rank 32 | rank += 1 33 | self.ranks = np.array([self.mapping[l] for l in self.labels]) 34 | 35 | self.hardness_multiplier = 20 # thus, interval is 4 (=5-1) tau 36 | 37 | def __getitem__(self, item): 38 | order_label, ref_idx = self.find_reference_triplet(self.labels[item], self.labels, min_rank=self.min_age, 39 | max_rank=self.max_age) 40 | if self.is_filelist: 41 | base_img = np.asarray(load_one_image(self.imgs[item])).astype('uint8') 42 | ref_img1 = np.asarray(load_one_image(self.imgs[ref_idx[0]])).astype('uint8') 43 | ref_img2 = np.asarray(load_one_image(self.imgs[ref_idx[1]])).astype('uint8') 44 | else: 45 | base_img = np.asarray(self.imgs[item]).astype('uint8') 46 | ref_img1 = np.asarray(self.imgs[ref_idx[0]]).astype('uint8') 47 | ref_img2 = np.asarray(self.imgs[ref_idx[1]]).astype('uint8') 48 | base_img = self.transform(base_img) 49 | ref_img1 = self.transform(ref_img1) 50 | ref_img2 = self.transform(ref_img2) 51 | 52 | base_age = self.labels[item] 53 | ref_age1 = self.labels[ref_idx[0]] 54 | ref_age2 = self.labels[ref_idx[1]] 55 | 56 | # gt ranks 57 | base_rank = self.ranks[item] 58 | ref_rank1 = self.ranks[ref_idx[0]] 59 | ref_rank2 = self.ranks[ref_idx[1]] 60 | 61 | return base_img, ref_img1, ref_img2, order_label, [base_age, ref_age1, ref_age2, base_rank, ref_rank1, ref_rank2], item 62 | 63 | def __len__(self): 64 | return self.n_imgs 65 | 66 | def find_reference_triplet(self, base_rank, ref_ranks, min_rank=0, max_rank=32, epsilon=1e-4): 67 | 68 | def get_indices_in_range(search_range, ages): 69 | """find indices of values within range[0] <= x <= range[1]""" 70 | return np.argwhere(np.logical_and(search_range[0] <= ages, ages <= search_range[1])) 71 | 72 | rng = np.random.default_rng() 73 | case = rng.choice([0,1], 1, p=[2/3, 1/3]) 74 | ref_idx2 = -1 75 | debug_flag = 0 76 | while ref_idx2 == -1: 77 | if debug_flag == 2: 78 | raise ValueError(f'Failed to find reference... base_score: {base_rank}') 79 | if case == 0: # ref1_rank + tau < base_rank < ref2_rank - tau 80 | ref_range_min = max(base_rank - (self.tau*self.hardness_multiplier), min_rank) 81 | ref_range_max = base_rank - self.tau - epsilon 82 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 83 | if len(candidates) > 0: 84 | ref_idx1 = candidates[rng.choice(len(candidates), 1)[0]][0] 85 | else: 86 | case = (case + 1) % 2 87 | debug_flag += 1 88 | continue 89 | 90 | ref_range_min = base_rank + self.tau + epsilon 91 | ref_range_max = min(base_rank + (self.tau*self.hardness_multiplier), max_rank) 92 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 93 | if len(candidates) > 0: 94 | ref_idx2 = candidates[rng.choice(len(candidates), 1)[0]][0] 95 | else: 96 | case = (case + 1) % 2 97 | debug_flag += 1 98 | continue 99 | 100 | elif case == 1: # |base_rank - ref_rank| <= tau 101 | ref_range_min = base_rank - self.tau - epsilon 102 | ref_range_max = base_rank + self.tau + epsilon 103 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 104 | if len(candidates) > 0: 105 | ref_idx1 = candidates[rng.choice(len(candidates), 1)[0]][0] 106 | ref_idx2 = candidates[rng.choice(len(candidates), 1)[0]][0] 107 | else: 108 | case = (case + 1) % 2 109 | debug_flag += 1 110 | else: 111 | raise ValueError(f'[!] something is wrong... base rank{base_rank}, case{case}') 112 | 113 | if case == 0: 114 | order = [0, 1] 115 | else: 116 | order = [2, 2] 117 | 118 | return order, [ref_idx1, ref_idx2] -------------------------------------------------------------------------------- /data/datasets/angl_triplet_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | 8 | class AngleTriplet_Train(Dataset): 9 | def __init__(self, imgs, labels, transform, tau, norm_age=True, logscale=False, is_filelist=False): 10 | super(Dataset, self).__init__() 11 | self.imgs = imgs 12 | self.labels = labels 13 | self.transform = transform 14 | self.n_imgs = len(self.imgs) 15 | 16 | if logscale: 17 | self.labels = np.log(labels.astype(np.float32)) 18 | else: 19 | if norm_age: 20 | self.labels = self.labels - min(self.labels) 21 | 22 | self.max_age = self.labels.max() 23 | self.min_age = self.labels.min() 24 | self.tau = tau 25 | self.is_filelist = is_filelist 26 | 27 | # mapping age to rank : because there are omitted ages 28 | rank = 0 29 | self.mapping = dict() 30 | for cls in np.unique(self.labels): 31 | self.mapping[cls] = rank 32 | rank += 1 33 | self.ranks = np.array([self.mapping[l] for l in self.labels]) 34 | 35 | self.hardness_multiplier = 20 # thus, interval is 4 (=5-1) tau 36 | 37 | def __getitem__(self, item): 38 | order_label, ref_idx = self.find_reference_triplet(self.labels[item], self.labels, min_rank=self.min_age, 39 | max_rank=self.max_age) 40 | if self.is_filelist: 41 | base_img = np.asarray(load_one_image(self.imgs[item])).astype('uint8') 42 | ref_img1 = np.asarray(load_one_image(self.imgs[ref_idx[0]])).astype('uint8') 43 | ref_img2 = np.asarray(load_one_image(self.imgs[ref_idx[1]])).astype('uint8') 44 | else: 45 | base_img = np.asarray(self.imgs[item]).astype('uint8') 46 | ref_img1 = np.asarray(self.imgs[ref_idx[0]]).astype('uint8') 47 | ref_img2 = np.asarray(self.imgs[ref_idx[1]]).astype('uint8') 48 | base_img = self.transform(base_img) 49 | ref_img1 = self.transform(ref_img1) 50 | ref_img2 = self.transform(ref_img2) 51 | 52 | base_age = self.labels[item] 53 | ref_age1 = self.labels[ref_idx[0]] 54 | ref_age2 = self.labels[ref_idx[1]] 55 | 56 | # gt ranks 57 | base_rank = self.ranks[item] 58 | ref_rank1 = self.ranks[ref_idx[0]] 59 | ref_rank2 = self.ranks[ref_idx[1]] 60 | 61 | return base_img, ref_img1, ref_img2, order_label, [base_age, ref_age1, ref_age2, base_rank, ref_rank1, ref_rank2], item 62 | 63 | def __len__(self): 64 | return self.n_imgs 65 | 66 | def find_reference_triplet(self, base_rank, ref_ranks, min_rank=0, max_rank=32, epsilon=1e-4): 67 | 68 | def get_indices_in_range(search_range, ages): 69 | """find indices of values within range[0] <= x <= range[1]""" 70 | return np.argwhere(np.logical_and(search_range[0] <= ages, ages <= search_range[1])) 71 | 72 | rng = np.random.default_rng() 73 | case = rng.choice([0,1], 1, p=[2/3, 1/3]) 74 | ref_idx2 = -1 75 | debug_flag = 0 76 | while ref_idx2 == -1: 77 | if debug_flag == 2: 78 | raise ValueError(f'Failed to find reference... base_score: {base_rank}') 79 | if case == 0: # ref1_rank + tau < base_rank < ref2_rank - tau 80 | ref_range_min = max(base_rank - (self.tau*self.hardness_multiplier), min_rank) 81 | ref_range_max = base_rank - self.tau - epsilon 82 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 83 | if len(candidates) > 0: 84 | ref_idx1 = candidates[rng.choice(len(candidates), 1)[0]][0] 85 | else: 86 | case = (case + 1) % 2 87 | debug_flag += 1 88 | continue 89 | 90 | ref_range_min = base_rank + self.tau + epsilon 91 | ref_range_max = min(base_rank + (self.tau*self.hardness_multiplier), max_rank) 92 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 93 | if len(candidates) > 0: 94 | ref_idx2 = candidates[rng.choice(len(candidates), 1)[0]][0] 95 | else: 96 | case = (case + 1) % 2 97 | debug_flag += 1 98 | continue 99 | 100 | elif case == 1: # |base_rank - ref_rank| <= tau 101 | ref_range_min = base_rank - self.tau - epsilon 102 | ref_range_max = base_rank + self.tau + epsilon 103 | candidates = get_indices_in_range([ref_range_min, ref_range_max], ref_ranks) 104 | if len(candidates) > 0: 105 | ref_idx1 = candidates[rng.choice(len(candidates), 1)[0]][0] 106 | ref_idx2 = candidates[rng.choice(len(candidates), 1)[0]][0] 107 | else: 108 | case = (case + 1) % 2 109 | debug_flag += 1 110 | else: 111 | raise ValueError(f'[!] something is wrong... base rank{base_rank}, case{case}') 112 | 113 | if case == 0: 114 | order = [0, 1] 115 | else: 116 | order = [2, 2] 117 | 118 | return order, [ref_idx1, ref_idx2] -------------------------------------------------------------------------------- /data/datasets/OL_angl_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils.util import load_one_image 6 | 7 | 8 | class OLAngle_Train(Dataset): 9 | def __init__(self, imgs, labels, transform, tau, norm_age=True, logscale=False, is_filelist=False): 10 | super(Dataset, self).__init__() 11 | self.imgs = imgs 12 | self.labels = labels 13 | self.transform = transform 14 | self.n_imgs = len(self.imgs) 15 | 16 | if logscale: 17 | self.labels = np.log(labels.astype(np.float32)) 18 | else: 19 | if norm_age: 20 | self.labels = self.labels - min(self.labels) 21 | 22 | self.max_age = self.labels.max() 23 | self.min_age = self.labels.min() 24 | self.tau = tau 25 | self.is_filelist = is_filelist 26 | 27 | # mapping age to rank : because there are omitted ages 28 | rank = 0 29 | self.mapping = dict() 30 | for cls in np.unique(self.labels): 31 | self.mapping[cls] = rank 32 | rank += 1 33 | self.ranks = np.array([self.mapping[l] for l in self.labels]) 34 | self.max_rank = self.ranks.max() 35 | self.min_rank = self.ranks.min() 36 | self.margin = 0 # should smaller than tau 37 | assert self.margin < self.tau 38 | 39 | def __getitem__(self, item): 40 | left_idx, center_idx, right_idx, pos_idx = self.find_references(item, self.ranks, margin=self.margin) 41 | 42 | if self.is_filelist: 43 | l_img = np.asarray(load_one_image(self.imgs[left_idx])).astype('uint8') 44 | c_img = np.asarray(load_one_image(self.imgs[center_idx])).astype('uint8') 45 | r_img = np.asarray(load_one_image(self.imgs[right_idx])).astype('uint8') 46 | p_img = np.asarray(load_one_image(self.imgs[pos_idx])).astype('uint8') 47 | 48 | else: 49 | l_img = np.asarray(self.imgs[left_idx]).astype('uint8') 50 | c_img = np.asarray(self.imgs[center_idx]).astype('uint8') 51 | r_img = np.asarray(self.imgs[right_idx]).astype('uint8') 52 | p_img = np.asarray(self.imgs[pos_idx]).astype('uint8') 53 | 54 | l_img = self.transform(l_img) 55 | c_img = self.transform(c_img) 56 | r_img = self.transform(r_img) 57 | p_img = self.transform(p_img) 58 | 59 | return l_img, c_img, r_img, p_img, [self.ranks[left_idx], self.ranks[center_idx], self.ranks[right_idx], self.ranks[pos_idx]] 60 | 61 | def __len__(self): 62 | return self.n_imgs 63 | 64 | def find_references(self, item, ranks, margin=1, epsilon=1e-4): 65 | 66 | def get_indices_in_range(search_range, ages): 67 | """find indices of values within range[0] <= x <= range[1]""" 68 | return np.argwhere(np.logical_and(search_range[0] <= ages, ages <= search_range[1])) 69 | 70 | base_rank = ranks[item] 71 | rng = np.random.default_rng() 72 | 73 | if base_rank - self.tau - margin < self.min_rank: 74 | left_idx = item 75 | 76 | # pick center 77 | candidates = get_indices_in_range([base_rank+self.tau-margin-epsilon, base_rank+self.tau+margin+epsilon], ranks) 78 | center_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 79 | center_rank = ranks[center_idx] 80 | 81 | # pick right 82 | candidates = get_indices_in_range([center_rank + self.tau -margin-epsilon, center_rank + self.tau + margin+epsilon], ranks) 83 | right_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 84 | 85 | # pick pos 86 | candidates = get_indices_in_range( 87 | [center_rank - epsilon, center_rank + epsilon], ranks) 88 | pos_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 89 | 90 | elif base_rank + self.tau + margin > self.max_rank: 91 | right_idx = item 92 | 93 | # pick center 94 | candidates = get_indices_in_range( 95 | [base_rank - self.tau - margin - epsilon, base_rank - self.tau + margin + epsilon], ranks) 96 | center_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 97 | center_rank = ranks[center_idx] 98 | 99 | # pick left 100 | candidates = get_indices_in_range([center_rank - self.tau - margin - epsilon, 101 | center_rank - self.tau + margin + epsilon], ranks) 102 | left_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 103 | 104 | # pick pos 105 | candidates = get_indices_in_range( 106 | [center_rank - epsilon, center_rank + epsilon], ranks) 107 | pos_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 108 | 109 | else: 110 | center_idx = item 111 | 112 | # pick left 113 | candidates = get_indices_in_range([base_rank - self.tau - margin - epsilon, 114 | base_rank - self.tau + margin + epsilon], ranks) 115 | left_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 116 | 117 | # pick right 118 | candidates = get_indices_in_range([base_rank + self.tau - margin - epsilon, 119 | base_rank + self.tau + margin + epsilon], ranks) 120 | right_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 121 | 122 | # pick pos 123 | candidates = get_indices_in_range( 124 | [base_rank - epsilon, base_rank + epsilon], ranks) 125 | pos_idx = candidates[rng.choice(len(candidates), 1)[0]][0] 126 | return left_idx, center_idx, right_idx, pos_idx -------------------------------------------------------------------------------- /config/basic.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | from PIL import Image 3 | 4 | class ConfigBasic: 5 | def __init__(self,): 6 | self.dataset = None 7 | self.setting = None 8 | self.logscale = False 9 | self.set_optimizer_parameters() 10 | self.set_training_opts() 11 | self.set_network() 12 | 13 | def set_dataset(self): 14 | if self.dataset == 'morph': 15 | if self.logscale: 16 | self.tau = 0.1 17 | else: 18 | self.tau = 2 19 | 20 | self.img_root = '/hdd/2020/Research/datasets/Agedataset/img/morph' 21 | if self.setting == 'A': 22 | self.is_filelist = True 23 | self.train_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_A/setting_A_train_fold{self.fold}.txt' 24 | self.test_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_A/setting_A_test_fold{self.fold}.txt' 25 | self.delimeter = "," 26 | self.img_idx = 4 27 | self.lb_idx = 3 28 | 29 | elif self.setting == 'B': 30 | self.delimeter = " " 31 | self.img_idx = 3 32 | self.lb_idx = 2 33 | if self.fold == 1: 34 | self.is_filelist = True 35 | self.train_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_B/Setting_B_S1_train.txt' 36 | self.test_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_B/Setting_B_S2+S3_test.txt' 37 | else: 38 | self.is_filelist = True 39 | self.train_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_B/Setting_B_S2_train.txt' 40 | self.test_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_B/Setting_B_S1+S3_test.txt' 41 | 42 | elif self.setting == 'C': 43 | self.delimeter = " " 44 | self.img_idx = 0 45 | self.lb_idx = 2 46 | self.is_filelist = True 47 | self.train_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_C/setting_C_train_fold{self.fold}.txt' 48 | self.test_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_C/setting_C_test_fold{self.fold}.txt' 49 | 50 | elif self.setting == 'D': 51 | self.delimeter = " " 52 | self.img_idx = 0 53 | self.lb_idx = 2 54 | self.is_filelist = True 55 | self.train_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_D/setting_D_train_fold{self.fold}.txt' 56 | self.test_file = f'/hdd/2020/Research/datasets/Agedataset/morph_fold/Setting_D/setting_D_test_fold{self.fold}.txt' 57 | else: 58 | raise ValueError(f'setting {self.setting} is out of range.') 59 | 60 | elif self.dataset == 'adience': 61 | self.is_filelist = False 62 | self.train_file = f'/hdd/2021/Research/99_dataset/Adience/adience_F{self.fold}_train_algn_[0_7].pickle' 63 | self.test_file = f'/hdd/2021/Research/99_dataset/Adience/adience_F{self.fold}_test_algn_[0_7].pickle' 64 | self.tau = 1 65 | 66 | elif self.dataset =='clap': 67 | self.delimeter = " " 68 | self.img_idx = 0 69 | self.lb_idx = 1 70 | self.is_filelist = True 71 | self.img_root = '/hdd/2020/Research/datasets/Agedataset/img/CLAP/2015' 72 | if self.fold == 'eval_on_test': 73 | self.train_file = '/hdd/2020/Research/datasets/Agedataset/clap_split/CLAP_trainval.txt' 74 | self.test_file = '/hdd/2020/Research/datasets/Agedataset/clap_split/CLAP_test.txt' 75 | elif self.fold == 'eval_on_val': 76 | self.train_file = '/hdd/2020/Research/datasets/Agedataset/clap_split/CLAP_train.txt' 77 | self.test_file = '/hdd/2020/Research/datasets/Agedataset/clap_split/CLAP_val.txt' 78 | else: 79 | raise ValueError(f'check fold: it should be [eval_on_test] or [eval_on_val], but {self.fold} is given.') 80 | else: 81 | raise ValueError(f'{self.dataset} is out of range!') 82 | 83 | self.mean = [0.485, 0.456, 0.406] 84 | self.std = [0.229, 0.224, 0.225] 85 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 86 | self.transform_tr = transforms.Compose([ 87 | lambda x: Image.fromarray(x), 88 | transforms.RandomCrop(224), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | self.normalize 92 | ]) 93 | 94 | self.transform_te = transforms.Compose([ 95 | lambda x: Image.fromarray(x), 96 | transforms.CenterCrop(224), 97 | transforms.ToTensor(), 98 | self.normalize 99 | ]) 100 | 101 | def set_optimizer_parameters(self): 102 | # *** Optimizer 103 | self.adam = True 104 | self.learning_rate = 0.0001 105 | self.lr_decay_epochs = [30, 50, 100] 106 | self.lr_decay_rate = 0.1 107 | self.momentum = 0.9 108 | self.weight_decay = 0.0005 109 | 110 | # *** Scheduler 111 | self.scheduler = 'cosine' 112 | 113 | def set_network(self): 114 | self.model = 'T_v0' 115 | self.backbone = 'vgg16bn' 116 | self.ckpt = None 117 | 118 | def set_training_opts(self): 119 | # *** Print Option 120 | self.val_freq = 3 121 | self.print_freq = 50 122 | 123 | # *** Training 124 | self.batch_size = 16 125 | self.num_workers = 1 126 | self.epochs = 100 127 | 128 | # *** Save option 129 | self.save_freq = 100 130 | self.wandb = False 131 | 132 | def set_test_opts(self): 133 | self.ckpt = None 134 | -------------------------------------------------------------------------------- /utils/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def update_loss_matrix(A, loss, base_ranks, ref_ranks=None): 6 | batch_size = len(base_ranks) 7 | if ref_ranks is not None: 8 | for i in range(batch_size): 9 | A[0][base_ranks[i], ref_ranks[i]] += loss[i] 10 | A[1][base_ranks[i], ref_ranks[i]] += 1 11 | else: 12 | for i in range(batch_size): 13 | A[0][base_ranks[i]] += loss[i] 14 | A[1][base_ranks[i]] += 1 15 | return A 16 | 17 | 18 | def get_pairs_equally(ranks, tau, m=32): 19 | orders = [] 20 | base_idx = [] 21 | ref_idx = [] 22 | N = len(ranks) 23 | for i in range(N): 24 | for j in range(i+1, N): 25 | if np.random.rand(1) > 0.5: 26 | base_idx.append(i) 27 | ref_idx.append(j) 28 | order_ij = get_order_labels(ranks[i], ranks[j], tau) 29 | orders.append(order_ij) 30 | else: 31 | base_idx.append(j) 32 | ref_idx.append(i) 33 | order_ji = get_order_labels(ranks[j], ranks[i], tau) 34 | orders.append(order_ji) 35 | refine = [] 36 | orders = np.array(orders) 37 | for o in range(3): 38 | o_idxs = np.argwhere(orders==o).flatten() 39 | if len(o_idxs) > m: 40 | sel = np.random.choice(o_idxs, m, replace=False) 41 | refine.append(sel) 42 | else: 43 | refine.append(o_idxs) 44 | refine = np.concatenate(refine) 45 | base_idx = np.array(base_idx)[refine] 46 | ref_idx = np.array(ref_idx)[refine] 47 | orders = orders[refine] 48 | return base_idx, ref_idx, orders 49 | 50 | 51 | def get_pairs_loss_balancing(ranks, loss_record, tau, m=32): 52 | orders = [] 53 | base_idx = [] 54 | ref_idx = [] 55 | loss_val = [] 56 | N = len(ranks) 57 | 58 | for i in range(N): 59 | for j in range(i+1, N): 60 | if np.random.rand(1) > 0.5: 61 | base_idx.append(i) 62 | ref_idx.append(j) 63 | order_ij = get_order_labels(ranks[i], ranks[j], tau) 64 | orders.append(order_ij) 65 | loss_val.append(loss_record[ranks[i], ranks[j]]) 66 | 67 | else: 68 | base_idx.append(j) 69 | ref_idx.append(i) 70 | order_ji = get_order_labels(ranks[j], ranks[i], tau) 71 | orders.append(order_ji) 72 | loss_val.append(loss_record[ranks[j], ranks[i]]) 73 | orders = np.array(orders) 74 | loss_val = np.array(loss_val) 75 | refine = [] 76 | 77 | for o in range(3): 78 | o_idxs = np.argwhere(orders==o).flatten() 79 | if len(o_idxs) > m: 80 | sel = np.random.choice(o_idxs, m, p=np.exp(loss_val[o_idxs]) / np.sum(np.exp(loss_val[o_idxs])), 81 | replace=False) 82 | refine.append(sel) 83 | else: 84 | refine.append(o_idxs) 85 | 86 | refine = np.concatenate(refine) 87 | base_idx = np.array(base_idx)[refine] 88 | ref_idx = np.array(ref_idx)[refine] 89 | orders = orders[refine] 90 | return base_idx, ref_idx, orders 91 | 92 | 93 | class LossTracker: 94 | def __init__(self, cfg): 95 | self.n_ranks = cfg.n_ranks 96 | 97 | self.pairwise_loss_record = dict() 98 | self.pairwise_loss_record['drct'] = np.zeros([self.n_ranks, self.n_ranks]) 99 | self.pairwise_loss_record['dist'] = np.zeros([self.n_ranks, self.n_ranks]) 100 | self.pairwise_loss_record['total'] = np.zeros([self.n_ranks, self.n_ranks]) 101 | 102 | self.samplewise_loss_record = dict() 103 | self.samplewise_loss_record['center'] = np.zeros([self.n_ranks,]) 104 | 105 | self.counter = dict() 106 | self.counter['drct'] = np.zeros([self.n_ranks, self.n_ranks]) 107 | self.counter['dist'] = np.zeros([self.n_ranks, self.n_ranks]) 108 | self.counter['total'] = np.zeros([self.n_ranks, self.n_ranks]) 109 | self.counter['center'] = np.zeros([self.n_ranks, ]) 110 | self.tau = cfg.tau 111 | 112 | def update_pairwise_loss_matrix(self, loss, base_ranks, ref_ranks, losstype='drct'): 113 | pair_size = len(base_ranks) 114 | for i in range(pair_size): 115 | self.pairwise_loss_record[losstype][base_ranks[i], ref_ranks[i]] += loss[i] 116 | self.counter[losstype][base_ranks[i], ref_ranks[i]] += 1 117 | 118 | def update_pairwise_loss_matrix_total(self, drct_loss, dist_loss, base_ranks, ref_ranks): 119 | pair_size = len(base_ranks) 120 | drct_idx = 0 121 | for i in range(pair_size): 122 | if abs(base_ranks[i] - ref_ranks[i]) <= self.tau: 123 | pass 124 | else: 125 | self.pairwise_loss_record['drct'][base_ranks[i], ref_ranks[i]] += drct_loss[drct_idx] 126 | self.pairwise_loss_record['total'][base_ranks[i], ref_ranks[i]] += drct_loss[drct_idx] 127 | self.counter['drct'][base_ranks[i], ref_ranks[i]] += 1 128 | drct_idx += 1 129 | self.pairwise_loss_record['dist'][base_ranks[i], ref_ranks[i]] += dist_loss[i] 130 | self.pairwise_loss_record['total'][base_ranks[i], ref_ranks[i]] += dist_loss[i] 131 | self.counter['dist'][base_ranks[i], ref_ranks[i]] += 1 132 | self.counter['total'] = self.counter['dist'] 133 | 134 | def update_samplewise_loss(self, loss, ranks, losstype='center'): 135 | batch_size = len(ranks) 136 | for i in range(batch_size): 137 | self.samplewise_loss_record[losstype][ranks[i]] += loss[i] 138 | self.counter[losstype][ranks[i]] += 1 139 | 140 | def get_avg_samplewise_loss(self, losstypes=['total']): 141 | avg_samplewise_loss = np.zeros_like(self.samplewise_loss_record['center']) 142 | avg_samplewise_loss += (self.samplewise_loss_record['center']/(self.counter['center']+1e-7)) 143 | for losstype in losstypes: 144 | samplewise_loss_sum = np.sum(self.pairwise_loss_record[losstype], axis=-1) 145 | samplewise_cnt = np.sum(self.counter[losstype], axis=-1) 146 | avg_samplewise_loss += samplewise_loss_sum / (samplewise_cnt+1e-7) 147 | cnt_zero_idx = np.argwhere(self.counter['center']==0).flatten() 148 | mean_loss_val = np.mean(avg_samplewise_loss) 149 | avg_samplewise_loss[cnt_zero_idx] = mean_loss_val # assign mean loss value to some ranks if they didn't occur during previous training period 150 | return avg_samplewise_loss 151 | 152 | def get_avg_pairwise_loss(self, ): 153 | avg_pairwise_loss = np.zeros_like(self.pairwise_loss_record['total']) 154 | avg_pairwise_loss += (self.pairwise_loss_record['total']/(self.counter['total']+1e-7)) 155 | mean_loss_val = np.mean(avg_pairwise_loss) 156 | cnt_zero_idx = np.argwhere(self.counter['total'] == 0) 157 | avg_pairwise_loss[cnt_zero_idx[:,0], cnt_zero_idx[:,1]] = mean_loss_val 158 | avg_samplewise_loss = (self.samplewise_loss_record['center']/(self.counter['center']+1e-7)).reshape(-1,1) 159 | avg_pairwise_loss = avg_pairwise_loss + avg_samplewise_loss 160 | return avg_pairwise_loss 161 | 162 | def restart_record(self): 163 | self.pairwise_loss_record = dict() 164 | self.pairwise_loss_record['drct'] = np.zeros([self.n_ranks, self.n_ranks]) 165 | self.pairwise_loss_record['dist'] = np.zeros([self.n_ranks, self.n_ranks]) 166 | self.pairwise_loss_record['total'] = np.zeros([self.n_ranks, self.n_ranks]) 167 | 168 | self.samplewise_loss_record = dict() 169 | self.samplewise_loss_record['center'] = np.zeros([self.n_ranks, ]) 170 | 171 | self.counter = dict() 172 | self.counter['drct'] = np.zeros([self.n_ranks, self.n_ranks]) 173 | self.counter['dist'] = np.zeros([self.n_ranks, self.n_ranks]) 174 | self.counter['total'] = np.zeros([self.n_ranks, self.n_ranks]) 175 | self.counter['center'] = np.zeros([self.n_ranks, ]) 176 | 177 | def update_pairwise_loss_matrix_v2(A, drct_loss, dist_loss, base_ranks, ref_ranks=None): 178 | batch_size = len(base_ranks) 179 | if ref_ranks is not None: 180 | for i in range(batch_size): 181 | A[0][base_ranks[i], ref_ranks[i]] += (drct_loss[i] + dist_loss[i]) 182 | A[1][base_ranks[i], ref_ranks[i]] += 1 183 | else: 184 | for i in range(batch_size): 185 | A[0][base_ranks[i]] += (drct_loss[i] + dist_loss[i]) 186 | A[1][base_ranks[i]] += 1 187 | return A 188 | 189 | 190 | 191 | def get_order_labels(rank_base, rank_ref, tau): 192 | if rank_base > rank_ref + tau: 193 | order = 0 194 | elif rank_base < rank_ref - tau: 195 | order = 1 196 | else: 197 | order = 2 198 | return order 199 | -------------------------------------------------------------------------------- /utils/loss_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import itertools 5 | 6 | from utils.util import to_np 7 | 8 | 9 | def compute_order_loss(embs, base_idx, ref_idx, rank_labels, fdc_points, cfg, record=False): 10 | def get_forward_and_backward_idxs(base_idx, ref_idx, ranks, fdc_ranks, cfg): 11 | batch_size = len(base_idx) 12 | base_ranks = ranks[base_idx] 13 | ref_ranks = ranks[ref_idx] 14 | forward_idxs = [] 15 | backward_idxs = [] 16 | mask = [] 17 | gt = [] 18 | for i in range(batch_size): 19 | if base_ranks[i] > ref_ranks[i]: 20 | fdc_1_idx = cfg.fiducial_point_num - np.sum(fdc_ranks - ref_ranks[i] > 0) - 1 21 | fdc_2_idx = cfg.fiducial_point_num - np.sum(fdc_ranks - base_ranks[i] >= 0) 22 | fdc_3_idx = fdc_2_idx + 1 23 | 24 | backward_idxs.append([fdc_1_idx, fdc_2_idx]) 25 | if fdc_3_idx >= len(fdc_points): 26 | forward_idxs.append([fdc_2_idx, fdc_2_idx-1]) 27 | else: 28 | forward_idxs.append([fdc_3_idx, fdc_2_idx]) 29 | 30 | mask.append(True) 31 | gt.append(0) 32 | elif base_ranks[i] < ref_ranks[i]: 33 | fdc_1_idx = cfg.fiducial_point_num - np.sum(fdc_ranks - base_ranks[i] > 0) - 1 34 | fdc_2_idx = cfg.fiducial_point_num - np.sum(fdc_ranks - ref_ranks[i] >= 0) 35 | fdc_3_idx = fdc_1_idx - 1 36 | forward_idxs.append([fdc_2_idx, fdc_1_idx]) 37 | if fdc_3_idx < 0: 38 | backward_idxs.append([fdc_1_idx, fdc_1_idx+1]) 39 | else: 40 | backward_idxs.append([fdc_3_idx, fdc_1_idx]) 41 | mask.append(True) 42 | gt.append(1) 43 | else: 44 | mask.append(False) 45 | 46 | return np.array(forward_idxs), np.array(backward_idxs), torch.tensor(gt).cuda(), base_idx[mask], ref_idx[mask] 47 | 48 | fdc_points = nn.functional.normalize(fdc_points, dim=-1) 49 | hdim = fdc_points.shape[-1] 50 | fdc_point_ranks = np.array([((cfg.n_ranks-1) / (cfg.fiducial_point_num-1)) * i for i in range(cfg.fiducial_point_num)]) 51 | 52 | direction_matrix = fdc_points.view(cfg.fiducial_point_num, 1, hdim).expand(cfg.fiducial_point_num, cfg.fiducial_point_num, hdim) - fdc_points.view(1, cfg.fiducial_point_num, hdim).expand(cfg.fiducial_point_num, cfg.fiducial_point_num, hdim) 53 | direction_matrix = nn.functional.normalize(direction_matrix, dim=-1) 54 | 55 | forward_idxs, backward_idxs, gt, base_idx, ref_idx = get_forward_and_backward_idxs(base_idx, ref_idx, rank_labels, fdc_point_ranks, cfg) 56 | batch_size = base_idx.shape[0] 57 | 58 | v_xy = nn.functional.normalize(embs[ref_idx] - embs[base_idx], dim=-1) 59 | v_forward = direction_matrix[forward_idxs[:,0], forward_idxs[:,1]] 60 | v_backward = direction_matrix[backward_idxs[:,0], backward_idxs[:,1]] 61 | 62 | v_fb = torch.stack([v_backward, v_forward], dim=-1) 63 | logits = 20*torch.matmul(v_xy.view(batch_size, 1, hdim), v_fb).squeeze(1) 64 | 65 | 66 | if record: 67 | loss_per_pair = nn.CrossEntropyLoss(reduction='none')(logits, gt) 68 | loss = torch.mean(loss_per_pair) 69 | return loss, logits, gt, to_np(loss_per_pair) 70 | else: 71 | loss = nn.CrossEntropyLoss()(logits, gt) 72 | return loss, logits, gt 73 | 74 | 75 | def compute_metric_loss(embs, base_idx, ref_idx, rank_labels, fdc_points, margin, cfg, record=False): 76 | fdc_points = nn.functional.normalize(fdc_points, dim=-1) 77 | fdc_point_ranks = np.array( 78 | [((cfg.n_ranks - 1) / (cfg.fiducial_point_num - 1)) * i for i in range(cfg.fiducial_point_num)]) 79 | 80 | if cfg.metric == 'L2': 81 | dists = torch.cdist(fdc_points, embs) 82 | elif cfg.metric == 'cosine': 83 | dists = 1 - torch.matmul(fdc_points, embs.transpose(1, 0)) 84 | def get_pos_neg_idxs(base_idx, ref_idx, ranks, fdc_ranks, cfg): 85 | batch_size = len(base_idx) 86 | base_ranks = ranks[base_idx] 87 | ref_ranks = ranks[ref_idx] 88 | row_idxs = [] 89 | pos_idxs = [] 90 | neg_idxs = [] 91 | split_idxs = [] 92 | 93 | sim_row_idxs = [] 94 | sim_pos_idxs = [] 95 | sim_neg_idxs = [] 96 | 97 | for i in range(batch_size): 98 | if base_ranks[i] > (ref_ranks[i] + cfg.tau): 99 | fdc_1_idx = cfg.fiducial_point_num - np.sum(fdc_ranks - ref_ranks[i] > 0) - 1 100 | fdc_2_idx = cfg.fiducial_point_num - np.sum(fdc_ranks - base_ranks[i] >= 0) 101 | 102 | row_idxs.append(np.arange(fdc_1_idx+1)) 103 | pos_idxs.append([ref_idx[i]]*(fdc_1_idx+1)) 104 | neg_idxs.append([base_idx[i]]*(fdc_1_idx+1)) 105 | row_idxs.append(np.arange(fdc_2_idx, cfg.fiducial_point_num)) 106 | pos_idxs.append([base_idx[i]]*(cfg.fiducial_point_num-fdc_2_idx)) 107 | neg_idxs.append([ref_idx[i]]*(cfg.fiducial_point_num-fdc_2_idx)) 108 | split_idxs.append(fdc_1_idx + 1 + cfg.fiducial_point_num - fdc_2_idx) 109 | 110 | elif base_ranks[i] < (ref_ranks[i] - cfg.tau): 111 | fdc_1_idx = cfg.fiducial_point_num - np.sum(fdc_point_ranks - rank_labels[base_idx[i]] > 0) - 1 112 | fdc_2_idx = cfg.fiducial_point_num - np.sum(fdc_point_ranks - rank_labels[ref_idx[i]] >= 0) 113 | 114 | row_idxs.append(np.arange(fdc_1_idx + 1)) 115 | pos_idxs.append([base_idx[i]] * (fdc_1_idx + 1)) 116 | neg_idxs.append([ref_idx[i]] * (fdc_1_idx + 1)) 117 | row_idxs.append(np.arange(fdc_2_idx, cfg.fiducial_point_num)) 118 | pos_idxs.append([ref_idx[i]] * (cfg.fiducial_point_num-fdc_2_idx)) 119 | neg_idxs.append([base_idx[i]] * (cfg.fiducial_point_num-fdc_2_idx)) 120 | split_idxs.append(fdc_1_idx + 1 + cfg.fiducial_point_num - fdc_2_idx) 121 | else: 122 | sim_row_idxs.append(np.arange(cfg.fiducial_point_num)) 123 | sim_pos_idxs.append([base_idx[i]]*cfg.fiducial_point_num) 124 | sim_neg_idxs.append([ref_idx[i]]*cfg.fiducial_point_num) 125 | split_idxs.append(cfg.fiducial_point_num) 126 | row_idxs = np.concatenate(row_idxs) 127 | pos_idxs = np.concatenate(pos_idxs) 128 | neg_idxs = np.concatenate(neg_idxs) 129 | sim_row_idxs = np.concatenate(sim_row_idxs) 130 | sim_pos_idxs = np.concatenate(sim_pos_idxs) 131 | sim_neg_idxs = np.concatenate(sim_neg_idxs) 132 | return row_idxs, pos_idxs, neg_idxs, sim_row_idxs, sim_pos_idxs, sim_neg_idxs, split_idxs 133 | 134 | row_idxs, pos_idxs, neg_idxs, sim_row_idxs, sim_pos_idxs, sim_neg_idxs, split_idxs = get_pos_neg_idxs(base_idx, ref_idx, rank_labels, fdc_point_ranks, cfg) 135 | 136 | violation = dists[row_idxs, pos_idxs] - dists[row_idxs,neg_idxs] 137 | violation = violation + margin 138 | 139 | if len(sim_row_idxs) > 0: 140 | if cfg.tau == 0: 141 | sim_violation = torch.abs(dists[sim_row_idxs, sim_pos_idxs] - dists[sim_row_idxs, sim_neg_idxs]) 142 | else: 143 | sim_violation = torch.abs(dists[sim_row_idxs,sim_pos_idxs] - dists[sim_row_idxs,sim_neg_idxs]) - margin 144 | loss = torch.cat([nn.functional.relu(violation), nn.functional.relu(sim_violation)]) 145 | 146 | else: 147 | loss = nn.functional.relu(violation) 148 | if record: 149 | loss_per_pairs = torch.tensor([torch.sum(s) for s in torch.split(loss, split_idxs)]) 150 | return torch.sum(loss) / len(base_idx), to_np(loss_per_pairs) 151 | return torch.sum(loss) / len(base_idx) 152 | 153 | 154 | 155 | 156 | def compute_center_loss(embs, rank_labels, fdc_points, cfg, record=False): 157 | fdc_points = nn.functional.normalize(fdc_points, dim=-1) 158 | fdc_point_ranks = np.array([((cfg.n_ranks - 1) / (cfg.fiducial_point_num - 1)) * i for i in range(cfg.fiducial_point_num)]) 159 | 160 | def get_pos_neg_idxs(ranks, fdc_ranks, cfg): 161 | adaptive_margin = cfg.n_ranks != cfg.fiducial_point_num 162 | if adaptive_margin: 163 | nn_idxs = [] 164 | margins = [] 165 | emb_idxs = [] 166 | emb_idx = 0 167 | for r in ranks: 168 | abs_diff = np.abs(fdc_ranks-r) 169 | min_val = abs_diff.min() 170 | nn = np.argwhere(abs_diff==min_val).flatten() 171 | nn_idxs.append(nn) 172 | 173 | margin_val = min_val*cfg.margin/(max(cfg.tau, 1)) 174 | margins.append([margin_val]*len(nn)) 175 | emb_idxs.append([emb_idx]*len(nn)) 176 | emb_idx += 1 177 | nn_idxs = np.concatenate(nn_idxs) 178 | margins = np.concatenate(margins) 179 | emb_idxs = np.concatenate(emb_idxs) 180 | else: 181 | nn_idxs = ranks 182 | margins = np.array([0.5 * cfg.margin / (max(cfg.tau, 1))] * len(nn_idxs)) 183 | emb_idxs = np.arange(len(nn_idxs)) 184 | 185 | return nn_idxs, emb_idxs, margins 186 | 187 | nn_idxs, emb_idxs, margins = get_pos_neg_idxs(rank_labels, fdc_point_ranks, cfg) 188 | 189 | if cfg.metric == 'L2': 190 | dists = torch.cdist(fdc_points, embs) 191 | elif cfg.metric == 'cosine': 192 | dists = 1 - torch.matmul(fdc_points, embs.transpose(1, 0)) 193 | 194 | 195 | loss = dists[nn_idxs, emb_idxs] 196 | 197 | # loss = nn.functional.relu(violation) 198 | # loss = torch.tensor([torch.sum(s) for s in torch.split(loss, split_idxs)]) 199 | if record: 200 | return torch.sum(loss) / (torch.sum(loss > 0) + 1e-7), to_np(loss) 201 | return torch.sum(loss) / (torch.sum(loss > 0) + 1e-7) 202 | 203 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | from copy import deepcopy 5 | 6 | import numpy as np 7 | import wandb 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn as nn 11 | import torch.backends.cudnn as cudnn 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | import matplotlib.pyplot as plt 15 | from collections import defaultdict 16 | 17 | from config.basic import ConfigBasic 18 | from utils.util import write_log, get_current_time, to_np, make_dir, log_configs, save_ckpt, set_wandb 19 | from utils.util import adjust_learning_rate, AverageMeter, ClassWiseAverageMeter, cls_accuracy, extract_embs, print_eval_result_by_groups_and_k 20 | from utils.loss_util import compute_order_loss, compute_metric_loss, compute_center_loss 21 | from utils.comparison_utils import find_kNN 22 | from networks.util import prepare_model 23 | from data.get_datasets_tr_OLbasic_val_NN import get_datasets 24 | 25 | 26 | def set_local_config(cfg): 27 | # Dataset 28 | cfg.dataset = 'morph' 29 | cfg.setting = 'D' 30 | cfg.fold = 4 31 | 32 | cfg.logscale = False 33 | cfg.set_dataset() 34 | cfg.tau = 1 35 | 36 | # Model 37 | cfg.model = 'GOL' 38 | cfg.backbone = 'vgg16v2norm' 39 | cfg.metric = 'L2' 40 | cfg.k = np.arange(2, 60, 2) 41 | cfg.epochs = 100 42 | cfg.scheduler = 'cosine' 43 | cfg.lr_decay_epochs = [100, 200, 300] 44 | cfg.period = 3 45 | 46 | cfg.margin = 0.25 47 | cfg.ref_mode = 'flex' 48 | cfg.ref_point_num = 60 # 60 Fold1, 58 Fold0 setting D // 56 setting c // 58 setting B // 55 setting A 49 | cfg.drct_wieght = 1 50 | cfg.start_norm = True 51 | cfg.learning_rate = 0.0001 52 | 53 | # Log 54 | cfg.wandb = False 55 | cfg.experiment_name = 'EXP_NAME' 56 | cfg.save_folder = f'../../RESULT_FOLDER_NAME/{cfg.dataset}/setting{cfg.setting}/{cfg.experiment_name}/PREFIX_{cfg.margin}_tau{cfg.tau}_F{cfg.fold}_{cfg.model}_{cfg.backbone}_{get_current_time()}' 57 | make_dir(cfg.save_folder) 58 | 59 | cfg.n_gpu = torch.cuda.device_count() 60 | cfg.num_workers = 1 61 | return cfg 62 | 63 | 64 | def main(): 65 | np.random.seed(999) 66 | 67 | cfg = ConfigBasic() 68 | cfg = set_local_config(cfg) 69 | cfg.logfile = log_configs(cfg, log_file='train_log.txt') 70 | 71 | # dataloader 72 | loader_dict = get_datasets(cfg) 73 | cfg.n_ranks = loader_dict['train'].dataset.ranks.max() + 1 74 | print(f'[*] {cfg.n_ranks} ranks exist. ') 75 | 76 | # model 77 | model = prepare_model(cfg) 78 | if cfg.wandb: 79 | set_wandb(cfg) 80 | wandb.watch(model) 81 | 82 | if cfg.adam: 83 | optimizer = optim.Adam(model.parameters(), 84 | lr=cfg.learning_rate, 85 | weight_decay=cfg.weight_decay) 86 | else: 87 | optimizer = optim.SGD(model.parameters(), 88 | lr=cfg.learning_rate, 89 | momentum=cfg.momentum, 90 | weight_decay=cfg.weight_decay) 91 | if cfg.scheduler == 'cosine': 92 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.epochs, eta_min=cfg.learning_rate*0.001) 93 | elif cfg.scheduler == 'multistep': 94 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.lr_decay_epochs, gamma=cfg.lr_decay_rate) 95 | 96 | # criterion = torch.nn.CrossEntropyLoss(reduction='none') 97 | 98 | if torch.cuda.is_available(): 99 | if cfg.n_gpu > 1: 100 | model = nn.DataParallel(model) 101 | model = model.cuda() 102 | # criterion = criterion.cuda() 103 | cudnn.benchmark = True 104 | 105 | val_mae_best = 2.9 106 | log_dict = dict() 107 | # init loss matrix 108 | loss_record = dict() 109 | loss_record['angle'] = [np.zeros([cfg.n_ranks, cfg.n_ranks]), np.zeros([cfg.n_ranks, cfg.n_ranks])] 110 | # loss_record['angle'] = [np.zeros([n_ranks, ]), np.zeros([n_ranks, ])] 111 | 112 | for epoch in range(cfg.epochs): 113 | print("==> training...") 114 | 115 | time1 = time.time() 116 | train_loss, loss_record = train(epoch, loader_dict['train'], model, optimizer, cfg, 117 | prev_loss_record=loss_record) 118 | 119 | 120 | if cfg.scheduler: 121 | scheduler.step() 122 | time2 = time.time() 123 | print('epoch {}, loss {:.4f}, total time {:.2f}'.format(epoch, train_loss, time2 - time1)) 124 | 125 | if epoch % cfg.val_freq == 0: 126 | print('==> validation...') 127 | val_mae, best_k = validate(loader_dict, model, cfg) 128 | if val_mae < val_mae_best: 129 | val_mae_best = val_mae 130 | save_ckpt(cfg, model, f'ep_{epoch}_val_best_{val_mae:.3f}_k{best_k}.pth') 131 | 132 | # if train_acc > best: 133 | # best = train_acc 134 | # save_ckpt(cfg, model, f'ep_{epoch}_train_best_{best:.3f}.pth') 135 | 136 | elif epoch % cfg.save_freq == 0: 137 | save_ckpt(cfg, model, f'ep_{epoch}.pth') 138 | 139 | if cfg.wandb: 140 | log_dict['Epoch'] = epoch 141 | log_dict['Train Loss'] = train_loss 142 | log_dict['Val Mae'] = val_mae 143 | log_dict['LR'] = scheduler.get_lr()[0] if scheduler else cfg.learning_rate 144 | wandb.log(log_dict) 145 | 146 | print('[*] Training ends') 147 | 148 | 149 | def update_loss_matrix(A, loss, base_ranks, ref_ranks=None): 150 | batch_size = len(base_ranks) 151 | if ref_ranks is not None: 152 | for i in range(batch_size): 153 | A[0][base_ranks[i], ref_ranks[i]] += loss[i] 154 | A[1][base_ranks[i], ref_ranks[i]] += 1 155 | else: 156 | for i in range(batch_size): 157 | A[0][base_ranks[i]] += loss[i] 158 | A[1][base_ranks[i]] += 1 159 | return A 160 | 161 | 162 | def get_pairs_equally(ranks, tau, m=32): 163 | orders = [] 164 | base_idx = [] 165 | ref_idx = [] 166 | N = len(ranks) 167 | for i in range(N): 168 | for j in range(i+1, N): 169 | if np.random.rand(1) > 0.5: 170 | base_idx.append(i) 171 | ref_idx.append(j) 172 | order_ij = get_order_labels(ranks[i], ranks[j], tau) 173 | orders.append(order_ij) 174 | else: 175 | base_idx.append(j) 176 | ref_idx.append(i) 177 | order_ji = get_order_labels(ranks[j], ranks[i], tau) 178 | orders.append(order_ji) 179 | refine = [] 180 | orders = np.array(orders) 181 | for o in range(3): 182 | o_idxs = np.argwhere(orders==o).flatten() 183 | if len(o_idxs) > m: 184 | sel = np.random.choice(o_idxs, m, replace=False) 185 | refine.append(sel) 186 | else: 187 | refine.append(o_idxs) 188 | refine = np.concatenate(refine) 189 | base_idx = np.array(base_idx)[refine] 190 | ref_idx = np.array(ref_idx)[refine] 191 | orders = orders[refine] 192 | return base_idx, ref_idx, orders 193 | 194 | 195 | def get_order_labels(rank_base, rank_ref, tau): 196 | if rank_base > rank_ref + tau: 197 | order = 0 198 | elif rank_base < rank_ref - tau: 199 | order = 1 200 | else: 201 | order = 2 202 | return order 203 | 204 | 205 | def train(epoch, train_loader, model, optimizer, cfg, prev_loss_record): 206 | """One epoch training""" 207 | model.train() 208 | 209 | batch_time = AverageMeter() 210 | data_time = AverageMeter() 211 | losses = AverageMeter() 212 | angle_losses = AverageMeter() 213 | dist_losses = AverageMeter() 214 | center_losses = AverageMeter() 215 | angle_acc_meter = ClassWiseAverageMeter(2) 216 | # dist_acc_meter = ClassWiseAverageMeter(2) 217 | 218 | loss_record = deepcopy(prev_loss_record) 219 | end = time.time() 220 | for idx, (x_base, x_ref, _, ranks, _) in enumerate(train_loader): 221 | 222 | labels_np = torch.cat(ranks).detach().numpy() 223 | 224 | base_idx, ref_idx, order_labels = get_pairs_equally(labels_np, cfg.tau) 225 | 226 | if torch.cuda.is_available(): 227 | x_base = x_base.cuda() 228 | x_ref = x_ref.cuda() 229 | 230 | # order_labels = order_labels.cuda() 231 | data_time.update(time.time() - end) 232 | 233 | # ===================forward===================== 234 | embs = model.encoder(torch.cat([x_base, x_ref], dim=0)) 235 | 236 | # =====================loss====================== 237 | tic = time.time() 238 | dist_loss = compute_metric_loss(embs, base_idx, ref_idx, labels_np, model.ref_points, cfg.margin, cfg) 239 | dist_loss_time = time.time() - tic 240 | tic = time.time() 241 | angle_loss, logits, order_gt = compute_order_loss(embs, base_idx, ref_idx, labels_np, model.ref_points, cfg) 242 | angle_loss_time = time.time() - tic 243 | center_loss = compute_center_loss(embs, labels_np, model.ref_points, cfg) 244 | 245 | 246 | total_loss = (cfg.drct_wieght * angle_loss) + dist_loss + center_loss 247 | losses.update(total_loss.item(), x_base.size(0)) 248 | angle_losses.update(angle_loss.item(), x_base.size(0)) 249 | dist_losses.update(dist_loss.item(), x_base.size(0)) 250 | center_losses.update(center_loss.item(), x_base.size(0)) 251 | # ===================backward===================== 252 | optimizer.zero_grad() 253 | total_loss.backward() 254 | optimizer.step() 255 | 256 | acc, cnt = cls_accuracy(nn.functional.softmax(logits, dim=-1), order_gt, n_cls=2) 257 | # dist_acc, dist_cnt = cls_accuracy(nn.functional.softmax(dist_logits, dim=-1), dist_gt, n_cls=2) 258 | 259 | angle_acc_meter.update(acc, cnt) 260 | # dist_acc_meter.update(dist_acc, dist_cnt) 261 | 262 | # ===================meters===================== 263 | batch_time.update(time.time() - end) 264 | end = time.time() 265 | 266 | # update loss matrix 267 | # loss_record['angle'] = update_loss_matrix(loss_record['angle'], to_np(angle_loss), labels_np[base_idx], labels_np[ref_idx]) 268 | 269 | # print info 270 | if idx % cfg.print_freq == 0: 271 | write_log(cfg.logfile, 272 | f'Epoch [{epoch}][{idx}/{len(train_loader)}]\t' 273 | f'Time {batch_time.val:.3f}\t' 274 | f'Data {data_time.val:3f}\t' 275 | f'Loss {losses.val:.4f}\t' 276 | f'Angle-Loss {angle_losses.val:.4f}\t' 277 | f'Dist-Loss {dist_losses.val:.4f}\t' 278 | f'Center-Loss {center_losses.val:.4f}\t' 279 | f'Angle-Acc [{angle_acc_meter.val[0]:.3f} {angle_acc_meter.val[1]:.3f}] [{angle_acc_meter.total_avg:.3f}]\t' 280 | ) 281 | sys.stdout.flush() 282 | write_log(cfg.logfile, f' * Angle-Acc [{angle_acc_meter.avg[0]:.3f} {angle_acc_meter.avg[1]:.3f}] [{angle_acc_meter.total_avg:.3f}]\n') 283 | # write_log(cfg.logfile, f' * Dist-Acc [{dist_acc_meter.avg[0]:.3f} {dist_acc_meter.avg[1]:.3f}] [{dist_acc_meter.total_avg:.3f}]\n') 284 | 285 | return losses.avg, loss_record 286 | 287 | 288 | def validate(loader_dict, model, cfg): 289 | model.eval() 290 | data_time = AverageMeter() 291 | 292 | embs_train = extract_embs(model.encoder, loader_dict['train_for_val']) 293 | embs_train = embs_train.cuda() 294 | 295 | embs_test = extract_embs(model.encoder, loader_dict['val']) 296 | embs_test = embs_test.cuda() 297 | n_test = len(embs_test) 298 | n_batch = int(np.ceil(n_test / cfg.batch_size)) 299 | test_labels = loader_dict['val'].dataset.labels 300 | train_labels = loader_dict['train_for_val'].dataset.labels 301 | 302 | preds_all = defaultdict(list) 303 | 304 | with torch.no_grad(): 305 | end = time.time() 306 | for idx in range(n_batch): 307 | data_time.update(time.time() - end) 308 | i_st = idx * cfg.batch_size 309 | i_end = min(i_st + cfg.batch_size, n_test) 310 | 311 | # ===================meters===================== 312 | vals, inds = find_kNN(embs_test[i_st:i_end].view(i_end - i_st, -1), embs_train, k=max(cfg.k), 313 | metric=cfg.metric) 314 | inds = np.squeeze(to_np(inds), 0) 315 | for k in cfg.k: 316 | nn_labels = train_labels[inds[:, :k]] 317 | pred_mean = np.round(np.mean(nn_labels, axis=-1, dtype=np.float32)) 318 | preds_all[k].append(pred_mean) 319 | 320 | for key in preds_all.keys(): 321 | preds_all[key] = np.concatenate(preds_all[key]) 322 | 323 | best_mae, best_k = print_eval_result_by_groups_and_k(test_labels, train_labels, preds_all, cfg.logfile, interval=3) 324 | acc = np.sum(test_labels==preds_all[best_k])/len(test_labels) 325 | write_log(cfg.logfile, f'Acc : {acc*100:.2f}') 326 | sys.stdout.flush() 327 | return best_mae, best_k 328 | 329 | 330 | if __name__ == "__main__": 331 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 332 | main() 333 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from sklearn.metrics import confusion_matrix 7 | from datetime import datetime 8 | import wandb 9 | from copy import deepcopy 10 | 11 | class LabelSmoothing(nn.Module): 12 | """ 13 | NLL loss with label smoothing. 14 | """ 15 | def __init__(self, smoothing=0.0): 16 | """ 17 | Constructor for the LabelSmoothing module. 18 | :param smoothing: label smoothing factor 19 | """ 20 | super(LabelSmoothing, self).__init__() 21 | self.confidence = 1.0 - smoothing 22 | self.smoothing = smoothing 23 | 24 | def forward(self, x, target): 25 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 26 | 27 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 28 | nll_loss = nll_loss.squeeze(1) 29 | smooth_loss = -logprobs.mean(dim=-1) 30 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 31 | return loss.mean() 32 | 33 | 34 | class BCEWithLogitsLoss(nn.Module): 35 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None, num_classes=64): 36 | super(BCEWithLogitsLoss, self).__init__() 37 | self.num_classes = num_classes 38 | self.criterion = nn.BCEWithLogitsLoss(weight=weight, 39 | size_average=size_average, 40 | reduce=reduce, 41 | reduction=reduction, 42 | pos_weight=pos_weight) 43 | def forward(self, input, target): 44 | target_onehot = F.one_hot(target, num_classes=self.num_classes) 45 | return self.criterion(input, target_onehot) 46 | 47 | 48 | class AverageMeter(object): 49 | """Computes and stores the average and current value""" 50 | def __init__(self): 51 | self.reset() 52 | 53 | def reset(self): 54 | self.val = 0 55 | self.avg = 0 56 | self.sum = 0 57 | self.count = 0 58 | 59 | def update(self, val, n=1): 60 | self.val = val 61 | self.sum += val * n 62 | self.count += n 63 | self.avg = self.sum / self.count 64 | 65 | 66 | class ClassWiseAverageMeter(object): 67 | """Computes and stores the average and current value""" 68 | def __init__(self, n_cls): 69 | self.n_cls = n_cls 70 | self.reset() 71 | 72 | def reset(self): 73 | self.val = np.zeros([self.n_cls,]) 74 | self.avg = np.zeros([self.n_cls,]) 75 | self.sum = np.zeros([self.n_cls,]) 76 | self.count = np.ones([self.n_cls,]) * 1e-7 77 | self.total_avg = 0 78 | 79 | def update(self, val, n=[1,1,1]): 80 | self.val = val 81 | self.sum += val * n 82 | self.count += n 83 | self.avg = self.sum / self.count 84 | self.total_avg = np.sum(self.sum) / np.sum(self.count) 85 | 86 | 87 | def adjust_learning_rate(epoch, opt, optimizer): 88 | """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 89 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 90 | if steps > 0: 91 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) 92 | for param_group in optimizer.param_groups: 93 | param_group['lr'] = new_lr 94 | 95 | 96 | def accuracy(output, target, topk=(1,)): 97 | """Computes the accuracy over the k top predictions for the specified values of k""" 98 | with torch.no_grad(): 99 | maxk = max(topk) 100 | batch_size = target.size(0) 101 | 102 | _, pred = output.topk(maxk, 1, True, True) 103 | pred = pred.t() 104 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 105 | 106 | res = [] 107 | for k in topk: 108 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 109 | res.append(correct_k.mul_(100.0 / batch_size)) 110 | return res 111 | 112 | 113 | def cls_accuracy(output, target, n_cls=3): 114 | with torch.no_grad(): 115 | _, pred = output.topk(1, 1, True, True) 116 | pred = pred.view(-1) 117 | correct = pred.eq(target).cpu().numpy() 118 | accs = np.zeros([n_cls,]) 119 | cnts = np.ones([n_cls,]) * 1e-5 120 | target = target.cpu().numpy() 121 | for i_cls in range(n_cls): 122 | i_cls_idx = np.argwhere(target == i_cls).flatten() 123 | if len(i_cls_idx) > 0: 124 | cnts[i_cls] = len(i_cls_idx) 125 | accs[i_cls] = np.sum(correct[i_cls_idx])/len(i_cls_idx)*100 126 | 127 | return accs, cnts 128 | 129 | 130 | def cls_accuracy_bc(output, target, cls=[0,1,2], delta=0.1): 131 | with torch.no_grad(): 132 | accs = np.zeros([3, ]) 133 | cnts = np.ones([3,])* 1e-7 134 | _, pred = output.topk(1, 1, True, True) 135 | pred = pred.view(-1) 136 | correct = pred.eq(target).cpu().numpy() 137 | for i in range(len(target)): 138 | if target[i] == cls[0]: 139 | accs[0] += correct[i] 140 | cnts[0] += 1 141 | elif target[i] == cls[1]: 142 | accs[1] += correct[i] 143 | cnts[1] += 1 144 | elif target[i] == cls[2]: 145 | i_correct = np.abs(output[i][0].cpu().numpy() - 0.5) < delta 146 | accs[2] += i_correct 147 | cnts[2] += 1 148 | else: 149 | raise ValueError(f'Out of range error! {target[i]} is given') 150 | accs = accs/ cnts *100 151 | return accs, cnts 152 | 153 | 154 | def get_confusion_matrix_bc(output, target, cls=[-1,0,1], delta=0.1): 155 | with torch.no_grad(): 156 | _, pred = output.topk(1, 1, True, True) 157 | pred = pred.view(-1).cpu().numpy() 158 | 159 | for i in range(len(target)): 160 | if target[i] == cls[0]: 161 | if np.abs(output[i][0].cpu().numpy()-0.5) < delta: 162 | pred[i] = -1 163 | else: 164 | continue 165 | 166 | pred = np.transpose(pred) 167 | cm = confusion_matrix(target.cpu().numpy(), pred) 168 | 169 | return cm, np.diag(cm)/np.sum(cm, axis=-1) 170 | 171 | 172 | def get_confusion_matrix(output, target): 173 | with torch.no_grad(): 174 | _, pred = output.topk(1, 1, True, True) 175 | pred = pred.t() 176 | cm = confusion_matrix(target.cpu().numpy(), pred.cpu().numpy()) 177 | 178 | return cm, np.diag(cm)/np.sum(cm, axis=-1) 179 | 180 | 181 | def split_weights(net): 182 | """split network weights into to categlories, 183 | one are weights in conv layer and linear layer, 184 | others are other learnable paramters(conv bias, 185 | bn weights, bn bias, linear bias) 186 | Args: 187 | net: network architecture 188 | 189 | Returns: 190 | a dictionary of params splite into to categlories 191 | """ 192 | 193 | decay = [] 194 | no_decay = [] 195 | 196 | for m in net.modules(): 197 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 198 | decay.append(m.weight) 199 | 200 | if m.bias is not None: 201 | no_decay.append(m.bias) 202 | 203 | else: 204 | if hasattr(m, 'weight'): 205 | no_decay.append(m.weight) 206 | if hasattr(m, 'bias'): 207 | no_decay.append(m.bias) 208 | 209 | assert len(list(net.parameters())) == len(decay) + len(no_decay) 210 | 211 | return [dict(params=decay), dict(params=no_decay, weight_decay=0)] 212 | 213 | 214 | def write_log(log_file, out_str): 215 | log_file.write(out_str + '\n') 216 | log_file.flush() 217 | print(out_str) 218 | 219 | 220 | def cross_entropy_loss_with_one_hot_labels(logits, labels): 221 | log_probs = nn.functional.log_softmax(logits, dim=1) 222 | loss = -torch.sum(log_probs*labels, dim=1) 223 | return loss.mean() 224 | 225 | 226 | def cross_entropy_loss_with_one_hot_labels_with_weights(logits, labels, weights): 227 | log_probs = nn.functional.log_softmax(logits, dim=1) 228 | loss = -torch.sum(log_probs*labels, dim=1) * weights 229 | return loss.mean() 230 | 231 | 232 | def mix_ce_and_kl_loss(logits, labels, mask, alpha=1): 233 | inv_mask = mask.__invert__() 234 | log_probs = nn.functional.log_softmax(logits, dim=1) 235 | ce_loss = -torch.sum(log_probs[mask]*labels[mask], dim=1) 236 | kl_loss = torch.nn.KLDivLoss(reduction='batchmean')(log_probs[inv_mask], labels[inv_mask]) 237 | loss = ce_loss.mean() + alpha*kl_loss 238 | return loss 239 | 240 | 241 | def load_one_image(img_path, width=256, height=256): 242 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 243 | img = cv2.resize(img, (width, height)) 244 | return img 245 | 246 | 247 | def load_images(img_root, img_name_list, width=256, height=256): 248 | num_images = len(img_name_list) 249 | images = np.zeros([num_images, height, width, 3], dtype=np.uint8) 250 | for idx, img_path in enumerate(img_name_list): 251 | img = cv2.imread(os.path.join(img_root, img_path), cv2.IMREAD_COLOR) 252 | images[idx] = cv2.resize(img, (width, height)) 253 | return images 254 | 255 | def to_np(x): 256 | return x.cpu().detach().numpy() 257 | 258 | 259 | def get_current_time(): 260 | _now = datetime.now() 261 | _now = str(_now)[:-7] 262 | return _now 263 | 264 | 265 | def display_lr(optimizer): 266 | for param_group in optimizer.param_groups: 267 | print(param_group['lr'], param_group['initial_lr']) 268 | 269 | 270 | def get_distribution(data): 271 | cls, cnt = np.unique(data, return_counts=True) 272 | for i_cls, i_cnt in zip(cls, cnt): 273 | print(f'{i_cls}: {i_cnt} ({i_cnt/len(data)*100:.2f}%)') 274 | print(f'total: {len(data)}') 275 | 276 | 277 | def make_dir(path): 278 | if not os.path.isdir(path): 279 | os.makedirs(path) 280 | 281 | 282 | def log_configs(cfg, log_file='log.txt'): 283 | if os.path.exists(f'{cfg.save_folder}/{log_file}'): 284 | log_file = open(f'{cfg.save_folder}/{log_file}', 'a') 285 | else: 286 | log_file = open(f'{cfg.save_folder}/{log_file}', 'w') 287 | opt_dict = vars(cfg) 288 | for key in opt_dict.keys(): 289 | write_log(log_file, f'{key}: {opt_dict[key]}') 290 | return log_file 291 | 292 | 293 | def save_ckpt(cfg, model, postfix): 294 | state = { 295 | 'model': model.state_dict() if cfg.n_gpu <= 1 else model.module.state_dict(), 296 | } 297 | save_file = os.path.join(cfg.save_folder, f'{postfix}') 298 | torch.save(state, save_file) 299 | print(f'ckpt saved to {save_file}.') 300 | 301 | 302 | def set_wandb(cfg, key='private_key'): 303 | wandb.login(key=key) 304 | wandb.init(project=cfg.experiment_name, tags=[cfg.dataset]) 305 | wandb.config.update(cfg) 306 | wandb.save('*.py') 307 | wandb.run.save() 308 | 309 | 310 | def extract_embs(encoder, data_loader): 311 | encoder.eval() 312 | embs = [] 313 | inds = [] 314 | with torch.no_grad(): 315 | for x_base, _, item in data_loader: 316 | x_base = x_base.cuda() 317 | embs.append(encoder(x_base).cpu()) 318 | inds.append(item) 319 | embs = torch.cat(embs) 320 | inds = torch.cat(inds) 321 | embs_temp = deepcopy(embs) 322 | embs[inds] = embs_temp 323 | 324 | return embs 325 | 326 | 327 | def to_dtype(x, tensor=None, dtype=None): 328 | if not torch.is_autocast_enabled(): 329 | dt = dtype if dtype is not None else tensor.dtype 330 | if x.dtype != dt: 331 | x = x.type(dt) 332 | return x 333 | 334 | def to_device(x, tensor=None, device=None, dtype=None): 335 | dv = device if device is not None else tensor.device 336 | if x.device != dv: 337 | x = x.to(dv) 338 | if dtype is not None: 339 | x = to_dtype(x, dtype=dtype) 340 | return x 341 | 342 | 343 | def print_eval_result_by_groups_and_k(gt, ref_gt, preds_all, log_file, interval=10): 344 | test_cls_arr, cnt = np.unique(gt, return_counts=True) 345 | test_cls_min = test_cls_arr.min() 346 | test_cls_max = test_cls_arr.max() 347 | n_groups = int((test_cls_max - test_cls_min + 1) / interval + 0.5) 348 | 349 | title = 'Group \\ K |' 350 | for k in preds_all.keys(): 351 | title += f" {k:<4} " 352 | title = title + ' | Best K | #Test | #Train ' 353 | write_log(log_file, title) 354 | for i_group in range(n_groups): 355 | min_rank = interval * i_group 356 | max_rank = min(test_cls_max + 1, min_rank + interval) 357 | sample_idx_in_group = np.argwhere(np.logical_and(gt >= min_rank, gt < max_rank)).flatten() 358 | ref_sample_idx_in_group = np.argwhere(np.logical_and(ref_gt >= min_rank, ref_gt < max_rank)).flatten() 359 | 360 | if len(sample_idx_in_group) < 1: 361 | continue 362 | to_print = f' {min_rank:<3}~ {max_rank - 1:<3} |' 363 | 364 | best_k = -1 365 | best_mae = 1000 366 | for k in preds_all.keys(): 367 | i_group_errors_at_k = np.abs(preds_all[k][sample_idx_in_group] - gt[sample_idx_in_group]) 368 | i_group_mean_at_k = np.mean(i_group_errors_at_k) 369 | to_print += f' {i_group_mean_at_k:.3f}' if i_group_mean_at_k<10 else f' {i_group_mean_at_k:.2f}' 370 | if i_group_mean_at_k < best_mae: 371 | best_mae = i_group_mean_at_k 372 | best_k = k 373 | to_print += f' | {best_k:<2} | {len(sample_idx_in_group):<4} | {len(ref_sample_idx_in_group):<4} ' 374 | write_log(log_file, to_print) 375 | 376 | mean_all = ' Total |' 377 | best_k = -1 378 | best_mae = 1000 379 | for k in preds_all.keys(): 380 | mean_at_k = np.mean(np.abs(preds_all[k] - gt)) 381 | mean_all += f' {mean_at_k:.3f}' 382 | if mean_at_k < best_mae: 383 | best_mae = mean_at_k 384 | best_k = k 385 | mean_all += f' | {best_k:<2} | {len(gt):<5} | {len(ref_gt):<5}' 386 | write_log(log_file, mean_all) 387 | write_log(log_file, f'Best Total MAE : {best_mae:.3f}\n') 388 | return best_mae, best_k 389 | 390 | 391 | 392 | def sample_fdcs(model, fdc_pts, train_labels, cfg): 393 | to_select = np.unique(train_labels) 394 | model.select_reference_points(to_select.astype(np.int32), fdc_pts) 395 | cfg.fiducial_point_num = len(to_select) 396 | return model, cfg 397 | -------------------------------------------------------------------------------- /utils/comparison_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # ==================================================================================================================== # 5 | # compute order # 6 | # ==================================================================================================================== # 7 | # --- functions for ternary order 8 | def compute_ternary_order_tau(base_rank, ref_rank, tau=1e-4): 9 | if base_rank - ref_rank > tau: 10 | order = 0 11 | elif abs(base_rank - ref_rank) <= tau: 12 | order = 2 ################################################## CHECK LATER!!!! 13 | elif base_rank - ref_rank < -tau: 14 | order = 1 15 | else: 16 | raise ValueError(f'order relation is wrong. (base,ref,tau): {base_rank, ref_rank, tau}') 17 | return order 18 | 19 | 20 | def compute_ternary_order_fixed_base(base_rank, ref_ranks, tau=1e-4): 21 | # base score is fixed. 22 | num_scores = len(ref_ranks) 23 | orders = np.zeros((num_scores,), dtype=np.int32) 24 | for idx in range(num_scores): 25 | orders[idx] = compute_ternary_order_tau(base_rank, ref_ranks[idx], tau) 26 | return orders 27 | 28 | 29 | def compute_ternary_order_fixed_ref(ref_rank, base_ranks, tau=1e-4): 30 | # ref score is fixed. 31 | num_scores = len(base_ranks) 32 | orders = np.zeros((num_scores,), dtype=np.int32) 33 | for idx in range(num_scores): 34 | orders[idx] = compute_ternary_order_tau(base_ranks[idx], ref_rank, tau) 35 | return orders 36 | 37 | 38 | # --- functions for binary order 39 | def compute_binary_order(base_rank, ref_rank): 40 | if base_rank > ref_rank: 41 | order = 0 42 | elif base_rank < ref_rank: 43 | order = 1 44 | elif base_rank == ref_rank: 45 | order = 2 46 | else: 47 | raise ValueError(f'order relation is wrong. (base,ref,tau): {base_rank, ref_rank}') 48 | return order 49 | 50 | 51 | def compute_binary_order_fixed_base(base_rank, ref_ranks): 52 | # base score is fixed. 53 | num_scores = len(ref_ranks) 54 | orders = np.zeros((num_scores,), dtype=np.int32) 55 | for idx in range(num_scores): 56 | orders[idx] = compute_binary_order(base_rank, ref_ranks[idx]) 57 | return orders 58 | 59 | 60 | def compute_binary_order_fixed_ref(ref_rank, base_ranks): 61 | # ref score is fixed. 62 | num_scores = len(base_ranks) 63 | orders = np.zeros((num_scores,), dtype=np.int32) 64 | for idx in range(num_scores): 65 | orders[idx] = compute_binary_order(base_ranks[idx], ref_rank) 66 | return orders 67 | 68 | # ==================================================================================================================== # 69 | # estimation method (saaty, voting) # 70 | # ==================================================================================================================== # 71 | def one_step_voting_ternary(orders, ranks, rank_levels, tau=1e-4): 72 | num_refs = len(orders) 73 | votes = np.zeros_like(rank_levels, dtype=np.int32) 74 | 75 | for idx in range(num_refs): 76 | # compute where to vote 77 | order = orders[idx] 78 | if order == 0: 79 | try: 80 | min_idx = np.argwhere(rank_levels > (ranks[idx] + tau))[0, 0] 81 | except: 82 | min_idx = -1 83 | max_idx = len(rank_levels) 84 | elif order == 2: ############ CHECK LATER!!!!! 85 | try: 86 | min_idx = np.argwhere(rank_levels >= (ranks[idx] - tau))[0, 0] 87 | except: 88 | min_idx = -1 89 | 90 | max_idx = np.argwhere(rank_levels <= (ranks[idx] + tau))[-1, 0] + 1 91 | 92 | elif order == 1: ############ CHECK LATER!!!!! 93 | min_idx = 0 94 | try: 95 | max_idx = np.argwhere(rank_levels < (ranks[idx] - tau))[-1, 0] + 1 96 | except: 97 | max_idx = 0 98 | else: 99 | raise ValueError(f'order value is out of range: {order}') 100 | 101 | # voting 102 | votes[min_idx:max_idx] += 1 103 | winners = np.argwhere(votes == np.amax(votes)) 104 | # elected_index = winners[(len(winners)/2)][0] # take the middle value when multiple winners exist. 105 | elected_index = winners[0][0] # take the min value when multiple winners exist. 106 | return rank_levels[elected_index], votes 107 | 108 | 109 | def one_step_voting_binary(orders, ranks, rank_levels, tau=0.0): 110 | num_refs = len(orders) 111 | votes = np.zeros_like(rank_levels, dtype=np.int32) 112 | votes_for_sum = np.zeros_like(rank_levels, dtype=np.float32) 113 | 114 | for idx in range(num_refs): 115 | # compute where to vote 116 | order = orders[idx] 117 | if order == 0: 118 | try: 119 | min_idx = np.argwhere(rank_levels >= (ranks[idx] + tau))[0, 0] 120 | except: 121 | min_idx = -1 122 | max_idx = len(rank_levels) 123 | 124 | elif order == 1: 125 | min_idx = 0 126 | try: 127 | max_idx = np.argwhere(rank_levels < (ranks[idx] - tau))[-1, 0] + 1 128 | except: 129 | max_idx = 0 130 | else: 131 | raise ValueError(f'order value is out of range: {order}') 132 | 133 | # voting 134 | votes[min_idx:max_idx] += 1 135 | # votes_for_sum[min_idx:max_idx] += 1/(len(rank_levels) - min_idx) 136 | winners = np.argwhere(votes == np.amax(votes)) 137 | elected_index = winners[int((len(winners)/2))][0] # take the middle value when multiple winners exist. 138 | # elected_index = winners[0][0] # take the min value when multiple winners exist. 139 | return rank_levels[elected_index], votes 140 | 141 | 142 | def soft_voting_ternary(probs, ranks, rank_levels, tau=1e-4): 143 | num_refs = len(probs) 144 | rank_levels = rank_levels.astype(np.float32) 145 | p_x = np.zeros_like(rank_levels) 146 | 147 | for i_ref, ref_score in enumerate(ranks): 148 | cond_p_per_levels = np.zeros((len(rank_levels), 3)) 149 | cond_probs = _conditional_probs_uniform_ternary(ref_score, tau, rank_levels) 150 | order_per_levels = compute_ternary_order_fixed_ref(ref_score, rank_levels, tau) 151 | for i_level, order in enumerate(order_per_levels): 152 | if order == 1: 153 | cond_p_per_levels[i_level, order] = cond_probs[order] 154 | else: 155 | cond_p_per_levels[i_level, order] = cond_probs[order] 156 | p_x += np.matmul(cond_p_per_levels, probs[i_ref]) 157 | p_x = p_x / num_refs 158 | max_idx = np.argmax(p_x) 159 | # plt.scatter(score_levels, p_x) 160 | # plt.xticks(score_levels) 161 | # plt.grid() 162 | 163 | # for binary classification : summation method 164 | low_scores = np.squeeze(np.argwhere(rank_levels < 5.0)) 165 | high_scores = np.squeeze(np.argwhere(rank_levels >= 5.0)) 166 | if np.sum(p_x[low_scores]) <= np.sum(p_x[high_scores]): 167 | pred_by_sum = 0 168 | else: 169 | pred_by_sum = 1 170 | 171 | return rank_levels[max_idx], np.sum(rank_levels * p_x), pred_by_sum 172 | 173 | 174 | def soft_voting_binary(probs, ranks, rank_levels, tau=0.0): 175 | num_refs = len(probs) 176 | rank_levels = rank_levels.astype(np.float32) 177 | p_x = np.zeros_like(rank_levels) 178 | 179 | for i_ref, ref_score in enumerate(ranks): 180 | cond_p_per_levels = np.zeros((len(rank_levels), 2)) 181 | cond_probs = _conditional_probs_uniform_binary(ref_score, rank_levels, tau=tau) 182 | order_per_levels = compute_binary_order_fixed_ref(ref_score, rank_levels) 183 | for i_level, order in enumerate(order_per_levels): 184 | cond_p_per_levels[i_level, order] = cond_probs[order] 185 | p_x += np.matmul(cond_p_per_levels, probs[i_ref]) 186 | 187 | # normalize the sum of probs to be 1.0 188 | p_x = p_x / num_refs 189 | max_idx = np.argmax(p_x) 190 | # plt.scatter(score_levels, p_x) 191 | # plt.xticks(score_levels) 192 | # plt.grid() 193 | 194 | # for binary classification : summation method 195 | low_scores = np.squeeze(np.argwhere(rank_levels < 5.0)) 196 | high_scores = np.squeeze(np.argwhere(rank_levels >= 5.0)) 197 | if np.sum(p_x[low_scores]) <= np.sum(p_x[high_scores]): 198 | pred_by_sum = 0 199 | else: 200 | pred_by_sum = 1 201 | 202 | return rank_levels[max_idx], np.sum(rank_levels * p_x), pred_by_sum 203 | 204 | 205 | def MAP_rule_binary(probs, ranks, rank_levels, tau=0.0): 206 | num_refs = len(probs) 207 | rank_levels = rank_levels.astype(np.float32) 208 | p_x = np.zeros_like(rank_levels) 209 | 210 | for i_ref, ref_score in enumerate(ranks): 211 | cond_p_per_levels = np.zeros((len(rank_levels), 2)) 212 | cond_probs = _conditional_probs_uniform_binary(ref_score, rank_levels, tau=tau) 213 | for i_level, i_rank in enumerate(rank_levels): 214 | if ref_score == i_rank: 215 | cond_p_per_levels[i_level, 0] = cond_probs[0]*(1/2) 216 | cond_p_per_levels[i_level, 1] = cond_probs[1]*(1/2) 217 | elif ref_score > i_rank: 218 | cond_p_per_levels[i_level, 1] = cond_probs[1] 219 | elif ref_score < i_rank: 220 | cond_p_per_levels[i_level, 0] = cond_probs[0] 221 | 222 | p_x += np.matmul(cond_p_per_levels, probs[i_ref]) 223 | 224 | # normalize the sum of probs to be 1.0 225 | p_x = p_x / num_refs 226 | winners = np.argwhere(p_x == np.amax(p_x)) 227 | max_idx = winners[int((len(winners) / 2))][0] 228 | # max_idx = np.argmax(p_x) 229 | # plt.scatter(score_levels, p_x) 230 | # plt.xticks(score_levels) 231 | # plt.grid() 232 | # 233 | # # for binary classification : summation method 234 | # low_scores = np.squeeze(np.argwhere(rank_levels < 5.0)) 235 | # high_scores = np.squeeze(np.argwhere(rank_levels >= 5.0)) 236 | # if np.sum(p_x[low_scores]) <= np.sum(p_x[high_scores]): 237 | # pred_by_sum = 0 238 | # else: 239 | # pred_by_sum = 1 240 | 241 | # return rank_levels[max_idx], np.sum(rank_levels * p_x), pred_by_sum 242 | return rank_levels[max_idx], np.sum(rank_levels * p_x) 243 | 244 | ### for debug 245 | 246 | 247 | def MC_and_MAP_rule(orders, probs, ranks, rank_levels, gt, tau=0.0, is_debug=False): 248 | # MC rule 249 | num_refs = len(orders) 250 | votes = np.zeros_like(rank_levels, dtype=np.int32) 251 | votes_for_sum = np.zeros_like(rank_levels, dtype=np.float32) 252 | 253 | for idx in range(num_refs): 254 | # compute where to vote 255 | order = orders[idx] 256 | if order == 0: 257 | try: 258 | min_idx = np.argwhere(rank_levels > (ranks[idx] + tau))[0, 0] 259 | except: 260 | min_idx = -1 261 | max_idx = len(rank_levels) 262 | 263 | elif order == 1: 264 | min_idx = 0 265 | try: 266 | max_idx = np.argwhere(rank_levels < (ranks[idx] - tau))[-1, 0] + 1 267 | except: 268 | max_idx = 0 269 | elif order == 2: 270 | min_idx = np.argwhere(rank_levels==ranks[idx]) 271 | max_idx = np.argwhere(rank_levels==ranks[idx]) 272 | else: 273 | raise ValueError(f'order value is out of range: {order}') 274 | 275 | # voting 276 | if order == 2: 277 | votes[min_idx] += 1 278 | else: 279 | votes[min_idx:max_idx] += 1 280 | # eq_idx = np.argwhere(rank_levels == ranks[idx]).flatten()[0] 281 | # votes[eq_idx] += 0.5 282 | # votes_for_sum[min_idx:max_idx] += 1/(len(rank_levels) - min_idx) 283 | winners = np.argwhere(votes == np.amax(votes)) 284 | elected_index = winners[int((len(winners) / 2))][0] # take the middle value when multiple winners exist. 285 | # elected_index = winners[0][0] # take the min value when multiple winners exist. 286 | MC_estimation = rank_levels[elected_index] 287 | 288 | # MAP rule 289 | num_refs = len(probs) 290 | rank_levels = rank_levels.astype(np.float32) 291 | p_x = np.zeros_like(rank_levels) 292 | 293 | for i_ref, ref_score in enumerate(ranks): 294 | cond_p_per_levels = np.zeros((len(rank_levels), 2)) 295 | cond_probs = _conditional_probs_uniform_binary(ref_score, rank_levels, tau=tau) 296 | for i_level, i_rank in enumerate(rank_levels): 297 | if ref_score == i_rank: 298 | cond_p_per_levels[i_level, 0] = cond_probs[0] * (1 / 2) 299 | cond_p_per_levels[i_level, 1] = cond_probs[1] * (1 / 2) 300 | elif ref_score > i_rank: 301 | cond_p_per_levels[i_level, 1] = cond_probs[1] 302 | elif ref_score < i_rank: 303 | cond_p_per_levels[i_level, 0] = cond_probs[0] 304 | 305 | p_x += np.matmul(cond_p_per_levels, probs[i_ref]) 306 | 307 | # normalize the sum of probs to be 1.0 308 | p_x = p_x / num_refs 309 | winners = np.argwhere(p_x == np.amax(p_x)) 310 | max_idx = winners[int((len(winners) / 2))][0] 311 | MAP_estimation = rank_levels[max_idx] 312 | 313 | # MAP_estimation = np.sum(rank_levels * p_x) 314 | 315 | # window_size = 5 316 | # window_idx = np.argmax(np.convolve(p_x, np.ones(window_size), 'valid')) 317 | # p_x_in_window = p_x[window_idx: window_idx+window_size] 318 | # p_x_in_window = p_x_in_window / np.sum(p_x_in_window) 319 | # ranks_in_window = rank_levels[window_idx: window_idx+window_size] 320 | # MAP_estimation2 = np.sum(ranks_in_window * p_x_in_window) 321 | 322 | if abs(abs(MAP_estimation-gt) - abs(MC_estimation-gt)) > 5 and is_debug: 323 | print(f'MAP:{MAP_estimation}, MC:{MC_estimation}, GT:{gt}') 324 | print(f'MAP estimation error : {abs(MAP_estimation - gt)}') 325 | print(f'MC estimation error : {abs(MC_estimation - gt)}') 326 | 327 | return MC_estimation, MAP_estimation 328 | 329 | 330 | def _conditional_probs_uniform_ternary(ref_rank, rank_levels, tau=1e-4, assertion=False): 331 | n_high = len(np.argwhere(rank_levels > (ref_rank + tau))) 332 | n_similar = len(np.argwhere(np.logical_and((ref_rank - tau) <= rank_levels, rank_levels <= (ref_rank + tau)))) 333 | n_low = len(np.argwhere(rank_levels < (ref_rank - tau))) 334 | 335 | if assertion: 336 | assert((n_high + n_similar + n_low) == len(rank_levels)) 337 | 338 | cond_probs = np.zeros((3,)) 339 | for idx, n_levels in enumerate([n_high, n_similar, n_low]): 340 | if n_levels < 1: # to prevent dividing by zero 341 | continue 342 | cond_probs[idx] = 1/n_levels 343 | return cond_probs 344 | 345 | 346 | def _conditional_probs_uniform_binary(ref_rank, rank_levels, tau=0.0): 347 | n_high = len(np.argwhere(rank_levels > (ref_rank + tau))) + 0.5 348 | n_low = len(np.argwhere(rank_levels < (ref_rank - tau))) + 0.5 349 | # assert((n_high + n_low) == len(rank_levels)) 350 | 351 | cond_probs = np.zeros((2,)) 352 | for idx, n_levels in enumerate([n_high, n_low]): 353 | # if n_levels < 1: 354 | # continue 355 | cond_probs[idx] = 1/n_levels 356 | return cond_probs 357 | 358 | 359 | def find_kNN(queries, samples, k=1, metric='L2'): 360 | """ 361 | :param queries: BxNxC 362 | :param samples: BxMxC 363 | :param metric: 364 | :return: 365 | """ 366 | if len(queries.shape) == 2: 367 | queries = queries.view(1, queries.shape[0], queries.shape[1]) 368 | if len(samples.shape) == 2: 369 | samples = samples.view(1, samples.shape[0], samples.shape[1]) 370 | 371 | if metric == 'L2': 372 | dist_mat = -torch.cdist(queries, samples) # BxNxM 373 | 374 | elif metric == 'cosine': 375 | # queries = torch.nn.functional.normalize(queries, dim=-1) 376 | # samples = torch.nn.functional.normalize(samples, dim=-1) 377 | 378 | dist_mat = torch.matmul(queries, samples.transpose(2,1)) 379 | 380 | vals, inds = torch.topk(dist_mat, k, dim=-1) 381 | return vals, inds 382 | --------------------------------------------------------------------------------