├── .github └── workflows │ └── python-publish.yml ├── LICENSE ├── README.md ├── build └── lib │ └── GIO │ ├── GIOKL.py │ ├── GIO_super.py │ ├── __init__.py │ └── generate_text_embeddings.py ├── dist └── grad_info_opt-0.1.2-py3-none-any.whl ├── experiments ├── bm25_scripts │ ├── get_less.py │ ├── get_pairs.py │ └── make_csv.py ├── checks │ ├── negative_consistency.py │ ├── quantization_consistency.py │ └── self_consistency.py ├── fashion_mnist │ ├── csv_to_image.py │ ├── image_to_csv.py │ └── train.py └── speller │ └── calculate_scores.py ├── images ├── process.gif ├── process_once.gif ├── readme_ex1.png └── readme_ex2.png ├── setup.py └── src ├── GIO ├── GIOKL.py ├── GIO_super.py ├── __init__.py └── generate_text_embeddings.py └── grad_info_opt.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GIO: Gradient Information Optimization 2 |

3 | 4 |

5 | 6 | GIO is a library that implements Gradient Information Optimization (GIO) at scale, from the paper GIO: Gradient Information Optimization for Training Dataset Selection. GIO is a data selection technique that can 7 | be used to select a subset of training data that gives similar or superior performance to a model trained on full data. 8 | 9 | **Paper Abstract** 10 | 11 | It is often advantageous to train models on a subset of the available train examples, because the examples are of variable quality or because one would like to train with fewer examples, without sacrificing performance. We present Gradient Information Optimization (GIO), a scalable, task-agnostic approach to this data selection problem that requires only a small set of (unlabeled) examples representing a target distribution. GIO begins from a natural, information-theoretic objective that is intractable in practice. Our contribution is in showing that it can be made highly scalable through a simple relaxation of the objective and a highly efficient implementation. In experiments with machine translation, spelling correction, and image recognition, we show that GIO delivers outstanding results with very small train sets. These findings are robust to different representation models and hyperparameters for GIO itself. GIO is task- and domain-agnostic and can be applied out-of-the-box to new datasets and domains. 12 | 13 | 14 | **Features**: 15 | - GIO with quantization using K-means. 16 | - Sentence embedding script to generate embeddings from data to use in GIO 17 | 18 | 19 | ## Installation 20 | 21 | Installable via pip: 22 | ```bash 23 | pip install grad-info-opt 24 | ``` 25 | Or install directly form the repository: 26 | 27 | ```bash 28 | git clone git@github.com:daeveraert/gradient-information-optimization.git 29 | cd gradient-information-optimization 30 | pip install -e . 31 | ``` 32 | 33 | Direct installation will require you to install additional dependencies listed below. We welcome contributions to GIO. 34 | 35 | ## Requirements 36 | - `numpy>=1.21.6` 37 | - `jax>=0.3.25` 38 | - `pyspark>=2.4.8` 39 | - `sentence_transformers>=2.2.2` 40 | - `jaxlib>=0.3.2` 41 | - `pandas>=1.0.5` 42 | 43 | 44 | 45 | ## Quick Start 46 | **Note:** GIO uses a Spark context, or if it can't find one, it will create a local one. You may encounter a Spark error before the algorithm runs complaining it cannot find a free port. In this case, executing ```export SPARK_LOCAL_IP="127.0.0.1"``` should resolve the issue. 47 | 48 | Here is a simple 2D demonstration of how to use GIO with visualization: 49 | ```python 50 | from GIO import GIOKL 51 | import numpy as np 52 | import jax.numpy as jnp 53 | import matplotlib.pyplot as plt 54 | 55 | # Create some data 56 | def getX(): 57 | mean = [3,4] 58 | cov = [[0.5,0],[0,0.5]] 59 | np.random.seed(1) 60 | x, y = np.random.multivariate_normal(mean, cov, 100).T 61 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 62 | 63 | def getXTest(): 64 | mean = [3,4] 65 | cov = [[0.5,0],[0,0.5]] 66 | np.random.seed(5) 67 | x, y = np.random.multivariate_normal(mean, cov, 100).T 68 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 69 | 70 | X = getX() 71 | X_test = getXTest() 72 | 73 | # Initialize class 74 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2) 75 | 76 | # Perform the Algorithm 77 | W, kl_divs, _ = gio_kl.fit(X_test, X, normalize=False) 78 | W = W[100:] # Remove the uniform start 79 | 80 | # Plot results 81 | plt.plot(kl_divs) 82 | plt.title("KL Divergence vs. Iterations") 83 | plt.xlabel("Iterations") 84 | plt.ylabel("KL Divergence") 85 | plt.show() 86 | plt.clf() 87 | plt.scatter([each[0] for each in W], [each[1] for each in W], label='Selected Data') 88 | plt.scatter([each[0] for each in X], [each[1] for each in X], label='Target Data') 89 | plt.title("Target Data and Selected Data") 90 | plt.xlabel("Dimension 1") 91 | plt.ylabel("Dimension 2") 92 | plt.legend() 93 | plt.show() 94 | ``` 95 |

96 | 97 | 98 |

99 | 100 | Here is a more complex example for scale applications, reading and using a CSV that stores embeddings and data, using quantization-explosion, and Spark: 101 | ```python 102 | from GIO import GIOKL 103 | import jax.numpy as jnp 104 | import matplotlib.pyplot as plt 105 | import pyspark.sql.functions as F 106 | 107 | # Initialize class 108 | gio_kl = GIOKL.GIOKL(uniform_low=-1, uniform_high=1, uniform_start_size=20, dim=768) 109 | 110 | # Read data 111 | train_df, target_df = gio_kl.read_data_from_csv(PATH_TO_TRAIN, PATH_TO_TARGET) 112 | 113 | # Quantize data 114 | model_train, model_X, transformed_train, transformed_X = gio_kl.quantize(train_df, target_df) 115 | 116 | X = jnp.array(model_X.clusterCenters()) 117 | train = jnp.array(model_train.clusterCenters()) 118 | centroids_df = gio_kl.spark.createDataFrame(data=[(i, each.tolist()) for i, each in enumerate(model_train.clusterCenters())], schema=["id", "centroid"]) 119 | 120 | # Perform the Algorithm 121 | W, kl_divs, _ = gio_kl.fit(train, X, max_iter=300, stopping_criterion='sequential_increase_tolerance', v_init='jump') 122 | W = W[20:] # Remove the uniform start 123 | 124 | # Explode back to original data and write resulting data 125 | full_selections_df = gio_kl.explode(W, transformed_train, centroids_df) 126 | full_selections_df.select(F.col("_c0"), F.col("_c1")).write.option("delimiter", "\t").csv(OUTPUT_PATH) 127 | 128 | 129 | # Plot results 130 | plt.plot(kl_divs) 131 | plt.title("KL Divergence vs. Iterations") 132 | plt.xlabel("Iterations") 133 | plt.ylabel("KL Divergence") 134 | plt.show() 135 | ``` 136 | **Note:** For quantization, Spark requires a large rpc message size. It is recommended to place ```gio_kl.spark.conf.set("spark.rpc.message.maxSize", "500")``` (or any large number) in the code before calling quantize, if the defaults haven't already been increased. 137 | 138 | ## Available Options 139 | `GIOKL.fit` takes the following arguments: 140 | - `train`: training data as a jnp array (jnp is almost identical to numpy) [M, D] shape 141 | - `X`: target data as a jnp array [N, D] shape 142 | - `D`: initial data as a jnp array, default None. Use None to initialize from 0 (uniform) or a subset of training data 143 | - `k`: kth nearest neighbor to use in the KL divergence estimation, default 5 144 | - `max_iter`: maximum iterations for the algorithm. One iteration adds one point (cluster) 145 | - `stop_criterion`: a string for the stopping criterion, one of the following: 'increase', 'max_resets', 'min_difference', 'sequential_increase_tolerance', 'min_kl', 'data_size'. Default is 'increase' 146 | - `min_difference`: the minimum difference between prior and current KL divergence for 'min_difference' stop criterion only. Default is 0 147 | - `resets_allowed`: whether if KL divergence increases, resetting G to the full train is allowed (allows the algorithm to pick duplicates). Must be set to true if the stop criterion is 'max_resets'. Default is False 148 | - `max_resets`: the number of resets allowed for the 'max_resets' stop criterion only (a reset resets G to the full train set and allows the algorithm to pick duplicates). Default is 2 149 | - `max_data_size`: the maximum size of data to be selected for the 'data_size' stop criterion only, as a percentage (of total data) between 0 and 1. Default is 1 150 | - `min_kl`: the minimum kl divergence for the 'min_kl' stop criterion only. Default is 0 151 | - `max_sequential_increases`: the maximum number of sequential KL divergence increases for the 'sequential_increase_tolerance' stop criterion only. Default is 3 152 | - `random_init_pct`: the percent of training data to initialize the algorithm from. Default is 0 153 | - `random_restart_prob`: probability at any given iteration to extend the gradient descent iterations by 3x, to find potentially better extrema. Higher values come at the cost of efficiency. Default is 0 154 | - `scale_factor`: factor to scale the gradient by, or 'auto'. Default is 'auto', which is recommended 155 | - `v_init`: how to initialize v in gradients descent, one of the following: 'mean', 'prev_opt', 'jump'. Default is 'mean' 156 | - `grad_desc_iter`: the number of iterations to use in gradient descent. Default is 50 157 | - `discard_nearest_for_xy`: discard nearest in the xy calculation of KL divergence, for use when X and the train set are the same, comes at the cost of efficiency. Default is False 158 | - `lr`: Learning rate for gradient descent. Default is 0.01 159 | 160 | ## Citing GIO 161 | If you use GIO in a publication, blog or software project, please cite the paper: 162 | ``` 163 | @misc{everaert2023gio, 164 | title={GIO: Gradient Information Optimization for Training Dataset Selection}, 165 | author={Dante Everaert and Christopher Potts}, 166 | year={2023}, 167 | eprint={2306.11670}, 168 | archivePrefix={arXiv}, 169 | primaryClass={cs.LG} 170 | } 171 | ``` 172 | -------------------------------------------------------------------------------- /build/lib/GIO/GIOKL.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import grad 3 | import random 4 | import jax 5 | 6 | import numpy as np 7 | import pyspark.sql.functions as F 8 | 9 | from pyspark.ml.clustering import KMeans 10 | from pyspark.sql import SparkSession 11 | from pyspark.sql.types import * 12 | from .GIO_super import GIO_super 13 | 14 | 15 | class GIOKL(GIO_super): 16 | def __init__(self, uniform_low=-1, uniform_high=1, uniform_start_size=20, dim=768): 17 | super().__init__() 18 | self.spark = SparkSession.builder.getOrCreate() 19 | self.uniform_low = uniform_low 20 | self.uniform_high = uniform_high 21 | self.uniform_start_size = uniform_start_size 22 | self.dim = dim 23 | self.random_init = False 24 | 25 | def _get_nearest(self, sample, point): 26 | """Euclidean distance from point to it's nearest point in sample. 27 | :param sample: a set of points to compute the nearest distance to 28 | :param point: point to retrieve the nearest point in sample from 29 | :return: the index of the nearest point 30 | """ 31 | norms = jnp.linalg.norm(sample - point, axis=1) 32 | return jnp.argsort(norms)[0] 33 | 34 | def _knn(self, x, y, k, last_only, discard_nearest, avg): 35 | """Find k_neighbors-nearest neighbor distances from y for each example in a minibatch x. 36 | :param x: tensor of shape [N_1, D] 37 | :param y: tensor of shape [N_2, D] 38 | :param k: the (k_neighbors+1):th nearest neighbor 39 | :param last_only: use only the last knn vs. all of them 40 | :param discard_nearest: 41 | :return: knn distances of shape [N, k_neighbors] or [N, 1] if last_only 42 | """ 43 | 44 | dist_x = jnp.sum((x ** 2), axis=-1)[:, jnp.newaxis] 45 | dist_y = jnp.sum((y ** 2), axis=-1)[:, jnp.newaxis].T 46 | cross = - 2 * jnp.matmul(x, y.T) 47 | distmat = dist_x + cross + dist_y 48 | distmat = jnp.clip(distmat, 1e-10, 1e+20) 49 | 50 | if discard_nearest: 51 | if not avg: 52 | knn, _ = jax.lax.top_k(-distmat, k + 1) 53 | else: 54 | knn = -jnp.sort(distmat) 55 | knn = knn[:, 1:] 56 | else: 57 | knn = -distmat 58 | 59 | if last_only: 60 | knn = knn[:, -1:] 61 | 62 | return jnp.sqrt(-knn) 63 | 64 | def _kl_divergence_knn(self, x, y, k, eps, discard_nearest_for_xy): 65 | """KL divergence estimator for D(x~p || y~q). 66 | :param x: x~p 67 | :param y: y~q 68 | :param k: kth nearest neighbor 69 | :param discard_nearest_for_xy: discard nearest in the xy calculation 70 | :param eps: small epsilon to pass to log 71 | :return: scalar 72 | """ 73 | n, d = x.shape 74 | m, _ = y.shape 75 | nns_xx = self._knn(x, x, k=k, last_only=True, discard_nearest=True, avg=False) 76 | nns_xy = self._knn(x, y, k=m, last_only=False, discard_nearest=discard_nearest_for_xy, avg=discard_nearest_for_xy) 77 | 78 | divergence = jnp.mean(d*jnp.log(nns_xy + eps) - d*jnp.log(nns_xx + eps)) + jnp.mean(jnp.log((k*m)/(jnp.arange(1, m+1) * (n-1)))) 79 | 80 | return divergence 81 | 82 | def calculate_statistical_distance(self, x, y, k=5, eps=1e-8, discard_nearest_for_xy=False): 83 | """Calculate statistical distance d(p,q) based on x~p and y~q. 84 | :param x: x~p 85 | :param y: y~q 86 | :param k: kth nearest neighbor 87 | :param eps: small epsilon to pass to log 88 | :return: scalar 89 | """ 90 | return self._kl_divergence_knn(x, y, k, eps, discard_nearest_for_xy) 91 | 92 | def gradient_descend(self, X, W, v, scaling_factor, max_iterations, lr=0.01, k=5, discard_nearest_for_xy=False): 93 | """Perform gradient descent on the statistical distance bwteen X and W+v 94 | :param X: target data 95 | :param W: current selected data 96 | :param v: initial v 97 | :param scaling_factor: scale the gradient 98 | :param max_iterations: iterations in the gradient descent 99 | :param lr: learning rate 100 | :param k: kth nearest neighbor 101 | :param discard_nearest_for_xy: discard nearest in the xy calculation 102 | :return: vector v opt 103 | """ 104 | i = 0 105 | while i < max_iterations: 106 | gradient = grad(lambda v: self.calculate_statistical_distance(X, jnp.concatenate((W, v[jnp.newaxis, :])), k, discard_nearest_for_xy=discard_nearest_for_xy))(v) 107 | v = v - lr * scaling_factor * gradient 108 | i += 1 109 | return v 110 | 111 | def _get_uniform_start(self, do_normalize): 112 | """Get a uniform start for D. 113 | :return: jnp array of uniform points 114 | """ 115 | def normalize(v): 116 | norm = np.linalg.norm(v) 117 | if norm == 0: 118 | return v 119 | return v / norm 120 | if do_normalize: 121 | return jnp.array([normalize(each) for each in np.random.uniform(low=self.uniform_low,high=self.uniform_high,size=(self.uniform_start_size,self.dim))]) 122 | else: 123 | return jnp.array([each for each in np.random.uniform(low=self.uniform_low,high=self.uniform_high,size=(self.uniform_start_size,self.dim))]) 124 | 125 | def fit(self, train, X, D=None, k=5, max_iter=100, stop_criterion="increase", min_difference=0, resets_allowed=False, max_resets=2, max_data_size=1, min_kl=0, max_sequential_increases=3, random_init_pct=0, random_restart_prob=0, scale_factor="auto", v_init='mean', grad_desc_iter=50, discard_nearest_for_xy=False, normalize=True, lr=0.01): 126 | """Perform GIO 127 | :param train: training data 128 | :param X: target data 129 | :param D: initial data 130 | :param k: kth nearest neighbor 131 | :param max_iter: max iterations for the algorithm 132 | :param stop_criterion: a string for the stopping criterion, one of the following: 'increase', 'max_resets', 'min_difference', 'sequential_increase_tolerance', 'min_kl', 'data_size' 133 | :param min_difference: the minimum difference between prior and current KL divergence for 'min_difference' stop criterion 134 | :param resets_allowed: whether if KL divergence increase, resetting G to the full train is allowed (allows the algorithm to pick duplicates). Must be set to true if the stop criterion is 'max_resets' 135 | :param max_resets: the number of resets allowed for the 'max_resets' stop criterion 136 | :param max_data_size: the maximum size of data for the 'data_size' stop criterion, as a percentage 137 | :param min_kl: the minimum kl divergence for the 'min_kl' stop criterion 138 | :param max_sequential_increases: the maximum number of sequential KL divergence increases for the 'sequential_increase_tolerance' stop criterion 139 | :param random_init_pct: the percent of training data to initialize the algorithm from 140 | :param random_restart_prob: probability to extend the gradient descent iterations by 3x to find potentially better extrema. Higher values come at the cost of efficiency 141 | :param scale_factor: factor to scale the gradient by or 'auto' 142 | :param v_init: how to initialize v in gradients descent, one of the following: 'mean', 'prev_opt', 'jump' 143 | :param grad_desc_iter: the number of iterations in gradient descent 144 | :param discard_nearest_for_xy: discard nearest in the xy calculation 145 | :param lr: Learning rate for gradient descent 146 | :return: selected data, kl divergences, (v, scale_factor, just_reset, num_resets, increases, adder, kl_divs) 147 | """ 148 | if not random_init_pct and D is None: 149 | W = self._get_uniform_start(normalize) 150 | self.random_init = True 151 | elif D is None: 152 | amount = int(random_init_pct * len(train)) 153 | W = jnp.array(random.sample(train.tolist(), amount)) 154 | else: 155 | W = D[:] 156 | 157 | kl_dist_prev = self.calculate_statistical_distance(X, W, k, discard_nearest_for_xy=discard_nearest_for_xy) 158 | 159 | print("Starting KL: " + str(kl_dist_prev)) 160 | if v_init == 'mean' or v_init == 'prev_opt': 161 | v = jnp.mean(X, axis=0) 162 | elif v_init == 'jump': 163 | v = jnp.array(random.sample(X.tolist(), 1)).squeeze() 164 | adder = train[:] 165 | kl_divs = [] 166 | 167 | scale_factor = jnp.linalg.norm(v)/jnp.linalg.norm(grad(lambda v: self.calculate_statistical_distance(X, jnp.concatenate((W, v[jnp.newaxis, :])), k, discard_nearest_for_xy=discard_nearest_for_xy))(v)) if scale_factor == "auto" else scale_factor 168 | 169 | i = 0 170 | just_reset = False 171 | num_resets = 0 172 | total_iter = 0 173 | increases = 0 174 | while True: 175 | # Warmup, reset or random restart 176 | if i == 0 or just_reset or random.random() < random_restart_prob: 177 | v = self.gradient_descend(X, W, v, scale_factor, grad_desc_iter * 3, lr=lr, k=k, discard_nearest_for_xy=discard_nearest_for_xy) 178 | else: 179 | v = self.gradient_descend(X, W, v, scale_factor, grad_desc_iter, lr=lr, k=k, discard_nearest_for_xy=discard_nearest_for_xy) 180 | idx = self._get_nearest(v, adder) 181 | minvals = adder[idx] 182 | adder = jnp.delete(adder, idx, axis=0) 183 | 184 | W_tmp = jnp.concatenate((W, jnp.array(minvals)[jnp.newaxis, :])) 185 | 186 | kl_dist = self.calculate_statistical_distance(X, W_tmp, k, discard_nearest_for_xy=discard_nearest_for_xy) 187 | print("KL Divergence at iteration " + str(i) + ": " + str(kl_dist)) 188 | 189 | # STOPPING CRITERIA 190 | if total_iter > max_iter: 191 | break 192 | 193 | if v_init == 'mean': 194 | v = jnp.mean(X, axis=0) 195 | elif v_init == 'jump': 196 | v = jnp.array(random.sample(X.tolist(), 1)).squeeze() 197 | 198 | adder, i, just_reset, stop, v, increases, num_resets = self._test_stop_criterion(v_init, stop_criterion, kl_dist, kl_dist_prev, num_resets, max_resets, min_difference, increases, max_sequential_increases, min_kl, max_data_size, train, X, i, v, just_reset, resets_allowed, adder) 199 | 200 | if stop: 201 | break 202 | if not just_reset: 203 | W = W_tmp 204 | kl_divs += [kl_dist] 205 | kl_dist_prev = kl_dist 206 | i += 1 207 | total_iter += 1 208 | return W, kl_divs, (v, scale_factor, just_reset, num_resets, increases, adder, kl_divs) 209 | 210 | def _test_stop_criterion(self, v_init, stop_criterion, kl_dist, kl_dist_prev, num_resets, max_resets, min_difference, increases, max_sequential_increases, min_kl, max_data_size, train, X, i, v, just_reset, resets_allowed, adder): 211 | stop = False 212 | if stop_criterion == "increase" and kl_dist - kl_dist_prev > 0: 213 | stop = True 214 | elif stop_criterion == "max_resets" and kl_dist - kl_dist_prev > 0 and num_resets == max_resets: 215 | stop = True 216 | elif stop_criterion == "min_difference" and kl_dist_prev - kl_dist < min_difference: 217 | stop = True 218 | elif stop_criterion == 'sequential_increase_tolerance' and kl_dist - kl_dist_prev > 0 and increases == max_sequential_increases: 219 | stop = True 220 | elif stop_criterion == 'min_kl' and kl_dist < min_kl: 221 | stop = True 222 | elif stop_criterion == 'data_size' and i > int(max_data_size * len(train)): 223 | stop = True 224 | if stop: 225 | if just_reset: 226 | increases += 1 227 | if resets_allowed and num_resets < max_resets: 228 | num_resets += 1 229 | if v_init == 'prev_opt': 230 | v = jnp.mean(X, axis=0) 231 | print("KL Div Increase, Resetting G") 232 | adder = train[:] 233 | i -= 1 234 | stop = False 235 | just_reset = True 236 | else: 237 | just_reset = False 238 | increases = 0 239 | return adder, i, just_reset, stop, v, increases, num_resets 240 | 241 | def _return_kmeans(self, df, k, rseed): 242 | """Use Spark to perform K-Means 243 | :param df: dataframe to perform K-Means with 244 | :param k: number of clusters to compute 245 | :param rseed: random seed 246 | :return: k-means model, transformed df 247 | """ 248 | kmeans = KMeans().setK(k).setSeed(rseed) 249 | model = kmeans.fit(df.select("features")) 250 | transformed_df = model.transform(df) 251 | return model, transformed_df 252 | 253 | def quantize(self, df_train, df_x, k=1500, rseed='auto', rseed1=234, rseed2=456): 254 | """Use Spark to perform K-Means 255 | :param df_train: train dataframe to quantize 256 | :param df_x: target dataframe to quantize 257 | :param k: number of clusters to compute 258 | :param rseed: 'auto' or 'manual' 259 | :param rseed1: first random seed 260 | :param rseed2: second random seed 261 | :return: k-means model, transformed df 262 | """ 263 | if rseed == 'auto': 264 | rseed1 = random.randint(-1000,1000) 265 | rseed2 = random.randint(-1000,1000) 266 | model_train, transformed_train = self._return_kmeans(df_train, k, rseed1) 267 | model_X, transformed_X = self._return_kmeans(df_x, k, rseed2) 268 | return model_train, model_X, transformed_train, transformed_X 269 | 270 | def read_data_from_csv(self, path, path_X, delim="\t"): 271 | """Read in and process data stored in a csv. Data must be of the format: _c0, _c1, _c2 where _c2 contains the 272 | string representation of the vector, like "[0.1, 0.23, 0.45 ...]" 273 | :param path: path to training data 274 | :param path_X: path to target data 275 | :param delim: delimiter for csv file 276 | :return: train df, target df 277 | """ 278 | new_schema = ArrayType(DoubleType(), containsNull=False) 279 | udf_json_to_arr = F.udf(lambda x: x, new_schema) 280 | 281 | df_read = self.spark.read.option("delimiter", delim).csv(path) 282 | df_with_embeddings = df_read.withColumn("features", udf_json_to_arr(F.from_json(F.col("_c2"), "array"))) 283 | 284 | df_X_read = self.spark.read.option("delimiter", delim).csv(path_X) 285 | df_X_with_embeddings = df_X_read.withColumn("features", udf_json_to_arr(F.from_json(F.col("_c2"), "array"))) 286 | 287 | return df_with_embeddings, df_X_with_embeddings 288 | 289 | def read_data_from_parquet(self, path, path_X): 290 | """Read in and process data stored in a parquet format. Data must contain a column "features" that stores an array 291 | of the vectors and be non-nullable. 292 | :param path: path to training data 293 | :param path_X: path to target data 294 | :return: train df, target df 295 | """ 296 | new_schema = ArrayType(DoubleType(), containsNull=False) 297 | udf_no_null = F.udf(lambda x: x, new_schema) 298 | 299 | df_with_embeddings = self.spark.read.parquet(path).withColumn("features", udf_no_null(F.col("features"))) 300 | df_X_with_embeddings = self.spark.read.parquet(path_X).withColumn("features", udf_no_null(F.col("features"))) 301 | return df_with_embeddings, df_X_with_embeddings 302 | 303 | def explode(self, chosen_centroids, kmeans_transformed_df, kmeans_centroids_df): 304 | """Read in and process data stored in a parquet format. Data must contain a column "features" that stores an array 305 | of the vectors. 306 | :param path: path to training data 307 | :param path_X: path to target data 308 | :return: train df, target df 309 | """ 310 | pre_existing_centroids = jnp.array([f[1] for f in sorted([[each[0], each[1]] for each in kmeans_centroids_df.collect()], key=lambda x: x[0])]) 311 | paired = [] 312 | for each in chosen_centroids: 313 | for i, x in enumerate(pre_existing_centroids.tolist()): 314 | if each.tolist() in [x]: 315 | paired += [i] 316 | print("Found " + str(len(paired)) + " centroids out of " + str(len(chosen_centroids)) + " selected centroids") 317 | full_selections_df = self.spark.createDataFrame(data=[(i, each) for i, each in enumerate(paired)], schema=["i", "id"]).join(kmeans_transformed_df, F.col("id") == F.col("prediction")) 318 | return full_selections_df 319 | 320 | -------------------------------------------------------------------------------- /build/lib/GIO/GIO_super.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import grad 3 | import random 4 | import jax 5 | 6 | import numpy as np 7 | import pyspark.sql.functions as F 8 | 9 | from pyspark.ml.clustering import KMeans 10 | from pyspark.sql import SparkSession 11 | from pyspark.sql.types import * 12 | 13 | 14 | class GIO_super: 15 | def __init__(self): 16 | pass 17 | 18 | def calculate_statistical_distance(self, x, y): 19 | pass 20 | 21 | def gradient_descend(self, X, W, v, factor, max_iterations, lr, *arg): 22 | pass 23 | 24 | def fit(self, train, X, *arg): 25 | pass 26 | 27 | def quantize(self, df_train, df_x, quantize_into): 28 | pass 29 | 30 | def _get_nearest(self, sample, point): 31 | pass 32 | 33 | def explode(self, chosen_centroids, kmeans_transformed_df, kmeans_centroids_df): 34 | pass 35 | -------------------------------------------------------------------------------- /build/lib/GIO/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/build/lib/GIO/__init__.py -------------------------------------------------------------------------------- /build/lib/GIO/generate_text_embeddings.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | import time 3 | 4 | 5 | 6 | class GenerateEmbeddings: 7 | def __init__(self, model_name, device='cuda'): 8 | self.model = SentenceTransformer(model_name) 9 | self.device = device 10 | 11 | def generate_embeddings(self, input_file_path, output_file_path): 12 | """Generate Embeddings from a text file 13 | :param input_file_path: path to input text, one sentence per line 14 | :param output_file_path: path to desired output file 15 | """ 16 | print('Reading File...') 17 | with open(input_file_path, 'r') as fp: 18 | sentences = fp.readlines() 19 | print('Generating Embeddings... This May Take a While') 20 | start = time.time() 21 | embeddings = self.model.encode(sentences, device=self.device) 22 | end = time.time() 23 | 24 | print("Time Taken (s): " + str(end - start)) 25 | 26 | print("Writing Embeddings.. This May Take a While") 27 | with open(output_file_path, 'w') as op: 28 | for i, each in enumerate(embeddings): 29 | op.write(str(each.tolist()).strip() + "\n") 30 | -------------------------------------------------------------------------------- /dist/grad_info_opt-0.1.2-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/dist/grad_info_opt-0.1.2-py3-none-any.whl -------------------------------------------------------------------------------- /experiments/bm25_scripts/get_less.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import csv 3 | 4 | bad_data_counter = 0 5 | written = 0 6 | 7 | with open(sys.argv[1], "r", newline='') as fp, open(sys.argv[2], "w") as op: 8 | csv_fp = csv.reader(fp, delimiter=',') 9 | for i, each in enumerate(csv_fp): 10 | # Skip header 11 | if i == 0: 12 | continue 13 | 14 | # Expect length of 3; determine how many are bad data 15 | if len(each) < 3: 16 | bad_data_counter += 1 17 | 18 | # Filter if the topK ID is above the specified filter value 19 | if int(each[2]) > int(sys.argv[3]): 20 | continue 21 | else: 22 | written += 1 23 | op.write(each[0].strip() + "\n") 24 | 25 | print("Bad Data: " + str(bad_data_counter)) 26 | print("Total Written: " + str(written)) 27 | -------------------------------------------------------------------------------- /experiments/bm25_scripts/get_pairs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | dct = {} 4 | bad_data_counter = 0 5 | with open(sys.argv[1], "r") as fp: 6 | for each in fp: 7 | sepped = each.strip().split("\t") 8 | # Expect length of 2 9 | if len(sepped) != 2: 10 | bad_data_counter += 1 11 | else: 12 | dct[sepped[0].strip()] = sepped[1] # Create dictionary of input-output pairs to later look up and match 13 | 14 | print("Weird parses: " + str(bad_data_counter)) 15 | 16 | lines = [] 17 | not_in_dct = 0 18 | with open(sys.argv[2], "r") as fp: 19 | for each in fp: 20 | cleaned = each.strip() 21 | # Expect all data to be in the dictionary 22 | if cleaned not in dct: 23 | not_in_dct += 1 24 | else: 25 | lines += [cleaned + '\t' + dct[cleaned] + "\n"] # Write input \t output 26 | 27 | print("Not in Dict: " + str(not_in_dct)) 28 | 29 | with open(sys.argv[3], "w") as op: 30 | op.writelines(lines) 31 | 32 | 33 | -------------------------------------------------------------------------------- /experiments/bm25_scripts/make_csv.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | 4 | with open(sys.argv[1], "r") as fp, open(sys.argv[2], "w", newline='') as op: 5 | csv_fp = fp 6 | csv_op = csv.writer(op, delimiter=',') 7 | csv_op.writerow(["text","id"]) # Give column names 8 | total = 0 9 | for i, each in enumerate(csv_fp): 10 | csv_op.writerow([each.strip()] + [i]) # Write value and ID in CSV format 11 | total += 1 12 | -------------------------------------------------------------------------------- /experiments/checks/negative_consistency.py: -------------------------------------------------------------------------------- 1 | from GIO import GIOKL 2 | import numpy as np 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | 6 | # Create some data 7 | def getX(): 8 | mean = [3,4] 9 | cov = [[0.5,0],[0,0.5]] 10 | np.random.seed(1) 11 | x, y = np.random.multivariate_normal(mean, cov, 100).T 12 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 13 | 14 | def getXTest(): 15 | mean = [300,400] 16 | cov = [[0.5,0],[0,0.5]] 17 | np.random.seed(5) 18 | x, y = np.random.multivariate_normal(mean, cov, 100).T 19 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 20 | 21 | X = getX() 22 | X_test = getXTest() 23 | 24 | # Initialize class 25 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2) 26 | 27 | # Perform the Algorithm 28 | W, kl_divs, _ = gio_kl.fit(X_test, X, normalize=False) 29 | W = W[100:] # Remove the uniform start 30 | 31 | # Plot results 32 | plt.plot(kl_divs) 33 | plt.title("KL Divergence vs. Iterations") 34 | plt.xlabel("Iterations") 35 | plt.ylabel("KL Divergence") 36 | plt.show() 37 | plt.clf() 38 | plt.scatter([each[0] for each in W], [each[1] for each in W], label='Selected Data') 39 | plt.scatter([each[0] for each in X], [each[1] for each in X], label='Target Data') 40 | plt.title("Target Data and Selected Data") 41 | plt.xlabel("Dimension 1") 42 | plt.ylabel("Dimension 2") 43 | plt.legend() 44 | plt.show() 45 | -------------------------------------------------------------------------------- /experiments/checks/quantization_consistency.py: -------------------------------------------------------------------------------- 1 | from GIO import GIOKL 2 | import numpy as np 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | from pyspark.sql.types import * 6 | import pyspark.sql.functions as F 7 | 8 | # Create some data 9 | def getX(): 10 | mean = [3,4] 11 | cov = [[0.5,0],[0,0.5]] 12 | np.random.seed(1) 13 | x, y = np.random.multivariate_normal(mean, cov, 100).T 14 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 15 | 16 | X = getX() 17 | 18 | new_schema = ArrayType(DoubleType(), containsNull=False) 19 | udf_no_null = F.udf(lambda x: x, new_schema) 20 | 21 | # Initialize class 22 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2) 23 | X_df = gio_kl.spark.createDataFrame(data=[(i, each.tolist()) for i, each in enumerate(X)], schema=["id", "features"]).withColumn("features", udf_no_null(F.col("features"))) 24 | 25 | # Quantize data 26 | model_train, model_X, transformed_train, transformed_X = gio_kl.quantize(X_df, X_df) 27 | quantized_X = jnp.array(model_X.clusterCenters()) 28 | 29 | # Calculate KL 30 | kl = gio_kl.calculate_statistical_distance(X, quantized_X) 31 | 32 | print("KL Divergence: " + str(kl)) 33 | -------------------------------------------------------------------------------- /experiments/checks/self_consistency.py: -------------------------------------------------------------------------------- 1 | from GIO import GIOKL 2 | import numpy as np 3 | import jax.numpy as jnp 4 | import matplotlib.pyplot as plt 5 | 6 | # Create some data 7 | def getX(): 8 | mean = [3,4] 9 | cov = [[0.5,0],[0,0.5]] 10 | np.random.seed(1) 11 | x, y = np.random.multivariate_normal(mean, cov, 100).T 12 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 13 | 14 | def getXTest(): 15 | mean = [3,4] 16 | cov = [[0.5,0],[0,0.5]] 17 | np.random.seed(5) 18 | x, y = np.random.multivariate_normal(mean, cov, 100).T 19 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 20 | 21 | X = getX() 22 | X_test = getXTest() 23 | 24 | # Initialize class 25 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2) 26 | 27 | # Perform the Algorithm 28 | W, kl_divs, _ = gio_kl.fit(X_test, X, normalize=False) 29 | W = W[100:] # Remove the uniform start 30 | 31 | # Plot results 32 | plt.plot(kl_divs) 33 | plt.title("KL Divergence vs. Iterations") 34 | plt.xlabel("Iterations") 35 | plt.ylabel("KL Divergence") 36 | plt.show() 37 | plt.clf() 38 | plt.scatter([each[0] for each in W], [each[1] for each in W], label='Selected Data') 39 | plt.scatter([each[0] for each in X], [each[1] for each in X], label='Target Data') 40 | plt.title("Target Data and Selected Data") 41 | plt.xlabel("Dimension 1") 42 | plt.ylabel("Dimension 2") 43 | plt.legend() 44 | plt.show() 45 | -------------------------------------------------------------------------------- /experiments/fashion_mnist/csv_to_image.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | to_write = "" 4 | 5 | with open(sys.argv[1], "r") as fp, open(sys.argv[2], "w") as op: 6 | for i, each in enumerate(fp): 7 | # Skip header 8 | if i == 0: 9 | continue 10 | sepped = each.strip().split(",") # Split into label and image pixels 11 | to_write += "[" + ", ".join(sepped[1:]) + "]\t" + sepped[0] + "\n" # Make into array of image \t label format 12 | op.write(to_write) 13 | -------------------------------------------------------------------------------- /experiments/fashion_mnist/image_to_csv.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | to_write = "label,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,pixel9,pixel10,pixel11,pixel12,pixel13,pixel14,pixel15,pixel16,pixel17,pixel18,pixel19,pixel20,pixel21,pixel22,pixel23,pixel24,pixel25,pixel26,pixel27,pixel28,pixel29,pixel30,pixel31,pixel32,pixel33,pixel34,pixel35,pixel36,pixel37,pixel38,pixel39,pixel40,pixel41,pixel42,pixel43,pixel44,pixel45,pixel46,pixel47,pixel48,pixel49,pixel50,pixel51,pixel52,pixel53,pixel54,pixel55,pixel56,pixel57,pixel58,pixel59,pixel60,pixel61,pixel62,pixel63,pixel64,pixel65,pixel66,pixel67,pixel68,pixel69,pixel70,pixel71,pixel72,pixel73,pixel74,pixel75,pixel76,pixel77,pixel78,pixel79,pixel80,pixel81,pixel82,pixel83,pixel84,pixel85,pixel86,pixel87,pixel88,pixel89,pixel90,pixel91,pixel92,pixel93,pixel94,pixel95,pixel96,pixel97,pixel98,pixel99,pixel100,pixel101,pixel102,pixel103,pixel104,pixel105,pixel106,pixel107,pixel108,pixel109,pixel110,pixel111,pixel112,pixel113,pixel114,pixel115,pixel116,pixel117,pixel118,pixel119,pixel120,pixel121,pixel122,pixel123,pixel124,pixel125,pixel126,pixel127,pixel128,pixel129,pixel130,pixel131,pixel132,pixel133,pixel134,pixel135,pixel136,pixel137,pixel138,pixel139,pixel140,pixel141,pixel142,pixel143,pixel144,pixel145,pixel146,pixel147,pixel148,pixel149,pixel150,pixel151,pixel152,pixel153,pixel154,pixel155,pixel156,pixel157,pixel158,pixel159,pixel160,pixel161,pixel162,pixel163,pixel164,pixel165,pixel166,pixel167,pixel168,pixel169,pixel170,pixel171,pixel172,pixel173,pixel174,pixel175,pixel176,pixel177,pixel178,pixel179,pixel180,pixel181,pixel182,pixel183,pixel184,pixel185,pixel186,pixel187,pixel188,pixel189,pixel190,pixel191,pixel192,pixel193,pixel194,pixel195,pixel196,pixel197,pixel198,pixel199,pixel200,pixel201,pixel202,pixel203,pixel204,pixel205,pixel206,pixel207,pixel208,pixel209,pixel210,pixel211,pixel212,pixel213,pixel214,pixel215,pixel216,pixel217,pixel218,pixel219,pixel220,pixel221,pixel222,pixel223,pixel224,pixel225,pixel226,pixel227,pixel228,pixel229,pixel230,pixel231,pixel232,pixel233,pixel234,pixel235,pixel236,pixel237,pixel238,pixel239,pixel240,pixel241,pixel242,pixel243,pixel244,pixel245,pixel246,pixel247,pixel248,pixel249,pixel250,pixel251,pixel252,pixel253,pixel254,pixel255,pixel256,pixel257,pixel258,pixel259,pixel260,pixel261,pixel262,pixel263,pixel264,pixel265,pixel266,pixel267,pixel268,pixel269,pixel270,pixel271,pixel272,pixel273,pixel274,pixel275,pixel276,pixel277,pixel278,pixel279,pixel280,pixel281,pixel282,pixel283,pixel284,pixel285,pixel286,pixel287,pixel288,pixel289,pixel290,pixel291,pixel292,pixel293,pixel294,pixel295,pixel296,pixel297,pixel298,pixel299,pixel300,pixel301,pixel302,pixel303,pixel304,pixel305,pixel306,pixel307,pixel308,pixel309,pixel310,pixel311,pixel312,pixel313,pixel314,pixel315,pixel316,pixel317,pixel318,pixel319,pixel320,pixel321,pixel322,pixel323,pixel324,pixel325,pixel326,pixel327,pixel328,pixel329,pixel330,pixel331,pixel332,pixel333,pixel334,pixel335,pixel336,pixel337,pixel338,pixel339,pixel340,pixel341,pixel342,pixel343,pixel344,pixel345,pixel346,pixel347,pixel348,pixel349,pixel350,pixel351,pixel352,pixel353,pixel354,pixel355,pixel356,pixel357,pixel358,pixel359,pixel360,pixel361,pixel362,pixel363,pixel364,pixel365,pixel366,pixel367,pixel368,pixel369,pixel370,pixel371,pixel372,pixel373,pixel374,pixel375,pixel376,pixel377,pixel378,pixel379,pixel380,pixel381,pixel382,pixel383,pixel384,pixel385,pixel386,pixel387,pixel388,pixel389,pixel390,pixel391,pixel392,pixel393,pixel394,pixel395,pixel396,pixel397,pixel398,pixel399,pixel400,pixel401,pixel402,pixel403,pixel404,pixel405,pixel406,pixel407,pixel408,pixel409,pixel410,pixel411,pixel412,pixel413,pixel414,pixel415,pixel416,pixel417,pixel418,pixel419,pixel420,pixel421,pixel422,pixel423,pixel424,pixel425,pixel426,pixel427,pixel428,pixel429,pixel430,pixel431,pixel432,pixel433,pixel434,pixel435,pixel436,pixel437,pixel438,pixel439,pixel440,pixel441,pixel442,pixel443,pixel444,pixel445,pixel446,pixel447,pixel448,pixel449,pixel450,pixel451,pixel452,pixel453,pixel454,pixel455,pixel456,pixel457,pixel458,pixel459,pixel460,pixel461,pixel462,pixel463,pixel464,pixel465,pixel466,pixel467,pixel468,pixel469,pixel470,pixel471,pixel472,pixel473,pixel474,pixel475,pixel476,pixel477,pixel478,pixel479,pixel480,pixel481,pixel482,pixel483,pixel484,pixel485,pixel486,pixel487,pixel488,pixel489,pixel490,pixel491,pixel492,pixel493,pixel494,pixel495,pixel496,pixel497,pixel498,pixel499,pixel500,pixel501,pixel502,pixel503,pixel504,pixel505,pixel506,pixel507,pixel508,pixel509,pixel510,pixel511,pixel512,pixel513,pixel514,pixel515,pixel516,pixel517,pixel518,pixel519,pixel520,pixel521,pixel522,pixel523,pixel524,pixel525,pixel526,pixel527,pixel528,pixel529,pixel530,pixel531,pixel532,pixel533,pixel534,pixel535,pixel536,pixel537,pixel538,pixel539,pixel540,pixel541,pixel542,pixel543,pixel544,pixel545,pixel546,pixel547,pixel548,pixel549,pixel550,pixel551,pixel552,pixel553,pixel554,pixel555,pixel556,pixel557,pixel558,pixel559,pixel560,pixel561,pixel562,pixel563,pixel564,pixel565,pixel566,pixel567,pixel568,pixel569,pixel570,pixel571,pixel572,pixel573,pixel574,pixel575,pixel576,pixel577,pixel578,pixel579,pixel580,pixel581,pixel582,pixel583,pixel584,pixel585,pixel586,pixel587,pixel588,pixel589,pixel590,pixel591,pixel592,pixel593,pixel594,pixel595,pixel596,pixel597,pixel598,pixel599,pixel600,pixel601,pixel602,pixel603,pixel604,pixel605,pixel606,pixel607,pixel608,pixel609,pixel610,pixel611,pixel612,pixel613,pixel614,pixel615,pixel616,pixel617,pixel618,pixel619,pixel620,pixel621,pixel622,pixel623,pixel624,pixel625,pixel626,pixel627,pixel628,pixel629,pixel630,pixel631,pixel632,pixel633,pixel634,pixel635,pixel636,pixel637,pixel638,pixel639,pixel640,pixel641,pixel642,pixel643,pixel644,pixel645,pixel646,pixel647,pixel648,pixel649,pixel650,pixel651,pixel652,pixel653,pixel654,pixel655,pixel656,pixel657,pixel658,pixel659,pixel660,pixel661,pixel662,pixel663,pixel664,pixel665,pixel666,pixel667,pixel668,pixel669,pixel670,pixel671,pixel672,pixel673,pixel674,pixel675,pixel676,pixel677,pixel678,pixel679,pixel680,pixel681,pixel682,pixel683,pixel684,pixel685,pixel686,pixel687,pixel688,pixel689,pixel690,pixel691,pixel692,pixel693,pixel694,pixel695,pixel696,pixel697,pixel698,pixel699,pixel700,pixel701,pixel702,pixel703,pixel704,pixel705,pixel706,pixel707,pixel708,pixel709,pixel710,pixel711,pixel712,pixel713,pixel714,pixel715,pixel716,pixel717,pixel718,pixel719,pixel720,pixel721,pixel722,pixel723,pixel724,pixel725,pixel726,pixel727,pixel728,pixel729,pixel730,pixel731,pixel732,pixel733,pixel734,pixel735,pixel736,pixel737,pixel738,pixel739,pixel740,pixel741,pixel742,pixel743,pixel744,pixel745,pixel746,pixel747,pixel748,pixel749,pixel750,pixel751,pixel752,pixel753,pixel754,pixel755,pixel756,pixel757,pixel758,pixel759,pixel760,pixel761,pixel762,pixel763,pixel764,pixel765,pixel766,pixel767,pixel768,pixel769,pixel770,pixel771,pixel772,pixel773,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783,pixel784\n" 4 | 5 | with open(sys.argv[1], "r") as fp, open(sys.argv[2], "w") as op: 6 | for each in fp: 7 | sepped = each.strip().split("\t") # Split into label, image 8 | to_write += sepped[1].strip() + "," # Label 9 | to_write += sepped[0].strip().replace("[", "").replace("]", "").replace(" ", "").replace(".0", "").strip() + "\n" # Remove parentheses and decimals 10 | op.write(to_write) 11 | -------------------------------------------------------------------------------- /experiments/fashion_mnist/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import sys 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | from skimage.transform import resize 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision import models 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | import time 13 | from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score 14 | 15 | train_csv = pd.read_csv(sys.argv[1]) 16 | test_csv = pd.read_csv(sys.argv[2]) 17 | valid_csv = pd.read_csv(sys.argv[3]) 18 | 19 | class FashionDataset(Dataset): 20 | """Class to build a dataset from FashionMNIST using Pytorch class Dataset.""" 21 | 22 | def __init__(self, data, transform=None): 23 | self.fashion_MNIST = list(data.values) 24 | self.transform = transform 25 | 26 | label = [] 27 | image = [] 28 | 29 | for i in self.fashion_MNIST: 30 | # first column is of labels. 31 | label.append(i[0]) 32 | image.append(i[1:]) 33 | self.labels = np.asarray(label) 34 | m = np.mean(image) 35 | std = np.std(image) 36 | # Dimension of Images = 28 * 28 * 1. where height = width = 28 and color_channels = 1. 37 | self.images = np.asarray((np.array(image)-m)/std).reshape(-1, 28, 28, 1).astype('float32') 38 | 39 | def __getitem__(self, index): 40 | label = self.labels[index] 41 | image = self.images[index] 42 | image = self.transform(resize(image, (224,224))/255) 43 | 44 | return image, label 45 | 46 | def __len__(self): 47 | return len(self.images) 48 | 49 | 50 | class MnistResNet(nn.Module): 51 | """Class that edits Resnet50 to do FashionMNIST classification.""" 52 | 53 | def __init__(self): 54 | super(MnistResNet, self).__init__() 55 | 56 | # Load a pretrained resnet model from torchvision.models in Pytorch 57 | self.model = models.resnet50(pretrained=True) 58 | 59 | # Change the input layer to take Grayscale image, instead of RGB images. 60 | # Hence in_channels is set as 1 61 | self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 62 | 63 | # Change the output layer to output 10 classes instead of 1000 classes 64 | num_ftrs = self.model.fc.in_features 65 | self.model.fc = nn.Linear(num_ftrs, 10) 66 | 67 | def forward(self, x): 68 | return self.model(x) 69 | 70 | 71 | my_resnet = MnistResNet() 72 | 73 | 74 | def calculate_metric(metric_fn, true_y, pred_y): 75 | """Calculate a metric on the outputs and ground truth 76 | :param metric_fn: metric function to calculate 77 | :param true_y: ground truth labels 78 | :param pred_y: predicted labels 79 | :return: the output of the applie dmetric_fn 80 | """ 81 | if metric_fn == accuracy_score: 82 | return metric_fn(true_y, pred_y) 83 | else: 84 | return metric_fn(true_y, pred_y, average="macro") 85 | 86 | 87 | def print_scores(p, r, f1, a, batch_size): 88 | """Print the P/R/F1 and Accuracy in a readable format 89 | :param p: precision 90 | :param r: recall 91 | :param f1: F1 between precision and recall 92 | :param a: accuracy 93 | :param batch_size: batch size used 94 | """ 95 | for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)): 96 | print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}") 97 | 98 | 99 | # Transform data into Tensor that has a range from 0 to 1 100 | train_set = FashionDataset(train_csv, transform=transforms.Compose([transforms.ToTensor()])) 101 | test_set = FashionDataset(test_csv, transform=transforms.Compose([transforms.ToTensor()])) 102 | valid_set = FashionDataset(valid_csv, transform=transforms.Compose([transforms.ToTensor()])) 103 | 104 | train_loader = DataLoader(train_set, batch_size=100, shuffle=True) 105 | test_loader = DataLoader(test_set, batch_size=100, shuffle=True) 106 | valid_loader = DataLoader(valid_set, batch_size=100, shuffle=True) 107 | 108 | 109 | # model 110 | model = MnistResNet().to(device) 111 | 112 | # params 113 | epochs = 5 114 | batch_size = 100 115 | 116 | 117 | # loss function and optimizer 118 | loss_function = nn.CrossEntropyLoss() 119 | 120 | # optimizer 121 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 122 | 123 | start_ts = time.time() 124 | 125 | losses = [] 126 | batches = len(train_loader) 127 | val_batches = len(valid_loader) 128 | test_batches = len(test_loader) 129 | 130 | # loop for every epoch (training + evaluation) 131 | for epoch in range(epochs): 132 | total_loss = 0 133 | 134 | # progress bar (works in Jupyter notebook too!) 135 | progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches) 136 | 137 | # ----------------- TRAINING -------------------- 138 | # set model to training 139 | model.train() 140 | 141 | for i, data in progress: 142 | X, y = data[0].to(device), data[1].to(device) 143 | 144 | # training step for single batch 145 | model.zero_grad() 146 | outputs = model(X) 147 | loss = loss_function(outputs, y) 148 | loss.backward() 149 | optimizer.step() 150 | 151 | # getting training quality data 152 | current_loss = loss.item() 153 | total_loss += current_loss 154 | 155 | # updating progress bar 156 | progress.set_description("Loss: {:.4f}".format(total_loss/(i+1))) 157 | 158 | # releasing unnecessary memory in GPU 159 | if torch.cuda.is_available(): 160 | torch.cuda.empty_cache() 161 | 162 | # ----------------- VALIDATION ----------------- 163 | val_losses = 0 164 | precision, recall, f1, accuracy = [], [], [], [] 165 | 166 | # set model to evaluating (testing) 167 | model.eval() 168 | with torch.no_grad(): 169 | for i, data in enumerate(valid_loader): 170 | X, y = data[0].to(device), data[1].to(device) 171 | 172 | outputs = model(X) # this gets the prediction from the network 173 | 174 | val_losses += loss_function(outputs, y) 175 | 176 | predicted_classes = torch.max(outputs, 1)[1] # get class from network's prediction 177 | 178 | # calculate P/R/F1/A metrics for batch 179 | for acc, metric in zip((precision, recall, f1, accuracy), 180 | (precision_score, recall_score, f1_score, accuracy_score)): 181 | acc.append( 182 | calculate_metric(metric, y.cpu(), predicted_classes.cpu()) 183 | ) 184 | test_losses = 0 185 | precision2, recall2, f12, accuracy2 = [], [], [], [] 186 | for i, data in enumerate(test_loader): 187 | X, y = data[0].to(device), data[1].to(device) 188 | 189 | outputs = model(X) # this get's the prediction from the network 190 | 191 | test_losses += loss_function(outputs, y) 192 | 193 | predicted_classes = torch.max(outputs, 1)[1] # get class from network's prediction 194 | 195 | # calculate P/R/F1/A metrics for batch 196 | for acc, metric in zip((precision2, recall2, f12, accuracy2), 197 | (precision_score, recall_score, f1_score, accuracy_score)): 198 | acc.append( 199 | calculate_metric(metric, y.cpu(), predicted_classes.cpu()) 200 | ) 201 | 202 | print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}, test loss: {test_losses/test_batches}") 203 | print_scores(precision, recall, f1, accuracy, val_batches) 204 | print_scores(precision2, recall2, f12, accuracy2, test_batches) 205 | losses.append(total_loss/batches) # for plotting learning curve 206 | print(f"Training time: {time.time()-start_ts}s") 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /experiments/speller/calculate_scores.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | # Preprocess hypotheses and split into words 4 | with open(sys.argv[1], 'r') as fp: 5 | hypotheses = [each.strip().split(" ") for each in fp.readlines()] 6 | 7 | # Preprocess targets and split into words 8 | with open(sys.argv[2], 'r') as fp: 9 | targets = [each.strip().split(" ") for each in fp.readlines()] 10 | 11 | # Preprocess sources and split into words 12 | with open(sys.argv[3], 'r') as fp: 13 | sources = [each.strip().split(" ") for each in fp.readlines()] 14 | 15 | acc_numerator = 0 16 | acc_denominator = 0 17 | corr_numerator = 0 18 | corr_denominator = 0 19 | len_uneven = 0 20 | for i in range(len(hypotheses)): 21 | hypothesis = hypotheses[i] 22 | target = targets[i] 23 | source = sources[i] 24 | 25 | # Skip mismatches 26 | if len(hypothesis) != len(target): 27 | len_uneven += 1 28 | continue 29 | 30 | # Compute word-level correction rate 31 | for x in range(len(target)): 32 | if len(target) == len(source) and target[x] != source[x]: 33 | corr_denominator += 1 34 | if source[x] != hypothesis[x]: 35 | corr_numerator += 1 36 | 37 | # Compute word-level accuracy 38 | for x in range(len(target)): 39 | if target[x] == hypothesis[x]: 40 | acc_numerator += 1 41 | acc_denominator += 1 42 | 43 | print("Mismatch: " + str(len_uneven)) 44 | print("Word-Level Accuracy: " + str(acc_numerator / acc_denominator)) 45 | print("Correction Rate: " + str(corr_numerator / corr_denominator)) 46 | 47 | 48 | -------------------------------------------------------------------------------- /images/process.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/images/process.gif -------------------------------------------------------------------------------- /images/process_once.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/images/process_once.gif -------------------------------------------------------------------------------- /images/readme_ex1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/images/readme_ex1.png -------------------------------------------------------------------------------- /images/readme_ex2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/images/readme_ex2.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="grad-info-opt", 8 | version="0.1.2", 9 | author="Dante Everaert", 10 | author_email="dante.everaert@berkeley.edu", 11 | description="Implementation of Gradient Information Optimization for efficient and scalable training data selection", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/daeveraert/gradient-information-optimization", 15 | project_urls={ 16 | "Bug Tracker": "https://github.com/daeveraert/gradient-information-optimization/issues", 17 | }, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Operating System :: OS Independent", 22 | ], 23 | package_dir={"": "src"}, 24 | packages=setuptools.find_packages(where="src"), 25 | python_requires=">=3.6", 26 | install_requires=[ 27 | 'jax>=0.3.25', 28 | 'pyspark>=2.4.8', 29 | 'numpy>=1.21.6', 30 | 'sentence_transformers>=2.2.2', 31 | 'jaxlib>=0.3.2', 32 | 'pandas>=1.0.5'] 33 | ) -------------------------------------------------------------------------------- /src/GIO/GIOKL.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import grad 3 | import random 4 | import jax 5 | 6 | import numpy as np 7 | import pyspark.sql.functions as F 8 | 9 | from pyspark.ml.clustering import KMeans 10 | from pyspark.sql import SparkSession 11 | from pyspark.sql.types import * 12 | from .GIO_super import GIO_super 13 | 14 | 15 | class GIOKL(GIO_super): 16 | def __init__(self, uniform_low=-1, uniform_high=1, uniform_start_size=20, dim=768): 17 | super().__init__() 18 | self.spark = SparkSession.builder.getOrCreate() 19 | self.uniform_low = uniform_low 20 | self.uniform_high = uniform_high 21 | self.uniform_start_size = uniform_start_size 22 | self.dim = dim 23 | self.random_init = False 24 | 25 | def _get_nearest(self, sample, point): 26 | """Euclidean distance from point to it's nearest point in sample. 27 | :param sample: a set of points to compute the nearest distance to 28 | :param point: point to retrieve the nearest point in sample from 29 | :return: the index of the nearest point 30 | """ 31 | norms = jnp.linalg.norm(sample - point, axis=1) 32 | return jnp.argsort(norms)[0] 33 | 34 | def _knn(self, x, y, k, last_only, discard_nearest, avg): 35 | """Find k_neighbors-nearest neighbor distances from y for each example in a minibatch x. 36 | :param x: tensor of shape [N_1, D] 37 | :param y: tensor of shape [N_2, D] 38 | :param k: the (k_neighbors+1):th nearest neighbor 39 | :param last_only: use only the last knn vs. all of them 40 | :param discard_nearest: 41 | :return: knn distances of shape [N, k_neighbors] or [N, 1] if last_only 42 | """ 43 | 44 | dist_x = jnp.sum((x ** 2), axis=-1)[:, jnp.newaxis] 45 | dist_y = jnp.sum((y ** 2), axis=-1)[:, jnp.newaxis].T 46 | cross = - 2 * jnp.matmul(x, y.T) 47 | distmat = dist_x + cross + dist_y 48 | distmat = jnp.clip(distmat, 1e-10, 1e+20) 49 | 50 | if discard_nearest: 51 | if not avg: 52 | knn, _ = jax.lax.top_k(-distmat, k + 1) 53 | else: 54 | knn = -jnp.sort(distmat) 55 | knn = knn[:, 1:] 56 | else: 57 | knn = -distmat 58 | 59 | if last_only: 60 | knn = knn[:, -1:] 61 | 62 | return jnp.sqrt(-knn) 63 | 64 | def _kl_divergence_knn(self, x, y, k, eps, discard_nearest_for_xy): 65 | """KL divergence estimator for D(x~p || y~q). 66 | :param x: x~p 67 | :param y: y~q 68 | :param k: kth nearest neighbor 69 | :param discard_nearest_for_xy: discard nearest in the xy calculation 70 | :param eps: small epsilon to pass to log 71 | :return: scalar 72 | """ 73 | n, d = x.shape 74 | m, _ = y.shape 75 | nns_xx = self._knn(x, x, k=k, last_only=True, discard_nearest=True, avg=False) 76 | nns_xy = self._knn(x, y, k=m, last_only=False, discard_nearest=discard_nearest_for_xy, avg=discard_nearest_for_xy) 77 | 78 | divergence = jnp.mean(d*jnp.log(nns_xy + eps) - d*jnp.log(nns_xx + eps)) + jnp.mean(jnp.log((k*m)/(jnp.arange(1, m+1) * (n-1)))) 79 | 80 | return divergence 81 | 82 | def calculate_statistical_distance(self, x, y, k=5, eps=1e-8, discard_nearest_for_xy=False): 83 | """Calculate statistical distance d(p,q) based on x~p and y~q. 84 | :param x: x~p 85 | :param y: y~q 86 | :param k: kth nearest neighbor 87 | :param eps: small epsilon to pass to log 88 | :return: scalar 89 | """ 90 | return self._kl_divergence_knn(x, y, k, eps, discard_nearest_for_xy) 91 | 92 | def gradient_descend(self, X, W, v, scaling_factor, max_iterations, lr=0.01, k=5, discard_nearest_for_xy=False): 93 | """Perform gradient descent on the statistical distance bwteen X and W+v 94 | :param X: target data 95 | :param W: current selected data 96 | :param v: initial v 97 | :param scaling_factor: scale the gradient 98 | :param max_iterations: iterations in the gradient descent 99 | :param lr: learning rate 100 | :param k: kth nearest neighbor 101 | :param discard_nearest_for_xy: discard nearest in the xy calculation 102 | :return: vector v opt 103 | """ 104 | i = 0 105 | while i < max_iterations: 106 | gradient = grad(lambda v: self.calculate_statistical_distance(X, jnp.concatenate((W, v[jnp.newaxis, :])), k, discard_nearest_for_xy=discard_nearest_for_xy))(v) 107 | v = v - lr * scaling_factor * gradient 108 | i += 1 109 | return v 110 | 111 | def _get_uniform_start(self, do_normalize): 112 | """Get a uniform start for D. 113 | :return: jnp array of uniform points 114 | """ 115 | def normalize(v): 116 | norm = np.linalg.norm(v) 117 | if norm == 0: 118 | return v 119 | return v / norm 120 | if do_normalize: 121 | return jnp.array([normalize(each) for each in np.random.uniform(low=self.uniform_low,high=self.uniform_high,size=(self.uniform_start_size,self.dim))]) 122 | else: 123 | return jnp.array([each for each in np.random.uniform(low=self.uniform_low,high=self.uniform_high,size=(self.uniform_start_size,self.dim))]) 124 | 125 | def fit(self, train, X, D=None, k=5, max_iter=100, stop_criterion="increase", min_difference=0, resets_allowed=False, max_resets=2, max_data_size=1, min_kl=0, max_sequential_increases=3, random_init_pct=0, random_restart_prob=0, scale_factor="auto", v_init='mean', grad_desc_iter=50, discard_nearest_for_xy=False, normalize=True, lr=0.01): 126 | """Perform GIO 127 | :param train: training data 128 | :param X: target data 129 | :param D: initial data 130 | :param k: kth nearest neighbor 131 | :param max_iter: max iterations for the algorithm 132 | :param stop_criterion: a string for the stopping criterion, one of the following: 'increase', 'max_resets', 'min_difference', 'sequential_increase_tolerance', 'min_kl', 'data_size' 133 | :param min_difference: the minimum difference between prior and current KL divergence for 'min_difference' stop criterion 134 | :param resets_allowed: whether if KL divergence increase, resetting G to the full train is allowed (allows the algorithm to pick duplicates). Must be set to true if the stop criterion is 'max_resets' 135 | :param max_resets: the number of resets allowed for the 'max_resets' stop criterion 136 | :param max_data_size: the maximum size of data for the 'data_size' stop criterion, as a percentage 137 | :param min_kl: the minimum kl divergence for the 'min_kl' stop criterion 138 | :param max_sequential_increases: the maximum number of sequential KL divergence increases for the 'sequential_increase_tolerance' stop criterion 139 | :param random_init_pct: the percent of training data to initialize the algorithm from 140 | :param random_restart_prob: probability to extend the gradient descent iterations by 3x to find potentially better extrema. Higher values come at the cost of efficiency 141 | :param scale_factor: factor to scale the gradient by or 'auto' 142 | :param v_init: how to initialize v in gradients descent, one of the following: 'mean', 'prev_opt', 'jump' 143 | :param grad_desc_iter: the number of iterations in gradient descent 144 | :param discard_nearest_for_xy: discard nearest in the xy calculation 145 | :param lr: Learning rate for gradient descent 146 | :return: selected data, kl divergences, (v, scale_factor, just_reset, num_resets, increases, adder, kl_divs) 147 | """ 148 | if not random_init_pct and D is None: 149 | W = self._get_uniform_start(normalize) 150 | self.random_init = True 151 | elif D is None: 152 | amount = int(random_init_pct * len(train)) 153 | W = jnp.array(random.sample(train.tolist(), amount)) 154 | else: 155 | W = D[:] 156 | 157 | kl_dist_prev = self.calculate_statistical_distance(X, W, k, discard_nearest_for_xy=discard_nearest_for_xy) 158 | 159 | print("Starting KL: " + str(kl_dist_prev)) 160 | if v_init == 'mean' or v_init == 'prev_opt': 161 | v = jnp.mean(X, axis=0) 162 | elif v_init == 'jump': 163 | v = jnp.array(random.sample(X.tolist(), 1)).squeeze() 164 | adder = train[:] 165 | kl_divs = [] 166 | 167 | scale_factor = jnp.linalg.norm(v)/jnp.linalg.norm(grad(lambda v: self.calculate_statistical_distance(X, jnp.concatenate((W, v[jnp.newaxis, :])), k, discard_nearest_for_xy=discard_nearest_for_xy))(v)) if scale_factor == "auto" else scale_factor 168 | 169 | i = 0 170 | just_reset = False 171 | num_resets = 0 172 | total_iter = 0 173 | increases = 0 174 | while True: 175 | # Warmup, reset or random restart 176 | if i == 0 or just_reset or random.random() < random_restart_prob: 177 | v = self.gradient_descend(X, W, v, scale_factor, grad_desc_iter * 3, lr=lr, k=k, discard_nearest_for_xy=discard_nearest_for_xy) 178 | else: 179 | v = self.gradient_descend(X, W, v, scale_factor, grad_desc_iter, lr=lr, k=k, discard_nearest_for_xy=discard_nearest_for_xy) 180 | idx = self._get_nearest(v, adder) 181 | minvals = adder[idx] 182 | adder = jnp.delete(adder, idx, axis=0) 183 | 184 | W_tmp = jnp.concatenate((W, jnp.array(minvals)[jnp.newaxis, :])) 185 | 186 | kl_dist = self.calculate_statistical_distance(X, W_tmp, k, discard_nearest_for_xy=discard_nearest_for_xy) 187 | print("KL Divergence at iteration " + str(i) + ": " + str(kl_dist)) 188 | 189 | # STOPPING CRITERIA 190 | if total_iter > max_iter: 191 | break 192 | 193 | if v_init == 'mean': 194 | v = jnp.mean(X, axis=0) 195 | elif v_init == 'jump': 196 | v = jnp.array(random.sample(X.tolist(), 1)).squeeze() 197 | 198 | adder, i, just_reset, stop, v, increases, num_resets = self._test_stop_criterion(v_init, stop_criterion, kl_dist, kl_dist_prev, num_resets, max_resets, min_difference, increases, max_sequential_increases, min_kl, max_data_size, train, X, i, v, just_reset, resets_allowed, adder) 199 | 200 | if stop: 201 | break 202 | if not just_reset: 203 | W = W_tmp 204 | kl_divs += [kl_dist] 205 | kl_dist_prev = kl_dist 206 | i += 1 207 | total_iter += 1 208 | return W, kl_divs, (v, scale_factor, just_reset, num_resets, increases, adder, kl_divs) 209 | 210 | def _test_stop_criterion(self, v_init, stop_criterion, kl_dist, kl_dist_prev, num_resets, max_resets, min_difference, increases, max_sequential_increases, min_kl, max_data_size, train, X, i, v, just_reset, resets_allowed, adder): 211 | stop = False 212 | if stop_criterion == "increase" and kl_dist - kl_dist_prev > 0: 213 | stop = True 214 | elif stop_criterion == "max_resets" and kl_dist - kl_dist_prev > 0 and num_resets == max_resets: 215 | stop = True 216 | elif stop_criterion == "min_difference" and kl_dist_prev - kl_dist < min_difference: 217 | stop = True 218 | elif stop_criterion == 'sequential_increase_tolerance' and kl_dist - kl_dist_prev > 0 and increases == max_sequential_increases: 219 | stop = True 220 | elif stop_criterion == 'min_kl' and kl_dist < min_kl: 221 | stop = True 222 | elif stop_criterion == 'data_size' and i > int(max_data_size * len(train)): 223 | stop = True 224 | if stop: 225 | if just_reset: 226 | increases += 1 227 | if resets_allowed and num_resets < max_resets: 228 | num_resets += 1 229 | if v_init == 'prev_opt': 230 | v = jnp.mean(X, axis=0) 231 | print("KL Div Increase, Resetting G") 232 | adder = train[:] 233 | i -= 1 234 | stop = False 235 | just_reset = True 236 | else: 237 | just_reset = False 238 | increases = 0 239 | return adder, i, just_reset, stop, v, increases, num_resets 240 | 241 | def _return_kmeans(self, df, k, rseed): 242 | """Use Spark to perform K-Means 243 | :param df: dataframe to perform K-Means with 244 | :param k: number of clusters to compute 245 | :param rseed: random seed 246 | :return: k-means model, transformed df 247 | """ 248 | kmeans = KMeans().setK(k).setSeed(rseed) 249 | model = kmeans.fit(df.select("features")) 250 | transformed_df = model.transform(df) 251 | return model, transformed_df 252 | 253 | def quantize(self, df_train, df_x, k=1500, rseed='auto', rseed1=234, rseed2=456): 254 | """Use Spark to perform K-Means 255 | :param df_train: train dataframe to quantize 256 | :param df_x: target dataframe to quantize 257 | :param k: number of clusters to compute 258 | :param rseed: 'auto' or 'manual' 259 | :param rseed1: first random seed 260 | :param rseed2: second random seed 261 | :return: k-means model, transformed df 262 | """ 263 | if rseed == 'auto': 264 | rseed1 = random.randint(-1000,1000) 265 | rseed2 = random.randint(-1000,1000) 266 | model_train, transformed_train = self._return_kmeans(df_train, k, rseed1) 267 | model_X, transformed_X = self._return_kmeans(df_x, k, rseed2) 268 | return model_train, model_X, transformed_train, transformed_X 269 | 270 | def read_data_from_csv(self, path, path_X, delim="\t"): 271 | """Read in and process data stored in a csv. Data must be of the format: _c0, _c1, _c2 where _c2 contains the 272 | string representation of the vector, like "[0.1, 0.23, 0.45 ...]" 273 | :param path: path to training data 274 | :param path_X: path to target data 275 | :param delim: delimiter for csv file 276 | :return: train df, target df 277 | """ 278 | new_schema = ArrayType(DoubleType(), containsNull=False) 279 | udf_json_to_arr = F.udf(lambda x: x, new_schema) 280 | 281 | df_read = self.spark.read.option("delimiter", delim).csv(path) 282 | df_with_embeddings = df_read.withColumn("features", udf_json_to_arr(F.from_json(F.col("_c2"), "array"))) 283 | 284 | df_X_read = self.spark.read.option("delimiter", delim).csv(path_X) 285 | df_X_with_embeddings = df_X_read.withColumn("features", udf_json_to_arr(F.from_json(F.col("_c2"), "array"))) 286 | 287 | return df_with_embeddings, df_X_with_embeddings 288 | 289 | def read_data_from_parquet(self, path, path_X): 290 | """Read in and process data stored in a parquet format. Data must contain a column "features" that stores an array 291 | of the vectors and be non-nullable. 292 | :param path: path to training data 293 | :param path_X: path to target data 294 | :return: train df, target df 295 | """ 296 | new_schema = ArrayType(DoubleType(), containsNull=False) 297 | udf_no_null = F.udf(lambda x: x, new_schema) 298 | 299 | df_with_embeddings = self.spark.read.parquet(path).withColumn("features", udf_no_null(F.col("features"))) 300 | df_X_with_embeddings = self.spark.read.parquet(path_X).withColumn("features", udf_no_null(F.col("features"))) 301 | return df_with_embeddings, df_X_with_embeddings 302 | 303 | def explode(self, chosen_centroids, kmeans_transformed_df, kmeans_centroids_df): 304 | """Read in and process data stored in a parquet format. Data must contain a column "features" that stores an array 305 | of the vectors. 306 | :param path: path to training data 307 | :param path_X: path to target data 308 | :return: train df, target df 309 | """ 310 | pre_existing_centroids = jnp.array([f[1] for f in sorted([[each[0], each[1]] for each in kmeans_centroids_df.collect()], key=lambda x: x[0])]) 311 | paired = [] 312 | for each in chosen_centroids: 313 | for i, x in enumerate(pre_existing_centroids.tolist()): 314 | if each.tolist() in [x]: 315 | paired += [i] 316 | print("Found " + str(len(paired)) + " centroids out of " + str(len(chosen_centroids)) + " selected centroids") 317 | full_selections_df = self.spark.createDataFrame(data=[(i, each) for i, each in enumerate(paired)], schema=["i", "id"]).join(kmeans_transformed_df, F.col("id") == F.col("prediction")) 318 | return full_selections_df 319 | 320 | -------------------------------------------------------------------------------- /src/GIO/GIO_super.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import grad 3 | import random 4 | import jax 5 | 6 | import numpy as np 7 | import pyspark.sql.functions as F 8 | 9 | from pyspark.ml.clustering import KMeans 10 | from pyspark.sql import SparkSession 11 | from pyspark.sql.types import * 12 | 13 | 14 | class GIO_super: 15 | def __init__(self): 16 | pass 17 | 18 | def calculate_statistical_distance(self, x, y): 19 | pass 20 | 21 | def gradient_descend(self, X, W, v, factor, max_iterations, lr, *arg): 22 | pass 23 | 24 | def fit(self, train, X, *arg): 25 | pass 26 | 27 | def quantize(self, df_train, df_x, quantize_into): 28 | pass 29 | 30 | def _get_nearest(self, sample, point): 31 | pass 32 | 33 | def explode(self, chosen_centroids, kmeans_transformed_df, kmeans_centroids_df): 34 | pass 35 | -------------------------------------------------------------------------------- /src/GIO/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/src/GIO/__init__.py -------------------------------------------------------------------------------- /src/GIO/generate_text_embeddings.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | import time 3 | 4 | 5 | 6 | class GenerateEmbeddings: 7 | def __init__(self, model_name, device='cuda'): 8 | self.model = SentenceTransformer(model_name) 9 | self.device = device 10 | 11 | def generate_embeddings(self, input_file_path, output_file_path): 12 | """Generate Embeddings from a text file 13 | :param input_file_path: path to input text, one sentence per line 14 | :param output_file_path: path to desired output file 15 | """ 16 | print('Reading File...') 17 | with open(input_file_path, 'r') as fp: 18 | sentences = fp.readlines() 19 | print('Generating Embeddings... This May Take a While') 20 | start = time.time() 21 | embeddings = self.model.encode(sentences, device=self.device) 22 | end = time.time() 23 | 24 | print("Time Taken (s): " + str(end - start)) 25 | 26 | print("Writing Embeddings.. This May Take a While") 27 | with open(output_file_path, 'w') as op: 28 | for i, each in enumerate(embeddings): 29 | op.write(str(each.tolist()).strip() + "\n") 30 | -------------------------------------------------------------------------------- /src/grad_info_opt.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: grad-info-opt 3 | Version: 0.1.2 4 | Summary: Implementation of Gradient Information Optimization for efficient and scalable training data selection 5 | Home-page: https://github.com/daeveraert/gradient-information-optimization 6 | Author: Dante Everaert 7 | Author-email: dante.everaert@berkeley.edu 8 | License: UNKNOWN 9 | Project-URL: Bug Tracker, https://github.com/daeveraert/gradient-information-optimization/issues 10 | Platform: UNKNOWN 11 | Classifier: Programming Language :: Python :: 3 12 | Classifier: License :: OSI Approved :: Apache Software License 13 | Classifier: Operating System :: OS Independent 14 | Requires-Python: >=3.6 15 | Description-Content-Type: text/markdown 16 | License-File: LICENSE 17 | 18 | # GIO: Gradient Information Optimization 19 |

20 | 21 |

22 | 23 | GIO is a library that implements Gradient Information Optimization (GIO) at scale, from the paper GIO: Gradient Information Optimization for Training Dataset Selection. GIO is a data selection technique that can 24 | be used to select a subset of training data that gives similar or superior performance to a model trained on full data. 25 | 26 | **Paper Abstract** 27 | 28 | It is often advantageous to train models on a subset of the available train examples, because the examples are of variable quality or because one would like to train with fewer examples, without sacrificing performance. We present Gradient Information Optimization (GIO), a scalable, task-agnostic approach to this data selection problem that requires only a small set of (unlabeled) examples representing a target distribution. GIO begins from a natural, information-theoretic objective that is intractable in practice. Our contribution is in showing that it can be made highly scalable through a simple relaxation of the objective and a highly efficient implementation. In experiments with machine translation, spelling correction, and image recognition, we show that GIO delivers outstanding results with very small train sets. These findings are robust to different representation models and hyperparameters for GIO itself. GIO is task- and domain-agnostic and can be applied out-of-the-box to new datasets and domains. 29 | 30 | 31 | **Features**: 32 | - GIO with quantization using K-means. 33 | - Sentence embedding script to generate embeddings from data to use in GIO 34 | 35 | 36 | ## Installation 37 | 38 | Installable via pip: 39 | ```bash 40 | pip install grad-info-opt 41 | ``` 42 | Or install directly form the repository: 43 | 44 | ```bash 45 | git clone git@github.com:daeveraert/gradient-information-optimization.git 46 | cd gradient-information-optimization 47 | pip install -e . 48 | ``` 49 | 50 | Direct installation will require you to install additional dependencies listed below. We welcome contributions to GIO. 51 | 52 | ## Requirements 53 | - `numpy>=1.21.6` 54 | - `jax>=0.3.25` 55 | - `pyspark>=2.4.8` 56 | - `sentence_transformers>=2.2.2` 57 | - `jaxlib>=0.3.2` 58 | - `pandas>=1.0.5` 59 | 60 | 61 | 62 | ## Quick Start 63 | **Note:** GIO uses a Spark context, or if it can't find one, it will create a local one. You may encounter a Spark error before the algorithm runs complaining it cannot find a free port. In this case, executing ```export SPARK_LOCAL_IP="127.0.0.1"``` should resolve the issue. 64 | 65 | Here is a simple 2D demonstration of how to use GIO with visualization: 66 | ```python 67 | from GIO import GIOKL 68 | import numpy as np 69 | import jax.numpy as jnp 70 | import matplotlib.pyplot as plt 71 | 72 | # Create some data 73 | def getX(): 74 | mean = [3,4] 75 | cov = [[0.5,0],[0,0.5]] 76 | np.random.seed(1) 77 | x, y = np.random.multivariate_normal(mean, cov, 100).T 78 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 79 | 80 | def getXTest(): 81 | mean = [3,4] 82 | cov = [[0.5,0],[0,0.5]] 83 | np.random.seed(5) 84 | x, y = np.random.multivariate_normal(mean, cov, 100).T 85 | return jnp.array([[x[i],y[i]] for i in range(len(x))]) 86 | 87 | X = getX() 88 | X_test = getXTest() 89 | 90 | # Initialize class 91 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2) 92 | 93 | # Perform the Algorithm 94 | W, kl_divs, _ = gio_kl.fit(X_test, X, normalize=False) 95 | W = W[100:] # Remove the uniform start 96 | 97 | # Plot results 98 | plt.plot(kl_divs) 99 | plt.title("KL Divergence vs. Iterations") 100 | plt.xlabel("Iterations") 101 | plt.ylabel("KL Divergence") 102 | plt.show() 103 | plt.clf() 104 | plt.scatter([each[0] for each in W], [each[1] for each in W], label='Selected Data') 105 | plt.scatter([each[0] for each in X], [each[1] for each in X], label='Target Data') 106 | plt.title("Target Data and Selected Data") 107 | plt.xlabel("Dimension 1") 108 | plt.ylabel("Dimension 2") 109 | plt.legend() 110 | plt.show() 111 | ``` 112 |

113 | 114 | 115 |

116 | 117 | Here is a more complex example for scale applications, reading and using a CSV that stores embeddings and data, using quantization-explosion, and Spark: 118 | ```python 119 | from GIO import GIOKL 120 | import jax.numpy as jnp 121 | import matplotlib.pyplot as plt 122 | import pyspark.sql.functions as F 123 | 124 | # Initialize class 125 | gio_kl = GIOKL.GIOKL(uniform_low=-1, uniform_high=1, uniform_start_size=20, dim=768) 126 | 127 | # Read data 128 | train_df, target_df = gio_kl.read_data_from_csv(PATH_TO_TRAIN, PATH_TO_TARGET) 129 | 130 | # Quantize data 131 | model_train, model_X, transformed_train, transformed_X = gio_kl.quantize(train_df, target_df) 132 | 133 | X = jnp.array(model_X.clusterCenters()) 134 | train = jnp.array(model_train.clusterCenters()) 135 | centroids_df = gio_kl.spark.createDataFrame(data=[(i, each.tolist()) for i, each in enumerate(model_train.clusterCenters())], schema=["id", "centroid"]) 136 | 137 | # Perform the Algorithm 138 | W, kl_divs, _ = gio_kl.fit(train, X, max_iter=300, stopping_criterion='sequential_increase_tolerance', v_init='jump') 139 | W = W[20:] # Remove the uniform start 140 | 141 | # Explode back to original data and write resulting data 142 | full_selections_df = gio_kl.explode(W, transformed_train, centroids_df) 143 | full_selections_df.select(F.col("_c0"), F.col("_c1")).write.option("delimiter", "\t").csv(OUTPUT_PATH) 144 | 145 | 146 | # Plot results 147 | plt.plot(kl_divs) 148 | plt.title("KL Divergence vs. Iterations") 149 | plt.xlabel("Iterations") 150 | plt.ylabel("KL Divergence") 151 | plt.show() 152 | ``` 153 | **Note:** For quantization, Spark requires a large rpc message size. It is recommended to place ```gio_kl.spark.conf.set("spark.rpc.message.maxSize", "500")``` (or any large number) in the code before calling quantize, if the defaults haven't already been increased. 154 | 155 | ## Available Options 156 | `GIOKL.fit` takes the following arguments: 157 | - `train`: training data as a jnp array (jnp is almost identical to numpy) [M, D] shape 158 | - `X`: target data as a jnp array [N, D] shape 159 | - `D`: initial data as a jnp array, default None. Use None to initialize from 0 (uniform) or a subset of training data 160 | - `k`: kth nearest neighbor to use in the KL divergence estimation, default 5 161 | - `max_iter`: maximum iterations for the algorithm. One iteration adds one point (cluster) 162 | - `stop_criterion`: a string for the stopping criterion, one of the following: 'increase', 'max_resets', 'min_difference', 'sequential_increase_tolerance', 'min_kl', 'data_size'. Default is 'increase' 163 | - `min_difference`: the minimum difference between prior and current KL divergence for 'min_difference' stop criterion only. Default is 0 164 | - `resets_allowed`: whether if KL divergence increases, resetting G to the full train is allowed (allows the algorithm to pick duplicates). Must be set to true if the stop criterion is 'max_resets'. Default is False 165 | - `max_resets`: the number of resets allowed for the 'max_resets' stop criterion only (a reset resets G to the full train set and allows the algorithm to pick duplicates). Default is 2 166 | - `max_data_size`: the maximum size of data to be selected for the 'data_size' stop criterion only, as a percentage (of total data) between 0 and 1. Default is 1 167 | - `min_kl`: the minimum kl divergence for the 'min_kl' stop criterion only. Default is 0 168 | - `max_sequential_increases`: the maximum number of sequential KL divergence increases for the 'sequential_increase_tolerance' stop criterion only. Default is 3 169 | - `random_init_pct`: the percent of training data to initialize the algorithm from. Default is 0 170 | - `random_restart_prob`: probability at any given iteration to extend the gradient descent iterations by 3x, to find potentially better extrema. Higher values come at the cost of efficiency. Default is 0 171 | - `scale_factor`: factor to scale the gradient by, or 'auto'. Default is 'auto', which is recommended 172 | - `v_init`: how to initialize v in gradients descent, one of the following: 'mean', 'prev_opt', 'jump'. Default is 'mean' 173 | - `grad_desc_iter`: the number of iterations to use in gradient descent. Default is 50 174 | - `discard_nearest_for_xy`: discard nearest in the xy calculation of KL divergence, for use when X and the train set are the same, comes at the cost of efficiency. Default is False 175 | - `lr`: Learning rate for gradient descent. Default is 0.01 176 | 177 | ## Citing GIO 178 | If you use GIO in a publication, blog or software project, please cite the paper: 179 | ``` 180 | @misc{everaert2023gio, 181 | title={GIO: Gradient Information Optimization for Training Dataset Selection}, 182 | author={Dante Everaert and Christopher Potts}, 183 | year={2023}, 184 | eprint={2306.11670}, 185 | archivePrefix={arXiv}, 186 | primaryClass={cs.LG} 187 | } 188 | ``` 189 | 190 | 191 | -------------------------------------------------------------------------------- /src/grad_info_opt.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.py 4 | src/GIO/GIOKL.py 5 | src/GIO/GIO_super.py 6 | src/GIO/__init__.py 7 | src/GIO/generate_text_embeddings.py 8 | src/grad_info_opt.egg-info/PKG-INFO 9 | src/grad_info_opt.egg-info/SOURCES.txt 10 | src/grad_info_opt.egg-info/dependency_links.txt 11 | src/grad_info_opt.egg-info/requires.txt 12 | src/grad_info_opt.egg-info/top_level.txt -------------------------------------------------------------------------------- /src/grad_info_opt.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/grad_info_opt.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | jax>=0.3.25 2 | pyspark>=2.4.8 3 | numpy>=1.21.6 4 | sentence_transformers>=2.2.2 5 | jaxlib>=0.3.2 6 | pandas>=1.0.5 7 | -------------------------------------------------------------------------------- /src/grad_info_opt.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | GIO 2 | --------------------------------------------------------------------------------