├── .gitignore ├── LICENSE ├── README.md ├── assets ├── ucate.png └── ucate_600x600.png ├── environment.yml ├── setup.py └── ucate ├── __init__.py ├── application ├── __init__.py ├── main.py └── workflows │ ├── __init__.py │ ├── bart.py │ ├── cevae.py │ ├── evaluation.py │ ├── tarnet.py │ └── tlearner.py └── library ├── __init__.py ├── data ├── __init__.py ├── acic.py ├── cemnist.py ├── core.py └── ihdp.py ├── evaluation.py ├── layers.py ├── models ├── __init__.py ├── cevae.py ├── cnn.py ├── core.py ├── mlp.py └── tarnet.py ├── modules ├── __init__.py ├── convolution.py ├── core.py ├── dense.py └── samplers.py ├── scratch.py └── utils ├── __init__.py ├── plotting.py └── prediction.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Dev 132 | .idea/ 133 | .vscode/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Andrew Jesson 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This repository contains the code to replicate the results reported in 3 | [Identifying Causal Effect Inference Failure with Uncertainty-Aware Models](https://arxiv.org/abs/2007.00163). 4 | 5 | ![Intuitive Uncertainty for CATE](assets/ucate.png) 6 | 7 | ## Installation 8 | ``` 9 | $ git clone git@github.com:OATML/ucate.git 10 | $ cd ucate 11 | $ conda env create -f environment.yml 12 | $ conda activate ucate 13 | ``` 14 | 15 | ### Download data 16 | ``` 17 | $ mkdir data 18 | ``` 19 | 20 | To run the following experiments, the IHDP [train](http://www.fredjo.com/files/ihdp_npci_1-1000.train.npz.zip), 21 | [test](http://www.fredjo.com/files/ihdp_npci_1-1000.test.npz.zip), and 22 | [ACIC2016](https://jenniferhill7.wixsite.com/acic-2016/competition) datasets must be downloaded into the folder created 23 | above. The datasets will need to be uncompressed. 24 | 25 | ## BTLearner Experiment 26 | ### IHDP 27 | ``` 28 | $ ucate train \ 29 | --job-dir ~/experiments/ucate \ 30 | --dataset-name ihdp \ 31 | --data-dir data/ \ 32 | --num-trials 1000 \ 33 | tlearner 34 | $ ucate evaluate \ 35 | --experiment-dir ~/experiments/ucate/ihdp/tlearner/bf-200_dp-5_dr-0.5_bs-100_lr-0.001_ep-False 36 | ``` 37 | 38 | ### IHDP Covariate Shift 39 | ``` 40 | $ ucate train \ 41 | --job-dir ~/experiments/ucate \ 42 | --dataset-name ihdp \ 43 | --data-dir data/ \ 44 | --num-trials 1000 \ 45 | --exclude-population True 46 | tlearner 47 | $ ucate evaluate \ 48 | --experiment-dir ~/experiments/ucate/ihdp/tlearner/bf-200_dp-5_dr-0.5_bs-100_lr-0.001_ep-False 49 | ``` 50 | 51 | ### CEMNIST 52 | ``` 53 | $ ucate train \ 54 | --job-dir ~/experiments/ucate \ 55 | --dataset-name cemnist \ 56 | --data-dir data/ \ 57 | --num-trials 20 58 | tlearner 59 | $ ucate evaluate \ 60 | --experiment-dir ~/experiments/ucate/cemnist/tlearner/bf-200_dp-5_dr-0.5_bs-100_lr-0.001_ep-False 61 | ``` 62 | 63 | ### ACIC 64 | ``` 65 | $ ucate train \ 66 | --job-dir ~/experiments/ucate \ 67 | --dataset-name acic \ 68 | --data-dir data/data_cf_all/ \ 69 | --num-trials 77 70 | tlearner 71 | $ ucate evaluate \ 72 | --experiment-dir ~/experiments/ucate/acic/tlearner/bf-200_dp-5_dr-0.5_bs-100_lr-0.001_ep-False 73 | ``` 74 | 75 | ## BTARNet Experiment 76 | ### IHDP 77 | ``` 78 | $ ucate train \ 79 | --job-dir ~/experiments/ucate \ 80 | --dataset-name ihdp \ 81 | --data-dir data/ \ 82 | --num-trials 1000 \ 83 | tarnet 84 | $ ucate evaluate \ 85 | --experiment-dir ~/experiments/ucate/ihdp/tarnet/md-tarnet_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 86 | ``` 87 | 88 | ### IHDP Covariate Shift 89 | ``` 90 | $ ucate train \ 91 | --job-dir ~/experiments/ucate \ 92 | --dataset-name ihdp \ 93 | --data-dir data/ \ 94 | --num-trials 1000 \ 95 | --exclude-population True 96 | tarnet 97 | $ ucate evaluate \ 98 | --experiment-dir ~/experiments/ucate/ihdp/tarnet/md-tarnet_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-True 99 | ``` 100 | 101 | ### CEMNIST 102 | ``` 103 | $ ucate train \ 104 | --job-dir ~/experiments/ucate \ 105 | --dataset-name cemnist \ 106 | --data-dir data/ \ 107 | --num-trials 20 108 | tarnet 109 | $ ucate evaluate \ 110 | --experiment-dir ~/experiments/ucate/cemnist/tarnet/md-tarnet_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 111 | ``` 112 | 113 | ### ACIC 114 | ``` 115 | $ ucate train \ 116 | --job-dir ~/experiments/ucate \ 117 | --dataset-name acic \ 118 | --data-dir data/data_cf_all/ \ 119 | --num-trials 77 120 | tarnet 121 | $ ucate evaluate \ 122 | --experiment-dir ~/experiments/ucate/acic/tarnet/md-tarnet_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 123 | ``` 124 | 125 | ## BCFR-MMD Experiment 126 | ### IHDP 127 | ``` 128 | $ ucate train \ 129 | --job-dir ~/experiments/ucate \ 130 | --dataset-name ihdp \ 131 | --data-dir data/ \ 132 | --num-trials 1000 \ 133 | tarnet \ 134 | --mode mmd 135 | $ ucate evaluate \ 136 | --experiment-dir ~/experiments/ucate/ihdp/tarnet/md-mmd_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 137 | ``` 138 | 139 | ### IHDP Covariate Shift 140 | ``` 141 | $ ucate train \ 142 | --job-dir ~/experiments/ucate \ 143 | --dataset-name ihdp \ 144 | --data-dir data/ \ 145 | --num-trials 1000 \ 146 | --exclude-population True 147 | tarnet \ 148 | --mode mmd 149 | $ ucate evaluate \ 150 | --experiment-dir ~/experiments/ucate/ihdp/tarnet/md-mmd_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-True 151 | ``` 152 | 153 | ### CEMNIST 154 | ``` 155 | $ ucate train \ 156 | --job-dir ~/experiments/ucate \ 157 | --dataset-name cemnist \ 158 | --data-dir data/ \ 159 | --num-trials 20 160 | tarnet \ 161 | --mode mmd 162 | $ ucate evaluate \ 163 | --experiment-dir ~/experiments/ucate/cemnist/tarnet/md-mmd_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 164 | ``` 165 | 166 | ### ACIC 167 | ``` 168 | $ ucate train \ 169 | --job-dir ~/experiments/ucate \ 170 | --dataset-name acic \ 171 | --data-dir data/data_cf_all/ \ 172 | --num-trials 77 173 | tarnet \ 174 | --mode mmd 175 | $ ucate evaluate \ 176 | --experiment-dir ~/experiments/ucate/acic/tarnet/md-mmd_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 177 | ``` 178 | 179 | ## BDragonnet Experiment 180 | ### IHDP 181 | ``` 182 | $ ucate train \ 183 | --job-dir ~/experiments/ucate \ 184 | --dataset-name ihdp \ 185 | --data-dir data/ \ 186 | --num-trials 1000 \ 187 | tarnet \ 188 | --mode dragon 189 | $ ucate evaluate \ 190 | --experiment-dir ~/experiments/ucate/ihdp/tarnet/md-dragon_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 191 | ``` 192 | 193 | ### IHDP Covariate Shift 194 | ``` 195 | $ ucate train \ 196 | --job-dir ~/experiments/ucate \ 197 | --dataset-name ihdp \ 198 | --data-dir data/ \ 199 | --num-trials 1000 \ 200 | --exclude-population True 201 | tarnet \ 202 | --mode dragon 203 | $ ucate evaluate \ 204 | --experiment-dir ~/experiments/ucate/ihdp/tarnet/md-dragon_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-True 205 | ``` 206 | 207 | ### CEMNIST 208 | ``` 209 | $ ucate train \ 210 | --job-dir ~/experiments/ucate \ 211 | --dataset-name cemnist \ 212 | --data-dir data/ \ 213 | --num-trials 20 214 | tarnet \ 215 | --mode dragon 216 | $ ucate evaluate \ 217 | --experiment-dir ~/experiments/ucate/cemnist/tarnet/md-dragon_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 218 | ``` 219 | 220 | ### ACIC 221 | ``` 222 | $ ucate train \ 223 | --job-dir ~/experiments/ucate \ 224 | --dataset-name acic \ 225 | --data-dir data/data_cf_all/ \ 226 | --num-trials 77 227 | tarnet \ 228 | --mode dragon 229 | $ ucate evaluate \ 230 | --experiment-dir ~/experiments/ucate/acic/tarnet/md-dragon_bf-200_dr-0.5_beta-1.0_bs-100_lr-0.001_ep-False 231 | ``` 232 | 233 | ## BCEVAE Experiment 234 | ### IHDP 235 | ``` 236 | $ ucate train \ 237 | --job-dir ~/experiments/ucate \ 238 | --dataset-name ihdp \ 239 | --data-dir data/ \ 240 | --num-trials 1000 \ 241 | cevae 242 | $ ucate evaluate \ 243 | --experiment-dir ~/experiments/ucate/ihdp/cevae/dl-32_bf-200_dr-0.1_beta-0.1_ns-True_bs-100_lr-0.001_ep-False 244 | ``` 245 | 246 | ### IHDP Covariate Shift 247 | ``` 248 | $ ucate train \ 249 | --job-dir ~/experiments/ucate \ 250 | --dataset-name ihdp \ 251 | --data-dir data/ \ 252 | --num-trials 1000 \ 253 | --exclude-population True 254 | cevae 255 | $ ucate evaluate \ 256 | --experiment-dir ~/experiments/ucate/ihdp/cevae/dl-32_bf-200_dr-0.1_beta-0.1_ns-True_bs-100_lr-0.001_ep-True 257 | ``` 258 | 259 | ### CEMNIST 260 | ``` 261 | $ ucate train \ 262 | --job-dir ~/experiments/ucate \ 263 | --dataset-name cemnist \ 264 | --data-dir data/ \ 265 | --num-trials 20 266 | cevae 267 | --learning-rate 2e-4 268 | $ ucate evaluate \ 269 | --experiment-dir ~/experiments/ucate/cemnist/cevae/dl-32_bf-200_dr-0.1_beta-0.1_ns-True_bs-100_lr-0.0002_ep-False 270 | ``` 271 | 272 | ### ACIC 273 | ``` 274 | $ ucate train \ 275 | --job-dir ~/experiments/ucate \ 276 | --dataset-name acic \ 277 | --data-dir data/data_cf_all/ \ 278 | --num-trials 77 279 | cevae 280 | $ ucate evaluate \ 281 | --experiment-dir ~/experiments/ucate/acic/cevae/dl-32_bf-200_dr-0.1_beta-0.1_ns-True_bs-100_lr-0.001_ep-False 282 | ``` 283 | 284 | ## BART Experiment 285 | ### IHDP 286 | ``` 287 | $ ucate train \ 288 | --job-dir ~/experiments/ucate \ 289 | --dataset-name ihdp \ 290 | --data-dir data/ \ 291 | --num-trials 1000 \ 292 | bart 293 | $ ucate evaluate \ 294 | --experiment-dir ~/experiments/ucate/ihdp/bart/ep-False 295 | ``` 296 | 297 | ### IHDP Covariate Shift 298 | ``` 299 | $ ucate train \ 300 | --job-dir ~/experiments/ucate \ 301 | --dataset-name ihdp \ 302 | --data-dir data/ \ 303 | --num-trials 1000 \ 304 | --exclude-population True 305 | bart 306 | $ ucate evaluate \ 307 | --experiment-dir ~/experiments/ucate/ihdp/bart/ep-True 308 | ``` 309 | 310 | ### CEMNIST 311 | ``` 312 | $ ucate train \ 313 | --job-dir ~/experiments/ucate \ 314 | --dataset-name cemnist \ 315 | --data-dir data/ \ 316 | --num-trials 20 317 | bart 318 | $ ucate evaluate \ 319 | --experiment-dir ~/experiments/ucate/cemnist/bart/ep-False 320 | ``` 321 | 322 | ### ACIC 323 | ``` 324 | $ ucate train \ 325 | --job-dir ~/experiments/ucate \ 326 | --dataset-name acic \ 327 | --data-dir data/data_cf_all/ \ 328 | --num-trials 77 329 | bart 330 | $ ucate evaluate \ 331 | --experiment-dir ~/experiments/ucate/acic/bart/ep-False 332 | ``` 333 | -------------------------------------------------------------------------------- /assets/ucate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/ucate/cddc23596a463e2ce9e270cf075b3e924137bf7a/assets/ucate.png -------------------------------------------------------------------------------- /assets/ucate_600x600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/ucate/cddc23596a463e2ce9e270cf075b3e924137bf7a/assets/ucate_600x600.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ucate 2 | channels: 3 | - defaults 4 | dependencies: 5 | - pip 6 | - r-base 7 | - r-rjava 8 | - python=3.7 9 | - cudatoolkit=10.1 10 | - pip: 11 | - ray 12 | - -e . -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | here = os.path.abspath(os.path.dirname(__file__)) 5 | 6 | setup( 7 | name="ucate", 8 | version="0.0.0", 9 | description="Exploring uncertainty for CATE inference", 10 | long_description_content_type="text/markdown", 11 | url="https://github.com/OATML/ucate", 12 | author="Andrew Jesson", 13 | author_email="andrew.jesson@cs.ox.ac.uk", 14 | license="Apache-2.0", 15 | packages=find_packages(), 16 | install_requires=[ 17 | "rpy2", 18 | "click", 19 | "numpy>=1.16.0,<1.19.0", 20 | "scipy", 21 | "pandas", 22 | "sklearn", 23 | "seaborn", 24 | "matplotlib", 25 | "tensorflow==2.3.1", 26 | "tensorflow-probability==0.11.1", 27 | ], 28 | entry_points={ 29 | "console_scripts": ["ucate=ucate.application.main:cli"], 30 | }, 31 | ) 32 | -------------------------------------------------------------------------------- /ucate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/ucate/cddc23596a463e2ce9e270cf075b3e924137bf7a/ucate/__init__.py -------------------------------------------------------------------------------- /ucate/application/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/ucate/cddc23596a463e2ce9e270cf075b3e924137bf7a/ucate/application/__init__.py -------------------------------------------------------------------------------- /ucate/application/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | import ray 5 | import json 6 | import click 7 | import tensorflow as tf 8 | 9 | from ucate.application import workflows 10 | 11 | 12 | @click.group(chain=True) 13 | @click.pass_context 14 | def cli(context): 15 | gpus = tf.config.experimental.list_physical_devices("GPU") 16 | context.obj = {"n_gpu": len(gpus)} 17 | 18 | 19 | @cli.command("train") 20 | @click.option( 21 | "--job-dir", 22 | type=str, 23 | required=True, 24 | help="local or GCS location for writing checkpoints and exporting models", 25 | ) 26 | @click.option("--dataset-name", type=str, default="ihdp", help="dataset name") 27 | @click.option( 28 | "--data-dir", 29 | type=str, 30 | default="./data/", 31 | help="location to write/read dataset, default=.data/", 32 | ) 33 | @click.option( 34 | "--exclude-population", 35 | default=False, 36 | type=bool, 37 | help="exclude population from training, default=False", 38 | ) 39 | @click.option("--num-trials", default=1, type=int, help="number of trials, default=1") 40 | @click.option( 41 | "--gpu-per-trial", 42 | default=0.0, 43 | type=float, 44 | help="number of gpus for each trial, default=0", 45 | ) 46 | @click.option( 47 | "--cpu-per-trial", 48 | default=1.0, 49 | type=float, 50 | help="number of cpus for each trial, default=1", 51 | ) 52 | @click.option("--verbose", default=False, type=bool, help="verbosity default=False") 53 | @click.pass_context 54 | def train( 55 | context, 56 | job_dir, 57 | dataset_name, 58 | data_dir, 59 | exclude_population, 60 | num_trials, 61 | gpu_per_trial, 62 | cpu_per_trial, 63 | verbose, 64 | ): 65 | ray.init( 66 | num_gpus=context.obj["n_gpu"], 67 | dashboard_host="127.0.0.1", 68 | ignore_reinit_error=True, 69 | ) 70 | gpu_per_trial = 0 if context.obj["n_gpu"] == 0 else gpu_per_trial 71 | context.obj.update( 72 | { 73 | "job_dir": job_dir, 74 | "dataset_name": dataset_name, 75 | "data_dir": data_dir, 76 | "exclude_population": exclude_population, 77 | "num_trials": num_trials, 78 | "gpu_per_trial": gpu_per_trial, 79 | "cpu_per_trial": cpu_per_trial, 80 | "verbose": verbose, 81 | } 82 | ) 83 | 84 | 85 | @cli.command("bart") 86 | @click.pass_context 87 | def bart(context): 88 | from ucate.application.workflows import bart 89 | 90 | config = context.obj 91 | dataset_name = config.get("dataset_name") 92 | exclude_population = config.get("exclude_population") 93 | experiment_name = f"ep-{exclude_population}" 94 | 95 | bart.install() 96 | 97 | results = [] 98 | for trial in range(config.get("num_trials")): 99 | output_dir = os.path.join( 100 | config.get("job_dir"), 101 | dataset_name, 102 | "bart", 103 | experiment_name, 104 | f"trial_{trial:03d}", 105 | ) 106 | os.makedirs(output_dir, exist_ok=True) 107 | config["output_dir"] = output_dir 108 | config["trial"] = trial 109 | config_file = os.path.join(output_dir, "config.json") 110 | with open(config_file, "w") as fp: 111 | json.dump(config, fp, indent=4, sort_keys=True) 112 | 113 | results.append( 114 | bart.train( 115 | output_dir=output_dir, 116 | dataset_name=dataset_name, 117 | data_dir=config.get("data_dir"), 118 | trial=trial, 119 | exclude_population=exclude_population, 120 | verbose=config.get("verbose"), 121 | ) 122 | ) 123 | 124 | 125 | @cli.command("tarnet") 126 | @click.pass_context 127 | @click.option( 128 | "--mode", default="tarnet", type=str, help="mode, one of tarnet, dragon, or mmd" 129 | ) 130 | @click.option( 131 | "--base-filters", default=200, type=int, help="base number of filters, default=200" 132 | ) 133 | @click.option( 134 | "--dropout-rate", default=0.5, type=float, help="dropout rate, default=0.0" 135 | ) 136 | @click.option( 137 | "--beta", default=1.0, type=float, help="dragonnet loss param, default=1.0" 138 | ) 139 | @click.option( 140 | "--epochs", type=int, default=2000, help="number of training epochs, default=750" 141 | ) 142 | @click.option( 143 | "--batch-size", 144 | default=100, 145 | type=int, 146 | help="number of examples to read during each training step, default=100", 147 | ) 148 | @click.option( 149 | "--learning-rate", 150 | default=1e-3, 151 | type=float, 152 | help="learning rate for gradient descent, default=.0001", 153 | ) 154 | @click.option( 155 | "--mc-samples", 156 | type=int, 157 | default=100, 158 | help="number of mc_samples at inference, default=100", 159 | ) 160 | def tarnet( 161 | context, 162 | mode, 163 | base_filters, 164 | dropout_rate, 165 | beta, 166 | epochs, 167 | batch_size, 168 | learning_rate, 169 | mc_samples, 170 | ): 171 | config = context.obj 172 | dataset_name = config.get("dataset_name") 173 | exclude_population = config.get("exclude_population") 174 | 175 | @ray.remote( 176 | num_gpus=config.get("gpu_per_trial"), num_cpus=config.get("cpu_per_trial") 177 | ) 178 | def trainer(**kwargs): 179 | func = workflows.train_tarnet(**kwargs) 180 | return func 181 | 182 | results = [] 183 | for trial in range(config.get("num_trials")): 184 | results.append( 185 | trainer.remote( 186 | job_dir=config["job_dir"], 187 | dataset_name=dataset_name, 188 | data_dir=config.get("data_dir"), 189 | trial=trial, 190 | exclude_population=exclude_population, 191 | verbose=config.get("verbose"), 192 | mode=mode, 193 | base_filters=base_filters, 194 | dropout_rate=dropout_rate, 195 | beta=beta, 196 | epochs=epochs, 197 | batch_size=batch_size, 198 | learning_rate=learning_rate, 199 | mc_samples=mc_samples, 200 | ) 201 | ) 202 | ray.get(results) 203 | 204 | 205 | @cli.command("tlearner") 206 | @click.pass_context 207 | @click.option( 208 | "--base-filters", default=200, type=int, help="base number of filters, default=200" 209 | ) 210 | @click.option("--depth", default=5, type=int, help="depth of neural network, default=5") 211 | @click.option( 212 | "--dropout-rate", default=0.5, type=float, help="dropout rate, default=0.0" 213 | ) 214 | @click.option( 215 | "--epochs", type=int, default=2000, help="number of training epochs, default=750" 216 | ) 217 | @click.option( 218 | "--batch-size", 219 | default=100, 220 | type=int, 221 | help="number of examples to read during each training step, default=100", 222 | ) 223 | @click.option( 224 | "--learning-rate", 225 | default=1e-3, 226 | type=float, 227 | help="learning rate for gradient descent, default=.0001", 228 | ) 229 | @click.option( 230 | "--mc-samples", 231 | type=int, 232 | default=100, 233 | help="number of mc_samples at inference, default=100", 234 | ) 235 | def tlearner( 236 | context, 237 | base_filters, 238 | depth, 239 | dropout_rate, 240 | epochs, 241 | batch_size, 242 | learning_rate, 243 | mc_samples, 244 | ): 245 | config = context.obj 246 | dataset_name = config.get("dataset_name") 247 | exclude_population = config.get("exclude_population") 248 | 249 | @ray.remote( 250 | num_gpus=config.get("gpu_per_trial"), num_cpus=config.get("cpu_per_trial") 251 | ) 252 | def trainer(**kwargs): 253 | func = workflows.train_tlearner(**kwargs) 254 | return func 255 | 256 | results = [] 257 | for trial in range(config.get("num_trials")): 258 | results.append( 259 | trainer.remote( 260 | job_dir=config["job_dir"], 261 | dataset_name=dataset_name, 262 | data_dir=config.get("data_dir"), 263 | trial=trial, 264 | exclude_population=exclude_population, 265 | verbose=config.get("verbose"), 266 | base_filters=base_filters, 267 | depth=depth, 268 | dropout_rate=dropout_rate, 269 | epochs=epochs, 270 | batch_size=batch_size, 271 | learning_rate=learning_rate, 272 | mc_samples=mc_samples, 273 | ) 274 | ) 275 | ray.get(results) 276 | 277 | 278 | @cli.command("cevae") 279 | @click.pass_context 280 | @click.option( 281 | "--dim-latent", default=32, type=int, help="dimension of latent z, default=32" 282 | ) 283 | @click.option( 284 | "--base-filters", default=200, type=int, help="base number of filters, default=200" 285 | ) 286 | @click.option( 287 | "--dropout-rate", default=0.1, type=float, help="dropout rate, default=0.0" 288 | ) 289 | @click.option( 290 | "--beta", default=0.1, type=float, help="dragonnet loss param, default=1.0" 291 | ) 292 | @click.option( 293 | "--negative-sampling", 294 | default=True, 295 | type=bool, 296 | help="Use negative sampling during training, default=True", 297 | ) 298 | @click.option( 299 | "--epochs", type=int, default=1000, help="number of training epochs, default=750" 300 | ) 301 | @click.option( 302 | "--batch-size", 303 | default=100, 304 | type=int, 305 | help="number of examples to read during each training step, default=100", 306 | ) 307 | @click.option( 308 | "--learning-rate", 309 | default=1e-3, 310 | type=float, 311 | help="learning rate for gradient descent, default=.0001", 312 | ) 313 | @click.option( 314 | "--mc-samples", 315 | type=int, 316 | default=100, 317 | help="number of mc_samples at inference, default=100", 318 | ) 319 | def cevae( 320 | context, 321 | dim_latent, 322 | base_filters, 323 | dropout_rate, 324 | beta, 325 | negative_sampling, 326 | epochs, 327 | batch_size, 328 | learning_rate, 329 | mc_samples, 330 | ): 331 | config = context.obj 332 | dataset_name = config.get("dataset_name") 333 | exclude_population = config.get("exclude_population") 334 | 335 | @ray.remote( 336 | num_gpus=config.get("gpu_per_trial"), num_cpus=config.get("cpu_per_trial") 337 | ) 338 | def trainer(**kwargs): 339 | func = workflows.train_cevae(**kwargs) 340 | return func 341 | 342 | results = [] 343 | for trial in range(config.get("num_trials")): 344 | results.append( 345 | trainer.remote( 346 | job_dir=config["job_dir"], 347 | dataset_name=dataset_name, 348 | data_dir=config.get("data_dir"), 349 | trial=trial, 350 | exclude_population=exclude_population, 351 | verbose=config.get("verbose"), 352 | dim_latent=dim_latent, 353 | base_filters=base_filters, 354 | dropout_rate=dropout_rate, 355 | beta=beta, 356 | negative_sampling=negative_sampling, 357 | epochs=epochs, 358 | batch_size=batch_size, 359 | learning_rate=learning_rate, 360 | mc_samples=mc_samples, 361 | ) 362 | ) 363 | ray.get(results) 364 | 365 | 366 | @cli.command() 367 | @click.option( 368 | "--experiment-dir", 369 | type=str, 370 | required=True, 371 | help="Location of saved experiment files", 372 | ) 373 | def evaluate( 374 | experiment_dir, 375 | ): 376 | summary_path = os.path.join(experiment_dir, "summary.json") 377 | if not os.path.exists(summary_path): 378 | ray.init( 379 | dashboard_host="127.0.0.1", 380 | ignore_reinit_error=True, 381 | ) 382 | results = [] 383 | _, dirs, _ = list(os.walk(experiment_dir))[0] 384 | for trial_dir in dirs: 385 | output_dir = os.path.join(experiment_dir, trial_dir) 386 | if os.path.exists( 387 | os.path.join(output_dir, "predictions_train.npz") 388 | ) and os.path.exists(os.path.join(output_dir, "predictions_test.npz")): 389 | results.append(workflows.evaluate.remote(output_dir)) 390 | summary = workflows.build_summary( 391 | results=ray.get(results), experiment_dir=experiment_dir 392 | ) 393 | else: 394 | with open(summary_path) as summary_file: 395 | summary = json.load(summary_file) 396 | workflows.summarize(summary=summary, experiment_dir=experiment_dir) 397 | 398 | 399 | if __name__ == "__main__": 400 | cli() 401 | -------------------------------------------------------------------------------- /ucate/application/workflows/__init__.py: -------------------------------------------------------------------------------- 1 | from ucate.application.workflows.tarnet import train as train_tarnet 2 | 3 | from ucate.application.workflows.tlearner import train as train_tlearner 4 | 5 | from ucate.application.workflows.cevae import train as train_cevae 6 | 7 | from ucate.application.workflows.evaluation import evaluate 8 | from ucate.application.workflows.evaluation import summarize 9 | from ucate.application.workflows.evaluation import build_summary 10 | -------------------------------------------------------------------------------- /ucate/application/workflows/bart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import rpy2.robjects as robjects 7 | import rpy2.robjects.packages as rpackages 8 | from rpy2.robjects import numpy2ri 9 | from rpy2.robjects.packages import importr 10 | from rpy2.robjects.vectors import StrVector 11 | 12 | from ucate.library import data 13 | from ucate.library import models 14 | from ucate.library.utils import plotting 15 | 16 | 17 | def train( 18 | output_dir, 19 | dataset_name, 20 | data_dir, 21 | trial, 22 | exclude_population, 23 | verbose, 24 | ): 25 | gpus = tf.config.experimental.list_physical_devices("GPU") 26 | if gpus: 27 | try: 28 | for gpu in gpus: 29 | tf.config.experimental.set_memory_growth(gpu, True) 30 | logical_gpus = tf.config.experimental.list_logical_devices("GPU") 31 | except RuntimeError as e: 32 | print(e) 33 | checkpoint_dir = os.path.join(output_dir, "checkpoints") 34 | dbarts = importr("dbarts") 35 | # Instantiate data loaders 36 | dl = data.DATASETS[dataset_name]( 37 | path=data_dir, trial=trial, exclude_population=exclude_population 38 | ) 39 | x_train, y_train, t_train, examples_per_treatment = dl.get_training_data() 40 | x_test, cate = dl.get_test_data(test_set=True) 41 | num_train = len(x_train) 42 | num_test = len(x_test) 43 | x_train = np.reshape(x_train, (num_train, -1)) 44 | x_test = np.reshape(x_test, (num_test, -1)) 45 | xt_train = np.hstack([x_train, t_train[:, -1:]]) 46 | xt_test = np.vstack( 47 | [ 48 | np.hstack([x_train, np.zeros((num_train, 1), "float32")]), 49 | np.hstack([x_train, np.ones((num_train, 1), "float32")]), 50 | np.hstack([x_test, np.zeros((num_test, 1), "float32")]), 51 | np.hstack([x_test, np.ones((num_test, 1), "float32")]), 52 | ] 53 | ) 54 | # Instantiate models 55 | model = dbarts.bart(xt_train, y_train, xt_test, verbose=verbose) 56 | model_dict = dict(zip(model.names, map(list, list(model)))) 57 | model_prop = models.MODELS["cnn" if dataset_name == "cemnist" else "mlp"]( 58 | num_examples=sum(examples_per_treatment), 59 | dim_hidden=200, 60 | dropout_rate=0.5, 61 | regression=False, 62 | depth=2, 63 | ) 64 | model_prop.compile( 65 | optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), 66 | loss=tf.keras.losses.BinaryCrossentropy(), 67 | metrics=[tf.metrics.BinaryAccuracy()], 68 | loss_weights=[0.0, 0.0], 69 | ) 70 | model_prop_checkpoint = os.path.join(checkpoint_dir, "model_prop") 71 | # Instantiate trainer 72 | _ = model_prop.fit( 73 | [x_train, t_train[:, -1]], 74 | [t_train[:, -1], np.zeros_like(t_train[:, -1])], 75 | batch_size=100, 76 | epochs=2000, 77 | validation_split=0.3, 78 | shuffle=True, 79 | callbacks=[ 80 | tf.keras.callbacks.ModelCheckpoint( 81 | filepath=model_prop_checkpoint, 82 | save_best_only=True, 83 | save_weights_only=True, 84 | ), 85 | tf.keras.callbacks.EarlyStopping(patience=50), 86 | ], 87 | verbose=verbose, 88 | ) 89 | # Restore best models 90 | model_prop.load_weights(model_prop_checkpoint) 91 | # Predict ys 92 | y_hat = ( 93 | np.asarray(model_dict["yhat.test"]) 94 | .reshape((2 * num_train + 2 * num_test, -1)) 95 | .transpose() 96 | ) 97 | y_0_train = y_hat[:, :num_train] * dl.y_std + dl.y_mean 98 | y_1_train = y_hat[:, num_train : 2 * num_train] * dl.y_std + dl.y_mean 99 | y_0_test = y_hat[:, 2 * num_train : 2 * num_train + num_test] * dl.y_std + dl.y_mean 100 | y_1_test = ( 101 | y_hat[:, 2 * num_train + num_test : 2 * num_train + 2 * num_test] * dl.y_std 102 | + dl.y_mean 103 | ) 104 | # Predict propensity 105 | p_t_train, _ = model_prop.predict( 106 | [x_train, np.zeros((num_train,), "float32")], 107 | batch_size=200, 108 | workers=8, 109 | use_multiprocessing=True, 110 | ) 111 | p_t_test, _ = model_prop.predict( 112 | [x_test, np.zeros((num_test,), "float32")], 113 | batch_size=200, 114 | workers=8, 115 | use_multiprocessing=True, 116 | ) 117 | predictions_train = { 118 | "mu_0": y_0_train, 119 | "mu_1": y_1_train, 120 | "y_0": y_0_train, 121 | "y_1": y_1_train, 122 | "p_t": p_t_train, 123 | } 124 | predictions_test = { 125 | "mu_0": y_0_test, 126 | "mu_1": y_1_test, 127 | "y_0": y_0_test, 128 | "y_1": y_1_test, 129 | "p_t": p_t_test, 130 | } 131 | 132 | np.savez(os.path.join(output_dir, "predictions_train.npz"), **predictions_train) 133 | np.savez(os.path.join(output_dir, "predictions_test.npz"), **predictions_test) 134 | 135 | plotting.error_bars( 136 | data={ 137 | "predictions (95% CI)": [ 138 | (y_1_test - y_0_test).mean(0), 139 | cate, 140 | 2 * (y_1_test - y_0_test).std(0), 141 | ] 142 | }, 143 | file_name=os.path.join(output_dir, "cate_scatter_test.png"), 144 | ) 145 | plotting.histogram( 146 | x={ 147 | "$t=0$ test": p_t_test[dl.get_t(True)[:, 0] > 0.5], 148 | "$t=1$ test": p_t_test[dl.get_t(True)[:, 1] >= 0.5], 149 | }, 150 | bins=128, 151 | alpha=0.5, 152 | x_label="$p(t=1 | \mathbf{x})$", 153 | y_label="Number of individuals", 154 | x_limit=(0.0, 1.0), 155 | file_name=os.path.join(output_dir, "cate_propensity_test.png"), 156 | ) 157 | shutil.rmtree(checkpoint_dir) 158 | return -1 159 | 160 | 161 | def install(): 162 | # Install BART 163 | robjects.r.options(download_file_method="curl") 164 | numpy2ri.activate() 165 | rj = importr("rJava", robject_translations={".env": "rj_env"}) 166 | rj._jinit(parameters="-Xmx16g", force_init=True) 167 | package_names = ["dbarts"] 168 | utils = rpackages.importr("utils") 169 | utils.chooseCRANmirror(ind=0) 170 | utils.chooseCRANmirror(ind=0) 171 | names_to_install = [x for x in package_names if not rpackages.isinstalled(x)] 172 | if len(names_to_install) > 0: 173 | utils.install_packages( 174 | StrVector(names_to_install), repos="http://cran.us.r-project.org" 175 | ) 176 | -------------------------------------------------------------------------------- /ucate/application/workflows/cevae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from ucate.library import data 8 | from ucate.library import models 9 | from ucate.library import evaluation 10 | from ucate.library.utils import plotting 11 | 12 | 13 | def train( 14 | job_dir, 15 | dataset_name, 16 | data_dir, 17 | trial, 18 | exclude_population, 19 | verbose, 20 | dim_latent, 21 | base_filters, 22 | dropout_rate, 23 | beta, 24 | negative_sampling, 25 | batch_size, 26 | epochs, 27 | learning_rate, 28 | mc_samples, 29 | ): 30 | gpus = tf.config.experimental.list_physical_devices("GPU") 31 | if gpus: 32 | try: 33 | for gpu in gpus: 34 | tf.config.experimental.set_memory_growth(gpu, True) 35 | logical_gpus = tf.config.experimental.list_logical_devices("GPU") 36 | except RuntimeError as e: 37 | print(e) 38 | print("TRIAL {:04d} ".format(trial)) 39 | experiment_name = f"dl-{dim_latent}_bf-{base_filters}_dr-{dropout_rate}_beta-{beta}_ns-{negative_sampling}_bs-{batch_size}_lr-{learning_rate}_ep-{exclude_population}" 40 | output_dir = os.path.join( 41 | job_dir, 42 | dataset_name, 43 | "cevae", 44 | experiment_name, 45 | f"trial_{trial:03d}", 46 | ) 47 | os.makedirs(output_dir, exist_ok=True) 48 | checkpoint_dir = os.path.join(output_dir, "checkpoints") 49 | config = { 50 | "job_dir": job_dir, 51 | "dataset_name": dataset_name, 52 | "data_dir": data_dir, 53 | "exclude_population": exclude_population, 54 | "trial": trial, 55 | "dim_latent": dim_latent, 56 | "base_filters": base_filters, 57 | "dropout_rate": dropout_rate, 58 | "beta": beta, 59 | "negative_sampling": negative_sampling, 60 | "batch_size": batch_size, 61 | "epochs": epochs, 62 | "learning_rate": learning_rate, 63 | "mc_samples": mc_samples, 64 | } 65 | config_file = os.path.join(output_dir, "config.json") 66 | with open(config_file, "w") as fp: 67 | json.dump(config, fp, indent=4, sort_keys=True) 68 | # Instantiate data loaders 69 | dl = data.DATASETS[dataset_name]( 70 | path=data_dir, trial=trial, exclude_population=exclude_population, center=True 71 | ) 72 | x_train, y_train, t_train, examples_per_treatment = dl.get_training_data() 73 | if dataset_name in ["acic", "ihdp"]: 74 | regression = True 75 | model_name = "mlp" 76 | loss = tf.keras.losses.MeanSquaredError() 77 | error = tf.keras.metrics.MeanAbsoluteError() 78 | else: 79 | regression = False 80 | model_name = "cnn" 81 | loss = tf.keras.losses.BinaryCrossentropy() 82 | error = tf.keras.metrics.BinaryAccuracy() 83 | # Instantiate models 84 | model = models.BayesianCEVAE( 85 | dim_x=[dl.dim_x_cont, dl.dim_x_bin] if regression else dl.dim_x, 86 | dim_t=2, 87 | dim_y=1, 88 | regression=regression, 89 | dim_latent=dim_latent, 90 | num_examples=examples_per_treatment, 91 | dim_hidden=base_filters, 92 | dropout_rate=dropout_rate, 93 | beta=beta, 94 | negative_sampling=negative_sampling, 95 | do_convolution=not regression, 96 | ) 97 | model.compile( 98 | optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 99 | loss=[loss, loss], 100 | metrics=[error], 101 | loss_weights=[0.0, 0.0], 102 | ) 103 | model_checkpoint = os.path.join(checkpoint_dir, "model_0") 104 | model_prop = models.MODELS[model_name]( 105 | num_examples=sum(examples_per_treatment), 106 | dim_hidden=base_filters, 107 | dropout_rate=dropout_rate, 108 | regression=False, 109 | depth=2, 110 | ) 111 | model_prop.compile( 112 | optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 113 | loss=[ 114 | tf.keras.losses.BinaryCrossentropy(), 115 | tf.keras.losses.BinaryCrossentropy(), 116 | ], 117 | metrics=[tf.metrics.BinaryAccuracy()], 118 | loss_weights=[0.0, 0.0], 119 | ) 120 | model_prop_checkpoint = os.path.join(checkpoint_dir, "model_prop") 121 | # Fit models 122 | hist = model.fit( 123 | [x_train, t_train, y_train], 124 | [y_train, np.zeros_like(y_train)], 125 | batch_size=batch_size, 126 | epochs=epochs, 127 | validation_split=0.3, 128 | shuffle=True, 129 | callbacks=[ 130 | tf.keras.callbacks.ModelCheckpoint( 131 | filepath=model_checkpoint, 132 | save_best_only=True, 133 | save_weights_only=True, 134 | monitor="val_output_1_loss", 135 | ), 136 | tf.keras.callbacks.EarlyStopping(monitor="val_output_1_loss", patience=50), 137 | ], 138 | verbose=verbose, 139 | ) 140 | hist_prop = model_prop.fit( 141 | [x_train, t_train[:, -1]], 142 | [t_train[:, -1], np.zeros_like(t_train[:, -1])], 143 | batch_size=batch_size, 144 | epochs=epochs, 145 | validation_split=0.3, 146 | shuffle=True, 147 | callbacks=[ 148 | tf.keras.callbacks.ModelCheckpoint( 149 | filepath=model_prop_checkpoint, 150 | save_best_only=True, 151 | save_weights_only=True, 152 | ), 153 | tf.keras.callbacks.EarlyStopping(patience=50), 154 | ], 155 | verbose=verbose, 156 | ) 157 | # Restore best models 158 | model.load_weights(model_checkpoint) 159 | model_prop.load_weights(model_prop_checkpoint) 160 | 161 | predictions_train = evaluation.get_predictions( 162 | dl=dl, 163 | model_0=model, 164 | model_1=None, 165 | model_prop=model_prop, 166 | mc_samples=mc_samples, 167 | test_set=False, 168 | ) 169 | 170 | predictions_test = evaluation.get_predictions( 171 | dl=dl, 172 | model_0=model, 173 | model_1=None, 174 | model_prop=model_prop, 175 | mc_samples=mc_samples, 176 | test_set=True, 177 | ) 178 | 179 | np.savez(os.path.join(output_dir, "predictions_train.npz"), **predictions_train) 180 | np.savez(os.path.join(output_dir, "predictions_test.npz"), **predictions_test) 181 | 182 | _, cate = dl.get_test_data(test_set=True) 183 | plotting.error_bars( 184 | data={ 185 | "predictions (95% CI)": [ 186 | (predictions_test["mu_1"] - predictions_test["mu_0"]).mean(0).ravel(), 187 | cate, 188 | 2 189 | * (predictions_test["mu_1"] - predictions_test["mu_0"]).std(0).ravel(), 190 | ] 191 | }, 192 | file_name=os.path.join(output_dir, "cate_scatter_test.png"), 193 | ) 194 | shutil.rmtree(checkpoint_dir) 195 | -------------------------------------------------------------------------------- /ucate/application/workflows/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ray 3 | import json 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from scipy import stats 9 | 10 | from ucate.library import data 11 | from ucate.library import evaluation 12 | from ucate.library.utils import plotting 13 | 14 | 15 | REJECT_PCTS = [ 16 | 0.00, 17 | 0.05, 18 | 0.10, 19 | 0.15, 20 | 0.2, 21 | 0.25, 22 | 0.30, 23 | 0.35, 24 | 0.40, 25 | 0.45, 26 | 0.50, 27 | 0.55, 28 | 0.60, 29 | 0.65, 30 | 0.70, 31 | 0.75, 32 | 0.80, 33 | 0.85, 34 | 0.90, 35 | 0.95, 36 | ] 37 | 38 | POLICIES = [ 39 | "Propensity quantile", 40 | "Propensity trimming", 41 | "Epistemic Uncertainty", 42 | "Random", 43 | ] 44 | 45 | 46 | @ray.remote 47 | def evaluate(output_dir): 48 | predictions_train = np.load( 49 | os.path.join(output_dir, "predictions_train.npz"), allow_pickle=True 50 | ) 51 | predictions_test = np.load( 52 | os.path.join(output_dir, "predictions_test.npz"), allow_pickle=True 53 | ) 54 | with open(os.path.join(output_dir, "config.json")) as json_file: 55 | config = json.load(json_file) 56 | trial = config["trial"] 57 | dataset_name = config.get("dataset_name") 58 | regression = dataset_name in ["acic", "ihdp"] 59 | exclude_population = config.get("exclude_population") 60 | # Evaluate on training set 61 | print(f"TRIAL {trial:04d} ") 62 | print("Training Evaluation") 63 | dl = data.DATASETS[dataset_name]( 64 | path=config.get("data_dir"), 65 | trial=trial, 66 | exclude_population=exclude_population, 67 | ) 68 | pehe_train, error_train, quantiles = evaluation.evaluate_2( 69 | dl=dl, 70 | predictions=predictions_train, 71 | regression=regression, 72 | test_set=False, 73 | output_dir=output_dir, 74 | reject_pcts=REJECT_PCTS, 75 | exclude_population=exclude_population, 76 | ) 77 | print("\n") 78 | print("Test Evaluation") 79 | pehe_test, error_test, _ = evaluation.evaluate_2( 80 | dl=dl, 81 | predictions=predictions_test, 82 | regression=regression, 83 | test_set=True, 84 | output_dir=output_dir, 85 | reject_pcts=REJECT_PCTS, 86 | quantiles=quantiles, 87 | exclude_population=exclude_population, 88 | ) 89 | 90 | result = { 91 | "trial": trial, 92 | "train": {"pehe": pehe_train, "error": error_train}, 93 | "test": {"pehe": pehe_test, "error": error_test}, 94 | } 95 | 96 | with open(os.path.join(output_dir, "result.json"), "w") as outfile: 97 | json.dump(result, outfile, indent=4, sort_keys=True) 98 | 99 | return result 100 | 101 | 102 | def build_summary(results, experiment_dir): 103 | summary = { 104 | "train": { 105 | "error": { 106 | "Propensity quantile": [], 107 | "Propensity trimming": [], 108 | "Epistemic Uncertainty": [], 109 | "Random": [], 110 | }, 111 | "pehe": { 112 | "Propensity quantile": [], 113 | "Propensity trimming": [], 114 | "Epistemic Uncertainty": [], 115 | "Random": [], 116 | }, 117 | }, 118 | "test": { 119 | "error": { 120 | "Propensity quantile": [], 121 | "Propensity trimming": [], 122 | "Epistemic Uncertainty": [], 123 | "Random": [], 124 | }, 125 | "pehe": { 126 | "Propensity quantile": [], 127 | "Propensity trimming": [], 128 | "Epistemic Uncertainty": [], 129 | "Random": [], 130 | }, 131 | }, 132 | } 133 | for result in results: 134 | for split in ["train", "test"]: 135 | for metric in ["error", "pehe"]: 136 | for policy in POLICIES: 137 | summary[split][metric][policy].append(result[split][metric][policy]) 138 | 139 | with open(os.path.join(experiment_dir, "summary.json"), "w") as outfile: 140 | json.dump(summary, outfile, indent=4, sort_keys=True) 141 | return summary 142 | 143 | 144 | def summarize(summary, experiment_dir): 145 | for split in ["train", "test"]: 146 | for metric in ["error", "pehe"]: 147 | ys = {} 148 | df = pd.DataFrame(index=POLICIES, columns=REJECT_PCTS) 149 | for policy in POLICIES: 150 | arr = np.asarray(summary[split][metric][policy]) 151 | mean_val = np.nanmean(arr, 0) 152 | ste_val = stats.sem(arr, 0, nan_policy="omit") 153 | ys.update({policy: (mean_val, ste_val)}) 154 | row = [f"{m:.03f}+-{s:.03f}" for m, s in zip(mean_val, ste_val)] 155 | df.loc[policy] = row 156 | plotting.sweep( 157 | x=REJECT_PCTS, 158 | ys=ys, 159 | y_label="$\sqrt{\epsilon_{PEHE}}$" 160 | if metric == "pehe" 161 | else "Number of errors / N", 162 | file_name=os.path.join(experiment_dir, f"{split}_{metric}_sweep.png"), 163 | ) 164 | df.to_csv(path_or_buf=os.path.join(experiment_dir, f"{split}_{metric}.csv")) 165 | -------------------------------------------------------------------------------- /ucate/application/workflows/tarnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from ucate.library import data 8 | from ucate.library import models 9 | from ucate.library import evaluation 10 | from ucate.library.utils import plotting 11 | 12 | 13 | def train( 14 | job_dir, 15 | dataset_name, 16 | data_dir, 17 | trial, 18 | exclude_population, 19 | verbose, 20 | mode, 21 | base_filters, 22 | dropout_rate, 23 | beta, 24 | batch_size, 25 | epochs, 26 | learning_rate, 27 | mc_samples, 28 | ): 29 | gpus = tf.config.experimental.list_physical_devices("GPU") 30 | if gpus: 31 | try: 32 | for gpu in gpus: 33 | tf.config.experimental.set_memory_growth(gpu, True) 34 | logical_gpus = tf.config.experimental.list_logical_devices("GPU") 35 | except RuntimeError as e: 36 | print(e) 37 | print("TRIAL {:04d} ".format(trial)) 38 | experiment_name = f"md-{mode}_bf-{base_filters}_dr-{dropout_rate}_beta-{beta}_bs-{batch_size}_lr-{learning_rate}_ep-{exclude_population}" 39 | output_dir = os.path.join( 40 | job_dir, 41 | dataset_name, 42 | "tarnet", 43 | experiment_name, 44 | f"trial_{trial:03d}", 45 | ) 46 | os.makedirs(output_dir, exist_ok=True) 47 | checkpoint_dir = os.path.join(output_dir, "checkpoints") 48 | config = { 49 | "job_dir": job_dir, 50 | "dataset_name": dataset_name, 51 | "data_dir": data_dir, 52 | "exclude_population": exclude_population, 53 | "trial": trial, 54 | "mode": mode, 55 | "base_filters": base_filters, 56 | "dropout_rate": dropout_rate, 57 | "beta": beta, 58 | "batch_size": batch_size, 59 | "epochs": epochs, 60 | "learning_rate": learning_rate, 61 | "mc_samples": mc_samples, 62 | } 63 | config_file = os.path.join(output_dir, "config.json") 64 | with open(config_file, "w") as fp: 65 | json.dump(config, fp, indent=4, sort_keys=True) 66 | # Instantiate data loaders 67 | dl = data.DATASETS[dataset_name]( 68 | path=data_dir, trial=trial, exclude_population=exclude_population 69 | ) 70 | x_train, y_train, t_train, examples_per_treatment = dl.get_training_data() 71 | if dataset_name in ["acic", "ihdp"]: 72 | regression = True 73 | model_name = "mlp" 74 | loss = tf.keras.losses.MeanSquaredError() 75 | error = tf.keras.metrics.MeanAbsoluteError() 76 | else: 77 | regression = False 78 | model_name = "cnn" 79 | loss = tf.keras.losses.BinaryCrossentropy() 80 | error = tf.keras.metrics.BinaryAccuracy() 81 | # Instantiate models 82 | model = models.TARNet( 83 | do_convolution=model_name == "cnn", 84 | num_examples=examples_per_treatment, 85 | dim_hidden=base_filters, 86 | regression=regression, 87 | dropout_rate=dropout_rate, 88 | beta=beta, 89 | mode=mode, 90 | ) 91 | model.compile( 92 | optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 93 | loss=[loss, loss], 94 | metrics=[error], 95 | loss_weights=[0.0, 0.0], 96 | ) 97 | model_checkpoint = os.path.join(checkpoint_dir, "model_0") 98 | model_prop = models.MODELS[model_name]( 99 | num_examples=sum(examples_per_treatment), 100 | dim_hidden=base_filters, 101 | dropout_rate=dropout_rate, 102 | regression=False, 103 | depth=2, 104 | ) 105 | model_prop.compile( 106 | optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 107 | loss=[ 108 | tf.keras.losses.BinaryCrossentropy(), 109 | tf.keras.losses.BinaryCrossentropy(), 110 | ], 111 | metrics=[tf.metrics.BinaryAccuracy()], 112 | loss_weights=[0.0, 0.0], 113 | ) 114 | model_prop_checkpoint = os.path.join(checkpoint_dir, "model_prop") 115 | # Instantiate trainer 116 | _ = model.fit( 117 | [x_train, t_train, y_train], 118 | [y_train, np.zeros_like(y_train)], 119 | batch_size=batch_size, 120 | epochs=epochs, 121 | validation_split=0.3, 122 | shuffle=True, 123 | callbacks=[ 124 | tf.keras.callbacks.ModelCheckpoint( 125 | filepath=model_checkpoint, save_best_only=True, save_weights_only=True 126 | ), 127 | tf.keras.callbacks.EarlyStopping(patience=50), 128 | ], 129 | verbose=verbose, 130 | ) 131 | _ = model_prop.fit( 132 | [x_train, t_train[:, -1]], 133 | [t_train[:, -1], np.zeros_like(t_train[:, -1])], 134 | batch_size=batch_size, 135 | epochs=epochs, 136 | validation_split=0.3, 137 | shuffle=True, 138 | callbacks=[ 139 | tf.keras.callbacks.ModelCheckpoint( 140 | filepath=model_prop_checkpoint, 141 | save_best_only=True, 142 | save_weights_only=True, 143 | ), 144 | tf.keras.callbacks.EarlyStopping(patience=50), 145 | ], 146 | verbose=verbose, 147 | ) 148 | # Restore best models 149 | model.load_weights(model_checkpoint) 150 | model_prop.load_weights(model_prop_checkpoint) 151 | 152 | predictions_train = evaluation.get_predictions( 153 | dl=dl, 154 | model_0=model, 155 | model_1=None, 156 | model_prop=model_prop, 157 | mc_samples=mc_samples, 158 | test_set=False, 159 | ) 160 | 161 | predictions_test = evaluation.get_predictions( 162 | dl=dl, 163 | model_0=model, 164 | model_1=None, 165 | model_prop=model_prop, 166 | mc_samples=mc_samples, 167 | test_set=True, 168 | ) 169 | 170 | np.savez(os.path.join(output_dir, "predictions_train.npz"), **predictions_train) 171 | np.savez(os.path.join(output_dir, "predictions_test.npz"), **predictions_test) 172 | 173 | _, cate = dl.get_test_data(test_set=True) 174 | plotting.error_bars( 175 | data={ 176 | "predictions (95% CI)": [ 177 | (predictions_test["mu_1"] - predictions_test["mu_0"]).mean(0).ravel(), 178 | cate, 179 | 2 180 | * (predictions_test["mu_1"] - predictions_test["mu_0"]).std(0).ravel(), 181 | ] 182 | }, 183 | file_name=os.path.join(output_dir, "cate_scatter_test.png"), 184 | ) 185 | shutil.rmtree(checkpoint_dir) 186 | -------------------------------------------------------------------------------- /ucate/application/workflows/tlearner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from ucate.library import data 8 | from ucate.library import models 9 | from ucate.library import evaluation 10 | from ucate.library.utils import plotting 11 | 12 | 13 | def train( 14 | job_dir, 15 | dataset_name, 16 | data_dir, 17 | trial, 18 | exclude_population, 19 | verbose, 20 | base_filters, 21 | depth, 22 | dropout_rate, 23 | batch_size, 24 | epochs, 25 | learning_rate, 26 | mc_samples, 27 | ): 28 | gpus = tf.config.experimental.list_physical_devices("GPU") 29 | if gpus: 30 | try: 31 | for gpu in gpus: 32 | tf.config.experimental.set_memory_growth(gpu, True) 33 | except RuntimeError as e: 34 | print(e) 35 | print("TRIAL {:04d} ".format(trial)) 36 | experiment_name = f"bf-{base_filters}_dp-{depth}_dr-{dropout_rate}_bs-{batch_size}_lr-{learning_rate}_ep-{exclude_population}" 37 | output_dir = os.path.join( 38 | job_dir, 39 | dataset_name, 40 | "tlearner", 41 | experiment_name, 42 | f"trial_{trial:03d}", 43 | ) 44 | os.makedirs(output_dir, exist_ok=True) 45 | checkpoint_dir = os.path.join(output_dir, "checkpoints") 46 | config = { 47 | "job_dir": job_dir, 48 | "dataset_name": dataset_name, 49 | "data_dir": data_dir, 50 | "exclude_population": exclude_population, 51 | "trial": trial, 52 | "base_filters": base_filters, 53 | "depth": depth, 54 | "dropout_rate": dropout_rate, 55 | "batch_size": batch_size, 56 | "epochs": epochs, 57 | "learning_rate": learning_rate, 58 | "mc_samples": mc_samples, 59 | } 60 | config_file = os.path.join(output_dir, "config.json") 61 | with open(config_file, "w") as fp: 62 | json.dump(config, fp, indent=4, sort_keys=True) 63 | # Instantiate data loaders 64 | dl = data.DATASETS[dataset_name]( 65 | path=data_dir, trial=trial, exclude_population=exclude_population 66 | ) 67 | if dataset_name in ["acic", "ihdp"]: 68 | regression = True 69 | model_name = "mlp" 70 | loss = tf.keras.losses.MeanSquaredError() 71 | error = tf.keras.metrics.MeanAbsoluteError() 72 | else: 73 | regression = False 74 | model_name = "cnn" 75 | loss = tf.keras.losses.BinaryCrossentropy() 76 | error = tf.keras.metrics.BinaryAccuracy() 77 | x_train, y_train, t_train, examples_per_treatment = dl.get_training_data() 78 | idx_0_train = np.where(t_train[:, 0])[0] 79 | idx_1_train = np.where(t_train[:, 1])[0] 80 | # Instantiate models 81 | model_0 = models.MODELS[model_name]( 82 | num_examples=examples_per_treatment[0], 83 | dim_hidden=base_filters, 84 | dropout_rate=dropout_rate, 85 | regression=regression, 86 | depth=depth, 87 | ) 88 | model_0.compile( 89 | optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 90 | loss=loss, 91 | metrics=[error], 92 | loss_weights=[0.0, 0.0], 93 | ) 94 | model_0_checkpoint = os.path.join(checkpoint_dir, "model_0") 95 | model_1 = models.MODELS[model_name]( 96 | num_examples=examples_per_treatment[1], 97 | dim_hidden=base_filters, 98 | dropout_rate=dropout_rate, 99 | regression=regression, 100 | depth=depth, 101 | ) 102 | model_1.compile( 103 | optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 104 | loss=loss, 105 | metrics=[error], 106 | loss_weights=[0.0, 0.0], 107 | ) 108 | model_1_checkpoint = os.path.join(checkpoint_dir, "model_1") 109 | model_prop = models.MODELS[model_name]( 110 | num_examples=sum(examples_per_treatment), 111 | dim_hidden=base_filters, 112 | dropout_rate=dropout_rate, 113 | regression=False, 114 | depth=2, 115 | ) 116 | model_prop.compile( 117 | optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 118 | loss=tf.keras.losses.BinaryCrossentropy(), 119 | metrics=[tf.metrics.BinaryAccuracy()], 120 | loss_weights=[0.0, 0.0], 121 | ) 122 | model_prop_checkpoint = os.path.join(checkpoint_dir, "model_prop") 123 | # Instantiate trainer 124 | _ = model_0.fit( 125 | [x_train[idx_0_train], y_train[idx_0_train]], 126 | [y_train[idx_0_train], np.zeros_like(y_train[idx_0_train])], 127 | batch_size=batch_size, 128 | epochs=epochs, 129 | validation_split=0.3, 130 | shuffle=True, 131 | callbacks=[ 132 | tf.keras.callbacks.ModelCheckpoint( 133 | filepath=model_0_checkpoint, save_best_only=True, save_weights_only=True 134 | ), 135 | tf.keras.callbacks.EarlyStopping(patience=50), 136 | ], 137 | verbose=verbose, 138 | ) 139 | _ = model_1.fit( 140 | [x_train[idx_1_train], y_train[idx_1_train]], 141 | [y_train[idx_1_train], np.zeros_like(y_train[idx_1_train])], 142 | batch_size=batch_size, 143 | epochs=epochs, 144 | validation_split=0.3, 145 | shuffle=True, 146 | callbacks=[ 147 | tf.keras.callbacks.ModelCheckpoint( 148 | filepath=model_1_checkpoint, save_best_only=True, save_weights_only=True 149 | ), 150 | tf.keras.callbacks.EarlyStopping(patience=50), 151 | ], 152 | verbose=verbose, 153 | ) 154 | _ = model_prop.fit( 155 | [x_train, t_train[:, -1]], 156 | [t_train[:, -1], np.zeros_like(t_train[:, -1])], 157 | batch_size=batch_size, 158 | epochs=epochs, 159 | validation_split=0.3, 160 | shuffle=True, 161 | callbacks=[ 162 | tf.keras.callbacks.ModelCheckpoint( 163 | filepath=model_prop_checkpoint, 164 | save_best_only=True, 165 | save_weights_only=True, 166 | ), 167 | tf.keras.callbacks.EarlyStopping(patience=50), 168 | ], 169 | verbose=verbose, 170 | ) 171 | # Restore best models 172 | model_0.load_weights(model_0_checkpoint) 173 | model_1.load_weights(model_1_checkpoint) 174 | model_prop.load_weights(model_prop_checkpoint) 175 | 176 | predictions_train = evaluation.get_predictions( 177 | dl=dl, 178 | model_0=model_0, 179 | model_1=model_1, 180 | model_prop=model_prop, 181 | mc_samples=mc_samples, 182 | test_set=False, 183 | ) 184 | 185 | predictions_test = evaluation.get_predictions( 186 | dl=dl, 187 | model_0=model_0, 188 | model_1=model_1, 189 | model_prop=model_prop, 190 | mc_samples=mc_samples, 191 | test_set=True, 192 | ) 193 | 194 | np.savez(os.path.join(output_dir, "predictions_train.npz"), **predictions_train) 195 | np.savez(os.path.join(output_dir, "predictions_test.npz"), **predictions_test) 196 | 197 | _, cate = dl.get_test_data(test_set=True) 198 | plotting.error_bars( 199 | data={ 200 | "predictions (95% CI)": [ 201 | (predictions_test["mu_1"] - predictions_test["mu_0"]).mean(0).ravel(), 202 | cate, 203 | 2 204 | * (predictions_test["mu_1"] - predictions_test["mu_0"]).std(0).ravel(), 205 | ] 206 | }, 207 | file_name=os.path.join(output_dir, "cate_scatter_test.png"), 208 | ) 209 | shutil.rmtree(checkpoint_dir) 210 | -------------------------------------------------------------------------------- /ucate/library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/ucate/cddc23596a463e2ce9e270cf075b3e924137bf7a/ucate/library/__init__.py -------------------------------------------------------------------------------- /ucate/library/data/__init__.py: -------------------------------------------------------------------------------- 1 | from ucate.library.data.ihdp import IHDP 2 | 3 | from ucate.library.data.acic import ACIC 4 | 5 | from ucate.library.data.cemnist import CEMNIST 6 | 7 | DATASETS = { 8 | "acic": ACIC, 9 | "ihdp": IHDP, 10 | "cemnist": CEMNIST, 11 | } 12 | -------------------------------------------------------------------------------- /ucate/library/data/acic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from tensorflow.keras.utils import to_categorical 5 | from sklearn import model_selection 6 | 7 | 8 | class ACIC(object): 9 | def __init__( 10 | self, 11 | path, 12 | trial, 13 | center=False, 14 | exclude_population=False, 15 | ): 16 | self.trial = trial 17 | x_df, targets_df = load_data( 18 | path=path, 19 | trial=trial, 20 | subset='1' 21 | ) 22 | train_dataset, test_dataset = train_test_split( 23 | x_df=x_df, 24 | targets_df=targets_df, 25 | trial=trial, 26 | test_size=0.3 27 | ) 28 | self.train_data = get_trial( 29 | dataset=train_dataset 30 | ) 31 | self.x_mean = self.train_data['x_cont'].mean(0, keepdims=True) 32 | self.x_std = self.train_data['x_cont'].std(0, keepdims=True) + 1e-7 33 | self.y_mean = self.train_data['y'].mean(dtype='float32') 34 | self.y_std = self.train_data['y'].std(dtype='float32') + 1e-7 35 | self.test_data = get_trial( 36 | dataset=test_dataset 37 | ) 38 | self.dim_x_cont = self.train_data['x_cont'].shape[-1] 39 | self.dim_x_bin = self.train_data['x_bin'].shape[-1] 40 | self.dim_x = self.dim_x_cont + self.dim_x_bin 41 | 42 | def get_training_data(self): 43 | x, y, t = self.preprocess(self.train_data) 44 | examples_per_treatment = t.sum(0) 45 | return x, y, t, examples_per_treatment 46 | 47 | def get_test_data(self, test_set=True): 48 | _data = self.test_data if test_set else self.train_data 49 | x, _, _ = self.preprocess(_data) 50 | mu1 = _data['mu1'].astype('float32') 51 | mu0 = _data['mu0'].astype('float32') 52 | cate = mu1 - mu0 53 | return x, cate 54 | 55 | def get_subpop(self, test_set=True): 56 | _data = self.test_data if test_set else self.train_data 57 | return _data['ind_subpop'] 58 | 59 | def get_t(self, test_set=True): 60 | _data = self.test_data if test_set else self.train_data 61 | return _data['t'] 62 | 63 | def preprocess(self, dataset): 64 | x_cont = (dataset['x_cont'] - self.x_mean) / self.x_std 65 | x_bin = dataset['x_bin'] 66 | x = np.hstack([x_cont, x_bin]) 67 | y = (dataset['y'].astype('float32') - self.y_mean) / self.y_std 68 | t = dataset['t'].astype('float32') 69 | return x, y, t 70 | 71 | 72 | def load_data( 73 | path, 74 | trial, 75 | subset='1' 76 | ): 77 | x_path = os.path.join(path, 'x.csv') 78 | targets_dir = os.path.join(path, str(trial + 1)) 79 | targets_paths = os.listdir(targets_dir) 80 | targets_paths.sort() 81 | x_df = pd.read_csv( 82 | x_path 83 | ) 84 | x_df['x_2'] = [ord(x) - 65 for x in x_df['x_2']] 85 | x_df['x_21'] = [ord(x) - 65 for x in x_df['x_21']] 86 | x_df['x_24'] = [ord(x) - 65 for x in x_df['x_24']] 87 | targets_df = pd.read_csv( 88 | os.path.join( 89 | targets_dir, 90 | targets_paths[0] 91 | ) 92 | ) 93 | return x_df, targets_df 94 | 95 | 96 | def train_test_split( 97 | x_df, 98 | targets_df, 99 | trial, 100 | test_size=0.3, 101 | ): 102 | x_df_train, x_df_test, targets_df_train, targets_df_test = model_selection.train_test_split( 103 | x_df, 104 | targets_df, 105 | test_size=test_size, 106 | random_state=trial, 107 | shuffle=True 108 | ) 109 | train_data = { 110 | 'x': x_df_train, 111 | 'targets': targets_df_train 112 | } 113 | test_data = { 114 | 'x': x_df_test, 115 | 'targets': targets_df_test 116 | } 117 | return train_data, test_data 118 | 119 | 120 | def get_trial( 121 | dataset 122 | ): 123 | cat_feats = {'x_2': 6, 'x_21': 16, 'x_24': 5} 124 | bin_feats = ['x_17', 'x_22', 'x_38', 'x_51', 'x_54'] 125 | cont_feats = [] 126 | for i in range(1, 59): 127 | feat_id = 'x_{}'.format(i) 128 | if (feat_id not in bin_feats) and (feat_id not in cat_feats.keys()): 129 | cont_feats.append(feat_id) 130 | x_df = dataset['x'] 131 | x_bin = x_df[bin_feats].to_numpy('float32') 132 | for k, v in cat_feats.items(): 133 | f = dataset['x'][k].to_numpy() 134 | f = to_categorical( 135 | f, 136 | num_classes=v, 137 | dtype='float32' 138 | ) 139 | x_bin = np.hstack([x_bin, f]) 140 | x_cont = x_df[cont_feats].to_numpy('float32') 141 | targets_df = dataset['targets'] 142 | t = targets_df['z'].to_numpy() 143 | y0 = targets_df['y0'].to_numpy() 144 | y1 = targets_df['y1'].to_numpy() 145 | y = np.zeros_like(t, 'float32') 146 | y[t > 0.5] = y1[t > 0.5] 147 | y[t < 0.5] = y0[t < 0.5] 148 | t_in = np.zeros((len(t), 2), 'float32') 149 | t_in[:, 0] = 1 - t 150 | t_in[:, 1] = t 151 | mu0 = targets_df['mu0'].to_numpy() 152 | mu1 = targets_df['mu1'].to_numpy() 153 | trial_data = { 154 | 'x_cont': x_cont, 155 | 'x_bin': x_bin, 156 | 'y': y.astype('float32'), 157 | 't': t_in.astype('float32'), 158 | 'mu0': mu0.astype('float32'), 159 | 'mu1': mu1.astype('float32') 160 | } 161 | return trial_data 162 | -------------------------------------------------------------------------------- /ucate/library/data/cemnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.datasets import mnist 3 | 4 | 5 | class CEMNIST(object): 6 | def __init__( 7 | self, 8 | path, 9 | trial, 10 | center=False, 11 | exclude_population=False, 12 | ): 13 | self.trial = trial 14 | train_dataset, test_dataset = mnist.load_data() 15 | self.train_data = get_trial_new( 16 | dataset=train_dataset, 17 | trial=trial, 18 | center=center 19 | ) 20 | self.y_mean = 0.0 21 | self.y_std = 1.0 22 | self.test_data = get_trial_new( 23 | dataset=test_dataset, 24 | trial=trial, 25 | center=center 26 | ) 27 | self.dim_x = (28, 28, 1) 28 | 29 | def get_training_data(self): 30 | x, y, t = self.preprocess(self.train_data) 31 | examples_per_treatment = t.sum(0) 32 | return x, y, t, examples_per_treatment 33 | 34 | def get_test_data(self, test_set=True): 35 | _data = self.test_data if test_set else self.train_data 36 | x, _, _ = self.preprocess(_data) 37 | mu1 = _data['mu1'].astype('float32') 38 | mu0 = _data['mu0'].astype('float32') 39 | cate = mu1 - mu0 40 | return x, cate 41 | 42 | def get_t(self, test_set=True): 43 | _data = self.test_data if test_set else self.train_data 44 | return _data['t'] 45 | 46 | def preprocess(self, dataset): 47 | x = dataset['x'].astype('float32') 48 | y = (dataset['y'].astype('float32') - self.y_mean) / self.y_std 49 | t = dataset['t'].astype('float32') 50 | return x, y, t 51 | 52 | def get_pops(self, test_set=True): 53 | _data = self.test_data if test_set else self.train_data 54 | return { 55 | '9': _data['ind_9'], 56 | '2': _data['ind_2'], 57 | 'other': _data['ind_other'] 58 | } 59 | 60 | 61 | def get_trial( 62 | dataset, 63 | center=False, 64 | induce_spurrious_correlation=False 65 | ): 66 | x = np.expand_dims(dataset[0], -1) / 255. 67 | if center: 68 | x = 2. * x - 1 69 | y = dataset[1] 70 | p = 0.5 71 | t = np.random.choice([0, 1], size=y.shape, p=[1 - p, p]) 72 | ind = np.in1d(y, [2, 6, 9]) 73 | ind_9 = np.in1d(y, 9) 74 | ind_6 = np.in1d(y, 6) 75 | ind_2 = np.in1d(y, 2) 76 | t[ind_9] = 0 77 | t[ind_2] = 1 78 | t_in = np.zeros((len(t), 2), 'float32') 79 | t_in[:, 0] = 1 - t 80 | t_in[:, 1] = t 81 | if induce_spurrious_correlation: 82 | x[np.in1d(t, 1), :2, :2] = 1. 83 | y = (ind_9 * (1 - t) + ind_6 * t).astype('float32') 84 | ycf = (ind_9 * t + ind_6 * (1 - t)).astype('float32') 85 | mu0 = y * (1 - t) + ycf * t 86 | mu1 = y * t + ycf * (1 - t) 87 | x = x.astype('float32') 88 | trial_data = { 89 | 'x': x[ind], 90 | 'y': y[ind], 91 | 't': t_in[ind], 92 | 'ycf': ycf[ind], 93 | 'mu0': mu0[ind], 94 | 'mu1': mu1[ind], 95 | 'yadd': 0., 96 | 'ymul': 1. 97 | } 98 | return trial_data 99 | 100 | 101 | def get_trial_new( 102 | dataset, 103 | trial, 104 | center=False 105 | ): 106 | rng = np.random.RandomState(trial) 107 | x = np.expand_dims(dataset[0], -1) / 255. 108 | if center: 109 | x = 2. * x - 1 110 | y = dataset[1] 111 | p_t = 0.5 112 | p_t_9 = 1 / 9 113 | t = rng.choice( 114 | [0, 1], 115 | size=y.shape, p=[1 - p_t, p_t], 116 | 117 | ) 118 | ind_9 = np.in1d(y, 9) 119 | ind_2 = np.in1d(y, 2) 120 | ind_even = np.in1d(y, [0, 4, 6, 8]) 121 | ind_odd = np.in1d(y, [1, 3, 5, 7]) 122 | num_examples = 2 * ind_9.sum() 123 | t[ind_9] = rng.choice([0, 1], size=ind_9.sum(), p=[1 - p_t_9, p_t_9]) 124 | t[ind_2] = 1 125 | t_in = np.zeros((len(t), 2), 'float32') 126 | t_in[:, 0] = 1 - t 127 | t_in[:, 1] = t 128 | p_x = (0.5 / 9) * np.ones_like(y) 129 | p_x[ind_9] = 0.5 130 | p_x = p_x / p_x.sum() 131 | y = np.zeros_like(y, 'float32') 132 | y[ind_9] = 1 - t[ind_9] 133 | y[ind_2] = t[ind_2] 134 | y[ind_even] = t[ind_even] 135 | y[ind_odd] = 1 - t[ind_odd] 136 | ycf = np.zeros_like(y, 'float32') 137 | ycf[ind_9] = t[ind_9] 138 | ycf[ind_2] = 1 - t[ind_2] 139 | ycf[ind_even] = 1 - t[ind_even] 140 | ycf[ind_odd] = t[ind_odd] 141 | mu0 = y * (1 - t) + ycf * t 142 | mu1 = y * t + ycf * (1 - t) 143 | x = x.astype('float32') 144 | ind = rng.choice(np.arange(len(y)), size=num_examples, p=p_x) 145 | trial_data = { 146 | 'x': x[ind], 147 | 'y': y[ind], 148 | 't': t_in[ind], 149 | 'ycf': ycf[ind], 150 | 'mu0': mu0[ind], 151 | 'mu1': mu1[ind], 152 | 'yadd': 0., 153 | 'ymul': 1., 154 | 'ind_2': ind_2[ind], 155 | 'ind_9': ind_9[ind], 156 | 'ind_other': (ind_even + ind_odd)[ind] 157 | } 158 | return trial_data 159 | -------------------------------------------------------------------------------- /ucate/library/data/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.model_selection import train_test_split 3 | 4 | 5 | def make_train_tune_datasets( 6 | x, 7 | y, 8 | t, 9 | validation_pct=0.3, 10 | random_state=1331 11 | ): 12 | x = x.astype('float32') 13 | y = y.astype('float32') 14 | t = t.astype('float32') 15 | 16 | x_train, x_tune, y_train, y_tune, t_train, t_tune = train_test_split( 17 | x, y, t, 18 | test_size=validation_pct, 19 | random_state=random_state 20 | ) 21 | examples_per_treatment = np.sum(t_train, 0) 22 | return (x_train, y_train, t_train), (x_tune, y_tune, t_tune), examples_per_treatment 23 | -------------------------------------------------------------------------------- /ucate/library/data/ihdp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | class IHDP(object): 6 | def __init__( 7 | self, 8 | path, 9 | trial, 10 | center=False, 11 | exclude_population=False, 12 | ): 13 | self.trial = trial 14 | train_dataset = np.load( 15 | os.path.join(path, 'ihdp_npci_1-1000.train.npz') 16 | ) 17 | test_dataset = np.load( 18 | os.path.join(path, 'ihdp_npci_1-1000.test.npz') 19 | ) 20 | self.train_data = get_trial( 21 | dataset=train_dataset, 22 | trial=trial, 23 | training=True, 24 | exclude_population=exclude_population 25 | ) 26 | self.y_mean = self.train_data['y'].mean(dtype='float32') 27 | self.y_std = self.train_data['y'].std(dtype='float32') 28 | self.test_data = get_trial( 29 | dataset=test_dataset, 30 | trial=trial, 31 | training=False, 32 | exclude_population=exclude_population 33 | ) 34 | self.dim_x_cont = self.train_data['x_cont'].shape[-1] 35 | self.dim_x_bin = self.train_data['x_bin'].shape[-1] 36 | self.dim_x = self.dim_x_cont + self.dim_x_bin 37 | 38 | def get_training_data(self): 39 | x, y, t = self.preprocess(self.train_data) 40 | examples_per_treatment = t.sum(0) 41 | return x, y, t, examples_per_treatment 42 | 43 | def get_test_data(self, test_set=True): 44 | _data = self.test_data if test_set else self.train_data 45 | x, _, _ = self.preprocess(_data) 46 | mu1 = _data['mu1'].astype('float32') 47 | mu0 = _data['mu0'].astype('float32') 48 | cate = mu1 - mu0 49 | return x, cate 50 | 51 | def get_subpop(self, test_set=True): 52 | _data = self.test_data if test_set else self.train_data 53 | return _data['ind_subpop'] 54 | 55 | def get_t(self, test_set=True): 56 | _data = self.test_data if test_set else self.train_data 57 | return _data['t'] 58 | 59 | def preprocess(self, dataset): 60 | x = np.hstack([dataset['x_cont'], dataset['x_bin']]) 61 | y = (dataset['y'].astype('float32') - self.y_mean) / self.y_std 62 | t = dataset['t'].astype('float32') 63 | return x, y, t 64 | 65 | 66 | def get_trial( 67 | dataset, 68 | trial, 69 | training=True, 70 | exclude_population=False 71 | ): 72 | bin_feats = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] 73 | cont_feats = [i for i in range(25) if i not in bin_feats] 74 | ind_subpop = dataset['x'][:, bin_feats[2], trial].astype('bool') 75 | x = dataset['x'][:, :, trial] 76 | if exclude_population: 77 | x = np.delete(x, bin_feats[2], axis=-1) 78 | bin_feats.pop(2) 79 | if training: 80 | idx_included = np.where(ind_subpop)[0] 81 | else: 82 | idx_included = np.arange(dataset['x'].shape[0], dtype='int32') 83 | else: 84 | idx_included = np.arange(dataset['x'].shape[0], dtype='int32') 85 | x_bin = dataset['x'][:, bin_feats, trial][idx_included] 86 | x_bin[:, 7] -= 1. 87 | t = dataset['t'][:, trial] 88 | t_in = np.zeros((len(t), 2), 'float32') 89 | t_in[:, 0] = 1 - t 90 | t_in[:, 1] = t 91 | trial_data = { 92 | 'x_bin': x_bin.astype('float32'), 93 | 'x_cont': dataset['x'][:, cont_feats, trial][idx_included].astype('float32'), 94 | 'y': dataset['yf'][:, trial][idx_included], 95 | 't': t_in[idx_included], 96 | 'ycf': dataset['ycf'][:, trial][idx_included], 97 | 'mu0': dataset['mu0'][:, trial][idx_included], 98 | 'mu1': dataset['mu1'][:, trial][idx_included], 99 | 'ate': dataset['ate'], 100 | 'yadd': dataset['yadd'], 101 | 'ymul': dataset['ymul'], 102 | 'ind_subpop': ind_subpop[idx_included] 103 | } 104 | return trial_data 105 | -------------------------------------------------------------------------------- /ucate/library/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from ucate.library.utils import plotting 4 | from ucate.library.utils import prediction 5 | from ucate.library import scratch 6 | 7 | 8 | def get_predictions( 9 | dl, 10 | model_0, 11 | model_1, 12 | model_prop, 13 | mc_samples, 14 | test_set, 15 | ): 16 | x, _ = dl.get_test_data(test_set=test_set) 17 | if model_1 is not None: 18 | (mu_0, y_0), (mu_1, y_1) = prediction.mc_sample_tl( 19 | x=x, 20 | model_0=model_0.mc_sample, 21 | model_1=model_1.mc_sample, 22 | mean=dl.y_mean, 23 | std=dl.y_std, 24 | mc_samples=mc_samples, 25 | ) 26 | else: 27 | t_0 = np.concatenate( 28 | [ 29 | np.ones((x.shape[0], 1), dtype="float32"), 30 | np.zeros((x.shape[0], 1), dtype="float32"), 31 | ], 32 | -1, 33 | ) 34 | t_1 = t_0[:, [1, 0]] 35 | (mu_0, y_0), (mu_1, y_1) = prediction.mc_sample_cevae( 36 | x=x, 37 | t=[t_0, t_1], 38 | model=model_0.mc_sample, 39 | mean=dl.y_mean, 40 | std=dl.y_std, 41 | mc_samples=mc_samples, 42 | ) 43 | p_t, _ = prediction.mc_sample_2(x=x, model=model_prop.mc_sample) 44 | p_t = p_t.mean(0) 45 | return {"mu_0": mu_0, "mu_1": mu_1, "y_0": y_0, "y_1": y_1, "p_t": p_t} 46 | 47 | 48 | def evaluate( 49 | dl, 50 | model_0, 51 | model_1, 52 | model_prop, 53 | mc_samples, 54 | regression, 55 | test_set, 56 | output_dir, 57 | reject_pcts, 58 | quantiles=None, 59 | exclude_population=False, 60 | ): 61 | predictions = get_predictions( 62 | dl=dl, 63 | model_0=model_0, 64 | model_1=model_1, 65 | model_prop=model_prop, 66 | mc_samples=mc_samples, 67 | test_set=test_set, 68 | ) 69 | pehe_stats, error_stats, quantiles = evaluate_2( 70 | dl=dl, 71 | predictions=predictions, 72 | regression=regression, 73 | test_set=test_set, 74 | output_dir=output_dir, 75 | reject_pcts=reject_pcts, 76 | quantiles=quantiles, 77 | exclude_population=exclude_population, 78 | ) 79 | return pehe_stats, error_stats, quantiles 80 | 81 | 82 | def evaluate_2( 83 | dl, 84 | predictions, 85 | regression, 86 | test_set, 87 | output_dir, 88 | reject_pcts, 89 | quantiles=None, 90 | exclude_population=False, 91 | ): 92 | tag = "test" if test_set else "train" 93 | _, cate = dl.get_test_data(test_set=test_set) 94 | cate_pred, predictive_uncrt, epistemic_unct = prediction.cate_measures( 95 | mu_0=predictions["mu_0"], 96 | mu_1=predictions["mu_1"], 97 | y_0=predictions["y_0"], 98 | y_1=predictions["y_1"], 99 | regression=regression, 100 | ) 101 | recommendation_pred = cate_pred.ravel() > 0.0 102 | recommendation_true = cate > 0.0 103 | errors = recommendation_pred != recommendation_true 104 | pehe_prop = [] 105 | errors_prop = [] 106 | pehe_prop_mag = [] 107 | errors_prop_mag = [] 108 | pehe_unct = [] 109 | errors_unct = [] 110 | pehe_unct_altr = [] 111 | errors_unct_altr = [] 112 | pehe_random = [] 113 | errors_random = [] 114 | num_examples = len(cate) 115 | p_t = predictions["p_t"] 116 | p_t_0 = p_t[dl.get_t(test_set)[:, 0] > 0.5] 117 | p_t_1 = p_t[dl.get_t(test_set)[:, 1] > 0.5] 118 | digit_counts = { 119 | "Propensity trimming": { 120 | "9": np.asarray([0] * len(reject_pcts)), 121 | "2": np.asarray([0] * len(reject_pcts)), 122 | "other": np.asarray([0] * len(reject_pcts)), 123 | }, 124 | "Epistemic Uncertainty": { 125 | "9": np.asarray([0] * len(reject_pcts)), 126 | "2": np.asarray([0] * len(reject_pcts)), 127 | "other": np.asarray([0] * len(reject_pcts)), 128 | }, 129 | } 130 | if not test_set: 131 | kde_0 = scratch.get_density_estimator(x=100.0 * p_t_0, bandwidth=2.0) 132 | kde_1 = scratch.get_density_estimator(x=100.0 * p_t_1, bandwidth=2.0) 133 | else: 134 | kde_0, kde_1 = None, None 135 | for i, pct in enumerate(reject_pcts): 136 | if not test_set: 137 | if quantiles is None: 138 | quantiles = {} 139 | quantiles["Propensity quantile"] = [ 140 | np.quantile(p_t, [pct / 2, 1.0 - (pct / 2)]) 141 | ] 142 | quantiles["kde"] = {"0": kde_0, "1": kde_1} 143 | overlap_score = get_overlap_score( 144 | kde_0=quantiles["kde"]["0"], 145 | kde_1=quantiles["kde"]["1"], 146 | p_t=p_t, 147 | num_0=len(p_t_0), 148 | num_1=len(p_t_1), 149 | ) 150 | quantiles["Propensity trimming"] = [np.quantile(overlap_score, pct)] 151 | quantiles["Epistemic Uncertainty"] = [ 152 | np.quantile(epistemic_unct, 1.0 - pct) 153 | ] 154 | else: 155 | quantiles["Propensity quantile"].append( 156 | np.quantile(p_t, [pct / 2, 1.0 - (pct / 2)]) 157 | ) 158 | overlap_score = get_overlap_score( 159 | kde_0=quantiles["kde"]["0"], 160 | kde_1=quantiles["kde"]["1"], 161 | p_t=p_t, 162 | num_0=len(p_t_0), 163 | num_1=len(p_t_1), 164 | ) 165 | quantiles["Propensity trimming"].append(np.quantile(overlap_score, pct)) 166 | quantiles["Epistemic Uncertainty"].append( 167 | np.quantile(epistemic_unct, 1.0 - pct) 168 | ) 169 | overlap_score = get_overlap_score( 170 | kde_0=quantiles["kde"]["0"], 171 | kde_1=quantiles["kde"]["1"], 172 | p_t=p_t, 173 | num_0=len(p_t_0), 174 | num_1=len(p_t_1), 175 | ) 176 | ind_prop = np.ravel( 177 | (p_t >= quantiles["Propensity quantile"][i][0]) 178 | * (p_t <= quantiles["Propensity quantile"][i][1]) 179 | ) 180 | ind_prop_mag = (overlap_score >= quantiles["Propensity trimming"][i]).ravel() 181 | ind_unct = (epistemic_unct <= quantiles["Epistemic Uncertainty"][i]).ravel() 182 | ind_random = np.random.choice([False, True], ind_unct.shape, p=[pct, 1.0 - pct]) 183 | pehe_prop.append( 184 | np.sqrt(np.square(cate[ind_prop] - cate_pred[ind_prop]).mean().ravel()) 185 | ) 186 | errors_prop.append(np.sum(errors[ind_prop]).ravel() / num_examples) 187 | pehe_prop_mag.append( 188 | np.sqrt( 189 | np.square(cate[ind_prop_mag] - cate_pred[ind_prop_mag]).mean().ravel() 190 | ) 191 | ) 192 | errors_prop_mag.append(np.sum(errors[ind_prop_mag]).ravel() / num_examples) 193 | pehe_unct.append( 194 | np.sqrt(np.square(cate[ind_unct] - cate_pred[ind_unct]).mean().ravel()) 195 | ) 196 | errors_unct.append(np.sum(errors[ind_unct]).ravel() / num_examples) 197 | pehe_random.append( 198 | np.sqrt(np.square(cate[ind_random] - cate_pred[ind_random]).mean().ravel()) 199 | ) 200 | errors_random.append(np.sum(errors[ind_random]).ravel() / num_examples) 201 | if not regression: 202 | digit_indices = dl.get_pops(test_set=test_set) 203 | for k, v in digit_indices.items(): 204 | digit_counts["Propensity trimming"][k][i] += ind_prop_mag[v].sum() 205 | digit_counts["Epistemic Uncertainty"][k][i] += ind_unct[v].sum() 206 | 207 | pehe_prop = np.asarray(pehe_prop).ravel() 208 | errors_prop = np.asarray(errors_prop).ravel() 209 | pehe_prop_mag = np.asarray(pehe_prop_mag).ravel() 210 | errors_prop_mag = np.asarray(errors_prop_mag).ravel() 211 | pehe_unct = np.asarray(pehe_unct).ravel() 212 | errors_unct = np.asarray(errors_unct).ravel() 213 | pehe_random = np.asarray(pehe_random).ravel() 214 | errors_random = np.asarray(errors_random).ravel() 215 | if regression: 216 | data = ( 217 | { 218 | "married mother (95% CI)": [ 219 | cate[dl.get_subpop(test_set)], 220 | cate_pred[dl.get_subpop(test_set)], 221 | epistemic_unct[dl.get_subpop(test_set)], 222 | ], 223 | "unmarried mother (95% CI)": [ 224 | cate[np.invert(dl.get_subpop(test_set))], 225 | cate_pred[np.invert(dl.get_subpop(test_set))], 226 | epistemic_unct[np.invert(dl.get_subpop(test_set))], 227 | ], 228 | } 229 | if exclude_population and test_set 230 | else {"predictions (95% CI)": [cate_pred, cate, epistemic_unct]} 231 | ) 232 | plotting.error_bars( 233 | data=data, 234 | file_name=os.path.join(output_dir, "cate_scatter_{}.png".format(tag)), 235 | ) 236 | x = ( 237 | { 238 | "married mother {}".format(tag): predictive_uncrt[dl.get_subpop(test_set)], 239 | "unmarried mother {}".format(tag): predictive_uncrt[ 240 | np.invert(dl.get_subpop(test_set)) 241 | ], 242 | } 243 | if exclude_population and test_set 244 | else {tag: predictive_uncrt} 245 | ) 246 | plotting.histogram( 247 | x=x, 248 | bins=64, 249 | alpha=0.5, 250 | x_label="$\widehat{Var}[Y_1(\mathbf{x}_i) - Y_0(\mathbf{x}_i)]$", 251 | y_label="Number of individuals", 252 | x_limit=(None, None), 253 | file_name=os.path.join(output_dir, "cate_variance_{}.png".format(tag)), 254 | ) 255 | x = ( 256 | { 257 | "married mother {}".format(tag): epistemic_unct[dl.get_subpop(test_set)], 258 | "unmarried mother {}".format(tag): epistemic_unct[ 259 | np.invert(dl.get_subpop(test_set)) 260 | ], 261 | } 262 | if exclude_population and test_set 263 | else {tag: epistemic_unct} 264 | ) 265 | plotting.histogram( 266 | x=x, 267 | bins=64, 268 | alpha=0.5, 269 | x_label="$\widehat{I}_{tot}[\mu_1(\mathbf{x}_i), \mu_0(\mathbf{x}_i)]$", 270 | y_label="Number of individuals", 271 | x_limit=(None, None), 272 | file_name=os.path.join(output_dir, "cate_mi_{}.png".format(tag)), 273 | ) 274 | x = { 275 | "$t=0$ {}".format(tag): p_t[dl.get_t(test_set)[:, 0] > 0.5], 276 | "$t=1$ {}".format(tag): p_t[dl.get_t(test_set)[:, 1] > 0.5], 277 | } 278 | plotting.histogram( 279 | x=x, 280 | bins=128, 281 | alpha=0.5, 282 | x_label="$p(t=1 | \mathbf{x})$", 283 | y_label="Number of individuals", 284 | x_limit=(0.0, 1.0), 285 | file_name=os.path.join(output_dir, "cate_propensity_{}.png".format(tag)), 286 | ) 287 | pehe_stats = { 288 | "Propensity quantile": [float(v) for v in pehe_prop], 289 | "Propensity trimming": [float(v) for v in pehe_prop_mag], 290 | "Epistemic Uncertainty": [float(v) for v in pehe_unct], 291 | "Random": [float(v) for v in pehe_random], 292 | } 293 | error_stats = { 294 | "Propensity quantile": [float(v) for v in errors_prop], 295 | "Propensity trimming": [float(v) for v in errors_prop_mag], 296 | "Epistemic Uncertainty": [float(v) for v in errors_unct], 297 | "Random": [float(v) for v in errors_random], 298 | } 299 | quantiles["counts"] = digit_counts 300 | return pehe_stats, error_stats, quantiles 301 | 302 | 303 | def get_overlap_score(kde_0, kde_1, p_t, num_0, num_1): 304 | s_0 = num_0 * np.exp(kde_0.score_samples(100.0 * p_t).ravel()) / len(p_t) 305 | s_1 = num_1 * np.exp(kde_1.score_samples(100.0 * p_t).ravel()) / len(p_t) 306 | return s_0 * s_1 / (s_0 + s_1 + 1e-7) 307 | -------------------------------------------------------------------------------- /ucate/library/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras as tfk 3 | from tensorflow.python.ops import math_ops 4 | from tensorflow.python.ops import array_ops 5 | from tensorflow.python.ops import gen_math_ops 6 | from tensorflow.python.framework import ops 7 | from tensorflow.python.framework import dtypes 8 | from tensorflow.python.framework import tensor_shape 9 | from tensorflow.python.keras import backend as K 10 | from tensorflow.python.keras.engine.input_spec import InputSpec 11 | 12 | 13 | class MultiBranchDense(tf.keras.layers.Layer): 14 | def __init__( 15 | self, 16 | units, 17 | num_branches, 18 | num_examples, 19 | activation=None, 20 | use_bias=True, 21 | kernel_initializer='glorot_uniform', 22 | bias_initializer='zeros', 23 | activity_regularizer=None, 24 | kernel_constraint=None, 25 | bias_constraint=None, 26 | **kwargs 27 | ): 28 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 29 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 30 | 31 | super(MultiBranchDense, self).__init__( 32 | activity_regularizer=tfk.regularizers.get(activity_regularizer), **kwargs) 33 | 34 | self.units = int(units) if not isinstance(units, int) else units 35 | self.num_groups = num_branches 36 | self.num_examples = num_examples 37 | self.activation = tfk.activations.get(activation) 38 | self.use_bias = use_bias 39 | self.kernel_initializer = tfk.initializers.get(kernel_initializer) 40 | self.bias_initializer = tfk.initializers.get(bias_initializer) 41 | self.kernel_constraint = tfk.constraints.get(kernel_constraint) 42 | self.bias_constraint = tfk.constraints.get(bias_constraint) 43 | 44 | self.supports_masking = True 45 | self.input_spec = [InputSpec(min_ndim=2), InputSpec(min_ndim=2)] 46 | 47 | def build(self, input_shape): 48 | dtype = dtypes.as_dtype(self.dtype or K.floatx()) 49 | if not (dtype.is_floating or dtype.is_complex): 50 | raise TypeError('Unable to build `Dense` layer with non-floating point ' 51 | 'dtype %s' % (dtype,)) 52 | input_shape_x, input_shape_g = tensor_shape.TensorShape(input_shape[0]), tensor_shape.TensorShape( 53 | input_shape[1]) 54 | if tensor_shape.dimension_value(input_shape_x[-1]) is None: 55 | raise ValueError('The last dimension of the inputs to `Dense` ' 56 | 'should be defined. Found `None`.') 57 | last_dim_x = tensor_shape.dimension_value(input_shape_x[-1]) 58 | last_dim_g = tensor_shape.dimension_value(input_shape_g[-1]) 59 | self.input_spec = [ 60 | InputSpec( 61 | min_ndim=2, 62 | axes={-1: last_dim_x} 63 | ), 64 | InputSpec( 65 | min_ndim=2, 66 | axes={-1: last_dim_g} 67 | ) 68 | ] 69 | num_examples = tf.cast(self.num_examples, self.dtype) 70 | self.kernel_shape = [last_dim_x, self.units] 71 | self.kernel = self.add_weight( 72 | 'kernel', 73 | shape=[last_dim_g, last_dim_x, self.units], 74 | initializer=self.kernel_initializer, 75 | regularizer=branch_l2(num_examples), 76 | constraint=self.kernel_constraint, 77 | dtype=self.dtype, 78 | trainable=True 79 | ) 80 | 81 | if self.use_bias: 82 | self.bias = self.add_weight( 83 | 'bias', 84 | shape=[last_dim_g, self.units], 85 | initializer=self.bias_initializer, 86 | regularizer=branch_bias_l2(num_examples), 87 | constraint=self.bias_constraint, 88 | dtype=self.dtype, 89 | trainable=True) 90 | else: 91 | self.bias = None 92 | self.built = True 93 | 94 | def call(self, inputs): 95 | x, g = inputs 96 | x = math_ops.cast(x, self._compute_dtype) 97 | g = math_ops.cast(g, self._compute_dtype) 98 | kernel = tf.reshape(self.kernel, [-1, tf.reduce_prod(self.kernel_shape)]) 99 | kernel = gen_math_ops.mat_mul(g, kernel) 100 | kernel = tf.reshape(kernel, [-1, ] + self.kernel_shape) 101 | outputs = tf.keras.backend.batch_dot(x, kernel) 102 | if self.use_bias: 103 | bias = gen_math_ops.mat_mul(g, self.bias) 104 | outputs = outputs + bias 105 | if self.activation is not None: 106 | return self.activation(outputs) # pylint: disable=not-callable 107 | return outputs 108 | 109 | def compute_output_shape(self, input_shape): 110 | input_shape = tensor_shape.TensorShape(input_shape) 111 | input_shape = input_shape.with_rank_at_least(2) 112 | if tensor_shape.dimension_value(input_shape[-1]) is None: 113 | raise ValueError( 114 | 'The innermost dimension of input_shape must be defined, but saw: %s' 115 | % input_shape) 116 | return input_shape[:-1].concatenate(self.units) 117 | 118 | def get_config(self): 119 | config = { 120 | 'units': self.units, 121 | 'activation': tfk.activations.serialize(self.activation), 122 | 'use_bias': self.use_bias, 123 | 'kernel_initializer': tfk.initializers.serialize(self.kernel_initializer), 124 | 'bias_initializer': tfk.initializers.serialize(self.bias_initializer), 125 | 'activity_regularizer': 126 | tfk.regularizers.serialize(self.activity_regularizer), 127 | 'kernel_constraint': tfk.constraints.serialize(self.kernel_constraint), 128 | 'bias_constraint': tfk.constraints.serialize(self.bias_constraint) 129 | } 130 | base_config = super(MultiBranchDense, self).get_config() 131 | return dict(list(base_config.items()) + list(config.items())) 132 | 133 | 134 | class BatchDropout(tf.keras.layers.Dropout): 135 | def _get_noise_shape(self, inputs): 136 | return ops.convert_to_tensor_v2([1, array_ops.shape(inputs)[-1]]) 137 | 138 | 139 | def branch_l2(rate): 140 | def func(x): 141 | l2 = tf.reduce_sum(tf.square(x), axis=[1, 2]) 142 | return 0.5 * tf.reduce_sum(l2 / rate) 143 | return func 144 | 145 | 146 | def branch_bias_l2(rate): 147 | def func(x): 148 | l2 = tf.reduce_sum(tf.square(x), axis=-1) 149 | return 0.5 * tf.reduce_sum(l2 / rate) 150 | return func 151 | -------------------------------------------------------------------------------- /ucate/library/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ucate.library.models.mlp import BayesianNeuralNetwork 2 | 3 | from ucate.library.models.cnn import BayesianConvolutionalNeuralNetwork 4 | 5 | from ucate.library.models.cevae import BayesianCEVAE 6 | 7 | from ucate.library.models.tarnet import TARNet 8 | 9 | from ucate.library.models.core import BaseModel 10 | 11 | 12 | MODELS = { 13 | "mlp": BayesianNeuralNetwork, 14 | "cnn": BayesianConvolutionalNeuralNetwork, 15 | "cevae": BayesianCEVAE, 16 | "tarnet": TARNet 17 | } 18 | -------------------------------------------------------------------------------- /ucate/library/models/cevae.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ucate.library.models import core 3 | from ucate.library.modules import dense 4 | from ucate.library.modules import samplers 5 | from ucate.library.modules import convolution 6 | 7 | 8 | class BayesianCEVAE(core.BaseModel): 9 | def __init__( 10 | self, 11 | dim_x, 12 | dim_t, 13 | dim_y, 14 | regression, 15 | dim_latent, 16 | num_examples, 17 | dim_hidden, 18 | dropout_rate=0.1, 19 | beta=1.0, 20 | negative_sampling=True, 21 | do_convolution=False, 22 | *args, 23 | **kwargs 24 | ): 25 | super(BayesianCEVAE, self).__init__(*args, **kwargs) 26 | if isinstance(dim_x, list) and not do_convolution: 27 | self.dim_x = dim_x[0] 28 | self.dim_x_bin = dim_x[1] 29 | dim_x = sum(dim_x) 30 | else: 31 | self.dim_x = dim_x 32 | self.dim_x_bin = None 33 | self.regression = regression 34 | self.encoder = Encoder( 35 | do_convolution=do_convolution, 36 | dim_latent=dim_latent, 37 | num_examples=num_examples, 38 | dim_hidden=dim_hidden, 39 | dropout_rate=dropout_rate, 40 | beta=beta, 41 | negative_sampling=negative_sampling, 42 | ) 43 | self.decoder = Decoder( 44 | dim_x=dim_x, 45 | dim_t=dim_t, 46 | dim_y=dim_y, 47 | regression=regression, 48 | num_examples=num_examples, 49 | dim_hidden=dim_hidden, 50 | dropout_rate=dropout_rate, 51 | ) 52 | if do_convolution: 53 | self.x_sampler = samplers.ConvNormalSampler( 54 | dim_output=self.dim_x[-1], num_examples=sum(num_examples), beta=0.0 55 | ) 56 | else: 57 | self.x_sampler = samplers.NormalSampler( 58 | dim_output=self.dim_x, 59 | num_branches=1, 60 | num_examples=sum(num_examples), 61 | beta=0.0, 62 | ) 63 | if self.dim_x_bin is not None: 64 | self.x_sampler_bin = samplers.BernoulliSampler( 65 | dim_output=self.dim_x_bin, 66 | num_branches=1, 67 | num_examples=sum(num_examples), 68 | beta=0.0, 69 | ) 70 | 71 | def call(self, inputs, training=None): 72 | x, t, y = inputs 73 | qz = self.encoder([x, t], training=training) 74 | hx, pt, py = self.decoder([qz.sample(), t], training=training) 75 | px = self.x_sampler(hx, training=training) 76 | if self.dim_x_bin is not None: 77 | px_bin = self.x_sampler_bin(hx, training=training) 78 | x_cont, x_bin = tf.split(x, [self.dim_x, self.dim_x_bin], axis=-1) 79 | self.add_loss(tf.reduce_mean(-px.log_prob(x_cont))) 80 | self.add_loss(tf.reduce_mean(-px_bin.log_prob(x_bin))) 81 | else: 82 | self.add_loss(tf.reduce_mean(-px.log_prob(x))) 83 | self.add_loss(tf.reduce_mean(-pt.log_prob(t))) 84 | self.add_loss(tf.reduce_mean(-py.log_prob(y))) 85 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 86 | return mu, py.sample() 87 | 88 | def mc_sample_model(self, inputs): 89 | x, t = inputs 90 | qz = self.encoder([x, t], training=False) 91 | hx, pt, py = self.decoder([qz.loc, t], training=True) 92 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 93 | return mu, py.sample() 94 | 95 | def mc_sample_latent_step(self, inputs): 96 | x, t = inputs 97 | qz = self.encoder([x, t], training=False) 98 | hx, pt, py = self.decoder([qz.sample(), t], training=False) 99 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 100 | return mu, py.sample() 101 | 102 | def mc_sample_step(self, inputs): 103 | x, t = inputs 104 | qz = self.encoder([x, t], training=False) 105 | hx, pt, py = self.decoder([qz.sample(), t], training=True) 106 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 107 | return mu, py.sample() 108 | 109 | 110 | class Encoder(tf.keras.Model): 111 | def __init__( 112 | self, 113 | do_convolution, 114 | dim_latent, 115 | num_examples, 116 | dim_hidden, 117 | dropout_rate=0.1, 118 | beta=1.0, 119 | negative_sampling=True, 120 | *args, 121 | **kwargs 122 | ): 123 | super(Encoder, self).__init__(*args, **kwargs) 124 | self.conv = ( 125 | convolution.ConvHead( 126 | base_filters=32, 127 | num_examples=sum(num_examples), 128 | dropout_rate=dropout_rate, 129 | ) 130 | if do_convolution 131 | else dense.identity() 132 | ) 133 | self.hidden_1 = dense.Dense( 134 | units=dim_hidden, 135 | num_examples=sum(num_examples), 136 | dropout_rate=dropout_rate, 137 | activation="elu", 138 | name="encoder_hidden_1", 139 | ) 140 | self.hidden_2 = dense.Dense( 141 | units=dim_hidden, 142 | num_examples=sum(num_examples), 143 | dropout_rate=dropout_rate, 144 | activation="elu", 145 | name="encoder_hidden_2", 146 | ) 147 | self.hidden_3 = dense.Dense( 148 | units=dim_hidden, 149 | num_examples=num_examples, 150 | dropout_rate=dropout_rate, 151 | num_branches=2, 152 | activation="elu", 153 | name="encoder_hidden_3", 154 | ) 155 | self.hidden_4 = dense.Dense( 156 | units=dim_hidden, 157 | num_examples=num_examples, 158 | dropout_rate=dropout_rate, 159 | num_branches=2, 160 | activation="elu", 161 | name="encoder_hidden_4", 162 | ) 163 | self.sampler = samplers.NormalSampler( 164 | dim_output=dim_latent, 165 | num_branches=2, 166 | num_examples=num_examples, 167 | beta=beta / 2 if negative_sampling else beta, 168 | ) 169 | self.negative_sampling = negative_sampling 170 | 171 | def call(self, inputs, training=None): 172 | x, t = inputs 173 | q = self.forward([x, t], training=training) 174 | if self.negative_sampling: 175 | t_cf = 1.0 - t 176 | _ = self.forward([x, t_cf], training=training) 177 | return q 178 | 179 | def forward(self, inputs, training=None): 180 | x, t = inputs 181 | outputs = self.conv(x, training=training) 182 | outputs = self.hidden_1(outputs, training=training) 183 | outputs = self.hidden_2(outputs, training=training) 184 | outputs = self.hidden_3([outputs, t], training=training) 185 | outputs = self.hidden_4([outputs, t], training=training) 186 | return self.sampler([outputs, t], training=training) 187 | 188 | 189 | class Decoder(tf.keras.Model): 190 | def __init__( 191 | self, 192 | dim_x, 193 | dim_t, 194 | dim_y, 195 | regression, 196 | num_examples, 197 | dim_hidden, 198 | dropout_rate=0.1, 199 | *args, 200 | **kwargs 201 | ): 202 | super(Decoder, self).__init__(*args, **kwargs) 203 | do_convolution = isinstance(dim_x, (tuple, list)) 204 | self.x_hidden_1 = dense.Dense( 205 | units=dim_hidden, 206 | num_examples=sum(num_examples), 207 | dropout_rate=dropout_rate, 208 | num_branches=1, 209 | activation="elu", 210 | name="decoder_x_hidden_1", 211 | ) 212 | self.x_hidden_2 = dense.Dense( 213 | units=dim_hidden, 214 | num_examples=sum(num_examples), 215 | dropout_rate=dropout_rate if do_convolution else 0.1, 216 | num_branches=1, 217 | activation="elu", 218 | name="decoder_x_hidden_2", 219 | ) 220 | self.x_conv = ( 221 | convolution.ConvTail( 222 | base_filters=32, 223 | num_examples=sum(num_examples), 224 | dropout_rate=dropout_rate, 225 | ) 226 | if do_convolution 227 | else dense.identity() 228 | ) 229 | self.t_hidden_1 = dense.Dense( 230 | units=dim_hidden, 231 | num_examples=sum(num_examples), 232 | dropout_rate=dropout_rate, 233 | num_branches=1, 234 | activation="elu", 235 | name="decoder_t_hidden_1", 236 | ) 237 | self.t_hidden_2 = dense.Dense( 238 | units=dim_hidden, 239 | num_examples=sum(num_examples), 240 | dropout_rate=0.5, 241 | num_branches=1, 242 | activation="elu", 243 | name="decoder_t_hidden_2", 244 | ) 245 | self.t_sampler = samplers.CategoricalSampler( 246 | dim_output=dim_t, num_branches=1, num_examples=sum(num_examples), beta=0.0 247 | ) 248 | self.y_hidden_1 = dense.Dense( 249 | units=dim_hidden, 250 | num_examples=sum(num_examples), 251 | dropout_rate=dropout_rate, 252 | activation="elu", 253 | name="decoder_y_hidden_1", 254 | ) 255 | self.y_hidden_2 = dense.Dense( 256 | units=dim_hidden, 257 | num_examples=sum(num_examples), 258 | dropout_rate=dropout_rate, 259 | activation="elu", 260 | name="decoder_y_hidden_2", 261 | ) 262 | self.y_hidden_3 = dense.Dense( 263 | units=dim_hidden, 264 | num_examples=num_examples, 265 | dropout_rate=dropout_rate, 266 | num_branches=2, 267 | activation="elu", 268 | name="decoder_y_hidden_3", 269 | ) 270 | self.y_hidden_4 = dense.Dense( 271 | units=dim_hidden, 272 | num_examples=num_examples, 273 | dropout_rate=0.5, 274 | num_branches=2, 275 | activation="elu", 276 | name="decoder_y_hidden_4", 277 | ) 278 | y_sampler = samplers.NormalSampler if regression else samplers.BernoulliSampler 279 | self.y_sampler = y_sampler( 280 | dim_output=dim_y, num_branches=2, num_examples=num_examples, beta=0.0 281 | ) 282 | 283 | def call(self, inputs, training=None): 284 | z, t = inputs 285 | hx = self.x_hidden_1(z, training=training) 286 | hx = self.x_hidden_2(hx, training=training) 287 | hx = self.x_conv(hx, training=training) 288 | ht = self.t_hidden_1(z, training=training) 289 | ht = self.t_hidden_2(ht, training=training) 290 | pt = self.t_sampler(ht, training=training) 291 | hy = self.y_hidden_1(z, training=training) 292 | hy = self.y_hidden_2(hy, training=training) 293 | hy = self.y_hidden_3([hy, t], training=training) 294 | hy = self.y_hidden_4([hy, t], training=training) 295 | py = self.y_sampler([hy, t], training=training) 296 | return hx, pt, py 297 | -------------------------------------------------------------------------------- /ucate/library/models/cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ucate.library.models import core 3 | from ucate.library.modules import samplers 4 | 5 | 6 | class BayesianConvolutionalNeuralNetwork(core.BaseModel): 7 | def __init__( 8 | self, 9 | num_examples, 10 | dim_hidden, 11 | regression, 12 | depth, 13 | dropout_rate=0.1, 14 | *args, 15 | **kwargs 16 | ): 17 | super(BayesianConvolutionalNeuralNetwork, self).__init__( 18 | *args, 19 | **kwargs 20 | ) 21 | self.conv_1 = tf.keras.layers.Conv2D( 22 | filters=32, 23 | kernel_size=5, 24 | strides=(1, 1), 25 | padding='same', 26 | activation='elu', 27 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples) * (1 / 25)), 28 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 29 | ) 30 | self.conv_1_pool = tf.keras.layers.MaxPool2D() 31 | self.conv_1_drop = tf.keras.layers.SpatialDropout2D(rate=dropout_rate) 32 | self.conv_2 = tf.keras.layers.Conv2D( 33 | filters=64, 34 | kernel_size=5, 35 | strides=(1, 1), 36 | padding='same', 37 | activation='elu', 38 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1. - dropout_rate) * (1 / num_examples) * (1 / 25)), 39 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 40 | ) 41 | self.conv_2_pool = tf.keras.layers.MaxPool2D() 42 | self.flatten = tf.keras.layers.Flatten() 43 | self.conv_2_drop = tf.keras.layers.Dropout(rate=dropout_rate) 44 | self.hidden_1 = tf.keras.layers.Dense( 45 | units=dim_hidden, 46 | activation='elu', 47 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1. - dropout_rate) * (1 / num_examples)), 48 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 49 | ) 50 | self.dropout_1 = tf.keras.layers.Dropout(rate=dropout_rate) 51 | self.hidden_2 = tf.keras.layers.Dense( 52 | units=dim_hidden, 53 | activation='elu', 54 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1. - dropout_rate) * (1 / num_examples)), 55 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 56 | ) 57 | self.dropout_2 = tf.keras.layers.Dropout(rate=0.5) 58 | self.sampler = samplers.NormalSampler( 59 | dim_output=1, 60 | num_branches=1, 61 | num_examples=num_examples, 62 | beta=0.0 63 | ) if regression else samplers.BernoulliSampler( 64 | dim_output=1, 65 | num_branches=1, 66 | num_examples=num_examples, 67 | beta=0.0 68 | ) 69 | self.regression = regression 70 | self.mc_sample_function = None 71 | 72 | def stem( 73 | self, 74 | inputs, 75 | training 76 | ): 77 | outputs = self.conv_1(inputs) 78 | outputs = self.conv_1_pool(outputs) 79 | outputs = self.conv_1_drop(outputs, training=training) 80 | outputs = self.conv_2(outputs) 81 | outputs = self.conv_2_pool(outputs) 82 | outputs = self.flatten(outputs) 83 | outputs = self.conv_2_drop(outputs, training=training) 84 | outputs = self.hidden_1(outputs) 85 | outputs = self.dropout_1(outputs, training=training) 86 | outputs = self.hidden_2(outputs) 87 | outputs = self.dropout_2(outputs, training=training) 88 | return outputs 89 | 90 | def call( 91 | self, 92 | inputs, 93 | training=None 94 | ): 95 | x, y = inputs 96 | h = self.stem( 97 | x, 98 | training=training 99 | ) 100 | py = self.sampler(h) 101 | self.add_loss(tf.reduce_mean(-py.log_prob(y))) 102 | return py.sample(), py.entropy() 103 | 104 | def mc_sample_step( 105 | self, 106 | inputs 107 | ): 108 | h = self.stem( 109 | inputs, 110 | training=True 111 | ) 112 | py = self.sampler(h) 113 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 114 | return mu, py.sample() 115 | -------------------------------------------------------------------------------- /ucate/library/models/core.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.util import nest 3 | from tensorflow.python.eager import context 4 | from tensorflow.python.keras.utils import tf_utils 5 | from tensorflow.python.keras.engine import data_adapter 6 | from tensorflow.python.keras.engine.training import concat, reduce_per_replica 7 | 8 | eps = tf.keras.backend.epsilon() 9 | ln = tf.keras.backend.log 10 | 11 | 12 | class BaseModel(tf.keras.Model): 13 | def __init__( 14 | self, 15 | *args, 16 | **kwargs 17 | ): 18 | super(BaseModel, self).__init__( 19 | *args, 20 | **kwargs 21 | ) 22 | self.mc_sample_function = None 23 | 24 | def mc_sample_step( 25 | self, 26 | inputs 27 | ): 28 | raise NotImplementedError('mc_sample step must me impletmented') 29 | 30 | def make_mc_sample_function(self): 31 | if self.mc_sample_function is not None: 32 | return self.mc_sample_function 33 | 34 | def predict_function(iterator): 35 | data = next(iterator) 36 | data = data if isinstance(data, (list, tuple)) else [data] 37 | outputs = self.distribute_strategy.run( 38 | self.mc_sample_step, 39 | args=data 40 | ) 41 | outputs = reduce_per_replica( 42 | outputs, 43 | self.distribute_strategy, 44 | reduction='concat' 45 | ) 46 | return outputs 47 | 48 | self.mc_sample_function = predict_function 49 | return self.mc_sample_function 50 | 51 | def mc_sample( 52 | self, 53 | x, 54 | batch_size=None, 55 | steps=None, 56 | max_queue_size=10, 57 | workers=1, 58 | use_multiprocessing=False 59 | ): 60 | outputs = None 61 | with self.distribute_strategy.scope(): 62 | data_handler = data_adapter.DataHandler( 63 | x=x, 64 | batch_size=batch_size, 65 | steps_per_epoch=steps, 66 | initial_epoch=0, 67 | epochs=1, 68 | max_queue_size=max_queue_size, 69 | workers=workers, 70 | use_multiprocessing=use_multiprocessing, 71 | model=self 72 | ) 73 | predict_function = self.make_mc_sample_function() 74 | for _, iterator in data_handler.enumerate_epochs(): 75 | with data_handler.catch_stop_iteration(): 76 | for step in data_handler.steps(): 77 | tmp_batch_outputs = predict_function(iterator) 78 | if not data_handler.inferred_steps: 79 | context.async_wait() 80 | batch_outputs = tmp_batch_outputs 81 | if outputs is None: 82 | outputs = nest.map_structure( 83 | lambda batch_output: [batch_output], 84 | batch_outputs 85 | ) 86 | else: 87 | nest.map_structure_up_to( 88 | batch_outputs, 89 | lambda output, batch_output: output.append(batch_output), 90 | outputs, 91 | batch_outputs 92 | ) 93 | all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs) 94 | return tf_utils.to_numpy_or_python_type(all_outputs) 95 | -------------------------------------------------------------------------------- /ucate/library/models/mlp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ucate.library.models import core 3 | from ucate.library.modules import samplers 4 | 5 | 6 | class BayesianNeuralNetwork(core.BaseModel): 7 | def __init__( 8 | self, 9 | num_examples, 10 | dim_hidden, 11 | regression, 12 | dropout_rate=0.1, 13 | depth=5, 14 | *args, 15 | **kwargs 16 | ): 17 | super(BayesianNeuralNetwork, self).__init__( 18 | *args, 19 | **kwargs 20 | ) 21 | self.blocks = [] 22 | for i in range(depth): 23 | self.blocks.append( 24 | tf.keras.layers.Dense( 25 | units=dim_hidden, 26 | activation='elu', 27 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)), 28 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 29 | ) 30 | ) 31 | self.blocks.append( 32 | tf.keras.layers.Dropout(rate=0.5 if i == depth - 1 else dropout_rate) 33 | ) 34 | self.sampler = samplers.NormalSampler( 35 | dim_output=1, 36 | num_branches=1, 37 | num_examples=num_examples, 38 | beta=0.0 39 | ) if regression else samplers.BernoulliSampler( 40 | dim_output=1, 41 | num_branches=1, 42 | num_examples=num_examples, 43 | beta=0.0 44 | ) 45 | self.regression = regression 46 | 47 | def call( 48 | self, 49 | inputs, 50 | training=None 51 | ): 52 | x, y = inputs 53 | for block in self.blocks: 54 | x = block(x, training=training) 55 | py = self.sampler(x) 56 | self.add_loss(tf.reduce_mean(-py.log_prob(y))) 57 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 58 | return mu, py.sample() 59 | 60 | def mc_sample_step( 61 | self, 62 | inputs 63 | ): 64 | x = inputs 65 | for block in self.blocks: 66 | x = block(x, training=True) 67 | py = self.sampler(x) 68 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 69 | return mu, py.sample() 70 | -------------------------------------------------------------------------------- /ucate/library/models/tarnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ucate.library.models import core 3 | from ucate.library.modules import dense 4 | from ucate.library.modules import samplers 5 | from ucate.library.modules import convolution 6 | 7 | 8 | class TARNet(core.BaseModel): 9 | def __init__( 10 | self, 11 | do_convolution, 12 | num_examples, 13 | dim_hidden, 14 | regression, 15 | dropout_rate=0.1, 16 | beta=1.0, 17 | mode="tarnet", 18 | *args, 19 | **kwargs 20 | ): 21 | super(TARNet, self).__init__(*args, **kwargs) 22 | self.conv = ( 23 | convolution.ConvHead( 24 | base_filters=32, 25 | num_examples=sum(num_examples), 26 | dropout_rate=0.1, 27 | ) 28 | if do_convolution 29 | else dense.identity() 30 | ) 31 | self.hidden_1 = dense.Dense( 32 | units=dim_hidden, 33 | num_examples=sum(num_examples), 34 | dropout_rate=dropout_rate, 35 | activation="elu", 36 | name="tarnet_hidden_1", 37 | ) 38 | self.hidden_2 = dense.Dense( 39 | units=dim_hidden, 40 | num_examples=sum(num_examples), 41 | dropout_rate=dropout_rate, 42 | activation="elu", 43 | name="tarnet_hidden_2", 44 | ) 45 | self.hidden_3 = dense.Dense( 46 | units=dim_hidden, 47 | num_examples=sum(num_examples), 48 | dropout_rate=dropout_rate, 49 | activation="elu", 50 | name="tarnet_hidden_3", 51 | ) 52 | self.hidden_4 = dense.Dense( 53 | units=dim_hidden, 54 | num_examples=num_examples, 55 | dropout_rate=dropout_rate, 56 | num_branches=2, 57 | activation="elu", 58 | name="encoder_hidden_4", 59 | ) 60 | self.hidden_5 = dense.Dense( 61 | units=dim_hidden, 62 | num_examples=num_examples, 63 | dropout_rate=0.5, 64 | num_branches=2, 65 | activation="elu", 66 | name="encoder_hidden_5", 67 | ) 68 | y_sampler = samplers.NormalSampler if regression else samplers.BernoulliSampler 69 | self.y_sampler = y_sampler( 70 | dim_output=1, num_branches=2, num_examples=num_examples, beta=0.0 71 | ) 72 | self.regression = regression 73 | self.beta = beta 74 | if mode == "dragon": 75 | self.t_sampler = samplers.CategoricalSampler( 76 | 2, num_branches=1, num_examples=sum(num_examples), beta=0.0 77 | ) 78 | self.mode = mode 79 | self.y_loss = ( 80 | tf.keras.losses.MeanSquaredError() 81 | if regression 82 | else tf.keras.losses.BinaryCrossentropy() 83 | ) 84 | 85 | def call(self, inputs, training=None): 86 | x, t, y = inputs 87 | py, pt = self.forward([x, t], training=training) 88 | self.add_loss(tf.reduce_mean(-py.log_prob(y))) 89 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 90 | if self.mode == "dragon": 91 | self.add_loss(tf.reduce_mean(-pt.log_prob(t))) 92 | self.add_loss( 93 | self.beta 94 | * propensity_loss( 95 | t_true=t, 96 | t_pred=tf.keras.backend.softmax(pt.logits), 97 | y_true=y, 98 | y_pred=mu, 99 | y_loss=self.y_loss, 100 | ) 101 | ) 102 | elif self.mode == "mmd": 103 | self.add_loss(self.beta * mmd(t, pt)) 104 | return mu, py.sample() 105 | 106 | def forward(self, inputs, training=None): 107 | x, t = inputs 108 | phi = self.conv(x, training=training) 109 | phi = self.hidden_1(phi, training=training) 110 | phi = self.hidden_2(phi, training=training) 111 | phi = self.hidden_3(phi, training=training) 112 | pt = self.t_sampler(phi) if self.mode == "dragon" else phi 113 | outputs = self.hidden_4([phi, t], training=training) 114 | outputs = self.hidden_5([outputs, t], training=training) 115 | return self.y_sampler([outputs, t], training=training), pt 116 | 117 | def mc_sample_step(self, inputs): 118 | x, t = inputs 119 | py, pt = self.forward([x, t], training=True) 120 | mu = py.loc if self.regression else tf.sigmoid(py.logits) 121 | return mu, py.sample() 122 | 123 | 124 | def propensity_loss(t_true, t_pred, y_true, y_pred, y_loss): 125 | eps = tf.keras.backend.epsilon() 126 | loss_value = tf.reduce_mean( 127 | tf.keras.losses.categorical_crossentropy(t_true, t_pred) 128 | ) 129 | t_pred = (t_pred + 0.001) / 1.002 130 | h = tf.reduce_sum(t_true / t_pred, -1, keepdims=True) 131 | q_tilde = y_pred + eps * h 132 | loss_value += tf.reduce_mean(y_loss(y_true, q_tilde)) 133 | return loss_value 134 | 135 | 136 | def mmd(t, phi): 137 | t_0 = tf.expand_dims(t[:, 0], -1) 138 | t_1 = tf.expand_dims(t[:, 1], -1) 139 | phi_0 = phi * t_0 140 | phi_1 = phi * t_1 141 | mu_0 = tf.reduce_sum(phi_0, 0) / (tf.reduce_sum(t_0) + tf.keras.backend.epsilon()) 142 | mu_1 = tf.reduce_sum(phi_1, 0) / (tf.reduce_sum(t_1) + tf.keras.backend.epsilon()) 143 | return tf.reduce_sum(tf.square(mu_0 - mu_1)) 144 | -------------------------------------------------------------------------------- /ucate/library/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/ucate/cddc23596a463e2ce9e270cf075b3e924137bf7a/ucate/library/modules/__init__.py -------------------------------------------------------------------------------- /ucate/library/modules/convolution.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class ConvHead(tf.keras.Model): 5 | def __init__( 6 | self, 7 | base_filters, 8 | num_examples=1000, 9 | dropout_rate=0.0, 10 | kernel_initializer='he_normal', 11 | *args, 12 | **kwargs 13 | ): 14 | super(ConvHead, self).__init__( 15 | *args, 16 | **kwargs 17 | ) 18 | self.conv_1 = tf.keras.layers.Conv2D( 19 | filters=base_filters, 20 | kernel_size=5, 21 | strides=(1, 1), 22 | padding='same', 23 | activation='elu', 24 | kernel_initializer=kernel_initializer, 25 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples) * (1 / 25)), 26 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 27 | ) 28 | self.conv_1_pool = tf.keras.layers.MaxPool2D() 29 | self.conv_1_drop = tf.keras.layers.SpatialDropout2D(rate=dropout_rate) 30 | self.conv_2 = tf.keras.layers.Conv2D( 31 | filters=base_filters * 2, 32 | kernel_size=5, 33 | strides=(1, 1), 34 | padding='same', 35 | activation='elu', 36 | kernel_initializer=kernel_initializer, 37 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1. - dropout_rate) * (1 / num_examples) * (1 / 25)), 38 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 39 | ) 40 | self.conv_2_pool = tf.keras.layers.MaxPool2D() 41 | self.flatten = tf.keras.layers.Flatten() 42 | self.conv_2_drop = tf.keras.layers.Dropout(rate=dropout_rate) 43 | 44 | def call( 45 | self, 46 | inputs, 47 | training=None 48 | ): 49 | outputs = self.conv_1(inputs) 50 | outputs = self.conv_1_pool(outputs) 51 | outputs = self.conv_1_drop(outputs, training=True) 52 | outputs = self.conv_2(outputs) 53 | outputs = self.conv_2_pool(outputs) 54 | outputs = self.flatten(outputs) 55 | outputs = self.conv_2_drop(outputs, training=True) 56 | return outputs 57 | 58 | 59 | class ConvTail(tf.keras.Model): 60 | def __init__( 61 | self, 62 | base_filters, 63 | num_examples=1000, 64 | dropout_rate=0.0, 65 | kernel_initializer='he_normal', 66 | *args, 67 | **kwargs 68 | ): 69 | super(ConvTail, self).__init__( 70 | *args, 71 | **kwargs 72 | ) 73 | self.conv_1 = tf.keras.layers.Conv2DTranspose( 74 | filters=base_filters * 2, 75 | kernel_size=7, 76 | activation='elu', 77 | kernel_initializer=kernel_initializer, 78 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1. - dropout_rate) * (1 / num_examples)), 79 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 80 | ) 81 | self.conv_1_pool = tf.keras.layers.UpSampling2D() 82 | self.conv_1_drop = tf.keras.layers.SpatialDropout2D(rate=dropout_rate) 83 | self.conv_2 = tf.keras.layers.Conv2DTranspose( 84 | filters=base_filters * 1, 85 | kernel_size=5, 86 | strides=(1, 1), 87 | padding='same', 88 | activation='elu', 89 | kernel_initializer=kernel_initializer, 90 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1. - dropout_rate) * (1 / num_examples) * (1 / 25)), 91 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)) 92 | ) 93 | self.conv_2_pool = tf.keras.layers.UpSampling2D() 94 | self.conv_2_drop = tf.keras.layers.Dropout(rate=dropout_rate) 95 | 96 | def call( 97 | self, 98 | inputs, 99 | training=None 100 | ): 101 | outputs = tf.keras.backend.expand_dims(inputs, 1) 102 | outputs = tf.keras.backend.expand_dims(outputs, 1) 103 | outputs = self.conv_1(outputs) 104 | outputs = self.conv_1_pool(outputs) 105 | outputs = self.conv_1_drop(outputs, training=True) 106 | outputs = self.conv_2(outputs) 107 | outputs = self.conv_2_pool(outputs) 108 | outputs = self.conv_2_drop(outputs, training=True) 109 | return outputs 110 | -------------------------------------------------------------------------------- /ucate/library/modules/core.py: -------------------------------------------------------------------------------- 1 | def identity(): 2 | def func(x, training=None): 3 | return x 4 | return func 5 | -------------------------------------------------------------------------------- /ucate/library/modules/dense.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ucate.library import layers 3 | from ucate.library.modules.core import identity 4 | 5 | 6 | class Dense(tf.keras.Model): 7 | def __init__( 8 | self, 9 | units, 10 | use_bias=True, 11 | num_examples=1000, 12 | dropout_rate=0.0, 13 | batch_norm=False, 14 | num_branches=1, 15 | activation='linear', 16 | kernel_initializer='he_normal', 17 | name=None, 18 | *args, 19 | **kwargs 20 | ): 21 | super(Dense, self).__init__( 22 | *args, 23 | **kwargs 24 | ) 25 | if num_branches > 1: 26 | self.dense = layers.MultiBranchDense( 27 | units=max(units // num_branches, 1), 28 | num_branches=num_branches, 29 | num_examples=num_examples, 30 | use_bias=use_bias, 31 | kernel_initializer=kernel_initializer, 32 | name='{}_dense'.format(name) 33 | ) 34 | else: 35 | weight_decay = 1 / num_examples 36 | self.dense = tf.keras.layers.Dense( 37 | units=units, 38 | use_bias=use_bias, 39 | kernel_initializer=kernel_initializer, 40 | kernel_regularizer=tf.keras.regularizers.l2( 41 | 0.5 * (1. - dropout_rate) * weight_decay 42 | ), 43 | bias_regularizer=tf.keras.regularizers.l2(0.5 * weight_decay), 44 | name='{}_dense'.format(name) 45 | ) 46 | self.norm = tf.keras.layers.BatchNormalization(name='{}_norm'.format(name)) if batch_norm else identity() 47 | self.activation = tf.keras.layers.Activation( 48 | activation, 49 | name='{}_act'.format(name) 50 | ) 51 | self.dropout = tf.keras.layers.Dropout( 52 | rate=dropout_rate, 53 | name='{}_drop'.format(name) 54 | ) 55 | 56 | def call(self, inputs, training=None): 57 | outputs = self.dense(inputs) 58 | outputs = self.norm(outputs, training=training) 59 | outputs = self.activation(outputs) 60 | return self.dropout(outputs, training=training) 61 | -------------------------------------------------------------------------------- /ucate/library/modules/samplers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability import distributions 3 | from ucate.library.modules import dense 4 | 5 | 6 | class NormalSampler(tf.keras.Model): 7 | def __init__( 8 | self, 9 | dim_output, 10 | num_branches, 11 | num_examples, 12 | beta, 13 | *args, 14 | **kwargs 15 | ): 16 | super(NormalSampler, self).__init__( 17 | *args, 18 | **kwargs 19 | ) 20 | self.mu = dense.Dense( 21 | units=dim_output, 22 | num_examples=num_examples, 23 | dropout_rate=0.0, 24 | num_branches=num_branches, 25 | activation='linear', 26 | name='normal_mu' 27 | ) 28 | self.sigma = dense.Dense( 29 | units=dim_output, 30 | num_examples=num_examples, 31 | dropout_rate=0.0, 32 | num_branches=num_branches, 33 | activation='softplus', 34 | kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02), 35 | name='normal_sigma' 36 | ) 37 | self.beta = beta 38 | 39 | def call( 40 | self, 41 | inputs, 42 | training=None 43 | ): 44 | mu = self.mu(inputs) 45 | sigma = self.sigma(inputs) + tf.keras.backend.epsilon() 46 | q = distributions.MultivariateNormalDiag( 47 | loc=mu, 48 | scale_diag=sigma 49 | ) 50 | if self.beta > 0.0: 51 | kld = q.kl_divergence( 52 | distributions.MultivariateNormalDiag( 53 | loc=tf.zeros_like(mu) 54 | ) 55 | ) 56 | self.add_loss(self.beta * tf.reduce_mean(kld)) 57 | return q 58 | 59 | 60 | class BernoulliSampler(tf.keras.Model): 61 | def __init__( 62 | self, 63 | dim_output, 64 | num_branches, 65 | num_examples, 66 | beta, 67 | *args, 68 | **kwargs 69 | ): 70 | super(BernoulliSampler, self).__init__( 71 | *args, 72 | **kwargs 73 | ) 74 | self.logits = dense.Dense( 75 | units=dim_output, 76 | num_examples=num_examples, 77 | dropout_rate=0.0, 78 | num_branches=num_branches, 79 | activation='linear', 80 | name='encoder_mu' 81 | ) 82 | self.beta = beta 83 | 84 | def call( 85 | self, 86 | inputs, 87 | training=None 88 | ): 89 | logits = self.logits(inputs) 90 | q = distributions.Bernoulli( 91 | logits=logits, 92 | dtype=tf.float32, 93 | ) 94 | if self.beta > 0.0: 95 | kld = q.kl_divergence( 96 | distributions.Bernoulli( 97 | logits=tf.zeros_like(logits), 98 | dtype=tf.float32, 99 | ) 100 | ) 101 | self.add_loss(tf.reduce_mean(kld)) 102 | return q 103 | 104 | 105 | class CategoricalSampler(tf.keras.Model): 106 | def __init__( 107 | self, 108 | dim_output, 109 | num_branches, 110 | num_examples, 111 | beta, 112 | *args, 113 | **kwargs 114 | ): 115 | super(CategoricalSampler, self).__init__( 116 | *args, 117 | **kwargs 118 | ) 119 | self.logits = dense.Dense( 120 | units=dim_output, 121 | num_examples=num_examples, 122 | dropout_rate=0.0, 123 | num_branches=num_branches, 124 | activation='linear', 125 | name='categorical_logits' 126 | ) 127 | self.beta = beta 128 | 129 | def call( 130 | self, 131 | inputs, 132 | training=None 133 | ): 134 | logits = self.logits(inputs) 135 | q = distributions.OneHotCategorical( 136 | logits=logits, 137 | dtype=tf.float32, 138 | ) 139 | if self.beta > 0.0: 140 | kld = q.kl_divergence( 141 | distributions.OneHotCategorical( 142 | logits=tf.zeros_like(logits), 143 | dtype=tf.float32, 144 | ) 145 | ) 146 | self.add_loss(tf.reduce_mean(kld)) 147 | return q 148 | 149 | 150 | class ConvNormalSampler(tf.keras.Model): 151 | def __init__( 152 | self, 153 | dim_output, 154 | num_examples, 155 | beta, 156 | *args, 157 | **kwargs 158 | ): 159 | super(ConvNormalSampler, self).__init__( 160 | *args, 161 | **kwargs 162 | ) 163 | self.mu = tf.keras.layers.Conv2D( 164 | filters=dim_output, 165 | kernel_size=3, 166 | padding='same', 167 | activation='linear', 168 | kernel_initializer='he_normal', 169 | kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)), 170 | bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)), 171 | name='normal_mu' 172 | ) 173 | # self.sigma = tf.keras.layers.Conv2D( 174 | # filters=dim_output, 175 | # kernel_size=3, 176 | # padding='same', 177 | # activation='softplus', 178 | # kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02), 179 | # kernel_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)), 180 | # bias_regularizer=tf.keras.regularizers.l2(0.5 * (1 / num_examples)), 181 | # name='normal_sigma' 182 | # ) 183 | self.beta = beta 184 | 185 | def call( 186 | self, 187 | inputs, 188 | training=None 189 | ): 190 | mu = self.mu(inputs) 191 | # sigma = self.sigma(inputs) + tf.keras.backend.epsilon() 192 | q = distributions.Independent( 193 | distributions.MultivariateNormalDiag( 194 | loc=mu, 195 | scale_diag=tf.ones_like(mu) 196 | ) 197 | ) 198 | if self.beta > 0.0: 199 | kld = q.kl_divergence( 200 | distributions.MultivariateNormalDiag( 201 | loc=tf.zeros_like(mu) 202 | ) 203 | ) 204 | self.add_loss(0.5 * tf.reduce_mean(kld)) 205 | return q 206 | -------------------------------------------------------------------------------- /ucate/library/scratch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import KernelDensity 3 | 4 | 5 | def get_density_estimator( 6 | x, 7 | bandwidth=1.0, 8 | kernel='gaussian' 9 | ): 10 | idx = np.random.choice(np.arange(len(x)), 2000) 11 | estimator = KernelDensity( 12 | kernel=kernel, 13 | bandwidth=bandwidth, 14 | ) 15 | return estimator.fit(x[idx]) 16 | -------------------------------------------------------------------------------- /ucate/library/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OATML/ucate/cddc23596a463e2ce9e270cf075b3e924137bf7a/ucate/library/utils/__init__.py -------------------------------------------------------------------------------- /ucate/library/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | 5 | sns.set(style="whitegrid", palette="colorblind") 6 | 7 | _FONTSIZE = 18 8 | _LEGEND_FONTSIZE = 16 9 | _TICK_FONTSIZE = 16 10 | _MARKER_SIZE = 3.0 11 | _LINEWIDTH = 3.0 12 | _DPI = 300 13 | 14 | 15 | def error_bars(data, file_name): 16 | plt.rcParams["figure.constrained_layout.use"] = True 17 | plt.clf() 18 | plt.cla() 19 | fig = plt.figure(figsize=(8, 6)) 20 | min_y = 10000 21 | min_x = 10000 22 | max_y = -10000 23 | max_x = -10000 24 | max_sigma = -10000 25 | for k, v in data.items(): 26 | x = v[0] 27 | y = v[1] 28 | two_sigma = 2.0 * np.sqrt(v[2]) 29 | _ = plt.errorbar( 30 | x=x, 31 | y=y, 32 | yerr=two_sigma, 33 | alpha=0.9, 34 | linestyle="None", 35 | marker="o", 36 | elinewidth=1.0, 37 | capsize=2.0, 38 | markersize=_MARKER_SIZE, 39 | label=k, 40 | ) 41 | min_y = y.min() if y.min() < min_y else min_y 42 | min_x = x.min() if x.min() < min_x else min_x 43 | max_y = y.max() if y.max() > max_y else max_y 44 | max_x = x.max() if x.max() > max_x else max_x 45 | max_sigma = two_sigma.max() if two_sigma.max() > max_sigma else max_sigma 46 | _ = plt.xlabel("True CATE", fontsize=_FONTSIZE) 47 | _ = plt.ylabel("Predicted CATE", fontsize=_FONTSIZE) 48 | _ = plt.xticks(fontsize=_TICK_FONTSIZE) 49 | _ = plt.yticks(fontsize=_TICK_FONTSIZE) 50 | limits = (min(min_y, min_x) - max_sigma, max(max_y, max_x) + max_sigma) 51 | _ = plt.plot( 52 | np.arange(limits[0], limits[1] + 1), 53 | np.arange(limits[0], limits[1] + 1), 54 | linestyle="--", 55 | label="ideal prediction line", 56 | ) 57 | _ = plt.legend(frameon=True, fontsize=_LEGEND_FONTSIZE) 58 | _ = plt.savefig(file_name, dpi=_DPI) 59 | plt.close(fig) 60 | 61 | 62 | def sweep(x, ys, y_label, file_name): 63 | plt.rcParams["figure.constrained_layout.use"] = True 64 | linestyles = ["solid", "dashed", "dashdot", "dotted"] 65 | append = "error" in file_name 66 | x = x + [1.0] if append else x 67 | plt.clf() 68 | plt.cla() 69 | fig = plt.figure(figsize=(8, 6)) 70 | for k, v in ys.items(): 71 | val, error = v 72 | val = np.append(val, 0.0) if append else val 73 | error = np.append(error, 0.0) if append else error 74 | plot = plt.plot( 75 | x, val, label=k, linewidth=_LINEWIDTH, linestyle=linestyles.pop(0) 76 | ) 77 | fill = plt.fill_between(x, val - error, val + error, alpha=0.2) 78 | _ = plt.xlabel("Proportion of recommendations withheld", fontsize=_FONTSIZE) 79 | _ = plt.ylabel(y_label, fontsize=_FONTSIZE) 80 | _ = plt.xticks(fontsize=_TICK_FONTSIZE) 81 | _ = plt.yticks(fontsize=_TICK_FONTSIZE, rotation=45) 82 | leg = plt.legend( 83 | title="Rejection policy", 84 | loc="upper right", 85 | frameon=True, 86 | fontsize=_LEGEND_FONTSIZE, 87 | ) 88 | leg._legend_box.align = "left" 89 | plt.setp(leg.get_title(), fontsize=_LEGEND_FONTSIZE) 90 | _ = plt.savefig(file_name, dpi=_DPI) 91 | plt.close(fig) 92 | 93 | 94 | def histogram( 95 | x, 96 | bins=50, 97 | alpha=1.0, 98 | x_label=None, 99 | y_label=None, 100 | x_limit=(0.0, 1.0), 101 | file_name=None, 102 | ): 103 | plt.rcParams["figure.constrained_layout.use"] = True 104 | plt.clf() 105 | plt.cla() 106 | colors = ["C0", "C1", "C2", "C3"] 107 | fig = plt.figure(figsize=(8, 6)) 108 | values = [v.ravel() for v in x.values()] 109 | _ = plt.hist( 110 | values, 111 | bins=bins, 112 | alpha=1.0, 113 | color=colors[: len(values)], 114 | label=list(x.keys()), 115 | linewidth=0.0, 116 | ) 117 | 118 | _ = plt.legend(loc="upper right", frameon=True, fontsize=_LEGEND_FONTSIZE) 119 | _ = plt.xlabel(x_label, fontsize=_FONTSIZE) 120 | _ = plt.ylabel(y_label, fontsize=_FONTSIZE) 121 | _ = plt.xticks(fontsize=_TICK_FONTSIZE) 122 | _ = plt.yticks(fontsize=_TICK_FONTSIZE, rotation=45) 123 | _ = plt.xlim(x_limit) 124 | if file_name is None: 125 | _ = plt.show() 126 | else: 127 | _ = plt.savefig(file_name, dpi=_DPI) 128 | plt.close(fig) 129 | -------------------------------------------------------------------------------- /ucate/library/utils/prediction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def mc_sample_tl(x, model_0, model_1, mean=0.0, std=1.0, mc_samples=100): 5 | y_0, h_0 = mc_sample_2(x, model_0, mc_samples) 6 | y_0 = y_0 * std + mean 7 | y_1, h_1 = mc_sample_2(x, model_1, mc_samples) 8 | y_1 = y_1 * std + mean 9 | return (y_0, h_0), (y_1, h_1) 10 | 11 | 12 | def mc_sample_cevae(x, t, model, mean=0.0, std=1.0, mc_samples=100): 13 | y_0, h_0 = mc_sample_2([x, t[0]], model, mc_samples) 14 | y_0 = y_0 * std + mean 15 | y_1, h_1 = mc_sample_2([x, t[1]], model, mc_samples) 16 | y_1 = y_1 * std + mean 17 | return (y_0, h_0), (y_1, h_1) 18 | 19 | 20 | def mc_sample(x, model, mc_samples=100): 21 | return np.asarray( 22 | [model(x, training=True) for _ in range(mc_samples)], dtype="float32" 23 | ) 24 | # return np.asarray([model.predict(x, batch_size=200) for _ in range(mc_samples)], dtype='float32') 25 | 26 | 27 | def mc_sample_2(x, model, mc_samples=100): 28 | y, h = [], [] 29 | for _ in range(mc_samples): 30 | y_pred, h_pred = model(x, batch_size=200) 31 | y.append(y_pred) 32 | h.append(h_pred) 33 | return np.asarray(y, dtype="float32"), np.asarray(h, dtype="float32") 34 | 35 | 36 | def cate_measures(mu_0, mu_1, y_0, y_1, regression): 37 | cate_pred = (mu_1 - mu_0).mean(0).ravel() 38 | predictive_uncrt = np.var(y_1 - y_0, 0).ravel() 39 | epistemic_unct = np.var(mu_1 - mu_0, 0).ravel() 40 | return cate_pred, predictive_uncrt, epistemic_unct 41 | 42 | 43 | def total_mi(p_0, p_1): 44 | return mi(p_0) + mi(p_1) 45 | 46 | 47 | def mi(p): 48 | h = entropy(p.mean(0)) 49 | h_cond = entropy(p).mean(0) 50 | return h - h_cond 51 | 52 | 53 | def entropy(p): 54 | eps = 1e-7 55 | p = np.clip(p, eps, 1 - eps) 56 | return -p * np.log(p) - (1 - p) * np.log((1 - p)) 57 | 58 | 59 | def differential_entropy(sigma): 60 | eps = 1e-7 61 | return 0.5 * np.log(2.0 * np.pi) + np.log(sigma + eps) + 0.5 62 | --------------------------------------------------------------------------------