├── .gitmodules
├── LICENSE
├── README.md
├── environment.yml
├── helpers
├── data_helpers.py
├── decisionlayer_helpers.py
├── feature_helpers.py
├── nlp_helpers.py
└── vis_helpers.py
├── inspect_language_models.ipynb
├── inspect_vision_models.ipynb
├── language
├── datasets.py
├── jigsaw_loaders.py
└── models.py
├── main.py
├── pipeline.png
├── pipeline_600x400.png
└── requirements.txt
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "lucent"]
2 | path = lucent
3 | url = https://github.com/greentfrapp/lucent.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Eric Wong & Shibani Santurkar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Improving Deep Network Debuggability via Sparse Decision Layers
2 |
3 | This repository contains the code for our paper:
4 |
5 | **Leveraging Sparse Linear Layers for Debuggable Deep Networks**
6 | *Eric Wong\*, Shibani Santurkar\*, Aleksander Madry*
7 | Paper: http://arxiv.org/abs/2105.04857
8 | Blog posts: [Part1](https://gradientscience.org/glm_saga) and [Part2](https://gradientscience.org/debugging)
9 |
10 |
11 |
12 |
13 |
14 |
15 | ```bibtex
16 | @article{wong2021leveraging,
17 | title={Leveraging Sparse Linear Layers for Debuggable Deep Networks},
18 | author={Wong, Eric and Santurkar, Shibani and M{\k{a}}dry, Aleksander},
19 | journal={arXiv preprint arXiv:2105.04857},
20 | year={2021}
21 | }
22 | ```
23 |
24 | ## Getting started
25 | *Our code relies on the [MadryLab](http://madry-lab.ml/) public [`robustness`](https://github.com/MadryLab/robustness) library, as well as the [`glm_saga`](https://github.com/MadryLab/glm_saga) library which will be automatically installed when you follow the instructions below. The [`glm_saga`](https://github.com/MadryLab/glm_saga) library contains a standalone implementation of our sparse GLM solver.*
26 | 1. Clone our repo: `git clone https://github.com/microsoft/DebuggableDeepNetworks.git`
27 |
28 | 2. Setup the [lucent](https://github.com/greentfrapp/lucent) submodule using: `git submodule update --init --recursive`
29 |
30 | 3. We recommend using conda for dependencies:
31 | ```
32 | conda env create -f environment.yml
33 | conda activate debuggable
34 | ```
35 |
36 | ## Training sparse decision layers
37 |
38 | Contents:
39 | + `main.py` fits a sparse decision layer on top of the deep features of the specified pre-trained (language/vision) deep network
40 | + `helpers/` has some helper functions for loading datasets, models, and features
41 | + `language/` has some additional code for handling language models and datasets
42 |
43 | To run the settings in our paper, you can use the following commands:
44 | ```
45 | # Sentiment classification
46 | python main.py --dataset sst --dataset-path --dataset-type language --model-path barissayil/bert-sentiment-analysis-sst --arch bert --out-path ./tmp/sst/ --cache
47 |
48 | # Toxic comment classification (biased)
49 | python main.py --dataset jigsaw-toxic --dataset-path --dataset-type language --model-path unitary/toxic-bert --arch bert --out-path ./tmp/jigsaw-toxic/ --cache --balance
50 |
51 | # Toxic comment classification (unbiased)
52 | python main.py --dataset jigsaw-alt-toxic --dataset-path --dataset-type language --model-path unitary/unbiased-toxic-roberta --arch roberta --out-path ./tmp/unbiased-jigsaw-toxic/ --cache --balance
53 |
54 | # Places-10
55 | python main.py --dataset places-10 --dataset-path --dataset-type vision --model-path --arch resnet50 --out-path ./tmp/places/ --cache
56 |
57 | # ImageNet
58 | python main.py --dataset imagenet --dataset-path --dataset-type vision --model-path --arch resnet50 --out-path ./tmp/imagenet/ --cache
59 | ```
60 |
61 | ## Interpreting deep features
62 | After fitting a sparse GLM with one of the above commands, we provide some
63 | notebooks for inspecting and visualizing the resulting features. See
64 | `inspect_vision_models.ipynb` and `inspect_language_models.ipynb` for the vision and language settings respectively.
65 |
66 | # Maintainers
67 |
68 | * [Eric Wong](https://twitter.com/RICEric22)
69 | * [Shibani Santurkar](https://twitter.com/ShibaniSan)
70 | * [Aleksander Madry](https://twitter.com/aleks_madry)
71 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: debuggable
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - anaconda
6 | - defaults
7 | dependencies:
8 | - backcall=0.2.0=py_0
9 | - blas=1.0=mkl
10 | - ca-certificates=2021.4.13=h06a4308_1
11 | - certifi=2020.12.5=py37h06a4308_0
12 | - cudatoolkit=10.2.89=hfd86e86_1
13 | - cycler=0.10.0=py_2
14 | - dbus=1.13.18=hb2f20db_0
15 | - decorator=4.4.2=py_0
16 | - expat=2.2.10=he6710b0_2
17 | - fontconfig=2.13.0=h9420a91_0
18 | - freetype=2.10.4=h5ab3b9f_0
19 | - glib=2.66.1=h92f7085_0
20 | - gst-plugins-base=1.14.0=hbbd80ab_1
21 | - gstreamer=1.14.0=hb31296c_0
22 | - icu=58.2=he6710b0_3
23 | - intel-openmp=2020.2=254
24 | - ipykernel=5.3.4=py37h5ca1d4c_0
25 | - ipython=7.18.1=py37h5ca1d4c_0
26 | - ipython_genutils=0.2.0=py37_0
27 | - jedi=0.17.2=py37_0
28 | - jpeg=9b=habf39ab_1
29 | - jupyter_client=6.1.7=py_0
30 | - jupyter_core=4.6.3=py37_0
31 | - kiwisolver=1.3.1=py37hc928c03_0
32 | - lcms2=2.11=h396b838_0
33 | - ld_impl_linux-64=2.33.1=h53a641e_7
34 | - libedit=3.1.20191231=h14c3975_1
35 | - libffi=3.3=he6710b0_2
36 | - libgcc-ng=9.1.0=hdf63c60_0
37 | - libgfortran-ng=7.3.0=hdf63c60_0
38 | - libpng=1.6.37=hbc83047_0
39 | - libprotobuf=3.13.0.1=h8b12597_0
40 | - libsodium=1.0.18=h7b6447c_0
41 | - libstdcxx-ng=9.1.0=hdf63c60_0
42 | - libtiff=4.1.0=h2733197_1
43 | - libuuid=1.0.3=h1bed415_2
44 | - libuv=1.40.0=h7b6447c_0
45 | - libxcb=1.14=h7b6447c_0
46 | - libxml2=2.9.10=hb55368b_3
47 | - lz4-c=1.9.2=heb0550a_3
48 | - matplotlib=3.3.2=h06a4308_0
49 | - matplotlib-base=3.3.2=py37h817c723_0
50 | - mkl=2019.4=243
51 | - mkl-service=2.3.0=py37he904b0f_0
52 | - mkl_fft=1.2.0=py37h23d657b_0
53 | - mkl_random=1.0.4=py37hd81dba3_0
54 | - ncurses=6.2=he6710b0_1
55 | - ninja=1.10.1=py37hfd86e86_0
56 | - numpy=1.19.1=py37hbc911f0_0
57 | - numpy-base=1.19.1=py37hfa32c7d_0
58 | - olefile=0.46=py37_0
59 | - openssl=1.1.1k=h27cfd23_0
60 | - pandas=1.1.3=py37he6710b0_0
61 | - parso=0.7.0=py_0
62 | - pcre=8.44=he6710b0_0
63 | - pexpect=4.8.0=py37_1
64 | - pickleshare=0.7.5=py37_1001
65 | - pillow=8.0.0=py37h9a89aac_0
66 | - pip=20.2.4=py37_0
67 | - prompt-toolkit=3.0.8=py_0
68 | - ptyprocess=0.6.0=py37_0
69 | - pygments=2.7.1=py_0
70 | - pyparsing=2.4.7=pyh9f0ad1d_0
71 | - pyqt=5.9.2=py37h05f1152_2
72 | - python=3.7.9=h7579374_0
73 | - python-dateutil=2.8.1=py_0
74 | - python_abi=3.7=1_cp37m
75 | - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0
76 | - pytz=2020.1=py_0
77 | - pyzmq=19.0.2=py37he6710b0_1
78 | - qt=5.9.7=h5867ecd_1
79 | - readline=8.0=h7b6447c_0
80 | - seaborn=0.11.0=py_0
81 | - setuptools=50.3.0=py37hb0f4dca_1
82 | - sip=4.19.8=py37hf484d3e_0
83 | - six=1.15.0=py_0
84 | - sqlite=3.33.0=h62c20be_0
85 | - tensorboardx=2.1=py_0
86 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
87 | - tk=8.6.10=hbc83047_0
88 | - torchvision=0.8.1=py37_cu102
89 | - tornado=6.0.4=py37h7b6447c_1
90 | - traitlets=5.0.5=py_0
91 | - typing_extensions=3.7.4.3=py_0
92 | - wcwidth=0.2.5=py_0
93 | - wheel=0.35.1=py_0
94 | - wordcloud=1.8.1=py37h4abf009_1
95 | - xz=5.2.5=h7b6447c_0
96 | - zeromq=4.3.3=he6710b0_3
97 | - zlib=1.2.11=h7b6447c_3
98 | - zstd=1.4.4=h0b5b093_3
99 | - pip:
100 | - argon2-cffi==20.1.0
101 | - async-generator==1.10
102 | - attrs==20.3.0
103 | - bleach==3.3.0
104 | - cffi==1.14.5
105 | - chardet==3.0.4
106 | - click==7.1.2
107 | - coverage==5.5
108 | - coveralls==3.0.1
109 | - cox==0.1.post3
110 | - dataclasses==0.6
111 | - datasets==1.1.3
112 | - defusedxml==0.7.1
113 | - dill==0.3.3
114 | - docopt==0.6.2
115 | - entrypoints==0.3
116 | - filelock==3.0.12
117 | - future==0.18.2
118 | - gitdb==4.0.5
119 | - gitpython==3.1.11
120 | - glm-saga==0.1.1
121 | - grpcio==1.34.0
122 | - idna==2.10
123 | - imageio==2.9.0
124 | - importlib-metadata==3.7.3
125 | - iniconfig==1.1.1
126 | - ipywidgets==7.6.3
127 | - jinja2==2.11.3
128 | - joblib==0.17.0
129 | - jsonschema==3.2.0
130 | - jupyterlab-pygments==0.1.2
131 | - jupyterlab-widgets==1.0.0
132 | - kornia==0.4.1
133 | - lime==0.2.0.1
134 | - markupsafe==1.1.1
135 | - mistune==0.8.4
136 | - multiprocess==0.70.11.1
137 | - nbclient==0.5.3
138 | - nbconvert==6.0.7
139 | - nbformat==5.1.2
140 | - nest-asyncio==1.5.1
141 | - networkx==2.5
142 | - notebook==6.2.0
143 | - packaging==20.7
144 | - pandocfilters==1.4.3
145 | - pluggy==0.13.1
146 | - prometheus-client==0.9.0
147 | - protobuf==3.14.0
148 | - psutil==5.7.3
149 | - py==1.10.0
150 | - py3nvml==0.2.6
151 | - pyarrow==2.0.0
152 | - pycparser==2.20
153 | - pyrsistent==0.17.3
154 | - pytest==6.2.4
155 | - pytest-mock==3.6.1
156 | - pywavelets==1.1.1
157 | - regex==2020.11.13
158 | - requests==2.25.0
159 | - robustness
160 | - sacremoses==0.0.43
161 | - scikit-image==0.17.2
162 | - scikit-learn==0.23.2
163 | - scipy==1.5.4
164 | - send2trash==1.5.0
165 | - sentencepiece==0.1.91
166 | - smmap==3.0.4
167 | - terminado==0.9.3
168 | - testpath==0.4.4
169 | - tifffile==2020.11.26
170 | - tokenizers==0.9.3
171 | - toml==0.10.2
172 | - torch-lucent==0.1.8
173 | - tqdm==4.49.0
174 | - transformers==3.5.1
175 | - urllib3==1.26.2
176 | - webencodings==0.5.1
177 | - widgetsnbextension==3.5.1
178 | - xmltodict==0.12.0
179 | - xxhash==2.0.0
180 | - zipp==3.4.1
181 |
--------------------------------------------------------------------------------
/helpers/data_helpers.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | sys.path.append('..')
3 | import numpy as np
4 | import torch as ch
5 | from torch.utils.data import TensorDataset
6 | from robustness.datasets import DATASETS as VISION_DATASETS
7 | from robustness.tools.label_maps import CLASS_DICT
8 | from language.datasets import DATASETS as LANGUAGE_DATASETS
9 | from language.models import LANGUAGE_MODEL_DICT
10 | from transformers import AutoTokenizer
11 |
12 | def get_label_mapping(dataset_name):
13 | if dataset_name == 'imagenet':
14 | return CLASS_DICT['ImageNet']
15 | elif dataset_name == 'places-10':
16 | return CD_PLACES
17 | elif dataset_name == 'sst':
18 | return {0: 'negative', 1: 'positive'}
19 | elif 'jigsaw' in dataset_name:
20 | category = dataset_name.split('jigsaw-')[1] if 'alt' not in dataset_name \
21 | else dataset_name.split('jigsaw-alt-')[1]
22 | return {0: f'not {category}', 1: f'{category}'}
23 | else:
24 | raise ValueError("Dataset not currently supported...")
25 |
26 | def load_dataset(dataset_name, dataset_path, dataset_type,
27 | batch_size, num_workers,
28 | maxlen_train=256, maxlen_val=256,
29 | shuffle=False, model_path=None, return_sentences=False):
30 |
31 |
32 | if dataset_type == 'vision':
33 | if dataset_name == 'places-10': dataset_name = 'places365'
34 | if dataset_name not in VISION_DATASETS:
35 | raise ValueError("Vision dataset not currently supported...")
36 | dataset = VISION_DATASETS[dataset_name](os.path.expandvars(dataset_path))
37 |
38 | if dataset_name == 'places365':
39 | dataset.num_classes = 10
40 |
41 | train_loader, test_loader = dataset.make_loaders(num_workers,
42 | batch_size,
43 | data_aug=False,
44 | shuffle_train=shuffle,
45 | shuffle_val=shuffle)
46 | return dataset, train_loader, test_loader
47 | else:
48 | if model_path is None:
49 | model_path = LANGUAGE_MODEL_DICT[dataset_name]
50 |
51 | tokenizer = AutoTokenizer.from_pretrained(model_path)
52 |
53 | kwargs = {} if 'jigsaw' not in dataset_name else \
54 | {'label': dataset_name[11:] if 'alt' in dataset_name \
55 | else dataset_name[7:]}
56 | kwargs['return_sentences'] = return_sentences
57 | train_set = LANGUAGE_DATASETS(dataset_name)(filename=f'{dataset_path}/train.tsv',
58 | maxlen=maxlen_train,
59 | tokenizer=tokenizer,
60 | **kwargs)
61 | test_set = LANGUAGE_DATASETS(dataset_name)(filename=f'{dataset_path}/test.tsv',
62 | maxlen=maxlen_val,
63 | tokenizer=tokenizer,
64 | **kwargs)
65 | train_loader = ch.utils.data.DataLoader(dataset=train_set,
66 | batch_size=batch_size,
67 | num_workers=num_workers)
68 | test_loader = ch.utils.data.DataLoader(dataset=test_set,
69 | batch_size=batch_size,
70 | num_workers=num_workers)
71 | #assert len(np.unique(train_set.df['label'].values)) == len(np.unique(test_set.df['label'].values))
72 | train_set.num_classes = 2
73 | # train_loader.dataset.targets = train_loader.dataset.df['label'].values
74 | # test_loader.dataset.targets = test_loader.dataset.df['label'].values
75 |
76 | return train_set, train_loader, test_loader
77 |
78 |
79 | class IndexedTensorDataset(ch.utils.data.TensorDataset):
80 | def __getitem__(self, index):
81 | val = super(IndexedTensorDataset, self).__getitem__(index)
82 | return val + (index,)
83 |
84 | class IndexedDataset(ch.utils.data.Dataset):
85 | def __init__(self, ds, sample_weight=None):
86 | super(ch.utils.data.Dataset, self).__init__()
87 | self.dataset = ds
88 | self.sample_weight=sample_weight
89 |
90 | def __getitem__(self, index):
91 | val = self.dataset[index]
92 | if self.sample_weight is None:
93 | return val + (index,)
94 | else:
95 | weight = self.sample_weight[index]
96 | return val + (weight,index)
97 | def __len__(self):
98 | return len(self.dataset)
99 |
100 | def add_index_to_dataloader(loader, sample_weight=None):
101 | return ch.utils.data.DataLoader(
102 | IndexedDataset(loader.dataset, sample_weight=sample_weight),
103 | batch_size=loader.batch_size,
104 | sampler=loader.sampler,
105 | num_workers=loader.num_workers,
106 | collate_fn=loader.collate_fn,
107 | pin_memory=loader.pin_memory,
108 | drop_last=loader.drop_last,
109 | timeout=loader.timeout,
110 | worker_init_fn=loader.worker_init_fn,
111 | multiprocessing_context=loader.multiprocessing_context
112 | )
113 |
114 | class NormalizedRepresentation(ch.nn.Module):
115 | def __init__(self, loader, metadata, device='cuda', tol=1e-5):
116 | super(NormalizedRepresentation, self).__init__()
117 |
118 |
119 | assert metadata is not None
120 | self.device = device
121 | self.mu = metadata['X']['mean']
122 | self.sigma = ch.clamp(metadata['X']['std'], tol)
123 |
124 | def forward(self, X):
125 | return (X - self.mu.to(self.device))/self.sigma.to(self.device)
126 |
127 | CD_PLACES = {0: 'airport_terminal',
128 | 1: 'boat_deck',
129 | 2: 'bridge',
130 | 3: 'butchers_shop',
131 | 4: 'church-outdoor',
132 | 5: 'hotel_room',
133 | 6: 'laundromat',
134 | 7: 'river',
135 | 8: 'ski_slope',
136 | 9: 'volcano'}
137 |
--------------------------------------------------------------------------------
/helpers/decisionlayer_helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch as ch
4 | import pandas as pd
5 |
6 | def load_glm(result_dir):
7 |
8 | Nlambda = max([int(f.split('params')[1].split('.pth')[0])
9 | for f in os.listdir(result_dir) if 'params' in f]) + 1
10 |
11 | print(f"Loading regularization path of length {Nlambda}")
12 |
13 | params_dict = {i: ch.load(os.path.join(result_dir, f"params{i}.pth"),
14 | map_location=ch.device('cpu')) for i in range(Nlambda)}
15 |
16 | regularization_strengths = [params_dict[i]['lam'].item() for i in range(Nlambda)]
17 | weights = [params_dict[i]['weight'] for i in range(Nlambda)]
18 | biases = [params_dict[i]['bias'] for i in range(Nlambda)]
19 |
20 | metrics = {'acc_tr': [], 'acc_val': [], 'acc_test': []}
21 |
22 | for k in metrics.keys():
23 | for i in range(Nlambda):
24 | metrics[k].append(params_dict[i]['metrics'][k])
25 | metrics[k] = 100 * np.stack(metrics[k])
26 | metrics = pd.DataFrame(metrics)
27 | metrics = metrics.rename(columns={'acc_tr': 'acc_train'})
28 |
29 | weights_stacked = ch.stack(weights)
30 | sparsity = ch.sum(weights_stacked != 0, dim=2).numpy()
31 |
32 | return {'metrics': metrics,
33 | 'regularization_strengths': regularization_strengths,
34 | 'weights': weights,
35 | 'biases': biases,
36 | 'sparsity': sparsity,
37 | 'weight_dense': weights[-1],
38 | 'bias_dense': biases[-1]}
39 |
40 | def select_sparse_model(result_dict,
41 | selection_criterion='absolute',
42 | factor=6):
43 |
44 | assert selection_criterion in ['sparsity', 'absolute', 'relative', 'percentile']
45 |
46 | metrics, sparsity = result_dict['metrics'], result_dict['sparsity']
47 |
48 | acc_val, acc_test = metrics['acc_val'], metrics['acc_test']
49 |
50 | if factor == 0:
51 | sel_idx = -1
52 | elif selection_criterion == 'sparsity':
53 | sel_idx = np.argmin(np.abs(np.mean(sparsity, axis=1) - factor))
54 | elif selection_criterion == 'relative':
55 | sel_idx = np.argmin(np.abs(acc_val - factor * np.max(acc_val)))
56 | elif selection_criterion == 'absolute':
57 | delta = acc_val - (np.max(acc_val) - factor)
58 | lidx = np.where(delta <= 0)[0]
59 | sel_idx = lidx[np.argmin(-delta[lidx])]
60 | elif selection_criterion == 'percentile':
61 | diff = np.max(acc_val) - np.min(acc_val)
62 | sel_idx = np.argmax(acc_val > np.max(acc_val) - factor * diff)
63 |
64 | print(f"Test accuracy | Best: {max(acc_test): .2f},",
65 | f"Sparse: {acc_test[sel_idx]:.2f}",
66 | f"Sparsity: {np.mean(sparsity[sel_idx]):.2f}")
67 |
68 | result_dict.update({'weight_sparse': result_dict['weights'][sel_idx],
69 | 'bias_sparse': result_dict['biases'][sel_idx]})
70 | return result_dict
--------------------------------------------------------------------------------
/helpers/feature_helpers.py:
--------------------------------------------------------------------------------
1 | import os, math, sys
2 | sys.path.append('..')
3 | import numpy as np
4 | import torch as ch
5 | from torch._utils import _accumulate
6 | from torch.utils.data import Subset
7 | from tqdm import tqdm
8 | from robustness.model_utils import make_and_restore_model
9 | from robustness.loaders import LambdaLoader
10 | from transformers import AutoConfig
11 | import language.models as lm
12 |
13 | def load_model(model_root, arch, dataset, dataset_name,
14 | dataset_type, device='cuda'):
15 | """Loads existing vision/language models.
16 | Args:
17 | model_root (str): Path to model
18 | arch (str): Model architecture
19 | dataset (torch dataset): Dataset on which the model was
20 | trained
21 | dataset_name (str): Name of dataset
22 | dataset_type (str): One of vision or language
23 | device (str): Device on which to keep the model
24 | Returns:
25 | model: Torch model
26 | pooled_output (bool): Whether or not to pool outputs
27 | (only relevant for some language models)
28 | """
29 |
30 | if model_root is None and dataset_type == 'language':
31 | model_root = lm.LANGUAGE_MODEL_DICT[dataset_name]
32 |
33 | pooled_output = None
34 | if dataset_type == 'vision':
35 | model, _ = make_and_restore_model(arch=arch,
36 | dataset=dataset,
37 | resume_path=model_root,
38 | pytorch_pretrained=(model_root is None)
39 | )
40 | else:
41 | config = AutoConfig.from_pretrained(model_root)
42 | if config.model_type == 'bert':
43 | if model_root == 'barissayil/bert-sentiment-analysis-sst':
44 | model = lm.BertForSentimentClassification.from_pretrained(model_root)
45 | pooled_output = False
46 | else:
47 | model = lm.BertForSequenceClassification.from_pretrained(model_root)
48 | pooled_output = True
49 | elif config.model_type == 'roberta':
50 | model = lm.RobertaForSequenceClassification.from_pretrained(model_root)
51 | pooled_output = False
52 | else:
53 | raise ValueError('This transformer model is not supported yet.')
54 |
55 | model.eval()
56 | model = ch.nn.DataParallel(model.to(device))
57 | return model, pooled_output
58 |
59 | def get_features_batch(batch, model, dataset_type, pooled_output=None, device='cuda'):
60 |
61 | if dataset_type == 'vision':
62 | ims, targets = batch
63 | (_,latents), _ = model(ims.to(device), with_latent=True)
64 | else:
65 | (input_ids, attention_mask, targets) = batch
66 | mask = targets != -1
67 | input_ids, attention_mask, targets = [t[mask] for t in (input_ids, attention_mask, targets)]
68 |
69 | if hasattr(model, 'module'):
70 | model = model.module
71 | if hasattr(model, "roberta"):
72 | latents = model.roberta(input_ids=input_ids.to(device),
73 | attention_mask=attention_mask.to(device))[0]
74 | latents = model.classifier.dropout(latents[:,0,:])
75 | latents = model.classifier.dense(latents)
76 | latents = ch.tanh(latents)
77 | latents = model.classifier.dropout(latents)
78 | else:
79 | latents = model.bert(input_ids=input_ids.to(device),
80 | attention_mask=attention_mask.to(device))
81 | if pooled_output:
82 | latents = latents[1]
83 | else:
84 | latents = latents[0][:,0]
85 | return latents, targets
86 |
87 | def compute_features(loader, model, dataset_type, pooled_output,
88 | batch_size, num_workers,
89 | shuffle=False, device='cuda',
90 | filename=None, chunk_threshold=20000, balance=False):
91 |
92 | """Compute deep features for a given dataset using a modeln and returnss
93 | them as a pytorch dataset and loader.
94 | Args:
95 | loader : Torch data loader
96 | model: Torch model
97 | dataset_type (str): One of vision or language
98 | pooled_output (bool): Whether or not to pool outputs
99 | (only relevant for some language models)
100 | batch_size (int): Batch size for output loader
101 | num_workers (int): Number of workers to use for output loader
102 | shuffle (bool): Whether or not to shuffle output data loaoder
103 | device (str): Device on which to keep the model
104 | filename (str):Optional file to cache computed feature. Recommended
105 | for large datasets like ImageNet.
106 | chunk_threshold (int): Size of shard while caching
107 | balance (bool): Whether or not to balance output data loader
108 | (only relevant for some language models)
109 | Returns:
110 | feature_dataset: Torch dataset with deep features
111 | feature_loader: Torch data loader with deep features
112 | """
113 |
114 | if filename is None or not os.path.exists(os.path.join(filename, f'0_features.npy')):
115 |
116 | all_latents, all_targets = [], []
117 | Nsamples, chunk_id = 0, 0
118 |
119 | for batch_idx, batch in tqdm(enumerate(loader), total=len(loader)):
120 |
121 | with ch.no_grad():
122 | latents, targets = get_features_batch(batch, model, dataset_type,
123 | pooled_output=pooled_output,
124 | device=device)
125 |
126 | if batch_idx == 0:
127 | print("Latents shape", latents.shape)
128 |
129 |
130 | Nsamples += latents.size(0)
131 |
132 | all_latents.append(latents.cpu())
133 | all_targets.append(targets.cpu())
134 |
135 | if filename is not None and Nsamples > chunk_threshold:
136 | if not os.path.exists(filename): os.makedirs(filename)
137 | np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy())
138 | np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy())
139 | all_latents, all_targets, Nsamples = [], [], 0
140 | chunk_id += 1
141 |
142 | if filename is not None and Nsamples > 0:
143 | if not os.path.exists(filename): os.makedirs(filename)
144 | np.save(os.path.join(filename, f'{chunk_id}_features.npy'), ch.cat(all_latents).numpy())
145 | np.save(os.path.join(filename, f'{chunk_id}_labels.npy'), ch.cat(all_targets).numpy())
146 |
147 |
148 | feature_dataset = load_features(filename) if filename is not None else \
149 | ch.utils.data.TensorDataset(ch.cat(all_latents), ch.cat(all_targets))
150 | if balance:
151 | feature_dataset = balance_dataset(feature_dataset)
152 |
153 | feature_loader = ch.utils.data.DataLoader(feature_dataset,
154 | num_workers=num_workers,
155 | batch_size=batch_size,
156 | shuffle=shuffle)
157 |
158 | return feature_dataset, feature_loader
159 |
160 | def balance_dataset(dataset):
161 | """Balances a given dataset to have the same number of samples/class.
162 | Args:
163 | dataset : Torch dataset
164 | Returns:
165 | Torch dataset with equal number of samples/class
166 | """
167 |
168 | print("Balancing dataset...")
169 | n = len(dataset)
170 | labels = ch.Tensor([dataset[i][1] for i in range(n)]).int()
171 | n0 = sum(labels).item()
172 | I_pos = labels == 1
173 |
174 | idx = ch.arange(n)
175 | idx_pos = idx[I_pos]
176 | ch.manual_seed(0)
177 | I = ch.randperm(n - n0)[:n0]
178 | idx_neg = idx[~I_pos][I]
179 | idx_bal = ch.cat([idx_pos, idx_neg],dim=0)
180 | return Subset(dataset, idx_bal)
181 |
182 | def load_features_mode(feature_path, mode='test',
183 | num_workers=10, batch_size=128):
184 | """Loads precomputed deep features corresponding to the
185 | train/test set along with normalization statitic.
186 | Args:
187 | feature_path (str): Path to precomputed deep features
188 | mode (str): One of train or tesst
189 | num_workers (int): Number of workers to use for output loader
190 | batch_size (int): Batch size for output loader
191 |
192 | Returns:
193 | features (np.array): Recovered deep features
194 | feature_mean: Mean of deep features
195 | feature_std: Standard deviation of deep features
196 | """
197 | feature_dataset = load_features(os.path.join(feature_path, f'features_{mode}'))
198 | feature_loader = ch.utils.data.DataLoader(feature_dataset,
199 | num_workers=num_workers,
200 | batch_size=batch_size,
201 | shuffle=False)
202 |
203 | feature_metadata = ch.load(os.path.join(feature_path, f'metadata_train.pth'))
204 | feature_mean, feature_std = feature_metadata['X']['mean'], feature_metadata['X']['std']
205 |
206 |
207 | features = []
208 |
209 | for _, (feature, _) in tqdm(enumerate(feature_loader), total=len(feature_loader)):
210 | features.append(feature)
211 |
212 | features = ch.cat(features).numpy()
213 | return features, feature_mean, feature_std
214 |
215 | def load_features(feature_path):
216 | """Loads precomputed deep features.
217 | Args:
218 | feature_path (str): Path to precomputed deep features
219 |
220 | Returns:
221 | Torch dataset with recovered deep features.
222 | """
223 | if not os.path.exists(os.path.join(feature_path, f"0_features.npy")):
224 | raise ValueError(f"The provided location {feature_path} does not contain any representation files")
225 |
226 | ds_list, chunk_id = [], 0
227 | while os.path.exists(os.path.join(feature_path, f"{chunk_id}_features.npy")):
228 | features = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_features.npy"))).float()
229 | labels = ch.from_numpy(np.load(os.path.join(feature_path, f"{chunk_id}_labels.npy"))).long()
230 | ds_list.append(ch.utils.data.TensorDataset(features, labels))
231 | chunk_id += 1
232 |
233 | print(f"==> loaded {chunk_id} files of representations...")
234 | return ch.utils.data.ConcatDataset(ds_list)
235 |
236 |
237 | def calculate_metadata(loader, num_classes=None, filename=None):
238 | """Calculates mean and standard deviation of the deep features over
239 | a given set of images.
240 | Args:
241 | loader : torch data loader
242 | num_classes (int): Number of classes in the dataset
243 | filename (str): Optional filepath to cache metadata. Recommended
244 | for large datasets like ImageNet.
245 |
246 | Returns:
247 | metadata (dict): Dictionary with desired statistics.
248 | """
249 |
250 | if filename is not None and os.path.exists(filename):
251 | return ch.load(filename)
252 |
253 | # Calculate number of classes if not given
254 | if num_classes is None:
255 | num_classes = 1
256 | for batch in loader:
257 | y = batch[1]
258 | print(y)
259 | num_classes = max(num_classes, y.max().item()+1)
260 |
261 | eye = ch.eye(num_classes)
262 |
263 | X_bar, y_bar, y_max, n = 0, 0, 0, 0
264 |
265 | # calculate means and maximum
266 | print("Calculating means")
267 | for X,y in tqdm(loader, total=len(loader)):
268 | X_bar += X.sum(0)
269 | y_bar += eye[y].sum(0)
270 | y_max = max(y_max, y.max())
271 | n += y.size(0)
272 | X_bar = X_bar.float()/n
273 | y_bar = y_bar.float()/n
274 |
275 | # calculate std
276 | X_std, y_std = 0, 0
277 | print("Calculating standard deviations")
278 | for X,y in tqdm(loader, total=len(loader)):
279 | X_std += ((X - X_bar)**2).sum(0)
280 | y_std += ((eye[y] - y_bar)**2).sum(0)
281 | X_std = ch.sqrt(X_std.float()/n)
282 | y_std = ch.sqrt(y_std.float()/n)
283 |
284 | # calculate maximum regularization
285 | inner_products = 0
286 | print("Calculating maximum lambda")
287 | for X,y in tqdm(loader, total=len(loader)):
288 | y_map = (eye[y] - y_bar)/y_std
289 | inner_products += X.t().mm(y_map)*y_std
290 |
291 | inner_products_group = inner_products.norm(p=2,dim=1)
292 |
293 | metadata = {
294 | "X": {
295 | "mean": X_bar,
296 | "std": X_std,
297 | "num_features": X.size()[1:],
298 | "num_examples": n
299 | },
300 | "y": {
301 | "mean": y_bar,
302 | "std": y_std,
303 | "num_classes": y_max+1
304 | },
305 | "max_reg": {
306 | "group": inner_products_group.abs().max().item()/n,
307 | "nongrouped": inner_products.abs().max().item()/n
308 | }
309 | }
310 |
311 | if filename is not None:
312 | ch.save(metadata, filename)
313 |
314 | return metadata
315 |
316 | def split_dataset(dataset, Ntotal, val_frac,
317 | batch_size, num_workers,
318 | random_seed=0, shuffle=True, balance=False):
319 | """Splits a given dataset into train and validation
320 | Args:
321 | dataset : Torch dataset
322 | Ntotal: Total number of dataset samples
323 | val_frac: Fraction to reserve for validation
324 | batch_size (int): Batch size for output loader
325 | num_workers (int): Number of workers to use for output loader
326 | random_seed (int): Random seed
327 | shuffle (bool): Whether or not to shuffle output data loaoder
328 | balance (bool): Whether or not to balance output data loader
329 | (only relevant for some language models)
330 |
331 | Returns:
332 | split_datasets (list): List of datasets (one each for train and val)
333 | split_loaders (list): List of loaders (one each for train and val)
334 | """
335 |
336 | Nval = math.floor(Ntotal*val_frac)
337 | train_ds, val_ds = ch.utils.data.random_split(dataset,
338 | [Ntotal - Nval, Nval],
339 | generator=ch.Generator().manual_seed(random_seed))
340 | if balance:
341 | val_ds = balance_dataset(val_ds)
342 | split_datasets = [train_ds, val_ds]
343 |
344 | split_loaders = []
345 | for ds in split_datasets:
346 | split_loaders.append(ch.utils.data.DataLoader(ds,
347 | num_workers=num_workers,
348 | batch_size=batch_size,
349 | shuffle=shuffle))
350 | return split_datasets, split_loaders
351 |
352 |
353 |
--------------------------------------------------------------------------------
/helpers/nlp_helpers.py:
--------------------------------------------------------------------------------
1 | import torch as ch
2 | import numpy as np
3 |
4 | from lime.lime_text import LimeTextExplainer
5 | from tqdm import tqdm
6 | from wordcloud import WordCloud
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 | import seaborn as sns
10 | sns.set_style('darkgrid')
11 |
12 | import os
13 | from collections import defaultdict
14 |
15 | def make_lime_fn(model,val_set, pooled_output, mu, std, bs=128):
16 | device = 'cuda'
17 | def classifier_fn(sentences):
18 | try:
19 | input_ids, attention_mask = zip(*[val_set.process_sentence(s) for s in sentences])
20 | except:
21 | input_ids, attention_mask = zip(*[val_set.dataset.process_sentence(s) for s in sentences])
22 | input_ids, attention_mask = ch.stack(input_ids), ch.stack(attention_mask)
23 | all_reps = []
24 | n = input_ids.size(0)
25 | # bs = args.batch_size
26 | for i in range(0,input_ids.size(0),bs):
27 | i0 = min(i+bs,n)
28 | # reps, _ =
29 | if hasattr(model, "roberta"):
30 | output = model.roberta(input_ids=input_ids[i:i0].to(device), attention_mask=attention_mask[i:i0].to(device))
31 | output = output[0]
32 |
33 | # do RobertA classification head minus last out_proj classifier
34 | # https://huggingface.co/transformers/_modules/transformers/models/roberta/modeling_roberta.html
35 | output = output[:,0,:]
36 | output = model.classifier.dropout(output)
37 | output = model.classifier.dense(output)
38 | output = ch.tanh(output)
39 | cls_reps = model.classifier.dropout(output)
40 | else:
41 | output = model.bert(input_ids=input_ids[i:i0].to(device), attention_mask=attention_mask[i:i0].to(device))
42 | if pooled_output:
43 | cls_reps = output[1]
44 | else:
45 | cls_reps = output[0][:,0]
46 | cls_reps = cls_reps.cpu()
47 | cls_reps = (cls_reps - mu)/std
48 | all_reps.append(cls_reps.cpu())
49 | return ch.cat(all_reps,dim=0).numpy()
50 | return classifier_fn
51 |
52 | def get_lime_features(model, val_set, val_loader, out_dir, pooled_output, mu, std):
53 | os.makedirs(out_dir,exist_ok=True)
54 | explainer = LimeTextExplainer()
55 | reps_size = 768
56 |
57 | clf_fn = make_lime_fn(model,val_set, pooled_output, mu, std)
58 | files = []
59 | with ch.no_grad():
60 | print('number of sentences', len(val_loader))
61 | for i,(sentences, labels) in enumerate(tqdm(val_loader, total=len(val_loader), desc="Generating LIME")):
62 | assert len(sentences) == 1
63 | sentence, label = sentences[0], labels[0]
64 | if label.item() == -1:
65 | continue
66 | out_file = f"{out_dir}/{i}.pth"
67 | try:
68 | files.append(ch.load(out_file))
69 | continue
70 | except:
71 | pass
72 | # if os.path.exists(out_file):
73 | # continue
74 | exp = explainer.explain_instance(sentence, clf_fn, labels=list(range(reps_size)))
75 | out = {
76 | "sentence": sentence,
77 | "explanation": exp
78 | }
79 | ch.save(out, out_file)
80 | files.append(out)
81 | return files
82 |
83 | def top_and_bottom_words(files, num_features=768):
84 | top_words = []
85 | bot_words = []
86 | for j in range(num_features):
87 | exps = [f['explanation'].as_list(label=j) for f in files]
88 | exps_collapsed = [a for e in exps for a in e]
89 |
90 | accumulator = defaultdict(lambda: [])
91 | for word,weight in exps_collapsed:
92 | accumulator[word].append(weight)
93 | exps_collapsed = [(k,np.array(accumulator[k]).mean()) for k in accumulator]
94 |
95 |
96 | exps_collapsed.sort(key=lambda a: a[1])
97 |
98 | weights = [a[1] for a in exps_collapsed]
99 | l = np.percentile(weights, q=1)
100 | u = np.percentile(weights, q=99)
101 | top_words.append([a for a in reversed(exps_collapsed) if a[1] > u])
102 | bot_words.append([a for a in exps_collapsed if a[1] < l])
103 | return top_words,bot_words
104 |
105 | def get_explanations(feature_idxs, sparse, top_words, bot_words):
106 | expl_dict = {}
107 | maxfreq = 0
108 | for i,idx in enumerate(feature_idxs):
109 |
110 | expl_dict[idx] = {}
111 |
112 | aligned = sparse[1,idx] > 0
113 | for words, j in zip([top_words, bot_words],[0,1]):
114 | if not aligned:
115 | j = 1-j
116 | expl_dict[idx][j] = {
117 | a[0]:abs(a[1]) for a in words[idx] ## ARE THESE SIGNS CORRECT
118 | }
119 | maxfreq = max(maxfreq, max(list(expl_dict[idx][j].values())))
120 |
121 | for k in expl_dict:
122 | for s in expl_dict[k]:
123 | expl_dict[k][s] = {kk: vv / maxfreq for kk, vv in expl_dict[k][s].items()}
124 | return expl_dict
125 |
126 | from matplotlib.colors import ListedColormap
127 |
128 | def grey_color_func(word, font_size, position, orientation, random_state=None,
129 | **kwargs):
130 | #cmap = sns.color_palette("RdYlGn_r", as_cmap=True)
131 | cmap = ListedColormap(sns.color_palette("RdYlGn_r").as_hex())
132 | fs = (font_size - 14) / (42 - 14)
133 |
134 | color_orig = cmap(fs)
135 | color = (255 * np.array(cmap(fs))).astype(np.uint8)
136 |
137 | return tuple(color[:3])
138 |
139 | def plot_wordcloud(expln_dict, weights, factor=3, transpose=False, labels=("Positive Sentiment", "Negative Sentiment")):
140 |
141 | if transpose:
142 | fig, axs = plt.subplots(2, len(expln_dict), figsize=(factor*3.5*len(expln_dict), factor*4), squeeze=False)
143 | else:
144 | fig, axs = plt.subplots(len(expln_dict), 2, figsize=(factor*7, factor*1.5*len(expln_dict)), squeeze=False)
145 |
146 | for i,idx in enumerate(expln_dict.keys()):
147 | if i == 0:
148 | if transpose:
149 | axs[0,0].set_ylabel(labels[0], fontsize=36)
150 | axs[1,0].set_ylabel(labels[1], fontsize=36)
151 | else:
152 | axs[i,0].set_title(labels[0], fontsize=36)
153 | axs[i,1].set_title(labels[1], fontsize=36)
154 |
155 | for j in [0, 1]:
156 | wc = WordCloud(background_color="white", max_words=1000, min_font_size=14,
157 | max_font_size=42)
158 | # generate word cloud
159 | d = {k[:20]:v for k,v in expln_dict[idx][j].items()}
160 | wc.generate_from_frequencies(d)
161 | default_colors = wc.to_array()
162 |
163 | if transpose:
164 | ax = axs[j,i]
165 | else:
166 | ax = axs[i,j]
167 | ax.imshow(wc.recolor(color_func=grey_color_func, random_state=3),
168 | interpolation="bilinear")
169 | ax.set_xticks([])
170 | ax.set_yticks([])
171 | ax.spines['right'].set_visible(False)
172 | ax.spines['top'].set_visible(False)
173 | ax.spines['left'].set_visible(False)
174 | ax.spines['bottom'].set_visible(False)
175 | # ax.set_axis_off()
176 | if not transpose:
177 | axs[i,0].set_ylabel(f"#{idx}\nW={weights[i]:.4f}", fontsize=36)
178 | else:
179 | axs[0,i].set_title(f"#{idx}", fontsize=36)
180 | plt.tight_layout()
181 | # plt.subplots_adjust(left=None, bottom=0.1, right=None, top=0.9, wspace=0.25, hspace=0.0)
182 | return fig, axs
--------------------------------------------------------------------------------
/helpers/vis_helpers.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./lucent')
3 | import torch as ch
4 | import torchvision
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import seaborn as sns
8 | from functools import partial
9 | from lucent.optvis import render, param, transform, objectives
10 | from lime import lime_image
11 |
12 | def plot_sparsity(results):
13 | """Function to visualize the sparsity-accuracy trade-off of regularized decision
14 | layers
15 | Args:
16 | results (dictionary): Appropriately formatted dictionary with regularization
17 | paths and logs of train/val/test accuracy.
18 | """
19 |
20 | if type(results['metrics']['acc_train'].values[0]) == list:
21 | all_tr = 100 * np.array(results['metrics']['acc_train'].values[0])
22 | all_val = 100 * np.array(results['metrics']['acc_val'].values[0])
23 | all_te = 100 * np.array(results['metrics']['acc_test'].values[0])
24 | else:
25 | all_tr = 100 * np.array(results['metrics']['acc_train'].values)
26 | all_val = 100 * np.array(results['metrics']['acc_val'].values)
27 | all_te = 100 * np.array(results['metrics']['acc_test'].values)
28 |
29 | fig, axarr = plt.subplots(1, 2, figsize=(14, 5))
30 | axarr[0].plot(all_tr)
31 | axarr[0].plot(all_val)
32 | axarr[0].plot(all_te)
33 | axarr[0].legend(['Train', 'Val', 'Test'], fontsize=16)
34 | axarr[0].set_ylabel("Accuracy (%)", fontsize=18)
35 | axarr[0].set_xlabel("Regularization index", fontsize=18)
36 |
37 | num_features = results['weights'][0].shape[1]
38 | total_sparsity = np.mean(results['sparsity'], axis=1) / num_features
39 | axarr[1].plot(total_sparsity, all_tr, 'o-')
40 | axarr[1].plot(total_sparsity, all_te, 'o-')
41 | axarr[1].legend(['Train', 'Val', 'Test'], fontsize=16)
42 | axarr[1].set_ylabel("Accuracy (%)", fontsize=18)
43 | axarr[1].set_xlabel("1 - Sparsity", fontsize=18)
44 | axarr[1].set_xscale('log')
45 |
46 | plt.show()
47 |
48 | def normalize_weight(w):
49 | """Normalizes weights to a unit vector
50 | Args:
51 | w (tensor): Weight vector for a class.
52 | Returns:
53 | Normalized weight vector in the form of a numpy array.
54 | """
55 | return w.numpy() / np.linalg.norm(w.numpy())
56 |
57 | def get_feature_visualization(model, feature_idx, signs):
58 | """Performs feature visualization using Lucid.
59 | Args:
60 | model: deep network whose deep features are to be visualized.
61 | feature_idx: indice of features to visualize.
62 | signs: +/-1 array indicating whether a feature should be maximized/minimized.
63 | Returns:
64 | Batch of feature visualizations .
65 | """
66 | param_f = lambda: param.image(224, batch=len(feature_idx), fft=True, decorrelate=True)
67 | obj = 0
68 | for fi, (f, s) in enumerate(zip(feature_idx, signs)):
69 | obj += s * objectives.channel('avgpool', f, batch=fi)
70 | op = render.render_vis(model.model,
71 | show_inline=False,
72 | objective_f=obj,
73 | param_f=param_f,
74 | thresholds=(512,))[0]
75 | return ch.tensor(op).permute(0, 3, 1, 2)
76 |
77 | def latent_predict(images, model, mean=None, std=None):
78 | """LIME helper function that computes the deep feature representation
79 | for a given batch of images.
80 | Args:
81 | image (tensor): batch of images.
82 | model: deep network whose deep features are to be visualized.
83 | mean (tensor): mean of deep features.
84 | std (tensor): std deviation of deep features.
85 | Returns:
86 | Normalized deep features for batch of images.
87 | """
88 | preprocess_transform = torchvision.transforms.Compose([
89 | torchvision.transforms.ToTensor(),
90 | ])
91 | device = 'cuda' if next(model.parameters()).is_cuda else 'cpu'
92 | batch = ch.stack(tuple(preprocess_transform(i) for i in images), dim=0).to(device)
93 |
94 | (_, latents), _ = model(batch.to(ch.float), with_latent=True)
95 | scaled_latents = (latents.detach().cpu() - mean.to(ch.float)) / std.to(ch.float)
96 | return scaled_latents.numpy()
97 |
98 | def parse_lime_explanation(expln, f, sign, NLime=3):
99 | """LIME helper function that extracts a mask from a lime explanation
100 | Args:
101 | expln: LIME explanation from LIME library
102 | f: indice of features to visualize
103 | sign: +/-1 array indicating whether the feature should be maximized/minimized.
104 | images (tensor): batch of images.
105 | NLime (int): Number of top-superpixels to visualize.
106 | Returns:
107 | Tensor where the first and second channels contains superpixels that cause the
108 | deep feature to activate and deactivate respectively.
109 | """
110 | segs = expln.segments
111 | vis_mask = np.zeros(segs.shape + (3,))
112 |
113 | weights = sorted([v for v in expln.local_exp[f]],
114 | key=lambda x: -np.abs(x[1]))
115 | weight_values = [w[1] for w in weights]
116 | pos_lim, neg_lim = np.max(weight_values), (1e-8 + np.min(weight_values))
117 |
118 | if NLime is not None:
119 | weights = weights[:NLime]
120 |
121 | for wi, w in enumerate(weights):
122 | if w[1] >= 0:
123 | si = (w[1] / pos_lim, 0, 0) if sign == 1 else (0, w[1] / pos_lim, 0)
124 | else:
125 | si = (0, w[1] / neg_lim, 0) if sign == 1 else (w[1] / neg_lim, 0, 0)
126 | vis_mask[segs == w[0]] = si
127 |
128 | return ch.tensor(vis_mask.transpose(2, 0, 1))
129 |
130 | def get_lime_explanation(model, feature_idx, signs,
131 | images, rep_mean, rep_std,
132 | num_samples=1000,
133 | NLime=3,
134 | background_color=0.6):
135 | """Computes LIME explanations for a given set of deep features. The LIME
136 | objective in this case is to identify the superpixels within the specified
137 | images that maximally/minimally activate the corresponding deep feature.
138 | Args:
139 | model: deep network whose deep features are to be visualized.
140 | feature_idx: indice of features to visualize
141 | signs: +/-1 array indicating whether a feature should be maximized/minimized.
142 | images (tensor): batch of images.
143 | rep_mean (tensor): mean of deep features.
144 | rep_std (tensor): std deviation of deep features.
145 | NLime (int): Number of top-superpixels to visualize
146 | background_color (float): Color to assign non-relevant super pixels
147 | Returns:
148 | Tensor comprising LIME explanations for the given set of deep features.
149 | """
150 | explainer = lime_image.LimeImageExplainer()
151 | lime_objective = partial(latent_predict, model=model, mean=rep_mean, std=rep_std)
152 |
153 | explanations = []
154 | for im, feature, sign in zip(images, feature_idx, signs):
155 | explanation = explainer.explain_instance(im.numpy().transpose(1, 2, 0),
156 | lime_objective,
157 | labels=np.array([feature]),
158 | top_labels=None,
159 | hide_color=0,
160 | num_samples=num_samples)
161 | explanation = parse_lime_explanation(explanation,
162 | feature,
163 | sign,
164 | NLime=NLime)
165 |
166 | if sign == 1:
167 | explanation = explanation[:1].unsqueeze(0).repeat(1, 3, 1, 1)
168 | else:
169 | explanation = explanation[1:2].unsqueeze(0).repeat(1, 3, 1, 1)
170 |
171 | interpolated = im * explanation + background_color * ch.ones_like(im) * (1 - explanation)
172 | explanations.append(interpolated)
173 |
174 | return ch.cat(explanations)
--------------------------------------------------------------------------------
/language/datasets.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | from torch.utils.data import Dataset
4 | from transformers import AutoTokenizer, AutoConfig
5 |
6 | from .jigsaw_loaders import JigsawDataOriginal
7 |
8 | def DATASETS(dataset_name):
9 | if dataset_name == 'sst': return SSTDataset
10 | elif dataset_name.startswith('jigsaw'): return JigsawDataset
11 | else:
12 | raise ValueError("Language dataset is not currently supported...")
13 |
14 | class SSTDataset(Dataset):
15 | """
16 | Stanford Sentiment Treebank V1.0
17 | Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank
18 | Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher Manning, Andrew Ng and Christopher Potts
19 | Conference on Empirical Methods in Natural Language Processing (EMNLP 2013)
20 | """
21 | def __init__(self, filename, maxlen, tokenizer, return_sentences=False):
22 | #Store the contents of the file in a pandas dataframe
23 | self.df = pd.read_csv(filename, delimiter = '\t')
24 | #Initialize the tokenizer for the desired transformer model
25 | self.tokenizer = tokenizer
26 | #Maximum length of the tokens list to keep all the sequences of fixed size
27 | self.maxlen = maxlen
28 | #whether to tokenize or return raw setences
29 | self.return_sentences = return_sentences
30 |
31 | def __len__(self):
32 | return len(self.df)
33 |
34 | def __getitem__(self, index):
35 | #Select the sentence and label at the specified index in the data frame
36 | sentence = self.df.loc[index, 'sentence']
37 | label = self.df.loc[index, 'label']
38 | #Preprocess the text to be suitable for the transformer
39 | if self.return_sentences:
40 | return sentence, label
41 | else:
42 | input_ids, attention_mask = self.process_sentence(sentence)
43 | return input_ids, attention_mask, label
44 |
45 | def process_sentence(self, sentence):
46 | tokens = self.tokenizer.tokenize(sentence)
47 | tokens = ['[CLS]'] + tokens + ['[SEP]']
48 | if len(tokens) < self.maxlen:
49 | tokens = tokens + ['[PAD]' for _ in range(self.maxlen - len(tokens))]
50 | else:
51 | tokens = tokens[:self.maxlen-1] + ['[SEP]']
52 | #Obtain the indices of the tokens in the BERT Vocabulary
53 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
54 | input_ids = torch.tensor(input_ids)
55 | #Obtain the attention mask i.e a tensor containing 1s for no padded tokens and 0s for padded ones
56 | attention_mask = (input_ids != 0).long()
57 | return input_ids, attention_mask
58 |
59 | class JigsawDataset(Dataset):
60 | def __init__(self, filename, maxlen, tokenizer, return_sentences=False, label="toxic"):
61 | classes=[label]
62 | if 'train' in filename:
63 | self.dataset = JigsawDataOriginal(
64 | train_csv_file=filename,
65 | test_csv_file=None,
66 | train=True,
67 | create_val_set=False,
68 | add_test_labels=False,
69 | classes=classes
70 | )
71 | elif 'test' in filename:
72 | self.dataset = JigsawDataOriginal(
73 | train_csv_file=None,
74 | test_csv_file=filename,
75 | train=False,
76 | create_val_set=False,
77 | add_test_labels=True,
78 | classes=classes
79 | )
80 | else:
81 | raise ValueError("Unknown filename {filename}")
82 | # #Store the contents of the file in a pandas dataframe
83 | # self.df = pd.read_csv(filename, header=None, names=['label', 'sentence'])
84 | #Initialize the tokenizer for the desired transformer model
85 | self.tokenizer = tokenizer
86 | #Maximum length of the tokens list to keep all the sequences of fixed size
87 | self.maxlen = maxlen
88 | #whether to tokenize or return raw setences
89 | self.return_sentences = return_sentences
90 |
91 | def __len__(self):
92 | return len(self.dataset)
93 |
94 | def __getitem__(self, index):
95 | #Select the sentence and label at the specified index in the data frame
96 | sentence, meta = self.dataset[index]
97 | label = meta["multi_target"].squeeze()
98 |
99 | #Preprocess the text to be suitable for the transformer
100 | if self.return_sentences:
101 | return sentence, label
102 | else:
103 | input_ids, attention_mask = self.process_sentence(sentence)
104 | return input_ids, attention_mask, label
105 |
106 | def process_sentence(self, sentence):
107 | # print(sentence)
108 | d = self.tokenizer(sentence, padding='max_length', truncation=True)
109 | input_ids = torch.tensor(d["input_ids"])
110 | attention_mask = torch.tensor(d["attention_mask"])
111 | return input_ids, attention_mask
--------------------------------------------------------------------------------
/language/jigsaw_loaders.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataset import Dataset
2 | import torch
3 | import pandas as pd
4 | import numpy as np
5 | import language.datasets
6 | from tqdm import tqdm
7 |
8 |
9 | class JigsawData(Dataset):
10 | """Dataloader for the Jigsaw Toxic Comment Classification Challenges.
11 | If test_csv_file is None and create_val_set is True the train file
12 | specified gets split into a train and validation set according to
13 | train_fraction."""
14 |
15 | def __init__(
16 | self,
17 | train_csv_file,
18 | test_csv_file,
19 | train=True,
20 | val_fraction=0.9,
21 | add_test_labels=False,
22 | create_val_set=True,
23 | ):
24 |
25 | if train_csv_file is not None:
26 | if isinstance(train_csv_file, list):
27 | train_set_pd = self.load_data(train_csv_file)
28 | else:
29 | train_set_pd = pd.read_csv(train_csv_file)
30 | self.train_set_pd = train_set_pd
31 | if "toxicity" not in train_set_pd.columns:
32 | train_set_pd.rename(columns={"target": "toxicity"}, inplace=True)
33 | self.train_set = datasets.Dataset.from_pandas(train_set_pd)
34 |
35 | if create_val_set:
36 | data = self.train_set.train_test_split(val_fraction)
37 | self.train_set = data["train"]
38 | self.val_set = data["test"]
39 |
40 | if test_csv_file is not None:
41 | val_set = pd.read_csv(test_csv_file)
42 | if add_test_labels:
43 | data_labels = pd.read_csv(test_csv_file[:-4] + "_labels.csv")
44 | for category in data_labels.columns[1:]:
45 | val_set[category] = data_labels[category]
46 | val_set = datasets.Dataset.from_pandas(val_set)
47 | self.val_set = val_set
48 |
49 | if train:
50 | self.data = self.train_set
51 | else:
52 | self.data = self.val_set
53 |
54 | self.train = train
55 |
56 | def __len__(self):
57 | return len(self.data)
58 |
59 | def load_data(self, train_csv_file):
60 | files = []
61 | cols = ["id", "comment_text", "toxic"]
62 | for file in tqdm(train_csv_file):
63 | file_df = pd.read_csv(file)
64 | file_df = file_df[cols]
65 | file_df = file_df.astype({"id": "string"}, {"toxic": "float64"})
66 | files.append(file_df)
67 | train = pd.concat(files)
68 | return train
69 |
70 |
71 | class JigsawDataOriginal(JigsawData):
72 | """Dataloader for the original Jigsaw Toxic Comment Classification Challenge.
73 | Source: https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
74 | """
75 |
76 | def __init__(
77 | self,
78 | train_csv_file="jigsaw_data/train.csv",
79 | test_csv_file="jigsaw_data/test.csv",
80 | train=True,
81 | val_fraction=0.1,
82 | create_val_set=True,
83 | add_test_labels=True,
84 | classes=["toxic"],
85 | ):
86 |
87 | super().__init__(
88 | train_csv_file=train_csv_file,
89 | test_csv_file=test_csv_file,
90 | train=train,
91 | val_fraction=val_fraction,
92 | add_test_labels=add_test_labels,
93 | create_val_set=create_val_set,
94 | )
95 | self.classes = classes
96 |
97 | def __getitem__(self, index):
98 | meta = {}
99 | entry = self.data[index]
100 | text_id = entry["id"]
101 | text = entry["comment_text"]
102 |
103 | target_dict = {
104 | label: value for label, value in entry.items() if label in self.classes
105 | }
106 |
107 | meta["multi_target"] = torch.tensor(
108 | list(target_dict.values()), dtype=torch.int32
109 | )
110 | meta["text_id"] = text_id
111 |
112 | return text, meta
113 |
114 |
115 | class JigsawDataBias(JigsawData):
116 | """Dataloader for the Jigsaw Unintended Bias in Toxicity Classification.
117 | Source: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/
118 | """
119 |
120 | def __init__(
121 | self,
122 | train_csv_file="jigsaw_data/train.csv",
123 | test_csv_file="jigsaw_data/test.csv",
124 | train=True,
125 | val_fraction=0.1,
126 | create_val_set=True,
127 | compute_bias_weights=True,
128 | loss_weight=0.75,
129 | classes=["toxic"],
130 | identity_classes=["female"],
131 | ):
132 |
133 | self.classes = classes
134 |
135 | self.identity_classes = identity_classes
136 |
137 | super().__init__(
138 | train_csv_file=train_csv_file,
139 | test_csv_file=test_csv_file,
140 | train=train,
141 | val_fraction=val_fraction,
142 | create_val_set=create_val_set,
143 | )
144 | if train:
145 | if compute_bias_weights:
146 | self.weights = self.compute_weigths(self.train_set_pd)
147 | else:
148 | self.weights = None
149 |
150 | self.train = train
151 | self.loss_weight = loss_weight
152 |
153 | def __getitem__(self, index):
154 | meta = {}
155 | entry = self.data[index]
156 | text_id = entry["id"]
157 | text = entry["comment_text"]
158 |
159 | target_dict = {label: 1 if entry[label] >= 0.5 else 0 for label in self.classes}
160 |
161 | identity_target = {
162 | label: -1 if entry[label] is None else entry[label]
163 | for label in self.identity_classes
164 | }
165 | identity_target.update(
166 | {label: 1 for label in identity_target if identity_target[label] >= 0.5}
167 | )
168 | identity_target.update(
169 | {label: 0 for label in identity_target if 0 <= identity_target[label] < 0.5}
170 | )
171 |
172 | target_dict.update(identity_target)
173 |
174 | meta["multi_target"] = torch.tensor(
175 | list(target_dict.values()), dtype=torch.float32
176 | )
177 | meta["text_id"] = text_id
178 |
179 | if self.train:
180 | meta["weights"] = self.weights[index]
181 | toxic_weight = (
182 | self.weights[index] * self.loss_weight * 1.0 / len(self.classes)
183 | )
184 | identity_weight = (1 - self.loss_weight) * 1.0 / len(self.identity_classes)
185 | meta["weights1"] = torch.tensor(
186 | [
187 | *[toxic_weight] * len(self.classes),
188 | *[identity_weight] * len(self.identity_classes),
189 | ]
190 | )
191 |
192 | return text, meta
193 |
194 | def compute_weigths(self, train_df):
195 | """Inspired from 2nd solution.
196 | Source: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/discussion/100661"""
197 | subgroup_bool = (train_df[self.identity_classes].fillna(0) >= 0.5).sum(
198 | axis=1
199 | ) > 0
200 | positive_bool = train_df["toxicity"] >= 0.5
201 | weights = np.ones(len(train_df)) * 0.25
202 |
203 | # Backgroud Positive and Subgroup Negative
204 | weights[
205 | ((~subgroup_bool) & (positive_bool)) | ((subgroup_bool) & (~positive_bool))
206 | ] += 0.25
207 | weights[(subgroup_bool)] += 0.25
208 | return weights
209 |
210 |
211 | class JigsawDataMultilingual(JigsawData):
212 | """Dataloader for the Jigsaw Multilingual Toxic Comment Classification.
213 | Source: https://www.kaggle.com/c/jigsaw-multilingual-toxic-comment-classification/
214 | """
215 |
216 | def __init__(
217 | self,
218 | train_csv_file="jigsaw_data/multilingual_challenge/jigsaw-toxic-comment-train.csv",
219 | test_csv_file="jigsaw_data/multilingual_challenge/validation.csv",
220 | train=True,
221 | val_fraction=0.1,
222 | create_val_set=False,
223 | classes=["toxic"],
224 | ):
225 |
226 | self.classes = classes
227 | super().__init__(
228 | train_csv_file=train_csv_file,
229 | test_csv_file=test_csv_file,
230 | train=train,
231 | val_fraction=val_fraction,
232 | create_val_set=create_val_set,
233 | )
234 |
235 | def __getitem__(self, index):
236 | meta = {}
237 | entry = self.data[index]
238 | text_id = entry["id"]
239 | if "translated" in entry:
240 | text = entry["translated"]
241 | elif "comment_text_en" in entry:
242 | text = entry["comment_text_en"]
243 | else:
244 | text = entry["comment_text"]
245 |
246 | target_dict = {label: 1 if entry[label] >= 0.5 else 0 for label in self.classes}
247 | meta["target"] = torch.tensor(list(target_dict.values()), dtype=torch.int32)
248 | meta["text_id"] = text_id
249 |
250 | return text, meta
251 |
--------------------------------------------------------------------------------
/language/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import BertModel, BertPreTrainedModel, BertForSequenceClassification, RobertaForSequenceClassification
4 |
5 | LANGUAGE_MODEL_DICT = {
6 | 'sst': 'barissayil/bert-sentiment-analysis-sst',
7 | 'jigsaw-toxic': 'unitary/toxic-bert',
8 | 'jigsaw-severe_toxic': 'unitary/toxic-bert',
9 | 'jigsaw-obscene': 'unitary/toxic-bert',
10 | 'jigsaw-threat': 'unitary/toxic-bert',
11 | 'jigsaw-insult': 'unitary/toxic-bert',
12 | 'jigsaw-identity_hate': 'unitary/toxic-bert',
13 | 'jigsaw-alt-toxic': 'unitary/unbiased-toxic-roberta',
14 | 'jigsaw-alt-severe_toxic': 'unitary/unbiased-toxic-roberta',
15 | 'jigsaw-alt-obscene': 'unitary/unbiased-toxic-roberta',
16 | 'jigsaw-alt-threat': 'unitary/unbiased-toxic-roberta',
17 | 'jigsaw-alt-insult': 'unitary/unbiased-toxic-roberta',
18 | 'jigsaw-alt-identity_hate': 'unitary/unbiased-toxic-roberta'
19 | }
20 |
21 | class BertForSentimentClassification(BertPreTrainedModel):
22 | def __init__(self, config):
23 | super().__init__(config)
24 | self.bert = BertModel(config)
25 | # self.dropout = nn.Dropout(config.hidden_dropout_prob)
26 | #The classification layer that takes the [CLS] representation and outputs the logit
27 | self.cls_layer = nn.Linear(config.hidden_size, 1)
28 |
29 | def forward(self, input_ids, attention_mask):
30 | '''
31 | Inputs:
32 | -input_ids : Tensor of shape [B, T] containing token ids of sequences
33 | -attention_mask : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens
34 | (where B is the batch size and T is the input length)
35 | '''
36 | #Feed the input to Bert model to obtain contextualized representations
37 | reps, _ = self.bert(input_ids=input_ids, attention_mask=attention_mask)
38 | #Obtain the representations of [CLS] heads
39 | cls_reps = reps[:, 0]
40 | # cls_reps = self.dropout(cls_reps)
41 | logits = self.cls_layer(cls_reps)
42 | return logits
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os, math, time
2 | import torch as ch
3 | import torch.nn as nn
4 |
5 | from robustness.datasets import DATASETS
6 |
7 | # be sure to pip install glm_saga, or clone the repo from
8 | # https://github.com/madrylab/glm_saga
9 | from glm_saga.elasticnet import glm_saga
10 |
11 | import helpers.data_helpers as data_helpers
12 | import helpers.feature_helpers as feature_helpers
13 |
14 | from argparse import ArgumentParser
15 |
16 | ch.manual_seed(0)
17 | ch.set_grad_enabled(False)
18 |
19 |
20 | if __name__ == '__main__':
21 | parser = ArgumentParser()
22 | parser.add_argument('--dataset', type=str, help='dataset name')
23 | parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]')
24 | parser.add_argument('--dataset-path', type=str, help='path to dataset')
25 | parser.add_argument('--model-path', type=str, help='path to model checkpoint')
26 | parser.add_argument('--arch', type=str, help='model architecture type')
27 | parser.add_argument('--out-path', help='location for saving results')
28 | parser.add_argument('--cache', action='store_true', help='cache deep features')
29 | parser.add_argument('--balance', action='store_true', help='balance classes for evaluation')
30 |
31 | parser.add_argument('--device', default='cuda')
32 | parser.add_argument('--random-seed', default=0)
33 | parser.add_argument('--num-workers', type=int, default=2)
34 | parser.add_argument('--batch-size', type=int, default=256)
35 | parser.add_argument('--val-frac', type=float, default=0.1)
36 | parser.add_argument('--lr-decay-factor', type=float, default=1)
37 | parser.add_argument('--lr', type=float, default=0.1)
38 | parser.add_argument('--alpha', type=float, default=0.99)
39 | parser.add_argument('--max-epochs', type=int, default=2000)
40 | parser.add_argument('--verbose', type=int, default=200)
41 | parser.add_argument('--tol', type=float, default=1e-4)
42 | parser.add_argument('--lookbehind', type=int, default=3)
43 | parser.add_argument('--lam-factor', type=float, default=0.001)
44 | parser.add_argument('--group', action='store_true')
45 | args = parser.parse_args()
46 |
47 | start_time = time.time()
48 |
49 | out_dir = args.out_path
50 | out_dir_ckpt = f'{out_dir}/checkpoint'
51 | out_dir_feats = f'{out_dir}/features'
52 | for path in [out_dir, out_dir_ckpt, out_dir_feats]:
53 | if not os.path.exists(path):
54 | os.makedirs(path)
55 |
56 | print("Initializing dataset and loader...")
57 |
58 | dataset, train_loader, test_loader = data_helpers.load_dataset(args.dataset,
59 | os.path.expandvars(args.dataset_path),
60 | args.dataset_type,
61 | args.batch_size,
62 | args.num_workers,
63 | shuffle=False,
64 | model_path=args.model_path)
65 |
66 | num_classes = dataset.num_classes
67 | Ntotal = len(train_loader.dataset)
68 |
69 | print("Loading model...")
70 | model, pooled_output = feature_helpers.load_model(args.model_path,
71 | args.arch,
72 | dataset,
73 | args.dataset,
74 | args.dataset_type,
75 | device=args.device)
76 |
77 | print("Computing/loading deep features...")
78 | feature_loaders = {}
79 | for mode, loader in zip(['train', 'test'], [train_loader, test_loader]):
80 | print(f"For {mode} set...")
81 |
82 | sink_path = f"{out_dir_feats}/features_{mode}" if args.cache else None
83 | metadata_path = f"{out_dir_feats}/metadata_{mode}.pth" if args.cache else None
84 |
85 | feature_ds, feature_loader = feature_helpers.compute_features(loader,
86 | model,
87 | dataset_type=args.dataset_type,
88 | pooled_output=pooled_output,
89 | batch_size=args.batch_size,
90 | num_workers=args.num_workers,
91 | shuffle=(mode == 'test'),
92 | device=args.device,
93 | filename=sink_path,
94 | balance=args.balance if mode == 'test' else False)
95 |
96 | if mode == 'train':
97 | metadata = feature_helpers.calculate_metadata(feature_loader,
98 | num_classes=num_classes,
99 | filename=metadata_path)
100 | split_datasets, split_loaders = feature_helpers.split_dataset(feature_ds,
101 | Ntotal,
102 | val_frac=args.val_frac,
103 | batch_size=args.batch_size,
104 | num_workers=args.num_workers,
105 | random_seed=args.random_seed,
106 | shuffle=True, balance=args.balance)
107 | feature_loaders.update({mm : data_helpers.add_index_to_dataloader(split_loaders[mi])
108 | for mi, mm in enumerate(['train', 'val'])})
109 |
110 | else:
111 | feature_loaders[mode] = feature_loader
112 |
113 |
114 | num_features = metadata["X"]["num_features"][0]
115 | assert metadata["y"]["num_classes"].numpy() == num_classes
116 |
117 | print("Initializing linear model...")
118 | linear = nn.Linear(num_features, num_classes).to(args.device)
119 | for p in [linear.weight, linear.bias]:
120 | p.data.zero_()
121 |
122 | print("Preparing normalization preprocess and indexed dataloader")
123 | preprocess = data_helpers.NormalizedRepresentation(feature_loaders['train'],
124 | metadata=metadata,
125 | device=linear.weight.device)
126 |
127 | print("Calculating the regularization path")
128 | params = glm_saga(linear,
129 | feature_loaders['train'],
130 | args.lr,
131 | args.max_epochs,
132 | args.alpha,
133 | val_loader=feature_loaders['val'],
134 | test_loader=feature_loaders['test'],
135 | n_classes=num_classes,
136 | checkpoint=out_dir_ckpt,
137 | verbose=args.verbose,
138 | tol=args.tol,
139 | lookbehind=args.lookbehind,
140 | lr_decay_factor=args.lr_decay_factor,
141 | group=args.group,
142 | epsilon=args.lam_factor,
143 | metadata=metadata,
144 | preprocess=preprocess)
145 |
146 | print(f"Total time: {time.time() - start_time}")
147 |
148 |
--------------------------------------------------------------------------------
/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MadryLab/DebuggableDeepNetworks/a2491c72a0c9a07b05803a88e9936b4dcc8fd080/pipeline.png
--------------------------------------------------------------------------------
/pipeline_600x400.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MadryLab/DebuggableDeepNetworks/a2491c72a0c9a07b05803a88e9936b4dcc8fd080/pipeline_600x400.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | pandas
4 | numpy
5 | scipy
6 | GPUtil
7 | dill
8 | tensorboardX
9 | tables
10 | tqdm
11 | seaborn
12 | jupyter
13 | sklearn
14 | pillow
15 | transformers
16 | datasets
17 | kornia
18 | lime
19 | robustness
20 | cox
21 | glmsaga
--------------------------------------------------------------------------------