├── .gitignore ├── CONTRIBUTING.md ├── Generalized_Empirical_Likelihood.ipynb ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /Generalized_Empirical_Likelihood.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "pSlhEBZNRP8a" 7 | }, 8 | "source": [ 9 | "# **Understanding Deep Generative Models with Generalized Empirical Likehoods**\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "QGn8KmmlRNje" 16 | }, 17 | "source": [ 18 | "Copyright 2023 DeepMind Technologies Limited\n", 19 | "\n", 20 | "All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0\n", 21 | "\n", 22 | "All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode\n", 23 | "\n", 24 | "Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.\n", 25 | "\n", 26 | "This is not an official Google product." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "6KF-NotYRX4v" 33 | }, 34 | "source": [ 35 | "# Imports" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "id": "nK4M6jWpQ4sO" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "!pip install ml_collections\n", 47 | "\n", 48 | "import functools\n", 49 | "import io\n", 50 | "import os\n", 51 | "\n", 52 | "from absl import flags\n", 53 | "from absl import logging\n", 54 | "from ml_collections import config_dict\n", 55 | "import numpy as np\n", 56 | "from scipy.optimize import linprog\n", 57 | "from scipy.stats import entropy\n", 58 | "import timeit\n", 59 | "\n", 60 | "from sklearn.metrics import pairwise_distances\n", 61 | "from enum import Enum\n", 62 | "from google.colab import auth as google_auth\n", 63 | "\n", 64 | "import matplotlib.pyplot as plt\n", 65 | "\n", 66 | "logging.set_verbosity(logging.INFO)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": { 72 | "id": "-vMQM7BxRgL7" 73 | }, 74 | "source": [ 75 | "# Copy Data from Google Cloud Bucket" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "qLnnXm4TRjIb" 82 | }, 83 | "source": [ 84 | "Authenticate user, and list data in GCP bucket" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": { 91 | "id": "-HQyBR7yRhMv" 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "google_auth.authenticate_user()\n", 96 | "!gsutil ls gs://dm_gel_metric/cifar10_mode_drop_data/" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": { 102 | "id": "EuKJLj8RRuKv" 103 | }, 104 | "source": [ 105 | "Copy to local disk and list what is copied" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "id": "NpgSsgQ_RvsJ" 113 | }, 114 | "outputs": [], 115 | "source": [ 116 | "!mkdir -p cifar10_mode_drop_data\n", 117 | "!gsutil cp gs://dm_gel_metric/cifar10_mode_drop_data/*.npz cifar10_mode_drop_data/\n", 118 | "!ls cifar10_mode_drop_data" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": { 124 | "id": "JqF3NzcWR46Y" 125 | }, 126 | "source": [ 127 | "# GEL Code" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": { 133 | "id": "bfXAcU4o4E6K" 134 | }, 135 | "source": [ 136 | "A note on speed: GEL calculations typically take 3-5 minutes to complete." 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": { 142 | "id": "1gAlAWA3R7XR" 143 | }, 144 | "source": [ 145 | "## Helper Functions and Classes" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "id": "vPy-fA8KR5pc" 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "def approx_in_cvx_hull(\n", 157 | " hull_points: np.array, test_point: np.array, eps: float=0.01):\n", 158 | " \"\"\"Triangle alg. to see if test_point is in the convex hull of hull_points.\n", 159 | "\n", 160 | " Implementation of Kalantari et al.,\n", 161 | " \"Randomized triangle algorithms for convex hull membership\"\n", 162 | "\n", 163 | " Args:\n", 164 | " hull_points: (n_points, n_dim) matrix of hull points.\n", 165 | " test_point: (n_dim,) vector which is the test point.\n", 166 | " eps: epsilon tolerance\n", 167 | " Returns:\n", 168 | " in_hull: boolean (True) if test point is in the convex hull of hull_points.\n", 169 | " pivot_point: pivot point of the algorithm.\n", 170 | " \"\"\"\n", 171 | " mean_hull_point = np.mean(hull_points, axis=0)\n", 172 | " aug_hull_points = np.vstack([hull_points, mean_hull_point])\n", 173 | "\n", 174 | " def calc_dists(hull_points, point):\n", 175 | " return np.sqrt(np.sum((hull_points-point) ** 2, axis=1))\n", 176 | "\n", 177 | " def nearest_point(test_point, pivot_point, hull_point):\n", 178 | " alpha = (np.dot(test_point-pivot_point, hull_point-pivot_point) /\n", 179 | " np.sum((hull_point-pivot_point) ** 2))\n", 180 | " if alpha \u003e= 0.0 and alpha \u003c 1.0:\n", 181 | " return (1.0 - alpha) * pivot_point + alpha * hull_point\n", 182 | " return hull_point\n", 183 | "\n", 184 | " def get_new_hull_point(sub_hull_points, pivot_point, test_point):\n", 185 | " norm_vec = ((test_point - pivot_point) /\n", 186 | " np.linalg.norm(test_point - pivot_point))\n", 187 | " a_mat = sub_hull_points - pivot_point\n", 188 | " a_mat /= np.linalg.norm(a_mat, axis=1)[:, np.newaxis]\n", 189 | " return sub_hull_points[np.argmax(np.dot(a_mat, norm_vec)), :]\n", 190 | "\n", 191 | " dist_vec = calc_dists(aug_hull_points, test_point)\n", 192 | " radius = np.max(dist_vec)\n", 193 | " pivot_point = aug_hull_points[np.argmin(dist_vec), :]\n", 194 | " i = 0\n", 195 | " while np.linalg.norm(pivot_point-test_point) \u003e eps * radius:\n", 196 | " dist_vec_pivot = calc_dists(aug_hull_points, pivot_point)\n", 197 | " if np.all(dist_vec_pivot \u003c dist_vec):\n", 198 | " # convex hull condition failed\n", 199 | " logging.info('Witness found at iter %d', i)\n", 200 | " return False, pivot_point\n", 201 | " sub_hull_points = aug_hull_points[dist_vec_pivot \u003e dist_vec, :]\n", 202 | " hull_point = get_new_hull_point(sub_hull_points, pivot_point, test_point)\n", 203 | " pivot_point = nearest_point(test_point, pivot_point, hull_point)\n", 204 | " i += 1\n", 205 | "\n", 206 | " logging.info('It is probably in the convex hull. Took %d iterations', i)\n", 207 | " logging.info('Distance for iter %d is %.5f, radius %.5f', i,\n", 208 | " np.linalg.norm(pivot_point-test_point) / radius, radius)\n", 209 | " return True, pivot_point\n", 210 | "\n", 211 | "\n", 212 | "def is_in_cvx_hull(hull_points: np.array, test_point: np.array):\n", 213 | " \"\"\"Checks to see if test_point is in the convex hull of hull_points.\n", 214 | "\n", 215 | " Args:\n", 216 | " hull_points: (n_points, n_dim) matrix of hull points.\n", 217 | " test_point: (n_dim,) vector which is the test point.\n", 218 | " Returns:\n", 219 | " in_hull: boolean (True) if test point is in the convex hull of hull_points.\n", 220 | " \"\"\"\n", 221 | "\n", 222 | " n_points, n_dims = hull_points.shape\n", 223 | " c = np.zeros((n_points,))\n", 224 | " matrix_a_ub = -np.eye(n_points)\n", 225 | " b_ub = np.zeros((n_points,))\n", 226 | "\n", 227 | " matrix_a_eq = np.ones((n_dims+1, n_points))\n", 228 | " matrix_a_eq[:n_dims, :] = hull_points.T\n", 229 | "\n", 230 | " b_eq = np.ones((n_dims+1))\n", 231 | " b_eq[:n_dims] = test_point\n", 232 | "\n", 233 | " res = linprog(c, A_ub=matrix_a_ub, b_ub=b_ub, A_eq=matrix_a_eq, b_eq=b_eq)\n", 234 | " logging.info(res.status)\n", 235 | " if res.success:\n", 236 | " log_prob = np.sum(np.log2(res.x))\n", 237 | " logging.info('Initial log prob is %.5f', log_prob)\n", 238 | " in_hull = res.success\n", 239 | " return in_hull\n", 240 | "\n", 241 | "\n", 242 | "def convert_space_to_dims(features: np.array):\n", 243 | " \"\"\"Flatten features to dimensions.\"\"\"\n", 244 | " out = features.reshape([features.shape[0], -1])\n", 245 | " return out\n", 246 | "\n", 247 | "\n", 248 | "def do_pca(features: np.array, pca_dim: int):\n", 249 | " logging.info('Whitening...')\n", 250 | " cov_mat = np.cov(features, rowvar=False)\n", 251 | " v = np.linalg.eig(cov_mat)[1]\n", 252 | " features = np.real(np.dot(features, v)[:, :pca_dim])\n", 253 | " logging.info('Done')\n", 254 | "\n", 255 | " return features\n", 256 | "\n", 257 | "def one_sample_emp_lik_iteration(feature_diffs: np.array, params: np.array):\n", 258 | " \"\"\"Perform newton iteration for empirical likelihood.\n", 259 | "\n", 260 | " Args:\n", 261 | " feature_diffs: per-sample features for moment conditions.\n", 262 | " params: current parameters for calculating empirical likelihood.\n", 263 | " Returns:\n", 264 | " params: parameters after a Newton step.\n", 265 | " output_stats: dictionary of output statistics.\n", 266 | " \"\"\"\n", 267 | " num_examples = feature_diffs.shape[0]\n", 268 | " z = 1.0 + np.dot(feature_diffs, params)\n", 269 | " inv_n = 1.0 / num_examples\n", 270 | "\n", 271 | " # positive part of the modified logarithm\n", 272 | " w_pos = 1.0 / z[z \u003e= inv_n]\n", 273 | " f_diff_pos = feature_diffs[z \u003e= inv_n, :] * w_pos[:, np.newaxis]\n", 274 | "\n", 275 | " # negative part of the modified logarithm\n", 276 | " w_neg = (2.0 - num_examples * z[z \u003c inv_n]) * num_examples\n", 277 | " f_diff_neg = feature_diffs[z \u003c inv_n, :]\n", 278 | " num_egs2 = num_examples ** 2\n", 279 | "\n", 280 | " neg_hess = (np.dot(f_diff_pos.T, f_diff_pos)\n", 281 | " + np.dot(f_diff_neg.T, f_diff_neg) * num_egs2)\n", 282 | " sc_f_diff_neg = f_diff_neg * w_neg[:, np.newaxis]\n", 283 | " log_grad = np.sum(f_diff_pos, axis=0) + np.sum(sc_f_diff_neg, axis=0)\n", 284 | " log_grad_norm = np.linalg.norm(log_grad)\n", 285 | "\n", 286 | " direction = np.linalg.solve(neg_hess, log_grad)\n", 287 | " params += 1.0 * direction\n", 288 | " n_out_of_domain = f_diff_neg.shape[0]\n", 289 | "\n", 290 | " probs = 1.0 / (num_examples * (1.0 + np.dot(feature_diffs, params)))\n", 291 | " log_lik = np.sum(np.log(probs))\n", 292 | " output_stats = dict(\n", 293 | " probs=probs, obj=log_lik,\n", 294 | " n_out_of_domain=n_out_of_domain, log_grad_norm=log_grad_norm,)\n", 295 | " return params, output_stats\n", 296 | "\n", 297 | "\n", 298 | "def hellinger_dist(p: np.array, q: np.array):\n", 299 | " assert np.all(p \u003e= 0.0)\n", 300 | " assert np.all(q \u003e= 0.0)\n", 301 | " norm_p = p / np.sum(p)\n", 302 | " norm_q = q / np.sum(q)\n", 303 | " return np.sqrt(1. - np.sum(np.sqrt(norm_p * norm_q)))\n", 304 | "\n", 305 | "\n", 306 | "class GELStatus(Enum):\n", 307 | " SOLVED = \"solved\"\n", 308 | " NOT_IN_CONVEX_HULL = \"not_in_convex_hull\"\n", 309 | " BOUNDARY = \"boundary\"\n", 310 | " OPTIMIZATION_FAILURE = \"optimization_failure\"\n", 311 | " RUNNING = \"running\"\n", 312 | "\n", 313 | "\n", 314 | "class GELObjective(Enum):\n", 315 | " EMPIRICAL_LIKELIHOOD = 1\n", 316 | " EXPONENTIAL_TILTING = 2\n", 317 | " EUCLIDEAN_LIKELIHOOD = 3" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": { 323 | "id": "la4ayxabSCZy" 324 | }, 325 | "source": [ 326 | "## One-Sample GEL code" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": { 333 | "id": "0lWkkRacSDGQ" 334 | }, 335 | "outputs": [], 336 | "source": [ 337 | "class OneSampleGEL(object):\n", 338 | " def __init__(self, config, model_unprocessed_feats: np.array,\n", 339 | " test_unprocessed_feats: np.array):\n", 340 | " self.config = config\n", 341 | " feature_diffs, test_features, model_features = self._preprocess_features(\n", 342 | " model_unprocessed_feats, test_unprocessed_feats)\n", 343 | " self.feature_diffs = feature_diffs\n", 344 | " self.test_features = test_features\n", 345 | " self.model_features = model_features\n", 346 | "\n", 347 | " self._current_loss = np.inf\n", 348 | " iter_func, output_stats, gel_status, params = self._init_optimizer()\n", 349 | " self._iter_func = iter_func\n", 350 | " self._output_stats = output_stats\n", 351 | " self._status = gel_status # termination conditions\n", 352 | " self._params = params\n", 353 | "\n", 354 | " def _init_optimizer(self):\n", 355 | " num_examples, ndims = self.feature_diffs.shape\n", 356 | " if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:\n", 357 | " iter_func = self.emp_lik_iteration\n", 358 | " elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:\n", 359 | " iter_func = self.exp_tilted_iteration\n", 360 | " elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 361 | " iter_func = self.euc_lik_iteration\n", 362 | " else:\n", 363 | " raise ValueError('Objective type %s not valid', self.config.obj_type)\n", 364 | " probs = np.empty((num_examples,))\n", 365 | " probs[:] = np.nan\n", 366 | " output_stats = dict(probs=probs, obj=np.inf,\n", 367 | " n_out_of_domain=0, log_grad_norm=np.inf)\n", 368 | " params = np.zeros((ndims,))\n", 369 | " gel_status = GELStatus.RUNNING\n", 370 | " # Check if the convex hull condition is satisfied.\n", 371 | " # This condition does not apply to the Euc. Likelihood\n", 372 | " if not self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 373 | " if not approx_in_cvx_hull(self.feature_diffs, np.zeros((ndims,)))[0]:\n", 374 | " gel_status = GELStatus.NOT_IN_CONVEX_HULL\n", 375 | " return iter_func, output_stats, gel_status, params\n", 376 | "\n", 377 | " def euc_lik_iteration(self):\n", 378 | " \"\"\"Perform newton step for euclidean likelihood (which solves the problem).\n", 379 | " \"\"\"\n", 380 | " num_examples = self.feature_diffs.shape[0]\n", 381 | " sample_cov = np.cov(self.test_features, rowvar=False, bias=True)\n", 382 | " self._params = np.linalg.solve(\n", 383 | " sample_cov, np.mean(self.feature_diffs, axis=0))\n", 384 | " demeaned_test_features = self.test_features - np.mean(\n", 385 | " self.test_features, axis=0)\n", 386 | " probs = (1 - np.dot(demeaned_test_features, self._params)) / num_examples\n", 387 | "\n", 388 | " # square is probably \"more\" correct but easier for per point\n", 389 | " euc_dist = np.sum((probs - 1.0 / num_examples) ** 2) * num_examples\n", 390 | " log_grad_norm, n_out_of_domain = 0.0, 0\n", 391 | " self._output_stats = dict(\n", 392 | " probs=probs, loss=euc_dist,\n", 393 | " n_out_of_domain=n_out_of_domain, log_grad_norm=log_grad_norm,)\n", 394 | "\n", 395 | " def exp_tilted_iteration(self):\n", 396 | " \"\"\"Perform half newton step for exponential tilting objective.\"\"\"\n", 397 | " num_examples = self.feature_diffs.shape[0]\n", 398 | " w_exp_tilt = np.exp(np.dot(self.feature_diffs, self._params)) / num_examples\n", 399 | " sc_f_diff = self.feature_diffs * w_exp_tilt[:, np.newaxis]\n", 400 | "\n", 401 | " hess = np.dot(sc_f_diff.T, self.feature_diffs)\n", 402 | " log_grad = np.sum(sc_f_diff, axis=0)\n", 403 | " newton_step = np.linalg.solve(hess, log_grad)\n", 404 | " log_grad_norm = np.linalg.norm(log_grad)\n", 405 | "\n", 406 | " self._params -= 0.5 * newton_step\n", 407 | " exp_weights = np.exp(np.dot(self.feature_diffs, self._params))\n", 408 | " probs = exp_weights / np.sum(exp_weights)\n", 409 | " n_out_of_domain = 0\n", 410 | " ent = entropy(probs, base=2)\n", 411 | " self._output_stats = dict(\n", 412 | " probs=probs, loss=-ent,\n", 413 | " n_out_of_domain=n_out_of_domain, log_grad_norm=log_grad_norm,)\n", 414 | "\n", 415 | " def emp_lik_iteration(self):\n", 416 | " \"\"\"Perform newton iteration for empirical likelihood objective.\n", 417 | "\n", 418 | " Returns:\n", 419 | " params: parameters after a Newton step.\n", 420 | " output_stats: dictionary of output statistics.\n", 421 | " \"\"\"\n", 422 | " # The empirical likelihood iteration is in the helper function section\n", 423 | " # since we use it for both the one-sample and two-sample versions\n", 424 | " self._params, output_stats = one_sample_emp_lik_iteration(\n", 425 | " self.feature_diffs, self._params)\n", 426 | " del output_stats['obj']\n", 427 | " output_stats['loss'] = -np.mean(np.log(self._output_stats['probs']))\n", 428 | " self._output_stats = output_stats\n", 429 | "\n", 430 | " def _print_stats(self, iter_i: int):\n", 431 | " \"\"\"Prints statistics at the given iteration.\"\"\"\n", 432 | " logging.info('Loss is at iter %d is %.8f',\n", 433 | " iter_i, self._output_stats['loss'])\n", 434 | " logging.info('minimum probability is %.8f',\n", 435 | " np.min(self._output_stats['probs']))\n", 436 | " logging.info('maximum probability is %.8f',\n", 437 | " np.max(self._output_stats['probs']))\n", 438 | " logging.info('sum of probability is %.8f',\n", 439 | " np.sum(self._output_stats['probs']))\n", 440 | " logging.info('no. not in domain is %d',\n", 441 | " self._output_stats['n_out_of_domain'])\n", 442 | " logging.info('log grad norm is %.15f',\n", 443 | " self._output_stats['log_grad_norm'])\n", 444 | "\n", 445 | " def _preprocess_features(self, pre_model_features: np.ndarray,\n", 446 | " pre_test_features: np.ndarray):\n", 447 | " \"\"\"Loads and converts features for calculation gen. empirical likelihood.\n", 448 | "\n", 449 | " Args:\n", 450 | " pre_model_features: unpreprocessed model features.\n", 451 | " pre_test_features: unpreprocessed test features.\n", 452 | " Returns:\n", 453 | " feature_diffs: per-sample moment conditions.\n", 454 | " test_features: per-sample test features.\n", 455 | " model_features: per-sample model features.\n", 456 | " \"\"\"\n", 457 | " if self.config.num_model_examples \u003e 0:\n", 458 | " pre_model_features = pre_model_features[:self.config.num_model_examples]\n", 459 | " assert len(pre_model_features.shape) == len(pre_test_features.shape)\n", 460 | "\n", 461 | " # convert features of size [bs, h, w, c] -\u003e [bs, dim_new]\n", 462 | " pre_model_features = convert_space_to_dims(pre_model_features)\n", 463 | " pre_test_features = convert_space_to_dims(pre_test_features)\n", 464 | "\n", 465 | " # Use fewer dimensions if flags ask us to.\n", 466 | " cut_dim = min(self.config.cut_dim, np.prod(pre_test_features.shape[1:]))\n", 467 | " ndims = pre_model_features.shape[1]\n", 468 | " assert cut_dim \u003c= ndims, f'cut_dim {cut_dim} and ndims {ndims}'\n", 469 | " model_features = pre_model_features[:, :cut_dim]\n", 470 | " test_features = pre_test_features[:, :cut_dim]\n", 471 | " assert test_features.shape[1] == model_features.shape[1]\n", 472 | "\n", 473 | " model_means = np.mean(model_features, axis=0)\n", 474 | " feature_diffs = test_features - model_means[np.newaxis, :]\n", 475 | "\n", 476 | " if self.config.whiten:\n", 477 | " assert self.config.obj_type != GELObjective.EUCLIDEAN_LIKELIHOOD\n", 478 | " feature_diffs = do_pca(feature_diffs, self.config.pca_dim)\n", 479 | "\n", 480 | " return feature_diffs, test_features, model_features\n", 481 | "\n", 482 | " def _num_out_of_domain(self) -\u003e bool:\n", 483 | " if np.any(self._output_stats['probs']) \u003c 0.0:\n", 484 | " logging.info(\n", 485 | " 'Encountered negative probability, likely due to optimization. '\n", 486 | " 'This will likely be fixed in an iteration or two.')\n", 487 | " if np.any(self._output_stats['probs']) \u003e 1.0:\n", 488 | " logging.info(\n", 489 | " 'Encountered probability \u003e 1.0, likely due to optimization. '\n", 490 | " 'This will likely be fixed in an iteration or two.')\n", 491 | " return self._output_stats['n_out_of_domain'] \u003e 0\n", 492 | "\n", 493 | " def _check_norm_param_condition(self):\n", 494 | " norm_params = np.linalg.norm(self._params)\n", 495 | " if (self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD\n", 496 | " and norm_params \u003e self.config.norm_param_tol):\n", 497 | " logging.info('Mean likely near boundary since norm is high '\n", 498 | " '%.5f, so likelihood is -Inf (\u003c%.5f)',\n", 499 | " norm_params, np.sum(np.log(self._output_stats['probs'])))\n", 500 | " self._status = GELStatus.BOUNDARY\n", 501 | "\n", 502 | " def _check_if_solved(self, iteration: int):\n", 503 | " if iteration % 10 == 1: self._print_stats(iteration)\n", 504 | " if not np.isfinite(self._output_stats['loss']):\n", 505 | " return\n", 506 | " elif not np.isfinite(self._current_loss):\n", 507 | " return\n", 508 | " if (np.abs(self._output_stats['loss']-self._current_loss) \u003c self.config.tol\n", 509 | " or self._output_stats['log_grad_norm'] \u003c self.config.grad_norm_tol):\n", 510 | " logging.info('GEL calculation converged at iteration %d. Terminating.',\n", 511 | " iteration)\n", 512 | " self._print_stats(iteration)\n", 513 | " self._status = GELStatus.SOLVED\n", 514 | "\n", 515 | " def _calculate_objective(self):\n", 516 | " probs = self._output_stats['probs']\n", 517 | " num_examples = probs.shape[0]\n", 518 | " if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:\n", 519 | " objective = np.mean(np.log(probs))\n", 520 | " best_objective = -np.log(num_examples)\n", 521 | " worst_objective = -np.inf\n", 522 | " elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:\n", 523 | " objective = entropy(probs, base=2)\n", 524 | " best_objective = np.log2(num_examples)\n", 525 | " worst_objective = 0.0\n", 526 | " elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 527 | " objective = np.sum((probs - 1. / num_examples) ** 2) * num_examples\n", 528 | " best_objective = 0.0\n", 529 | " worst_objective = np.inf\n", 530 | " else:\n", 531 | " raise ValueError('Objective type %s not valid', self.config.obj_type)\n", 532 | "\n", 533 | " return objective, best_objective, worst_objective\n", 534 | "\n", 535 | " def calculate_divergence(self):\n", 536 | " probs = self._output_stats['probs']\n", 537 | " num_examples = probs.shape[0]\n", 538 | " if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:\n", 539 | " divergence = -np.mean(np.log(num_examples * probs))\n", 540 | " elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:\n", 541 | " divergence = np.sum(probs * np.log(probs * num_examples))\n", 542 | " elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 543 | " divergence = np.sum((probs - 1. / num_examples) ** 2) * num_examples\n", 544 | " else:\n", 545 | " raise ValueError('Objective type %s not valid', self.config.obj_type)\n", 546 | "\n", 547 | " return divergence\n", 548 | "\n", 549 | " def _evaluate_gel_solution(self, elapsed_time: float = 0.0):\n", 550 | " objective, best_objective, worst_objective = self._calculate_objective()\n", 551 | " if self._status == GELStatus.SOLVED:\n", 552 | " logging.info('solved... in %f seconds', elapsed_time)\n", 553 | " logging.info('Final objective is %.5f', objective)\n", 554 | " logging.info('Best objective is %.5f', best_objective)\n", 555 | " elif self._status == GELStatus.BOUNDARY:\n", 556 | " assert self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD\n", 557 | " logging.info('final log lik is -Inf_boundary')\n", 558 | " elif self._status == GELStatus.NOT_IN_CONVEX_HULL:\n", 559 | " logging.info('Convex Hull condition not satisfied')\n", 560 | " err = ('model mean not in convex hull of features... distributions not '\n", 561 | " 'close enough to GEL')\n", 562 | " logging.info(err)\n", 563 | " else:\n", 564 | " logging.info('failed to solve... checking to see if model mean '\n", 565 | " 'is in convex hull')\n", 566 | " logging.info('calculating convex hull')\n", 567 | " if not is_in_cvx_hull(test_features, model_means):\n", 568 | " err = ('model mean not in convex hull of features... distributions not '\n", 569 | " 'close enough to GEL')\n", 570 | " logging.info(err)\n", 571 | " self._status = GELStatus.NOT_IN_CONVEX_HULL\n", 572 | " else:\n", 573 | " logging.info('optimization failed')\n", 574 | " self._status = GELStatus.OPTIMIZATION_FAILED\n", 575 | "\n", 576 | " if self._status != GELStatus.SOLVED:\n", 577 | " objective = worst_objective\n", 578 | " self._output_stats['probs'][:] = np.nan\n", 579 | " out_dict = dict()\n", 580 | " out_dict.update(self.config.to_dict())\n", 581 | " out_dict['termination_status'] = self._status\n", 582 | " out_dict['probs'] = self._output_stats['probs']\n", 583 | " out_dict['objective'] = objective\n", 584 | " out_dict['elapsed_time'] = elapsed_time\n", 585 | " out_dict['model_feat_dims'] = self.model_features.shape\n", 586 | " out_dict['test_feat_dims'] = self.test_features.shape\n", 587 | "\n", 588 | " return out_dict\n", 589 | "\n", 590 | " def calculate_gel(self):\n", 591 | " \"\"\"Calculates empirical likelihood and output results to a dictionary.\"\"\"\n", 592 | " if self._status != GELStatus.RUNNING:\n", 593 | " return self._evaluate_gel_solution()\n", 594 | "\n", 595 | " start_time = timeit.default_timer()\n", 596 | " for i in range(self.config.num_iterations):\n", 597 | " try:\n", 598 | " self._iter_func()\n", 599 | " except np.linalg.LinAlgError: # This error indicates EL hit the boundary\n", 600 | " logging.info('Encountered LinAlg Error, means we hit boundary cond')\n", 601 | " self._status = GELStatus.BOUNDARY\n", 602 | " break\n", 603 | " self._check_norm_param_condition()\n", 604 | " if self._status == GELStatus.BOUNDARY:\n", 605 | " break\n", 606 | " if self._num_out_of_domain(): continue\n", 607 | " self._check_if_solved(i)\n", 608 | " if self._status == GELStatus.SOLVED:\n", 609 | " break\n", 610 | " self._current_loss = self._output_stats['loss']\n", 611 | "\n", 612 | " elapsed_time = timeit.default_timer() - start_time\n", 613 | " return self._evaluate_gel_solution(elapsed_time)" 614 | ] 615 | }, 616 | { 617 | "cell_type": "markdown", 618 | "metadata": { 619 | "id": "gfrU7abYSHDZ" 620 | }, 621 | "source": [ 622 | "## Two-Sample GEL Code" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": null, 628 | "metadata": { 629 | "id": "F6UKtBdMSHhN" 630 | }, 631 | "outputs": [], 632 | "source": [ 633 | "class TwoSampleGEL(object):\n", 634 | " def __init__(self, config, model_unprocessed_feats: np.array,\n", 635 | " test_unprocessed_feats: np.array):\n", 636 | " self.config = config\n", 637 | " res = self._preprocess_features(\n", 638 | " model_unprocessed_feats, test_unprocessed_feats)\n", 639 | " model_feats, test_feats, aug_model_feats, aug_test_feats, aug_feats = res\n", 640 | " self.model_features = model_feats\n", 641 | " self.test_features = test_feats\n", 642 | " self.aug_features = aug_feats\n", 643 | " self.aug_model_features = aug_model_feats\n", 644 | " self.aug_test_features = aug_test_feats\n", 645 | "\n", 646 | " self._current_loss = np.inf\n", 647 | " iter_func, output_stats, gel_status, params = self._init_optimizer()\n", 648 | " self._iter_func = iter_func\n", 649 | " self._output_stats = output_stats\n", 650 | " self._status = gel_status # termination conditions\n", 651 | " self._params = params\n", 652 | "\n", 653 | " def _init_optimizer(self):\n", 654 | " num_model_examples = self.model_features.shape[0]\n", 655 | " num_test_examples = self.test_features.shape[0]\n", 656 | " if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:\n", 657 | " iter_func = self.emp_lik_iteration\n", 658 | " ndims = self.aug_features.shape[1]\n", 659 | " elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:\n", 660 | " iter_func = self.exp_tilted_iteration\n", 661 | " ndims = self.aug_features.shape[1]\n", 662 | " elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 663 | " iter_func = self.euc_likelihood_iteration\n", 664 | " ndims = self.model_features.shape[1]\n", 665 | " else:\n", 666 | " raise ValueError('Objective type %s not valid', self.config.obj_type)\n", 667 | " model_probs = np.empty((num_model_examples,))\n", 668 | " model_probs[:] = np.nan\n", 669 | " test_probs = np.empty((num_test_examples,))\n", 670 | " test_probs[:] = np.nan\n", 671 | " output_stats = dict(\n", 672 | " model_probs=model_probs, test_probs=test_probs,\n", 673 | " model_loss=np.inf, test_loss=np.inf,\n", 674 | " n_out_of_domain=0, log_grad_norm=np.nan)\n", 675 | "\n", 676 | " params = np.zeros((ndims,))\n", 677 | "\n", 678 | " gel_status = GELStatus.RUNNING\n", 679 | " # Check if the convex hull condition is satisfied.\n", 680 | " # This condition does not apply to the Euc. Likelihood\n", 681 | " if not self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 682 | " if not approx_in_cvx_hull(\n", 683 | " self.aug_features, np.zeros((self.aug_features.shape[1],)))[0]:\n", 684 | " gel_status = GELStatus.NOT_IN_CONVEX_HULL\n", 685 | " return iter_func, output_stats, gel_status, params\n", 686 | "\n", 687 | " def euc_likelihood_iteration(self):\n", 688 | " \"\"\"Perform one (and only) newton iteration for two-sample\n", 689 | " euclidean likelihood.\n", 690 | " \"\"\"\n", 691 | " model_egs = self.model_features.shape[0]\n", 692 | " test_egs = self.test_features.shape[0]\n", 693 | " if model_egs != test_egs:\n", 694 | " raise ValueError(\n", 695 | " \"Different number of model and test examples currently not supported\")\n", 696 | " model_mean = np.mean(self.model_features, axis=0)\n", 697 | " test_mean = np.mean(self.test_features, axis=0)\n", 698 | " model_cov = np.cov(self.model_features, rowvar=False, bias=True)\n", 699 | " test_cov = np.cov(self.test_features, rowvar=False, bias=True)\n", 700 | " sample_cov = (model_egs * model_cov + test_egs * test_cov) / (\n", 701 | " model_egs + test_egs)\n", 702 | " self._params = np.linalg.solve(sample_cov, model_mean - test_mean)\n", 703 | " model_probs = 1. / model_egs - np.dot(\n", 704 | " self.model_features - model_mean, self._params)\n", 705 | " test_probs = 1. / test_egs + np.dot(\n", 706 | " self.test_features - test_mean, self._params)\n", 707 | "\n", 708 | " model_obj = 0.5 * np.mean((model_probs - 1.0 / model_egs) ** 2)\n", 709 | " test_obj = 0.5 * np.mean((test_probs - 1.0 / test_egs) ** 2)\n", 710 | "\n", 711 | " self._output_stats = dict(\n", 712 | " model_probs=model_probs,\n", 713 | " test_probs=test_probs,\n", 714 | " model_loss=model_obj,\n", 715 | " test_loss=test_obj,\n", 716 | " n_out_of_domain=0,\n", 717 | " log_grad_norm=0.0, # optimized with one step\n", 718 | " )\n", 719 | "\n", 720 | " def emp_lik_iteration(self):\n", 721 | " \"\"\"Perform newton step for two-sample empirical likelihood.\n", 722 | "\n", 723 | " Standard two-sample GEL methods are in the form:\n", 724 | " sum_i log(p_i) + sum_j log(q_j)\n", 725 | " s.t.\n", 726 | " sum_i p_i = 1\n", 727 | " sum_j q_j = 1\n", 728 | " sum_i p_i X_i = sum_j q_j Y_j\n", 729 | "\n", 730 | " This is can recast in the form:\n", 731 | " sum_k log(w_k)\n", 732 | " s.t.\n", 733 | " sum_k w_k = 1\n", 734 | " sum_k w_k Z_k = 0\n", 735 | "\n", 736 | " where w_k = 0.5 * p_k for k=1,...,N\n", 737 | " and w_{k+N} = 0.5 * q_k for k=1,...,M\n", 738 | "\n", 739 | " Z_k = [X_k] for k=1,...,N\n", 740 | " [1 ]\n", 741 | " and\n", 742 | " Z_{k+N} = [-Y_k] for k=1,...,M\n", 743 | " [-1 ]\n", 744 | " \"\"\"\n", 745 | " self._params, output_stats = one_sample_emp_lik_iteration(\n", 746 | " self.aug_features, self._params)\n", 747 | " model_egs = model_features.shape[0]\n", 748 | " test_egs = test_features.shape[0]\n", 749 | " model_probs = 2.0 * output_stats['probs'][:model_egs]\n", 750 | " test_probs = 2.0 * output_stats['probs'][model_egs:]\n", 751 | " model_log_lik = np.mean(np.log(model_probs))\n", 752 | " test_log_lik = np.mean(np.log(test_probs))\n", 753 | "\n", 754 | " self._output_stats = dict(\n", 755 | " model_probs=model_probs,\n", 756 | " test_probs=test_probs,\n", 757 | " model_loss=-model_log_lik,\n", 758 | " test_loss=-test_log_lik,\n", 759 | " n_out_of_domain=output_stats['n_out_of_domain'],\n", 760 | " log_grad_norm=output_stats['log_grad_norm'],\n", 761 | " )\n", 762 | "\n", 763 | " def exp_tilted_iteration(self):\n", 764 | " \"\"\"Performs half Newton step on two-sample exp. tilting objective.\n", 765 | "\n", 766 | " See Appendix C.2 for details of the implementation.\n", 767 | " \"\"\"\n", 768 | " model_egs = self.model_features.shape[0]\n", 769 | " test_egs = self.test_features.shape[0]\n", 770 | " if model_egs != test_egs:\n", 771 | " raise ValueError(\n", 772 | " \"Different number of model and test examples currently not supported\")\n", 773 | " num_examples = model_egs + test_egs\n", 774 | "\n", 775 | " w_exp_tilt = np.exp(np.dot(self.aug_features, self._params)) / num_examples\n", 776 | " sc_f_diff = self.aug_features * w_exp_tilt[:, np.newaxis]\n", 777 | "\n", 778 | " hess = np.dot(sc_f_diff.T, self.aug_features)\n", 779 | " log_grad = np.sum(sc_f_diff, axis=0)\n", 780 | " log_grad_norm = np.linalg.norm(log_grad)\n", 781 | " newton_step = np.linalg.solve(hess, log_grad)\n", 782 | " self._params -= 0.5 * newton_step\n", 783 | " exp_weights = np.concatenate(\n", 784 | " [np.exp(np.dot(self.aug_model_features, self._params)) * test_egs,\n", 785 | " np.exp(np.dot(-self.aug_test_features, self._params)) * model_egs])\n", 786 | " probs = exp_weights / np.sum(exp_weights)\n", 787 | " model_probs = 2 * probs[:model_egs]\n", 788 | " test_probs = 2 * probs[model_egs:]\n", 789 | " self._output_stats = dict(\n", 790 | " model_probs=model_probs,\n", 791 | " test_probs=test_probs,\n", 792 | " model_loss=-entropy(model_probs, base=2),\n", 793 | " test_loss=-entropy(test_probs, base=2),\n", 794 | " n_out_of_domain=0,\n", 795 | " log_grad_norm=log_grad_norm,\n", 796 | " )\n", 797 | "\n", 798 | " def _preprocess_features(self, model_features, test_features):\n", 799 | " \"\"\"Loads and converts features for calculate_two_sample_gel.\"\"\"\n", 800 | " model_egs = model_features.shape[0]\n", 801 | " test_egs = test_features.shape[0]\n", 802 | "\n", 803 | " two_sample_features = np.concatenate(\n", 804 | " [model_features, -test_features], axis=0)\n", 805 | " if self.config.whiten:\n", 806 | " assert self.config.obj_type != GELObjective.EUCLIDEAN_LIKELIHOOD\n", 807 | " two_sample_features = do_pca(two_sample_features, self.config.pca_dim)\n", 808 | "\n", 809 | " out_model_feats = two_sample_features[:model_egs]\n", 810 | " out_test_feats = -two_sample_features[-test_egs:]\n", 811 | " aug_model_feats, aug_test_feats, aug_feats = self._make_augmented_features(\n", 812 | " out_model_feats, out_test_feats)\n", 813 | "\n", 814 | " return (out_model_feats, out_test_feats, aug_model_feats, aug_test_feats,\n", 815 | " aug_feats)\n", 816 | "\n", 817 | " def _make_augmented_features(self, model_features, test_features):\n", 818 | " \"\"\"Make two-sample features for use with empirical likelihood iteration.\"\"\"\n", 819 | " model_egs = model_features.shape[0]\n", 820 | " test_egs = test_features.shape[0]\n", 821 | "\n", 822 | " aug_model_feats = np.concatenate(\n", 823 | " [model_features, np.ones((model_egs, 1))], axis=1)\n", 824 | " aug_test_feats = np.concatenate(\n", 825 | " [test_features, np.ones((test_egs, 1))], axis=1)\n", 826 | "\n", 827 | " aug_features = np.concatenate([aug_model_feats, -aug_test_feats], axis=0)\n", 828 | "\n", 829 | " return aug_model_feats, aug_test_feats, aug_features\n", 830 | "\n", 831 | " def _print_stats(self, iter_i):\n", 832 | " \"\"\"Print statistics during optimization.\"\"\"\n", 833 | " logging.info('model loss is at iter %d is %.8f',\n", 834 | " iter_i, self._output_stats['model_loss'])\n", 835 | " logging.info('test loss is at iter %d is %.8f',\n", 836 | " iter_i, self._output_stats['test_loss'])\n", 837 | " logging.info('minimum model probability is %.8f',\n", 838 | " np.min(self._output_stats['model_probs']))\n", 839 | " logging.info('maximum model probability is %.8f',\n", 840 | " np.max(self._output_stats['model_probs']))\n", 841 | " logging.info('sum of model probabilities is %.8f',\n", 842 | " np.sum(self._output_stats['model_probs']))\n", 843 | " logging.info('minimum test probability is %.8f',\n", 844 | " np.min(self._output_stats['test_probs']))\n", 845 | " logging.info('maximum test probability is %.8f',\n", 846 | " np.max(self._output_stats['test_probs']))\n", 847 | " logging.info('sum of test probabilities is %.8f',\n", 848 | " np.sum(self._output_stats['test_probs']))\n", 849 | " logging.info(\n", 850 | " 'number not in domain is %d', self._output_stats['n_out_of_domain'])\n", 851 | " logging.info(\n", 852 | " 'log grad norm is %.15f', self._output_stats['log_grad_norm'])\n", 853 | "\n", 854 | "\n", 855 | " def _num_out_of_domain(self) -\u003e bool:\n", 856 | " if np.any(self._output_stats['model_probs']) \u003c 0.0:\n", 857 | " logging.info(\n", 858 | " 'Encountered negative model probability, likely due to '\n", 859 | " 'optimization. This will likely be fixed in an iteration or two.')\n", 860 | " if np.any(self._output_stats['model_probs']) \u003e 1.0:\n", 861 | " logging.info(\n", 862 | " 'Encountered model probability \u003e 1.0, likely due to optimization.'\n", 863 | " ' This will likely be fixed in an iteration or two.')\n", 864 | " if np.any(self._output_stats['test_probs']) \u003c 0.0:\n", 865 | " logging.info(\n", 866 | " 'Encountered negative test probability, likely due to '\n", 867 | " 'optimization. This will likely be fixed in an iteration or two.')\n", 868 | " if np.any(self._output_stats['test_probs']) \u003e 1.0:\n", 869 | " logging.info(\n", 870 | " 'Encountered test probability \u003e 1.0, likely due to optimization. '\n", 871 | " 'This will likely be fixed in an iteration or two.')\n", 872 | "\n", 873 | " return self._output_stats['n_out_of_domain'] \u003e 0\n", 874 | "\n", 875 | " def _check_norm_param_condition(self):\n", 876 | " norm_params = np.linalg.norm(self._params)\n", 877 | " if (self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD\n", 878 | " and norm_params \u003e self.config.norm_param_tol):\n", 879 | " model_ll = np.sum(np.log(self._output_stats['model_probs']))\n", 880 | " test_ll = np.sum(np.log(self._output_stats['test_probs']))\n", 881 | " logging.info(\n", 882 | " 'Mean likely near boundary since norm is high %.5f, '\n", 883 | " 'so model likelihood is -Inf, \u003c%.5f, as is test likelihood, \u003c%.5f.',\n", 884 | " norm_params, model_ll, test_ll)\n", 885 | " self._status = GELStatus.BOUNDARY\n", 886 | "\n", 887 | " def _check_if_solved(self, iteration: int):\n", 888 | " if iteration % 10 == 1: self._print_stats(iteration)\n", 889 | " obj = self._output_stats['model_loss'] + self._output_stats['test_loss']\n", 890 | " if not(np.isfinite(obj) and np.isfinite(self._current_loss)): return\n", 891 | " if (np.abs(obj-self._current_loss) \u003c self.config.tol\n", 892 | " or self._output_stats['log_grad_norm'] \u003c self.config.grad_norm_tol):\n", 893 | " logging.info('GEL calculation converged '\n", 894 | " 'at iteration %d... terminating', iteration)\n", 895 | " self._print_stats(iteration)\n", 896 | " self._status = GELStatus.SOLVED\n", 897 | "\n", 898 | " def _calculate_objective(self):\n", 899 | " model_probs = self._output_stats['model_probs']\n", 900 | " test_probs = self._output_stats['test_probs']\n", 901 | " num_model_examples = model_probs.shape[0]\n", 902 | " num_test_examples = test_probs.shape[0]\n", 903 | " if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:\n", 904 | " model_objective = np.mean(np.log(model_probs))\n", 905 | " test_objective = np.mean(np.log(test_probs))\n", 906 | " best_objective = -np.log(num_model_examples) - np.log(num_test_examples)\n", 907 | " worst_objective = -np.inf\n", 908 | " elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:\n", 909 | " model_objective = entropy(model_probs, base=2)\n", 910 | " test_objective = entropy(test_probs, base=2)\n", 911 | " best_objective = np.log2(num_test_examples) + np.log2(num_model_examples)\n", 912 | " worst_objective = 0.0\n", 913 | " elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 914 | " model_objective = self._output_stats['model_loss']\n", 915 | " test_objective = self._output_stats['test_loss']\n", 916 | " best_objective = 0.0\n", 917 | " worst_objective = np.inf\n", 918 | " else:\n", 919 | " raise ValueError('Objective type %s not valid', self.config.obj_type)\n", 920 | "\n", 921 | " objective = model_objective + test_objective\n", 922 | " return objective, best_objective, worst_objective\n", 923 | "\n", 924 | " def calculate_divergence(self):\n", 925 | " model_probs = self._output_stats['model_probs']\n", 926 | " test_probs = self._output_stats['test_probs']\n", 927 | " num_model_examples = model_probs.shape[0]\n", 928 | " num_test_examples = test_probs.shape[0]\n", 929 | " if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:\n", 930 | " model_divergence = -np.mean(np.log(num_model_examples * model_probs))\n", 931 | " test_divergence = -np.mean(np.log(num_test_examples * test_probs))\n", 932 | " divergence = model_divergence + test_divergence\n", 933 | " elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:\n", 934 | " model_divergence = np.sum(\n", 935 | " model_probs * np.log(model_probs * num_model_examples))\n", 936 | " test_divergence = np.sum(\n", 937 | " test_probs * np.log(test_probs * num_test_examples))\n", 938 | " elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 939 | " model_divergence = 0.5 * np.mean(\n", 940 | " (1. / num_model_examples - model_probs) ** 2)\n", 941 | " test_divergence = 0.5 * np.mean(\n", 942 | " (1. / num_test_examples - test_probs) ** 2)\n", 943 | " else:\n", 944 | " raise ValueError('Objective type %s not valid', self.config.obj_type)\n", 945 | "\n", 946 | " divergence = model_divergence + test_divergence\n", 947 | " return divergence, model_divergence, test_divergence\n", 948 | "\n", 949 | " def _evaluate_gel_solution(self, elapsed_time: float = 0.0):\n", 950 | " if self._status == GELStatus.SOLVED:\n", 951 | " logging.info('solved... in %f seconds', elapsed_time)\n", 952 | " elif self._status == GELStatus.BOUNDARY:\n", 953 | " assert self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD\n", 954 | " logging.info('final log lik is -Inf_boundary')\n", 955 | " elif self._status == GELStatus.NOT_IN_CONVEX_HULL:\n", 956 | " logging.info('Convex Hull condition broken')\n", 957 | " err = ('model mean not in convex hull of features... distributions not '\n", 958 | " 'close enough to GEL')\n", 959 | " logging.info(err)\n", 960 | " else:\n", 961 | " logging.info('failed to solve... checking to see if model mean '\n", 962 | " 'is in convex hull')\n", 963 | " logging.info('calculating convex hull')\n", 964 | " if not is_in_cvx_hull(self.aug_features,\n", 965 | " np.zeros((self.aug_features.shape[1],))):\n", 966 | " err = ('model mean not in convex hull of features... distributions not '\n", 967 | " 'close enough to GEL')\n", 968 | " logging.info(err)\n", 969 | " self._status = GELStatus.NOT_IN_CONVEX_HULL\n", 970 | " else:\n", 971 | " logging.info('optimization failed')\n", 972 | " self._status = GELStatus.OPTIMIZATION_FAILED\n", 973 | "\n", 974 | " if self._status != GELStatus.SOLVED:\n", 975 | " _, _, worst_objective = self._calculate_objective()\n", 976 | " objective = worst_objective\n", 977 | " self._output_stats['model_probs'][:] = np.nan\n", 978 | " self._output_stats['test_probs'][:] = np.nan\n", 979 | "\n", 980 | " out_dict = dict()\n", 981 | " out_dict.update(self.config.to_dict())\n", 982 | " out_dict['termination_status'] = self._status\n", 983 | " out_dict['model_probs'] = self._output_stats['model_probs']\n", 984 | " out_dict['test_probs'] = self._output_stats['test_probs']\n", 985 | " objective, _, _ = self._calculate_objective()\n", 986 | " out_dict['objective'] = objective\n", 987 | " out_dict['elapsed_time'] = elapsed_time\n", 988 | " out_dict['model_feat_dims'] = self.model_features.shape\n", 989 | " out_dict['test_feat_dims'] = self.test_features.shape\n", 990 | "\n", 991 | " return out_dict\n", 992 | "\n", 993 | " def calculate_gel(self):\n", 994 | " \"\"\"Calculates empirical likelihood and output results to a dictionary.\"\"\"\n", 995 | " if self._status != GELStatus.RUNNING:\n", 996 | " return self._evaluate_gel_solution()\n", 997 | "\n", 998 | " start_time = timeit.default_timer()\n", 999 | " for i in range(self.config.num_iterations):\n", 1000 | " try:\n", 1001 | " self._iter_func()\n", 1002 | " except np.linalg.LinAlgError: # This error indicates EL hit the boundary\n", 1003 | " logging.info('Encountered LinAlg Error, means we hit boundary cond')\n", 1004 | " self._status = GELStatus.BOUNDARY\n", 1005 | " break\n", 1006 | " self._check_norm_param_condition()\n", 1007 | " if self._status == GELStatus.BOUNDARY:\n", 1008 | " break\n", 1009 | " if self._num_out_of_domain(): continue\n", 1010 | " self._check_if_solved(i)\n", 1011 | " if self._status == GELStatus.SOLVED:\n", 1012 | " break\n", 1013 | " self._current_loss = (self._output_stats['model_loss'] +\n", 1014 | " self._output_stats['test_loss'])\n", 1015 | "\n", 1016 | " elapsed_time = timeit.default_timer() - start_time\n", 1017 | " return self._evaluate_gel_solution(elapsed_time)" 1018 | ] 1019 | }, 1020 | { 1021 | "cell_type": "markdown", 1022 | "metadata": { 1023 | "id": "ibeE17-4SN_P" 1024 | }, 1025 | "source": [ 1026 | "## Kernel Code" 1027 | ] 1028 | }, 1029 | { 1030 | "cell_type": "code", 1031 | "execution_count": null, 1032 | "metadata": { 1033 | "id": "7OFpbJiDSQM7" 1034 | }, 1035 | "outputs": [], 1036 | "source": [ 1037 | "def get_kernel_config():\n", 1038 | " \"\"\"Flags for the kernel config.\"\"\"\n", 1039 | " config = config_dict.ConfigDict()\n", 1040 | " config.kernel_type = 'exponential'\n", 1041 | " config.polynomial_params = config_dict.create(order=3, const=1.0)\n", 1042 | " config.exponential_params = config_dict.create(sigma=1.0)\n", 1043 | " config.laplacian_params = config_dict.create(sigma=1.0)\n", 1044 | " config.rbf_params = config_dict.create(sigma=1.0)\n", 1045 | " config.rational_quadratic_params = config_dict.create(order=2.0, const=1.0)\n", 1046 | "\n", 1047 | " return config\n", 1048 | "\n", 1049 | "\n", 1050 | "def kernel_matrix(feat1: np.ndarray, feat2: np.ndarray,\n", 1051 | " config: config_dict.ConfigDict):\n", 1052 | " \"\"\"Generate kernel matrix from two sets of features.\"\"\"\n", 1053 | " kernel_type = config.kernel_type.lower()\n", 1054 | " ndim = feat1.shape[1]\n", 1055 | "\n", 1056 | " if kernel_type == 'linear':\n", 1057 | " kernel_mat = np.dot(feat1, feat2.T) / ndim\n", 1058 | " elif kernel_type == 'exponential':\n", 1059 | " sigma = config.exponential_params.sigma\n", 1060 | " assert sigma \u003e 0.0\n", 1061 | " kernel_mat = np.exp(sigma * np.dot(feat1, feat2.T) / ndim)\n", 1062 | " elif kernel_type == 'polynomial':\n", 1063 | " order = config.polynomial_params.order\n", 1064 | " const = config.polynomial_params.const\n", 1065 | " assert order \u003e 0.0 and const \u003e 0.0\n", 1066 | " ndim = feat1.shape[1]\n", 1067 | " kernel_mat = (np.dot(feat1, feat2.T) / ndim + const) ** order\n", 1068 | " elif kernel_type == 'laplacian' or 'laplace':\n", 1069 | " sigma = config.laplacian_params.sigma\n", 1070 | " assert sigma \u003e 0.0\n", 1071 | " dist_mat = pairwise_distances(feat1, feat2, metric='l1')\n", 1072 | " kernel_mat = np.exp(-sigma * dist_mat)\n", 1073 | " elif kernel_type == 'rbf' or 'gaussian':\n", 1074 | " sigma = config.rbf_params.sigma\n", 1075 | " assert sigma \u003e 0.0\n", 1076 | " dist_mat = pairwise_distances(feat1, feat2, metric='l2')\n", 1077 | " kernel_mat = np.exp(-sigma * (dist_mat ** 2))\n", 1078 | " elif kernel_type == 'rational_quadratic':\n", 1079 | " const = config.rational_quadratic_params.const\n", 1080 | " order = config.rational_quadratic_params.order\n", 1081 | " assert order \u003e 0.0 and const \u003e 0.0\n", 1082 | " squared_dist = pairwise_distances(feat1, feat2, metric='l2') ** 2\n", 1083 | " kernel_mat = (squared_dist * (const ** 2)) ** -order\n", 1084 | " else:\n", 1085 | " raise ValueError(f'kernel_type {kernel_type} not supported')\n", 1086 | "\n", 1087 | " return kernel_mat" 1088 | ] 1089 | }, 1090 | { 1091 | "cell_type": "markdown", 1092 | "metadata": { 1093 | "id": "hwhZP9QnSYmc" 1094 | }, 1095 | "source": [ 1096 | "## Config Flags" 1097 | ] 1098 | }, 1099 | { 1100 | "cell_type": "code", 1101 | "execution_count": null, 1102 | "metadata": { 1103 | "id": "gLTlJPMQSa0A" 1104 | }, 1105 | "outputs": [], 1106 | "source": [ 1107 | "def get_gel_config():\n", 1108 | " \"\"\"Flags for GEL calculation.\n", 1109 | "\n", 1110 | " Returns:\n", 1111 | " config: ConfigDict for Flags\n", 1112 | " \"\"\"\n", 1113 | " config = config_dict.ConfigDict()\n", 1114 | " config.cut_dim = 1024\n", 1115 | " config.pca_dim = 1024\n", 1116 | " config.whiten = True\n", 1117 | " config.num_model_examples = 0\n", 1118 | " config.obj_type = GELObjective.EXPONENTIAL_TILTING\n", 1119 | "\n", 1120 | " config.num_iterations = 10000\n", 1121 | " config.norm_param_tol = 1E8\n", 1122 | " config.tol = 1E-8\n", 1123 | " config.grad_norm_tol = 1E-8\n", 1124 | "\n", 1125 | " return config" 1126 | ] 1127 | }, 1128 | { 1129 | "cell_type": "markdown", 1130 | "metadata": { 1131 | "id": "ALHI9NmxSixs" 1132 | }, 1133 | "source": [ 1134 | "# Unit Tests" 1135 | ] 1136 | }, 1137 | { 1138 | "cell_type": "markdown", 1139 | "metadata": { 1140 | "id": "583aBhFWSkjk" 1141 | }, 1142 | "source": [ 1143 | "## Helper Functions" 1144 | ] 1145 | }, 1146 | { 1147 | "cell_type": "code", 1148 | "execution_count": null, 1149 | "metadata": { 1150 | "id": "cC5LT9dNSg5T" 1151 | }, 1152 | "outputs": [], 1153 | "source": [ 1154 | "def calculate_gel_unit_tests(\n", 1155 | " shifts, obj_type, is_one_sample_test: bool = True):\n", 1156 | " config_flags = get_gel_config()\n", 1157 | " config_flags.obj_type = obj_type\n", 1158 | " if obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:\n", 1159 | " config_flags.whiten = False\n", 1160 | " else:\n", 1161 | " config_flags.whiten = True\n", 1162 | " if is_one_sample_test:\n", 1163 | " test_class = OneSampleGEL\n", 1164 | " else:\n", 1165 | " test_class = TwoSampleGEL\n", 1166 | " gel_one_sample_tests = list()\n", 1167 | " for shift in shifts:\n", 1168 | " gel_one_sample_test = test_class(\n", 1169 | " config_flags, model_features + shift, test_features)\n", 1170 | " gel_one_sample_test.calculate_gel()\n", 1171 | " gel_one_sample_tests.append(gel_one_sample_test)\n", 1172 | "\n", 1173 | " return gel_one_sample_tests, config_flags\n", 1174 | "\n", 1175 | "\n", 1176 | "def print_one_sample_unit_test_stats(\n", 1177 | " shifts, gel_one_sample_tests, config: config_dict.ConfigDict):\n", 1178 | " logging.info('Divergences for %s objective:', config.obj_type)\n", 1179 | " for shift, gel_one_sample_test in zip(shifts, gel_one_sample_tests):\n", 1180 | " logging.info('Shift is %f', shift)\n", 1181 | " solution_status = gel_one_sample_test._status\n", 1182 | " logging.info('Solution Status: %s', solution_status)\n", 1183 | " if solution_status == GELStatus.SOLVED:\n", 1184 | " logging.info(\n", 1185 | " 'Divergence is %f', gel_one_sample_test.calculate_divergence())\n", 1186 | " logging.info('---------------------------------------------')\n", 1187 | "\n", 1188 | "\n", 1189 | "def print_two_sample_unit_test_stats(\n", 1190 | " shifts, gel_two_sample_tests, config: config_dict.ConfigDict):\n", 1191 | " logging.info('Divergences for %s objective:', config.obj_type)\n", 1192 | " for shift, gel_two_sample_test in zip(shifts, gel_two_sample_tests):\n", 1193 | " logging.info('Shift is %f', shift)\n", 1194 | " solution_status = gel_two_sample_test._status\n", 1195 | " logging.info('Solution Status: %s', solution_status)\n", 1196 | " res = gel_two_sample_test.calculate_divergence()\n", 1197 | " divergence, model_divergence, test_divergence = res\n", 1198 | " if solution_status == GELStatus.SOLVED:\n", 1199 | " logging.info('Divergence is %f', divergence)\n", 1200 | " logging.info('Model divergence is %f', model_divergence)\n", 1201 | " logging.info('Test divergence is %f', test_divergence)\n", 1202 | " logging.info('---------------------------------------------')" 1203 | ] 1204 | }, 1205 | { 1206 | "cell_type": "markdown", 1207 | "metadata": { 1208 | "id": "u1ZSQPXySsjv" 1209 | }, 1210 | "source": [ 1211 | "## Load features and mean shift hyperparameters" 1212 | ] 1213 | }, 1214 | { 1215 | "cell_type": "code", 1216 | "execution_count": null, 1217 | "metadata": { 1218 | "id": "jq6c1Rm0StvX" 1219 | }, 1220 | "outputs": [], 1221 | "source": [ 1222 | "model_features = np.random.randn(*(50000, 128))\n", 1223 | "test_features = np.random.randn(*(50000, 128))\n", 1224 | "\n", 1225 | "shifts = [0.0, 0.1, 0.3]" 1226 | ] 1227 | }, 1228 | { 1229 | "cell_type": "markdown", 1230 | "metadata": { 1231 | "id": "PyMcdEjISvWH" 1232 | }, 1233 | "source": [ 1234 | "## One-Sample GEL" 1235 | ] 1236 | }, 1237 | { 1238 | "cell_type": "markdown", 1239 | "metadata": { 1240 | "id": "ofgxe2FnSxoG" 1241 | }, 1242 | "source": [ 1243 | "Empirical Likelihood" 1244 | ] 1245 | }, 1246 | { 1247 | "cell_type": "code", 1248 | "execution_count": null, 1249 | "metadata": { 1250 | "id": "QK6gPXQsS0K1" 1251 | }, 1252 | "outputs": [], 1253 | "source": [ 1254 | "gel_one_sample_tests, config_flags = calculate_gel_unit_tests(\n", 1255 | " shifts, GELObjective.EMPIRICAL_LIKELIHOOD)" 1256 | ] 1257 | }, 1258 | { 1259 | "cell_type": "code", 1260 | "execution_count": null, 1261 | "metadata": { 1262 | "id": "VZqt0qPqS2ci" 1263 | }, 1264 | "outputs": [], 1265 | "source": [ 1266 | "print_one_sample_unit_test_stats(shifts, gel_one_sample_tests, config_flags)" 1267 | ] 1268 | }, 1269 | { 1270 | "cell_type": "markdown", 1271 | "metadata": { 1272 | "id": "lfelAGaKS3uk" 1273 | }, 1274 | "source": [ 1275 | "Exponential Tilting" 1276 | ] 1277 | }, 1278 | { 1279 | "cell_type": "code", 1280 | "execution_count": null, 1281 | "metadata": { 1282 | "id": "F0jWxdMnS6MI" 1283 | }, 1284 | "outputs": [], 1285 | "source": [ 1286 | "gel_one_sample_tests, config_flags = calculate_gel_unit_tests(\n", 1287 | " shifts, GELObjective.EXPONENTIAL_TILTING)" 1288 | ] 1289 | }, 1290 | { 1291 | "cell_type": "code", 1292 | "execution_count": null, 1293 | "metadata": { 1294 | "id": "41wn41WRS9Mg" 1295 | }, 1296 | "outputs": [], 1297 | "source": [ 1298 | "print_one_sample_unit_test_stats(shifts, gel_one_sample_tests, config_flags)" 1299 | ] 1300 | }, 1301 | { 1302 | "cell_type": "markdown", 1303 | "metadata": { 1304 | "id": "pGABwRt1S-HK" 1305 | }, 1306 | "source": [ 1307 | "Euclidean Likelihood" 1308 | ] 1309 | }, 1310 | { 1311 | "cell_type": "code", 1312 | "execution_count": null, 1313 | "metadata": { 1314 | "id": "rQkkYQGUTAMs" 1315 | }, 1316 | "outputs": [], 1317 | "source": [ 1318 | "gel_one_sample_tests, config_flags = calculate_gel_unit_tests(\n", 1319 | " shifts, GELObjective.EUCLIDEAN_LIKELIHOOD)" 1320 | ] 1321 | }, 1322 | { 1323 | "cell_type": "code", 1324 | "execution_count": null, 1325 | "metadata": { 1326 | "id": "Q5vfd9teTC1Y" 1327 | }, 1328 | "outputs": [], 1329 | "source": [ 1330 | "print_one_sample_unit_test_stats(shifts, gel_one_sample_tests, config_flags)" 1331 | ] 1332 | }, 1333 | { 1334 | "cell_type": "markdown", 1335 | "metadata": { 1336 | "id": "FQzMkwXDTMYF" 1337 | }, 1338 | "source": [ 1339 | "## Two-Sample GEL" 1340 | ] 1341 | }, 1342 | { 1343 | "cell_type": "markdown", 1344 | "metadata": { 1345 | "id": "m9jZa8v4TQTF" 1346 | }, 1347 | "source": [ 1348 | "Empirical Likelihood" 1349 | ] 1350 | }, 1351 | { 1352 | "cell_type": "code", 1353 | "execution_count": null, 1354 | "metadata": { 1355 | "id": "QxIhZQ8eTHtT" 1356 | }, 1357 | "outputs": [], 1358 | "source": [ 1359 | "gel_two_sample_tests, config_flags = calculate_gel_unit_tests(\n", 1360 | " shifts, GELObjective.EMPIRICAL_LIKELIHOOD, False)" 1361 | ] 1362 | }, 1363 | { 1364 | "cell_type": "code", 1365 | "execution_count": null, 1366 | "metadata": { 1367 | "id": "Tei2QXeSTTYl" 1368 | }, 1369 | "outputs": [], 1370 | "source": [ 1371 | "print_two_sample_unit_test_stats(shifts, gel_two_sample_tests, config_flags)" 1372 | ] 1373 | }, 1374 | { 1375 | "cell_type": "markdown", 1376 | "metadata": { 1377 | "id": "MfyUYq6xTVpg" 1378 | }, 1379 | "source": [ 1380 | "Exponential Tilting" 1381 | ] 1382 | }, 1383 | { 1384 | "cell_type": "code", 1385 | "execution_count": null, 1386 | "metadata": { 1387 | "id": "_RYkOIJ7TXmC" 1388 | }, 1389 | "outputs": [], 1390 | "source": [ 1391 | "gel_two_sample_tests, config_flags = calculate_gel_unit_tests(\n", 1392 | " shifts, GELObjective.EXPONENTIAL_TILTING, False)" 1393 | ] 1394 | }, 1395 | { 1396 | "cell_type": "code", 1397 | "execution_count": null, 1398 | "metadata": { 1399 | "id": "E5l_QNQpTmHC" 1400 | }, 1401 | "outputs": [], 1402 | "source": [ 1403 | "print_two_sample_unit_test_stats(shifts, gel_two_sample_tests, config_flags)" 1404 | ] 1405 | }, 1406 | { 1407 | "cell_type": "markdown", 1408 | "metadata": { 1409 | "id": "KQ01W0w6TpHr" 1410 | }, 1411 | "source": [ 1412 | "Euclidean Likelihood" 1413 | ] 1414 | }, 1415 | { 1416 | "cell_type": "code", 1417 | "execution_count": null, 1418 | "metadata": { 1419 | "id": "xBk7t5UGTn3R" 1420 | }, 1421 | "outputs": [], 1422 | "source": [ 1423 | "gel_two_sample_tests, config_flags = calculate_gel_unit_tests(\n", 1424 | " shifts, GELObjective.EUCLIDEAN_LIKELIHOOD, False)" 1425 | ] 1426 | }, 1427 | { 1428 | "cell_type": "code", 1429 | "execution_count": null, 1430 | "metadata": { 1431 | "id": "2XVXZtmJTr3H" 1432 | }, 1433 | "outputs": [], 1434 | "source": [ 1435 | "print_two_sample_unit_test_stats(shifts, gel_two_sample_tests, config_flags)" 1436 | ] 1437 | }, 1438 | { 1439 | "cell_type": "markdown", 1440 | "metadata": { 1441 | "id": "MCLi_aJHTuuD" 1442 | }, 1443 | "source": [ 1444 | "# A couple motivating examples" 1445 | ] 1446 | }, 1447 | { 1448 | "cell_type": "markdown", 1449 | "metadata": { 1450 | "id": "cxb7_a0PTw90" 1451 | }, 1452 | "source": [ 1453 | "## Helper Functions" 1454 | ] 1455 | }, 1456 | { 1457 | "cell_type": "code", 1458 | "execution_count": null, 1459 | "metadata": { 1460 | "id": "6NzMIq74Tx7F" 1461 | }, 1462 | "outputs": [], 1463 | "source": [ 1464 | "def make_mode_probs(probs, test_labels, num_classes=10):\n", 1465 | " mode_probs = np.empty((num_classes,))\n", 1466 | " for i in range(num_classes):\n", 1467 | " mode_probs[i] = np.sum(probs[test_labels == i])\n", 1468 | "\n", 1469 | " return mode_probs" 1470 | ] 1471 | }, 1472 | { 1473 | "cell_type": "markdown", 1474 | "metadata": { 1475 | "id": "CU8GZBpAT5Gl" 1476 | }, 1477 | "source": [ 1478 | "## Evaluating mode dropping with one-sample exponential tilting\n", 1479 | "\n", 1480 | "Here, we recreate Figure 3(a) of the paper.\n", 1481 | "\n", 1482 | "The \"generative model\" here is 40k examples from the CIFAR10 training set, with up to 8 classes missing.\n", 1483 | "\n", 1484 | "For the mode dropping experiments, we remove examples from the last n classes. For example, if two modes are dropped, we remove labels 9 and 10. We use pool3 features." 1485 | ] 1486 | }, 1487 | { 1488 | "cell_type": "markdown", 1489 | "metadata": { 1490 | "id": "5xaiTIEzT9KU" 1491 | }, 1492 | "source": [ 1493 | "### Load features" 1494 | ] 1495 | }, 1496 | { 1497 | "cell_type": "code", 1498 | "execution_count": null, 1499 | "metadata": { 1500 | "id": "KxzT9U_lT_Oc" 1501 | }, 1502 | "outputs": [], 1503 | "source": [ 1504 | "data_dirn = 'cifar10_mode_drop_data/'\n", 1505 | "cifar10_dropped_mode_data = dict()\n", 1506 | "cifar10_mode_drop_gold_probs = dict()\n", 1507 | "gel_one_sample_tests = dict()\n", 1508 | "\n", 1509 | "test_data = np.load(os.path.join(data_dirn, 'cifar10_test_pool3.npz'))\n", 1510 | "test_feats = test_data['features']\n", 1511 | "test_labels = test_data['labels']\n", 1512 | "\n", 1513 | "witness_data = np.load(os.path.join(\n", 1514 | " data_dirn, 'cifar10_train_valid_10k_pool3.npz'))\n", 1515 | "witness_feats = witness_data['features'][:1024]\n", 1516 | "\n", 1517 | "all_train_data = np.load(\n", 1518 | " os.path.join(data_dirn, 'cifar10_train_valid_40k_pool3.npz'))\n", 1519 | "\n", 1520 | "for num_present_modes in range(2, 11, 2):\n", 1521 | " cifar10_dropped_mode_data[num_present_modes] = (\n", 1522 | " all_train_data['features'][all_train_data['labels'] \u003c num_present_modes])\n", 1523 | " cifar10_mode_drop_gold_probs[num_present_modes] = np.zeros((10,))\n", 1524 | " cifar10_mode_drop_gold_probs[num_present_modes][:num_present_modes] = (\n", 1525 | " 1. / num_present_modes)" 1526 | ] 1527 | }, 1528 | { 1529 | "cell_type": "markdown", 1530 | "metadata": { 1531 | "id": "QX1DSzanUDkS" 1532 | }, 1533 | "source": [ 1534 | "### Calculate kernel features" 1535 | ] 1536 | }, 1537 | { 1538 | "cell_type": "code", 1539 | "execution_count": null, 1540 | "metadata": { 1541 | "id": "xwUJmpsNUH86" 1542 | }, 1543 | "outputs": [], 1544 | "source": [ 1545 | "kernel_config_flags = get_kernel_config()\n", 1546 | "kernel_features = dict()\n", 1547 | "test_kernel_feats = kernel_matrix(\n", 1548 | " test_feats, witness_feats, kernel_config_flags)\n", 1549 | "for num_present_modes in range(2, 11, 2):\n", 1550 | " kernel_features[num_present_modes] = kernel_matrix(\n", 1551 | " cifar10_dropped_mode_data[num_present_modes], witness_feats,\n", 1552 | " kernel_config_flags)" 1553 | ] 1554 | }, 1555 | { 1556 | "cell_type": "markdown", 1557 | "metadata": { 1558 | "id": "6ZRvLdjFULdA" 1559 | }, 1560 | "source": [ 1561 | "### Calculate KGEL\n", 1562 | "\n", 1563 | "This code block recreates the results of Figure 3(a)." 1564 | ] 1565 | }, 1566 | { 1567 | "cell_type": "code", 1568 | "execution_count": null, 1569 | "metadata": { 1570 | "id": "kSL2mvjxUQ9T" 1571 | }, 1572 | "outputs": [], 1573 | "source": [ 1574 | "config_flags = get_gel_config()\n", 1575 | "hellinger_distances = dict()\n", 1576 | "for num_present_modes in range(2, 11, 2):\n", 1577 | " gel_one_sample_tests[num_present_modes] = OneSampleGEL(\n", 1578 | " config_flags, kernel_features[num_present_modes], test_kernel_feats)\n", 1579 | " outputs = gel_one_sample_tests[num_present_modes].calculate_gel()\n", 1580 | " per_sample_probs = outputs['probs']\n", 1581 | " mode_probs = make_mode_probs(per_sample_probs, test_labels)\n", 1582 | " num_missing_modes = 10 - num_present_modes\n", 1583 | " hellinger_distances[num_missing_modes] = hellinger_dist(\n", 1584 | " mode_probs, cifar10_mode_drop_gold_probs[num_present_modes])" 1585 | ] 1586 | }, 1587 | { 1588 | "cell_type": "markdown", 1589 | "metadata": { 1590 | "id": "oEsAznA8UUK1" 1591 | }, 1592 | "source": [ 1593 | "### Calculate Hellinger Distances\n", 1594 | "\n", 1595 | "Hellinger distances calculated here are *slightly* different than what is reported in the paper, likely due to numerical precision" 1596 | ] 1597 | }, 1598 | { 1599 | "cell_type": "code", 1600 | "execution_count": null, 1601 | "metadata": { 1602 | "id": "JocNFyFaUVrQ" 1603 | }, 1604 | "outputs": [], 1605 | "source": [ 1606 | "for num_present_modes in range(2, 11, 2)[::-1]:\n", 1607 | " num_missing_modes = 10 - num_present_modes\n", 1608 | " print(\"%d missing modes, Hellinger distance: %.4f\"\n", 1609 | " % (num_missing_modes, hellinger_distances[num_missing_modes]))" 1610 | ] 1611 | }, 1612 | { 1613 | "cell_type": "markdown", 1614 | "metadata": { 1615 | "id": "dIsad0JHUYv7" 1616 | }, 1617 | "source": [ 1618 | "## Identifying missing and extraneous modes with two-sample exponential tilting\n", 1619 | "\n", 1620 | "In this example, we perform an experiment where the model distribution (the first half of the CIFAR10 training+validation sets) only has samples from class labels 0 and 1, while the test distribution (the second half of the CIFAR10 training+validation sets) only has samples from class labels 1 and 2. In the ideal scenario, the model and test probabilities for class 1 should sum to 1.0, while the other probabilities should sum to 0.0.\n", 1621 | "\n", 1622 | "The witness features are from the CIFAR10 test set." 1623 | ] 1624 | }, 1625 | { 1626 | "cell_type": "markdown", 1627 | "metadata": { 1628 | "id": "rUKeOQAPUaFO" 1629 | }, 1630 | "source": [ 1631 | "### Helper Functions" 1632 | ] 1633 | }, 1634 | { 1635 | "cell_type": "code", 1636 | "execution_count": null, 1637 | "metadata": { 1638 | "id": "_htQIO2vUda-" 1639 | }, 1640 | "outputs": [], 1641 | "source": [ 1642 | "def make_two_modes(\n", 1643 | " model_data: dict, test_data: dict,\n", 1644 | " common_mode: int = 0, disjoint_mode1: int = 1, disjoint_mode2: int = 2,\n", 1645 | " num_feats_per_class = None):\n", 1646 | " \"\"\"Given two sets of features, construct two for use with two-sample tests.\n", 1647 | "\n", 1648 | " The structure of the outputs have two sets of features. One has examples from\n", 1649 | " common_mode (class A) and disjoint_mode1 (class B). The second set has\n", 1650 | " examples from common_mode (class A) and disjoint_mode2 (class C).\"\"\"\n", 1651 | " assert disjoint_mode1 != disjoint_mode2\n", 1652 | " out1_data = dict()\n", 1653 | " model_feats = model_data['features']\n", 1654 | " model_labels = model_data['labels']\n", 1655 | " feats_common = model_feats[model_labels == common_mode][:num_feats_per_class]\n", 1656 | " assert feats_common.shape[0] == num_feats_per_class\n", 1657 | " feats_disjoint1 = model_feats[model_labels == disjoint_mode1][\n", 1658 | " :num_feats_per_class]\n", 1659 | " assert feats_disjoint1.shape[0] == num_feats_per_class\n", 1660 | " out1_data['features'] = np.concatenate(\n", 1661 | " [feats_common, feats_disjoint1], axis=0)\n", 1662 | " labels_list = [common_mode] * num_feats_per_class\n", 1663 | " labels_list += [disjoint_mode1] * num_feats_per_class\n", 1664 | " out1_data['labels'] = np.array(labels_list, dtype=np.int32)\n", 1665 | "\n", 1666 | " out2_data = dict()\n", 1667 | " test_feats = test_data['features']\n", 1668 | " test_labels = test_data['labels']\n", 1669 | " feats_common2 = test_feats[test_labels == common_mode][:num_feats_per_class]\n", 1670 | " assert feats_common2.shape[0] == num_feats_per_class, feats_common2.shape[0]\n", 1671 | " feats_disjoint2 = test_feats[test_labels == disjoint_mode2][\n", 1672 | " :num_feats_per_class]\n", 1673 | " assert feats_disjoint2.shape[0] == num_feats_per_class\n", 1674 | "\n", 1675 | " out2_data['features'] = np.concatenate(\n", 1676 | " [feats_common2, feats_disjoint2], axis=0)\n", 1677 | " labels_list = [common_mode] * num_feats_per_class\n", 1678 | " labels_list += [disjoint_mode2] * num_feats_per_class\n", 1679 | " out2_data['labels'] = np.array(labels_list, dtype=np.int32)\n", 1680 | "\n", 1681 | " return out1_data, out2_data\n", 1682 | "\n", 1683 | "def make_label_balanced_model_and_test_data():\n", 1684 | " \"\"\"Make features with 2500 egs per class for data with 5000 egs per class.\"\"\"\n", 1685 | " data_dirn = 'cifar10_mode_drop_data/'\n", 1686 | " all_data = np.load(\n", 1687 | " os.path.join(data_dirn, 'cifar10_train_valid_50k_pool3.npz'))\n", 1688 | " all_data_feats = all_data['features']\n", 1689 | " all_data_labels = all_data['labels']\n", 1690 | " out_feats_model = list()\n", 1691 | " out_labels_model = list()\n", 1692 | " out_feats_test = list()\n", 1693 | " out_labels_test = list()\n", 1694 | " num_examples_per_sample_set = 2500\n", 1695 | " for label in range(10):\n", 1696 | " per_class_feats = all_data_feats[all_data_labels == label]\n", 1697 | " per_class_labels = all_data_labels[all_data_labels == label]\n", 1698 | " out_feats_model.append(per_class_feats[:num_examples_per_sample_set])\n", 1699 | " out_labels_model.append(per_class_labels[:num_examples_per_sample_set])\n", 1700 | " out_feats_test.append(per_class_feats[num_examples_per_sample_set:])\n", 1701 | " out_labels_test.append(per_class_labels[num_examples_per_sample_set:])\n", 1702 | "\n", 1703 | " out_model_data = dict(\n", 1704 | " features=np.concatenate(out_feats_model, axis=0),\n", 1705 | " labels=np.concatenate(out_labels_model))\n", 1706 | " out_test_data = dict(\n", 1707 | " features=np.concatenate(out_feats_test, axis=0),\n", 1708 | " labels=np.concatenate(out_labels_test))\n", 1709 | "\n", 1710 | " return out_model_data, out_test_data" 1711 | ] 1712 | }, 1713 | { 1714 | "cell_type": "markdown", 1715 | "metadata": { 1716 | "id": "IhlQaJr2Ug2k" 1717 | }, 1718 | "source": [ 1719 | "### Load features" 1720 | ] 1721 | }, 1722 | { 1723 | "cell_type": "markdown", 1724 | "metadata": { 1725 | "id": "Q2jcJE-YUjU-" 1726 | }, 1727 | "source": [ 1728 | "Make two sets of features\n", 1729 | "\n", 1730 | "1. For the \"model\", keep features from classes 0 and 1\n", 1731 | "2. For the \"test\", keep features from classes 1 and 2" 1732 | ] 1733 | }, 1734 | { 1735 | "cell_type": "code", 1736 | "execution_count": null, 1737 | "metadata": { 1738 | "id": "_EFrN9j2Ulqn" 1739 | }, 1740 | "outputs": [], 1741 | "source": [ 1742 | "model_data, test_data = make_label_balanced_model_and_test_data()\n", 1743 | "num_feats_per_class = 2500 # using 50k set, 2500 egs/class\n", 1744 | "model_two_classes, test_two_classes = make_two_modes(\n", 1745 | " model_data, test_data, num_feats_per_class=num_feats_per_class)" 1746 | ] 1747 | }, 1748 | { 1749 | "cell_type": "markdown", 1750 | "metadata": { 1751 | "id": "gXLoj7xFUnig" 1752 | }, 1753 | "source": [ 1754 | "### Calculate kernel features" 1755 | ] 1756 | }, 1757 | { 1758 | "cell_type": "code", 1759 | "execution_count": null, 1760 | "metadata": { 1761 | "id": "-kssIgrXUrai" 1762 | }, 1763 | "outputs": [], 1764 | "source": [ 1765 | "kernel_config_flags = get_kernel_config()\n", 1766 | "witness_data = np.load(os.path.join(\n", 1767 | " data_dirn, 'cifar10_test_pool3.npz'))\n", 1768 | "witness_feats = witness_data['features'][:1024]\n", 1769 | "model_kgel_data = dict(labels=model_two_classes['labels'])\n", 1770 | "test_kgel_data = dict(labels=test_two_classes['labels'])\n", 1771 | "model_kgel_data['features'] = kernel_matrix(\n", 1772 | " model_two_classes['features'], witness_feats, kernel_config_flags)\n", 1773 | "test_kgel_data['features'] = kernel_matrix(\n", 1774 | " test_two_classes['features'], witness_feats, kernel_config_flags)" 1775 | ] 1776 | }, 1777 | { 1778 | "cell_type": "markdown", 1779 | "metadata": { 1780 | "id": "liLC7t4KUuQR" 1781 | }, 1782 | "source": [ 1783 | "### Calculate KGEL" 1784 | ] 1785 | }, 1786 | { 1787 | "cell_type": "code", 1788 | "execution_count": null, 1789 | "metadata": { 1790 | "id": "-7bJxFHBUxhr" 1791 | }, 1792 | "outputs": [], 1793 | "source": [ 1794 | "config_flags = get_gel_config()\n", 1795 | "gel_two_sample_test = TwoSampleGEL(\n", 1796 | " config_flags, model_kgel_data['features'], test_kgel_data['features'])\n", 1797 | "out_dict = gel_two_sample_test.calculate_gel()" 1798 | ] 1799 | }, 1800 | { 1801 | "cell_type": "markdown", 1802 | "metadata": { 1803 | "id": "ZBuZQ5N6UyfN" 1804 | }, 1805 | "source": [ 1806 | "### Extract Results" 1807 | ] 1808 | }, 1809 | { 1810 | "cell_type": "code", 1811 | "execution_count": null, 1812 | "metadata": { 1813 | "id": "CWM8bX_VU0nK" 1814 | }, 1815 | "outputs": [], 1816 | "source": [ 1817 | "model_probs = np.array(\n", 1818 | " [out_dict['model_probs'][:num_feats_per_class].sum(),\n", 1819 | " out_dict['model_probs'][num_feats_per_class:].sum()])\n", 1820 | "\n", 1821 | "test_probs = np.array(\n", 1822 | " [out_dict['test_probs'][:num_feats_per_class].sum(),\n", 1823 | " out_dict['test_probs'][num_feats_per_class:].sum()])\n", 1824 | "\n", 1825 | "print('Ideal probability of the common mode is 1.0')\n", 1826 | "print('Ideal probability of the disjoint mode is 0.0')\n", 1827 | "\n", 1828 | "print('Model probability of the common mode is', model_probs[0])\n", 1829 | "print('Model probability of the disjoint mode is', model_probs[1])\n", 1830 | "\n", 1831 | "print('Test probability of the common mode is', test_probs[0])\n", 1832 | "print('Test probability of the disjoint mode is', test_probs[1])" 1833 | ] 1834 | } 1835 | ], 1836 | "metadata": { 1837 | "colab": { 1838 | "private_outputs": true, 1839 | "provenance": [ 1840 | { 1841 | "file_id": "1q21iNK4Uhkdehj3FrCfmeW2lzPGksQJm", 1842 | "timestamp": 1687020660673 1843 | } 1844 | ], 1845 | "toc_visible": true 1846 | }, 1847 | "kernelspec": { 1848 | "display_name": "Python 3", 1849 | "name": "python3" 1850 | }, 1851 | "language_info": { 1852 | "name": "python" 1853 | } 1854 | }, 1855 | "nbformat": 4, 1856 | "nbformat_minor": 0 1857 | } 1858 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Understanding Deep Generative Models with Generalized Empirical Likelihoods 2 | 3 | This repository includes an implementation of the generalized empirical 4 | likelihood metric proposed in: 5 | S. Ravuri, M. Rey, S. Mohamed, M. Deisenroth, "Understanding Deep Generative 6 | Models with Generalized Empirical Likelihoods." Computer Vision and Pattern 7 | Recognition, 2023. 8 | 9 | The metric is implemented in the "Generalized_Empirical_Likelihood" colab. 10 | The colab also includes some sample use cases, which you can find here: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/understanding_deep_generative_models_with_generalized_empirical_likelihood/blob/main/Generalized_Empirical_Likelihood.ipynb) 11 | 12 | ## Citing this work 13 | 14 | If you use this work, please cite the following: 15 | 16 | @InProceedings{Ravuri_2023_CVPR, 17 | author = {Ravuri, Suman and Rey, M\'elanie and Mohamed, Shakir and Deisenroth, Marc Peter}, 18 | title = {Understanding Deep Generative Models With Generalized Empirical Likelihoods}, 19 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 20 | month = {June}, 21 | year = {2023}, 22 | pages = {24395-24405} 23 | } 24 | 25 | ## License and disclaimer 26 | 27 | Copyright 2023 DeepMind Technologies Limited 28 | 29 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 30 | you may not use this file except in compliance with the Apache 2.0 license. 31 | You may obtain a copy of the Apache 2.0 license at: 32 | https://www.apache.org/licenses/LICENSE-2.0 33 | 34 | All other materials are licensed under the Creative Commons Attribution 4.0 35 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 36 | https://creativecommons.org/licenses/by/4.0/legalcode 37 | 38 | Unless required by applicable law or agreed to in writing, all software and 39 | materials distributed here under the Apache 2.0 or CC-BY licenses are 40 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 41 | either express or implied. See the licenses for the specific language governing 42 | permissions and limitations under those licenses. 43 | 44 | This is not an official Google product. 45 | --------------------------------------------------------------------------------