├── .gitignore
├── LICENSE
├── README.md
├── agedb-dir
├── README.md
├── data
│ ├── agedb.csv
│ ├── create_agedb.py
│ └── preprocess_agedb.py
├── datasets.py
├── fds.py
├── loss.py
├── resnet.py
├── train.py
└── utils.py
├── imdb-wiki-dir
├── README.md
├── data
│ ├── create_imdb_wiki.py
│ ├── imdb_wiki.csv
│ └── preprocess_imdb_wiki.py
├── datasets.py
├── download_imdb_wiki.py
├── fds.py
├── loss.py
├── resnet.py
├── train.py
└── utils.py
├── nyud2-dir
├── README.md
├── data
│ ├── nyu2_train_FDS_subset.csv
│ └── test_balanced_mask.npy
├── download_nyud2.py
├── loaddata.py
├── models
│ ├── __init__.py
│ ├── fds.py
│ ├── modules.py
│ ├── net.py
│ └── resnet.py
├── nyu_transform.py
├── preprocess_nyud2.py
├── test.py
├── train.py
└── util.py
├── sts-b-dir
├── README.md
├── allennlp_mods
│ └── numeric_field.py
├── evaluate.py
├── fds.py
├── glove
│ └── download_glove.py
├── glue_data
│ ├── STS-B
│ │ ├── dev.tsv
│ │ ├── dev_new.tsv
│ │ ├── test.tsv
│ │ ├── test_new.tsv
│ │ ├── train.tsv
│ │ └── train_new.tsv
│ └── create_sts.py
├── loss.py
├── models.py
├── preprocess.py
├── requirements.txt
├── tasks.py
├── train.py
├── trainer.py
└── util.py
├── teaser
├── agedb_dir.png
├── fds.gif
├── imdb_wiki_dir.png
├── lds.gif
├── nyud2_dir.png
├── overview.gif
├── shhs_dir.png
└── stsb_dir.png
└── tutorial
├── README.md
└── tutorial.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | .idea/*
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Yuzhe Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Delving into Deep Imbalanced Regression
2 |
3 | This repository contains the implementation code for paper:
4 | __Delving into Deep Imbalanced Regression__
5 | [Yuzhe Yang](http://www.mit.edu/~yuzhe/), [Kaiwen Zha](https://kaiwenzha.github.io/), [Ying-Cong Chen](https://yingcong.github.io/), [Hao Wang](http://www.wanghao.in/), [Dina Katabi](https://people.csail.mit.edu/dina/)
6 | _38th International Conference on Machine Learning (ICML 2021), **Long Oral**_
7 | [[Project Page](http://dir.csail.mit.edu/)] [[Paper](https://arxiv.org/abs/2102.09554)] [[Video](https://youtu.be/grJGixofQRU)] [[Blog Post](https://towardsdatascience.com/strategies-and-tactics-for-regression-on-imbalanced-data-61eeb0921fca)] [](https://colab.research.google.com/github/YyzHarry/imbalanced-regression/blob/master/tutorial/tutorial.ipynb)
8 | ___
9 |
10 |
11 | Deep Imbalanced Regression (DIR) aims to learn from imbalanced data with continuous targets,
tackle potential missing data for certain regions, and generalize to the entire target range.
12 |
13 |
14 |
15 | ## Beyond Imbalanced Classification: Brief Introduction for DIR
16 | Existing techniques for learning from imbalanced data focus on targets with __categorical__ indices, i.e., the targets are different classes. However, many real-world tasks involve __continuous__ and even infinite target values. We systematically investigate _Deep Imbalanced Regression (DIR)_, which aims to learn continuous targets from natural imbalanced data, deal with potential missing data for certain target values, and generalize to the entire target range.
17 |
18 | We curate and benchmark large-scale DIR datasets for common real-world tasks in _computer vision_, _natural language processing_, and _healthcare_ domains, ranging from single-value prediction such as age, text similarity score, health condition score, to dense-value prediction such as depth.
19 |
20 |
21 | ## Usage
22 | We separate the codebase for different datasets into different subfolders. Please go into the subfolders for more information (e.g., installation, dataset preparation, training, evaluation & models).
23 |
24 | #### __[IMDB-WIKI-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir)__ | __[AgeDB-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/agedb-dir)__ | __[NYUD2-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/nyud2-dir)__ | __[STS-B-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/sts-b-dir)__
25 |
26 |
27 | ## Highlights
28 | __(1) :heavy_check_mark: New Task:__ Deep Imbalanced Regression (DIR)
29 |
30 | __(2) :heavy_check_mark: New Techniques:__
31 |
32 | |  |  |
33 | | :-: | :-: |
34 | | Label distribution smoothing (LDS) | Feature distribution smoothing (FDS) |
35 |
36 | __(3) :heavy_check_mark: New Benchmarks:__
37 | - _Computer Vision:_ :bulb: IMDB-WIKI-DIR (age) / AgeDB-DIR (age) / NYUD2-DIR (depth)
38 | - _Natural Language Processing:_ :clipboard: STS-B-DIR (text similarity score)
39 | - _Healthcare:_ :hospital: SHHS-DIR (health condition score)
40 |
41 | | [IMDB-WIKI-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/imdb-wiki-dir) | [AgeDB-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/agedb-dir) | [NYUD2-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/nyud2-dir) | [STS-B-DIR](https://github.com/YyzHarry/imbalanced-regression/tree/main/sts-b-dir) | SHHS-DIR |
42 | | :-: | :-: | :-: | :-: | :-: |
43 | |  |  |  |  |  |
44 |
45 |
46 | ## Apply LDS and FDS on Other Datasets / Models
47 | We provide examples of how to apply LDS and FDS on other customized datasets and/or models.
48 |
49 | ### LDS
50 | To apply LDS on your customized dataset, you will first need to estimate the effective label distribution:
51 | ```python
52 | from collections import Counter
53 | from scipy.ndimage import convolve1d
54 | from utils import get_lds_kernel_window
55 |
56 | # preds, labels: [Ns,], "Ns" is the number of total samples
57 | preds, labels = ..., ...
58 | # assign each label to its corresponding bin (start from 0)
59 | # with your defined get_bin_idx(), return bin_index_per_label: [Ns,]
60 | bin_index_per_label = [get_bin_idx(label) for label in labels]
61 |
62 | # calculate empirical (original) label distribution: [Nb,]
63 | # "Nb" is the number of bins
64 | Nb = max(bin_index_per_label) + 1
65 | num_samples_of_bins = dict(Counter(bin_index_per_label))
66 | emp_label_dist = [num_samples_of_bins.get(i, 0) for i in range(Nb)]
67 |
68 | # lds_kernel_window: [ks,], here for example, we use gaussian, ks=5, sigma=2
69 | lds_kernel_window = get_lds_kernel_window(kernel='gaussian', ks=5, sigma=2)
70 | # calculate effective label distribution: [Nb,]
71 | eff_label_dist = convolve1d(np.array(emp_label_dist), weights=lds_kernel_window, mode='constant')
72 | ```
73 | With the estimated effective label distribution, one straightforward option is to use the loss re-weighting scheme:
74 | ```python
75 | from loss import weighted_mse_loss
76 |
77 | # Use re-weighting based on effective label distribution, sample-wise weights: [Ns,]
78 | eff_num_per_label = [eff_label_dist[bin_idx] for bin_idx in bin_index_per_label]
79 | weights = [np.float32(1 / x) for x in eff_num_per_label]
80 |
81 | # calculate loss
82 | loss = weighted_mse_loss(preds, labels, weights=weights)
83 | ```
84 |
85 | ### FDS
86 | To apply FDS on your customized data/model, you will first need to define the FDS module in your network:
87 | ```python
88 | from fds import FDS
89 |
90 | config = dict(feature_dim=..., start_update=0, start_smooth=1, kernel='gaussian', ks=5, sigma=2)
91 |
92 | def Network(nn.Module):
93 | def __init__(self, **config):
94 | super().__init__()
95 | self.feature_extractor = ...
96 | self.regressor = nn.Linear(config['feature_dim'], 1) # FDS operates before the final regressor
97 | self.FDS = FDS(**config)
98 |
99 | def forward(self, inputs, labels, epoch):
100 | features = self.feature_extractor(inputs) # features: [batch_size, feature_dim]
101 | # smooth the feature distributions over the target space
102 | smoothed_features = features
103 | if self.training and epoch >= config['start_smooth']:
104 | smoothed_features = self.FDS.smooth(smoothed_features, labels, epoch)
105 | preds = self.regressor(smoothed_features)
106 |
107 | return {'preds': preds, 'features': features}
108 | ```
109 | During training, you will need to update the FDS statistics after each training epoch:
110 | ```python
111 | model = Network(**config)
112 |
113 | for epoch in range(num_epochs):
114 | for (inputs, labels) in train_loader:
115 | # standard training pipeline
116 | ...
117 |
118 | # update FDS statistics after each training epoch
119 | if epoch >= config['start_update']:
120 | # collect features and labels for all training samples
121 | ...
122 | # training_features: [num_samples, feature_dim], training_labels: [num_samples,]
123 | training_features, training_labels = ..., ...
124 | model.FDS.update_last_epoch_stats(epoch)
125 | model.FDS.update_running_stats(training_features, training_labels, epoch)
126 | ```
127 |
128 |
129 | ## Updates
130 | - [06/2021] We provide a [hands-on tutorial](https://github.com/YyzHarry/imbalanced-regression/tree/main/tutorial) of DIR. Check it out! [](https://colab.research.google.com/github/YyzHarry/imbalanced-regression/blob/master/tutorial/tutorial.ipynb)
131 | - [05/2021] We create a [Blog post](https://towardsdatascience.com/strategies-and-tactics-for-regression-on-imbalanced-data-61eeb0921fca) for this work (version in Chinese is also available [here](https://zhuanlan.zhihu.com/p/369627086)). Check it out for more details!
132 | - [05/2021] Paper accepted to ICML 2021 as a __Long Talk__. We have released the code and models. You can find all reproduced checkpoints via [this link](https://drive.google.com/drive/folders/1UfFJNIG-LPOMecwi1tfYzEViBiAYhNU0?usp=sharing), or go into each subfolder for models for each dataset.
133 | - [02/2021] [arXiv version](https://arxiv.org/abs/2102.09554) posted. Please stay tuned for updates.
134 |
135 |
136 | ## Citation
137 | If you find this code or idea useful, please cite our work:
138 | ```bib
139 | @inproceedings{yang2021delving,
140 | title={Delving into Deep Imbalanced Regression},
141 | author={Yang, Yuzhe and Zha, Kaiwen and Chen, Ying-Cong and Wang, Hao and Katabi, Dina},
142 | booktitle={International Conference on Machine Learning (ICML)},
143 | year={2021}
144 | }
145 | ```
146 |
147 |
148 | ## Contact
149 | If you have any questions, feel free to contact us through email (yuzhe@mit.edu & kzha@mit.edu) or Github issues. Enjoy!
150 |
--------------------------------------------------------------------------------
/agedb-dir/README.md:
--------------------------------------------------------------------------------
1 | # AgeDB-DIR
2 | ## Installation
3 |
4 | #### Prerequisites
5 |
6 | 1. Download AgeDB dataset from [here](https://ibug.doc.ic.ac.uk/resources/agedb/) and extract the zip file (you may need to contact the authors of AgeDB dataset for the zip password) to folder `./data`
7 |
8 | 2. __(Optional)__ We have provided required AgeDB-DIR meta file `agedb.csv` to set up balanced val/test set in folder `./data`. To reproduce the results in the paper, please directly use this file. If you want to try different balanced splits, you can generate it using
9 |
10 | ```bash
11 | python data/create_agedb.py
12 | python data/preprocess_agedb.py
13 | ```
14 |
15 | #### Dependencies
16 |
17 | - PyTorch (>= 1.2, tested on 1.6)
18 | - tensorboard_logger
19 | - numpy, pandas, scipy, tqdm, matplotlib, PIL
20 |
21 | ## Code Overview
22 |
23 | #### Main Files
24 |
25 | - `train.py`: main training and evaluation script
26 | - `create_agedb.py`: create AgeDB raw meta data
27 | - `preprocess_agedb.py`: create AgeDB-DIR meta file `agedb.csv` with balanced val/test set
28 |
29 | #### Main Arguments
30 |
31 | - `--data_dir`: data directory to place data and meta file
32 | - `--lds`: LDS switch (whether to enable LDS)
33 | - `--fds`: FDS switch (whether to enable FDS)
34 | - `--reweight`: cost-sensitive re-weighting scheme to use
35 | - `--retrain_fc`: whether to retrain regressor
36 | - `--loss`: training loss type
37 | - `--resume`: path to resume checkpoint (for both training and evaluation)
38 | - `--evaluate`: evaluate only flag
39 | - `--pretrained`: path to load backbone weights for regressor re-training (RRT)
40 |
41 | ## Getting Started
42 |
43 | #### Train a vanilla model
44 |
45 | ```bash
46 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_dir --reweight none
47 | ```
48 |
49 | Always specify `CUDA_VISIBLE_DEVICES` for GPU IDs to be used (by default, 4 GPUs) and `--data_dir` when training a model or directly fix your default data directory path in the code. We will omit these arguments in the following for simplicity.
50 |
51 | #### Train a model using re-weighting
52 |
53 | To perform inverse re-weighting
54 |
55 | ```bash
56 | python train.py --reweight inverse
57 | ```
58 |
59 | To perform square-root inverse re-weighting
60 |
61 | ```bash
62 | python train.py --reweight sqrt_inv
63 | ```
64 |
65 | #### Train a model with different losses
66 |
67 | To use Focal-R loss
68 |
69 | ```bash
70 | python train.py --loss focal_l1
71 | ```
72 |
73 | To use huber loss
74 |
75 | ```bash
76 | python train.py --loss huber
77 | ```
78 |
79 | #### Train a model using RRT
80 |
81 | ```bash
82 | python train.py [...retrained model arguments...] --retrain_fc --pretrained
83 | ```
84 |
85 | #### Train a model using LDS
86 |
87 | To use Gaussian kernel (kernel size: 5, sigma: 2)
88 |
89 | ```bash
90 | python train.py --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2
91 | ```
92 |
93 | #### Train a model using FDS
94 |
95 | To use Gaussian kernel (kernel size: 5, sigma: 2)
96 |
97 | ```bash
98 | python train.py --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
99 | ```
100 | #### Train a model using LDS + FDS
101 | ```bash
102 | python train.py --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
103 | ```
104 |
105 | #### Evaluate a trained checkpoint
106 |
107 | ```bash
108 | python train.py [...evaluation model arguments...] --evaluate --resume
109 | ```
110 |
111 | ## Reproduced Benchmarks and Model Zoo
112 |
113 | We provide below reproduced results on AgeDB-DIR (base method `SQINV`, metric `MAE`).
114 | Note that some models could give **better** results than the reported numbers in the paper.
115 |
116 | | Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download |
117 | | :-------: | :-----: | :-------: | :---------: | :------: | :------: |
118 | | LDS | 7.67 | 6.98 | 8.86 | 10.89 | [model](https://drive.google.com/file/d/1CPDlcRCQ1EC4E3x9w955cmILSaVOlkyz/view?usp=sharing) |
119 | | FDS | 7.69 | 7.11 | 8.86 | 9.98 | [model](https://drive.google.com/file/d/1JmRS4V8zmmS9eschsBSmSeQjFWefYlib/view?usp=sharing) |
120 | | LDS + FDS | 7.47 | 6.91 | 8.26 | 10.55 | [model](https://drive.google.com/file/d/1N0nMdu-wuWoAOS1x_m-pnzHcAp61ajY9/view?usp=sharing) |
--------------------------------------------------------------------------------
/agedb-dir/data/create_agedb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pandas as pd
4 | from tqdm import tqdm
5 |
6 |
7 | def get_args():
8 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
9 | parser.add_argument("--data_path", type=str, default="./data")
10 | args = parser.parse_args()
11 | return args
12 |
13 |
14 | def main():
15 | args = get_args()
16 | ages, img_paths = [], []
17 |
18 | for filename in tqdm(os.listdir(os.path.join(args.data_path, 'AgeDB'))):
19 | _, _, age, gender = filename.split('.')[0].split('_')
20 |
21 | ages.append(age)
22 | img_paths.append(f"AgeDB/{filename}")
23 |
24 | outputs = dict(age=ages, path=img_paths)
25 | output_dir = os.path.join(args.data_path, "meta")
26 | os.makedirs(output_dir, exist_ok=True)
27 | output_path = os.path.join(output_dir, "agedb.csv")
28 | df = pd.DataFrame(data=outputs)
29 | df.to_csv(str(output_path), index=False)
30 |
31 |
32 | if __name__ == '__main__':
33 | main()
34 |
--------------------------------------------------------------------------------
/agedb-dir/data/preprocess_agedb.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | BASE_PATH = './data'
7 |
8 |
9 | def visualize_dataset(db="agedb"):
10 | file_path = join(BASE_PATH, "meta", "agedb.csv")
11 | data = pd.read_csv(file_path)
12 | _, ax = plt.subplots(figsize=(6, 3), sharex='all', sharey='all')
13 | ax.hist(data['age'], range(max(data['age']) + 2))
14 | # ax.set_xlim([0, 102])
15 | plt.title(f"{db.upper()} (total: {data.shape[0]})")
16 | plt.tight_layout()
17 | plt.show()
18 |
19 |
20 | def make_balanced_testset(db="agedb", max_size=30, seed=666, verbose=True, vis=True, save=False):
21 | file_path = join(BASE_PATH, "meta", f"{db}.csv")
22 | df = pd.read_csv(file_path)
23 | df['age'] = df.age.astype(int)
24 | val_set, test_set = [], []
25 | import random
26 | random.seed(seed)
27 | for value in range(121):
28 | curr_df = df[df['age'] == value]
29 | curr_data = curr_df['path'].values
30 | random.shuffle(curr_data)
31 | curr_size = min(len(curr_data) // 3, max_size)
32 | val_set += list(curr_data[:curr_size])
33 | test_set += list(curr_data[curr_size:curr_size * 2])
34 | if verbose:
35 | print(f"Val: {len(val_set)}\nTest: {len(test_set)}")
36 | assert len(set(val_set).intersection(set(test_set))) == 0
37 | combined_set = dict(zip(val_set, ['val' for _ in range(len(val_set))]))
38 | combined_set.update(dict(zip(test_set, ['test' for _ in range(len(test_set))])))
39 | df['split'] = df['path'].map(combined_set)
40 | df['split'].fillna('train', inplace=True)
41 | if verbose:
42 | print(df)
43 | if save:
44 | df.to_csv(str(join(BASE_PATH, f"{db}.csv")), index=False)
45 | if vis:
46 | _, ax = plt.subplots(3, figsize=(6, 9), sharex='all')
47 | df_train = df[df['split'] == 'train']
48 | ax[0].hist(df_train['age'], range(max(df['age'])))
49 | ax[0].set_title(f"[{db.upper()}] train: {df_train.shape[0]}")
50 | ax[1].hist(df[df['split'] == 'val']['age'], range(max(df['age'])))
51 | ax[1].set_title(f"[{db.upper()}] val: {df[df['split'] == 'val'].shape[0]}")
52 | ax[2].hist(df[df['split'] == 'test']['age'], range(max(df['age'])))
53 | ax[2].set_title(f"[{db.upper()}] test: {df[df['split'] == 'test'].shape[0]}")
54 | ax[0].set_xlim([0, 120])
55 | plt.tight_layout()
56 | plt.show()
57 |
58 |
59 | if __name__ == '__main__':
60 | make_balanced_testset()
61 | visualize_dataset()
62 |
--------------------------------------------------------------------------------
/agedb-dir/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | from PIL import Image
5 | from scipy.ndimage import convolve1d
6 | from torch.utils import data
7 | import torchvision.transforms as transforms
8 |
9 | from utils import get_lds_kernel_window
10 |
11 | print = logging.info
12 |
13 |
14 | class AgeDB(data.Dataset):
15 | def __init__(self, df, data_dir, img_size, split='train', reweight='none',
16 | lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
17 | self.df = df
18 | self.data_dir = data_dir
19 | self.img_size = img_size
20 | self.split = split
21 |
22 | self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma)
23 |
24 | def __len__(self):
25 | return len(self.df)
26 |
27 | def __getitem__(self, index):
28 | index = index % len(self.df)
29 | row = self.df.iloc[index]
30 | img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB')
31 | transform = self.get_transform()
32 | img = transform(img)
33 | label = np.asarray([row['age']]).astype('float32')
34 | weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)])
35 |
36 | return img, label, weight
37 |
38 | def get_transform(self):
39 | if self.split == 'train':
40 | transform = transforms.Compose([
41 | transforms.Resize((self.img_size, self.img_size)),
42 | transforms.RandomCrop(self.img_size, padding=16),
43 | transforms.RandomHorizontalFlip(),
44 | transforms.ToTensor(),
45 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]),
46 | ])
47 | else:
48 | transform = transforms.Compose([
49 | transforms.Resize((self.img_size, self.img_size)),
50 | transforms.ToTensor(),
51 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]),
52 | ])
53 | return transform
54 |
55 | def _prepare_weights(self, reweight, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
56 | assert reweight in {'none', 'inverse', 'sqrt_inv'}
57 | assert reweight != 'none' if lds else True, \
58 | "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS"
59 |
60 | value_dict = {x: 0 for x in range(max_target)}
61 | labels = self.df['age'].values
62 | for label in labels:
63 | value_dict[min(max_target - 1, int(label))] += 1
64 | if reweight == 'sqrt_inv':
65 | value_dict = {k: np.sqrt(v) for k, v in value_dict.items()}
66 | elif reweight == 'inverse':
67 | value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight
68 | num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels]
69 | if not len(num_per_label) or reweight == 'none':
70 | return None
71 | print(f"Using re-weighting: [{reweight.upper()}]")
72 |
73 | if lds:
74 | lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
75 | print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
76 | smoothed_value = convolve1d(
77 | np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant')
78 | num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels]
79 |
80 | weights = [np.float32(1 / x) for x in num_per_label]
81 | scaling = len(weights) / np.sum(weights)
82 | weights = [scaling * x for x in weights]
83 | return weights
84 |
--------------------------------------------------------------------------------
/agedb-dir/fds.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | from scipy.ndimage import gaussian_filter1d
4 | from scipy.signal.windows import triang
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from utils import calibrate_mean_var
10 |
11 | print = logging.info
12 |
13 |
14 | class FDS(nn.Module):
15 |
16 | def __init__(self, feature_dim, bucket_num=100, bucket_start=3, start_update=0, start_smooth=1,
17 | kernel='gaussian', ks=5, sigma=2, momentum=0.9):
18 | super(FDS, self).__init__()
19 | self.feature_dim = feature_dim
20 | self.bucket_num = bucket_num
21 | self.bucket_start = bucket_start
22 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma)
23 | self.half_ks = (ks - 1) // 2
24 | self.momentum = momentum
25 | self.start_update = start_update
26 | self.start_smooth = start_smooth
27 |
28 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update))
29 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim))
30 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim))
31 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
32 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
33 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
34 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
35 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start))
36 |
37 | @staticmethod
38 | def _get_kernel_window(kernel, ks, sigma):
39 | assert kernel in ['gaussian', 'triang', 'laplace']
40 | half_ks = (ks - 1) // 2
41 | if kernel == 'gaussian':
42 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
43 | base_kernel = np.array(base_kernel, dtype=np.float32)
44 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma))
45 | elif kernel == 'triang':
46 | kernel_window = triang(ks) / sum(triang(ks))
47 | else:
48 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
49 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1)))
50 |
51 | print(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})')
52 | return torch.tensor(kernel_window, dtype=torch.float32).cuda()
53 |
54 | def _update_last_epoch_stats(self):
55 | self.running_mean_last_epoch = self.running_mean
56 | self.running_var_last_epoch = self.running_var
57 |
58 | self.smoothed_mean_last_epoch = F.conv1d(
59 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0),
60 | pad=(self.half_ks, self.half_ks), mode='reflect'),
61 | weight=self.kernel_window.view(1, 1, -1), padding=0
62 | ).permute(2, 1, 0).squeeze(1)
63 | self.smoothed_var_last_epoch = F.conv1d(
64 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0),
65 | pad=(self.half_ks, self.half_ks), mode='reflect'),
66 | weight=self.kernel_window.view(1, 1, -1), padding=0
67 | ).permute(2, 1, 0).squeeze(1)
68 |
69 | def reset(self):
70 | self.running_mean.zero_()
71 | self.running_var.fill_(1)
72 | self.running_mean_last_epoch.zero_()
73 | self.running_var_last_epoch.fill_(1)
74 | self.smoothed_mean_last_epoch.zero_()
75 | self.smoothed_var_last_epoch.fill_(1)
76 | self.num_samples_tracked.zero_()
77 |
78 | def update_last_epoch_stats(self, epoch):
79 | if epoch == self.epoch + 1:
80 | self.epoch += 1
81 | self._update_last_epoch_stats()
82 | print(f"Updated smoothed statistics on Epoch [{epoch}]!")
83 |
84 | def update_running_stats(self, features, labels, epoch):
85 | if epoch < self.epoch:
86 | return
87 |
88 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!"
89 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!"
90 |
91 | for label in torch.unique(labels):
92 | if label > self.bucket_num - 1 or label < self.bucket_start:
93 | continue
94 | elif label == self.bucket_start:
95 | curr_feats = features[labels <= label]
96 | elif label == self.bucket_num - 1:
97 | curr_feats = features[labels >= label]
98 | else:
99 | curr_feats = features[labels == label]
100 | curr_num_sample = curr_feats.size(0)
101 | curr_mean = torch.mean(curr_feats, 0)
102 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False)
103 |
104 | self.num_samples_tracked[int(label - self.bucket_start)] += curr_num_sample
105 | factor = self.momentum if self.momentum is not None else \
106 | (1 - curr_num_sample / float(self.num_samples_tracked[int(label - self.bucket_start)]))
107 | factor = 0 if epoch == self.start_update else factor
108 | self.running_mean[int(label - self.bucket_start)] = \
109 | (1 - factor) * curr_mean + factor * self.running_mean[int(label - self.bucket_start)]
110 | self.running_var[int(label - self.bucket_start)] = \
111 | (1 - factor) * curr_var + factor * self.running_var[int(label - self.bucket_start)]
112 |
113 | print(f"Updated running statistics with Epoch [{epoch}] features!")
114 |
115 | def smooth(self, features, labels, epoch):
116 | if epoch < self.start_smooth:
117 | return features
118 |
119 | labels = labels.squeeze(1)
120 | for label in torch.unique(labels):
121 | if label > self.bucket_num - 1 or label < self.bucket_start:
122 | continue
123 | elif label == self.bucket_start:
124 | features[labels <= label] = calibrate_mean_var(
125 | features[labels <= label],
126 | self.running_mean_last_epoch[int(label - self.bucket_start)],
127 | self.running_var_last_epoch[int(label - self.bucket_start)],
128 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
129 | self.smoothed_var_last_epoch[int(label - self.bucket_start)])
130 | elif label == self.bucket_num - 1:
131 | features[labels >= label] = calibrate_mean_var(
132 | features[labels >= label],
133 | self.running_mean_last_epoch[int(label - self.bucket_start)],
134 | self.running_var_last_epoch[int(label - self.bucket_start)],
135 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
136 | self.smoothed_var_last_epoch[int(label - self.bucket_start)])
137 | else:
138 | features[labels == label] = calibrate_mean_var(
139 | features[labels == label],
140 | self.running_mean_last_epoch[int(label - self.bucket_start)],
141 | self.running_var_last_epoch[int(label - self.bucket_start)],
142 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
143 | self.smoothed_var_last_epoch[int(label - self.bucket_start)])
144 | return features
145 |
--------------------------------------------------------------------------------
/agedb-dir/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def weighted_mse_loss(inputs, targets, weights=None):
6 | loss = (inputs - targets) ** 2
7 | if weights is not None:
8 | loss *= weights.expand_as(loss)
9 | loss = torch.mean(loss)
10 | return loss
11 |
12 |
13 | def weighted_l1_loss(inputs, targets, weights=None):
14 | loss = F.l1_loss(inputs, targets, reduction='none')
15 | if weights is not None:
16 | loss *= weights.expand_as(loss)
17 | loss = torch.mean(loss)
18 | return loss
19 |
20 |
21 | def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1):
22 | loss = (inputs - targets) ** 2
23 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
24 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
25 | if weights is not None:
26 | loss *= weights.expand_as(loss)
27 | loss = torch.mean(loss)
28 | return loss
29 |
30 |
31 | def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1):
32 | loss = F.l1_loss(inputs, targets, reduction='none')
33 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
34 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
35 | if weights is not None:
36 | loss *= weights.expand_as(loss)
37 | loss = torch.mean(loss)
38 | return loss
39 |
40 |
41 | def weighted_huber_loss(inputs, targets, weights=None, beta=1.):
42 | l1_loss = torch.abs(inputs - targets)
43 | cond = l1_loss < beta
44 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta)
45 | if weights is not None:
46 | loss *= weights.expand_as(loss)
47 | loss = torch.mean(loss)
48 | return loss
49 |
--------------------------------------------------------------------------------
/agedb-dir/resnet.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import torch.nn as nn
4 | from fds import FDS
5 |
6 | print = logging.info
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.conv2 = conv3x3(planes, planes)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | def forward(self, x):
28 | residual = x
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 | out = self.conv2(out)
33 | out = self.bn2(out)
34 | if self.downsample is not None:
35 | residual = self.downsample(x)
36 | out += residual
37 | out = self.relu(out)
38 | return out
39 |
40 |
41 | class Bottleneck(nn.Module):
42 | expansion = 4
43 |
44 | def __init__(self, inplanes, planes, stride=1, downsample=None):
45 | super(Bottleneck, self).__init__()
46 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
47 | self.bn1 = nn.BatchNorm2d(planes)
48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.bn2 = nn.BatchNorm2d(planes)
50 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(planes * 4)
52 | self.relu = nn.ReLU(inplace=True)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x):
57 | residual = x
58 | out = self.conv1(x)
59 | out = self.bn1(out)
60 | out = self.relu(out)
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 | out = self.relu(out)
64 | out = self.conv3(out)
65 | out = self.bn3(out)
66 | if self.downsample is not None:
67 | residual = self.downsample(x)
68 | out += residual
69 | out = self.relu(out)
70 | return out
71 |
72 |
73 | class ResNet(nn.Module):
74 |
75 | def __init__(self, block, layers, fds, bucket_num, bucket_start, start_update, start_smooth,
76 | kernel, ks, sigma, momentum, dropout=None):
77 | self.inplanes = 64
78 | super(ResNet, self).__init__()
79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
80 | self.bn1 = nn.BatchNorm2d(64)
81 | self.relu = nn.ReLU(inplace=True)
82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
83 | self.layer1 = self._make_layer(block, 64, layers[0])
84 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
85 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
86 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
87 | self.avgpool = nn.AvgPool2d(7, stride=1)
88 | self.linear = nn.Linear(512 * block.expansion, 1)
89 |
90 | if fds:
91 | self.FDS = FDS(
92 | feature_dim=512 * block.expansion, bucket_num=bucket_num, bucket_start=bucket_start,
93 | start_update=start_update, start_smooth=start_smooth, kernel=kernel, ks=ks, sigma=sigma, momentum=momentum
94 | )
95 | self.fds = fds
96 | self.start_smooth = start_smooth
97 |
98 | self.use_dropout = True if dropout else False
99 | if self.use_dropout:
100 | print(f'Using dropout: {dropout}')
101 | self.dropout = nn.Dropout(p=dropout)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1):
112 | downsample = None
113 | if stride != 1 or self.inplanes != planes * block.expansion:
114 | downsample = nn.Sequential(
115 | nn.Conv2d(self.inplanes, planes * block.expansion,
116 | kernel_size=1, stride=stride, bias=False),
117 | nn.BatchNorm2d(planes * block.expansion),
118 | )
119 | layers = []
120 | layers.append(block(self.inplanes, planes, stride, downsample))
121 | self.inplanes = planes * block.expansion
122 | for i in range(1, blocks):
123 | layers.append(block(self.inplanes, planes))
124 |
125 | return nn.Sequential(*layers)
126 |
127 | def forward(self, x, targets=None, epoch=None):
128 | x = self.conv1(x)
129 | x = self.bn1(x)
130 | x = self.relu(x)
131 | x = self.maxpool(x)
132 |
133 | x = self.layer1(x)
134 | x = self.layer2(x)
135 | x = self.layer3(x)
136 | x = self.layer4(x)
137 | x = self.avgpool(x)
138 | encoding = x.view(x.size(0), -1)
139 |
140 | encoding_s = encoding
141 |
142 | if self.training and self.fds:
143 | if epoch >= self.start_smooth:
144 | encoding_s = self.FDS.smooth(encoding_s, targets, epoch)
145 |
146 | if self.use_dropout:
147 | encoding_s = self.dropout(encoding_s)
148 | x = self.linear(encoding_s)
149 |
150 | if self.training and self.fds:
151 | return x, encoding
152 | else:
153 | return x
154 |
155 |
156 | def resnet50(**kwargs):
157 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
158 |
--------------------------------------------------------------------------------
/agedb-dir/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import torch
4 | import logging
5 | import numpy as np
6 | from scipy.ndimage import gaussian_filter1d
7 | from scipy.signal.windows import triang
8 |
9 |
10 | class AverageMeter(object):
11 | def __init__(self, name, fmt=':f'):
12 | self.name = name
13 | self.fmt = fmt
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 | def __str__(self):
29 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
30 | return fmtstr.format(**self.__dict__)
31 |
32 |
33 | class ProgressMeter(object):
34 | def __init__(self, num_batches, meters, prefix=""):
35 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
36 | self.meters = meters
37 | self.prefix = prefix
38 |
39 | def display(self, batch):
40 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
41 | entries += [str(meter) for meter in self.meters]
42 | logging.info('\t'.join(entries))
43 |
44 | @staticmethod
45 | def _get_batch_fmtstr(num_batches):
46 | num_digits = len(str(num_batches // 1))
47 | fmt = '{:' + str(num_digits) + 'd}'
48 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
49 |
50 |
51 | def query_yes_no(question):
52 | """ Ask a yes/no question via input() and return their answer. """
53 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
54 | prompt = " [Y/n] "
55 |
56 | while True:
57 | print(question + prompt, end=':')
58 | choice = input().lower()
59 | if choice == '':
60 | return valid['y']
61 | elif choice in valid:
62 | return valid[choice]
63 | else:
64 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n")
65 |
66 |
67 | def prepare_folders(args):
68 | folders_util = [args.store_root, os.path.join(args.store_root, args.store_name)]
69 | if os.path.exists(folders_util[-1]) and not args.resume and not args.pretrained and not args.evaluate:
70 | if query_yes_no('overwrite previous folder: {} ?'.format(folders_util[-1])):
71 | shutil.rmtree(folders_util[-1])
72 | print(folders_util[-1] + ' removed.')
73 | else:
74 | raise RuntimeError('Output folder {} already exists'.format(folders_util[-1]))
75 | for folder in folders_util:
76 | if not os.path.exists(folder):
77 | print(f"===> Creating folder: {folder}")
78 | os.mkdir(folder)
79 |
80 |
81 | def adjust_learning_rate(optimizer, epoch, args):
82 | lr = args.lr
83 | for milestone in args.schedule:
84 | lr *= 0.1 if epoch >= milestone else 1.
85 | for param_group in optimizer.param_groups:
86 | param_group['lr'] = lr
87 |
88 |
89 | def save_checkpoint(args, state, is_best, prefix=''):
90 | filename = f"{args.store_root}/{args.store_name}/{prefix}ckpt.pth.tar"
91 | torch.save(state, filename)
92 | if is_best:
93 | logging.info("===> Saving current best checkpoint...")
94 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
95 |
96 |
97 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.1, clip_max=10):
98 | if torch.sum(v1) < 1e-10:
99 | return matrix
100 | if (v1 == 0.).any():
101 | valid = (v1 != 0.)
102 | factor = torch.clamp(v2[valid] / v1[valid], clip_min, clip_max)
103 | matrix[:, valid] = (matrix[:, valid] - m1[valid]) * torch.sqrt(factor) + m2[valid]
104 | return matrix
105 |
106 | factor = torch.clamp(v2 / v1, clip_min, clip_max)
107 | return (matrix - m1) * torch.sqrt(factor) + m2
108 |
109 |
110 | def get_lds_kernel_window(kernel, ks, sigma):
111 | assert kernel in ['gaussian', 'triang', 'laplace']
112 | half_ks = (ks - 1) // 2
113 | if kernel == 'gaussian':
114 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
115 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma))
116 | elif kernel == 'triang':
117 | kernel_window = triang(ks)
118 | else:
119 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
120 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1)))
121 |
122 | return kernel_window
123 |
--------------------------------------------------------------------------------
/imdb-wiki-dir/README.md:
--------------------------------------------------------------------------------
1 | # IMDB-WIKI-DIR
2 | ## Installation
3 |
4 | #### Prerequisites
5 |
6 | 1. Download and extract IMDB faces and WIKI faces respectively using
7 |
8 | ```bash
9 | python download_imdb_wiki.py
10 | ```
11 |
12 | 2. __(Optional)__ We have provided required IMDB-WIKI-DIR meta file `imdb_wiki.csv` to set up balanced val/test set in folder `./data`. To reproduce the results in the paper, please directly use this file. You can also generate it using
13 |
14 | ```bash
15 | python data/create_imdb_wiki.py
16 | python data/preprocess_imdb_wiki.py
17 | ```
18 |
19 | #### Dependencies
20 |
21 | - PyTorch (>= 1.2, tested on 1.6)
22 | - tensorboard_logger
23 | - numpy, pandas, scipy, tqdm, matplotlib, PIL, wget
24 |
25 | ## Code Overview
26 |
27 | #### Main Files
28 |
29 | - `train.py`: main training and evaluation script
30 | - `create_imdb_wiki.py`: create IMDB-WIKI raw meta data
31 | - `preprocess_imdb_wiki.py`: create IMDB-WIKI-DIR meta file `imdb_wiki.csv` with balanced val/test set
32 |
33 | #### Main Arguments
34 |
35 | - `--data_dir`: data directory to place data and meta file
36 | - `--lds`: LDS switch (whether to enable LDS)
37 | - `--fds`: FDS switch (whether to enable FDS)
38 | - `--reweight`: cost-sensitive re-weighting scheme to use
39 | - `--retrain_fc`: whether to retrain regressor
40 | - `--loss`: training loss type
41 | - `--resume`: path to resume checkpoint (for both training and evaluation)
42 | - `--evaluate`: evaluate only flag
43 | - `--pretrained`: path to load backbone weights for regressor re-training (RRT)
44 |
45 | ## Getting Started
46 |
47 | #### Train a vanilla model
48 |
49 | ```bash
50 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_dir --reweight none
51 | ```
52 |
53 | Always specify `CUDA_VISIBLE_DEVICES` for GPU IDs to be used (by default, 4 GPUs) and `--data_dir` when training a model or directly fix your default data directory path in the code. We will omit these arguments in the following for simplicity.
54 |
55 | #### Train a model using re-weighting
56 |
57 | To perform inverse re-weighting
58 |
59 | ```bash
60 | python train.py --reweight inverse
61 | ```
62 |
63 | To perform square-root inverse re-weighting
64 |
65 | ```bash
66 | python train.py --reweight sqrt_inv
67 | ```
68 |
69 | #### Train a model with different losses
70 |
71 | To use Focal-R loss
72 |
73 | ```bash
74 | python train.py --loss focal_l1
75 | ```
76 |
77 | To use huber loss
78 |
79 | ```bash
80 | python train.py --loss huber
81 | ```
82 |
83 | #### Train a model using RRT
84 |
85 | ```bash
86 | python train.py [...retrained model arguments...] --retrain_fc --pretrained
87 | ```
88 |
89 | #### Train a model using LDS
90 |
91 | To use Gaussian kernel (kernel size: 5, sigma: 2)
92 |
93 | ```bash
94 | python train.py --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2
95 | ```
96 |
97 | #### Train a model using FDS
98 |
99 | To use Gaussian kernel (kernel size: 5, sigma: 2)
100 |
101 | ```bash
102 | python train.py --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
103 | ```
104 | #### Train a model using LDS + FDS
105 | ```bash
106 | python train.py --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
107 | ```
108 |
109 | #### Evaluate a trained checkpoint
110 |
111 | ```bash
112 | python train.py [...evaluation model arguments...] --evaluate --resume
113 | ```
114 |
115 | ## Reproduced Benchmarks and Model Zoo
116 |
117 | We provide below reproduced results on IMDB-WIKI-DIR (base method `SQINV`, metric `MAE`).
118 | Note that some models could give **better** results than the reported numbers in the paper.
119 |
120 |
121 | | Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download |
122 | | :-------: | :-----: | :-------: | :---------: | :------: | :------: |
123 | | LDS | 7.87 | 7.31 | 12.45 | 22.60 | [model](https://drive.google.com/file/d/1HnGw1gs6UAlvbol4EulHX_Kqx_pwJZ70/view?usp=sharing) |
124 | | FDS | 7.66 | 7.06 | 12.60 | 22.37 | [model](https://drive.google.com/file/d/1H7_dDMn83-paFrcrEmOiDLZmoham4js9/view?usp=sharing) |
125 | | LDS + FDS | 7.68 | 7.07 | 12.79 | 21.85 | [model](https://drive.google.com/file/d/1C_YxpTW-rhCRIF4wnFShojp5ydAjFmHo/view?usp=sharing) |
--------------------------------------------------------------------------------
/imdb-wiki-dir/data/create_imdb_wiki.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import pandas as pd
5 | from tqdm import tqdm
6 | from scipy.io import loadmat
7 | from datetime import datetime
8 |
9 |
10 | def calc_age(taken, dob):
11 | birth = datetime.fromordinal(max(int(dob) - 366, 1))
12 | # assume the photo was taken in the middle of the year
13 | if birth.month < 7:
14 | return taken - birth.year
15 | else:
16 | return taken - birth.year - 1
17 |
18 |
19 | def get_meta(mat_path, db):
20 | meta = loadmat(mat_path)
21 | full_path = meta[db][0, 0]["full_path"][0]
22 | dob = meta[db][0, 0]["dob"][0] # date
23 | gender = meta[db][0, 0]["gender"][0]
24 | photo_taken = meta[db][0, 0]["photo_taken"][0] # year
25 | face_score = meta[db][0, 0]["face_score"][0]
26 | second_face_score = meta[db][0, 0]["second_face_score"][0]
27 | age = [calc_age(photo_taken[i], dob[i]) for i in range(len(dob))]
28 |
29 | return full_path, dob, gender, photo_taken, face_score, second_face_score, age
30 |
31 |
32 | def load_data(mat_path):
33 | d = loadmat(mat_path)
34 | return d["image"], d["gender"][0], d["age"][0], d["db"][0], d["img_size"][0, 0], d["min_score"][0, 0]
35 |
36 |
37 | def combine_dataset(path='meta'):
38 | args = get_args()
39 | data_imdb = pd.read_csv(os.path.join(args.data_path, path, "imdb.csv"))
40 | data_wiki = pd.read_csv(os.path.join(args.data_path, path, "wiki.csv"))
41 | data_imdb['path'] = data_imdb['path'].apply(lambda x: f"imdb_crop/{x}")
42 | data_wiki['path'] = data_wiki['path'].apply(lambda x: f"wiki_crop/{x}")
43 | df = pd.concat((data_imdb, data_wiki))
44 | output_path = os.path.join(args.data_path, path, "imdb_wiki.csv")
45 | df.to_csv(str(output_path), index=False)
46 |
47 |
48 | def get_args():
49 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
50 | parser.add_argument("--data_path", type=str, default="./data")
51 | parser.add_argument("--min_score", type=float, default=1., help="minimum face score")
52 | args = parser.parse_args()
53 | return args
54 |
55 |
56 | def create(db):
57 | args = get_args()
58 | mat_path = os.path.join(args.data_path, f"{db}_crop", f"{db}.mat")
59 | full_path, dob, gender, photo_taken, face_score, second_face_score, age = get_meta(mat_path, db)
60 |
61 | ages, img_paths = [], []
62 |
63 | for i in tqdm(range(len(face_score))):
64 | if face_score[i] < args.min_score:
65 | continue
66 |
67 | if (~np.isnan(second_face_score[i])) and second_face_score[i] > 0.0:
68 | continue
69 |
70 | if ~(0 <= age[i] <= 200):
71 | continue
72 |
73 | ages.append(age[i])
74 | img_paths.append(full_path[i][0])
75 |
76 | outputs = dict(age=ages, path=img_paths)
77 | output_dir = os.path.join(args.data_path, "meta")
78 | os.makedirs(output_dir, exist_ok=True)
79 | output_path = os.path.join(output_dir, f"{db}.csv")
80 | df = pd.DataFrame(data=outputs)
81 | df.to_csv(str(output_path), index=False)
82 |
83 |
84 | if __name__ == '__main__':
85 | create("imdb")
86 | create("wiki")
87 | combine_dataset()
88 |
--------------------------------------------------------------------------------
/imdb-wiki-dir/data/preprocess_imdb_wiki.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | BASE_PATH = './data'
7 |
8 |
9 | def visualize_dataset(db="imdb_wiki"):
10 | file_path = join(BASE_PATH, "meta", f"{db}.csv")
11 | data = pd.read_csv(file_path)
12 | _, ax = plt.subplots(figsize=(6, 3), sharex='all', sharey='all')
13 | ax.hist(data['age'], range(max(data['age'])))
14 | ax.set_xlim([0, 120])
15 | plt.title(f"{db.upper()} (total: {data.shape[0]})")
16 | plt.tight_layout()
17 | plt.show()
18 |
19 |
20 | def make_balanced_testset(db="imdb_wiki", max_size=150, seed=666, verbose=True, vis=True, save=False):
21 | file_path = join(BASE_PATH, "meta", f"{db}.csv")
22 | df = pd.read_csv(file_path)
23 | df['age'] = df.age.astype(int)
24 | val_set, test_set = [], []
25 | import random
26 | random.seed(seed)
27 | for value in range(121):
28 | curr_df = df[df['age'] == value]
29 | curr_data = curr_df['path'].values
30 | random.shuffle(curr_data)
31 | curr_size = min(len(curr_data) // 3, max_size)
32 | val_set += list(curr_data[:curr_size])
33 | test_set += list(curr_data[curr_size:curr_size * 2])
34 | if verbose:
35 | print(f"Val: {len(val_set)}\nTest: {len(test_set)}")
36 | assert len(set(val_set).intersection(set(test_set))) == 0
37 | combined_set = dict(zip(val_set, ['val' for _ in range(len(val_set))]))
38 | combined_set.update(dict(zip(test_set, ['test' for _ in range(len(test_set))])))
39 | df['split'] = df['path'].map(combined_set)
40 | df['split'].fillna('train', inplace=True)
41 | if verbose:
42 | print(df)
43 | if save:
44 | df.to_csv(str(join(BASE_PATH, f"{db}.csv")), index=False)
45 | if vis:
46 | _, ax = plt.subplots(3, figsize=(6, 9), sharex='all')
47 | df_train = df[df['split'] == 'train']
48 | # df_train = df_train[(df_train['age'] <= 20) | (df_train['age'] > 50)]
49 | ax[0].hist(df_train['age'], range(max(df['age'])))
50 | ax[0].set_title(f"[{db.upper()}] train: {df_train.shape[0]}")
51 | ax[1].hist(df[df['split'] == 'val']['age'], range(max(df['age'])))
52 | ax[1].set_title(f"[{db.upper()}] val: {df[df['split'] == 'val'].shape[0]}")
53 | ax[2].hist(df[df['split'] == 'test']['age'], range(max(df['age'])))
54 | ax[2].set_title(f"[{db.upper()}] test: {df[df['split'] == 'test'].shape[0]}")
55 | ax[0].set_xlim([0, 120])
56 | plt.tight_layout()
57 | plt.show()
58 |
59 |
60 | if __name__ == '__main__':
61 | make_balanced_testset()
62 | visualize_dataset(db="imdb_wiki")
63 |
--------------------------------------------------------------------------------
/imdb-wiki-dir/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | from PIL import Image
5 | from scipy.ndimage import convolve1d
6 | from torch.utils import data
7 | import torchvision.transforms as transforms
8 |
9 | from utils import get_lds_kernel_window
10 |
11 | print = logging.info
12 |
13 |
14 | class IMDBWIKI(data.Dataset):
15 | def __init__(self, df, data_dir, img_size, split='train', reweight='none',
16 | lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
17 | self.df = df
18 | self.data_dir = data_dir
19 | self.img_size = img_size
20 | self.split = split
21 |
22 | self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma)
23 |
24 | def __len__(self):
25 | return len(self.df)
26 |
27 | def __getitem__(self, index):
28 | index = index % len(self.df)
29 | row = self.df.iloc[index]
30 | img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB')
31 | transform = self.get_transform()
32 | img = transform(img)
33 | label = np.asarray([row['age']]).astype('float32')
34 | weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)])
35 |
36 | return img, label, weight
37 |
38 | def get_transform(self):
39 | if self.split == 'train':
40 | transform = transforms.Compose([
41 | transforms.Resize((self.img_size, self.img_size)),
42 | transforms.RandomCrop(self.img_size, padding=16),
43 | transforms.RandomHorizontalFlip(),
44 | transforms.ToTensor(),
45 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]),
46 | ])
47 | else:
48 | transform = transforms.Compose([
49 | transforms.Resize((self.img_size, self.img_size)),
50 | transforms.ToTensor(),
51 | transforms.Normalize([.5, .5, .5], [.5, .5, .5]),
52 | ])
53 | return transform
54 |
55 | def _prepare_weights(self, reweight, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
56 | assert reweight in {'none', 'inverse', 'sqrt_inv'}
57 | assert reweight != 'none' if lds else True, \
58 | "Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS"
59 |
60 | value_dict = {x: 0 for x in range(max_target)}
61 | labels = self.df['age'].values
62 | for label in labels:
63 | value_dict[min(max_target - 1, int(label))] += 1
64 | if reweight == 'sqrt_inv':
65 | value_dict = {k: np.sqrt(v) for k, v in value_dict.items()}
66 | elif reweight == 'inverse':
67 | value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight
68 | num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels]
69 | if not len(num_per_label) or reweight == 'none':
70 | return None
71 | print(f"Using re-weighting: [{reweight.upper()}]")
72 |
73 | if lds:
74 | lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
75 | print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
76 | smoothed_value = convolve1d(
77 | np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant')
78 | num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels]
79 |
80 | weights = [np.float32(1 / x) for x in num_per_label]
81 | scaling = len(weights) / np.sum(weights)
82 | weights = [scaling * x for x in weights]
83 | return weights
84 |
--------------------------------------------------------------------------------
/imdb-wiki-dir/download_imdb_wiki.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wget
3 |
4 | print("Downloading IMDB faces...")
5 | imdb_file = "./data/imdb_crop.tar"
6 | wget.download("https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/imdb_crop.tar", out=imdb_file)
7 | print("Downloading WIKI faces...")
8 | wiki_file = "./data/wiki_crop.tar"
9 | wget.download("https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar", out=wiki_file)
10 | print("Extracting IMDB faces...")
11 | os.system(f"tar -xvf {imdb_file} -C ./data")
12 | print("Extracting WIKI faces...")
13 | os.system(f"tar -xvf {wiki_file} -C ./data")
14 | os.remove(imdb_file)
15 | os.remove(wiki_file)
16 | print("\nCompleted!")
--------------------------------------------------------------------------------
/imdb-wiki-dir/fds.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | from scipy.ndimage import gaussian_filter1d
4 | from scipy.signal.windows import triang
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from utils import calibrate_mean_var
10 |
11 | print = logging.info
12 |
13 |
14 | class FDS(nn.Module):
15 |
16 | def __init__(self, feature_dim, bucket_num=100, bucket_start=0, start_update=0, start_smooth=1,
17 | kernel='gaussian', ks=5, sigma=2, momentum=0.9):
18 | super(FDS, self).__init__()
19 | self.feature_dim = feature_dim
20 | self.bucket_num = bucket_num
21 | self.bucket_start = bucket_start
22 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma)
23 | self.half_ks = (ks - 1) // 2
24 | self.momentum = momentum
25 | self.start_update = start_update
26 | self.start_smooth = start_smooth
27 |
28 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update))
29 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim))
30 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim))
31 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
32 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
33 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
34 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
35 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start))
36 |
37 | @staticmethod
38 | def _get_kernel_window(kernel, ks, sigma):
39 | assert kernel in ['gaussian', 'triang', 'laplace']
40 | half_ks = (ks - 1) // 2
41 | if kernel == 'gaussian':
42 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
43 | base_kernel = np.array(base_kernel, dtype=np.float32)
44 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma))
45 | elif kernel == 'triang':
46 | kernel_window = triang(ks) / sum(triang(ks))
47 | else:
48 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
49 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1)))
50 |
51 | print(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})')
52 | return torch.tensor(kernel_window, dtype=torch.float32).cuda()
53 |
54 | def _update_last_epoch_stats(self):
55 | self.running_mean_last_epoch = self.running_mean
56 | self.running_var_last_epoch = self.running_var
57 |
58 | self.smoothed_mean_last_epoch = F.conv1d(
59 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0),
60 | pad=(self.half_ks, self.half_ks), mode='reflect'),
61 | weight=self.kernel_window.view(1, 1, -1), padding=0
62 | ).permute(2, 1, 0).squeeze(1)
63 | self.smoothed_var_last_epoch = F.conv1d(
64 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0),
65 | pad=(self.half_ks, self.half_ks), mode='reflect'),
66 | weight=self.kernel_window.view(1, 1, -1), padding=0
67 | ).permute(2, 1, 0).squeeze(1)
68 |
69 | def reset(self):
70 | self.running_mean.zero_()
71 | self.running_var.fill_(1)
72 | self.running_mean_last_epoch.zero_()
73 | self.running_var_last_epoch.fill_(1)
74 | self.smoothed_mean_last_epoch.zero_()
75 | self.smoothed_var_last_epoch.fill_(1)
76 | self.num_samples_tracked.zero_()
77 |
78 | def update_last_epoch_stats(self, epoch):
79 | if epoch == self.epoch + 1:
80 | self.epoch += 1
81 | self._update_last_epoch_stats()
82 | print(f"Updated smoothed statistics on Epoch [{epoch}]!")
83 |
84 | def update_running_stats(self, features, labels, epoch):
85 | if epoch < self.epoch:
86 | return
87 |
88 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!"
89 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!"
90 |
91 | for label in torch.unique(labels):
92 | if label > self.bucket_num - 1 or label < self.bucket_start:
93 | continue
94 | elif label == self.bucket_start:
95 | curr_feats = features[labels <= label]
96 | elif label == self.bucket_num - 1:
97 | curr_feats = features[labels >= label]
98 | else:
99 | curr_feats = features[labels == label]
100 | curr_num_sample = curr_feats.size(0)
101 | curr_mean = torch.mean(curr_feats, 0)
102 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False)
103 |
104 | self.num_samples_tracked[int(label - self.bucket_start)] += curr_num_sample
105 | factor = self.momentum if self.momentum is not None else \
106 | (1 - curr_num_sample / float(self.num_samples_tracked[int(label - self.bucket_start)]))
107 | factor = 0 if epoch == self.start_update else factor
108 | self.running_mean[int(label - self.bucket_start)] = \
109 | (1 - factor) * curr_mean + factor * self.running_mean[int(label - self.bucket_start)]
110 | self.running_var[int(label - self.bucket_start)] = \
111 | (1 - factor) * curr_var + factor * self.running_var[int(label - self.bucket_start)]
112 |
113 | print(f"Updated running statistics with Epoch [{epoch}] features!")
114 |
115 | def smooth(self, features, labels, epoch):
116 | if epoch < self.start_smooth:
117 | return features
118 |
119 | labels = labels.squeeze(1)
120 | for label in torch.unique(labels):
121 | if label > self.bucket_num - 1 or label < self.bucket_start:
122 | continue
123 | elif label == self.bucket_start:
124 | features[labels <= label] = calibrate_mean_var(
125 | features[labels <= label],
126 | self.running_mean_last_epoch[int(label - self.bucket_start)],
127 | self.running_var_last_epoch[int(label - self.bucket_start)],
128 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
129 | self.smoothed_var_last_epoch[int(label - self.bucket_start)])
130 | elif label == self.bucket_num - 1:
131 | features[labels >= label] = calibrate_mean_var(
132 | features[labels >= label],
133 | self.running_mean_last_epoch[int(label - self.bucket_start)],
134 | self.running_var_last_epoch[int(label - self.bucket_start)],
135 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
136 | self.smoothed_var_last_epoch[int(label - self.bucket_start)])
137 | else:
138 | features[labels == label] = calibrate_mean_var(
139 | features[labels == label],
140 | self.running_mean_last_epoch[int(label - self.bucket_start)],
141 | self.running_var_last_epoch[int(label - self.bucket_start)],
142 | self.smoothed_mean_last_epoch[int(label - self.bucket_start)],
143 | self.smoothed_var_last_epoch[int(label - self.bucket_start)])
144 | return features
145 |
--------------------------------------------------------------------------------
/imdb-wiki-dir/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def weighted_mse_loss(inputs, targets, weights=None):
6 | loss = (inputs - targets) ** 2
7 | if weights is not None:
8 | loss *= weights.expand_as(loss)
9 | loss = torch.mean(loss)
10 | return loss
11 |
12 |
13 | def weighted_l1_loss(inputs, targets, weights=None):
14 | loss = F.l1_loss(inputs, targets, reduction='none')
15 | if weights is not None:
16 | loss *= weights.expand_as(loss)
17 | loss = torch.mean(loss)
18 | return loss
19 |
20 |
21 | def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1):
22 | loss = (inputs - targets) ** 2
23 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
24 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
25 | if weights is not None:
26 | loss *= weights.expand_as(loss)
27 | loss = torch.mean(loss)
28 | return loss
29 |
30 |
31 | def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1):
32 | loss = F.l1_loss(inputs, targets, reduction='none')
33 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
34 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
35 | if weights is not None:
36 | loss *= weights.expand_as(loss)
37 | loss = torch.mean(loss)
38 | return loss
39 |
40 |
41 | def weighted_huber_loss(inputs, targets, weights=None, beta=1.):
42 | l1_loss = torch.abs(inputs - targets)
43 | cond = l1_loss < beta
44 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta)
45 | if weights is not None:
46 | loss *= weights.expand_as(loss)
47 | loss = torch.mean(loss)
48 | return loss
49 |
--------------------------------------------------------------------------------
/imdb-wiki-dir/resnet.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import torch.nn as nn
4 | from fds import FDS
5 |
6 | print = logging.info
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | """3x3 convolution with padding"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.conv2 = conv3x3(planes, planes)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | def forward(self, x):
28 | residual = x
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 | out = self.conv2(out)
33 | out = self.bn2(out)
34 | if self.downsample is not None:
35 | residual = self.downsample(x)
36 | out += residual
37 | out = self.relu(out)
38 | return out
39 |
40 |
41 | class Bottleneck(nn.Module):
42 | expansion = 4
43 |
44 | def __init__(self, inplanes, planes, stride=1, downsample=None):
45 | super(Bottleneck, self).__init__()
46 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
47 | self.bn1 = nn.BatchNorm2d(planes)
48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.bn2 = nn.BatchNorm2d(planes)
50 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(planes * 4)
52 | self.relu = nn.ReLU(inplace=True)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x):
57 | residual = x
58 | out = self.conv1(x)
59 | out = self.bn1(out)
60 | out = self.relu(out)
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 | out = self.relu(out)
64 | out = self.conv3(out)
65 | out = self.bn3(out)
66 | if self.downsample is not None:
67 | residual = self.downsample(x)
68 | out += residual
69 | out = self.relu(out)
70 | return out
71 |
72 |
73 | class ResNet(nn.Module):
74 |
75 | def __init__(self, block, layers, fds, bucket_num, bucket_start, start_update, start_smooth,
76 | kernel, ks, sigma, momentum, dropout=None):
77 | self.inplanes = 64
78 | super(ResNet, self).__init__()
79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
80 | self.bn1 = nn.BatchNorm2d(64)
81 | self.relu = nn.ReLU(inplace=True)
82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
83 | self.layer1 = self._make_layer(block, 64, layers[0])
84 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
85 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
86 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
87 | self.avgpool = nn.AvgPool2d(7, stride=1)
88 | self.linear = nn.Linear(512 * block.expansion, 1)
89 |
90 | if fds:
91 | self.FDS = FDS(
92 | feature_dim=512 * block.expansion, bucket_num=bucket_num, bucket_start=bucket_start,
93 | start_update=start_update, start_smooth=start_smooth, kernel=kernel, ks=ks, sigma=sigma, momentum=momentum
94 | )
95 | self.fds = fds
96 | self.start_smooth = start_smooth
97 |
98 | self.use_dropout = True if dropout else False
99 | if self.use_dropout:
100 | print(f'Using dropout: {dropout}')
101 | self.dropout = nn.Dropout(p=dropout)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1):
112 | downsample = None
113 | if stride != 1 or self.inplanes != planes * block.expansion:
114 | downsample = nn.Sequential(
115 | nn.Conv2d(self.inplanes, planes * block.expansion,
116 | kernel_size=1, stride=stride, bias=False),
117 | nn.BatchNorm2d(planes * block.expansion),
118 | )
119 | layers = []
120 | layers.append(block(self.inplanes, planes, stride, downsample))
121 | self.inplanes = planes * block.expansion
122 | for i in range(1, blocks):
123 | layers.append(block(self.inplanes, planes))
124 |
125 | return nn.Sequential(*layers)
126 |
127 | def forward(self, x, targets=None, epoch=None):
128 | x = self.conv1(x)
129 | x = self.bn1(x)
130 | x = self.relu(x)
131 | x = self.maxpool(x)
132 |
133 | x = self.layer1(x)
134 | x = self.layer2(x)
135 | x = self.layer3(x)
136 | x = self.layer4(x)
137 | x = self.avgpool(x)
138 | encoding = x.view(x.size(0), -1)
139 |
140 | encoding_s = encoding
141 |
142 | if self.training and self.fds:
143 | if epoch >= self.start_smooth:
144 | encoding_s = self.FDS.smooth(encoding_s, targets, epoch)
145 |
146 | if self.use_dropout:
147 | encoding_s = self.dropout(encoding_s)
148 | x = self.linear(encoding_s)
149 |
150 | if self.training and self.fds:
151 | return x, encoding
152 | else:
153 | return x
154 |
155 |
156 | def resnet50(**kwargs):
157 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
158 |
--------------------------------------------------------------------------------
/imdb-wiki-dir/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import torch
4 | import logging
5 | import numpy as np
6 | from scipy.ndimage import gaussian_filter1d
7 | from scipy.signal.windows import triang
8 |
9 |
10 | class AverageMeter(object):
11 | def __init__(self, name, fmt=':f'):
12 | self.name = name
13 | self.fmt = fmt
14 | self.reset()
15 |
16 | def reset(self):
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 | def __str__(self):
29 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
30 | return fmtstr.format(**self.__dict__)
31 |
32 |
33 | class ProgressMeter(object):
34 | def __init__(self, num_batches, meters, prefix=""):
35 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
36 | self.meters = meters
37 | self.prefix = prefix
38 |
39 | def display(self, batch):
40 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
41 | entries += [str(meter) for meter in self.meters]
42 | logging.info('\t'.join(entries))
43 |
44 | @staticmethod
45 | def _get_batch_fmtstr(num_batches):
46 | num_digits = len(str(num_batches // 1))
47 | fmt = '{:' + str(num_digits) + 'd}'
48 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
49 |
50 |
51 | def query_yes_no(question):
52 | """ Ask a yes/no question via input() and return their answer. """
53 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
54 | prompt = " [Y/n] "
55 |
56 | while True:
57 | print(question + prompt, end=':')
58 | choice = input().lower()
59 | if choice == '':
60 | return valid['y']
61 | elif choice in valid:
62 | return valid[choice]
63 | else:
64 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n")
65 |
66 |
67 | def prepare_folders(args):
68 | folders_util = [args.store_root, os.path.join(args.store_root, args.store_name)]
69 | if os.path.exists(folders_util[-1]) and not args.resume and not args.pretrained and not args.evaluate:
70 | if query_yes_no('overwrite previous folder: {} ?'.format(folders_util[-1])):
71 | shutil.rmtree(folders_util[-1])
72 | print(folders_util[-1] + ' removed.')
73 | else:
74 | raise RuntimeError('Output folder {} already exists'.format(folders_util[-1]))
75 | for folder in folders_util:
76 | if not os.path.exists(folder):
77 | print(f"===> Creating folder: {folder}")
78 | os.mkdir(folder)
79 |
80 |
81 | def adjust_learning_rate(optimizer, epoch, args):
82 | lr = args.lr
83 | for milestone in args.schedule:
84 | lr *= 0.1 if epoch >= milestone else 1.
85 | for param_group in optimizer.param_groups:
86 | param_group['lr'] = lr
87 |
88 |
89 | def save_checkpoint(args, state, is_best, prefix=''):
90 | filename = f"{args.store_root}/{args.store_name}/{prefix}ckpt.pth.tar"
91 | torch.save(state, filename)
92 | if is_best:
93 | logging.info("===> Saving current best checkpoint...")
94 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
95 |
96 |
97 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.1, clip_max=10):
98 | if torch.sum(v1) < 1e-10:
99 | return matrix
100 | if (v1 == 0.).any():
101 | valid = (v1 != 0.)
102 | factor = torch.clamp(v2[valid] / v1[valid], clip_min, clip_max)
103 | matrix[:, valid] = (matrix[:, valid] - m1[valid]) * torch.sqrt(factor) + m2[valid]
104 | return matrix
105 |
106 | factor = torch.clamp(v2 / v1, clip_min, clip_max)
107 | return (matrix - m1) * torch.sqrt(factor) + m2
108 |
109 |
110 | def get_lds_kernel_window(kernel, ks, sigma):
111 | assert kernel in ['gaussian', 'triang', 'laplace']
112 | half_ks = (ks - 1) // 2
113 | if kernel == 'gaussian':
114 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
115 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma))
116 | elif kernel == 'triang':
117 | kernel_window = triang(ks)
118 | else:
119 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
120 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1)))
121 |
122 | return kernel_window
123 |
--------------------------------------------------------------------------------
/nyud2-dir/README.md:
--------------------------------------------------------------------------------
1 | # NYUD2-DIR
2 | ## Installation
3 |
4 | #### Prerequisites
5 |
6 | 1. Download and extract NYU v2 dataset to folder `./data` using
7 |
8 | ```bash
9 | python download_nyud2.py
10 | ```
11 |
12 | 2. __(Optional)__ We have provided required meta files `nyu2_train_FDS_subset.csv` and `test_balanced_mask.npy` for efficient FDS feature statistics computation and balanced test set mask in folder `./data`. To reproduce the results in the paper, please directly use these two files. If you want to try different FDS computation subsets and balanced test set masks, you can run
13 |
14 | ```bash
15 | python preprocess_nyud2.py
16 | ```
17 |
18 | #### Dependencies
19 |
20 | - PyTorch (>= 1.2, tested on 1.6)
21 | - numpy, pandas, scipy, tqdm, matplotlib, PIL, gdown, tensorboardX
22 |
23 | ## Code Overview
24 |
25 | #### Main Files
26 |
27 | - `train.py`: main training script
28 | - `test.py`: main evaluation script
29 | - `preprocess_nyud2.py`: create meta files `nyu2_train_FDS_subset.csv` and `test_balanced_mask.npy` for NYUD2-DIR
30 |
31 | #### Main Arguments
32 |
33 | For `train.py`:
34 |
35 | - `--data_dir`: data directory to place data and meta file
36 | - `--lds`: LDS switch (whether to enable LDS)
37 | - `--fds`: FDS switch (whether to enable FDS)
38 | - `--reweight`: cost-sensitive re-weighting scheme to use
39 | - `--resume`: whether to resume training (only for training)
40 | - `--retrain_fc`: whether to retrain regressor
41 | - `--pretrained`: path to load backbone weights for regressor re-training (RRT)
42 |
43 | For `test.py`:
44 |
45 | - `--eval_model`: path to resume checkpoint (only for evaluation)
46 |
47 | ## Getting Started
48 |
49 | #### Train a vanilla model
50 |
51 | ```bash
52 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_dir --reweight none
53 | ```
54 |
55 | Always specify `CUDA_VISIBLE_DEVICES` for GPU IDs to be used (by default, 4 GPUs) and `--data_dir` when training a model or directly fix your default data directory path in the code. We will omit these arguments in the following for simplicity.
56 |
57 |
58 | #### Train a model using re-weighting
59 |
60 | To perform inverse re-weighting
61 |
62 | ```bash
63 | python train.py --reweight inverse
64 | ```
65 |
66 | To perform square-root inverse re-weighting
67 |
68 | ```bash
69 | python train.py --reweight sqrt_inv
70 | ```
71 |
72 | #### Train a model using RRT
73 |
74 | ```bash
75 | python train.py [...retrained model arguments...] --retrain_fc --pretrained
76 | ```
77 |
78 | #### Train a model using LDS
79 |
80 | To use Gaussian kernel (kernel size: 5, sigma: 2)
81 |
82 | ```bash
83 | python train.py --reweight inverse --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2
84 | ```
85 |
86 | #### Train a model using FDS
87 |
88 | To use Gaussian kernel (kernel size: 5, sigma: 2)
89 |
90 | ```bash
91 | python train.py --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
92 | ```
93 |
94 | #### Train a model using LDS + FDS
95 |
96 | ```bash
97 | python train.py --reweight inverse --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
98 | ```
99 |
100 | #### Evaluate a trained checkpoint
101 |
102 | ```bash
103 | python test.py --data_dir --eval_model
104 | ```
105 |
106 | ## Reproduced Benchmarks and Model Zoo
107 |
108 | We provide below reproduced results on NYUD2-DIR (base method `Vanilla`, metric `RMSE`).
109 | Note that some models could give **better** results than the reported numbers in the paper.
110 |
111 | | Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download |
112 | | :-------: | :-----: | :-------: | :---------: | :------: | :------: |
113 | | LDS | 1.387 | 0.671 | 0.913 | 1.954 | [model](https://drive.google.com/file/d/1RgQx-nreiJ-chH0887xCy7gxah-zrrEO/view?usp=sharing) |
114 | | FDS | 1.442 | 0.615 | 0.940 | 2.059 | [model](https://drive.google.com/file/d/1FEKzBzMPaGubmv9iK4BP6LJng44Mhc7s/view?usp=sharing) |
115 | | LDS + FDS | 1.301 | 0.731 | 0.832 | 1.799 | [model](https://drive.google.com/file/d/1QlZJOPYSyRRFqa1Q-y7-JlTDABiQZJUF/view?usp=sharing) |
--------------------------------------------------------------------------------
/nyud2-dir/data/test_balanced_mask.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/nyud2-dir/data/test_balanced_mask.npy
--------------------------------------------------------------------------------
/nyud2-dir/download_nyud2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gdown
3 | import zipfile
4 |
5 | print("Downloading and extracting NYU v2 dataset to folder './data'...")
6 | data_file = "./data.zip"
7 | gdown.download("https://drive.google.com/uc?id=1WoOZOBpOWfmwe7bknWS5PMUCLBPFKTOw")
8 | print('Extracting...')
9 | with zipfile.ZipFile(data_file) as zip_ref:
10 | zip_ref.extractall('.')
11 | os.remove(data_file)
12 | print("Completed!")
--------------------------------------------------------------------------------
/nyud2-dir/loaddata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import pandas as pd
4 | from torch.utils.data import Dataset, DataLoader
5 | from torchvision import transforms
6 | from nyu_transform import *
7 | from scipy.ndimage import convolve1d
8 | from util import get_lds_kernel_window
9 |
10 | # for data loading efficiency
11 | TRAIN_BUCKET_NUM = [0, 0, 0, 0, 0, 0, 0, 25848691, 24732940, 53324326, 69112955, 54455432, 95637682, 71403954, 117244217,
12 | 84813007, 126524456, 84486706, 133130272, 95464874, 146051415, 146133612, 96561379, 138366677, 89680276,
13 | 127689043, 81608990, 119121178, 74360607, 106839384, 97595765, 66718296, 90661239, 53103021, 83340912,
14 | 51365604, 71262770, 42243737, 65860580, 38415940, 53647559, 54038467, 28335524, 41485143, 32106001,
15 | 35936734, 23966211, 32018765, 19297203, 31503743, 21681574, 16363187, 25743420, 12769509, 17675327,
16 | 13147819, 15798560, 9547180, 14933200, 9663019, 12887283, 11803562, 7656609, 11515700, 7756306, 9046228,
17 | 5114894, 8653419, 6859433, 8001904, 6430700, 3305839, 6318461, 3486268, 5621065, 4030498, 3839488, 3220208,
18 | 4483027, 2555777, 4685983, 3145082, 2951048, 2762369, 2367581, 2546089, 2343867, 2481579, 1722140, 3018892,
19 | 2325197, 1952354, 2047038, 1858707, 2052729, 1348558, 2487278, 1314198, 3338550, 1132666]
20 |
21 | class depthDataset(Dataset):
22 | def __init__(self, data_dir, csv_file, mask_file=None, transform=None, args=None):
23 | self.data_dir = data_dir
24 | self.frame = pd.read_csv(csv_file, header=None)
25 | self.mask = torch.tensor(np.load(mask_file), dtype=torch.bool) if mask_file is not None else None
26 | self.transform = transform
27 | self.bucket_weights = self._get_bucket_weights(args) if args is not None else None
28 |
29 | def _get_bucket_weights(self, args):
30 | assert args.reweight in {'none', 'inverse', 'sqrt_inv'}
31 | assert args.reweight != 'none' if args.lds else True, "Set reweight to \'sqrt_inv\' or \'inverse\' (default) when using LDS"
32 | if args.reweight == 'none':
33 | return None
34 | logging.info(f"Using re-weighting: [{args.reweight.upper()}]")
35 |
36 | if args.lds:
37 | value_lst = TRAIN_BUCKET_NUM[args.bucket_start:]
38 | lds_kernel_window = get_lds_kernel_window(args.lds_kernel, args.lds_ks, args.lds_sigma)
39 | logging.info(f'Using LDS: [{args.lds_kernel.upper()}] ({args.lds_ks}/{args.lds_sigma})')
40 | if args.reweight == 'sqrt_inv':
41 | value_lst = np.sqrt(value_lst)
42 | smoothed_value = convolve1d(np.asarray(value_lst), weights=lds_kernel_window, mode='reflect')
43 | smoothed_value = [smoothed_value[0]] * args.bucket_start + list(smoothed_value)
44 | scaling = np.sum(TRAIN_BUCKET_NUM) / np.sum(np.array(TRAIN_BUCKET_NUM) / np.array(smoothed_value))
45 | bucket_weights = [np.float32(scaling / smoothed_value[bucket]) for bucket in range(args.bucket_num)]
46 | else:
47 | value_lst = [TRAIN_BUCKET_NUM[args.bucket_start]] * args.bucket_start + TRAIN_BUCKET_NUM[args.bucket_start:]
48 | if args.reweight == 'sqrt_inv':
49 | value_lst = np.sqrt(value_lst)
50 | scaling = np.sum(TRAIN_BUCKET_NUM) / np.sum(np.array(TRAIN_BUCKET_NUM) / np.array(value_lst))
51 | bucket_weights = [np.float32(scaling / value_lst[bucket]) for bucket in range(args.bucket_num)]
52 |
53 | return bucket_weights
54 |
55 | def get_bin_idx(self, x):
56 | return min(int(x * np.float32(10)), 99)
57 |
58 | def _get_weights(self, depth):
59 | sp = depth.shape
60 | if self.bucket_weights is not None:
61 | depth = depth.view(-1).cpu().numpy()
62 | assert depth.dtype == np.float32
63 | weights = np.array(list(map(lambda v: self.bucket_weights[self.get_bin_idx(v)], depth)))
64 | weights = torch.tensor(weights, dtype=torch.float32).view(*sp)
65 | else:
66 | weights = torch.tensor([np.float32(1.)], dtype=torch.float32).repeat(*sp)
67 | return weights
68 |
69 | def __getitem__(self, idx):
70 | image_name = self.frame.iloc[idx, 0]
71 | depth_name = self.frame.iloc[idx, 1]
72 |
73 | image_name = os.path.join(self.data_dir, '/'.join(image_name.split('/')[1:]))
74 | depth_name = os.path.join(self.data_dir, '/'.join(depth_name.split('/')[1:]))
75 |
76 | image = Image.open(image_name)
77 | depth = Image.open(depth_name)
78 |
79 | sample = {'image': image, 'depth': depth}
80 |
81 | if self.transform:
82 | sample = self.transform(sample)
83 |
84 | sample['weight'] = self._get_weights(sample['depth'])
85 | sample['idx'] = idx
86 |
87 | if self.mask is not None:
88 | sample['mask'] = self.mask[idx].unsqueeze(0)
89 |
90 | return sample
91 |
92 | def __len__(self):
93 | return len(self.frame)
94 |
95 |
96 | def getTrainingData(args, batch_size=64):
97 | __imagenet_pca = {
98 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
99 | 'eigvec': torch.Tensor([
100 | [-0.5675, 0.7192, 0.4009],
101 | [-0.5808, -0.0045, -0.8140],
102 | [-0.5836, -0.6948, 0.4203],
103 | ])
104 | }
105 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
106 | 'std': [0.229, 0.224, 0.225]}
107 |
108 | transformed_training = depthDataset(data_dir=args.data_dir,
109 | csv_file=os.path.join(args.data_dir, 'nyu2_train.csv'),
110 | transform=transforms.Compose([
111 | Scale(240),
112 | RandomHorizontalFlip(),
113 | RandomRotate(5),
114 | CenterCrop([304, 228], [152, 114]),
115 | ToTensor(),
116 | Lighting(0.1, __imagenet_pca[
117 | 'eigval'], __imagenet_pca['eigvec']),
118 | ColorJitter(
119 | brightness=0.4,
120 | contrast=0.4,
121 | saturation=0.4,
122 | ),
123 | Normalize(__imagenet_stats['mean'],
124 | __imagenet_stats['std'])
125 | ]), args=args)
126 |
127 | dataloader_training = DataLoader(transformed_training, batch_size,
128 | shuffle=True, num_workers=8, pin_memory=False)
129 |
130 | return dataloader_training
131 |
132 | def getTrainingFDSData(args, batch_size=64):
133 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
134 | 'std': [0.229, 0.224, 0.225]}
135 |
136 | transformed_training = depthDataset(data_dir=args.data_dir,
137 | csv_file=os.path.join(args.data_dir, 'nyu2_train_FDS_subset.csv'),
138 | transform=transforms.Compose([
139 | Scale(240),
140 | CenterCrop([304, 228], [152, 114]),
141 | ToTensor(),
142 | Normalize(__imagenet_stats['mean'],
143 | __imagenet_stats['std'])
144 | ]))
145 |
146 | dataloader_training = DataLoader(transformed_training, batch_size,
147 | shuffle=False, num_workers=8, pin_memory=False)
148 | return dataloader_training
149 |
150 |
151 | def getTestingData(args, batch_size=64):
152 |
153 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
154 | 'std': [0.229, 0.224, 0.225]}
155 |
156 | transformed_testing = depthDataset(data_dir=args.data_dir,
157 | csv_file=os.path.join(args.data_dir, 'nyu2_test.csv'),
158 | mask_file=os.path.join(args.data_dir, 'test_balanced_mask.npy'),
159 | transform=transforms.Compose([
160 | Scale(240),
161 | CenterCrop([304, 228], [304, 228]),
162 | ToTensor(is_test=True),
163 | Normalize(__imagenet_stats['mean'],
164 | __imagenet_stats['std'])
165 | ]))
166 |
167 | dataloader_testing = DataLoader(transformed_testing, batch_size,
168 | shuffle=False, num_workers=0, pin_memory=False)
169 |
170 | return dataloader_testing
171 |
--------------------------------------------------------------------------------
/nyud2-dir/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/nyud2-dir/models/__init__.py
--------------------------------------------------------------------------------
/nyud2-dir/models/fds.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from scipy.ndimage import gaussian_filter1d
7 | from scipy.signal.windows import triang
8 | from util import calibrate_mean_var
9 |
10 |
11 | class FDS(nn.Module):
12 |
13 | def __init__(self, feature_dim, bucket_num=100, bucket_start=7, start_update=0, start_smooth=1,
14 | kernel='gaussian', ks=5, sigma=2, momentum=0.9):
15 | super(FDS, self).__init__()
16 | self.feature_dim = feature_dim
17 | self.bucket_num = bucket_num
18 | self.bucket_start = bucket_start
19 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma)
20 | self.half_ks = (ks - 1) // 2
21 | self.momentum = momentum
22 | self.start_update = start_update
23 | self.start_smooth = start_smooth
24 |
25 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update))
26 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim))
27 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim))
28 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
29 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
30 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
31 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
32 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start))
33 |
34 | @staticmethod
35 | def _get_kernel_window(kernel, ks, sigma):
36 | assert kernel in ['gaussian', 'triang', 'laplace']
37 | half_ks = (ks - 1) // 2
38 | if kernel == 'gaussian':
39 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
40 | base_kernel = np.array(base_kernel, dtype=np.float32)
41 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma))
42 | elif kernel == 'triang':
43 | kernel_window = triang(ks) / sum(triang(ks))
44 | else:
45 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
46 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(map(laplace, np.arange(-half_ks, half_ks + 1)))
47 |
48 | logging.info(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})')
49 | return torch.tensor(kernel_window, dtype=torch.float32).cuda()
50 |
51 | def _get_bucket_idx(self, label):
52 | label = np.float32(label.cpu())
53 | return max(min(int(label * np.float32(10)), self.bucket_num - 1), self.bucket_start)
54 |
55 | def _update_last_epoch_stats(self):
56 | self.running_mean_last_epoch = self.running_mean
57 | self.running_var_last_epoch = self.running_var
58 |
59 | self.smoothed_mean_last_epoch = F.conv1d(
60 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0),
61 | pad=(self.half_ks, self.half_ks), mode='reflect'),
62 | weight=self.kernel_window.view(1, 1, -1), padding=0
63 | ).permute(2, 1, 0).squeeze(1)
64 | self.smoothed_var_last_epoch = F.conv1d(
65 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0),
66 | pad=(self.half_ks, self.half_ks), mode='reflect'),
67 | weight=self.kernel_window.view(1, 1, -1), padding=0
68 | ).permute(2, 1, 0).squeeze(1)
69 |
70 | assert self.smoothed_mean_last_epoch.shape == self.running_mean_last_epoch.shape, \
71 | "Smoothed shape is not aligned with running shape!"
72 |
73 | def reset(self):
74 | self.running_mean.zero_()
75 | self.running_var.fill_(1)
76 | self.running_mean_last_epoch.zero_()
77 | self.running_var_last_epoch.fill_(1)
78 | self.smoothed_mean_last_epoch.zero_()
79 | self.smoothed_var_last_epoch.fill_(1)
80 | self.num_samples_tracked.zero_()
81 |
82 | def update_last_epoch_stats(self, epoch):
83 | if epoch == self.epoch + 1:
84 | self.epoch += 1
85 | self._update_last_epoch_stats()
86 | logging.info(f"Updated smoothed statistics of last epoch on Epoch [{epoch}]!")
87 |
88 | def _running_stats_to_device(self, device):
89 | if device == 'cpu':
90 | self.num_samples_tracked = self.num_samples_tracked.cpu()
91 | self.running_mean = self.running_mean.cpu()
92 | self.running_var = self.running_var.cpu()
93 | else:
94 | self.num_samples_tracked = self.num_samples_tracked.cuda()
95 | self.running_mean = self.running_mean.cuda()
96 | self.running_var = self.running_var.cuda()
97 |
98 | def update_running_stats(self, features, labels, epoch):
99 | if epoch < self.epoch:
100 | return
101 |
102 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!"
103 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!"
104 |
105 | self._running_stats_to_device('cpu')
106 |
107 | labels = labels.squeeze(1).view(-1)
108 | features = features.permute(0, 2, 3, 1).contiguous().view(-1, self.feature_dim)
109 |
110 | buckets = np.array([self._get_bucket_idx(label) for label in labels])
111 | for bucket in np.unique(buckets):
112 | curr_feats = features[torch.tensor((buckets == bucket).astype(np.bool))]
113 | curr_num_sample = curr_feats.size(0)
114 | curr_mean = torch.mean(curr_feats, 0)
115 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False)
116 |
117 | self.num_samples_tracked[bucket - self.bucket_start] += curr_num_sample
118 | factor = self.momentum if self.momentum is not None else \
119 | (1 - curr_num_sample / float(self.num_samples_tracked[bucket - self.bucket_start]))
120 | factor = 0 if epoch == self.start_update else factor
121 | self.running_mean[bucket - self.bucket_start] = \
122 | (1 - factor) * curr_mean + factor * self.running_mean[bucket - self.bucket_start]
123 | self.running_var[bucket - self.bucket_start] = \
124 | (1 - factor) * curr_var + factor * self.running_var[bucket - self.bucket_start]
125 |
126 | self._running_stats_to_device('cuda')
127 | logging.info(f"Updated running statistics with Epoch [{epoch}] features!")
128 |
129 | def smooth(self, features, labels, epoch):
130 | if epoch < self.start_smooth:
131 | return features
132 |
133 | sp = labels.squeeze(1).shape
134 |
135 | labels = labels.squeeze(1).view(-1)
136 | features = features.permute(0, 2, 3, 1).contiguous().view(-1, self.feature_dim)
137 |
138 | buckets = torch.max(torch.stack([torch.min(torch.stack([torch.floor(labels * torch.tensor([10.]).cuda()).int(),
139 | torch.zeros(labels.size(0)).fill_(self.bucket_num - 1).int().cuda()], 0), 0)[0], torch.zeros(labels.size(0)).fill_(self.bucket_start).int().cuda()], 0), 0)[0]
140 | for bucket in torch.unique(buckets):
141 | features[buckets.eq(bucket)] = calibrate_mean_var(
142 | features[buckets.eq(bucket)],
143 | self.running_mean_last_epoch[bucket.item() - self.bucket_start],
144 | self.running_var_last_epoch[bucket.item() - self.bucket_start],
145 | self.smoothed_mean_last_epoch[bucket.item() - self.bucket_start],
146 | self.smoothed_var_last_epoch[bucket.item() - self.bucket_start]
147 | )
148 |
149 | return features.view(*sp, self.feature_dim).permute(0, 3, 1, 2)
150 |
--------------------------------------------------------------------------------
/nyud2-dir/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .fds import FDS
5 |
6 | class _UpProjection(nn.Sequential):
7 |
8 | def __init__(self, num_input_features, num_output_features):
9 | super(_UpProjection, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(num_input_features, num_output_features,
12 | kernel_size=5, stride=1, padding=2, bias=False)
13 | self.bn1 = nn.BatchNorm2d(num_output_features)
14 | self.relu = nn.ReLU(inplace=True)
15 | self.conv1_2 = nn.Conv2d(num_output_features, num_output_features,
16 | kernel_size=3, stride=1, padding=1, bias=False)
17 | self.bn1_2 = nn.BatchNorm2d(num_output_features)
18 |
19 | self.conv2 = nn.Conv2d(num_input_features, num_output_features,
20 | kernel_size=5, stride=1, padding=2, bias=False)
21 | self.bn2 = nn.BatchNorm2d(num_output_features)
22 |
23 | def forward(self, x, size):
24 | x = F.upsample(x, size=size, mode='bilinear')
25 | x_conv1 = self.relu(self.bn1(self.conv1(x)))
26 | bran1 = self.bn1_2(self.conv1_2(x_conv1))
27 | bran2 = self.bn2(self.conv2(x))
28 |
29 | out = self.relu(bran1 + bran2)
30 |
31 | return out
32 |
33 | class E_resnet(nn.Module):
34 |
35 | def __init__(self, original_model, num_features = 2048):
36 | super(E_resnet, self).__init__()
37 | self.conv1 = original_model.conv1
38 | self.bn1 = original_model.bn1
39 | self.relu = original_model.relu
40 | self.maxpool = original_model.maxpool
41 |
42 | self.layer1 = original_model.layer1
43 | self.layer2 = original_model.layer2
44 | self.layer3 = original_model.layer3
45 | self.layer4 = original_model.layer4
46 |
47 |
48 | def forward(self, x):
49 | x = self.conv1(x)
50 | x = self.bn1(x)
51 | x = self.relu(x)
52 | x = self.maxpool(x)
53 |
54 | x_block1 = self.layer1(x)
55 | x_block2 = self.layer2(x_block1)
56 | x_block3 = self.layer3(x_block2)
57 | x_block4 = self.layer4(x_block3)
58 |
59 | return x_block1, x_block2, x_block3, x_block4
60 |
61 | class D(nn.Module):
62 |
63 | def __init__(self, num_features = 2048):
64 | super(D, self).__init__()
65 | self.conv = nn.Conv2d(num_features, num_features //
66 | 2, kernel_size=1, stride=1, bias=False)
67 | num_features = num_features // 2
68 | self.bn = nn.BatchNorm2d(num_features)
69 |
70 | self.up1 = _UpProjection(
71 | num_input_features=num_features, num_output_features=num_features // 2)
72 | num_features = num_features // 2
73 |
74 | self.up2 = _UpProjection(
75 | num_input_features=num_features, num_output_features=num_features // 2)
76 | num_features = num_features // 2
77 |
78 | self.up3 = _UpProjection(
79 | num_input_features=num_features, num_output_features=num_features // 2)
80 | num_features = num_features // 2
81 |
82 | self.up4 = _UpProjection(
83 | num_input_features=num_features, num_output_features=num_features // 2)
84 | num_features = num_features // 2
85 |
86 |
87 | def forward(self, x_block1, x_block2, x_block3, x_block4):
88 | x_d0 = F.relu(self.bn(self.conv(x_block4)))
89 | x_d1 = self.up1(x_d0, [x_block3.size(2), x_block3.size(3)])
90 | x_d2 = self.up2(x_d1, [x_block2.size(2), x_block2.size(3)])
91 | x_d3 = self.up3(x_d2, [x_block1.size(2), x_block1.size(3)])
92 | x_d4 = self.up4(x_d3, [x_block1.size(2)*2, x_block1.size(3)*2])
93 |
94 | return x_d4
95 |
96 | class MFF(nn.Module):
97 |
98 | def __init__(self, block_channel, num_features=64):
99 |
100 | super(MFF, self).__init__()
101 |
102 | self.up1 = _UpProjection(
103 | num_input_features=block_channel[0], num_output_features=16)
104 |
105 | self.up2 = _UpProjection(
106 | num_input_features=block_channel[1], num_output_features=16)
107 |
108 | self.up3 = _UpProjection(
109 | num_input_features=block_channel[2], num_output_features=16)
110 |
111 | self.up4 = _UpProjection(
112 | num_input_features=block_channel[3], num_output_features=16)
113 |
114 | self.conv = nn.Conv2d(
115 | num_features, num_features, kernel_size=5, stride=1, padding=2, bias=False)
116 | self.bn = nn.BatchNorm2d(num_features)
117 |
118 |
119 | def forward(self, x_block1, x_block2, x_block3, x_block4, size):
120 | x_m1 = self.up1(x_block1, size)
121 | x_m2 = self.up2(x_block2, size)
122 | x_m3 = self.up3(x_block3, size)
123 | x_m4 = self.up4(x_block4, size)
124 |
125 | x = self.bn(self.conv(torch.cat((x_m1, x_m2, x_m3, x_m4), 1)))
126 | x = F.relu(x)
127 |
128 | return x
129 |
130 |
131 | class R(nn.Module):
132 | def __init__(self, args, block_channel):
133 |
134 | super(R, self).__init__()
135 |
136 | num_features = 64 + block_channel[3] // 32
137 | self.conv0 = nn.Conv2d(num_features, num_features,
138 | kernel_size=5, stride=1, padding=2, bias=False)
139 | self.bn0 = nn.BatchNorm2d(num_features)
140 |
141 | self.conv1 = nn.Conv2d(num_features, num_features,
142 | kernel_size=5, stride=1, padding=2, bias=False)
143 | self.bn1 = nn.BatchNorm2d(num_features)
144 |
145 | self.conv2 = nn.Conv2d(num_features, 1, kernel_size=5, stride=1, padding=2, bias=True)
146 |
147 | self.args = args
148 |
149 | if args is not None and args.fds:
150 | self.FDS = FDS(feature_dim=num_features, bucket_num=args.bucket_num, bucket_start=args.bucket_start,
151 | start_update=args.start_update, start_smooth=args.start_smooth, kernel=args.fds_kernel,
152 | ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt)
153 |
154 | def forward(self, x, depth=None, epoch=None):
155 | x0 = self.conv0(x)
156 | x0 = self.bn0(x0)
157 | x0 = F.relu(x0)
158 |
159 | x1 = self.conv1(x0)
160 | x1 = self.bn1(x1)
161 | x1 = F.relu(x1)
162 |
163 | x1_s = x1
164 |
165 | if self.training and self.args.fds:
166 | if epoch >= self.args.start_smooth:
167 | x1_s = self.FDS.smooth(x1_s, depth, epoch)
168 |
169 | x2 = self.conv2(x1_s)
170 |
171 | if self.training and self.args.fds:
172 | return x2, x1
173 | else:
174 | return x2
--------------------------------------------------------------------------------
/nyud2-dir/models/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models import modules
4 |
5 | class model(nn.Module):
6 | def __init__(self, args, Encoder, num_features, block_channel):
7 |
8 | super(model, self).__init__()
9 |
10 | self.E = Encoder
11 | self.D = modules.D(num_features)
12 | self.MFF = modules.MFF(block_channel)
13 | self.R = modules.R(args, block_channel)
14 |
15 |
16 | def forward(self, x, depth=None, epoch=None):
17 | x_block1, x_block2, x_block3, x_block4 = self.E(x)
18 | x_decoder = self.D(x_block1, x_block2, x_block3, x_block4)
19 | x_mff = self.MFF(x_block1, x_block2, x_block3, x_block4,[x_decoder.size(2),x_decoder.size(3)])
20 | out = self.R(torch.cat((x_decoder, x_mff), 1), depth, epoch)
21 |
22 | return out
23 |
--------------------------------------------------------------------------------
/nyud2-dir/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
6 |
7 |
8 | model_urls = {
9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=False)
21 |
22 |
23 | class BasicBlock(nn.Module):
24 | expansion = 1
25 |
26 | def __init__(self, inplanes, planes, stride=1, downsample=None):
27 | super(BasicBlock, self).__init__()
28 | self.conv1 = conv3x3(inplanes, planes, stride)
29 | self.bn1 = nn.BatchNorm2d(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.bn2 = nn.BatchNorm2d(planes)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | residual = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 |
46 | if self.downsample is not None:
47 | residual = self.downsample(x)
48 |
49 | out += residual
50 | out = self.relu(out)
51 |
52 | return out
53 |
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 4
57 |
58 | def __init__(self, inplanes, planes, stride=1, downsample=None):
59 | super(Bottleneck, self).__init__()
60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61 | self.bn1 = nn.BatchNorm2d(planes)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63 | padding=1, bias=False)
64 | self.bn2 = nn.BatchNorm2d(planes)
65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
66 | self.bn3 = nn.BatchNorm2d(planes * 4)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.downsample = downsample
69 | self.stride = stride
70 |
71 | def forward(self, x):
72 | residual = x
73 |
74 | out = self.conv1(x)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv2(out)
79 | out = self.bn2(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv3(out)
83 | out = self.bn3(out)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(x)
87 |
88 | out += residual
89 | out = self.relu(out)
90 |
91 | return out
92 |
93 |
94 | class ResNet(nn.Module):
95 |
96 | def __init__(self, block, layers, num_classes=1000):
97 | self.inplanes = 64
98 | super(ResNet, self).__init__()
99 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
100 | bias=False)
101 | self.bn1 = nn.BatchNorm2d(64)
102 | self.relu = nn.ReLU(inplace=True)
103 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
104 | self.layer1 = self._make_layer(block, 64, layers[0])
105 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
106 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
107 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
108 | self.avgpool = nn.AvgPool2d(7, stride=1)
109 | self.fc = nn.Linear(512 * block.expansion, num_classes)
110 |
111 | for m in self.modules():
112 | if isinstance(m, nn.Conv2d):
113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
114 | m.weight.data.normal_(0, math.sqrt(2. / n))
115 | elif isinstance(m, nn.BatchNorm2d):
116 | m.weight.data.fill_(1)
117 | m.bias.data.zero_()
118 |
119 | def _make_layer(self, block, planes, blocks, stride=1):
120 | downsample = None
121 | if stride != 1 or self.inplanes != planes * block.expansion:
122 | downsample = nn.Sequential(
123 | nn.Conv2d(self.inplanes, planes * block.expansion,
124 | kernel_size=1, stride=stride, bias=False),
125 | nn.BatchNorm2d(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, stride, downsample))
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(block(self.inplanes, planes))
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool(x)
141 |
142 | x = self.layer1(x)
143 | x = self.layer2(x)
144 | x = self.layer3(x)
145 | x = self.layer4(x)
146 |
147 | x = self.avgpool(x)
148 | x = x.view(x.size(0), -1)
149 | x = self.fc(x)
150 |
151 | return x
152 |
153 | def resnet18(pretrained=False, **kwargs):
154 | """Constructs a ResNet-18 model.
155 | Args:
156 | pretrained (bool): If True, returns a model pre-trained on ImageNet
157 | """
158 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
159 | if pretrained:
160 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
161 | return model
162 |
163 |
164 | def resnet34(pretrained=False, **kwargs):
165 | """Constructs a ResNet-34 model.
166 | Args:
167 | pretrained (bool): If True, returns a model pre-trained on ImageNet
168 | """
169 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
170 | if pretrained:
171 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
172 | return model
173 |
174 |
175 | def resnet50(pretrained=False, **kwargs):
176 | """Constructs a ResNet-50 model.
177 | Args:
178 | pretrained (bool): If True, returns a model pre-trained on ImageNet
179 | """
180 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
181 | if pretrained:
182 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'], 'pretrained_model/encoder'))
183 | return model
184 |
185 |
186 | def resnet101(pretrained=False, **kwargs):
187 | """Constructs a ResNet-101 model.
188 | Args:
189 | pretrained (bool): If True, returns a model pre-trained on ImageNet
190 | """
191 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
192 | if pretrained:
193 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
194 | return model
195 |
196 |
197 | def resnet152(pretrained=False, **kwargs):
198 | """Constructs a ResNet-152 model.
199 | Args:
200 | pretrained (bool): If True, returns a model pre-trained on ImageNet
201 | """
202 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
203 | if pretrained:
204 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
205 | return model
206 |
--------------------------------------------------------------------------------
/nyud2-dir/nyu_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | import collections
5 | try:
6 | import accimage
7 | except ImportError:
8 | accimage = None
9 | import random
10 | import scipy.ndimage as ndimage
11 |
12 |
13 | def _is_pil_image(img):
14 | if accimage is not None:
15 | return isinstance(img, (Image.Image, accimage.Image))
16 | else:
17 | return isinstance(img, Image.Image)
18 |
19 |
20 | def _is_numpy_image(img):
21 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
22 |
23 |
24 | class RandomRotate(object):
25 | """Random rotation of the image from -angle to angle (in degrees)
26 | This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation
27 | angle: max angle of the rotation
28 | interpolation order: Default: 2 (bilinear)
29 | reshape: Default: false. If set to true, image size will be set to keep every pixel in the image.
30 | diff_angle: Default: 0. Must stay less than 10 degrees, or linear approximation of flowmap will be off.
31 | """
32 |
33 | def __init__(self, angle, diff_angle=0, order=2, reshape=False):
34 | self.angle = angle
35 | self.reshape = reshape
36 | self.order = order
37 |
38 | def __call__(self, sample):
39 | image, depth = sample['image'], sample['depth']
40 |
41 | applied_angle = random.uniform(-self.angle, self.angle)
42 | angle1 = applied_angle
43 | angle1_rad = angle1 * np.pi / 180
44 |
45 | image = ndimage.interpolation.rotate(
46 | image, angle1, reshape=self.reshape, order=self.order)
47 | depth = ndimage.interpolation.rotate(
48 | depth, angle1, reshape=self.reshape, order=self.order)
49 |
50 | image = Image.fromarray(image)
51 | depth = Image.fromarray(depth)
52 |
53 | return {'image': image, 'depth': depth}
54 |
55 | class RandomHorizontalFlip(object):
56 |
57 | def __call__(self, sample):
58 | image, depth = sample['image'], sample['depth']
59 |
60 | if not _is_pil_image(image):
61 | raise TypeError(
62 | 'img should be PIL Image. Got {}'.format(type(image)))
63 | if not _is_pil_image(depth):
64 | raise TypeError(
65 | 'img should be PIL Image. Got {}'.format(type(depth)))
66 |
67 | if random.random() < 0.5:
68 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
69 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
70 |
71 | return {'image': image, 'depth': depth}
72 |
73 |
74 | class Scale(object):
75 | """ Rescales the inputs and target arrays to the given 'size'.
76 | 'size' will be the size of the smaller edge.
77 | For example, if height > width, then image will be
78 | rescaled to (size * height / width, size)
79 | size: size of the smaller edge
80 | interpolation order: Default: 2 (bilinear)
81 | """
82 |
83 | def __init__(self, size):
84 | self.size = size
85 |
86 | def __call__(self, sample):
87 | image, depth = sample['image'], sample['depth']
88 |
89 | image = self.changeScale(image, self.size)
90 | depth = self.changeScale(depth, self.size,Image.NEAREST)
91 |
92 | return {'image': image, 'depth': depth}
93 |
94 | def changeScale(self, img, size, interpolation=Image.BILINEAR):
95 |
96 | if not _is_pil_image(img):
97 | raise TypeError(
98 | 'img should be PIL Image. Got {}'.format(type(img)))
99 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
100 | raise TypeError('Got inappropriate size arg: {}'.format(size))
101 |
102 | if isinstance(size, int):
103 | w, h = img.size
104 | if (w <= h and w == size) or (h <= w and h == size):
105 | return img
106 | if w < h:
107 | ow = size
108 | oh = int(size * h / w)
109 | return img.resize((ow, oh), interpolation)
110 | else:
111 | oh = size
112 | ow = int(size * w / h)
113 | return img.resize((ow, oh), interpolation)
114 | else:
115 | return img.resize(size[::-1], interpolation)
116 |
117 |
118 | class CenterCrop(object):
119 | def __init__(self, size_image, size_depth):
120 | self.size_image = size_image
121 | self.size_depth = size_depth
122 |
123 | def __call__(self, sample):
124 | image, depth = sample['image'], sample['depth']
125 |
126 | image = self.centerCrop(image, self.size_image)
127 | depth = self.centerCrop(depth, self.size_image)
128 |
129 | ow, oh = self.size_depth
130 | depth = depth.resize((ow, oh))
131 |
132 | return {'image': image, 'depth': depth}
133 |
134 | def centerCrop(self, image, size):
135 |
136 | w1, h1 = image.size
137 |
138 | tw, th = size
139 |
140 | if w1 == tw and h1 == th:
141 | return image
142 |
143 | x1 = int(round((w1 - tw) / 2.))
144 | y1 = int(round((h1 - th) / 2.))
145 |
146 | image = image.crop((x1, y1, tw + x1, th + y1))
147 |
148 | return image
149 |
150 |
151 | class ToTensor(object):
152 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
153 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
154 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
155 | """
156 | def __init__(self,is_test=False):
157 | self.is_test = is_test
158 |
159 | def __call__(self, sample):
160 | image, depth = sample['image'], sample['depth']
161 | """
162 | Args:
163 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
164 | Returns:
165 | Tensor: Converted image.
166 | """
167 | # ground truth depth of training samples is stored in 8-bit while test samples are saved in 16 bit
168 | image = self.to_tensor(image)
169 | if self.is_test:
170 | depth = self.to_tensor(depth).float() / 1000
171 | else:
172 | depth = self.to_tensor(depth).float() * 10
173 | return {'image': image, 'depth': depth}
174 |
175 | def to_tensor(self, pic):
176 | if not(_is_pil_image(pic) or _is_numpy_image(pic)):
177 | raise TypeError(
178 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
179 |
180 | if isinstance(pic, np.ndarray):
181 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
182 |
183 | return img.float().div(255)
184 |
185 | if accimage is not None and isinstance(pic, accimage.Image):
186 | nppic = np.zeros(
187 | [pic.channels, pic.height, pic.width], dtype=np.float32)
188 | pic.copyto(nppic)
189 | return torch.from_numpy(nppic)
190 |
191 | # handle PIL Image
192 | if pic.mode == 'I':
193 | img = torch.from_numpy(np.array(pic, np.int32))
194 | elif pic.mode == 'I;16':
195 | img = torch.from_numpy(np.array(pic, np.int16))
196 | else:
197 | img = torch.ByteTensor(
198 | torch.ByteStorage.from_buffer(pic.tobytes()))
199 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
200 | if pic.mode == 'YCbCr':
201 | nchannel = 3
202 | elif pic.mode == 'I;16':
203 | nchannel = 1
204 | else:
205 | nchannel = len(pic.mode)
206 | img = img.view(pic.size[1], pic.size[0], nchannel)
207 | # put it from HWC to CHW format
208 | # yikes, this transpose takes 80% of the loading time/CPU
209 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
210 | if isinstance(img, torch.ByteTensor):
211 | return img.float().div(255)
212 | else:
213 | return img
214 |
215 |
216 | class Lighting(object):
217 |
218 | def __init__(self, alphastd, eigval, eigvec):
219 | self.alphastd = alphastd
220 | self.eigval = eigval
221 | self.eigvec = eigvec
222 |
223 | def __call__(self, sample):
224 | image, depth = sample['image'], sample['depth']
225 | if self.alphastd == 0:
226 | return image
227 |
228 | alpha = image.new().resize_(3).normal_(0, self.alphastd)
229 | rgb = self.eigvec.type_as(image).clone()\
230 | .mul(alpha.view(1, 3).expand(3, 3))\
231 | .mul(self.eigval.view(1, 3).expand(3, 3))\
232 | .sum(1).squeeze()
233 |
234 | image = image.add(rgb.view(3, 1, 1).expand_as(image))
235 |
236 | return {'image': image, 'depth': depth}
237 |
238 |
239 | class Grayscale(object):
240 |
241 | def __call__(self, img):
242 | gs = img.clone()
243 | gs[0].mul_(0.299).add_(gs[1], alpha=0.587).add_(gs[2], alpha=0.114)
244 | gs[1].copy_(gs[0])
245 | gs[2].copy_(gs[0])
246 | return gs
247 |
248 |
249 | class Saturation(object):
250 |
251 | def __init__(self, var):
252 | self.var = var
253 |
254 | def __call__(self, img):
255 | gs = Grayscale()(img)
256 | alpha = random.uniform(-self.var, self.var)
257 | return img.lerp(gs, alpha)
258 |
259 |
260 | class Brightness(object):
261 |
262 | def __init__(self, var):
263 | self.var = var
264 |
265 | def __call__(self, img):
266 | gs = img.new().resize_as_(img).zero_()
267 | alpha = random.uniform(-self.var, self.var)
268 |
269 | return img.lerp(gs, alpha)
270 |
271 |
272 | class Contrast(object):
273 |
274 | def __init__(self, var):
275 | self.var = var
276 |
277 | def __call__(self, img):
278 | gs = Grayscale()(img)
279 | gs.fill_(gs.mean())
280 | alpha = random.uniform(-self.var, self.var)
281 | return img.lerp(gs, alpha)
282 |
283 |
284 | class RandomOrder(object):
285 | """ Composes several transforms together in random order.
286 | """
287 |
288 | def __init__(self, transforms):
289 | self.transforms = transforms
290 |
291 | def __call__(self, sample):
292 | image, depth = sample['image'], sample['depth']
293 |
294 | if self.transforms is None:
295 | return {'image': image, 'depth': depth}
296 | order = torch.randperm(len(self.transforms))
297 | for i in order:
298 | image = self.transforms[i](image)
299 |
300 | return {'image': image, 'depth': depth}
301 |
302 |
303 | class ColorJitter(RandomOrder):
304 |
305 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
306 | self.transforms = []
307 | if brightness != 0:
308 | self.transforms.append(Brightness(brightness))
309 | if contrast != 0:
310 | self.transforms.append(Contrast(contrast))
311 | if saturation != 0:
312 | self.transforms.append(Saturation(saturation))
313 |
314 |
315 | class Normalize(object):
316 | def __init__(self, mean, std):
317 | self.mean = mean
318 | self.std = std
319 |
320 | def __call__(self, sample):
321 | """
322 | Args:
323 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
324 | Returns:
325 | Tensor: Normalized image.
326 | """
327 | image, depth = sample['image'], sample['depth']
328 |
329 | image = self.normalize(image, self.mean, self.std)
330 |
331 | return {'image': image, 'depth': depth}
332 |
333 | def normalize(self, tensor, mean, std):
334 | """Normalize a tensor image with mean and standard deviation.
335 | See ``Normalize`` for more details.
336 | Args:
337 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
338 | mean (sequence): Sequence of means for R, G, B channels respecitvely.
339 | std (sequence): Sequence of standard deviations for R, G, B channels
340 | respecitvely.
341 | Returns:
342 | Tensor: Normalized image.
343 | """
344 |
345 | for t, m, s in zip(tensor, mean, std):
346 | t.sub_(m).div_(s)
347 | return tensor
348 |
--------------------------------------------------------------------------------
/nyud2-dir/preprocess_nyud2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pandas as pd
4 | from tqdm import tqdm
5 | from torchvision import transforms
6 | from torch.utils.data import DataLoader
7 | from nyu_transform import *
8 | from loaddata import depthDataset
9 |
10 | def load_data(args):
11 | train_dataset = depthDataset(
12 | csv_file=os.path.join(args.data_dir, 'nyu2_train.csv'),
13 | transform=transforms.Compose([
14 | Scale(240),
15 | CenterCrop([304, 228], [304, 228]),
16 | ToTensor(is_test=False),
17 | ])
18 | )
19 | train_dataloader = DataLoader(train_dataset, 256, shuffle=False, num_workers=16, pin_memory=False)
20 |
21 | test_dataset = depthDataset(
22 | csv_file=os.path.join(args.data_dir, 'nyu2_test.csv'),
23 | transform=transforms.Compose([
24 | Scale(240),
25 | CenterCrop([304, 228], [304, 228]),
26 | ToTensor(is_test=True),
27 | ])
28 | )
29 | # print(train_dataset.__len__(), test_dataset.__len__())
30 | test_dataloader = DataLoader(test_dataset, 256, shuffle=False, num_workers=16, pin_memory=False)
31 |
32 | return train_dataloader, test_dataloader
33 |
34 | def create_FDS_train_subset_id(args):
35 | print('Creating FDS statistics updating subset ids...')
36 | train_dataloader, _ = load_data(args)
37 | train_depth_values = []
38 | for i, sample in enumerate(tqdm(train_dataloader)):
39 | train_depth_values.append(sample['depth'].squeeze())
40 | train_depth_values = torch.cat(train_depth_values, 0)
41 | select_idx = np.random.choice(a=list(range(train_depth_values.size(0))), size=600, replace=False)
42 | np.save(os.path.join(args.data_dir, 'FDS_train_subset_id.npy'), select_idx)
43 |
44 | def create_FDS_train_subset(args):
45 | print('Creating FDS statistics updating subset...')
46 | frame = pd.read_csv(os.path.join(args.data_dir, 'nyu2_train.csv'), header=None)
47 | select_id = np.load(os.path.join(args.data_dir, 'FDS_train_subset_id.npy'))
48 | frame.iloc[select_id].to_csv(os.path.join(args.data_dir, 'nyu2_train_FDS_subset.csv'), index=False, header=False)
49 |
50 | def get_bin_idx(x):
51 | return min(int(x * np.float32(10)), 99)
52 |
53 | def create_balanced_testset(args):
54 | print('Creating balanced test set mask...')
55 | _, test_dataloader = load_data(args)
56 | test_depth_values = []
57 |
58 | for i, sample in enumerate(tqdm(test_dataloader)):
59 | test_depth_values.append(sample['depth'].squeeze())
60 | test_depth_values = torch.cat(test_depth_values, 0)
61 | test_depth_values_flatten = test_depth_values.view(-1).numpy()
62 | test_bins_number, _ = np.histogram(a=test_depth_values_flatten, bins=100, range=(0., 10.))
63 |
64 | select_pixel_num = min(test_bins_number[test_bins_number != 0])
65 | test_depth_values_flatten_bins = np.array(list(map(lambda v: get_bin_idx(v), test_depth_values_flatten)))
66 |
67 | test_depth_flatten_mask = np.zeros(test_depth_values_flatten.shape[0], dtype=np.uint8)
68 | for i in range(7, 100):
69 | bucket_idx = np.where(test_depth_values_flatten_bins == i)[0]
70 | select_bucket_idx = np.random.choice(a=bucket_idx, size=select_pixel_num, replace=False)
71 | test_depth_flatten_mask[select_bucket_idx] = np.uint8(1)
72 | test_depth_mask = test_depth_flatten_mask.reshape(test_depth_values.numpy().shape)
73 | np.save(os.path.join(args.data_dir, 'test_balanced_mask.npy'), test_depth_mask)
74 |
75 | if __name__ == '__main__':
76 | parser = argparse.ArgumentParser(description='')
77 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory')
78 | args = parser.parse_args()
79 |
80 | create_FDS_train_subset_id(args)
81 | create_FDS_train_subset(args)
82 | create_balanced_testset(args)
--------------------------------------------------------------------------------
/nyud2-dir/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 |
10 | import loaddata
11 | from models import modules, net, resnet
12 | from util import Evaluator
13 |
14 | def main():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--eval_model', type=str, default='', help='evaluation model path')
17 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory')
18 | args = parser.parse_args()
19 |
20 | logging.root.handlers = []
21 | logging.basicConfig(
22 | level=logging.INFO,
23 | format="%(asctime)s | %(message)s",
24 | handlers=[
25 | logging.StreamHandler()
26 | ])
27 |
28 | model = define_model()
29 | assert os.path.isfile(args.eval_model), f"No checkpoint found at '{args.eval_model}'"
30 | model = torch.nn.DataParallel(model).cuda()
31 | model_state = torch.load(args.eval_model)
32 | logging.info(f"Loading checkpoint from {args.eval_model}")
33 | model.load_state_dict(model_state['state_dict'], strict=False)
34 | logging.info('Loaded successfully!')
35 |
36 | test_loader = loaddata.getTestingData(args, 1)
37 | test(test_loader, model)
38 |
39 | def test(test_loader, model):
40 | model.eval()
41 |
42 | logging.info('Starting testing...')
43 |
44 | evaluator = Evaluator()
45 |
46 | with torch.no_grad():
47 | for i, sample_batched in enumerate(tqdm(test_loader)):
48 | image, depth, mask = sample_batched['image'], sample_batched['depth'], sample_batched['mask']
49 | depth = depth.cuda(non_blocking=True)
50 | image = image.cuda()
51 | output = model(image)
52 | output = nn.functional.interpolate(output, size=[depth.size(2),depth.size(3)], mode='bilinear', align_corners=True)
53 |
54 | evaluator(output[mask], depth[mask])
55 |
56 | logging.info('Finished testing. Start printing statistics below...')
57 | metric_dict = evaluator.evaluate_shot()
58 |
59 | return metric_dict['overall']['RMSE'], metric_dict
60 |
61 | def define_model():
62 | original_model = resnet.resnet50(pretrained = True)
63 | Encoder = modules.E_resnet(original_model)
64 | model = net.model(None, Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
65 |
66 | return model
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
--------------------------------------------------------------------------------
/nyud2-dir/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 | import shutil
5 | import logging
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import loaddata
9 | from tqdm import tqdm
10 | from models import modules, net, resnet
11 | from util import query_yes_no
12 | from test import test
13 | from tensorboardX import SummaryWriter
14 |
15 | parser = argparse.ArgumentParser(description='')
16 |
17 | # training/optimization related
18 | parser.add_argument('--epochs', default=10, type=int,
19 | help='number of total epochs to run')
20 | parser.add_argument('--start_epoch', default=0, type=int,
21 | help='manual epoch number (useful on restarts)')
22 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
23 | help='initial learning rate')
24 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float,
25 | help='weight decay (default: 1e-4)')
26 | parser.add_argument('--batch_size', default=32, type=int, help='batch size number') # 1 GPU - 8
27 | parser.add_argument('--store_root', type=str, default='checkpoint')
28 | parser.add_argument('--store_name', type=str, default='nyud2')
29 | parser.add_argument('--data_dir', type=str, default='./data', help='data directory')
30 | parser.add_argument('--resume', action='store_true', default=False, help='whether to resume training')
31 |
32 | # imbalanced related
33 | # LDS
34 | parser.add_argument('--lds', action='store_true', default=False, help='whether to enable LDS')
35 | parser.add_argument('--lds_kernel', type=str, default='gaussian',
36 | choices=['gaussian', 'triang', 'laplace'], help='LDS kernel type')
37 | parser.add_argument('--lds_ks', type=int, default=5, help='LDS kernel size: should be odd number')
38 | parser.add_argument('--lds_sigma', type=float, default=2, help='LDS gaussian/laplace kernel sigma')
39 | # FDS
40 | parser.add_argument('--fds', action='store_true', default=False, help='whether to enable FDS')
41 | parser.add_argument('--fds_kernel', type=str, default='gaussian',
42 | choices=['gaussian', 'triang', 'laplace'], help='FDS kernel type')
43 | parser.add_argument('--fds_ks', type=int, default=5, help='FDS kernel size: should be odd number')
44 | parser.add_argument('--fds_sigma', type=float, default=2, help='FDS gaussian/laplace kernel sigma')
45 | parser.add_argument('--start_update', type=int, default=0, help='which epoch to start FDS updating')
46 | parser.add_argument('--start_smooth', type=int, default=1, help='which epoch to start using FDS to smooth features')
47 | parser.add_argument('--bucket_num', type=int, default=100, help='maximum bucket considered for FDS')
48 | parser.add_argument('--bucket_start', type=int, default=7, help='minimum(starting) bucket for FDS, 7 for NYUDv2')
49 | parser.add_argument('--fds_mmt', type=float, default=0.9, help='FDS momentum')
50 |
51 | # re-weighting: SQRT_INV / INV
52 | parser.add_argument('--reweight', type=str, default='none', choices=['none', 'inverse', 'sqrt_inv'],
53 | help='cost-sensitive reweighting scheme')
54 | # two-stage training: RRT
55 | parser.add_argument('--retrain_fc', action='store_true', default=False,
56 | help='whether to retrain last regression layer (regressor)')
57 | parser.add_argument('--pretrained', type=str, default='', help='pretrained checkpoint file path to load backbone weights for RRT')
58 |
59 | def define_model(args):
60 | original_model = resnet.resnet50(pretrained=True)
61 | Encoder = modules.E_resnet(original_model)
62 | model = net.model(args, Encoder, num_features=2048, block_channel = [256, 512, 1024, 2048])
63 |
64 | return model
65 |
66 | def main():
67 | error_best = 1e5
68 | metric_dict_best = {}
69 | epoch_best = -1
70 |
71 | global args
72 | args = parser.parse_args()
73 |
74 | if not args.lds and args.reweight != 'none':
75 | args.store_name += f'_{args.reweight}'
76 | if args.lds:
77 | args.store_name += f'_lds_{args.lds_kernel[:3]}_{args.lds_ks}'
78 | if args.lds_kernel in ['gaussian', 'laplace']:
79 | args.store_name += f'_{args.lds_sigma}'
80 | if args.fds:
81 | args.store_name += f'_fds_{args.fds_kernel[:3]}_{args.fds_ks}'
82 | if args.fds_kernel in ['gaussian', 'laplace']:
83 | args.store_name += f'_{args.fds_sigma}'
84 | args.store_name += f'_{args.start_update}_{args.start_smooth}_{args.fds_mmt}'
85 | if args.retrain_fc:
86 | args.store_name += f'_retrain_fc'
87 | args.store_name += f'_lr_{args.lr}_bs_{args.batch_size}'
88 |
89 | args.store_dir = os.path.join(args.store_root, args.store_name)
90 |
91 | if not args.resume:
92 | if os.path.exists(args.store_dir):
93 | if query_yes_no('overwrite previous folder: {} ?'.format(args.store_dir)):
94 | shutil.rmtree(args.store_dir)
95 | print(args.store_dir + ' removed.')
96 | else:
97 | raise RuntimeError('Output folder {} already exists'.format(args.store_dir))
98 | print(f"===> Creating folder: {args.store_dir}")
99 | os.makedirs(args.store_dir)
100 |
101 | logging.root.handlers = []
102 | log_file = os.path.join(args.store_dir, 'training_log.log')
103 | logging.basicConfig(
104 | level=logging.INFO,
105 | format="%(asctime)s | %(message)s",
106 | handlers=[
107 | logging.FileHandler(log_file),
108 | logging.StreamHandler()
109 | ])
110 | logging.info(args)
111 |
112 | writer = SummaryWriter(args.store_dir)
113 |
114 | model = define_model(args)
115 | model = torch.nn.DataParallel(model).cuda()
116 |
117 | if args.resume:
118 | model_state = torch.load(os.path.join(args.store_dir, 'checkpoint.pth.tar'))
119 | logging.info(f"Loading checkpoint from {os.path.join(args.store_dir, 'checkpoint.pth.tar')}"
120 | f" (Epoch [{model_state['epoch']}], RMSE: {model_state['error']:.3f})")
121 | model.load_state_dict(model_state['state_dict'])
122 |
123 | args.start_epoch = model_state['epoch'] + 1
124 | epoch_best = model_state['epoch']
125 | error_best = model_state['error']
126 | metric_dict_best = model_state['metric']
127 |
128 | if args.retrain_fc:
129 | assert os.path.isfile(args.pretrained), f"No checkpoint found at '{args.pretrained}'"
130 | model_state = torch.load(args.pretrained, map_location="cpu")
131 | from collections import OrderedDict
132 | new_state_dict = OrderedDict()
133 | for k, v in model_state['state_dict'].items():
134 | if 'R' not in k:
135 | new_state_dict[k] = v
136 | model.load_state_dict(new_state_dict, strict=False)
137 | logging.info(f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]')
138 | logging.info(f'===> Pre-trained model loaded: {args.pretrained}')
139 | for name, param in model.named_parameters():
140 | if 'R' not in name:
141 | param.requires_grad = False
142 | logging.info(f'Only optimize parameters: {[n for n, p in model.named_parameters() if p.requires_grad]}')
143 |
144 | cudnn.benchmark = True
145 | if not args.retrain_fc:
146 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
147 | else:
148 | parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
149 | optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=args.weight_decay)
150 |
151 | train_loader = loaddata.getTrainingData(args, args.batch_size)
152 | train_fds_loader = loaddata.getTrainingFDSData(args, args.batch_size)
153 | test_loader = loaddata.getTestingData(args, 1)
154 |
155 | for epoch in range(args.start_epoch, args.epochs):
156 | adjust_learning_rate(optimizer, epoch)
157 | train(train_loader, train_fds_loader, model, optimizer, epoch, writer)
158 | error, metric_dict = test(test_loader, model)
159 | if error < error_best:
160 | error_best = error
161 | metric_dict_best = metric_dict
162 | epoch_best = epoch
163 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint_best.pth.tar')
164 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint.pth.tar')
165 |
166 | save_checkpoint(model.state_dict(), epoch, error, metric_dict, 'checkpoint_final.pth.tar')
167 | logging.info(f'Best epoch: {epoch_best}; RMSE: {error_best:.3f}')
168 | logging.info('***** TEST RESULTS *****')
169 | for shot in ['Overall', 'Many', 'Medium', 'Few']:
170 | logging.info(f" * {shot}: RMSE {metric_dict_best[shot.lower()]['RMSE']:.3f}\t"
171 | f"ABS_REL {metric_dict_best[shot.lower()]['ABS_REL']:.3f}\t"
172 | f"LG10 {metric_dict_best[shot.lower()]['LG10']:.3f}\t"
173 | f"MAE {metric_dict_best[shot.lower()]['MAE']:.3f}\t"
174 | f"DELTA1 {metric_dict_best[shot.lower()]['DELTA1']:.3f}\t"
175 | f"DELTA2 {metric_dict_best[shot.lower()]['DELTA2']:.3f}\t"
176 | f"DELTA3 {metric_dict_best[shot.lower()]['DELTA3']:.3f}\t"
177 | f"NUM {metric_dict_best[shot.lower()]['NUM']}")
178 |
179 | writer.close()
180 |
181 | def train(train_loader, train_fds_loader, model, optimizer, epoch, writer):
182 | batch_time = AverageMeter()
183 | losses = AverageMeter()
184 |
185 | model.train()
186 |
187 | end = time.time()
188 | for i, sample_batched in enumerate(train_loader):
189 | image, depth, weight = sample_batched['image'], sample_batched['depth'], sample_batched['weight']
190 |
191 | depth = depth.cuda(non_blocking=True)
192 | weight = weight.cuda(non_blocking=True)
193 | image = image.cuda()
194 | optimizer.zero_grad()
195 |
196 | if args.fds:
197 | output, feature = model(image, depth, epoch)
198 | else:
199 | output = model(image, depth, epoch)
200 | loss = torch.mean(((output - depth) ** 2) * weight)
201 |
202 | losses.update(loss.item(), image.size(0))
203 | loss.backward()
204 | optimizer.step()
205 |
206 | batch_time.update(time.time() - end)
207 | end = time.time()
208 |
209 | writer.add_scalar('data/loss', loss.item(), i + epoch * len(train_loader))
210 |
211 | logging.info('Epoch: [{0}][{1}/{2}]\t'
212 | 'Time {batch_time.val:.3f} ({batch_time.sum:.3f})\t'
213 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
214 | .format(epoch, i, len(train_loader), batch_time=batch_time, loss=losses))
215 |
216 | if args.fds and epoch >= args.start_update:
217 | logging.info(f"Starting Creating Epoch [{epoch}] features of subsampled training data...")
218 | encodings, depths = [], []
219 | with torch.no_grad():
220 | for i, sample_batched in enumerate(tqdm(train_fds_loader)):
221 | image, depth = sample_batched['image'].cuda(), sample_batched['depth'].cuda()
222 | _, feature = model(image, depth, epoch)
223 | encodings.append(feature.data.cpu())
224 | depths.append(depth.data.cpu())
225 | encodings, depths = torch.cat(encodings, 0), torch.cat(depths, 0)
226 | logging.info(f"Created Epoch [{epoch}] features of subsampled training data (size: {encodings.size(0)})!")
227 | model.module.R.FDS.update_last_epoch_stats(epoch)
228 | model.module.R.FDS.update_running_stats(encodings, depths, epoch)
229 |
230 | def adjust_learning_rate(optimizer, epoch):
231 | lr = args.lr * (0.1 ** (epoch // 5))
232 |
233 | for param_group in optimizer.param_groups:
234 | param_group['lr'] = lr
235 |
236 |
237 | class AverageMeter(object):
238 | def __init__(self):
239 | self.reset()
240 |
241 | def reset(self):
242 | self.val = 0
243 | self.avg = 0
244 | self.sum = 0
245 | self.count = 0
246 |
247 | def update(self, val, n=1):
248 | self.val = val
249 | self.sum += val * n
250 | self.count += n
251 | self.avg = self.sum / self.count
252 |
253 |
254 | def save_checkpoint(state_dict, epoch, error, metric_dict, filename='checkpoint.pth.tar'):
255 | logging.info(f'Saving checkpoint to {os.path.join(args.store_dir, filename)}...')
256 | torch.save({
257 | 'state_dict': state_dict,
258 | 'epoch': epoch,
259 | 'error': error,
260 | 'metric': metric_dict
261 | }, os.path.join(args.store_dir, filename))
262 |
263 | if __name__ == '__main__':
264 | main()
265 |
--------------------------------------------------------------------------------
/nyud2-dir/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | import torch
4 | import numpy as np
5 | from scipy.ndimage import gaussian_filter1d
6 | from scipy.signal.windows import triang
7 |
8 | def lg10(x):
9 | return torch.div(torch.log(x), math.log(10))
10 |
11 | def maxOfTwo(x, y):
12 | z = x.clone()
13 | maskYLarger = torch.lt(x, y)
14 | z[maskYLarger.detach()] = y[maskYLarger.detach()]
15 | return z
16 |
17 | def nValid(x):
18 | return torch.sum(torch.eq(x, x).float())
19 |
20 | def getNanMask(x):
21 | return torch.ne(x, x)
22 |
23 | def setNanToZero(input, target):
24 | nanMask = getNanMask(target)
25 | nValidElement = nValid(target)
26 |
27 | _input = input.clone()
28 | _target = target.clone()
29 |
30 | _input[nanMask] = 0
31 | _target[nanMask] = 0
32 |
33 | return _input, _target, nanMask, nValidElement
34 |
35 | class Evaluator:
36 | def __init__(self):
37 | self.shot_idx = {
38 | 'many': [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
39 | 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 49],
40 | 'medium': [7, 8, 46, 48, 50, 51, 52, 53, 54, 55, 56, 58, 60, 61, 63],
41 | 'few': [0, 1, 2, 3, 4, 5, 6, 57, 59, 62, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
42 | 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
43 | }
44 | self.output = torch.tensor([], dtype=torch.float32)
45 | self.depth = torch.tensor([], dtype=torch.float32)
46 |
47 | def __call__(self, output, depth):
48 | output = output.squeeze().view(-1).cpu()
49 | depth = depth.squeeze().view(-1).cpu()
50 | self.output = torch.cat([self.output, output])
51 | self.depth = torch.cat([self.depth, depth])
52 |
53 | def evaluate_shot(self):
54 | metric_dict = {'overall': {}, 'many': {}, 'medium': {}, 'few': {}}
55 | self.depth_bucket = np.array(list(map(lambda v: self.get_bin_idx(v), self.depth.cpu().numpy())))
56 |
57 | for shot in metric_dict.keys():
58 | if shot == 'overall':
59 | metric_dict[shot] = self.evaluate(self.output, self.depth)
60 | else:
61 | mask = np.zeros(self.depth.size(0), dtype=np.bool)
62 | for i in self.shot_idx[shot]:
63 | mask[np.where(self.depth_bucket == i)[0]] = True
64 | mask = torch.tensor(mask, dtype=torch.bool)
65 | metric_dict[shot] = self.evaluate(self.output[mask], self.depth[mask])
66 |
67 | logging.info('\n***** TEST RESULTS *****')
68 | for shot in ['Overall', 'Many', 'Medium', 'Few']:
69 | logging.info(f" * {shot}: RMSE {metric_dict[shot.lower()]['RMSE']:.3f}\t"
70 | f"ABS_REL {metric_dict[shot.lower()]['ABS_REL']:.3f}\t"
71 | f"LG10 {metric_dict[shot.lower()]['LG10']:.3f}\t"
72 | f"MAE {metric_dict[shot.lower()]['MAE']:.3f}\t"
73 | f"DELTA1 {metric_dict[shot.lower()]['DELTA1']:.3f}\t"
74 | f"DELTA2 {metric_dict[shot.lower()]['DELTA2']:.3f}\t"
75 | f"DELTA3 {metric_dict[shot.lower()]['DELTA3']:.3f}\t"
76 | f"NUM {metric_dict[shot.lower()]['NUM']}")
77 |
78 | return metric_dict
79 |
80 | def reset(self):
81 | self.output = torch.tensor([], dtype=torch.float32)
82 | self.depth = torch.tensor([], dtype=torch.float32)
83 |
84 | @staticmethod
85 | def get_bin_idx(x):
86 | return min(int(x * np.float32(10)), 99)
87 |
88 | @staticmethod
89 | def evaluate(output, target):
90 | errors = {'MSE': 0, 'RMSE': 0, 'ABS_REL': 0, 'LG10': 0,
91 | 'MAE': 0, 'DELTA1': 0, 'DELTA2': 0, 'DELTA3': 0, 'NUM': 0}
92 |
93 | _output, _target, nanMask, nValidElement = setNanToZero(output, target)
94 |
95 | if (nValidElement.data.cpu().numpy() > 0):
96 | diffMatrix = torch.abs(_output - _target)
97 |
98 | errors['MSE'] = torch.sum(torch.pow(diffMatrix, 2)) / nValidElement
99 |
100 | errors['MAE'] = torch.sum(diffMatrix) / nValidElement
101 |
102 | realMatrix = torch.div(diffMatrix, _target)
103 | realMatrix[nanMask] = 0
104 | errors['ABS_REL'] = torch.sum(realMatrix) / nValidElement
105 |
106 | LG10Matrix = torch.abs(lg10(_output) - lg10(_target))
107 | LG10Matrix[nanMask] = 0
108 | errors['LG10'] = torch.sum(LG10Matrix) / nValidElement
109 |
110 | yOverZ = torch.div(_output, _target)
111 | zOverY = torch.div(_target, _output)
112 |
113 | maxRatio = maxOfTwo(yOverZ, zOverY)
114 |
115 | errors['DELTA1'] = torch.sum(
116 | torch.le(maxRatio, 1.25).float()) / nValidElement
117 | errors['DELTA2'] = torch.sum(
118 | torch.le(maxRatio, math.pow(1.25, 2)).float()) / nValidElement
119 | errors['DELTA3'] = torch.sum(
120 | torch.le(maxRatio, math.pow(1.25, 3)).float()) / nValidElement
121 |
122 | errors['MSE'] = float(errors['MSE'].data.cpu().numpy())
123 | errors['ABS_REL'] = float(errors['ABS_REL'].data.cpu().numpy())
124 | errors['LG10'] = float(errors['LG10'].data.cpu().numpy())
125 | errors['MAE'] = float(errors['MAE'].data.cpu().numpy())
126 | errors['DELTA1'] = float(errors['DELTA1'].data.cpu().numpy())
127 | errors['DELTA2'] = float(errors['DELTA2'].data.cpu().numpy())
128 | errors['DELTA3'] = float(errors['DELTA3'].data.cpu().numpy())
129 | errors['NUM'] = int(nValidElement)
130 |
131 | errors['RMSE'] = np.sqrt(errors['MSE'])
132 |
133 | return errors
134 |
135 |
136 | def query_yes_no(question):
137 | """ Ask a yes/no question via input() and return their answer. """
138 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
139 | prompt = " [Y/n] "
140 |
141 | while True:
142 | print(question + prompt, end=':')
143 | choice = input().lower()
144 | if choice == '':
145 | return valid['y']
146 | elif choice in valid:
147 | return valid[choice]
148 | else:
149 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n")
150 |
151 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.2, clip_max=5.):
152 | if torch.sum(v1) < 1e-10:
153 | return matrix
154 | if (v1 <= 0.).any() or (v2 < 0.).any():
155 | valid_pos = (((v1 > 0.) + (v2 >= 0.)) == 2)
156 | # print(torch.sum(valid_pos))
157 | factor = torch.clamp(v2[valid_pos] / v1[valid_pos], clip_min, clip_max)
158 | matrix[:, valid_pos] = (matrix[:, valid_pos] - m1[valid_pos]) * torch.sqrt(factor) + m2[valid_pos]
159 | return matrix
160 |
161 | factor = torch.clamp(v2 / v1, clip_min, clip_max)
162 | return (matrix - m1) * torch.sqrt(factor) + m2
163 |
164 | def get_lds_kernel_window(kernel, ks, sigma):
165 | assert kernel in ['gaussian', 'triang', 'laplace']
166 | half_ks = (ks - 1) // 2
167 | if kernel == 'gaussian':
168 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
169 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma))
170 | elif kernel == 'triang':
171 | kernel_window = triang(ks)
172 | else:
173 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
174 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1)))
175 |
176 | return kernel_window
177 |
178 |
179 |
180 |
181 |
--------------------------------------------------------------------------------
/sts-b-dir/README.md:
--------------------------------------------------------------------------------
1 | # STS-B-DIR
2 | ## Installation
3 |
4 | #### Prerequisites
5 |
6 | 1. Download GloVe word embeddings (840B tokens, 300D vectors) using
7 |
8 | ```bash
9 | python glove/download_glove.py
10 | ```
11 |
12 | 2. __(Optional)__ We have provided both original STS-B dataset and our created balanced STS-B-DIR dataset in folder `./glue_data/STS-B`. To reproduce the results in the paper, please use our created STS-B-DIR dataset. If you want to try different balanced splits, you can delete the folder `./glue_data/STS-B` and run
13 |
14 | ```bash
15 | python glue_data/create_sts.py
16 | ```
17 |
18 | #### Dependencies
19 |
20 | The required dependencies for this task are quite different to other three tasks, so it's better to create a new environment for this task. If you use conda, you can create the environment and install dependencies using the following commands:
21 |
22 | ```bash
23 | conda create -n sts python=3.6
24 | conda activate sts
25 | # PyTorch 0.4 (required) + Cuda 9.2
26 | conda install pytorch=0.4.1 cuda92 -c pytorch
27 | # other dependencies
28 | pip install -r requirements.txt
29 | # The current latest "overrides" dependency installed along with allennlp 0.5.0 will now raise error.
30 | # We need to downgrade "overrides" version to 3.1.0
31 | pip install overrides==3.1.0
32 | ```
33 |
34 | ## Code Overview
35 |
36 | #### Main Files
37 |
38 | - `train.py`: main training and evaluation script
39 | - `create_sts.py`: download original STS-B dataset and create STS-B-DIR dataset with balanced val/test set
40 |
41 | #### Main Arguments
42 |
43 | - `--lds`: LDS switch (whether to enable LDS)
44 | - `--fds`: FDS switch (whether to enable FDS)
45 | - `--reweight`: cost-sensitive re-weighting scheme to use
46 | - `--loss`: training loss type
47 | - `--resume`: whether to resume training (only for training)
48 | - `--evaluate`: evaluate only flag
49 | - `--eval_model`: path to resume checkpoint (only for evaluation)
50 | - `--retrain_fc`: whether to retrain regressor
51 | - `--pretrained`: path to load backbone weights for regressor re-training (RRT)
52 | - `--val_interval`: number of iterations between validation checks
53 | - `--patience`: patience (number of validation checks) for early stopping
54 |
55 | ## Getting Started
56 |
57 | #### Train a vanilla model
58 |
59 | ```bash
60 | python train.py --cuda --reweight none
61 | ```
62 |
63 | Always specify `--cuda ` for the GPU ID (single GPU) to be used. We will omit this argument in the following for simplicity.
64 |
65 | #### Train a model using re-weighting
66 |
67 | To perform inverse re-weighting
68 |
69 | ```bash
70 | python train.py --reweight inverse
71 | ```
72 |
73 | To perform square-root inverse re-weighting
74 |
75 | ```bash
76 | python train.py --reweight sqrt_inv
77 | ```
78 |
79 | #### Train a model with different losses
80 |
81 | To use Focal-R loss
82 |
83 | ```bash
84 | python train.py --loss focal_mse
85 | ```
86 |
87 | To use huber loss
88 |
89 | ```bash
90 | python train.py --loss huber --huber_beta 0.3
91 | ```
92 |
93 | #### Train a model using RRT
94 |
95 | ```bash
96 | python train.py [...retrained model arguments...] --retrain_fc --pretrained
97 | ```
98 |
99 | #### Train a model using LDS
100 |
101 | To use Gaussian kernel (kernel size: 5, sigma: 2)
102 |
103 | ```bash
104 | python train.py --reweight inverse --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2
105 | ```
106 |
107 | #### Train a model using FDS
108 |
109 | To use Gaussian kernel (kernel size: 5, sigma: 2)
110 |
111 | ```bash
112 | python train.py --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
113 | ```
114 |
115 | #### Train a model using LDS + FDS
116 |
117 | ```bash
118 | python train.py --reweight inverse --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
119 | ```
120 |
121 | #### Evaluate a trained checkpoint
122 |
123 | ```bash
124 | python train.py [...evaluation model arguments...] --evaluate --eval_model
125 | ```
126 |
127 | ## Reproduced Benchmarks and Model Zoo
128 |
129 | We provide below reproduced results on STS-B-DIR (base method `Vanilla`, metric `MSE`).
130 | Note that some models could give **better** results than the reported numbers in the paper.
131 |
132 | | Model | Overall | Many-Shot | Medium-Shot | Few-Shot | Download |
133 | | :-------: | :-----: | :-------: | :---------: | :------: | :------: |
134 | | LDS | 0.914 | 0.819 | 1.319 | 0.955 | [model](https://drive.google.com/file/d/1CVyycq0OMgD9N9gJX5UDcfpRaJdZBjzo/view?usp=sharing) |
135 | | FDS | 0.916 | 0.875 | 1.027 | 1.086 | [model](https://drive.google.com/file/d/13e-1kd-KQrzFFVrJp1FeNDIBwUp3qtYx/view?usp=sharing) |
136 | | LDS + FDS | 0.907 | 0.802 | 1.363 | 0.942 | [model](https://drive.google.com/file/d/1kb_GV2coJRK_o9OxnMcxchq1EKOcpx-h/view?usp=sharing) |
137 |
--------------------------------------------------------------------------------
/sts-b-dir/allennlp_mods/numeric_field.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 | import pdb
3 | import logging
4 |
5 | from overrides import overrides
6 | import numpy
7 | import torch
8 | from torch.autograd import Variable
9 |
10 | from allennlp.data.fields.field import Field
11 | from allennlp.data.vocabulary import Vocabulary
12 | from allennlp.common.checks import ConfigurationError
13 |
14 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
15 |
16 |
17 | class NumericField(Field[numpy.ndarray]):
18 | """
19 | A ``LabelField`` is a categorical label of some kind, where the labels are either strings of
20 | text or 0-indexed integers (if you wish to skip indexing by passing skip_indexing=True).
21 | If the labels need indexing, we will use a :class:`Vocabulary` to convert the string labels
22 | into integers.
23 |
24 | This field will get converted into an integer index representing the class label.
25 |
26 | Parameters
27 | ----------
28 | label : ``Union[str, int]``
29 | label_namespace : ``str``, optional (default="labels")
30 | The namespace to use for converting label strings into integers. We map label strings to
31 | integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...),
32 | and this namespace tells the ``Vocabulary`` object which mapping from strings to integers
33 | to use (so "entailment" as a label doesn't get the same integer id as "entailment" as a
34 | word). If you have multiple different label fields in your data, you should make sure you
35 | use different namespaces for each one, always using the suffix "labels" (e.g.,
36 | "passage_labels" and "question_labels").
37 | skip_indexing : ``bool``, optional (default=False)
38 | If your labels are 0-indexed integers, you can pass in this flag, and we'll skip the indexing
39 | step. If this is ``False`` and your labels are not strings, this throws a ``ConfigurationError``.
40 | """
41 | def __init__(self,
42 | label: Union[float, int],
43 | label_namespace: str = 'labels') -> None:
44 | self.label = label
45 | self._label_namespace = label_namespace
46 | self._label_id = numpy.array(label, dtype=numpy.float32)
47 | if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")):
48 | logger.warning("Your label namespace was '%s'. We recommend you use a namespace "
49 | "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by "
50 | "default to your vocabulary. See documentation for "
51 | "`non_padded_namespaces` parameter in Vocabulary.", self._label_namespace)
52 |
53 | # idk what this is for
54 | @overrides
55 | def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
56 | if self._label_id is None:
57 | counter[self._label_namespace][self.label] += 1 # type: ignore
58 |
59 | @overrides
60 | def get_padding_lengths(self) -> Dict[str, int]: # pylint: disable=no-self-use
61 | return {}
62 |
63 | def as_array(self, padding_lengths: Dict[str, int]) -> numpy.ndarray: # pylint: disable=unused-argument
64 | return numpy.asarray([self._label_id])
65 |
66 | @overrides
67 | def as_tensor(self, padding_lengths: Dict[str, int],
68 | cuda_device: int = -1,
69 | for_training: bool = True) -> torch.Tensor: # pylint: disable=unused-argument
70 | label_id = self._label_id.tolist()
71 | tensor = Variable(torch.FloatTensor([label_id]), volatile=not for_training)
72 | return tensor if cuda_device == -1 else tensor.cuda(cuda_device)
73 |
74 | @overrides
75 | def empty_field(self):
76 | return NumericField(0, self._label_namespace)
77 |
--------------------------------------------------------------------------------
/sts-b-dir/evaluate.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import tqdm
3 | import numpy as np
4 |
5 | def evaluate(model, tasks, iterator, cuda_device, split="val"):
6 | '''Evaluate on a dataset'''
7 | model.eval()
8 |
9 | all_preds = {}
10 | n_overall_examples = 0
11 | for task in tasks:
12 | n_examples = 0
13 | task_preds, task_idxs, task_labels = [], [], []
14 | if split == "val":
15 | dataset = task.val_data
16 | elif split == 'train':
17 | dataset = task.train_data
18 | elif split == "test":
19 | dataset = task.test_data
20 | generator = iterator(dataset, num_epochs=1, shuffle=False, cuda_device=cuda_device)
21 | generator_tqdm = tqdm.tqdm(generator, total=iterator.get_num_batches(dataset), disable=True)
22 | for batch in generator_tqdm:
23 | tensor_batch = batch
24 | out = model.forward(task, **tensor_batch)
25 | n_examples += batch['label'].size()[0]
26 | preds, _ = out['logits'].max(dim=1)
27 | task_preds += list(preds.data.cpu().numpy())
28 | task_labels += list(batch['label'].squeeze().data.cpu().numpy())
29 |
30 | task_metrics = task.get_metrics(reset=True)
31 | logging.info('\n***** TEST RESULTS *****')
32 | for shot in ['Overall', 'Many', 'Medium', 'Few']:
33 | logging.info(f" * {shot}: MSE {task_metrics[shot.lower()]['mse']:.3f}\t"
34 | f"L1 {task_metrics[shot.lower()]['l1']:.3f}\t"
35 | f"G-Mean {task_metrics[shot.lower()]['gmean']:.3f}\t"
36 | f"Pearson {task_metrics[shot.lower()]['pearsonr']:.3f}\t"
37 | f"Spearman {task_metrics[shot.lower()]['spearmanr']:.3f}\t"
38 | f"Number {task_metrics[shot.lower()]['num_samples']}")
39 |
40 | n_overall_examples += n_examples
41 | task_preds = [min(max(np.float32(0.), pred * np.float32(5.)), np.float32(5.)) for pred in task_preds]
42 | all_preds[task.name] = (task_preds, task_idxs)
43 |
44 | return task_preds, task_labels, task_metrics['overall']['mse']
--------------------------------------------------------------------------------
/sts-b-dir/fds.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import logging
6 | from scipy.ndimage import gaussian_filter1d
7 | from util import calibrate_mean_var
8 | from scipy.signal.windows import triang
9 |
10 | class FDS(nn.Module):
11 |
12 | def __init__(self, feature_dim, bucket_num=50, bucket_start=0, start_update=0, start_smooth=1,
13 | kernel='gaussian', ks=5, sigma=2, momentum=0.9):
14 | super(FDS, self).__init__()
15 | self.feature_dim = feature_dim
16 | self.bucket_num = bucket_num
17 | self.bucket_start = bucket_start
18 | self.kernel_window = self._get_kernel_window(kernel, ks, sigma)
19 | self.half_ks = (ks - 1) // 2
20 | self.momentum = momentum
21 | self.start_update = start_update
22 | self.start_smooth = start_smooth
23 |
24 | self.register_buffer('epoch', torch.zeros(1).fill_(start_update))
25 | self.register_buffer('running_mean', torch.zeros(bucket_num - bucket_start, feature_dim))
26 | self.register_buffer('running_var', torch.ones(bucket_num - bucket_start, feature_dim))
27 | self.register_buffer('running_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
28 | self.register_buffer('running_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
29 | self.register_buffer('smoothed_mean_last_epoch', torch.zeros(bucket_num - bucket_start, feature_dim))
30 | self.register_buffer('smoothed_var_last_epoch', torch.ones(bucket_num - bucket_start, feature_dim))
31 | self.register_buffer('num_samples_tracked', torch.zeros(bucket_num - bucket_start))
32 |
33 | @staticmethod
34 | def _get_kernel_window(kernel, ks, sigma):
35 | assert kernel in ['gaussian', 'triang', 'laplace']
36 | half_ks = (ks - 1) // 2
37 | if kernel == 'gaussian':
38 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
39 | base_kernel = np.array(base_kernel, dtype=np.float32)
40 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / sum(gaussian_filter1d(base_kernel, sigma=sigma))
41 | elif kernel == 'triang':
42 | kernel_window = triang(ks) / sum(triang(ks))
43 | else:
44 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
45 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / sum(
46 | map(laplace, np.arange(-half_ks, half_ks + 1)))
47 |
48 | logging.info(f'Using FDS: [{kernel.upper()}] ({ks}/{sigma})')
49 | return torch.tensor(kernel_window, dtype=torch.float32).cuda()
50 |
51 | def _get_bucket_idx(self, label):
52 | label = np.float32(label)
53 | _, bins_edges = np.histogram(a=np.array([], dtype=np.float32), bins=self.bucket_num, range=(0., 5.))
54 | if label == 5.:
55 | return self.bucket_num - 1
56 | else:
57 | return max(np.where(bins_edges > label)[0][0] - 1, self.bucket_start)
58 |
59 | def _update_last_epoch_stats(self):
60 | self.running_mean_last_epoch = self.running_mean
61 | self.running_var_last_epoch = self.running_var
62 |
63 | self.smoothed_mean_last_epoch = F.conv1d(
64 | input=F.pad(self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0),
65 | pad=(self.half_ks, self.half_ks), mode='reflect'),
66 | weight=self.kernel_window.view(1, 1, -1), padding=0
67 | ).permute(2, 1, 0).squeeze(1)
68 | self.smoothed_var_last_epoch = F.conv1d(
69 | input=F.pad(self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0),
70 | pad=(self.half_ks, self.half_ks), mode='reflect'),
71 | weight=self.kernel_window.view(1, 1, -1), padding=0
72 | ).permute(2, 1, 0).squeeze(1)
73 |
74 | def reset(self):
75 | self.running_mean.zero_()
76 | self.running_var.fill_(1)
77 | self.running_mean_last_epoch.zero_()
78 | self.running_var_last_epoch.fill_(1)
79 | self.smoothed_mean_last_epoch.zero_()
80 | self.smoothed_var_last_epoch.fill_(1)
81 | self.num_samples_tracked.zero_()
82 |
83 | def update_last_epoch_stats(self, epoch):
84 | if epoch == self.epoch + 1:
85 | self.epoch += 1
86 | self._update_last_epoch_stats()
87 | logging.info(f"Updated smoothed statistics of last epoch on Epoch [{epoch}]!")
88 |
89 | def update_running_stats(self, features, labels, epoch):
90 | if epoch < self.epoch:
91 | return
92 |
93 | assert self.feature_dim == features.size(1), "Input feature dimension is not aligned!"
94 | assert features.size(0) == labels.size(0), "Dimensions of features and labels are not aligned!"
95 |
96 | buckets = np.array([self._get_bucket_idx(label) for label in labels])
97 | for bucket in np.unique(buckets):
98 | curr_feats = features[torch.tensor((buckets == bucket).astype(np.uint8))]
99 | curr_num_sample = curr_feats.size(0)
100 | curr_mean = torch.mean(curr_feats, 0)
101 | curr_var = torch.var(curr_feats, 0, unbiased=True if curr_feats.size(0) != 1 else False)
102 |
103 | self.num_samples_tracked[bucket - self.bucket_start] += curr_num_sample
104 | factor = self.momentum if self.momentum is not None else \
105 | (1 - curr_num_sample / float(self.num_samples_tracked[bucket - self.bucket_start]))
106 | factor = 0 if epoch == self.start_update else factor
107 | self.running_mean[bucket - self.bucket_start] = \
108 | (1 - factor) * curr_mean + factor * self.running_mean[bucket - self.bucket_start]
109 | self.running_var[bucket - self.bucket_start] = \
110 | (1 - factor) * curr_var + factor * self.running_var[bucket - self.bucket_start]
111 |
112 | # make up for zero training samples buckets
113 | for bucket in range(self.bucket_start, self.bucket_num):
114 | if bucket not in np.unique(buckets):
115 | if bucket == self.bucket_start:
116 | self.running_mean[0] = self.running_mean[1]
117 | self.running_var[0] = self.running_var[1]
118 | elif bucket == self.bucket_num - 1:
119 | self.running_mean[bucket - self.bucket_start] = self.running_mean[bucket - self.bucket_start - 1]
120 | self.running_var[bucket - self.bucket_start] = self.running_var[bucket - self.bucket_start - 1]
121 | else:
122 | self.running_mean[bucket - self.bucket_start] = (self.running_mean[bucket - self.bucket_start - 1] +
123 | self.running_mean[bucket - self.bucket_start + 1]) / 2.
124 | self.running_var[bucket - self.bucket_start] = (self.running_var[bucket - self.bucket_start - 1] +
125 | self.running_var[bucket - self.bucket_start + 1]) / 2.
126 | logging.info(f"Updated running statistics with Epoch [{epoch}] features!")
127 |
128 | def smooth(self, features, labels, epoch):
129 | if epoch < self.start_smooth:
130 | return features
131 |
132 | labels = labels.squeeze(1)
133 | buckets = np.array([self._get_bucket_idx(label) for label in labels])
134 | for bucket in np.unique(buckets):
135 | features[torch.tensor((buckets == bucket).astype(np.uint8))] = calibrate_mean_var(
136 | features[torch.tensor((buckets == bucket).astype(np.uint8))],
137 | self.running_mean_last_epoch[bucket - self.bucket_start],
138 | self.running_var_last_epoch[bucket - self.bucket_start],
139 | self.smoothed_mean_last_epoch[bucket - self.bucket_start],
140 | self.smoothed_var_last_epoch[bucket - self.bucket_start]
141 | )
142 |
143 | return features
144 |
--------------------------------------------------------------------------------
/sts-b-dir/glove/download_glove.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wget
3 | import zipfile
4 |
5 | print("Downloading and extracting GloVe word embeddings...")
6 | data_file = "./glove/glove.840B.300d.zip"
7 | wget.download("http://nlp.stanford.edu/data/glove.840B.300d.zip", out=data_file)
8 | with zipfile.ZipFile(data_file) as zip_ref:
9 | zip_ref.extractall('./glove')
10 | os.remove(data_file)
11 | print("\nCompleted!")
--------------------------------------------------------------------------------
/sts-b-dir/glue_data/create_sts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import codecs
4 | import numpy as np
5 | import urllib
6 | if sys.version_info >= (3, 0):
7 | import urllib.request
8 | import zipfile
9 |
10 | URLLIB=urllib
11 | if sys.version_info >= (3, 0):
12 | URLLIB=urllib.request
13 |
14 | ##### Downloading raw STS-B dataset
15 | print("Downloading and extracting STS-B...")
16 | data_file = "./glue_data/STS-B.zip"
17 | URLLIB.urlretrieve("https://dl.fbaipublicfiles.com/glue/data/STS-B.zip", data_file)
18 | with zipfile.ZipFile(data_file) as zip_ref:
19 | zip_ref.extractall('./glue_data')
20 | os.remove(data_file)
21 | print("Completed!")
22 |
23 | ##### Creating STS-B-DIR dataset
24 | print("Creating STS-B-DIR dataset...")
25 | contents = {'train': [], 'dev': [], 'test': []}
26 | target = {'train': [], 'dev': [], 'test': []}
27 |
28 | for set_name in ['train', 'dev']:
29 | with codecs.open(f'./glue_data/STS-B/{set_name}.tsv', 'r', 'utf-8') as data_fh:
30 | for _ in range(1):
31 | titles = data_fh.readline()
32 | for row_idx, row in enumerate(data_fh):
33 | contents[set_name].append(row)
34 | row = row.strip().split('\t')
35 | targ = row[9]
36 | target[set_name].append(np.float32(targ))
37 |
38 | bins = 20
39 | target_all = target['train'] + target['dev']
40 | contents_all = contents['train'] + contents['dev']
41 |
42 | bins_numbers, bins_edges = np.histogram(a=target_all, bins=bins, range=(0., 5.))
43 |
44 | def _get_bin_idx(label):
45 | if label == 5.:
46 | return bins - 1
47 | else:
48 | return np.where(bins_edges > label)[0][0] - 1
49 |
50 | bins_contents = [None] * bins
51 | bins_targets = [None] * bins
52 | bins_numbers = list(bins_numbers)
53 |
54 | for i, score in enumerate(target_all):
55 | bin_idx = _get_bin_idx(score)
56 | if bins_contents[bin_idx] is None:
57 | bins_contents[bin_idx] = []
58 | if bins_targets[bin_idx] is None:
59 | bins_targets[bin_idx] = []
60 | bins_contents[bin_idx].append(contents_all[i])
61 | bins_targets[bin_idx].append(score)
62 | contents_bins_numbers = []
63 | targets_bins_numbers = []
64 | for i in range(bins):
65 | contents_bins_numbers.append(len(bins_contents[i]))
66 | targets_bins_numbers.append(len(bins_targets[i]))
67 |
68 | new_contents = {'train': [None] * bins, 'dev': [None] * bins, 'test': [None] * bins}
69 | new_targets = {'train': [None] * bins, 'dev': [None] * bins, 'test': [None] * bins}
70 | select_num = 100
71 | for i in range(bins):
72 | new_index = {}
73 | new_dev_test_index = np.random.choice(a=list(range(bins_numbers[i])), size=select_num, replace=False)
74 | new_index['train'] = np.setdiff1d(np.array(range(bins_numbers[i])), new_dev_test_index)
75 | new_index['dev'] = np.random.choice(a=new_dev_test_index, size=int(select_num / 2), replace=False)
76 | new_index['test'] = np.setdiff1d(new_dev_test_index, new_index['dev'])
77 | for set_name in ['train', 'dev', 'test']:
78 | new_contents[set_name][i] = np.array(bins_contents[i])[new_index[set_name]]
79 | new_targets[set_name][i] = np.array(bins_targets[i])[new_index[set_name]]
80 |
81 | new_contents_merged = {'train': [], 'dev': [], 'test': []}
82 | for i in range(bins):
83 | for set_name in ['train', 'dev', 'test']:
84 | new_contents_merged[set_name] += new_contents[set_name][i].tolist()
85 | print('Number of samples for train/dev/test set in STS-B-DIR:',
86 | len(new_contents_merged['train']), len(new_contents_merged['dev']), len(new_contents_merged['test']))
87 |
88 | for set_name in ['train', 'dev', 'test']:
89 | for i in range(len(new_contents_merged[set_name])):
90 | content_split = new_contents_merged[set_name][i].split('\t')
91 | content_split[0] = str(i)
92 | content_split = '\t'.join(content_split)
93 | new_contents_merged[set_name][i] = content_split
94 | for set_name in ['train', 'dev', 'test']:
95 | with open(f'./glue_data/STS-B/{set_name}_new.tsv', 'w') as f:
96 | f.write(titles)
97 | for i in range(len(new_contents_merged[set_name])):
98 | f.write(new_contents_merged[set_name][i])
99 | print("STS-B-DIR dataset created! ('./glue_data/STS-B/(train_new/dev_new/test_new).tsv')")
100 |
--------------------------------------------------------------------------------
/sts-b-dir/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def weighted_mse_loss(inputs, targets, weights=None):
6 | loss = F.mse_loss(inputs, targets, reduce=False)
7 | if weights is not None:
8 | loss *= weights.expand_as(loss)
9 | loss = torch.mean(loss)
10 | return loss
11 |
12 |
13 | def weighted_l1_loss(inputs, targets, weights=None):
14 | loss = F.l1_loss(inputs, targets, reduce=False)
15 | if weights is not None:
16 | loss *= weights.expand_as(loss)
17 | loss = torch.mean(loss)
18 | return loss
19 |
20 |
21 | def weighted_huber_loss(inputs, targets, weights=None, beta=0.5):
22 | l1_loss = torch.abs(inputs - targets)
23 | cond = l1_loss < beta
24 | loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta)
25 | if weights is not None:
26 | loss *= weights.expand_as(loss)
27 | loss = torch.mean(loss)
28 | return loss
29 |
30 |
31 | def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=20., gamma=1):
32 | loss = F.mse_loss(inputs, targets, reduce=False)
33 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
34 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
35 | if weights is not None:
36 | loss *= weights.expand_as(loss)
37 | loss = torch.mean(loss)
38 | return loss
39 |
40 |
41 | def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=20., gamma=1):
42 | loss = F.l1_loss(inputs, targets, reduce=False)
43 | loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
44 | (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
45 | if weights is not None:
46 | loss *= weights.expand_as(loss)
47 | loss = torch.mean(loss)
48 | return loss
49 |
--------------------------------------------------------------------------------
/sts-b-dir/models.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.nn as nn
3 |
4 | from allennlp.common import Params
5 | from allennlp.models.model import Model
6 | from allennlp.modules import Highway
7 | from allennlp.modules import TimeDistributed
8 | from allennlp.nn import util, InitializerApplicator
9 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
10 | from allennlp.modules.token_embedders import Embedding
11 | from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder as s2s_e
12 |
13 | from fds import FDS
14 | from loss import *
15 |
16 | def build_model(args, vocab, pretrained_embs, tasks):
17 | '''
18 | Build model according to arguments
19 | '''
20 | d_word, n_layers_highway = args.d_word, args.n_layers_highway
21 |
22 | # Build embedding layers
23 | if args.glove:
24 | word_embs = pretrained_embs
25 | train_embs = bool(args.train_words)
26 | else:
27 | logging.info("\tLearning embeddings from scratch!")
28 | word_embs = None
29 | train_embs = True
30 | word_embedder = Embedding(vocab.get_vocab_size('tokens'), d_word, weight=word_embs, trainable=train_embs,
31 | padding_index=vocab.get_token_index('@@PADDING@@'))
32 | d_inp_phrase = 0
33 |
34 | token_embedder = {"words": word_embedder}
35 | d_inp_phrase += d_word
36 | text_field_embedder = BasicTextFieldEmbedder(token_embedder)
37 | d_hid_phrase = args.d_hid
38 |
39 | # Build encoders
40 | phrase_layer = s2s_e.by_name('lstm').from_params(Params({'input_size': d_inp_phrase,
41 | 'hidden_size': d_hid_phrase,
42 | 'num_layers': args.n_layers_enc,
43 | 'bidirectional': True}))
44 | pair_encoder = HeadlessPairEncoder(vocab, text_field_embedder, n_layers_highway,
45 | phrase_layer, dropout=args.dropout)
46 | d_pair = 2 * d_hid_phrase
47 |
48 | if args.fds:
49 | _FDS = FDS(feature_dim=d_pair * 4, bucket_num=args.bucket_num, bucket_start=args.bucket_start,
50 | start_update=args.start_update, start_smooth=args.start_smooth,
51 | kernel=args.fds_kernel, ks=args.fds_ks, sigma=args.fds_sigma, momentum=args.fds_mmt)
52 |
53 | # Build model and classifiers
54 | model = MultiTaskModel(args, pair_encoder, _FDS if args.fds else None)
55 | build_regressor(tasks, model, d_pair)
56 |
57 | if args.cuda >= 0:
58 | model = model.cuda()
59 |
60 | return model
61 |
62 | def build_regressor(tasks, model, d_pair):
63 | '''
64 | Build the regressor
65 | '''
66 | for task in tasks:
67 | d_task = d_pair * 4
68 | model.build_regressor(task, d_task)
69 | return
70 |
71 | class MultiTaskModel(nn.Module):
72 | def __init__(self, args, pair_encoder, FDS=None):
73 | super(MultiTaskModel, self).__init__()
74 | self.args = args
75 | self.pair_encoder = pair_encoder
76 |
77 | self.FDS = FDS
78 | self.start_smooth = args.start_smooth
79 |
80 | def build_regressor(self, task, d_inp):
81 | layer = nn.Linear(d_inp, 1)
82 | setattr(self, '%s_pred_layer' % task.name, layer)
83 |
84 | def forward(self, task=None, epoch=None, input1=None, input2=None, mask1=None, mask2=None, label=None, weight=None):
85 | pred_layer = getattr(self, '%s_pred_layer' % task.name)
86 |
87 | pair_emb = self.pair_encoder(input1, input2, mask1, mask2)
88 | pair_emb_s = pair_emb
89 | if self.training and self.FDS is not None:
90 | if epoch >= self.start_smooth:
91 | pair_emb_s = self.FDS.smooth(pair_emb_s, label, epoch)
92 | logits = pred_layer(pair_emb_s)
93 |
94 | out = {}
95 | if self.training and self.FDS is not None:
96 | out['embs'] = pair_emb
97 | out['labels'] = label
98 |
99 | if self.args.loss == 'huber':
100 | loss = globals()[f"weighted_{self.args.loss}_loss"](
101 | inputs=logits, targets=label / torch.tensor(5.).cuda(), weights=weight,
102 | beta=self.args.huber_beta
103 | )
104 | else:
105 | loss = globals()[f"weighted_{self.args.loss}_loss"](
106 | inputs=logits, targets=label / torch.tensor(5.).cuda(), weights=weight
107 | )
108 | out['logits'] = logits
109 | label = label.squeeze(-1).data.cpu().numpy()
110 | logits = logits.squeeze(-1).data.cpu().numpy()
111 | task.scorer(logits, label)
112 | out['loss'] = loss
113 |
114 | return out
115 |
116 | class HeadlessPairEncoder(Model):
117 | def __init__(self, vocab, text_field_embedder, num_highway_layers, phrase_layer,
118 | dropout=0.2, mask_lstms=True, initializer=InitializerApplicator()):
119 | super(HeadlessPairEncoder, self).__init__(vocab)
120 |
121 | self._text_field_embedder = text_field_embedder
122 | d_emb = text_field_embedder.get_output_dim()
123 | self._highway_layer = TimeDistributed(Highway(d_emb, num_highway_layers))
124 |
125 | self._phrase_layer = phrase_layer
126 | self.pad_idx = vocab.get_token_index(vocab._padding_token)
127 | self.output_dim = phrase_layer.get_output_dim()
128 |
129 | if dropout > 0:
130 | self._dropout = torch.nn.Dropout(p=dropout)
131 | else:
132 | self._dropout = lambda x: x
133 | self._mask_lstms = mask_lstms
134 |
135 | initializer(self)
136 |
137 | def forward(self, s1, s2, m1=None, m2=None):
138 | s1_embs = self._highway_layer(self._text_field_embedder(s1) if m1 is None else s1)
139 | s2_embs = self._highway_layer(self._text_field_embedder(s2) if m2 is None else s2)
140 |
141 | s1_embs = self._dropout(s1_embs)
142 | s2_embs = self._dropout(s2_embs)
143 |
144 | # Set up masks
145 | s1_mask = util.get_text_field_mask(s1) if m1 is None else m1.long()
146 | s2_mask = util.get_text_field_mask(s2) if m2 is None else m2.long()
147 |
148 | s1_lstm_mask = s1_mask.float() if self._mask_lstms else None
149 | s2_lstm_mask = s2_mask.float() if self._mask_lstms else None
150 |
151 | # Sentence encodings with LSTMs
152 | s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask)
153 | s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask)
154 |
155 | s1_enc = self._dropout(s1_enc)
156 | s2_enc = self._dropout(s2_enc)
157 |
158 | # Max pooling
159 | s1_mask = s1_mask.unsqueeze(dim=-1)
160 | s2_mask = s2_mask.unsqueeze(dim=-1)
161 | s1_enc.data.masked_fill_(1 - s1_mask.byte().data, -float('inf'))
162 | s2_enc.data.masked_fill_(1 - s2_mask.byte().data, -float('inf'))
163 | s1_enc, _ = s1_enc.max(dim=1)
164 | s2_enc, _ = s2_enc.max(dim=1)
165 |
166 | return torch.cat([s1_enc, s2_enc, torch.abs(s1_enc - s2_enc), s1_enc * s2_enc], 1)
--------------------------------------------------------------------------------
/sts-b-dir/preprocess.py:
--------------------------------------------------------------------------------
1 | '''Preprocessing functions and pipeline'''
2 | import nltk
3 | nltk.download('punkt')
4 | import torch
5 | import logging
6 | import numpy as np
7 | from collections import defaultdict
8 |
9 | from allennlp.data import Instance, Vocabulary, Token
10 | from allennlp.data.fields import TextField
11 | from allennlp.data.token_indexers import SingleIdTokenIndexer
12 | from allennlp_mods.numeric_field import NumericField
13 |
14 | from tasks import STSBTask
15 |
16 | PATH_PREFIX = './glue_data/'
17 |
18 | ALL_TASKS = ['sts-b']
19 | NAME2INFO = {'sts-b': (STSBTask, 'STS-B/')}
20 |
21 | for k, v in NAME2INFO.items():
22 | NAME2INFO[k] = (v[0], PATH_PREFIX + v[1])
23 |
24 | def build_tasks(args):
25 | '''Prepare tasks'''
26 |
27 | task_names = [args.task]
28 | tasks = get_tasks(args, task_names, args.max_seq_len)
29 |
30 | max_v_sizes = {'word': args.max_word_v_size}
31 | token_indexer = {}
32 | token_indexer["words"] = SingleIdTokenIndexer()
33 |
34 | logging.info("\tProcessing tasks from scratch")
35 | word2freq = get_words(tasks)
36 | vocab = get_vocab(word2freq, max_v_sizes)
37 | word_embs = get_embeddings(vocab, args.word_embs_file, args.d_word)
38 | for task in tasks:
39 | train, val, test = process_task(task, token_indexer, vocab)
40 | task.train_data = train
41 | task.val_data = val
42 | task.test_data = test
43 | del_field_tokens(task)
44 | logging.info("\tFinished indexing tasks")
45 |
46 | train_eval_tasks = [task for task in tasks if task.name in task_names]
47 | logging.info('\t Training and evaluating on %s', ', '.join([task.name for task in train_eval_tasks]))
48 |
49 | return train_eval_tasks, vocab, word_embs
50 |
51 | def del_field_tokens(task):
52 | ''' Save memory by deleting the tokens that will no longer be used '''
53 |
54 | all_instances = task.train_data + task.val_data + task.test_data
55 | for instance in all_instances:
56 | if 'input1' in instance.fields:
57 | field = instance.fields['input1']
58 | del field.tokens
59 | if 'input2' in instance.fields:
60 | field = instance.fields['input2']
61 | del field.tokens
62 |
63 | def get_tasks(args, task_names, max_seq_len):
64 | '''Load tasks'''
65 | tasks = []
66 | for name in task_names:
67 | assert name in NAME2INFO, 'Task not found!'
68 | task = NAME2INFO[name][0](args, NAME2INFO[name][1], max_seq_len, name)
69 | tasks.append(task)
70 | logging.info("\tFinished loading tasks: %s.", ' '.join([task.name for task in tasks]))
71 |
72 | return tasks
73 |
74 | def get_words(tasks):
75 | '''
76 | Get all words for all tasks for all splits for all sentences
77 | Return dictionary mapping words to frequencies.
78 | '''
79 | word2freq = defaultdict(int)
80 |
81 | def count_sentence(sentence):
82 | '''Update counts for words in the sentence'''
83 | for word in sentence:
84 | word2freq[word] += 1
85 | return
86 |
87 | for task in tasks:
88 | splits = [task.train_data_text, task.val_data_text, task.test_data_text]
89 | for split in [split for split in splits if split is not None]:
90 | for sentence in split[0]:
91 | count_sentence(sentence)
92 | for sentence in split[1]:
93 | count_sentence(sentence)
94 |
95 | logging.info("\tFinished counting words")
96 |
97 | return word2freq
98 |
99 | def get_vocab(word2freq, max_v_sizes):
100 | '''Build vocabulary'''
101 | vocab = Vocabulary(counter=None, max_vocab_size=max_v_sizes['word'])
102 | words_by_freq = [(word, freq) for word, freq in word2freq.items()]
103 | words_by_freq.sort(key=lambda x: x[1], reverse=True)
104 | for word, _ in words_by_freq[:max_v_sizes['word']]:
105 | vocab.add_token_to_namespace(word, 'tokens')
106 | logging.info("\tFinished building vocab. Using %d words", vocab.get_vocab_size('tokens'))
107 |
108 | return vocab
109 |
110 | def get_embeddings(vocab, vec_file, d_word):
111 | '''Get embeddings for the words in vocab'''
112 | word_v_size, unk_idx = vocab.get_vocab_size('tokens'), vocab.get_token_index(vocab._oov_token)
113 | embeddings = np.random.randn(word_v_size, d_word)
114 | with open(vec_file) as vec_fh:
115 | for line in vec_fh:
116 | word, vec = line.split(' ', 1)
117 | idx = vocab.get_token_index(word)
118 | if idx != unk_idx:
119 | idx = vocab.get_token_index(word)
120 | embeddings[idx] = np.array(list(map(float, vec.split())))
121 | embeddings[vocab.get_token_index('@@PADDING@@')] = 0.
122 | embeddings = torch.FloatTensor(embeddings)
123 | logging.info("\tFinished loading embeddings")
124 |
125 | return embeddings
126 |
127 | def process_task(task, token_indexer, vocab):
128 | '''
129 | Convert a task's splits into AllenNLP fields then
130 | Index the splits using the given vocab (experiment dependent)
131 | '''
132 | if hasattr(task, 'train_data_text') and task.train_data_text is not None:
133 | train = process_split(task.train_data_text, token_indexer)
134 | else:
135 | train = None
136 | if hasattr(task, 'val_data_text') and task.val_data_text is not None:
137 | val = process_split(task.val_data_text, token_indexer)
138 | else:
139 | val = None
140 | if hasattr(task, 'test_data_text') and task.test_data_text is not None:
141 | test = process_split(task.test_data_text, token_indexer)
142 | else:
143 | test = None
144 |
145 | for instance in train + val + test:
146 | instance.index_fields(vocab)
147 |
148 | return train, val, test
149 |
150 | def process_split(split, indexers):
151 | '''
152 | Convert a dataset of sentences into padded sequences of indices.
153 | '''
154 | inputs1 = [TextField(list(map(Token, sent)), token_indexers=indexers) for sent in split[0]]
155 | inputs2 = [TextField(list(map(Token, sent)), token_indexers=indexers) for sent in split[1]]
156 | labels = [NumericField(l) for l in split[-1]]
157 |
158 | if len(split) == 4: # weight
159 | weights = [NumericField(w) for w in split[2]]
160 | instances = [Instance({"input1": input1, "input2": input2, "label": label, "weight": weight}) for \
161 | (input1, input2, label, weight) in zip(inputs1, inputs2, labels, weights)]
162 | else:
163 | instances = [Instance({"input1": input1, "input2": input2, "label": label}) for \
164 | (input1, input2, label) in zip(inputs1, inputs2, labels)]
165 |
166 | return instances
167 |
--------------------------------------------------------------------------------
/sts-b-dir/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk
2 | wget
3 | ipdb
4 | scikit-learn
5 | allennlp==0.5.0
--------------------------------------------------------------------------------
/sts-b-dir/tasks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import nltk
3 | import codecs
4 | import logging
5 | import numpy as np
6 | from scipy.ndimage import convolve1d
7 | from util import get_lds_kernel_window, STSShotAverage
8 |
9 | def process_sentence(sent, max_seq_len):
10 | '''process a sentence using NLTK toolkit'''
11 | return nltk.word_tokenize(sent)[:max_seq_len]
12 |
13 | def load_tsv(data_file, max_seq_len, s1_idx=0, s2_idx=1, targ_idx=2, targ_fn=None, skip_rows=0, delimiter='\t', args=None):
14 | '''Load a tsv '''
15 | sent1s, sent2s, targs = [], [], []
16 | with codecs.open(data_file, 'r', 'utf-8') as data_fh:
17 | for _ in range(skip_rows):
18 | data_fh.readline()
19 | for row_idx, row in enumerate(data_fh):
20 | try:
21 | row = row.strip().split(delimiter)
22 | sent1 = process_sentence(row[s1_idx], max_seq_len)
23 | if (targ_idx is not None and not row[targ_idx]) or not len(sent1):
24 | continue
25 |
26 | if targ_idx is not None:
27 | targ = targ_fn(row[targ_idx])
28 | else:
29 | targ = 0
30 |
31 | if s2_idx is not None:
32 | sent2 = process_sentence(row[s2_idx], max_seq_len)
33 | if not len(sent2):
34 | continue
35 | sent2s.append(sent2)
36 |
37 | sent1s.append(sent1)
38 | targs.append(targ)
39 |
40 | except Exception as e:
41 | logging.info(e, " file: %s, row: %d" % (data_file, row_idx))
42 | continue
43 |
44 | if args is not None and args.reweight != 'none':
45 | assert args.reweight in {'inverse', 'sqrt_inv'}
46 | assert args.reweight != 'none' if args.lds else True, "Set reweight to \'inverse\' (default) or \'sqrt_inv\' when using LDS"
47 |
48 | bins = args.bucket_num
49 | value_lst, bins_edges = np.histogram(targs, bins=bins, range=(0., 5.))
50 |
51 | def get_bin_idx(label):
52 | if label == 5.:
53 | return bins - 1
54 | else:
55 | return np.where(bins_edges > label)[0][0] - 1
56 |
57 | if args.reweight == 'sqrt_inv':
58 | value_lst = [np.sqrt(x) for x in value_lst]
59 | num_per_label = [value_lst[get_bin_idx(label)] for label in targs]
60 |
61 | logging.info(f"Using re-weighting: [{args.reweight.upper()}]")
62 |
63 | if args.lds:
64 | lds_kernel_window = get_lds_kernel_window(args.lds_kernel, args.lds_ks, args.lds_sigma)
65 | logging.info(f'Using LDS: [{args.lds_kernel.upper()}] ({args.lds_ks}/{args.lds_sigma})')
66 | smoothed_value = convolve1d(value_lst, weights=lds_kernel_window, mode='constant')
67 | num_per_label = [smoothed_value[get_bin_idx(label)] for label in targs]
68 |
69 | weights = [np.float32(1 / x) for x in num_per_label]
70 | scaling = len(weights) / np.sum(weights)
71 | weights = [scaling * x for x in weights]
72 |
73 | return sent1s, sent2s, weights, targs
74 |
75 | return sent1s, sent2s, targs
76 |
77 | class STSBTask:
78 | ''' Task class for Sentence Textual Similarity Benchmark. '''
79 | def __init__(self, args, path, max_seq_len, name="sts-b"):
80 | ''' '''
81 | super(STSBTask, self).__init__()
82 | self.args = args
83 | self.name = name
84 | self.train_data_text, self.val_data_text, self.test_data_text = None, None, None
85 | self.val_metric = 'mse'
86 | self.scorer = STSShotAverage(metric=['mse', 'l1', 'gmean', 'pearsonr', 'spearmanr'])
87 | self.load_data(path, max_seq_len)
88 |
89 | def load_data(self, path, max_seq_len):
90 | ''' '''
91 | tr_data = load_tsv(os.path.join(path, 'train_new.tsv'), max_seq_len, skip_rows=1,
92 | s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: np.float32(x), args=self.args)
93 | val_data = load_tsv(os.path.join(path, 'dev_new.tsv'), max_seq_len, skip_rows=1,
94 | s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: np.float32(x))
95 | te_data = load_tsv(os.path.join(path, 'test_new.tsv'), max_seq_len, skip_rows=1,
96 | s1_idx=7, s2_idx=8, targ_idx=9, targ_fn=lambda x: np.float32(x))
97 |
98 | self.train_data_text = tr_data
99 | self.val_data_text = val_data
100 | self.test_data_text = te_data
101 | logging.info("\tFinished loading STS Benchmark data.")
102 |
103 | def get_metrics(self, reset=False, type=None):
104 | metric = self.scorer.get_metric(reset, type)
105 |
106 | return metric
107 |
--------------------------------------------------------------------------------
/sts-b-dir/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import argparse
6 | import shutil
7 | import logging
8 | import torch
9 | import numpy as np
10 | from allennlp.data.iterators import BasicIterator
11 |
12 | from preprocess import build_tasks
13 | from models import build_model
14 | from trainer import build_trainer
15 | from evaluate import evaluate
16 | from util import device_mapping, query_yes_no, resume_checkpoint
17 |
18 | def main(arguments):
19 | parser = argparse.ArgumentParser(description='')
20 |
21 | parser.add_argument('--cuda', help='-1 if no CUDA, else gpu id (single gpu is enough)', type=int, default=0)
22 | parser.add_argument('--random_seed', help='random seed to use', type=int, default=111)
23 |
24 | # Paths and logging
25 | parser.add_argument('--log_file', help='file to log to', type=str, default='training.log')
26 | parser.add_argument('--store_root', help='store root path', type=str, default='checkpoint')
27 | parser.add_argument('--store_name', help='store name prefix for current experiment', type=str, default='sts')
28 | parser.add_argument('--suffix', help='store name suffix for current experiment', type=str, default='')
29 | parser.add_argument('--word_embs_file', help='file containing word embs', type=str, default='glove/glove.840B.300d.txt')
30 |
31 | # Training resuming flag
32 | parser.add_argument('--resume', help='whether to resume training', action='store_true', default=False)
33 |
34 | # Tasks
35 | parser.add_argument('--task', help='training and evaluation task', type=str, default='sts-b')
36 |
37 | # Preprocessing options
38 | parser.add_argument('--max_seq_len', help='max sequence length', type=int, default=40)
39 | parser.add_argument('--max_word_v_size', help='max word vocab size', type=int, default=30000)
40 |
41 | # Embedding options
42 | parser.add_argument('--dropout_embs', help='dropout rate for embeddings', type=float, default=.2)
43 | parser.add_argument('--d_word', help='dimension of word embeddings', type=int, default=300)
44 | parser.add_argument('--glove', help='1 if use glove, else from scratch', type=int, default=1)
45 | parser.add_argument('--train_words', help='1 if make word embs trainable', type=int, default=0)
46 |
47 | # Model options
48 | parser.add_argument('--d_hid', help='hidden dimension size', type=int, default=1500)
49 | parser.add_argument('--n_layers_enc', help='number of RNN layers', type=int, default=2)
50 | parser.add_argument('--n_layers_highway', help='number of highway layers', type=int, default=0)
51 | parser.add_argument('--dropout', help='dropout rate to use in training', type=float, default=0.2)
52 |
53 | # Training options
54 | parser.add_argument('--batch_size', help='batch size', type=int, default=128)
55 | parser.add_argument('--optimizer', help='optimizer to use', type=str, default='adam')
56 | parser.add_argument('--lr', help='starting learning rate', type=float, default=1e-4)
57 | parser.add_argument('--loss', type=str, default='mse', choices=['mse', 'l1', 'focal_l1', 'focal_mse', 'huber'])
58 | parser.add_argument('--huber_beta', type=float, default=0.3, help='beta for huber loss')
59 | parser.add_argument('--max_grad_norm', help='max grad norm', type=float, default=5.)
60 | parser.add_argument('--val_interval', help='number of iterations between validation checks', type=int, default=400)
61 | parser.add_argument('--max_vals', help='maximum number of validation checks', type=int, default=100)
62 | parser.add_argument('--patience', help='patience for early stopping', type=int, default=10)
63 |
64 | # imbalanced related
65 | # LDS
66 | parser.add_argument('--lds', action='store_true', default=False, help='whether to enable LDS')
67 | parser.add_argument('--lds_kernel', type=str, default='gaussian',
68 | choices=['gaussian', 'triang', 'laplace'], help='LDS kernel type')
69 | parser.add_argument('--lds_ks', type=int, default=5, help='LDS kernel size: should be odd number')
70 | parser.add_argument('--lds_sigma', type=float, default=2, help='LDS gaussian/laplace kernel sigma')
71 | # FDS
72 | parser.add_argument('--fds', action='store_true', default=False, help='whether to enable FDS')
73 | parser.add_argument('--fds_kernel', type=str, default='gaussian',
74 | choices=['gaussian', 'triang', 'laplace'], help='FDS kernel type')
75 | parser.add_argument('--fds_ks', type=int, default=5, help='FDS kernel size: should be odd number')
76 | parser.add_argument('--fds_sigma', type=float, default=2, help='FDS gaussian/laplace kernel sigma')
77 | parser.add_argument('--start_update', type=int, default=0, help='which epoch to start FDS updating')
78 | parser.add_argument('--start_smooth', type=int, default=1, help='which epoch to start using FDS to smooth features')
79 | parser.add_argument('--bucket_num', type=int, default=50, help='maximum bucket considered for FDS')
80 | parser.add_argument('--bucket_start', type=int, default=0, help='minimum(starting) bucket for FDS')
81 | parser.add_argument('--fds_mmt', type=float, default=0.9, help='FDS momentum')
82 |
83 | # re-weighting: SQRT_INV / INV
84 | parser.add_argument('--reweight', type=str, default='none', choices=['none', 'sqrt_inv', 'inverse'],
85 | help='cost-sensitive reweighting scheme')
86 | # two-stage training: RRT
87 | parser.add_argument('--retrain_fc', action='store_true', default=False,
88 | help='whether to retrain last regression layer (regressor)')
89 | parser.add_argument('--pretrained', type=str, default='', help='pretrained checkpoint file path to load backbone weights for RRT')
90 | # evaluate only
91 | parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate only flag')
92 | parser.add_argument('--eval_model', type=str, default='', help='the model to evaluate on; if not specified, '
93 | 'use the default best model in store_dir')
94 |
95 | args = parser.parse_args(arguments)
96 |
97 | os.makedirs(args.store_root, exist_ok=True)
98 |
99 | if not args.lds and args.reweight != 'none':
100 | args.store_name += f'_{args.reweight}'
101 | if args.lds:
102 | args.store_name += f'_lds_{args.lds_kernel[:3]}_{args.lds_ks}'
103 | if args.lds_kernel in ['gaussian', 'laplace']:
104 | args.store_name += f'_{args.lds_sigma}'
105 | if args.fds:
106 | args.store_name += f'_fds_{args.fds_kernel[:3]}_{args.fds_ks}'
107 | if args.fds_kernel in ['gaussian', 'laplace']:
108 | args.store_name += f'_{args.fds_sigma}'
109 | args.store_name += f'_{args.start_update}_{args.start_smooth}_{args.fds_mmt}'
110 | if args.retrain_fc:
111 | args.store_name += f'_retrain_fc'
112 |
113 | if args.loss == 'huber':
114 | args.store_name += f'_{args.loss}_beta_{args.huber_beta}'
115 | else:
116 | args.store_name += f'_{args.loss}'
117 |
118 | args.store_name += f'_seed_{args.random_seed}_valint_{args.val_interval}_patience_{args.patience}' \
119 | f'_{args.optimizer}_{args.lr}_{args.batch_size}'
120 | args.store_name += f'_{args.suffix}' if len(args.suffix) else ''
121 |
122 | args.store_dir = os.path.join(args.store_root, args.store_name)
123 |
124 | if not args.evaluate and not args.resume:
125 | if os.path.exists(args.store_dir):
126 | if query_yes_no('overwrite previous folder: {} ?'.format(args.store_dir)):
127 | shutil.rmtree(args.store_dir)
128 | print(args.store_dir + ' removed.\n')
129 | else:
130 | raise RuntimeError('Output folder {} already exists'.format(args.store_dir))
131 | logging.info(f"===> Creating folder: {args.store_dir}")
132 | os.makedirs(args.store_dir)
133 |
134 | # Logistics
135 | logging.root.handlers = []
136 | if os.path.exists(args.store_dir):
137 | log_file = os.path.join(args.store_dir, args.log_file)
138 | logging.basicConfig(
139 | level=logging.INFO,
140 | format="%(asctime)s | %(message)s",
141 | handlers=[
142 | logging.FileHandler(log_file),
143 | logging.StreamHandler()
144 | ])
145 | else:
146 | logging.basicConfig(
147 | level=logging.INFO,
148 | format="%(asctime)s | %(message)s",
149 | handlers=[logging.StreamHandler()]
150 | )
151 | logging.info(args)
152 |
153 | seed = random.randint(1, 10000) if args.random_seed < 0 else args.random_seed
154 | random.seed(seed)
155 | torch.manual_seed(seed)
156 | if args.cuda >= 0:
157 | logging.info("Using GPU %d", args.cuda)
158 | torch.cuda.set_device(args.cuda)
159 | torch.cuda.manual_seed_all(seed)
160 | logging.info("Using random seed %d", seed)
161 |
162 | # Load tasks
163 | logging.info("Loading tasks...")
164 | start_time = time.time()
165 | tasks, vocab, word_embs = build_tasks(args)
166 | logging.info('\tFinished loading tasks in %.3fs', time.time() - start_time)
167 |
168 | # Build model
169 | logging.info('Building model...')
170 | start_time = time.time()
171 | model = build_model(args, vocab, word_embs, tasks)
172 | logging.info('\tFinished building model in %.3fs', time.time() - start_time)
173 |
174 | # Set up trainer
175 | iterator = BasicIterator(args.batch_size)
176 | trainer, train_params, opt_params = build_trainer(args, model, iterator)
177 |
178 | # Train
179 | if tasks and not args.evaluate:
180 | if args.retrain_fc and len(args.pretrained):
181 | model_path = args.pretrained
182 | assert os.path.isfile(model_path), f"No checkpoint found at '{model_path}'"
183 | model_state = torch.load(model_path, map_location=device_mapping(args.cuda))
184 | trainer._model = resume_checkpoint(trainer._model, model_state, backbone_only=True)
185 | logging.info(f'Pre-trained backbone weights loaded: {model_path}')
186 | logging.info('Retrain last regression layer only!')
187 | for name, param in trainer._model.named_parameters():
188 | if "sts-b_pred_layer" not in name:
189 | param.requires_grad = False
190 | logging.info(f'Only optimize parameters: {[n for n, p in trainer._model.named_parameters() if p.requires_grad]}')
191 | to_train = [(n, p) for n, p in trainer._model.named_parameters() if p.requires_grad]
192 | else:
193 | to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
194 |
195 | trainer.train(tasks, args.val_interval, to_train, opt_params, args.resume)
196 | else:
197 | logging.info("Skipping training...")
198 |
199 | logging.info('Testing on test set...')
200 | model_path = os.path.join(args.store_dir, "model_state_best.th") if not len(args.eval_model) else args.eval_model
201 | assert os.path.isfile(model_path), f"No checkpoint found at '{model_path}'"
202 | logging.info(f'Evaluating {model_path}...')
203 | model_state = torch.load(model_path, map_location=device_mapping(args.cuda))
204 | model = resume_checkpoint(model, model_state)
205 | te_preds, te_labels, _ = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="test")
206 | if not len(args.eval_model):
207 | np.savez_compressed(os.path.join(args.store_dir, f"{args.store_name}.npz"), preds=te_preds, labels=te_labels)
208 |
209 | logging.info("Done testing.")
210 |
211 | if __name__ == '__main__':
212 | sys.exit(main(sys.argv[1:]))
213 |
--------------------------------------------------------------------------------
/sts-b-dir/util.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import numpy as np
3 | import torch
4 |
5 | from scipy.ndimage import gaussian_filter1d
6 | from scipy.signal.windows import triang
7 | from scipy.stats import pearsonr, spearmanr, gmean
8 |
9 |
10 | def get_text_field_mask(text_field_tensors: Dict[str, torch.Tensor]) -> torch.LongTensor:
11 | """
12 | Takes the dictionary of tensors produced by a ``TextField`` and returns a mask of shape
13 | ``(batch_size, num_tokens)``. This mask will be 0 where the tokens are padding, and 1
14 | otherwise.
15 |
16 | There could be several entries in the tensor dictionary with different shapes (e.g., one for
17 | word ids, one for character ids). In order to get a token mask, we assume that the tensor in
18 | the dictionary with the lowest number of dimensions has plain token ids. This allows us to
19 | also handle cases where the input is actually a ``ListField[TextField]``.
20 |
21 | NOTE: Our functions for generating masks create torch.LongTensors, because using
22 | torch.byteTensors inside Variables makes it easy to run into overflow errors
23 | when doing mask manipulation, such as summing to get the lengths of sequences - see below.
24 | >>> mask = torch.ones([260]).byte()
25 | >>> mask.sum() # equals 260.
26 | >>> var_mask = torch.autograd.Variable(mask)
27 | >>> var_mask.sum() # equals 4, due to 8 bit precision - the sum overflows.
28 | """
29 | tensor_dims = [(tensor.dim(), tensor) for tensor in text_field_tensors.values()]
30 | tensor_dims.sort(key=lambda x: x[0])
31 | token_tensor = tensor_dims[0][1]
32 |
33 | return (token_tensor != 0).long()
34 |
35 | def device_mapping(cuda_device: int):
36 | """
37 | In order to `torch.load()` a GPU-trained model onto a CPU (or specific GPU),
38 | you have to supply a `map_location` function. Call this with
39 | the desired `cuda_device` to get the function that `torch.load()` needs.
40 | """
41 | def inner_device_mapping(storage: torch.Storage, location) -> torch.Storage:
42 | if cuda_device >= 0:
43 | return storage.cuda(cuda_device)
44 | else:
45 | return storage
46 | return inner_device_mapping
47 |
48 | def query_yes_no(question):
49 | """ Ask a yes/no question via input() and return their answer. """
50 | valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
51 | prompt = " [Y/n] "
52 |
53 | while True:
54 | print(question + prompt, end=':')
55 | choice = input().lower()
56 | if choice == '':
57 | return valid['y']
58 | elif choice in valid:
59 | return valid[choice]
60 | else:
61 | print("Please respond with 'yes' or 'no' (or 'y' or 'n').\n")
62 |
63 | def calibrate_mean_var(matrix, m1, v1, m2, v2, clip_min=0.5, clip_max=2.):
64 | if torch.sum(v1) < 1e-10:
65 | return matrix
66 | if (v1 <= 0.).any() or (v2 < 0.).any():
67 | valid_pos = (((v1 > 0.) + (v2 >= 0.)) == 2)
68 | factor = torch.clamp(v2[valid_pos] / v1[valid_pos], clip_min, clip_max)
69 | matrix[:, valid_pos] = (matrix[:, valid_pos] - m1[valid_pos]) * torch.sqrt(factor) + m2[valid_pos]
70 | return matrix
71 |
72 | factor = torch.clamp(v2 / v1, clip_min, clip_max)
73 | return (matrix - m1) * torch.sqrt(factor) + m2
74 |
75 | def resume_checkpoint(model, model_state, backbone_only=False):
76 | model.pair_encoder.load_state_dict(
77 | {k.split('.', 1)[1]: v for k, v in model_state.items() if 'pair_encoder' in k}
78 | )
79 | if not backbone_only:
80 | getattr(model, 'sts-b_pred_layer').load_state_dict(
81 | {k.split('.', 1)[1]: v for k, v in model_state.items() if 'sts-b_pred_layer' in k}
82 | )
83 |
84 | return model
85 |
86 | def get_lds_kernel_window(kernel, ks, sigma):
87 | assert kernel in ['gaussian', 'triang', 'laplace']
88 | half_ks = (ks - 1) // 2
89 | if kernel == 'gaussian':
90 | base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
91 | kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma))
92 | # kernel = gaussian(ks)
93 | elif kernel == 'triang':
94 | kernel_window = triang(ks)
95 | else:
96 | laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
97 | kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1)))
98 |
99 | return kernel_window
100 |
101 | class STSShotAverage:
102 | def __init__(self, metric):
103 | self._pred = []
104 | self._label = []
105 | self._count = 0
106 | self._metric = metric
107 | self._num_bins = 50
108 | # under np.float32 division
109 | self._shot_idx = {
110 | 'many': [0, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 49],
111 | 'medium': [2, 4, 6, 8, 27, 35, 37],
112 | 'few': [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 29, 31, 33, 39, 41, 43, 45, 47]
113 | }
114 |
115 | def get_bin_idx(label):
116 | _, bins_edges = np.histogram(a=np.array([], dtype=np.float32), bins=self._num_bins, range=(0., 5.))
117 | if label == 5.:
118 | return self._num_bins - 1
119 | else:
120 | return np.where(bins_edges > label)[0][0] - 1
121 | self._get_bin_idx = get_bin_idx
122 |
123 | def __call__(self, pred, label):
124 | self._pred += pred.tolist()
125 | self._label += label.tolist()
126 | self._count += len(pred)
127 |
128 | def get_metric(self, reset=False, type=None):
129 | label_bin_idx = list(map(self._get_bin_idx, self._label))
130 | def bin2shot(idx):
131 | if idx in self._shot_idx['many']:
132 | return 'many'
133 | elif idx in self._shot_idx['medium']:
134 | return 'medium'
135 | else:
136 | return 'few'
137 | label_category = np.array(list(map(bin2shot, label_bin_idx)))
138 |
139 | pred_shot = {'many': [], 'medium': [], 'few': [], 'overall': []}
140 | label_shot = {'many': [], 'medium': [], 'few': [], 'overall': []}
141 | metric = {'many': {}, 'medium': {}, 'few': {}, 'overall': {}}
142 | for shot in ['overall', 'many', 'medium', 'few']:
143 | pred_shot[shot] = np.array(self._pred)[label_category == shot] * 5. if shot != 'overall' else np.array(self._pred) * 5.
144 | label_shot[shot] = np.array(self._label)[label_category == shot] if shot != 'overall' else np.array(self._label)
145 | if 'mse' in self._metric:
146 | metric[shot]['mse'] = np.mean((pred_shot[shot] - label_shot[shot]) ** 2) if pred_shot[shot].size > 0 else 0.
147 | if 'l1' in self._metric:
148 | metric[shot]['l1'] = np.mean(np.abs(pred_shot[shot] - label_shot[shot])) if pred_shot[shot].size > 0 else 0.
149 | if 'gmean' in self._metric:
150 | if pred_shot[shot].size <= 0:
151 | metric[shot]['gmean'] = 0.
152 | else:
153 | diff = np.abs(pred_shot[shot] - label_shot[shot])
154 | if diff[diff == 0.].size:
155 | diff[diff == 0.] += 1e-10
156 | metric[shot]['gmean'] = gmean(diff) if pred_shot[shot].size > 0 else 0.
157 | else:
158 | metric[shot]['gmean'] = gmean(np.abs(pred_shot[shot] - label_shot[shot])) if pred_shot[shot].size > 0 else 0.
159 | if 'pearsonr' in self._metric:
160 | metric[shot]['pearsonr'] = pearsonr(pred_shot[shot], label_shot[shot])[0] if pred_shot[shot].size > 1 else 0.
161 | if 'spearmanr' in self._metric:
162 | metric[shot]['spearmanr'] = spearmanr(pred_shot[shot], label_shot[shot])[0] if pred_shot[shot].size > 1 else 0.
163 | metric[shot]['num_samples'] = pred_shot[shot].size
164 | if reset:
165 | self.reset()
166 | return metric['overall'] if type == 'overall' else metric
167 |
168 |
169 | def reset(self):
170 | self._pred = []
171 | self._label = []
172 | self._count = 0
--------------------------------------------------------------------------------
/teaser/agedb_dir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/agedb_dir.png
--------------------------------------------------------------------------------
/teaser/fds.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/fds.gif
--------------------------------------------------------------------------------
/teaser/imdb_wiki_dir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/imdb_wiki_dir.png
--------------------------------------------------------------------------------
/teaser/lds.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/lds.gif
--------------------------------------------------------------------------------
/teaser/nyud2_dir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/nyud2_dir.png
--------------------------------------------------------------------------------
/teaser/overview.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/overview.gif
--------------------------------------------------------------------------------
/teaser/shhs_dir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/shhs_dir.png
--------------------------------------------------------------------------------
/teaser/stsb_dir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YyzHarry/imbalanced-regression/a6fdc45d45c04e6f5c40f43925bc66e580911084/teaser/stsb_dir.png
--------------------------------------------------------------------------------
/tutorial/README.md:
--------------------------------------------------------------------------------
1 | # Hands-on Tutorial of Deep Imbalanced Regression
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | This is a hands-on tutorial for **Delving into Deep Imbalanced Regression** [[Paper]](https://arxiv.org/abs/2102.09554).
11 |
12 | ```bib
13 | @inproceedings{yang2021delving,
14 | title={Delving into Deep Imbalanced Regression},
15 | author={Yang, Yuzhe and Zha, Kaiwen and Chen, Ying-Cong and Wang, Hao and Katabi, Dina},
16 | booktitle={International Conference on Machine Learning (ICML)},
17 | year={2021}
18 | }
19 | ```
20 |
21 | In this notebook, we will provide a hands-on tutorial for DIR on a small-scale dataset, [Boston Housing dataset](https://www.cs.toronto.edu/~delve/data/boston/bostonDetail.html), as a quick overview on how to perform practical (deep) imbalanced regression on custom datasets.
22 |
23 | You can directly open it via Colab: [](https://colab.research.google.com/github/YyzHarry/imbalanced-regression/blob/master/tutorial/tutorial.ipynb), or using jupyter notebook with the following instructions.
24 |
25 | Required packages:
26 | ```bash
27 | pip install --upgrade pip
28 | pip install --upgrade jupyter notebook
29 | ```
30 |
31 | Then, please clone this repository to your computer using:
32 |
33 | ```bash
34 | git clone https://github.com/YyzHarry/imbalanced-regression.git
35 | ```
36 |
37 | After cloning is finished, you may go to the directory of this tutorial and run
38 |
39 | ```bash
40 | jupyter notebook --port 8888
41 | ```
42 |
43 | to start a jupyter notebook and access it through the browser. Finally, let's explore the notebook `tutorial.ipynb` prepared by us!
44 |
--------------------------------------------------------------------------------