├── .gitignore ├── LICENSE ├── README.md ├── SIC_imports.py ├── datasets ├── __init__.py ├── ccle_dataset.py ├── dataset_builder.py ├── datasets.py └── toy_dataset.py ├── modules ├── __init__.py ├── generators.py ├── linregression.py └── models.py ├── output └── SINEXP_250.png ├── plot_results.py ├── requirements.txt ├── run_baselines.py ├── run_sic.py ├── run_sic_supervised.py ├── slides ├── sic_neurips_slides.key └── sic_neurips_slides.pdf ├── stattests.py ├── test ├── __init__.py ├── test_cbatchnorm.py ├── test_generators.py └── test_toydataset.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | ~* 3 | .* 4 | *~ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sobolev Independence Criterion 2 | Pytorch source code for paper 3 | > Mroueh, Sercu, Rigotti, Padhi, dos Santos, "Sobolev Independence Criterion", NeurIPS 2019 [[arXiv:1910.14212](https://arxiv.org/abs/1910.14212)] [[NeurIPS 2019 Proceedings]](https://papers.nips.cc/paper/9147-sobolev-independence-criterion) 4 | 5 | 6 | ## Requirements 7 | * Python 3.6 or above 8 | * PyTorch 1.1.0 9 | * Torchvision 0.3.0 10 | * Scikit-learn 0.21 11 | * Pandas 0.25 (for CCLE dataset) 12 | 13 | These can be installed using `pip` by running: 14 | 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Usage 20 | 21 | We will look at the example of performing *feature selection* on one of the toy datasets examined in [Zhang et al., arXiv:1606.07892](https://arxiv.org/abs/1606.07892) (see sections 5.1 5.2) that we will call `SinExp`. 22 | 23 | * **Baseline models:** 24 | * To train an *elastic net* (one of the implemented baseline models) on 250 samples from `SinExp` execute: 25 | ```bash 26 | python run_baselines.py --model en --dataset sinexp --numSamples 250 --do-hrt 27 | ``` 28 | * Analogously, to train a *random forest* on 250 samples from `SinExp` execute: 29 | ```bash 30 | python run_baselines.py --model rf --dataset sinexp --numSamples 250 --do-hrt 31 | ``` 32 | The flag `--do-hrt` tells the script to use the Holdout Randomization Test by [Tansey et al., arXiv:1811.00645](https://arxiv.org/abs/1811.00645) to rank the important features in the data and control False Discovery Rate (FDR). 33 | 34 | * **Multi-layer neural network regression with Sobolev penalty:** To train a multilayer neural network on the prediction problem of regressing the responses `y` on the inputs `X`, subject to gradient penalty (Sobolev penalty), again on 250 samples from `SinExp` execute: 35 | ```bash 36 | python run_sic_supervised.py --dataset sinexp --numSamples 250 --do-hrt 37 | ``` 38 | 39 | * **Sobolev Independence Criterion:** To train a multilayer discriminator network using the Sobolev Independence Criterion (SIC) between the responses `y` and the inputs `X` on 250 samples from `SinExp` execute: 40 | ```bash 41 | python run_sic.py --dataset sinexp --numSamples 250 --do-hrt 42 | ``` 43 | 44 | * The results can be plotted using the script `plot_results.py`, which will generate the following figure: 45 | 46 | ![figure](/output/SINEXP_250.png) 47 | Visualization of the results of executing the previous commands. We plot True Positive Rate (TPR, i.e. Power) and False Discovery Rate (FDR) for the three algorithms, indicating when FDR is controlled with HRT. Higher is better for TPR (blue bars), and lower is better for TPR (red bars). The red horizontal dashed line indicates a TPR of 10%, which is what was used as target FDR for HRT. In this case SIC combined with HRT (bars on the right) has the highest TPR, while maintaining a low FDR. 48 | 49 | 50 | ## Citation 51 | > Youssef Mroueh, Tom Sercu, Mattia Rigotti, Inkit Padhi, Cicero Dos Santos, "Sobolev Independence Criterion", NeurIPS, 2019 [[arXiv](https://arxiv.org/abs/1910.14212)] [[NeurIPS Proceedings]](https://papers.nips.cc/paper/9147-sobolev-independence-criterion) 52 | 53 | ``` 54 | @incollection{NIPS2019_9147, 55 | title = {Sobolev Independence Criterion}, 56 | author = {Mroueh, Youssef and Sercu, Tom and Rigotti, Mattia and Padhi, Inkit and Nogueira dos Santos, Cicero}, 57 | booktitle = {Advances in Neural Information Processing Systems 32}, 58 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 59 | pages = {9505--9515}, 60 | year = {2019}, 61 | publisher = {Curran Associates, Inc.}, 62 | url = {http://papers.nips.cc/paper/9147-sobolev-independence-criterion.pdf} 63 | } 64 | ``` 65 | 66 | -------------------------------------------------------------------------------- /SIC_imports.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils import log_current_variables 4 | 5 | 6 | def minibatch(data, batch_size, requires_grad=True): 7 | x, y = data 8 | if not batch_size: 9 | x, y = x.clone(), y.clone() 10 | return x.requires_grad_(requires_grad), y.requires_grad_(requires_grad) 11 | else: 12 | indx = torch.LongTensor(batch_size).random_(0, x.size(0)) 13 | return x[indx].requires_grad_(requires_grad), y[indx].requires_grad_(requires_grad) 14 | 15 | 16 | def compute_objective_terms(data, net, need_penalty_terms=False): 17 | """Construct minibatch of (x,y): compute main objective term and optionally penalty terms. 18 | Returns expectations over minibatch (last 2 terms only if penalty_terms=True) 19 | * E[ f(x,y) ] 20 | * E[ f(x,y)**2 ] 21 | * [ E[ |df/dx_j|^2 ], E[ |df/dy_j|^2 ] ] 22 | """ 23 | device = next(net.parameters()).device 24 | x = data[0].requires_grad_(need_penalty_terms).to(device) 25 | y = data[1].requires_grad_(False).to(device) 26 | f = net(x,y) 27 | E_f = f.mean(0) 28 | 29 | E_f2, E_grad2 = None, None 30 | if need_penalty_terms: 31 | E_f2 = (f**2).sum() 32 | gradx = torch.autograd.grad(f.sum(), x, create_graph=True)[0] 33 | E_grad2 = (gradx**2).mean(0) # expectation, keep x_j coordinates separate. 34 | return E_f, E_f2, E_grad2 35 | 36 | 37 | def compute_mse(dataloader, net, targets_mu=0.0, targets_sd=1.0): 38 | device = next(net.parameters()).device 39 | net.eval() 40 | mse_loss = 0 41 | n_samples = 0 42 | with torch.no_grad(): 43 | for data, targets in dataloader: 44 | data, targets = data.to(device), targets.to(device) 45 | targets_norm = (targets.view(-1) - targets_mu) / targets_sd 46 | mse_loss += F.mse_loss(net(data).view(-1), targets_norm, reduction='sum') 47 | n_samples += data.shape[0] 48 | return mse_loss / n_samples 49 | 50 | 51 | def sobolev_forward(D, eta_x, dataP, dataQ, mu, selectSobo='L1^2'): 52 | ETA_EPS = 1e-6 # stabilize denominator in eta constraints 53 | 54 | Ep_f, Ep_f2, Ep_grad2 = compute_objective_terms(dataP, D, need_penalty_terms='P' in mu) 55 | Eq_f, Eq_f2, Eq_grad2 = compute_objective_terms(dataQ, D, need_penalty_terms='Q' in mu) 56 | 57 | # mu: dominant measure on which to compute expectations. 58 | if mu == 'P': 59 | Emu_f2, Emu_grad2 = Ep_f2, Ep_grad2 60 | elif mu == 'Q': 61 | Emu_f2, Emu_grad2 = Eq_f2, Eq_grad2 62 | elif mu == 'P+Q/2': 63 | Emu_f2, Emu_grad2 = (Ep_f2 + Eq_f2) / 2, (Ep_grad2 + Eq_grad2) / 2 64 | 65 | sobo_dist = (Ep_f - Eq_f) 66 | constraint_f2 = Emu_f2 67 | 68 | if selectSobo == 'L2': 69 | constraint_Sobo = Emu_grad2.sum() # L2 norm: sum_j (E |df / dx_j |^2) 70 | elif selectSobo == 'L1-biased': 71 | constraint_Sobo = Emu_grad2.sqrt().sum() # L1 norm: sum_j sqrt(E |df / dx_j |^2) 72 | elif selectSobo == 'L1': 73 | constraint_Sobo = (Emu_grad2 / (eta_x + ETA_EPS)).sum() + eta_x.sum() 74 | elif selectSobo == 'L1^2': 75 | constraint_Sobo = (Emu_grad2 / (eta_x + ETA_EPS)).sum() 76 | else: 77 | raise KeyError('Unrecognized selectSobo argument == {}'.format(selectSobo)) 78 | 79 | return sobo_dist, constraint_f2, constraint_Sobo 80 | 81 | 82 | def Ep_D(D, test_loader): 83 | device = next(D.parameters()).device 84 | D.eval() 85 | Ep_f = [] 86 | with torch.no_grad(): 87 | for batch_idx, (data, targets) in enumerate(test_loader): 88 | data, targets = data.to(device), targets.to(device) 89 | Ep_f.append(D(data, targets)) 90 | return torch.cat(Ep_f).mean(0).item() 91 | 92 | 93 | def avg_sobolev_dist(D, dl_P, dl_Q): 94 | D.eval() 95 | sobo_dist = [] 96 | n_samples = 0 97 | with torch.no_grad(): 98 | for (dataP, dataQ) in zip(dl_P, dl_Q): 99 | 100 | Ep_f, Ep_f2, Ep_grad2 = compute_objective_terms(dataP, D) 101 | Eq_f, Eq_f2, Eq_grad2 = compute_objective_terms(dataQ, D) 102 | 103 | sobo_dist.append((Ep_f - Eq_f) * dataP[0].shape[0]) 104 | n_samples += dataP[0].shape[0] 105 | 106 | return torch.cat(sobo_dist).sum(0).item() / n_samples 107 | 108 | 109 | def compute_objective_supervised(inputs, targets, net): 110 | """Construct minibatch of (x,y): compute main objective term and optionally penalty terms. 111 | Returns expectations over minibatch (last 2 terms only if penalty_terms=True) 112 | * E[ |df/dx_j|^2 ] 113 | """ 114 | device = next(net.parameters()).device 115 | inputs, targets = inputs.requires_grad_(True).to(device), targets.requires_grad_(False).to(device) 116 | 117 | outputs = net(inputs).view(-1) 118 | mse_loss = F.mse_loss(outputs, targets) 119 | 120 | gradx = torch.autograd.grad(outputs.sum(), inputs, create_graph=True)[0] 121 | E_grad2 = (gradx**2).mean(0) # expectation, keep x_j coordinates separate. 122 | return mse_loss, E_grad2 123 | 124 | 125 | def supervised_forward_sobolev_penalty(net, inputs, targets, eta_x, selectSobo='L1^2'): 126 | ETA_EPS = 1e-6 # stabilize denominator in eta constraints 127 | 128 | mse_loss, Ep_grad2 = compute_objective_supervised(inputs, targets, net) 129 | 130 | if selectSobo == 'L2': 131 | constraint_Sobo = Ep_grad2.sum() # L2 norm: sum_j (E |df / dx_j |^2) 132 | elif selectSobo == 'L1-biased': 133 | constraint_Sobo = Ep_grad2.sqrt().sum() # L1 norm: sum_j sqrt(E |df / dx_j |^2) 134 | elif selectSobo == 'L1': 135 | constraint_Sobo = (Ep_grad2 / (eta_x + ETA_EPS)).sum() + eta_x.sum() 136 | elif selectSobo == 'L1^2': 137 | constraint_Sobo = (Ep_grad2 / (eta_x + ETA_EPS)).sum() 138 | else: 139 | raise KeyError('Unrecognized selectSobo argument == {}'.format(selectSobo)) 140 | 141 | return mse_loss, constraint_Sobo 142 | 143 | 144 | def recompute_etas_P(D, dataloader): 145 | """Recomputes etas integrating dD/dx over P 146 | """ 147 | etas = 0 148 | n_samples = 0 149 | for data in dataloader: 150 | _, _, Ep_grad2 = compute_objective_terms(data, D, need_penalty_terms=True) 151 | # Only 1 term, i.e. integrate over P ((P+Q)/2 would work also) 152 | etas += Ep_grad2.detach() * data[0].shape[0] 153 | n_samples += data[0].shape[0] 154 | return etas / n_samples 155 | 156 | 157 | def normalize_etas(eta): 158 | EPS = 1e-6 159 | logits = (eta.data + EPS).log() 160 | eta.data.copy_(torch.softmax(logits, 0)) 161 | return eta 162 | 163 | 164 | def log_eta_stats(tbw, t, eta_x, eta_lr, tb_trunc_tensor_size): 165 | ETA_EPS = 1e-6 # stabilize denominator in eta constraints 166 | 167 | eta_x_grad = eta_x.grad 168 | # Linf norm: max eta update, how far from 1 on average 169 | eta_x_update = (torch.exp(-eta_lr * eta_x.grad) - 1).abs() 170 | eta_x_update_L1 = eta_x_update.sum() 171 | eta_x_update_Linf = eta_x_update.max() 172 | if torch.isclose(eta_x.data.sum(), torch.tensor(1.0)): 173 | # entropy if etas are on the simplex 174 | eta_entropy = - (eta_x.data * eta_x.data.log()).sum() 175 | eta_sparse_count = (eta_x < 10 * ETA_EPS).sum() 176 | log_current_variables(tbw, t, locals(), 177 | keys_to_log=['eta_x', 'eta_x_grad', 'eta_x_update_L1', 'eta_x_update_Linf', 178 | 'eta_entropy', 'eta_sparse_count'], 179 | key_prefix='eta/', 180 | tb_trunc_tensor_size=tb_trunc_tensor_size) 181 | 182 | 183 | def logstab_mirror_descent_step_(eta, lr): 184 | # V3: stabilized in log domain, based on Youssef chat in channel on May 1st. 185 | EPS = 1e-6 186 | logits = (eta.data + EPS).log() 187 | logits.add_(-lr, eta.grad.data) 188 | eta.data.copy_(torch.softmax(logits, 0)) 189 | 190 | 191 | def reduced_gradient_step_(eta, lr): 192 | """Reduced gradient for projecting on the simplex 193 | See Bonnans, used in SimpleMKL paper: http://www.jmlr.org/papers/volume9/rakotomamonjy08a/rakotomamonjy08a.pdf 194 | Note: initialize eta to uniform for this to work 195 | """ 196 | # get the maximum value of eta 197 | eta_max,index_max = torch.max(eta.data,0) 198 | 199 | # define reduced gradient 200 | reduced_grad = eta.grad.data 201 | grad_eta_max = eta.grad.data[index_max] 202 | reduced_grad = reduced_grad - grad_eta_max.expand_as(reduced_grad) 203 | reduced_grad[index_max] = - torch.sum(reduced_grad) 204 | ## find if any eta = 0 and its reduced gardient positive 205 | intersection = (eta == 0)*(reduced_grad > 0) 206 | indices_intersection = intersection.nonzero() 207 | ## reduced gradient update 208 | reduced_grad_corrected = reduced_grad 209 | reduced_grad_corrected[indices_intersection] = 0 210 | reduced_grad_corrected[index_max] = 0 211 | reduced_grad_corrected[index_max] = - torch.sum(reduced_grad_corrected) 212 | 213 | ### apply gradient descent with reduced gradient descent 214 | x = -lr * reduced_grad_corrected 215 | eta.data.add_(x) 216 | # End of training - evaluate etas 217 | 218 | 219 | def eta_optim_step_(eta, eta_step_type, lr): 220 | """Optimization step for etas 221 | Note: this only implements 'L1^2' 222 | """ 223 | if eta_step_type == 'mirror': 224 | logstab_mirror_descent_step_(eta, lr) 225 | elif eta_step_type == 'reduced': 226 | reduced_gradient_step_(eta, lr) 227 | else: 228 | raise KeyError("eta_step_type must be one of the following values: mirror | reduced" ) 229 | return eta 230 | 231 | 232 | def heldout_eval(tbw, t, te_P, te_Q, D, logger=None): 233 | Ep_f, Ep_f2, Ep_grad2 = compute_objective_terms(te_P, D, need_penalty_terms=True) 234 | Eq_f, Eq_f2, Eq_grad2 = compute_objective_terms(te_Q, D, need_penalty_terms=True) 235 | sobo_dist = Ep_f.item() - Eq_f.item() 236 | # NOTE were taking mu=Q here, should be passed in from opt.mu 237 | betas = Eq_grad2 238 | constraint_f2 = Eq_f2 239 | constraint_L2 = Eq_grad2.sum() # L2 norm: sum_j (E |df / dx_j |^2) 240 | constraint_L1 = Eq_grad2.sqrt().sum() # L1 norm: sum_j sqrt(E |df / dx_j |^2) 241 | log_current_variables(tbw, t, locals(), 242 | keys_to_log=['sobo_dist', 'betas', 'constraint_f2', 'constraint_L2', 'constraint_L1'], 243 | key_prefix='hld/') 244 | 245 | msg = '[{:5d}] sobo_dist={:.4f} constraint_L2={:.4f} constraint_L1={:.4f} constraint_f2={:.4f} Ep_f={:.4f} Eq_f={:.4f}'.format( 246 | t, sobo_dist, constraint_L2.item(), constraint_L1.item(), constraint_f2.item(), Ep_f.item(), Eq_f.item()) 247 | 248 | if logger: logger.info(msg) 249 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SIC/c4e45d7736da6e6faabdc56bfc1336445df99204/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/ccle_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import transforms 5 | from torchvision.datasets.utils import download_url 6 | import os 7 | import errno 8 | import numpy as np 9 | import pandas as pd 10 | 11 | 12 | class CCLE_Dataset(torch.utils.data.Dataset): 13 | """`CCLE` dataset from paper: 14 | 15 | Barretina, J., Caponigro, G., Stransky, N., Venkatesan, K., Margolin, A. A., Kim, S., ... & Reddy, A. (2012). 16 | The Cancer Cell Line Encyclopedia enables predictive modelling of anticancer drug sensitivity. 17 | Nature, 483(7391), 603. 18 | 19 | Note: 20 | The X dataset is z-scored, which means that if it is partitioned into a training and test split, these have 21 | to be re-z-scored according to the training split 22 | 23 | Args: 24 | root (string): Root directory of dataset where ``processed/training.pt`` 25 | and ``processed/test.pt`` exist. 26 | task (string): Which task should provide output among: 27 | ['Bakery', 'Sour','Intensity','Sweet','Burnt','Pleasantness','Fish', 'Fruit','Garlic','Spices', 28 | 'Cold','Acid','Warm', Musky','Sweaty','Ammonia','Decayed','Wood','Grass', 'Flower','Chemical'] 29 | train (bool, optional): If True, creates dataset from the training split, 30 | otherwise from the test split. The two split come from the same partition if the 31 | random seed `seed` is the same. 32 | test_size (int, float): how much data has to be reserved for test. 33 | If test_size is int it will indicate the number of samples. If it's a float, it's the fraction 34 | of samples over the total. 35 | shuffle_target (bool): If True, it shuffle the targets (Y) compared to data (X) breaking the dependence between 36 | X and Y, such that P(X,Y) = P(X)P(Y). If False, X and Y are sampled together from P(X,Y). 37 | seed (int): seed of random number generator 38 | z_score (bool): whether to z-score X features or not. z-score statistics are always computed on the training split. 39 | Also, note that the whole X dataset is already z-score (see note above). 40 | download (bool, optional): If true, downloads the dataset from the internet and 41 | puts it in root directory. If dataset is already downloaded, it is not 42 | downloaded again. 43 | """ 44 | urls = [ 45 | 'https://www.dropbox.com/s/7iy0ght31hxhn7d/mutation.txt', 46 | 'https://www.dropbox.com/s/bplwquwbc7zleck/expression.txt', 47 | 'https://www.dropbox.com/s/78mp3ebnb4h6jsy/response.csv', 48 | ] 49 | download_option = '?dl=1' 50 | 51 | files = [ 52 | 'mutation.txt', 53 | 'expression.txt', 54 | 'response.csv' 55 | ] 56 | 57 | def __init__(self, root, task = 'PLX4720', feature_type ='both', train=True, test_size=0.1, 58 | shuffle_targets=False, seed=1, z_score=True, download=True, verbose = False, parent_dataset = None): 59 | self.root = os.path.expanduser(root) 60 | 61 | if not isinstance(feature_type, str) or feature_type not in ['mutation', 'expression', 'both']: 62 | raise ValueError('task must be one of the following task descriptors: ' + str(['mutation', 'expression', 'both'])) 63 | else: 64 | self.feature_type = feature_type 65 | 66 | self.fea_groundtruth = ['C11orf85', 'FXYD4', 'SLC28A2', 'MAML3_MUT', 'RAD51L1_MUT', 'GAPDHS', 'BRAF_MUT'] 67 | 68 | self.task = task # drug target 69 | 70 | self.train = train 71 | self.shuffle_targets = shuffle_targets 72 | self.verbose = verbose 73 | 74 | # Random number generator 75 | self.rng = np.random.RandomState(seed) 76 | 77 | if parent_dataset is None: 78 | if download: 79 | self.download() 80 | 81 | if not self._check_exists(): 82 | raise RuntimeError('Dataset not found. You can use download=True to download it') 83 | 84 | self.full_data, self.full_targets, self.features, all_features = self.load_data() 85 | 86 | if isinstance(test_size, float): 87 | if test_size > 1.0 or test_size < 0.0: 88 | raise ValueError('test_size must be integer or a float between 0.0 and 1.0') 89 | else: 90 | self.test_size = int(len(self.full_data) * test_size) 91 | elif isinstance(test_size, int): 92 | if test_size >= len(self.full_data) or test_size < 0: 93 | raise ValueError('integer test_size must be between 0 and {}'.format(len(self.full_data))) 94 | else: 95 | self.test_size = test_size 96 | 97 | # Permutation indices: 98 | perm = self.rng.permutation(len(self.full_data)) 99 | self.ind_train = perm[self.test_size:] 100 | self.ind_test = perm[:self.test_size] 101 | 102 | else: 103 | self.full_data = parent_dataset.full_data 104 | self.full_targets = parent_dataset.full_targets 105 | self.features = parent_dataset.features 106 | self.ind_train = parent_dataset.ind_train 107 | self.ind_test = parent_dataset.ind_test 108 | 109 | # get feature indexes 110 | self.fea_groundtruth_idx = [self.features.get_loc(ftr) for ftr in self.fea_groundtruth] 111 | 112 | if self.train: 113 | self.data, self.targets = self.full_data[self.ind_train], self.full_targets[self.ind_train] 114 | else: 115 | self.data, self.targets = self.full_data[self.ind_test], self.full_targets[self.ind_test] 116 | 117 | # z-score according to training split 118 | if z_score: 119 | mu = np.mean(self.full_data[self.ind_train], 0) 120 | sd = np.std(self.full_data[self.ind_train], 0) + 1e-6 121 | self.data = (self.data - mu) / sd 122 | 123 | mu = np.mean(self.full_targets[self.ind_train], 0) 124 | sd = np.std(self.full_targets[self.ind_train], 0) + 1e-6 125 | self.targets = (self.targets - mu) / sd 126 | 127 | self.z_score = z_score 128 | 129 | self.data, self.targets = torch.FloatTensor(self.data), torch.FloatTensor(self.targets) 130 | 131 | 132 | def get_feature_names(self): 133 | return self.features.values 134 | 135 | def get_groundtruth_features(self): 136 | return self.fea_groundtruth_idx 137 | 138 | def __getitem__(self, index): 139 | """ 140 | Args: 141 | index (int): Index 142 | Returns: 143 | tuple: (image, target) where target is index of the target class. 144 | """ 145 | if self.shuffle_targets: 146 | y_index = self.rng.randint(len(self.targets)) 147 | else: 148 | y_index = index 149 | 150 | return self.data[index], self.targets[y_index] 151 | 152 | def __len__(self): 153 | return len(self.data) 154 | 155 | def _check_exists(self): 156 | return all(map(lambda f: os.path.exists(os.path.join(self.root, f)), self.files)) 157 | 158 | def download(self): 159 | """Download the olfaction data if it doesn't exist in processed_folder already.""" 160 | 161 | if self._check_exists(): 162 | return 163 | 164 | # download files 165 | try: 166 | os.makedirs(os.path.join(self.root)) 167 | except OSError as e: 168 | if e.errno == errno.EEXIST: 169 | pass 170 | else: 171 | raise 172 | 173 | for url in self.urls: 174 | filename = url.rpartition('/')[2] 175 | download_url(url + self.download_option, root=self.root, filename=filename, md5=None) 176 | 177 | def ccle_feature_filter(self, X, y, threshold=0.1): 178 | # Remove all features that do not have at least pearson correlation at threshold with y 179 | corrs = np.array([np.abs(np.corrcoef(x, y)[0,1]) if x.std() > 0 else 0 for x in X.T]) 180 | selected = corrs >= threshold 181 | print(selected.sum(), selected.shape, corrs[34758]) 182 | return selected, corrs 183 | 184 | def load_data(self): 185 | X_drugs, y_drugs, drugs, cells, features = self.load_ccle() 186 | drug_idx = drugs.get_loc(self.task) 187 | 188 | if self.verbose: 189 | print('Drug {}'.format(drugs[drug_idx])) 190 | 191 | X_drug, y_drug = X_drugs[drug_idx], y_drugs[drug_idx] 192 | 193 | # Specific to PLX4720. Filters out all features with pearson correlation less than 0.1 in magnitude 194 | if self.verbose: 195 | print('Filtering by correlation with signal first') 196 | ccle_selected, corrs = self.ccle_feature_filter(X_drug, y_drug) 197 | # keeps the ground truth features 198 | for plx4720_feat in self.fea_groundtruth: 199 | idx = features.get_loc(plx4720_feat) 200 | ccle_selected[idx] = True 201 | if self.verbose: 202 | print('Correlation for {}: {:.4f}'.format(plx4720_feat, corrs[idx])) 203 | ccle_features = features[ccle_selected] 204 | 205 | # uses data from filtered features only 206 | X_drug = X_drug[:, np.nonzero(ccle_selected)[0]] 207 | 208 | return X_drug, y_drug, ccle_features, features 209 | 210 | def load_ccle(self): 211 | r"""Load CCLE dataset 212 | This method is based on the code in https://github.com/tansey/hrt/blob/master/examples/ccle/main.py 213 | published together with the paper Tansey et al. (http://arxiv.org/abs/1811.00645) 214 | and is subject to the following license: 215 | 216 | The MIT License (MIT) 217 | 218 | Copyright (c) 2018 Wesley Tansey 219 | 220 | Permission is hereby granted, free of charge, to any person obtaining a copy 221 | of this software and associated documentation files (the "Software"), to deal 222 | in the Software without restriction, including without limitation the rights 223 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 224 | copies of the Software, and to permit persons to whom the Software is 225 | furnished to do so, subject to the following conditions: 226 | 227 | The above copyright notice and this permission notice shall be included in all 228 | copies or substantial portions of the Software. 229 | 230 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 231 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 232 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 233 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 234 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 235 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 236 | SOFTWARE. 237 | """ 238 | 239 | if self.feature_type in ['expression', 'both']: 240 | # Load gene expression 241 | expression = pd.read_csv(os.path.join(self.root, self.files[1]), delimiter='\t', header=2, index_col=1).iloc[:,1:] 242 | expression.columns = [c.split(' (ACH')[0] for c in expression.columns] 243 | features = expression 244 | if self.feature_type in ['mutation', 'both']: 245 | # Load gene mutation 246 | mutations = pd.read_csv(os.path.join(self.root, self.files[0]), delimiter='\t', header=2, index_col=1).iloc[:,1:] 247 | mutations = mutations.iloc[[c.endswith('_MUT') for c in mutations.index]] 248 | features = mutations 249 | if self.feature_type == 'both': 250 | both_cells = set(expression.columns) & set(mutations.columns) 251 | z = {} 252 | for c in both_cells: 253 | exp = expression[c].values 254 | if len(exp.shape) > 1: 255 | exp = exp[:,0] 256 | z[c] = np.concatenate([exp, mutations[c].values]) 257 | both_df = pd.DataFrame(z, index=[c for c in expression.index] + [c for c in mutations.index]) 258 | features = both_df 259 | response = pd.read_csv(os.path.join(self.root, self.files[2]), header=0, index_col=[0,2]) 260 | 261 | # Get per-drug X and y regression targets 262 | cells = response.index.levels[0] 263 | drugs = response.index.levels[1] 264 | X_drugs = [[] for _ in drugs] 265 | y_drugs = [[] for _ in drugs] 266 | for j, drug in enumerate(drugs): 267 | if self.task is not None and drug != self.task: 268 | continue 269 | for i,cell in enumerate(cells): 270 | if cell not in features.columns or (cell, drug) not in response.index: 271 | continue 272 | X_drugs[j].append(features[cell].values) 273 | y_drugs[j].append(response.loc[(cell,drug), 'Amax']) 274 | print('{}: {}'.format(drug, len(y_drugs[j]))) 275 | 276 | X_drugs = [np.array(x_i) for x_i in X_drugs] 277 | y_drugs = [np.array(y_i) for y_i in y_drugs] 278 | 279 | return X_drugs, y_drugs, drugs, cells, features.index 280 | 281 | def __repr__(self): 282 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 283 | fmt_str += ' CCLE Task: {}\n'.format(self.task) 284 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 285 | tmp = 'train' if self.train is True else 'test' 286 | fmt_str += ' Split: {}\n'.format(tmp) 287 | fmt_str += ' Z-Score: {}\n'.format(self.z_score) 288 | fmt_str += ' Root Location: {}\n'.format(self.root) 289 | return fmt_str 290 | 291 | 292 | if __name__=="__main__": 293 | 294 | DIR_DATASET = '~/data/ccle' 295 | 296 | # Common random seed to all datasets: 297 | random_seed = 123 298 | 299 | # P(X,X) distribution: 300 | trainset = CCLE_Dataset(DIR_DATASET, train = True) 301 | print (trainset) 302 | tr_P = DataLoader(trainset, batch_size=50, shuffle=True, num_workers=1) 303 | 304 | trainset_t = CCLE_Dataset(DIR_DATASET, train = False, parent_dataset = trainset) 305 | print (trainset_t) 306 | tr_P_t = DataLoader(trainset_t, batch_size=50, shuffle=True, num_workers=1) 307 | -------------------------------------------------------------------------------- /datasets/dataset_builder.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import sys 3 | 4 | from datasets.toy_dataset import ToyDataset, LiangSwitchingDataset, LiangDataset, SinExpDataset 5 | from datasets.ccle_dataset import CCLE_Dataset 6 | 7 | 8 | def build_dataset(opt): 9 | if opt.dataset in ['toy', 'liang', 'liang_switch']: 10 | # initalization of all toy datasets are the same 11 | DSCLASS = {'toy': ToyDataset, 12 | 'liang': LiangDataset, 'liang_switch': LiangSwitchingDataset} 13 | 14 | DSCLASS = DSCLASS[opt.dataset] 15 | 16 | ds_train_P = DSCLASS(opt.numSamples, opt.Xdim, seed=opt.dataseed) 17 | ds_train_Q = DSCLASS(opt.numSamples, opt.Xdim, seed=opt.dataseed, 18 | betas=ds_train_P.betas, 19 | data=ds_train_P.data, targets=ds_train_P.targets, shuffle_targets=True) 20 | 21 | ds_test_P = DSCLASS(opt.numSamples, opt.Xdim, seed=opt.dataseed+3, 22 | betas=ds_train_P.betas) 23 | ds_test_Q = DSCLASS(opt.numSamples, opt.Xdim, seed=opt.dataseed+3, 24 | betas=ds_train_P.betas, 25 | data=ds_test_P.data, targets=ds_test_P.targets, shuffle_targets=True) 26 | 27 | elif opt.dataset == 'sinexp': 28 | ds_train_P = SinExpDataset(opt.numSamples, opt.Xdim, seed=opt.dataseed, 29 | rho = opt.sinexp_rho, gaussian = opt.sinexp_gaussian) 30 | ds_train_Q = SinExpDataset(opt.numSamples, opt.Xdim, seed=opt.dataseed, 31 | betas=ds_train_P.betas, rho = opt.sinexp_rho, gaussian = opt.sinexp_gaussian, 32 | data=ds_train_P.data, targets=ds_train_P.targets, shuffle_targets=True) 33 | 34 | ds_test_P = SinExpDataset(opt.numSamples, opt.Xdim, seed=opt.dataseed+3, 35 | betas=ds_train_P.betas, rho = opt.sinexp_rho, gaussian = opt.sinexp_gaussian) 36 | ds_test_Q = SinExpDataset(opt.numSamples, opt.Xdim, seed=opt.dataseed+3, 37 | betas=ds_train_P.betas, rho = opt.sinexp_rho, gaussian = opt.sinexp_gaussian, 38 | data=ds_test_P.data, targets=ds_test_P.targets, shuffle_targets=True) 39 | 40 | elif opt.dataset == 'ccle': 41 | ds_train_P = CCLE_Dataset(opt.dataroot, task=opt.task, train=True, test_size=opt.test_size, shuffle_targets=False, 42 | seed=opt.dataseed, z_score=True, download=True) 43 | ds_train_Q = CCLE_Dataset(opt.dataroot, task=opt.task, train=True, test_size=opt.test_size, shuffle_targets=True, 44 | seed=opt.dataseed, z_score=True, download=True, parent_dataset = ds_train_P) 45 | 46 | ds_test_P = CCLE_Dataset(opt.dataroot, task=opt.task, train=False, test_size=opt.test_size, shuffle_targets=False, 47 | seed=opt.dataseed, z_score=True, download=True, parent_dataset = ds_train_P) 48 | ds_test_Q = CCLE_Dataset(opt.dataroot, task=opt.task, train=False, test_size=opt.test_size, shuffle_targets=True, 49 | seed=opt.dataseed, z_score=True, download=True, parent_dataset = ds_train_P) 50 | else: 51 | raise ValueError('Please use one of the following for dataset: toy | ccle | olfaction.') 52 | 53 | tr_P = DataLoader(ds_train_P, batch_size=opt.batchSize, shuffle=True, drop_last = True) 54 | tr_Q = DataLoader(ds_train_Q, batch_size=opt.batchSize, shuffle=True, drop_last = True) 55 | 56 | # for test phase, one single batch is created: batch_size=np.inf 57 | te_P = DataLoader(ds_test_P, batch_size=sys.maxsize, shuffle=True) 58 | te_Q = DataLoader(ds_test_Q, batch_size=sys.maxsize, shuffle=True) 59 | 60 | # resets to the correct dimension 61 | opt.Xdim = ds_train_P.data.size(1) 62 | 63 | return [tr_P, tr_Q, te_P, te_Q], ds_train_P.get_feature_names(), ds_train_P.get_groundtruth_features() 64 | 65 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from torch.utils.data import DataLoader 4 | import numpy as np 5 | 6 | 7 | def load_mnist(batch_size=200, conv_net=False, num_workers=1): 8 | '''Load the MNIST dataset 9 | 10 | Args: 11 | conv_net: set to `True` if the dataset is being used with a conv net (i.e. the inputs have to be 3d tensors and not flattened) 12 | ''' 13 | DIR_DATASET = '~/data/mnist' 14 | 15 | transform_list = [ 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.1307,), (0.3081,))] 18 | 19 | if not conv_net: 20 | transform_list.append(transforms.Lambda(lambda x: x.view(x.size(1) * x.size(2)))) 21 | 22 | transform = transforms.Compose(transform_list) 23 | 24 | trainset = datasets.MNIST(DIR_DATASET, train=True, download=True, transform=transform) 25 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 26 | 27 | testset = datasets.MNIST(DIR_DATASET, train=False, transform=transform) 28 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 29 | 30 | input_shape = trainset.train_data[0].shape 31 | 32 | if conv_net: 33 | # Channels, width, height 34 | input_shape = tuple(input_shape[-1:] + input_shape[:-1]) 35 | else: 36 | input_shape = np.prod(input_shape) 37 | 38 | return train_loader, test_loader, input_shape 39 | 40 | 41 | def load_fashion(batch_size=200, conv_net=False, num_workers=1): 42 | '''Load the fashion MNIST dataset 43 | 44 | Args: 45 | conv_net: set to `True` if the dataset is being used with a conv net (i.e. the inputs have to be 3d tensors and not flattened) 46 | ''' 47 | DIR_DATASET = '~/data/fashion' 48 | 49 | transform_list = [ 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.1307,), (0.3081,))] 52 | 53 | if not conv_net: 54 | transform_list.append(transforms.Lambda(lambda x: x.view(x.size(1) * x.size(2)))) 55 | 56 | transform = transforms.Compose(transform_list) 57 | 58 | trainset = datasets.FashionMNIST(DIR_DATASET, train=True, download=True, transform=transform) 59 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 60 | 61 | testset = datasets.FashionMNIST(DIR_DATASET, train=False, transform=transform) 62 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 63 | 64 | input_shape = trainset.train_data[0].shape 65 | 66 | if conv_net: 67 | # Channels, width, height 68 | input_shape = tuple(input_shape[-1:] + input_shape[:-1]) 69 | else: 70 | input_shape = np.prod(input_shape) 71 | 72 | return train_loader, test_loader, input_shape 73 | 74 | 75 | def load_sklearndata(namedataset='iris', batch_size=100, test_size=0.1, seed=123, z_score=True, xname=None, yname=None, **kwargs): 76 | r"""Loads any of the sklearn standard datasets, which includes several UCI datasets 77 | 78 | Args: 79 | datase_name (string): the name of the dataset to load 80 | xname (strin): name of inputs x (if it's None, guess `data`) 81 | yname (strin): name of targets y (if it's None, guess `target`) 82 | kwargs (dict): arguments to pass to sklearn dataset loader 83 | """ 84 | from sklearn import datasets as skdatasets 85 | 86 | if hasattr(skdatasets, namedataset): 87 | name = namedataset 88 | elif hasattr(skdatasets, 'load_' + namedataset): 89 | name = 'load_' + namedataset 90 | elif hasattr(skdatasets, 'make_' + namedataset): 91 | name = 'make_' + namedataset 92 | else: 93 | raise ValueError('Dataset {} not recognized'.format(namedataset)) 94 | 95 | print('Loading scikit learn dataset {}'.format(name)) 96 | dataset = getattr(skdatasets, name)(**kwargs) 97 | 98 | if xname is None: 99 | xname = 'data' 100 | if yname is None: 101 | yname = 'target' 102 | 103 | data = dataset[xname] 104 | targets = dataset[yname] 105 | 106 | # Split in trainind and test 107 | if isinstance(test_size, float): 108 | if test_size > 1.0 or test_size < 0.0: 109 | raise ValueError('test_size must be integer or a float between 0.0 and 1.0') 110 | else: 111 | test_size = int(len(data) * test_size) 112 | elif isinstance(test_size, int): 113 | if test_size >= len(data) or test_size < 0: 114 | raise ValueError('integer test_size must be between 0 and {}'.format(len(data))) 115 | 116 | # Random number generator 117 | rng = np.random.RandomState(seed) 118 | 119 | # Permutation indices: 120 | perm = rng.permutation(len(data)) 121 | ind_train = perm[test_size:] 122 | ind_test = perm[:test_size] 123 | 124 | train_data, train_targets = data[ind_train], targets[ind_train] 125 | test_data, test_targets = data[ind_test], targets[ind_test] 126 | 127 | # z-score according to training split 128 | if z_score: 129 | mu = np.mean(data[ind_train], 0) 130 | sd = np.std(data[ind_train], 0) + 1e-6 131 | train_data = (train_data - mu) / sd 132 | test_data = (test_data - mu) / sd 133 | 134 | # Convert to torch.Tensor 135 | train_data = torch.FloatTensor(train_data) 136 | test_data = torch.FloatTensor(test_data) 137 | 138 | if np.issubdtype(train_targets.reshape(-1, 1)[0][0], np.integer): 139 | train_targets = torch.LongTensor(train_targets) 140 | test_targets = torch.LongTensor(test_targets) 141 | else: 142 | train_targets = torch.FloatTensor(train_targets) 143 | test_targets = torch.FloatTensor(test_targets) 144 | 145 | trainset = torch.utils.data.TensorDataset(train_data, train_targets) 146 | testset = torch.utils.data.TensorDataset(test_data, test_targets) 147 | 148 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 149 | test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True) 150 | 151 | input_shape = data.shape[-1] 152 | 153 | return train_loader, test_loader, input_shape 154 | -------------------------------------------------------------------------------- /datasets/toy_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.utils.data 5 | 6 | 7 | class ToyDataset(torch.utils.data.Dataset): 8 | """Implementation of a toy dataset as described in https://arxiv.org/pdf/1606.07892.pdf sec 5.1 5.2. 9 | """ 10 | 11 | def __init__(self, opt, shuffle_targets = False, data = None, targets = None, seed = 31): 12 | # Random number generator 13 | self.rng = np.random.RandomState(seed) 14 | 15 | self.shuffle_targets = shuffle_targets 16 | self.data = data 17 | self.targets = targets 18 | self.features = [ 'x%d'%i for i in range(opt.Xdim)] 19 | self.fea_groundtruth = [0, 1] 20 | 21 | if self.data is None: 22 | self.data = torch.randn(opt.numSamples, opt.Xdim) 23 | 24 | if self.targets is None: 25 | Z = torch.randn(opt.numSamples) 26 | if opt.Yfunction == 'linear': 27 | self.targets = self.data[:,0] + Z 28 | elif opt.Yfunction == 'sine': 29 | self.targets = 5.0 * torch.sin(4*math.pi*(self.data[:,0]**2 + self.data[:,1]**2)) + 0.25 * Z 30 | self.targets = self.targets.view(opt.numSamples,1) 31 | 32 | def __getitem__(self, index): 33 | """ 34 | Args: 35 | index (int): Index 36 | 37 | Returns: 38 | tuple: (image, target) where target is index of the target class. 39 | """ 40 | if self.shuffle_targets: 41 | y_index = self.rng.randint(len(self.targets)) 42 | else: 43 | y_index = index 44 | return self.data[index], self.targets[y_index] 45 | 46 | def __repr__(self): 47 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 48 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 49 | 50 | fmt_str += '=== X === \n' 51 | t = self.data.data if isinstance(self.data, torch.Tensor) else self.data 52 | s = '{:8s} [{:.4f} , {:.4f}] m+-s = {:.4f} +- {:.4f}' 53 | si = 'x'.join(map(str, t.shape if isinstance(t, np.ndarray) else t.size())) 54 | fmt_str += s.format(si, t.min(), t.max(), t.mean(), t.std()) + '\n' 55 | 56 | fmt_str += '=== Y === \n' 57 | t = self.targets.data if isinstance(self.targets, torch.Tensor) else self.targets 58 | s = '{:8s} [{:.4f} , {:.4f}] m+-s = {:.4f} +- {:.4f}' 59 | si = 'x'.join(map(str, t.shape if isinstance(t, np.ndarray) else t.size())) 60 | fmt_str += s.format(si, t.min(), t.max(), t.mean(), t.std()) + '\n' 61 | 62 | return fmt_str 63 | 64 | def __len__(self): 65 | return len(self.data) 66 | 67 | def get_feature_names(self): 68 | return self.features 69 | 70 | def get_groundtruth_features(self): 71 | return self.fea_groundtruth 72 | 73 | 74 | class LiangDataset(torch.utils.data.Dataset): 75 | """Generates data from the simulation study in Liang et al, JASA 2017 76 | Sourced from : https://github.com/tansey/hrt/blob/master/benchmarks/liang/sim_liang.py 77 | N = 500 # total number of samples 78 | P = 500 # number of features 79 | S = 40 # number of signal features 80 | T = 100 # test sample size 81 | """ 82 | def __init__(self, numSamples, Xdim, seed, S=40, betas=None, data=None, targets=None, shuffle_targets=False, targets_mu=1.0, targets_sd=1.0): 83 | # Random number generator 84 | self.rng = np.random.RandomState(seed) 85 | 86 | self.shuffle_targets = shuffle_targets 87 | self.data = data 88 | self.betas = betas 89 | self.targets = targets 90 | self.features = [ 'x%d'%i for i in range(Xdim)] 91 | self.fea_groundtruth = [i for i in range(S)] 92 | 93 | if betas is None: 94 | self.betas = self.compute_betas(S) 95 | 96 | if self.data is None and self.targets is None: 97 | self.data = self.make_X(numSamples, Xdim) 98 | self.targets = self.make_targets(self.data, S) # may modify X in-place 99 | self.data = torch.from_numpy(self.data).type('torch.FloatTensor') 100 | self.targets = torch.from_numpy(self.targets).type('torch.FloatTensor') 101 | self.targets = self.targets.view(numSamples,1) 102 | 103 | self.targets_mu, self.targets_sd = targets_mu, targets_sd 104 | 105 | def compute_betas(self, S): 106 | w0 = self.rng.normal(1, size=S//4) 107 | w1 = self.rng.normal(2, size=S//4) 108 | w2 = self.rng.normal(2, size=S//4) 109 | w21 = self.rng.normal(1, size=(1,S//4)) 110 | w22 = self.rng.normal(2, size=(1,S//4)) 111 | return [w0, w1, w2, w21, w22] 112 | 113 | def make_X(self, N, P): 114 | X = (self.rng.normal(size=(N,1)) + self.rng.normal(size=(N,P))) / 2. 115 | return X 116 | 117 | def make_targets(self, X, S): 118 | N, P = X.shape 119 | w0, w1, w2, w21, w22 = self.betas 120 | y = X[:,0:S:4].dot(w0) + X[:,1:S:4].dot(w1) + np.tanh(w21*X[:,2:S:4] + w22*X[:,3:S:4]).dot(w2) + self.rng.normal(0, 0.5, size=N) 121 | return y 122 | 123 | def __getitem__(self, index): 124 | """ 125 | Args: 126 | index (int): Index 127 | Returns: 128 | tuple: (image, target) where target is index of the target class. 129 | """ 130 | if self.shuffle_targets: 131 | y_index = self.rng.randint(len(self.targets)) 132 | else: 133 | y_index = index 134 | return self.data[index], (self.targets[y_index] - self.targets_mu) / self.targets_sd 135 | 136 | def __len__(self): 137 | return len(self.data) 138 | def get_feature_names(self): 139 | return self.features 140 | 141 | def get_groundtruth_features(self): 142 | return self.fea_groundtruth 143 | 144 | class LiangSwitchingDataset(LiangDataset): 145 | def make_targets(self, X, S): 146 | # NOTE modifies X in-place 147 | N, P = X.shape 148 | R = S//4 # number of regions to switch between 149 | w0, w1, w2, w21, w22 = self.betas 150 | w21, w22 = w21.squeeze(), w22.squeeze() 151 | r = self.rng.choice(R, replace=True, size=N) 152 | y = np.zeros(N) 153 | Z = np.zeros((N,R)) # one hot indication of region 154 | truth = np.zeros((R,P-R), dtype=int) 155 | for i in range(R): 156 | y[r == i] = (X[r == i,i*4] * w0[i] + 157 | X[r == i,i*4+1] * w1[i] + 158 | w2[i] * np.tanh(w21[i]*X[r==i,i*4+2] + w22[i]*X[r==i,i*4+3])) 159 | Z[r==i, i] = 1 160 | truth[i] = np.concatenate([np.zeros(i*4), np.ones(4), np.zeros((R-i-1)*4), np.zeros(P-5*R)]) 161 | y += self.rng.normal(0, 0.5, size=N) 162 | assert P >= S+R, 'Need high enough X dimension to have S used features, and R one-hots' 163 | # overwrite (unused) last R features in X with one hot region indicators. 164 | X[:, -R:] = Z 165 | # Return just y 166 | return y 167 | 168 | 169 | class SinExpDataset(torch.utils.data.Dataset): 170 | """Complex multivariate model from https://www.padl.ws/papers/Paper%2012.pdf section 5.3 171 | y = sin(x_1 * (x_1 + x_2)) * cos(x_3 + x_4 * x_5) * sin(exp(x_5) + exp(x_6) - x_2) + eps 172 | with P = 50, in the original paper, sampled from uniform distribution, 173 | and eps sampled from a zero-centered Gaussian with variance such that the SNR = 2 174 | 175 | Args: 176 | opt.Xdim (int): number of features 177 | opt.numSamples (int): number of samples 178 | shuffle_targets (bool): whether y should be reshuffled to be decorrelated from X 179 | snr (float): signal to noise ratio of y 180 | gaussian (bool): sample covariates from gaussian or uniform 181 | rho (float): correlation coefficient between pairs of covariates: 182 | x_i = sqrt(rho) * z + sqrt(1 - rho) * randn(0, 1), with z common to all x_i 183 | sigma (float): noise amplitude 184 | seed (int): random seed 185 | """ 186 | def __init__(self, n_samples=125, n_features=50, shuffle_targets=False, data=None, targets=None, betas=None, 187 | rho=0.5, gaussian=False, normalize=True, sigma=None, seed=31): 188 | 189 | self.n_samples = n_samples 190 | self.n_features = n_features 191 | self.shuffle_targets = shuffle_targets 192 | self.rho = rho 193 | self.sigma = sigma 194 | self.gaussian = gaussian 195 | self.normalize = normalize 196 | self.seed = seed 197 | 198 | # Interface stuff 199 | self.fea_groundtruth = list(range(6)) 200 | self.features = ['x{}'.format(i) for i in range(self.n_features)] 201 | self.betas = None 202 | 203 | # Random number generator 204 | self.rng = np.random.RandomState(seed) 205 | 206 | # Fix sigma so that SNR = s^2 / sigma^2 = 2.0 --> sigma = s / sqrt(2.0) 207 | if sigma is None: 208 | if gaussian: 209 | self.sigma = 0.2043 210 | else: 211 | self.sigma = 0.1840 212 | 213 | # Get data and targets 214 | if data is not None: 215 | self.data = data 216 | self.targets = targets 217 | 218 | else: 219 | data, targets = self._sample_dataset(self.n_samples, self.n_features, self.rho, self.sigma, self.gaussian) 220 | self.data, self.targets = torch.FloatTensor(data), torch.FloatTensor(targets) 221 | 222 | # z-score data and targets (normalization computed at SNR = 2.0) 223 | if normalize: 224 | if gaussian: 225 | self.data_mu, self.data_sd = 0.0, 1.0 226 | self.targets_mu, self.targets_sd = 0.03441, 0.3547 227 | 228 | else: # uniform distribution 229 | self.data_mu = 0.5 * np.sqrt(self.rho) + 0.5 * np.sqrt(1 - self.rho) 230 | self.data_sd = 1 / np.sqrt(12) 231 | self.targets_mu, self.targets_sd = 0.08565, 0.3187 232 | 233 | self.data = (self.data - self.data_mu) / self.data_sd 234 | self.targets = (self.targets - self.targets_mu) / self.targets_sd 235 | 236 | def __getitem__(self, index): 237 | """ 238 | Args: 239 | index (int): Index 240 | Returns: 241 | tuple: (image, target) where target is index of the target class. 242 | """ 243 | if self.shuffle_targets: 244 | y_index = self.rng.randint(len(self.targets)) 245 | else: 246 | y_index = index 247 | return self.data[index], self.targets[y_index] 248 | 249 | def __len__(self): 250 | return len(self.data) 251 | 252 | def _sample_dataset(self, N, P, rho, sigma, gaussian): 253 | assert P >= 6, "n_features must be at least 6 for this dataset" 254 | 255 | # In the original paper X are sampled from a uniform distribution 256 | if gaussian: 257 | X = np.sqrt(rho) * self.rng.randn(N, 1) + np.sqrt(1 - rho) * self.rng.randn(N, P) 258 | else: 259 | X = np.sqrt(rho) * self.rng.rand(N, 1) + np.sqrt(1 - rho) * self.rng.rand(N, P) 260 | 261 | y = np.sin(X[:, 0] * (X[:, 0] + X[:, 1])) * np.cos(X[:, 2] + X[:, 3] * X[:, 4]) *\ 262 | np.sin(np.exp(X[:, 4]) + np.exp(X[:, 5]) - X[:, 1]) 263 | 264 | y += sigma * self.rng.randn(len(y)) 265 | 266 | return X, y 267 | 268 | def get_feature_names(self): 269 | return self.features 270 | 271 | def get_groundtruth_features(self): 272 | return self.fea_groundtruth 273 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SIC/c4e45d7736da6e6faabdc56bfc1336445df99204/modules/__init__.py -------------------------------------------------------------------------------- /modules/generators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.utils.data.dataloader import DataLoader 6 | from torch.utils.data.dataset import Dataset 7 | from torch.distributions import Categorical 8 | import numpy as np 9 | import types 10 | from copy import copy 11 | from utils import DDICT 12 | 13 | 14 | class ConditionalBatchNorm1d(nn.Module): 15 | r"""Conditional BatchNorm 16 | 17 | Args: 18 | num_features (int): number of features in input 19 | num_classes (int): number of classe among for labels in input 20 | 21 | Attributes: 22 | embed (Tensor): embedding of labels to scaling matrix `gamma` and bias matrix `beta` 23 | 24 | Shape: 25 | - Inputs: 26 | inputs: 2-tuple such that inputs = (x, labels), with 27 | x: FloatTensor of shape `(batch_size, num_features)` 28 | labels: LongTensor of size `(batch_size)` 29 | - Outputs: FloatTensor of shape `(batch_size, num_features)`, i.e. same as x 30 | """ 31 | def __init__(self, num_features, num_classes): 32 | super().__init__() 33 | self.num_features = num_features 34 | self.num_classes = num_classes 35 | 36 | self.bn = nn.BatchNorm1d(num_features, affine=False) 37 | self.embed = nn.Embedding(num_classes, num_features * 2) 38 | 39 | # Scale and biases 40 | self.embed.weight.data[:, :num_features].normal_(1, 1.0 / np.sqrt(num_features)) 41 | self.embed.weight.data[:, num_features:].zero_() 42 | 43 | def forward(self, inputs): 44 | x, labels = inputs 45 | outputs = self.bn(x) 46 | gamma, beta = self.embed(labels).chunk(2, 1) 47 | outputs = gamma.view(-1, self.num_features) * outputs + beta.view(-1, self.num_features) 48 | return outputs 49 | 50 | def __repr__(self): 51 | fmt_str = self.__class__.__name__ 52 | fmt_str += '(num_features={num_features}, num_classes={num_classes})'.format(**self.__dict__) 53 | return fmt_str 54 | 55 | 56 | class Generator(nn.Module): 57 | r"""Generic generator that will be inherited by all generators. 58 | 59 | Args: 60 | num_features (int): number of inputs features 61 | n_layers (int): number of hidden layers (each one with ConditionalBatchNorm1d) 62 | n_hiddens (int): number of hidden neurons 63 | p_dropout (float): dropout rate 64 | 65 | Shape: 66 | - Inputs: 67 | inputs: 2-tuple such that inputs = (x, labels), with 68 | x: FloatTensor of size `(batch_size, in_feaures)` 69 | labels: LongTensor of size `(batchsize, 1)` indicating which feature has to be predicted 70 | - Output: Not implemented 71 | """ 72 | def __init__(self, num_features, n_layers, n_hiddens, p_dropout=0.0): 73 | super().__init__() 74 | self.num_features = num_features 75 | self.n_layers = n_layers 76 | self.n_hiddens = n_hiddens 77 | self.p_dropout = p_dropout 78 | 79 | # Linear layers and ConditionalBatchNorm 80 | self.w, self.cb, self.dr = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() 81 | for k in range(n_layers): 82 | self.w.append(nn.Linear(num_features if k == 0 else n_hiddens, n_hiddens)) 83 | self.cb.append(ConditionalBatchNorm1d(n_hiddens, num_features)) 84 | if p_dropout > 0.0: 85 | self.dr.append(nn.Dropout(p=self.p_dropout)) 86 | 87 | self.w_out = None 88 | self.criterion = None 89 | 90 | def forward(self, inputs): 91 | x, labels = inputs 92 | x = x.view(-1, self.num_features) 93 | 94 | # mask with a zero in correspondance to label in `labels` 95 | mask = torch.ones_like(x) 96 | mask.scatter_(-1, labels.view(-1, 1), 0.0) 97 | 98 | x = x * mask 99 | for l, (w, cb) in enumerate(zip(self.w, self.cb)): 100 | x = F.relu(cb((w(x), labels))) 101 | if self.p_dropout > 0.0: 102 | x = self.dr[l](x) 103 | return self._sample_outputs(x) 104 | 105 | def get_targets(self, x, labels): 106 | """Returns the targets, i.e. the features corresponding to the labels 107 | """ 108 | idx = torch.arange(0, len(labels)) 109 | return x[idx, labels].view(-1, 1) 110 | 111 | def get_dataloader(self, dl, idx_feature): 112 | """Returns a dataloader that is the same as `dataloader` except that the feature `i` is sampled from the generator 113 | 114 | Args: 115 | dl (DataLoader): base data loader 116 | idx_feature (int): index of the feature that is replaced 117 | """ 118 | new_loader = DataLoader(dataset=GenDataset(self, dl.dataset, idx_feature), batch_size=dl.batch_size) 119 | return new_loader 120 | 121 | def _sample_outputs(self, x): 122 | raise NotImplementedError 123 | 124 | def sample_features(self, inputs): 125 | """Samples with no noise. Used by `test_generator`. 126 | """ 127 | raise NotImplementedError 128 | 129 | def training_loss(self, x, labels): 130 | raise NotImplementedError 131 | 132 | 133 | class GeneratorAvg(Generator): 134 | r"""Generator that as prediction of a feature outputs the average of all other features 135 | 136 | Args: 137 | num_features (int): number of inputs features 138 | 139 | Shape: 140 | - Inputs: 141 | inputs: 2-tuple such that inputs = (x, labels), with 142 | x: FloatTensor of size `(batch_size, in_feaures)` 143 | labels: LongTensor of size `(batchsize, 1)` indicating which feature has to be predicted 144 | - Output: FloatTensor of size `(batch_size, 1)` 145 | """ 146 | def __init__(self, num_features): 147 | super().__init__(num_features, 0, 0, 0) 148 | self.id = nn.Parameter(torch.randn(1)) # needed for backward compatibility of Generator interface (so that gen.parameters() is not empty) 149 | 150 | def _sample_outputs(self, x): 151 | """The features corresponding to `labels` have been masked out. 152 | Predicts the average of the other features. 153 | """ 154 | mu = x.sum(-1) / (self.num_features - 1) 155 | sd = x.std(-1) / np.sqrt(self.num_features - 1) 156 | eps = torch.randn(mu.size()) / 20 157 | return mu + sd * eps 158 | 159 | def sample_features(self, inputs): 160 | """This is redundant wrt forward. It's just here to have a coherent interface with GeneratorClassify 161 | """ 162 | self.eval() 163 | with torch.no_grad(): 164 | outputs = self(inputs) 165 | return outputs 166 | 167 | 168 | class GeneratorOracle(Generator): 169 | r"""Generator that from a sample of a set of correlated features with known correlation `rho`, predicts an output with same correlation 170 | 171 | Args: 172 | num_features (int): number of inputs features 173 | gaussian (bool): whether the variables are assumed to be gaussian or uniform 174 | rho (float): pairwise correlation between variables (features) 175 | normalize (True): whether the features are normalized 176 | 177 | Shape: 178 | - Inputs: 179 | inputs: 2-tuple such that inputs = (x, labels), with 180 | x: FloatTensor of size `(batch_size, in_feaures)` 181 | labels: LongTensor of size `(batchsize, 1)` indicating which feature has to be predicted 182 | - Output: FloatTensor of size `(batch_size, 1)` 183 | """ 184 | def __init__(self, num_features, gaussian=False, rho=0.5, normalize=True): 185 | super().__init__(num_features, 0, 0, 0) 186 | self.id = nn.Parameter(torch.randn(1)) # needed for backward compatibility of Generator interface (so that gen.parameters() is not empty) 187 | 188 | self.gaussian = gaussian 189 | self.rho = rho 190 | self.normalize = normalize 191 | 192 | # Normalizations 193 | if gaussian: 194 | self.X_mu, self.X_sd = 0.0, 1.0 195 | else: 196 | self.X_mu = 0.5 * np.sqrt(self.rho) + 0.5 * np.sqrt(1 - self.rho) 197 | self.X_sd = 1 / np.sqrt(12) 198 | 199 | def _sample_outputs(self, inputs): 200 | """The features corresponding to `labels` have been masked out. 201 | Predicts the average of the other features. 202 | """ 203 | N, P = inputs.size() 204 | 205 | # Estimate of correlation factor 206 | corrcoef = inputs.sum(-1) / (P - 1) 207 | if self.normalize: 208 | corrcoef = self.X_sd * corrcoef + self.X_mu 209 | 210 | # Generate feature 211 | if self.gaussian: 212 | X = corrcoef * np.sqrt(self.rho) + np.sqrt(1 - self.rho) * torch.randn(N) 213 | else: 214 | X = corrcoef * np.sqrt(self.rho) + np.sqrt(1 - self.rho) * torch.rand(N) 215 | 216 | if self.normalize: 217 | X = (X - self.X_mu) / self.X_sd 218 | 219 | return X 220 | 221 | def sample_features(self, inputs): 222 | """This is redundant wrt forward. It's just here to have a coherent interface with GeneratorClassify 223 | """ 224 | self.eval() 225 | with torch.no_grad(): 226 | outputs = self(inputs) 227 | return outputs 228 | 229 | 230 | class GeneratorRegress(Generator): 231 | r"""Generator that outputs regression prediction 232 | Architecture is a 2-layer network, training loss is Huber loss 233 | 234 | Args: 235 | num_features (int): number of inputs features 236 | n_layers (int): number of hidden layers (each one with ConditionalBatchNorm1d) 237 | n_hiddens (int): number of hidden neurons 238 | p_dropout (float): dropout rate 239 | 240 | Shape: 241 | - Inputs: 242 | inputs: 2-tuple such that inputs = (x, labels), with 243 | x: FloatTensor of size `(batch_size, in_feaures)` 244 | labels: LongTensor of size `(batchsize, 1)` indicating which feature has to be predicted 245 | - Output: FloatTensor of size `(batch_size, 1)` 246 | """ 247 | def __init__(self, num_features, n_layers, n_hiddens, p_dropout=0.0): 248 | super().__init__(num_features, n_layers, n_hiddens, p_dropout) 249 | 250 | self.w_out = nn.Linear(n_hiddens, 2) # mu, log_var 251 | self.criterion = nn.SmoothL1Loss() # Huber loss 252 | 253 | def _sample_outputs(self, x): 254 | """Samples from a Gaussian distribution with mu=inputs[:,0] and log_var=inputs[:,1] 255 | using the reparametrization trick 256 | """ 257 | inputs = self.w_out(x) 258 | 259 | mu, log_var = inputs.chunk(2, dim=-1) 260 | eps = torch.randn(mu.size()) 261 | return mu + torch.exp(log_var / 2) * eps 262 | 263 | def sample_features(self, inputs): 264 | """This is redundant wrt forward. It's just here to have a coherent interface with GeneratorClassify 265 | """ 266 | self.eval() 267 | with torch.no_grad(): 268 | outputs = self(inputs) 269 | return outputs 270 | 271 | def training_loss(self, x, labels): 272 | """Huber loss 273 | """ 274 | outputs = self((x, labels)) 275 | targets = self.get_targets(x, labels) 276 | return self.criterion(outputs, targets) 277 | 278 | 279 | class GeneratorClassify(Generator): 280 | r"""Generator that outputs classification prediction (softmax) 281 | Architecture is a 2-layer network, training loss is cross-entropy 282 | 283 | Args: 284 | num_features (int): number of inputs features 285 | n_layers (int): number of hidden layers (each one with ConditionalBatchNorm1d) 286 | n_hiddens (int): number of hidden neurons 287 | num_bins (int): number of bins, i.e. number of softmax outputs 288 | init_dataset (torch.FloatTensor, Dataset or DataLoader): a set of input samples representative of the dataset, 289 | necessary to compute the bins for quantization 290 | beta (float): pseudo-temperature for sampling over bins 291 | p_dropout (float): dropout rate 292 | 293 | Shape: 294 | - Inputs: 295 | inputs: 2-tuple such that inputs = (x, labels), with 296 | x: FloatTensor of size `(batch_size, in_feaures)` 297 | labels: LongTensor of size `(batchsize, 1)` indicating which feature has to be predicted 298 | - Output: FloatTensor of size `(batch_size, num_bins)` 299 | """ 300 | def __init__(self, num_features, n_layers, n_hiddens, num_bins, init_dataset, beta=1.0, p_dropout=0.0): 301 | super().__init__(num_features, n_layers, n_hiddens, p_dropout) 302 | 303 | self.num_bins = num_bins 304 | self.beta = beta 305 | self.w_out = nn.Linear(n_hiddens, num_bins) 306 | self.criterion = nn.CrossEntropyLoss() 307 | 308 | # Initialize binning 309 | if isinstance(init_dataset, DataLoader): 310 | self.init_dataset = init_dataset.dataset[:][0] 311 | elif isinstance(init_dataset, Dataset): 312 | self.init_dataset = init_dataset[:][0] 313 | else: 314 | self.init_dataset = init_dataset 315 | 316 | self.bin_edges, self.bin_centers, self.bin_widths = self._quantization_binning(self.init_dataset, num_bins) 317 | self.bin_centers = torch.FloatTensor(self.bin_centers) 318 | self.bin_widths = torch.FloatTensor(self.bin_widths) 319 | 320 | def _quantization_binning(self, data, num_bins): 321 | """Quantize the inputs and computes binning, assuming that all input features have same distribution 322 | 323 | Shape: 324 | - Outputs: 325 | bin_edges: array of size `(num_bins + 1, num_features)`, edges of bins for each feature 326 | bin_centers: array of size `(num_bins, num_features)`, ceters of bins for each feature 327 | """ 328 | qtls = np.arange(0.0, 1.0 + 1 / num_bins, 1 / num_bins) 329 | bin_edges = np.quantile(data, qtls, axis=0) # (num_bins + 1, num_features) 330 | bin_widths = np.diff(bin_edges, axis=0) 331 | bin_centers = bin_edges[:-1, :] + bin_widths / 2 # () 332 | return bin_edges, bin_centers, bin_widths 333 | 334 | def _quantize(self, inputs, labels): 335 | quant_inputs = np.zeros(inputs.shape[0]) 336 | for i, (x, l) in enumerate(zip(inputs.cpu(), labels)): 337 | quant_inputs[i] = np.digitize(x, self.bin_edges[:, l]) 338 | quant_inputs = quant_inputs.clip(1, self.num_bins) - 1 # Clip edges 339 | return torch.LongTensor(quant_inputs).to(inputs.device) 340 | 341 | def sample_features(self, inputs): 342 | """Samples with no noise 343 | """ 344 | self.eval() 345 | with torch.no_grad(): 346 | x, labels = inputs 347 | logits = self(inputs) 348 | sampled_bins = Categorical(logits=self.beta * logits).sample() 349 | samples = self.bin_centers[sampled_bins, labels] + (torch.rand(len(sampled_bins)) - 0.5) * self.bin_widths[sampled_bins, labels] 350 | return samples.to(x.device).view(-1, 1) 351 | 352 | def _sample_outputs(self, x): 353 | return self.w_out(x) 354 | 355 | def training_loss(self, x, labels): 356 | """Computes cross-entropy loss from classification output 357 | """ 358 | outputs = self((x, labels)) 359 | targets = self.get_targets(x, labels) 360 | 361 | quant_targets = self._quantize(targets, labels).view(-1) 362 | return self.criterion(outputs, quant_targets) 363 | 364 | 365 | def train_generator(generator, dataloader, optimizer, features_list=None, log_times=5): 366 | '''Trains a generator model 367 | 368 | Args: 369 | generator (Generator): a generator object 370 | dataloader (Dataloader): a dataloader object 371 | optimizer (optim.optimizer): optimizer used to train `generator` 372 | feature_list (list): list of features to which training has to be restricted 373 | log_times (int): how many times the training will be logged 374 | ''' 375 | generator.train() 376 | device = next(generator.parameters()).device 377 | 378 | # In case we're using nn.DataParallel 379 | if isinstance(generator, nn.DataParallel): 380 | generator_loss = generator.module.training_loss 381 | else: 382 | generator_loss = generator.training_loss 383 | 384 | # Subset of features to train on 385 | if features_list is None: 386 | # Train to output all features 387 | features_list = list(range(generator.num_features)) 388 | features_list = torch.LongTensor(features_list) 389 | 390 | mean_loss = 0.0 391 | for batch_idx, data in enumerate(dataloader): 392 | if isinstance(data, tuple) or isinstance(data, list): 393 | data = data[0] 394 | data = data.to(device) 395 | 396 | # Generate labels at random 397 | rIdx = torch.randint(0, len(features_list), (data.shape[0],)).to(device) 398 | labels = features_list.index_select(0, rIdx) 399 | 400 | optimizer.zero_grad() 401 | loss = generator_loss(data, labels) 402 | loss.backward() 403 | optimizer.step() 404 | 405 | mean_loss += loss.item() / len(data) 406 | if log_times > 0 and batch_idx % (len(dataloader) // log_times) == 0: 407 | print(' training progress: {}/{} ({:.0f}%)\tloss: {:.6f}'.format( 408 | batch_idx * len(data), len(dataloader.dataset), 100. * batch_idx / len(dataloader), loss.item())) 409 | 410 | return mean_loss 411 | 412 | 413 | def test_generator(generator, criterion, dataloader, test_all_features=False): 414 | '''Test generator model 415 | 416 | Args: 417 | test_all_features (bool): if True, test all features for all samples as outputs, 418 | otherwise only test one feature at random per sample 419 | ''' 420 | generator.eval() 421 | device = next(generator.parameters()).device 422 | 423 | # In case we're using nn.DataParallel 424 | if isinstance(generator, nn.DataParallel): 425 | generator_get_targets = generator.module.get_targets 426 | generator_sample_features = generator.module.sample_features 427 | else: 428 | generator_get_targets = generator.get_targets 429 | generator_sample_features = generator.sample_features 430 | 431 | test_loss = 0.0 432 | with torch.no_grad(): 433 | for data in dataloader: 434 | data = data[0].to(device) 435 | 436 | if test_all_features: 437 | # Generate labels by running all features for all samples 438 | labels = torch.arange(0, data.shape[-1]).repeat(data.shape[0], 1).view(-1).to(device) 439 | idx = torch.arange(0, data.shape[0]).repeat(data.shape[-1], 1).t().reshape(-1).to(device) 440 | else: 441 | # Generate labels )t random 442 | labels = torch.randint(0, data.shape[-1], (data.shape[0],)).to(device) 443 | idx = torch.arange(0, len(labels)).to(device) 444 | 445 | targets = generator_get_targets(data[idx], labels) 446 | outputs = generator_sample_features((data[idx], labels)) 447 | test_loss += criterion(outputs, targets).item() 448 | 449 | test_loss /= len(dataloader) # loss function already averages over batch size 450 | return test_loss 451 | 452 | 453 | class GenDataset(Dataset): 454 | r"""Generator dataset: replaces one feature of the input dataset with one generated by a generator 455 | Args: 456 | generator: generator trained to generate inputs features 457 | dataset (Dataset): dataset whose features are going to be replaced 458 | idx_feature (int): feature that will be replaced by the generator 459 | 460 | Attributes: 461 | idx_feature: feature which is being replaced by the generator 462 | 463 | Notes: 464 | - the replacement feature is sampled only once (at initialization) 465 | - if you want to resample replacement features, call `resample_replaced_feature()` 466 | 467 | """ 468 | def __init__(self, generator, dataset, idx_feature): 469 | self.generator = generator 470 | self.dataset = dataset 471 | self.idx_feature = idx_feature 472 | self.resample_replaced_feature() 473 | 474 | def resample_replaced_feature(self): 475 | device = next(self.generator.parameters()).device 476 | replaced_feature = [] 477 | for data in DataLoader(self.dataset, batch_size=256, shuffle=False): 478 | data = data[0].to(device) 479 | labels = data.new_full((data.shape[0],), self.idx_feature, dtype=torch.long) 480 | replaced_feature.append(self.generator.sample_features((data, labels)).view(-1)) 481 | 482 | self._replaced_features = torch.cat(replaced_feature).to(data.device) 483 | 484 | def __getitem__(self, index): 485 | data = self.dataset.__getitem__(index) 486 | 487 | if len(data) > 1: 488 | data, target = data 489 | else: 490 | data, target = data[0], None 491 | 492 | r_data = data.new_empty(data.size()) 493 | r_data.copy_(data) 494 | r_data[..., self.idx_feature] = self._replaced_features[index] 495 | return r_data, target 496 | 497 | def __len__(self): 498 | return len(self.dataset) 499 | 500 | 501 | def generator_from_data(dataset, generator_type='regress', features_list=None, n_epochs=100, n_layers=3, n_hiddens=200, p_dropout=0, num_bins=100, training_args=None): 502 | """NOTE: Training epochs `n_epochs` should scale with the number of features. 503 | """ 504 | if generator_type == 'oracle': 505 | n_features = dataset[0][0].shape[-1] 506 | generator = GeneratorOracle(n_features, gaussian=dataset.gaussian, rho=dataset.rho, normalize=dataset.normalize) 507 | 508 | return generator, None 509 | 510 | else: # Generator needs to be trained 511 | 512 | # All default training arguments are hidden here 513 | default_args = DDICT( 514 | optimizer='Adam', 515 | batch_size=128, 516 | lr=0.003, 517 | lr_step_size=20, 518 | lr_decay=0.5, 519 | num_bins=10, 520 | ) 521 | 522 | # Custom training arguments 523 | args = default_args 524 | if training_args is not None: 525 | for k in training_args: 526 | args[k] = training_args[k] 527 | 528 | # Data 529 | n_features = dataset[0][0].shape[-1] 530 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 531 | 532 | if generator_type == 'classify': 533 | generator = GeneratorClassify(n_features, n_layers, n_hiddens, num_bins=num_bins, init_dataset=dataset) 534 | elif generator_type == 'regress': 535 | generator = GeneratorRegress(n_features, n_layers, n_hiddens) 536 | else: 537 | raise ValueError('generator_type has to be classify or regress') 538 | 539 | optimizer = getattr(optim, args.optimizer)(generator.parameters(), lr=args.lr) 540 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_decay) 541 | 542 | tr_loss = [] 543 | for epoch in range(n_epochs): 544 | tr_loss += [train_generator(generator, dataloader, optimizer, features_list, log_times=0)] 545 | scheduler.step() 546 | 547 | return generator, tr_loss 548 | -------------------------------------------------------------------------------- /modules/linregression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LinRegression(nn.Linear): 6 | r"""Linear Regression computed in closed-form 7 | 8 | Args: 9 | X (Tensor): input tensor of size `(batch_size, num_features)` 10 | Y (Tensor): output tensor of size `(batch_size, num_outputs)` 11 | eps (float): regularization parameter 12 | """ 13 | def __init__(self, X, Y, eps=1e-6): 14 | super().__init__(X.shape[-1], Y.shape[-1]) 15 | # Center data 16 | mX = X.mean(0) 17 | mY = Y.mean(0) 18 | 19 | self.weight.data = self._lin_reg(X - mX, Y - mY, eps).t() 20 | with torch.no_grad(): 21 | self.bias.data = mY - self.weight.mm(mX.view(-1, 1)).view(-1) 22 | 23 | def _lin_reg(self, X, Y, eps): 24 | CC = X.t().mm(X) / X.shape[0] 25 | XC = X.t().mm(Y) / X.shape[0] 26 | return (CC + eps * torch.eye(CC.shape[0])).inverse().mm(XC) 27 | 28 | 29 | def linreg_reconstruct(X, idx_feature, eps=1e-6): 30 | """Linear regression reconstructing an input feature 31 | It returns a `LinRegression` object trained on reconstructing feature `idx_feature` from the others 32 | 33 | Args: 34 | X (Tensor): input tensor of size `(batch_size, num_features)` 35 | idx_feature (int): feature that needs to be linearly recunstructed from the rest 36 | eps (float): regularization parameter 37 | """ 38 | idx_rest = list(set(range(X.shape[1])) - set([idx_feature])) 39 | return LinRegression(X[:, idx_rest], X[:, idx_feature].view(-1, 1), eps) 40 | -------------------------------------------------------------------------------- /modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def add_layer(seq, ix, n_inputs, n_outputs, nonlin, normalization): 6 | seq.add_module('L'+str(ix), nn.Linear(n_inputs, n_outputs)) 7 | if ix > 0 and normalization: # only LN/IN after first layer. 8 | if normalization == 'LN': 9 | seq.main.add_module('A'+str(ix), nn.LayerNorm(n_outputs)) 10 | else: 11 | raise ValueError('Unknown normalization: {}'.format(normalization)) 12 | if nonlin == 'LeakyReLU': 13 | seq.add_module('N'+str(ix), nn.LeakyReLU(0.2, inplace=True)) 14 | elif nonlin == 'ReLU': 15 | seq.add_module('N'+str(ix), nn.ReLU(inplace=True)) 16 | elif nonlin == 'Sigmoid': 17 | seq.add_module('N'+str(ix), nn.Sigmoid()) 18 | 19 | 20 | class D_phiVpsi(nn.Module): 21 | def __init__(self, insizes=[1,1], layerSizes=[ [32,32,16] ]*2, nonlin='LeakyReLU', normalization=None): 22 | super(D_phiVpsi, self).__init__() 23 | self.phi_x, self.psi_y = nn.Sequential(), nn.Sequential() 24 | # phi_x and psi_y same arch (by layerSizes) 25 | for seq, insize, layerSize in [(self.phi_x, insizes[0], layerSizes[0]), (self.psi_y, insizes[1], layerSizes[1])]: 26 | for ix, n_inputs, n_outputs in zip(range(len(layerSize)), [insize]+layerSize[:-1], layerSize): 27 | add_layer(seq, ix, n_inputs, n_outputs, nonlin, normalization) 28 | self.phiD, self.psiD = layerSizes[0][-1], layerSizes[1][-1] 29 | # inner matrix in bilinear form 30 | self.W = nn.Parameter(torch.randn(self.phiD, self.psiD)) 31 | 32 | def forward(self, x, y): 33 | x = x.view(x.size(0), -1) # bs x D with D >=1 34 | y = y.view(x.size(0), 1) # bs x 1 35 | phi_x = self.phi_x(x) 36 | psi_y = self.psi_y(y) 37 | out = (torch.mm(phi_x, self.W) * psi_y).sum(1, keepdim=True) 38 | return out 39 | 40 | class D_concat(nn.Module): 41 | def __init__(self, insizes=[1,1], layerSizes=[32,32,16], nonlin='LeakyReLU', normalization=None): 42 | super(D_concat, self).__init__() 43 | insize = sum(insizes) 44 | self.main = nn.Sequential() 45 | for ix, n_inputs, n_outputs in zip(range(len(layerSizes)), [insize]+layerSizes[:-1], layerSizes): 46 | add_layer(self.main, ix, n_inputs, n_outputs, nonlin, normalization) 47 | self.PhiD = n_outputs 48 | self.V = nn.Linear(self.PhiD, 1, bias=False) 49 | self.V.weight.data *= 100 50 | def forward(self, x, y): 51 | x = x.view(x.size(0), -1) # bs x D with D >=1 52 | y = y.view(x.size(0), 1) # bs x 1 53 | inp = torch.cat( [x,y], dim=1) 54 | phi = self.main(inp) 55 | return self.V(phi) 56 | 57 | 58 | class D_concat2(nn.Module): 59 | def __init__(self, insizes=[1,1], layerSize=100): 60 | super(D_concat2, self).__init__() 61 | self.branchx = nn.Sequential( 62 | nn.Linear(insizes[0], layerSize), 63 | nn.LeakyReLU(), 64 | nn.Linear(layerSize, layerSize), 65 | nn.LeakyReLU(), 66 | ) 67 | self.branchy = nn.Sequential( 68 | nn.Linear(insizes[1], layerSize), 69 | nn.LeakyReLU(), 70 | nn.Linear(layerSize, layerSize), 71 | nn.LeakyReLU(), 72 | ) 73 | self.branchxy = nn.Sequential( 74 | nn.Linear(2*layerSize, layerSize), 75 | nn.LeakyReLU(), 76 | nn.Linear(layerSize, layerSize), 77 | nn.LeakyReLU(), 78 | nn.Linear(layerSize, 1), 79 | ) 80 | def forward(self, x, y): 81 | x = x.view(x.size(0), -1) # bs x D with D >=1 82 | y = y.view(x.size(0), 1) # bs x 1 83 | xy = torch.cat([self.branchx(x), self.branchy(y)], dim=1) 84 | return self.branchxy(xy) 85 | 86 | 87 | class D_concat_first(nn.Module): 88 | def __init__(self, insize=2, layerSize=100, dropout=0.0): 89 | super(D_concat_first, self).__init__() 90 | self.branchxy = nn.Sequential( 91 | nn.Linear(insize, layerSize, bias=False), 92 | nn.ReLU(), 93 | nn.Dropout(p=dropout), 94 | nn.Linear(layerSize, layerSize, bias=False), 95 | nn.ReLU(), 96 | nn.Dropout(p=dropout) 97 | ) 98 | self.last_linear = nn.Linear(layerSize, 1, bias=False) 99 | def forward(self, x, y): 100 | x = x.view(x.size(0), -1) 101 | y = y.view(x.size(0), 1) 102 | 103 | xy = torch.cat([x,y], dim=1) 104 | return self.last_linear(self.branchxy(xy)) 105 | 106 | 107 | class D_supervised_nobias(nn.Module): 108 | def __init__(self, n_inputs, n_outputs, layerSize=100, dropout=0.0, bias=False): 109 | super().__init__() 110 | 111 | self.n_inputs = n_inputs 112 | self.net = nn.Sequential( 113 | nn.Linear(n_inputs, layerSize, bias=bias), 114 | nn.ReLU(), 115 | nn.Dropout(p=dropout), 116 | nn.Linear(layerSize, layerSize, bias=bias), 117 | nn.ReLU(), 118 | nn.Dropout(p=dropout), 119 | nn.Linear(layerSize, 1, bias=bias) 120 | ) 121 | 122 | def forward(self, x): 123 | x = x.view(x.size(0), -1) 124 | return self.net(x) 125 | 126 | 127 | class D_supervised(nn.Module): 128 | def __init__(self, n_inputs, n_outputs, layerSize=100, dropout=0.0, bias=True): 129 | super(D_supervised, self).__init__() 130 | 131 | self.n_inputs = n_inputs 132 | self.net = nn.Sequential( 133 | nn.Linear(n_inputs, layerSize, bias=bias), 134 | nn.ReLU(), 135 | nn.Dropout(p=dropout), 136 | nn.Linear(layerSize, layerSize, bias=bias), 137 | nn.ReLU(), 138 | nn.Dropout(p=dropout), 139 | nn.Linear(layerSize, 1, bias=bias) 140 | ) 141 | self.mask = None # to set after creation 142 | 143 | def set_mask(self, mask): 144 | assert mask.dim() == 1 and mask.size(0) ==self.n_inputs 145 | self.mask = mask.detach().clone().unsqueeze(0) # (1, D) 146 | 147 | def forward(self, x): 148 | x = x.view(x.size(0), -1) 149 | x = x * self.mask # broadcast (1, D) 150 | return self.net(x) 151 | 152 | 153 | def init_D(opt, device): 154 | """Initialize discriminator 155 | """ 156 | if opt.DiscArch == 'phiVpsi': 157 | D = D_phiVpsi([opt.Xdim,1], [opt.layerSizeX, opt.layerSizeY], opt.nonlin, opt.normalization).to(device) 158 | elif opt.DiscArch == 'concat': 159 | D = D_concat([opt.Xdim,1], opt.layerSizeX, opt.nonlin, opt.normalization).to(device) # no separate x,y branches 160 | elif opt.DiscArch == 'concat2': 161 | D = D_concat2([opt.Xdim, 1], opt.layerSize).to(device) # separate x,y branches then merge 162 | elif opt.DiscArch == 'concat_first': 163 | D = D_concat_first(sum([opt.Xdim, 1]), opt.layerSize, opt.dropout).to(device) 164 | elif opt.DiscArch == 'supervised': 165 | D = D_supervised(opt.Xdim, 1, opt.layerSize, opt.dropout).to(device) 166 | elif opt.DiscArch == 'supervised_nobias': 167 | D = D_supervised_nobias(opt.Xdim, 1, opt.layerSize, opt.dropout).to(device) 168 | return D 169 | 170 | 171 | def init_optimizerD(opt, D, train_last_layer_only=False): 172 | """Initialize optimizer for discriminator D 173 | """ 174 | params_to_train = D.parameters() 175 | if train_last_layer_only and opt.DiscArch == 'concat_first': 176 | params_to_train = D.last_linear.parameters() 177 | optimizerD = torch.optim.Adam(params_to_train, lr=opt.lrD, betas=(opt.beta1, opt.beta2), weight_decay=opt.wdecay) 178 | return optimizerD 179 | -------------------------------------------------------------------------------- /output/SINEXP_250.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SIC/c4e45d7736da6e6faabdc56bfc1336445df99204/output/SINEXP_250.png -------------------------------------------------------------------------------- /plot_results.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import isdir, isfile, join 3 | from itertools import chain 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from utils import shelf 7 | 8 | 9 | def dlist(key, dat): 10 | r"""Runs over a list of dictionaries and outputs a list of values corresponding to `key` 11 | Short version (no checks): return np.array([d[key] for d in dat]) 12 | """ 13 | ret = [] 14 | for i, d in enumerate(dat): 15 | if key in d: 16 | ret.append(d[key]) 17 | else: 18 | print('key {} is not in dat[{}]. Skip.'.format(key, i)) 19 | return np.array(ret) 20 | 21 | 22 | def get_data(select_dict, ARGS, key_list, DAT): 23 | data = [] 24 | for sel, key in zip(select_dict, key_list): 25 | # Select DAT 26 | k, v = next(iter(sel.items())) 27 | dat = [da[0] for da in zip(DAT, ARGS) if k in da[1] and da[1][k] == v][0] 28 | data.append(dlist(key, dat)) 29 | return data 30 | 31 | 32 | def color_bplot(bplot, colors): 33 | r"""Color the boxplots""" 34 | for patch, color in zip(bplot['boxes'], colors): 35 | patch.set_facecolor(color) 36 | for median in bplot['medians']: 37 | median.set(color='k', linewidth=1.5,) 38 | 39 | 40 | def label_axis(ax, labels, xpos, ypos, fontsize=16, target_fdr=0.1): 41 | # Partially remove frame 42 | ax.spines['top'].set_visible(False) 43 | ax.spines['right'].set_visible(False) 44 | 45 | # y label 46 | ax.set_ylabel('Power and FDR', fontsize=fontsize) 47 | ax.set_ylim([-0.05, 1.05]) 48 | 49 | # Hortizontal line for target fdr 50 | if target_fdr: 51 | ax.plot(ax.get_xlim(), [target_fdr, target_fdr], '--r') 52 | 53 | # New Axis 54 | new_ax = ax.twiny() 55 | new_ax.set_xticks(xpos) 56 | new_ax.set_xticklabels(labels) 57 | 58 | new_ax.xaxis.set_ticks_position('bottom') # set the position of the second x-axis to bottom 59 | new_ax.xaxis.set_label_position('bottom') # set the position of the second x-axis to bottom 60 | new_ax.spines['bottom'].set_position(('outward', ypos)) # positions below 61 | 62 | # Remove frame for new_ax 63 | new_ax.spines['bottom'].set_visible(False) 64 | new_ax.spines['top'].set_visible(False) 65 | new_ax.spines['left'].set_visible(False) 66 | new_ax.spines['right'].set_visible(False) 67 | 68 | new_ax.tick_params(length=0, labelsize=fontsize) 69 | new_ax.set_xlim(ax.get_xlim()) 70 | 71 | return new_ax 72 | 73 | 74 | if __name__ == "__main__": 75 | # Load data 76 | PATH = 'output/' 77 | DIRS = [d for d in listdir(PATH) if isdir(join(PATH, d))] 78 | FILES = [join(PATH, d, f) for d in DIRS for f in listdir(join(PATH, d)) 79 | if isfile(join(PATH, d, f)) and f[-3:]=='.pt'] 80 | 81 | ARGS, DAT, MODELS = [], [], [] 82 | for f in FILES: 83 | sh = shelf()._load(f) 84 | ARGS.append(sh.args) 85 | if 'd' in sh: 86 | DAT.append(sh['d']) 87 | MODELS.append(sh.args['model']) 88 | else: 89 | print("WARNING: There is no data field d field in file {}. Skip.".format(f)) 90 | continue 91 | 92 | # --------------------------- 93 | # Process data 94 | # --------------------------- 95 | select_dict, key_list, labels, positions, ax_labels, ax_positions = [], [], [], [-2], [], [-2] 96 | # Baseline models 97 | for m, l in zip(['en', 'rf'], ['Elastic Net', 'Random Forest']): 98 | if m in MODELS: 99 | select_dict += 4*[{'model': m}] 100 | key_list += ['tpr_selected', 'fdr_selected', 'hrt_tpr_selected', 'hrt_fdr_selected'] 101 | labels += ['TPR', 'FDR', 'TPR\nHRT', 'FDR\nHRT'] 102 | p = positions[-1] + 2 103 | positions += [1+p, 2+p, 4+p, 5+p] 104 | ax_labels += [l] 105 | ax_positions += [ax_positions[-1] + len(l)/2] 106 | 107 | # Our models 108 | for m, l, pos in zip(['sic_supervised', 'sic'], ['Sobolev Penalty', 'SIC'], [5.5, 4]): 109 | if m in MODELS: 110 | select_dict += 2*[{'model': m}] 111 | key_list += ['hrt_tpr_selected', 'hrt_fdr_selected'] 112 | labels += ['TPR\nHRT', 'FDR\nHRT'] 113 | p = positions[-1] + 2 114 | positions += [1+p, 2+p] 115 | ax_labels += [l] 116 | ax_positions += [ax_positions[-1] + pos] 117 | 118 | positions.pop(0); 119 | ax_positions.pop(0); 120 | 121 | data = get_data(select_dict, ARGS, key_list, DAT) 122 | 123 | # --------------------------- 124 | # Plot 125 | # --------------------------- 126 | dataset = ARGS[0]['dataset'].upper() 127 | n_samples = ARGS[0]['numSamples'] 128 | 129 | fig = plt.figure(figsize=(8, 3)) 130 | ax = plt.subplot(111) 131 | 132 | bplot = plt.boxplot(data, positions=positions, labels=labels, patch_artist=True) 133 | label_axis(ax, ax_labels, ax_positions, 32, fontsize=13) 134 | color_bplot(bplot, len(positions)//2*['lightblue', 'orange']) 135 | 136 | fig.suptitle(f'Dataset {dataset}, N={n_samples}'); 137 | 138 | fig.tight_layout() 139 | fig.savefig(f"output/{dataset}_{n_samples}.png", bbox_inches='tight') 140 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | torchvision==0.3.0 3 | scikit-learn==0.21 4 | pandas==0.25 5 | -------------------------------------------------------------------------------- /run_baselines.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import torch 4 | import numpy as np 5 | 6 | from utils import DDICT, shelf 7 | from stattests import hrt, log_metrics, log_metrics_selected 8 | from datasets.dataset_builder import build_dataset 9 | 10 | from sklearn.linear_model import ElasticNetCV 11 | from sklearn.ensemble import RandomForestRegressor 12 | 13 | # configure event logging 14 | import logging 15 | logging.basicConfig(format='%(asctime)s - %(message)s', 16 | datefmt='%Y-%m-%d %H:%M:%S', 17 | level=logging.INFO, stream=sys.stdout) 18 | logger = logging.getLogger('main') 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--dataset', default='sinexp', help='toy | sinexp | liang | liang_switch | ccle | olfaction | biobank | hiv ') 23 | parser.add_argument('--sinexp-rho', default=0.5, type=float, help='Correlation coefficient between pairs of covariates in SinExp dataset') 24 | parser.add_argument('--sinexp-gaussian', action='store_true', help='Sample covariates from gaussian or uniform in SinExp dataset') 25 | parser.add_argument('--generator-type', default='classify', help='classify | regress') 26 | parser.add_argument('--numSamples', type=int, default=250, help='(for toy) num Samples in train & heldout') 27 | parser.add_argument('--Xdim', type=int, default=50, help='(for toy) X dimensionality') 28 | parser.add_argument('--ftdr-cutoff', default=6, type=int, help='fdr / tpr cutoff') 29 | parser.add_argument('--model', default='en', help='en (elastic net) | rf (random forest)') 30 | parser.add_argument('--do-hrt', action='store_true', help='perform HRT') 31 | parser.add_argument('--hrt-cutoff', type=int, default=20, help='maximal number of features for HRT to evaluate') 32 | parser.add_argument('--target-fdr', default=0.1, type=float, help='target FDR for HRT') 33 | parser.add_argument('--n-runs', type=int, default=100, help='number of repetitons over everything') 34 | parser.add_argument('--data-seed', type=int, default=0, help='initial random seed for data') 35 | 36 | 37 | def train_sklearn_model(dataloaders, sklearn_model=ElasticNetCV, init_kwargs={'cv': 5}, importance_attr='coef_', logger=True): 38 | dl_train_P, dl_train_Q, dl_test_P, dl_test_Q = dataloaders # unpack 39 | X, y = dl_train_P.dataset[:] 40 | y = y.view(-1) 41 | y = np.array(y.view(-1), dtype=np.float) 42 | 43 | model = sklearn_model(**init_kwargs) 44 | model_name = model.__repr__().split('(')[0] 45 | 46 | if logger: logger.info('Start training {}'.format(model_name)) 47 | 48 | model.fit(X, y) 49 | 50 | # heldout eval / printing 51 | X_te, y_te = dl_test_P.dataset[:] 52 | y_te = y_te.view(-1) 53 | y_te = np.array(y_te.view(-1), dtype=np.float) 54 | 55 | loss_tr = ((model.predict(X) - y)**2).mean() 56 | loss_te = ((model.predict(X_te) - y_te)**2).mean() 57 | 58 | if logger: logger.info('{}, training MSE: {:.3f}, test MSE: {:.3f}'.format(model_name, loss_tr, loss_te)) 59 | 60 | etas = torch.FloatTensor(np.abs(getattr(model, importance_attr))) 61 | return model, etas, loss_tr, loss_te 62 | 63 | 64 | # HRT risk function 65 | def risk_model_fn(dataloader): 66 | global model 67 | X, y = dataloader.dataset[:] 68 | X = np.array(X, dtype=np.float) 69 | y = np.array(y.view(-1), dtype=np.float) 70 | mse_loss = ((model.predict(X) - y)**2).mean() 71 | return mse_loss 72 | 73 | 74 | if __name__ == "__main__": 75 | args = parser.parse_args() 76 | 77 | # Initialize input & output dirs 78 | args.outdir = os.path.join('output', 'baselines') 79 | os.makedirs(args.outdir, exist_ok=True) 80 | logger.info(str(args).replace(', ', '\n')) 81 | 82 | # options to pass to data builder 83 | data_opt = DDICT( 84 | dataset=args.dataset, 85 | sinexp_gaussian=args.sinexp_gaussian, 86 | sinexp_rho=args.sinexp_rho, 87 | numSamples=args.numSamples, 88 | Xdim=args.Xdim, 89 | batchSize=100, 90 | num_workers=1, 91 | dataseed=args.data_seed, 92 | ) 93 | 94 | # sklearn model 95 | if args.model == 'en': 96 | sk_model = ElasticNetCV 97 | sk_init_kwargs = {'cv': 5, 'n_jobs': 1} 98 | sk_importance_attr = 'coef_' 99 | 100 | elif args.model == 'rf': 101 | sk_model = RandomForestRegressor 102 | sk_init_kwargs = {'n_estimators': 10, 'n_jobs': 1} 103 | sk_importance_attr = 'feature_importances_' 104 | 105 | else: 106 | raise ValueError("Sklearn model {} not recognized".format(args.model)) 107 | 108 | # Save everything in SHELF 109 | save_filename = os.path.join(args.outdir, args.model) + '_' + args.dataset 110 | SH = shelf(args=args.__dict__) 111 | SH._save(save_filename, date=True) 112 | SH.d = [] 113 | 114 | for n_iter in range(args.n_runs): 115 | logger.info("\n\n* Repetition {} of {}\n".format(n_iter + 1, args.n_runs)) 116 | 117 | RES = dict(n_iter=n_iter) # Results to save 118 | 119 | # Reload dataset (with fresh random seed) 120 | data_opt.dataseed += 1 121 | dataloaders, fea_names, fea_groundtruth = build_dataset(data_opt) 122 | 123 | logger.info(str(dataloaders[0].dataset)) 124 | 125 | # Train model 126 | model, eta_x, loss_tr, loss_te = train_sklearn_model(dataloaders, sk_model, sk_init_kwargs, sk_importance_attr, logger=logger) 127 | RES.update({'loss_tr': loss_tr, 'loss_te': loss_te}) 128 | 129 | # metrics: 130 | RES.update(log_metrics(eta_x, fea_groundtruth, args.ftdr_cutoff, key_prefix='', logger=logger)) 131 | 132 | # HRT 133 | if args.do_hrt: 134 | hrt_sorted_selected_features, hrt_pvals = hrt(eta_x, risk_model_fn, dataloaders, hrt_cutoff=args.hrt_cutoff, target_fdr=args.target_fdr, 135 | generator_type=args.generator_type, n_rounds=1000, logger=logger) 136 | RES.update({'hrt_pvals': hrt_pvals}) 137 | RES.update(log_metrics_selected(hrt_sorted_selected_features, fea_groundtruth, key_prefix='hrt_', logger=logger)) 138 | 139 | SH.d += [RES] 140 | SH._save() 141 | -------------------------------------------------------------------------------- /run_sic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import torch 4 | import torch.nn as nn 5 | 6 | # configure event logging 7 | import logging 8 | logging.basicConfig(format='%(asctime)s - %(message)s', 9 | datefmt='%Y-%m-%d %H:%M:%S', 10 | level=logging.INFO, stream=sys.stdout) 11 | logger = logging.getLogger('main') 12 | 13 | from modules.models import init_D, init_optimizerD 14 | from SIC_imports import eta_optim_step_, sobolev_forward, avg_sobolev_dist, Ep_D 15 | from utils import DDICT, shelf, avg_iterable 16 | from datasets.dataset_builder import build_dataset 17 | from stattests import compute_fdr, compute_tpr, hrt_sobolev, log_metrics_selected, log_metrics 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset', default='sinexp', help='toy | bleitoy | liang | liang_switch | ccle | olfaction | biobank | hiv ') 22 | parser.add_argument('--data-seed', type=int, default=0, help='initial random seed for data') 23 | parser.add_argument('--Yfunction', default='sine', help='sine | linear ') 24 | parser.add_argument('--task', default='PLX4720', help='(ccle): PLX4720 | (olfaction): Bakery') 25 | parser.add_argument('--Xdim', type=int, default=50, help='(for toy) X dimensionality') 26 | parser.add_argument('--numSamples', type=int, default=250, help='(for toy) num Samples in train & heldout') 27 | parser.add_argument('--batchSize', type=int, default=50, help='input batch size') 28 | parser.add_argument('--DiscArch', default='concat_first', help='phiVpsi | concat | concat2') 29 | parser.add_argument('--layerSize', type=int, default=100, help='') 30 | parser.add_argument('--nonlin', default='ReLU', help='') 31 | parser.add_argument('--normalization', default='', help='None | LN layernorm | TODO more options?') 32 | parser.add_argument('--wdecay', type=float, default=1e-4, help='') 33 | parser.add_argument('--lrD', type=float, default=1e-3, help='learning rate for D = Sobolev Mut Info neural estimator') 34 | parser.add_argument('--beta1', type=float, default=0.5, help='Adam optimizer for D: (beta1, beta2)') 35 | parser.add_argument('--beta2', type=float, default=0.999, help='Adam optimizer for D: (beta1, beta2)') 36 | parser.add_argument('--lambdaFisher', type=float, default=0.01, help='lambda on Fisher constraint term E_Q[f^2]') 37 | parser.add_argument('--lambdaSobolev', type=float, default=0.01, help='lambda on Sobolev constraint term') 38 | parser.add_argument('--mu', default='Q', help='Q | P | P+Q. mu: dominant measure on which to constrain expectations.') 39 | parser.add_argument('--eta-lr', type=float, default=0.1, help='lr for eta; in case of L1^2 this is mirror descent scale') 40 | parser.add_argument('--T', type=int, default=200, help='number of updates to D, training duration') 41 | parser.add_argument('--log-every', type=int, default=10, help='interval to log, compute metrics on heldout') 42 | parser.add_argument('--eta-step_type', default='mirror', help='mirror | reduced') 43 | parser.add_argument('--seed', type=int, default=1238, help='random seed') 44 | parser.add_argument('--dataseed', type=int, default=1258, help='random seed for toy datasets') 45 | parser.add_argument('--ftdr-cutoff', default=6, type=int, help='fdr / tpr cutoff') 46 | parser.add_argument('--dropout', default=0.3, type=float, help='Discriminator/Critic dropout') 47 | parser.add_argument('--n-critic', default=1, type=int, help='No. of critic before eta update') 48 | parser.add_argument('--do-hrt', action='store_true', help='perform HRT') 49 | parser.add_argument('--hrt-cutoff', type=int, default=20, help='maximal number of features for HRT to evaluate') 50 | parser.add_argument('--target-fdr', default=0.1, type=float, help='target FDR for HRT') 51 | parser.add_argument('--sinexp-rho', default=0.5, type=float, help='Correlation coefficient between pairs of covariates in SinExp dataset') 52 | parser.add_argument('--sinexp-gaussian', action='store_true', help='Sample covariates from gaussian or uniform in SinExp dataset') 53 | parser.add_argument('--generator-type', default='classify', help='regress | classify') 54 | parser.add_argument('--n-runs', type=int, default=100, help='number of repetitons over everything') 55 | parser.add_argument('--nocuda', action='store_true', help='disables cuda') 56 | 57 | 58 | def train_sobo_critic(opt, dataloaders, D, groundtruth_feat=None, n_epochs=1, train_last_layer_only=False, logger=None): 59 | """ 60 | Args: 61 | opt (DDICT): parameters for training 62 | dataloaders (tuple): list of dataloaders 63 | D (nn.module): Discriminator architecture 64 | """ 65 | dl_train_P, dl_train_Q, dl_test_P, dl_test_Q = dataloaders # unpack 66 | 67 | # Optizer 68 | optimizerD = init_optimizerD(opt, D, train_last_layer_only=train_last_layer_only) 69 | 70 | # etas are initialized uniformly 71 | eta_x = torch.tensor([1 / opt.Xdim] * opt.Xdim, device=next(D.parameters()).device, requires_grad=True) 72 | 73 | # Log architecture 74 | if logger: logger.info(D) 75 | 76 | # Train 77 | if logger: logger.info('Start training') 78 | 79 | for epoch in range(n_epochs): 80 | for batch_idx, (dataP, dataQ) in enumerate(zip(dl_train_P, dl_train_Q)): 81 | n_iter = epoch * len(dl_train_P) + batch_idx 82 | 83 | optimizerD.zero_grad() 84 | if hasattr(eta_x.grad, 'zero_'): 85 | eta_x.grad.zero_() 86 | 87 | sobo_dist, constraint_f2, constraint_Sobo = sobolev_forward(D, eta_x, dataP, dataQ, opt.mu) 88 | 89 | obj_D = - sobo_dist \ 90 | + opt.lambdaFisher * constraint_f2 \ 91 | + (opt.lambdaSobolev / 2) * constraint_Sobo 92 | 93 | obj_D.backward() 94 | optimizerD.step() 95 | 96 | if (n_iter + 1) % opt.n_critic == 0: 97 | eta_optim_step_(eta_x, opt.eta_step_type, opt.eta_lr) 98 | 99 | # eval / logging 100 | if logger and epoch % opt.log_every == 0: 101 | # Average test sobolev distance 102 | sobo_dist_te, constraint_f2_te, constraint_Sobo_te = avg_iterable( 103 | zip(dl_test_P, dl_test_Q), lambda PQ: sobolev_forward(D, eta_x, PQ[0], PQ[1], opt.mu)) 104 | 105 | obj_D_te = - sobo_dist_te \ 106 | + opt.lambdaFisher * constraint_f2_te \ 107 | + (opt.lambdaSobolev / 2) * constraint_Sobo_te 108 | 109 | msg = '[{:5d}] TRAIN: obj_D={:.4f}, sobo-dist={:.4f} TEST: obj_D={:.4f}, sobo-dist={:.4f}'\ 110 | .format(epoch, obj_D.item(), sobo_dist.item(), constraint_Sobo.item(), obj_D_te.item(), sobo_dist_te.item(), constraint_Sobo_te.item()) 111 | 112 | # fdr and tpr 113 | if groundtruth_feat: 114 | _, eta_sortix = torch.sort(eta_x, descending=True) 115 | fdr = compute_fdr(eta_sortix.clone().detach().cpu(), groundtruth_feat, eta_sortix.size(0), cut_off=opt.ftdr_cutoff) 116 | tpr = compute_tpr(eta_sortix.clone().detach().cpu(), groundtruth_feat, eta_sortix.size(0), cut_off=opt.ftdr_cutoff) 117 | msg += ' FDR={:.3f}, TPR={:.3f}'.format(fdr, tpr) 118 | 119 | logger.info(msg) 120 | 121 | sobo_dist_tr = avg_sobolev_dist(D, dl_train_P, dl_train_Q) 122 | sobo_dist_te = avg_sobolev_dist(D, dl_test_P, dl_test_Q) 123 | return D, eta_x, sobo_dist_tr, sobo_dist_te 124 | 125 | 126 | # HRT risk function 127 | def sobolev_dist_fn(dl_P): 128 | global D 129 | return Ep_D(D, dl_P) 130 | 131 | 132 | if __name__ == "__main__": 133 | args = parser.parse_args() 134 | 135 | args.model = 'sic' 136 | 137 | ## Initialize input & output dirs 138 | args.outdir = os.path.join('output', 'sic') 139 | os.makedirs(args.outdir, exist_ok=True) 140 | 141 | # options to pass to data builder 142 | data_opt = DDICT( 143 | dataset=args.dataset, 144 | sinexp_gaussian=args.sinexp_gaussian, 145 | sinexp_rho=args.sinexp_rho, 146 | numSamples=args.numSamples, 147 | Xdim=args.Xdim, 148 | batchSize=args.batchSize, 149 | dataseed=args.data_seed, 150 | ) 151 | 152 | if args.nocuda: 153 | device = torch.device("cpu") 154 | else: 155 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 156 | print ('device:', device) 157 | 158 | logger.info(str(args).replace(', ', '\n')) 159 | torch.manual_seed(args.seed) 160 | 161 | # Save everything in SHELF 162 | save_filename = os.path.join(args.outdir, 'sic') + '_' + args.dataset 163 | SH = shelf(args=args.__dict__) 164 | SH._save(save_filename, date=True) 165 | SH.d = [] 166 | 167 | for n_iter in range(args.n_runs): 168 | logger.info("\n\n* Repetition {} of {}\n".format(n_iter + 1, args.n_runs)) 169 | 170 | RES = dict(n_iter=n_iter) # Results to save 171 | 172 | # Reload dataset (with fresh random seed) 173 | data_opt.dataseed += 1 174 | dataloaders, fea_names, groundtruth_feat = build_dataset(data_opt) 175 | 176 | logger.info(str(dataloaders[0].dataset)) 177 | 178 | # Init and train discriminator 179 | D = init_D(args, device) 180 | D, eta_x, sobo_dist_tr, sobo_dist_te = train_sobo_critic(args, dataloaders, D, groundtruth_feat=groundtruth_feat, n_epochs=args.T, logger=logger) 181 | 182 | RES.update({'sobo_dist_tr': sobo_dist_tr, 'sobo_dist_te': sobo_dist_te}) 183 | 184 | # metrics: 185 | eta_x = eta_x.detach() 186 | RES.update(log_metrics(eta_x, groundtruth_feat, args.ftdr_cutoff, key_prefix='', logger=logger)) 187 | 188 | # HRT 189 | if args.do_hrt: 190 | hrt_sorted_selected_features, hrt_pvals = hrt_sobolev(eta_x, sobolev_dist_fn, dataloaders, hrt_cutoff=args.hrt_cutoff, target_fdr=args.target_fdr, 191 | generator_type=args.generator_type, n_rounds=1000, logger=logger) 192 | RES.update({'hrt_pvals': hrt_pvals}) 193 | RES.update(log_metrics_selected(hrt_sorted_selected_features, groundtruth_feat, key_prefix='hrt_', logger=logger)) 194 | 195 | SH.d += [RES] 196 | SH._save() 197 | -------------------------------------------------------------------------------- /run_sic_supervised.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import torch 4 | 5 | from modules.models import init_D, init_optimizerD 6 | from SIC_imports import compute_mse, eta_optim_step_, supervised_forward_sobolev_penalty 7 | from utils import DDICT, shelf, avg_iterable 8 | from datasets.dataset_builder import build_dataset 9 | 10 | from stattests import hrt, compute_fdr, compute_tpr, log_metrics, log_metrics_selected 11 | 12 | # configure event logging 13 | import logging 14 | logging.basicConfig(format='%(asctime)s - %(message)s', 15 | datefmt='%Y-%m-%d %H:%M:%S', 16 | level=logging.INFO, stream=sys.stdout) 17 | logger = logging.getLogger('main') 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset', default='sinexp', help='toy | bleitoy | liang | liang_switch | ccle | olfaction | biobank | hiv ') 22 | parser.add_argument('--sinexp-rho', default=0.5, type=float, help='Correlation coefficient between pairs of covariates in SinExp dataset') 23 | parser.add_argument('--sinexp-gaussian', action='store_true', help='Sample covariates from gaussian or uniform in SinExp dataset') 24 | parser.add_argument('--data-seed', type=int, default=0, help='initial random seed for data') 25 | parser.add_argument('--generator-type', default='classify', help='classify | regress') 26 | parser.add_argument('--batchSize', type=int, default=50, help='input batch size') 27 | parser.add_argument('--numSamples', type=int, default=250, help='(for toy) num Samples in train & heldout') 28 | parser.add_argument('--Xdim', type=int, default=50, help='(for toy) X dimensionality') 29 | parser.add_argument('--layerSize', type=int, default=100, help='') 30 | parser.add_argument('--wdecay', type=float, default=1e-3, help='') 31 | parser.add_argument('--lrD', type=float, default=1e-3, help='learning rate for D = Sobolev Mut Info neural estimator') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='Adam optimizer for D: (beta1, beta2)') 33 | parser.add_argument('--beta2', type=float, default=0.999, help='Adam optimizer for D: (beta1, beta2)') 34 | parser.add_argument('--dropout', default=0.3, type=float, help='Discriminator/Critic dropout') 35 | parser.add_argument('--lambdaSobolev', type=float, default=0.5, help='lambda on Sobolev constraint term') 36 | parser.add_argument('--eta-lr', type=float, default=0.1, help='lr for eta; in case of L1^2 this is mirror descent scale') 37 | parser.add_argument('--T', type=int, default=201, help='number of updates to D, training duration') 38 | parser.add_argument('--log-every', type=int, default=10, help='interval to log, compute metrics on heldout') 39 | parser.add_argument('--eta-step-type', default='mirror', help='mirror | reduced') 40 | parser.add_argument('--seed', type=int, default=1238, help='random seed') 41 | parser.add_argument('--ftdr-cutoff', default=6, type=int, help='fdr / tpr cutoff') 42 | parser.add_argument('--do-hrt', action='store_true', help='perform HRT') 43 | parser.add_argument('--target-fdr', default=0.1, type=float, help='target FDR for HRT') 44 | parser.add_argument('--hrt-cutoff', type=int, default=20, help='maximal number of features for HRT to evaluate') 45 | parser.add_argument('--n-critic', default=1, type=int, help='No. of critic updates before eta update') 46 | parser.add_argument('--n-runs', type=int, default=100, help='number of repetitons over everything') 47 | parser.add_argument('--nocuda', action='store_true', help='enables cuda') 48 | 49 | 50 | def train_supervised(opt, dataloaders, net, groundtruth_feat=None, n_epochs=1, logger=None): 51 | """ 52 | Args: 53 | opt (DDICT): parameters for training 54 | dataloaders (tuple): list of dataloaders 55 | net (nn.module): initialization for network `net` 56 | """ 57 | dl_train_P, dl_train_Q, dl_test_P, dl_test_Q = dataloaders # unpack 58 | 59 | # Optizer 60 | optimizerD = init_optimizerD(opt, net) 61 | 62 | # etas are initialized uniformly 63 | eta_x = torch.tensor([1 / opt.Xdim] * opt.Xdim, device=next(net.parameters()).device, requires_grad=True) 64 | 65 | if logger: logger.info(net) 66 | 67 | # Train 68 | if logger: logger.info('Start training') 69 | 70 | for epoch in range(n_epochs): 71 | for batch_idx, dataP in enumerate(dl_train_P): 72 | n_iter = epoch * len(dl_train_P) + batch_idx 73 | 74 | optimizerD.zero_grad() 75 | if hasattr(eta_x.grad, 'zero_'): 76 | eta_x.grad.zero_() 77 | 78 | mse_loss, constraint_Sobo = supervised_forward_sobolev_penalty(net, dataP[0], dataP[1], eta_x) 79 | 80 | obj_net = mse_loss + (opt.lambdaSobolev / 2) * constraint_Sobo 81 | 82 | obj_net.backward() 83 | optimizerD.step() 84 | 85 | if (n_iter + 1) % opt.n_critic == 0: 86 | eta_optim_step_(eta_x, opt.eta_step_type, opt.eta_lr) 87 | 88 | # eval / logging 89 | if logger and epoch % opt.log_every == 0: 90 | # Average test performance 91 | mse_loss_te, constraint_Sobo_te = avg_iterable(dl_test_P, lambda d: supervised_forward_sobolev_penalty(net, d[0], d[1], eta_x)) 92 | 93 | msg = '[{:5d}] TRAIN: mse={:.3f}, constr_Sobo={:.3f} TEST: mse={:.3f}, constr_Sobo={:.3f}'\ 94 | .format(epoch, mse_loss, constraint_Sobo, mse_loss_te, constraint_Sobo_te.item()) 95 | 96 | # fdr and tpr 97 | if groundtruth_feat: 98 | _, eta_sortix = torch.sort(eta_x, descending=True) 99 | fdr = compute_fdr(eta_sortix.clone().detach().cpu(), groundtruth_feat, eta_sortix.size(0), cut_off=opt.ftdr_cutoff) 100 | tpr = compute_tpr(eta_sortix.clone().detach().cpu(), groundtruth_feat, eta_sortix.size(0), cut_off=opt.ftdr_cutoff) 101 | msg += ' FDR={:.3f}, TPR={:.3f}'.format(fdr, tpr) 102 | 103 | logger.info(msg) 104 | 105 | # End of training - saving and logging 106 | loss_tr = compute_mse(dl_train_P, net).item() 107 | loss_te = compute_mse(dl_test_P, net).item() 108 | 109 | if logger: logger.info('training MSE: {:.3f}, test MSE: {:.3f}'.format(loss_tr, loss_te)) 110 | 111 | return net, eta_x, loss_tr, loss_te 112 | 113 | 114 | # HRT risk function 115 | def risk_model_fn(dataloader): 116 | global net 117 | return compute_mse(dataloader, net).item() 118 | 119 | 120 | if __name__ == "__main__": 121 | args = parser.parse_args() 122 | 123 | args.model = 'sic_supervised' 124 | args.DiscArch = 'supervised_nobias' 125 | 126 | # Initialize input & output dirs 127 | args.outdir = os.path.join('output', 'sic_supervised') 128 | os.makedirs(args.outdir, exist_ok=True) 129 | 130 | # options to pass to data builder 131 | data_opt = DDICT( 132 | dataset=args.dataset, 133 | sinexp_gaussian=args.sinexp_gaussian, 134 | sinexp_rho=args.sinexp_rho, 135 | numSamples=args.numSamples, 136 | Xdim=args.Xdim, 137 | batchSize=args.batchSize, 138 | dataseed=args.data_seed, 139 | ) 140 | 141 | if args.nocuda: 142 | device = torch.device("cpu") 143 | else: 144 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 145 | print('device:', device) 146 | 147 | logger.info(str(args).replace(', ', '\n')) 148 | torch.manual_seed(args.seed) 149 | 150 | # Save everything in SHELF 151 | save_filename = os.path.join(args.outdir, 'mse_sobo') + '_' + args.dataset 152 | SH = shelf(args=args.__dict__) 153 | SH._save(save_filename, date=True) 154 | SH.d = [] 155 | 156 | for n_iter in range(args.n_runs): 157 | logger.info("\n\n* Repetition {} of {}\n".format(n_iter + 1, args.n_runs)) 158 | 159 | RES = dict(n_iter=n_iter) # Results to save 160 | 161 | # Reload dataset (with fresh random seed) 162 | data_opt.dataseed += 1 163 | dataloaders, fea_names, groundtruth_feat = build_dataset(data_opt) 164 | 165 | logger.info(str(dataloaders[0].dataset)) 166 | 167 | # Init and train model 168 | net = init_D(args, device) 169 | net, eta_x, loss_tr, loss_te = train_supervised(args, dataloaders, net, groundtruth_feat=groundtruth_feat, n_epochs=args.T, logger=logger) 170 | 171 | RES.update({'loss_tr': loss_tr, 'loss_te': loss_te}) 172 | 173 | # metrics: 174 | eta_x = eta_x.detach() 175 | RES.update(log_metrics(eta_x, groundtruth_feat, args.ftdr_cutoff, key_prefix='', logger=logger)) 176 | 177 | # HRT 178 | if args.do_hrt: 179 | hrt_sorted_selected_features, hrt_pvals = hrt(eta_x, risk_model_fn, dataloaders, hrt_cutoff=args.hrt_cutoff, target_fdr=args.target_fdr, 180 | generator_type=args.generator_type, n_rounds=1000, logger=logger) 181 | RES.update({'hrt_pvals': hrt_pvals}) 182 | RES.update(log_metrics_selected(hrt_sorted_selected_features, groundtruth_feat, key_prefix='hrt_', logger=logger)) 183 | 184 | SH.d += [RES] 185 | SH._save() 186 | -------------------------------------------------------------------------------- /slides/sic_neurips_slides.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SIC/c4e45d7736da6e6faabdc56bfc1336445df99204/slides/sic_neurips_slides.key -------------------------------------------------------------------------------- /slides/sic_neurips_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SIC/c4e45d7736da6e6faabdc56bfc1336445df99204/slides/sic_neurips_slides.pdf -------------------------------------------------------------------------------- /stattests.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from modules.generators import generator_from_data 5 | from sklearn.metrics import roc_auc_score 6 | from utils import log_to_dict 7 | 8 | 9 | def hrt_sobolev(etas, Ep_D, dataloaders, hrt_cutoff=None, target_fdr=0.1, generator_type='classify', n_rounds=1000, logger=None): 10 | r"""Performs Holdout Randomization Test from Tansey et al. (http://arxiv.org/abs/1811.00645) 11 | using Sobolev distance as risk model 12 | 13 | Args: 14 | etas (torch.Tensor): sequence of etas 15 | Ep_D (functin): integration of discriminator 16 | dataloaders (list): list of dataloaders 17 | hrt_cutoff (int): number of top etas to consider 18 | """ 19 | dl_train_P, dl_train_Q, dl_test_P, dl_test_Q = dataloaders # unpack 20 | 21 | _, eta_sortix = torch.sort(etas, descending=True) 22 | 23 | # Only consider top `hrt_cutoff` features 24 | if hrt_cutoff is not None: 25 | eta_sortix = eta_sortix[:hrt_cutoff] 26 | 27 | # Compute first term of Sobolev distance (notice the '-', since it's a sup) 28 | t_h = -Ep_D(dl_test_P) 29 | 30 | # Instantiate generator 31 | if logger: logger.info('HRT: training generator') 32 | generator, _ = generator_from_data(dl_train_P.dataset, generator_type) 33 | 34 | # Loop over features 35 | p_vals = [] 36 | if logger: logger.info('HRT: analyzing {} features'.format(len(eta_sortix))) 37 | 38 | for i, j in enumerate(eta_sortix): 39 | if logger: logger.info(' HRT testing feature {}\t({} of {})'.format(j, i + 1, len(eta_sortix))) 40 | t_j = [] 41 | gen_j_data = generator.get_dataloader(dl_test_P, j) 42 | for r in range(n_rounds): 43 | # Sample from P_{j|-j} 44 | gen_j_data.dataset.resample_replaced_feature() 45 | # Compute empirical risk (notice the '-') 46 | t_j.append(-Ep_D(gen_j_data)) 47 | p_vals.append(get_pval(t_h, t_j)) 48 | if logger: logger.info(' eta={:.3f} (p={:.3f})'.format(etas[j], p_vals[-1])) 49 | 50 | # Use Benjamini-Hochberg to calibrate FDR 51 | idx_sorted_selected = bh(p_vals, target_fdr) 52 | sorted_selected_features = np.array([eta_sortix[j].item() for j in idx_sorted_selected]) 53 | 54 | return sorted_selected_features, np.array(p_vals) 55 | 56 | 57 | def hrt(etas, risk_model, dataloaders, hrt_cutoff=None, target_fdr=0.1, generator_type='regress', n_rounds=1000, logger=None): 58 | r"""Performs Holdout Randomization Test from Tansey et al. (http://arxiv.org/abs/1811.00645) 59 | 60 | Args: 61 | etas (torch.Tensor): etas to test 62 | risk_model (function): returns the risk of the model given a dataloader 63 | dataloaders (list): list of dataloaders 64 | 65 | Returns: 66 | """ 67 | dl_train_P, dl_train_Q, dl_test_P, dl_test_Q = dataloaders # unpack 68 | 69 | eta_sortix = torch.argsort(etas, descending=True) 70 | if hrt_cutoff is None: 71 | hrt_cutoff = len(etas) 72 | 73 | # Cut off all etas that are zero by setting cut_off at fist zero eta 74 | if etas[eta_sortix[hrt_cutoff]] == 0.0: 75 | hrt_cutoff = torch.nonzero(etas[eta_sortix] <= 0.0)[0].item() 76 | 77 | eta_sortix = eta_sortix[:hrt_cutoff] 78 | 79 | # Compute risk 80 | t_h = risk_model(dl_test_P) 81 | 82 | # Instantiate generator 83 | if logger: logger.info('HRT: training generator') 84 | generator, _ = generator_from_data(dl_train_P.dataset, generator_type) 85 | 86 | # Loop over features 87 | p_vals = [] 88 | if logger: logger.info('HRT: analyzing {} features'.format(len(eta_sortix))) 89 | 90 | for i, j in enumerate(eta_sortix): 91 | if logger: logger.info(' HRT testing feature {}\t({} of {})'.format(j, i + 1, len(eta_sortix))) 92 | t_j = [] 93 | gen_j_data = generator.get_dataloader(dl_test_P, j) 94 | for r in range(n_rounds): 95 | # Sample from P_{j|-j} 96 | gen_j_data.dataset.resample_replaced_feature() 97 | # Compute empirical risk 98 | t_j.append(risk_model(gen_j_data)) 99 | p_vals.append(get_pval(t_h, t_j)) 100 | if logger: logger.info(' eta={:.3f} (p={:.3f})'.format(etas[j], p_vals[-1])) 101 | 102 | # Use Benjamini-Hochberg to calibrate FDR 103 | idx_sorted_selected = bh(p_vals, target_fdr) 104 | sorted_selected_features = np.array([eta_sortix[j].item() for j in idx_sorted_selected]) 105 | 106 | return sorted_selected_features, np.array(p_vals) 107 | 108 | 109 | def compute_fdr(eta_sortix, ground_truth, total_features, axis=None, cut_off=40): 110 | r"""Computes FDR 111 | """ 112 | truth = np.zeros(total_features) 113 | pred = np.zeros(total_features) 114 | 115 | assert len(ground_truth)<=cut_off, "Cut-off feature length more than ground_truth features " 116 | 117 | eta_sortix = eta_sortix[:cut_off] 118 | truth[ground_truth] = 1 119 | pred[eta_sortix] = 1 120 | 121 | return ((pred==1) & (truth==0)).sum(axis=axis) / pred.sum(axis=axis).astype(float).clip(1,np.inf) 122 | 123 | 124 | def compute_tpr(eta_sortix, ground_truth, total_features, axis=None, cut_off=40): 125 | r"""Computes TPR 126 | """ 127 | truth = np.zeros(total_features) 128 | pred = np.zeros(total_features) 129 | 130 | assert len(ground_truth)<=cut_off, "Cut-off feature length more than ground_truth features " 131 | 132 | eta_sortix = eta_sortix[:cut_off] 133 | truth[ground_truth] = 1 134 | pred[eta_sortix] = 1 135 | 136 | return ((pred==1) & (truth==1)).sum(axis=axis) / truth.sum(axis=axis).astype(float).clip(1,np.inf) 137 | 138 | 139 | 140 | def get_pval(t, t_list, invert=False): 141 | r"""Computes p-values by counting how often elements of `t_list` are below `t` (or above, if `invert` is True) 142 | 143 | Args: 144 | t (float): statistics 145 | t_list (list): samples from null-distribution 146 | invert (bool): whether to count t >= t_list or t <= t_list 147 | `invert` should be `False` if `t` is a cost, and should be `True` if it is an importance weight 148 | """ 149 | K = len(t_list) 150 | if invert: 151 | total = 1 + (t <= np.array(t_list)).sum() 152 | else: 153 | total = 1 + (t >= np.array(t_list)).sum() 154 | return total / (K + 1) 155 | 156 | 157 | def auc_score(scores, ground_truth): 158 | r"""Computes the roc auc score from scores and a list of the ground truth discoveries 159 | """ 160 | y_scores = np.array(scores).reshape(-1) 161 | y_true = np.zeros(len(y_scores)) 162 | y_true[ground_truth] = 1 163 | auc = roc_auc_score(y_true, y_scores) 164 | return auc 165 | 166 | 167 | def fdr_tpr(selected_features, ground_truth): 168 | r"""Computes the FDR and TPR @ discoveries number 169 | """ 170 | n_discoveries = len(selected_features) 171 | P = len(ground_truth) 172 | TP = len(set(selected_features).intersection(set(ground_truth))) 173 | tpr = TP / P 174 | fdr = len(set(selected_features) - set(ground_truth)) / n_discoveries 175 | return fdr, tpr 176 | 177 | 178 | def auc_fdr_tpr_curves(scores, ground_truth): 179 | if isinstance(scores, torch.Tensor): 180 | scores = np.array(scores.cpu().detach(), dtype=np.float) 181 | scores = scores.reshape(-1) 182 | 183 | fdr_curve, tpr_curve = np.zeros(len(scores)), np.zeros(len(scores)) 184 | scores_sortix = np.argsort(-scores) 185 | for i in range(len(scores_sortix)): 186 | fdr_curve[i], tpr_curve[i] = fdr_tpr(scores_sortix[:i + 1], ground_truth) 187 | auc = auc_score(scores, ground_truth) 188 | return auc, fdr_curve, tpr_curve 189 | 190 | 191 | def selected_fdr_tpr_curves(sorted_selected_features, ground_truth): 192 | r"""Takes a list of sorted selected features and the ground truth dicoveries, and build fdr and tpr curves 193 | 194 | Args: 195 | sorted_selected_features (list): list of selected features sorted from higher to lower importance 196 | ground_truth (list): list of groud truth dicoveries 197 | """ 198 | sorted_selected_features = np.array(sorted_selected_features) 199 | 200 | fdr_curve, tpr_curve = np.zeros(len(sorted_selected_features)), np.zeros(len(sorted_selected_features)) 201 | for i in range(len(fdr_curve)): 202 | fdr_curve[i], tpr_curve[i] = fdr_tpr(sorted_selected_features[:i + 1], ground_truth) 203 | return fdr_curve, tpr_curve 204 | 205 | 206 | def bh(p, fdr): 207 | r"""Performs Benjamini-hochberg 208 | """ 209 | p_orders = np.argsort(p) 210 | discoveries = [] 211 | m = float(len(p_orders)) 212 | for k, s in enumerate(p_orders): 213 | if p[s] <= (k + 1) / m * fdr: 214 | discoveries.append(s) 215 | else: 216 | break 217 | return np.array(discoveries) 218 | 219 | 220 | def log_metrics(etas, fea_groundtruth, ftdr_cutoff, key_prefix='', logger=None): 221 | if isinstance(etas, torch.Tensor): 222 | etas = np.array(etas.cpu().detach(), dtype=np.float) 223 | etas = etas.reshape(-1) 224 | 225 | _, fdr_curve, tpr_curve = auc_fdr_tpr_curves(etas, fea_groundtruth) 226 | fdr_at, tpr_at = fdr_curve[ftdr_cutoff - 1], tpr_curve[ftdr_cutoff - 1] 227 | 228 | eta_sortix = np.argsort(-etas) 229 | selected_features = eta_sortix[np.nonzero(etas[eta_sortix] > 0)[0]] 230 | if logger: 231 | logger.info(" {} selected features {}: {}".format(key_prefix, len(selected_features), selected_features)) 232 | logger.info(" {}FDR @ {}: {:.3f}".format(key_prefix, ftdr_cutoff, fdr_at)) 233 | logger.info(" {}TPR @ {}: {:.3f}".format(key_prefix, ftdr_cutoff, tpr_at)) 234 | 235 | if len(selected_features) > 0: 236 | fdr_selected = fdr_curve[len(selected_features) - 1] 237 | tpr_selected = tpr_curve[len(selected_features) - 1] 238 | else: 239 | fdr_selected, tpr_selected = 0.0, 0.0 240 | 241 | if logger: 242 | logger.info(" {}FDR @ selected: {:.3f}".format(key_prefix, fdr_selected)) 243 | logger.info(" {}TPR @ selected: {:.3f}".format(key_prefix, tpr_selected)) 244 | 245 | return log_to_dict(['selected_features', 'fdr_curve', 'tpr_curve', 'fdr_at', 'tpr_at', 'fdr_selected', 'tpr_selected'], 246 | locals(), key_prefix=key_prefix) 247 | 248 | 249 | def log_metrics_selected(selected_features, fea_groundtruth, key_prefix='', logger=None): 250 | r"""FDR and TPR are computed at len(sorted_selected_features) 251 | """ 252 | if hasattr(selected_features, 'reshape'): 253 | selected_features = selected_features.reshape(-1).tolist() 254 | 255 | n_features = len(selected_features) 256 | 257 | fdr_curve, tpr_curve = selected_fdr_tpr_curves(selected_features, fea_groundtruth) 258 | 259 | if n_features > 0: 260 | fdr_selected, tpr_selected = fdr_curve[-1], tpr_curve[-1] 261 | else: 262 | fdr_selected, tpr_selected = 0.0, 0.0 263 | 264 | if logger: 265 | logger.info(" {} selected features {}: {}".format(key_prefix, n_features, selected_features)) 266 | logger.info(" {}FDR @ selected: {:.3f}".format(key_prefix, fdr_selected)) 267 | logger.info(" {}TPR @ selected: {:.3f}".format(key_prefix, tpr_selected)) 268 | 269 | return log_to_dict(['selected_features', 'fdr_curve', 'tpr_curve', 'fdr_selected', 'tpr_selected'], 270 | locals(), key_prefix=key_prefix) 271 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | def run(verbose=False): 2 | import nose 3 | from os import path 4 | currentdir = path.dirname(__file__) 5 | updir = path.join(currentdir, '..') 6 | argv = ['', '--exe', '-w', updir] 7 | if verbose: 8 | argv.append('--verbose') 9 | nose.run('SIC', argv=argv) 10 | -------------------------------------------------------------------------------- /test/test_cbatchnorm.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | import torch 4 | from torch.autograd import gradcheck 5 | 6 | 7 | class test_modules(unittest.TestCase): 8 | 9 | def test_conditional_batchnorm(self): 10 | from modules.generators import ConditionalBatchNorm1d 11 | 12 | torch.manual_seed(1) 13 | 14 | num_classes = 10 15 | num_features = 20 16 | num_samples = 100 17 | 18 | cbn = ConditionalBatchNorm1d(num_features, num_classes) 19 | 20 | # cast parameters to double precision 21 | cbn.embed.weight.data = cbn.embed.weight.data.double() 22 | cbn.bn.running_mean = cbn.bn.running_mean.double() 23 | cbn.bn.running_var = cbn.bn.running_var.double() 24 | 25 | for _ in range(10): 26 | x = torch.randn(num_samples, num_features, dtype=torch.float64, requires_grad=True) 27 | labels = torch.randint(0, num_classes, (num_samples,)) 28 | func = lambda x: cbn.forward((x, labels)) 29 | self.assertTrue(gradcheck(func, (x,), eps=1e-4, atol=1e-3)) 30 | 31 | self.assertEqual(x.shape, cbn((x, labels)).shape) 32 | 33 | 34 | if __name__ == '__main__': 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /test/test_generators.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import TensorDataset 8 | from torch.utils.data.dataloader import DataLoader 9 | from modules.linregression import linreg_reconstruct 10 | 11 | from utils import DDICT 12 | from datasets.dataset_builder import build_dataset 13 | 14 | 15 | args = DDICT( 16 | batch_size=128, 17 | lr=0.003, 18 | n_layers=4, 19 | n_features=10, 20 | n_hiddens=200, 21 | num_bins=200, 22 | p_dropout=0.5, 23 | epochs=100, 24 | n_samples=1000 25 | ) 26 | 27 | def generate_correlated_data(args, corr=0.5): 28 | # Correlated Gaussians 29 | CC = (1 - corr) * torch.eye(args.n_features) + corr * torch.ones(args.n_features, args.n_features) 30 | 31 | gauss_data = torch.distributions.MultivariateNormal(torch.zeros(args.n_features), CC) 32 | tr_loader = DataLoader(TensorDataset(gauss_data.sample((args.n_samples,))), batch_size=args.batch_size, shuffle=False) 33 | 34 | gauss_data = torch.distributions.MultivariateNormal(torch.zeros(args.n_features), CC) 35 | te_loader = DataLoader(TensorDataset(gauss_data.sample((args.n_samples,))), batch_size=args.batch_size, shuffle=False) 36 | 37 | return tr_loader, te_loader 38 | 39 | 40 | def compare_generator_to_linreg(gen, dataloader, idx_feature): 41 | gen.eval() 42 | 43 | with torch.no_grad(): 44 | # True data 45 | X = dataloader.dataset[:][0] 46 | Xj = X[:, idx_feature] 47 | 48 | lin_reg = linreg_reconstruct(X, idx_feature) 49 | 50 | # Generator output 51 | out = gen.get_dataloader(dataloader, idx_feature).dataset[:][0] 52 | 53 | # check that non-generated features are the same 54 | idx = list(set(range(X.shape[1])) - set([idx_feature])) 55 | assert (X[:, idx] - out[:, idx]).abs().sum().item() == 0, "Non generated features are different" 56 | 57 | # Cut out predicted feature 58 | out_pred = out[:, idx_feature] 59 | err_gen = (out_pred - Xj).abs().mean().item() 60 | 61 | # Linear regression 62 | out_linreg = lin_reg(X[:, idx]).view(-1) 63 | err_linreg = (out_linreg - Xj).abs().mean().item() 64 | 65 | return err_gen, err_linreg 66 | 67 | 68 | class test_generators(unittest.TestCase): 69 | 70 | def test_gerator_oracle(self): 71 | from modules.generators import generator_from_data 72 | 73 | torch.manual_seed(1) 74 | 75 | # options to pass to data builder 76 | data_opt = DDICT( 77 | dataset='sinexp', 78 | sinexp_gaussian=False, 79 | numSamples=500, 80 | Xdim=50, 81 | batchSize=100, 82 | num_workers=1, 83 | dataseed=0, 84 | ) 85 | 86 | # rho = 1.0 87 | data_opt.sinexp_rho = 1.0 88 | dataloaders, _, _ = build_dataset(data_opt) 89 | dl_train_P, dl_train_Q, dl_test_P, dl_test_Q = dataloaders # unpack 90 | 91 | gen, _ = generator_from_data(dl_train_P.dataset, 'oracle') 92 | 93 | gen_dl = gen.get_dataloader(dl_test_P, 0) 94 | 95 | self.assertAlmostEqual((gen_dl.dataset[:][0] - dl_test_P.dataset[:][0]).abs().mean().item(), 0.0, places=6) 96 | 97 | # rho = 0.5 98 | data_opt.sinexp_rho = 0.5 99 | dataloaders, _, _ = build_dataset(data_opt) 100 | dl_train_P, dl_train_Q, dl_test_P, dl_test_Q = dataloaders # unpack 101 | 102 | gen, _ = generator_from_data(dl_train_P.dataset, 'oracle') 103 | 104 | gen_dl = gen.get_dataloader(dl_test_P, 0) 105 | 106 | self.assertEqual((gen_dl.dataset[:][0][:, 1:] - dl_test_P.dataset[:][0][:, 1:]).abs().mean().item(), 0.0) 107 | 108 | 109 | def test_gerator_classify(self): 110 | from modules.generators import GeneratorClassify 111 | from modules.generators import train_generator, test_generator 112 | 113 | tr_loader, te_loader = generate_correlated_data(args) 114 | gen_cl = GeneratorClassify(args.n_features, args.n_layers, args.n_hiddens, num_bins=args.num_bins, init_dataset=tr_loader, p_dropout=args.p_dropout) 115 | optimizer = optim.Adam(gen_cl.parameters(), lr=args.lr) 116 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 117 | 118 | for epoch in range(args.epochs): 119 | tr_loss = train_generator(gen_cl, tr_loader, optimizer, log_times=0) 120 | te_loss = test_generator(gen_cl, nn.MSELoss(), te_loader) 121 | scheduler.step() 122 | print('{}: training loss = {:.3}\t test loss (mse) = {:.3}'.format(epoch, tr_loss, te_loss)) 123 | 124 | for idx_feature in range(args.n_features): 125 | err_gen_cl, err_linreg = compare_generator_to_linreg(gen_cl, tr_loader, idx_feature) 126 | self.assertGreater(2.0 * err_linreg, err_gen_cl, 'Error of generator_classify is substantially higher than linear regression') 127 | 128 | 129 | def test_gerator_regress(self): 130 | from modules.generators import GeneratorRegress 131 | from modules.generators import train_generator, test_generator 132 | 133 | tr_loader, te_loader = generate_correlated_data(args) 134 | gen_reg = GeneratorRegress(args.n_features, args.n_layers, args.n_hiddens, p_dropout=args.p_dropout) 135 | optimizer = optim.Adam(gen_reg.parameters(), lr=args.lr) 136 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 137 | 138 | for epoch in range(args.epochs): 139 | tr_loss = train_generator(gen_reg, tr_loader, optimizer, log_times=0) 140 | te_loss = test_generator(gen_reg, nn.MSELoss(), te_loader) 141 | scheduler.step() 142 | print('{}: training loss = {:.3}\t test loss (mse) = {:.3}'.format(epoch, tr_loss, te_loss)) 143 | 144 | for idx_feature in range(args.n_features): 145 | err_gen_reg, err_linreg = compare_generator_to_linreg(gen_reg, tr_loader, idx_feature) 146 | self.assertGreater(2.0 * err_linreg, err_gen_reg, 'Error of generator_regress is substantially higher than linear regression') 147 | 148 | 149 | if __name__ == '__main__': 150 | unittest.main() 151 | -------------------------------------------------------------------------------- /test/test_toydataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class test_toydataset(unittest.TestCase): 8 | 9 | def test_sinexp(self): 10 | from datasets.toy_dataset import SinExpDataset 11 | 12 | torch.manual_seed(1) 13 | 14 | # Check consistency across random seeds 15 | ds1 = SinExpDataset(n_samples=1000, n_features=50, gaussian=False, seed=123) 16 | ds2 = SinExpDataset(n_samples=1000, n_features=50, gaussian=False, seed=123) 17 | 18 | self.assertAlmostEqual((ds1.data - ds2.data).abs().mean().item(), 0.0) 19 | self.assertAlmostEqual((ds1.targets - ds2.targets).abs().mean().item(), 0.0) 20 | 21 | # Check normalization for rho=0 22 | for gaussian in [True, False]: 23 | for n_features in [6, 20]: 24 | ds = SinExpDataset(n_samples=10000000, rho=0.0, n_features=n_features, gaussian=gaussian, seed=1) 25 | 26 | self.assertAlmostEqual(ds.data.mean().item(), 0.0, places=2) 27 | self.assertAlmostEqual(ds.data.std().item(), 1.0, places=2) 28 | 29 | self.assertAlmostEqual(ds.targets.mean().item(), 0.0, places=2) 30 | self.assertAlmostEqual(ds.targets.std().item(), 1.0, places=2) 31 | 32 | # Check normalization of only features for rho=0.5 33 | for gaussian in [True, False]: 34 | for n_features in [6, 20]: 35 | ds = SinExpDataset(n_samples=10000000, rho=0.0, n_features=n_features, gaussian=gaussian, seed=1) 36 | 37 | self.assertAlmostEqual(ds.data.mean().item(), 0.0, places=2) 38 | self.assertAlmostEqual(ds.data.std().item(), 1.0, places=2) 39 | 40 | # Check snr of y = f(X) relation 41 | def fn(X): 42 | y = np.sin(X[:, 0] * (X[:, 0] + X[:, 1])) * np.cos(X[:, 2] + X[:, 3] * X[:, 4]) *\ 43 | np.sin(np.exp(X[:, 4]) + np.exp(X[:, 5]) - X[:, 1]) 44 | return y 45 | 46 | for rho in [0.0, 0.5]: 47 | for n_features in [6, 20]: 48 | ds = SinExpDataset(n_samples=1000, rho=rho, n_features=n_features, gaussian=True, seed=1) 49 | 50 | X = ds.data * ds.data_sd + ds.data_mu 51 | y = ds.targets * ds.targets_sd + ds.targets_mu 52 | 53 | noise = ((y - fn(X))**2).mean().item() 54 | signal = (fn(X)**2).mean().item() 55 | 56 | self.assertGreater(signal, 1.5 * noise, msg="signal-to-noise smaller than 1.5, for Gaussian features and rho={}".format(rho)) 57 | self.assertGreater(2.5 * noise, signal, msg="signal-to-noise greater than 2.5, for Gaussian features and rho={}".format(rho)) 58 | 59 | for rho in [0.0, 0.5]: 60 | for n_features in [6, 20]: 61 | ds = SinExpDataset(n_samples=1000, rho=rho, n_features=n_features, gaussian=False, seed=1) 62 | 63 | X = ds.data * ds.data_sd + ds.data_mu 64 | y = ds.targets * ds.targets_sd + ds.targets_mu 65 | 66 | noise = ((y - fn(X))**2).mean().item() 67 | signal = (fn(X)**2).mean().item() 68 | 69 | self.assertGreater(signal, 1.0 * noise, msg="signal-to-noise smaller than 1.0, for Uniform features and rho={}".format(rho)) 70 | self.assertGreater(3.0 * noise, signal, msg="signal-to-noise greater than 3.0, for Uniform features and rho={}".format(rho)) 71 | 72 | 73 | if __name__ == '__main__': 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from random import randint 4 | from functools import reduce 5 | import torch 6 | from torch import nn 7 | import numpy as np 8 | 9 | 10 | def describe(t): 11 | """Returns a string describing an array 12 | Args 13 | t (numpy.array or torch.tensor): array of data 14 | 15 | Returns 16 | string describing array t 17 | """ 18 | t = t.data if isinstance(t, torch.Tensor) else t 19 | s = '{:8s} [{:.4f} , {:.4f}] m+-s = {:.4f} +- {:.4f}' 20 | si = 'x'.join(map(str, t.shape if isinstance(t, np.ndarray) else t.size())) 21 | return s.format(si, t.min(), t.max(), t.mean(), t.std()) 22 | 23 | 24 | def log_current_variables(tbw, n_iter, all_data, keys_to_log, key_prefix='', tb_trunc_tensor_size=10): 25 | for k in keys_to_log: 26 | v = all_data[k] 27 | logkey = key_prefix + k 28 | if isinstance(v, torch.Tensor): 29 | v = v.detach().cpu() # get out of autograd 30 | if v.numel() == 1: 31 | tbw.add_scalar(logkey, v.item(), n_iter) 32 | elif v.dim() == 1 and v.numel() <= tb_trunc_tensor_size: 33 | # log as scalar group [0, D[. make dictionary. 34 | v = {str(i): v[i].item() for i in range(len(v))} 35 | tbw.add_scalars(logkey, v, n_iter) 36 | else: 37 | vtrunc = {str(i): v[i].item() for i in range(tb_trunc_tensor_size)} 38 | tbw.add_scalars(logkey, vtrunc, n_iter) 39 | tbw.add_histogram(logkey, v.numpy(), n_iter) 40 | else: 41 | tbw.add_scalar(logkey, v, n_iter) 42 | 43 | 44 | class DDICT: 45 | """DotDictionary, dictionary whose items can be accesses with the dot operator 46 | 47 | E.g. 48 | >> args = DDICT(batch_size=128, epochs=10) 49 | >> print(args.batch_size) 50 | """ 51 | def __init__(self, **kwds): 52 | self.__dict__.update(kwds) 53 | 54 | def __repr__(self): 55 | return str(self.__dict__) 56 | 57 | def __iter__(self): 58 | return self.__dict__.__iter__() 59 | 60 | def __len__(self): 61 | return len(self.__dict__) 62 | 63 | def __setitem__(self, key, value): 64 | self.__dict__[key] = value 65 | 66 | def __getitem__(self, key): 67 | return self.__dict__[key] 68 | 69 | 70 | def get_devices(cuda_device="cuda:0", seed=1): 71 | """Gets cuda devices 72 | """ 73 | device = torch.device(cuda_device) 74 | torch.manual_seed(seed) 75 | # Multi GPU? 76 | num_gpus = torch.cuda.device_count() 77 | if device.type != 'cpu': 78 | print('\033[93m' + 'Using CUDA,', num_gpus, 'GPUs\033[0m') 79 | torch.cuda.manual_seed(seed) 80 | return device, num_gpus 81 | 82 | 83 | def make_data_parallel(module, expose_methods=None): 84 | """Wraps `nn.Module object` into `nn.DataParallel` and links methods whose name is listed in `expose_methods` 85 | """ 86 | dp_module = nn.DataParallel(module) 87 | 88 | if expose_methods is None: 89 | if hasattr(module, 'expose_methods'): 90 | expose_methods = module.expose_methods 91 | 92 | if expose_methods is not None: 93 | for mt in expose_methods: 94 | setattr(dp_module, mt, getattr(dp_module.module, mt)) 95 | return dp_module 96 | 97 | 98 | class shelf(object): 99 | '''Shelf to save stuff to disk. Basically a DDICT which can save to disk. 100 | 101 | Example: 102 | SH = shelf(lr=[0.1, 0.2], n_hiddens=[100, 500, 1000], n_layers=2) 103 | SH._extend(['lr', 'n_hiddens'], [[0.3, 0.4], [2000]]) 104 | # Save to file: 105 | SH._save('my_file', date=False) 106 | # Load shelf from file: 107 | new_dd = shelf()._load('my_file') 108 | ''' 109 | def __init__(self, **kwargs): 110 | self.__dict__.update(kwargs) 111 | 112 | def __add__(self, other): 113 | if isinstance(other, type(self)): 114 | sum_dct = copy.copy(self.__dict__) 115 | for k, v in other.__dict__.items(): 116 | if k not in sum_dct: 117 | sum_dct[k] = v 118 | else: 119 | if type(v) is list and type(sum_dct[k]) is list: 120 | sum_dct[k] = sum_dct[k] + v 121 | elif type(v) is not list and type(sum_dct[k]) is list: 122 | sum_dct[k] = sum_dct[k] + [v] 123 | elif type(v) is list and type(sum_dct[k]) is not list: 124 | sum_dct[k] = [sum_dct[k]] + v 125 | else: 126 | sum_dct[k] = [sum_dct[k]] + [v] 127 | return shelf(**sum_dct) 128 | 129 | elif isinstance(other, dict): 130 | return self.__add__(shelf(**other)) 131 | else: 132 | raise ValueError("shelf or dict is required") 133 | 134 | def __radd__(self, other): 135 | return self.__add__(other) 136 | 137 | def __repr__(self): 138 | items = ("{}={!r}".format(k, self.__dict__[k]) for k in self._keys()) 139 | return "{}({})".format(type(self).__name__, ", ".join(items)) 140 | 141 | def __eq__(self, other): 142 | return self.__dict__ == other.__dict__ 143 | 144 | def __iter__(self): 145 | return self.__dict__.__iter__() 146 | 147 | def __len__(self): 148 | return len(self.__dict__) 149 | 150 | def __setitem__(self, key, value): 151 | self.__dict__[key] = value 152 | 153 | def __getitem__(self, key): 154 | return self.__dict__[key] 155 | 156 | @staticmethod 157 | def _flatten_dict(d, parent_key='', sep='_'): 158 | "Recursively flattens nested dicts" 159 | items = [] 160 | for k, v in d.items(): 161 | new_key = parent_key + sep + k if parent_key else k 162 | if isinstance(v, MutableMapping): 163 | items.extend(shelf._flatten_dict(v, new_key, sep=sep).items()) 164 | else: 165 | items.append((new_key, v)) 166 | return dict(items) 167 | 168 | def _extend(self, keys, values_list): 169 | if type(keys) not in (tuple, list): # Individual key 170 | if keys not in self._keys(): 171 | self[keys] = values_list 172 | else: 173 | self[keys] += values_list 174 | else: 175 | for key, val in zip(keys, values_list): 176 | if type(val) is list: 177 | self._extend(key, val) 178 | else: 179 | self._extend(key, [val]) 180 | return self 181 | 182 | def _keys(self): 183 | return tuple(sorted([k for k in self.__dict__ if not k.startswith('_')])) 184 | 185 | def _values(self): 186 | return tuple([self.__dict__[k] for k in self._keys()]) 187 | 188 | def _items(self): 189 | return tuple(zip(self._keys(), self._values())) 190 | 191 | def _save(self, filename=None, date=True): 192 | if filename is None: 193 | if not hasattr(self, '_filename'): # First save 194 | raise ValueError("filename must be provided the first time you call _save()") 195 | else: # Already saved 196 | torch.save(self, self._filename + '.pt') 197 | else: # New filename 198 | if date: 199 | filename += '_' + time.strftime("%Y%m%d-%H:%M:%S") 200 | # Check if filename does not already exist. If it does, change name. 201 | while os.path.exists(filename + '.pt') and len(filename) < 100: 202 | filename += str(randint(0, 9)) 203 | self._filename = filename 204 | torch.save(self, self._filename + '.pt') 205 | return self 206 | 207 | def _load(self, filename): 208 | try: 209 | self = torch.load(filename) 210 | except FileNotFoundError: 211 | self = torch.load(filename + '.pt') 212 | return self 213 | 214 | def _to_dict(self): 215 | "Returns a dict (it's recursive)" 216 | return_dict = {} 217 | for k, v in self.__dict__.items(): 218 | if isinstance(v, type(self)): 219 | return_dict[k] = v._to_dict() 220 | else: 221 | return_dict[k] = v 222 | return return_dict 223 | 224 | def _flatten(self, parent_key='', sep='_'): 225 | "Recursively flattens nested ddicts" 226 | d = self._to_dict() 227 | return shelf._flatten_dict(d) 228 | 229 | 230 | def log_to_dict(keys_to_log, scope, key_prefix=''): 231 | """ 232 | Examples:: 233 | >>> a,b = 1.0, 2.0 234 | >>> d = log_to_dict(['a', 'b'], d, locals()) 235 | >>> d 236 | >>> {'a': 1.0, 'b': 2.0} 237 | """ 238 | d = dict() 239 | for k in keys_to_log: 240 | v = scope[k] 241 | if isinstance(v, torch.Tensor): 242 | v = v.detach().cpu() # get out of autograd 243 | v = np.array(v, dtype=np.float) 244 | d[key_prefix + k] = v 245 | return d 246 | 247 | 248 | def avg_iterable(iterable, func): 249 | '''Applies function `func` to each element of `iterable` and averages the results 250 | 251 | Args: 252 | iterable: an iterable 253 | func: function being applied on each element of `iterable` 254 | 255 | Returns: 256 | Average of `func` applied on `iterable` 257 | ''' 258 | lst = [func(it) for it in iterable] 259 | return [sum(x) / len(lst) for x in zip(*lst)] 260 | --------------------------------------------------------------------------------