├── requirements.txt ├── sigmoid_design.sh ├── non_linear_experiments.sh ├── linear_experiments.sh ├── LICENSE ├── README.md ├── src ├── kiv.py ├── data.py ├── utils.py ├── plotting.py └── run.py └── .gitignore /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | jax==0.1.76 3 | jaxlib==0.1.55 4 | matplotlib==3.3.1 5 | numpy==1.22.0 6 | scikit-learn==0.23.2 7 | scipy==1.5.2 8 | statsmodels==0.12.0 9 | tqdm==4.49.0 10 | -------------------------------------------------------------------------------- /sigmoid_design.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 src/run.py --dataset=synthetic --equations=np --output_dir=results/ --response_type=poly --dim_theta=4 --slack_abs=0.1 --output_name=sigmoid-cubic 4 | python3 src/run.py --dataset=synthetic --equations=np --output_dir=results/ --response_type=gp --dim_theta=7 --slack_abs=0.1 --output_name=sigmoid-gp 5 | python3 src/run.py --dataset=synthetic --equations=np --output_dir=results/ --response_type=mlp --dim_theta=7 --slack_abs=0.1 --output_name=sigmoid-mlp 6 | 7 | -------------------------------------------------------------------------------- /non_linear_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 src/run.py --dataset=synthetic --equations=quad1 --response_type=poly --dim_theta=2 --output_name=quad1-linear 4 | python3 src/run.py --dataset=synthetic --equations=quad1 --response_type=poly --dim_theta=3 --output_name=quad1-quadratic 5 | python3 src/run.py --dataset=synthetic --equations=quad1 --response_type=mlp --dim_theta=7 --output_name=quad1-mlp 6 | 7 | python3 src/run.py --dataset=synthetic --equations=quad2 --response_type=poly --dim_theta=2 --output_name=quad2-linear 8 | python3 src/run.py --dataset=synthetic --equations=quad2 --response_type=poly --dim_theta=3 --output_name=quad2-quadratic 9 | python3 src/run.py --dataset=synthetic --equations=quad2 --response_type=mlp --dim_theta=7 --output_name=quad2-mlp 10 | 11 | -------------------------------------------------------------------------------- /linear_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 src/run.py --dataset=synthetic --equations=lin1 --output_dir=results/ --response_type=poly --dim_theta=2 --output_name=lin1-linear 4 | python3 src/run.py --dataset=synthetic --equations=lin1 --output_dir=results/ --response_type=poly --dim_theta=3 --output_name=lin1-quadratic 5 | python3 src/run.py --dataset=synthetic --equations=lin1 --output_dir=results/ --response_type=mlp --dim_theta=7 --output_name=lin1-mlp 6 | 7 | python3 src/run.py --dataset=synthetic --equations=lin2 --output_dir=results/ --response_type=poly --dim_theta=2 --output_name=lin2-linear 8 | python3 src/run.py --dataset=synthetic --equations=lin2 --output_dir=results/ --response_type=poly --dim_theta=3 --output_name=lin2-quadratic 9 | python3 src/run.py --dataset=synthetic --equations=lin2 --output_dir=results/ --response_type=mlp --dim_theta=7 --output_name=lin2-mlp 10 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Niki Kilbertus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A class of algorithms for general instrumental variable models 2 | 3 | This is the Python code accompanying the paper 4 | 5 | [A Class of General Instrumental Variable Models](https://arxiv.org/abs/2006.06366)\ 6 | [Niki Kilbertus](https://sites.google.com/view/nikikilbertus), [Matt J. Kusner](https://mkusner.github.io/), [Ricardo Silva](http://www.homepages.ucl.ac.uk/~ucgtrbd/)\ 7 | Neural Information Processing Systems (NeurIPS) 2020 8 | 9 | ## Setup 10 | 11 | First clone this repository and navigate to the main directory 12 | 13 | ```sh 14 | git clone git@github.com:nikikilbertus/general-iv-models.git 15 | cd general-iv-models 16 | ``` 17 | 18 | To run the code, please first create a new Python3 environment (Python version >= 3.6 should work). 19 | For example, if you are using [virtualenvwrapper](https://virtualenvwrapper.readthedocs.io/) run 20 | 21 | ```sh 22 | mktmpenv -n 23 | ``` 24 | 25 | Then install the required packages into your newly created environment via 26 | 27 | ```sh 28 | python -m pip install -r requirements.txt 29 | ``` 30 | 31 | ## Run experiments 32 | 33 | There are three executable scripts in the `general-iv-models` directory to run different subsets of the experiments in the paper: 34 | 35 | * `linear_experiments.sh` 36 | * `non_linear_experiments.sh` 37 | * `sigmoid_design.sh` 38 | 39 | Running any of them will create a directory `general-iv-models/results` where all results (plots) will be stored. 40 | 41 | To run all experiments, simply run 42 | 43 | ```sh 44 | ./linear_experiments.sh 45 | ./non_linear_experiments.sh 46 | ./sigmoid_design.sh 47 | ``` -------------------------------------------------------------------------------- /src/kiv.py: -------------------------------------------------------------------------------- 1 | """Stripped down re-implementation kernel instrumental variable (KIV).""" 2 | 3 | from absl import logging 4 | 5 | import numpy as np 6 | from scipy.optimize import minimize 7 | 8 | from sklearn.model_selection import train_test_split 9 | 10 | from typing import Tuple, Text, Dict 11 | 12 | Kerneldict = Dict[Text, np.ndarray] 13 | 14 | 15 | def median_inter(x: np.ndarray) -> float: 16 | A = np.repeat(x[:, np.newaxis], len(x), -1) 17 | dist = np.abs(A - A.T).ravel() 18 | return float(np.median(dist)) 19 | 20 | 21 | def mse(x: np.ndarray, y: np.ndarray) -> float: 22 | """Mean squared error.""" 23 | return float(np.mean((x - y) ** 2)) 24 | 25 | 26 | def make_psd(k: np.ndarray, eps: float = 1e-10) -> np.ndarray: 27 | """Make matrix positive semi-definite.""" 28 | n = k.shape[0] 29 | return (k + k.T) / 2 + eps * np.eye(n) 30 | 31 | 32 | def get_k(x: np.ndarray, 33 | y: np.ndarray, 34 | z: np.ndarray, 35 | x_vis: np.ndarray) -> Kerneldict: 36 | """Setup all required matrices from input data.""" 37 | vx = median_inter(x) 38 | vz = median_inter(z) 39 | 40 | x1, x2, y1, y2, z1, z2 = train_test_split( 41 | x, y, z, shuffle=False, test_size=0.5) 42 | 43 | results = { 44 | 'y1': y1, 45 | 'y2': y2, 46 | 'y': y, 47 | 'K_XX': get_k_matrix(x1, x1, vx), 48 | 'K_xx': get_k_matrix(x2, x2, vx), 49 | 'K_xX': get_k_matrix(x2, x1, vx), 50 | 'K_Xtest': get_k_matrix(x1, x_vis, vx), 51 | 'K_ZZ': get_k_matrix(z1, z1, vz), 52 | 'K_Zz': get_k_matrix(z1, z2, vz), 53 | } 54 | return results 55 | 56 | 57 | def get_k_matrix(x1: np.ndarray, x2: np.ndarray, v: float) -> np.ndarray: 58 | """Construct rbf kernel matrix with parameter v.""" 59 | m = len(x1) 60 | n = len(x2) 61 | x1 = np.repeat(x1[:, np.newaxis], n, 1) 62 | x2 = np.repeat(x2[:, np.newaxis].T, m, 0) 63 | return np.exp(- ((x1 - x2) ** 2) / (2. * v ** 2)) 64 | 65 | 66 | def kiv1_loss(df: Kerneldict, lam: float) -> float: 67 | """Loss for tuning hyperparameter lambda.""" 68 | n = len(df["y1"]) 69 | m = len(df["y2"]) 70 | 71 | brac = make_psd(df["K_ZZ"]) + lam * np.eye(n) 72 | gamma = np.linalg.solve(brac, df["K_Zz"]) 73 | return np.trace(df["K_xx"] - 2. * df["K_xX"] @ gamma + 74 | gamma.T @ df["K_XX"] @ gamma) / m 75 | 76 | 77 | def kiv2_loss(df: Kerneldict, lam: float, xi: float) -> float: 78 | """Loss for tuning hyperparameter xi.""" 79 | y1_pred = kiv_pred(df, lam, xi, 2) 80 | return mse(df["y1"], y1_pred) 81 | 82 | 83 | def kiv_pred(df: Kerneldict, lam: float, xi: float, stage: int) -> np.ndarray: 84 | """Kernel instrumental variable prediction.""" 85 | n = len(df["y1"]) 86 | m = len(df["y2"]) 87 | 88 | brac = make_psd(df["K_ZZ"]) + lam * np.eye(n) 89 | W = np.linalg.solve(brac, df["K_XX"]).T @ df["K_Zz"] 90 | brac2 = make_psd(W @ W.T) + m * xi * make_psd(df["K_XX"]) 91 | alpha = np.linalg.solve(brac2, W @ df["y2"]) 92 | 93 | if stage == 2: 94 | k_xtest = df["K_XX"] 95 | elif stage == 3: 96 | k_xtest = df["K_Xtest"] 97 | else: 98 | raise ValueError(f"Stage must be 2 or 3, not {stage}") 99 | return (alpha.T @ k_xtest).T 100 | 101 | 102 | def fit_kiv(z: np.ndarray, 103 | x: np.ndarray, 104 | y: np.ndarray, 105 | num_xstar: int = 500, 106 | lambda_guess: float = None, 107 | xi_guess: float = None, 108 | fix_hyper: bool = False) -> Tuple[np.ndarray, np.ndarray]: 109 | """Fit kernel instrumental variable regression. 110 | 111 | Args: 112 | z: Instrument 113 | x: Treatment 114 | y: Outcome 115 | num_xstar: Number of points to put on the x grid 116 | lambda_guess: Guess for lambda. Either starting point for optimization or 117 | fixed if `fix_hyper` is True. 118 | xi_guess: Guess for xi. Either starting point for optimization or fixed 119 | if `fix_hyper` is True. 120 | fix_hyper: Whether to use fixed hyperparameters instead of optimizing. 121 | 122 | Returns: 123 | xstar: linear grid over range of provided x values 124 | ystar: predicted treatment effect evaluated at x_star 125 | """ 126 | xstar = np.linspace(np.min(x), np.max(x), num_xstar) 127 | logging.info("Setup matrices...") 128 | df = get_k(x, y, z, xstar) 129 | 130 | if not fix_hyper: 131 | lambda_0 = lambda_guess or np.log(0.05) 132 | 133 | def kiv1_obj(lam: float): 134 | return kiv1_loss(df, np.exp(lam)) 135 | 136 | logging.info("Optimize lambda...") 137 | print("Optimize lambda...") 138 | res = minimize(kiv1_obj, x0=lambda_0, method='BFGS') 139 | if not res.success: 140 | logging.info("KIV1 minimization did not succeed.") 141 | lambda_star = res.x 142 | logging.info(f"Optimal lambda {lambda_star}...") 143 | print(f"Optimal lambda {lambda_star}...") 144 | else: 145 | if lambda_guess is None: 146 | raise ValueError("lambda_guess required for fixed hyperparams.") 147 | lambda_star = lambda_guess 148 | logging.info(f"Fixed lambda {lambda_star}...") 149 | print(f"Fixed lambda {lambda_star}...") 150 | 151 | if not fix_hyper: 152 | xi_0 = xi_guess or np.log(0.05) 153 | 154 | def kiv2_obj(xi: float): 155 | return kiv2_loss(df, np.exp(lambda_star), np.exp(xi)) 156 | 157 | logging.info("Optimize xi...") 158 | print("Optimize xi...") 159 | # res = minimize_scalar(kiv2_obj) 160 | res = minimize(kiv2_obj, x0=xi_0, method='BFGS') 161 | if not res.success: 162 | logging.info("KIV2 minimization did not succeed.") 163 | xi_star = res.x 164 | logging.info(f"Optimal xi {xi_star}...") 165 | print(f"Optimal xi {xi_star}...") 166 | else: 167 | if xi_guess is None: 168 | raise ValueError("xi_guess required for fixed hyperparams.") 169 | xi_star = xi_guess 170 | logging.info(f"Fixed xi {xi_star}...") 171 | print(f"Fixed xi {xi_star}...") 172 | 173 | logging.info("Predict treatment effect...") 174 | ystar = kiv_pred(df, np.exp(lambda_star), np.exp(xi_star), 3) 175 | return xstar, ystar 176 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,macos,linux,vim,pycharm+all 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,macos,linux,vim,pycharm+all 4 | 5 | ### Linux ### 6 | *~ 7 | 8 | # temporary files which can be created if a process still has a handle open of a deleted file 9 | .fuse_hidden* 10 | 11 | # KDE directory preferences 12 | .directory 13 | 14 | # Linux trash folder which might appear on any partition or disk 15 | .Trash-* 16 | 17 | # .nfs files are created when an open file is removed but is still being accessed 18 | .nfs* 19 | 20 | ### macOS ### 21 | # General 22 | .DS_Store 23 | .AppleDouble 24 | .LSOverride 25 | 26 | # Icon must end with two \r 27 | Icon 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### PyCharm+all ### 49 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 50 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 51 | 52 | # User-specific stuff 53 | .idea/**/workspace.xml 54 | .idea/**/tasks.xml 55 | .idea/**/usage.statistics.xml 56 | .idea/**/dictionaries 57 | .idea/**/shelf 58 | 59 | # Generated files 60 | .idea/**/contentModel.xml 61 | 62 | # Sensitive or high-churn files 63 | .idea/**/dataSources/ 64 | .idea/**/dataSources.ids 65 | .idea/**/dataSources.local.xml 66 | .idea/**/sqlDataSources.xml 67 | .idea/**/dynamic.xml 68 | .idea/**/uiDesigner.xml 69 | .idea/**/dbnavigator.xml 70 | 71 | # Gradle 72 | .idea/**/gradle.xml 73 | .idea/**/libraries 74 | 75 | # Gradle and Maven with auto-import 76 | # When using Gradle or Maven with auto-import, you should exclude module files, 77 | # since they will be recreated, and may cause churn. Uncomment if using 78 | # auto-import. 79 | # .idea/artifacts 80 | # .idea/compiler.xml 81 | # .idea/jarRepositories.xml 82 | # .idea/modules.xml 83 | # .idea/*.iml 84 | # .idea/modules 85 | # *.iml 86 | # *.ipr 87 | 88 | # CMake 89 | cmake-build-*/ 90 | 91 | # Mongo Explorer plugin 92 | .idea/**/mongoSettings.xml 93 | 94 | # File-based project format 95 | *.iws 96 | 97 | # IntelliJ 98 | out/ 99 | 100 | # mpeltonen/sbt-idea plugin 101 | .idea_modules/ 102 | 103 | # JIRA plugin 104 | atlassian-ide-plugin.xml 105 | 106 | # Cursive Clojure plugin 107 | .idea/replstate.xml 108 | 109 | # Crashlytics plugin (for Android Studio and IntelliJ) 110 | com_crashlytics_export_strings.xml 111 | crashlytics.properties 112 | crashlytics-build.properties 113 | fabric.properties 114 | 115 | # Editor-based Rest Client 116 | .idea/httpRequests 117 | 118 | # Android studio 3.1+ serialized cache file 119 | .idea/caches/build_file_checksums.ser 120 | 121 | ### PyCharm+all Patch ### 122 | # Ignores the whole .idea folder and all .iml files 123 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 124 | 125 | .idea/ 126 | 127 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 128 | 129 | *.iml 130 | modules.xml 131 | .idea/misc.xml 132 | *.ipr 133 | 134 | # Sonarlint plugin 135 | .idea/sonarlint 136 | 137 | ### Python ### 138 | # Byte-compiled / optimized / DLL files 139 | __pycache__/ 140 | *.py[cod] 141 | *$py.class 142 | 143 | # C extensions 144 | *.so 145 | 146 | # Distribution / packaging 147 | .Python 148 | build/ 149 | develop-eggs/ 150 | dist/ 151 | downloads/ 152 | eggs/ 153 | .eggs/ 154 | lib/ 155 | lib64/ 156 | parts/ 157 | sdist/ 158 | var/ 159 | wheels/ 160 | pip-wheel-metadata/ 161 | share/python-wheels/ 162 | *.egg-info/ 163 | .installed.cfg 164 | *.egg 165 | MANIFEST 166 | 167 | # PyInstaller 168 | # Usually these files are written by a python script from a template 169 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 170 | *.manifest 171 | *.spec 172 | 173 | # Installer logs 174 | pip-log.txt 175 | pip-delete-this-directory.txt 176 | 177 | # Unit test / coverage reports 178 | htmlcov/ 179 | .tox/ 180 | .nox/ 181 | .coverage 182 | .coverage.* 183 | .cache 184 | nosetests.xml 185 | coverage.xml 186 | *.cover 187 | *.py,cover 188 | .hypothesis/ 189 | .pytest_cache/ 190 | pytestdebug.log 191 | 192 | # Translations 193 | *.mo 194 | *.pot 195 | 196 | # Django stuff: 197 | *.log 198 | local_settings.py 199 | db.sqlite3 200 | db.sqlite3-journal 201 | 202 | # Flask stuff: 203 | instance/ 204 | .webassets-cache 205 | 206 | # Scrapy stuff: 207 | .scrapy 208 | 209 | # Sphinx documentation 210 | docs/_build/ 211 | doc/_build/ 212 | 213 | # PyBuilder 214 | target/ 215 | 216 | # Jupyter Notebook 217 | .ipynb_checkpoints 218 | 219 | # IPython 220 | profile_default/ 221 | ipython_config.py 222 | 223 | # pyenv 224 | .python-version 225 | 226 | # pipenv 227 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 228 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 229 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 230 | # install all needed dependencies. 231 | #Pipfile.lock 232 | 233 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 234 | __pypackages__/ 235 | 236 | # Celery stuff 237 | celerybeat-schedule 238 | celerybeat.pid 239 | 240 | # SageMath parsed files 241 | *.sage.py 242 | 243 | # Environments 244 | .env 245 | .venv 246 | env/ 247 | venv/ 248 | ENV/ 249 | env.bak/ 250 | venv.bak/ 251 | pythonenv* 252 | 253 | # Spyder project settings 254 | .spyderproject 255 | .spyproject 256 | 257 | # Rope project settings 258 | .ropeproject 259 | 260 | # mkdocs documentation 261 | /site 262 | 263 | # mypy 264 | .mypy_cache/ 265 | .dmypy.json 266 | dmypy.json 267 | 268 | # Pyre type checker 269 | .pyre/ 270 | 271 | # pytype static type analyzer 272 | .pytype/ 273 | 274 | # profiling data 275 | .prof 276 | 277 | ### Vim ### 278 | # Swap 279 | [._]*.s[a-v][a-z] 280 | !*.svg # comment out if you don't need vector files 281 | [._]*.sw[a-p] 282 | [._]s[a-rt-v][a-z] 283 | [._]ss[a-gi-z] 284 | [._]sw[a-p] 285 | 286 | # Session 287 | Session.vim 288 | Sessionx.vim 289 | 290 | # Temporary 291 | .netrwhist 292 | # Auto-generated tag files 293 | tags 294 | # Persistent undo 295 | [._]*.un~ 296 | 297 | # End of https://www.toptal.com/developers/gitignore/api/python,macos,linux,vim,pycharm+all 298 | 299 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | """Data loading and pre-processing utilities.""" 2 | 3 | from typing import Tuple, Callable, Sequence, Text, Dict, Union 4 | 5 | import os 6 | 7 | from absl import logging 8 | 9 | import jax.numpy as np 10 | from jax import random 11 | 12 | import numpy as onp 13 | import pandas as pd 14 | from scipy.stats import norm 15 | 16 | import utils 17 | 18 | 19 | DataSynth = Tuple[Dict[Text, Union[np.ndarray, float, None]], 20 | np.ndarray, np.ndarray] 21 | DataReal = Dict[Text, Union[np.ndarray, float, None]] 22 | ArrayTup = Tuple[np.ndarray, np.ndarray] 23 | 24 | Equations = Dict[Text, Callable[..., np.ndarray]] 25 | 26 | 27 | # ============================================================================= 28 | # NOISE SOURCES 29 | # ============================================================================= 30 | 31 | 32 | def std_normal_1d(key: np.ndarray, num: int) -> np.ndarray: 33 | """Generate a Gaussian for the confounder.""" 34 | return random.normal(key, (num,)) 35 | 36 | 37 | def std_normal_2d(key: np.ndarray, num: int) -> ArrayTup: 38 | """Generate a multivariate Gaussian for the noises e_X, e_Y.""" 39 | key1, key2 = random.split(key) 40 | return random.normal(key1, (num,)), random.normal(key2, (num,)) 41 | 42 | 43 | # ============================================================================= 44 | # SYNTHETIC STRUCTURAL EQUATIONS 45 | # ============================================================================= 46 | 47 | 48 | structural_equations = { 49 | "lin1": { 50 | "noise": std_normal_2d, 51 | "confounder": std_normal_1d, 52 | "f_z": std_normal_1d, 53 | "f_x": lambda z, c, ex: 0.5 * z + 3 * c + ex, 54 | "f_y": lambda x, c, ey: x - 6 * c + ey, 55 | }, 56 | "lin2": { 57 | "noise": std_normal_2d, 58 | "confounder": std_normal_1d, 59 | "f_z": std_normal_1d, 60 | "f_x": lambda z, c, ex: 3.0 * z + 0.5 * c + ex, 61 | "f_y": lambda x, c, ey: x - 6 * c + ey, 62 | }, 63 | "quad1": { 64 | "noise": std_normal_2d, 65 | "confounder": std_normal_1d, 66 | "f_z": std_normal_1d, 67 | "f_x": lambda z, c, ex: 0.5 * z + 3 * c + ex, 68 | "f_y": lambda x, c, ey: 0.3 * x ** 2 - 1.5 * x * c + ey, 69 | }, 70 | "quad2": { 71 | "noise": std_normal_2d, 72 | "confounder": std_normal_1d, 73 | "f_z": std_normal_1d, 74 | "f_x": lambda z, c, ex: 3.0 * z + 0.5 * c + ex, 75 | "f_y": lambda x, c, ey: 0.3 * x ** 2 - 1.5 * x * c + ey, 76 | }, 77 | } 78 | 79 | 80 | # ============================================================================= 81 | # DATA GENERATORS 82 | # ============================================================================= 83 | 84 | def whiten( 85 | inputs: Dict[Text, np.ndarray] 86 | ) -> Dict[Text, Union[float, np.ndarray, None]]: 87 | """Whiten each input.""" 88 | res = {} 89 | for k, v in inputs.items(): 90 | if v is not None: 91 | mu = np.mean(v, 0) 92 | std = np.maximum(np.std(v, 0), 1e-7) 93 | res[k + "_mu"] = mu 94 | res[k + "_std"] = std 95 | res[k] = (v - mu) / std 96 | else: 97 | res[k] = v 98 | return res 99 | 100 | 101 | def whiten_with_mu_std(val: np.ndarray, mu: float, std: float) -> np.ndarray: 102 | return (val - mu) / std 103 | 104 | 105 | def get_synth_data( 106 | key: np.ndarray, 107 | num: int, 108 | equations: Text, 109 | num_xstar: int = 100, 110 | external_equations: Equations = None, 111 | disconnect_instrument: bool = False 112 | ) -> DataSynth: 113 | """Generate some synthetic data. 114 | 115 | Args: 116 | key: A JAX random key. 117 | num: The number of examples to generate. 118 | equations: Which structural equations to choose for x and y. Default: 1 119 | num_xstar: Size of grid for interventions on x. 120 | external_equations: A dictionary that must contain the keys 'f_x' and 121 | 'f_y' mapping to callables as values that take two np.ndarrays as 122 | arguments and produce another np.ndarray. These are the structural 123 | equations for X and Y in the graph Z -> X -> Y. 124 | If this argument is not provided, the `equation` argument selects 125 | structural equations from the pre-defined dict `structural_equations`. 126 | disconnect_instrument: Whether to regenerate random (standard Gaussian) 127 | values for the instrument after the data has been generated. This 128 | serves for diagnostic purposes, i.e., looking at the same x, y data, 129 | 130 | Returns: 131 | A 3-tuple (values, xstar, ystar) consisting a dictionary `values` 132 | containing values for x, y, z, confounder, ex, ey as well as two 133 | array xstar, ystar containing values for the true cause-effect. 134 | """ 135 | if external_equations is not None: 136 | eqs = external_equations 137 | elif equations == "np": 138 | return get_newey_powell(key, num, num_xstar) 139 | else: 140 | eqs = structural_equations[equations] 141 | 142 | key, subkey = random.split(key) 143 | ex, ey = eqs["noise"](subkey, num) 144 | key, subkey = random.split(key) 145 | confounder = eqs["confounder"](subkey, num) 146 | key, subkey = random.split(key) 147 | z = eqs["f_z"](subkey, num) 148 | x = eqs["f_x"](z, confounder, ex) 149 | y = eqs["f_y"](x, confounder, ey) 150 | 151 | values = whiten({'x': x, 'y': y, 'z': z, 'confounder': confounder, 152 | 'ex': ex, 'ey': ey}) 153 | 154 | # Evaluate E[ Y | do(x^*)] empirically 155 | xmin, xmax = np.min(x), np.max(x) 156 | xstar = np.linspace(xmin, xmax, num_xstar) 157 | ystar = [] 158 | for _ in range(500): 159 | key, subkey = random.split(key) 160 | tmpey = eqs["noise"](subkey, num_xstar)[1] 161 | key, subkey = random.split(key) 162 | tmpconf = eqs["confounder"](subkey, num_xstar) 163 | tmp_ystar = whiten_with_mu_std( 164 | eqs["f_y"](xstar, tmpconf, tmpey), values["y_mu"], values["y_std"]) 165 | ystar.append(tmp_ystar) 166 | ystar = np.array(ystar) 167 | xstar = whiten_with_mu_std(xstar, values["x_mu"], values["x_std"]) 168 | if disconnect_instrument: 169 | key, subkey = random.split(key) 170 | values['z'] = random.normal(subkey, shape=z.shape) 171 | return values, xstar, ystar 172 | 173 | 174 | def get_colonial_origins(data_dir: Text = "../data") -> DataReal: 175 | """Load data from colonial origins paper of Acemoglu.""" 176 | stata_path = os.path.join(data_dir, "colonial_origins", "data.dta") 177 | df = pd.read_stata(stata_path) 178 | ycol = 'logpgp95' 179 | zcol = 'logem4' 180 | xcol = 'avexpr' 181 | df = df[[zcol, xcol, ycol]].dropna() 182 | z, x, y = df[zcol].values, df[xcol].values, df[ycol].values 183 | data = {'x': x, 'y': y, 'z': z, 'confounder': None, 'ex': None, 'ey': None} 184 | return whiten(data) 185 | 186 | 187 | def get_newey_powell(key: np.ndarray, 188 | num: int, 189 | num_xstar: int = 100) -> DataSynth: 190 | """Get simulated Newey Powell (sigmoid design) data from KIV paper.""" 191 | def np_true(vals: np.ndarray): 192 | return np.log(np.abs(16. * vals - 8) + 1) * np.sign(vals - 0.5) 193 | xstar = np.linspace(0, 1, num_xstar) 194 | ystar = np_true(xstar) 195 | 196 | mu = np.zeros(3) 197 | sigma = np.array([[1., 0.5, 0.], [0.5, 1., 0.], [0., 0., 1.]]) 198 | r = random.multivariate_normal(key, mu, sigma, shape=(num,)) 199 | u, t, w = r[:, 0], r[:, 1], r[:, 2] 200 | x = w + t 201 | x = norm.cdf(x / np.sqrt(2.)) 202 | z = norm.cdf(w) 203 | e = u 204 | y = np_true(x) + e 205 | values = whiten({'x': x, 'y': y, 'z': z, 'ex': e, 'ey': e}) 206 | xstar = whiten_with_mu_std(xstar, values['x_mu'], values['x_std']) 207 | ystar = whiten_with_mu_std(ystar, values['y_mu'], values['y_std']) 208 | values['confounder'] = None 209 | return values, xstar, ystar 210 | 211 | 212 | # ============================================================================= 213 | # DISCRETIZATION AND CDF HANDLING 214 | # ============================================================================= 215 | 216 | 217 | def ecdf(vals: np.ndarray, num_points: int = None) -> ArrayTup: 218 | """Evaluate the empirical distribution function on fixed number of points.""" 219 | if num_points is None: 220 | num_points = len(vals) 221 | cdf = np.linspace(0, 1, num_points) 222 | t = np.quantile(vals, cdf) 223 | return t, cdf 224 | 225 | 226 | def cdf_inv(vals: np.ndarray, 227 | num_points: int = None) -> Callable[..., np.ndarray]: 228 | """Compute an interpolation function of the (empirical) inverse cdf.""" 229 | t, cdf = ecdf(vals, num_points) 230 | return lambda x: onp.interp(x, cdf, t) 231 | 232 | 233 | def get_cdf_invs(val: np.ndarray, 234 | bin_ids: np.ndarray, 235 | num_z: int) -> Sequence[Callable[..., np.ndarray]]: 236 | """Compute a list of interpolated inverse CDFs of val at each z in Z grid.""" 237 | cdf_invs = [] 238 | for i in range(num_z): 239 | cdf_invs.append(cdf_inv(val[bin_ids == i])) 240 | return cdf_invs 241 | 242 | 243 | def get_z_bin_assigment(z: np.ndarray, z_grid: np.ndarray) -> np.ndarray: 244 | """Assignment of values in z to the respective bin in z_grid.""" 245 | bins = np.concatenate((np.array([-np.inf]), 246 | z_grid[1:-1], 247 | np.array([np.inf]))) 248 | hist = onp.digitize(z, bins=bins, right=True) - 1 249 | return hist 250 | 251 | 252 | def get_x_samples(x: np.ndarray, 253 | bin_ids: np.ndarray, 254 | num_z: int, 255 | num_sample: int) -> ArrayTup: 256 | """Pre-compute samples from p(x | z^{(i)}) for each gridpoint zi.""" 257 | x_cdf_invs = get_cdf_invs(x, bin_ids, num_z) 258 | tmp = np.linspace(0, 1, num_sample + 2)[1:-1] 259 | tmp0 = utils.normal_cdf_inv(tmp, np.array([0]), np.array([0])) 260 | return tmp0, np.array([x_cdf_inv(tmp) for x_cdf_inv in x_cdf_invs]) 261 | 262 | 263 | def get_y_pre(y: np.ndarray, 264 | bin_ids: np.ndarray, 265 | num_z: int, 266 | num_points: int) -> np.ndarray: 267 | """Compute the grid of y points for constraint approach y.""" 268 | y_cdf_invs = get_cdf_invs(y, bin_ids, num_z) 269 | grid = np.linspace(0, 1, num_points + 2)[1:-1] 270 | return np.array([y_cdf_inv(grid) for y_cdf_inv in y_cdf_invs]) 271 | 272 | 273 | def make_zgrid_and_binids(z: np.ndarray, num_z: int) -> ArrayTup: 274 | """Discretize instrument Z and assign all z points to corresponding bins.""" 275 | if num_z <= 0: 276 | logging.info("Discrete instrument specified, checking for values.") 277 | z_grid = np.sort(onp.unique(z)) 278 | if len(z_grid) > 50: 279 | logging.info("Found more than 50 unique values for z. This is not a " 280 | "discrete instrument. Aborting!") 281 | raise RuntimeError("Discrete instrument specified, but not found.") 282 | logging.info(f"Found {len(z_grid)} unique values for discrete instrument.") 283 | bin_ids = - onp.ones_like(z) 284 | for i, zpoint in enumerate(z_grid): 285 | bin_ids[z == zpoint] = i 286 | if onp.any(bin_ids < 0): 287 | raise ValueError(f"Found negative value in bin_ids. " 288 | "Couldn't discretize instrument.") 289 | bin_ids = np.array(bin_ids).astype(int) 290 | else: 291 | z_grid = ecdf(z, num_z + 1)[0] 292 | bin_ids = get_z_bin_assigment(z, z_grid) 293 | z_grid = (z_grid[:-1] + z_grid[1:]) / 2 294 | return z_grid, bin_ids 295 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """Helper functions.""" 2 | 3 | import itertools 4 | from typing import Iterable 5 | 6 | from absl import logging 7 | import jax.numpy as np 8 | import jax.scipy as sp 9 | import numpy as onp 10 | from jax import jit, vmap, grad 11 | from jax.experimental import stax, optimizers 12 | from jax.nn import softmax 13 | from jax.scipy.special import logsumexp 14 | from sklearn.gaussian_process import GaussianProcessRegressor 15 | from sklearn.gaussian_process.kernels import RBF, PairwiseKernel 16 | from statsmodels.api import add_constant 17 | from statsmodels.sandbox.regression.gmm import IV2SLS 18 | 19 | 20 | # ============================================================================= 21 | # CDF and inverse CDF of Gaussians 22 | # ============================================================================= 23 | 24 | 25 | @jit 26 | def std_normal_cdf(x: np.ndarray) -> np.ndarray: 27 | """CDF of the standard normal.""" 28 | return 0.5 * (1. + sp.special.erf(x / np.sqrt(2.))) 29 | 30 | 31 | @jit 32 | def normal_cdf_inv(x: np.ndarray, 33 | mu: np.ndarray, 34 | log_sigma: np.ndarray) -> np.ndarray: 35 | """Inverse CDF of a Gaussian with given mean and log standard deviation.""" 36 | num = x.shape[-1] 37 | sigma = np.repeat(np.exp(log_sigma)[:, None], num, axis=-1) 38 | mu = np.repeat(mu[:, None], num, axis=-1) 39 | xx = np.clip(2 * x - 1, -0.999999, 0.999999) 40 | return np.sqrt(2.) * sigma * sp.special.erfinv(xx) + mu 41 | 42 | 43 | # ============================================================================= 44 | # Two stage least squares (2SLS) 45 | # ============================================================================= 46 | 47 | 48 | def two_stage_least_squares(z: np.ndarray, 49 | x: np.ndarray, 50 | y: np.ndarray) -> np.ndarray: 51 | """Fit 2sls model to data. 52 | 53 | Args: 54 | z: Instrument 55 | x: Treatment 56 | y: Outcome 57 | 58 | Returns: 59 | coeff: The coefficients of the estimated linear cause-effect relation. 60 | """ 61 | x = add_constant(onp.array(x)) 62 | z = add_constant(onp.array(z)) 63 | y = onp.array(y) 64 | iv2sls = IV2SLS(y, x, z).fit() 65 | logging.info(iv2sls.summary()) 66 | return np.array(iv2sls.params) 67 | 68 | 69 | # ============================================================================= 70 | # Basis functions: neural basis functions and GPs 71 | # ============================================================================= 72 | 73 | 74 | def interp_regular_1d(x: np.ndarray, 75 | xmin: float, 76 | xmax: float, 77 | yp: np.ndarray) -> np.ndarray: 78 | """One-dimensional linear interpolation. 79 | 80 | Returns the one-dimensional piecewise linear interpolation of the data points 81 | (xp, yp) evaluated at x. We extrapolate with the constants xmin and xmax 82 | outside the range [xmin, xmax]. 83 | 84 | Args: 85 | x: The x-coordinates at which to evaluate the interpolated values. 86 | xmin: The lower bound of the regular input x-coordinate grid. 87 | xmax: The upper bound of the regular input x-coordinate grid. 88 | yp: The y coordinates of the data points. 89 | 90 | Returns: 91 | y: The interpolated values, same shape as x. 92 | """ 93 | ny = len(yp) 94 | fractional_idx = (x - xmin) / (xmax - xmin) 95 | x_idx_unclipped = fractional_idx * (ny - 1) 96 | x_idx = np.clip(x_idx_unclipped, 0, ny - 1) 97 | idx_below = np.floor(x_idx) 98 | idx_above = np.minimum(idx_below + 1, ny - 1) 99 | idx_below = np.maximum(idx_above - 1, 0) 100 | y_ref_below = yp[idx_below.astype(np.int32)] 101 | y_ref_above = yp[idx_above.astype(np.int32)] 102 | t = x_idx - idx_below 103 | y = t * y_ref_above + (1 - t) * y_ref_below 104 | return y 105 | 106 | 107 | interp1d = jit(vmap(interp_regular_1d, in_axes=(None, None, None, 0))) 108 | 109 | 110 | def get_gp_prediction(x: np.ndarray, 111 | y: np.ndarray, 112 | n_samples: int, 113 | n_points: int = 100): 114 | """Fit a GP to observed P( Y | X ) and sample some functions from it. 115 | 116 | Args: 117 | x: x-values (features) 118 | y: y-values (targets/labels) 119 | n_samples: The number of GP samples to use as basis functions. 120 | n_points: The number of points to subsample form x and y to fit each GP. 121 | 122 | Returns: 123 | a function that takes as input an array either of shape (n,) or (k, n) 124 | and outputs: 125 | if input is 1D -> output: (n, n_samples) 126 | if input is 2D -> output: (k, n, n_samples) 127 | """ 128 | kernel = PairwiseKernel(metric='poly') + RBF() 129 | gp = GaussianProcessRegressor(kernel=kernel, 130 | alpha=0.4, 131 | n_restarts_optimizer=0, 132 | normalize_y=True) 133 | xmin = np.min(x) 134 | xmax = np.max(x) 135 | xx = np.linspace(xmin, xmax, n_points) 136 | y_samples = [] 137 | rng = onp.random.RandomState(0) 138 | for i in range(n_samples): 139 | logging.info("Subsample 200 points and fit GP to P(Y|X)...") 140 | idx = rng.choice(len(x), 200, replace=False) 141 | gp.fit(x[idx, np.newaxis], y[idx]) 142 | logging.info(f"Get a sample functions from the GP") 143 | y_samples.append(gp.sample_y(xx[:, np.newaxis], 1)) 144 | y_samples = np.array(y_samples).squeeze() 145 | logging.info(f"Shape of samples: {y_samples.shape}") 146 | 147 | def predict(inputs: np.ndarray) -> np.ndarray: 148 | return interp1d(inputs, xmin, xmax, y_samples).T 149 | return jit(vmap(predict)) 150 | 151 | 152 | def fit_mlp(key: np.ndarray, 153 | x: np.ndarray, 154 | y: np.ndarray, 155 | n_samples: np.ndarray, 156 | epochs: int = 100, 157 | batch_size: int = 256, 158 | learning_rate: float = 0.001, 159 | layers: Iterable[int] = (64, 64), 160 | return_basis=True): 161 | """Fit a small MLP to data. 162 | 163 | Args: 164 | key: Key for randomness. 165 | x: x-values (features) 166 | y: y-values (targets/labels) 167 | n_samples: The number of neurons in the additional last hidden layer which 168 | is used as basis functions. 169 | epochs: Number of epochs to train for. 170 | batch_size: batch size in MLP training 171 | learning_rate: Initial learning rate for Adam optimizer. 172 | layers: The hidden layer sizes. To this list, two dense layers of size 173 | n_samples and 1 are added. 174 | return_basis: Whether to return a function that returns the activations of 175 | the last layer instead of simply the full MLP model itself. 176 | 177 | Returns: 178 | a function that takes as input an array either of shape (k, n) 179 | and outputs and array of shape (k, n, n_samples) 180 | """ 181 | logging.info(f"Fit small mlp with {layers} neurons, batchsize {batch_size} " 182 | f"for {epochs} epochs with Adam and lr={learning_rate}") 183 | seq = [] 184 | for i in layers: 185 | seq.append(stax.Dense(i)) 186 | seq.append(stax.Relu) 187 | # the final hidden layer used as basis functions 188 | seq.append(stax.Dense(n_samples)) 189 | seq.append(stax.Relu) 190 | seq.append(stax.Dense(1)) 191 | init_fun, mlp = stax.serial(*seq) 192 | opt_init, opt_update, get_params = optimizers.adam(learning_rate) 193 | _, init_params = init_fun(key, (-1, x.shape[-1])) 194 | opt_state = opt_init(init_params) 195 | 196 | n_train = x.shape[0] 197 | num_complete_batches, leftover = divmod(n_train, batch_size) 198 | num_batches = num_complete_batches + bool(leftover) 199 | 200 | def data_stream(): 201 | rng = onp.random.RandomState(0) 202 | while True: 203 | perm = rng.permutation(n_train) 204 | for j in range(num_batches): 205 | batch_idx = perm[j * batch_size:(j + 1) * batch_size] 206 | yield x[batch_idx], y[batch_idx] 207 | batches = data_stream() 208 | itercount = itertools.count() 209 | 210 | @jit 211 | def loss(_params, _batch): 212 | inputs, targets = _batch 213 | preds = mlp(_params, inputs).squeeze() 214 | return np.mean((preds - targets) ** 2) 215 | 216 | @jit 217 | def update(_i, _opt_state, _batch): 218 | _params = get_params(_opt_state) 219 | return opt_update(_i, grad(loss)(_params, _batch), _opt_state) 220 | 221 | for epoch in range(epochs): 222 | for _ in range(num_batches): 223 | opt_state = update(next(itercount), opt_state, next(batches)) 224 | params = get_params(opt_state) 225 | train_loss = loss(params, (x, y)) 226 | logging.info(f"Epoch: {epoch + 1} / {epochs}") 227 | logging.info(f" Loss: {train_loss}") 228 | 229 | # Copy of the original MLP without the last layer 230 | if return_basis: 231 | _, out = stax.serial(*seq[:-1]) 232 | 233 | def predict(inputs: np.ndarray): 234 | return out(params[:-1], inputs) 235 | else: 236 | def predict(inputs: np.ndarray): 237 | return mlp(params, inputs) 238 | 239 | return jit(vmap(predict)) 240 | 241 | 242 | # ============================================================================= 243 | # Train mixture density network 244 | # ============================================================================= 245 | def fit_mdn(key: np.ndarray, 246 | x: np.ndarray, 247 | y: np.ndarray, 248 | n_hidden: int = 20, 249 | n_mixture: int = 5, 250 | learning_rate: float = 0.001, 251 | batch_size: int = 256, 252 | n_epochs: int = 300): 253 | """Fit a mixture density network to the data. 254 | 255 | Args: 256 | x: Input values. 257 | y: Output values. 258 | n_hidden: Number of hidden neurons in the (only) hidden layer. 259 | n_mixture: Number of mixture components. 260 | learning_rate: The (fixed) learning rate for fitting with SGD. 261 | batch_size: Training batch size. 262 | n_epochs: Number of epochs to train for. 263 | 264 | Returns: 265 | the fitted predictor (as a callable function: x -> y) 266 | """ 267 | log_sqrt_2pi = onp.log(onp.sqrt(2.0 * onp.pi)) 268 | 269 | init_fun, network = stax.serial(stax.Dense(n_hidden), 270 | stax.Tanh, 271 | stax.Dense(n_mixture * 3)) 272 | 273 | _, params = init_fun(key, (batch_size, x.shape[1])) 274 | 275 | opt_init, opt_update, get_params = optimizers.adam(learning_rate) 276 | opt_state = opt_init(params) 277 | 278 | def lognormal(_y, mean, logstd): 279 | return - 0.5 * ((_y - mean) / np.exp(logstd)) ** 2 - logstd - log_sqrt_2pi 280 | 281 | def get_mdn_coef(output): 282 | """Extract MDN coefficients.""" 283 | logmix, mean, logstd = output.split(3, axis=1) 284 | logmix = logmix - logsumexp(logmix, 1, keepdims=True) 285 | return logmix, mean, logstd 286 | 287 | def mdn_loss_func(logmix, mean, logstd, _y): 288 | """MDN loss function.""" 289 | v = logmix + lognormal(_y, mean, logstd) 290 | v = logsumexp(v, axis=1) 291 | return - np.mean(v) 292 | 293 | def loss_fn(_params, batch): 294 | """ MDN Loss function for training loop. """ 295 | inputs, targets = batch 296 | outputs = network(_params, inputs) 297 | logmix, mean, logstd = get_mdn_coef(outputs) 298 | return mdn_loss_func(logmix, mean, logstd, targets) 299 | 300 | @jit 301 | def update(step, _opt_state, batch): 302 | """ Compute the gradient for a batch and update the parameters.""" 303 | _params = get_params(_opt_state) 304 | grads = grad(loss_fn)(_params, batch) 305 | _opt_state = opt_update(step, grads, _opt_state) 306 | return _opt_state 307 | 308 | n_train = x.shape[0] 309 | num_complete_batches, leftover = divmod(n_train, batch_size) 310 | num_batches = num_complete_batches + bool(leftover) 311 | 312 | def data_stream(): 313 | rng = onp.random.RandomState(342) 314 | while True: 315 | perm = rng.permutation(n_train) 316 | for j in range(num_batches): 317 | batch_idx = perm[j * batch_size:(j + 1) * batch_size] 318 | yield x[batch_idx], y[batch_idx] 319 | batches = data_stream() 320 | itercount = itertools.count() 321 | 322 | for epoch in range(n_epochs): 323 | for _ in range(num_batches): 324 | opt_state = update(next(itercount), opt_state, next(batches)) 325 | params = get_params(opt_state) 326 | train_loss = loss_fn(params, (x, y)) 327 | if epoch % 10 == 0: 328 | logging.info(f"Epoch: {epoch + 1} / {n_epochs}") 329 | logging.info(f" Loss: {train_loss}") 330 | 331 | def predict(_x: np.ndarray): 332 | """Predict new values from MDN.""" 333 | logmix, mu_data, logstd = get_mdn_coef(network(params, _x)) 334 | pi_data = softmax(logmix) 335 | sigma_data = np.exp(logstd) 336 | z = onp.random.gumbel(loc=0, scale=1, size=pi_data.shape) 337 | k = (onp.log(pi_data) + z).argmax(axis=1) 338 | indices = (onp.arange(_x.shape[0]), k) 339 | rn = onp.random.randn(_x.shape[0]) 340 | sampled = rn * sigma_data[indices] + mu_data[indices] 341 | return sampled 342 | 343 | return predict 344 | -------------------------------------------------------------------------------- /src/plotting.py: -------------------------------------------------------------------------------- 1 | """Plotting functionality.""" 2 | 3 | from typing import Text, Optional, Dict 4 | 5 | import os 6 | 7 | from collections import Counter 8 | import numpy as onp 9 | 10 | import jax.numpy as np 11 | 12 | import matplotlib 13 | from matplotlib import rc 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | # ============================================================================= 18 | # MATPLOTLIB STYLING SETTINGS 19 | # ============================================================================= 20 | 21 | matplotlib.rcdefaults() 22 | 23 | rc('text', usetex=True) 24 | rc('font', size='16', family='serif', serif=['Palatino']) 25 | rc('figure', titlesize='20') # fontsize of the figure title 26 | rc('axes', titlesize='20') # fontsize of the axes title 27 | rc('axes', labelsize='18') # fontsize of the x and y labels 28 | rc('legend', fontsize='18') # legend fontsize 29 | rc('xtick', labelsize='18') # fontsize of the tick labels 30 | rc('ytick', labelsize='18') # fontsize of the tick labels 31 | 32 | rc('axes', xmargin=0) 33 | rc('lines', linewidth=3) 34 | rc('lines', markersize=10) 35 | rc('grid', color='grey', linestyle='solid', linewidth=0.5) 36 | titlekws = dict(y=1.0) 37 | 38 | FIGSIZE = (9, 6) 39 | data_kwargs = dict(alpha=0.5, s=5, marker='.', c='grey', label="data") 40 | 41 | 42 | # ============================================================================= 43 | # DECORATORS & GENERAL FUNCTIONALITY 44 | # ============================================================================= 45 | 46 | def empty_fig_on_failure(func): 47 | """Decorator for individual plot functions to return empty fig on failure.""" 48 | def applicator(*args, **kwargs): 49 | # noinspection PyBroadException 50 | try: 51 | return func(*args, **kwargs) 52 | except Exception: # pylint: disable=bare-except 53 | return plt.figure() 54 | return applicator 55 | 56 | 57 | def save_plot(figure: plt.Figure, path: Text): 58 | """Store a figure in a given location on disk.""" 59 | if path is not None: 60 | figure.savefig(path, bbox_inches="tight", format="pdf") 61 | plt.close(figure) 62 | 63 | 64 | # ============================================================================= 65 | # FINAL AGGREGATE RESULTS 66 | # ============================================================================= 67 | 68 | @empty_fig_on_failure 69 | def plot_final_max_abs_diff(xstar: np.ndarray, maxabsdiff: np.ndarray): 70 | fig = plt.figure() 71 | plt.semilogy(xstar, maxabsdiff[:, 0], 'g--x', label="lower", lw=2) 72 | plt.semilogy(xstar, maxabsdiff[:, 1], 'r--x', label="upper", lw=2) 73 | plt.xlabel("x") 74 | plt.ylabel(f"$\max |LHS - RHS|$") 75 | plt.title(f"Final maximum violation of constraints") 76 | plt.legend() 77 | return fig 78 | 79 | 80 | @empty_fig_on_failure 81 | def plot_final_bounds(x: np.ndarray, 82 | y: np.ndarray, 83 | xstar: np.ndarray, 84 | bounds: np.ndarray, 85 | data_xstar: np.ndarray, 86 | data_ystar: np.ndarray, 87 | coeff_2sls: np.ndarray = None, 88 | x_kiv: np.ndarray = None, 89 | y_kiv: np.ndarray = None) -> plt.Figure: 90 | fig = plt.figure() 91 | plt.scatter(x, y, **data_kwargs) 92 | plt.plot(xstar, bounds[:, 0], 'g--x', label="lower", lw=2, markersize=10) 93 | plt.plot(xstar, bounds[:, 1], 'r--x', label="upper", lw=2, markersize=10) 94 | if data_xstar is not None and data_ystar is not None: 95 | if data_ystar.ndim > 1: 96 | data_ystar = data_ystar.mean(0) 97 | plt.plot(data_xstar, data_ystar, label=f"$E[Y | do(x^*)]$", lw=2) 98 | if coeff_2sls is not None: 99 | tt = np.linspace(np.min(x), np.max(x), 10) 100 | y_2sls = coeff_2sls[0] + coeff_2sls[1] * tt 101 | plt.plot(tt, y_2sls, ls='dotted', c="tab:purple", lw=2, label="2sls") 102 | if x_kiv is not None and y_kiv is not None: 103 | plt.plot(x_kiv, y_kiv, ls='dashdot', c="tab:olive", lw=2, label="KIV") 104 | 105 | def get_limits(vals): 106 | lo = np.min(vals) 107 | hi = np.max(vals) 108 | extend = (hi - lo) / 15. 109 | return lo - extend, hi + extend 110 | 111 | plt.xlim(get_limits(x)) 112 | plt.ylim(get_limits(y)) 113 | plt.xlabel('x') 114 | plt.ylabel('y') 115 | plt.title("Lower and upper bound on actual effect") 116 | plt.legend() 117 | return fig 118 | 119 | 120 | # ============================================================================= 121 | # INDIVIDUAL RUN RESULTS 122 | # ============================================================================= 123 | 124 | @empty_fig_on_failure 125 | def plot_lagrangian(values: np.ndarray) -> plt.Figure: 126 | fig = plt.figure() 127 | plt.plot(values) 128 | plt.xlabel("update steps") 129 | plt.ylabel("Lagrangian") 130 | plt.title(f"Overall Lagrangian") 131 | return fig 132 | 133 | 134 | @empty_fig_on_failure 135 | def plot_max_sq_lhs_rhs(lhs: np.ndarray, rhs: np.ndarray) -> plt.Figure: 136 | fig = plt.figure() 137 | tt = np.array([np.max((lhs - r)**2) for r in rhs]) 138 | plt.semilogy(tt) 139 | plt.xlabel("optimization rounds") 140 | plt.ylabel("(LHS - RHS)^2") 141 | plt.title(f"(LHS - RHS)^2") 142 | return fig 143 | 144 | 145 | @empty_fig_on_failure 146 | def plot_max_abs_lhs_rhs(lhs: np.ndarray, rhs: np.ndarray) -> plt.Figure: 147 | fig = plt.figure() 148 | tt = np.array([np.max(np.abs(lhs - r)) for r in rhs]) 149 | plt.semilogy(tt) 150 | plt.xlabel("optimization rounds") 151 | plt.ylabel("max(|LHS - RHS|)") 152 | plt.title(f"max(|LHS - RHS|)") 153 | return fig 154 | 155 | 156 | @empty_fig_on_failure 157 | def plot_max_rel_abs_lhs_rhs(lhs: np.ndarray, rhs: np.ndarray) -> plt.Figure: 158 | fig = plt.figure() 159 | tt = np.array([np.max(np.abs((lhs - r) / lhs)) for r in rhs]) 160 | plt.semilogy(tt) 161 | plt.xlabel("optimization rounds") 162 | plt.ylabel("max(|LHS - RHS| / |LHS|)") 163 | plt.title(f"max(|LHS - RHS| / |LHS|)") 164 | return fig 165 | 166 | 167 | @empty_fig_on_failure 168 | def plot_abs_lhs_rhs(lhs: np.ndarray, rhs: np.ndarray) -> plt.Figure: 169 | fig = plt.figure() 170 | tt = np.array([np.abs(lhs - r) for r in rhs]) 171 | for i in range(len(lhs)): 172 | plt.semilogy(tt[:, i], label=f'{i + 1}') 173 | plt.xlabel("optimization rounds") 174 | plt.ylabel("|LHS - RHS|") 175 | plt.title(f"individual |LHS - RHS|") 176 | plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.5)) 177 | return fig 178 | 179 | 180 | @empty_fig_on_failure 181 | def plot_min_max_rhs(rhs: np.ndarray) -> plt.Figure: 182 | fig = plt.figure() 183 | tt = np.array([(np.min(r), np.max(r)) for r in rhs]) 184 | plt.plot(tt) 185 | plt.xlabel("optimization rounds") 186 | plt.ylabel("RHS min and max") 187 | plt.title(f"min and max of RHS") 188 | return fig 189 | 190 | 191 | @empty_fig_on_failure 192 | def plot_grad_norms(grad_norms: np.ndarray) -> plt.Figure: 193 | fig = plt.figure() 194 | grad_norms = np.array(grad_norms) 195 | plt.semilogy(grad_norms) 196 | plt.xlabel("update steps") 197 | plt.ylabel("norm of gradients") 198 | plt.legend(["L", "mu", "log_sigma"]) 199 | plt.title(f"Gradient norms (w.r.t. $L$, $\mu$, $\log(\sigma)$)") 200 | return fig 201 | 202 | 203 | @empty_fig_on_failure 204 | def plot_mu(mus: np.ndarray) -> plt.Figure: 205 | fig = plt.figure() 206 | plt.plot(np.array(mus), '-x') 207 | plt.xlabel("optimization rounds") 208 | plt.ylabel(f"$\mu$") 209 | plt.title(f"Means $\mu$ of $\\theta$s") 210 | return fig 211 | 212 | 213 | @empty_fig_on_failure 214 | def plot_sigma(sigmas: np.ndarray) -> plt.Figure: 215 | fig = plt.figure() 216 | plt.plot(np.array(sigmas), '-x') 217 | plt.xlabel("optimization rounds") 218 | plt.ylabel(f"$\sigma$") 219 | plt.title(f"Stddevs $\sigma$ of $\\theta$s") 220 | return fig 221 | 222 | 223 | @empty_fig_on_failure 224 | def plot_mu_and_sigma(mus: np.ndarray, sigmas: np.ndarray) -> plt.Figure: 225 | fig = plt.figure() 226 | epochs = np.arange(mus.shape[0]) 227 | mus = np.array(mus) 228 | sigmas = np.array(sigmas) 229 | for i in range(mus.shape[1]): 230 | mu = mus[:, i] 231 | sigma = sigmas[:, i] 232 | plt.fill_between(epochs, mu - sigma, mu + sigma, alpha=0.3) 233 | plt.plot(epochs, mu, '-x') 234 | plt.xlabel("optimization rounds") 235 | plt.ylabel(f"$\mu$") 236 | plt.title(f"Means $\mu$ of $\\theta$s") 237 | return fig 238 | 239 | 240 | @empty_fig_on_failure 241 | def plot_cholesky_factor(ls: np.ndarray) -> plt.Figure: 242 | fig = plt.figure() 243 | plt.plot(ls, '-x') 244 | plt.xlabel("optimization rounds") 245 | plt.ylabel(f"entries of $L$") 246 | plt.title(f"entries of the Cholesky factor $L$") 247 | return fig 248 | 249 | 250 | @empty_fig_on_failure 251 | def plot_tau(taus: np.ndarray) -> plt.Figure: 252 | fig = plt.figure() 253 | plt.plot(taus, "-x") 254 | plt.xlabel("optimization rounds") 255 | plt.ylabel(f"temperature $\\tau$") 256 | plt.title(f"temperature parameter $\\tau$") 257 | return fig 258 | 259 | 260 | @empty_fig_on_failure 261 | def plot_rho(rhos: np.ndarray) -> plt.Figure: 262 | fig = plt.figure() 263 | plt.plot(rhos, "-x") 264 | plt.xlabel("optimization rounds") 265 | plt.ylabel(f"annealing $\\rho$") 266 | plt.title(f"annealing parameter $\\rho$") 267 | return fig 268 | 269 | 270 | @empty_fig_on_failure 271 | def plot_lambda(lmbdas: np.ndarray) -> plt.Figure: 272 | fig = plt.figure() 273 | plt.semilogy(lmbdas, "-x") 274 | plt.xlabel("optimization rounds") 275 | plt.ylabel(f"multipliers $\lambda$") 276 | plt.title(f"Lagrange multipliers $\lambda$") 277 | return fig 278 | 279 | 280 | @empty_fig_on_failure 281 | def plot_objective(objectives: np.ndarray) -> plt.Figure: 282 | fig = plt.figure() 283 | plt.plot(objectives) 284 | plt.xlabel("optimization rounds") 285 | plt.ylabel(f"objective value") 286 | plt.title("Objective") 287 | return fig 288 | 289 | 290 | @empty_fig_on_failure 291 | def plot_constraint_term(constrs: np.ndarray) -> plt.Figure: 292 | fig = plt.figure() 293 | plt.semilogy(constrs) 294 | plt.xlabel("optimization rounds") 295 | plt.ylabel(f"constraint term") 296 | plt.title("Constraint term") 297 | return fig 298 | 299 | 300 | @empty_fig_on_failure 301 | def plot_mean_response(mus: np.ndarray, 302 | x: np.ndarray, 303 | y: np.ndarray, 304 | response) -> plt.Figure: 305 | fig = plt.figure() 306 | plt.plot(x, y, '.', alpha=0.3, label='data') 307 | xx = np.linspace(np.min(x), np.max(x), 100) 308 | yy = [] 309 | for x in xx: 310 | yy.append(response(mus[-1, :], x)) 311 | yy = np.array(yy).squeeze() 312 | plt.plot(xx, yy, label='mean response') 313 | plt.xlabel("x") 314 | plt.ylabel("y") 315 | plt.title("Mean response") 316 | return fig 317 | 318 | 319 | # ============================================================================= 320 | # DATA AND PREPROCESSING 321 | # ============================================================================= 322 | 323 | # @empty_fig_on_failure 324 | def plot_data(z: np.ndarray, 325 | x: np.ndarray, 326 | y: np.ndarray, 327 | confounder: np.ndarray, 328 | ex: np.ndarray, 329 | ey: np.ndarray) -> plt.Figure: 330 | 331 | def corr_label(_x, _y): 332 | return f'$\\rho = $ {onp.corrcoef(_x, _y)[0, 1]:.02f}' 333 | 334 | fig, axs = plt.subplots(3, 3, figsize=(15, 10)) 335 | if ex is not None: 336 | axs[0, 0].plot(ex, x, '.', label=corr_label(ex, x)) 337 | axs[0, 0].set_xlabel("noise ex") 338 | axs[0, 0].set_ylabel("treatment x") 339 | axs[0, 0].legend() 340 | axs[0, 1].plot(z, x, '.', label=corr_label(z, x)) 341 | axs[0, 1].set_xlabel("instrument z") 342 | axs[0, 1].set_ylabel("treatment x") 343 | axs[0, 1].legend() 344 | if confounder is not None: 345 | axs[0, 2].plot(confounder, x, '.', label=corr_label(confounder, x)) 346 | axs[0, 2].set_xlabel("confounder") 347 | axs[0, 2].set_ylabel("treatment x") 348 | axs[0, 2].legend() 349 | 350 | if ey is not None: 351 | axs[1, 0].plot(ey, y, '.', label=corr_label(ey, y)) 352 | axs[1, 0].set_xlabel("noise ey") 353 | axs[1, 0].set_ylabel("outcome y") 354 | axs[1, 0].legend() 355 | axs[1, 1].plot(x, y, '.', label=corr_label(x, y)) 356 | axs[1, 1].set_xlabel("treatment x") 357 | axs[1, 1].set_ylabel("outcome y") 358 | axs[1, 1].legend() 359 | if confounder is not None: 360 | axs[1, 2].plot(confounder, y, '.', label=corr_label(confounder, y)) 361 | axs[1, 2].set_xlabel("confounder") 362 | axs[1, 2].set_ylabel("outcome y") 363 | axs[1, 2].legend() 364 | 365 | if ey is not None and ex is not None: 366 | axs[2, 0].plot(ex, ey, '.', label=corr_label(ex, ey)) 367 | axs[2, 0].set_xlabel("noise ex") 368 | axs[2, 0].set_ylabel("noise ey") 369 | axs[2, 0].legend() 370 | axs[2, 1].plot(z, y, '.', label=corr_label(z, y)) 371 | axs[2, 1].set_xlabel("instrument z") 372 | axs[2, 1].set_ylabel("outcome y") 373 | axs[2, 1].legend() 374 | 375 | return fig 376 | 377 | 378 | @empty_fig_on_failure 379 | def plot_bin_hist(bin_ids: np.ndarray) -> plt.Figure: 380 | fig = plt.figure() 381 | tt = np.array(list(Counter(bin_ids).items())) 382 | plt.bar(tt[:, 0], tt[:, 1]) 383 | plt.xlabel("bins") 384 | plt.ylabel("Number of data points") 385 | plt.title("Distribution of datapoints into z-bins") 386 | return fig 387 | 388 | 389 | @empty_fig_on_failure 390 | def plot_bin_assignment(z: np.ndarray, 391 | val: np.ndarray, 392 | z_grid: np.ndarray, 393 | bin_ids: np.ndarray, 394 | ylabel: Text) -> plt.Figure: 395 | fig = plt.figure() 396 | num_z = len(z_grid) 397 | for i in range(num_z): 398 | plt.plot(z[bin_ids == i], val[bin_ids == i], '.') 399 | for zi in z_grid: 400 | plt.axvline(zi, c='k', lw=0.5) 401 | plt.xlabel('z') 402 | plt.ylabel(ylabel) 403 | plt.title("Bin assignment and z-grid lines") 404 | return fig 405 | 406 | 407 | @empty_fig_on_failure 408 | def plot_hist_at_z(y: np.ndarray, bin_ids: np.ndarray, idx: int) -> plt.Figure: 409 | fig = plt.figure() 410 | plt.hist(y[bin_ids == idx], bins=30) 411 | plt.xlabel('y') 412 | mean = np.mean(y[bin_ids == idx]) 413 | var = np.var(y[bin_ids == idx]) 414 | plt.title( 415 | f"mean {mean:.2f} and variance {var:.2f} for z bin {idx}") 416 | return fig 417 | 418 | 419 | @empty_fig_on_failure 420 | def plot_y_with_constraints(z: np.ndarray, 421 | y: np.ndarray, 422 | z_grid: np.ndarray, 423 | lhs: np.ndarray) -> plt.Figure: 424 | fig = plt.figure() 425 | lo = lhs[:, 0] - lhs[:, 1] 426 | hi = lhs[:, 0] + lhs[:, 1] 427 | plt.fill_between(z_grid, lo, hi, alpha=0.5, color='r') 428 | plt.plot(z, y, '.', alpha=0.3) 429 | plt.plot(z_grid, lhs[:, 0]) 430 | plt.xlabel('z') 431 | plt.ylabel('y') 432 | plt.title("Datapoints with mean and variance from LHS constraints") 433 | return fig 434 | 435 | 436 | @empty_fig_on_failure 437 | def plot_inverse_cdfs(x_cdf_invs) -> plt.Figure: 438 | fig = plt.figure() 439 | t = np.linspace(0, 1, 50) 440 | for i, invcdf in enumerate(x_cdf_invs): 441 | plt.plot(t, invcdf(t), label=f"i: {i}") 442 | plt.ylabel("x") 443 | plt.xlabel("CDF") 444 | plt.title("Inverse CDFs of x for different z in grid") 445 | return fig 446 | 447 | 448 | @empty_fig_on_failure 449 | def plot_xhats_distr(x: np.ndarray, xhats: np.ndarray) -> plt.Figure: 450 | fig = plt.figure() 451 | plt.hist(xhats.ravel(), bins=50, density=True, alpha=0.3, label="sampled x") 452 | plt.hist(x, bins=50, density=True, alpha=0.3, label="actual x (data)") 453 | plt.legend() 454 | plt.xlabel("x") 455 | plt.ylabel("density") 456 | plt.title("Distribution of pre-sampled and actual x") 457 | return fig 458 | 459 | 460 | @empty_fig_on_failure 461 | def plot_discrepancy_x(z: np.ndarray, 462 | x: np.ndarray, 463 | xhats: np.ndarray, 464 | z_grid: np.ndarray) -> plt.Figure: 465 | # Check where the discrepancy between real x and sampled x comes from 466 | fig = plt.figure() 467 | middle = np.mean(xhats, axis=-1) 468 | delta = np.std(xhats, axis=-1) 469 | plt.plot(z, x, '.', alpha=0.2, label='data') 470 | plt.fill_between(z_grid, middle - delta, middle + delta, 471 | alpha=0.5, color='r', label='var samples') 472 | plt.plot(z_grid, middle, label='mean samples') 473 | plt.xlabel('z') 474 | plt.ylabel('x') 475 | plt.legend() 476 | return fig 477 | 478 | 479 | # @empty_fig_on_failure 480 | def plot_basis_samples(basis_func, x: np.ndarray, y: np.ndarray) -> plt.Figure: 481 | fig = plt.figure() 482 | xx = np.linspace(np.min(x), np.max(x), 200) 483 | ys = basis_func(xx).squeeze() 484 | plt.plot(x, y, '.', alpha=0.2, label='data') 485 | plt.plot(xx, ys, label='basis funcs') 486 | plt.xlabel('x') 487 | plt.ylabel('y') 488 | plt.legend() 489 | return fig 490 | 491 | 492 | # ============================================================================= 493 | # PLOT ALL 494 | # ============================================================================= 495 | 496 | def plot_all_init(z: np.ndarray, 497 | x: np.ndarray, 498 | y: np.ndarray, 499 | confounder: np.ndarray, 500 | ex: np.ndarray, 501 | ey: np.ndarray, 502 | xhats: np.ndarray, 503 | z_grid: np.ndarray, 504 | bin_ids: np.ndarray, 505 | lhs: np.ndarray, 506 | basis_func=None, 507 | base_dir: Optional[Text] = None): 508 | """Call all relevant plotting functions for initialization and data.""" 509 | if base_dir is not None and not os.path.exists(base_dir): 510 | os.makedirs(base_dir) 511 | 512 | num_z = len(z_grid) 513 | 514 | path = os.path.join(base_dir, f"data.pdf") 515 | save_plot(plot_data(z, x, y, confounder, ex, ey), path) 516 | 517 | path = os.path.join(base_dir, f"bin_histogram.pdf") 518 | save_plot(plot_bin_hist(bin_ids), path) 519 | 520 | path = os.path.join(base_dir, f"bin_assignment_x.pdf") 521 | save_plot(plot_bin_assignment(z, x, z_grid, bin_ids, 'x'), path) 522 | 523 | path = os.path.join(base_dir, f"bin_assignment_y.pdf") 524 | save_plot(plot_bin_assignment(z, y, z_grid, bin_ids, 'y'), path) 525 | 526 | path = os.path.join(base_dir, f"y_hist_last_z.pdf") 527 | save_plot(plot_hist_at_z(y, bin_ids, num_z - 1), path) 528 | 529 | path = os.path.join(base_dir, f"y_with_constraints.pdf") 530 | save_plot(plot_y_with_constraints(z, y, z_grid, lhs), path) 531 | 532 | path = os.path.join(base_dir, f"xhat_distribution.pdf") 533 | save_plot(plot_xhats_distr(x, xhats), path) 534 | 535 | path = os.path.join(base_dir, f"discrepancy_x.pdf") 536 | save_plot(plot_discrepancy_x(z, x, xhats, z_grid), path) 537 | 538 | # path = os.path.join(base_dir, f"inverse_cdfs.pdf") 539 | # save_plot(plot_inverse_cdfs(x_cdf_invs), path) 540 | 541 | if basis_func is not None: 542 | path = os.path.join(base_dir, f"basis_functions.pdf") 543 | save_plot(plot_basis_samples(basis_func, x, y), path) 544 | 545 | 546 | def plot_all_final(final: Dict[Text, np.ndarray], 547 | satisfied: Dict[Text, np.ndarray], 548 | x: np.ndarray, 549 | y: np.ndarray, 550 | xstar_grid: np.ndarray, 551 | data_xstar: np.ndarray, 552 | data_ystar: np.ndarray, 553 | coeff_2sls: np.ndarray = None, 554 | x_kiv: np.ndarray = None, 555 | y_kiv: np.ndarray = None, 556 | base_dir: Optional[Text] = None): 557 | """Call all relevant plotting functions for final aggregate results.""" 558 | if base_dir is not None and not os.path.exists(base_dir): 559 | os.makedirs(base_dir) 560 | 561 | # To also show the last valid (non-nan) bound regardless of whether they 562 | # satisfied the constraints, use mode="non-nan", results=satisfied instead. 563 | mode = "satisfied" 564 | results = satisfied 565 | result_path = os.path.join(base_dir, f"final_{mode}_bounds.pdf") 566 | save_plot(plot_final_bounds(x, y, xstar_grid, results["objective"], 567 | data_xstar, data_ystar, coeff_2sls, 568 | x_kiv, y_kiv), 569 | result_path) 570 | 571 | # Uncomment to show maximum absolute violation of exact constraints 572 | # result_path = os.path.join(base_dir, f"final_{mode}_maxabsdiff.pdf") 573 | # save_plot(plot_final_max_abs_diff(xstar_grid, results["maxabsdiff"]), 574 | # result_path) 575 | 576 | 577 | def plot_all(results, 578 | x: np.ndarray, 579 | y: np.ndarray, 580 | response, 581 | base_dir: Optional[Text] = None, suffix: Text = ""): 582 | """Call all relevant plotting functions. 583 | 584 | Args: 585 | results: The results dictionary. 586 | x: The x values of the original data. 587 | y: The y values of the original data. 588 | response: The response function. 589 | base_dir: The path where to store the figures. If `None` don't save the 590 | figures to disk. 591 | suffix: An optional suffix to each filename stored by this function. 592 | """ 593 | if base_dir is not None and not os.path.exists(base_dir): 594 | os.makedirs(base_dir) 595 | 596 | def get_filename(base: Text, fname: Text): 597 | return None if base is None else os.path.join(base, fname) 598 | 599 | suff = "_" + suffix if suffix else suffix 600 | 601 | name = "lagrangian{}.pdf".format(suff) 602 | save_plot(plot_lagrangian(results["lagrangian"]), 603 | get_filename(base_dir, name)) 604 | 605 | name = "grad_norms{}.pdf".format(suff) 606 | save_plot(plot_grad_norms(results["grad_norms"]), 607 | get_filename(base_dir, name)) 608 | 609 | name = "mu{}.pdf".format(suff) 610 | save_plot(plot_mu(results["mu"]), 611 | get_filename(base_dir, name)) 612 | 613 | name = "sigma{}.pdf".format(suff) 614 | save_plot(plot_sigma(results["sigma"]), 615 | get_filename(base_dir, name)) 616 | 617 | name = "mu_and_sigma{}.pdf".format(suff) 618 | save_plot(plot_mu_and_sigma(results["mu"], results["sigma"]), 619 | get_filename(base_dir, name)) 620 | 621 | name = "cholesky_factor{}.pdf".format(suff) 622 | save_plot(plot_cholesky_factor(results["cholesky_factor"]), 623 | get_filename(base_dir, name)) 624 | 625 | name = "tau{}.pdf".format(suff) 626 | save_plot(plot_tau(results["tau"]), 627 | get_filename(base_dir, name)) 628 | 629 | name = "rho{}.pdf".format(suff) 630 | save_plot(plot_rho(results["rho"]), 631 | get_filename(base_dir, name)) 632 | 633 | name = "lambda{}.pdf".format(suff) 634 | save_plot(plot_lambda(results["lambda"]), 635 | get_filename(base_dir, name)) 636 | 637 | name = "objective{}.pdf".format(suff) 638 | save_plot(plot_objective(results["objective"]), 639 | get_filename(base_dir, name)) 640 | 641 | name = "constraint_term{}.pdf".format(suff) 642 | save_plot(plot_constraint_term(results["constraint_term"]), 643 | get_filename(base_dir, name)) 644 | 645 | name = "mean_response{}.pdf".format(suff) 646 | save_plot(plot_mean_response(results["mu"], x, y, response), 647 | get_filename(base_dir, name)) 648 | 649 | name = "max_abs_lhs_rhs{}.pdf".format(suff) 650 | save_plot(plot_max_abs_lhs_rhs(results["lhs"], results["rhs"]), 651 | get_filename(base_dir, name)) 652 | 653 | name = "max_rel_abs_lhs_rhs{}.pdf".format(suff) 654 | save_plot(plot_max_rel_abs_lhs_rhs(results["lhs"], results["rhs"]), 655 | get_filename(base_dir, name)) 656 | 657 | name = "abs_lhs_rhs{}.pdf".format(suff) 658 | save_plot(plot_abs_lhs_rhs(results["lhs"], results["rhs"]), 659 | get_filename(base_dir, name)) 660 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | """Main entry point.""" 2 | 3 | from typing import Tuple, Text 4 | 5 | import json 6 | import os 7 | 8 | from absl import app 9 | from absl import flags 10 | from absl import logging 11 | 12 | from datetime import datetime 13 | 14 | import jax.numpy as np 15 | from jax import random, value_and_grad, jit 16 | import jax.experimental.optimizers as optim 17 | from jax.ops import index_update 18 | import numpy as onp 19 | from scipy.interpolate import UnivariateSpline 20 | 21 | from tqdm import tqdm 22 | 23 | import data 24 | import kiv 25 | import plotting 26 | import utils 27 | 28 | Params = Tuple[np.ndarray, np.ndarray, np.ndarray] 29 | 30 | # ------------------------------- DATASET ------------------------------------- 31 | flags.DEFINE_enum("dataset", "synthetic", 32 | ("synthetic", "nlsym", "colonial_origins"), 33 | "Which dataset to use.") 34 | flags.DEFINE_string("equations", "lin1", 35 | "Which structural equations to use in synthetic setting.") 36 | flags.DEFINE_bool("disconnect_instrument", False, 37 | "Whether to resample independent values for instrument. " 38 | "(Mostly for debugging.)") 39 | flags.DEFINE_integer("num_data", 5_000, 40 | "The number of observations in the synthetic dataset.") 41 | # ---------------------------- APPROXIMATIONS --------------------------------- 42 | flags.DEFINE_enum("response_type", "poly", ("poly", "gp", "mlp"), 43 | "Basis response functions (polynomials or GP samples).") 44 | flags.DEFINE_integer("num_xstar", 15, 45 | "Number of x values at which to evaluate the objective.") 46 | flags.DEFINE_integer("dim_theta", 2, 47 | "The dimension of the parameter theta. This is also the " 48 | "number of response basis functions to use.") 49 | flags.DEFINE_integer("num_z", 20, 50 | "The number of grid points for the instrument Z.") 51 | # ---------------------------- OPTIMIZATION ----------------------------------- 52 | flags.DEFINE_integer("num_rounds", 150, 53 | "Number of rounds in the augmented Lagrangian.") 54 | flags.DEFINE_integer("opt_steps", 30, 55 | "Number of gradient updates per optimization subproblem.") 56 | flags.DEFINE_integer("bs", 1024, 57 | "Number of examples for MC estimates of the objective.") 58 | flags.DEFINE_integer("bs_constr", 4096, 59 | "Number of examples for MC estimates of the constraints.") 60 | # ---------------------------- CONSTRAINT ------------------------------------- 61 | flags.DEFINE_float("slack", 0.2, 62 | "Fractional tolerance for the constraints.") 63 | flags.DEFINE_float("slack_abs", 0.2, 64 | "Additional absolute tolerance for the constraints.") 65 | # --------------------- LEARNING RATE & MOMENTUM ------------------------------ 66 | flags.DEFINE_float("lr", 0.001, 67 | "The (initial) learning rate for the optimization.") 68 | flags.DEFINE_integer("decay_steps", 1000, 69 | "Number of decay steps for the learning rate schedule.") 70 | flags.DEFINE_float("decay_rate", 1.0, 71 | "The decay rate for the learning rate schedule.") 72 | flags.DEFINE_bool("staircase", False, 73 | "Whether to use staircases in the learning rate schedule") 74 | flags.DEFINE_float("momentum", 0.9, 75 | "The momentum parameter for the SGD optimizer.") 76 | # ---------------------------- SCHEDULES -------------------------------------- 77 | flags.DEFINE_float("tau_init", 0.1, 78 | "The initial value of the temperature parameter.") 79 | flags.DEFINE_float("tau_factor", 1.08, 80 | "The factor by which tau is multiplied each round.") 81 | flags.DEFINE_float("tau_max", 10.0, 82 | "The maximum temperature tau.") 83 | # ---------------------------- INPUT/OUTPUT ----------------------------------- 84 | flags.DEFINE_string("data_dir", "../data/", 85 | "Directory of the input data.") 86 | flags.DEFINE_string("output_dir", "../results/", 87 | "Path to the output directory (for results).") 88 | flags.DEFINE_string("output_name", "", 89 | "Name for result folder. Use timestamp if empty.") 90 | flags.DEFINE_bool("plot_init", False, 91 | "Whether to plot data and initialization visuals.") 92 | flags.DEFINE_bool("plot_intermediate", False, 93 | "Whether to plot results from individual runs.") 94 | flags.DEFINE_bool("plot_final", True, 95 | "Whether to plot final aggregate results.") 96 | flags.DEFINE_bool("store_data", False, 97 | "Whether to store data, intermediate and final results and " 98 | "baseline results.") 99 | # ---------------------------- COMPARISONS ------------------------------------ 100 | flags.DEFINE_bool("run_2sls", True, 101 | "Whether to run two stage least squares as comparison.") 102 | flags.DEFINE_bool("run_kiv", True, 103 | "Whether to run kernel instrumental variable as comparison.") 104 | # ------------------------------ MISC ----------------------------------------- 105 | flags.DEFINE_integer("seed", 0, "The random seed.") 106 | FLAGS = flags.FLAGS 107 | 108 | 109 | # ============================================================================= 110 | # RHS CONSTRAINT FUNCTIONS THAT MUST BE OVERWRITTEN 111 | # ============================================================================= 112 | 113 | @jit 114 | def get_phi(y: np.ndarray) -> np.ndarray: 115 | """The phis for the constraints.""" 116 | return np.array([np.mean(y, axis=-1), np.var(y, axis=-1)]).T 117 | 118 | 119 | @jit 120 | def get_rhs(thetahat: np.ndarray, xhats_pre: np.ndarray) -> np.ndarray: 121 | """Construct the RHS for the second approach (unsing basis functions phi).""" 122 | return get_phi(response(thetahat, xhats_pre)) 123 | 124 | 125 | def make_constraint_lhs(y: np.ndarray, 126 | bin_ids: np.ndarray, 127 | z_grid: np.ndarray) -> np.ndarray: 128 | """Get the LHS of the constraints.""" 129 | # Use indicator annealing approach 130 | logging.info(f"Setup {FLAGS.num_z * 2} constraints...") 131 | tmp = [] 132 | for i in range(FLAGS.num_z): 133 | tmp.append(get_phi(y[bin_ids == i])) 134 | tmp = np.array(tmp) 135 | # Smoothen LHS constraints with UnivariateSpline smoothing 136 | logging.info(f"Smoothen constraints with splines. Fixed factor: 0.2 ...") 137 | lhs = [] 138 | for i in range(tmp.shape[-1]): 139 | spl = UnivariateSpline(z_grid, tmp[:, i], s=0.2) 140 | lhs.append(spl(z_grid)) 141 | lhs = np.array(lhs).T 142 | return lhs 143 | 144 | 145 | @jit 146 | def response_poly(theta: np.ndarray, x: np.ndarray) -> np.ndarray: 147 | """The response function.""" 148 | return np.polyval(theta, x) 149 | 150 | 151 | # Must be overwritten with one of the available response functions 152 | # noinspection PyUnusedLocal 153 | @jit 154 | def response(theta: np.ndarray, x: np.ndarray) -> np.ndarray: 155 | """The response function.""" 156 | return np.empty((0,)) 157 | 158 | 159 | # ============================================================================= 160 | # OPTIMIZATION (AUGMENTED LAGRANGIAN) 161 | # ============================================================================= 162 | 163 | @jit 164 | def get_constraint_term(constr: np.ndarray, 165 | lmbda: np.ndarray, 166 | tau: float) -> float: 167 | """Compute the sum of \psi(c_i, \lambda, \tau) for the Lagrangian.""" 168 | case1 = - lmbda * constr + 0.5 * tau * constr**2 169 | case2 = - 0.5 * lmbda**2 / tau 170 | psi = np.where(tau * constr <= lmbda, case1, case2) 171 | return np.sum(psi) 172 | 173 | 174 | @jit 175 | def update_lambda(constr: np.ndarray, 176 | lmbda: np.ndarray, 177 | tau: float) -> np.ndarray: 178 | """Update Lagrangian parameters lambda.""" 179 | return np.maximum(lmbda - tau * constr, 0) 180 | 181 | 182 | @jit 183 | def make_cholesky_factor(l_param: np.ndarray) -> np.ndarray: 184 | """Get the actual cholesky factor from our parameterization of L.""" 185 | lmask = np.tri(l_param.shape[0]) 186 | lmask = index_update(lmask, (0, 0), 0) 187 | tmp = l_param * lmask 188 | idx = np.diag_indices(l_param.shape[0]) 189 | return index_update(tmp, idx, np.exp(tmp[idx])) 190 | 191 | 192 | @jit 193 | def make_correlation_matrix(l_param: np.ndarray) -> np.ndarray: 194 | """Get correlation matrix from our parameterization of L.""" 195 | chol = make_cholesky_factor(l_param) 196 | return chol @ chol.T 197 | 198 | 199 | @jit 200 | def objective_rhs_psisum_constr( 201 | key: np.ndarray, 202 | params: Params, 203 | lmbda: np.ndarray, 204 | tau: float, 205 | lhs: np.ndarray, 206 | slack: np.ndarray, 207 | xstar: float, 208 | tmp_pre: np.ndarray, 209 | xhats_pre: np.ndarray, 210 | ) -> Tuple[float, np.ndarray, float, np.ndarray]: 211 | """Estimate the objective, RHS, psisum (constraint term), and constraints. 212 | 213 | Refer to the docstring of `lagrangian` for a description of the arguments. 214 | """ 215 | # (k+1, k+1), (k,), (k,) 216 | L, mu, log_sigma = params 217 | n = tmp_pre.shape[-1] 218 | # (k, n) 219 | tmp = random.normal(key, (FLAGS.dim_theta, n)) 220 | # (k+1, n) 221 | tmp = np.concatenate((tmp_pre, tmp), axis=0) 222 | # (k+1, n) add initial dependence 223 | tmp = utils.std_normal_cdf(make_cholesky_factor(L) @ tmp) 224 | # (k, n) get thetas with current means and variances 225 | thetahat = utils.normal_cdf_inv(tmp[1:, :], mu, log_sigma) 226 | # (1,) main objective <- (n,) <- (k, n), () 227 | obj = np.mean(response(thetahat, np.array(xstar))) 228 | # (m, l) computes rhs for all z 229 | rhs = get_rhs(thetahat, xhats_pre) 230 | # (m * l,) constraints (with tolerances) 231 | constr = slack - np.ravel(np.abs(lhs - rhs)) 232 | # (1,) constraint term of lagrangian 233 | psisum = get_constraint_term(constr, lmbda, tau) 234 | return obj, rhs, psisum, constr 235 | 236 | 237 | @jit 238 | def lagrangian(key: np.ndarray, 239 | params: Params, 240 | lmbda: np.ndarray, 241 | tau: float, 242 | lhs: np.ndarray, 243 | slack: np.ndarray, 244 | xstar: float, 245 | tmp_pre: np.ndarray, 246 | xhats_pre: np.ndarray, 247 | sign: float = 1.) -> float: 248 | """Estimate the Lagrangian at a given \eta. 249 | 250 | For given $\eta$ compute MC estimate of the Lagrangian with samples from 251 | $p(\theta | x, z)$, which are used for the constraints, but also 252 | (marginalized) for the main objective. 253 | 254 | Args: 255 | key: Key for the random number generator. 256 | params: A 3-tuple with the parameters to optimize consisting of 257 | L: Lower triangular matrix from which we compute the Cholesky factor. 258 | (Not the Cholesky factor itself!). 259 | Dimension: (DIM_THETA + 1, DIM_THETA + 1) 260 | mu: The means of the (Gaussian) marginals of the thetas. 261 | Dimension: (DIM_THETA, ) 262 | log_sigma: The log of the standard deviations of the (Gaussian) 263 | marginals of the thetas. (Use log to ensure they're positive). 264 | Dimension: (DIM_THETA, ) 265 | lmbda: The Lagrangian multipliers lambda. Dimension: (NUM_Z * NUM_PHI, ) 266 | tau: The temperature parameter for the augmented Lagrangian approach. 267 | lhs: The LHS of the constraints. Dimension: (NUM_Z, NUM_PHI) 268 | slack: The tolerance for how well the constraints must be satisfied. 269 | Dimension: (NUM_Z * NUM_PHI, ) 270 | xstar: The interventional value of x in the objective. 271 | tmp_pre: Precomputed standard Guassian distributed values (for x). 272 | Dimension: (1, num_sample) 273 | xhats_pre: Precomputed samples following p(x | zi) for the zi in the 274 | Z grid (corresponding to the values in tmp_pre). 275 | Dimension: (NUM_Z, num_sample) 276 | sign: Either -1 or 1. If sign == 1, we are computing a lower bound. 277 | If sign == -1, we are computing an upper bound. 278 | 279 | Returns: 280 | a scalar estimate of the Lagrangian at the given eta and xstar 281 | """ 282 | obj, _, psisum, _ = objective_rhs_psisum_constr( 283 | key, params, lmbda, tau, lhs, slack, xstar, tmp_pre, xhats_pre) 284 | return sign * obj + psisum 285 | 286 | 287 | def init_params(key: np.ndarray) -> Params: 288 | """Initiliaze the optimization parameters.""" 289 | key, subkey = random.split(key) 290 | # init diagonal at 0, because it will be exponentiated 291 | L = 0.05 * np.tri(FLAGS.dim_theta + 1, k=-1) 292 | L *= random.normal(subkey, (FLAGS.dim_theta + 1, FLAGS.dim_theta + 1)) 293 | corr = make_correlation_matrix(L) 294 | assert np.all(np.isclose(np.linalg.cholesky(corr), 295 | make_cholesky_factor(L))), "not PSD" 296 | key, subkey = random.split(key) 297 | if FLAGS.response_type == "poly": 298 | mu = 0.001 * random.normal(subkey, (FLAGS.dim_theta,)) 299 | log_sigma = np.array([np.log(1. / (i + 1)) 300 | for i in range(FLAGS.dim_theta)]) 301 | elif FLAGS.response_type == "gp": 302 | mu = np.ones(FLAGS.dim_theta) / FLAGS.dim_theta 303 | log_sigma = 0.5 * np.ones(FLAGS.dim_theta) 304 | else: 305 | mu = 0.01 * random.normal(subkey, (FLAGS.dim_theta,)) 306 | log_sigma = 0.5 * np.ones(FLAGS.dim_theta) 307 | params = (L, mu, log_sigma) 308 | return params 309 | 310 | 311 | lagrangian_value_and_grad = jit(value_and_grad(lagrangian, argnums=1)) 312 | 313 | 314 | def run_optim(key: np.ndarray, 315 | lhs: np.ndarray, 316 | tmp: np.ndarray, 317 | xhats: np.ndarray, 318 | tmp_c: np.ndarray, 319 | xhats_c: np.ndarray, 320 | xstar: float, 321 | bound: Text, 322 | out_dir: Text, 323 | x: np.ndarray, 324 | y: np.ndarray) -> Tuple[int, float, float, int, float, float]: 325 | """Run optimization (either lower or upper) for a single xstar.""" 326 | # Directory setup 327 | # --------------------------------------------------------------------------- 328 | out_dir = os.path.join(out_dir, f"{bound}-xstar_{xstar}") 329 | if FLAGS.store_data: 330 | logging.info(f"Current run output directory: {out_dir}...") 331 | if not os.path.exists(out_dir): 332 | os.makedirs(out_dir) 333 | 334 | # Init optim params 335 | # --------------------------------------------------------------------------- 336 | logging.info(f"Initialize parameters L, mu, log_sigma, lmbda, tau, slack...") 337 | key, subkey = random.split(key) 338 | params = init_params(subkey) 339 | 340 | for parname, param in zip(['L', 'mu', 'log_sigma'], params): 341 | logging.info(f"Parameter {parname}: {param.shape}") 342 | logging.info(f" -> {parname}: {param}") 343 | 344 | tau = FLAGS.tau_init 345 | logging.info(f"Initial tau = {tau}") 346 | fin_tau = np.minimum(FLAGS.tau_factor**FLAGS.num_rounds * tau, FLAGS.tau_max) 347 | logging.info(f"Final tau = {fin_tau}") 348 | 349 | # Set constraint approach and slacks 350 | # --------------------------------------------------------------------------- 351 | slack = FLAGS.slack * np.ones(FLAGS.num_z * 2) 352 | lmbda = np.zeros(FLAGS.num_z * 2) 353 | logging.info(f"Lambdas: {lmbda.shape}") 354 | 355 | logging.info(f"Fractional tolerance (slack) for constraints = {FLAGS.slack}") 356 | logging.info(f"Set relative slack variables...") 357 | slack *= np.abs(lhs.ravel()) 358 | logging.info(f"Set minimum slack to {FLAGS.slack_abs}...") 359 | slack = np.maximum(FLAGS.slack_abs, slack) 360 | logging.info(f"Slack {slack.shape}") 361 | logging.info(f"Actual slack min: {np.min(slack)}, max: {np.max(slack)}") 362 | 363 | # Setup optimizer 364 | # --------------------------------------------------------------------------- 365 | logging.info(f"Vanilla SGD with init_lr={FLAGS.lr}...") 366 | logging.info(f"Set learning rate schedule") 367 | step_size = optim.inverse_time_decay( 368 | FLAGS.lr, FLAGS.decay_steps, FLAGS.decay_rate, FLAGS.staircase) 369 | init_fun, update_fun, get_params = optim.sgd(step_size) 370 | 371 | logging.info(f"Init state for JAX optimizer (including L, mu, log_sigma)...") 372 | state = init_fun(params) 373 | 374 | # Setup result dict 375 | # --------------------------------------------------------------------------- 376 | logging.info(f"Initialize dictionary for results...") 377 | results = { 378 | "mu": [], 379 | "sigma": [], 380 | "cholesky_factor": [], 381 | "tau": [], 382 | "lambda": [], 383 | "objective": [], 384 | "constraint_term": [], 385 | "rhs": [] 386 | } 387 | if FLAGS.plot_intermediate: 388 | results["grad_norms"] = [] 389 | results["lagrangian"] = [] 390 | 391 | logging.info(f"Evaluate at xstar={xstar}...") 392 | 393 | logging.info(f"Evaluate {bound} bound...") 394 | sign = 1 if bound == "lower" else -1 395 | 396 | # =========================================================================== 397 | # OPTIMIZATION LOOP 398 | # =========================================================================== 399 | # One-time logging before first step 400 | # --------------------------------------------------------------------------- 401 | key, subkey = random.split(key) 402 | obj, rhs, psisum, constr = objective_rhs_psisum_constr( 403 | subkey, get_params(state), lmbda, tau, lhs, slack, xstar, tmp_c, xhats_c) 404 | results["objective"].append(obj) 405 | results["constraint_term"].append(psisum) 406 | results["rhs"].append(rhs) 407 | 408 | logging.info(f"Objective: scalar") 409 | logging.info(f"RHS: {rhs.shape}") 410 | logging.info(f"Sum over Psis: scalar") 411 | logging.info(f"Constraint: {constr.shape}") 412 | 413 | tril_idx = np.tril_indices(FLAGS.dim_theta + 1) 414 | count = 0 415 | logging.info(f"Start optimization loop...") 416 | for _ in tqdm(range(FLAGS.num_rounds)): 417 | # log current parameters 418 | # ------------------------------------------------------------------------- 419 | results["lambda"].append(lmbda) 420 | results["tau"].append(tau) 421 | cur_L, cur_mu, cur_logsigma = get_params(state) 422 | cur_chol = make_cholesky_factor(cur_L)[tril_idx].ravel()[1:] 423 | results["mu"].append(cur_mu) 424 | results["sigma"].append(np.exp(cur_logsigma)) 425 | results["cholesky_factor"].append(cur_chol) 426 | 427 | subkeys = random.split(key, num=FLAGS.opt_steps + 1) 428 | key = subkeys[0] 429 | # inner optimization for subproblem 430 | # ------------------------------------------------------------------------- 431 | for j in range(FLAGS.opt_steps): 432 | v, grads = lagrangian_value_and_grad( 433 | subkeys[j + 1], get_params(state), lmbda, tau, lhs, slack, xstar, 434 | tmp, xhats, sign) 435 | state = update_fun(count, grads, state) 436 | count += 1 437 | if FLAGS.plot_intermediate: 438 | results["lagrangian"].append(v) 439 | results["grad_norms"].append([np.linalg.norm(grad) for grad in grads]) 440 | 441 | # post inner optimization logging 442 | # ------------------------------------------------------------------------- 443 | key, subkey = random.split(key) 444 | obj, rhs, psisum, constr = objective_rhs_psisum_constr( 445 | subkey, get_params(state), lmbda, tau, lhs, slack, xstar, tmp_c, xhats_c) 446 | results["objective"].append(obj) 447 | results["constraint_term"].append(psisum) 448 | results["rhs"].append(rhs) 449 | 450 | # update lambda, tau 451 | # ------------------------------------------------------------------------- 452 | lmbda = update_lambda(constr, lmbda, tau) 453 | tau = np.minimum(tau * FLAGS.tau_factor, FLAGS.tau_max) 454 | 455 | # Convert and store results 456 | # --------------------------------------------------------------------------- 457 | logging.info(f"Finished optimization loop...") 458 | 459 | logging.info(f"Convert all results to numpy arrays...") 460 | results = {k: np.array(v) for k, v in results.items()} 461 | 462 | logging.info(f"Add final parameters and lhs to results...") 463 | L, mu, log_sigma = get_params(state) 464 | results["final_L"] = L 465 | results["final_mu"] = mu 466 | results["final_log_sigma"] = log_sigma 467 | results["lhs"] = lhs 468 | 469 | if FLAGS.store_data: 470 | logging.info(f"Save result data to...") 471 | result_path = os.path.join(out_dir, "results.npz") 472 | onp.savez(result_path, **results) 473 | 474 | # Generate and store plots 475 | # --------------------------------------------------------------------------- 476 | if FLAGS.plot_intermediate: 477 | fig_dir = os.path.join(out_dir, "figures") 478 | logging.info(f"Generate and save all plots at {fig_dir}...") 479 | plotting.plot_all(results, x, y, response, fig_dir) 480 | 481 | # Compute last valid and last satisfied 482 | # --------------------------------------------------------------------------- 483 | maxabsdiff = np.array([np.max(np.abs(lhs - r)) for r in results["rhs"]]) 484 | fin_i = np.sum(~np.isnan(results["objective"])) - 1 485 | logging.info(f"Final non-nan objective at {fin_i}.") 486 | fin_obj = results["objective"][fin_i] 487 | fin_maxabsdiff = maxabsdiff[fin_i] 488 | 489 | sat_i = [np.all((np.abs((lhs - r) / lhs) < FLAGS.slack) | 490 | (np.abs(lhs - r) < FLAGS.slack_abs)) 491 | for r in results["rhs"]] 492 | sat_i = np.where(sat_i)[0] 493 | 494 | if len(sat_i) > 0: 495 | sat_i = sat_i[-1] 496 | logging.info(f"Final satisfied constraint at {sat_i}.") 497 | sat_obj = results["objective"][sat_i] 498 | sat_maxabsdiff = maxabsdiff[sat_i] 499 | else: 500 | sat_i = -1 501 | logging.info(f"Constraints were never satisfied.") 502 | sat_obj, sat_maxabsdiff = np.nan, np.nan 503 | 504 | logging.info("Finished run.") 505 | return fin_i, fin_obj, fin_maxabsdiff, sat_i, sat_obj, sat_maxabsdiff 506 | 507 | 508 | # ============================================================================= 509 | # MAIN 510 | # ============================================================================= 511 | 512 | def main(_): 513 | # --------------------------------------------------------------------------- 514 | # Directory setup, save flags, set random seed 515 | # --------------------------------------------------------------------------- 516 | FLAGS.alsologtostderr = True 517 | 518 | if FLAGS.output_name == "": 519 | dir_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 520 | else: 521 | dir_name = FLAGS.output_name 522 | out_dir = os.path.join(os.path.abspath(FLAGS.output_dir), dir_name) 523 | logging.info(f"Save all output to {out_dir}...") 524 | if not os.path.exists(out_dir): 525 | os.makedirs(out_dir) 526 | 527 | FLAGS.log_dir = out_dir 528 | logging.get_absl_handler().use_absl_log_file(program_name="run") 529 | 530 | logging.info("Save FLAGS (arguments)...") 531 | with open(os.path.join(out_dir, 'flags.json'), 'w') as fp: 532 | json.dump(FLAGS.flag_values_dict(), fp, sort_keys=True, indent=2) 533 | 534 | logging.info(f"Set random seed {FLAGS.seed}...") 535 | key = random.PRNGKey(FLAGS.seed) 536 | 537 | # --------------------------------------------------------------------------- 538 | # Load and store data 539 | # --------------------------------------------------------------------------- 540 | logging.info(f"Get dataset: {FLAGS.dataset}") 541 | if FLAGS.dataset == "synthetic": 542 | logging.info(f"Generate synthetic data (n={FLAGS.num_data}) " 543 | f"using equations {FLAGS.equations}...") 544 | key, subkey = random.split(key) 545 | dat, data_xstar, data_ystar = data.get_synth_data( 546 | subkey, FLAGS.num_data, FLAGS.equations, 547 | disconnect_instrument=FLAGS.disconnect_instrument) 548 | elif FLAGS.dataset == "colonial_origins": 549 | dat = data.get_colonial_origins(FLAGS.data_dir) 550 | data_xstar, data_ystar = None, None 551 | else: 552 | raise ValueError(f"Unknown dataset {FLAGS.dataset}") 553 | 554 | for k, v in dat.items(): 555 | if v is not None: 556 | logging.info(f'{k}: {v.shape}') 557 | 558 | if FLAGS.store_data: 559 | logging.info(f"Store data...") 560 | result_path = os.path.join(out_dir, "data.npz") 561 | onp.savez(result_path, **dat, xstar=data_xstar, ystar=data_ystar) 562 | 563 | x, y, z, ex, ey = dat['x'], dat['y'], dat['z'], dat['ex'], dat['ey'] 564 | confounder = dat['confounder'] 565 | 566 | # --------------------------------------------------------------------------- 567 | # Discretize z, generate LHS and reusable x samples 568 | # --------------------------------------------------------------------------- 569 | logging.info(f"Discretize Z and bin datapoints (num_z={FLAGS.num_z})...") 570 | z_grid, bin_ids = data.make_zgrid_and_binids(z, FLAGS.num_z) 571 | if len(z_grid) != FLAGS.num_z: 572 | FLAGS.num_z = len(z_grid) 573 | logging.info(f"Updated num_z to {FLAGS.num_z}") 574 | 575 | # Set global functions depending on FLAGS 576 | # --------------------------------------------------------------------------- 577 | logging.info(f"Use response type {FLAGS.response_type}...") 578 | basis_predict = None 579 | global response 580 | if FLAGS.response_type == "poly": 581 | response = response_poly 582 | elif FLAGS.response_type == "gp": 583 | basis_predict = utils.get_gp_prediction(x, y, n_samples=FLAGS.dim_theta) 584 | 585 | @jit 586 | def response_gp(theta: np.ndarray, _x: np.ndarray) -> np.ndarray: 587 | _x = np.atleast_1d(_x) 588 | if _x.ndim == 1: 589 | # (n,) <- (1, k) @ (k, n) 590 | return (basis_predict(_x) @ theta).squeeze() 591 | else: 592 | # (n_constr, n) <- (n_constr, n, k) @ (k, n) 593 | return np.einsum('ijk,kj->ij', basis_predict(_x), theta) 594 | 595 | response = response_gp 596 | elif FLAGS.response_type == "mlp": 597 | key, subkey = random.split(key) 598 | basis_predict = utils.fit_mlp(subkey, x[:, np.newaxis], y, 599 | n_samples=FLAGS.dim_theta) 600 | 601 | @jit 602 | def response_mlp(theta: np.ndarray, _x: np.ndarray) -> np.ndarray: 603 | _x = np.atleast_2d(_x) 604 | if _x.shape[0] == 1: 605 | # (n,) <- (1, k) @ (k, n) 606 | return (basis_predict(_x) @ theta).squeeze() 607 | else: 608 | # (n_constr, n) <- (n_constr, n, k) @ (k, n) 609 | return np.einsum('ijk,kj->ij', basis_predict(_x[:, :, None]), theta) 610 | 611 | response = response_mlp 612 | else: 613 | raise NotImplementedError(f"Unknown response_type {FLAGS.response_type}.") 614 | 615 | logging.info(f"Make LHS of constraints ...") 616 | lhs = make_constraint_lhs(y, bin_ids, z_grid) 617 | logging.info(f"LHS: {lhs.shape} ...") 618 | 619 | logging.info(f"Generate fixed x samples for objective {FLAGS.bs}...") 620 | tmp, xhats = data.get_x_samples(x, bin_ids, FLAGS.num_z, FLAGS.bs) 621 | logging.info(f"tmp: {tmp.shape}...") 622 | logging.info(f"xhats: {xhats.shape}...") 623 | 624 | logging.info(f"Generate fixed x samples for constraint {FLAGS.bs_constr}...") 625 | tmp_c, xhats_c = data.get_x_samples(x, bin_ids, FLAGS.num_z, FLAGS.bs_constr) 626 | logging.info(f"tmp_c: {tmp_c.shape}...") 627 | logging.info(f"xhats_c: {xhats_c.shape}...") 628 | 629 | xmin, xmax = np.min(x), np.max(x) 630 | xstar_grid = np.linspace(xmin, xmax, FLAGS.num_xstar + 1) 631 | xstar_grid = (xstar_grid[:-1] + xstar_grid[1:]) / 2 632 | 633 | # --------------------------------------------------------------------------- 634 | # Plot data and initialization 635 | # --------------------------------------------------------------------------- 636 | if FLAGS.plot_init: 637 | logging.info(f"Plot data and discretization visuals...") 638 | plotting.plot_all_init(z, x, y, confounder, ex, ey, xhats_c, z_grid, 639 | bin_ids, lhs, basis_func=basis_predict, 640 | base_dir=out_dir) 641 | else: 642 | logging.info(f"Skip plots of data and discretization visuals...") 643 | 644 | # --------------------------------------------------------------------------- 645 | # Allocate memory for aggregate results 646 | # --------------------------------------------------------------------------- 647 | final = { 648 | "indices": np.zeros((FLAGS.num_xstar, 2), dtype=np.int32), 649 | "objective": np.zeros((FLAGS.num_xstar, 2)), 650 | "maxabsdiff": np.zeros((FLAGS.num_xstar, 2)), 651 | } 652 | satis = { 653 | "indices": np.zeros((FLAGS.num_xstar, 2), dtype=np.int32), 654 | "objective": np.zeros((FLAGS.num_xstar, 2)), 655 | "maxabsdiff": np.zeros((FLAGS.num_xstar, 2)), 656 | } 657 | 658 | # --------------------------------------------------------------------------- 659 | # Main loops over xstar and bounds 660 | # --------------------------------------------------------------------------- 661 | for i, xstar in enumerate(xstar_grid): 662 | for j, bound in enumerate(["lower", "upper"]): 663 | logging.info(f"Run xstar={xstar}, bound={bound}...") 664 | vis = "=" * 10 665 | logging.info(f"{vis} {i * 2 + j + 1}/{2 * FLAGS.num_xstar} {vis}") 666 | fin_i, fin_obj, fin_diff, sat_i, sat_obj, sat_diff = run_optim( 667 | key, lhs, tmp, xhats, tmp_c, xhats_c, xstar, bound, out_dir, 668 | x, y) 669 | final["indices"] = index_update(final["indices"], (i, j), fin_i) 670 | final["objective"] = index_update(final["objective"], (i, j), fin_obj) 671 | final["maxabsdiff"] = index_update(final["maxabsdiff"], (i, j), fin_diff) 672 | satis["indices"] = index_update(satis["indices"], (i, j), sat_i) 673 | satis["objective"] = index_update(satis["objective"], (i, j), sat_obj) 674 | satis["maxabsdiff"] = index_update(satis["maxabsdiff"], (i, j), sat_diff) 675 | 676 | # --------------------------------------------------------------------------- 677 | # Comparison methods 678 | # --------------------------------------------------------------------------- 679 | if FLAGS.run_2sls: 680 | logging.info(f"Compute 2SLS regression...") 681 | coeff_2sls = utils.two_stage_least_squares(z, x, y) 682 | if FLAGS.store_data: 683 | result_path = os.path.join(out_dir, "coeff_2sls.npz") 684 | onp.savez(result_path, coeff_2sls=coeff_2sls) 685 | else: 686 | coeff_2sls = None 687 | 688 | if FLAGS.run_kiv: 689 | logging.info(f"Compute KIV regression...") 690 | x_kiv, y_kiv = kiv.fit_kiv(z, x, y) 691 | if FLAGS.store_data: 692 | result_path = os.path.join(out_dir, "kiv_results.npz") 693 | onp.savez(result_path, x_star=x_kiv, y_star=y_kiv) 694 | else: 695 | x_kiv, y_kiv = None, None 696 | 697 | # --------------------------------------------------------------------------- 698 | # Store basis functions 699 | # --------------------------------------------------------------------------- 700 | if FLAGS.response_type != "poly": 701 | basis_x = np.linspace(xmin, xmax, 100) 702 | basis_y = basis_predict(basis_x).squeeze() 703 | if FLAGS.store_data: 704 | logging.info(f"Store the basis functions...") 705 | result_path = os.path.join(out_dir, "basis_functions.npz") 706 | onp.savez(result_path, x=basis_x, y=basis_y) 707 | 708 | # --------------------------------------------------------------------------- 709 | # Store aggregate results 710 | # --------------------------------------------------------------------------- 711 | if FLAGS.store_data: 712 | logging.info(f"Store indices, bounds, constraint diffs at final non-nan.") 713 | result_path = os.path.join(out_dir, "final_nonnan.npz") 714 | onp.savez(result_path, xstar_grid=xstar_grid, **final) 715 | 716 | logging.info(f"Store indices, bounds, constraint diffs at last satisfied.") 717 | result_path = os.path.join(out_dir, "final_satisfied.npz") 718 | onp.savez(result_path, xstar_grid=xstar_grid, **satis) 719 | 720 | # --------------------------------------------------------------------------- 721 | # Plot aggregate results 722 | # --------------------------------------------------------------------------- 723 | if FLAGS.plot_final: 724 | logging.info(f"Plot final aggregate results...") 725 | plotting.plot_all_final( 726 | final, satis, x, y, xstar_grid, data_xstar, data_ystar, 727 | coeff_2sls=coeff_2sls, x_kiv=x_kiv, y_kiv=y_kiv, base_dir=out_dir) 728 | 729 | logging.info(f"DONE") 730 | 731 | 732 | if __name__ == "__main__": 733 | app.run(main) 734 | --------------------------------------------------------------------------------