├── .gitmodules ├── LICENSE ├── README.md ├── environment.yml ├── helpers ├── data_helpers.py ├── decisionlayer_helpers.py ├── feature_helpers.py ├── nlp_helpers.py └── vis_helpers.py ├── inspect_language_models.ipynb ├── inspect_vision_models.ipynb ├── language ├── datasets.py ├── jigsaw_loaders.py └── models.py ├── main.py ├── pipeline.png ├── pipeline_600x400.png └── requirements.txt /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "lucent"] 2 | path = lucent 3 | url = https://github.com/greentfrapp/lucent.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Eric Wong & Shibani Santurkar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Deep Network Debuggability via Sparse Decision Layers 2 | 3 | This repository contains the code for our paper: 4 | 5 | **Leveraging Sparse Linear Layers for Debuggable Deep Networks**
6 | *Eric Wong\*, Shibani Santurkar\*, Aleksander Madry*
7 | Paper: http://arxiv.org/abs/2105.04857
8 | Blog posts: [Part1](https://gradientscience.org/glm_saga) and [Part2](https://gradientscience.org/debugging)
9 | 10 |

11 | Pipeline overview 12 |

13 | 14 | 15 | ```bibtex 16 | @article{wong2021leveraging, 17 | title={Leveraging Sparse Linear Layers for Debuggable Deep Networks}, 18 | author={Wong, Eric and Santurkar, Shibani and M{\k{a}}dry, Aleksander}, 19 | journal={arXiv preprint arXiv:2105.04857}, 20 | year={2021} 21 | } 22 | ``` 23 | 24 | ## Getting started 25 | *Our code relies on the [MadryLab](http://madry-lab.ml/) public [`robustness`](https://github.com/MadryLab/robustness) library, as well as the [`glm_saga`](https://github.com/MadryLab/glm_saga) library which will be automatically installed when you follow the instructions below. The [`glm_saga`](https://github.com/MadryLab/glm_saga) library contains a standalone implementation of our sparse GLM solver.* 26 | 1. Clone our repo: `git clone https://github.com/microsoft/DebuggableDeepNetworks.git` 27 | 28 | 2. Setup the [lucent](https://github.com/greentfrapp/lucent) submodule using: `git submodule update --init --recursive` 29 | 30 | 3. We recommend using conda for dependencies: 31 | ``` 32 | conda env create -f environment.yml 33 | conda activate debuggable 34 | ``` 35 | 36 | ## Training sparse decision layers 37 | 38 | Contents: 39 | + `main.py` fits a sparse decision layer on top of the deep features of the specified pre-trained (language/vision) deep network 40 | + `helpers/` has some helper functions for loading datasets, models, and features 41 | + `language/` has some additional code for handling language models and datasets 42 | 43 | To run the settings in our paper, you can use the following commands: 44 | ``` 45 | # Sentiment classification 46 | python main.py --dataset sst --dataset-path --dataset-type language --model-path barissayil/bert-sentiment-analysis-sst --arch bert --out-path ./tmp/sst/ --cache 47 | 48 | # Toxic comment classification (biased) 49 | python main.py --dataset jigsaw-toxic --dataset-path --dataset-type language --model-path unitary/toxic-bert --arch bert --out-path ./tmp/jigsaw-toxic/ --cache --balance 50 | 51 | # Toxic comment classification (unbiased) 52 | python main.py --dataset jigsaw-alt-toxic --dataset-path --dataset-type language --model-path unitary/unbiased-toxic-roberta --arch roberta --out-path ./tmp/unbiased-jigsaw-toxic/ --cache --balance 53 | 54 | # Places-10 55 | python main.py --dataset places-10 --dataset-path --dataset-type vision --model-path --arch resnet50 --out-path ./tmp/places/ --cache 56 | 57 | # ImageNet 58 | python main.py --dataset imagenet --dataset-path --dataset-type vision --model-path --arch resnet50 --out-path ./tmp/imagenet/ --cache 59 | ``` 60 | 61 | ## Interpreting deep features 62 | After fitting a sparse GLM with one of the above commands, we provide some 63 | notebooks for inspecting and visualizing the resulting features. See 64 | `inspect_vision_models.ipynb` and `inspect_language_models.ipynb` for the vision and language settings respectively. 65 | 66 | # Maintainers 67 | 68 | * [Eric Wong](https://twitter.com/RICEric22) 69 | * [Shibani Santurkar](https://twitter.com/ShibaniSan) 70 | * [Aleksander Madry](https://twitter.com/aleks_madry) 71 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: debuggable 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - backcall=0.2.0=py_0 9 | - blas=1.0=mkl 10 | - ca-certificates=2021.4.13=h06a4308_1 11 | - certifi=2020.12.5=py37h06a4308_0 12 | - cudatoolkit=10.2.89=hfd86e86_1 13 | - cycler=0.10.0=py_2 14 | - dbus=1.13.18=hb2f20db_0 15 | - decorator=4.4.2=py_0 16 | - expat=2.2.10=he6710b0_2 17 | - fontconfig=2.13.0=h9420a91_0 18 | - freetype=2.10.4=h5ab3b9f_0 19 | - glib=2.66.1=h92f7085_0 20 | - gst-plugins-base=1.14.0=hbbd80ab_1 21 | - gstreamer=1.14.0=hb31296c_0 22 | - icu=58.2=he6710b0_3 23 | - intel-openmp=2020.2=254 24 | - ipykernel=5.3.4=py37h5ca1d4c_0 25 | - ipython=7.18.1=py37h5ca1d4c_0 26 | - ipython_genutils=0.2.0=py37_0 27 | - jedi=0.17.2=py37_0 28 | - jpeg=9b=habf39ab_1 29 | - jupyter_client=6.1.7=py_0 30 | - jupyter_core=4.6.3=py37_0 31 | - kiwisolver=1.3.1=py37hc928c03_0 32 | - lcms2=2.11=h396b838_0 33 | - ld_impl_linux-64=2.33.1=h53a641e_7 34 | - libedit=3.1.20191231=h14c3975_1 35 | - libffi=3.3=he6710b0_2 36 | - libgcc-ng=9.1.0=hdf63c60_0 37 | - libgfortran-ng=7.3.0=hdf63c60_0 38 | - libpng=1.6.37=hbc83047_0 39 | - libprotobuf=3.13.0.1=h8b12597_0 40 | - libsodium=1.0.18=h7b6447c_0 41 | - libstdcxx-ng=9.1.0=hdf63c60_0 42 | - libtiff=4.1.0=h2733197_1 43 | - libuuid=1.0.3=h1bed415_2 44 | - libuv=1.40.0=h7b6447c_0 45 | - libxcb=1.14=h7b6447c_0 46 | - libxml2=2.9.10=hb55368b_3 47 | - lz4-c=1.9.2=heb0550a_3 48 | - matplotlib=3.3.2=h06a4308_0 49 | - matplotlib-base=3.3.2=py37h817c723_0 50 | - mkl=2019.4=243 51 | - mkl-service=2.3.0=py37he904b0f_0 52 | - mkl_fft=1.2.0=py37h23d657b_0 53 | - mkl_random=1.0.4=py37hd81dba3_0 54 | - ncurses=6.2=he6710b0_1 55 | - ninja=1.10.1=py37hfd86e86_0 56 | - numpy=1.19.1=py37hbc911f0_0 57 | - numpy-base=1.19.1=py37hfa32c7d_0 58 | - olefile=0.46=py37_0 59 | - openssl=1.1.1k=h27cfd23_0 60 | - pandas=1.1.3=py37he6710b0_0 61 | - parso=0.7.0=py_0 62 | - pcre=8.44=he6710b0_0 63 | - pexpect=4.8.0=py37_1 64 | - pickleshare=0.7.5=py37_1001 65 | - pillow=8.0.0=py37h9a89aac_0 66 | - pip=20.2.4=py37_0 67 | - prompt-toolkit=3.0.8=py_0 68 | - ptyprocess=0.6.0=py37_0 69 | - pygments=2.7.1=py_0 70 | - pyparsing=2.4.7=pyh9f0ad1d_0 71 | - pyqt=5.9.2=py37h05f1152_2 72 | - python=3.7.9=h7579374_0 73 | - python-dateutil=2.8.1=py_0 74 | - python_abi=3.7=1_cp37m 75 | - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0 76 | - pytz=2020.1=py_0 77 | - pyzmq=19.0.2=py37he6710b0_1 78 | - qt=5.9.7=h5867ecd_1 79 | - readline=8.0=h7b6447c_0 80 | - seaborn=0.11.0=py_0 81 | - setuptools=50.3.0=py37hb0f4dca_1 82 | - sip=4.19.8=py37hf484d3e_0 83 | - six=1.15.0=py_0 84 | - sqlite=3.33.0=h62c20be_0 85 | - tensorboardx=2.1=py_0 86 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 87 | - tk=8.6.10=hbc83047_0 88 | - torchvision=0.8.1=py37_cu102 89 | - tornado=6.0.4=py37h7b6447c_1 90 | - traitlets=5.0.5=py_0 91 | - typing_extensions=3.7.4.3=py_0 92 | - wcwidth=0.2.5=py_0 93 | - wheel=0.35.1=py_0 94 | - wordcloud=1.8.1=py37h4abf009_1 95 | - xz=5.2.5=h7b6447c_0 96 | - zeromq=4.3.3=he6710b0_3 97 | - zlib=1.2.11=h7b6447c_3 98 | - zstd=1.4.4=h0b5b093_3 99 | - pip: 100 | - argon2-cffi==20.1.0 101 | - async-generator==1.10 102 | - attrs==20.3.0 103 | - bleach==3.3.0 104 | - cffi==1.14.5 105 | - chardet==3.0.4 106 | - click==7.1.2 107 | - coverage==5.5 108 | - coveralls==3.0.1 109 | - cox==0.1.post3 110 | - dataclasses==0.6 111 | - datasets==1.1.3 112 | - defusedxml==0.7.1 113 | - dill==0.3.3 114 | - docopt==0.6.2 115 | - entrypoints==0.3 116 | - filelock==3.0.12 117 | - future==0.18.2 118 | - gitdb==4.0.5 119 | - gitpython==3.1.11 120 | - glm-saga==0.1.1 121 | - grpcio==1.34.0 122 | - idna==2.10 123 | - imageio==2.9.0 124 | - importlib-metadata==3.7.3 125 | - iniconfig==1.1.1 126 | - ipywidgets==7.6.3 127 | - jinja2==2.11.3 128 | - joblib==0.17.0 129 | - jsonschema==3.2.0 130 | - jupyterlab-pygments==0.1.2 131 | - jupyterlab-widgets==1.0.0 132 | - kornia==0.4.1 133 | - lime==0.2.0.1 134 | - markupsafe==1.1.1 135 | - mistune==0.8.4 136 | - multiprocess==0.70.11.1 137 | - nbclient==0.5.3 138 | - nbconvert==6.0.7 139 | - nbformat==5.1.2 140 | - nest-asyncio==1.5.1 141 | - networkx==2.5 142 | - notebook==6.2.0 143 | - packaging==20.7 144 | - pandocfilters==1.4.3 145 | - pluggy==0.13.1 146 | - prometheus-client==0.9.0 147 | - protobuf==3.14.0 148 | - psutil==5.7.3 149 | - py==1.10.0 150 | - py3nvml==0.2.6 151 | - pyarrow==2.0.0 152 | - pycparser==2.20 153 | - pyrsistent==0.17.3 154 | - pytest==6.2.4 155 | - pytest-mock==3.6.1 156 | - pywavelets==1.1.1 157 | - regex==2020.11.13 158 | - requests==2.25.0 159 | - robustness 160 | - sacremoses==0.0.43 161 | - scikit-image==0.17.2 162 | - scikit-learn==0.23.2 163 | - scipy==1.5.4 164 | - send2trash==1.5.0 165 | - sentencepiece==0.1.91 166 | - smmap==3.0.4 167 | - terminado==0.9.3 168 | - testpath==0.4.4 169 | - tifffile==2020.11.26 170 | - tokenizers==0.9.3 171 | - toml==0.10.2 172 | - torch-lucent==0.1.8 173 | - tqdm==4.49.0 174 | - transformers==3.5.1 175 | - urllib3==1.26.2 176 | - webencodings==0.5.1 177 | - widgetsnbextension==3.5.1 178 | - xmltodict==0.12.0 179 | - xxhash==2.0.0 180 | - zipp==3.4.1 181 | -------------------------------------------------------------------------------- /helpers/data_helpers.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('..') 3 | import numpy as np 4 | import torch as ch 5 | from torch.utils.data import TensorDataset 6 | from robustness.datasets import DATASETS as VISION_DATASETS 7 | from robustness.tools.label_maps import CLASS_DICT 8 | from language.datasets import DATASETS as LANGUAGE_DATASETS 9 | from language.models import LANGUAGE_MODEL_DICT 10 | from transformers import AutoTokenizer 11 | 12 | def get_label_mapping(dataset_name): 13 | if dataset_name == 'imagenet': 14 | return CLASS_DICT['ImageNet'] 15 | elif dataset_name == 'places-10': 16 | return CD_PLACES 17 | elif dataset_name == 'sst': 18 | return {0: 'negative', 1: 'positive'} 19 | elif 'jigsaw' in dataset_name: 20 | category = dataset_name.split('jigsaw-')[1] if 'alt' not in dataset_name \ 21 | else dataset_name.split('jigsaw-alt-')[1] 22 | return {0: f'not {category}', 1: f'{category}'} 23 | else: 24 | raise ValueError("Dataset not currently supported...") 25 | 26 | def load_dataset(dataset_name, dataset_path, dataset_type, 27 | batch_size, num_workers, 28 | maxlen_train=256, maxlen_val=256, 29 | shuffle=False, model_path=None, return_sentences=False): 30 | 31 | 32 | if dataset_type == 'vision': 33 | if dataset_name == 'places-10': dataset_name = 'places365' 34 | if dataset_name not in VISION_DATASETS: 35 | raise ValueError("Vision dataset not currently supported...") 36 | dataset = VISION_DATASETS[dataset_name](os.path.expandvars(dataset_path)) 37 | 38 | if dataset_name == 'places365': 39 | dataset.num_classes = 10 40 | 41 | train_loader, test_loader = dataset.make_loaders(num_workers, 42 | batch_size, 43 | data_aug=False, 44 | shuffle_train=shuffle, 45 | shuffle_val=shuffle) 46 | return dataset, train_loader, test_loader 47 | else: 48 | if model_path is None: 49 | model_path = LANGUAGE_MODEL_DICT[dataset_name] 50 | 51 | tokenizer = AutoTokenizer.from_pretrained(model_path) 52 | 53 | kwargs = {} if 'jigsaw' not in dataset_name else \ 54 | {'label': dataset_name[11:] if 'alt' in dataset_name \ 55 | else dataset_name[7:]} 56 | kwargs['return_sentences'] = return_sentences 57 | train_set = LANGUAGE_DATASETS(dataset_name)(filename=f'{dataset_path}/train.tsv', 58 | maxlen=maxlen_train, 59 | tokenizer=tokenizer, 60 | **kwargs) 61 | test_set = LANGUAGE_DATASETS(dataset_name)(filename=f'{dataset_path}/test.tsv', 62 | maxlen=maxlen_val, 63 | tokenizer=tokenizer, 64 | **kwargs) 65 | train_loader = ch.utils.data.DataLoader(dataset=train_set, 66 | batch_size=batch_size, 67 | num_workers=num_workers) 68 | test_loader = ch.utils.data.DataLoader(dataset=test_set, 69 | batch_size=batch_size, 70 | num_workers=num_workers) 71 | #assert len(np.unique(train_set.df['label'].values)) == len(np.unique(test_set.df['label'].values)) 72 | train_set.num_classes = 2 73 | # train_loader.dataset.targets = train_loader.dataset.df['label'].values 74 | # test_loader.dataset.targets = test_loader.dataset.df['label'].values 75 | 76 | return train_set, train_loader, test_loader 77 | 78 | 79 | class IndexedTensorDataset(ch.utils.data.TensorDataset): 80 | def __getitem__(self, index): 81 | val = super(IndexedTensorDataset, self).__getitem__(index) 82 | return val + (index,) 83 | 84 | class IndexedDataset(ch.utils.data.Dataset): 85 | def __init__(self, ds, sample_weight=None): 86 | super(ch.utils.data.Dataset, self).__init__() 87 | self.dataset = ds 88 | self.sample_weight=sample_weight 89 | 90 | def __getitem__(self, index): 91 | val = self.dataset[index] 92 | if self.sample_weight is None: 93 | return val + (index,) 94 | else: 95 | weight = self.sample_weight[index] 96 | return val + (weight,index) 97 | def __len__(self): 98 | return len(self.dataset) 99 | 100 | def add_index_to_dataloader(loader, sample_weight=None): 101 | return ch.utils.data.DataLoader( 102 | IndexedDataset(loader.dataset, sample_weight=sample_weight), 103 | batch_size=loader.batch_size, 104 | sampler=loader.sampler, 105 | num_workers=loader.num_workers, 106 | collate_fn=loader.collate_fn, 107 | pin_memory=loader.pin_memory, 108 | drop_last=loader.drop_last, 109 | timeout=loader.timeout, 110 | worker_init_fn=loader.worker_init_fn, 111 | multiprocessing_context=loader.multiprocessing_context 112 | ) 113 | 114 | class NormalizedRepresentation(ch.nn.Module): 115 | def __init__(self, loader, metadata, device='cuda', tol=1e-5): 116 | super(NormalizedRepresentation, self).__init__() 117 | 118 | 119 | assert metadata is not None 120 | self.device = device 121 | self.mu = metadata['X']['mean'] 122 | self.sigma = ch.clamp(metadata['X']['std'], tol) 123 | 124 | def forward(self, X): 125 | return (X - self.mu.to(self.device))/self.sigma.to(self.device) 126 | 127 | CD_PLACES = {0: 'airport_terminal', 128 | 1: 'boat_deck', 129 | 2: 'bridge', 130 | 3: 'butchers_shop', 131 | 4: 'church-outdoor', 132 | 5: 'hotel_room', 133 | 6: 'laundromat', 134 | 7: 'river', 135 | 8: 'ski_slope', 136 | 9: 'volcano'} 137 | -------------------------------------------------------------------------------- /helpers/decisionlayer_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch as ch 4 | import pandas as pd 5 | 6 | def load_glm(result_dir): 7 | 8 | Nlambda = max([int(f.split('params')[1].split('.pth')[0]) 9 | for f in os.listdir(result_dir) if 'params' in f]) + 1 10 | 11 | print(f"Loading regularization path of length {Nlambda}") 12 | 13 | params_dict = {i: ch.load(os.path.join(result_dir, f"params{i}.pth"), 14 | map_location=ch.device('cpu')) for i in range(Nlambda)} 15 | 16 | regularization_strengths = [params_dict[i]['lam'].item() for i in range(Nlambda)] 17 | weights = [params_dict[i]['weight'] for i in range(Nlambda)] 18 | biases = [params_dict[i]['bias'] for i in range(Nlambda)] 19 | 20 | metrics = {'acc_tr': [], 'acc_val': [], 'acc_test': []} 21 | 22 | for k in metrics.keys(): 23 | for i in range(Nlambda): 24 | metrics[k].append(params_dict[i]['metrics'][k]) 25 | metrics[k] = 100 * np.stack(metrics[k]) 26 | metrics = pd.DataFrame(metrics) 27 | metrics = metrics.rename(columns={'acc_tr': 'acc_train'}) 28 | 29 | weights_stacked = ch.stack(weights) 30 | sparsity = ch.sum(weights_stacked != 0, dim=2).numpy() 31 | 32 | return {'metrics': metrics, 33 | 'regularization_strengths': regularization_strengths, 34 | 'weights': weights, 35 | 'biases': biases, 36 | 'sparsity': sparsity, 37 | 'weight_dense': weights[-1], 38 | 'bias_dense': biases[-1]} 39 | 40 | def select_sparse_model(result_dict, 41 | selection_criterion='absolute', 42 | factor=6): 43 | 44 | assert selection_criterion in ['sparsity', 'absolute', 'relative', 'percentile'] 45 | 46 | metrics, sparsity = result_dict['metrics'], result_dict['sparsity'] 47 | 48 | acc_val, acc_test = metrics['acc_val'], metrics['acc_test'] 49 | 50 | if factor == 0: 51 | sel_idx = -1 52 | elif selection_criterion == 'sparsity': 53 | sel_idx = np.argmin(np.abs(np.mean(sparsity, axis=1) - factor)) 54 | elif selection_criterion == 'relative': 55 | sel_idx = np.argmin(np.abs(acc_val - factor * np.max(acc_val))) 56 | elif selection_criterion == 'absolute': 57 | delta = acc_val - (np.max(acc_val) - factor) 58 | lidx = np.where(delta <= 0)[0] 59 | sel_idx = lidx[np.argmin(-delta[lidx])] 60 | elif selection_criterion == 'percentile': 61 | diff = np.max(acc_val) - np.min(acc_val) 62 | sel_idx = np.argmax(acc_val > np.max(acc_val) - factor * diff) 63 | 64 | print(f"Test accuracy | Best: {max(acc_test): .2f},", 65 | f"Sparse: {acc_test[sel_idx]:.2f}", 66 | f"Sparsity: {np.mean(sparsity[sel_idx]):.2f}") 67 | 68 | result_dict.update({'weight_sparse': result_dict['weights'][sel_idx], 69 | 'bias_sparse': result_dict['biases'][sel_idx]}) 70 | return result_dict -------------------------------------------------------------------------------- /helpers/feature_helpers.py: -------------------------------------------------------------------------------- 1 | import os, math, sys 2 | sys.path.append('..') 3 | import numpy as np 4 | import torch as ch 5 | from torch._utils import _accumulate 6 | from torch.utils.data import Subset 7 | from tqdm import tqdm 8 | from robustness.model_utils import make_and_restore_model 9 | from robustness.loaders import LambdaLoader 10 | from transformers import AutoConfig 11 | import language.models as lm 12 | 13 | def load_model(model_root, arch, dataset, dataset_name, 14 | dataset_type, device='cuda'): 15 | """Loads existing vision/language models. 16 | Args: 17 | model_root (str): Path to model 18 | arch (str): Model architecture 19 | dataset (torch dataset): Dataset on which the model was 20 | trained 21 | dataset_name (str): Name of dataset 22 | dataset_type (str): One of vision or language 23 | device (str): Device on which to keep the model 24 | Returns: 25 | model: Torch model 26 | pooled_output (bool): Whether or not to pool outputs 27 | (only relevant for some language models) 28 | """ 29 | 30 | if model_root is None and dataset_type == 'language': 31 | model_root = lm.LANGUAGE_MODEL_DICT[dataset_name] 32 | 33 | pooled_output = None 34 | if dataset_type == 'vision': 35 | model, _ = make_and_restore_model(arch=arch, 36 | dataset=dataset, 37 | resume_path=model_root, 38 | pytorch_pretrained=(model_root is None) 39 | ) 40 | else: 41 | config = AutoConfig.from_pretrained(model_root) 42 | if config.model_type == 'bert': 43 | if model_root == 'barissayil/bert-sentiment-analysis-sst': 44 | model = lm.BertForSentimentClassification.from_pretrained(model_root) 45 | pooled_output = False 46 | else: 47 | model = lm.BertForSequenceClassification.from_pretrained(model_root) 48 | pooled_output = True 49 | elif config.model_type == 'roberta': 50 | model = lm.RobertaForSequenceClassification.from_pretrained(model_root) 51 | pooled_output = False 52 | else: 53 | raise ValueError('This transformer model is not supported yet.') 54 | 55 | model.eval() 56 | model = ch.nn.DataParallel(model.to(device)) 57 | return model, pooled_output 58 | 59 | def get_features_batch(batch, model, dataset_type, pooled_output=None, device='cuda'): 60 | 61 | if dataset_type == 'vision': 62 | ims, targets = batch 63 | (_,latents), _ = model(ims.to(device), with_latent=True) 64 | else: 65 | (input_ids, attention_mask, targets) = batch 66 | mask = targets != -1 67 | input_ids, attention_mask, targets = [t[mask] for t in (input_ids, attention_mask, targets)] 68 | 69 | if hasattr(model, 'module'): 70 | model = model.module 71 | if hasattr(model, "roberta"): 72 | latents = model.roberta(input_ids=input_ids.to(device), 73 | attention_mask=attention_mask.to(device))[0] 74 | latents = model.classifier.dropout(latents[:,0,:]) 75 | latents = model.classifier.dense(latents) 76 | latents = ch.tanh(latents) 77 | latents = model.classifier.dropout(latents) 78 | else: 79 | latents = model.bert(input_ids=input_ids.to(device), 80 | attention_mask=attention_mask.to(device)) 81 | if pooled_output: 82 | latents = latents[1] 83 | else: 84 | latents = latents[0][:,0] 85 | return latents, targets 86 | 87 | def compute_features(loader, model, dataset_type, pooled_output, 88 | batch_size, num_workers, 89 | shuffle=False, device='cuda', 90 | filename=None, chunk_threshold=20000, balance=False): 91 | 92 | """Compute deep features for a given dataset using a modeln and returnss 93 | them as a pytorch dataset and loader. 94 | Args: 95 | loader : Torch data loader 96 | model: Torch model 97 | dataset_type (str): One of vision or language 98 | pooled_output (bool): Whether or not to pool outputs 99 | (only relevant for some language models) 100 | batch_size (int): Batch size for output loader 101 | num_workers (int): Number of workers to use for output loader 102 | shuffle (bool): Whether or not to shuffle output data loaoder 103 | device (str): Device on which to keep the model 104 | filename (str):Optional file to cache computed feature. Recommended 105 | for large datasets like ImageNet. 106 | chunk_threshold (int): Size of shard while caching 107 | balance (bool): Whether or not to balance output data loader 108 | (only relevant for some language models) 109 | Returns: 110 | feature_dataset: Torch dataset with deep features 111 | feature_loader: Torch data loader with deep features 112 | """ 113 | 114 | if filename is None or not os.path.exists(os.path.join(filename, f'0_features.npy')): 115 | 116 | all_latents, all_targets = [], [] 117 | Nsamples, chunk_id = 0, 0 118 | 119 | for batch_idx, batch in tqdm(enumerate(loader), total=len(loader)): 120 | 121 | with ch.no_grad(): 122 | latents, targets = get_features_batch(batch, model, dataset_type, 123 | pooled_output=pooled_output, 124 | device=device) 125 | 126 | if batch_idx == 0: 127 | print("Latents shape", latents.shape) 128 | 129 | 130 | Nsamples += latents.size(0) 131 | 132 | all_latents.append(latents.cpu()) 133 | all_targets.append(targets.cpu()) 134 | 135 | if filename is not None and Nsamples > chunk_threshold: 136 | if not os.path.exists(filename): os.makedirs(filename) 137 | np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy()) 138 | np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy()) 139 | all_latents, all_targets, Nsamples = [], [], 0 140 | chunk_id += 1 141 | 142 | if filename is not None and Nsamples > 0: 143 | if not os.path.exists(filename): os.makedirs(filename) 144 | np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy()) 145 | np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy()) 146 | 147 | 148 | feature_dataset = load_features(filename) if filename is not None else \ 149 | ch.utils.data.TensorDataset(ch.cat(all_latents), ch.cat(all_targets)) 150 | if balance: 151 | feature_dataset = balance_dataset(feature_dataset) 152 | 153 | feature_loader = ch.utils.data.DataLoader(feature_dataset, 154 | num_workers=num_workers, 155 | batch_size=batch_size, 156 | shuffle=shuffle) 157 | 158 | return feature_dataset, feature_loader 159 | 160 | def balance_dataset(dataset): 161 | """Balances a given dataset to have the same number of samples/class. 162 | Args: 163 | dataset : Torch dataset 164 | Returns: 165 | Torch dataset with equal number of samples/class 166 | """ 167 | 168 | print("Balancing dataset...") 169 | n = len(dataset) 170 | labels = ch.Tensor([dataset[i][1] for i in range(n)]).int() 171 | n0 = sum(labels).item() 172 | I_pos = labels == 1 173 | 174 | idx = ch.arange(n) 175 | idx_pos = idx[I_pos] 176 | ch.manual_seed(0) 177 | I = ch.randperm(n - n0)[:n0] 178 | idx_neg = idx[~I_pos][I] 179 | idx_bal = ch.cat([idx_pos, idx_neg],dim=0) 180 | return Subset(dataset, idx_bal) 181 | 182 | def load_features_mode(feature_path, mode='test', 183 | num_workers=10, batch_size=128): 184 | """Loads precomputed deep features corresponding to the 185 | train/test set along with normalization statitic. 186 | Args: 187 | feature_path (str): Path to precomputed deep features 188 | mode (str): One of train or tesst 189 | num_workers (int): Number of workers to use for output loader 190 | batch_size (int): Batch size for output loader 191 | 192 | Returns: 193 | features (np.array): Recovered deep features 194 | feature_mean: Mean of deep features 195 | feature_std: Standard deviation of deep features 196 | """ 197 | feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}')) 198 | feature_loader = ch.utils.data.DataLoader(feature_dataset, 199 | num_workers=num_workers, 200 | batch_size=batch_size, 201 | shuffle=False) 202 | 203 | feature_metadata = ch.load(os.path.join(feature_path, f'metadata_train.pth')) 204 | feature_mean, feature_std = feature_metadata['X']['mean'], feature_metadata['X']['std'] 205 | 206 | 207 | features = [] 208 | 209 | for _, (feature, _) in tqdm(enumerate(feature_loader), total=len(feature_loader)): 210 | features.append(feature) 211 | 212 | features = ch.cat(features).numpy() 213 | return features, feature_mean, feature_std 214 | 215 | def load_features(feature_path): 216 | """Loads precomputed deep features. 217 | Args: 218 | feature_path (str): Path to precomputed deep features 219 | 220 | Returns: 221 | Torch dataset with recovered deep features. 222 | """ 223 | if not os.path.exists(os.path.join(feature_path, f"0_features.npy")): 224 | raise ValueError(f"The provided location {feature_path} does not contain any representation files") 225 | 226 | ds_list, chunk_id = [], 0 227 | while os.path.exists(os.path.join(feature_path, f"{chunk_id}_features.npy")): 228 | features = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_features.npy"))).float() 229 | labels = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_labels.npy"))).long() 230 | ds_list.append(ch.utils.data.TensorDataset(features, labels)) 231 | chunk_id += 1 232 | 233 | print(f"==> loaded {chunk_id} files of representations...") 234 | return ch.utils.data.ConcatDataset(ds_list) 235 | 236 | 237 | def calculate_metadata(loader, num_classes=None, filename=None): 238 | """Calculates mean and standard deviation of the deep features over 239 | a given set of images. 240 | Args: 241 | loader : torch data loader 242 | num_classes (int): Number of classes in the dataset 243 | filename (str): Optional filepath to cache metadata. Recommended 244 | for large datasets like ImageNet. 245 | 246 | Returns: 247 | metadata (dict): Dictionary with desired statistics. 248 | """ 249 | 250 | if filename is not None and os.path.exists(filename): 251 | return ch.load(filename) 252 | 253 | # Calculate number of classes if not given 254 | if num_classes is None: 255 | num_classes = 1 256 | for batch in loader: 257 | y = batch[1] 258 | print(y) 259 | num_classes = max(num_classes, y.max().item()+1) 260 | 261 | eye = ch.eye(num_classes) 262 | 263 | X_bar, y_bar, y_max, n = 0, 0, 0, 0 264 | 265 | # calculate means and maximum 266 | print("Calculating means") 267 | for X,y in tqdm(loader, total=len(loader)): 268 | X_bar += X.sum(0) 269 | y_bar += eye[y].sum(0) 270 | y_max = max(y_max, y.max()) 271 | n += y.size(0) 272 | X_bar = X_bar.float()/n 273 | y_bar = y_bar.float()/n 274 | 275 | # calculate std 276 | X_std, y_std = 0, 0 277 | print("Calculating standard deviations") 278 | for X,y in tqdm(loader, total=len(loader)): 279 | X_std += ((X - X_bar)**2).sum(0) 280 | y_std += ((eye[y] - y_bar)**2).sum(0) 281 | X_std = ch.sqrt(X_std.float()/n) 282 | y_std = ch.sqrt(y_std.float()/n) 283 | 284 | # calculate maximum regularization 285 | inner_products = 0 286 | print("Calculating maximum lambda") 287 | for X,y in tqdm(loader, total=len(loader)): 288 | y_map = (eye[y] - y_bar)/y_std 289 | inner_products += X.t().mm(y_map)*y_std 290 | 291 | inner_products_group = inner_products.norm(p=2,dim=1) 292 | 293 | metadata = { 294 | "X": { 295 | "mean": X_bar, 296 | "std": X_std, 297 | "num_features": X.size()[1:], 298 | "num_examples": n 299 | }, 300 | "y": { 301 | "mean": y_bar, 302 | "std": y_std, 303 | "num_classes": y_max+1 304 | }, 305 | "max_reg": { 306 | "group": inner_products_group.abs().max().item()/n, 307 | "nongrouped": inner_products.abs().max().item()/n 308 | } 309 | } 310 | 311 | if filename is not None: 312 | ch.save(metadata, filename) 313 | 314 | return metadata 315 | 316 | def split_dataset(dataset, Ntotal, val_frac, 317 | batch_size, num_workers, 318 | random_seed=0, shuffle=True, balance=False): 319 | """Splits a given dataset into train and validation 320 | Args: 321 | dataset : Torch dataset 322 | Ntotal: Total number of dataset samples 323 | val_frac: Fraction to reserve for validation 324 | batch_size (int): Batch size for output loader 325 | num_workers (int): Number of workers to use for output loader 326 | random_seed (int): Random seed 327 | shuffle (bool): Whether or not to shuffle output data loaoder 328 | balance (bool): Whether or not to balance output data loader 329 | (only relevant for some language models) 330 | 331 | Returns: 332 | split_datasets (list): List of datasets (one each for train and val) 333 | split_loaders (list): List of loaders (one each for train and val) 334 | """ 335 | 336 | Nval = math.floor(Ntotal*val_frac) 337 | train_ds, val_ds = ch.utils.data.random_split(dataset, 338 | [Ntotal - Nval, Nval], 339 | generator=ch.Generator().manual_seed(random_seed)) 340 | if balance: 341 | val_ds = balance_dataset(val_ds) 342 | split_datasets = [train_ds, val_ds] 343 | 344 | split_loaders = [] 345 | for ds in split_datasets: 346 | split_loaders.append(ch.utils.data.DataLoader(ds, 347 | num_workers=num_workers, 348 | batch_size=batch_size, 349 | shuffle=shuffle)) 350 | return split_datasets, split_loaders 351 | 352 | 353 | -------------------------------------------------------------------------------- /helpers/nlp_helpers.py: -------------------------------------------------------------------------------- 1 | import torch as ch 2 | import numpy as np 3 | 4 | from lime.lime_text import LimeTextExplainer 5 | from tqdm import tqdm 6 | from wordcloud import WordCloud 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | sns.set_style('darkgrid') 11 | 12 | import os 13 | from collections import defaultdict 14 | 15 | def make_lime_fn(model,val_set, pooled_output, mu, std, bs=128): 16 | device = 'cuda' 17 | def classifier_fn(sentences): 18 | try: 19 | input_ids, attention_mask = zip(*[val_set.process_sentence(s) for s in sentences]) 20 | except: 21 | input_ids, attention_mask = zip(*[val_set.dataset.process_sentence(s) for s in sentences]) 22 | input_ids, attention_mask = ch.stack(input_ids), ch.stack(attention_mask) 23 | all_reps = [] 24 | n = input_ids.size(0) 25 | # bs = args.batch_size 26 | for i in range(0,input_ids.size(0),bs): 27 | i0 = min(i+bs,n) 28 | # reps, _ = 29 | if hasattr(model, "roberta"): 30 | output = model.roberta(input_ids=input_ids[i:i0].to(device), attention_mask=attention_mask[i:i0].to(device)) 31 | output = output[0] 32 | 33 | # do RobertA classification head minus last out_proj classifier 34 | # https://huggingface.co/transformers/_modules/transformers/models/roberta/modeling_roberta.html 35 | output = output[:,0,:] 36 | output = model.classifier.dropout(output) 37 | output = model.classifier.dense(output) 38 | output = ch.tanh(output) 39 | cls_reps = model.classifier.dropout(output) 40 | else: 41 | output = model.bert(input_ids=input_ids[i:i0].to(device), attention_mask=attention_mask[i:i0].to(device)) 42 | if pooled_output: 43 | cls_reps = output[1] 44 | else: 45 | cls_reps = output[0][:,0] 46 | cls_reps = cls_reps.cpu() 47 | cls_reps = (cls_reps - mu)/std 48 | all_reps.append(cls_reps.cpu()) 49 | return ch.cat(all_reps,dim=0).numpy() 50 | return classifier_fn 51 | 52 | def get_lime_features(model, val_set, val_loader, out_dir, pooled_output, mu, std): 53 | os.makedirs(out_dir,exist_ok=True) 54 | explainer = LimeTextExplainer() 55 | reps_size = 768 56 | 57 | clf_fn = make_lime_fn(model,val_set, pooled_output, mu, std) 58 | files = [] 59 | with ch.no_grad(): 60 | print('number of sentences', len(val_loader)) 61 | for i,(sentences, labels) in enumerate(tqdm(val_loader, total=len(val_loader), desc="Generating LIME")): 62 | assert len(sentences) == 1 63 | sentence, label = sentences[0], labels[0] 64 | if label.item() == -1: 65 | continue 66 | out_file = f"{out_dir}/{i}.pth" 67 | try: 68 | files.append(ch.load(out_file)) 69 | continue 70 | except: 71 | pass 72 | # if os.path.exists(out_file): 73 | # continue 74 | exp = explainer.explain_instance(sentence, clf_fn, labels=list(range(reps_size))) 75 | out = { 76 | "sentence": sentence, 77 | "explanation": exp 78 | } 79 | ch.save(out, out_file) 80 | files.append(out) 81 | return files 82 | 83 | def top_and_bottom_words(files, num_features=768): 84 | top_words = [] 85 | bot_words = [] 86 | for j in range(num_features): 87 | exps = [f['explanation'].as_list(label=j) for f in files] 88 | exps_collapsed = [a for e in exps for a in e] 89 | 90 | accumulator = defaultdict(lambda: []) 91 | for word,weight in exps_collapsed: 92 | accumulator[word].append(weight) 93 | exps_collapsed = [(k,np.array(accumulator[k]).mean()) for k in accumulator] 94 | 95 | 96 | exps_collapsed.sort(key=lambda a: a[1]) 97 | 98 | weights = [a[1] for a in exps_collapsed] 99 | l = np.percentile(weights, q=1) 100 | u = np.percentile(weights, q=99) 101 | top_words.append([a for a in reversed(exps_collapsed) if a[1] > u]) 102 | bot_words.append([a for a in exps_collapsed if a[1] < l]) 103 | return top_words,bot_words 104 | 105 | def get_explanations(feature_idxs, sparse, top_words, bot_words): 106 | expl_dict = {} 107 | maxfreq = 0 108 | for i,idx in enumerate(feature_idxs): 109 | 110 | expl_dict[idx] = {} 111 | 112 | aligned = sparse[1,idx] > 0 113 | for words, j in zip([top_words, bot_words],[0,1]): 114 | if not aligned: 115 | j = 1-j 116 | expl_dict[idx][j] = { 117 | a[0]:abs(a[1]) for a in words[idx] ## ARE THESE SIGNS CORRECT 118 | } 119 | maxfreq = max(maxfreq, max(list(expl_dict[idx][j].values()))) 120 | 121 | for k in expl_dict: 122 | for s in expl_dict[k]: 123 | expl_dict[k][s] = {kk: vv / maxfreq for kk, vv in expl_dict[k][s].items()} 124 | return expl_dict 125 | 126 | from matplotlib.colors import ListedColormap 127 | 128 | def grey_color_func(word, font_size, position, orientation, random_state=None, 129 | **kwargs): 130 | #cmap = sns.color_palette("RdYlGn_r", as_cmap=True) 131 | cmap = ListedColormap(sns.color_palette("RdYlGn_r").as_hex()) 132 | fs = (font_size - 14) / (42 - 14) 133 | 134 | color_orig = cmap(fs) 135 | color = (255 * np.array(cmap(fs))).astype(np.uint8) 136 | 137 | return tuple(color[:3]) 138 | 139 | def plot_wordcloud(expln_dict, weights, factor=3, transpose=False, labels=("Positive Sentiment", "Negative Sentiment")): 140 | 141 | if transpose: 142 | fig, axs = plt.subplots(2, len(expln_dict), figsize=(factor*3.5*len(expln_dict), factor*4), squeeze=False) 143 | else: 144 | fig, axs = plt.subplots(len(expln_dict), 2, figsize=(factor*7, factor*1.5*len(expln_dict)), squeeze=False) 145 | 146 | for i,idx in enumerate(expln_dict.keys()): 147 | if i == 0: 148 | if transpose: 149 | axs[0,0].set_ylabel(labels[0], fontsize=36) 150 | axs[1,0].set_ylabel(labels[1], fontsize=36) 151 | else: 152 | axs[i,0].set_title(labels[0], fontsize=36) 153 | axs[i,1].set_title(labels[1], fontsize=36) 154 | 155 | for j in [0, 1]: 156 | wc = WordCloud(background_color="white", max_words=1000, min_font_size=14, 157 | max_font_size=42) 158 | # generate word cloud 159 | d = {k[:20]:v for k,v in expln_dict[idx][j].items()} 160 | wc.generate_from_frequencies(d) 161 | default_colors = wc.to_array() 162 | 163 | if transpose: 164 | ax = axs[j,i] 165 | else: 166 | ax = axs[i,j] 167 | ax.imshow(wc.recolor(color_func=grey_color_func, random_state=3), 168 | interpolation="bilinear") 169 | ax.set_xticks([]) 170 | ax.set_yticks([]) 171 | ax.spines['right'].set_visible(False) 172 | ax.spines['top'].set_visible(False) 173 | ax.spines['left'].set_visible(False) 174 | ax.spines['bottom'].set_visible(False) 175 | # ax.set_axis_off() 176 | if not transpose: 177 | axs[i,0].set_ylabel(f"#{idx}\nW={weights[i]:.4f}", fontsize=36) 178 | else: 179 | axs[0,i].set_title(f"#{idx}", fontsize=36) 180 | plt.tight_layout() 181 | # plt.subplots_adjust(left=None, bottom=0.1, right=None, top=0.9, wspace=0.25, hspace=0.0) 182 | return fig, axs -------------------------------------------------------------------------------- /helpers/vis_helpers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./lucent') 3 | import torch as ch 4 | import torchvision 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | from functools import partial 9 | from lucent.optvis import render, param, transform, objectives 10 | from lime import lime_image 11 | 12 | def plot_sparsity(results): 13 | """Function to visualize the sparsity-accuracy trade-off of regularized decision 14 | layers 15 | Args: 16 | results (dictionary): Appropriately formatted dictionary with regularization 17 | paths and logs of train/val/test accuracy. 18 | """ 19 | 20 | if type(results['metrics']['acc_train'].values[0]) == list: 21 | all_tr = 100 * np.array(results['metrics']['acc_train'].values[0]) 22 | all_val = 100 * np.array(results['metrics']['acc_val'].values[0]) 23 | all_te = 100 * np.array(results['metrics']['acc_test'].values[0]) 24 | else: 25 | all_tr = 100 * np.array(results['metrics']['acc_train'].values) 26 | all_val = 100 * np.array(results['metrics']['acc_val'].values) 27 | all_te = 100 * np.array(results['metrics']['acc_test'].values) 28 | 29 | fig, axarr = plt.subplots(1, 2, figsize=(14, 5)) 30 | axarr[0].plot(all_tr) 31 | axarr[0].plot(all_val) 32 | axarr[0].plot(all_te) 33 | axarr[0].legend(['Train', 'Val', 'Test'], fontsize=16) 34 | axarr[0].set_ylabel("Accuracy (%)", fontsize=18) 35 | axarr[0].set_xlabel("Regularization index", fontsize=18) 36 | 37 | num_features = results['weights'][0].shape[1] 38 | total_sparsity = np.mean(results['sparsity'], axis=1) / num_features 39 | axarr[1].plot(total_sparsity, all_tr, 'o-') 40 | axarr[1].plot(total_sparsity, all_te, 'o-') 41 | axarr[1].legend(['Train', 'Val', 'Test'], fontsize=16) 42 | axarr[1].set_ylabel("Accuracy (%)", fontsize=18) 43 | axarr[1].set_xlabel("1 - Sparsity", fontsize=18) 44 | axarr[1].set_xscale('log') 45 | 46 | plt.show() 47 | 48 | def normalize_weight(w): 49 | """Normalizes weights to a unit vector 50 | Args: 51 | w (tensor): Weight vector for a class. 52 | Returns: 53 | Normalized weight vector in the form of a numpy array. 54 | """ 55 | return w.numpy() / np.linalg.norm(w.numpy()) 56 | 57 | def get_feature_visualization(model, feature_idx, signs): 58 | """Performs feature visualization using Lucid. 59 | Args: 60 | model: deep network whose deep features are to be visualized. 61 | feature_idx: indice of features to visualize. 62 | signs: +/-1 array indicating whether a feature should be maximized/minimized. 63 | Returns: 64 | Batch of feature visualizations . 65 | """ 66 | param_f = lambda: param.image(224, batch=len(feature_idx), fft=True, decorrelate=True) 67 | obj = 0 68 | for fi, (f, s) in enumerate(zip(feature_idx, signs)): 69 | obj += s * objectives.channel('avgpool', f, batch=fi) 70 | op = render.render_vis(model.model, 71 | show_inline=False, 72 | objective_f=obj, 73 | param_f=param_f, 74 | thresholds=(512,))[0] 75 | return ch.tensor(op).permute(0, 3, 1, 2) 76 | 77 | def latent_predict(images, model, mean=None, std=None): 78 | """LIME helper function that computes the deep feature representation 79 | for a given batch of images. 80 | Args: 81 | image (tensor): batch of images. 82 | model: deep network whose deep features are to be visualized. 83 | mean (tensor): mean of deep features. 84 | std (tensor): std deviation of deep features. 85 | Returns: 86 | Normalized deep features for batch of images. 87 | """ 88 | preprocess_transform = torchvision.transforms.Compose([ 89 | torchvision.transforms.ToTensor(), 90 | ]) 91 | device = 'cuda' if next(model.parameters()).is_cuda else 'cpu' 92 | batch = ch.stack(tuple(preprocess_transform(i) for i in images), dim=0).to(device) 93 | 94 | (_, latents), _ = model(batch.to(ch.float), with_latent=True) 95 | scaled_latents = (latents.detach().cpu() - mean.to(ch.float)) / std.to(ch.float) 96 | return scaled_latents.numpy() 97 | 98 | def parse_lime_explanation(expln, f, sign, NLime=3): 99 | """LIME helper function that extracts a mask from a lime explanation 100 | Args: 101 | expln: LIME explanation from LIME library 102 | f: indice of features to visualize 103 | sign: +/-1 array indicating whether the feature should be maximized/minimized. 104 | images (tensor): batch of images. 105 | NLime (int): Number of top-superpixels to visualize. 106 | Returns: 107 | Tensor where the first and second channels contains superpixels that cause the 108 | deep feature to activate and deactivate respectively. 109 | """ 110 | segs = expln.segments 111 | vis_mask = np.zeros(segs.shape + (3,)) 112 | 113 | weights = sorted([v for v in expln.local_exp[f]], 114 | key=lambda x: -np.abs(x[1])) 115 | weight_values = [w[1] for w in weights] 116 | pos_lim, neg_lim = np.max(weight_values), (1e-8 + np.min(weight_values)) 117 | 118 | if NLime is not None: 119 | weights = weights[:NLime] 120 | 121 | for wi, w in enumerate(weights): 122 | if w[1] >= 0: 123 | si = (w[1] / pos_lim, 0, 0) if sign == 1 else (0, w[1] / pos_lim, 0) 124 | else: 125 | si = (0, w[1] / neg_lim, 0) if sign == 1 else (w[1] / neg_lim, 0, 0) 126 | vis_mask[segs == w[0]] = si 127 | 128 | return ch.tensor(vis_mask.transpose(2, 0, 1)) 129 | 130 | def get_lime_explanation(model, feature_idx, signs, 131 | images, rep_mean, rep_std, 132 | num_samples=1000, 133 | NLime=3, 134 | background_color=0.6): 135 | """Computes LIME explanations for a given set of deep features. The LIME 136 | objective in this case is to identify the superpixels within the specified 137 | images that maximally/minimally activate the corresponding deep feature. 138 | Args: 139 | model: deep network whose deep features are to be visualized. 140 | feature_idx: indice of features to visualize 141 | signs: +/-1 array indicating whether a feature should be maximized/minimized. 142 | images (tensor): batch of images. 143 | rep_mean (tensor): mean of deep features. 144 | rep_std (tensor): std deviation of deep features. 145 | NLime (int): Number of top-superpixels to visualize 146 | background_color (float): Color to assign non-relevant super pixels 147 | Returns: 148 | Tensor comprising LIME explanations for the given set of deep features. 149 | """ 150 | explainer = lime_image.LimeImageExplainer() 151 | lime_objective = partial(latent_predict, model=model, mean=rep_mean, std=rep_std) 152 | 153 | explanations = [] 154 | for im, feature, sign in zip(images, feature_idx, signs): 155 | explanation = explainer.explain_instance(im.numpy().transpose(1, 2, 0), 156 | lime_objective, 157 | labels=np.array([feature]), 158 | top_labels=None, 159 | hide_color=0, 160 | num_samples=num_samples) 161 | explanation = parse_lime_explanation(explanation, 162 | feature, 163 | sign, 164 | NLime=NLime) 165 | 166 | if sign == 1: 167 | explanation = explanation[:1].unsqueeze(0).repeat(1, 3, 1, 1) 168 | else: 169 | explanation = explanation[1:2].unsqueeze(0).repeat(1, 3, 1, 1) 170 | 171 | interpolated = im * explanation + background_color * ch.ones_like(im) * (1 - explanation) 172 | explanations.append(interpolated) 173 | 174 | return ch.cat(explanations) -------------------------------------------------------------------------------- /language/datasets.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from torch.utils.data import Dataset 4 | from transformers import AutoTokenizer, AutoConfig 5 | 6 | from .jigsaw_loaders import JigsawDataOriginal 7 | 8 | def DATASETS(dataset_name): 9 | if dataset_name == 'sst': return SSTDataset 10 | elif dataset_name.startswith('jigsaw'): return JigsawDataset 11 | else: 12 | raise ValueError("Language dataset is not currently supported...") 13 | 14 | class SSTDataset(Dataset): 15 | """ 16 | Stanford Sentiment Treebank V1.0 17 | Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank 18 | Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher Manning, Andrew Ng and Christopher Potts 19 | Conference on Empirical Methods in Natural Language Processing (EMNLP 2013) 20 | """ 21 | def __init__(self, filename, maxlen, tokenizer, return_sentences=False): 22 | #Store the contents of the file in a pandas dataframe 23 | self.df = pd.read_csv(filename, delimiter = '\t') 24 | #Initialize the tokenizer for the desired transformer model 25 | self.tokenizer = tokenizer 26 | #Maximum length of the tokens list to keep all the sequences of fixed size 27 | self.maxlen = maxlen 28 | #whether to tokenize or return raw setences 29 | self.return_sentences = return_sentences 30 | 31 | def __len__(self): 32 | return len(self.df) 33 | 34 | def __getitem__(self, index): 35 | #Select the sentence and label at the specified index in the data frame 36 | sentence = self.df.loc[index, 'sentence'] 37 | label = self.df.loc[index, 'label'] 38 | #Preprocess the text to be suitable for the transformer 39 | if self.return_sentences: 40 | return sentence, label 41 | else: 42 | input_ids, attention_mask = self.process_sentence(sentence) 43 | return input_ids, attention_mask, label 44 | 45 | def process_sentence(self, sentence): 46 | tokens = self.tokenizer.tokenize(sentence) 47 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 48 | if len(tokens) < self.maxlen: 49 | tokens = tokens + ['[PAD]' for _ in range(self.maxlen - len(tokens))] 50 | else: 51 | tokens = tokens[:self.maxlen-1] + ['[SEP]'] 52 | #Obtain the indices of the tokens in the BERT Vocabulary 53 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 54 | input_ids = torch.tensor(input_ids) 55 | #Obtain the attention mask i.e a tensor containing 1s for no padded tokens and 0s for padded ones 56 | attention_mask = (input_ids != 0).long() 57 | return input_ids, attention_mask 58 | 59 | class JigsawDataset(Dataset): 60 | def __init__(self, filename, maxlen, tokenizer, return_sentences=False, label="toxic"): 61 | classes=[label] 62 | if 'train' in filename: 63 | self.dataset = JigsawDataOriginal( 64 | train_csv_file=filename, 65 | test_csv_file=None, 66 | train=True, 67 | create_val_set=False, 68 | add_test_labels=False, 69 | classes=classes 70 | ) 71 | elif 'test' in filename: 72 | self.dataset = JigsawDataOriginal( 73 | train_csv_file=None, 74 | test_csv_file=filename, 75 | train=False, 76 | create_val_set=False, 77 | add_test_labels=True, 78 | classes=classes 79 | ) 80 | else: 81 | raise ValueError("Unknown filename {filename}") 82 | # #Store the contents of the file in a pandas dataframe 83 | # self.df = pd.read_csv(filename, header=None, names=['label', 'sentence']) 84 | #Initialize the tokenizer for the desired transformer model 85 | self.tokenizer = tokenizer 86 | #Maximum length of the tokens list to keep all the sequences of fixed size 87 | self.maxlen = maxlen 88 | #whether to tokenize or return raw setences 89 | self.return_sentences = return_sentences 90 | 91 | def __len__(self): 92 | return len(self.dataset) 93 | 94 | def __getitem__(self, index): 95 | #Select the sentence and label at the specified index in the data frame 96 | sentence, meta = self.dataset[index] 97 | label = meta["multi_target"].squeeze() 98 | 99 | #Preprocess the text to be suitable for the transformer 100 | if self.return_sentences: 101 | return sentence, label 102 | else: 103 | input_ids, attention_mask = self.process_sentence(sentence) 104 | return input_ids, attention_mask, label 105 | 106 | def process_sentence(self, sentence): 107 | # print(sentence) 108 | d = self.tokenizer(sentence, padding='max_length', truncation=True) 109 | input_ids = torch.tensor(d["input_ids"]) 110 | attention_mask = torch.tensor(d["attention_mask"]) 111 | return input_ids, attention_mask -------------------------------------------------------------------------------- /language/jigsaw_loaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | import torch 3 | import pandas as pd 4 | import numpy as np 5 | import language.datasets 6 | from tqdm import tqdm 7 | 8 | 9 | class JigsawData(Dataset): 10 | """Dataloader for the Jigsaw Toxic Comment Classification Challenges. 11 | If test_csv_file is None and create_val_set is True the train file 12 | specified gets split into a train and validation set according to 13 | train_fraction.""" 14 | 15 | def __init__( 16 | self, 17 | train_csv_file, 18 | test_csv_file, 19 | train=True, 20 | val_fraction=0.9, 21 | add_test_labels=False, 22 | create_val_set=True, 23 | ): 24 | 25 | if train_csv_file is not None: 26 | if isinstance(train_csv_file, list): 27 | train_set_pd = self.load_data(train_csv_file) 28 | else: 29 | train_set_pd = pd.read_csv(train_csv_file) 30 | self.train_set_pd = train_set_pd 31 | if "toxicity" not in train_set_pd.columns: 32 | train_set_pd.rename(columns={"target": "toxicity"}, inplace=True) 33 | self.train_set = datasets.Dataset.from_pandas(train_set_pd) 34 | 35 | if create_val_set: 36 | data = self.train_set.train_test_split(val_fraction) 37 | self.train_set = data["train"] 38 | self.val_set = data["test"] 39 | 40 | if test_csv_file is not None: 41 | val_set = pd.read_csv(test_csv_file) 42 | if add_test_labels: 43 | data_labels = pd.read_csv(test_csv_file[:-4] + "_labels.csv") 44 | for category in data_labels.columns[1:]: 45 | val_set[category] = data_labels[category] 46 | val_set = datasets.Dataset.from_pandas(val_set) 47 | self.val_set = val_set 48 | 49 | if train: 50 | self.data = self.train_set 51 | else: 52 | self.data = self.val_set 53 | 54 | self.train = train 55 | 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | def load_data(self, train_csv_file): 60 | files = [] 61 | cols = ["id", "comment_text", "toxic"] 62 | for file in tqdm(train_csv_file): 63 | file_df = pd.read_csv(file) 64 | file_df = file_df[cols] 65 | file_df = file_df.astype({"id": "string"}, {"toxic": "float64"}) 66 | files.append(file_df) 67 | train = pd.concat(files) 68 | return train 69 | 70 | 71 | class JigsawDataOriginal(JigsawData): 72 | """Dataloader for the original Jigsaw Toxic Comment Classification Challenge. 73 | Source: https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge 74 | """ 75 | 76 | def __init__( 77 | self, 78 | train_csv_file="jigsaw_data/train.csv", 79 | test_csv_file="jigsaw_data/test.csv", 80 | train=True, 81 | val_fraction=0.1, 82 | create_val_set=True, 83 | add_test_labels=True, 84 | classes=["toxic"], 85 | ): 86 | 87 | super().__init__( 88 | train_csv_file=train_csv_file, 89 | test_csv_file=test_csv_file, 90 | train=train, 91 | val_fraction=val_fraction, 92 | add_test_labels=add_test_labels, 93 | create_val_set=create_val_set, 94 | ) 95 | self.classes = classes 96 | 97 | def __getitem__(self, index): 98 | meta = {} 99 | entry = self.data[index] 100 | text_id = entry["id"] 101 | text = entry["comment_text"] 102 | 103 | target_dict = { 104 | label: value for label, value in entry.items() if label in self.classes 105 | } 106 | 107 | meta["multi_target"] = torch.tensor( 108 | list(target_dict.values()), dtype=torch.int32 109 | ) 110 | meta["text_id"] = text_id 111 | 112 | return text, meta 113 | 114 | 115 | class JigsawDataBias(JigsawData): 116 | """Dataloader for the Jigsaw Unintended Bias in Toxicity Classification. 117 | Source: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/ 118 | """ 119 | 120 | def __init__( 121 | self, 122 | train_csv_file="jigsaw_data/train.csv", 123 | test_csv_file="jigsaw_data/test.csv", 124 | train=True, 125 | val_fraction=0.1, 126 | create_val_set=True, 127 | compute_bias_weights=True, 128 | loss_weight=0.75, 129 | classes=["toxic"], 130 | identity_classes=["female"], 131 | ): 132 | 133 | self.classes = classes 134 | 135 | self.identity_classes = identity_classes 136 | 137 | super().__init__( 138 | train_csv_file=train_csv_file, 139 | test_csv_file=test_csv_file, 140 | train=train, 141 | val_fraction=val_fraction, 142 | create_val_set=create_val_set, 143 | ) 144 | if train: 145 | if compute_bias_weights: 146 | self.weights = self.compute_weigths(self.train_set_pd) 147 | else: 148 | self.weights = None 149 | 150 | self.train = train 151 | self.loss_weight = loss_weight 152 | 153 | def __getitem__(self, index): 154 | meta = {} 155 | entry = self.data[index] 156 | text_id = entry["id"] 157 | text = entry["comment_text"] 158 | 159 | target_dict = {label: 1 if entry[label] >= 0.5 else 0 for label in self.classes} 160 | 161 | identity_target = { 162 | label: -1 if entry[label] is None else entry[label] 163 | for label in self.identity_classes 164 | } 165 | identity_target.update( 166 | {label: 1 for label in identity_target if identity_target[label] >= 0.5} 167 | ) 168 | identity_target.update( 169 | {label: 0 for label in identity_target if 0 <= identity_target[label] < 0.5} 170 | ) 171 | 172 | target_dict.update(identity_target) 173 | 174 | meta["multi_target"] = torch.tensor( 175 | list(target_dict.values()), dtype=torch.float32 176 | ) 177 | meta["text_id"] = text_id 178 | 179 | if self.train: 180 | meta["weights"] = self.weights[index] 181 | toxic_weight = ( 182 | self.weights[index] * self.loss_weight * 1.0 / len(self.classes) 183 | ) 184 | identity_weight = (1 - self.loss_weight) * 1.0 / len(self.identity_classes) 185 | meta["weights1"] = torch.tensor( 186 | [ 187 | *[toxic_weight] * len(self.classes), 188 | *[identity_weight] * len(self.identity_classes), 189 | ] 190 | ) 191 | 192 | return text, meta 193 | 194 | def compute_weigths(self, train_df): 195 | """Inspired from 2nd solution. 196 | Source: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/discussion/100661""" 197 | subgroup_bool = (train_df[self.identity_classes].fillna(0) >= 0.5).sum( 198 | axis=1 199 | ) > 0 200 | positive_bool = train_df["toxicity"] >= 0.5 201 | weights = np.ones(len(train_df)) * 0.25 202 | 203 | # Backgroud Positive and Subgroup Negative 204 | weights[ 205 | ((~subgroup_bool) & (positive_bool)) | ((subgroup_bool) & (~positive_bool)) 206 | ] += 0.25 207 | weights[(subgroup_bool)] += 0.25 208 | return weights 209 | 210 | 211 | class JigsawDataMultilingual(JigsawData): 212 | """Dataloader for the Jigsaw Multilingual Toxic Comment Classification. 213 | Source: https://www.kaggle.com/c/jigsaw-multilingual-toxic-comment-classification/ 214 | """ 215 | 216 | def __init__( 217 | self, 218 | train_csv_file="jigsaw_data/multilingual_challenge/jigsaw-toxic-comment-train.csv", 219 | test_csv_file="jigsaw_data/multilingual_challenge/validation.csv", 220 | train=True, 221 | val_fraction=0.1, 222 | create_val_set=False, 223 | classes=["toxic"], 224 | ): 225 | 226 | self.classes = classes 227 | super().__init__( 228 | train_csv_file=train_csv_file, 229 | test_csv_file=test_csv_file, 230 | train=train, 231 | val_fraction=val_fraction, 232 | create_val_set=create_val_set, 233 | ) 234 | 235 | def __getitem__(self, index): 236 | meta = {} 237 | entry = self.data[index] 238 | text_id = entry["id"] 239 | if "translated" in entry: 240 | text = entry["translated"] 241 | elif "comment_text_en" in entry: 242 | text = entry["comment_text_en"] 243 | else: 244 | text = entry["comment_text"] 245 | 246 | target_dict = {label: 1 if entry[label] >= 0.5 else 0 for label in self.classes} 247 | meta["target"] = torch.tensor(list(target_dict.values()), dtype=torch.int32) 248 | meta["text_id"] = text_id 249 | 250 | return text, meta 251 | -------------------------------------------------------------------------------- /language/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel, BertPreTrainedModel, BertForSequenceClassification, RobertaForSequenceClassification 4 | 5 | LANGUAGE_MODEL_DICT = { 6 | 'sst': 'barissayil/bert-sentiment-analysis-sst', 7 | 'jigsaw-toxic': 'unitary/toxic-bert', 8 | 'jigsaw-severe_toxic': 'unitary/toxic-bert', 9 | 'jigsaw-obscene': 'unitary/toxic-bert', 10 | 'jigsaw-threat': 'unitary/toxic-bert', 11 | 'jigsaw-insult': 'unitary/toxic-bert', 12 | 'jigsaw-identity_hate': 'unitary/toxic-bert', 13 | 'jigsaw-alt-toxic': 'unitary/unbiased-toxic-roberta', 14 | 'jigsaw-alt-severe_toxic': 'unitary/unbiased-toxic-roberta', 15 | 'jigsaw-alt-obscene': 'unitary/unbiased-toxic-roberta', 16 | 'jigsaw-alt-threat': 'unitary/unbiased-toxic-roberta', 17 | 'jigsaw-alt-insult': 'unitary/unbiased-toxic-roberta', 18 | 'jigsaw-alt-identity_hate': 'unitary/unbiased-toxic-roberta' 19 | } 20 | 21 | class BertForSentimentClassification(BertPreTrainedModel): 22 | def __init__(self, config): 23 | super().__init__(config) 24 | self.bert = BertModel(config) 25 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 26 | #The classification layer that takes the [CLS] representation and outputs the logit 27 | self.cls_layer = nn.Linear(config.hidden_size, 1) 28 | 29 | def forward(self, input_ids, attention_mask): 30 | ''' 31 | Inputs: 32 | -input_ids : Tensor of shape [B, T] containing token ids of sequences 33 | -attention_mask : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens 34 | (where B is the batch size and T is the input length) 35 | ''' 36 | #Feed the input to Bert model to obtain contextualized representations 37 | reps, _ = self.bert(input_ids=input_ids, attention_mask=attention_mask) 38 | #Obtain the representations of [CLS] heads 39 | cls_reps = reps[:, 0] 40 | # cls_reps = self.dropout(cls_reps) 41 | logits = self.cls_layer(cls_reps) 42 | return logits -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, math, time 2 | import torch as ch 3 | import torch.nn as nn 4 | 5 | from robustness.datasets import DATASETS 6 | 7 | # be sure to pip install glm_saga, or clone the repo from 8 | # https://github.com/madrylab/glm_saga 9 | from glm_saga.elasticnet import glm_saga 10 | 11 | import helpers.data_helpers as data_helpers 12 | import helpers.feature_helpers as feature_helpers 13 | 14 | from argparse import ArgumentParser 15 | 16 | ch.manual_seed(0) 17 | ch.set_grad_enabled(False) 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = ArgumentParser() 22 | parser.add_argument('--dataset', type=str, help='dataset name') 23 | parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]') 24 | parser.add_argument('--dataset-path', type=str, help='path to dataset') 25 | parser.add_argument('--model-path', type=str, help='path to model checkpoint') 26 | parser.add_argument('--arch', type=str, help='model architecture type') 27 | parser.add_argument('--out-path', help='location for saving results') 28 | parser.add_argument('--cache', action='store_true', help='cache deep features') 29 | parser.add_argument('--balance', action='store_true', help='balance classes for evaluation') 30 | 31 | parser.add_argument('--device', default='cuda') 32 | parser.add_argument('--random-seed', default=0) 33 | parser.add_argument('--num-workers', type=int, default=2) 34 | parser.add_argument('--batch-size', type=int, default=256) 35 | parser.add_argument('--val-frac', type=float, default=0.1) 36 | parser.add_argument('--lr-decay-factor', type=float, default=1) 37 | parser.add_argument('--lr', type=float, default=0.1) 38 | parser.add_argument('--alpha', type=float, default=0.99) 39 | parser.add_argument('--max-epochs', type=int, default=2000) 40 | parser.add_argument('--verbose', type=int, default=200) 41 | parser.add_argument('--tol', type=float, default=1e-4) 42 | parser.add_argument('--lookbehind', type=int, default=3) 43 | parser.add_argument('--lam-factor', type=float, default=0.001) 44 | parser.add_argument('--group', action='store_true') 45 | args = parser.parse_args() 46 | 47 | start_time = time.time() 48 | 49 | out_dir = args.out_path 50 | out_dir_ckpt = f'{out_dir}/checkpoint' 51 | out_dir_feats = f'{out_dir}/features' 52 | for path in [out_dir, out_dir_ckpt, out_dir_feats]: 53 | if not os.path.exists(path): 54 | os.makedirs(path) 55 | 56 | print("Initializing dataset and loader...") 57 | 58 | dataset, train_loader, test_loader = data_helpers.load_dataset(args.dataset, 59 | os.path.expandvars(args.dataset_path), 60 | args.dataset_type, 61 | args.batch_size, 62 | args.num_workers, 63 | shuffle=False, 64 | model_path=args.model_path) 65 | 66 | num_classes = dataset.num_classes 67 | Ntotal = len(train_loader.dataset) 68 | 69 | print("Loading model...") 70 | model, pooled_output = feature_helpers.load_model(args.model_path, 71 | args.arch, 72 | dataset, 73 | args.dataset, 74 | args.dataset_type, 75 | device=args.device) 76 | 77 | print("Computing/loading deep features...") 78 | feature_loaders = {} 79 | for mode, loader in zip(['train', 'test'], [train_loader, test_loader]): 80 | print(f"For {mode} set...") 81 | 82 | sink_path = f"{out_dir_feats}/features_{mode}" if args.cache else None 83 | metadata_path = f"{out_dir_feats}/metadata_{mode}.pth" if args.cache else None 84 | 85 | feature_ds, feature_loader = feature_helpers.compute_features(loader, 86 | model, 87 | dataset_type=args.dataset_type, 88 | pooled_output=pooled_output, 89 | batch_size=args.batch_size, 90 | num_workers=args.num_workers, 91 | shuffle=(mode == 'test'), 92 | device=args.device, 93 | filename=sink_path, 94 | balance=args.balance if mode == 'test' else False) 95 | 96 | if mode == 'train': 97 | metadata = feature_helpers.calculate_metadata(feature_loader, 98 | num_classes=num_classes, 99 | filename=metadata_path) 100 | split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds, 101 | Ntotal, 102 | val_frac=args.val_frac, 103 | batch_size=args.batch_size, 104 | num_workers=args.num_workers, 105 | random_seed=args.random_seed, 106 | shuffle=True, balance=args.balance) 107 | feature_loaders.update({mm : data_helpers.add_index_to_dataloader(split_loaders[mi]) 108 | for mi, mm in enumerate(['train', 'val'])}) 109 | 110 | else: 111 | feature_loaders[mode] = feature_loader 112 | 113 | 114 | num_features = metadata["X"]["num_features"][0] 115 | assert metadata["y"]["num_classes"].numpy() == num_classes 116 | 117 | print("Initializing linear model...") 118 | linear = nn.Linear(num_features, num_classes).to(args.device) 119 | for p in [linear.weight, linear.bias]: 120 | p.data.zero_() 121 | 122 | print("Preparing normalization preprocess and indexed dataloader") 123 | preprocess = data_helpers.NormalizedRepresentation(feature_loaders['train'], 124 | metadata=metadata, 125 | device=linear.weight.device) 126 | 127 | print("Calculating the regularization path") 128 | params = glm_saga(linear, 129 | feature_loaders['train'], 130 | args.lr, 131 | args.max_epochs, 132 | args.alpha, 133 | val_loader=feature_loaders['val'], 134 | test_loader=feature_loaders['test'], 135 | n_classes=num_classes, 136 | checkpoint=out_dir_ckpt, 137 | verbose=args.verbose, 138 | tol=args.tol, 139 | lookbehind=args.lookbehind, 140 | lr_decay_factor=args.lr_decay_factor, 141 | group=args.group, 142 | epsilon=args.lam_factor, 143 | metadata=metadata, 144 | preprocess=preprocess) 145 | 146 | print(f"Total time: {time.time() - start_time}") 147 | 148 | -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MadryLab/DebuggableDeepNetworks/a2491c72a0c9a07b05803a88e9936b4dcc8fd080/pipeline.png -------------------------------------------------------------------------------- /pipeline_600x400.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MadryLab/DebuggableDeepNetworks/a2491c72a0c9a07b05803a88e9936b4dcc8fd080/pipeline_600x400.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | pandas 4 | numpy 5 | scipy 6 | GPUtil 7 | dill 8 | tensorboardX 9 | tables 10 | tqdm 11 | seaborn 12 | jupyter 13 | sklearn 14 | pillow 15 | transformers 16 | datasets 17 | kornia 18 | lime 19 | robustness 20 | cox 21 | glmsaga --------------------------------------------------------------------------------