├── .gitignore ├── README.md ├── download_data.sh ├── example.ipynb ├── generate_scores ├── README.md ├── download_cifar.sh ├── train_imagenet │ ├── convert.py │ ├── download.py │ ├── get_simclr_representations.py │ ├── resnet.py │ ├── run_get_simclr_representations.sh │ ├── run_train_linear_and_get_logits.sh │ └── train_linear_and_get_logits.py └── train_models │ ├── cifar-100.ipynb │ ├── cifar_utils.py │ ├── data_utils.py │ ├── torchvision_dataset_utils.py │ ├── train.py │ ├── train_inaturalist.sh │ ├── train_inaturalist_family.sh │ └── train_places365.sh ├── notebooks ├── create_heatmaps.ipynb ├── create_latex_table.ipynb ├── create_varying_n_plots.ipynb ├── get_dataset_characteristics.ipynb ├── imagenet_case_study.ipynb └── neurips_rebuttal_undercov.ipynb ├── requirements.txt ├── run_experiment.py ├── run_experiment.sh ├── run_heatmap_experiments.sh └── utils ├── clustering_utils.py ├── conformal_utils.py └── experiment_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | old/ 2 | data/ 3 | figs/ 4 | *r152_3x_sk1* 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | This is the code release accompanying the paper [Class-Conditional Conformal Prediction with Many Classes](https://arxiv.org/abs/2306.09335) 3 | 4 | Citation: 5 | ``` 6 | @article{ding2023classconditional, 7 | title={Class-Conditional Conformal Prediction with Many Classes}, 8 | author={Ding, Tiffany and Angelopoulos, Anastasios N and Bates, 9 | Stephen and Jordan, Michael I and Tibshirani, Ryan J}, 10 | journal={arXiv preprint arXiv:2306.09335}, 11 | year={2023} 12 | } 13 | ``` 14 | 15 | 16 | ## Setup 17 | 18 | First, create a virtual environment and install the necessary packages by running 19 | 20 | ``` 21 | conda create --name env 22 | conda activate env 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | To make the environment accessible from Jupyter notebooks, run 27 | 28 | ``` 29 | ipython3 kernel install --user --name=conformal_env 30 | ``` 31 | 32 | This adds a kernel called `conformal_env` to your list of Jupyter kernels. 33 | 34 | Download the datasets by running 35 | 36 | ``` 37 | sh download_data.sh 38 | ``` 39 | 40 | which will create a folder called `data/` and download the data described in the following section. 41 | 42 | ## Data description 43 | 44 | 1. `imagenet` (4.62 GB): `(115301, 1000)` array of softmax scores and `(115301,)` array of labels 45 | 1. `cifar-100` (0.01 GB): `(30000, 100)` array of softmax scores and `(30000,)` array of labels 46 | 1. `places365` (0.54 GB): `(183996, 365)` array of softmax scores and `(183996,)` array of labels 47 | 1. `inaturalist` (6.72 GB): `(1324900, 633)` array of softmax scores and `(1324900,)` array of labels 48 | 49 | The code for training models on the raw datasets to produce the softmax scores is located in `generate_scores/` 50 | 51 | ## Running Clustered Conformal 52 | 53 | See `example.ipynb` for an example of how to run clustered conformal prediction. 54 | 55 | ## Reproducing our experiments 56 | 57 | Run `sh run_experiments.sh` to run our main set of experiments. Run `sh run_heatmap_experiments.sh` for experiments that test the sensitivity of clustered conformal to the hyperparameter values. To view the main results, run `jupyter notebook` from Terminal, then run the notebooks in the `notebooks/` directory. 58 | 59 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | cd data 3 | gdown 1Ax-mP7PJXHEWo4TbUMru5vVjGjmpK5gE # ImageNet (4.62 GB) 4 | gdown 12huZJjuubElkMKB9y05M1Bz_VYH3E0T_ # CIFAR-100 (0.01 GB) 5 | gdown 1LtEDNYiru2hJIOlJFKfa2rS3z_oHcvC5 # Places365 (0.54 GB) 6 | gdown 1D8H_vgzRc66cHAukyOgT0GUm5_u8dX5x # iNaturalist (6.72 GB) 7 | cd .. -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "a90def29", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "\n", 12 | "from utils.conformal_utils import clustered_conformal, random_split\n", 13 | "from utils.experiment_utils import load_dataset" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "id": "99ef4e62", 19 | "metadata": {}, 20 | "source": [ 21 | "This notebook shows how to apply _Clustered Conformal Prediction_ to a set of softmax scores and labels" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "b1feb024", 27 | "metadata": {}, 28 | "source": [ 29 | "## 0) Specify desired coverage level" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "d1d93258", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "alpha = 0.1 # Correspond to 90% coverage" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "fd75b1f4", 45 | "metadata": {}, 46 | "source": [ 47 | "## 1) Get conformal scores\n", 48 | "* softmax_score: `(num_instances, num_classes)` array\n", 49 | "* labels: `(num_instances,)` array" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "id": "9aa1d5a5", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "softmax_scores, labels = load_dataset('imagenet')" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "id": "ba81d289", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "scores_all = 1 - softmax_scores" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "682dbc1b", 75 | "metadata": {}, 76 | "source": [ 77 | "## 2) Split into calibration and validation datasets" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "id": "96e063bd", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# Specify size of calibration dataset\n", 88 | "n_avg = 30 # Average number of examples per class \n", 89 | "cal_scores_all, cal_labels, val_scores_all, val_labels = random_split(scores_all, labels, n_avg)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "id": "730531ff", 95 | "metadata": {}, 96 | "source": [ 97 | "## 3) Use the calibration dataset to estimate conformal quantiles" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 6, 103 | "id": "96d49c24", 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "n_clustering=12, num_clusters=6\n", 111 | "0 of 1000 classes are rare in the clustering set and will be assigned to the null cluster\n", 112 | "Cluster sizes: [186, 185, 180, 171, 153, 125]\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "q_hats = clustered_conformal(cal_scores_all, cal_labels, alpha)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 7, 123 | "id": "9c4d3379", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "# You can pass the quantiles into a wrapper to get a prediction set function \n", 128 | "get_pred_set = lambda softmax_vec: np.where(softmax_vec <= q_hats)[0]" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "20e56bef", 134 | "metadata": {}, 135 | "source": [ 136 | "## 4) Apply prediction set function to new examples \n", 137 | "\n", 138 | "You can rerun the following cell to generate prediction sets for different randomly sampled test points" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 8, 144 | "id": "1e40489a", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Prediction set: [433 457 463 529 615 631 638 667 773 804 837 868 898 911 999]\n", 152 | "True label: 433\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "# Get a test softmax vector from the calibration dataset\n", 158 | "i = np.random.choice(np.arange(len(val_labels)))\n", 159 | "softmax_vec = val_scores_all[i]\n", 160 | "true_label = val_labels[i]\n", 161 | "\n", 162 | "print('Prediction set:', get_pred_set(softmax_vec))\n", 163 | "print('True label:', true_label)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "id": "2086e071", 169 | "metadata": {}, 170 | "source": [ 171 | "### Evaluation\n", 172 | "\n", 173 | "To compute coverage and set size metrics, you can pass `val_scores_all` and `val_labels` into the call to `clustered_conformal()`:" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 9, 179 | "id": "4552f04a", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "n_clustering=12, num_clusters=6\n", 187 | "0 of 1000 classes are rare in the clustering set and will be assigned to the null cluster\n", 188 | "Cluster sizes: [186, 185, 180, 171, 153, 125]\n", 189 | "CLASS COVERAGE GAP: 0.03313341096404464\n", 190 | "AVERAGE SET SIZE: 2.808151188147288\n" 191 | ] 192 | } 193 | ], 194 | "source": [ 195 | "qhats, preds, class_cov_metrics, set_size_metrics = clustered_conformal(cal_scores_all, cal_labels,\n", 196 | " alpha,\n", 197 | " val_scores_all=val_scores_all, \n", 198 | " val_labels=val_labels)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "id": "102337d7", 204 | "metadata": {}, 205 | "source": [ 206 | "Additional metrics can be found in `class_cov_metrics` and `set_size_metrics`" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "id": "d6837e15", 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "conformal_env", 221 | "language": "python", 222 | "name": "conformal_env" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.10.4" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 5 239 | } 240 | -------------------------------------------------------------------------------- /generate_scores/README.md: -------------------------------------------------------------------------------- 1 | # Code for training models and generating conformal scores 2 | 3 | 4 | 5 | ## ImageNet 6 | 7 | First, download ImageNet and update the `path` variable in `run()` (located in `train_imagenet/get_simclr_representations.py`). Then `cd` into the `generate_scores/train_imagenet` folder and run: 8 | 9 | ``` 10 | python download.py r152_3x_sk1 11 | python convert.py r152_3x_sk1/model.ckpt-250228 12 | python get_simclr_representations.py train 13 | python train_linear_and_get_softmax.py 14 | ``` 15 | (Code based on https://github.com/Separius/SimCLRv2-Pytorch) 16 | 17 | ## CIFAR-100 18 | 19 | Download CIFAR-100 if necessary by running `sh download_cifar.sh`. 20 | Go to the `train_models` folder and open `cifar-100.ipynb`. Running all cells will train the model (if we have not already trained) and save the softmax scores and labels for the validation dataset will be saved to `train_models/.cache/`. 21 | 22 | ## Places365 23 | 24 | Download the data from http://places.csail.mit.edu/index.html if necessary. Update the `root` argument of the `datasets.Places365()` dataloader in `get_dataloaders` (located in `train_models/torchvision_dataset_utils.py`) to point to the data. 25 | Go to the `train_models` folder. If you work on a cluster with a SLURM scheduler, running `sbatch train_places365.sh` will train the model on the Places365 dataset. 0.1 of the dataset is reserved for validation. The softmax scores and labels for the validation dataset will be saved to `train_models/.cache/`. 26 | 27 | ## iNaturalist 28 | Download the data from https://github.com/visipedia/inat_comp/tree/master/2021 if necessary. Update the `root` argument of the `datasets.INaturalist()` dataloader in `get_dataloaders` (located in `train_models/torchvision_dataset_utils.py`). 29 | Go to the `train_models` folder. If you work on a cluster with a SLURM scheduler, running `sbatch train_inaturalist.sh` will train the model on the Places365 dataset. 0.5 of the dataset is reserved for validation. The softmax scores and labels for the validation dataset will be saved to `train_models/.cache/`. -------------------------------------------------------------------------------- /generate_scores/download_cifar.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data 2 | cd data 3 | wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz 4 | tar -xzvf cifar-100-python.tar.gz 5 | cd .. -------------------------------------------------------------------------------- /generate_scores/train_imagenet/convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import tensorflow as tf 7 | 8 | from resnet import get_resnet, name_to_params 9 | 10 | parser = argparse.ArgumentParser(description='SimCLR converter') 11 | parser.add_argument('tf_path', type=str, help='path of the input tensorflow file (ex: model.ckpt-250228)') 12 | parser.add_argument('--ema', action='store_true') 13 | parser.add_argument('--supervised', action='store_true') 14 | args = parser.parse_args() 15 | 16 | 17 | def main(): 18 | use_ema_model = args.ema 19 | prefix = ('ema_model/' if use_ema_model else '') + 'base_model/' 20 | head_prefix = ('ema_model/' if use_ema_model else '') + 'head_contrastive/' 21 | # 1. read tensorflow weight into a python dict 22 | vars_list = [] 23 | contrastive_vars = [] 24 | for v in tf.train.list_variables(args.tf_path): 25 | if v[0].startswith(prefix) and not v[0].endswith('/Momentum'): 26 | vars_list.append(v[0]) 27 | elif v[0] in {'head_supervised/linear_layer/dense/bias', 'head_supervised/linear_layer/dense/kernel'}: 28 | vars_list.append(v[0]) 29 | elif v[0].startswith(head_prefix) and not v[0].endswith('/Momentum'): 30 | contrastive_vars.append(v[0]) 31 | 32 | sd = {} 33 | ckpt_reader = tf.train.load_checkpoint(args.tf_path) 34 | for v in vars_list: 35 | sd[v] = ckpt_reader.get_tensor(v) 36 | 37 | split_idx = 2 if use_ema_model else 1 38 | # 2. convert the state_dict to PyTorch format 39 | conv_keys = [k for k in sd.keys() if k.split('/')[split_idx].split('_')[0] == 'conv2d'] 40 | conv_idx = [] 41 | for k in conv_keys: 42 | mid = k.split('/')[split_idx] 43 | if len(mid) == 6: 44 | conv_idx.append(0) 45 | else: 46 | conv_idx.append(int(mid[7:])) 47 | arg_idx = np.argsort(conv_idx) 48 | conv_keys = [conv_keys[idx] for idx in arg_idx] 49 | 50 | bn_keys = list(set([k.split('/')[split_idx] for k in sd.keys() 51 | if k.split('/')[split_idx].split('_')[0] == 'batch'])) 52 | bn_idx = [] 53 | for k in bn_keys: 54 | if len(k.split('_')) == 2: 55 | bn_idx.append(0) 56 | else: 57 | bn_idx.append(int(k.split('_')[2])) 58 | arg_idx = np.argsort(bn_idx) 59 | bn_keys = [bn_keys[idx] for idx in arg_idx] 60 | 61 | depth, width, sk_ratio = name_to_params(args.tf_path) 62 | model, head = get_resnet(depth, width, sk_ratio) 63 | 64 | conv_op = [] 65 | bn_op = [] 66 | for m in model.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | conv_op.append(m) 69 | elif isinstance(m, nn.BatchNorm2d): 70 | bn_op.append(m) 71 | assert len(vars_list) == (len(conv_op) + len(bn_op) * 4 + 2) # 2 for fc 72 | 73 | for i_conv in range(len(conv_keys)): 74 | m = conv_op[i_conv] 75 | w = torch.from_numpy(sd[conv_keys[i_conv]]).permute(3, 2, 0, 1) 76 | assert w.shape == m.weight.shape, f'size mismatch {w.shape} <> {m.weight.shape}' 77 | m.weight.data = w 78 | 79 | for i_bn in range(len(bn_keys)): 80 | m = bn_op[i_bn] 81 | gamma = torch.from_numpy(sd[prefix + bn_keys[i_bn] + '/gamma']) 82 | assert m.weight.shape == gamma.shape, f'size mismatch {gamma.shape} <> {m.weight.shape}' 83 | m.weight.data = gamma 84 | m.bias.data = torch.from_numpy(sd[prefix + bn_keys[i_bn] + '/beta']) 85 | m.running_mean = torch.from_numpy(sd[prefix + bn_keys[i_bn] + '/moving_mean']) 86 | m.running_var = torch.from_numpy(sd[prefix + bn_keys[i_bn] + '/moving_variance']) 87 | 88 | w = torch.from_numpy(sd['head_supervised/linear_layer/dense/kernel']).t() 89 | assert model.fc.weight.shape == w.shape 90 | model.fc.weight.data = w 91 | b = torch.from_numpy(sd['head_supervised/linear_layer/dense/bias']) 92 | assert model.fc.bias.shape == b.shape 93 | model.fc.bias.data = b 94 | 95 | if args.supervised: 96 | save_location = f'r{depth}_{width}x_sk{1 if sk_ratio != 0 else 0}{"_ema" if use_ema_model else ""}.pth' 97 | torch.save({'resnet': model.state_dict(), 'head': head.state_dict()}, save_location) 98 | return 99 | sd = {} 100 | for v in contrastive_vars: 101 | sd[v] = ckpt_reader.get_tensor(v) 102 | linear_op = [] 103 | bn_op = [] 104 | for m in head.modules(): 105 | if isinstance(m, nn.Linear): 106 | linear_op.append(m) 107 | elif isinstance(m, nn.BatchNorm1d): 108 | bn_op.append(m) 109 | for i, (l, m) in enumerate(zip(linear_op, bn_op)): 110 | l.weight.data = torch.from_numpy(sd[f'{head_prefix}nl_{i}/dense/kernel']).t() 111 | common_prefix = f'{head_prefix}nl_{i}/batch_normalization/' 112 | m.weight.data = torch.from_numpy(sd[f'{common_prefix}gamma']) 113 | if i != 2: 114 | m.bias.data = torch.from_numpy(sd[f'{common_prefix}beta']) 115 | m.running_mean = torch.from_numpy(sd[f'{common_prefix}moving_mean']) 116 | m.running_var = torch.from_numpy(sd[f'{common_prefix}moving_variance']) 117 | 118 | # 3. dump the PyTorch weights. 119 | save_location = f'r{depth}_{width}x_sk{1 if sk_ratio != 0 else 0}{"_ema" if use_ema_model else ""}.pth' 120 | torch.save({'resnet': model.state_dict(), 'head': head.state_dict()}, save_location) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /generate_scores/train_imagenet/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from math import ceil 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | available_simclr_models = ['r50_1x_sk0', 'r50_1x_sk1', 'r50_2x_sk0', 'r50_2x_sk1', 9 | 'r101_1x_sk0', 'r101_1x_sk1', 'r101_2x_sk0', 'r101_2x_sk1', 10 | 'r152_1x_sk0', 'r152_1x_sk1', 'r152_2x_sk0', 'r152_2x_sk1', 'r152_3x_sk1'] 11 | simclr_base_url = 'https://storage.googleapis.com/simclr-checkpoints/simclrv2/{category}/{model}/' 12 | files = ['checkpoint', 'graph.pbtxt', 'model.ckpt-{category}.data-00000-of-00001', 13 | 'model.ckpt-{category}.index', 'model.ckpt-{category}.meta'] 14 | simclr_categories = {'finetuned_100pct': 37535, 'finetuned_10pct': 3754, 15 | 'finetuned_1pct': 751, 'pretrained': 250228, 'supervised': 28151} 16 | chunk_size = 1024 * 8 17 | 18 | 19 | def download(url, destination): 20 | if os.path.exists(destination): 21 | return 22 | response = requests.get(url, stream=True) 23 | save_response_content(response, destination) 24 | 25 | 26 | def save_response_content(response, destination): 27 | if 'Content-length' in response.headers: 28 | total = int(ceil(int(response.headers['Content-length']) / chunk_size)) 29 | else: 30 | total = None 31 | with open(destination, 'wb') as f: 32 | for data in tqdm(response.iter_content(chunk_size=chunk_size), leave=False, total=total): 33 | f.write(data) 34 | 35 | 36 | def run(): 37 | parser = argparse.ArgumentParser(description='Model Downloader') 38 | parser.add_argument('model', type=str, choices=available_simclr_models) 39 | parser.add_argument('--simclr_category', type=str, choices=list(simclr_categories.keys()), default='pretrained') 40 | args = parser.parse_args() 41 | model = args.model 42 | os.makedirs(model, exist_ok=True) 43 | url = simclr_base_url.format(model=model, category=args.simclr_category) 44 | model_category = simclr_categories[args.simclr_category] 45 | for file in tqdm(files): 46 | f = file.format(category=model_category) 47 | download(url + f, os.path.join(model, f)) 48 | 49 | 50 | if __name__ == '__main__': 51 | run() 52 | -------------------------------------------------------------------------------- /generate_scores/train_imagenet/get_simclr_representations.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Based on https://github.com/ae-foster/pytorch-simclr/blob/simclr-master/gradient_linear_clf.py 3 | ''' 4 | 5 | 6 | import os 7 | import numpy as np 8 | import argparse 9 | from collections import Counter 10 | import pdb 11 | 12 | import torch 13 | import torchvision 14 | from PIL import Image 15 | from tqdm import tqdm 16 | import torchvision.transforms as transforms 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from resnet import get_resnet, name_to_params 20 | 21 | #class ImagenetValidationDataset(Dataset): 22 | # def __init__(self, val_path, ground_truth_path): # Modified to take in separate path for ground truth 23 | # super().__init__() 24 | # self.val_path = val_path 25 | # self.transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]) 26 | # with open(ground_truth_path) as f: 27 | # self.labels = [int(l) - 1 for l in f.readlines()] 28 | # 29 | # def __len__(self): 30 | # return len(self.labels) 31 | # 32 | # def __getitem__(self, item): 33 | # img = Image.open(os.path.join(self.val_path, f'ILSVRC2012_val_{item + 1:08d}.JPEG')).convert('RGB') 34 | # return self.transform(img), self.labels[item] 35 | 36 | 37 | 38 | @torch.no_grad() 39 | def run(pth_path, train_or_val, batch_size): 40 | device = 'cuda' 41 | path = f'/home/tding/data/imagenet/{train_or_val}' 42 | # if train_or_val == 'train': 43 | # path = '/home/eecs/tiffany_ding/data/imagenet/train' 44 | # else: 45 | # path = 46 | # path = '/data/imagenetwhole/ilsvrc2012/val' 47 | # if not os.path.isdir(path): 48 | # print('ERROR: TO access ImageNet-val, use the node "ace". Exiting...') 49 | # exit() 50 | 51 | 52 | print(f'Loading data from {path}') 53 | dataset = torchvision.datasets.ImageFolder(path, 54 | transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])) 55 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=0) 56 | net, _ = get_resnet(*name_to_params(pth_path)) # renamed model --> net 57 | 58 | print('==> loading encoder from checkpoint..') 59 | net.load_state_dict(torch.load(pth_path)['resnet']) 60 | 61 | print('Number of GPUs available:', torch.cuda.device_count()) 62 | if device == 'cuda': 63 | net = torch.nn.DataParallel(net) 64 | torch.backends.cudnn.benchmark = True 65 | 66 | net = net.to(device) 67 | net.eval() 68 | 69 | features = [] 70 | labels = [] 71 | 72 | t = tqdm(enumerate(dataloader), total=len(dataloader), bar_format='{desc}{bar}{r_bar}') 73 | for batch_idx, (inputs, targets) in t: 74 | inputs = inputs.to(device) 75 | representation = net(inputs) 76 | features += [representation.cpu()] 77 | labels += [targets.cpu()] 78 | # break # TO TEST 79 | 80 | features = torch.cat(features,dim=0) 81 | labels = torch.cat(labels,dim=0) 82 | 83 | save_to = f'/home/eecs/tiffany_ding/code/SimCLRv2-Pytorch/.cache/simclr_representations/imagenet_{train_or_val}' 84 | torch.save(features,save_to + '_features.pt') 85 | torch.save(labels,save_to + '_labels.pt') 86 | print(f'Saved features and labels to {save_to}_{{features, labels}}.pt') 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser(description='Get SimCLR representations') 90 | parser.add_argument('dataset', type=str, help='which ImageNet dataset to use (train or val)') 91 | parser.add_argument('--pth_path', type=str, default='r152_3x_sk1.pth', help='path of the input checkpoint file') 92 | parser.add_argument("--batch_size", type=int, default=64, help='batch size') 93 | args = parser.parse_args() 94 | run(args.pth_path, args.dataset, args.batch_size) 95 | 96 | 97 | ##### 98 | 99 | 100 | -------------------------------------------------------------------------------- /generate_scores/train_imagenet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | BATCH_NORM_EPSILON = 1e-5 6 | BATCH_NORM_DECAY = 0.9 # == pytorch's default value as well 7 | 8 | 9 | class BatchNormRelu(nn.Sequential): 10 | def __init__(self, num_channels, relu=True): 11 | super().__init__(nn.BatchNorm2d(num_channels, eps=BATCH_NORM_EPSILON), nn.ReLU() if relu else nn.Identity()) 12 | 13 | 14 | def conv(in_channels, out_channels, kernel_size=3, stride=1, bias=False): 15 | return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 16 | stride=stride, padding=(kernel_size - 1) // 2, bias=bias) 17 | 18 | 19 | class SelectiveKernel(nn.Module): 20 | def __init__(self, in_channels, out_channels, stride, sk_ratio, min_dim=32): 21 | super().__init__() 22 | assert sk_ratio > 0.0 23 | self.main_conv = nn.Sequential(conv(in_channels, 2 * out_channels, stride=stride), 24 | BatchNormRelu(2 * out_channels)) 25 | mid_dim = max(int(out_channels * sk_ratio), min_dim) 26 | self.mixing_conv = nn.Sequential(conv(out_channels, mid_dim, kernel_size=1), BatchNormRelu(mid_dim), 27 | conv(mid_dim, 2 * out_channels, kernel_size=1)) 28 | 29 | def forward(self, x): 30 | x = self.main_conv(x) 31 | x = torch.stack(torch.chunk(x, 2, dim=1), dim=0) # 2, B, C, H, W 32 | g = x.sum(dim=0).mean(dim=[2, 3], keepdim=True) 33 | m = self.mixing_conv(g) 34 | m = torch.stack(torch.chunk(m, 2, dim=1), dim=0) # 2, B, C, 1, 1 35 | return (x * F.softmax(m, dim=0)).sum(dim=0) 36 | 37 | 38 | class Projection(nn.Module): 39 | def __init__(self, in_channels, out_channels, stride, sk_ratio=0): 40 | super().__init__() 41 | if sk_ratio > 0: 42 | self.shortcut = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), 43 | # kernel_size = 2 => padding = 1 44 | nn.AvgPool2d(kernel_size=2, stride=stride, padding=0), 45 | conv(in_channels, out_channels, kernel_size=1)) 46 | else: 47 | self.shortcut = conv(in_channels, out_channels, kernel_size=1, stride=stride) 48 | self.bn = BatchNormRelu(out_channels, relu=False) 49 | 50 | def forward(self, x): 51 | return self.bn(self.shortcut(x)) 52 | 53 | 54 | class BottleneckBlock(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, in_channels, out_channels, stride, sk_ratio=0, use_projection=False): 58 | super().__init__() 59 | if use_projection: 60 | self.projection = Projection(in_channels, out_channels * 4, stride, sk_ratio) 61 | else: 62 | self.projection = nn.Identity() 63 | ops = [conv(in_channels, out_channels, kernel_size=1), BatchNormRelu(out_channels)] 64 | if sk_ratio > 0: 65 | ops.append(SelectiveKernel(out_channels, out_channels, stride, sk_ratio)) 66 | else: 67 | ops.append(conv(out_channels, out_channels, stride=stride)) 68 | ops.append(BatchNormRelu(out_channels)) 69 | ops.append(conv(out_channels, out_channels * 4, kernel_size=1)) 70 | ops.append(BatchNormRelu(out_channels * 4, relu=False)) 71 | self.net = nn.Sequential(*ops) 72 | 73 | def forward(self, x): 74 | shortcut = self.projection(x) 75 | return F.relu(shortcut + self.net(x)) 76 | 77 | 78 | class Blocks(nn.Module): 79 | def __init__(self, num_blocks, in_channels, out_channels, stride, sk_ratio=0): 80 | super().__init__() 81 | self.blocks = nn.ModuleList([BottleneckBlock(in_channels, out_channels, stride, sk_ratio, True)]) 82 | self.channels_out = out_channels * BottleneckBlock.expansion 83 | for _ in range(num_blocks - 1): 84 | self.blocks.append(BottleneckBlock(self.channels_out, out_channels, 1, sk_ratio)) 85 | 86 | def forward(self, x): 87 | for b in self.blocks: 88 | x = b(x) 89 | return x 90 | 91 | 92 | class Stem(nn.Sequential): 93 | def __init__(self, sk_ratio, width_multiplier): 94 | ops = [] 95 | channels = 64 * width_multiplier // 2 96 | if sk_ratio > 0: 97 | ops.append(conv(3, channels, stride=2)) 98 | ops.append(BatchNormRelu(channels)) 99 | ops.append(conv(channels, channels)) 100 | ops.append(BatchNormRelu(channels)) 101 | ops.append(conv(channels, channels * 2)) 102 | else: 103 | ops.append(conv(3, channels * 2, kernel_size=7, stride=2)) 104 | ops.append(BatchNormRelu(channels * 2)) 105 | ops.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 106 | super().__init__(*ops) 107 | 108 | 109 | class ResNet(nn.Module): 110 | def __init__(self, layers, width_multiplier, sk_ratio): 111 | super().__init__() 112 | ops = [Stem(sk_ratio, width_multiplier)] 113 | channels_in = 64 * width_multiplier 114 | ops.append(Blocks(layers[0], channels_in, 64 * width_multiplier, 1, sk_ratio)) 115 | channels_in = ops[-1].channels_out 116 | ops.append(Blocks(layers[1], channels_in, 128 * width_multiplier, 2, sk_ratio)) 117 | channels_in = ops[-1].channels_out 118 | ops.append(Blocks(layers[2], channels_in, 256 * width_multiplier, 2, sk_ratio)) 119 | channels_in = ops[-1].channels_out 120 | ops.append(Blocks(layers[3], channels_in, 512 * width_multiplier, 2, sk_ratio)) 121 | channels_in = ops[-1].channels_out 122 | self.channels_out = channels_in 123 | #breakpoint() 124 | self.net = nn.Sequential(*ops) 125 | self.fc = nn.Linear(channels_in, 1000) 126 | 127 | def forward(self, x, apply_fc=False): 128 | h = self.net(x).mean(dim=[2, 3]) 129 | if apply_fc: 130 | h = self.fc(h) 131 | return h 132 | 133 | 134 | class ContrastiveHead(nn.Module): 135 | def __init__(self, channels_in, out_dim=128, num_layers=3): 136 | super().__init__() 137 | self.layers = nn.ModuleList() 138 | for i in range(num_layers): 139 | if i != num_layers - 1: 140 | dim, relu = channels_in, True 141 | else: 142 | dim, relu = out_dim, False 143 | self.layers.append(nn.Linear(channels_in, dim, bias=False)) 144 | bn = nn.BatchNorm1d(dim, eps=BATCH_NORM_EPSILON, affine=True) 145 | if i == num_layers - 1: 146 | nn.init.zeros_(bn.bias) 147 | self.layers.append(bn) 148 | if relu: 149 | self.layers.append(nn.ReLU()) 150 | 151 | def forward(self, x): 152 | for b in self.layers: 153 | x = b(x) 154 | return x 155 | 156 | 157 | def get_resnet(depth=50, width_multiplier=1, sk_ratio=0): # sk_ratio=0.0625 is recommended 158 | layers = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}[depth] 159 | resnet = ResNet(layers, width_multiplier, sk_ratio) 160 | return resnet, ContrastiveHead(resnet.channels_out) 161 | 162 | 163 | def name_to_params(checkpoint): 164 | sk_ratio = 0.0625 if '_sk1' in checkpoint else 0 165 | if 'r50_' in checkpoint: 166 | depth = 50 167 | elif 'r101_' in checkpoint: 168 | depth = 101 169 | elif 'r152_' in checkpoint: 170 | depth = 152 171 | else: 172 | raise NotImplementedError 173 | 174 | if '_1x_' in checkpoint: 175 | width = 1 176 | elif '_2x_' in checkpoint: 177 | width = 2 178 | elif '_3x_' in checkpoint: 179 | width = 3 180 | else: 181 | raise NotImplementedError 182 | 183 | return depth, width, sk_ratio 184 | -------------------------------------------------------------------------------- /generate_scores/train_imagenet/run_get_simclr_representations.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run this file using "sbatch my_script.sh" 4 | 5 | # the SBATCH directives must appear before any executable 6 | # line in this script 7 | 8 | #SBATCH -p rise # partition (queue) 9 | #SBATCH --cpus-per-task=72 # number of cores per task 10 | # I think gpu:8 will request 8 of any kind of gpu per node, 11 | # and gpu:v100_32:8 should request 8 v100_32 per node 12 | #SBATCH --mem-per-cpu=64G 13 | #SBATCH --gres=gpu:1 14 | #SBATCH -w como # Request como specifically 15 | #SBATCH --exclude=freddie,flaminio,blaze # nodes not yet on SLURM-only 16 | #SBATCH -t 0-48:00 # time requested (D-HH:MM) 17 | # slurm will cd to this directory before running the script 18 | # you can also just run sbatch submit.sh from the directory 19 | # you want to be in 20 | #SBATCH -D /home/eecs/tiffany_ding/code/SimCLRv2-Pytorch 21 | # use these two lines to control the output file. Default is 22 | # slurm-.out. By default stdout and stderr go to the same 23 | # place, but if you use both commands below they'll be split up 24 | # filename patterns here: https://slurm.schedmd.com/sbatch.html 25 | # %N is the hostname (if used, will create output(s) per node) 26 | # %j is jobid 27 | #SBATCH -o /home/eecs/tiffany_ding/slurm_output/simclr_repr_train.out # STDOUT 28 | #SBATCH -e /home/eecs/tiffany_ding/slurm_output/simclr_repr_train.err # STDERR 29 | # if you want to get emails as your jobs run/fail 30 | ##SBATCH --mail-type=NONE # Mail events (NONE, BEGIN, END, FAIL, ALL) 31 | ##SBATCH --mail-user=tiffany_ding@eecs.berkeley.edu # Where to send mail 32 | #seff $SLURM_JOBID 33 | # print some info for context 34 | pwd | xargs -I{} echo "Current directory:" {} 35 | hostname | xargs -I{} echo "Node:" {} 36 | python get_simclr_representations.py train --batch_size=400 37 | -------------------------------------------------------------------------------- /generate_scores/train_imagenet/run_train_linear_and_get_logits.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run this file using "sbatch my_script.sh" 4 | 5 | # the SBATCH directives must appear before any executable 6 | # line in this script 7 | 8 | #SBATCH -p rise # partition (queue) 9 | #SBATCH --cpus-per-task=72 # number of cores per task 10 | # I think gpu:8 will request 8 of any kind of gpu per node, 11 | # and gpu:v100_32:8 should request 8 v100_32 per node 12 | #SBATCH --mem-per-cpu=64G 13 | #SBATCH --gres=gpu:1 14 | #SBATCH -w ace # Request ace specifically 15 | #SBATCH --exclude=freddie,flaminio,blaze # nodes not yet on SLURM-only 16 | #SBATCH -t 0-48:00 # time requested (D-HH:MM) 17 | # slurm will cd to this directory before running the script 18 | # you can also just run sbatch submit.sh from the directory 19 | # you want to be in 20 | #SBATCH -D /home/eecs/tiffany_ding/code/SimCLRv2-Pytorch 21 | # use these two lines to control the output file. Default is 22 | # slurm-.out. By default stdout and stderr go to the same 23 | # place, but if you use both commands below they'll be split up 24 | # filename patterns here: https://slurm.schedmd.com/sbatch.html 25 | # %N is the hostname (if used, will create output(s) per node) 26 | # %j is jobid 27 | #SBATCH -o /home/eecs/tiffany_ding/slurm_output/simclr_repr_train.out # STDOUT 28 | #SBATCH -e /home/eecs/tiffany_ding/slurm_output/simclr_repr_train.err # STDERR 29 | # if you want to get emails as your jobs run/fail 30 | ##SBATCH --mail-type=NONE # Mail events (NONE, BEGIN, END, FAIL, ALL) 31 | ##SBATCH --mail-user=tiffany_ding@eecs.berkeley.edu # Where to send mail 32 | #seff $SLURM_JOBID 33 | # print some info for context 34 | pwd | xargs -I{} echo "Current directory:" {} 35 | hostname | xargs -I{} echo "Node:" {} 36 | python train_linear_and_get_logits.py 37 | -------------------------------------------------------------------------------- /generate_scores/train_imagenet/train_linear_and_get_logits.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from collections import Counter 4 | 5 | import torch 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from resnet import get_resnet, name_to_params 12 | import pdb 13 | 14 | 15 | device = 'cuda' # Other option: 'cpu' 16 | 17 | 18 | def test(testloader, device, clf): 19 | criterion = torch.nn.CrossEntropyLoss() 20 | clf.eval() 21 | test_clf_loss = 0 22 | correct = 0 23 | total = 0 24 | with torch.no_grad(): 25 | t = tqdm(enumerate(testloader), total=len(testloader), desc='Loss: **** | Test Acc: ****% ', 26 | bar_format='{desc}{bar}{r_bar}') 27 | for batch_idx, (features, targets) in t: 28 | features, targets = features.to(device), targets.to(device) 29 | logits = clf(features) 30 | clf_loss = criterion(logits, targets) 31 | 32 | 33 | test_clf_loss += clf_loss.item() 34 | _, predicted = logits.max(1) 35 | total += targets.size(0) 36 | correct += predicted.eq(targets).sum().item() 37 | 38 | t.set_description('Loss: %.3f | Test Acc: %.3f%% ' % (test_clf_loss / (batch_idx + 1), 100. * correct / total)) 39 | 40 | acc = 100. * correct / total 41 | return acc 42 | 43 | def get_train_and_val_dataloaders(batch_size=64): 44 | ''' 45 | Returns 2 DataLoaders, one containing all of ImageNet-train and another containing all of ImageNet-val 46 | ''' 47 | dataloader_list = [] 48 | for dataset_name in ['train', 'val']: 49 | print(f'Loading SimCLR representations for ImageNet {dataset_name}...') 50 | representation_location = f'/home/eecs/tiffany_ding/code/SimCLRv2-Pytorch/.cache/simclr_representations/imagenet_{dataset_name}' 51 | features = torch.load(representation_location+'_features.pt') 52 | labels = torch.load(representation_location+'_labels.pt') 53 | # breakpoint() 54 | print('Dimension of features:', features.shape) 55 | dataset = torch.utils.data.TensorDataset(features,labels) 56 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0) 57 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0) 58 | dataloader_list.append(dataloader) 59 | 60 | return dataloader_list[0], dataloader_list[1] 61 | 62 | def get_split_dataloaders(dataset_name, train_split, test_split='default', no_train=False): 63 | ''' 64 | Load in either ImageNet train or val representations and split the dataset into two or three splits 65 | ''' 66 | 67 | representation_location = f'/home/eecs/tiffany_ding/code/SimCLRv2-Pytorch/.cache/simclr_representations/imagenet_{dataset_name}' 68 | #dataset_name = 'train' # 'train' or 'val' 69 | #train_split = 0.7 # what fraction of data to use for training 70 | 71 | # I moved imagenet_train_features.pt to /data/tiffany_ding on `ace` 72 | if dataset_name == 'train': 73 | features = torch.load('/data/tiffany_ding/imagenet_train_features.pt') 74 | else: 75 | features = torch.load(representation_location+'_features.pt') 76 | labels = torch.load(representation_location+'_labels.pt') 77 | 78 | clfdataset = torch.utils.data.TensorDataset(features,labels) 79 | train_size = int(len(clfdataset) * train_split) 80 | if test_split == 'default': # Everything that is not in train will be included in test 81 | test_size = len(clfdataset) - train_size 82 | calib_size = 0 83 | splits = [train_size, test_size, calib_size] 84 | train_dataset, test_dataset, calibration_dataset = torch.utils.data.random_split(clfdataset, splits, generator=torch.Generator().manual_seed(0)) 85 | else: 86 | test_size = int(len(clfdataset) * test_split) 87 | splits = [train_size, test_size, len(clfdataset) - (train_size + test_size)] 88 | train_dataset, test_dataset, calibration_dataset = torch.utils.data.random_split(clfdataset, splits, generator=torch.Generator().manual_seed(0)) 89 | 90 | if no_train: 91 | clftrainloader = None 92 | else: 93 | clftrainloader = DataLoader(train_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=0) 94 | 95 | testloader = DataLoader(test_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=0) 96 | calibloader = DataLoader(calibration_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=0) 97 | 98 | print(f'Size of classifier training set: {train_size}') 99 | print(f'Size of classifier test set: {test_size}') 100 | print(f'Size of calibration/evaluation set: {len(calibration_dataset)}') 101 | return clftrainloader, testloader, calibloader 102 | 103 | 104 | def get_logits(save_prefix, num_epochs=90, weights_path=None, get_train_logits=False): 105 | 106 | clf_train_split = 0.1 # What fraction of ImageNet train we should use to train the classifier 107 | clf_test_split = 0.01 # What fraction of ImageNet train we should use to compute classifier accuracy 108 | 109 | if weights_path is None: 110 | # Train classifier 111 | 112 | clftrainloader, clftestloader, calib_loader = get_split_dataloaders('train', train_split=clf_train_split, test_split=clf_test_split) 113 | weights_path = '.cache/trained_classifiers/train-0.1.pt' 114 | clf = train(clftrainloader, clftestloader, weights_path, num_epochs, learning_rate=.001) 115 | else: 116 | # Load classifier weights 117 | simclr_feature_dim = 6144 118 | num_classes = 1000 119 | clf = torch.nn.Linear(simclr_feature_dim, num_classes) 120 | clf.load_state_dict(torch.load(weights_path)) 121 | clf.to(device) 122 | 123 | if get_train_logits: 124 | clftrainloader, clftestloader, calib_loader = get_split_dataloaders('train', train_split=clf_train_split, test_split=clf_test_split) 125 | else: 126 | _, clftestloader, calib_loader = get_split_dataloaders('train', train_split=clf_train_split, test_split=clf_test_split, no_train=True) 127 | 128 | 129 | with torch.no_grad(): 130 | logits = [] 131 | labels = [] 132 | 133 | print('Computing logits...') 134 | # We only compute logits for the data we haven't already used to train the classifier 135 | for loader in [clftestloader, calib_loader]: 136 | t = tqdm(enumerate(loader), total=len(loader), desc='Batch:') 137 | for batch_idx, (features, targets) in t: 138 | features, targets = features.to(device), targets.to(device) 139 | curr_logits = clf(features) 140 | 141 | logits += [curr_logits] 142 | labels += [targets] 143 | 144 | # Concatenate 145 | logits = torch.cat(logits,dim=0) 146 | labels = torch.cat(labels,dim=0) 147 | 148 | # Save test+cal logits 149 | torch.save(logits,save_prefix + '_logits.pt') 150 | torch.save(labels,save_prefix + '_labels.pt') 151 | print(f'Saved logits to', save_prefix + '_logits.pt') 152 | print(f'Saved labels to', save_prefix + '_labels.pt') 153 | 154 | # Optionally, we can also compute the logits for the training data 155 | if get_train_logits: 156 | as_softmax = True 157 | loader = clftrainloader 158 | 159 | logits = [] 160 | labels = [] 161 | 162 | t = tqdm(enumerate(loader), total=len(loader), desc='Batch:') 163 | for batch_idx, (features, targets) in t: 164 | features, targets = features.to(device), targets.to(device) 165 | curr_logits = clf(features) 166 | 167 | logits += [curr_logits] 168 | labels += [targets] 169 | 170 | logits = torch.cat(logits,dim=0) 171 | labels = torch.cat(labels,dim=0) 172 | 173 | 174 | # Save logits 175 | torch.save(logits,save_prefix + '_logits_TRAIN.pt') 176 | torch.save(labels,save_prefix + '_labels_TRAIN.pt') 177 | print(f'Saved logits to', save_prefix + '_logits_TRAIN.pt') 178 | print(f'Saved labels to', save_prefix + '_labels_TRAIN.pt') 179 | 180 | 181 | 182 | def train(clftrainloader, clftestloader, save_to, num_epochs, learning_rate=.001): 183 | 184 | print(f'After training, weights will be saved to {save_to}') 185 | 186 | simclr_feature_dim = 6144 187 | num_classes = 1000 188 | clf = torch.nn.Linear(simclr_feature_dim, num_classes).to(device) 189 | clf.train() 190 | 191 | criterion = torch.nn.CrossEntropyLoss() 192 | clf_optimizer = torch.optim.Adam(clf.parameters(), lr=learning_rate, weight_decay=1e-6) 193 | 194 | print(f'Training for {num_epochs} epochs') 195 | for epoch in range(num_epochs): 196 | print('Epoch', epoch) 197 | train_loss = 0 198 | t = tqdm(enumerate(clftrainloader), desc='Loss: **** ', total=len(clftrainloader), bar_format='{desc}{bar}{r_bar}') 199 | for batch_idx, (features, targets) in t: 200 | clf_optimizer.zero_grad() 201 | features, targets = features.to(device), targets.to(device) 202 | predictions = clf(features) 203 | loss = criterion(predictions, targets) 204 | loss.backward() 205 | clf_optimizer.step() 206 | 207 | train_loss += loss.item() 208 | 209 | t.set_description('Loss: %.3f ' % (train_loss / (batch_idx + 1))) 210 | 211 | acc = test(clftestloader, device, clf) 212 | print(f"Accuracy: {acc:.3f}%") 213 | 214 | # Save trained classifier weights 215 | save_to = save_to + f'acc={acc / 100:.4f}.pt' 216 | torch.save(clf.state_dict(), save_to) 217 | print(f'Saved classifier weights to {save_to}') 218 | 219 | return clf 220 | 221 | 222 | def run(args): 223 | 224 | # Set location to save weights 225 | save_prefix = 'train-val' # UPDATE THIS AS NECESSARY 226 | #save_prefix = 'train-0.7' 227 | save_to = f'.cache/trained_classifiers/{save_prefix}_epochs={args.num_epochs}' 228 | print(f'After training, weights will be saved to {save_to}[...]') 229 | 230 | # OPTION 1: Train classifier and save weights 231 | # Load data 232 | # # clftrainloader, clftestloader, _ = get_split_dataloaders("train", train_split=0.7) 233 | # clftrainloader, clftestloader = get_train_and_val_dataloaders(batch_size=64) 234 | # train(clftrainloader, clftestloader, save_to, args.num_epochs, learning_rate=.01) 235 | 236 | 237 | # # OPTION 2: Train classifier and apply classifier to get logits for data not used to train 238 | # save_prefix = f'.cache/logits/imagenet_train_subset' 239 | # get_logits(save_prefix, num_epochs=args.num_epochs, weights_path=None) 240 | 241 | # OPTION 3: Load pretrained classifier weights and apply classifier for data not used to train 242 | save_prefix = f'.cache/logits/imagenet_train_subset' 243 | weights_path = f'.cache/trained_classifiers/train-all_epochs=10.pt' 244 | # get_logits(save_prefix, weights_path=weights_path) 245 | get_logits(save_prefix, weights_path=weights_path, get_train_logits=True) # Get logits for training data too 246 | 247 | 248 | 249 | if __name__ == '__main__': 250 | parser = argparse.ArgumentParser(description='Train downstream classifier with gradients.') 251 | parser.add_argument('--num_epochs', default=10, type=int, help='Number of epochs') 252 | run(parser.parse_args()) 253 | 254 | -------------------------------------------------------------------------------- /generate_scores/train_models/cifar-100.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "bc1280f6", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "import numpy as np\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import pickle\n", 15 | "import os, time, copy\n", 16 | "import torch\n", 17 | "import torchvision as tv\n", 18 | "import torchvision.transforms as transforms\n", 19 | "import torchvision.datasets as datasets\n", 20 | "from torch.utils.data import TensorDataset\n", 21 | "from torchvision.models import resnet50\n", 22 | "import torch.optim as optim\n", 23 | "import torch.nn as nn\n", 24 | "from cifar_utils import get_model, get_data, show_img" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "d0709702", 30 | "metadata": {}, 31 | "source": [ 32 | "The easiest way to access a GPU for model training is to request one using srun or sbatch and run `python cifar_utils.py` (make sure to update the config in that file as appropriate)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "id": "73ea3e13", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Model parameters\n", 43 | "config = {\n", 44 | " 'num_classes' : 100,\n", 45 | " 'batch_size' : 128,\n", 46 | " 'lr' : 0.0001,\n", 47 | " 'feature_extract' : False, # If False, fine tune all layers. If True, fine tune last layer only\n", 48 | " 'num_epochs' : 30,\n", 49 | " 'device' : 'cpu',\n", 50 | " 'frac_val' : 0.5, # CHANGED FROM 0.3\n", 51 | " 'model_filename' : 'best-cifar100-model-fracval=0.7', # CHANGED FROM no suffix\n", 52 | " 'num_workers' : 4,\n", 53 | "}" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "id": "a7d0910f", 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Params to learn:\n", 67 | "\t conv1.weight\n", 68 | "\t bn1.weight\n", 69 | "\t bn1.bias\n", 70 | "\t layer1.0.conv1.weight\n", 71 | "\t layer1.0.bn1.weight\n", 72 | "\t layer1.0.bn1.bias\n", 73 | "\t layer1.0.conv2.weight\n", 74 | "\t layer1.0.bn2.weight\n", 75 | "\t layer1.0.bn2.bias\n", 76 | "\t layer1.0.conv3.weight\n", 77 | "\t layer1.0.bn3.weight\n", 78 | "\t layer1.0.bn3.bias\n", 79 | "\t layer1.0.downsample.0.weight\n", 80 | "\t layer1.0.downsample.1.weight\n", 81 | "\t layer1.0.downsample.1.bias\n", 82 | "\t layer1.1.conv1.weight\n", 83 | "\t layer1.1.bn1.weight\n", 84 | "\t layer1.1.bn1.bias\n", 85 | "\t layer1.1.conv2.weight\n", 86 | "\t layer1.1.bn2.weight\n", 87 | "\t layer1.1.bn2.bias\n", 88 | "\t layer1.1.conv3.weight\n", 89 | "\t layer1.1.bn3.weight\n", 90 | "\t layer1.1.bn3.bias\n", 91 | "\t layer1.2.conv1.weight\n", 92 | "\t layer1.2.bn1.weight\n", 93 | "\t layer1.2.bn1.bias\n", 94 | "\t layer1.2.conv2.weight\n", 95 | "\t layer1.2.bn2.weight\n", 96 | "\t layer1.2.bn2.bias\n", 97 | "\t layer1.2.conv3.weight\n", 98 | "\t layer1.2.bn3.weight\n", 99 | "\t layer1.2.bn3.bias\n", 100 | "\t layer2.0.conv1.weight\n", 101 | "\t layer2.0.bn1.weight\n", 102 | "\t layer2.0.bn1.bias\n", 103 | "\t layer2.0.conv2.weight\n", 104 | "\t layer2.0.bn2.weight\n", 105 | "\t layer2.0.bn2.bias\n", 106 | "\t layer2.0.conv3.weight\n", 107 | "\t layer2.0.bn3.weight\n", 108 | "\t layer2.0.bn3.bias\n", 109 | "\t layer2.0.downsample.0.weight\n", 110 | "\t layer2.0.downsample.1.weight\n", 111 | "\t layer2.0.downsample.1.bias\n", 112 | "\t layer2.1.conv1.weight\n", 113 | "\t layer2.1.bn1.weight\n", 114 | "\t layer2.1.bn1.bias\n", 115 | "\t layer2.1.conv2.weight\n", 116 | "\t layer2.1.bn2.weight\n", 117 | "\t layer2.1.bn2.bias\n", 118 | "\t layer2.1.conv3.weight\n", 119 | "\t layer2.1.bn3.weight\n", 120 | "\t layer2.1.bn3.bias\n", 121 | "\t layer2.2.conv1.weight\n", 122 | "\t layer2.2.bn1.weight\n", 123 | "\t layer2.2.bn1.bias\n", 124 | "\t layer2.2.conv2.weight\n", 125 | "\t layer2.2.bn2.weight\n", 126 | "\t layer2.2.bn2.bias\n", 127 | "\t layer2.2.conv3.weight\n", 128 | "\t layer2.2.bn3.weight\n", 129 | "\t layer2.2.bn3.bias\n", 130 | "\t layer2.3.conv1.weight\n", 131 | "\t layer2.3.bn1.weight\n", 132 | "\t layer2.3.bn1.bias\n", 133 | "\t layer2.3.conv2.weight\n", 134 | "\t layer2.3.bn2.weight\n", 135 | "\t layer2.3.bn2.bias\n", 136 | "\t layer2.3.conv3.weight\n", 137 | "\t layer2.3.bn3.weight\n", 138 | "\t layer2.3.bn3.bias\n", 139 | "\t layer3.0.conv1.weight\n", 140 | "\t layer3.0.bn1.weight\n", 141 | "\t layer3.0.bn1.bias\n", 142 | "\t layer3.0.conv2.weight\n", 143 | "\t layer3.0.bn2.weight\n", 144 | "\t layer3.0.bn2.bias\n", 145 | "\t layer3.0.conv3.weight\n", 146 | "\t layer3.0.bn3.weight\n", 147 | "\t layer3.0.bn3.bias\n", 148 | "\t layer3.0.downsample.0.weight\n", 149 | "\t layer3.0.downsample.1.weight\n", 150 | "\t layer3.0.downsample.1.bias\n", 151 | "\t layer3.1.conv1.weight\n", 152 | "\t layer3.1.bn1.weight\n", 153 | "\t layer3.1.bn1.bias\n", 154 | "\t layer3.1.conv2.weight\n", 155 | "\t layer3.1.bn2.weight\n", 156 | "\t layer3.1.bn2.bias\n", 157 | "\t layer3.1.conv3.weight\n", 158 | "\t layer3.1.bn3.weight\n", 159 | "\t layer3.1.bn3.bias\n", 160 | "\t layer3.2.conv1.weight\n", 161 | "\t layer3.2.bn1.weight\n", 162 | "\t layer3.2.bn1.bias\n", 163 | "\t layer3.2.conv2.weight\n", 164 | "\t layer3.2.bn2.weight\n", 165 | "\t layer3.2.bn2.bias\n", 166 | "\t layer3.2.conv3.weight\n", 167 | "\t layer3.2.bn3.weight\n", 168 | "\t layer3.2.bn3.bias\n", 169 | "\t layer3.3.conv1.weight\n", 170 | "\t layer3.3.bn1.weight\n", 171 | "\t layer3.3.bn1.bias\n", 172 | "\t layer3.3.conv2.weight\n", 173 | "\t layer3.3.bn2.weight\n", 174 | "\t layer3.3.bn2.bias\n", 175 | "\t layer3.3.conv3.weight\n", 176 | "\t layer3.3.bn3.weight\n", 177 | "\t layer3.3.bn3.bias\n", 178 | "\t layer3.4.conv1.weight\n", 179 | "\t layer3.4.bn1.weight\n", 180 | "\t layer3.4.bn1.bias\n", 181 | "\t layer3.4.conv2.weight\n", 182 | "\t layer3.4.bn2.weight\n", 183 | "\t layer3.4.bn2.bias\n", 184 | "\t layer3.4.conv3.weight\n", 185 | "\t layer3.4.bn3.weight\n", 186 | "\t layer3.4.bn3.bias\n", 187 | "\t layer3.5.conv1.weight\n", 188 | "\t layer3.5.bn1.weight\n", 189 | "\t layer3.5.bn1.bias\n", 190 | "\t layer3.5.conv2.weight\n", 191 | "\t layer3.5.bn2.weight\n", 192 | "\t layer3.5.bn2.bias\n", 193 | "\t layer3.5.conv3.weight\n", 194 | "\t layer3.5.bn3.weight\n", 195 | "\t layer3.5.bn3.bias\n", 196 | "\t layer4.0.conv1.weight\n", 197 | "\t layer4.0.bn1.weight\n", 198 | "\t layer4.0.bn1.bias\n", 199 | "\t layer4.0.conv2.weight\n", 200 | "\t layer4.0.bn2.weight\n", 201 | "\t layer4.0.bn2.bias\n", 202 | "\t layer4.0.conv3.weight\n", 203 | "\t layer4.0.bn3.weight\n", 204 | "\t layer4.0.bn3.bias\n", 205 | "\t layer4.0.downsample.0.weight\n", 206 | "\t layer4.0.downsample.1.weight\n", 207 | "\t layer4.0.downsample.1.bias\n", 208 | "\t layer4.1.conv1.weight\n", 209 | "\t layer4.1.bn1.weight\n", 210 | "\t layer4.1.bn1.bias\n", 211 | "\t layer4.1.conv2.weight\n", 212 | "\t layer4.1.bn2.weight\n", 213 | "\t layer4.1.bn2.bias\n", 214 | "\t layer4.1.conv3.weight\n", 215 | "\t layer4.1.bn3.weight\n", 216 | "\t layer4.1.bn3.bias\n", 217 | "\t layer4.2.conv1.weight\n", 218 | "\t layer4.2.bn1.weight\n", 219 | "\t layer4.2.bn1.bias\n", 220 | "\t layer4.2.conv2.weight\n", 221 | "\t layer4.2.bn2.weight\n", 222 | "\t layer4.2.bn2.bias\n", 223 | "\t layer4.2.conv3.weight\n", 224 | "\t layer4.2.bn3.weight\n", 225 | "\t layer4.2.bn3.bias\n", 226 | "\t fc.weight\n", 227 | "\t fc.bias\n", 228 | "Epoch 0/29\n", 229 | "----------\n", 230 | "train Loss: 3.7631 Acc: 0.1689\n", 231 | "val Loss: 2.6022 Acc: 0.3606\n", 232 | "\n", 233 | "Epoch 1/29\n", 234 | "----------\n", 235 | "train Loss: 2.0157 Acc: 0.4753\n", 236 | "val Loss: 1.8349 Acc: 0.5079\n", 237 | "\n", 238 | "Epoch 2/29\n", 239 | "----------\n", 240 | "train Loss: 1.2808 Acc: 0.6487\n", 241 | "val Loss: 1.6775 Acc: 0.5454\n", 242 | "\n", 243 | "Epoch 3/29\n", 244 | "----------\n", 245 | "train Loss: 0.8410 Acc: 0.7673\n", 246 | "val Loss: 1.6588 Acc: 0.5647\n", 247 | "\n", 248 | "Epoch 4/29\n", 249 | "----------\n", 250 | "train Loss: 0.5408 Acc: 0.8571\n", 251 | "val Loss: 1.6504 Acc: 0.5784\n", 252 | "\n", 253 | "Epoch 5/29\n", 254 | "----------\n", 255 | "train Loss: 0.3501 Acc: 0.9127\n", 256 | "val Loss: 1.7247 Acc: 0.5827\n", 257 | "\n", 258 | "Epoch 6/29\n", 259 | "----------\n", 260 | "train Loss: 0.2378 Acc: 0.9456\n", 261 | "val Loss: 1.7234 Acc: 0.5889\n", 262 | "\n", 263 | "Epoch 7/29\n", 264 | "----------\n", 265 | "train Loss: 0.1623 Acc: 0.9646\n", 266 | "val Loss: 1.7923 Acc: 0.5907\n", 267 | "\n", 268 | "Epoch 8/29\n", 269 | "----------\n", 270 | "train Loss: 0.1198 Acc: 0.9744\n", 271 | "val Loss: 1.8636 Acc: 0.5910\n", 272 | "\n", 273 | "Epoch 9/29\n", 274 | "----------\n", 275 | "train Loss: 0.1034 Acc: 0.9778\n", 276 | "val Loss: 1.9842 Acc: 0.5821\n", 277 | "\n", 278 | "Epoch 10/29\n", 279 | "----------\n", 280 | "train Loss: 0.0949 Acc: 0.9782\n", 281 | "val Loss: 1.9945 Acc: 0.5859\n", 282 | "\n", 283 | "Epoch 11/29\n", 284 | "----------\n", 285 | "train Loss: 0.1030 Acc: 0.9759\n", 286 | "val Loss: 2.0277 Acc: 0.5912\n", 287 | "\n", 288 | "Epoch 12/29\n", 289 | "----------\n", 290 | "train Loss: 0.0888 Acc: 0.9782\n", 291 | "val Loss: 2.1023 Acc: 0.5889\n", 292 | "\n", 293 | "Epoch 13/29\n", 294 | "----------\n", 295 | "train Loss: 0.1023 Acc: 0.9738\n", 296 | "val Loss: 2.0661 Acc: 0.5906\n", 297 | "\n", 298 | "Epoch 14/29\n", 299 | "----------\n", 300 | "train Loss: 0.0824 Acc: 0.9792\n", 301 | "val Loss: 2.1193 Acc: 0.5921\n", 302 | "\n", 303 | "Epoch 15/29\n", 304 | "----------\n", 305 | "train Loss: 0.0715 Acc: 0.9810\n", 306 | "val Loss: 2.0828 Acc: 0.5881\n", 307 | "\n", 308 | "Epoch 16/29\n", 309 | "----------\n", 310 | "train Loss: 0.0603 Acc: 0.9841\n", 311 | "val Loss: 2.0802 Acc: 0.5953\n", 312 | "\n", 313 | "Epoch 17/29\n", 314 | "----------\n", 315 | "train Loss: 0.0534 Acc: 0.9858\n", 316 | "val Loss: 2.2534 Acc: 0.5940\n", 317 | "\n", 318 | "Epoch 18/29\n", 319 | "----------\n", 320 | "train Loss: 0.0493 Acc: 0.9869\n", 321 | "val Loss: 2.2264 Acc: 0.5930\n", 322 | "\n", 323 | "Epoch 19/29\n", 324 | "----------\n", 325 | "train Loss: 0.0535 Acc: 0.9854\n", 326 | "val Loss: 2.3266 Acc: 0.5909\n", 327 | "\n", 328 | "Epoch 20/29\n", 329 | "----------\n", 330 | "train Loss: 0.0598 Acc: 0.9835\n", 331 | "val Loss: 2.2506 Acc: 0.5923\n", 332 | "\n", 333 | "Epoch 21/29\n", 334 | "----------\n", 335 | "train Loss: 0.0443 Acc: 0.9875\n", 336 | "val Loss: 2.4576 Acc: 0.5931\n", 337 | "\n", 338 | "Epoch 22/29\n", 339 | "----------\n", 340 | "train Loss: 0.0379 Acc: 0.9899\n", 341 | "val Loss: 2.3325 Acc: 0.5998\n", 342 | "\n", 343 | "Epoch 23/29\n", 344 | "----------\n", 345 | "train Loss: 0.0473 Acc: 0.9868\n", 346 | "val Loss: 2.5001 Acc: 0.5946\n", 347 | "\n", 348 | "Epoch 24/29\n", 349 | "----------\n", 350 | "train Loss: 0.0418 Acc: 0.9887\n", 351 | "val Loss: 2.4391 Acc: 0.5938\n", 352 | "\n", 353 | "Epoch 25/29\n", 354 | "----------\n", 355 | "train Loss: 0.0419 Acc: 0.9888\n", 356 | "val Loss: 2.3594 Acc: 0.5964\n", 357 | "\n", 358 | "Epoch 26/29\n", 359 | "----------\n", 360 | "train Loss: 0.0363 Acc: 0.9900\n", 361 | "val Loss: 2.4542 Acc: 0.5992\n", 362 | "\n", 363 | "Epoch 27/29\n", 364 | "----------\n", 365 | "train Loss: 0.0355 Acc: 0.9899\n", 366 | "val Loss: 2.5777 Acc: 0.5904\n", 367 | "\n", 368 | "Epoch 28/29\n", 369 | "----------\n", 370 | "train Loss: 0.0411 Acc: 0.9882\n", 371 | "val Loss: 2.3777 Acc: 0.5949\n", 372 | "\n", 373 | "Epoch 29/29\n", 374 | "----------\n", 375 | "train Loss: 0.0383 Acc: 0.9890\n", 376 | "val Loss: 2.4047 Acc: 0.5958\n", 377 | "\n", 378 | "Training complete in 49m 4s\n", 379 | "Best val Acc: 0.599767\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "model = get_model(config)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 4, 390 | "id": "6b942258", 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "'''\n", 395 | "=== With frac_val = 0.7 === \n", 396 | "Epoch 29/29\n", 397 | "----------\n", 398 | "train Loss: 0.0295 Acc: 0.9938\n", 399 | "val Loss: 2.5458 Acc: 0.5470\n", 400 | "\n", 401 | "Training complete in 4m 43s\n", 402 | "Best val Acc: 0.546976\n", 403 | "'''\n", 404 | "\n", 405 | "'''\n", 406 | "=== With frac_val = 0.5 === <-- THIS IS WHAT WE USE\n", 407 | "Epoch 29/29\n", 408 | "----------\n", 409 | "train Loss: 0.0315 Acc: 0.9913\n", 410 | "val Loss: 2.3326 Acc: 0.5974\n", 411 | "\n", 412 | "Training complete in 6m 44s\n", 413 | "Best val Acc: 0.597367\n", 414 | "'''\n", 415 | "\n", 416 | "\n", 417 | "'''\n", 418 | "=== With frac_val = 0.3 === \n", 419 | "Epoch 0/0\n", 420 | "----------\n", 421 | "train Loss: 3.3722 Acc: 0.2335\n", 422 | "val Loss: 2.2128 Acc: 0.4491\n", 423 | "\n", 424 | "Training complete in 3m 27s\n", 425 | "Best val Acc: 0.449056\n", 426 | "\n", 427 | "\n", 428 | "Epoch 29/29\n", 429 | "----------\n", 430 | "train Loss: 0.0373 Acc: 0.9889\n", 431 | "val Loss: 2.4569 Acc: 0.6305\n", 432 | "\n", 433 | "Training complete in 64m 27s\n", 434 | "Best val Acc: 0.630556\n", 435 | "'''\n", 436 | "None" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 5, 442 | "id": "0a976e04", 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "val_imgs = np.load('./.cache/' + config['model_filename'] + f'-valdata_frac={config[\"frac_val\"]}.npy')\n", 447 | "val_labels = np.load('./.cache/' + config['model_filename'] + f'-vallabels_frac={config[\"frac_val\"]}.npy')" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 6, 453 | "id": "9d5c6e00", 454 | "metadata": {}, 455 | "outputs": [ 456 | { 457 | "data": { 458 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAY7ElEQVR4nO3c248kh1XH8dN16eru6bnubHZ2d9a7Kzv2anEuCiEXCAnEKA9IXBQeglCwBAIegD8HJMQDEJB4RCEQGVlWwLYSY2vjxEnY2N44u97L7M59enqmu6urq5qHSEe85fyQLS76fp7PHNVUVfev66F+rfl8PjcAAMws+Z8+AADA/x6EAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAARygAAFwWHXz5q1/VFmdpeLZdFNJuRVF0pPm6qePDTSPtntXx3XmeS7uTpCXNKxrxt8NkWoZnk3b4FjQzs/HpKDw7GAyk3Zl4ztvCvHo9lfs2SbTdzSw+OxwOpd3j8SQ8OxKupZlZ3tY+y1kav7dy4fvKzGwmfE9Mxtr/mQrHMhe/g774e7/zU2d4UgAAOEIBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgAuXg6yfXX/fDqKuhTIWM6umVXi2mcVnf/IH8S6RVOztUTqelO4oM7OWWH1UVULHk3h9eou9+HEI19LMrCja4dn1de2eFaqpzEytvtI6aqyJ/16bif03e3uH4dm7d+5Ku2+/cyc8+3BrR9qdJNrn7cLmhfDstWtPSbvPnF0Lz/b7C9JupSNN7T6K4EkBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgAu/N66+Tt3M4/NJS8umZj4Pz+ZiXURRxCsaxuOJtLssR+FZ9dX4qtLqIqR6iUS7PmkeryNIM626oCyn4dksy6XduTgv3OJW1dr12dk5CM/+SKiWMDP7wc23w7PbW4+k3Ts7u+HZ4fFY2l2OtbqVxaVb4dmHj+Ln28zss7/06fDsxY0VaXcifN7e+5ILnhQAAP8FoQAAcIQCAMARCgAARygAAByhAABwhAIAwBEKAABHKAAAHKEAAHCEAgDAhYtnxpN4b4+Z1guz0O9Lu7NU6zN6v3S7HWm+nsW7W5T+EzO9P0rZnxeFtLsS/k+1uyVJWuHZXq8r7S4n8U4tM7PDo0F49t3bd6Tdr716Izz7zt0taffwJN7Z1elo1z5rxz8Ti8va56exU2l+ezfeZ/TKv78u7c6Ez8QXPv9z0u6V1eXwbPo+fBfypAAAcIQCAMARCgAARygAAByhAABwhAIAwBEKAABHKAAAHKEAAHCEAgDAhWsuup2etHg0itdiTMtS2p1m8Ve7y3Iq7a6mVXi2t6Cdk0J4Nb7Vitc5mJl1xMqNRughaRqxjEKo0EjEoovRaByevXP7nrR7b3cozW/d2wnPvnnzh9LuO7fvhmeL5RVpd1LE75VJVUu7G8vjs+Fvn5/o9OP1D2Zm5Sx+b+0eaNf+pRe/FZ598upZaffHz34sPqz2xATwpAAAcIQCAMARCgAARygAAByhAABwhAIAwBEKAABHKAAAHKEAAHCEAgDAEQoAABduH+l029LiWa11pijqWXx3nsd7kszMkkSZn0u7U6ETaF5ru1WVcA63d7ak3W++9U549sH9R9Lu+/fifUaPHm5Lu4v2ojTf68Q7hPIi3glkZvYzH7oWnt07nki7R8JnMym0Dq6kpfzOFH+TzrXP8urqUnj2wZZ2Hx7t74VnX3gh3pNkZrZ5+Up49syqds9G8KQAAHCEAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAARygAAByhAABwhAIAwIVrLspyKi3O8/hr/ZPxWNqt1Ev0uvEqAjOzsoxXBpRVJe3upr3wbGum1SLUTSPNDwbD8OytW3ek3TdevRGefeutH0u7V5b64dn+gnbtr1y+IM1fvno2PJu3tbqI0TD80bR3Hwyk3SfT+Ge5rmfS7plQn9It4p8HM7PBYCTNt5fi+5fEe+VeEa/9ufVOvBLDzOz179wKz/7Czz8t7Y7gSQEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAACOUAAAOEIBAOAIBQCAIxQAAC5csNJbWJAW10IHyunJibS7quLdLb2FeEeJqhL7oDKhD2pWxfudzMws1bp17t7dCs8+3N6Xdm9evhyeXV5dk3avnVkJz2aJ1h917my8y8jMbHk5vn97J36+zcyyLH7frqyuSLs7s3hnV5rFO5jMzAZH8R6m3sKitLvRqsakbrKFntbDdH7jfHj2RKx2+94bb4ZnP/KRJ7TlATwpAAAcoQAAcIQCAMARCgAARygAAByhAABwhAIAwBEKAABHKAAAHKEAAHDhd9iLopAWl1aGZ5dXlqXd4/FImNZe00+SeL1Ef0l7Nd6m8SqKlqXS6u19rYriW698Nzz7g5tvSbufevJKePby5UvS7kyoXRidTqTdN28+kOYPdrfjxzIeSruvXX86PJsVWpVLM4vXsyz3tXqbPI9fnzrehGNmZj3xWKrBcXh2NFK+U8yKbvycb5xfl3bfu3cvPPvKy69Ju5/942d/6gxPCgAARygAAByhAABwhAIAwBEKAABHKAAAHKEAAHCEAgDAEQoAAEcoAAAcoQAAcOGiksHRQFqc5/n7Mmtm1szinUP9ntaXUpbxXpi6nmm76/ju0anWxfLqa9+V5l968UZ49uG21qu0s3skzGqdQFUVP+f7+4fS7uF+/PqYmU2EbqV2of3+Wlu/GJ49c0Hr1rF4vZednJxKq3u9bni2KrVuqtW1FWm+FsqVToYn0u5MqJvqL2ndVPvb8Y601298V9odwZMCAMARCgAARygAAByhAABwhAIAwBEKAABHKAAAHKEAAHCEAgDAEQoAAEcoAABcuPtI1TRNeFbpKDEzGx7He0qqKn4cZmYP7m8Jx6H1wnS68Q6U+48eSruff/5Faf7ho6P4cBrvszEz2xX6jAbH/yHtNuFyzoV70Mysky1L80kS7+CaVqW0e/dgJzy7vNGXdhdF/D48PDiSdl/c3AjPJuLXT29hSZpXeszu3b0v7Z5WVXg27Wrdbr2FTnh2KF6fCJ4UAACOUAAAOEIBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAALj4e+bNXFpcNbP4QWSptDvN46+N3303XlthZvaXf/5X4dmN82el3U9cfzw8+/ob35d2/+Dm29K8teKv0jdWaLuFnxpVo/0uSYTxPI//j2Zmpt3ilqbxuogka0m7184vhmd7/fhxmJlVZfzzNhWqIn5yLAvh2bLUakiU6hwzs9HpKDybZVrlxslpvGonSeLfhWZmWR6/ydNE++6M4EkBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAAAuXPiRF2K/SlXFZ6fxWTOzRsiy51/4prT7zR/dD8/eeXdX2n3jjVvh2b2TA2n3dK7l+0woEZqLvx1aLaG7Rey9aoQOrsa0rhxrafdhVcW7dXpFvK/LzOzK5Uvh2c1LF6TdJ4MyPLt9/560O23q8OxyX+vUGhxNpPmHW/Hes3ah9WRVR8PwbCdeB2VmZkVH6NQ6s6otj+x8zzcCAP7PIhQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAAAuXHPRNFplQJrHX+tXKjHMzAajaXj27dvx2gozs7S7GD+OYfyVfjOzo/3D8GyVaueklYqv6TdCvUSiXfssi//WqFvx2gozsySZhWe1zWaJWBWitGg0pXg0VSs8euXihrT6sH0Unt1a6Um78yZeobG8Ev+smZlZHb/2Zmaj4+PwbH9NO4czi3+/dXOtImg8i99Y7Z62O4InBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAACOUAAAuHD3UX9B60CZlPF+ol5P2731vXvh2XKs9fZMxpPw7GhyKu1upcKxJFpez8VuqiyNd7eoJUKp0Kv0fv4qKRutK2dlQeuPKjpFeLaqtS6rsXAfriwuSbtz4ax/+CPXpd2KIteu/tKi9j2h3Fxj4fvKzKyq4vdWVYW/Zs3MLEvin82TsfYdFMGTAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAARygAAByhAABwhAIAwBEKAAAXfv+6abSuA+U1/W5Xqxf45kuvhmfv3Loj7W7mo/Ds+Q3ttfupUEVxb2cg7VZlLaHmoqX9dkiF/zNJ4lURZmZlFT+WxuJ1G2ZmeaHdh500Pt9LutLuS49txofFn3bLy0IthnIcZlZN43UeavWHNm2WJPH7cDIeitvr8OThwYm0eXWlH549GY6l3RE8KQAAHKEAAHCEAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAARygAAByhAABwhAIAwIW7jyZjrWOjKNrhWbFax+pqGp5dW1uQdn/i058Mz37uVz4u7T44inegfOVvvy7tvvX2bWk+sXiXVSpen3QW74WZlvGuKTOzROgbMhP6nf4blM6u/f19afedH78bnv3cZ7X7sBzFz3mr1ZJ2Ly4thme3d7el3U2jtR/lRfzG3d3Vusa6neXw7M6W+H8K/+ZoGP8ujOJJAQDgCAUAgCMUAACOUAAAOEIBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAALt59VGodG8sr8W6QPA8fhpmZ/cEffSk8+87te9LuK5cvhGfPXTwj7U7yIjw7mkir7S/+7CvS/NH+aXh2ua/0DZl94Zc/FT+O05m0++UbPwzP7h1rfV2jU21+tRfv+en2tHO4txfvSjo5GUq7u3m8lyxJtN+NaZLGZ7P4rJnZvNQ+FJuXNsKzt26/Lu3O8vj1rKp4F5iZ2eFh/HrWQs9YFE8KAABHKAAAHKEAAHCEAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAARygAAFy4X2JWV9LiWR2vLyjLUtq9+dhaeHb9XE/aXZVNeLYp59LutBV/rb/fiVdimJktiTUKmeXh2U6u/Xb44m89E55d33xM2n35n18Mz/713/yDtPsDK/HaCjOzpImflzyLn28zs+0H2+HZeha/Z83Mzm6eDc92utp9pdRizBPtO6WcaTUXT157Ijy7N9B2v/XWfWG6Je2elvHz0tJWh/CkAABwhAIAwBEKAABHKAAAHKEAAHCEAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAAF+4+6i8tS4vrJt4LdHB4JO0+l8e7jyanWqfJ4HAUnl0/uyHt/tfnXgjPvvJvL0u7P3rtqjT/9NMfCs/+09e+Lu0uq/g5X1vT+oY2VuPzF89q9+yXv/Rr0vyDO1vh2eef+4a0O8/i/UTr62ek3fNZHZ5N03hfl5lZt9sVdmvFPSdj7bP8wavx3qannrom7f77v/vH8OxLL9+UdmtNVu89nhQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAACOUAAAuHDNxeBkIC1uhFfpx2W8WsLMrCpX4rPam/G2v30Qnl1eih+HmdnG2fj8Zz/1UWn31auXpfnHn3g8PNtUR9LuH775o/Bsd2lV2v3ai98Kzz7z6Y9Kuz/zi/HqDzOz+pPxaoS9vXvS7tFwGp5N4h9jMzNrhCKFvJ1LuxVFUUjz586c0/an8e+s9fV4dY6Z2W9/6VfDswPt682+/e3vh2fn8/e+FIMnBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAACOUAAAuHj30f6utHgyjpcOtYuOtDtvx7Nsua/tnl1YD88ODnek3Y8/eSk8e/XaprS7abQOlP65fnj214WeFzOz5/7l+fBsajNp95/86bPh2U63J+2upifS/Obm+fDsM898Rtr93Ne+IUxr177TXQzPTst4B5OZ2clJ/BymaSrtzjJtvlO0w7PVtJR2X7/+VHj29/9wWdpdlvHvzu98+zvS7gieFAAAjlAAADhCAQDgCAUAgCMUAACOUAAAOEIBAOAIBQCAIxQAAI5QAAC4cM1FeXosLU6FV9IPth9Iu7eLeJatLK9Iu7Mk/op50WtJuy2rwqPDciytLivtNf3pdry+YCRUF5iZPba5Fp4dHj2Udq+snYsPa+0P1u/m0vxg/yA8u/fokbR7fW0pPNvraVUuc6ESJc/DXxFmZjY8HoZnW4n2mzTPteujaCXa/zkenYZnP3B2Rdr9u8/+Znj2/Ln4fRLFkwIAwBEKAABHKAAAHKEAAHCEAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAARygAAFy48GM6PpQWN/P4bJFr2XS4ez88W50eSbtHo3jPz/qZdWm3UDlje/tb0u79fe36KB01hdh/025m4dmJ2DlzPIj/n/2e1gvzwSefkObTLN45VI/jnVpmZj/7sevh2ZbyYTOzY+Ha522tb2g2q8OzRRHvRzMz6/cXpPnRaBSePRX7vVpp/Ly000La/cHHL4RnN778G9LuCJ4UAACOUAAAOEIBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAALhwx0BTaa/pL68sh2d7vZ60u5rEX18fj/ak3bOqCs/u746l3fVRvNLhUKgiMDOryqk2f3Iank3ylrS704n/1mgareqgIzQG1JV23CcD7V7p99fCs5/8xIel3Y9duRqezRLt/6zn8VqM6lS7x8+sx89JqyVeH+GeNTPL0vi9tdDvS7trobOmaGtVLifDMjy7sqp9d0bwpAAAcIQCAMARCgAARygAAByhAABwhAIAwBEKAABHKAAAHKEAAHCEAgDAEQoAABcu5SjyeF+KmVlVxntKhlW8y8jMrBrH58ux1t2ixGRvdUVarZyTzryWdteV1n3Us1l4Nk+07pbE4t0taSaUGZlZXcc7uOa1dg7v37krzS8uxfupFpfWpd15Gu8FOjw8kHaPpvF+r8UlrRMoy/LwbCP0B5mZHR0eSfML/YXwbLfTkXbXQkdaXmj9Xmbx3dOx9rmP4EkBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAAAuXGqTz+NdOWZm45N4P9Fc7ECxWbzTRm0dyfJ4d8u0inf8mJmdHp+EZ9uZ1sWymGv9RB3rhmeTLN7DY2Y2Fu6VPNWuUJrEr0+nuyjtzpr4bjOzxaVeePbMuSVp98P9R+HZmfbRtKwd75uaTuL3rJnZVLgPe0W8m8jMbKWvncNC6DNKE/E+zOO/p4uudl8J1W4mVqSF8KQAAHCEAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAARygAAByhAABwhAIAwIXfSW9mU2lxN4+/Nq6+pj+v4u929/rxKgIzs1kar3Q4ODySdtfVPDzbSbSaCzXf+0vxyoCR8t69mY3m8WNpCbUVZma10IiS5/E6BzOzSxsb0nxj8Xvl8Ghf2m1Czcl0ciqt3n00CM82Qp2DmVlaxK/nxQuPSbt7vb40XwjfQernZ1TGv7SqUqvxqYXv2vFYqyGJ4EkBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAAAu3H3U6WpdPPUs3vcxLifS7iWhz2gm9CSZmY3KMjzb72tdLHnaDs/OSu24e+2uNF/P4vubuXYsi5349ekvLEi7yzLeHzU61nph9rNdaX5pMd4fNZloBV+9pXiH0Hh8KO3e2bkbnk078XvWzGzaxO+V7a34cZiZrfbXpfm19ZXw7OhU6/dSfk/nC1r/2mgc/z6cyMf90/GkAABwhAIAwBEKAABHKAAAHKEAAHCEAgDAEQoAAEcoAAAcoQAAcIQCAMCFay6yNJUWj07jr2pnibZbqa4YnWpVB3Uez8luP3z6zMysHE/Ds9WoknardRHK9aln8WoJM7NcGD8+OpJ2F3n8/+wVWg3JXKg4MTNbubQYnr1++bK0+5uvvhGeHR7tS7v7QmPNcXks7S7r+H1bihU0kxOtzmNwHK9+qSqthqSVxD/7eR6/T8zMut14LcZYqMSI4kkBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgCMUAAAuXOAxnsR7e1S9Xrzrw8ysPB6FZ/OikHbXTbyPpSy1c9IIs6nYNTU40jpqFO1c63iqZ/EemVPhWpqZ5WvxeyWxlrS7EY7bzKwux+HZ6UTrqElNKJCqtfswFe7Ebq7dh0IlkPyLtCy1HrNqGL+32kVb2t3L471ava72HbS6shyefXh6Ku2O4EkBAOAIBQCAIxQAAI5QAAA4QgEA4AgFAIAjFAAAjlAAADhCAQDgCAUAgAu/lF6VpbS4buKv6R8fD6XdRZKHZ9XX9Eej+LGUlbTaMqEDIFX6AswsScRKB+H6ZJl2LHke/62RLom7W/Hr2VTxyhIzs6KnncOdra3w7MGhdo8Pdo/Cs9VIvBEtfl6qmXYOLYufw5Zpn800jX/uzczqWvg/p9r/OU3j57ybaBUnw4O98Gw91XZH8KQAAHCEAgDAEQoAAEcoAAAcoQAAcIQCAMARCgAARygAAByhAABwhAIAwBEKAADXms/n8RIcAMD/azwpAAAcoQAAcIQCAMARCgAARygAAByhAABwhAIAwBEKAABHKAAA3H8CNuPM4NWXOnAAAAAASUVORK5CYII=", 459 | "text/plain": [ 460 | "
" 461 | ] 462 | }, 463 | "metadata": {}, 464 | "output_type": "display_data" 465 | } 466 | ], 467 | "source": [ 468 | "show_img(val_imgs[np.random.choice(val_imgs.shape[0])])" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 7, 474 | "id": "afa2d72a", 475 | "metadata": {}, 476 | "outputs": [ 477 | { 478 | "name": "stdout", 479 | "output_type": "stream", 480 | "text": [ 481 | "CPU times: user 5min 39s, sys: 9min 22s, total: 15min 1s\n", 482 | "Wall time: 26.7 s\n" 483 | ] 484 | } 485 | ], 486 | "source": [ 487 | "%%time\n", 488 | "\n", 489 | "# Get softmax scores\n", 490 | "with torch.no_grad():\n", 491 | " logits = model(torch.from_numpy(val_imgs))\n", 492 | " \n", 493 | "softmax_scores = torch.nn.functional.softmax(logits,dim=1)\n", 494 | "softmax_scores = softmax_scores.numpy()" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 8, 500 | "id": "59ad2db4", 501 | "metadata": {}, 502 | "outputs": [ 503 | { 504 | "name": "stdout", 505 | "output_type": "stream", 506 | "text": [ 507 | "Save softmax scores to ./.cache/best-cifar100-model-fracval=0.7-valsoftmax_frac=0.5.npy\n" 508 | ] 509 | } 510 | ], 511 | "source": [ 512 | "# Save softmax scores\n", 513 | "pth = './.cache/' + config['model_filename'] + f'-valsoftmax_frac={config[\"frac_val\"]}.npy'\n", 514 | "\n", 515 | "np.save(pth, softmax_scores)\n", 516 | "print(f'Save softmax scores to {pth}')" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": null, 522 | "id": "114c1685", 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [] 526 | } 527 | ], 528 | "metadata": { 529 | "kernelspec": { 530 | "display_name": "env", 531 | "language": "python", 532 | "name": "env" 533 | }, 534 | "language_info": { 535 | "codemirror_mode": { 536 | "name": "ipython", 537 | "version": 3 538 | }, 539 | "file_extension": ".py", 540 | "mimetype": "text/x-python", 541 | "name": "python", 542 | "nbconvert_exporter": "python", 543 | "pygments_lexer": "ipython3", 544 | "version": "3.10.10" 545 | } 546 | }, 547 | "nbformat": 4, 548 | "nbformat_minor": 5 549 | } 550 | -------------------------------------------------------------------------------- /generate_scores/train_models/cifar_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pickle 4 | import os, time, copy 5 | import torch 6 | import torchvision as tv 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | from torch.utils.data import TensorDataset 10 | from torchvision.models import resnet50 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | from sklearn.model_selection import train_test_split 14 | import pdb 15 | 16 | def set_parameter_requires_grad(model, feature_extracting): 17 | if feature_extracting: 18 | for param in model.parameters(): 19 | param.requires_grad = False 20 | model.fc.weight.requires_grad = True 21 | model.fc.bias.requires_grad = True 22 | 23 | def unpickle(file): 24 | with open(file, 'rb') as fo: 25 | dict = pickle.load(fo, encoding='bytes') 26 | return dict 27 | 28 | def get_data(): 29 | # Load and unpack data 30 | orig_train_data = unpickle('../data/cifar-100-python/train') 31 | orig_test_data = unpickle('../data/cifar-100-python/test') 32 | tr_imgs = orig_train_data[b'data'].astype(np.float32) 33 | te_imgs = orig_test_data[b'data'].astype(np.float32) 34 | tr_labels = torch.tensor(np.array(orig_train_data[b'fine_labels']).astype(int)) 35 | te_labels = torch.tensor(np.array(orig_test_data[b'fine_labels']).astype(int)) 36 | 37 | # Fuse train and val sets 38 | imgs = np.concatenate([tr_imgs, te_imgs], axis=0) 39 | labels = np.concatenate([tr_labels, te_labels], axis=0) 40 | 41 | # Reshape and normalize images to mean 0 std 1 42 | imgs = imgs.reshape(imgs.shape[0], 3, 32, 32) 43 | total_pixels_per_channel = imgs.shape[0] * imgs.shape[2] * imgs.shape[3] 44 | means = imgs.sum(axis=2).sum(axis=2).sum(axis=0) / total_pixels_per_channel 45 | stds = np.sqrt(((imgs - means[None,:,None,None])**2).sum(axis=2).sum(axis=2).sum(axis=0)/total_pixels_per_channel) 46 | imgs = (imgs - means[None,:,None,None])/stds[None,:,None,None] 47 | return imgs, labels 48 | 49 | def get_dataloaders(config, frac_val=0.1): 50 | assert 0 <= frac_val <= 1 51 | 52 | imgs, labels = get_data() 53 | 54 | data_transforms = { 55 | 'train': transforms.Compose([ 56 | transforms.RandomResizedCrop(224), # 224 is due to Imagenet input size 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor() 59 | ]), 60 | 'val': transforms.Compose([ 61 | transforms.Resize((224,224)), 62 | transforms.CenterCrop(224), 63 | transforms.ToTensor(), 64 | ]), 65 | } 66 | 67 | train_imgs, val_imgs, train_labels, val_labels = train_test_split(imgs, labels, test_size=frac_val, random_state=0) 68 | 69 | # Create training and validation datasets 70 | image_datasets = { 71 | 'train' : TensorDataset(torch.tensor(train_imgs).float(), torch.tensor(train_labels).long()), 72 | 'val' : TensorDataset(torch.tensor(val_imgs).float(), torch.tensor(val_labels).long()) 73 | } 74 | # Create training and validation dataloaders 75 | dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) for x in ['train', 'val']} 76 | 77 | return dataloaders_dict 78 | 79 | def train_model(model, dataloaders, config): 80 | # Gather the parameters to be optimized/updated in this run. If we are 81 | # finetuning we will be updating all parameters. However, if we are 82 | # doing feature extract method, we will only update the parameters 83 | # that we have just initialized, i.e. the parameters with requires_grad 84 | # is True. 85 | set_parameter_requires_grad(model, config['feature_extract']) 86 | 87 | params_to_update = model.parameters() 88 | print("Params to learn:") 89 | if config['feature_extract']: 90 | params_to_update = [] 91 | for name,param in model.named_parameters(): 92 | if param.requires_grad == True: 93 | params_to_update.append(param) 94 | print("\t",name) 95 | else: 96 | for name,param in model.named_parameters(): 97 | if param.requires_grad == True: 98 | print("\t",name) 99 | 100 | # The above prints show which layers are being optimized 101 | 102 | criterion = nn.CrossEntropyLoss() 103 | 104 | optimizer = optim.Adam(params_to_update, lr=config['lr']) 105 | 106 | since = time.time() 107 | 108 | val_acc_history = [] 109 | 110 | best_model_wts = copy.deepcopy(model.state_dict()) 111 | best_acc = 0.0 112 | 113 | for epoch in range(config['num_epochs']): 114 | print('Epoch {}/{}'.format(epoch, config['num_epochs'] - 1)) 115 | print('-' * 10) 116 | 117 | # Each epoch has a training and validation phase 118 | for phase in ['train', 'val']: 119 | if phase == 'train': 120 | model.train() # Set model to training mode 121 | else: 122 | model.eval() # Set model to evaluate mode 123 | 124 | running_loss = 0.0 125 | running_corrects = 0 126 | 127 | # Iterate over data. 128 | for inputs, labels in dataloaders[phase]: 129 | inputs = inputs.to(config['device']) 130 | labels = labels.to(config['device']) 131 | 132 | # zero the parameter gradients 133 | optimizer.zero_grad() 134 | 135 | # forward 136 | # track history if only in train 137 | with torch.set_grad_enabled(phase == 'train'): 138 | # Get model outputs and calculate loss 139 | outputs = model(inputs) 140 | loss = criterion(outputs, labels) 141 | 142 | _, preds = torch.max(outputs, 1) 143 | 144 | # backward + optimize only if in training phase 145 | if phase == 'train': 146 | loss.backward() 147 | optimizer.step() 148 | 149 | # statistics 150 | running_loss += loss.item() * inputs.size(0) 151 | running_corrects += torch.sum(preds == labels.data) 152 | 153 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 154 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) 155 | 156 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 157 | 158 | # deep copy the model 159 | if phase == 'val' and epoch_acc > best_acc: 160 | best_acc = epoch_acc 161 | best_model_wts = copy.deepcopy(model.state_dict()) 162 | if phase == 'val': 163 | val_acc_history.append(epoch_acc.item()) 164 | 165 | print() 166 | 167 | time_elapsed = time.time() - since 168 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 169 | print('Best val Acc: {:4f}'.format(best_acc)) 170 | 171 | # load best model weights 172 | model.load_state_dict(best_model_wts) 173 | 174 | # Save best model weights 175 | os.makedirs('./.cache', exist_ok=True) 176 | torch.save(best_model_wts, './.cache/' + config['model_filename'] + '.pth') 177 | np.save('./.cache/' + config['model_filename'] + f'-valdata_frac={config["frac_val"]}.npy', dataloaders['val'].dataset.tensors[0].numpy()) 178 | np.save('./.cache/' + config['model_filename'] + f'-vallabels_frac={config["frac_val"]}.npy', dataloaders['val'].dataset.tensors[1].numpy()) 179 | with open('./.cache/' + config['model_filename'] + '-config.pkl', 'wb') as f: 180 | pickle.dump(config, f) 181 | 182 | return model, val_acc_history 183 | 184 | def get_model(config): 185 | model = resnet50(weights="IMAGENET1K_V2") 186 | model.fc = nn.Linear(model.fc.in_features, config['num_classes']) 187 | model = model.to(config['device']) 188 | try: 189 | state_dict = torch.load('./.cache/' + config['model_filename'] + '.pth', map_location=config['device']) 190 | model.load_state_dict(state_dict) 191 | model.eval() 192 | with open('./.cache/' + config['model_filename'] + '-config.pkl', 'rb') as f: 193 | loaded_config = pickle.load(f) 194 | assert config['num_classes'] == loaded_config['num_classes'] # If the configs aren't equal, retrain 195 | assert config['batch_size'] == loaded_config['batch_size'] # If the configs aren't equal, retrain 196 | assert config['num_epochs'] == loaded_config['num_epochs'] # If the configs aren't equal, retrain 197 | assert config['frac_val'] == loaded_config['frac_val'] # If the configs aren't equal, retrain 198 | assert config['lr'] == loaded_config['lr'] # If the configs aren't equal, retrain 199 | except: 200 | model = model.to(config['device']) 201 | dataloaders = get_dataloaders(config, frac_val = config['frac_val']) 202 | 203 | model, val_acc_history = train_model(model, dataloaders, config) 204 | return model 205 | 206 | def show_img(x): 207 | x = x.transpose(1,2,0) 208 | x = (x - x.min())/(x.max() - x.min()) 209 | plt.imshow(x) 210 | plt.axis('off') 211 | plt.show() 212 | 213 | if __name__ == "__main__": 214 | config = { 215 | 'num_classes' : 100, 216 | 'batch_size' : 128, 217 | 'lr' : 0.0001, 218 | 'feature_extract' : False, 219 | 'num_epochs' : 30, 220 | 'device' : 'cuda', 221 | 'frac_val' : 0.7, # CHANGED 222 | 'model_filename' : 'best-cifar100-model-fracval=0.7', # CHANGED 223 | 'num_workers' : 4, 224 | } 225 | get_model(config) 226 | -------------------------------------------------------------------------------- /generate_scores/train_models/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def split_X_and_y(X, y, n_k, num_classes, seed=0): 4 | ''' 5 | Randomly generate two subsets of features X and corresponding labels y such that the 6 | first subset contains n_k instances of each class k and the second subset contains all 7 | other instances 8 | 9 | Inputs: 10 | X: n x d array (e.g., matrix of softmax vectors) 11 | y: n x 1 array 12 | n_k: positive int or n x 1 array 13 | num_classes: total number of classes, corresponding to max(y) 14 | seed: random seed 15 | 16 | Output: 17 | X1, y1 18 | X2, y2 19 | ''' 20 | np.random.seed(seed) 21 | 22 | if not hasattr(n_k, '__iter__'): 23 | n_k = n_k * np.ones((num_classes,), dtype=int) 24 | 25 | X1 = np.zeros((np.sum(n_k), X.shape[1])) 26 | y1 = np.zeros((np.sum(n_k), ), dtype=np.int32) 27 | 28 | all_selected_indices = np.zeros(y.shape) 29 | 30 | i = 0 31 | for k in range(num_classes): 32 | 33 | # Randomly select n instances of class k 34 | idx = np.argwhere(y==k).flatten() 35 | selected_idx = np.random.choice(idx, replace=False, size=(n_k[k],)) 36 | 37 | X1[i:i+n_k[k], :] = X[selected_idx, :] 38 | y1[i:i+n_k[k]] = k 39 | i += n_k[k] 40 | 41 | all_selected_indices[selected_idx] = 1 42 | 43 | X2 = X[all_selected_indices == 0] 44 | y2 = y[all_selected_indices == 0] 45 | 46 | return X1, y1, X2, y2 -------------------------------------------------------------------------------- /generate_scores/train_models/torchvision_dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pickle 4 | import os, time, copy 5 | import torch 6 | import torchvision as tv 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | from torch.utils.data import TensorDataset, Subset 10 | from torchvision.models import resnet50 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | 14 | from scipy.special import softmax 15 | from sklearn.model_selection import train_test_split 16 | from tqdm import tqdm 17 | import pdb 18 | 19 | def set_parameter_requires_grad(model, feature_extracting): 20 | if feature_extracting: 21 | for param in model.parameters(): 22 | param.requires_grad = False 23 | model.fc.weight.requires_grad = True 24 | model.fc.bias.requires_grad = True 25 | 26 | def calc_mean_std(my_dataset): 27 | image_data_loader = torch.utils.data.DataLoader( 28 | my_dataset, 29 | batch_size=512, 30 | shuffle=True, 31 | num_workers=2 32 | ) 33 | 34 | X, y = iter(image_data_loader).__next__() 35 | 36 | mean = X.mean(dim=(0,2,3)) 37 | std = X.std(dim=(0,2,3)) 38 | return mean, std 39 | 40 | def load_and_process_dataset(dset_fn, target_fn, min_train_instances_class): 41 | # Standard transform 42 | transform_img = transforms.Compose([ 43 | transforms.Resize(224), 44 | transforms.CenterCrop(224), 45 | transforms.ToTensor(), 46 | # here do not use transforms.Normalize(mean, std) 47 | ]) 48 | train_dataset, val_dataset = dset_fn(transform_img) 49 | 50 | # Calculate mean and std dev to use when normalizing 51 | mean, std = calc_mean_std(val_dataset) 52 | 53 | # Filter out rare classes 54 | if min_train_instances_class > 0: 55 | train_targets = target_fn(train_dataset) 56 | val_targets = target_fn(val_dataset) 57 | unique_classes, counts = np.unique(train_targets, return_counts=True) 58 | counts_large_enough = counts >= min_train_instances_class 59 | final_classes = unique_classes[ counts_large_enough ] 60 | final_train_idxs = np.where( np.isin(train_targets, final_classes) )[0] 61 | final_val_idxs = np.where( np.isin(val_targets, final_classes) )[0] 62 | 63 | # Map class labels to consecutive 0,1,2,... 64 | label_remapping = {} 65 | idx = 0 66 | for k in final_classes: 67 | label_remapping[k] = idx 68 | idx += 1 69 | def transform_label(k): 70 | return label_remapping[k] 71 | target_transform = transform_label 72 | 73 | else: 74 | final_train_idxs = np.arange(len(train_dataset)) 75 | final_val_idxs = np.arange(len(val_dataset)) 76 | target_transform = None 77 | 78 | transform_img = transforms.Compose([transform_img, transforms.Normalize(mean, std)]) 79 | 80 | train_dataset, val_dataset = dset_fn(transform_img, target_transform=target_transform) 81 | 82 | dataset = torch.utils.data.ConcatDataset([Subset(train_dataset, final_train_idxs), Subset(val_dataset, final_val_idxs)]) 83 | 84 | return dataset 85 | 86 | def get_dataloaders(config): 87 | # Load and unpack data 88 | if config['dataset_name'] == 'iNaturalist': 89 | def dset_fn(transform_img, target_transform=None): 90 | train_dataset = datasets.INaturalist(root = '/checkpoints/aa/inaturalist/train', 91 | version = '2021_train', 92 | download=False, 93 | target_type = config['target_type'], 94 | transform=transform_img, 95 | target_transform=target_transform) 96 | val_dataset = datasets.INaturalist(root = '/checkpoints/aa/inaturalist/val', 97 | version = '2021_valid', 98 | download=False, 99 | target_type = config['target_type'], 100 | transform=transform_img, 101 | target_transform=target_transform) 102 | return train_dataset, val_dataset 103 | def target_fn(dset): 104 | return np.array([x[0] for x in dset.index]) 105 | if config['dataset_name'] == 'Places365': 106 | def dset_fn(transform_img, target_transform=None): 107 | train_dataset = datasets.Places365(root = '/checkpoints/aa/places/', 108 | split = 'train-standard', 109 | download=False, 110 | transform=transform_img, 111 | target_transform=target_transform) 112 | val_dataset = datasets.Places365(root = '/checkpoints/aa/places/', 113 | split = 'val', 114 | transform=transform_img, 115 | download=False, 116 | target_transform=target_transform) 117 | return train_dataset, val_dataset 118 | def target_fn(dset): 119 | return np.array(dset.targets) 120 | 121 | dataset = load_and_process_dataset(dset_fn, target_fn, config['min_train_instances_class'] ) 122 | 123 | assert 0 <= config['frac_val'] <= 1 124 | 125 | generator1 = torch.Generator().manual_seed(0) # For reproducibility 126 | train, val = torch.utils.data.random_split(dataset, [1-config['frac_val'], config['frac_val']], generator=generator1) 127 | 128 | # Create training and validation datasets 129 | image_datasets = { 130 | 'train' : train, 131 | 'val' : val 132 | } 133 | # Create training and validation dataloaders 134 | dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers']) for x in ['train', 'val']} 135 | 136 | return dataloaders_dict 137 | 138 | def train_model(model, dataloaders, config): 139 | # Gather the parameters to be optimized/updated in this run. If we are 140 | # finetuning we will be updating all parameters. However, if we are 141 | # doing feature extract method, we will only update the parameters 142 | # that we have just initialized, i.e. the parameters with requires_grad 143 | # is True. 144 | save_every_epoch = False # save weights every epoch if accuracy is better than previous best 145 | 146 | set_parameter_requires_grad(model, config['feature_extract']) 147 | 148 | params_to_update = model.parameters() 149 | print("Params to learn:") 150 | if config['feature_extract']: 151 | params_to_update = [] 152 | for name,param in model.named_parameters(): 153 | if param.requires_grad == True: 154 | params_to_update.append(param) 155 | print("\t",name) 156 | else: 157 | for name,param in model.named_parameters(): 158 | if param.requires_grad == True: 159 | print("\t",name) 160 | 161 | # The above prints show which layers are being optimized 162 | 163 | criterion = nn.CrossEntropyLoss() 164 | 165 | optimizer = optim.Adam(params_to_update, lr=config['lr']) 166 | 167 | since = time.time() 168 | 169 | val_acc_history = [] 170 | 171 | best_model_wts = copy.deepcopy(model.state_dict()) 172 | best_acc = 0.0 173 | 174 | for epoch in range(config['num_epochs']): 175 | print('Epoch {}/{}'.format(epoch, config['num_epochs'] - 1)) 176 | print('-' * 10) 177 | 178 | # Each epoch has a training and validation phase 179 | for phase in ['train', 'val']: 180 | if phase == 'train': 181 | model.train() # Set model to training mode 182 | else: 183 | model.eval() # Set model to evaluate mode 184 | 185 | running_loss = 0.0 186 | running_corrects = 0 187 | 188 | # Iterate over data. 189 | for inputs, labels in tqdm(dataloaders[phase]): 190 | inputs = inputs.to(config['device']) 191 | labels = labels.to(config['device']) 192 | 193 | # pdb.set_trace() 194 | 195 | # zero the parameter gradients 196 | optimizer.zero_grad() 197 | 198 | # forward 199 | # track history if only in train 200 | with torch.set_grad_enabled(phase == 'train'): 201 | # Get model outputs and calculate loss 202 | outputs = model(inputs) 203 | loss = criterion(outputs, labels) 204 | 205 | _, preds = torch.max(outputs, 1) 206 | 207 | # backward + optimize only if in training phase 208 | if phase == 'train': 209 | loss.backward() 210 | optimizer.step() 211 | 212 | # statistics 213 | running_loss += loss.item() * inputs.size(0) 214 | running_corrects += torch.sum(preds == labels.data) 215 | 216 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 217 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) 218 | 219 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 220 | 221 | # deep copy the model 222 | if phase == 'val' and epoch_acc > best_acc: 223 | best_acc = epoch_acc 224 | best_model_wts = copy.deepcopy(model.state_dict()) 225 | 226 | # Save model weights 227 | if save_every_epoch: 228 | print(f'Saving epoch {epoch} model') 229 | os.makedirs('./.cache', exist_ok=True) 230 | torch.save(best_model_wts, './.cache/' + config['model_filename'] + '.pth') 231 | with open('./.cache/' + config['model_filename'] + '-config.pkl', 'wb') as f: 232 | pickle.dump(config, f) 233 | 234 | if phase == 'val': 235 | val_acc_history.append(epoch_acc.item()) 236 | 237 | print() 238 | 239 | time_elapsed = time.time() - since 240 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 241 | print('Best val Acc: {:4f}'.format(best_acc)) 242 | 243 | # load best model weights 244 | model.load_state_dict(best_model_wts) 245 | 246 | # Save best model weights 247 | torch.save(best_model_wts, './.cache/' + config['model_filename'] + '.pth') 248 | 249 | with open('./.cache/' + config['model_filename'] + '-config.pkl', 'wb') as f: 250 | pickle.dump(config, f) 251 | 252 | # Save val softmax scores and labels 253 | val_softmax = np.zeros((len(dataloaders['val'].dataset),config['num_classes'])) 254 | val_labels = np.zeros((len(dataloaders['val'].dataset),), dtype=int) 255 | j = 0 256 | with torch.no_grad(): 257 | for inputs, labels in tqdm(dataloaders['val']): 258 | inputs = inputs.to(config['device']) 259 | val_labels[j:j+inputs.shape[0]] = labels.numpy() 260 | 261 | # Get model outputs 262 | val_softmax[j:j+inputs.shape[0],:] = model(inputs).detach().cpu().numpy() 263 | j = j + inputs.shape[0] 264 | 265 | # Apply softmax to logits 266 | val_softmax = softmax(val_softmax, axis=1) 267 | 268 | os.makedirs('./.cache', exist_ok=True) 269 | np.save('./.cache/' + config['model_filename'] + f'-valsoftmax_frac={config["frac_val"]}.npy', val_softmax) 270 | np.save('./.cache/' + config['model_filename'] + f'-vallabels_frac={config["frac_val"]}.npy', val_labels) 271 | print('Saved val set softmax scores and labels') 272 | 273 | return model, val_acc_history 274 | 275 | def get_model(config): 276 | model = resnet50(weights="IMAGENET1K_V2") 277 | model.fc = nn.Linear(model.fc.in_features, config['num_classes']) 278 | model = model.to(config['device']) 279 | try: 280 | # ADDED 281 | assert False 282 | state_dict = torch.load('./.cache/' + config['model_filename'] + '.pth', map_location=config['device']) 283 | model.load_state_dict(state_dict) 284 | model.eval() 285 | with open('./.cache/' + config['model_filename'] + '-config.pkl', 'rb') as f: 286 | loaded_config = pickle.load(f) 287 | 288 | for setting in ['num_classes', 'batch_size', 'num_epochs', 289 | 'frac_val', 'lr', 'dataset_name', 'min_train_instances_class', 'target_type']: 290 | assert config[setting] == loaded_config[setting] # If the configs aren't equal, retrain 291 | 292 | except: 293 | model = model.to(config['device']) 294 | dataloaders = get_dataloaders(config) 295 | 296 | model, val_acc_history = train_model(model, dataloaders, config) 297 | return model 298 | 299 | def show_img(x): 300 | x = x.transpose(1,2,0) 301 | x = (x - x.min())/(x.max() - x.min()) 302 | plt.imshow(x) 303 | plt.axis('off') 304 | plt.show() 305 | 306 | def postprocess_config(config): 307 | if config['dataset_name'] == 'Places365': 308 | config['num_classes'] = 365 309 | elif config['dataset_name'] == 'iNaturalist' and config['target_type'] == 'full': 310 | config['num_classes'] = 6414 311 | elif config['dataset_name'] == 'iNaturalist' and config['target_type'] == 'family': 312 | config['num_classes'] = 1103 313 | else: 314 | raise NotImplementedError 315 | return config 316 | 317 | if __name__ == "__main__": 318 | # config = { 319 | # 'batch_size' : 128, 320 | # 'lr' : 0.0001, 321 | # 'feature_extract' : False, 322 | # 'num_epochs' : 30, 323 | # 'device' : 'cuda', 324 | # 'frac_val' : 0.1, # For Places 365, this corresponds to >= 500 examples for calibration/val 325 | # 'num_workers' : 4, 326 | # 'dataset_name' : 'Places365', 327 | # 'model_filename' : 'best-places365-model', 328 | # 'target_type': 'full', 329 | # 'min_train_instances_class' : 10 330 | # } 331 | config = { 332 | 'batch_size' : 128, 333 | 'lr' : 0.0001, 334 | 'feature_extract' : False, 335 | 'num_epochs' : 30, 336 | 'device' : 'cuda', 337 | 'frac_val' : 0.5, 338 | 'num_workers' : 4, 339 | 'dataset_name' : 'iNaturalist', 340 | 'model_filename' : 'best-inaturalist-model', 341 | 'target_type': 'full', 342 | 'min_train_instances_class' : 290 343 | } 344 | config = postprocess_config(config) 345 | get_model(config) 346 | 347 | -------------------------------------------------------------------------------- /generate_scores/train_models/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from torchvision_dataset_utils import * 4 | 5 | if __name__ == "__main__": 6 | 7 | parser = argparse.ArgumentParser(description='Train model') 8 | parser.add_argument('dataset', type=str, choices=['Places365', 'iNaturalist'], 9 | help='Name of the dataset to train model on') 10 | parser.add_argument('frac_val', type=float, 11 | help='Fraction of data to reserve for validation') 12 | parser.add_argument('--min_train_instances', type=int, default=0, 13 | help='Classes with fewer than this many classes in the published train dataset will be filtered out') 14 | parser.add_argument('--num_epochs', type=int, default=30, 15 | help='Number of epochs to train for') 16 | parser.add_argument('--target_type', type=str, default='full', 17 | help="Only used when dataset==iNaturalist. Options are ['full', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus'] ") 18 | 19 | args = parser.parse_args() 20 | 21 | config = { 22 | 'batch_size' : 128, 23 | 'lr' : 0.0001, 24 | 'feature_extract' : False, 25 | 'num_epochs' : args.num_epochs, 26 | 'device' : 'cuda', 27 | 'frac_val' : args.frac_val, 28 | 'num_workers' : 4, 29 | 'dataset_name' : args.dataset, 30 | 'model_filename' : f'best-{args.dataset}-model', 31 | 'target_type': args.target_type, 32 | 'min_train_instances_class' : args.min_train_instances 33 | } 34 | config = postprocess_config(config) 35 | get_model(config) -------------------------------------------------------------------------------- /generate_scores/train_models/train_inaturalist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run this file using "sbatch my_script.sh" 4 | 5 | # the SBATCH directives must appear before any executable 6 | # line in this script 7 | 8 | #SBATCH --gres=gpu:1 9 | #SBATCH -t 0-48:00 # time requested (D-HH:MM) 10 | # slurm will cd to this directory before running the script 11 | # you can also just run sbatch submit.sh from the directory 12 | # you want to be in 13 | #SBATCH -D /home/tding/code/class-conditional-conformal-datasets/notebooks 14 | # use these two lines to control the output file. Default is 15 | # slurm-.out. By default stdout and stderr go to the same 16 | # place, but if you use both commands below they'll be split up 17 | # filename patterns here: https://slurm.schedmd.com/sbatch.html 18 | # %N is the hostname (if used, will create output(s) per node) 19 | # %j is jobid 20 | #SBATCH -o /home/tding/slurm_output/train_inaturalist_job=%j.out # STDOUT 21 | #SBATCH -e /home/tding/slurm_output/train_inaturalist_job=%j.err # STDERR 22 | # if you want to get emails as your jobs run/fail 23 | ##SBATCH --mail-type=NONE # Mail events (NONE, BEGIN, END, FAIL, ALL) 24 | ##SBATCH --mail-user=tiffany_ding@eecs.berkeley.edu # Where to send mail 25 | #seff $SLURM_JOBID 26 | # print some info for context 27 | pwd | xargs -I{} echo "Current directory:" {} 28 | hostname | xargs -I{} echo "Node:" {} 29 | python train.py 'iNaturalist' .5 --num_epochs 1 --min_train_instances 290 -------------------------------------------------------------------------------- /generate_scores/train_models/train_inaturalist_family.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run this file using "sbatch my_script.sh" 4 | 5 | # the SBATCH directives must appear before any executable 6 | # line in this script 7 | 8 | #SBATCH --gres=gpu:1 9 | #SBATCH -t 0-48:00 # time requested (D-HH:MM) 10 | # slurm will cd to this directory before running the script 11 | # you can also just run sbatch submit.sh from the directory 12 | # you want to be in 13 | #SBATCH -D /home/tding/code/class-conditional-conformal-datasets/notebooks 14 | # use these two lines to control the output file. Default is 15 | # slurm-.out. By default stdout and stderr go to the same 16 | # place, but if you use both commands below they'll be split up 17 | # filename patterns here: https://slurm.schedmd.com/sbatch.html 18 | # %N is the hostname (if used, will create output(s) per node) 19 | # %j is jobid 20 | #SBATCH -o /home/tding/slurm_output/train_inaturalist_family_job=%j.out # STDOUT 21 | #SBATCH -e /home/tding/slurm_output/train_inaturalist_family_job=%j.err # STDERR 22 | # if you want to get emails as your jobs run/fail 23 | ##SBATCH --mail-type=NONE # Mail events (NONE, BEGIN, END, FAIL, ALL) 24 | ##SBATCH --mail-user=tiffany_ding@eecs.berkeley.edu # Where to send mail 25 | #seff $SLURM_JOBID 26 | # print some info for context 27 | pwd | xargs -I{} echo "Current directory:" {} 28 | hostname | xargs -I{} echo "Node:" {} 29 | python train.py 'iNaturalist' .5 --num_epochs 1 --target_type family -------------------------------------------------------------------------------- /generate_scores/train_models/train_places365.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run this file using "sbatch my_script.sh" 4 | 5 | # the SBATCH directives must appear before any executable 6 | # line in this script 7 | 8 | #SBATCH --gres=gpu:1 9 | #SBATCH -t 0-48:00 # time requested (D-HH:MM) 10 | # slurm will cd to this directory before running the script 11 | # you can also just run sbatch submit.sh from the directory 12 | # you want to be in 13 | #SBATCH -D /home/tding/code/class-conditional-conformal-datasets/notebooks 14 | # use these two lines to control the output file. Default is 15 | # slurm-.out. By default stdout and stderr go to the same 16 | # place, but if you use both commands below they'll be split up 17 | # filename patterns here: https://slurm.schedmd.com/sbatch.html 18 | # %N is the hostname (if used, will create output(s) per node) 19 | # %j is jobid 20 | #SBATCH -o /home/tding/slurm_output/train_places365_job=%j.out # STDOUT 21 | #SBATCH -e /home/tding/slurm_output/train_places365_job=%j.err # STDERR 22 | # if you want to get emails as your jobs run/fail 23 | ##SBATCH --mail-type=NONE # Mail events (NONE, BEGIN, END, FAIL, ALL) 24 | ##SBATCH --mail-user=tiffany_ding@eecs.berkeley.edu # Where to send mail 25 | #seff $SLURM_JOBID 26 | # print some info for context 27 | pwd | xargs -I{} echo "Current directory:" {} 28 | hostname | xargs -I{} echo "Node:" {} 29 | python train.py Places365 .1 --num_epochs 1 30 | -------------------------------------------------------------------------------- /notebooks/get_dataset_characteristics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "9edb6bec", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys; sys.path.append(\"../\") # For relative imports\n", 11 | "\n", 12 | "from scipy.stats import entropy\n", 13 | "\n", 14 | "from utils.experiment_utils import *\n", 15 | "\n", 16 | "%load_ext autoreload\n", 17 | "%autoreload 2" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "id": "953cb8e1", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Change to class-conditional-conformal directory\n", 28 | "if os.getcwd()[-9:] == 'notebooks':\n", 29 | " os.chdir('..')" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "a477f3d8", 35 | "metadata": {}, 36 | "source": [ 37 | "In this notebook, we compute various measures of class imbalance for each dataset. The metrics are computed on the data not used for model training. The metric we use in the paper is `Normalized fraction of mass in rarest 0.05 of classes`, since we find that this metric best captures the type of imbalance that is challenging for our problem setting." 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "id": "d104af2a", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "\n", 51 | "==== Dataset: imagenet ====\n", 52 | "Min count: 663\n", 53 | "Max count: 1201\n", 54 | "Min/max ratio: 0.552\n", 55 | "Normalized fraction of mass in rarest 0.05 of classes: 0.7905461250196218\n", 56 | "# of examples in rarest 0.05 of classes divided by expected number if uniform: 0.7905461250196218\n", 57 | "Normalized Shannon entropy: 0.9997548276966274\n", 58 | "[.25, .5, .75, .9] class count quantiles: [1159. 1168. 1176. 1184.]\n", 59 | "\n", 60 | "==== Dataset: cifar-100 ====\n", 61 | "Min count: 257\n", 62 | "Max count: 330\n", 63 | "Min/max ratio: 0.779\n", 64 | "Normalized fraction of mass in rarest 0.05 of classes: 0.9039999999999999\n", 65 | "# of examples in rarest 0.05 of classes divided by expected number if uniform: 0.904\n", 66 | "Normalized Shannon entropy: 0.9997848210317252\n", 67 | "[.25, .5, .75, .9] class count quantiles: [290. 301.5 310.25 316.1 ]\n", 68 | "\n", 69 | "==== Dataset: places365 ====\n", 70 | "Min count: 300\n", 71 | "Max count: 576\n", 72 | "Min/max ratio: 0.521\n", 73 | "Normalized fraction of mass in rarest 0.05 of classes: 0.7687123633122458\n", 74 | "# of examples in rarest 0.05 of classes divided by expected number if uniform: 0.7687123633122458\n", 75 | "Normalized Shannon entropy: 0.9995684899390082\n", 76 | "[.25, .5, .75, .9] class count quantiles: [493. 508. 523. 538.6]\n", 77 | "\n", 78 | "==== Dataset: inaturalist ====\n", 79 | "Min count: 250\n", 80 | "Max count: 68838\n", 81 | "Min/max ratio: 0.004\n", 82 | "Normalized fraction of mass in rarest 0.05 of classes: 0.12139784134651672\n", 83 | "# of examples in rarest 0.05 of classes divided by expected number if uniform: 0.12139784134651672\n", 84 | "Normalized Shannon entropy: 0.8596172827554369\n", 85 | "[.25, .5, .75, .9] class count quantiles: [ 373. 697. 1822. 5099.4]\n" 86 | ] 87 | } 88 | ], 89 | "source": [ 90 | "dataset_list = ['imagenet', 'cifar-100', 'places365', 'inaturalist']\n", 91 | "\n", 92 | "for dataset in dataset_list:\n", 93 | " print(f'\\n==== Dataset: {dataset} ====')\n", 94 | " softmax_scores, labels = load_dataset(dataset)\n", 95 | " cts = Counter(labels).values()\n", 96 | " cts = sorted(np.array(list(cts)))\n", 97 | " num_classes = len(cts)\n", 98 | " print('Min count:', min(cts))\n", 99 | " print('Max count:', max(cts))\n", 100 | " print(f'Min/max ratio: { min(cts)/max(cts):.3f}')\n", 101 | " frac = .05\n", 102 | " print(f'Normalized fraction of mass in rarest {frac} of classes: {(np.sum(cts[:int(frac*num_classes)])/len(labels)) / .05}')\n", 103 | " print(f'# of examples in rarest {frac} of classes divided by expected number if uniform: {np.sum(cts[:int(frac*num_classes)])/(len(labels) * .05)}') # Another view\n", 104 | " print('Normalized Shannon entropy:', entropy(cts) / np.log(len(cts))) # See https://stats.stackexchange.com/questions/239973/a-general-measure-of-data-set-imbalance\n", 105 | " print('[.25, .5, .75, .9] class count quantiles:', np.quantile(cts, [.25, .5, .75, .9]))" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "57d8a4ec", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [] 115 | } 116 | ], 117 | "metadata": { 118 | "kernelspec": { 119 | "display_name": "conformal_env", 120 | "language": "python", 121 | "name": "conformal_env" 122 | }, 123 | "language_info": { 124 | "codemirror_mode": { 125 | "name": "ipython", 126 | "version": 3 127 | }, 128 | "file_extension": ".py", 129 | "mimetype": "text/x-python", 130 | "name": "python", 131 | "nbconvert_exporter": "python", 132 | "pygments_lexer": "ipython3", 133 | "version": "3.10.4" 134 | } 135 | }, 136 | "nbformat": 4, 137 | "nbformat_minor": 5 138 | } 139 | -------------------------------------------------------------------------------- /notebooks/imagenet_case_study.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "f6e4867c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys; sys.path.append(\"../\") # For relative imports\n", 11 | "\n", 12 | "from utils.experiment_utils import *\n", 13 | "\n", 14 | "%load_ext autoreload\n", 15 | "%autoreload 2" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "323738a7", 21 | "metadata": {}, 22 | "source": [ 23 | "In this notebook, we investigate the class-conditional coverage properties of standard conformal on ImageNet. \n", 24 | "\n", 25 | "**Note**: Before running this notebook, run `sh run_experiment.sh` (or just a single experiment using standard conformal). \n", 26 | "\n", 27 | "Mapping from ImageNet labels to names: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "id": "34949a8e", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "dataset = 'imagenet'\n", 38 | "pth = f'../.cache/paper/varying_n/{dataset}/random_calset/n_totalcal=20/score=softmax/seed=0_allresults.pkl'" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "id": "f2a22c8d", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "with open(pth, 'rb') as f:\n", 49 | " results = pickle.load(f)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 4, 55 | "id": "47be7b22", 56 | "metadata": { 57 | "scrolled": false 58 | }, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEWCAYAAACJ0YulAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAqe0lEQVR4nO3debxVdb3/8ddbQCBERAUvikyKqKAg92iRCZYmpjngdcDuT8MJrYysHmamqTncvNchS1PDQrNBMgglRcNrKnrLARRRMSdEPYKAmkwKMXx+f6zv2WwO+xw255x99hnez8djP1jru6bPd+/D/uzvd631XYoIzMzMALYqdwBmZtZ0OCmYmVmOk4KZmeU4KZiZWY6TgpmZ5TgpmJlZjpOCNRhJB0uqbMD9jZH0REPtr8hj7iRphqTlkq5rzGMXiKVB30+zYjgptDCSPifpb5KWSvpQ0v9J2j8ta/Qv2WZoLPA+sG1EfLfcwTSGppp8JF0mKSSdkFfWNpX1KWL7Jlmvps5JoQWRtC1wH3AjsD2wC/AjYHU54yqGpLbljiHpDcyNOtzV2YTq0JJ8CFwuqU25A2ktnBRalj0AIuKuiFgXEZ9ExPSImCNpL+BWYJikFZI+ApB0pKTnJC2T9I6ky6p2JqlP+lX2VUlvS3pf0kV5yztKukPSPyXNBfbPD0bS9yW9kbpi5koalbdsTGrF/ETSh8BlknaQNDXF8jSwW22VzWsVfZRiH5PKu0i6U9ISSW9JuljSVnnHfULStSnuNyV9KS27A/gq8L30Hh0qqb2kGyQtSK8bJLVP6x8sqVLSBZLeA25Pv27/KOm3qd4vSNpD0oWSFqc4D8urw2mSXk7rzpN0drEftqTPSnomtQqfkfTZvGWPSroivcfLJU2XtGOR+31U0pXpvV0h6c/ps/ld+myeyf+lLumnqV7LJM2SdFDeso6Sfp3e65clfS//17uknSVNTp/Vm5LGVQvnQeBfwP+rIdb26bN8W9IiSbemY3YCHgB2TnVYIWnnYurf6kWEXy3kBWwLfAD8GvgS0LXa8jHAE9XKDgb2IfuBsC+wCDg2LesDBHAb0BEYTNbq2Cstvxp4nKxVsivwIlCZt+8TgJ3Tvk8CVgI98mJZC3wTaJv2PxG4G+gEDALerR5v3r57AcuBk4F2wA7AkLTsTuBeoHOqw6vAGXnHXQOcBbQBvgYsAJSW3wFcmXecy4Enge5AN+BvwBV5791a4L+B9qkOlwGrgJGpXncCbwIXpTjPAt7M2/+RZMlPwAjgY2Bo3v4ra6j/9sA/gVPScU5O8zuk5Y8Cb5D9UOiY5q+uYV8bHSet+3qKqwswN72Hh+bV6fa89f9fev/bAt8F3gM65P2NPAZ0BXoCc6qORfZ3MQu4BNga6AfMA0am5ZcBvwWOTuXt0jEC6JPWuQGYmt6PzsCfgR9v7v3zq5bvkXIH4FcDf6CwV/piq0xfWFOBndKyMdTwJZu3/Q3AT9J0n/QfsGfe8qeB0Wl6HnB43rKxtf0nBGYDx+TF8nbesjZkX9Z75pX9V03xAhcCUwqUtyFLXHvnlZ0NPJp33Nfzln0q1fHf0vwdbJwU3gCOyJsfCcxP0weT/YrtkLf8MuChvPmjgBVAmzTfOR1vuxrqdQ/wrbz915QUTgGerlb2d2BMmn4UuDhv2deBB2vY10bHSdtelDd/HfBAtTrNruVz/icwOO9vZGTesjPZkBQ+nf83kPe53p73Xv42TT9FlsBzSYEska4Edsvbfhgp6db2/vlV88vdRy1MRLwcEWMioifZr+2dyb7oC5L0aUmPpOb7UuAcoHo3w3t50x8D26TpnYF38pa9VW3fp0qanbp3Pkrx5O87f9tuZP/ha9xfNbuSfWFXtyPZr878bd8iO79SJVefiPg4TW5DYTsX2Fd+N8SSiFhVbZtFedOfAO9HxLq8+dzxJH1J0pPKLgr4CDiCTd//YuKqiq1gPdn4cytG9TpUn8/tS9J3U9fQ0lSHLmyoQ/W/kfzp3mTdOx/l/Y38ANipQDwXk7W2OuSVdSNL6rPytn8wlVsdOSm0YBHxD7JfvoOqigqs9nuy1sSuEdGF7LyDijzEQrIv5yq9qiYk9SbrdjqXrEtjO7Lupfx958ezhKxlU3B/BbxD4XMO75O1OHpX28+7teyrNgsK7GtB3nydhxlO5yYmA9eStea2A6ZR3PtfPa6q2OpazzpJ5w8uAE4k667cDljKhjosJOs2qpL/+b5D9qt+u7xX54g4ovpxIuIhsi6tr+cVv0+WoAbmbd8lIqoSloeArgMnhRZE0p7pV1vPNL8rWV/zk2mVRUBPSVvnbdYZ+DAiVkk6APjKFhzybuBCSV3TMb+Zt6wT2X/KJSmW09iQnDaRfkn/ieyE86ck7U120rcmvwMOlXSisssUd5A0JO3nbuAqSZ1TcvoOWd90XdwFXCypWzpRe0k99lXd1mTnIpYAa9MJ78Nq3yRnGrCHpK+k+p8E7E129Vlj6kyWzJcAbSVdQnZuq0r+38guZD8SqjwNLEsn6jtKaiNpkNIl1AVcBHyvaiYi1pP98PiJpO4AknaRNDKtsgjYQVKXBqhnq+Gk0LIsJ+unfUrSSrJk8CLZyT+AvwIvAe9Jej+VfZ3skr/lZF94d2/B8X5E1mXxJjAd+E3VgoiYS9YX/Xey/5z7AP+3mf2dS9Yt8R5ZC+f2mlaMiLfJulq+S3bZ4myyE+GQJaeVZP3ZT5C1hiZsQb3yXQnMJDtB+gLwbCqrt4hYDowje8//SZaQpxa57QfAl8nq/wHZl+WXI+L9WjdseH8hu8rnVbK/hVVs3EV0Odn5rTeB/wUmkS6RTgn8KGBIWv4+8Euy7qdNRMT/kSWSfBeQtSCelLQsHWNAWv8fZEl9Xupe8tVHRai64sLMrOQkfY3sQoUR5Y7FCnNLwcxKRlIPSQdK2krSALKWzZRyx2U18x2YZlZKWwO/APoCH5Hdi3JzOQOy2rn7yMzMctx9ZGZmOc26+2jHHXeMPn36lDsMMyujt1Zl9/D17lD9tg2ryaxZs96PiII3+TXrpNCnTx9mzpxZ7jDMrIzGvjoWgPF7jC9zJM2HpBpHC3D3kZmZ5TgpmJlZjpOCmZnlNOtzCoWsWbOGyspKVq2qPnClWd116NCBnj170q5du3KHYlZSLS4pVFZW0rlzZ/r06YNU7GCfZjWLCD744AMqKyvp27dvucMxK6kW1320atUqdthhBycEazCS2GGHHdz6tFahxSUFwAnBGpz/pqy1aJFJwczM6sZJoQTatGnDkCFDGDhwIIMHD+b6669n/fr1ddrXEUccwUcffVTnWC677DKuvfbaOm/f1C1cuJAvf/nLufkf//jH7L777gwYMIC//OUvBbd5/vnnGTZsGPvssw9HHXUUy5YtA+CFF15gzJgxjRG2WZPV4k40NwUdO3Zk9uzZACxevJivfOUrLF26lB/96EdbvK9p06ZtUpZ7wPZWTT+nlzrW66+/nrPOOguAuXPnMnHiRF566SUWLFjAoYceyquvvkqbNm022ubMM8/k2muvZcSIEUyYMIFrrrmGK664gn322YfKykrefvttevWq7Umg1hL0+f799dp+/tVHNlAkTUvT/1Zp5rp378748eO56aabiAjWrVvH+eefz/7778++++7LL37xCyD7xTt8+HCGDBnCoEGDePzxx4FsKI/333+f+fPns9dee/H1r3+doUOH8s4773DNNdfk9nPppZfmjnnVVVcxYMAADj30UF555ZWCcS1atIhRo0YxePBgBg8ezN/+9jcg+5IdNGgQgwYN4oYbbgDgggsu4OabN4x2fNlll3HdddcBFIyhUKxf+9rXqKioYODAgRvFOm3aNPbcc08+97nPMW7cuNyv/pUrV3L66aez//77s99++3HvvfcWrMfkyZM5/PDDAbj33nsZPXo07du3p2/fvuy+++48/XT1B3XBK6+8wvDhwwH44he/yOTJk3PLjjrqKCZOnFjwWGatQYtvKVSNi9JQ6jK+Sr9+/Vi/fj2LFy/m3nvvpUuXLjzzzDOsXr2aAw88kMMOO4w//elPjBw5kosuuoh169bx8ccfb7KfV155hdtvv52bb76Z6dOn89prr/H0008TERx99NHMmDGDTp06MXHiRJ577jnWrl3L0KFD+fd///dN9jVu3DhGjBjBlClTWLduHStWrGDWrFncfvvtPPXUU0QEn/70pxkxYgSjR4/mvPPO4+tfz56Zfvfdd/Pggw/WGEOvXr02ihWyRLX99tuzbt06DjnkEObMmcMee+zB2WefzYwZM+jbty8nn3xyLr6rrrqKL3zhC0yYMIGPPvqIAw44gEMPPZROnTrl1nnzzTfp2rUr7du3B+Ddd9/lM5/5TG55z549effdTZ9jP2jQIKZOncoxxxzDH//4R955Z8PTIysqKrj66qv53ve+t8l2Zq1Bi08KTUXVcyumT5/OnDlzmDRpEgBLly7ltddeY//99+f0009nzZo1HHvssQwZMmSTffTu3Tv3pTd9+nSmT5/OfvvtB8CKFSt47bXXWL58OaNGjeJTn/oUAEcffXTBeP76179y5513Atk5kC5duvDEE08watSo3Bfvcccdx+OPP864ceNYvHgxCxYsYMmSJXTt2pVevXrxs5/9rGAMvXr12ihWyBLJ+PHjWbt2LQsXLmTu3LmsX7+efv365a79P/nkkxk/fnyuflOnTs2dD1m1ahVvv/02e+21V26fCxcupFu3DQM9Fno2SKGrhiZMmMC4ceO4/PLLOfroo9l6661zy7p3786CBQsKvmdmrUGLTwpNYeTEefPm0aZNG7p3705EcOONNzJy5MhN1psxYwb3338/p5xyCueffz6nnnrqRsvzfyVHBBdeeCFnn332RuvccMMNdb58srYHLh1//PFMmjSJ9957j9GjR9caw/z58zf5RX/ttdfyzDPP0LVrV8aMGcOqVatqPV5EMHnyZAYMGFDjOh07dtzo3oGePXtu9Ku/srKSnXfe9Fnte+65J9OnTwfg1Vdf5f77N/Qtr1q1io4dO9Z4TLOWzucUSmzJkiWcc845nHvuuUhi5MiR3HLLLaxZswbIvpRWrlzJW2+9Rffu3TnrrLM444wzePbZZ2vd78iRI5kwYQIrVqwAsq6TxYsXM3z4cKZMmcInn3zC8uXL+fOf/1xw+0MOOYRbbrkFgHXr1rFs2TKGDx/OPffcw8cff8zKlSuZMmUKBx10EACjR49m4sSJTJo0ieOPP77WGKpbtmwZnTp1okuXLixatIgHHngAyL6c582bx/z58wH4wx/+sFH9brzxxlzieO655zbZ7x577JHbFrJW0cSJE1m9ejVvvvkmr732GgcccMAm21XFuH79eq688krOOeec3LJXX32VQYMGFXzPzFqDFt9SKIdPPvmEIUOGsGbNGtq2bcspp5zCd77zHSC78mX+/PkMHTqUiKBbt27cc889PProo1xzzTW0a9eObbbZJte1U5PDDjuMl19+mWHDhgGwzTbb8Nvf/pahQ4dy0kknMWTIEHr37p37Uq/upz/9KWPHjuVXv/oVbdq04ZZbbmHYsGGMGTMm90V65pln5rqGBg4cyPLly9lll13o0aNHrTFUv9pn8ODB7LfffgwcOJB+/fpx4IEHAtkv/ZtvvpnDDz+cHXfccaMv8B/+8Iecd9557LvvvkQEffr04b777ttov506dWK33Xbj9ddfZ/fdd2fgwIGceOKJ7L333rRt25af//znuVjOPPNMzjnnHCoqKrjrrrv4+c9/DmRdZKeddlpun4888ghHHtkyryoxK0azfkZzRUVFVH/Izssvv7xRv7M1bStWrGCbbbYhIvjGN75B//79+fa3v1309lOmTGHWrFlceeWV9Y5l9erVjBgxgieeeIK2bTf9veS/raaprg/Zac2XpEqaFREVhZa5+8jK6rbbbsvd6Ld06dJNzk9szqhRo2ioR7K+/fbbXH311QUTgllr4b9+K6tvf/vbW9QyKOTMM89skFj69+9P//79G2RfZs1VyVoKkiZIWizpxbyyP0ianV7zJc1O5X0kfZK37Nb6HLs5d4lZ0+S/KWstStlSuAO4CcidMY2Ik6qmJV0HLM1b/42IGFLfg3bo0IEPPvjAw2dbg6l6nkKHDh3KHYpZyZUsKUTEDEl9Ci1T9m19IvCFhj5uz549qaysZMmSJQ29a2vFqp68ZtbSleucwkHAooh4La+sr6TngGXAxRHxeKENJY0FxgIFBy1r166dn45lZlZH5br66GTgrrz5hUCviNgP+A7we0nbFtowIsZHREVEVOQPcWBmZvXX6ElBUlvgOCB3+2pErI6ID9L0LOANYI/Gjs3MrLUrR0vhUOAfEVFZVSCpm6Q2abof0B+YV4bYzMxatVJeknoX8HdggKRKSWekRaPZuOsIYDgwR9LzwCTgnIj4sFSxmZlZYaW8+ujkGsrHFCibDEzedG0zM2tMHubCzMxynBTMzCzHScHMzHKcFMzMLMdJwczMcpwUzMwsx0nBzMxynBTMzCzHT14zM6uDlvqMZ7cUzMwsx0nBzMxy3H1kZs1SVffNtgd/kM1PqF93jmXcUjAzsxwnBTMzy3FSMDOzHCcFMzPLcVIwM7McJwUzM8txUjAzs5ySJQVJEyQtlvRiXtllkt6VNDu9jshbdqGk1yW9ImlkqeIyM7OalbKlcAdweIHyn0TEkPSaBiBpb2A0MDBtc7OkNiWMzczMCihZUoiIGcCHRa5+DDAxIlZHxJvA68ABpYrNzMwKK8c5hXMlzUndS11T2S7AO3nrVKayTUgaK2mmpJlLliwpdaxmZq1KYyeFW4DdgCHAQuC6VK4C60ahHUTE+IioiIiKbt26lSRIM7PWqlGTQkQsioh1EbEeuI0NXUSVwK55q/YEFjRmbGZm1shJQVKPvNlRQNWVSVOB0ZLaS+oL9AeebszYzMyshENnS7oLOBjYUVIlcClwsKQhZF1D84GzASLiJUl3A3OBtcA3ImJdqWIzM7PCSpYUIuLkAsW/qmX9q4CrShWPmZltnu9oNjOzHCcFMzPLcVIwM7McJwUzM8txUjAzsxwnBTMzy3FSMDOznM0mBUknSOqcpi+W9CdJQ0sfmpmZNbZiWgo/jIjlkj4HjAR+TTawnZmZtTDFJIWq4SaOBG6JiHuBrUsXkpmZlUsxSeFdSb8ATgSmSWpf5HZmZtbMFPPlfiLwF+DwiPgI2B44v5RBmZlZeWw2KUTEx8Bi4HOpaC3wWimDMjOz8ijm6qNLgQuAC1NRO+C3pQzKzMzKo5juo1HA0cBKgIhYAHQuZVBmZlYexSSFf0VEkJ6ZLKlTaUMyM7NyKSYp3J2uPtpO0lnA/5I9X9nMzFqYzT55LSKulfRFYBkwALgkIh4qeWRmZtboijnR3An4a0ScT9ZC6CipXRHbTZC0WNKLeWXXSPqHpDmSpkjaLpX3kfSJpNnpdWvdq2RmZnVVTPfRDKC9pF3Iuo5OA+4oYrs7gMOrlT0EDIqIfYFX2XBFE8AbETEkvc4pYv9mZtbAikkKSvcqHAfcGBGjgL03t1FEzAA+rFY2PSLWptkngZ5bGK+ZmZVQUUlB0jDgP4H7U9lmz0UU4XTggbz5vpKek/SYpINqCWaspJmSZi5ZsqQBwjAzsyrFJIXzyLp5pkTES5L6AY/U56CSLiK7M/p3qWgh0Csi9gO+A/xe0raFto2I8RFREREV3bp1q08YZmZWTTFXHz0GPAYgaSvg/YgYV9cDSvoq8GXgkHT/AxGxGlidpmdJegPYA5hZ1+OYmdmWK+bqo99L2jZdhTQXeEVSnQbEk3Q42ZAZR6fzFFXl3SS1SdP9gP7AvLocw8zM6q6Y7qO9I2IZcCwwDegFnLK5jSTdBfwdGCCpUtIZwE1kQ2Q8VO3S0+HAHEnPA5OAcyLiw4I7NjOzkinmhHG7dF/CscBNEbFGUmxuo4g4uUDxr2pYdzIwuYhYzMyshIppKfwCmA90AmZI6k12d7OZmbUwxZxo/hnws7yityR9vnQhmZlZuRR1v4GkI4GBQIe84stLEpGZmZVNMVcf3QqcBHwTEHAC0LvEcZmZWRkUc07hsxFxKvDPiPgRMAzYtbRhmZlZORSTFD5J/34saWdgDdC3dCGZmVm5FHNO4b40xPU1wLNkT2D7ZSmDMjOz8ijm6qMr0uRkSfcBHSJiaWnDMjOzcqgxKUg6rpZlRMSfShOSmZmVS20thaNqWRaAk4KZWQtTY1KIiNMaMxAzMyu/Yu5T+K+qZymn+a6SrixpVGZmVhbFXH30pYj4QdVMRPxT0hHAxaULy8xauj7fv3/zK1mjK+Y+hTaS2lfNSOoItK9lfTMza6aKaSn8FnhY0u1kJ5hPB35d0qjMzKwsirlP4X8kzQEOJRv76IqI+EvJIzMzs0ZX1CipEfEg8GCJYzEzszIr5pyCmZm1Ek4KZmaWU2NSkPRw+ve/67JjSRMkLZb0Yl7Z9pIekvRa+rdr3rILJb0u6RVJI+tyTDMzq5/aWgo9JI0Ajpa0n6Sh+a8i9n0HcHi1su8DD0dEf+DhNI+kvYHRZE93Oxy4WVKbLayLmZnVU20nmi8h+9LuCVxfbVkAX6htxxExQ1KfasXHAAen6V8DjwIXpPKJEbEaeFPS68ABwN83WwMzM2swtY19NAmYJOmHecNn19dOEbEw7X+hpO6pfBfgybz1KlPZJiSNBcYC9OrVq4HCMjMzKPJ5CpKOBoanokcj4r4GjkOFDl1DPOOB8QAVFRUF1zEzs7opZkC8HwPfAuam17dSWV0sktQj7bcHsDiVV7Lxc597AgvqeAwzM6ujYm5eOxIYEhHrAST9GngOuLAOx5sKfBW4Ov17b1757yVdD+wM9AeersP+zayReEC7lqmoO5qB7YAP03SXYjaQdBfZSeUdJVUCl5Ilg7slnQG8DZwAEBEvSbqbrCWyFvhGRKwrMjYzM2sgxSSFHwPPSXqErO9/OEW0EiLi5BoWHVLD+lcBVxURj5mZlUgxJ5rvkvQosD9ZUrggIt4rdWBmZtb4ih0QbyFZv7+ZmbVgHvvIzMxynBTMzCyn1qQgaav8Ae3MzKxlqzUppHsTnpfk8STMzFqBYk409wBekvQ0sLKqMCKOLllUZmZWFsUkhR+VPAozM2sSirlP4TFJvYH+EfG/kj4F+FkHZmYtUDED4p0FTAJ+kYp2Ae4pYUxmZlYmxVyS+g3gQGAZQES8BnSvdQszM2uWikkKqyPiX1UzktpSw7MOzMyseSsmKTwm6QdAR0lfBP4I/Lm0YZmZWTkUkxS+DywBXgDOBqYBF5cyKDMzK49irj5anx6s8xRZt9ErEeHuIzOzeqjvQ4rmX31kA0Wysc0mBUlHArcCb5ANnd1X0tkR8UBJIjIzs7Ip5ua164DPR8TrAJJ2A+4HnBTMzFqYYs4pLK5KCMk8YHGJ4jEzszKqsaUg6bg0+ZKkacDdZOcUTgCeaYTYzMyskdXWfXRU3vQiYESaXgJ0resBJQ0A/pBX1A+4BNgOOCvtH+AHETGtrscxM7MtV2NSiIjTSnHAiHgFGAIgqQ3wLjAFOA34SURcW4rjmtnG6nv1i7VMxVx91Bf4JtAnf/0GGjr7EOCNiHhLUgPszszM6qOYq4/uAX5Fdhfz+gY+/mjgrrz5cyWdCswEvhsR/6y+gaSxwFiAXr387B8zs4ZUzNVHqyLiZxHxSEQ8VvWq74ElbQ0cTTZsBsAtwG5kXUsLyS6F3UREjI+Iioio6NatW33DMDOzPMW0FH4q6VJgOrC6qjAinq3nsb8EPBsRi9L+FlUtkHQbcF89929mZluomKSwD3AK8AU2dB9Fmq+Pk8nrOpLUIyIWptlRwIv13L+ZmW2hYpLCKKBf/vDZ9ZWe3vZFsgH2qvyPpCFkCWd+tWVmZtYIikkKz5PdQ9BgdzFHxMfADtXKTmmo/ZuZWd0UkxR2Av4h6Rk2PqfQEJekmplZE1JMUri05FGYmVmTUMzzFOp9+amZmTUPxdzRvJwNz2TeGmgHrIyIbUsZmJmZNb5iWgqd8+clHQscUKqAzMysfIq5o3kjEXEP9b9HwczMmqBiuo+Oy5vdCqhgQ3eSmZm1IMVcfZT/XIW1ZDeWHVOSaMzMrKyKOadQkucqmJlZ01Pb4zgvqWW7iIgrShCPmZmVUW0thZUFyjoBZ5ANUeGkYGbWwtT2OM7c8wwkdQa+RfbIzInU8KwDMzNr3mo9pyBpe+A7wH8CvwaGFnoampmZtQy1nVO4BjgOGA/sExErGi0qMzMri9puXvsusDNwMbBA0rL0Wi5pWeOEZ2Zmjam2cwpbfLezmZk1b/7iNzOzHCcFMzPLKWaYiwYnaT6wHFgHrI2IinSl0x+APmRDaZzoK53MzBpXOVsKn4+IIRFRkea/DzwcEf2Bh9O8mZk1oqbUfXQM2b0QpH+PLV8oZmatU7mSQgDTJc2SNDaV7RQRCwHSv90LbShprKSZkmYuWbKkkcI1M2sdynJOATgwIhZI6g48JOkfxW4YEePJbqijoqLCz3UwM2tAZWkpRMSC9O9iYArZ4z0XSeoBkP5dXI7YzMxas0ZvKUjqBGwVEcvT9GHA5cBU4KvA1enfexs7NrPmpM/37y93CNYClaP7aCdgiqSq4/8+Ih6U9Axwt6QzgLeBE8oQm5lZq9boSSEi5gGDC5R/ABzS2PGYmdkGTemSVDMzKzMnBTMzy3FSMDOzHCcFMzPLcVIwM7McJwUzM8txUjAzsxwnBTMzy3FSMDOzHCcFMzPLcVIwM7McJwUzM8txUjAzsxwnBTMzy3FSMDOznHI9o9ms1fOT06wpckvBzMxynBTMzCzHScHMzHIaPSlI2lXSI5JelvSSpG+l8sskvStpdnod0dixmZm1duU40bwW+G5EPCupMzBL0kNp2U8i4toyxGRmZpQhKUTEQmBhml4u6WVgl8aOw8zMNlXWcwqS+gD7AU+lonMlzZE0QVLXGrYZK2mmpJlLlixprFDNzFqFsiUFSdsAk4HzImIZcAuwGzCErCVxXaHtImJ8RFREREW3bt0aK1wzs1ahLElBUjuyhPC7iPgTQEQsioh1EbEeuA04oByxmZm1ZuW4+kjAr4CXI+L6vPIeeauNAl5s7NjMzFq7clx9dCBwCvCCpNmp7AfAyZKGAAHMB84uQ2xmZq1aOa4+egJQgUXTGjsWMzPbmO9oNjOzHCcFMzPLcVIwM7McP0/BrI78PARridxSMDOzHCcFMzPLcVIwM7Mcn1OwVsvnBMw25ZaCmZnlOCmYmVmOk4KZmeU4KZiZWY5PNFuz5RPFZg3PLQUzM8txUjAzsxwnBTMzy3FSMDOzHCcFMzPL8dVHVmf1vfpn/tVHNlAkZtZQmlxLQdLhkl6R9Lqk75c7HjOz1qRJtRQktQF+DnwRqASekTQ1IuaWN7KWydf5m1l1TSopAAcAr0fEPABJE4FjgCaZFNx9Uj9OSmZNjyKi3DHkSDoeODwizkzzpwCfjohz89YZC4xNswOAV+pxyB2B9+uxfXPT2uoLrnNr4Tpvmd4R0a3QgqbWUlCBso2yVkSMB8Y3yMGkmRFR0RD7ag5aW33BdW4tXOeG09RONFcCu+bN9wQWlCkWM7NWp6klhWeA/pL6StoaGA1MLXNMZmatRpPqPoqItZLOBf4CtAEmRMRLJTxkg3RDNSOtrb7gOrcWrnMDaVInms3MrLyaWveRmZmVkZOCmZnltPiksLlhMyQdLGmppNnpdUk54mxIxQwVkuo9W9JLkh5r7BgbWhGf8/l5n/GLktZJ2r4csTaUIurcRdKfJT2fPufTyhFnQyqizl0lTZE0R9LTkgaVI86GImmCpMWSXqxhuST9LL0fcyQNrfdBI6LFvshOVr8B9AO2Bp4H9q62zsHAfeWOtZHrvB3ZXeK90nz3csdd6jpXW/8o4K/ljrsRPucfAP+dprsBHwJblzv2Etf5GuDSNL0n8HC5465nnYcDQ4EXa1h+BPAA2T1enwGequ8xW3pLITdsRkT8C6gaNqMlK6bOXwH+FBFvA0TE4kaOsaFt6ed8MnBXo0RWOsXUOYDOkgRsQ5YU1jZumA2qmDrvDTwMEBH/APpI2qlxw2w4ETGD7HOryTHAnZF5EthOUo/6HLOlJ4VdgHfy5itTWXXDUhP7AUkDGye0kimmznsAXSU9KmmWpFMbLbrSKPZzRtKngMOByY0QVykVU+ebgL3IbgB9AfhWRKxvnPBKopg6Pw8cByDpAKA32U2wLVXRf/vFalL3KZTAZofNAJ4lGwdkhaQjgHuA/qUOrISKqXNb4N+BQ4COwN8lPRkRr5Y6uBIpps5VjgL+LyJq+/XVHBRT55HAbOALwG7AQ5Iej4hlJY6tVIqp89XATyXNJkuEz9G8W0ebsyV/+0Vp6S2FzQ6bERHLImJFmp4GtJO0Y+OF2OCKGSqkEngwIlZGxPvADGBwI8VXClsyPMpomn/XERRX59PIugkjIl4H3iTrZ2+uiv3/fFpEDAFOJTuX8majRdj4GnxooJaeFDY7bIakf0t9rlXNza2ADxo90oZTzFAh9wIHSWqbulM+DbzcyHE2pKKGR5HUBRhBVv/mrpg6v03WGiT1qw8A5jVqlA2rmP/P26VlAGcCM5pxy6gYU4FT01VInwGWRsTC+uywRXcfRQ3DZkg6Jy2/FTge+JqktcAnwOhIp/Wbo2LqHBEvS3oQmAOsB34ZEQUveWsOivycAUYB0yNiZZlCbTBF1vkK4A5JL5B1M1yQWobNUpF13gu4U9I6sivszihbwA1A0l1kV0juKKkSuBRoB7n6TiO7Aul14GOy1mH9jtmMv//MzKyBtfTuIzMz2wJOCmZmluOkYGZmOU4KZmaW46RgZmY5TgrWoNJ9HxMlvSFprqRpkvaQ1KemkR6bKkl3SDo+Tf9S0t5p+gfV1vtbKY9t1picFKzBpJsApwCPRsRuEbE32UidzXZAsioRcWZEzE2zP6i27LNlCKmkJLUpdwxWHk4K1pA+D6zJu1mMiJgdEY/nr5RaDY9Leja9PpvKe0iakffMg4MktUm/ml+U9IKkbxc6sKRT03jyz0v6TSrrLenhVP6wpF6p/I40Bv3fJM3Law1I0k2phXM/0D1v/49KqpB0NdAxxfi7tGxF3vbX5MV6Uio/OG0/SdI/JP0u7y76SyQ9k7YZX1VeE0m7S/rfVM9nJe1Wy3H/oGw8r6pt75D0H+k9vSYdd46ks/PifETS78nGDULSPcoGTXxJ0ti8fZ0h6dVUr9sk3ZTKu0manPb9jKQDa6uPNUHlHi/cr5bzAsYBP6lhWR/SmPDAp4AOabo/MDNNfxe4KE23ATqTDdz3UN5+tiuw74HAK8COaX779O+fga+m6dOBe9L0HcAfyX4U7U02HDNko2s+lI69M/ARcHxa9ihQkaZXVDv+ivTvf+RtvxPZMBM9yO5IXUo2Ls1WwN+Bz+XHmqZ/AxyVF+PxBer6FDAqTXdI72VNxx0F/DqtuzXZaJodgbHAxam8PTAT6JviXAn0zTte1XvZEXgR2CG9N/OB7cnurn0cuCmt9/u8uvUCXi7336VfW/Zq0cNcWJPVDrhJ0hBgHdlQ3pCNbTNBUjuyL/DZkuYB/STdCNwPTC+wvy8AkyIN4RAbRkAdRhpGmewL93/ytrknsmGk52rDePvDgbsiYh2wQNJft7Ben8vbfpGyJ9rtDywDno6ISgBlI3j2AZ4APi/pe2Rf7tsDL5Els01I6gzsEhFTUj1XpfKajvsA8DNJ7cmGC58REZ9IOgzYN++cRRey5PyvFGf+AHLjJI1K07um9f4NeKzqfZb0RzZ8hocCe+c1eLaV1Dkilm/B+2hl5KRgDeklsrGkNufbwCKykVm3AlZB9kARScOBI4HfSLomIu6UNJhsGOhvACdKupQNX5y3ko3rU8x4LfnrrM6bVg3rbKnaun7yj7cOaCupA3AzWQvkHUmXkf3639L9FyyPiFWSHiV7705iw+iwAr4ZEX/ZaCfSwWQthfz5Q4FhEfFx2leHWuKA7PMcFhGf1LKONWE+p2AN6a9Ae0lnVRVI2l/SiGrrdQEWpl/qp5B1eyCpN7A4Im4DfgUMVTaM+VYRMRn4ITA0It6JiCHpdSvZk7ZOlLRD2k/Vs5f/RjaSJsB/kv0yr80MYHTqc+9Bdo6kkDWpNVNo+5PS9t3IWh5P13K8qgTwvqRt2ExCjWy0z0pJxwJIaq9slNvajjuRbJC0g8gGkiP9+7WqOii7OqxTgUN2Af6ZEsKeZI97JO17hLLnIbcl676qMh04t2omtQatGXFLwRpMRETqarhB2UPVV5H1PZ9XbdWbgcmSTgAeYcOv04OB8yWtAVaQjYe/C3C7pKofMBcWOO5Lkq4CHlM2OuZzwBiycxwTJJ0PLGHzI0hOIeuKegF4FXishvXGA3MkPRsR/1lt+2FkT/8K4HsR8V76Qt1ERHwk6bZ0vPlk3WebcwrwC0mXA2uAE2o6blp/OnAnMDWyR1gC/JKs++rZdGJ7CXBsgWM9CJwjaQ7ZOZsnU9zvSvovsvMbC8hGI12athkH/Dxt05YsYZ1TRL2sifAoqWa2xSRtE9nTCtuSJaUJVec6rHlz95GZ1cVl6YT5i2RPNrunrNFYg3FLwczMctxSMDOzHCcFMzPLcVIwM7McJwUzM8txUjAzs5z/D/2e3I/XwssgAAAAAElFTkSuQmCC\n", 63 | "text/plain": [ 64 | "
" 65 | ] 66 | }, 67 | "metadata": { 68 | "needs_background": "light" 69 | }, 70 | "output_type": "display_data" 71 | } 72 | ], 73 | "source": [ 74 | "bins = np.linspace(0.5, 1, 20)\n", 75 | "plt.hist(results['standard'][2]['raw_class_coverages'], bins=bins)\n", 76 | "plt.xlabel('Class-conditional coverage')\n", 77 | "plt.ylabel('Number of classes')\n", 78 | "\n", 79 | "ymin, ymax = plt.ylim()\n", 80 | "plt.vlines(x=0.9, ymin=ymin, ymax=ymax, label='Desired coverage (0.9)', color='limegreen')\n", 81 | "plt.ylim(ymin, ymax)\n", 82 | "\n", 83 | "plt.title('Standard conformal on ImageNet')\n", 84 | "plt.legend(loc='upper left')\n", 85 | "plt.show()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "id": "e6700b77", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "Class-conditional coverages:\n", 99 | "[0.95848057 0.98090278 0.94117647 0.90685859 0.92560554 0.88640275\n", 100 | " 0.94952951 0.96210164 0.92355556 0.98613518 0.96681223 0.98515284\n", 101 | " 0.98526863 0.98339161 0.98594025 0.97535211 0.99136442 0.97554585\n", 102 | " 0.9554413 0.99035933 0.95175439 0.95934256 0.98415493 0.96167247\n", 103 | " 0.9877836 0.95192308 0.88 0.92699491 0.9470538 0.96327925\n", 104 | " 0.96709957 0.96007098 0.90933333 0.90517241 0.89063867 0.88763066\n", 105 | " 0.8558952 0.95742832 0.92650919 0.93825043 0.95829713 0.90932868\n", 106 | " 0.94596988 0.91758794 0.92643285 0.96832192 0.836966 0.96575342\n", 107 | " 0.98615917 0.93787748 0.95438596 0.98931434 0.83333333 0.92788879\n", 108 | " 0.8194325 0.86215979 0.95750217 0.96440972 0.90740741 0.91989437\n", 109 | " 0.74558304 0.93410508 0.84479167 0.95081967 0.92957746 0.90339426\n", 110 | " 0.79020979 0.88636364 0.85008666 0.9675491 0.96462468 0.96575342\n", 111 | " 0.98167539 0.77873812 0.78497409 0.97243755 0.96140351 0.94323144\n", 112 | " 0.94220665 0.95368782 0.90964379 0.97048611 0.91804708 0.92857143\n", 113 | " 0.97829861 0.95462478 0.7677643 0.98792062 0.9792567 0.97826087\n", 114 | " 0.98685364 0.97635727 0.97723292 0.97695035 0.98502203 0.97891037\n", 115 | " 0.9806338 0.96847636 0.96616915 0.92248062 0.98607485 0.90079017\n", 116 | " 0.99026549 0.96108597 0.97416021 0.99120493 0.96309315 0.96818573\n", 117 | " 0.91373239 0.97053726 0.94371152 0.94283276 0.87413194 0.94395797\n", 118 | " 0.92121212 0.95833333 0.95909487 0.96925705 0.8462867 0.90972222\n", 119 | " 0.93474427 0.88742304 0.9291958 0.92491468 0.85989492 0.95644599\n", 120 | " 0.91441048 0.95617879 0.92400691 0.98320071 0.99050086 0.95017483\n", 121 | " 0.95431034 0.96788194 0.90791738 0.98183391 0.9720524 0.98101812\n", 122 | " 0.97222222 0.97666379 0.97224631 0.97767857 0.97865073 0.97219809\n", 123 | " 0.96844873 0.98792062 0.98242531 0.95172414 0.98006932 0.97797357\n", 124 | " 0.9200695 0.84729494 0.84750733 0.94185022 0.94194107 0.90394511\n", 125 | " 0.94545455 0.93593074 0.7243173 0.92962641 0.94323144 0.92700088\n", 126 | " 0.84364261 0.85204526 0.92629482 0.7136294 0.83628319 0.79429429\n", 127 | " 0.8600175 0.94425087 0.80948276 0.87971905 0.91356958 0.90854185\n", 128 | " 0.9293617 0.82126348 0.90751945 0.95478261 0.95774648 0.84114583\n", 129 | " 0.76977153 0.94336283 0.89023336 0.93171608 0.81701031 0.79406632\n", 130 | " 0.8633218 0.93386243 0.84907834 0.85311699 0.92694611 0.91777188\n", 131 | " 0.85677308 0.85551664 0.91707799 0.94478528 0.89785408 0.91418564\n", 132 | " 0.88770999 0.89759036 0.83958152 0.81849913 0.87336245 0.93413174\n", 133 | " 0.81730769 0.941331 0.90432383 0.89098712 0.8469657 0.8969697\n", 134 | " 0.95196507 0.91695804 0.88888889 0.94234079 0.91465517 0.92678725\n", 135 | " 0.92254134 0.93593074 0.94044444 0.8961039 0.86194996 0.91860465\n", 136 | " 0.93554007 0.91725979 0.93017241 0.89991372 0.8558952 0.81659389\n", 137 | " 0.90877797 0.95263158 0.87735849 0.85221239 0.92354474 0.93085566\n", 138 | " 0.91659501 0.91790393 0.9200695 0.90901137 0.83217993 0.92974848\n", 139 | " 0.81778929 0.87348354 0.85901926 0.93868739 0.88811189 0.8858885\n", 140 | " 0.84330986 0.95829713 0.77413793 0.89035088 0.7890223 0.96980155\n", 141 | " 0.92343934 0.90139616 0.93874029 0.95993031 0.91659574 0.86701209\n", 142 | " 0.94587629 0.94791667 0.97212544 0.96748682 0.91382979 0.92676548\n", 143 | " 0.86548673 0.81746725 0.8784965 0.90496949 0.93558282 0.90250216\n", 144 | " 0.9083045 0.87734242 0.883821 0.88675958 0.92017167 0.94449263\n", 145 | " 0.97339056 0.93904594 0.85143354 0.95536481 0.88675958 0.91361257\n", 146 | " 0.70959378 0.97112861 0.9826087 0.88687392 0.96200345 0.91703057\n", 147 | " 0.87274291 0.99059829 0.95087719 0.98686515 0.98789974 0.97998259\n", 148 | " 0.97366111 0.96807593 0.97864769 0.96234676 0.85327511 0.98277347\n", 149 | " 0.92771084 0.97827976 0.93421053 0.89423904 0.86683849 0.92682927\n", 150 | " 0.94649123 0.94031142 0.93545535 0.94763948 0.90933333 0.91737649\n", 151 | " 0.9206626 0.92173913 0.89741379 0.98006932 0.95296167 0.95913043\n", 152 | " 0.9644714 0.95411255 0.95716783 0.99567848 0.96701389 0.99557913\n", 153 | " 0.97579948 0.95599655 0.96732588 0.9534687 0.95482189 0.94570928\n", 154 | " 0.93950178 0.9231441 0.96620451 0.97998259 0.96052632 0.96660482\n", 155 | " 0.96239316 0.91529818 0.96031061 0.99654875 0.99040976 0.88927944\n", 156 | " 0.9090106 0.9737303 0.97975352 0.8921484 0.96822595 0.98631309\n", 157 | " 0.94829097 0.92268041 0.97320657 0.97236615 0.92161929 0.92977778\n", 158 | " 0.98774081 0.97584124 0.78508772 0.90666667 0.8712522 0.94449339\n", 159 | " 0.9226087 0.97120419 0.94468832 0.96296296 0.98106713 0.96200345\n", 160 | " 0.97212544 0.92105263 0.91515152 0.95096322 0.88024476 0.8221831\n", 161 | " 0.96092362 0.90431034 0.92116538 0.95996519 0.98607485 0.93971631\n", 162 | " 0.92125984 0.92173913 0.78097731 0.78987993 0.95733333 0.96187175\n", 163 | " 0.94199134 0.94148021 0.94661922 0.99120493 0.99465717 0.89965695\n", 164 | " 0.92138063 0.92959002 0.970726 0.9745167 0.95855379 0.94303243\n", 165 | " 0.97610619 0.98421053 0.90529974 0.91225022 0.92795139 0.90917186\n", 166 | " 0.92602263 0.96660959 0.95789474 0.95986038 0.89947552 0.91116974\n", 167 | " 0.88917526 0.84622068 0.9506383 0.90627687 0.71378092 0.85539001\n", 168 | " 0.84939759 0.73093777 0.9372313 0.92524186 0.82268227 0.78050922\n", 169 | " 0.92857143 0.82661996 0.87554777 0.907585 0.9122807 0.93674177\n", 170 | " 0.88425047 0.83666377 0.86026201 0.90297203 0.9537925 0.9250646\n", 171 | " 0.85914261 0.85248714 0.85863874 0.84155844 0.79861711 0.91034483\n", 172 | " 0.8845815 0.94637817 0.83496007 0.88626422 0.90522586 0.91428571\n", 173 | " 0.93981083 0.86262799 0.81315789 0.84219554 0.92838654 0.91002571\n", 174 | " 0.97569444 0.89767842 0.86869565 0.93333333 0.84499557 0.88601036\n", 175 | " 0.89382071 0.80299033 0.91896552 0.87640449 0.82854656 0.82563208\n", 176 | " 0.82623805 0.71043478 0.80856643 0.90267176 0.97152718 0.93419913\n", 177 | " 0.83461211 0.85038693 0.79930495 0.91838488 0.91543156 0.80895009\n", 178 | " 0.91507799 0.95316565 0.9720035 0.85862966 0.81880932 0.79531657\n", 179 | " 0.90853659 0.85466795 0.71699741 0.93151888 0.91371872 0.8115688\n", 180 | " 0.87920621 0.75218914 0.77413793 0.87435009 0.91184097 0.88530806\n", 181 | " 0.87783595 0.92724046 0.74336283 0.90017212 0.93033135 0.85021645\n", 182 | " 0.88501742 0.76659751 0.97764402 0.73856209 0.80363322 0.85575049\n", 183 | " 0.82493369 0.84142114 0.85047537 0.82653061 0.8415493 0.87314086\n", 184 | " 0.95070423 0.96097138 0.86113537 0.82619863 0.84843206 0.79487179\n", 185 | " 0.84117125 0.9083045 0.84347826 0.82136602 0.90123457 0.85605338\n", 186 | " 0.90414508 0.79126638 0.81989708 0.94479074 0.83433995 0.75234842\n", 187 | " 0.88792354 0.86451049 0.82882096 0.87516088 0.90043668 0.98146514\n", 188 | " 0.87392055 0.95175439 0.81708449 0.9797891 0.85065502 0.82033304\n", 189 | " 0.95126197 0.84594835 0.7211704 0.87347295 0.86089907 0.90220264\n", 190 | " 0.93367786 0.92413793 0.91513561 0.88908297 0.88534397 0.85980392\n", 191 | " 0.93859649 0.81717687 0.9614711 0.93263342 0.87291849 0.87105038\n", 192 | " 0.82560137 0.83000867 0.95070423 0.93280977 0.9256993 0.94347826\n", 193 | " 0.9717564 0.94015611 0.93989547 0.81668194 0.94655172 0.94641314\n", 194 | " 0.85201401 0.90463918 0.85764192 0.96587927 0.9505737 0.94325021\n", 195 | " 0.98945518 0.92559787 0.8830156 0.91202091 0.88179465 0.87359307\n", 196 | " 0.78761823 0.92638037 0.81433506 0.63914027 0.9135274 0.74070796\n", 197 | " 0.84092863 0.81184669 0.85869565 0.86879433 0.90980736 0.84020619\n", 198 | " 0.96983141 0.86188811 0.76100629 0.94759825 0.93398751 0.92107546\n", 199 | " 0.61149826 0.8948291 0.92467532 0.97442681 0.95851721 0.89137931\n", 200 | " 0.84201236 0.97525597 0.85664639 0.93556929 0.90239574 0.93613298\n", 201 | " 0.97001764 0.83840139 0.95902354 0.82214473 0.87929515 0.86558219\n", 202 | " 0.69913043 0.78119658 0.76943005 0.90854185 0.87065972 0.82730455\n", 203 | " 0.79177603 0.92367067 0.76876618 0.9469496 0.94736842 0.87817704\n", 204 | " 0.9244713 0.743521 0.75131349 0.77641921 0.90932868 0.88484252\n", 205 | " 0.79444926 0.8684669 0.80137575 0.75720524 0.96017699 0.82414698\n", 206 | " 0.95907928 0.83727034 0.84118674 0.98366294 0.89896373 0.83678756\n", 207 | " 0.88533333 0.95611015 0.76909871 0.90138408 0.86230637 0.90322581\n", 208 | " 0.88504754 0.84816754 0.85516045 0.80642361 0.93345009 0.80378657\n", 209 | " 0.91386736 0.97110333 0.80254154 0.76167315 0.7486911 0.89554795\n", 210 | " 0.86631944 0.90600522 0.93782837 0.94086957 0.92407248 0.93968531\n", 211 | " 0.96542783 0.73485514 0.8788143 0.89566929 0.84142716 0.84432718\n", 212 | " 0.78173077 0.93183779 0.8359375 0.78947368 0.94189072 0.86171132\n", 213 | " 0.90941073 0.96413322 0.90744467 0.95633188 0.89519651 0.88044693\n", 214 | " 0.93139842 0.85788562 0.77300613 0.82692308 0.95742832 0.87305699\n", 215 | " 0.72679965 0.92449517 0.87202118 0.93073593 0.76835664 0.91229579\n", 216 | " 0.88898601 0.90622261 0.89170306 0.86147186 0.91006843 0.8962766\n", 217 | " 0.87561214 0.86730269 0.81818182 0.82585752 0.90700344 0.92052402\n", 218 | " 0.83912612 0.95611015 0.8399654 0.88810573 0.80489939 0.94618056\n", 219 | " 0.84750219 0.87057808 0.83612335 0.92089249 0.94306418 0.80158033\n", 220 | " 0.93362069 0.92748433 0.84501062 0.88602151 0.91368788 0.55039439\n", 221 | " 0.94050744 0.65104167 0.90079017 0.89700704 0.95171203 0.84494774\n", 222 | " 0.87085515 0.96937883 0.83227446 0.89496157 0.80839895 0.85065502\n", 223 | " 0.81164384 0.84215168 0.90305677 0.8956044 0.7037037 0.88859878\n", 224 | " 0.88378144 0.93712317 0.92774309 0.86133333 0.70454545 0.96066434\n", 225 | " 0.8501292 0.90940767 0.85137457 0.94185022 0.87894737 0.85864794\n", 226 | " 0.80772532 0.87826087 0.76855124 0.91035683 0.92024014 0.78571429\n", 227 | " 0.92869875 0.84426947 0.93133047 0.8373102 0.80444444 0.84688581\n", 228 | " 0.90450928 0.87598253 0.90933099 0.88888889 0.81298036 0.94516971\n", 229 | " 0.92595819 0.96768559 0.73209549 0.83013937 0.6897747 0.93642612\n", 230 | " 0.91035683 0.82276281 0.79861111 0.87647593 0.85211268 0.90388007\n", 231 | " 0.80574913 0.89193825 0.87401575 0.91803279 0.87640449 0.83661249\n", 232 | " 0.9344894 0.78165939 0.96754386 0.92857143 0.97558849 0.93179805\n", 233 | " 0.80749129 0.92447917 0.90113736 0.89544236 0.89044698 0.79268293\n", 234 | " 0.82589771 0.7258248 0.95711679 0.62195122 0.96915167 0.95021834\n", 235 | " 0.93368237 0.87673611 0.76274165 0.90564374 0.98962835 0.95445545\n", 236 | " 0.94596913 0.86086957 0.73426573 0.95225694 0.8487395 0.8407699\n", 237 | " 0.77244987 0.94759825 0.86026201 0.90425532 0.96474635 0.94334764\n", 238 | " 0.86257563 0.89354276 0.6537133 0.80229479 0.74211503 0.93379791\n", 239 | " 0.79529617 0.75837321 0.86615516 0.94493783 0.86758383 0.77065026\n", 240 | " 0.87826087 0.95172414 0.7840708 0.85886403 0.92348285 0.82414698\n", 241 | " 0.76255319 0.96779661 0.91273375 0.88418323 0.8826087 0.91103203\n", 242 | " 0.90534619 0.86543536 0.84111311 0.85839161 0.8781331 0.971078\n", 243 | " 0.89938217 0.87190813 0.90971625 0.91884817 0.71221178 0.87982833\n", 244 | " 0.9288225 0.93203883 0.85867446 0.93920705 0.94987035 0.90529974\n", 245 | " 0.80352423 0.9206626 0.90592334 0.86339755 0.90034662 0.91383812\n", 246 | " 0.82450043 0.81500873 0.85082394 0.70042194 0.95438898 0.94667832\n", 247 | " 0.96512642 0.90748899 0.9619469 0.91716567 0.84475282 0.82413793\n", 248 | " 0.89382071 0.95462329 0.88184188 0.90743945 0.69223986 0.49956635\n", 249 | " 0.971113 0.85074627 0.78218695 0.82909728 0.84750219 0.78916173\n", 250 | " 0.89229341 0.81358885 0.87456446 0.8125 0.78829604 0.66839378\n", 251 | " 0.82704126 0.88120567 0.9213691 0.96223176 0.93648069 0.90635739\n", 252 | " 0.93744493 0.93187773 0.91494845 0.83101392 0.91782007 0.78664354\n", 253 | " 0.89264069 0.85595568 0.9404878 0.92489083 0.83840139 0.80960699\n", 254 | " 0.90751945 0.90652557 0.91463415 0.96187175 0.92175408 0.91093474\n", 255 | " 0.92974848 0.93652174 0.91398784 0.90393013 0.86996337 0.93684211\n", 256 | " 0.89708405 0.89102005 0.93858131 0.93934142 0.95495495 0.76186368\n", 257 | " 0.89670139 0.93362445 0.90158172 0.891748 0.93787748 0.9237435\n", 258 | " 0.90901213 0.97222222 0.97398092 0.95395308 0.94031142 0.96296296\n", 259 | " 0.8600175 0.8772688 0.89168111 0.94385965 0.9283247 0.94658494\n", 260 | " 0.79391304 0.88860104 0.71538462 0.81506196 0.91179039 0.91559001\n", 261 | " 0.90713672 0.84027778 0.98699046 0.79570815 0.83333333 0.79965157\n", 262 | " 0.83520276 0.8685567 0.9453125 0.941331 0.90853659 0.94230769\n", 263 | " 0.95804196 0.98091934 0.98846495 0.77729636 0.94390027 0.9622807\n", 264 | " 0.95724258 0.94107452 0.94622723 0.96143734 0.96214789 0.87640449\n", 265 | " 0.85639687 0.95888014 0.81081081 0.70593293]\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "print('Class-conditional coverages:')\n", 271 | "print(results['standard'][2]['raw_class_coverages'])" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 6, 277 | "id": "1d2d0a40", 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "name": "stdout", 282 | "output_type": "stream", 283 | "text": [ 284 | "==== Most undercovered class under standard ====\n", 285 | "Class: 899\n", 286 | "Coverage under standard : 0.49956634865568084\n", 287 | "Coverage under clustered: 0.7476149176062445\n", 288 | "Coverage under classwise: 0.8907198612315698\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "print('==== Most undercovered class under standard ====')\n", 294 | "most_undercov = results['standard'][2]['raw_class_coverages'].argmin()\n", 295 | "print('Class:', most_undercov)\n", 296 | "print('Coverage under standard :', results['standard'][2]['raw_class_coverages'][most_undercov])\n", 297 | "print('Coverage under clustered:', results['cluster_random'][2]['raw_class_coverages'][most_undercov])\n", 298 | "print('Coverage under classwise:', results['classwise'][2]['raw_class_coverages'][most_undercov])" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "id": "9f10a372", 304 | "metadata": {}, 305 | "source": [ 306 | "Class 889 corresponds to water jug" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 7, 312 | "id": "1d69e4d2", 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stdout", 317 | "output_type": "stream", 318 | "text": [ 319 | "==== Most overcovered class under standard ====\n", 320 | "Class: 339\n", 321 | "Coverage under standard : 0.996548748921484\n", 322 | "Coverage under clustered: 0.9836065573770492\n", 323 | "Coverage under classwise: 0.9301121656600517\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "print('==== Most overcovered class under standard ====')\n", 329 | "most_overcov = results['standard'][2]['raw_class_coverages'].argmax()\n", 330 | "print('Class:', most_overcov)\n", 331 | "print('Coverage under standard :', results['standard'][2]['raw_class_coverages'][most_overcov])\n", 332 | "print('Coverage under clustered:', results['cluster_random'][2]['raw_class_coverages'][most_overcov])\n", 333 | "print('Coverage under classwise:', results['classwise'][2]['raw_class_coverages'][most_overcov])" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "id": "a4e975e4", 339 | "metadata": {}, 340 | "source": [ 341 | "Class 339 corresponds to \"sorrel,\" which is not a very common thing. Let's find another overcovered class that is more familiar" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 8, 347 | "id": "98c5416d", 348 | "metadata": {}, 349 | "outputs": [ 350 | { 351 | "name": "stdout", 352 | "output_type": "stream", 353 | "text": [ 354 | "Classes with coverage above 99%:\n" 355 | ] 356 | }, 357 | { 358 | "data": { 359 | "text/plain": [ 360 | "(array([ 16, 19, 102, 105, 130, 289, 321, 323, 339, 340, 387, 388]),)" 361 | ] 362 | }, 363 | "execution_count": 8, 364 | "metadata": {}, 365 | "output_type": "execute_result" 366 | } 367 | ], 368 | "source": [ 369 | "print('Classes with coverage above 99%:')\n", 370 | "np.where(results['standard'][2]['raw_class_coverages'] > .99)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "id": "f52fc18a", 376 | "metadata": {}, 377 | "source": [ 378 | "Some classes that correspond to common things:\n", 379 | "* 102 = koala\n", 380 | "* 105 = flamingo\n", 381 | "* 289 = leopard" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 9, 387 | "id": "8eafc8bc", 388 | "metadata": {}, 389 | "outputs": [ 390 | { 391 | "name": "stdout", 392 | "output_type": "stream", 393 | "text": [ 394 | "Class-conditional coverage of Class 105:\n" 395 | ] 396 | }, 397 | { 398 | "data": { 399 | "text/plain": [ 400 | "0.9912049252418645" 401 | ] 402 | }, 403 | "execution_count": 9, 404 | "metadata": {}, 405 | "output_type": "execute_result" 406 | } 407 | ], 408 | "source": [ 409 | "cls = 105\n", 410 | "\n", 411 | "print(f'Class-conditional coverage of Class {cls}:')\n", 412 | "results['standard'][2]['raw_class_coverages'][cls]" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "id": "f0340541", 418 | "metadata": {}, 419 | "source": [ 420 | "## Some addititional code for comparing against other conformal methods" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 10, 426 | "id": "ff8ed2e3", 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "# # For comparison\n", 431 | "# print('==== Most undercovered class under clustered ====')\n", 432 | "# method = 'cluster_proportional'\n", 433 | "# most_undercov = results[method][2]['raw_class_coverages'].argmin()\n", 434 | "# print('Class:', most_undercov)\n", 435 | "# print('Coverage:', results[method][2]['raw_class_coverages'][most_undercov])" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 11, 441 | "id": "bb378d48", 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "# n_totalcal = 20\n", 446 | "# score = 'softmax'\n", 447 | "# folder = f'/home/tding/code/class-conditional-conformal/.cache/paper/varying_n/{dataset}/random_calset/n_totalcal={n_totalcal}/score={score}'\n", 448 | "# plot_class_coverage_histogram(folder, desired_cov=0.9, vmin=.5, vmax=1, nbins=30, \n", 449 | "# methods = ['standard', 'classwise', 'cluster_random'],\n", 450 | "# title=f'ImageNet, n={n_totalcal}, {score}')" 451 | ] 452 | } 453 | ], 454 | "metadata": { 455 | "kernelspec": { 456 | "display_name": "conformal_env", 457 | "language": "python", 458 | "name": "conformal_env" 459 | }, 460 | "language_info": { 461 | "codemirror_mode": { 462 | "name": "ipython", 463 | "version": 3 464 | }, 465 | "file_extension": ".py", 466 | "mimetype": "text/x-python", 467 | "name": "python", 468 | "nbconvert_exporter": "python", 469 | "pygments_lexer": "ipython3", 470 | "version": "3.10.4" 471 | } 472 | }, 473 | "nbformat": 4, 474 | "nbformat_minor": 5 475 | } 476 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gdown # for downloading Google Drive files 2 | ipykernel 3 | joblib 4 | jupyter 5 | pandas 6 | matplotlib 7 | numpy 8 | scikit-learn 9 | seaborn 10 | tensorflow 11 | torch 12 | torchvision 13 | tqdm 14 | -------------------------------------------------------------------------------- /run_experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils.experiment_utils import * 4 | 5 | 6 | if __name__ == "__main__": 7 | 8 | parser = argparse.ArgumentParser(description='Run experiment') 9 | parser.add_argument('dataset', type=str, choices=['imagenet', 'cifar-100', 'places365', 'inaturalist'], 10 | help='Name of the dataset to train model on') 11 | parser.add_argument('avg_num_per_class', type=int, 12 | help='Number of examples per class, on average, to include in calibration dataset') 13 | parser.add_argument('-score_functions', type=str, nargs='+', 14 | help='Conformal score functions to use. List with a space in between. Options are' 15 | '"softmax", "APS", "RAPS"') 16 | parser.add_argument('-methods', type=str, nargs='+', 17 | help='Conformal methods to use. List with a space in between. Options include' 18 | '"standard", "classwise", "classwise_default_standard", "always_cluster"') 19 | parser.add_argument('-seeds', type=int, nargs='+', 20 | help='Seeds for random splits into calibration and validation sets,' 21 | 'List with spaces in between') 22 | 23 | 24 | parser.add_argument('--calibration_sampling', type=str, default='random', 25 | help='How to sample the calibration set. Options are "random" and "balanced"') 26 | parser.add_argument('--frac_clustering', type=float, default=-1, 27 | help='For clustered conformal: the fraction of data used for clustering.' 28 | 'If frac_clustering and num_clusters are both -1, then a heuristic will be used to choose these values.') 29 | parser.add_argument('--num_clusters', type=int, default=-1, 30 | help='For clustered conformal: the number of clusters to request' 31 | 'If frac_clustering and num_clusters are both -1, then a heuristic will be used to choose these values.') 32 | parser.add_argument('--alpha', type=float, default=0.1, 33 | help='Desired coverage is 1-alpha') 34 | parser.add_argument('--save_folder', type=str, default='.cache/paper/varying_n', 35 | help='Folder to save results to') 36 | 37 | 38 | args = parser.parse_args() 39 | if args.frac_clustering != -1 and args.num_clusters != -1: 40 | run_one_experiment(args.dataset, args.save_folder, args.alpha, 41 | args.avg_num_per_class, args.score_functions, args.methods, args.seeds, 42 | cluster_args={'frac_clustering': args.frac_clustering, 'num_clusters': args.num_clusters}, 43 | save_preds=False, calibration_sampling=args.calibration_sampling) 44 | else: # choose frac_clustering and num_clusters automatically 45 | run_one_experiment(args.dataset, args.save_folder, args.alpha, 46 | args.avg_num_per_class, args.score_functions, args.methods, args.seeds, 47 | save_preds=False, calibration_sampling=args.calibration_sampling) -------------------------------------------------------------------------------- /run_experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # print some info for context 4 | pwd | xargs -I{} echo "Current directory:" {} 5 | hostname | xargs -I{} echo "Node:" {} 6 | 7 | # Run all experiments in parallel. To run sequentially, remove "&" 8 | for calibration_sampling in 'random'; 9 | do for dataset in 'imagenet' 'cifar-100' 'places365' 'inaturalist'; 10 | do for n in 10 20 30 40 50 75 100 150; 11 | do python3 run_experiment.py $dataset $n -score_functions softmax APS RAPS -methods standard classwise cluster_random exact_coverage_standard exact_coverage_classwise exact_coverage_cluster --calibration_sampling $calibration_sampling -seeds 0 1 2 3 4 5 6 7 8 9 & 12 | done; 13 | done; 14 | done 15 | 16 | 17 | ## Run a single experiment 18 | # python run_experiment.py cifar-100 30 -score_functions softmax APS -methods standard always_cluster -seeds 0 1 19 | 20 | -------------------------------------------------------------------------------- /run_heatmap_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -N 1 # number of nodes requested 4 | #SBATCH -n 20 # number of tasks (i.e. processes) 5 | #SBATCH -t 0-12:00 # time requested (D-HH:MM) 6 | #SBATCH -o /home/tding/slurm_output/heatmaps.%j.out # STDOUT 7 | #SBATCH -e /home/tding/slurm_output _/heatmaps.%j.err # STDERR 8 | 9 | 10 | 11 | # print some info for context 12 | pwd | xargs -I{} echo "Current directory:" {} 13 | hostname | xargs -I{} echo "Node:" {} 14 | 15 | # Run all experiments 16 | calibration_sampling='random' 17 | dataset='imagenet' 18 | 19 | for n in 10 50; 20 | do for frac_clustering in .1 .2 .3 .4 .5 .6 .7 .8 .9; 21 | do for num_clusters in 2 3 4 5 6 8 10 15 20 50; 22 | do save_folder=".cache/paper/heatmaps/${dataset}/frac=${frac_clustering}_numclusters=${num_clusters}" 23 | echo "Save folder: ${save_folder}" 24 | python3 run_experiment.py $dataset $n -score_functions softmax APS RAPS -methods cluster_random --calibration_sampling $calibration_sampling -seeds 0 1 2 3 4 5 6 7 8 9 --frac_clustering $frac_clustering --num_clusters $num_clusters --save_folder $save_folder & 25 | done; 26 | done; 27 | done 28 | 29 | -------------------------------------------------------------------------------- /utils/clustering_utils.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | 4 | from scipy import stats 5 | from sklearn.cluster import KMeans 6 | 7 | 8 | #======================================== 9 | # Computing embeddings for k-means 10 | #======================================== 11 | 12 | def quantile_embedding(samples, q=[0.5, 0.6, 0.7, 0.8, 0.9]): 13 | ''' 14 | Computes the q-quantiles of samples and returns the vector of quantiles 15 | ''' 16 | return np.quantile(samples, q) 17 | 18 | def embed_all_classes(scores_all, labels, q=[0.5, 0.6, 0.7, 0.8, 0.9], return_cts=False): 19 | ''' 20 | Input: 21 | - scores_all: num_instances x num_classes array where 22 | scores_all[i,j] = score of class j for instance i 23 | Alternatively, num_instances-length array where scores_all[i] = score of true class for instance i 24 | - labels: num_instances-length array of true class labels 25 | - q: quantiles to include in embedding 26 | - return_cts: if True, return an array containing the counts for each class 27 | 28 | Output: 29 | - embeddings: num_classes x len(q) array where ith row is the embeddings of class i 30 | - (Optional) cts: num_classes-length array where cts[i] = # of times class i 31 | appears in labels 32 | ''' 33 | num_classes = len(np.unique(labels)) 34 | 35 | embeddings = np.zeros((num_classes, len(q))) 36 | cts = np.zeros((num_classes,)) 37 | 38 | for i in range(num_classes): 39 | if len(scores_all.shape) == 2: 40 | class_i_scores = scores_all[labels==i,i] 41 | else: 42 | class_i_scores = scores_all[labels==i] 43 | cts[i] = class_i_scores.shape[0] 44 | embeddings[i,:] = quantile_embedding(class_i_scores, q=q) 45 | 46 | if return_cts: 47 | return embeddings, cts 48 | else: 49 | return embeddings 50 | 51 | 52 | #======================================== 53 | # Generating synthetic data 54 | #======================================== 55 | 56 | def generate_synthetic_clustered_data(num_clusters, num_classes, num_samples_per_class, 57 | cluster_probs=None, dist_between_means=1000, sd=1): 58 | ''' 59 | Generate clusters where cluster i is a N(i*dist_between_means, 1) distribution 60 | Randomly assign classes to clusters with probabilities determined by cluster_probs. Then sample 61 | num_samples_per_class from each class. 62 | 63 | Inputs: 64 | - num_clusters: Number of clusters 65 | - num_classes: Total number of classes 66 | - num_samples_per_class: Number of samples to generate per class 67 | - cluster_probs: If None, then every class has equal probability of being assigned 68 | to each cluster. Otherwise, it must be an array of probabilities of length num_clusters 69 | such that cluster_probs[i] = probability that a class is assigned to cluster i 70 | - dist_between_means: Distance between means of Normal distributions 71 | - sd = Standard deviation of Normal distributions 72 | 73 | Output: cluster_assignments, samples 74 | - cluster_assignments: (num_classes,) array of cluster assignments 75 | - samples: (num_classes, num_samples_per_class) array containing the generated samples 76 | ''' 77 | cluster_assignments = np.zeros((num_classes,)) 78 | samples = np.zeros((num_classes, num_samples_per_class)) 79 | 80 | for i in range(num_classes): 81 | cluster_assignments[i] = np.random.choice(np.arange(num_clusters), p=cluster_probs) 82 | samples[i,:] = np.random.normal(loc=cluster_assignments[i] * dist_between_means, 83 | scale=sd, 84 | size=(num_samples_per_class,)) 85 | 86 | return cluster_assignments, samples 87 | 88 | 89 | def sample_from_empirical_distr(data, num_samples): 90 | samples = np.random.choice(data, size=num_samples) 91 | 92 | return samples 93 | 94 | def generate_realistic_clustered_data(samples_list, 95 | num_classes, 96 | num_samples_per_class, 97 | cluster_probs=None): 98 | ''' 99 | Generate clusters where cluster i has the same distribution as the samples in samples_list[i]. 100 | Randomly assign classes to clusters with probabilities determined by cluster_probs. Then sample 101 | num_samples_per_class from each class. 102 | 103 | Inputs: 104 | - samples_list: num_cluster length list, where samples_list[i] is an 105 | array of samples from distribution i 106 | - num_classes: Total number of classes 107 | - num_samples_per_class: Number of samples to generate per class 108 | - cluster_probs: If None, then every class has equal probability of being assigned 109 | to each cluster. Otherwise, it must be an array of probabilities of length num_clusters 110 | such that cluster_probs[i] = probability that a class is assigned to cluster i 111 | 112 | Output: cluster_assignments, samples 113 | - cluster_assignments: (num_classes,) array of cluster assignments 114 | - samples: (num_classes, num_samples_per_class) array containing the generated samples 115 | ''' 116 | 117 | num_clusters = len(samples_list) 118 | 119 | cluster_assignments = np.zeros((num_classes,), dtype=int) 120 | samples = np.zeros((num_classes, num_samples_per_class)) 121 | for i in range(num_classes): 122 | cluster_assignments[i] = np.random.choice(np.arange(num_clusters), p=cluster_probs) 123 | samples[i,:] = sample_from_empirical_distr(samples_list[cluster_assignments[i]], num_samples_per_class) 124 | 125 | return cluster_assignments, samples -------------------------------------------------------------------------------- /utils/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import glob # For getting file names 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import pickle 6 | import seaborn as sns 7 | 8 | import pdb 9 | 10 | from collections import Counter 11 | from scipy import stats, cluster 12 | 13 | from utils.clustering_utils import * 14 | from utils.conformal_utils import * 15 | 16 | 17 | # Used for processing iNaturalist dataset 18 | def remove_rare_classes(softmax_scores, labels, thresh = 250): 19 | ''' 20 | Filter out classes with fewer than thresh examples 21 | (removes full rows and softmax score entries corresponding to those classes) 22 | 23 | Note: Make sure to use raw softmax scores instead of 1-softmax in order for 24 | normalization to work correctly 25 | ''' 26 | classes, cts = np.unique(labels, return_counts=True) 27 | non_rare_classes = classes[cts >= thresh] 28 | print(f'Data preprocessing: Keeping {len(non_rare_classes)} of {len(classes)} classes that have >= {thresh} examples') 29 | 30 | # Filter labels and re-index 31 | remaining_label_idx = np.isin(labels, non_rare_classes) 32 | labels = labels[remaining_label_idx] 33 | new_idx = 0 34 | mapping = {} # old to new 35 | for i, label in enumerate(labels): 36 | if label not in mapping: 37 | mapping[label] = new_idx 38 | new_idx += 1 39 | labels[i] = mapping[label] 40 | 41 | # Remove rows and columns corresponding to rare classes from scores matrix 42 | softmax_scores = softmax_scores[remaining_label_idx,:] 43 | new_softmax_scores = np.zeros((len(labels), len(non_rare_classes))) 44 | for k in non_rare_classes: 45 | new_softmax_scores[:, mapping[k]] = softmax_scores[:,k] 46 | 47 | # Renormalize each row to sum to 1 48 | new_softmax_scores = new_softmax_scores / np.expand_dims(np.sum(new_softmax_scores, axis=1), axis=1) 49 | 50 | return new_softmax_scores, labels 51 | 52 | 53 | 54 | def load_dataset(dataset, data_folder='data'): 55 | ''' 56 | Load softmax scores and labels for a dataset 57 | 58 | Input: 59 | - dataset: string specifying dataset. Options are 'imagenet', 'cifar-100', 'places365', 'inaturalist' 60 | - data_folder: string specifying folder containing the .npz files 61 | 62 | Output: softmax_scores, labels 63 | 64 | ''' 65 | assert dataset in ['imagenet', 'cifar-100', 'places365', 'inaturalist'] 66 | 67 | 68 | data = np.load(f'{data_folder}/{dataset}.npz') 69 | softmax_scores = data['softmax'] 70 | labels = data['labels'] 71 | 72 | return softmax_scores, labels 73 | 74 | 75 | def run_one_experiment(dataset, save_folder, alpha, n_totalcal, score_function_list, methods, seeds, 76 | cluster_args={'frac_clustering':'auto', 'num_clusters':'auto'}, 77 | save_preds=False, calibration_sampling='random', save_labels=False): 78 | ''' 79 | Run experiment and save results 80 | 81 | Inputs: 82 | - dataset: string specifying dataset. Options are 'imagenet', 'cifar-100', 'places365', 'inaturalist' 83 | - n_totalcal: *average* number of examples per class. Calibration dataset is generated by sampling 84 | n_totalcal x num_classes examples uniformly at random 85 | - methods: List of conformal calibration methods. Options are 'standard', 'classwise', 86 | 'classwise_default_standard', 'cluster_proportional', 'cluster_doubledip','cluster_random' 87 | -cluster_args: Dict of arguments to be bassed into cluster_random 88 | - save_preds: if True, the val prediction sets are included in the saved outputs 89 | - calibration_sampling: Method for sampling calibration dataset. Options are 90 | 'random' or 'balanced' 91 | - save_labels: If True, save the labels for each random seed in {save_folder}seed={seed}_labels.npy 92 | ''' 93 | np.random.seed(0) 94 | 95 | softmax_scores, labels = load_dataset(dataset) 96 | 97 | for score_function in score_function_list: 98 | curr_folder = os.path.join(save_folder, f'{dataset}/{calibration_sampling}_calset/n_totalcal={n_totalcal}/score={score_function}') 99 | os.makedirs(curr_folder, exist_ok=True) 100 | 101 | print(f'====== score_function={score_function} ======') 102 | 103 | print('Computing conformal score...') 104 | if score_function == 'softmax': 105 | scores_all = 1 - softmax_scores 106 | elif score_function == 'APS': 107 | scores_all = get_APS_scores_all(softmax_scores, randomize=True) 108 | elif score_function == 'RAPS': 109 | # RAPS hyperparameters (currently using ImageNet defaults) 110 | lmbda = .01 111 | kreg = 5 112 | 113 | scores_all = get_RAPS_scores_all(softmax_scores, lmbda, kreg, randomize=True) 114 | else: 115 | raise Exception('Undefined score function') 116 | 117 | for seed in seeds: 118 | print(f'\nseed={seed}') 119 | save_to = os.path.join(curr_folder, f'seed={seed}_allresults.pkl') 120 | if os.path.exists(save_to): 121 | with open(save_to,'rb') as f: 122 | all_results = pickle.load(f) 123 | print('Loaded existing results file containing results for', list(all_results.keys())) 124 | else: 125 | all_results = {} # Each value is (qhat(s), preds, coverage_metrics, set_size_metrics) 126 | 127 | # Split data 128 | if calibration_sampling == 'random': 129 | totalcal_scores_all, totalcal_labels, val_scores_all, val_labels = random_split(scores_all, 130 | labels, 131 | n_totalcal, 132 | seed=seed) 133 | elif calibration_sampling == 'balanced': 134 | num_classes = scores_all.shape[1] 135 | totalcal_scores_all, totalcal_labels, val_scores_all, val_labels = split_X_and_y(scores_all, 136 | labels, n_totalcal, num_classes, 137 | seed=seed, split='balanced') 138 | else: 139 | raise Exception('Invalid calibration_sampling option') 140 | 141 | # Inspect class imbalance of total calibration set 142 | cts = Counter(totalcal_labels).values() 143 | print(f'Class counts range from {min(cts)} to {max(cts)}') 144 | 145 | for method in methods: 146 | print(f'----- dataset={dataset}, n={n_totalcal},score_function={score_function}, seed={seed}, method={method} ----- ') 147 | 148 | if method == 'standard': 149 | # Standard conformal 150 | all_results[method] = standard_conformal(totalcal_scores_all, totalcal_labels, 151 | val_scores_all, val_labels, alpha) 152 | 153 | elif method == 'classwise': 154 | # Classwise conformal 155 | all_results[method] = classwise_conformal(totalcal_scores_all, totalcal_labels, 156 | val_scores_all, val_labels, alpha, 157 | num_classes=totalcal_scores_all.shape[1], 158 | default_qhat=np.inf, regularize=False) 159 | 160 | elif method == 'classwise_default_standard': 161 | # Classwise conformal, but use standard qhat as default value instead of infinity 162 | all_results[method] = classwise_conformal(totalcal_scores_all, totalcal_labels, 163 | val_scores_all, val_labels, alpha, 164 | num_classes=totalcal_scores_all.shape[1], 165 | default_qhat='standard', regularize=False) 166 | elif method == 'classwise_default_max': 167 | # Classwise conformal, but use largest conformal score in calibration dataset for each y 168 | # as default value instead of infinity 169 | all_results[method] = classwise_conformal(totalcal_scores_all, totalcal_labels, 170 | val_scores_all, val_labels, alpha, 171 | num_classes=totalcal_scores_all.shape[1], 172 | default_qhat='max', regularize=False) 173 | 174 | elif method == 'cluster_proportional': 175 | # Clustered conformal with proportionally sampled clustering set 176 | all_results[method] = clustered_conformal(totalcal_scores_all, totalcal_labels, 177 | alpha, 178 | val_scores_all, val_labels, 179 | split='proportional') 180 | 181 | elif method == 'cluster_doubledip': 182 | # Clustered conformal with double dipping for clustering and calibration 183 | all_results[method] = clustered_conformal(totalcal_scores_all, totalcal_labels, 184 | alpha, 185 | val_scores_all, val_labels, 186 | split='doubledip') 187 | 188 | elif method == 'cluster_random': 189 | # [RECOMMENDED] Clustered conformal with double dipping for clustering and calibration 190 | all_results[method] = clustered_conformal(totalcal_scores_all, totalcal_labels, 191 | alpha, 192 | val_scores_all, val_labels, 193 | frac_clustering=cluster_args['frac_clustering'], 194 | num_clusters=cluster_args['num_clusters'], 195 | split='random') 196 | elif method == 'regularized_classwise': 197 | # Empirical-Bayes-inspired regularized classwise conformal (shrink class qhats to standard) 198 | all_results[method] = classwise_conformal(totalcal_scores_all, totalcal_labels, 199 | val_scores_all, val_labels, alpha, 200 | num_classes=totalcal_scores_all.shape[1], 201 | default_qhat='standard', regularize=True) 202 | 203 | elif method == 'exact_coverage_standard': 204 | # Apply randomization to qhat to achieve exact coverage 205 | all_results[method] = standard_conformal(totalcal_scores_all, totalcal_labels, 206 | val_scores_all, val_labels, alpha, 207 | exact_coverage=True) 208 | 209 | elif method == 'exact_coverage_classwise': 210 | # Apply randomization to qhats to achieve exact coverage 211 | all_results[method] = classwise_conformal(totalcal_scores_all, totalcal_labels, 212 | val_scores_all, val_labels, alpha, 213 | num_classes=totalcal_scores_all.shape[1], 214 | default_qhat=np.inf, regularize=False, 215 | exact_coverage=True) 216 | 217 | 218 | elif method == 'exact_coverage_cluster': 219 | # Apply randomization to qhats to achieve exact coverage 220 | all_results[method] = clustered_conformal(totalcal_scores_all, totalcal_labels, 221 | alpha, 222 | val_scores_all, val_labels, 223 | frac_clustering=cluster_args['frac_clustering'], 224 | num_clusters=cluster_args['num_clusters'], 225 | split='random', 226 | exact_coverage=True) 227 | 228 | else: 229 | raise Exception('Invalid method selected') 230 | 231 | # Optionally remove predictions from saved output to reduce memory usage 232 | if not save_preds: 233 | for m in all_results.keys(): 234 | all_results[m] = (all_results[m][0], None, all_results[m][2], all_results[m][3]) 235 | 236 | # Optionally save val labels 237 | if save_labels: 238 | save_labels_to = os.path.join(curr_folder, f'seed={seed}_labels.npy') 239 | np.save(save_labels_to, val_labels) 240 | print(f'Saved labels to {save_labels_to}') 241 | 242 | # Save results 243 | with open(save_to,'wb') as f: 244 | pickle.dump(all_results, f) 245 | print(f'Saved results to {save_to}') 246 | 247 | # Helper function 248 | def initialize_metrics_dict(methods): 249 | 250 | metrics = {} 251 | for method in methods: 252 | metrics[method] = {'class_cov_gap': [], 253 | 'max_class_cov_gap': [], 254 | 'avg_set_size': [], 255 | 'marginal_cov': [], 256 | 'very_undercovered': [], 257 | 'undercov_gap': [], 258 | 'overcov_gap': []} # Could also retrieve other metrics 259 | 260 | return metrics 261 | 262 | # Original version, without undercov_gap and overcov_gap 263 | # def average_results_across_seeds(folder, print_results=True, display_table=True, show_seed_ct=False, 264 | # methods=['standard', 'classwise', 'cluster_balanced'], 265 | # max_seeds=np.inf): 266 | # ''' 267 | # Input: 268 | # - max_seeds: If we discover more than max_seeds random seeds, only use max_seeds of them 269 | # ''' 270 | 271 | 272 | # file_names = sorted(glob.glob(os.path.join(folder, '*.pkl'))) 273 | # num_seeds = len(file_names) 274 | # if show_seed_ct: 275 | # print('Number of seeds found:', num_seeds) 276 | # if max_seeds < np.inf and num_seeds > max_seeds: 277 | # print(f'Only using {max_seeds} seeds') 278 | # file_names = file_names[:max_seeds] 279 | 280 | # metrics = initialize_metrics_dict(methods) 281 | 282 | # for pth in file_names: 283 | # with open(pth, 'rb') as f: 284 | # results = pickle.load(f) 285 | 286 | # for method in methods: 287 | # try: 288 | # metrics[method]['class_cov_gap'].append(results[method][2]['mean_class_cov_gap']) 289 | # metrics[method]['avg_set_size'].append(results[method][3]['mean']) 290 | # metrics[method]['max_class_cov_gap'].append(results[method][2]['max_gap']) 291 | # metrics[method]['marginal_cov'].append(results[method][2]['marginal_cov']) 292 | # metrics[method]['very_undercovered'].append(results[method][2]['very_undercovered']) 293 | # except: 294 | # print(f'Missing {method} in {pth}') 295 | 296 | # # print(folder) 297 | # # for method in methods: 298 | # # print(method, metrics[method]['class_cov_gap']) 299 | 300 | # cov_means = [] 301 | # cov_ses = [] 302 | # set_size_means = [] 303 | # set_size_ses = [] 304 | # max_cov_gap_means = [] 305 | # max_cov_gap_ses = [] 306 | # marginal_cov_means = [] 307 | # marginal_cov_ses = [] 308 | # very_undercovered_means = [] 309 | # very_undercovered_ses = [] 310 | 311 | # if print_results: 312 | # print('Avg class coverage gap for each random seed:') 313 | # for method in methods: 314 | # n = num_seeds 315 | # if print_results: 316 | # print(f' {method}:', np.array(metrics[method]['class_cov_gap'])*100) 317 | # cov_means.append(np.mean(metrics[method]['class_cov_gap'])) 318 | # cov_ses.append(np.std(metrics[method]['class_cov_gap'])/np.sqrt(n)) 319 | 320 | # set_size_means.append(np.mean(metrics[method]['avg_set_size'])) 321 | # set_size_ses.append(np.std(metrics[method]['avg_set_size'])/np.sqrt(n)) 322 | 323 | # max_cov_gap_means.append(np.mean(metrics[method]['max_class_cov_gap'])) 324 | # max_cov_gap_ses.append(np.std(metrics[method]['max_class_cov_gap'])/np.sqrt(n)) 325 | 326 | # marginal_cov_means.append(np.mean(metrics[method]['marginal_cov'])) 327 | # marginal_cov_ses.append(np.std(metrics[method]['marginal_cov'])/np.sqrt(n)) 328 | 329 | # very_undercovered_means.append(np.mean(metrics[method]['very_undercovered'])) 330 | # very_undercovered_ses.append(np.std(metrics[method]['very_undercovered'])/np.sqrt(n)) 331 | 332 | # df = pd.DataFrame({'method': methods, 333 | # 'class_cov_gap_mean': np.array(cov_means)*100, 334 | # 'class_cov_gap_se': np.array(cov_ses)*100, 335 | # 'max_class_cov_gap_mean': np.array(max_cov_gap_means)*100, 336 | # 'max_class_cov_gap_se': np.array(max_cov_gap_ses)*100, 337 | # 'avg_set_size_mean': set_size_means, 338 | # 'avg_set_size_se': set_size_ses, 339 | # 'marginal_cov_mean': marginal_cov_means, 340 | # 'marginal_cov_se': marginal_cov_ses, 341 | # 'very_undercovered_mean': very_undercovered_means, 342 | # 'very_undercovered_se': very_undercovered_ses}) 343 | 344 | # if display_table: 345 | # display(df) # For Jupyter notebooks 346 | 347 | # return df 348 | 349 | def average_results_across_seeds(folder, print_results=True, display_table=True, show_seed_ct=False, 350 | methods=['standard', 'classwise', 'cluster_balanced'], 351 | max_seeds=np.inf): 352 | ''' 353 | Input: 354 | - max_seeds: If we discover more than max_seeds random seeds, only use max_seeds of them 355 | ''' 356 | 357 | 358 | file_names = sorted(glob.glob(os.path.join(folder, '*.pkl'))) 359 | num_seeds = len(file_names) 360 | if show_seed_ct: 361 | print('Number of seeds found:', num_seeds) 362 | if max_seeds < np.inf and num_seeds > max_seeds: 363 | print(f'Only using {max_seeds} seeds') 364 | file_names = file_names[:max_seeds] 365 | 366 | metrics = initialize_metrics_dict(methods) 367 | 368 | for pth in file_names: 369 | with open(pth, 'rb') as f: 370 | results = pickle.load(f) 371 | 372 | for method in methods: 373 | try: 374 | metrics[method]['class_cov_gap'].append(results[method][2]['mean_class_cov_gap']) 375 | metrics[method]['avg_set_size'].append(results[method][3]['mean']) 376 | metrics[method]['max_class_cov_gap'].append(results[method][2]['max_gap']) 377 | metrics[method]['marginal_cov'].append(results[method][2]['marginal_cov']) 378 | metrics[method]['very_undercovered'].append(results[method][2]['very_undercovered']) 379 | metrics[method]['undercov_gap'].append(results[method][2]['undercov_gap']) # ADDED 380 | metrics[method]['overcov_gap'].append(results[method][2]['overcov_gap']) # ADDED 381 | except: 382 | print(f'Missing {method} in {pth}') 383 | 384 | # print(folder) 385 | # for method in methods: 386 | # print(method, metrics[method]['class_cov_gap']) 387 | 388 | cov_means = [] 389 | cov_ses = [] 390 | set_size_means = [] 391 | set_size_ses = [] 392 | max_cov_gap_means = [] 393 | max_cov_gap_ses = [] 394 | marginal_cov_means = [] 395 | marginal_cov_ses = [] 396 | very_undercovered_means = [] 397 | very_undercovered_ses = [] 398 | undercov_means = [] 399 | undercov_ses = [] 400 | overcov_means = [] 401 | overcov_ses = [] 402 | 403 | if print_results: 404 | print('Avg class coverage gap for each random seed:') 405 | for method in methods: 406 | n = num_seeds 407 | if print_results: 408 | print(f' {method}:', np.array(metrics[method]['class_cov_gap'])*100) 409 | cov_means.append(np.mean(metrics[method]['class_cov_gap'])) 410 | cov_ses.append(np.std(metrics[method]['class_cov_gap'])/np.sqrt(n)) 411 | 412 | set_size_means.append(np.mean(metrics[method]['avg_set_size'])) 413 | set_size_ses.append(np.std(metrics[method]['avg_set_size'])/np.sqrt(n)) 414 | 415 | max_cov_gap_means.append(np.mean(metrics[method]['max_class_cov_gap'])) 416 | max_cov_gap_ses.append(np.std(metrics[method]['max_class_cov_gap'])/np.sqrt(n)) 417 | 418 | marginal_cov_means.append(np.mean(metrics[method]['marginal_cov'])) 419 | marginal_cov_ses.append(np.std(metrics[method]['marginal_cov'])/np.sqrt(n)) 420 | 421 | very_undercovered_means.append(np.mean(metrics[method]['very_undercovered'])) 422 | very_undercovered_ses.append(np.std(metrics[method]['very_undercovered'])/np.sqrt(n)) 423 | 424 | undercov_means.append(np.mean(metrics[method]['undercov_gap'])) 425 | undercov_ses.append(np.std(metrics[method]['undercov_gap'])/np.sqrt(n)) 426 | 427 | overcov_means.append(np.mean(metrics[method]['overcov_gap'])) 428 | overcov_ses.append(np.std(metrics[method]['overcov_gap'])/np.sqrt(n)) 429 | 430 | df = pd.DataFrame({'method': methods, 431 | 'class_cov_gap_mean': np.array(cov_means)*100, 432 | 'class_cov_gap_se': np.array(cov_ses)*100, 433 | 'max_class_cov_gap_mean': np.array(max_cov_gap_means)*100, 434 | 'max_class_cov_gap_se': np.array(max_cov_gap_ses)*100, 435 | 'avg_set_size_mean': set_size_means, 436 | 'avg_set_size_se': set_size_ses, 437 | 'marginal_cov_mean': marginal_cov_means, 438 | 'marginal_cov_se': marginal_cov_ses, 439 | 'very_undercovered_mean': very_undercovered_means, 440 | 'very_undercovered_se': very_undercovered_ses, 441 | 'undercov_gap_mean': np.array(undercov_means)*100, 442 | 'undercov_gap_se': np.array(undercov_ses)*100, 443 | 'overcov_gap_mean': np.array(overcov_means)*100, 444 | 'overcov_gap_se': np.array(overcov_ses)*100}) 445 | 446 | if display_table: 447 | display(df) # For Jupyter notebooks 448 | 449 | return df 450 | 451 | # Helper function for get_metric_df 452 | def initialize_dict(metrics, methods, suffixes=['mean', 'se']): 453 | d = {} 454 | for suffix in suffixes: 455 | for metric in metrics: 456 | d[f'{metric}_{suffix}'] = {} 457 | 458 | for method in methods: 459 | 460 | d[f'{metric}_{suffix}'][method] = [] 461 | 462 | 463 | return d 464 | 465 | def get_metric_df(dataset, cal_sampling, metric, 466 | score_function, 467 | method_list = ['standard', 'classwise', 'cluster_random'], 468 | n_list = [10, 20, 30, 40, 50, 75, 100, 150], 469 | show_seed_ct=False, 470 | print_folder=True, 471 | save_folder='../.cache/paper/varying_n'): # May have to update this path 472 | ''' 473 | Similar to average_results_across_seeds 474 | ''' 475 | 476 | aggregated_results = initialize_dict([metric], method_list) 477 | 478 | for n_totalcal in n_list: 479 | 480 | curr_folder = f'{save_folder}/{dataset}/{cal_sampling}_calset/n_totalcal={n_totalcal}/score={score_function}' 481 | if print_folder: 482 | print(curr_folder) 483 | 484 | df = average_results_across_seeds(curr_folder, print_results=False, 485 | display_table=False, methods=method_list, max_seeds=10, 486 | show_seed_ct=show_seed_ct) 487 | 488 | for method in method_list: 489 | 490 | for suffix in ['mean', 'se']: # Extract mean and SE 491 | 492 | aggregated_results[f'{metric}_{suffix}'][method].append(df[f'{metric}_{suffix}'][df['method']==method].values[0]) 493 | 494 | return aggregated_results 495 | 496 | # Not used in paper 497 | def plot_class_coverage_histogram(folder, desired_cov=None, vmin=.6, vmax=1, nbins=30, 498 | title=None, methods=['standard', 'classwise', 'always_cluster']): 499 | ''' 500 | For each method, aggregate class coverages across all random seeds and then 501 | plot density/histogram. This is equivalent to estimating a density for each 502 | random seed individually then averaging. 503 | 504 | Inputs: 505 | - folder: (str) containing path to folder of saved results 506 | - desired_cov: (float) Desired coverage level 507 | - vmin, vmax: (floats) Specify bin edges 508 | - 509 | ''' 510 | sns.set_style(style='white', rc={'axes.spines.right': False, 'axes.spines.top': False}) 511 | sns.set_palette('pastel') 512 | sns.set_context('talk') # 'paper', 'talk', 'poster' 513 | 514 | # For plotting 515 | map_to_label = {'standard': 'Standard', 516 | 'classwise': 'Classwise', 517 | 'cluster_random': 'Clustered',} 518 | map_to_color = {'standard': 'gray', 519 | 'classwise': 'lightcoral', 520 | 'cluster_random': 'dodgerblue'} 521 | 522 | bin_edges = np.linspace(vmin,vmax,nbins+1) 523 | 524 | file_names = sorted(glob.glob(os.path.join(folder, '*.pkl'))) 525 | num_seeds = len(file_names) 526 | print('Number of seeds found:', num_seeds) 527 | 528 | # OPTION 1: Plot average with 95% CIs 529 | cts_dict = {} 530 | for method in methods: 531 | cts_dict[method] = np.zeros((num_seeds, nbins)) 532 | 533 | for i, pth in enumerate(file_names): 534 | with open(pth, 'rb') as f: 535 | results = pickle.load(f) 536 | 537 | for method in methods: 538 | 539 | cts, _ = np.histogram(results[method][2]['raw_class_coverages'], bins=bin_edges) 540 | cts_dict[method][i,:] = cts 541 | 542 | for method in methods: 543 | bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 544 | graph = sns.lineplot(x=np.tile(bin_centers, num_seeds), y=np.ndarray.flatten(cts_dict[method]), 545 | label=map_to_label[method], color=map_to_color[method]) 546 | 547 | if desired_cov is not None: 548 | graph.axvline(desired_cov, color='black', linestyle='dashed', label='Desired coverage') 549 | 550 | plt.xlabel('Class-conditional coverage') 551 | plt.ylabel('Number of classes') 552 | plt.title(title) 553 | plt.ylim(bottom=0) 554 | plt.xlim(right=vmax) 555 | plt.legend() 556 | plt.show() 557 | 558 | # OPTION 2: Plot average, no CIs 559 | # class_coverages = {} 560 | # for method in methods: 561 | # class_coverages[method] = [] 562 | 563 | # for pth in file_names: 564 | # with open(pth, 'rb') as f: 565 | # results = pickle.load(f) 566 | 567 | # for method in methods: 568 | # class_coverages[method].append(results[method][2]['raw_class_coverages']) 569 | 570 | # bin_edges = np.linspace(vmin,vmax,30) # Can adjust 571 | 572 | # for method in methods: 573 | # aggregated_scores = np.concatenate(class_coverages[method], axis=0) 574 | # cts, _ = np.histogram(aggregated_scores, bins=bin_edges, density=False) 575 | # cts = cts / num_seeds 576 | # plt.plot((bin_edges[:-1] + bin_edges[1:]) / 2, cts, '-o', label=method, alpha=0.7) 577 | 578 | # plt.xlabel('Class-conditional coverage') 579 | # plt.ylabel('Number of classes') 580 | # plt.legend() 581 | 582 | # # OPTION 3: Plot separate lines 583 | # class_coverages = {} 584 | # for method in methods: 585 | # class_coverages[method] = [] 586 | 587 | # for pth in file_names: 588 | # with open(pth, 'rb') as f: 589 | # results = pickle.load(f) 590 | 591 | # for method in methods: 592 | # class_coverages[method].append(results[method][2]['raw_class_coverages']) 593 | 594 | # bin_edges = np.linspace(vmin,vmax,30) # Can adjust 595 | 596 | # for method in methods: 597 | # for class_covs in class_coverages[method]: 598 | # cts, _ = np.histogram(class_covs, bins=bin_edges, density=False) 599 | # plt.plot((bin_edges[:-1] + bin_edges[1:]) / 2, cts, '-', alpha=0.3, 600 | # label=map_to_label[method], color=map_to_color[method]) 601 | 602 | # plt.xlabel('Class-conditional coverage') 603 | # plt.ylabel('Number of classes') 604 | # plt.show() 605 | # plt.legend() 606 | 607 | # For square-root scaling in plots 608 | import matplotlib.scale as mscale 609 | import matplotlib.pyplot as plt 610 | import matplotlib.transforms as mtransforms 611 | import matplotlib.ticker as ticker 612 | import numpy as np 613 | 614 | class SquareRootScale(mscale.ScaleBase): 615 | """ 616 | ScaleBase class for generating square root scale. 617 | """ 618 | 619 | name = 'squareroot' 620 | 621 | def __init__(self, axis, **kwargs): 622 | # note in older versions of matplotlib (<3.1), this worked fine. 623 | # mscale.ScaleBase.__init__(self) 624 | 625 | # In newer versions (>=3.1), you also need to pass in `axis` as an arg 626 | mscale.ScaleBase.__init__(self, axis) 627 | 628 | def set_default_locators_and_formatters(self, axis): 629 | axis.set_major_locator(ticker.AutoLocator()) 630 | axis.set_major_formatter(ticker.ScalarFormatter()) 631 | axis.set_minor_locator(ticker.NullLocator()) 632 | axis.set_minor_formatter(ticker.NullFormatter()) 633 | 634 | def limit_range_for_scale(self, vmin, vmax, minpos): 635 | return max(0., vmin), vmax 636 | 637 | class SquareRootTransform(mtransforms.Transform): 638 | input_dims = 1 639 | output_dims = 1 640 | is_separable = True 641 | 642 | def transform_non_affine(self, a): 643 | return np.array(a)**0.5 644 | 645 | def inverted(self): 646 | return SquareRootScale.InvertedSquareRootTransform() 647 | 648 | class InvertedSquareRootTransform(mtransforms.Transform): 649 | input_dims = 1 650 | output_dims = 1 651 | is_separable = True 652 | 653 | def transform(self, a): 654 | return np.array(a)**2 655 | 656 | def inverted(self): 657 | return SquareRootScale.SquareRootTransform() 658 | 659 | def get_transform(self): 660 | return self.SquareRootTransform() 661 | 662 | mscale.register_scale(SquareRootScale) 663 | --------------------------------------------------------------------------------