├── .gitignore ├── LICENSE ├── README.md ├── agedb-dir ├── README.md ├── data │ ├── agedb.csv │ ├── create_agedb.py │ └── preprocess_agedb.py ├── datasets.py ├── fds.py ├── loss.py ├── resnet.py ├── train.py └── utils.py ├── imdb-wiki-dir ├── README.md ├── data │ ├── create_imdb_wiki.py │ ├── imdb_wiki.csv │ └── preprocess_imdb_wiki.py ├── datasets.py ├── download_imdb_wiki.py ├── fds.py ├── loss.py ├── resnet.py ├── train.py └── utils.py ├── nyud2-dir ├── README.md ├── data │ ├── nyu2_train_FDS_subset.csv │ └── test_balanced_mask.npy ├── download_nyud2.py ├── loaddata.py ├── models │ ├── __init__.py │ ├── fds.py │ ├── modules.py │ ├── net.py │ └── resnet.py ├── nyu_transform.py ├── preprocess_nyud2.py ├── test.py ├── train.py └── util.py ├── sts-b-dir ├── README.md ├── allennlp_mods │ └── numeric_field.py ├── evaluate.py ├── fds.py ├── glove │ └── download_glove.py ├── glue_data │ ├── STS-B │ │ ├── dev.tsv │ │ ├── dev_new.tsv │ │ ├── test.tsv │ │ ├── test_new.tsv │ │ ├── train.tsv │ │ └── train_new.tsv │ └── create_sts.py ├── loss.py ├── models.py ├── preprocess.py ├── requirements.txt ├── tasks.py ├── train.py ├── trainer.py └── util.py ├── teaser ├── agedb_dir.png ├── fds.gif ├── imdb_wiki_dir.png ├── lds.gif ├── nyud2_dir.png ├── overview.gif ├── shhs_dir.png └── stsb_dir.png └── tutorial ├── README.md └── tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | .idea/* 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yuzhe Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Delving into Deep Imbalanced Regression 2 | 3 | This repository contains the implementation code for paper:
4 | __Delving into Deep Imbalanced Regression__
5 | [Yuzhe Yang](http://www.mit.edu/~yuzhe/), [Kaiwen Zha](https://kaiwenzha.github.io/), [Ying-Cong Chen](https://yingcong.github.io/), [Hao Wang](http://www.wanghao.in/), [Dina Katabi](https://people.csail.mit.edu/dina/)
6 | _38th International Conference on Machine Learning (ICML 2021), **Long Oral**_
7 | [[Project Page](http://dir.csail.mit.edu/)] [[Paper](https://arxiv.org/abs/2102.09554)] [[Video](https://youtu.be/grJGixofQRU)] [[Blog Post](https://towardsdatascience.com/strategies-and-tactics-for-regression-on-imbalanced-data-61eeb0921fca)] [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/YyzHarry/imbalanced-regression/blob/master/tutorial/tutorial.ipynb) 8 | ___ 9 |

10 |
11 | Deep Imbalanced Regression (DIR) aims to learn from imbalanced data with continuous targets,
tackle potential missing data for certain regions, and generalize to the entire target range. 12 |

13 | 14 | 15 | ## Beyond Imbalanced Classification: Brief Introduction for DIR 16 | Existing techniques for learning from imbalanced data focus on targets with __categorical__ indices, i.e., the targets are different classes. However, many real-world tasks involve __continuous__ and even infinite target values. We systematically investigate _Deep Imbalanced Regression (DIR)_, which aims to learn continuous targets from natural imbalanced data, deal with potential missing data for certain target values, and generalize to the entire target range. 17 | 18 | We curate and benchmark large-scale DIR datasets for common real-world tasks in _computer vision_, _natural language processing_, and _healthcare_ domains, ranging from single-value prediction such as age, text similarity score, health condition score, to dense-value prediction such as depth. 19 | 20 | 21 | ## Usage 22 | We separate the codebase for different datasets into different subfolders. Please go into the subfolders for more information (e.g., installation, dataset preparation, training, evaluation & models). 23 | 24 | #### __[IMDB-WIKI-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir)__  |  __[AgeDB-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/agedb-dir)__  |  __[NYUD2-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/nyud2-dir)__  |  __[STS-B-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/sts-b-dir)__ 25 | 26 | 27 | ## Highlights 28 | __(1) :heavy_check_mark: New Task:__ Deep Imbalanced Regression (DIR) 29 | 30 | __(2) :heavy_check_mark: New Techniques:__ 31 | 32 | | ![image](teaser/lds.gif) | ![image](teaser/fds.gif) | 33 | | :-: | :-: | 34 | | Label distribution smoothing (LDS) | Feature distribution smoothing (FDS) | 35 | 36 | __(3) :heavy_check_mark: New Benchmarks:__
37 | - _Computer Vision:_ :bulb: IMDB-WIKI-DIR (age) / AgeDB-DIR (age) / NYUD2-DIR (depth) 38 | - _Natural Language Processing:_ :clipboard: STS-B-DIR (text similarity score) 39 | - _Healthcare:_ :hospital: SHHS-DIR (health condition score) 40 | 41 | | [IMDB-WIKI-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir) | [AgeDB-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/agedb-dir) | [NYUD2-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/nyud2-dir) | [STS-B-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/sts-b-dir) | SHHS-DIR | 42 | | :-: | :-: | :-: | :-: | :-: | 43 | | ![image](teaser/imdb_wiki_dir.png) | ![image](teaser/agedb_dir.png) | ![image](teaser/nyud2_dir.png) | ![image](teaser/stsb_dir.png) | ![image](teaser/shhs_dir.png) | 44 | 45 | 46 | ## Apply LDS and FDS on Other Datasets / Models 47 | We provide examples of how to apply LDS and FDS on other customized datasets and/or models. 48 | 49 | ### LDS 50 | To apply LDS on your customized dataset, you will first need to estimate the effective label distribution: 51 | ```python 52 | from collections import Counter 53 | from scipy.ndimage import convolve1d 54 | from utils import get_lds_kernel_window 55 | 56 | # preds, labels: [Ns,], "Ns" is the number of total samples 57 | preds, labels = ..., ... 58 | # assign each label to its corresponding bin (start from 0) 59 | # with your defined get_bin_idx(), return bin_index_per_label: [Ns,] 60 | bin_index_per_label = [get_bin_idx(label) for label in labels] 61 | 62 | # calculate empirical (original) label distribution: [Nb,] 63 | # "Nb" is the number of bins 64 | Nb = max(bin_index_per_label) + 1 65 | num_samples_of_bins = dict(Counter(bin_index_per_label)) 66 | emp_label_dist = [num_samples_of_bins.get(i, 0) for i in range(Nb)] 67 | 68 | # lds_kernel_window: [ks,], here for example, we use gaussian, ks=5, sigma=2 69 | lds_kernel_window = get_lds_kernel_window(kernel='gaussian', ks=5, sigma=2) 70 | # calculate effective label distribution: [Nb,] 71 | eff_label_dist = convolve1d(np.array(emp_label_dist), weights=lds_kernel_window, mode='constant') 72 | ``` 73 | With the estimated effective label distribution, one straightforward option is to use the loss re-weighting scheme: 74 | ```python 75 | from loss import weighted_mse_loss 76 | 77 | # Use re-weighting based on effective label distribution, sample-wise weights: [Ns,] 78 | eff_num_per_label = [eff_label_dist[bin_idx] for bin_idx in bin_index_per_label] 79 | weights = [np.float32(1 / x) for x in eff_num_per_label] 80 | 81 | # calculate loss 82 | loss = weighted_mse_loss(preds, labels, weights=weights) 83 | ``` 84 | 85 | ### FDS 86 | To apply FDS on your customized data/model, you will first need to define the FDS module in your network: 87 | ```python 88 | from fds import FDS 89 | 90 | config = dict(feature_dim=..., start_update=0, start_smooth=1, kernel='gaussian', ks=5, sigma=2) 91 | 92 | def Network(nn.Module): 93 | def __init__(self, **config): 94 | super().__init__() 95 | self.feature_extractor = ... 96 | self.regressor = nn.Linear(config['feature_dim'], 1) # FDS operates before the final regressor 97 | self.FDS = FDS(**config) 98 | 99 | def forward(self, inputs, labels, epoch): 100 | features = self.feature_extractor(inputs) # features: [batch_size, feature_dim] 101 | # smooth the feature distributions over the target space 102 | smoothed_features = features 103 | if self.training and epoch >= config['start_smooth']: 104 | smoothed_features = self.FDS.smooth(smoothed_features, labels, epoch) 105 | preds = self.regressor(smoothed_features) 106 | 107 | return {'preds': preds, 'features': features} 108 | ``` 109 | During training, you will need to update the FDS statistics after each training epoch: 110 | ```python 111 | model = Network(**config) 112 | 113 | for epoch in range(num_epochs): 114 | for (inputs, labels) in train_loader: 115 | # standard training pipeline 116 | ... 117 | 118 | # update FDS statistics after each training epoch 119 | if epoch >= config['start_update']: 120 | # collect features and labels for all training samples 121 | ... 122 | # training_features: [num_samples, feature_dim], training_labels: [num_samples,] 123 | training_features, training_labels = ..., ... 124 | model.FDS.update_last_epoch_stats(epoch) 125 | model.FDS.update_running_stats(training_features, training_labels, epoch) 126 | ``` 127 | 128 | 129 | ## Updates 130 | - [06/2021] We provide a [hands-on tutorial](https://github.com/YyzHarry/imbalanced-regression/tree/main/tutorial) of DIR. Check it out! [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/YyzHarry/imbalanced-regression/blob/master/tutorial/tutorial.ipynb) 131 | - [05/2021] We create a [Blog post](https://towardsdatascience.com/strategies-and-tactics-for-regression-on-imbalanced-data-61eeb0921fca) for this work (version in Chinese is also available [here](https://zhuanlan.zhihu.com/p/369627086)). Check it out for more details! 132 | - [05/2021] Paper accepted to ICML 2021 as a __Long Talk__. We have released the code and models. You can find all reproduced checkpoints via [this link](https://drive.google.com/drive/folders/1UfFJNIG-LPOMecwi1tfYzEViBiAYhNU0?usp=sharing), or go into each subfolder for models for each dataset. 133 | - [02/2021] [arXiv version](https://arxiv.org/abs/2102.09554) posted. Please stay tuned for updates. 134 | 135 | 136 | ## Citation 137 | If you find this code or idea useful, please cite our work: 138 | ```bib 139 | @inproceedings{yang2021delving, 140 | title={Delving into Deep Imbalanced Regression}, 141 | author={Yang, Yuzhe and Zha, Kaiwen and Chen, Ying-Cong and Wang, Hao and Katabi, Dina}, 142 | booktitle={International Conference on Machine Learning (ICML)}, 143 | year={2021} 144 | } 145 | ``` 146 | 147 | 148 | ## Contact 149 | If you have any questions, feel free to contact us through email (yuzhe@mit.edu & kzha@mit.edu) or Github issues. Enjoy! 150 | -------------------------------------------------------------------------------- /agedb-dir/README.md: -------------------------------------------------------------------------------- 1 | # AgeDB-DIR 2 | ## Installation 3 | 4 | #### Prerequisites 5 | 6 | 1. Download AgeDB dataset from [here](https://ibug.doc.ic.ac.uk/resources/agedb/) and extract the zip file (you may need to contact the authors of AgeDB dataset for the zip password) to folder `./data` 7 | 8 | 2. __(Optional)__ We have provided required AgeDB-DIR meta file `agedb.csv` to set up balanced val/test set in folder `./data`. To reproduce the results in the paper, please directly use this file. If you want to try different balanced splits, you can generate it using 9 | 10 | ```bash 11 | python data/create_agedb.py 12 | python data/preprocess_agedb.py 13 | ``` 14 | 15 | #### Dependencies 16 | 17 | - PyTorch (>= 1.2, tested on 1.6) 18 | - tensorboard_logger 19 | - numpy, pandas, scipy, tqdm, matplotlib, PIL 20 | 21 | ## Code Overview 22 | 23 | #### Main Files 24 | 25 | - `train.py`: main training and evaluation script 26 | - `create_agedb.py`: create AgeDB raw meta data 27 | - `preprocess_agedb.py`: create AgeDB-DIR meta file `agedb.csv` with balanced val/test set 28 | 29 | #### Main Arguments 30 | 31 | - `--data_dir`: data directory to place data and meta file 32 | - `--lds`: LDS switch (whether to enable LDS) 33 | - `--fds`: FDS switch (whether to enable FDS) 34 | - `--reweight`: cost-sensitive re-weighting scheme to use 35 | - `--retrain_fc`: whether to retrain regressor 36 | - `--loss`: training loss type 37 | - `--resume`: path to resume checkpoint (for both training and evaluation) 38 | - `--evaluate`: evaluate only flag 39 | - `--pretrained`: path to load backbone weights for regressor re-training (RRT) 40 | 41 | ## Getting Started 42 | 43 | #### Train a vanilla model 44 | 45 | ```bash 46 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_dir --reweight none 47 | ``` 48 | 49 | Always specify `CUDA_VISIBLE_DEVICES` for GPU IDs to be used (by default, 4 GPUs) and `--data_dir` when training a model or directly fix your default data directory path in the code. We will omit these arguments in the following for simplicity. 50 | 51 | #### Train a model using re-weighting 52 | 53 | To perform inverse re-weighting 54 | 55 | ```bash 56 | python train.py --reweight inverse 57 | ``` 58 | 59 | To perform square-root inverse re-weighting 60 | 61 | ```bash 62 | python train.py --reweight sqrt_inv 63 | ``` 64 | 65 | #### Train a model with different losses 66 | 67 | To use Focal-R loss 68 | 69 | ```bash 70 | python train.py --loss focal_l1 71 | ``` 72 | 73 | To use huber loss 74 | 75 | ```bash 76 | python train.py --loss huber 77 | ``` 78 | 79 | #### Train a model using RRT 80 | 81 | ```bash 82 | python train.py [...retrained model arguments...] --retrain_fc --pretrained 83 | ``` 84 | 85 | #### Train a model using LDS 86 | 87 | To use Gaussian kernel (kernel size: 5, sigma: 2) 88 | 89 | ```bash 90 | python train.py --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 91 | ``` 92 | 93 | #### Train a model using FDS 94 | 95 | To use Gaussian kernel (kernel size: 5, sigma: 2) 96 | 97 | ```bash 98 | python train.py --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 99 | ``` 100 | #### Train a model using LDS + FDS 101 | ```bash 102 | python train.py --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 103 | ``` 104 | 105 | #### Evaluate a trained checkpoint 106 | 107 | ```bash 108 | python train.py [...evaluation model arguments...] --evaluate --resume 109 | ``` 110 | 111 | ## Reproduced Benchmarks and Model Zoo 112 | 113 | We provide below reproduced results on AgeDB-DIR (base method `SQINV`, metric `MAE`). 114 | Note that some models could give **better** results than the reported numbers in the paper. 115 | 116 | | Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download | 117 | | :-------: | :-----: | :-------: | :---------: | :------: | :------: | 118 | | LDS | 7.67 | 6.98 | 8.86 | 10.89 | [model](https://drive.google.com/file/d/1CPDlcRCQ1EC4E3x9w955cmILSaVOlkyz/view?usp=sharing) | 119 | | FDS | 7.69 | 7.11 | 8.86 | 9.98 | [model](https://drive.google.com/file/d/1JmRS4V8zmmS9eschsBSmSeQjFWefYlib/view?usp=sharing) | 120 | | LDS + FDS | 7.47 | 6.91 | 8.26 | 10.55 | [model](https://drive.google.com/file/d/1N0nMdu-wuWoAOS1x_m-pnzHcAp61ajY9/view?usp=sharing) | -------------------------------------------------------------------------------- /agedb-dir/data/create_agedb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument("--data_path", type=str, default="./data") 10 | args = parser.parse_args() 11 | return args 12 | 13 | 14 | def main(): 15 | args = get_args() 16 | ages, img_paths = [], [] 17 | 18 | for filename in tqdm(os.listdir(os.path.join(args.data_path, 'AgeDB'))): 19 | _, _, age, gender = filename.split('.')[0].split('_') 20 | 21 | ages.append(age) 22 | img_paths.append(f"AgeDB/{filename}") 23 | 24 | outputs = dict(age=ages, path=img_paths) 25 | output_dir = os.path.join(args.data_path, "meta") 26 | os.makedirs(output_dir, exist_ok=True) 27 | output_path = os.path.join(output_dir, "agedb.csv") 28 | df = pd.DataFrame(data=outputs) 29 | df.to_csv(str(output_path), index=False) 30 | 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /agedb-dir/data/preprocess_agedb.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | BASE_PATH = './data' 7 | 8 | 9 | def visualize_dataset(db="agedb"): 10 | file_path = join(BASE_PATH, "meta", "agedb.csv") 11 | data = pd.read_csv(file_path) 12 | _, ax = plt.subplots(figsize=(6, 3), sharex='all', sharey='all') 13 | ax.hist(data['age'], range(max(data['age']) + 2)) 14 | # ax.set_xlim([0, 102]) 15 | plt.title(f"{db.upper()} (total: {data.shape[0]})") 16 | plt.tight_layout() 17 | plt.show() 18 | 19 | 20 | def make_balanced_testset(db="agedb", max_size=30, seed=666, verbose=True, vis=True, save=False): 21 | file_path = join(BASE_PATH, "meta", f"{db}.csv") 22 | df = pd.read_csv(file_path) 23 | df['age'] = df.age.astype(int) 24 | val_set, test_set = [], [] 25 | import random 26 | random.seed(seed) 27 | for value in range(121): 28 | curr_df = df[df['age'] == value] 29 | curr_data = curr_df['path'].values 30 | random.shuffle(curr_data) 31 | curr_size = min(len(curr_data) // 3, max_size) 32 | val_set += list(curr_data[:curr_size]) 33 | test_set += list(curr_data[curr_size:curr_size * 2]) 34 | if verbose: 35 | print(f"Val: {len(val_set)}\nTest: {len(test_set)}") 36 | assert len(set(val_set).intersection(set(test_set))) == 0 37 | combined_set = dict(zip(val_set, ['val' for _ in range(len(val_set))])) 38 | combined_set.update(dict(zip(test_set, ['test' for _ in range(len(test_set))]))) 39 | df['split'] = df['path'].map(combined_set) 40 | df['split'].fillna('train', inplace=True) 41 | if verbose: 42 | print(df) 43 | if save: 44 | df.to_csv(str(join(BASE_PATH, f"{db}.csv")), index=False) 45 | if vis: 46 | _, ax = plt.subplots(3, figsize=(6, 9), sharex='all') 47 | df_train = df[df['split'] == 'train'] 48 | ax[0].hist(df_train['age'], range(max(df['age']))) 49 | ax[0].set_title(f"[{db.upper()}] train: {df_train.shape[0]}") 50 | ax[1].hist(df[df['split'] == 'val']['age'], range(max(df['age']))) 51 | ax[1].set_title(f"[{db.upper()}] val: {df[df['split'] == 'val'].shape[0]}") 52 | ax[2].hist(df[df['split'] == 'test']['age'], range(max(df['age']))) 53 | ax[2].set_title(f"[{db.upper()}] test: {df[df['split'] == 'test'].shape[0]}") 54 | ax[0].set_xlim([0, 120]) 55 | plt.tight_layout() 56 | plt.show() 57 | 58 | 59 | if __name__ == '__main__': 60 | make_balanced_testset() 61 | visualize_dataset() 62 | -------------------------------------------------------------------------------- /agedb-dir/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | from PIL import Image 5 | from scipy.ndimage import convolve1d 6 | from torch.utils import data 7 | import torchvision.transforms as transforms 8 | 9 | from utils import get_lds_kernel_window 10 | 11 | print = logging.info 12 | 13 | 14 | class AgeDB(data.Dataset): 15 | def __init__(self, df, data_dir, img_size, split='train', reweight='none', 16 | lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2): 17 | self.df = df 18 | self.data_dir = data_dir 19 | self.img_size = img_size 20 | self.split = split 21 | 22 | self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma) 23 | 24 | def __len__(self): 25 | return len(self.df) 26 | 27 | def __getitem__(self, index): 28 | index = index % len(self.df) 29 | row = self.df.iloc[index] 30 | img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB') 31 | transform = self.get_transform() 32 | img = transform(img) 33 | label = np.asarray([row['age']]).astype('float32') 34 | weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)]) 35 | 36 | return img, label, weight 37 | 38 | def get_transform(self): 39 | if self.split == 'train': 40 | transform = transforms.Compose([ 41 | transforms.Resize((self.img_size, self.img_size)), 42 | transforms.RandomCrop(self.img_size, padding=16), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 46 | ]) 47 | else: 48 | transform = transforms.Compose([ 49 | transforms.Resize((self.img_size, self.img_size)), 50 | transforms.ToTensor(), 51 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 52 | ]) 53 | return transform 54 | 55 | def _prepare_weights(self, reweight, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2): 56 | assert reweight in {'none', 'inverse', 'sqrt_inv'} 57 | assert reweight != 'none' if lds else True, \ 58 | "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS" 59 | 60 | value_dict = {x: 0 for x in range(max_target)} 61 | labels = self.df['age'].values 62 | for label in labels: 63 | value_dict[min(max_target - 1, int(label))] += 1 64 | if reweight == 'sqrt_inv': 65 | value_dict = {k: np.sqrt(v) for k, v in value_dict.items()} 66 | elif reweight == 'inverse': 67 | value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight 68 | num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels] 69 | if not len(num_per_label) or reweight == 'none': 70 | return None 71 | print(f"Using re-weighting: [{reweight.upper()}]") 72 | 73 | if lds: 74 | lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma) 75 | print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})') 76 | smoothed_value = convolve1d( 77 | np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant') 78 | num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels] 79 | 80 | weights = [np.float32(1 / x) for x in num_per_label] 81 | scaling = len(weights) / np.sum(weights) 82 | weights = [scaling * x for x in weights] 83 | return weights 84 | -------------------------------------------------------------------------------- /agedb-dir/fds.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from scipy.ndimage import gaussian_filter1d 4 | from scipy.signal.windows import triang 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from utils import calibrate_mean_var 10 | 11 | print = logging.info 12 | 13 | 14 | class FDS(nn.Module): 15 | 16 | def __init__(self, feature_dim, bucket_num=100, bucket_start=3, start_update=0, start_smooth=1, 17 | kernel='gaussian', ks=5, sigma=2, momentum=0.9): 18 | super(FDS, self).__init__() 19 | self.feature_dim = feature_dim 20 | self.bucket_num = bucket_num 21 | self.bucket_start = bucket_start 22 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma) 23 | self.half_ks = (ks - 1) // 2 24 | self.momentum = momentum 25 | self.start_update = start_update 26 | self.start_smooth = start_smooth 27 | 28 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update)) 29 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim)) 30 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim)) 31 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 32 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 33 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 34 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 35 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start)) 36 | 37 | @staticmethod 38 | def _get_kernel_window(kernel, ks, sigma): 39 | assert kernel in ['gaussian', 'triang', 'laplace'] 40 | half_ks = (ks - 1) // 2 41 | if kernel == 'gaussian': 42 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 43 | base_kernel = np.array(base_kernel, dtype=np.float32) 44 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma)) 45 | elif kernel == 'triang': 46 | kernel_window = triang(ks) / sum(triang(ks)) 47 | else: 48 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 49 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1))) 50 | 51 | print(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})') 52 | return torch.tensor(kernel_window, dtype=torch.float32).cuda() 53 | 54 | def _update_last_epoch_stats(self): 55 | self.running_mean_last_epoch = self.running_mean 56 | self.running_var_last_epoch = self.running_var 57 | 58 | self.smoothed_mean_last_epoch = F.conv1d( 59 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0), 60 | pad=(self.half_ks, self.half_ks), mode='reflect'), 61 | weight=self.kernel_window.view(1, 1, -1), padding=0 62 | ).permute(2, 1, 0).squeeze(1) 63 | self.smoothed_var_last_epoch = F.conv1d( 64 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0), 65 | pad=(self.half_ks, self.half_ks), mode='reflect'), 66 | weight=self.kernel_window.view(1, 1, -1), padding=0 67 | ).permute(2, 1, 0).squeeze(1) 68 | 69 | def reset(self): 70 | self.running_mean.zero_() 71 | self.running_var.fill_(1) 72 | self.running_mean_last_epoch.zero_() 73 | self.running_var_last_epoch.fill_(1) 74 | self.smoothed_mean_last_epoch.zero_() 75 | self.smoothed_var_last_epoch.fill_(1) 76 | self.num_samples_tracked.zero_() 77 | 78 | def update_last_epoch_stats(self, epoch): 79 | if epoch == self.epoch + 1: 80 | self.epoch += 1 81 | self._update_last_epoch_stats() 82 | print(f"Updated smoothed statistics on Epoch [{epoch}]!") 83 | 84 | def update_running_stats(self, features, labels, epoch): 85 | if epoch < self.epoch: 86 | return 87 | 88 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!" 89 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!" 90 | 91 | for label in torch.unique(labels): 92 | if label > self.bucket_num - 1 or label < self.bucket_start: 93 | continue 94 | elif label == self.bucket_start: 95 | curr_feats = features[labels <= label] 96 | elif label == self.bucket_num - 1: 97 | curr_feats = features[labels >= label] 98 | else: 99 | curr_feats = features[labels == label] 100 | curr_num_sample = curr_feats.size(0) 101 | curr_mean = torch.mean(curr_feats, 0) 102 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False) 103 | 104 | self.num_samples_tracked[int(label - self.bucket_start)] += curr_num_sample 105 | factor = self.momentum if self.momentum is not None else \ 106 | (1 - curr_num_sample / float(self.num_samples_tracked[int(label - self.bucket_start)])) 107 | factor = 0 if epoch == self.start_update else factor 108 | self.running_mean[int(label - self.bucket_start)] = \ 109 | (1 - factor) * curr_mean + factor * self.running_mean[int(label - self.bucket_start)] 110 | self.running_var[int(label - self.bucket_start)] = \ 111 | (1 - factor) * curr_var + factor * self.running_var[int(label - self.bucket_start)] 112 | 113 | print(f"Updated running statistics with Epoch [{epoch}] features!") 114 | 115 | def smooth(self, features, labels, epoch): 116 | if epoch < self.start_smooth: 117 | return features 118 | 119 | labels = labels.squeeze(1) 120 | for label in torch.unique(labels): 121 | if label > self.bucket_num - 1 or label < self.bucket_start: 122 | continue 123 | elif label == self.bucket_start: 124 | features[labels <= label] = calibrate_mean_var( 125 | features[labels <= label], 126 | self.running_mean_last_epoch[int(label - self.bucket_start)], 127 | self.running_var_last_epoch[int(label - self.bucket_start)], 128 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 129 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 130 | elif label == self.bucket_num - 1: 131 | features[labels >= label] = calibrate_mean_var( 132 | features[labels >= label], 133 | self.running_mean_last_epoch[int(label - self.bucket_start)], 134 | self.running_var_last_epoch[int(label - self.bucket_start)], 135 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 136 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 137 | else: 138 | features[labels == label] = calibrate_mean_var( 139 | features[labels == label], 140 | self.running_mean_last_epoch[int(label - self.bucket_start)], 141 | self.running_var_last_epoch[int(label - self.bucket_start)], 142 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 143 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 144 | return features 145 | -------------------------------------------------------------------------------- /agedb-dir/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def weighted_mse_loss(inputs, targets, weights=None): 6 | loss = (inputs - targets) ** 2 7 | if weights is not None: 8 | loss *= weights.expand_as(loss) 9 | loss = torch.mean(loss) 10 | return loss 11 | 12 | 13 | def weighted_l1_loss(inputs, targets, weights=None): 14 | loss = F.l1_loss(inputs, targets, reduction='none') 15 | if weights is not None: 16 | loss *= weights.expand_as(loss) 17 | loss = torch.mean(loss) 18 | return loss 19 | 20 | 21 | def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1): 22 | loss = (inputs - targets) ** 2 23 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 24 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 25 | if weights is not None: 26 | loss *= weights.expand_as(loss) 27 | loss = torch.mean(loss) 28 | return loss 29 | 30 | 31 | def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1): 32 | loss = F.l1_loss(inputs, targets, reduction='none') 33 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 34 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 35 | if weights is not None: 36 | loss *= weights.expand_as(loss) 37 | loss = torch.mean(loss) 38 | return loss 39 | 40 | 41 | def weighted_huber_loss(inputs, targets, weights=None, beta=1.): 42 | l1_loss = torch.abs(inputs - targets) 43 | cond = l1_loss < beta 44 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta) 45 | if weights is not None: 46 | loss *= weights.expand_as(loss) 47 | loss = torch.mean(loss) 48 | return loss 49 | -------------------------------------------------------------------------------- /agedb-dir/resnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import torch.nn as nn 4 | from fds import FDS 5 | 6 | print = logging.info 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | out += residual 37 | out = self.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * 4) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | residual = x 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | out = self.relu(out) 64 | out = self.conv3(out) 65 | out = self.bn3(out) 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | out += residual 69 | out = self.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | 75 | def __init__(self, block, layers, fds, bucket_num, bucket_start, start_update, start_smooth, 76 | kernel, ks, sigma, momentum, dropout=None): 77 | self.inplanes = 64 78 | super(ResNet, self).__init__() 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 83 | self.layer1 = self._make_layer(block, 64, layers[0]) 84 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 85 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 86 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 87 | self.avgpool = nn.AvgPool2d(7, stride=1) 88 | self.linear = nn.Linear(512 * block.expansion, 1) 89 | 90 | if fds: 91 | self.FDS = FDS( 92 | feature_dim=512 * block.expansion, bucket_num=bucket_num, bucket_start=bucket_start, 93 | start_update=start_update, start_smooth=start_smooth, kernel=kernel, ks=ks, sigma=sigma, momentum=momentum 94 | ) 95 | self.fds = fds 96 | self.start_smooth = start_smooth 97 | 98 | self.use_dropout = True if dropout else False 99 | if self.use_dropout: 100 | print(f'Using dropout: {dropout}') 101 | self.dropout = nn.Dropout(p=dropout) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | downsample = None 113 | if stride != 1 or self.inplanes != planes * block.expansion: 114 | downsample = nn.Sequential( 115 | nn.Conv2d(self.inplanes, planes * block.expansion, 116 | kernel_size=1, stride=stride, bias=False), 117 | nn.BatchNorm2d(planes * block.expansion), 118 | ) 119 | layers = [] 120 | layers.append(block(self.inplanes, planes, stride, downsample)) 121 | self.inplanes = planes * block.expansion 122 | for i in range(1, blocks): 123 | layers.append(block(self.inplanes, planes)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x, targets=None, epoch=None): 128 | x = self.conv1(x) 129 | x = self.bn1(x) 130 | x = self.relu(x) 131 | x = self.maxpool(x) 132 | 133 | x = self.layer1(x) 134 | x = self.layer2(x) 135 | x = self.layer3(x) 136 | x = self.layer4(x) 137 | x = self.avgpool(x) 138 | encoding = x.view(x.size(0), -1) 139 | 140 | encoding_s = encoding 141 | 142 | if self.training and self.fds: 143 | if epoch >= self.start_smooth: 144 | encoding_s = self.FDS.smooth(encoding_s, targets, epoch) 145 | 146 | if self.use_dropout: 147 | encoding_s = self.dropout(encoding_s) 148 | x = self.linear(encoding_s) 149 | 150 | if self.training and self.fds: 151 | return x, encoding 152 | else: 153 | return x 154 | 155 | 156 | def resnet50(**kwargs): 157 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 158 | -------------------------------------------------------------------------------- /agedb-dir/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import logging 5 | import numpy as np 6 | from scipy.ndimage import gaussian_filter1d 7 | from scipy.signal.windows import triang 8 | 9 | 10 | class AverageMeter(object): 11 | def __init__(self, name, fmt=':f'): 12 | self.name = name 13 | self.fmt = fmt 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | def __str__(self): 29 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 30 | return fmtstr.format(**self.__dict__) 31 | 32 | 33 | class ProgressMeter(object): 34 | def __init__(self, num_batches, meters, prefix=""): 35 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 36 | self.meters = meters 37 | self.prefix = prefix 38 | 39 | def display(self, batch): 40 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 41 | entries += [str(meter) for meter in self.meters] 42 | logging.info('\t'.join(entries)) 43 | 44 | @staticmethod 45 | def _get_batch_fmtstr(num_batches): 46 | num_digits = len(str(num_batches // 1)) 47 | fmt = '{:' + str(num_digits) + 'd}' 48 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 49 | 50 | 51 | def query_yes_no(question): 52 | """ Ask a yes/no question via input() and return their answer. """ 53 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} 54 | prompt = " [Y/n] " 55 | 56 | while True: 57 | print(question + prompt, end=':') 58 | choice = input().lower() 59 | if choice == '': 60 | return valid['y'] 61 | elif choice in valid: 62 | return valid[choice] 63 | else: 64 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") 65 | 66 | 67 | def prepare_folders(args): 68 | folders_util = [args.store_root, os.path.join(args.store_root, args.store_name)] 69 | if os.path.exists(folders_util[-1]) and not args.resume and not args.pretrained and not args.evaluate: 70 | if query_yes_no('overwrite previous folder: {} ?'.format(folders_util[-1])): 71 | shutil.rmtree(folders_util[-1]) 72 | print(folders_util[-1] + ' removed.') 73 | else: 74 | raise RuntimeError('Output folder {} already exists'.format(folders_util[-1])) 75 | for folder in folders_util: 76 | if not os.path.exists(folder): 77 | print(f"===> Creating folder: {folder}") 78 | os.mkdir(folder) 79 | 80 | 81 | def adjust_learning_rate(optimizer, epoch, args): 82 | lr = args.lr 83 | for milestone in args.schedule: 84 | lr *= 0.1 if epoch >= milestone else 1. 85 | for param_group in optimizer.param_groups: 86 | param_group['lr'] = lr 87 | 88 | 89 | def save_checkpoint(args, state, is_best, prefix=''): 90 | filename = f"{args.store_root}/{args.store_name}/{prefix}ckpt.pth.tar" 91 | torch.save(state, filename) 92 | if is_best: 93 | logging.info("===> Saving current best checkpoint...") 94 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 95 | 96 | 97 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.1, clip_max=10): 98 | if torch.sum(v1) < 1e-10: 99 | return matrix 100 | if (v1 == 0.).any(): 101 | valid = (v1 != 0.) 102 | factor = torch.clamp(v2[valid] / v1[valid], clip_min, clip_max) 103 | matrix[:, valid] = (matrix[:, valid] - m1[valid]) * torch.sqrt(factor) + m2[valid] 104 | return matrix 105 | 106 | factor = torch.clamp(v2 / v1, clip_min, clip_max) 107 | return (matrix - m1) * torch.sqrt(factor) + m2 108 | 109 | 110 | def get_lds_kernel_window(kernel, ks, sigma): 111 | assert kernel in ['gaussian', 'triang', 'laplace'] 112 | half_ks = (ks - 1) // 2 113 | if kernel == 'gaussian': 114 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 115 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) 116 | elif kernel == 'triang': 117 | kernel_window = triang(ks) 118 | else: 119 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 120 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1))) 121 | 122 | return kernel_window 123 | -------------------------------------------------------------------------------- /imdb-wiki-dir/README.md: -------------------------------------------------------------------------------- 1 | # IMDB-WIKI-DIR 2 | ## Installation 3 | 4 | #### Prerequisites 5 | 6 | 1. Download and extract IMDB faces and WIKI faces respectively using 7 | 8 | ```bash 9 | python download_imdb_wiki.py 10 | ``` 11 | 12 | 2. __(Optional)__ We have provided required IMDB-WIKI-DIR meta file `imdb_wiki.csv` to set up balanced val/test set in folder `./data`. To reproduce the results in the paper, please directly use this file. You can also generate it using 13 | 14 | ```bash 15 | python data/create_imdb_wiki.py 16 | python data/preprocess_imdb_wiki.py 17 | ``` 18 | 19 | #### Dependencies 20 | 21 | - PyTorch (>= 1.2, tested on 1.6) 22 | - tensorboard_logger 23 | - numpy, pandas, scipy, tqdm, matplotlib, PIL, wget 24 | 25 | ## Code Overview 26 | 27 | #### Main Files 28 | 29 | - `train.py`: main training and evaluation script 30 | - `create_imdb_wiki.py`: create IMDB-WIKI raw meta data 31 | - `preprocess_imdb_wiki.py`: create IMDB-WIKI-DIR meta file `imdb_wiki.csv` with balanced val/test set 32 | 33 | #### Main Arguments 34 | 35 | - `--data_dir`: data directory to place data and meta file 36 | - `--lds`: LDS switch (whether to enable LDS) 37 | - `--fds`: FDS switch (whether to enable FDS) 38 | - `--reweight`: cost-sensitive re-weighting scheme to use 39 | - `--retrain_fc`: whether to retrain regressor 40 | - `--loss`: training loss type 41 | - `--resume`: path to resume checkpoint (for both training and evaluation) 42 | - `--evaluate`: evaluate only flag 43 | - `--pretrained`: path to load backbone weights for regressor re-training (RRT) 44 | 45 | ## Getting Started 46 | 47 | #### Train a vanilla model 48 | 49 | ```bash 50 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_dir --reweight none 51 | ``` 52 | 53 | Always specify `CUDA_VISIBLE_DEVICES` for GPU IDs to be used (by default, 4 GPUs) and `--data_dir` when training a model or directly fix your default data directory path in the code. We will omit these arguments in the following for simplicity. 54 | 55 | #### Train a model using re-weighting 56 | 57 | To perform inverse re-weighting 58 | 59 | ```bash 60 | python train.py --reweight inverse 61 | ``` 62 | 63 | To perform square-root inverse re-weighting 64 | 65 | ```bash 66 | python train.py --reweight sqrt_inv 67 | ``` 68 | 69 | #### Train a model with different losses 70 | 71 | To use Focal-R loss 72 | 73 | ```bash 74 | python train.py --loss focal_l1 75 | ``` 76 | 77 | To use huber loss 78 | 79 | ```bash 80 | python train.py --loss huber 81 | ``` 82 | 83 | #### Train a model using RRT 84 | 85 | ```bash 86 | python train.py [...retrained model arguments...] --retrain_fc --pretrained 87 | ``` 88 | 89 | #### Train a model using LDS 90 | 91 | To use Gaussian kernel (kernel size: 5, sigma: 2) 92 | 93 | ```bash 94 | python train.py --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 95 | ``` 96 | 97 | #### Train a model using FDS 98 | 99 | To use Gaussian kernel (kernel size: 5, sigma: 2) 100 | 101 | ```bash 102 | python train.py --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 103 | ``` 104 | #### Train a model using LDS + FDS 105 | ```bash 106 | python train.py --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 107 | ``` 108 | 109 | #### Evaluate a trained checkpoint 110 | 111 | ```bash 112 | python train.py [...evaluation model arguments...] --evaluate --resume 113 | ``` 114 | 115 | ## Reproduced Benchmarks and Model Zoo 116 | 117 | We provide below reproduced results on IMDB-WIKI-DIR (base method `SQINV`, metric `MAE`). 118 | Note that some models could give **better** results than the reported numbers in the paper. 119 | 120 | 121 | | Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download | 122 | | :-------: | :-----: | :-------: | :---------: | :------: | :------: | 123 | | LDS | 7.87 | 7.31 | 12.45 | 22.60 | [model](https://drive.google.com/file/d/1HnGw1gs6UAlvbol4EulHX_Kqx_pwJZ70/view?usp=sharing) | 124 | | FDS | 7.66 | 7.06 | 12.60 | 22.37 | [model](https://drive.google.com/file/d/1H7_dDMn83-paFrcrEmOiDLZmoham4js9/view?usp=sharing) | 125 | | LDS + FDS | 7.68 | 7.07 | 12.79 | 21.85 | [model](https://drive.google.com/file/d/1C_YxpTW-rhCRIF4wnFShojp5ydAjFmHo/view?usp=sharing) | -------------------------------------------------------------------------------- /imdb-wiki-dir/data/create_imdb_wiki.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | from scipy.io import loadmat 7 | from datetime import datetime 8 | 9 | 10 | def calc_age(taken, dob): 11 | birth = datetime.fromordinal(max(int(dob) - 366, 1)) 12 | # assume the photo was taken in the middle of the year 13 | if birth.month < 7: 14 | return taken - birth.year 15 | else: 16 | return taken - birth.year - 1 17 | 18 | 19 | def get_meta(mat_path, db): 20 | meta = loadmat(mat_path) 21 | full_path = meta[db][0, 0]["full_path"][0] 22 | dob = meta[db][0, 0]["dob"][0] # date 23 | gender = meta[db][0, 0]["gender"][0] 24 | photo_taken = meta[db][0, 0]["photo_taken"][0] # year 25 | face_score = meta[db][0, 0]["face_score"][0] 26 | second_face_score = meta[db][0, 0]["second_face_score"][0] 27 | age = [calc_age(photo_taken[i], dob[i]) for i in range(len(dob))] 28 | 29 | return full_path, dob, gender, photo_taken, face_score, second_face_score, age 30 | 31 | 32 | def load_data(mat_path): 33 | d = loadmat(mat_path) 34 | return d["image"], d["gender"][0], d["age"][0], d["db"][0], d["img_size"][0, 0], d["min_score"][0, 0] 35 | 36 | 37 | def combine_dataset(path='meta'): 38 | args = get_args() 39 | data_imdb = pd.read_csv(os.path.join(args.data_path, path, "imdb.csv")) 40 | data_wiki = pd.read_csv(os.path.join(args.data_path, path, "wiki.csv")) 41 | data_imdb['path'] = data_imdb['path'].apply(lambda x: f"imdb_crop/{x}") 42 | data_wiki['path'] = data_wiki['path'].apply(lambda x: f"wiki_crop/{x}") 43 | df = pd.concat((data_imdb, data_wiki)) 44 | output_path = os.path.join(args.data_path, path, "imdb_wiki.csv") 45 | df.to_csv(str(output_path), index=False) 46 | 47 | 48 | def get_args(): 49 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 50 | parser.add_argument("--data_path", type=str, default="./data") 51 | parser.add_argument("--min_score", type=float, default=1., help="minimum face score") 52 | args = parser.parse_args() 53 | return args 54 | 55 | 56 | def create(db): 57 | args = get_args() 58 | mat_path = os.path.join(args.data_path, f"{db}_crop", f"{db}.mat") 59 | full_path, dob, gender, photo_taken, face_score, second_face_score, age = get_meta(mat_path, db) 60 | 61 | ages, img_paths = [], [] 62 | 63 | for i in tqdm(range(len(face_score))): 64 | if face_score[i] < args.min_score: 65 | continue 66 | 67 | if (~np.isnan(second_face_score[i])) and second_face_score[i] > 0.0: 68 | continue 69 | 70 | if ~(0 <= age[i] <= 200): 71 | continue 72 | 73 | ages.append(age[i]) 74 | img_paths.append(full_path[i][0]) 75 | 76 | outputs = dict(age=ages, path=img_paths) 77 | output_dir = os.path.join(args.data_path, "meta") 78 | os.makedirs(output_dir, exist_ok=True) 79 | output_path = os.path.join(output_dir, f"{db}.csv") 80 | df = pd.DataFrame(data=outputs) 81 | df.to_csv(str(output_path), index=False) 82 | 83 | 84 | if __name__ == '__main__': 85 | create("imdb") 86 | create("wiki") 87 | combine_dataset() 88 | -------------------------------------------------------------------------------- /imdb-wiki-dir/data/preprocess_imdb_wiki.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | BASE_PATH = './data' 7 | 8 | 9 | def visualize_dataset(db="imdb_wiki"): 10 | file_path = join(BASE_PATH, "meta", f"{db}.csv") 11 | data = pd.read_csv(file_path) 12 | _, ax = plt.subplots(figsize=(6, 3), sharex='all', sharey='all') 13 | ax.hist(data['age'], range(max(data['age']))) 14 | ax.set_xlim([0, 120]) 15 | plt.title(f"{db.upper()} (total: {data.shape[0]})") 16 | plt.tight_layout() 17 | plt.show() 18 | 19 | 20 | def make_balanced_testset(db="imdb_wiki", max_size=150, seed=666, verbose=True, vis=True, save=False): 21 | file_path = join(BASE_PATH, "meta", f"{db}.csv") 22 | df = pd.read_csv(file_path) 23 | df['age'] = df.age.astype(int) 24 | val_set, test_set = [], [] 25 | import random 26 | random.seed(seed) 27 | for value in range(121): 28 | curr_df = df[df['age'] == value] 29 | curr_data = curr_df['path'].values 30 | random.shuffle(curr_data) 31 | curr_size = min(len(curr_data) // 3, max_size) 32 | val_set += list(curr_data[:curr_size]) 33 | test_set += list(curr_data[curr_size:curr_size * 2]) 34 | if verbose: 35 | print(f"Val: {len(val_set)}\nTest: {len(test_set)}") 36 | assert len(set(val_set).intersection(set(test_set))) == 0 37 | combined_set = dict(zip(val_set, ['val' for _ in range(len(val_set))])) 38 | combined_set.update(dict(zip(test_set, ['test' for _ in range(len(test_set))]))) 39 | df['split'] = df['path'].map(combined_set) 40 | df['split'].fillna('train', inplace=True) 41 | if verbose: 42 | print(df) 43 | if save: 44 | df.to_csv(str(join(BASE_PATH, f"{db}.csv")), index=False) 45 | if vis: 46 | _, ax = plt.subplots(3, figsize=(6, 9), sharex='all') 47 | df_train = df[df['split'] == 'train'] 48 | # df_train = df_train[(df_train['age'] <= 20) | (df_train['age'] > 50)] 49 | ax[0].hist(df_train['age'], range(max(df['age']))) 50 | ax[0].set_title(f"[{db.upper()}] train: {df_train.shape[0]}") 51 | ax[1].hist(df[df['split'] == 'val']['age'], range(max(df['age']))) 52 | ax[1].set_title(f"[{db.upper()}] val: {df[df['split'] == 'val'].shape[0]}") 53 | ax[2].hist(df[df['split'] == 'test']['age'], range(max(df['age']))) 54 | ax[2].set_title(f"[{db.upper()}] test: {df[df['split'] == 'test'].shape[0]}") 55 | ax[0].set_xlim([0, 120]) 56 | plt.tight_layout() 57 | plt.show() 58 | 59 | 60 | if __name__ == '__main__': 61 | make_balanced_testset() 62 | visualize_dataset(db="imdb_wiki") 63 | -------------------------------------------------------------------------------- /imdb-wiki-dir/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | from PIL import Image 5 | from scipy.ndimage import convolve1d 6 | from torch.utils import data 7 | import torchvision.transforms as transforms 8 | 9 | from utils import get_lds_kernel_window 10 | 11 | print = logging.info 12 | 13 | 14 | class IMDBWIKI(data.Dataset): 15 | def __init__(self, df, data_dir, img_size, split='train', reweight='none', 16 | lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2): 17 | self.df = df 18 | self.data_dir = data_dir 19 | self.img_size = img_size 20 | self.split = split 21 | 22 | self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma) 23 | 24 | def __len__(self): 25 | return len(self.df) 26 | 27 | def __getitem__(self, index): 28 | index = index % len(self.df) 29 | row = self.df.iloc[index] 30 | img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB') 31 | transform = self.get_transform() 32 | img = transform(img) 33 | label = np.asarray([row['age']]).astype('float32') 34 | weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)]) 35 | 36 | return img, label, weight 37 | 38 | def get_transform(self): 39 | if self.split == 'train': 40 | transform = transforms.Compose([ 41 | transforms.Resize((self.img_size, self.img_size)), 42 | transforms.RandomCrop(self.img_size, padding=16), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 46 | ]) 47 | else: 48 | transform = transforms.Compose([ 49 | transforms.Resize((self.img_size, self.img_size)), 50 | transforms.ToTensor(), 51 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]), 52 | ]) 53 | return transform 54 | 55 | def _prepare_weights(self, reweight, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2): 56 | assert reweight in {'none', 'inverse', 'sqrt_inv'} 57 | assert reweight != 'none' if lds else True, \ 58 | "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS" 59 | 60 | value_dict = {x: 0 for x in range(max_target)} 61 | labels = self.df['age'].values 62 | for label in labels: 63 | value_dict[min(max_target - 1, int(label))] += 1 64 | if reweight == 'sqrt_inv': 65 | value_dict = {k: np.sqrt(v) for k, v in value_dict.items()} 66 | elif reweight == 'inverse': 67 | value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight 68 | num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels] 69 | if not len(num_per_label) or reweight == 'none': 70 | return None 71 | print(f"Using re-weighting: [{reweight.upper()}]") 72 | 73 | if lds: 74 | lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma) 75 | print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})') 76 | smoothed_value = convolve1d( 77 | np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant') 78 | num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels] 79 | 80 | weights = [np.float32(1 / x) for x in num_per_label] 81 | scaling = len(weights) / np.sum(weights) 82 | weights = [scaling * x for x in weights] 83 | return weights 84 | -------------------------------------------------------------------------------- /imdb-wiki-dir/download_imdb_wiki.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | 4 | print("Downloading IMDB faces...") 5 | imdb_file = "./data/imdb_crop.tar" 6 | wget.download("https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/imdb_crop.tar", out=imdb_file) 7 | print("Downloading WIKI faces...") 8 | wiki_file = "./data/wiki_crop.tar" 9 | wget.download("https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar", out=wiki_file) 10 | print("Extracting IMDB faces...") 11 | os.system(f"tar -xvf {imdb_file} -C ./data") 12 | print("Extracting WIKI faces...") 13 | os.system(f"tar -xvf {wiki_file} -C ./data") 14 | os.remove(imdb_file) 15 | os.remove(wiki_file) 16 | print("\nCompleted!") -------------------------------------------------------------------------------- /imdb-wiki-dir/fds.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from scipy.ndimage import gaussian_filter1d 4 | from scipy.signal.windows import triang 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from utils import calibrate_mean_var 10 | 11 | print = logging.info 12 | 13 | 14 | class FDS(nn.Module): 15 | 16 | def __init__(self, feature_dim, bucket_num=100, bucket_start=0, start_update=0, start_smooth=1, 17 | kernel='gaussian', ks=5, sigma=2, momentum=0.9): 18 | super(FDS, self).__init__() 19 | self.feature_dim = feature_dim 20 | self.bucket_num = bucket_num 21 | self.bucket_start = bucket_start 22 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma) 23 | self.half_ks = (ks - 1) // 2 24 | self.momentum = momentum 25 | self.start_update = start_update 26 | self.start_smooth = start_smooth 27 | 28 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update)) 29 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim)) 30 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim)) 31 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 32 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 33 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 34 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 35 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start)) 36 | 37 | @staticmethod 38 | def _get_kernel_window(kernel, ks, sigma): 39 | assert kernel in ['gaussian', 'triang', 'laplace'] 40 | half_ks = (ks - 1) // 2 41 | if kernel == 'gaussian': 42 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 43 | base_kernel = np.array(base_kernel, dtype=np.float32) 44 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma)) 45 | elif kernel == 'triang': 46 | kernel_window = triang(ks) / sum(triang(ks)) 47 | else: 48 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 49 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1))) 50 | 51 | print(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})') 52 | return torch.tensor(kernel_window, dtype=torch.float32).cuda() 53 | 54 | def _update_last_epoch_stats(self): 55 | self.running_mean_last_epoch = self.running_mean 56 | self.running_var_last_epoch = self.running_var 57 | 58 | self.smoothed_mean_last_epoch = F.conv1d( 59 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0), 60 | pad=(self.half_ks, self.half_ks), mode='reflect'), 61 | weight=self.kernel_window.view(1, 1, -1), padding=0 62 | ).permute(2, 1, 0).squeeze(1) 63 | self.smoothed_var_last_epoch = F.conv1d( 64 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0), 65 | pad=(self.half_ks, self.half_ks), mode='reflect'), 66 | weight=self.kernel_window.view(1, 1, -1), padding=0 67 | ).permute(2, 1, 0).squeeze(1) 68 | 69 | def reset(self): 70 | self.running_mean.zero_() 71 | self.running_var.fill_(1) 72 | self.running_mean_last_epoch.zero_() 73 | self.running_var_last_epoch.fill_(1) 74 | self.smoothed_mean_last_epoch.zero_() 75 | self.smoothed_var_last_epoch.fill_(1) 76 | self.num_samples_tracked.zero_() 77 | 78 | def update_last_epoch_stats(self, epoch): 79 | if epoch == self.epoch + 1: 80 | self.epoch += 1 81 | self._update_last_epoch_stats() 82 | print(f"Updated smoothed statistics on Epoch [{epoch}]!") 83 | 84 | def update_running_stats(self, features, labels, epoch): 85 | if epoch < self.epoch: 86 | return 87 | 88 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!" 89 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!" 90 | 91 | for label in torch.unique(labels): 92 | if label > self.bucket_num - 1 or label < self.bucket_start: 93 | continue 94 | elif label == self.bucket_start: 95 | curr_feats = features[labels <= label] 96 | elif label == self.bucket_num - 1: 97 | curr_feats = features[labels >= label] 98 | else: 99 | curr_feats = features[labels == label] 100 | curr_num_sample = curr_feats.size(0) 101 | curr_mean = torch.mean(curr_feats, 0) 102 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False) 103 | 104 | self.num_samples_tracked[int(label - self.bucket_start)] += curr_num_sample 105 | factor = self.momentum if self.momentum is not None else \ 106 | (1 - curr_num_sample / float(self.num_samples_tracked[int(label - self.bucket_start)])) 107 | factor = 0 if epoch == self.start_update else factor 108 | self.running_mean[int(label - self.bucket_start)] = \ 109 | (1 - factor) * curr_mean + factor * self.running_mean[int(label - self.bucket_start)] 110 | self.running_var[int(label - self.bucket_start)] = \ 111 | (1 - factor) * curr_var + factor * self.running_var[int(label - self.bucket_start)] 112 | 113 | print(f"Updated running statistics with Epoch [{epoch}] features!") 114 | 115 | def smooth(self, features, labels, epoch): 116 | if epoch < self.start_smooth: 117 | return features 118 | 119 | labels = labels.squeeze(1) 120 | for label in torch.unique(labels): 121 | if label > self.bucket_num - 1 or label < self.bucket_start: 122 | continue 123 | elif label == self.bucket_start: 124 | features[labels <= label] = calibrate_mean_var( 125 | features[labels <= label], 126 | self.running_mean_last_epoch[int(label - self.bucket_start)], 127 | self.running_var_last_epoch[int(label - self.bucket_start)], 128 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 129 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 130 | elif label == self.bucket_num - 1: 131 | features[labels >= label] = calibrate_mean_var( 132 | features[labels >= label], 133 | self.running_mean_last_epoch[int(label - self.bucket_start)], 134 | self.running_var_last_epoch[int(label - self.bucket_start)], 135 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 136 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 137 | else: 138 | features[labels == label] = calibrate_mean_var( 139 | features[labels == label], 140 | self.running_mean_last_epoch[int(label - self.bucket_start)], 141 | self.running_var_last_epoch[int(label - self.bucket_start)], 142 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)], 143 | self.smoothed_var_last_epoch[int(label - self.bucket_start)]) 144 | return features 145 | -------------------------------------------------------------------------------- /imdb-wiki-dir/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def weighted_mse_loss(inputs, targets, weights=None): 6 | loss = (inputs - targets) ** 2 7 | if weights is not None: 8 | loss *= weights.expand_as(loss) 9 | loss = torch.mean(loss) 10 | return loss 11 | 12 | 13 | def weighted_l1_loss(inputs, targets, weights=None): 14 | loss = F.l1_loss(inputs, targets, reduction='none') 15 | if weights is not None: 16 | loss *= weights.expand_as(loss) 17 | loss = torch.mean(loss) 18 | return loss 19 | 20 | 21 | def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1): 22 | loss = (inputs - targets) ** 2 23 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 24 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 25 | if weights is not None: 26 | loss *= weights.expand_as(loss) 27 | loss = torch.mean(loss) 28 | return loss 29 | 30 | 31 | def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1): 32 | loss = F.l1_loss(inputs, targets, reduction='none') 33 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 34 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 35 | if weights is not None: 36 | loss *= weights.expand_as(loss) 37 | loss = torch.mean(loss) 38 | return loss 39 | 40 | 41 | def weighted_huber_loss(inputs, targets, weights=None, beta=1.): 42 | l1_loss = torch.abs(inputs - targets) 43 | cond = l1_loss < beta 44 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta) 45 | if weights is not None: 46 | loss *= weights.expand_as(loss) 47 | loss = torch.mean(loss) 48 | return loss 49 | -------------------------------------------------------------------------------- /imdb-wiki-dir/resnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import torch.nn as nn 4 | from fds import FDS 5 | 6 | print = logging.info 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | out += residual 37 | out = self.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * 4) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | residual = x 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | out = self.relu(out) 64 | out = self.conv3(out) 65 | out = self.bn3(out) 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | out += residual 69 | out = self.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | 75 | def __init__(self, block, layers, fds, bucket_num, bucket_start, start_update, start_smooth, 76 | kernel, ks, sigma, momentum, dropout=None): 77 | self.inplanes = 64 78 | super(ResNet, self).__init__() 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 83 | self.layer1 = self._make_layer(block, 64, layers[0]) 84 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 85 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 86 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 87 | self.avgpool = nn.AvgPool2d(7, stride=1) 88 | self.linear = nn.Linear(512 * block.expansion, 1) 89 | 90 | if fds: 91 | self.FDS = FDS( 92 | feature_dim=512 * block.expansion, bucket_num=bucket_num, bucket_start=bucket_start, 93 | start_update=start_update, start_smooth=start_smooth, kernel=kernel, ks=ks, sigma=sigma, momentum=momentum 94 | ) 95 | self.fds = fds 96 | self.start_smooth = start_smooth 97 | 98 | self.use_dropout = True if dropout else False 99 | if self.use_dropout: 100 | print(f'Using dropout: {dropout}') 101 | self.dropout = nn.Dropout(p=dropout) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | downsample = None 113 | if stride != 1 or self.inplanes != planes * block.expansion: 114 | downsample = nn.Sequential( 115 | nn.Conv2d(self.inplanes, planes * block.expansion, 116 | kernel_size=1, stride=stride, bias=False), 117 | nn.BatchNorm2d(planes * block.expansion), 118 | ) 119 | layers = [] 120 | layers.append(block(self.inplanes, planes, stride, downsample)) 121 | self.inplanes = planes * block.expansion 122 | for i in range(1, blocks): 123 | layers.append(block(self.inplanes, planes)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x, targets=None, epoch=None): 128 | x = self.conv1(x) 129 | x = self.bn1(x) 130 | x = self.relu(x) 131 | x = self.maxpool(x) 132 | 133 | x = self.layer1(x) 134 | x = self.layer2(x) 135 | x = self.layer3(x) 136 | x = self.layer4(x) 137 | x = self.avgpool(x) 138 | encoding = x.view(x.size(0), -1) 139 | 140 | encoding_s = encoding 141 | 142 | if self.training and self.fds: 143 | if epoch >= self.start_smooth: 144 | encoding_s = self.FDS.smooth(encoding_s, targets, epoch) 145 | 146 | if self.use_dropout: 147 | encoding_s = self.dropout(encoding_s) 148 | x = self.linear(encoding_s) 149 | 150 | if self.training and self.fds: 151 | return x, encoding 152 | else: 153 | return x 154 | 155 | 156 | def resnet50(**kwargs): 157 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 158 | -------------------------------------------------------------------------------- /imdb-wiki-dir/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import logging 5 | import numpy as np 6 | from scipy.ndimage import gaussian_filter1d 7 | from scipy.signal.windows import triang 8 | 9 | 10 | class AverageMeter(object): 11 | def __init__(self, name, fmt=':f'): 12 | self.name = name 13 | self.fmt = fmt 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | def __str__(self): 29 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 30 | return fmtstr.format(**self.__dict__) 31 | 32 | 33 | class ProgressMeter(object): 34 | def __init__(self, num_batches, meters, prefix=""): 35 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 36 | self.meters = meters 37 | self.prefix = prefix 38 | 39 | def display(self, batch): 40 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 41 | entries += [str(meter) for meter in self.meters] 42 | logging.info('\t'.join(entries)) 43 | 44 | @staticmethod 45 | def _get_batch_fmtstr(num_batches): 46 | num_digits = len(str(num_batches // 1)) 47 | fmt = '{:' + str(num_digits) + 'd}' 48 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 49 | 50 | 51 | def query_yes_no(question): 52 | """ Ask a yes/no question via input() and return their answer. """ 53 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} 54 | prompt = " [Y/n] " 55 | 56 | while True: 57 | print(question + prompt, end=':') 58 | choice = input().lower() 59 | if choice == '': 60 | return valid['y'] 61 | elif choice in valid: 62 | return valid[choice] 63 | else: 64 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") 65 | 66 | 67 | def prepare_folders(args): 68 | folders_util = [args.store_root, os.path.join(args.store_root, args.store_name)] 69 | if os.path.exists(folders_util[-1]) and not args.resume and not args.pretrained and not args.evaluate: 70 | if query_yes_no('overwrite previous folder: {} ?'.format(folders_util[-1])): 71 | shutil.rmtree(folders_util[-1]) 72 | print(folders_util[-1] + ' removed.') 73 | else: 74 | raise RuntimeError('Output folder {} already exists'.format(folders_util[-1])) 75 | for folder in folders_util: 76 | if not os.path.exists(folder): 77 | print(f"===> Creating folder: {folder}") 78 | os.mkdir(folder) 79 | 80 | 81 | def adjust_learning_rate(optimizer, epoch, args): 82 | lr = args.lr 83 | for milestone in args.schedule: 84 | lr *= 0.1 if epoch >= milestone else 1. 85 | for param_group in optimizer.param_groups: 86 | param_group['lr'] = lr 87 | 88 | 89 | def save_checkpoint(args, state, is_best, prefix=''): 90 | filename = f"{args.store_root}/{args.store_name}/{prefix}ckpt.pth.tar" 91 | torch.save(state, filename) 92 | if is_best: 93 | logging.info("===> Saving current best checkpoint...") 94 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 95 | 96 | 97 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.1, clip_max=10): 98 | if torch.sum(v1) < 1e-10: 99 | return matrix 100 | if (v1 == 0.).any(): 101 | valid = (v1 != 0.) 102 | factor = torch.clamp(v2[valid] / v1[valid], clip_min, clip_max) 103 | matrix[:, valid] = (matrix[:, valid] - m1[valid]) * torch.sqrt(factor) + m2[valid] 104 | return matrix 105 | 106 | factor = torch.clamp(v2 / v1, clip_min, clip_max) 107 | return (matrix - m1) * torch.sqrt(factor) + m2 108 | 109 | 110 | def get_lds_kernel_window(kernel, ks, sigma): 111 | assert kernel in ['gaussian', 'triang', 'laplace'] 112 | half_ks = (ks - 1) // 2 113 | if kernel == 'gaussian': 114 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 115 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) 116 | elif kernel == 'triang': 117 | kernel_window = triang(ks) 118 | else: 119 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 120 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1))) 121 | 122 | return kernel_window 123 | -------------------------------------------------------------------------------- /nyud2-dir/README.md: -------------------------------------------------------------------------------- 1 | # NYUD2-DIR 2 | ## Installation 3 | 4 | #### Prerequisites 5 | 6 | 1. Download and extract NYU v2 dataset to folder `./data` using 7 | 8 | ```bash 9 | python download_nyud2.py 10 | ``` 11 | 12 | 2. __(Optional)__ We have provided required meta files `nyu2_train_FDS_subset.csv` and `test_balanced_mask.npy` for efficient FDS feature statistics computation and balanced test set mask in folder `./data`. To reproduce the results in the paper, please directly use these two files. If you want to try different FDS computation subsets and balanced test set masks, you can run 13 | 14 | ```bash 15 | python preprocess_nyud2.py 16 | ``` 17 | 18 | #### Dependencies 19 | 20 | - PyTorch (>= 1.2, tested on 1.6) 21 | - numpy, pandas, scipy, tqdm, matplotlib, PIL, gdown, tensorboardX 22 | 23 | ## Code Overview 24 | 25 | #### Main Files 26 | 27 | - `train.py`: main training script 28 | - `test.py`: main evaluation script 29 | - `preprocess_nyud2.py`: create meta files `nyu2_train_FDS_subset.csv` and `test_balanced_mask.npy` for NYUD2-DIR 30 | 31 | #### Main Arguments 32 | 33 | For `train.py`: 34 | 35 | - `--data_dir`: data directory to place data and meta file 36 | - `--lds`: LDS switch (whether to enable LDS) 37 | - `--fds`: FDS switch (whether to enable FDS) 38 | - `--reweight`: cost-sensitive re-weighting scheme to use 39 | - `--resume`: whether to resume training (only for training) 40 | - `--retrain_fc`: whether to retrain regressor 41 | - `--pretrained`: path to load backbone weights for regressor re-training (RRT) 42 | 43 | For `test.py`: 44 | 45 | - `--eval_model`: path to resume checkpoint (only for evaluation) 46 | 47 | ## Getting Started 48 | 49 | #### Train a vanilla model 50 | 51 | ```bash 52 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_dir --reweight none 53 | ``` 54 | 55 | Always specify `CUDA_VISIBLE_DEVICES` for GPU IDs to be used (by default, 4 GPUs) and `--data_dir` when training a model or directly fix your default data directory path in the code. We will omit these arguments in the following for simplicity. 56 | 57 | 58 | #### Train a model using re-weighting 59 | 60 | To perform inverse re-weighting 61 | 62 | ```bash 63 | python train.py --reweight inverse 64 | ``` 65 | 66 | To perform square-root inverse re-weighting 67 | 68 | ```bash 69 | python train.py --reweight sqrt_inv 70 | ``` 71 | 72 | #### Train a model using RRT 73 | 74 | ```bash 75 | python train.py [...retrained model arguments...] --retrain_fc --pretrained 76 | ``` 77 | 78 | #### Train a model using LDS 79 | 80 | To use Gaussian kernel (kernel size: 5, sigma: 2) 81 | 82 | ```bash 83 | python train.py --reweight inverse --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 84 | ``` 85 | 86 | #### Train a model using FDS 87 | 88 | To use Gaussian kernel (kernel size: 5, sigma: 2) 89 | 90 | ```bash 91 | python train.py --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 92 | ``` 93 | 94 | #### Train a model using LDS + FDS 95 | 96 | ```bash 97 | python train.py --reweight inverse --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 98 | ``` 99 | 100 | #### Evaluate a trained checkpoint 101 | 102 | ```bash 103 | python test.py --data_dir --eval_model 104 | ``` 105 | 106 | ## Reproduced Benchmarks and Model Zoo 107 | 108 | We provide below reproduced results on NYUD2-DIR (base method `Vanilla`, metric `RMSE`). 109 | Note that some models could give **better** results than the reported numbers in the paper. 110 | 111 | | Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download | 112 | | :-------: | :-----: | :-------: | :---------: | :------: | :------: | 113 | | LDS | 1.387 | 0.671 | 0.913 | 1.954 | [model](https://drive.google.com/file/d/1RgQx-nreiJ-chH0887xCy7gxah-zrrEO/view?usp=sharing) | 114 | | FDS | 1.442 | 0.615 | 0.940 | 2.059 | [model](https://drive.google.com/file/d/1FEKzBzMPaGubmv9iK4BP6LJng44Mhc7s/view?usp=sharing) | 115 | | LDS + FDS | 1.301 | 0.731 | 0.832 | 1.799 | [model](https://drive.google.com/file/d/1QlZJOPYSyRRFqa1Q-y7-JlTDABiQZJUF/view?usp=sharing) | -------------------------------------------------------------------------------- /nyud2-dir/data/test_balanced_mask.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/nyud2-dir/data/test_balanced_mask.npy -------------------------------------------------------------------------------- /nyud2-dir/download_nyud2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import zipfile 4 | 5 | print("Downloading and extracting NYU v2 dataset to folder './data'...") 6 | data_file = "./data.zip" 7 | gdown.download("https://drive.google.com/uc?id=1WoOZOBpOWfmwe7bknWS5PMUCLBPFKTOw") 8 | print('Extracting...') 9 | with zipfile.ZipFile(data_file) as zip_ref: 10 | zip_ref.extractall('.') 11 | os.remove(data_file) 12 | print("Completed!") -------------------------------------------------------------------------------- /nyud2-dir/loaddata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import pandas as pd 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms 6 | from nyu_transform import * 7 | from scipy.ndimage import convolve1d 8 | from util import get_lds_kernel_window 9 | 10 | # for data loading efficiency 11 | TRAIN_BUCKET_NUM = [0, 0, 0, 0, 0, 0, 0, 25848691, 24732940, 53324326, 69112955, 54455432, 95637682, 71403954, 117244217, 12 | 84813007, 126524456, 84486706, 133130272, 95464874, 146051415, 146133612, 96561379, 138366677, 89680276, 13 | 127689043, 81608990, 119121178, 74360607, 106839384, 97595765, 66718296, 90661239, 53103021, 83340912, 14 | 51365604, 71262770, 42243737, 65860580, 38415940, 53647559, 54038467, 28335524, 41485143, 32106001, 15 | 35936734, 23966211, 32018765, 19297203, 31503743, 21681574, 16363187, 25743420, 12769509, 17675327, 16 | 13147819, 15798560, 9547180, 14933200, 9663019, 12887283, 11803562, 7656609, 11515700, 7756306, 9046228, 17 | 5114894, 8653419, 6859433, 8001904, 6430700, 3305839, 6318461, 3486268, 5621065, 4030498, 3839488, 3220208, 18 | 4483027, 2555777, 4685983, 3145082, 2951048, 2762369, 2367581, 2546089, 2343867, 2481579, 1722140, 3018892, 19 | 2325197, 1952354, 2047038, 1858707, 2052729, 1348558, 2487278, 1314198, 3338550, 1132666] 20 | 21 | class depthDataset(Dataset): 22 | def __init__(self, data_dir, csv_file, mask_file=None, transform=None, args=None): 23 | self.data_dir = data_dir 24 | self.frame = pd.read_csv(csv_file, header=None) 25 | self.mask = torch.tensor(np.load(mask_file), dtype=torch.bool) if mask_file is not None else None 26 | self.transform = transform 27 | self.bucket_weights = self._get_bucket_weights(args) if args is not None else None 28 | 29 | def _get_bucket_weights(self, args): 30 | assert args.reweight in {'none', 'inverse', 'sqrt_inv'} 31 | assert args.reweight != 'none' if args.lds else True, "Set reweight to \'sqrt_inv\' or \'inverse\' (default) when using LDS" 32 | if args.reweight == 'none': 33 | return None 34 | logging.info(f"Using re-weighting: [{args.reweight.upper()}]") 35 | 36 | if args.lds: 37 | value_lst = TRAIN_BUCKET_NUM[args.bucket_start:] 38 | lds_kernel_window = get_lds_kernel_window(args.lds_kernel, args.lds_ks, args.lds_sigma) 39 | logging.info(f'Using LDS: [{args.lds_kernel.upper()}] ({args.lds_ks}/{args.lds_sigma})') 40 | if args.reweight == 'sqrt_inv': 41 | value_lst = np.sqrt(value_lst) 42 | smoothed_value = convolve1d(np.asarray(value_lst), weights=lds_kernel_window, mode='reflect') 43 | smoothed_value = [smoothed_value[0]] * args.bucket_start + list(smoothed_value) 44 | scaling = np.sum(TRAIN_BUCKET_NUM) / np.sum(np.array(TRAIN_BUCKET_NUM) / np.array(smoothed_value)) 45 | bucket_weights = [np.float32(scaling / smoothed_value[bucket]) for bucket in range(args.bucket_num)] 46 | else: 47 | value_lst = [TRAIN_BUCKET_NUM[args.bucket_start]] * args.bucket_start + TRAIN_BUCKET_NUM[args.bucket_start:] 48 | if args.reweight == 'sqrt_inv': 49 | value_lst = np.sqrt(value_lst) 50 | scaling = np.sum(TRAIN_BUCKET_NUM) / np.sum(np.array(TRAIN_BUCKET_NUM) / np.array(value_lst)) 51 | bucket_weights = [np.float32(scaling / value_lst[bucket]) for bucket in range(args.bucket_num)] 52 | 53 | return bucket_weights 54 | 55 | def get_bin_idx(self, x): 56 | return min(int(x * np.float32(10)), 99) 57 | 58 | def _get_weights(self, depth): 59 | sp = depth.shape 60 | if self.bucket_weights is not None: 61 | depth = depth.view(-1).cpu().numpy() 62 | assert depth.dtype == np.float32 63 | weights = np.array(list(map(lambda v: self.bucket_weights[self.get_bin_idx(v)], depth))) 64 | weights = torch.tensor(weights, dtype=torch.float32).view(*sp) 65 | else: 66 | weights = torch.tensor([np.float32(1.)], dtype=torch.float32).repeat(*sp) 67 | return weights 68 | 69 | def __getitem__(self, idx): 70 | image_name = self.frame.iloc[idx, 0] 71 | depth_name = self.frame.iloc[idx, 1] 72 | 73 | image_name = os.path.join(self.data_dir, '/'.join(image_name.split('/')[1:])) 74 | depth_name = os.path.join(self.data_dir, '/'.join(depth_name.split('/')[1:])) 75 | 76 | image = Image.open(image_name) 77 | depth = Image.open(depth_name) 78 | 79 | sample = {'image': image, 'depth': depth} 80 | 81 | if self.transform: 82 | sample = self.transform(sample) 83 | 84 | sample['weight'] = self._get_weights(sample['depth']) 85 | sample['idx'] = idx 86 | 87 | if self.mask is not None: 88 | sample['mask'] = self.mask[idx].unsqueeze(0) 89 | 90 | return sample 91 | 92 | def __len__(self): 93 | return len(self.frame) 94 | 95 | 96 | def getTrainingData(args, batch_size=64): 97 | __imagenet_pca = { 98 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 99 | 'eigvec': torch.Tensor([ 100 | [-0.5675, 0.7192, 0.4009], 101 | [-0.5808, -0.0045, -0.8140], 102 | [-0.5836, -0.6948, 0.4203], 103 | ]) 104 | } 105 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 106 | 'std': [0.229, 0.224, 0.225]} 107 | 108 | transformed_training = depthDataset(data_dir=args.data_dir, 109 | csv_file=os.path.join(args.data_dir, 'nyu2_train.csv'), 110 | transform=transforms.Compose([ 111 | Scale(240), 112 | RandomHorizontalFlip(), 113 | RandomRotate(5), 114 | CenterCrop([304, 228], [152, 114]), 115 | ToTensor(), 116 | Lighting(0.1, __imagenet_pca[ 117 | 'eigval'], __imagenet_pca['eigvec']), 118 | ColorJitter( 119 | brightness=0.4, 120 | contrast=0.4, 121 | saturation=0.4, 122 | ), 123 | Normalize(__imagenet_stats['mean'], 124 | __imagenet_stats['std']) 125 | ]), args=args) 126 | 127 | dataloader_training = DataLoader(transformed_training, batch_size, 128 | shuffle=True, num_workers=8, pin_memory=False) 129 | 130 | return dataloader_training 131 | 132 | def getTrainingFDSData(args, batch_size=64): 133 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 134 | 'std': [0.229, 0.224, 0.225]} 135 | 136 | transformed_training = depthDataset(data_dir=args.data_dir, 137 | csv_file=os.path.join(args.data_dir, 'nyu2_train_FDS_subset.csv'), 138 | transform=transforms.Compose([ 139 | Scale(240), 140 | CenterCrop([304, 228], [152, 114]), 141 | ToTensor(), 142 | Normalize(__imagenet_stats['mean'], 143 | __imagenet_stats['std']) 144 | ])) 145 | 146 | dataloader_training = DataLoader(transformed_training, batch_size, 147 | shuffle=False, num_workers=8, pin_memory=False) 148 | return dataloader_training 149 | 150 | 151 | def getTestingData(args, batch_size=64): 152 | 153 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 154 | 'std': [0.229, 0.224, 0.225]} 155 | 156 | transformed_testing = depthDataset(data_dir=args.data_dir, 157 | csv_file=os.path.join(args.data_dir, 'nyu2_test.csv'), 158 | mask_file=os.path.join(args.data_dir, 'test_balanced_mask.npy'), 159 | transform=transforms.Compose([ 160 | Scale(240), 161 | CenterCrop([304, 228], [304, 228]), 162 | ToTensor(is_test=True), 163 | Normalize(__imagenet_stats['mean'], 164 | __imagenet_stats['std']) 165 | ])) 166 | 167 | dataloader_testing = DataLoader(transformed_testing, batch_size, 168 | shuffle=False, num_workers=0, pin_memory=False) 169 | 170 | return dataloader_testing 171 | -------------------------------------------------------------------------------- /nyud2-dir/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/nyud2-dir/models/__init__.py -------------------------------------------------------------------------------- /nyud2-dir/models/fds.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from scipy.ndimage import gaussian_filter1d 7 | from scipy.signal.windows import triang 8 | from util import calibrate_mean_var 9 | 10 | 11 | class FDS(nn.Module): 12 | 13 | def __init__(self, feature_dim, bucket_num=100, bucket_start=7, start_update=0, start_smooth=1, 14 | kernel='gaussian', ks=5, sigma=2, momentum=0.9): 15 | super(FDS, self).__init__() 16 | self.feature_dim = feature_dim 17 | self.bucket_num = bucket_num 18 | self.bucket_start = bucket_start 19 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma) 20 | self.half_ks = (ks - 1) // 2 21 | self.momentum = momentum 22 | self.start_update = start_update 23 | self.start_smooth = start_smooth 24 | 25 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update)) 26 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim)) 27 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim)) 28 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 29 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 30 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 31 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 32 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start)) 33 | 34 | @staticmethod 35 | def _get_kernel_window(kernel, ks, sigma): 36 | assert kernel in ['gaussian', 'triang', 'laplace'] 37 | half_ks = (ks - 1) // 2 38 | if kernel == 'gaussian': 39 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 40 | base_kernel = np.array(base_kernel, dtype=np.float32) 41 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma)) 42 | elif kernel == 'triang': 43 | kernel_window = triang(ks) / sum(triang(ks)) 44 | else: 45 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 46 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1))) 47 | 48 | logging.info(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})') 49 | return torch.tensor(kernel_window, dtype=torch.float32).cuda() 50 | 51 | def _get_bucket_idx(self, label): 52 | label = np.float32(label.cpu()) 53 | return max(min(int(label * np.float32(10)), self.bucket_num - 1), self.bucket_start) 54 | 55 | def _update_last_epoch_stats(self): 56 | self.running_mean_last_epoch = self.running_mean 57 | self.running_var_last_epoch = self.running_var 58 | 59 | self.smoothed_mean_last_epoch = F.conv1d( 60 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0), 61 | pad=(self.half_ks, self.half_ks), mode='reflect'), 62 | weight=self.kernel_window.view(1, 1, -1), padding=0 63 | ).permute(2, 1, 0).squeeze(1) 64 | self.smoothed_var_last_epoch = F.conv1d( 65 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0), 66 | pad=(self.half_ks, self.half_ks), mode='reflect'), 67 | weight=self.kernel_window.view(1, 1, -1), padding=0 68 | ).permute(2, 1, 0).squeeze(1) 69 | 70 | assert self.smoothed_mean_last_epoch.shape == self.running_mean_last_epoch.shape, \ 71 | "Smoothed shape is not aligned with running shape!" 72 | 73 | def reset(self): 74 | self.running_mean.zero_() 75 | self.running_var.fill_(1) 76 | self.running_mean_last_epoch.zero_() 77 | self.running_var_last_epoch.fill_(1) 78 | self.smoothed_mean_last_epoch.zero_() 79 | self.smoothed_var_last_epoch.fill_(1) 80 | self.num_samples_tracked.zero_() 81 | 82 | def update_last_epoch_stats(self, epoch): 83 | if epoch == self.epoch + 1: 84 | self.epoch += 1 85 | self._update_last_epoch_stats() 86 | logging.info(f"Updated smoothed statistics of last epoch on Epoch [{epoch}]!") 87 | 88 | def _running_stats_to_device(self, device): 89 | if device == 'cpu': 90 | self.num_samples_tracked = self.num_samples_tracked.cpu() 91 | self.running_mean = self.running_mean.cpu() 92 | self.running_var = self.running_var.cpu() 93 | else: 94 | self.num_samples_tracked = self.num_samples_tracked.cuda() 95 | self.running_mean = self.running_mean.cuda() 96 | self.running_var = self.running_var.cuda() 97 | 98 | def update_running_stats(self, features, labels, epoch): 99 | if epoch < self.epoch: 100 | return 101 | 102 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!" 103 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!" 104 | 105 | self._running_stats_to_device('cpu') 106 | 107 | labels = labels.squeeze(1).view(-1) 108 | features = features.permute(0, 2, 3, 1).contiguous().view(-1, self.feature_dim) 109 | 110 | buckets = np.array([self._get_bucket_idx(label) for label in labels]) 111 | for bucket in np.unique(buckets): 112 | curr_feats = features[torch.tensor((buckets == bucket).astype(np.bool))] 113 | curr_num_sample = curr_feats.size(0) 114 | curr_mean = torch.mean(curr_feats, 0) 115 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False) 116 | 117 | self.num_samples_tracked[bucket - self.bucket_start] += curr_num_sample 118 | factor = self.momentum if self.momentum is not None else \ 119 | (1 - curr_num_sample / float(self.num_samples_tracked[bucket - self.bucket_start])) 120 | factor = 0 if epoch == self.start_update else factor 121 | self.running_mean[bucket - self.bucket_start] = \ 122 | (1 - factor) * curr_mean + factor * self.running_mean[bucket - self.bucket_start] 123 | self.running_var[bucket - self.bucket_start] = \ 124 | (1 - factor) * curr_var + factor * self.running_var[bucket - self.bucket_start] 125 | 126 | self._running_stats_to_device('cuda') 127 | logging.info(f"Updated running statistics with Epoch [{epoch}] features!") 128 | 129 | def smooth(self, features, labels, epoch): 130 | if epoch < self.start_smooth: 131 | return features 132 | 133 | sp = labels.squeeze(1).shape 134 | 135 | labels = labels.squeeze(1).view(-1) 136 | features = features.permute(0, 2, 3, 1).contiguous().view(-1, self.feature_dim) 137 | 138 | buckets = torch.max(torch.stack([torch.min(torch.stack([torch.floor(labels * torch.tensor([10.]).cuda()).int(), 139 | torch.zeros(labels.size(0)).fill_(self.bucket_num - 1).int().cuda()], 0), 0)[0], torch.zeros(labels.size(0)).fill_(self.bucket_start).int().cuda()], 0), 0)[0] 140 | for bucket in torch.unique(buckets): 141 | features[buckets.eq(bucket)] = calibrate_mean_var( 142 | features[buckets.eq(bucket)], 143 | self.running_mean_last_epoch[bucket.item() - self.bucket_start], 144 | self.running_var_last_epoch[bucket.item() - self.bucket_start], 145 | self.smoothed_mean_last_epoch[bucket.item() - self.bucket_start], 146 | self.smoothed_var_last_epoch[bucket.item() - self.bucket_start] 147 | ) 148 | 149 | return features.view(*sp, self.feature_dim).permute(0, 3, 1, 2) 150 | -------------------------------------------------------------------------------- /nyud2-dir/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .fds import FDS 5 | 6 | class _UpProjection(nn.Sequential): 7 | 8 | def __init__(self, num_input_features, num_output_features): 9 | super(_UpProjection, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(num_input_features, num_output_features, 12 | kernel_size=5, stride=1, padding=2, bias=False) 13 | self.bn1 = nn.BatchNorm2d(num_output_features) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.conv1_2 = nn.Conv2d(num_output_features, num_output_features, 16 | kernel_size=3, stride=1, padding=1, bias=False) 17 | self.bn1_2 = nn.BatchNorm2d(num_output_features) 18 | 19 | self.conv2 = nn.Conv2d(num_input_features, num_output_features, 20 | kernel_size=5, stride=1, padding=2, bias=False) 21 | self.bn2 = nn.BatchNorm2d(num_output_features) 22 | 23 | def forward(self, x, size): 24 | x = F.upsample(x, size=size, mode='bilinear') 25 | x_conv1 = self.relu(self.bn1(self.conv1(x))) 26 | bran1 = self.bn1_2(self.conv1_2(x_conv1)) 27 | bran2 = self.bn2(self.conv2(x)) 28 | 29 | out = self.relu(bran1 + bran2) 30 | 31 | return out 32 | 33 | class E_resnet(nn.Module): 34 | 35 | def __init__(self, original_model, num_features = 2048): 36 | super(E_resnet, self).__init__() 37 | self.conv1 = original_model.conv1 38 | self.bn1 = original_model.bn1 39 | self.relu = original_model.relu 40 | self.maxpool = original_model.maxpool 41 | 42 | self.layer1 = original_model.layer1 43 | self.layer2 = original_model.layer2 44 | self.layer3 = original_model.layer3 45 | self.layer4 = original_model.layer4 46 | 47 | 48 | def forward(self, x): 49 | x = self.conv1(x) 50 | x = self.bn1(x) 51 | x = self.relu(x) 52 | x = self.maxpool(x) 53 | 54 | x_block1 = self.layer1(x) 55 | x_block2 = self.layer2(x_block1) 56 | x_block3 = self.layer3(x_block2) 57 | x_block4 = self.layer4(x_block3) 58 | 59 | return x_block1, x_block2, x_block3, x_block4 60 | 61 | class D(nn.Module): 62 | 63 | def __init__(self, num_features = 2048): 64 | super(D, self).__init__() 65 | self.conv = nn.Conv2d(num_features, num_features // 66 | 2, kernel_size=1, stride=1, bias=False) 67 | num_features = num_features // 2 68 | self.bn = nn.BatchNorm2d(num_features) 69 | 70 | self.up1 = _UpProjection( 71 | num_input_features=num_features, num_output_features=num_features // 2) 72 | num_features = num_features // 2 73 | 74 | self.up2 = _UpProjection( 75 | num_input_features=num_features, num_output_features=num_features // 2) 76 | num_features = num_features // 2 77 | 78 | self.up3 = _UpProjection( 79 | num_input_features=num_features, num_output_features=num_features // 2) 80 | num_features = num_features // 2 81 | 82 | self.up4 = _UpProjection( 83 | num_input_features=num_features, num_output_features=num_features // 2) 84 | num_features = num_features // 2 85 | 86 | 87 | def forward(self, x_block1, x_block2, x_block3, x_block4): 88 | x_d0 = F.relu(self.bn(self.conv(x_block4))) 89 | x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)]) 90 | x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)]) 91 | x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)]) 92 | x_d4 = self.up4(x_d3, [x_block1.size(2)*2, x_block1.size(3)*2]) 93 | 94 | return x_d4 95 | 96 | class MFF(nn.Module): 97 | 98 | def __init__(self, block_channel, num_features=64): 99 | 100 | super(MFF, self).__init__() 101 | 102 | self.up1 = _UpProjection( 103 | num_input_features=block_channel[0], num_output_features=16) 104 | 105 | self.up2 = _UpProjection( 106 | num_input_features=block_channel[1], num_output_features=16) 107 | 108 | self.up3 = _UpProjection( 109 | num_input_features=block_channel[2], num_output_features=16) 110 | 111 | self.up4 = _UpProjection( 112 | num_input_features=block_channel[3], num_output_features=16) 113 | 114 | self.conv = nn.Conv2d( 115 | num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False) 116 | self.bn = nn.BatchNorm2d(num_features) 117 | 118 | 119 | def forward(self, x_block1, x_block2, x_block3, x_block4, size): 120 | x_m1 = self.up1(x_block1, size) 121 | x_m2 = self.up2(x_block2, size) 122 | x_m3 = self.up3(x_block3, size) 123 | x_m4 = self.up4(x_block4, size) 124 | 125 | x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1))) 126 | x = F.relu(x) 127 | 128 | return x 129 | 130 | 131 | class R(nn.Module): 132 | def __init__(self, args, block_channel): 133 | 134 | super(R, self).__init__() 135 | 136 | num_features = 64 + block_channel[3] // 32 137 | self.conv0 = nn.Conv2d(num_features, num_features, 138 | kernel_size=5, stride=1, padding=2, bias=False) 139 | self.bn0 = nn.BatchNorm2d(num_features) 140 | 141 | self.conv1 = nn.Conv2d(num_features, num_features, 142 | kernel_size=5, stride=1, padding=2, bias=False) 143 | self.bn1 = nn.BatchNorm2d(num_features) 144 | 145 | self.conv2 = nn.Conv2d(num_features, 1, kernel_size=5, stride=1, padding=2, bias=True) 146 | 147 | self.args = args 148 | 149 | if args is not None and args.fds: 150 | self.FDS = FDS(feature_dim=num_features, bucket_num=args.bucket_num, bucket_start=args.bucket_start, 151 | start_update=args.start_update, start_smooth=args.start_smooth, kernel=args.fds_kernel, 152 | ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt) 153 | 154 | def forward(self, x, depth=None, epoch=None): 155 | x0 = self.conv0(x) 156 | x0 = self.bn0(x0) 157 | x0 = F.relu(x0) 158 | 159 | x1 = self.conv1(x0) 160 | x1 = self.bn1(x1) 161 | x1 = F.relu(x1) 162 | 163 | x1_s = x1 164 | 165 | if self.training and self.args.fds: 166 | if epoch >= self.args.start_smooth: 167 | x1_s = self.FDS.smooth(x1_s, depth, epoch) 168 | 169 | x2 = self.conv2(x1_s) 170 | 171 | if self.training and self.args.fds: 172 | return x2, x1 173 | else: 174 | return x2 -------------------------------------------------------------------------------- /nyud2-dir/models/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models import modules 4 | 5 | class model(nn.Module): 6 | def __init__(self, args, Encoder, num_features, block_channel): 7 | 8 | super(model, self).__init__() 9 | 10 | self.E = Encoder 11 | self.D = modules.D(num_features) 12 | self.MFF = modules.MFF(block_channel) 13 | self.R = modules.R(args, block_channel) 14 | 15 | 16 | def forward(self, x, depth=None, epoch=None): 17 | x_block1, x_block2, x_block3, x_block4 = self.E(x) 18 | x_decoder = self.D(x_block1, x_block2, x_block3, x_block4) 19 | x_mff = self.MFF(x_block1, x_block2, x_block3, x_block4,[x_decoder.size(2),x_decoder.size(3)]) 20 | out = self.R(torch.cat((x_decoder, x_mff), 1), depth, epoch) 21 | 22 | return out 23 | -------------------------------------------------------------------------------- /nyud2-dir/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 6 | 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self, block, layers, num_classes=1000): 97 | self.inplanes = 64 98 | super(ResNet, self).__init__() 99 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 100 | bias=False) 101 | self.bn1 = nn.BatchNorm2d(64) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 104 | self.layer1 = self._make_layer(block, 64, layers[0]) 105 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 106 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 107 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 108 | self.avgpool = nn.AvgPool2d(7, stride=1) 109 | self.fc = nn.Linear(512 * block.expansion, num_classes) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | x = self.fc(x) 150 | 151 | return x 152 | 153 | def resnet18(pretrained=False, **kwargs): 154 | """Constructs a ResNet-18 model. 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | """ 158 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 159 | if pretrained: 160 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 161 | return model 162 | 163 | 164 | def resnet34(pretrained=False, **kwargs): 165 | """Constructs a ResNet-34 model. 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 170 | if pretrained: 171 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 172 | return model 173 | 174 | 175 | def resnet50(pretrained=False, **kwargs): 176 | """Constructs a ResNet-50 model. 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 181 | if pretrained: 182 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], 'pretrained_model/encoder')) 183 | return model 184 | 185 | 186 | def resnet101(pretrained=False, **kwargs): 187 | """Constructs a ResNet-101 model. 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 192 | if pretrained: 193 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 194 | return model 195 | 196 | 197 | def resnet152(pretrained=False, **kwargs): 198 | """Constructs a ResNet-152 model. 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 203 | if pretrained: 204 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 205 | return model 206 | -------------------------------------------------------------------------------- /nyud2-dir/nyu_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import collections 5 | try: 6 | import accimage 7 | except ImportError: 8 | accimage = None 9 | import random 10 | import scipy.ndimage as ndimage 11 | 12 | 13 | def _is_pil_image(img): 14 | if accimage is not None: 15 | return isinstance(img, (Image.Image, accimage.Image)) 16 | else: 17 | return isinstance(img, Image.Image) 18 | 19 | 20 | def _is_numpy_image(img): 21 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 22 | 23 | 24 | class RandomRotate(object): 25 | """Random rotation of the image from -angle to angle (in degrees) 26 | This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation 27 | angle: max angle of the rotation 28 | interpolation order: Default: 2 (bilinear) 29 | reshape: Default: false. If set to true, image size will be set to keep every pixel in the image. 30 | diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off. 31 | """ 32 | 33 | def __init__(self, angle, diff_angle=0, order=2, reshape=False): 34 | self.angle = angle 35 | self.reshape = reshape 36 | self.order = order 37 | 38 | def __call__(self, sample): 39 | image, depth = sample['image'], sample['depth'] 40 | 41 | applied_angle = random.uniform(-self.angle, self.angle) 42 | angle1 = applied_angle 43 | angle1_rad = angle1 * np.pi / 180 44 | 45 | image = ndimage.interpolation.rotate( 46 | image, angle1, reshape=self.reshape, order=self.order) 47 | depth = ndimage.interpolation.rotate( 48 | depth, angle1, reshape=self.reshape, order=self.order) 49 | 50 | image = Image.fromarray(image) 51 | depth = Image.fromarray(depth) 52 | 53 | return {'image': image, 'depth': depth} 54 | 55 | class RandomHorizontalFlip(object): 56 | 57 | def __call__(self, sample): 58 | image, depth = sample['image'], sample['depth'] 59 | 60 | if not _is_pil_image(image): 61 | raise TypeError( 62 | 'img should be PIL Image. Got {}'.format(type(image))) 63 | if not _is_pil_image(depth): 64 | raise TypeError( 65 | 'img should be PIL Image. Got {}'.format(type(depth))) 66 | 67 | if random.random() < 0.5: 68 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 69 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 70 | 71 | return {'image': image, 'depth': depth} 72 | 73 | 74 | class Scale(object): 75 | """ Rescales the inputs and target arrays to the given 'size'. 76 | 'size' will be the size of the smaller edge. 77 | For example, if height > width, then image will be 78 | rescaled to (size * height / width, size) 79 | size: size of the smaller edge 80 | interpolation order: Default: 2 (bilinear) 81 | """ 82 | 83 | def __init__(self, size): 84 | self.size = size 85 | 86 | def __call__(self, sample): 87 | image, depth = sample['image'], sample['depth'] 88 | 89 | image = self.changeScale(image, self.size) 90 | depth = self.changeScale(depth, self.size,Image.NEAREST) 91 | 92 | return {'image': image, 'depth': depth} 93 | 94 | def changeScale(self, img, size, interpolation=Image.BILINEAR): 95 | 96 | if not _is_pil_image(img): 97 | raise TypeError( 98 | 'img should be PIL Image. Got {}'.format(type(img))) 99 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 100 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 101 | 102 | if isinstance(size, int): 103 | w, h = img.size 104 | if (w <= h and w == size) or (h <= w and h == size): 105 | return img 106 | if w < h: 107 | ow = size 108 | oh = int(size * h / w) 109 | return img.resize((ow, oh), interpolation) 110 | else: 111 | oh = size 112 | ow = int(size * w / h) 113 | return img.resize((ow, oh), interpolation) 114 | else: 115 | return img.resize(size[::-1], interpolation) 116 | 117 | 118 | class CenterCrop(object): 119 | def __init__(self, size_image, size_depth): 120 | self.size_image = size_image 121 | self.size_depth = size_depth 122 | 123 | def __call__(self, sample): 124 | image, depth = sample['image'], sample['depth'] 125 | 126 | image = self.centerCrop(image, self.size_image) 127 | depth = self.centerCrop(depth, self.size_image) 128 | 129 | ow, oh = self.size_depth 130 | depth = depth.resize((ow, oh)) 131 | 132 | return {'image': image, 'depth': depth} 133 | 134 | def centerCrop(self, image, size): 135 | 136 | w1, h1 = image.size 137 | 138 | tw, th = size 139 | 140 | if w1 == tw and h1 == th: 141 | return image 142 | 143 | x1 = int(round((w1 - tw) / 2.)) 144 | y1 = int(round((h1 - th) / 2.)) 145 | 146 | image = image.crop((x1, y1, tw + x1, th + y1)) 147 | 148 | return image 149 | 150 | 151 | class ToTensor(object): 152 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 153 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 154 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 155 | """ 156 | def __init__(self,is_test=False): 157 | self.is_test = is_test 158 | 159 | def __call__(self, sample): 160 | image, depth = sample['image'], sample['depth'] 161 | """ 162 | Args: 163 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 164 | Returns: 165 | Tensor: Converted image. 166 | """ 167 | # ground truth depth of training samples is stored in 8-bit while test samples are saved in 16 bit 168 | image = self.to_tensor(image) 169 | if self.is_test: 170 | depth = self.to_tensor(depth).float() / 1000 171 | else: 172 | depth = self.to_tensor(depth).float() * 10 173 | return {'image': image, 'depth': depth} 174 | 175 | def to_tensor(self, pic): 176 | if not(_is_pil_image(pic) or _is_numpy_image(pic)): 177 | raise TypeError( 178 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 179 | 180 | if isinstance(pic, np.ndarray): 181 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 182 | 183 | return img.float().div(255) 184 | 185 | if accimage is not None and isinstance(pic, accimage.Image): 186 | nppic = np.zeros( 187 | [pic.channels, pic.height, pic.width], dtype=np.float32) 188 | pic.copyto(nppic) 189 | return torch.from_numpy(nppic) 190 | 191 | # handle PIL Image 192 | if pic.mode == 'I': 193 | img = torch.from_numpy(np.array(pic, np.int32)) 194 | elif pic.mode == 'I;16': 195 | img = torch.from_numpy(np.array(pic, np.int16)) 196 | else: 197 | img = torch.ByteTensor( 198 | torch.ByteStorage.from_buffer(pic.tobytes())) 199 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 200 | if pic.mode == 'YCbCr': 201 | nchannel = 3 202 | elif pic.mode == 'I;16': 203 | nchannel = 1 204 | else: 205 | nchannel = len(pic.mode) 206 | img = img.view(pic.size[1], pic.size[0], nchannel) 207 | # put it from HWC to CHW format 208 | # yikes, this transpose takes 80% of the loading time/CPU 209 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 210 | if isinstance(img, torch.ByteTensor): 211 | return img.float().div(255) 212 | else: 213 | return img 214 | 215 | 216 | class Lighting(object): 217 | 218 | def __init__(self, alphastd, eigval, eigvec): 219 | self.alphastd = alphastd 220 | self.eigval = eigval 221 | self.eigvec = eigvec 222 | 223 | def __call__(self, sample): 224 | image, depth = sample['image'], sample['depth'] 225 | if self.alphastd == 0: 226 | return image 227 | 228 | alpha = image.new().resize_(3).normal_(0, self.alphastd) 229 | rgb = self.eigvec.type_as(image).clone()\ 230 | .mul(alpha.view(1, 3).expand(3, 3))\ 231 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 232 | .sum(1).squeeze() 233 | 234 | image = image.add(rgb.view(3, 1, 1).expand_as(image)) 235 | 236 | return {'image': image, 'depth': depth} 237 | 238 | 239 | class Grayscale(object): 240 | 241 | def __call__(self, img): 242 | gs = img.clone() 243 | gs[0].mul_(0.299).add_(gs[1], alpha=0.587).add_(gs[2], alpha=0.114) 244 | gs[1].copy_(gs[0]) 245 | gs[2].copy_(gs[0]) 246 | return gs 247 | 248 | 249 | class Saturation(object): 250 | 251 | def __init__(self, var): 252 | self.var = var 253 | 254 | def __call__(self, img): 255 | gs = Grayscale()(img) 256 | alpha = random.uniform(-self.var, self.var) 257 | return img.lerp(gs, alpha) 258 | 259 | 260 | class Brightness(object): 261 | 262 | def __init__(self, var): 263 | self.var = var 264 | 265 | def __call__(self, img): 266 | gs = img.new().resize_as_(img).zero_() 267 | alpha = random.uniform(-self.var, self.var) 268 | 269 | return img.lerp(gs, alpha) 270 | 271 | 272 | class Contrast(object): 273 | 274 | def __init__(self, var): 275 | self.var = var 276 | 277 | def __call__(self, img): 278 | gs = Grayscale()(img) 279 | gs.fill_(gs.mean()) 280 | alpha = random.uniform(-self.var, self.var) 281 | return img.lerp(gs, alpha) 282 | 283 | 284 | class RandomOrder(object): 285 | """ Composes several transforms together in random order. 286 | """ 287 | 288 | def __init__(self, transforms): 289 | self.transforms = transforms 290 | 291 | def __call__(self, sample): 292 | image, depth = sample['image'], sample['depth'] 293 | 294 | if self.transforms is None: 295 | return {'image': image, 'depth': depth} 296 | order = torch.randperm(len(self.transforms)) 297 | for i in order: 298 | image = self.transforms[i](image) 299 | 300 | return {'image': image, 'depth': depth} 301 | 302 | 303 | class ColorJitter(RandomOrder): 304 | 305 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 306 | self.transforms = [] 307 | if brightness != 0: 308 | self.transforms.append(Brightness(brightness)) 309 | if contrast != 0: 310 | self.transforms.append(Contrast(contrast)) 311 | if saturation != 0: 312 | self.transforms.append(Saturation(saturation)) 313 | 314 | 315 | class Normalize(object): 316 | def __init__(self, mean, std): 317 | self.mean = mean 318 | self.std = std 319 | 320 | def __call__(self, sample): 321 | """ 322 | Args: 323 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 324 | Returns: 325 | Tensor: Normalized image. 326 | """ 327 | image, depth = sample['image'], sample['depth'] 328 | 329 | image = self.normalize(image, self.mean, self.std) 330 | 331 | return {'image': image, 'depth': depth} 332 | 333 | def normalize(self, tensor, mean, std): 334 | """Normalize a tensor image with mean and standard deviation. 335 | See ``Normalize`` for more details. 336 | Args: 337 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 338 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 339 | std (sequence): Sequence of standard deviations for R, G, B channels 340 | respecitvely. 341 | Returns: 342 | Tensor: Normalized image. 343 | """ 344 | 345 | for t, m, s in zip(tensor, mean, std): 346 | t.sub_(m).div_(s) 347 | return tensor 348 | -------------------------------------------------------------------------------- /nyud2-dir/preprocess_nyud2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | from tqdm import tqdm 5 | from torchvision import transforms 6 | from torch.utils.data import DataLoader 7 | from nyu_transform import * 8 | from loaddata import depthDataset 9 | 10 | def load_data(args): 11 | train_dataset = depthDataset( 12 | csv_file=os.path.join(args.data_dir, 'nyu2_train.csv'), 13 | transform=transforms.Compose([ 14 | Scale(240), 15 | CenterCrop([304, 228], [304, 228]), 16 | ToTensor(is_test=False), 17 | ]) 18 | ) 19 | train_dataloader = DataLoader(train_dataset, 256, shuffle=False, num_workers=16, pin_memory=False) 20 | 21 | test_dataset = depthDataset( 22 | csv_file=os.path.join(args.data_dir, 'nyu2_test.csv'), 23 | transform=transforms.Compose([ 24 | Scale(240), 25 | CenterCrop([304, 228], [304, 228]), 26 | ToTensor(is_test=True), 27 | ]) 28 | ) 29 | # print(train_dataset.__len__(), test_dataset.__len__()) 30 | test_dataloader = DataLoader(test_dataset, 256, shuffle=False, num_workers=16, pin_memory=False) 31 | 32 | return train_dataloader, test_dataloader 33 | 34 | def create_FDS_train_subset_id(args): 35 | print('Creating FDS statistics updating subset ids...') 36 | train_dataloader, _ = load_data(args) 37 | train_depth_values = [] 38 | for i, sample in enumerate(tqdm(train_dataloader)): 39 | train_depth_values.append(sample['depth'].squeeze()) 40 | train_depth_values = torch.cat(train_depth_values, 0) 41 | select_idx = np.random.choice(a=list(range(train_depth_values.size(0))), size=600, replace=False) 42 | np.save(os.path.join(args.data_dir, 'FDS_train_subset_id.npy'), select_idx) 43 | 44 | def create_FDS_train_subset(args): 45 | print('Creating FDS statistics updating subset...') 46 | frame = pd.read_csv(os.path.join(args.data_dir, 'nyu2_train.csv'), header=None) 47 | select_id = np.load(os.path.join(args.data_dir, 'FDS_train_subset_id.npy')) 48 | frame.iloc[select_id].to_csv(os.path.join(args.data_dir, 'nyu2_train_FDS_subset.csv'), index=False, header=False) 49 | 50 | def get_bin_idx(x): 51 | return min(int(x * np.float32(10)), 99) 52 | 53 | def create_balanced_testset(args): 54 | print('Creating balanced test set mask...') 55 | _, test_dataloader = load_data(args) 56 | test_depth_values = [] 57 | 58 | for i, sample in enumerate(tqdm(test_dataloader)): 59 | test_depth_values.append(sample['depth'].squeeze()) 60 | test_depth_values = torch.cat(test_depth_values, 0) 61 | test_depth_values_flatten = test_depth_values.view(-1).numpy() 62 | test_bins_number, _ = np.histogram(a=test_depth_values_flatten, bins=100, range=(0., 10.)) 63 | 64 | select_pixel_num = min(test_bins_number[test_bins_number != 0]) 65 | test_depth_values_flatten_bins = np.array(list(map(lambda v: get_bin_idx(v), test_depth_values_flatten))) 66 | 67 | test_depth_flatten_mask = np.zeros(test_depth_values_flatten.shape[0], dtype=np.uint8) 68 | for i in range(7, 100): 69 | bucket_idx = np.where(test_depth_values_flatten_bins == i)[0] 70 | select_bucket_idx = np.random.choice(a=bucket_idx, size=select_pixel_num, replace=False) 71 | test_depth_flatten_mask[select_bucket_idx] = np.uint8(1) 72 | test_depth_mask = test_depth_flatten_mask.reshape(test_depth_values.numpy().shape) 73 | np.save(os.path.join(args.data_dir, 'test_balanced_mask.npy'), test_depth_mask) 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser(description='') 77 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory') 78 | args = parser.parse_args() 79 | 80 | create_FDS_train_subset_id(args) 81 | create_FDS_train_subset(args) 82 | create_balanced_testset(args) -------------------------------------------------------------------------------- /nyud2-dir/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | 10 | import loaddata 11 | from models import modules, net, resnet 12 | from util import Evaluator 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--eval_model', type=str, default='', help='evaluation model path') 17 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory') 18 | args = parser.parse_args() 19 | 20 | logging.root.handlers = [] 21 | logging.basicConfig( 22 | level=logging.INFO, 23 | format="%(asctime)s | %(message)s", 24 | handlers=[ 25 | logging.StreamHandler() 26 | ]) 27 | 28 | model = define_model() 29 | assert os.path.isfile(args.eval_model), f"No checkpoint found at '{args.eval_model}'" 30 | model = torch.nn.DataParallel(model).cuda() 31 | model_state = torch.load(args.eval_model) 32 | logging.info(f"Loading checkpoint from {args.eval_model}") 33 | model.load_state_dict(model_state['state_dict'], strict=False) 34 | logging.info('Loaded successfully!') 35 | 36 | test_loader = loaddata.getTestingData(args, 1) 37 | test(test_loader, model) 38 | 39 | def test(test_loader, model): 40 | model.eval() 41 | 42 | logging.info('Starting testing...') 43 | 44 | evaluator = Evaluator() 45 | 46 | with torch.no_grad(): 47 | for i, sample_batched in enumerate(tqdm(test_loader)): 48 | image, depth, mask = sample_batched['image'], sample_batched['depth'], sample_batched['mask'] 49 | depth = depth.cuda(non_blocking=True) 50 | image = image.cuda() 51 | output = model(image) 52 | output = nn.functional.interpolate(output, size=[depth.size(2),depth.size(3)], mode='bilinear', align_corners=True) 53 | 54 | evaluator(output[mask], depth[mask]) 55 | 56 | logging.info('Finished testing. Start printing statistics below...') 57 | metric_dict = evaluator.evaluate_shot() 58 | 59 | return metric_dict['overall']['RMSE'], metric_dict 60 | 61 | def define_model(): 62 | original_model = resnet.resnet50(pretrained = True) 63 | Encoder = modules.E_resnet(original_model) 64 | model = net.model(None, Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) 65 | 66 | return model 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /nyud2-dir/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import shutil 5 | import logging 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import loaddata 9 | from tqdm import tqdm 10 | from models import modules, net, resnet 11 | from util import query_yes_no 12 | from test import test 13 | from tensorboardX import SummaryWriter 14 | 15 | parser = argparse.ArgumentParser(description='') 16 | 17 | # training/optimization related 18 | parser.add_argument('--epochs', default=10, type=int, 19 | help='number of total epochs to run') 20 | parser.add_argument('--start_epoch', default=0, type=int, 21 | help='manual epoch number (useful on restarts)') 22 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, 23 | help='initial learning rate') 24 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, 25 | help='weight decay (default: 1e-4)') 26 | parser.add_argument('--batch_size', default=32, type=int, help='batch size number') # 1 GPU - 8 27 | parser.add_argument('--store_root', type=str, default='checkpoint') 28 | parser.add_argument('--store_name', type=str, default='nyud2') 29 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory') 30 | parser.add_argument('--resume', action='store_true', default=False, help='whether to resume training') 31 | 32 | # imbalanced related 33 | # LDS 34 | parser.add_argument('--lds', action='store_true', default=False, help='whether to enable LDS') 35 | parser.add_argument('--lds_kernel', type=str, default='gaussian', 36 | choices=['gaussian', 'triang', 'laplace'], help='LDS kernel type') 37 | parser.add_argument('--lds_ks', type=int, default=5, help='LDS kernel size: should be odd number') 38 | parser.add_argument('--lds_sigma', type=float, default=2, help='LDS gaussian/laplace kernel sigma') 39 | # FDS 40 | parser.add_argument('--fds', action='store_true', default=False, help='whether to enable FDS') 41 | parser.add_argument('--fds_kernel', type=str, default='gaussian', 42 | choices=['gaussian', 'triang', 'laplace'], help='FDS kernel type') 43 | parser.add_argument('--fds_ks', type=int, default=5, help='FDS kernel size: should be odd number') 44 | parser.add_argument('--fds_sigma', type=float, default=2, help='FDS gaussian/laplace kernel sigma') 45 | parser.add_argument('--start_update', type=int, default=0, help='which epoch to start FDS updating') 46 | parser.add_argument('--start_smooth', type=int, default=1, help='which epoch to start using FDS to smooth features') 47 | parser.add_argument('--bucket_num', type=int, default=100, help='maximum bucket considered for FDS') 48 | parser.add_argument('--bucket_start', type=int, default=7, help='minimum(starting) bucket for FDS, 7 for NYUDv2') 49 | parser.add_argument('--fds_mmt', type=float, default=0.9, help='FDS momentum') 50 | 51 | # re-weighting: SQRT_INV / INV 52 | parser.add_argument('--reweight', type=str, default='none', choices=['none', 'inverse', 'sqrt_inv'], 53 | help='cost-sensitive reweighting scheme') 54 | # two-stage training: RRT 55 | parser.add_argument('--retrain_fc', action='store_true', default=False, 56 | help='whether to retrain last regression layer (regressor)') 57 | parser.add_argument('--pretrained', type=str, default='', help='pretrained checkpoint file path to load backbone weights for RRT') 58 | 59 | def define_model(args): 60 | original_model = resnet.resnet50(pretrained=True) 61 | Encoder = modules.E_resnet(original_model) 62 | model = net.model(args, Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048]) 63 | 64 | return model 65 | 66 | def main(): 67 | error_best = 1e5 68 | metric_dict_best = {} 69 | epoch_best = -1 70 | 71 | global args 72 | args = parser.parse_args() 73 | 74 | if not args.lds and args.reweight != 'none': 75 | args.store_name += f'_{args.reweight}' 76 | if args.lds: 77 | args.store_name += f'_lds_{args.lds_kernel[:3]}_{args.lds_ks}' 78 | if args.lds_kernel in ['gaussian', 'laplace']: 79 | args.store_name += f'_{args.lds_sigma}' 80 | if args.fds: 81 | args.store_name += f'_fds_{args.fds_kernel[:3]}_{args.fds_ks}' 82 | if args.fds_kernel in ['gaussian', 'laplace']: 83 | args.store_name += f'_{args.fds_sigma}' 84 | args.store_name += f'_{args.start_update}_{args.start_smooth}_{args.fds_mmt}' 85 | if args.retrain_fc: 86 | args.store_name += f'_retrain_fc' 87 | args.store_name += f'_lr_{args.lr}_bs_{args.batch_size}' 88 | 89 | args.store_dir = os.path.join(args.store_root, args.store_name) 90 | 91 | if not args.resume: 92 | if os.path.exists(args.store_dir): 93 | if query_yes_no('overwrite previous folder: {} ?'.format(args.store_dir)): 94 | shutil.rmtree(args.store_dir) 95 | print(args.store_dir + ' removed.') 96 | else: 97 | raise RuntimeError('Output folder {} already exists'.format(args.store_dir)) 98 | print(f"===> Creating folder: {args.store_dir}") 99 | os.makedirs(args.store_dir) 100 | 101 | logging.root.handlers = [] 102 | log_file = os.path.join(args.store_dir, 'training_log.log') 103 | logging.basicConfig( 104 | level=logging.INFO, 105 | format="%(asctime)s | %(message)s", 106 | handlers=[ 107 | logging.FileHandler(log_file), 108 | logging.StreamHandler() 109 | ]) 110 | logging.info(args) 111 | 112 | writer = SummaryWriter(args.store_dir) 113 | 114 | model = define_model(args) 115 | model = torch.nn.DataParallel(model).cuda() 116 | 117 | if args.resume: 118 | model_state = torch.load(os.path.join(args.store_dir, 'checkpoint.pth.tar')) 119 | logging.info(f"Loading checkpoint from {os.path.join(args.store_dir, 'checkpoint.pth.tar')}" 120 | f" (Epoch [{model_state['epoch']}], RMSE: {model_state['error']:.3f})") 121 | model.load_state_dict(model_state['state_dict']) 122 | 123 | args.start_epoch = model_state['epoch'] + 1 124 | epoch_best = model_state['epoch'] 125 | error_best = model_state['error'] 126 | metric_dict_best = model_state['metric'] 127 | 128 | if args.retrain_fc: 129 | assert os.path.isfile(args.pretrained), f"No checkpoint found at '{args.pretrained}'" 130 | model_state = torch.load(args.pretrained, map_location="cpu") 131 | from collections import OrderedDict 132 | new_state_dict = OrderedDict() 133 | for k, v in model_state['state_dict'].items(): 134 | if 'R' not in k: 135 | new_state_dict[k] = v 136 | model.load_state_dict(new_state_dict, strict=False) 137 | logging.info(f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]') 138 | logging.info(f'===> Pre-trained model loaded: {args.pretrained}') 139 | for name, param in model.named_parameters(): 140 | if 'R' not in name: 141 | param.requires_grad = False 142 | logging.info(f'Only optimize parameters: {[n for n, p in model.named_parameters() if p.requires_grad]}') 143 | 144 | cudnn.benchmark = True 145 | if not args.retrain_fc: 146 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 147 | else: 148 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 149 | optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=args.weight_decay) 150 | 151 | train_loader = loaddata.getTrainingData(args, args.batch_size) 152 | train_fds_loader = loaddata.getTrainingFDSData(args, args.batch_size) 153 | test_loader = loaddata.getTestingData(args, 1) 154 | 155 | for epoch in range(args.start_epoch, args.epochs): 156 | adjust_learning_rate(optimizer, epoch) 157 | train(train_loader, train_fds_loader, model, optimizer, epoch, writer) 158 | error, metric_dict = test(test_loader, model) 159 | if error < error_best: 160 | error_best = error 161 | metric_dict_best = metric_dict 162 | epoch_best = epoch 163 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint_best.pth.tar') 164 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint.pth.tar') 165 | 166 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint_final.pth.tar') 167 | logging.info(f'Best epoch: {epoch_best}; RMSE: {error_best:.3f}') 168 | logging.info('***** TEST RESULTS *****') 169 | for shot in ['Overall', 'Many', 'Medium', 'Few']: 170 | logging.info(f" * {shot}: RMSE {metric_dict_best[shot.lower()]['RMSE']:.3f}\t" 171 | f"ABS_REL {metric_dict_best[shot.lower()]['ABS_REL']:.3f}\t" 172 | f"LG10 {metric_dict_best[shot.lower()]['LG10']:.3f}\t" 173 | f"MAE {metric_dict_best[shot.lower()]['MAE']:.3f}\t" 174 | f"DELTA1 {metric_dict_best[shot.lower()]['DELTA1']:.3f}\t" 175 | f"DELTA2 {metric_dict_best[shot.lower()]['DELTA2']:.3f}\t" 176 | f"DELTA3 {metric_dict_best[shot.lower()]['DELTA3']:.3f}\t" 177 | f"NUM {metric_dict_best[shot.lower()]['NUM']}") 178 | 179 | writer.close() 180 | 181 | def train(train_loader, train_fds_loader, model, optimizer, epoch, writer): 182 | batch_time = AverageMeter() 183 | losses = AverageMeter() 184 | 185 | model.train() 186 | 187 | end = time.time() 188 | for i, sample_batched in enumerate(train_loader): 189 | image, depth, weight = sample_batched['image'], sample_batched['depth'], sample_batched['weight'] 190 | 191 | depth = depth.cuda(non_blocking=True) 192 | weight = weight.cuda(non_blocking=True) 193 | image = image.cuda() 194 | optimizer.zero_grad() 195 | 196 | if args.fds: 197 | output, feature = model(image, depth, epoch) 198 | else: 199 | output = model(image, depth, epoch) 200 | loss = torch.mean(((output - depth) ** 2) * weight) 201 | 202 | losses.update(loss.item(), image.size(0)) 203 | loss.backward() 204 | optimizer.step() 205 | 206 | batch_time.update(time.time() - end) 207 | end = time.time() 208 | 209 | writer.add_scalar('data/loss', loss.item(), i + epoch * len(train_loader)) 210 | 211 | logging.info('Epoch: [{0}][{1}/{2}]\t' 212 | 'Time {batch_time.val:.3f} ({batch_time.sum:.3f})\t' 213 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 214 | .format(epoch, i, len(train_loader), batch_time=batch_time, loss=losses)) 215 | 216 | if args.fds and epoch >= args.start_update: 217 | logging.info(f"Starting Creating Epoch [{epoch}] features of subsampled training data...") 218 | encodings, depths = [], [] 219 | with torch.no_grad(): 220 | for i, sample_batched in enumerate(tqdm(train_fds_loader)): 221 | image, depth = sample_batched['image'].cuda(), sample_batched['depth'].cuda() 222 | _, feature = model(image, depth, epoch) 223 | encodings.append(feature.data.cpu()) 224 | depths.append(depth.data.cpu()) 225 | encodings, depths = torch.cat(encodings, 0), torch.cat(depths, 0) 226 | logging.info(f"Created Epoch [{epoch}] features of subsampled training data (size: {encodings.size(0)})!") 227 | model.module.R.FDS.update_last_epoch_stats(epoch) 228 | model.module.R.FDS.update_running_stats(encodings, depths, epoch) 229 | 230 | def adjust_learning_rate(optimizer, epoch): 231 | lr = args.lr * (0.1 ** (epoch // 5)) 232 | 233 | for param_group in optimizer.param_groups: 234 | param_group['lr'] = lr 235 | 236 | 237 | class AverageMeter(object): 238 | def __init__(self): 239 | self.reset() 240 | 241 | def reset(self): 242 | self.val = 0 243 | self.avg = 0 244 | self.sum = 0 245 | self.count = 0 246 | 247 | def update(self, val, n=1): 248 | self.val = val 249 | self.sum += val * n 250 | self.count += n 251 | self.avg = self.sum / self.count 252 | 253 | 254 | def save_checkpoint(state_dict, epoch, error, metric_dict, filename='checkpoint.pth.tar'): 255 | logging.info(f'Saving checkpoint to {os.path.join(args.store_dir, filename)}...') 256 | torch.save({ 257 | 'state_dict': state_dict, 258 | 'epoch': epoch, 259 | 'error': error, 260 | 'metric': metric_dict 261 | }, os.path.join(args.store_dir, filename)) 262 | 263 | if __name__ == '__main__': 264 | main() 265 | -------------------------------------------------------------------------------- /nyud2-dir/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import torch 4 | import numpy as np 5 | from scipy.ndimage import gaussian_filter1d 6 | from scipy.signal.windows import triang 7 | 8 | def lg10(x): 9 | return torch.div(torch.log(x), math.log(10)) 10 | 11 | def maxOfTwo(x, y): 12 | z = x.clone() 13 | maskYLarger = torch.lt(x, y) 14 | z[maskYLarger.detach()] = y[maskYLarger.detach()] 15 | return z 16 | 17 | def nValid(x): 18 | return torch.sum(torch.eq(x, x).float()) 19 | 20 | def getNanMask(x): 21 | return torch.ne(x, x) 22 | 23 | def setNanToZero(input, target): 24 | nanMask = getNanMask(target) 25 | nValidElement = nValid(target) 26 | 27 | _input = input.clone() 28 | _target = target.clone() 29 | 30 | _input[nanMask] = 0 31 | _target[nanMask] = 0 32 | 33 | return _input, _target, nanMask, nValidElement 34 | 35 | class Evaluator: 36 | def __init__(self): 37 | self.shot_idx = { 38 | 'many': [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 39 | 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 49], 40 | 'medium': [7, 8, 46, 48, 50, 51, 52, 53, 54, 55, 56, 58, 60, 61, 63], 41 | 'few': [0, 1, 2, 3, 4, 5, 6, 57, 59, 62, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 42 | 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] 43 | } 44 | self.output = torch.tensor([], dtype=torch.float32) 45 | self.depth = torch.tensor([], dtype=torch.float32) 46 | 47 | def __call__(self, output, depth): 48 | output = output.squeeze().view(-1).cpu() 49 | depth = depth.squeeze().view(-1).cpu() 50 | self.output = torch.cat([self.output, output]) 51 | self.depth = torch.cat([self.depth, depth]) 52 | 53 | def evaluate_shot(self): 54 | metric_dict = {'overall': {}, 'many': {}, 'medium': {}, 'few': {}} 55 | self.depth_bucket = np.array(list(map(lambda v: self.get_bin_idx(v), self.depth.cpu().numpy()))) 56 | 57 | for shot in metric_dict.keys(): 58 | if shot == 'overall': 59 | metric_dict[shot] = self.evaluate(self.output, self.depth) 60 | else: 61 | mask = np.zeros(self.depth.size(0), dtype=np.bool) 62 | for i in self.shot_idx[shot]: 63 | mask[np.where(self.depth_bucket == i)[0]] = True 64 | mask = torch.tensor(mask, dtype=torch.bool) 65 | metric_dict[shot] = self.evaluate(self.output[mask], self.depth[mask]) 66 | 67 | logging.info('\n***** TEST RESULTS *****') 68 | for shot in ['Overall', 'Many', 'Medium', 'Few']: 69 | logging.info(f" * {shot}: RMSE {metric_dict[shot.lower()]['RMSE']:.3f}\t" 70 | f"ABS_REL {metric_dict[shot.lower()]['ABS_REL']:.3f}\t" 71 | f"LG10 {metric_dict[shot.lower()]['LG10']:.3f}\t" 72 | f"MAE {metric_dict[shot.lower()]['MAE']:.3f}\t" 73 | f"DELTA1 {metric_dict[shot.lower()]['DELTA1']:.3f}\t" 74 | f"DELTA2 {metric_dict[shot.lower()]['DELTA2']:.3f}\t" 75 | f"DELTA3 {metric_dict[shot.lower()]['DELTA3']:.3f}\t" 76 | f"NUM {metric_dict[shot.lower()]['NUM']}") 77 | 78 | return metric_dict 79 | 80 | def reset(self): 81 | self.output = torch.tensor([], dtype=torch.float32) 82 | self.depth = torch.tensor([], dtype=torch.float32) 83 | 84 | @staticmethod 85 | def get_bin_idx(x): 86 | return min(int(x * np.float32(10)), 99) 87 | 88 | @staticmethod 89 | def evaluate(output, target): 90 | errors = {'MSE': 0, 'RMSE': 0, 'ABS_REL': 0, 'LG10': 0, 91 | 'MAE': 0, 'DELTA1': 0, 'DELTA2': 0, 'DELTA3': 0, 'NUM': 0} 92 | 93 | _output, _target, nanMask, nValidElement = setNanToZero(output, target) 94 | 95 | if (nValidElement.data.cpu().numpy() > 0): 96 | diffMatrix = torch.abs(_output - _target) 97 | 98 | errors['MSE'] = torch.sum(torch.pow(diffMatrix, 2)) / nValidElement 99 | 100 | errors['MAE'] = torch.sum(diffMatrix) / nValidElement 101 | 102 | realMatrix = torch.div(diffMatrix, _target) 103 | realMatrix[nanMask] = 0 104 | errors['ABS_REL'] = torch.sum(realMatrix) / nValidElement 105 | 106 | LG10Matrix = torch.abs(lg10(_output) - lg10(_target)) 107 | LG10Matrix[nanMask] = 0 108 | errors['LG10'] = torch.sum(LG10Matrix) / nValidElement 109 | 110 | yOverZ = torch.div(_output, _target) 111 | zOverY = torch.div(_target, _output) 112 | 113 | maxRatio = maxOfTwo(yOverZ, zOverY) 114 | 115 | errors['DELTA1'] = torch.sum( 116 | torch.le(maxRatio, 1.25).float()) / nValidElement 117 | errors['DELTA2'] = torch.sum( 118 | torch.le(maxRatio, math.pow(1.25, 2)).float()) / nValidElement 119 | errors['DELTA3'] = torch.sum( 120 | torch.le(maxRatio, math.pow(1.25, 3)).float()) / nValidElement 121 | 122 | errors['MSE'] = float(errors['MSE'].data.cpu().numpy()) 123 | errors['ABS_REL'] = float(errors['ABS_REL'].data.cpu().numpy()) 124 | errors['LG10'] = float(errors['LG10'].data.cpu().numpy()) 125 | errors['MAE'] = float(errors['MAE'].data.cpu().numpy()) 126 | errors['DELTA1'] = float(errors['DELTA1'].data.cpu().numpy()) 127 | errors['DELTA2'] = float(errors['DELTA2'].data.cpu().numpy()) 128 | errors['DELTA3'] = float(errors['DELTA3'].data.cpu().numpy()) 129 | errors['NUM'] = int(nValidElement) 130 | 131 | errors['RMSE'] = np.sqrt(errors['MSE']) 132 | 133 | return errors 134 | 135 | 136 | def query_yes_no(question): 137 | """ Ask a yes/no question via input() and return their answer. """ 138 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} 139 | prompt = " [Y/n] " 140 | 141 | while True: 142 | print(question + prompt, end=':') 143 | choice = input().lower() 144 | if choice == '': 145 | return valid['y'] 146 | elif choice in valid: 147 | return valid[choice] 148 | else: 149 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") 150 | 151 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.2, clip_max=5.): 152 | if torch.sum(v1) < 1e-10: 153 | return matrix 154 | if (v1 <= 0.).any() or (v2 < 0.).any(): 155 | valid_pos = (((v1 > 0.) + (v2 >= 0.)) == 2) 156 | # print(torch.sum(valid_pos)) 157 | factor = torch.clamp(v2[valid_pos] / v1[valid_pos], clip_min, clip_max) 158 | matrix[:, valid_pos] = (matrix[:, valid_pos] - m1[valid_pos]) * torch.sqrt(factor) + m2[valid_pos] 159 | return matrix 160 | 161 | factor = torch.clamp(v2 / v1, clip_min, clip_max) 162 | return (matrix - m1) * torch.sqrt(factor) + m2 163 | 164 | def get_lds_kernel_window(kernel, ks, sigma): 165 | assert kernel in ['gaussian', 'triang', 'laplace'] 166 | half_ks = (ks - 1) // 2 167 | if kernel == 'gaussian': 168 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 169 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) 170 | elif kernel == 'triang': 171 | kernel_window = triang(ks) 172 | else: 173 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 174 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1))) 175 | 176 | return kernel_window 177 | 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /sts-b-dir/README.md: -------------------------------------------------------------------------------- 1 | # STS-B-DIR 2 | ## Installation 3 | 4 | #### Prerequisites 5 | 6 | 1. Download GloVe word embeddings (840B tokens, 300D vectors) using 7 | 8 | ```bash 9 | python glove/download_glove.py 10 | ``` 11 | 12 | 2. __(Optional)__ We have provided both original STS-B dataset and our created balanced STS-B-DIR dataset in folder `./glue_data/STS-B`. To reproduce the results in the paper, please use our created STS-B-DIR dataset. If you want to try different balanced splits, you can delete the folder `./glue_data/STS-B` and run 13 | 14 | ```bash 15 | python glue_data/create_sts.py 16 | ``` 17 | 18 | #### Dependencies 19 | 20 | The required dependencies for this task are quite different to other three tasks, so it's better to create a new environment for this task. If you use conda, you can create the environment and install dependencies using the following commands: 21 | 22 | ```bash 23 | conda create -n sts python=3.6 24 | conda activate sts 25 | # PyTorch 0.4 (required) + Cuda 9.2 26 | conda install pytorch=0.4.1 cuda92 -c pytorch 27 | # other dependencies 28 | pip install -r requirements.txt 29 | # The current latest "overrides" dependency installed along with allennlp 0.5.0 will now raise error. 30 | # We need to downgrade "overrides" version to 3.1.0 31 | pip install overrides==3.1.0 32 | ``` 33 | 34 | ## Code Overview 35 | 36 | #### Main Files 37 | 38 | - `train.py`: main training and evaluation script 39 | - `create_sts.py`: download original STS-B dataset and create STS-B-DIR dataset with balanced val/test set 40 | 41 | #### Main Arguments 42 | 43 | - `--lds`: LDS switch (whether to enable LDS) 44 | - `--fds`: FDS switch (whether to enable FDS) 45 | - `--reweight`: cost-sensitive re-weighting scheme to use 46 | - `--loss`: training loss type 47 | - `--resume`: whether to resume training (only for training) 48 | - `--evaluate`: evaluate only flag 49 | - `--eval_model`: path to resume checkpoint (only for evaluation) 50 | - `--retrain_fc`: whether to retrain regressor 51 | - `--pretrained`: path to load backbone weights for regressor re-training (RRT) 52 | - `--val_interval`: number of iterations between validation checks 53 | - `--patience`: patience (number of validation checks) for early stopping 54 | 55 | ## Getting Started 56 | 57 | #### Train a vanilla model 58 | 59 | ```bash 60 | python train.py --cuda --reweight none 61 | ``` 62 | 63 | Always specify `--cuda ` for the GPU ID (single GPU) to be used. We will omit this argument in the following for simplicity. 64 | 65 | #### Train a model using re-weighting 66 | 67 | To perform inverse re-weighting 68 | 69 | ```bash 70 | python train.py --reweight inverse 71 | ``` 72 | 73 | To perform square-root inverse re-weighting 74 | 75 | ```bash 76 | python train.py --reweight sqrt_inv 77 | ``` 78 | 79 | #### Train a model with different losses 80 | 81 | To use Focal-R loss 82 | 83 | ```bash 84 | python train.py --loss focal_mse 85 | ``` 86 | 87 | To use huber loss 88 | 89 | ```bash 90 | python train.py --loss huber --huber_beta 0.3 91 | ``` 92 | 93 | #### Train a model using RRT 94 | 95 | ```bash 96 | python train.py [...retrained model arguments...] --retrain_fc --pretrained 97 | ``` 98 | 99 | #### Train a model using LDS 100 | 101 | To use Gaussian kernel (kernel size: 5, sigma: 2) 102 | 103 | ```bash 104 | python train.py --reweight inverse --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 105 | ``` 106 | 107 | #### Train a model using FDS 108 | 109 | To use Gaussian kernel (kernel size: 5, sigma: 2) 110 | 111 | ```bash 112 | python train.py --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 113 | ``` 114 | 115 | #### Train a model using LDS + FDS 116 | 117 | ```bash 118 | python train.py --reweight inverse --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 119 | ``` 120 | 121 | #### Evaluate a trained checkpoint 122 | 123 | ```bash 124 | python train.py [...evaluation model arguments...] --evaluate --eval_model 125 | ``` 126 | 127 | ## Reproduced Benchmarks and Model Zoo 128 | 129 | We provide below reproduced results on STS-B-DIR (base method `Vanilla`, metric `MSE`). 130 | Note that some models could give **better** results than the reported numbers in the paper. 131 | 132 | | Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download | 133 | | :-------: | :-----: | :-------: | :---------: | :------: | :------: | 134 | | LDS | 0.914 | 0.819 | 1.319 | 0.955 | [model](https://drive.google.com/file/d/1CVyycq0OMgD9N9gJX5UDcfpRaJdZBjzo/view?usp=sharing) | 135 | | FDS | 0.916 | 0.875 | 1.027 | 1.086 | [model](https://drive.google.com/file/d/13e-1kd-KQrzFFVrJp1FeNDIBwUp3qtYx/view?usp=sharing) | 136 | | LDS + FDS | 0.907 | 0.802 | 1.363 | 0.942 | [model](https://drive.google.com/file/d/1kb_GV2coJRK_o9OxnMcxchq1EKOcpx-h/view?usp=sharing) | 137 | -------------------------------------------------------------------------------- /sts-b-dir/allennlp_mods/numeric_field.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | import pdb 3 | import logging 4 | 5 | from overrides import overrides 6 | import numpy 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from allennlp.data.fields.field import Field 11 | from allennlp.data.vocabulary import Vocabulary 12 | from allennlp.common.checks import ConfigurationError 13 | 14 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 15 | 16 | 17 | class NumericField(Field[numpy.ndarray]): 18 | """ 19 | A ``LabelField`` is a categorical label of some kind, where the labels are either strings of 20 | text or 0-indexed integers (if you wish to skip indexing by passing skip_indexing=True). 21 | If the labels need indexing, we will use a :class:`Vocabulary` to convert the string labels 22 | into integers. 23 | 24 | This field will get converted into an integer index representing the class label. 25 | 26 | Parameters 27 | ---------- 28 | label : ``Union[str, int]`` 29 | label_namespace : ``str``, optional (default="labels") 30 | The namespace to use for converting label strings into integers. We map label strings to 31 | integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...), 32 | and this namespace tells the ``Vocabulary`` object which mapping from strings to integers 33 | to use (so "entailment" as a label doesn't get the same integer id as "entailment" as a 34 | word). If you have multiple different label fields in your data, you should make sure you 35 | use different namespaces for each one, always using the suffix "labels" (e.g., 36 | "passage_labels" and "question_labels"). 37 | skip_indexing : ``bool``, optional (default=False) 38 | If your labels are 0-indexed integers, you can pass in this flag, and we'll skip the indexing 39 | step. If this is ``False`` and your labels are not strings, this throws a ``ConfigurationError``. 40 | """ 41 | def __init__(self, 42 | label: Union[float, int], 43 | label_namespace: str = 'labels') -> None: 44 | self.label = label 45 | self._label_namespace = label_namespace 46 | self._label_id = numpy.array(label, dtype=numpy.float32) 47 | if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")): 48 | logger.warning("Your label namespace was '%s'. We recommend you use a namespace " 49 | "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by " 50 | "default to your vocabulary. See documentation for " 51 | "`non_padded_namespaces` parameter in Vocabulary.", self._label_namespace) 52 | 53 | # idk what this is for 54 | @overrides 55 | def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): 56 | if self._label_id is None: 57 | counter[self._label_namespace][self.label] += 1 # type: ignore 58 | 59 | @overrides 60 | def get_padding_lengths(self) -> Dict[str, int]: # pylint: disable=no-self-use 61 | return {} 62 | 63 | def as_array(self, padding_lengths: Dict[str, int]) -> numpy.ndarray: # pylint: disable=unused-argument 64 | return numpy.asarray([self._label_id]) 65 | 66 | @overrides 67 | def as_tensor(self, padding_lengths: Dict[str, int], 68 | cuda_device: int = -1, 69 | for_training: bool = True) -> torch.Tensor: # pylint: disable=unused-argument 70 | label_id = self._label_id.tolist() 71 | tensor = Variable(torch.FloatTensor([label_id]), volatile=not for_training) 72 | return tensor if cuda_device == -1 else tensor.cuda(cuda_device) 73 | 74 | @overrides 75 | def empty_field(self): 76 | return NumericField(0, self._label_namespace) 77 | -------------------------------------------------------------------------------- /sts-b-dir/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tqdm 3 | import numpy as np 4 | 5 | def evaluate(model, tasks, iterator, cuda_device, split="val"): 6 | '''Evaluate on a dataset''' 7 | model.eval() 8 | 9 | all_preds = {} 10 | n_overall_examples = 0 11 | for task in tasks: 12 | n_examples = 0 13 | task_preds, task_idxs, task_labels = [], [], [] 14 | if split == "val": 15 | dataset = task.val_data 16 | elif split == 'train': 17 | dataset = task.train_data 18 | elif split == "test": 19 | dataset = task.test_data 20 | generator = iterator(dataset, num_epochs=1, shuffle=False, cuda_device=cuda_device) 21 | generator_tqdm = tqdm.tqdm(generator, total=iterator.get_num_batches(dataset), disable=True) 22 | for batch in generator_tqdm: 23 | tensor_batch = batch 24 | out = model.forward(task, **tensor_batch) 25 | n_examples += batch['label'].size()[0] 26 | preds, _ = out['logits'].max(dim=1) 27 | task_preds += list(preds.data.cpu().numpy()) 28 | task_labels += list(batch['label'].squeeze().data.cpu().numpy()) 29 | 30 | task_metrics = task.get_metrics(reset=True) 31 | logging.info('\n***** TEST RESULTS *****') 32 | for shot in ['Overall', 'Many', 'Medium', 'Few']: 33 | logging.info(f" * {shot}: MSE {task_metrics[shot.lower()]['mse']:.3f}\t" 34 | f"L1 {task_metrics[shot.lower()]['l1']:.3f}\t" 35 | f"G-Mean {task_metrics[shot.lower()]['gmean']:.3f}\t" 36 | f"Pearson {task_metrics[shot.lower()]['pearsonr']:.3f}\t" 37 | f"Spearman {task_metrics[shot.lower()]['spearmanr']:.3f}\t" 38 | f"Number {task_metrics[shot.lower()]['num_samples']}") 39 | 40 | n_overall_examples += n_examples 41 | task_preds = [min(max(np.float32(0.), pred * np.float32(5.)), np.float32(5.)) for pred in task_preds] 42 | all_preds[task.name] = (task_preds, task_idxs) 43 | 44 | return task_preds, task_labels, task_metrics['overall']['mse'] -------------------------------------------------------------------------------- /sts-b-dir/fds.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import logging 6 | from scipy.ndimage import gaussian_filter1d 7 | from util import calibrate_mean_var 8 | from scipy.signal.windows import triang 9 | 10 | class FDS(nn.Module): 11 | 12 | def __init__(self, feature_dim, bucket_num=50, bucket_start=0, start_update=0, start_smooth=1, 13 | kernel='gaussian', ks=5, sigma=2, momentum=0.9): 14 | super(FDS, self).__init__() 15 | self.feature_dim = feature_dim 16 | self.bucket_num = bucket_num 17 | self.bucket_start = bucket_start 18 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma) 19 | self.half_ks = (ks - 1) // 2 20 | self.momentum = momentum 21 | self.start_update = start_update 22 | self.start_smooth = start_smooth 23 | 24 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update)) 25 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim)) 26 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim)) 27 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 28 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 29 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim)) 30 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim)) 31 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start)) 32 | 33 | @staticmethod 34 | def _get_kernel_window(kernel, ks, sigma): 35 | assert kernel in ['gaussian', 'triang', 'laplace'] 36 | half_ks = (ks - 1) // 2 37 | if kernel == 'gaussian': 38 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 39 | base_kernel = np.array(base_kernel, dtype=np.float32) 40 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma)) 41 | elif kernel == 'triang': 42 | kernel_window = triang(ks) / sum(triang(ks)) 43 | else: 44 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 45 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum( 46 | map(laplace, np.arange(-half_ks, half_ks + 1))) 47 | 48 | logging.info(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})') 49 | return torch.tensor(kernel_window, dtype=torch.float32).cuda() 50 | 51 | def _get_bucket_idx(self, label): 52 | label = np.float32(label) 53 | _, bins_edges = np.histogram(a=np.array([], dtype=np.float32), bins=self.bucket_num, range=(0., 5.)) 54 | if label == 5.: 55 | return self.bucket_num - 1 56 | else: 57 | return max(np.where(bins_edges > label)[0][0] - 1, self.bucket_start) 58 | 59 | def _update_last_epoch_stats(self): 60 | self.running_mean_last_epoch = self.running_mean 61 | self.running_var_last_epoch = self.running_var 62 | 63 | self.smoothed_mean_last_epoch = F.conv1d( 64 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0), 65 | pad=(self.half_ks, self.half_ks), mode='reflect'), 66 | weight=self.kernel_window.view(1, 1, -1), padding=0 67 | ).permute(2, 1, 0).squeeze(1) 68 | self.smoothed_var_last_epoch = F.conv1d( 69 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0), 70 | pad=(self.half_ks, self.half_ks), mode='reflect'), 71 | weight=self.kernel_window.view(1, 1, -1), padding=0 72 | ).permute(2, 1, 0).squeeze(1) 73 | 74 | def reset(self): 75 | self.running_mean.zero_() 76 | self.running_var.fill_(1) 77 | self.running_mean_last_epoch.zero_() 78 | self.running_var_last_epoch.fill_(1) 79 | self.smoothed_mean_last_epoch.zero_() 80 | self.smoothed_var_last_epoch.fill_(1) 81 | self.num_samples_tracked.zero_() 82 | 83 | def update_last_epoch_stats(self, epoch): 84 | if epoch == self.epoch + 1: 85 | self.epoch += 1 86 | self._update_last_epoch_stats() 87 | logging.info(f"Updated smoothed statistics of last epoch on Epoch [{epoch}]!") 88 | 89 | def update_running_stats(self, features, labels, epoch): 90 | if epoch < self.epoch: 91 | return 92 | 93 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!" 94 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!" 95 | 96 | buckets = np.array([self._get_bucket_idx(label) for label in labels]) 97 | for bucket in np.unique(buckets): 98 | curr_feats = features[torch.tensor((buckets == bucket).astype(np.uint8))] 99 | curr_num_sample = curr_feats.size(0) 100 | curr_mean = torch.mean(curr_feats, 0) 101 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False) 102 | 103 | self.num_samples_tracked[bucket - self.bucket_start] += curr_num_sample 104 | factor = self.momentum if self.momentum is not None else \ 105 | (1 - curr_num_sample / float(self.num_samples_tracked[bucket - self.bucket_start])) 106 | factor = 0 if epoch == self.start_update else factor 107 | self.running_mean[bucket - self.bucket_start] = \ 108 | (1 - factor) * curr_mean + factor * self.running_mean[bucket - self.bucket_start] 109 | self.running_var[bucket - self.bucket_start] = \ 110 | (1 - factor) * curr_var + factor * self.running_var[bucket - self.bucket_start] 111 | 112 | # make up for zero training samples buckets 113 | for bucket in range(self.bucket_start, self.bucket_num): 114 | if bucket not in np.unique(buckets): 115 | if bucket == self.bucket_start: 116 | self.running_mean[0] = self.running_mean[1] 117 | self.running_var[0] = self.running_var[1] 118 | elif bucket == self.bucket_num - 1: 119 | self.running_mean[bucket - self.bucket_start] = self.running_mean[bucket - self.bucket_start - 1] 120 | self.running_var[bucket - self.bucket_start] = self.running_var[bucket - self.bucket_start - 1] 121 | else: 122 | self.running_mean[bucket - self.bucket_start] = (self.running_mean[bucket - self.bucket_start - 1] + 123 | self.running_mean[bucket - self.bucket_start + 1]) / 2. 124 | self.running_var[bucket - self.bucket_start] = (self.running_var[bucket - self.bucket_start - 1] + 125 | self.running_var[bucket - self.bucket_start + 1]) / 2. 126 | logging.info(f"Updated running statistics with Epoch [{epoch}] features!") 127 | 128 | def smooth(self, features, labels, epoch): 129 | if epoch < self.start_smooth: 130 | return features 131 | 132 | labels = labels.squeeze(1) 133 | buckets = np.array([self._get_bucket_idx(label) for label in labels]) 134 | for bucket in np.unique(buckets): 135 | features[torch.tensor((buckets == bucket).astype(np.uint8))] = calibrate_mean_var( 136 | features[torch.tensor((buckets == bucket).astype(np.uint8))], 137 | self.running_mean_last_epoch[bucket - self.bucket_start], 138 | self.running_var_last_epoch[bucket - self.bucket_start], 139 | self.smoothed_mean_last_epoch[bucket - self.bucket_start], 140 | self.smoothed_var_last_epoch[bucket - self.bucket_start] 141 | ) 142 | 143 | return features 144 | -------------------------------------------------------------------------------- /sts-b-dir/glove/download_glove.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import zipfile 4 | 5 | print("Downloading and extracting GloVe word embeddings...") 6 | data_file = "./glove/glove.840B.300d.zip" 7 | wget.download("http://nlp.stanford.edu/data/glove.840B.300d.zip", out=data_file) 8 | with zipfile.ZipFile(data_file) as zip_ref: 9 | zip_ref.extractall('./glove') 10 | os.remove(data_file) 11 | print("\nCompleted!") -------------------------------------------------------------------------------- /sts-b-dir/glue_data/create_sts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import codecs 4 | import numpy as np 5 | import urllib 6 | if sys.version_info >= (3, 0): 7 | import urllib.request 8 | import zipfile 9 | 10 | URLLIB=urllib 11 | if sys.version_info >= (3, 0): 12 | URLLIB=urllib.request 13 | 14 | ##### Downloading raw STS-B dataset 15 | print("Downloading and extracting STS-B...") 16 | data_file = "./glue_data/STS-B.zip" 17 | URLLIB.urlretrieve("https://dl.fbaipublicfiles.com/glue/data/STS-B.zip", data_file) 18 | with zipfile.ZipFile(data_file) as zip_ref: 19 | zip_ref.extractall('./glue_data') 20 | os.remove(data_file) 21 | print("Completed!") 22 | 23 | ##### Creating STS-B-DIR dataset 24 | print("Creating STS-B-DIR dataset...") 25 | contents = {'train': [], 'dev': [], 'test': []} 26 | target = {'train': [], 'dev': [], 'test': []} 27 | 28 | for set_name in ['train', 'dev']: 29 | with codecs.open(f'./glue_data/STS-B/{set_name}.tsv', 'r', 'utf-8') as data_fh: 30 | for _ in range(1): 31 | titles = data_fh.readline() 32 | for row_idx, row in enumerate(data_fh): 33 | contents[set_name].append(row) 34 | row = row.strip().split('\t') 35 | targ = row[9] 36 | target[set_name].append(np.float32(targ)) 37 | 38 | bins = 20 39 | target_all = target['train'] + target['dev'] 40 | contents_all = contents['train'] + contents['dev'] 41 | 42 | bins_numbers, bins_edges = np.histogram(a=target_all, bins=bins, range=(0., 5.)) 43 | 44 | def _get_bin_idx(label): 45 | if label == 5.: 46 | return bins - 1 47 | else: 48 | return np.where(bins_edges > label)[0][0] - 1 49 | 50 | bins_contents = [None] * bins 51 | bins_targets = [None] * bins 52 | bins_numbers = list(bins_numbers) 53 | 54 | for i, score in enumerate(target_all): 55 | bin_idx = _get_bin_idx(score) 56 | if bins_contents[bin_idx] is None: 57 | bins_contents[bin_idx] = [] 58 | if bins_targets[bin_idx] is None: 59 | bins_targets[bin_idx] = [] 60 | bins_contents[bin_idx].append(contents_all[i]) 61 | bins_targets[bin_idx].append(score) 62 | contents_bins_numbers = [] 63 | targets_bins_numbers = [] 64 | for i in range(bins): 65 | contents_bins_numbers.append(len(bins_contents[i])) 66 | targets_bins_numbers.append(len(bins_targets[i])) 67 | 68 | new_contents = {'train': [None] * bins, 'dev': [None] * bins, 'test': [None] * bins} 69 | new_targets = {'train': [None] * bins, 'dev': [None] * bins, 'test': [None] * bins} 70 | select_num = 100 71 | for i in range(bins): 72 | new_index = {} 73 | new_dev_test_index = np.random.choice(a=list(range(bins_numbers[i])), size=select_num, replace=False) 74 | new_index['train'] = np.setdiff1d(np.array(range(bins_numbers[i])), new_dev_test_index) 75 | new_index['dev'] = np.random.choice(a=new_dev_test_index, size=int(select_num / 2), replace=False) 76 | new_index['test'] = np.setdiff1d(new_dev_test_index, new_index['dev']) 77 | for set_name in ['train', 'dev', 'test']: 78 | new_contents[set_name][i] = np.array(bins_contents[i])[new_index[set_name]] 79 | new_targets[set_name][i] = np.array(bins_targets[i])[new_index[set_name]] 80 | 81 | new_contents_merged = {'train': [], 'dev': [], 'test': []} 82 | for i in range(bins): 83 | for set_name in ['train', 'dev', 'test']: 84 | new_contents_merged[set_name] += new_contents[set_name][i].tolist() 85 | print('Number of samples for train/dev/test set in STS-B-DIR:', 86 | len(new_contents_merged['train']), len(new_contents_merged['dev']), len(new_contents_merged['test'])) 87 | 88 | for set_name in ['train', 'dev', 'test']: 89 | for i in range(len(new_contents_merged[set_name])): 90 | content_split = new_contents_merged[set_name][i].split('\t') 91 | content_split[0] = str(i) 92 | content_split = '\t'.join(content_split) 93 | new_contents_merged[set_name][i] = content_split 94 | for set_name in ['train', 'dev', 'test']: 95 | with open(f'./glue_data/STS-B/{set_name}_new.tsv', 'w') as f: 96 | f.write(titles) 97 | for i in range(len(new_contents_merged[set_name])): 98 | f.write(new_contents_merged[set_name][i]) 99 | print("STS-B-DIR dataset created! ('./glue_data/STS-B/(train_new/dev_new/test_new).tsv')") 100 | -------------------------------------------------------------------------------- /sts-b-dir/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def weighted_mse_loss(inputs, targets, weights=None): 6 | loss = F.mse_loss(inputs, targets, reduce=False) 7 | if weights is not None: 8 | loss *= weights.expand_as(loss) 9 | loss = torch.mean(loss) 10 | return loss 11 | 12 | 13 | def weighted_l1_loss(inputs, targets, weights=None): 14 | loss = F.l1_loss(inputs, targets, reduce=False) 15 | if weights is not None: 16 | loss *= weights.expand_as(loss) 17 | loss = torch.mean(loss) 18 | return loss 19 | 20 | 21 | def weighted_huber_loss(inputs, targets, weights=None, beta=0.5): 22 | l1_loss = torch.abs(inputs - targets) 23 | cond = l1_loss < beta 24 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta) 25 | if weights is not None: 26 | loss *= weights.expand_as(loss) 27 | loss = torch.mean(loss) 28 | return loss 29 | 30 | 31 | def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=20., gamma=1): 32 | loss = F.mse_loss(inputs, targets, reduce=False) 33 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 34 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 35 | if weights is not None: 36 | loss *= weights.expand_as(loss) 37 | loss = torch.mean(loss) 38 | return loss 39 | 40 | 41 | def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=20., gamma=1): 42 | loss = F.l1_loss(inputs, targets, reduce=False) 43 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \ 44 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma 45 | if weights is not None: 46 | loss *= weights.expand_as(loss) 47 | loss = torch.mean(loss) 48 | return loss 49 | -------------------------------------------------------------------------------- /sts-b-dir/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from allennlp.common import Params 5 | from allennlp.models.model import Model 6 | from allennlp.modules import Highway 7 | from allennlp.modules import TimeDistributed 8 | from allennlp.nn import util, InitializerApplicator 9 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder 10 | from allennlp.modules.token_embedders import Embedding 11 | from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder as s2s_e 12 | 13 | from fds import FDS 14 | from loss import * 15 | 16 | def build_model(args, vocab, pretrained_embs, tasks): 17 | ''' 18 | Build model according to arguments 19 | ''' 20 | d_word, n_layers_highway = args.d_word, args.n_layers_highway 21 | 22 | # Build embedding layers 23 | if args.glove: 24 | word_embs = pretrained_embs 25 | train_embs = bool(args.train_words) 26 | else: 27 | logging.info("\tLearning embeddings from scratch!") 28 | word_embs = None 29 | train_embs = True 30 | word_embedder = Embedding(vocab.get_vocab_size('tokens'), d_word, weight=word_embs, trainable=train_embs, 31 | padding_index=vocab.get_token_index('@@PADDING@@')) 32 | d_inp_phrase = 0 33 | 34 | token_embedder = {"words": word_embedder} 35 | d_inp_phrase += d_word 36 | text_field_embedder = BasicTextFieldEmbedder(token_embedder) 37 | d_hid_phrase = args.d_hid 38 | 39 | # Build encoders 40 | phrase_layer = s2s_e.by_name('lstm').from_params(Params({'input_size': d_inp_phrase, 41 | 'hidden_size': d_hid_phrase, 42 | 'num_layers': args.n_layers_enc, 43 | 'bidirectional': True})) 44 | pair_encoder = HeadlessPairEncoder(vocab, text_field_embedder, n_layers_highway, 45 | phrase_layer, dropout=args.dropout) 46 | d_pair = 2 * d_hid_phrase 47 | 48 | if args.fds: 49 | _FDS = FDS(feature_dim=d_pair * 4, bucket_num=args.bucket_num, bucket_start=args.bucket_start, 50 | start_update=args.start_update, start_smooth=args.start_smooth, 51 | kernel=args.fds_kernel, ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt) 52 | 53 | # Build model and classifiers 54 | model = MultiTaskModel(args, pair_encoder, _FDS if args.fds else None) 55 | build_regressor(tasks, model, d_pair) 56 | 57 | if args.cuda >= 0: 58 | model = model.cuda() 59 | 60 | return model 61 | 62 | def build_regressor(tasks, model, d_pair): 63 | ''' 64 | Build the regressor 65 | ''' 66 | for task in tasks: 67 | d_task = d_pair * 4 68 | model.build_regressor(task, d_task) 69 | return 70 | 71 | class MultiTaskModel(nn.Module): 72 | def __init__(self, args, pair_encoder, FDS=None): 73 | super(MultiTaskModel, self).__init__() 74 | self.args = args 75 | self.pair_encoder = pair_encoder 76 | 77 | self.FDS = FDS 78 | self.start_smooth = args.start_smooth 79 | 80 | def build_regressor(self, task, d_inp): 81 | layer = nn.Linear(d_inp, 1) 82 | setattr(self, '%s_pred_layer' % task.name, layer) 83 | 84 | def forward(self, task=None, epoch=None, input1=None, input2=None, mask1=None, mask2=None, label=None, weight=None): 85 | pred_layer = getattr(self, '%s_pred_layer' % task.name) 86 | 87 | pair_emb = self.pair_encoder(input1, input2, mask1, mask2) 88 | pair_emb_s = pair_emb 89 | if self.training and self.FDS is not None: 90 | if epoch >= self.start_smooth: 91 | pair_emb_s = self.FDS.smooth(pair_emb_s, label, epoch) 92 | logits = pred_layer(pair_emb_s) 93 | 94 | out = {} 95 | if self.training and self.FDS is not None: 96 | out['embs'] = pair_emb 97 | out['labels'] = label 98 | 99 | if self.args.loss == 'huber': 100 | loss = globals()[f"weighted_{self.args.loss}_loss"]( 101 | inputs=logits, targets=label / torch.tensor(5.).cuda(), weights=weight, 102 | beta=self.args.huber_beta 103 | ) 104 | else: 105 | loss = globals()[f"weighted_{self.args.loss}_loss"]( 106 | inputs=logits, targets=label / torch.tensor(5.).cuda(), weights=weight 107 | ) 108 | out['logits'] = logits 109 | label = label.squeeze(-1).data.cpu().numpy() 110 | logits = logits.squeeze(-1).data.cpu().numpy() 111 | task.scorer(logits, label) 112 | out['loss'] = loss 113 | 114 | return out 115 | 116 | class HeadlessPairEncoder(Model): 117 | def __init__(self, vocab, text_field_embedder, num_highway_layers, phrase_layer, 118 | dropout=0.2, mask_lstms=True, initializer=InitializerApplicator()): 119 | super(HeadlessPairEncoder, self).__init__(vocab) 120 | 121 | self._text_field_embedder = text_field_embedder 122 | d_emb = text_field_embedder.get_output_dim() 123 | self._highway_layer = TimeDistributed(Highway(d_emb, num_highway_layers)) 124 | 125 | self._phrase_layer = phrase_layer 126 | self.pad_idx = vocab.get_token_index(vocab._padding_token) 127 | self.output_dim = phrase_layer.get_output_dim() 128 | 129 | if dropout > 0: 130 | self._dropout = torch.nn.Dropout(p=dropout) 131 | else: 132 | self._dropout = lambda x: x 133 | self._mask_lstms = mask_lstms 134 | 135 | initializer(self) 136 | 137 | def forward(self, s1, s2, m1=None, m2=None): 138 | s1_embs = self._highway_layer(self._text_field_embedder(s1) if m1 is None else s1) 139 | s2_embs = self._highway_layer(self._text_field_embedder(s2) if m2 is None else s2) 140 | 141 | s1_embs = self._dropout(s1_embs) 142 | s2_embs = self._dropout(s2_embs) 143 | 144 | # Set up masks 145 | s1_mask = util.get_text_field_mask(s1) if m1 is None else m1.long() 146 | s2_mask = util.get_text_field_mask(s2) if m2 is None else m2.long() 147 | 148 | s1_lstm_mask = s1_mask.float() if self._mask_lstms else None 149 | s2_lstm_mask = s2_mask.float() if self._mask_lstms else None 150 | 151 | # Sentence encodings with LSTMs 152 | s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask) 153 | s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask) 154 | 155 | s1_enc = self._dropout(s1_enc) 156 | s2_enc = self._dropout(s2_enc) 157 | 158 | # Max pooling 159 | s1_mask = s1_mask.unsqueeze(dim=-1) 160 | s2_mask = s2_mask.unsqueeze(dim=-1) 161 | s1_enc.data.masked_fill_(1 - s1_mask.byte().data, -float('inf')) 162 | s2_enc.data.masked_fill_(1 - s2_mask.byte().data, -float('inf')) 163 | s1_enc, _ = s1_enc.max(dim=1) 164 | s2_enc, _ = s2_enc.max(dim=1) 165 | 166 | return torch.cat([s1_enc, s2_enc, torch.abs(s1_enc - s2_enc), s1_enc * s2_enc], 1) -------------------------------------------------------------------------------- /sts-b-dir/preprocess.py: -------------------------------------------------------------------------------- 1 | '''Preprocessing functions and pipeline''' 2 | import nltk 3 | nltk.download('punkt') 4 | import torch 5 | import logging 6 | import numpy as np 7 | from collections import defaultdict 8 | 9 | from allennlp.data import Instance, Vocabulary, Token 10 | from allennlp.data.fields import TextField 11 | from allennlp.data.token_indexers import SingleIdTokenIndexer 12 | from allennlp_mods.numeric_field import NumericField 13 | 14 | from tasks import STSBTask 15 | 16 | PATH_PREFIX = './glue_data/' 17 | 18 | ALL_TASKS = ['sts-b'] 19 | NAME2INFO = {'sts-b': (STSBTask, 'STS-B/')} 20 | 21 | for k, v in NAME2INFO.items(): 22 | NAME2INFO[k] = (v[0], PATH_PREFIX + v[1]) 23 | 24 | def build_tasks(args): 25 | '''Prepare tasks''' 26 | 27 | task_names = [args.task] 28 | tasks = get_tasks(args, task_names, args.max_seq_len) 29 | 30 | max_v_sizes = {'word': args.max_word_v_size} 31 | token_indexer = {} 32 | token_indexer["words"] = SingleIdTokenIndexer() 33 | 34 | logging.info("\tProcessing tasks from scratch") 35 | word2freq = get_words(tasks) 36 | vocab = get_vocab(word2freq, max_v_sizes) 37 | word_embs = get_embeddings(vocab, args.word_embs_file, args.d_word) 38 | for task in tasks: 39 | train, val, test = process_task(task, token_indexer, vocab) 40 | task.train_data = train 41 | task.val_data = val 42 | task.test_data = test 43 | del_field_tokens(task) 44 | logging.info("\tFinished indexing tasks") 45 | 46 | train_eval_tasks = [task for task in tasks if task.name in task_names] 47 | logging.info('\t Training and evaluating on %s', ', '.join([task.name for task in train_eval_tasks])) 48 | 49 | return train_eval_tasks, vocab, word_embs 50 | 51 | def del_field_tokens(task): 52 | ''' Save memory by deleting the tokens that will no longer be used ''' 53 | 54 | all_instances = task.train_data + task.val_data + task.test_data 55 | for instance in all_instances: 56 | if 'input1' in instance.fields: 57 | field = instance.fields['input1'] 58 | del field.tokens 59 | if 'input2' in instance.fields: 60 | field = instance.fields['input2'] 61 | del field.tokens 62 | 63 | def get_tasks(args, task_names, max_seq_len): 64 | '''Load tasks''' 65 | tasks = [] 66 | for name in task_names: 67 | assert name in NAME2INFO, 'Task not found!' 68 | task = NAME2INFO[name][0](args, NAME2INFO[name][1], max_seq_len, name) 69 | tasks.append(task) 70 | logging.info("\tFinished loading tasks: %s.", ' '.join([task.name for task in tasks])) 71 | 72 | return tasks 73 | 74 | def get_words(tasks): 75 | ''' 76 | Get all words for all tasks for all splits for all sentences 77 | Return dictionary mapping words to frequencies. 78 | ''' 79 | word2freq = defaultdict(int) 80 | 81 | def count_sentence(sentence): 82 | '''Update counts for words in the sentence''' 83 | for word in sentence: 84 | word2freq[word] += 1 85 | return 86 | 87 | for task in tasks: 88 | splits = [task.train_data_text, task.val_data_text, task.test_data_text] 89 | for split in [split for split in splits if split is not None]: 90 | for sentence in split[0]: 91 | count_sentence(sentence) 92 | for sentence in split[1]: 93 | count_sentence(sentence) 94 | 95 | logging.info("\tFinished counting words") 96 | 97 | return word2freq 98 | 99 | def get_vocab(word2freq, max_v_sizes): 100 | '''Build vocabulary''' 101 | vocab = Vocabulary(counter=None, max_vocab_size=max_v_sizes['word']) 102 | words_by_freq = [(word, freq) for word, freq in word2freq.items()] 103 | words_by_freq.sort(key=lambda x: x[1], reverse=True) 104 | for word, _ in words_by_freq[:max_v_sizes['word']]: 105 | vocab.add_token_to_namespace(word, 'tokens') 106 | logging.info("\tFinished building vocab. Using %d words", vocab.get_vocab_size('tokens')) 107 | 108 | return vocab 109 | 110 | def get_embeddings(vocab, vec_file, d_word): 111 | '''Get embeddings for the words in vocab''' 112 | word_v_size, unk_idx = vocab.get_vocab_size('tokens'), vocab.get_token_index(vocab._oov_token) 113 | embeddings = np.random.randn(word_v_size, d_word) 114 | with open(vec_file) as vec_fh: 115 | for line in vec_fh: 116 | word, vec = line.split(' ', 1) 117 | idx = vocab.get_token_index(word) 118 | if idx != unk_idx: 119 | idx = vocab.get_token_index(word) 120 | embeddings[idx] = np.array(list(map(float, vec.split()))) 121 | embeddings[vocab.get_token_index('@@PADDING@@')] = 0. 122 | embeddings = torch.FloatTensor(embeddings) 123 | logging.info("\tFinished loading embeddings") 124 | 125 | return embeddings 126 | 127 | def process_task(task, token_indexer, vocab): 128 | ''' 129 | Convert a task's splits into AllenNLP fields then 130 | Index the splits using the given vocab (experiment dependent) 131 | ''' 132 | if hasattr(task, 'train_data_text') and task.train_data_text is not None: 133 | train = process_split(task.train_data_text, token_indexer) 134 | else: 135 | train = None 136 | if hasattr(task, 'val_data_text') and task.val_data_text is not None: 137 | val = process_split(task.val_data_text, token_indexer) 138 | else: 139 | val = None 140 | if hasattr(task, 'test_data_text') and task.test_data_text is not None: 141 | test = process_split(task.test_data_text, token_indexer) 142 | else: 143 | test = None 144 | 145 | for instance in train + val + test: 146 | instance.index_fields(vocab) 147 | 148 | return train, val, test 149 | 150 | def process_split(split, indexers): 151 | ''' 152 | Convert a dataset of sentences into padded sequences of indices. 153 | ''' 154 | inputs1 = [TextField(list(map(Token, sent)), token_indexers=indexers) for sent in split[0]] 155 | inputs2 = [TextField(list(map(Token, sent)), token_indexers=indexers) for sent in split[1]] 156 | labels = [NumericField(l) for l in split[-1]] 157 | 158 | if len(split) == 4: # weight 159 | weights = [NumericField(w) for w in split[2]] 160 | instances = [Instance({"input1": input1, "input2": input2, "label": label, "weight": weight}) for \ 161 | (input1, input2, label, weight) in zip(inputs1, inputs2, labels, weights)] 162 | else: 163 | instances = [Instance({"input1": input1, "input2": input2, "label": label}) for \ 164 | (input1, input2, label) in zip(inputs1, inputs2, labels)] 165 | 166 | return instances 167 | -------------------------------------------------------------------------------- /sts-b-dir/requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | wget 3 | ipdb 4 | scikit-learn 5 | allennlp==0.5.0 -------------------------------------------------------------------------------- /sts-b-dir/tasks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nltk 3 | import codecs 4 | import logging 5 | import numpy as np 6 | from scipy.ndimage import convolve1d 7 | from util import get_lds_kernel_window, STSShotAverage 8 | 9 | def process_sentence(sent, max_seq_len): 10 | '''process a sentence using NLTK toolkit''' 11 | return nltk.word_tokenize(sent)[:max_seq_len] 12 | 13 | def load_tsv(data_file, max_seq_len, s1_idx=0, s2_idx=1, targ_idx=2, targ_fn=None, skip_rows=0, delimiter='\t', args=None): 14 | '''Load a tsv ''' 15 | sent1s, sent2s, targs = [], [], [] 16 | with codecs.open(data_file, 'r', 'utf-8') as data_fh: 17 | for _ in range(skip_rows): 18 | data_fh.readline() 19 | for row_idx, row in enumerate(data_fh): 20 | try: 21 | row = row.strip().split(delimiter) 22 | sent1 = process_sentence(row[s1_idx], max_seq_len) 23 | if (targ_idx is not None and not row[targ_idx]) or not len(sent1): 24 | continue 25 | 26 | if targ_idx is not None: 27 | targ = targ_fn(row[targ_idx]) 28 | else: 29 | targ = 0 30 | 31 | if s2_idx is not None: 32 | sent2 = process_sentence(row[s2_idx], max_seq_len) 33 | if not len(sent2): 34 | continue 35 | sent2s.append(sent2) 36 | 37 | sent1s.append(sent1) 38 | targs.append(targ) 39 | 40 | except Exception as e: 41 | logging.info(e, " file: %s, row: %d" % (data_file, row_idx)) 42 | continue 43 | 44 | if args is not None and args.reweight != 'none': 45 | assert args.reweight in {'inverse', 'sqrt_inv'} 46 | assert args.reweight != 'none' if args.lds else True, "Set reweight to \'inverse\' (default) or \'sqrt_inv\' when using LDS" 47 | 48 | bins = args.bucket_num 49 | value_lst, bins_edges = np.histogram(targs, bins=bins, range=(0., 5.)) 50 | 51 | def get_bin_idx(label): 52 | if label == 5.: 53 | return bins - 1 54 | else: 55 | return np.where(bins_edges > label)[0][0] - 1 56 | 57 | if args.reweight == 'sqrt_inv': 58 | value_lst = [np.sqrt(x) for x in value_lst] 59 | num_per_label = [value_lst[get_bin_idx(label)] for label in targs] 60 | 61 | logging.info(f"Using re-weighting: [{args.reweight.upper()}]") 62 | 63 | if args.lds: 64 | lds_kernel_window = get_lds_kernel_window(args.lds_kernel, args.lds_ks, args.lds_sigma) 65 | logging.info(f'Using LDS: [{args.lds_kernel.upper()}] ({args.lds_ks}/{args.lds_sigma})') 66 | smoothed_value = convolve1d(value_lst, weights=lds_kernel_window, mode='constant') 67 | num_per_label = [smoothed_value[get_bin_idx(label)] for label in targs] 68 | 69 | weights = [np.float32(1 / x) for x in num_per_label] 70 | scaling = len(weights) / np.sum(weights) 71 | weights = [scaling * x for x in weights] 72 | 73 | return sent1s, sent2s, weights, targs 74 | 75 | return sent1s, sent2s, targs 76 | 77 | class STSBTask: 78 | ''' Task class for Sentence Textual Similarity Benchmark. ''' 79 | def __init__(self, args, path, max_seq_len, name="sts-b"): 80 | ''' ''' 81 | super(STSBTask, self).__init__() 82 | self.args = args 83 | self.name = name 84 | self.train_data_text, self.val_data_text, self.test_data_text = None, None, None 85 | self.val_metric = 'mse' 86 | self.scorer = STSShotAverage(metric=['mse', 'l1', 'gmean', 'pearsonr', 'spearmanr']) 87 | self.load_data(path, max_seq_len) 88 | 89 | def load_data(self, path, max_seq_len): 90 | ''' ''' 91 | tr_data = load_tsv(os.path.join(path, 'train_new.tsv'), max_seq_len, skip_rows=1, 92 | s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: np.float32(x), args=self.args) 93 | val_data = load_tsv(os.path.join(path, 'dev_new.tsv'), max_seq_len, skip_rows=1, 94 | s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: np.float32(x)) 95 | te_data = load_tsv(os.path.join(path, 'test_new.tsv'), max_seq_len, skip_rows=1, 96 | s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: np.float32(x)) 97 | 98 | self.train_data_text = tr_data 99 | self.val_data_text = val_data 100 | self.test_data_text = te_data 101 | logging.info("\tFinished loading STS Benchmark data.") 102 | 103 | def get_metrics(self, reset=False, type=None): 104 | metric = self.scorer.get_metric(reset, type) 105 | 106 | return metric 107 | -------------------------------------------------------------------------------- /sts-b-dir/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import argparse 6 | import shutil 7 | import logging 8 | import torch 9 | import numpy as np 10 | from allennlp.data.iterators import BasicIterator 11 | 12 | from preprocess import build_tasks 13 | from models import build_model 14 | from trainer import build_trainer 15 | from evaluate import evaluate 16 | from util import device_mapping, query_yes_no, resume_checkpoint 17 | 18 | def main(arguments): 19 | parser = argparse.ArgumentParser(description='') 20 | 21 | parser.add_argument('--cuda', help='-1 if no CUDA, else gpu id (single gpu is enough)', type=int, default=0) 22 | parser.add_argument('--random_seed', help='random seed to use', type=int, default=111) 23 | 24 | # Paths and logging 25 | parser.add_argument('--log_file', help='file to log to', type=str, default='training.log') 26 | parser.add_argument('--store_root', help='store root path', type=str, default='checkpoint') 27 | parser.add_argument('--store_name', help='store name prefix for current experiment', type=str, default='sts') 28 | parser.add_argument('--suffix', help='store name suffix for current experiment', type=str, default='') 29 | parser.add_argument('--word_embs_file', help='file containing word embs', type=str, default='glove/glove.840B.300d.txt') 30 | 31 | # Training resuming flag 32 | parser.add_argument('--resume', help='whether to resume training', action='store_true', default=False) 33 | 34 | # Tasks 35 | parser.add_argument('--task', help='training and evaluation task', type=str, default='sts-b') 36 | 37 | # Preprocessing options 38 | parser.add_argument('--max_seq_len', help='max sequence length', type=int, default=40) 39 | parser.add_argument('--max_word_v_size', help='max word vocab size', type=int, default=30000) 40 | 41 | # Embedding options 42 | parser.add_argument('--dropout_embs', help='dropout rate for embeddings', type=float, default=.2) 43 | parser.add_argument('--d_word', help='dimension of word embeddings', type=int, default=300) 44 | parser.add_argument('--glove', help='1 if use glove, else from scratch', type=int, default=1) 45 | parser.add_argument('--train_words', help='1 if make word embs trainable', type=int, default=0) 46 | 47 | # Model options 48 | parser.add_argument('--d_hid', help='hidden dimension size', type=int, default=1500) 49 | parser.add_argument('--n_layers_enc', help='number of RNN layers', type=int, default=2) 50 | parser.add_argument('--n_layers_highway', help='number of highway layers', type=int, default=0) 51 | parser.add_argument('--dropout', help='dropout rate to use in training', type=float, default=0.2) 52 | 53 | # Training options 54 | parser.add_argument('--batch_size', help='batch size', type=int, default=128) 55 | parser.add_argument('--optimizer', help='optimizer to use', type=str, default='adam') 56 | parser.add_argument('--lr', help='starting learning rate', type=float, default=1e-4) 57 | parser.add_argument('--loss', type=str, default='mse', choices=['mse', 'l1', 'focal_l1', 'focal_mse', 'huber']) 58 | parser.add_argument('--huber_beta', type=float, default=0.3, help='beta for huber loss') 59 | parser.add_argument('--max_grad_norm', help='max grad norm', type=float, default=5.) 60 | parser.add_argument('--val_interval', help='number of iterations between validation checks', type=int, default=400) 61 | parser.add_argument('--max_vals', help='maximum number of validation checks', type=int, default=100) 62 | parser.add_argument('--patience', help='patience for early stopping', type=int, default=10) 63 | 64 | # imbalanced related 65 | # LDS 66 | parser.add_argument('--lds', action='store_true', default=False, help='whether to enable LDS') 67 | parser.add_argument('--lds_kernel', type=str, default='gaussian', 68 | choices=['gaussian', 'triang', 'laplace'], help='LDS kernel type') 69 | parser.add_argument('--lds_ks', type=int, default=5, help='LDS kernel size: should be odd number') 70 | parser.add_argument('--lds_sigma', type=float, default=2, help='LDS gaussian/laplace kernel sigma') 71 | # FDS 72 | parser.add_argument('--fds', action='store_true', default=False, help='whether to enable FDS') 73 | parser.add_argument('--fds_kernel', type=str, default='gaussian', 74 | choices=['gaussian', 'triang', 'laplace'], help='FDS kernel type') 75 | parser.add_argument('--fds_ks', type=int, default=5, help='FDS kernel size: should be odd number') 76 | parser.add_argument('--fds_sigma', type=float, default=2, help='FDS gaussian/laplace kernel sigma') 77 | parser.add_argument('--start_update', type=int, default=0, help='which epoch to start FDS updating') 78 | parser.add_argument('--start_smooth', type=int, default=1, help='which epoch to start using FDS to smooth features') 79 | parser.add_argument('--bucket_num', type=int, default=50, help='maximum bucket considered for FDS') 80 | parser.add_argument('--bucket_start', type=int, default=0, help='minimum(starting) bucket for FDS') 81 | parser.add_argument('--fds_mmt', type=float, default=0.9, help='FDS momentum') 82 | 83 | # re-weighting: SQRT_INV / INV 84 | parser.add_argument('--reweight', type=str, default='none', choices=['none', 'sqrt_inv', 'inverse'], 85 | help='cost-sensitive reweighting scheme') 86 | # two-stage training: RRT 87 | parser.add_argument('--retrain_fc', action='store_true', default=False, 88 | help='whether to retrain last regression layer (regressor)') 89 | parser.add_argument('--pretrained', type=str, default='', help='pretrained checkpoint file path to load backbone weights for RRT') 90 | # evaluate only 91 | parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate only flag') 92 | parser.add_argument('--eval_model', type=str, default='', help='the model to evaluate on; if not specified, ' 93 | 'use the default best model in store_dir') 94 | 95 | args = parser.parse_args(arguments) 96 | 97 | os.makedirs(args.store_root, exist_ok=True) 98 | 99 | if not args.lds and args.reweight != 'none': 100 | args.store_name += f'_{args.reweight}' 101 | if args.lds: 102 | args.store_name += f'_lds_{args.lds_kernel[:3]}_{args.lds_ks}' 103 | if args.lds_kernel in ['gaussian', 'laplace']: 104 | args.store_name += f'_{args.lds_sigma}' 105 | if args.fds: 106 | args.store_name += f'_fds_{args.fds_kernel[:3]}_{args.fds_ks}' 107 | if args.fds_kernel in ['gaussian', 'laplace']: 108 | args.store_name += f'_{args.fds_sigma}' 109 | args.store_name += f'_{args.start_update}_{args.start_smooth}_{args.fds_mmt}' 110 | if args.retrain_fc: 111 | args.store_name += f'_retrain_fc' 112 | 113 | if args.loss == 'huber': 114 | args.store_name += f'_{args.loss}_beta_{args.huber_beta}' 115 | else: 116 | args.store_name += f'_{args.loss}' 117 | 118 | args.store_name += f'_seed_{args.random_seed}_valint_{args.val_interval}_patience_{args.patience}' \ 119 | f'_{args.optimizer}_{args.lr}_{args.batch_size}' 120 | args.store_name += f'_{args.suffix}' if len(args.suffix) else '' 121 | 122 | args.store_dir = os.path.join(args.store_root, args.store_name) 123 | 124 | if not args.evaluate and not args.resume: 125 | if os.path.exists(args.store_dir): 126 | if query_yes_no('overwrite previous folder: {} ?'.format(args.store_dir)): 127 | shutil.rmtree(args.store_dir) 128 | print(args.store_dir + ' removed.\n') 129 | else: 130 | raise RuntimeError('Output folder {} already exists'.format(args.store_dir)) 131 | logging.info(f"===> Creating folder: {args.store_dir}") 132 | os.makedirs(args.store_dir) 133 | 134 | # Logistics 135 | logging.root.handlers = [] 136 | if os.path.exists(args.store_dir): 137 | log_file = os.path.join(args.store_dir, args.log_file) 138 | logging.basicConfig( 139 | level=logging.INFO, 140 | format="%(asctime)s | %(message)s", 141 | handlers=[ 142 | logging.FileHandler(log_file), 143 | logging.StreamHandler() 144 | ]) 145 | else: 146 | logging.basicConfig( 147 | level=logging.INFO, 148 | format="%(asctime)s | %(message)s", 149 | handlers=[logging.StreamHandler()] 150 | ) 151 | logging.info(args) 152 | 153 | seed = random.randint(1, 10000) if args.random_seed < 0 else args.random_seed 154 | random.seed(seed) 155 | torch.manual_seed(seed) 156 | if args.cuda >= 0: 157 | logging.info("Using GPU %d", args.cuda) 158 | torch.cuda.set_device(args.cuda) 159 | torch.cuda.manual_seed_all(seed) 160 | logging.info("Using random seed %d", seed) 161 | 162 | # Load tasks 163 | logging.info("Loading tasks...") 164 | start_time = time.time() 165 | tasks, vocab, word_embs = build_tasks(args) 166 | logging.info('\tFinished loading tasks in %.3fs', time.time() - start_time) 167 | 168 | # Build model 169 | logging.info('Building model...') 170 | start_time = time.time() 171 | model = build_model(args, vocab, word_embs, tasks) 172 | logging.info('\tFinished building model in %.3fs', time.time() - start_time) 173 | 174 | # Set up trainer 175 | iterator = BasicIterator(args.batch_size) 176 | trainer, train_params, opt_params = build_trainer(args, model, iterator) 177 | 178 | # Train 179 | if tasks and not args.evaluate: 180 | if args.retrain_fc and len(args.pretrained): 181 | model_path = args.pretrained 182 | assert os.path.isfile(model_path), f"No checkpoint found at '{model_path}'" 183 | model_state = torch.load(model_path, map_location=device_mapping(args.cuda)) 184 | trainer._model = resume_checkpoint(trainer._model, model_state, backbone_only=True) 185 | logging.info(f'Pre-trained backbone weights loaded: {model_path}') 186 | logging.info('Retrain last regression layer only!') 187 | for name, param in trainer._model.named_parameters(): 188 | if "sts-b_pred_layer" not in name: 189 | param.requires_grad = False 190 | logging.info(f'Only optimize parameters: {[n for n, p in trainer._model.named_parameters() if p.requires_grad]}') 191 | to_train = [(n, p) for n, p in trainer._model.named_parameters() if p.requires_grad] 192 | else: 193 | to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad] 194 | 195 | trainer.train(tasks, args.val_interval, to_train, opt_params, args.resume) 196 | else: 197 | logging.info("Skipping training...") 198 | 199 | logging.info('Testing on test set...') 200 | model_path = os.path.join(args.store_dir, "model_state_best.th") if not len(args.eval_model) else args.eval_model 201 | assert os.path.isfile(model_path), f"No checkpoint found at '{model_path}'" 202 | logging.info(f'Evaluating {model_path}...') 203 | model_state = torch.load(model_path, map_location=device_mapping(args.cuda)) 204 | model = resume_checkpoint(model, model_state) 205 | te_preds, te_labels, _ = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="test") 206 | if not len(args.eval_model): 207 | np.savez_compressed(os.path.join(args.store_dir, f"{args.store_name}.npz"), preds=te_preds, labels=te_labels) 208 | 209 | logging.info("Done testing.") 210 | 211 | if __name__ == '__main__': 212 | sys.exit(main(sys.argv[1:])) 213 | -------------------------------------------------------------------------------- /sts-b-dir/util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import numpy as np 3 | import torch 4 | 5 | from scipy.ndimage import gaussian_filter1d 6 | from scipy.signal.windows import triang 7 | from scipy.stats import pearsonr, spearmanr, gmean 8 | 9 | 10 | def get_text_field_mask(text_field_tensors: Dict[str, torch.Tensor]) -> torch.LongTensor: 11 | """ 12 | Takes the dictionary of tensors produced by a ``TextField`` and returns a mask of shape 13 | ``(batch_size, num_tokens)``. This mask will be 0 where the tokens are padding, and 1 14 | otherwise. 15 | 16 | There could be several entries in the tensor dictionary with different shapes (e.g., one for 17 | word ids, one for character ids). In order to get a token mask, we assume that the tensor in 18 | the dictionary with the lowest number of dimensions has plain token ids. This allows us to 19 | also handle cases where the input is actually a ``ListField[TextField]``. 20 | 21 | NOTE: Our functions for generating masks create torch.LongTensors, because using 22 | torch.byteTensors inside Variables makes it easy to run into overflow errors 23 | when doing mask manipulation, such as summing to get the lengths of sequences - see below. 24 | >>> mask = torch.ones([260]).byte() 25 | >>> mask.sum() # equals 260. 26 | >>> var_mask = torch.autograd.Variable(mask) 27 | >>> var_mask.sum() # equals 4, due to 8 bit precision - the sum overflows. 28 | """ 29 | tensor_dims = [(tensor.dim(), tensor) for tensor in text_field_tensors.values()] 30 | tensor_dims.sort(key=lambda x: x[0]) 31 | token_tensor = tensor_dims[0][1] 32 | 33 | return (token_tensor != 0).long() 34 | 35 | def device_mapping(cuda_device: int): 36 | """ 37 | In order to `torch.load()` a GPU-trained model onto a CPU (or specific GPU), 38 | you have to supply a `map_location` function. Call this with 39 | the desired `cuda_device` to get the function that `torch.load()` needs. 40 | """ 41 | def inner_device_mapping(storage: torch.Storage, location) -> torch.Storage: 42 | if cuda_device >= 0: 43 | return storage.cuda(cuda_device) 44 | else: 45 | return storage 46 | return inner_device_mapping 47 | 48 | def query_yes_no(question): 49 | """ Ask a yes/no question via input() and return their answer. """ 50 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} 51 | prompt = " [Y/n] " 52 | 53 | while True: 54 | print(question + prompt, end=':') 55 | choice = input().lower() 56 | if choice == '': 57 | return valid['y'] 58 | elif choice in valid: 59 | return valid[choice] 60 | else: 61 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") 62 | 63 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.5, clip_max=2.): 64 | if torch.sum(v1) < 1e-10: 65 | return matrix 66 | if (v1 <= 0.).any() or (v2 < 0.).any(): 67 | valid_pos = (((v1 > 0.) + (v2 >= 0.)) == 2) 68 | factor = torch.clamp(v2[valid_pos] / v1[valid_pos], clip_min, clip_max) 69 | matrix[:, valid_pos] = (matrix[:, valid_pos] - m1[valid_pos]) * torch.sqrt(factor) + m2[valid_pos] 70 | return matrix 71 | 72 | factor = torch.clamp(v2 / v1, clip_min, clip_max) 73 | return (matrix - m1) * torch.sqrt(factor) + m2 74 | 75 | def resume_checkpoint(model, model_state, backbone_only=False): 76 | model.pair_encoder.load_state_dict( 77 | {k.split('.', 1)[1]: v for k, v in model_state.items() if 'pair_encoder' in k} 78 | ) 79 | if not backbone_only: 80 | getattr(model, 'sts-b_pred_layer').load_state_dict( 81 | {k.split('.', 1)[1]: v for k, v in model_state.items() if 'sts-b_pred_layer' in k} 82 | ) 83 | 84 | return model 85 | 86 | def get_lds_kernel_window(kernel, ks, sigma): 87 | assert kernel in ['gaussian', 'triang', 'laplace'] 88 | half_ks = (ks - 1) // 2 89 | if kernel == 'gaussian': 90 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks 91 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma)) 92 | # kernel = gaussian(ks) 93 | elif kernel == 'triang': 94 | kernel_window = triang(ks) 95 | else: 96 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma) 97 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1))) 98 | 99 | return kernel_window 100 | 101 | class STSShotAverage: 102 | def __init__(self, metric): 103 | self._pred = [] 104 | self._label = [] 105 | self._count = 0 106 | self._metric = metric 107 | self._num_bins = 50 108 | # under np.float32 division 109 | self._shot_idx = { 110 | 'many': [0, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 49], 111 | 'medium': [2, 4, 6, 8, 27, 35, 37], 112 | 'few': [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 29, 31, 33, 39, 41, 43, 45, 47] 113 | } 114 | 115 | def get_bin_idx(label): 116 | _, bins_edges = np.histogram(a=np.array([], dtype=np.float32), bins=self._num_bins, range=(0., 5.)) 117 | if label == 5.: 118 | return self._num_bins - 1 119 | else: 120 | return np.where(bins_edges > label)[0][0] - 1 121 | self._get_bin_idx = get_bin_idx 122 | 123 | def __call__(self, pred, label): 124 | self._pred += pred.tolist() 125 | self._label += label.tolist() 126 | self._count += len(pred) 127 | 128 | def get_metric(self, reset=False, type=None): 129 | label_bin_idx = list(map(self._get_bin_idx, self._label)) 130 | def bin2shot(idx): 131 | if idx in self._shot_idx['many']: 132 | return 'many' 133 | elif idx in self._shot_idx['medium']: 134 | return 'medium' 135 | else: 136 | return 'few' 137 | label_category = np.array(list(map(bin2shot, label_bin_idx))) 138 | 139 | pred_shot = {'many': [], 'medium': [], 'few': [], 'overall': []} 140 | label_shot = {'many': [], 'medium': [], 'few': [], 'overall': []} 141 | metric = {'many': {}, 'medium': {}, 'few': {}, 'overall': {}} 142 | for shot in ['overall', 'many', 'medium', 'few']: 143 | pred_shot[shot] = np.array(self._pred)[label_category == shot] * 5. if shot != 'overall' else np.array(self._pred) * 5. 144 | label_shot[shot] = np.array(self._label)[label_category == shot] if shot != 'overall' else np.array(self._label) 145 | if 'mse' in self._metric: 146 | metric[shot]['mse'] = np.mean((pred_shot[shot] - label_shot[shot]) ** 2) if pred_shot[shot].size > 0 else 0. 147 | if 'l1' in self._metric: 148 | metric[shot]['l1'] = np.mean(np.abs(pred_shot[shot] - label_shot[shot])) if pred_shot[shot].size > 0 else 0. 149 | if 'gmean' in self._metric: 150 | if pred_shot[shot].size <= 0: 151 | metric[shot]['gmean'] = 0. 152 | else: 153 | diff = np.abs(pred_shot[shot] - label_shot[shot]) 154 | if diff[diff == 0.].size: 155 | diff[diff == 0.] += 1e-10 156 | metric[shot]['gmean'] = gmean(diff) if pred_shot[shot].size > 0 else 0. 157 | else: 158 | metric[shot]['gmean'] = gmean(np.abs(pred_shot[shot] - label_shot[shot])) if pred_shot[shot].size > 0 else 0. 159 | if 'pearsonr' in self._metric: 160 | metric[shot]['pearsonr'] = pearsonr(pred_shot[shot], label_shot[shot])[0] if pred_shot[shot].size > 1 else 0. 161 | if 'spearmanr' in self._metric: 162 | metric[shot]['spearmanr'] = spearmanr(pred_shot[shot], label_shot[shot])[0] if pred_shot[shot].size > 1 else 0. 163 | metric[shot]['num_samples'] = pred_shot[shot].size 164 | if reset: 165 | self.reset() 166 | return metric['overall'] if type == 'overall' else metric 167 | 168 | 169 | def reset(self): 170 | self._pred = [] 171 | self._label = [] 172 | self._count = 0 -------------------------------------------------------------------------------- /teaser/agedb_dir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/agedb_dir.png -------------------------------------------------------------------------------- /teaser/fds.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/fds.gif -------------------------------------------------------------------------------- /teaser/imdb_wiki_dir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/imdb_wiki_dir.png -------------------------------------------------------------------------------- /teaser/lds.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/lds.gif -------------------------------------------------------------------------------- /teaser/nyud2_dir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/nyud2_dir.png -------------------------------------------------------------------------------- /teaser/overview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/overview.gif -------------------------------------------------------------------------------- /teaser/shhs_dir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/shhs_dir.png -------------------------------------------------------------------------------- /teaser/stsb_dir.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/stsb_dir.png -------------------------------------------------------------------------------- /tutorial/README.md: -------------------------------------------------------------------------------- 1 | # Hands-on Tutorial of Deep Imbalanced Regression 2 | 3 |

4 | 5 | Open In Colab 6 | 7 |

8 | 9 | 10 | This is a hands-on tutorial for **Delving into Deep Imbalanced Regression** [[Paper]](https://arxiv.org/abs/2102.09554). 11 | 12 | ```bib 13 | @inproceedings{yang2021delving, 14 | title={Delving into Deep Imbalanced Regression}, 15 | author={Yang, Yuzhe and Zha, Kaiwen and Chen, Ying-Cong and Wang, Hao and Katabi, Dina}, 16 | booktitle={International Conference on Machine Learning (ICML)}, 17 | year={2021} 18 | } 19 | ``` 20 | 21 | In this notebook, we will provide a hands-on tutorial for DIR on a small-scale dataset, [Boston Housing dataset](https://www.cs.toronto.edu/~delve/data/boston/bostonDetail.html), as a quick overview on how to perform practical (deep) imbalanced regression on custom datasets. 22 | 23 | You can directly open it via Colab: [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/YyzHarry/imbalanced-regression/blob/master/tutorial/tutorial.ipynb), or using jupyter notebook with the following instructions. 24 | 25 | Required packages: 26 | ```bash 27 | pip install --upgrade pip 28 | pip install --upgrade jupyter notebook 29 | ``` 30 | 31 | Then, please clone this repository to your computer using: 32 | 33 | ```bash 34 | git clone https://github.com/YyzHarry/imbalanced-regression.git 35 | ``` 36 | 37 | After cloning is finished, you may go to the directory of this tutorial and run 38 | 39 | ```bash 40 | jupyter notebook --port 8888 41 | ``` 42 | 43 | to start a jupyter notebook and access it through the browser. Finally, let's explore the notebook `tutorial.ipynb` prepared by us! 44 | --------------------------------------------------------------------------------