├── .gitignore ├── LICENSE ├── README.md ├── demo.ipynb ├── figures └── algorithm.gif ├── multilayer_perceptron.py ├── neural_interaction_detection.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | temp/ 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Interaction Detection (NID) 2 | 3 | 4 | ![Algorithm](figures/algorithm.gif) 5 | 6 | 7 | [M. Tsang](http://www-scf.usc.edu/~tsangm/), [D. Cheng](http://www-scf.usc.edu/~dehuache/), [Y. Liu](http://www-bcf.usc.edu/~liu32/). Detecting Statistical Interactions from Neural Network Weights, ICLR 2018. [[pdf]](https://openreview.net/pdf?id=ByOfBggRZ) 8 | 9 | 10 | ## Usage 11 | 12 | 13 | - Run the demo at "demo.ipynb" 14 | * the demo trains a multilayer perceptron (MLP) on synthetic data containing interactions with nonlinearities. at the end of the notebook the interactions are found by decoding the learned weights 15 | - requires python 3.6+ and jupyter notebook, tested with pytorch 1.3.1, scikit-learn 0.21.3, numpy 1.17.1 16 | 17 | 18 | ## Reproducibility 19 | If you need to reproduce paper results, please contact me so I can share the original code used for experiments (in Tensorflow). Email: tsangm at usc dot edu 20 | 21 | 22 | ## Reference 23 | If you use NID in your research, please cite the following: 24 | 25 | ``` 26 | @article{tsang2017detecting, 27 | title={Detecting statistical interactions from neural network weights}, 28 | author={Tsang, Michael and Cheng, Dehua and Liu, Yan}, 29 | journal={arXiv preprint arXiv:1705.04977}, 30 | year={2017} 31 | } 32 | ``` 33 | 34 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import numpy as np\n", 11 | "from neural_interaction_detection import get_interactions\n", 12 | "from multilayer_perceptron import MLP, train, get_weights\n", 13 | "from utils import preprocess_data, get_pairwise_auc, get_anyorder_R_precision, set_seed, print_rankings" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "use_main_effect_nets = True # toggle this to use \"main effect\" nets\n", 23 | "num_samples = 30000\n", 24 | "num_features = 10" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Generate synthetic data with ground truth interactions" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "def synth_func(X):\n", 41 | " X1, X2, X3, X4, X5, X6, X7, X8, X9, X10 = X.transpose()\n", 42 | "\n", 43 | " interaction1 = np.exp(np.abs(X1-X2)) \n", 44 | " interaction2 = np.abs(X2*X3) \n", 45 | " interaction3 = -1*(X3**2)**np.abs(X4) \n", 46 | " interaction4 = (X1*X4)**2\n", 47 | " interaction5 = np.log(X4**2 + X5**2 + X7**2 + X8**2)\n", 48 | " main_effects = X9 + 1/(1 + X10**2)\n", 49 | "\n", 50 | " Y = interaction1 + interaction2 + interaction3 + interaction4 + interaction5 + main_effects\n", 51 | " ground_truth = [ {1,2}, {2,3}, {3,4}, {1,4}, {4,5,7,8} ]\n", 52 | " \n", 53 | " return Y, ground_truth" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "set_seed(42)\n", 63 | "X = np.random.uniform(low=-1, high=1, size=(num_samples,num_features))\n", 64 | "Y, ground_truth = synth_func(X)\n", 65 | "data_loaders = preprocess_data(X, Y, valid_size=10000, test_size=10000, std_scale=True, get_torch_loaders=True)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Train a multilayer perceptron (MLP)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 5, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "device = torch.device(\"cpu\")\n", 82 | "model = MLP(num_features, [140, 100, 60, 20], use_main_effect_nets=use_main_effect_nets).to(device)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 6, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "starting to train\n", 95 | "early stopping enabled\n", 96 | "[epoch 1, total 100] train loss: 0.1921, val loss: 0.0548\n", 97 | "[epoch 3, total 100] train loss: 0.0290, val loss: 0.0283\n", 98 | "[epoch 5, total 100] train loss: 0.0239, val loss: 0.0557\n", 99 | "[epoch 7, total 100] train loss: 0.0151, val loss: 0.0168\n", 100 | "[epoch 9, total 100] train loss: 0.0143, val loss: 0.0184\n", 101 | "[epoch 11, total 100] train loss: 0.0116, val loss: 0.0083\n", 102 | "[epoch 13, total 100] train loss: 0.0123, val loss: 0.0117\n", 103 | "[epoch 15, total 100] train loss: 0.0104, val loss: 0.0094\n", 104 | "[epoch 17, total 100] train loss: 0.0077, val loss: 0.0137\n", 105 | "[epoch 19, total 100] train loss: 0.0083, val loss: 0.0139\n", 106 | "[epoch 21, total 100] train loss: 0.0070, val loss: 0.0054\n", 107 | "[epoch 23, total 100] train loss: 0.0091, val loss: 0.0063\n", 108 | "[epoch 25, total 100] train loss: 0.0111, val loss: 0.0099\n", 109 | "[epoch 27, total 100] train loss: 0.0064, val loss: 0.0068\n", 110 | "early stopping!\n", 111 | "Finished Training. Test loss: 0.005764756351709366\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "model, mlp_loss = train(model, data_loaders, device=device, learning_rate=1e-2, l1_const = 5e-5, verbose=True)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## Get the MLP's learned weights" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 7, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "model_weights = get_weights(model)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## Detect interactions from the weights" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 8, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Pairwise interactions Arbitrary-order interactions\n", 152 | "(1, 2) 7.8430 (1, 2) 6.8951 \n", 153 | "(4, 8) 3.1959 (2, 3) 2.0953 \n", 154 | "(5, 8) 3.0521 (7, 8) 1.7971 \n", 155 | "(7, 8) 3.0290 (4, 5, 8) 1.6026 \n", 156 | "(4, 5) 2.8506 (1, 4) 1.5912 \n", 157 | "(2, 3) 2.6294 (5, 7) 1.5261 \n", 158 | "(1, 4) 2.5037 (3, 4) 1.3500 \n", 159 | "(5, 7) 2.4460 (4, 7) 1.0580 \n", 160 | "(4, 7) 2.2369 (4, 7, 8) 0.7727 \n", 161 | "(3, 4) 1.8870 (4, 5, 7, 8) 0.5467 \n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "anyorder_interactions = get_interactions(model_weights, one_indexed=True)\n", 167 | "pairwise_interactions = get_interactions(model_weights, pairwise=True, one_indexed=True)\n", 168 | "\n", 169 | " \n", 170 | "print_rankings(pairwise_interactions, anyorder_interactions, top_k=10, spacing=14)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## Evaluate the interactions" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 9, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "Pairwise AUC 1.0 , Any-order R-Precision 1.0\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "auc = get_pairwise_auc(pairwise_interactions, ground_truth)\n", 195 | "r_prec = get_anyorder_R_precision(anyorder_interactions, ground_truth)\n", 196 | "\n", 197 | "print(\"Pairwise AUC\", auc, \", Any-order R-Precision\", r_prec)" 198 | ] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "Python [conda env:torch]", 204 | "language": "python", 205 | "name": "conda-env-torch-py" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 3 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython3", 217 | "version": "3.6.2" 218 | } 219 | }, 220 | "nbformat": 4, 221 | "nbformat_minor": 2 222 | } 223 | -------------------------------------------------------------------------------- /figures/algorithm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/neural-interaction-detection/ae974e16816da3e34e62a259376fef00d97f93e4/figures/algorithm.gif -------------------------------------------------------------------------------- /multilayer_perceptron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import copy 6 | 7 | 8 | def get_weights(model): 9 | weights = [] 10 | for name, param in model.named_parameters(): 11 | if "interaction_mlp" in name and "weight" in name: 12 | weights.append(param.cpu().detach().numpy()) 13 | return weights 14 | 15 | 16 | class MLP(nn.Module): 17 | def __init__( 18 | self, 19 | num_features, 20 | hidden_units, 21 | use_main_effect_nets=False, 22 | main_effect_net_units=[10, 10, 10], 23 | ): 24 | super(MLP, self).__init__() 25 | 26 | self.hidden_units = hidden_units 27 | self.use_main_effect_nets = use_main_effect_nets 28 | self.interaction_mlp = create_mlp([num_features] + hidden_units + [1]) 29 | 30 | if main_effect_net_units == [1]: 31 | use_linear = True 32 | else: 33 | use_linear = False 34 | self.use_linear = use_linear 35 | 36 | if self.use_main_effect_nets: 37 | 38 | if use_linear: 39 | self.linear = nn.Linear(num_features, 1, bias=False) 40 | else: 41 | self.univariate_mlps = self.create_main_effect_nets( 42 | num_features, main_effect_net_units, False, "uni" 43 | ) 44 | 45 | def forward(self, x): 46 | y = self.interaction_mlp(x) 47 | 48 | if self.use_main_effect_nets: 49 | if self.use_linear: 50 | y += self.linear(x) 51 | else: 52 | y += self.forward_main_effect_nets(x, self.univariate_mlps) 53 | return y 54 | 55 | def create_main_effect_nets(self, num_features, hidden_units, out_bias, name): 56 | mlp_list = [ 57 | create_mlp([1] + hidden_units + [1], out_bias=out_bias) 58 | for _ in range(num_features) 59 | ] 60 | for i in range(num_features): 61 | setattr(self, name + "_" + str(i), mlp_list[i]) 62 | return mlp_list 63 | 64 | def forward_main_effect_nets(self, x, mlps): 65 | forwarded_mlps = [] 66 | for i, mlp in enumerate(mlps): 67 | forwarded_mlps.append(mlp(x[:, [i]])) 68 | forwarded_mlp = sum(forwarded_mlps) 69 | return forwarded_mlp 70 | 71 | 72 | def create_mlp(layer_sizes, out_bias=True): 73 | ls = list(layer_sizes) 74 | layers = nn.ModuleList() 75 | for i in range(1, len(ls) - 1): 76 | layers.append(nn.Linear(int(ls[i - 1]), int(ls[i]))) 77 | layers.append(nn.ReLU()) 78 | layers.append(nn.Linear(int(ls[-2]), int(ls[-1]), bias=out_bias)) 79 | return nn.Sequential(*layers) 80 | 81 | 82 | def train( 83 | net, 84 | data_loaders, 85 | criterion=nn.MSELoss(reduction="mean"), 86 | nepochs=100, 87 | verbose=False, 88 | early_stopping=True, 89 | patience=5, 90 | l1_const=1e-4, 91 | l2_const=0, 92 | learning_rate=0.01, 93 | opt_func=optim.Adam, 94 | device=torch.device("cpu"), 95 | ): 96 | optimizer = opt_func(net.parameters(), lr=learning_rate, weight_decay=l2_const) 97 | 98 | def evaluate(net, data_loader, criterion, device): 99 | losses = [] 100 | for inputs, labels in data_loader: 101 | inputs = inputs.to(device) 102 | labels = labels.to(device) 103 | loss = criterion(net(inputs), labels).cpu().data 104 | losses.append(loss) 105 | return torch.stack(losses).mean() 106 | 107 | best_loss = float("inf") 108 | best_net = None 109 | 110 | if "val" not in data_loaders: 111 | early_stopping = False 112 | 113 | patience_counter = 0 114 | 115 | if verbose: 116 | print("starting to train") 117 | if early_stopping: 118 | print("early stopping enabled") 119 | 120 | for epoch in range(nepochs): 121 | running_loss = 0.0 122 | run_count = 0 123 | for i, data in enumerate(data_loaders["train"], 0): 124 | inputs, labels = data 125 | inputs = inputs.to(device) 126 | labels = labels.to(device) 127 | optimizer.zero_grad() 128 | outputs = net(inputs) 129 | loss = criterion(outputs, labels).mean() 130 | 131 | reg_loss = 0 132 | for name, param in net.named_parameters(): 133 | if "interaction_mlp" in name and "weight" in name: 134 | reg_loss += torch.sum(torch.abs(param)) 135 | (loss + reg_loss * l1_const).backward() 136 | optimizer.step() 137 | running_loss += loss.item() 138 | run_count += 1 139 | 140 | if epoch % 1 == 0: 141 | key = "val" if "val" in data_loaders else "train" 142 | val_loss = evaluate(net, data_loaders[key], criterion, device) 143 | 144 | if epoch % 2 == 0: 145 | if verbose: 146 | print( 147 | "[epoch %d, total %d] train loss: %.4f, val loss: %.4f" 148 | % (epoch + 1, nepochs, running_loss / run_count, val_loss) 149 | ) 150 | if early_stopping: 151 | if val_loss < best_loss: 152 | best_loss = val_loss 153 | best_net = copy.deepcopy(net) 154 | patience_counter = 0 155 | else: 156 | patience_counter += 1 157 | if patience_counter > patience: 158 | net = best_net 159 | val_loss = best_loss 160 | if verbose: 161 | print("early stopping!") 162 | break 163 | 164 | prev_loss = running_loss 165 | running_loss = 0.0 166 | 167 | if "test" in data_loaders: 168 | key = "test" 169 | elif "val" in data_loaders: 170 | key = "val" 171 | else: 172 | key = "train" 173 | test_loss = evaluate(net, data_loaders[key], criterion, device).item() 174 | 175 | if verbose: 176 | print("Finished Training. Test loss: ", test_loss) 177 | 178 | return net, test_loss 179 | -------------------------------------------------------------------------------- /neural_interaction_detection.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import operator 3 | import numpy as np 4 | import torch 5 | from torch.utils import data 6 | from multilayer_perceptron import * 7 | from utils import * 8 | 9 | 10 | def preprocess_weights(weights): 11 | w_later = np.abs(weights[-1]) 12 | w_input = np.abs(weights[0]) 13 | 14 | for i in range(len(weights) - 2, 0, -1): 15 | w_later = np.matmul(w_later, np.abs(weights[i])) 16 | 17 | return w_input, w_later 18 | 19 | 20 | def make_one_indexed(interaction_ranking): 21 | return [(tuple(np.array(i) + 1), s) for i, s in interaction_ranking] 22 | 23 | 24 | def interpret_interactions(w_input, w_later, get_main_effects=False): 25 | interaction_strengths = {} 26 | for i in range(w_later.shape[1]): 27 | sorted_hweights = sorted( 28 | enumerate(w_input[i]), key=lambda x: x[1], reverse=True 29 | ) 30 | interaction_candidate = [] 31 | candidate_weights = [] 32 | for j in range(w_input.shape[1]): 33 | bisect.insort(interaction_candidate, sorted_hweights[j][0]) 34 | candidate_weights.append(sorted_hweights[j][1]) 35 | 36 | if not get_main_effects and len(interaction_candidate) == 1: 37 | continue 38 | interaction_tup = tuple(interaction_candidate) 39 | if interaction_tup not in interaction_strengths: 40 | interaction_strengths[interaction_tup] = 0 41 | interaction_strength = (min(candidate_weights)) * (np.sum(w_later[:, i])) 42 | interaction_strengths[interaction_tup] += interaction_strength 43 | 44 | interaction_ranking = sorted( 45 | interaction_strengths.items(), key=operator.itemgetter(1), reverse=True 46 | ) 47 | 48 | return interaction_ranking 49 | 50 | 51 | def interpret_pairwise_interactions(w_input, w_later): 52 | p = w_input.shape[1] 53 | 54 | interaction_ranking = [] 55 | for i in range(p): 56 | for j in range(p): 57 | if i < j: 58 | strength = (np.minimum(w_input[:, i], w_input[:, j]) * w_later).sum() 59 | interaction_ranking.append(((i, j), strength)) 60 | 61 | interaction_ranking.sort(key=lambda x: x[1], reverse=True) 62 | return interaction_ranking 63 | 64 | 65 | def get_interactions(weights, pairwise=False, one_indexed=False): 66 | w_input, w_later = preprocess_weights(weights) 67 | 68 | if pairwise: 69 | interaction_ranking = interpret_pairwise_interactions(w_input, w_later) 70 | else: 71 | interaction_ranking = interpret_interactions(w_input, w_later) 72 | interaction_ranking = prune_redundant_interactions(interaction_ranking) 73 | 74 | if one_indexed: 75 | return make_one_indexed(interaction_ranking) 76 | else: 77 | return interaction_ranking 78 | 79 | 80 | def prune_redundant_interactions(interaction_ranking, max_interactions=100): 81 | interaction_ranking_pruned = [] 82 | current_superset_inters = [] 83 | for inter, strength in interaction_ranking: 84 | set_inter = set(inter) 85 | if len(interaction_ranking_pruned) >= max_interactions: 86 | break 87 | subset_inter_skip = False 88 | update_superset_inters = [] 89 | for superset_inter in current_superset_inters: 90 | if set_inter < superset_inter: 91 | subset_inter_skip = True 92 | break 93 | elif not (set_inter > superset_inter): 94 | update_superset_inters.append(superset_inter) 95 | if subset_inter_skip: 96 | continue 97 | current_superset_inters = update_superset_inters 98 | current_superset_inters.append(set_inter) 99 | interaction_ranking_pruned.append((inter, strength)) 100 | 101 | return interaction_ranking_pruned 102 | 103 | 104 | def detect_interactions( 105 | Xd, 106 | Yd, 107 | arch=[256, 128, 64], 108 | batch_size=100, 109 | device=torch.device("cpu"), 110 | seed=None, 111 | **kwargs 112 | ): 113 | 114 | if seed is not None: 115 | set_seed(seed) 116 | 117 | data_loaders = convert_to_torch_loaders(Xd, Yd, batch_size) 118 | 119 | model = create_mlp([feats.shape[1]] + arch + [1]).to(device) 120 | 121 | model, mlp_loss = train(model, data_loaders, device=device, **kwargs) 122 | inters = get_interactions(get_weights(model)) 123 | 124 | return inters, mlp_loss 125 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | import sklearn 5 | from sklearn.preprocessing import StandardScaler 6 | from sklearn.metrics import roc_auc_score 7 | 8 | 9 | def set_seed(seed=42): 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | if torch.cuda.is_available(): 13 | torch.cuda.manual_seed(seed) 14 | 15 | 16 | def force_float(X_numpy): 17 | return torch.from_numpy(X_numpy.astype(np.float32)) 18 | 19 | 20 | def convert_to_torch_loaders(Xd, Yd, batch_size): 21 | if type(Xd) != dict and type(Yd) != dict: 22 | Xd = {"train": Xd} 23 | Yd = {"train": Yd} 24 | 25 | data_loaders = {} 26 | for k in Xd: 27 | if k == "scaler": 28 | continue 29 | feats = force_float(Xd[k]) 30 | targets = force_float(Yd[k]) 31 | dataset = data.TensorDataset(feats, targets) 32 | data_loaders[k] = data.DataLoader(dataset, batch_size, shuffle=(k == "train")) 33 | 34 | return data_loaders 35 | 36 | 37 | def preprocess_data( 38 | X, 39 | Y, 40 | valid_size=500, 41 | test_size=500, 42 | std_scale=False, 43 | get_torch_loaders=False, 44 | batch_size=100, 45 | ): 46 | 47 | n, p = X.shape 48 | ## Make dataset splits 49 | ntrain, nval, ntest = n - valid_size - test_size, valid_size, test_size 50 | 51 | Xd = { 52 | "train": X[:ntrain], 53 | "val": X[ntrain : ntrain + nval], 54 | "test": X[ntrain + nval : ntrain + nval + ntest], 55 | } 56 | Yd = { 57 | "train": np.expand_dims(Y[:ntrain], axis=1), 58 | "val": np.expand_dims(Y[ntrain : ntrain + nval], axis=1), 59 | "test": np.expand_dims(Y[ntrain + nval : ntrain + nval + ntest], axis=1), 60 | } 61 | 62 | for k in Xd: 63 | if len(Xd[k]) == 0: 64 | assert k != "train" 65 | del Xd[k] 66 | del Yd[k] 67 | 68 | if std_scale: 69 | scaler_x = StandardScaler() 70 | scaler_y = StandardScaler() 71 | 72 | scaler_x.fit(Xd["train"]) 73 | scaler_y.fit(Yd["train"]) 74 | 75 | for k in Xd: 76 | Xd[k] = scaler_x.transform(Xd[k]) 77 | Yd[k] = scaler_y.transform(Yd[k]) 78 | 79 | Xd["scaler"] = scaler_x 80 | Yd["scaler"] = scaler_y 81 | 82 | if get_torch_loaders: 83 | return convert_to_torch_loaders(Xd, Yd, batch_size) 84 | 85 | else: 86 | return Xd, Yd 87 | 88 | 89 | def get_pairwise_auc(interactions, ground_truth): 90 | strengths = [] 91 | gt_binary_list = [] 92 | for inter, strength in interactions: 93 | inter_set = set(inter) # assume 1-indexed 94 | strengths.append(strength) 95 | if any(inter_set <= gt for gt in ground_truth): 96 | gt_binary_list.append(1) 97 | else: 98 | gt_binary_list.append(0) 99 | 100 | auc = roc_auc_score(gt_binary_list, strengths) 101 | return auc 102 | 103 | 104 | def get_anyorder_R_precision(interactions, ground_truth): 105 | 106 | R = len(ground_truth) 107 | recovered_gt = [] 108 | counter = 0 109 | 110 | for inter, strength in interactions: 111 | if counter == R: 112 | break 113 | 114 | inter_set = set(inter) # assume 1-indexed 115 | 116 | if any(inter_set < gt for gt in ground_truth): 117 | continue 118 | counter += 1 119 | if inter_set in ground_truth: 120 | recovered_gt.append(inter_set) 121 | 122 | R_precision = len(recovered_gt) / R 123 | 124 | return R_precision 125 | 126 | 127 | def print_rankings(pairwise_interactions, anyorder_interactions, top_k=10, spacing=14): 128 | print( 129 | justify(["Pairwise interactions", "", "Arbitrary-order interactions"], spacing) 130 | ) 131 | for i in range(top_k): 132 | p_inter, p_strength = pairwise_interactions[i] 133 | a_inter, a_strength = anyorder_interactions[i] 134 | print( 135 | justify( 136 | [ 137 | p_inter, 138 | "{0:.4f}".format(p_strength), 139 | "", 140 | a_inter, 141 | "{0:.4f}".format(a_strength), 142 | ], 143 | spacing, 144 | ) 145 | ) 146 | 147 | 148 | def justify(row, spacing=14): 149 | return "".join(str(item).ljust(spacing) for item in row) 150 | --------------------------------------------------------------------------------