├── .github
└── workflows
│ └── python-publish.yml
├── LICENSE
├── README.md
├── build
└── lib
│ └── GIO
│ ├── GIOKL.py
│ ├── GIO_super.py
│ ├── __init__.py
│ └── generate_text_embeddings.py
├── dist
└── grad_info_opt-0.1.2-py3-none-any.whl
├── experiments
├── bm25_scripts
│ ├── get_less.py
│ ├── get_pairs.py
│ └── make_csv.py
├── checks
│ ├── negative_consistency.py
│ ├── quantization_consistency.py
│ └── self_consistency.py
├── fashion_mnist
│ ├── csv_to_image.py
│ ├── image_to_csv.py
│ └── train.py
└── speller
│ └── calculate_scores.py
├── images
├── process.gif
├── process_once.gif
├── readme_ex1.png
└── readme_ex2.png
├── setup.py
└── src
├── GIO
├── GIOKL.py
├── GIO_super.py
├── __init__.py
└── generate_text_embeddings.py
└── grad_info_opt.egg-info
├── PKG-INFO
├── SOURCES.txt
├── dependency_links.txt
├── requires.txt
└── top_level.txt
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GIO: Gradient Information Optimization
2 |
3 |
4 |
5 |
6 | GIO is a library that implements Gradient Information Optimization (GIO) at scale, from the paper GIO: Gradient Information Optimization for Training Dataset Selection. GIO is a data selection technique that can
7 | be used to select a subset of training data that gives similar or superior performance to a model trained on full data.
8 |
9 | **Paper Abstract**
10 |
11 | It is often advantageous to train models on a subset of the available train examples, because the examples are of variable quality or because one would like to train with fewer examples, without sacrificing performance. We present Gradient Information Optimization (GIO), a scalable, task-agnostic approach to this data selection problem that requires only a small set of (unlabeled) examples representing a target distribution. GIO begins from a natural, information-theoretic objective that is intractable in practice. Our contribution is in showing that it can be made highly scalable through a simple relaxation of the objective and a highly efficient implementation. In experiments with machine translation, spelling correction, and image recognition, we show that GIO delivers outstanding results with very small train sets. These findings are robust to different representation models and hyperparameters for GIO itself. GIO is task- and domain-agnostic and can be applied out-of-the-box to new datasets and domains.
12 |
13 |
14 | **Features**:
15 | - GIO with quantization using K-means.
16 | - Sentence embedding script to generate embeddings from data to use in GIO
17 |
18 |
19 | ## Installation
20 |
21 | Installable via pip:
22 | ```bash
23 | pip install grad-info-opt
24 | ```
25 | Or install directly form the repository:
26 |
27 | ```bash
28 | git clone git@github.com:daeveraert/gradient-information-optimization.git
29 | cd gradient-information-optimization
30 | pip install -e .
31 | ```
32 |
33 | Direct installation will require you to install additional dependencies listed below. We welcome contributions to GIO.
34 |
35 | ## Requirements
36 | - `numpy>=1.21.6`
37 | - `jax>=0.3.25`
38 | - `pyspark>=2.4.8`
39 | - `sentence_transformers>=2.2.2`
40 | - `jaxlib>=0.3.2`
41 | - `pandas>=1.0.5`
42 |
43 |
44 |
45 | ## Quick Start
46 | **Note:** GIO uses a Spark context, or if it can't find one, it will create a local one. You may encounter a Spark error before the algorithm runs complaining it cannot find a free port. In this case, executing ```export SPARK_LOCAL_IP="127.0.0.1"``` should resolve the issue.
47 |
48 | Here is a simple 2D demonstration of how to use GIO with visualization:
49 | ```python
50 | from GIO import GIOKL
51 | import numpy as np
52 | import jax.numpy as jnp
53 | import matplotlib.pyplot as plt
54 |
55 | # Create some data
56 | def getX():
57 | mean = [3,4]
58 | cov = [[0.5,0],[0,0.5]]
59 | np.random.seed(1)
60 | x, y = np.random.multivariate_normal(mean, cov, 100).T
61 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
62 |
63 | def getXTest():
64 | mean = [3,4]
65 | cov = [[0.5,0],[0,0.5]]
66 | np.random.seed(5)
67 | x, y = np.random.multivariate_normal(mean, cov, 100).T
68 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
69 |
70 | X = getX()
71 | X_test = getXTest()
72 |
73 | # Initialize class
74 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2)
75 |
76 | # Perform the Algorithm
77 | W, kl_divs, _ = gio_kl.fit(X_test, X, normalize=False)
78 | W = W[100:] # Remove the uniform start
79 |
80 | # Plot results
81 | plt.plot(kl_divs)
82 | plt.title("KL Divergence vs. Iterations")
83 | plt.xlabel("Iterations")
84 | plt.ylabel("KL Divergence")
85 | plt.show()
86 | plt.clf()
87 | plt.scatter([each[0] for each in W], [each[1] for each in W], label='Selected Data')
88 | plt.scatter([each[0] for each in X], [each[1] for each in X], label='Target Data')
89 | plt.title("Target Data and Selected Data")
90 | plt.xlabel("Dimension 1")
91 | plt.ylabel("Dimension 2")
92 | plt.legend()
93 | plt.show()
94 | ```
95 |
96 |
97 |
98 |
99 |
100 | Here is a more complex example for scale applications, reading and using a CSV that stores embeddings and data, using quantization-explosion, and Spark:
101 | ```python
102 | from GIO import GIOKL
103 | import jax.numpy as jnp
104 | import matplotlib.pyplot as plt
105 | import pyspark.sql.functions as F
106 |
107 | # Initialize class
108 | gio_kl = GIOKL.GIOKL(uniform_low=-1, uniform_high=1, uniform_start_size=20, dim=768)
109 |
110 | # Read data
111 | train_df, target_df = gio_kl.read_data_from_csv(PATH_TO_TRAIN, PATH_TO_TARGET)
112 |
113 | # Quantize data
114 | model_train, model_X, transformed_train, transformed_X = gio_kl.quantize(train_df, target_df)
115 |
116 | X = jnp.array(model_X.clusterCenters())
117 | train = jnp.array(model_train.clusterCenters())
118 | centroids_df = gio_kl.spark.createDataFrame(data=[(i, each.tolist()) for i, each in enumerate(model_train.clusterCenters())], schema=["id", "centroid"])
119 |
120 | # Perform the Algorithm
121 | W, kl_divs, _ = gio_kl.fit(train, X, max_iter=300, stopping_criterion='sequential_increase_tolerance', v_init='jump')
122 | W = W[20:] # Remove the uniform start
123 |
124 | # Explode back to original data and write resulting data
125 | full_selections_df = gio_kl.explode(W, transformed_train, centroids_df)
126 | full_selections_df.select(F.col("_c0"), F.col("_c1")).write.option("delimiter", "\t").csv(OUTPUT_PATH)
127 |
128 |
129 | # Plot results
130 | plt.plot(kl_divs)
131 | plt.title("KL Divergence vs. Iterations")
132 | plt.xlabel("Iterations")
133 | plt.ylabel("KL Divergence")
134 | plt.show()
135 | ```
136 | **Note:** For quantization, Spark requires a large rpc message size. It is recommended to place ```gio_kl.spark.conf.set("spark.rpc.message.maxSize", "500")``` (or any large number) in the code before calling quantize, if the defaults haven't already been increased.
137 |
138 | ## Available Options
139 | `GIOKL.fit` takes the following arguments:
140 | - `train`: training data as a jnp array (jnp is almost identical to numpy) [M, D] shape
141 | - `X`: target data as a jnp array [N, D] shape
142 | - `D`: initial data as a jnp array, default None. Use None to initialize from 0 (uniform) or a subset of training data
143 | - `k`: kth nearest neighbor to use in the KL divergence estimation, default 5
144 | - `max_iter`: maximum iterations for the algorithm. One iteration adds one point (cluster)
145 | - `stop_criterion`: a string for the stopping criterion, one of the following: 'increase', 'max_resets', 'min_difference', 'sequential_increase_tolerance', 'min_kl', 'data_size'. Default is 'increase'
146 | - `min_difference`: the minimum difference between prior and current KL divergence for 'min_difference' stop criterion only. Default is 0
147 | - `resets_allowed`: whether if KL divergence increases, resetting G to the full train is allowed (allows the algorithm to pick duplicates). Must be set to true if the stop criterion is 'max_resets'. Default is False
148 | - `max_resets`: the number of resets allowed for the 'max_resets' stop criterion only (a reset resets G to the full train set and allows the algorithm to pick duplicates). Default is 2
149 | - `max_data_size`: the maximum size of data to be selected for the 'data_size' stop criterion only, as a percentage (of total data) between 0 and 1. Default is 1
150 | - `min_kl`: the minimum kl divergence for the 'min_kl' stop criterion only. Default is 0
151 | - `max_sequential_increases`: the maximum number of sequential KL divergence increases for the 'sequential_increase_tolerance' stop criterion only. Default is 3
152 | - `random_init_pct`: the percent of training data to initialize the algorithm from. Default is 0
153 | - `random_restart_prob`: probability at any given iteration to extend the gradient descent iterations by 3x, to find potentially better extrema. Higher values come at the cost of efficiency. Default is 0
154 | - `scale_factor`: factor to scale the gradient by, or 'auto'. Default is 'auto', which is recommended
155 | - `v_init`: how to initialize v in gradients descent, one of the following: 'mean', 'prev_opt', 'jump'. Default is 'mean'
156 | - `grad_desc_iter`: the number of iterations to use in gradient descent. Default is 50
157 | - `discard_nearest_for_xy`: discard nearest in the xy calculation of KL divergence, for use when X and the train set are the same, comes at the cost of efficiency. Default is False
158 | - `lr`: Learning rate for gradient descent. Default is 0.01
159 |
160 | ## Citing GIO
161 | If you use GIO in a publication, blog or software project, please cite the paper:
162 | ```
163 | @misc{everaert2023gio,
164 | title={GIO: Gradient Information Optimization for Training Dataset Selection},
165 | author={Dante Everaert and Christopher Potts},
166 | year={2023},
167 | eprint={2306.11670},
168 | archivePrefix={arXiv},
169 | primaryClass={cs.LG}
170 | }
171 | ```
172 |
--------------------------------------------------------------------------------
/build/lib/GIO/GIOKL.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import grad
3 | import random
4 | import jax
5 |
6 | import numpy as np
7 | import pyspark.sql.functions as F
8 |
9 | from pyspark.ml.clustering import KMeans
10 | from pyspark.sql import SparkSession
11 | from pyspark.sql.types import *
12 | from .GIO_super import GIO_super
13 |
14 |
15 | class GIOKL(GIO_super):
16 | def __init__(self, uniform_low=-1, uniform_high=1, uniform_start_size=20, dim=768):
17 | super().__init__()
18 | self.spark = SparkSession.builder.getOrCreate()
19 | self.uniform_low = uniform_low
20 | self.uniform_high = uniform_high
21 | self.uniform_start_size = uniform_start_size
22 | self.dim = dim
23 | self.random_init = False
24 |
25 | def _get_nearest(self, sample, point):
26 | """Euclidean distance from point to it's nearest point in sample.
27 | :param sample: a set of points to compute the nearest distance to
28 | :param point: point to retrieve the nearest point in sample from
29 | :return: the index of the nearest point
30 | """
31 | norms = jnp.linalg.norm(sample - point, axis=1)
32 | return jnp.argsort(norms)[0]
33 |
34 | def _knn(self, x, y, k, last_only, discard_nearest, avg):
35 | """Find k_neighbors-nearest neighbor distances from y for each example in a minibatch x.
36 | :param x: tensor of shape [N_1, D]
37 | :param y: tensor of shape [N_2, D]
38 | :param k: the (k_neighbors+1):th nearest neighbor
39 | :param last_only: use only the last knn vs. all of them
40 | :param discard_nearest:
41 | :return: knn distances of shape [N, k_neighbors] or [N, 1] if last_only
42 | """
43 |
44 | dist_x = jnp.sum((x ** 2), axis=-1)[:, jnp.newaxis]
45 | dist_y = jnp.sum((y ** 2), axis=-1)[:, jnp.newaxis].T
46 | cross = - 2 * jnp.matmul(x, y.T)
47 | distmat = dist_x + cross + dist_y
48 | distmat = jnp.clip(distmat, 1e-10, 1e+20)
49 |
50 | if discard_nearest:
51 | if not avg:
52 | knn, _ = jax.lax.top_k(-distmat, k + 1)
53 | else:
54 | knn = -jnp.sort(distmat)
55 | knn = knn[:, 1:]
56 | else:
57 | knn = -distmat
58 |
59 | if last_only:
60 | knn = knn[:, -1:]
61 |
62 | return jnp.sqrt(-knn)
63 |
64 | def _kl_divergence_knn(self, x, y, k, eps, discard_nearest_for_xy):
65 | """KL divergence estimator for D(x~p || y~q).
66 | :param x: x~p
67 | :param y: y~q
68 | :param k: kth nearest neighbor
69 | :param discard_nearest_for_xy: discard nearest in the xy calculation
70 | :param eps: small epsilon to pass to log
71 | :return: scalar
72 | """
73 | n, d = x.shape
74 | m, _ = y.shape
75 | nns_xx = self._knn(x, x, k=k, last_only=True, discard_nearest=True, avg=False)
76 | nns_xy = self._knn(x, y, k=m, last_only=False, discard_nearest=discard_nearest_for_xy, avg=discard_nearest_for_xy)
77 |
78 | divergence = jnp.mean(d*jnp.log(nns_xy + eps) - d*jnp.log(nns_xx + eps)) + jnp.mean(jnp.log((k*m)/(jnp.arange(1, m+1) * (n-1))))
79 |
80 | return divergence
81 |
82 | def calculate_statistical_distance(self, x, y, k=5, eps=1e-8, discard_nearest_for_xy=False):
83 | """Calculate statistical distance d(p,q) based on x~p and y~q.
84 | :param x: x~p
85 | :param y: y~q
86 | :param k: kth nearest neighbor
87 | :param eps: small epsilon to pass to log
88 | :return: scalar
89 | """
90 | return self._kl_divergence_knn(x, y, k, eps, discard_nearest_for_xy)
91 |
92 | def gradient_descend(self, X, W, v, scaling_factor, max_iterations, lr=0.01, k=5, discard_nearest_for_xy=False):
93 | """Perform gradient descent on the statistical distance bwteen X and W+v
94 | :param X: target data
95 | :param W: current selected data
96 | :param v: initial v
97 | :param scaling_factor: scale the gradient
98 | :param max_iterations: iterations in the gradient descent
99 | :param lr: learning rate
100 | :param k: kth nearest neighbor
101 | :param discard_nearest_for_xy: discard nearest in the xy calculation
102 | :return: vector v opt
103 | """
104 | i = 0
105 | while i < max_iterations:
106 | gradient = grad(lambda v: self.calculate_statistical_distance(X, jnp.concatenate((W, v[jnp.newaxis, :])), k, discard_nearest_for_xy=discard_nearest_for_xy))(v)
107 | v = v - lr * scaling_factor * gradient
108 | i += 1
109 | return v
110 |
111 | def _get_uniform_start(self, do_normalize):
112 | """Get a uniform start for D.
113 | :return: jnp array of uniform points
114 | """
115 | def normalize(v):
116 | norm = np.linalg.norm(v)
117 | if norm == 0:
118 | return v
119 | return v / norm
120 | if do_normalize:
121 | return jnp.array([normalize(each) for each in np.random.uniform(low=self.uniform_low,high=self.uniform_high,size=(self.uniform_start_size,self.dim))])
122 | else:
123 | return jnp.array([each for each in np.random.uniform(low=self.uniform_low,high=self.uniform_high,size=(self.uniform_start_size,self.dim))])
124 |
125 | def fit(self, train, X, D=None, k=5, max_iter=100, stop_criterion="increase", min_difference=0, resets_allowed=False, max_resets=2, max_data_size=1, min_kl=0, max_sequential_increases=3, random_init_pct=0, random_restart_prob=0, scale_factor="auto", v_init='mean', grad_desc_iter=50, discard_nearest_for_xy=False, normalize=True, lr=0.01):
126 | """Perform GIO
127 | :param train: training data
128 | :param X: target data
129 | :param D: initial data
130 | :param k: kth nearest neighbor
131 | :param max_iter: max iterations for the algorithm
132 | :param stop_criterion: a string for the stopping criterion, one of the following: 'increase', 'max_resets', 'min_difference', 'sequential_increase_tolerance', 'min_kl', 'data_size'
133 | :param min_difference: the minimum difference between prior and current KL divergence for 'min_difference' stop criterion
134 | :param resets_allowed: whether if KL divergence increase, resetting G to the full train is allowed (allows the algorithm to pick duplicates). Must be set to true if the stop criterion is 'max_resets'
135 | :param max_resets: the number of resets allowed for the 'max_resets' stop criterion
136 | :param max_data_size: the maximum size of data for the 'data_size' stop criterion, as a percentage
137 | :param min_kl: the minimum kl divergence for the 'min_kl' stop criterion
138 | :param max_sequential_increases: the maximum number of sequential KL divergence increases for the 'sequential_increase_tolerance' stop criterion
139 | :param random_init_pct: the percent of training data to initialize the algorithm from
140 | :param random_restart_prob: probability to extend the gradient descent iterations by 3x to find potentially better extrema. Higher values come at the cost of efficiency
141 | :param scale_factor: factor to scale the gradient by or 'auto'
142 | :param v_init: how to initialize v in gradients descent, one of the following: 'mean', 'prev_opt', 'jump'
143 | :param grad_desc_iter: the number of iterations in gradient descent
144 | :param discard_nearest_for_xy: discard nearest in the xy calculation
145 | :param lr: Learning rate for gradient descent
146 | :return: selected data, kl divergences, (v, scale_factor, just_reset, num_resets, increases, adder, kl_divs)
147 | """
148 | if not random_init_pct and D is None:
149 | W = self._get_uniform_start(normalize)
150 | self.random_init = True
151 | elif D is None:
152 | amount = int(random_init_pct * len(train))
153 | W = jnp.array(random.sample(train.tolist(), amount))
154 | else:
155 | W = D[:]
156 |
157 | kl_dist_prev = self.calculate_statistical_distance(X, W, k, discard_nearest_for_xy=discard_nearest_for_xy)
158 |
159 | print("Starting KL: " + str(kl_dist_prev))
160 | if v_init == 'mean' or v_init == 'prev_opt':
161 | v = jnp.mean(X, axis=0)
162 | elif v_init == 'jump':
163 | v = jnp.array(random.sample(X.tolist(), 1)).squeeze()
164 | adder = train[:]
165 | kl_divs = []
166 |
167 | scale_factor = jnp.linalg.norm(v)/jnp.linalg.norm(grad(lambda v: self.calculate_statistical_distance(X, jnp.concatenate((W, v[jnp.newaxis, :])), k, discard_nearest_for_xy=discard_nearest_for_xy))(v)) if scale_factor == "auto" else scale_factor
168 |
169 | i = 0
170 | just_reset = False
171 | num_resets = 0
172 | total_iter = 0
173 | increases = 0
174 | while True:
175 | # Warmup, reset or random restart
176 | if i == 0 or just_reset or random.random() < random_restart_prob:
177 | v = self.gradient_descend(X, W, v, scale_factor, grad_desc_iter * 3, lr=lr, k=k, discard_nearest_for_xy=discard_nearest_for_xy)
178 | else:
179 | v = self.gradient_descend(X, W, v, scale_factor, grad_desc_iter, lr=lr, k=k, discard_nearest_for_xy=discard_nearest_for_xy)
180 | idx = self._get_nearest(v, adder)
181 | minvals = adder[idx]
182 | adder = jnp.delete(adder, idx, axis=0)
183 |
184 | W_tmp = jnp.concatenate((W, jnp.array(minvals)[jnp.newaxis, :]))
185 |
186 | kl_dist = self.calculate_statistical_distance(X, W_tmp, k, discard_nearest_for_xy=discard_nearest_for_xy)
187 | print("KL Divergence at iteration " + str(i) + ": " + str(kl_dist))
188 |
189 | # STOPPING CRITERIA
190 | if total_iter > max_iter:
191 | break
192 |
193 | if v_init == 'mean':
194 | v = jnp.mean(X, axis=0)
195 | elif v_init == 'jump':
196 | v = jnp.array(random.sample(X.tolist(), 1)).squeeze()
197 |
198 | adder, i, just_reset, stop, v, increases, num_resets = self._test_stop_criterion(v_init, stop_criterion, kl_dist, kl_dist_prev, num_resets, max_resets, min_difference, increases, max_sequential_increases, min_kl, max_data_size, train, X, i, v, just_reset, resets_allowed, adder)
199 |
200 | if stop:
201 | break
202 | if not just_reset:
203 | W = W_tmp
204 | kl_divs += [kl_dist]
205 | kl_dist_prev = kl_dist
206 | i += 1
207 | total_iter += 1
208 | return W, kl_divs, (v, scale_factor, just_reset, num_resets, increases, adder, kl_divs)
209 |
210 | def _test_stop_criterion(self, v_init, stop_criterion, kl_dist, kl_dist_prev, num_resets, max_resets, min_difference, increases, max_sequential_increases, min_kl, max_data_size, train, X, i, v, just_reset, resets_allowed, adder):
211 | stop = False
212 | if stop_criterion == "increase" and kl_dist - kl_dist_prev > 0:
213 | stop = True
214 | elif stop_criterion == "max_resets" and kl_dist - kl_dist_prev > 0 and num_resets == max_resets:
215 | stop = True
216 | elif stop_criterion == "min_difference" and kl_dist_prev - kl_dist < min_difference:
217 | stop = True
218 | elif stop_criterion == 'sequential_increase_tolerance' and kl_dist - kl_dist_prev > 0 and increases == max_sequential_increases:
219 | stop = True
220 | elif stop_criterion == 'min_kl' and kl_dist < min_kl:
221 | stop = True
222 | elif stop_criterion == 'data_size' and i > int(max_data_size * len(train)):
223 | stop = True
224 | if stop:
225 | if just_reset:
226 | increases += 1
227 | if resets_allowed and num_resets < max_resets:
228 | num_resets += 1
229 | if v_init == 'prev_opt':
230 | v = jnp.mean(X, axis=0)
231 | print("KL Div Increase, Resetting G")
232 | adder = train[:]
233 | i -= 1
234 | stop = False
235 | just_reset = True
236 | else:
237 | just_reset = False
238 | increases = 0
239 | return adder, i, just_reset, stop, v, increases, num_resets
240 |
241 | def _return_kmeans(self, df, k, rseed):
242 | """Use Spark to perform K-Means
243 | :param df: dataframe to perform K-Means with
244 | :param k: number of clusters to compute
245 | :param rseed: random seed
246 | :return: k-means model, transformed df
247 | """
248 | kmeans = KMeans().setK(k).setSeed(rseed)
249 | model = kmeans.fit(df.select("features"))
250 | transformed_df = model.transform(df)
251 | return model, transformed_df
252 |
253 | def quantize(self, df_train, df_x, k=1500, rseed='auto', rseed1=234, rseed2=456):
254 | """Use Spark to perform K-Means
255 | :param df_train: train dataframe to quantize
256 | :param df_x: target dataframe to quantize
257 | :param k: number of clusters to compute
258 | :param rseed: 'auto' or 'manual'
259 | :param rseed1: first random seed
260 | :param rseed2: second random seed
261 | :return: k-means model, transformed df
262 | """
263 | if rseed == 'auto':
264 | rseed1 = random.randint(-1000,1000)
265 | rseed2 = random.randint(-1000,1000)
266 | model_train, transformed_train = self._return_kmeans(df_train, k, rseed1)
267 | model_X, transformed_X = self._return_kmeans(df_x, k, rseed2)
268 | return model_train, model_X, transformed_train, transformed_X
269 |
270 | def read_data_from_csv(self, path, path_X, delim="\t"):
271 | """Read in and process data stored in a csv. Data must be of the format: _c0, _c1, _c2 where _c2 contains the
272 | string representation of the vector, like "[0.1, 0.23, 0.45 ...]"
273 | :param path: path to training data
274 | :param path_X: path to target data
275 | :param delim: delimiter for csv file
276 | :return: train df, target df
277 | """
278 | new_schema = ArrayType(DoubleType(), containsNull=False)
279 | udf_json_to_arr = F.udf(lambda x: x, new_schema)
280 |
281 | df_read = self.spark.read.option("delimiter", delim).csv(path)
282 | df_with_embeddings = df_read.withColumn("features", udf_json_to_arr(F.from_json(F.col("_c2"), "array")))
283 |
284 | df_X_read = self.spark.read.option("delimiter", delim).csv(path_X)
285 | df_X_with_embeddings = df_X_read.withColumn("features", udf_json_to_arr(F.from_json(F.col("_c2"), "array")))
286 |
287 | return df_with_embeddings, df_X_with_embeddings
288 |
289 | def read_data_from_parquet(self, path, path_X):
290 | """Read in and process data stored in a parquet format. Data must contain a column "features" that stores an array
291 | of the vectors and be non-nullable.
292 | :param path: path to training data
293 | :param path_X: path to target data
294 | :return: train df, target df
295 | """
296 | new_schema = ArrayType(DoubleType(), containsNull=False)
297 | udf_no_null = F.udf(lambda x: x, new_schema)
298 |
299 | df_with_embeddings = self.spark.read.parquet(path).withColumn("features", udf_no_null(F.col("features")))
300 | df_X_with_embeddings = self.spark.read.parquet(path_X).withColumn("features", udf_no_null(F.col("features")))
301 | return df_with_embeddings, df_X_with_embeddings
302 |
303 | def explode(self, chosen_centroids, kmeans_transformed_df, kmeans_centroids_df):
304 | """Read in and process data stored in a parquet format. Data must contain a column "features" that stores an array
305 | of the vectors.
306 | :param path: path to training data
307 | :param path_X: path to target data
308 | :return: train df, target df
309 | """
310 | pre_existing_centroids = jnp.array([f[1] for f in sorted([[each[0], each[1]] for each in kmeans_centroids_df.collect()], key=lambda x: x[0])])
311 | paired = []
312 | for each in chosen_centroids:
313 | for i, x in enumerate(pre_existing_centroids.tolist()):
314 | if each.tolist() in [x]:
315 | paired += [i]
316 | print("Found " + str(len(paired)) + " centroids out of " + str(len(chosen_centroids)) + " selected centroids")
317 | full_selections_df = self.spark.createDataFrame(data=[(i, each) for i, each in enumerate(paired)], schema=["i", "id"]).join(kmeans_transformed_df, F.col("id") == F.col("prediction"))
318 | return full_selections_df
319 |
320 |
--------------------------------------------------------------------------------
/build/lib/GIO/GIO_super.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import grad
3 | import random
4 | import jax
5 |
6 | import numpy as np
7 | import pyspark.sql.functions as F
8 |
9 | from pyspark.ml.clustering import KMeans
10 | from pyspark.sql import SparkSession
11 | from pyspark.sql.types import *
12 |
13 |
14 | class GIO_super:
15 | def __init__(self):
16 | pass
17 |
18 | def calculate_statistical_distance(self, x, y):
19 | pass
20 |
21 | def gradient_descend(self, X, W, v, factor, max_iterations, lr, *arg):
22 | pass
23 |
24 | def fit(self, train, X, *arg):
25 | pass
26 |
27 | def quantize(self, df_train, df_x, quantize_into):
28 | pass
29 |
30 | def _get_nearest(self, sample, point):
31 | pass
32 |
33 | def explode(self, chosen_centroids, kmeans_transformed_df, kmeans_centroids_df):
34 | pass
35 |
--------------------------------------------------------------------------------
/build/lib/GIO/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/build/lib/GIO/__init__.py
--------------------------------------------------------------------------------
/build/lib/GIO/generate_text_embeddings.py:
--------------------------------------------------------------------------------
1 | from sentence_transformers import SentenceTransformer
2 | import time
3 |
4 |
5 |
6 | class GenerateEmbeddings:
7 | def __init__(self, model_name, device='cuda'):
8 | self.model = SentenceTransformer(model_name)
9 | self.device = device
10 |
11 | def generate_embeddings(self, input_file_path, output_file_path):
12 | """Generate Embeddings from a text file
13 | :param input_file_path: path to input text, one sentence per line
14 | :param output_file_path: path to desired output file
15 | """
16 | print('Reading File...')
17 | with open(input_file_path, 'r') as fp:
18 | sentences = fp.readlines()
19 | print('Generating Embeddings... This May Take a While')
20 | start = time.time()
21 | embeddings = self.model.encode(sentences, device=self.device)
22 | end = time.time()
23 |
24 | print("Time Taken (s): " + str(end - start))
25 |
26 | print("Writing Embeddings.. This May Take a While")
27 | with open(output_file_path, 'w') as op:
28 | for i, each in enumerate(embeddings):
29 | op.write(str(each.tolist()).strip() + "\n")
30 |
--------------------------------------------------------------------------------
/dist/grad_info_opt-0.1.2-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/dist/grad_info_opt-0.1.2-py3-none-any.whl
--------------------------------------------------------------------------------
/experiments/bm25_scripts/get_less.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import csv
3 |
4 | bad_data_counter = 0
5 | written = 0
6 |
7 | with open(sys.argv[1], "r", newline='') as fp, open(sys.argv[2], "w") as op:
8 | csv_fp = csv.reader(fp, delimiter=',')
9 | for i, each in enumerate(csv_fp):
10 | # Skip header
11 | if i == 0:
12 | continue
13 |
14 | # Expect length of 3; determine how many are bad data
15 | if len(each) < 3:
16 | bad_data_counter += 1
17 |
18 | # Filter if the topK ID is above the specified filter value
19 | if int(each[2]) > int(sys.argv[3]):
20 | continue
21 | else:
22 | written += 1
23 | op.write(each[0].strip() + "\n")
24 |
25 | print("Bad Data: " + str(bad_data_counter))
26 | print("Total Written: " + str(written))
27 |
--------------------------------------------------------------------------------
/experiments/bm25_scripts/get_pairs.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | dct = {}
4 | bad_data_counter = 0
5 | with open(sys.argv[1], "r") as fp:
6 | for each in fp:
7 | sepped = each.strip().split("\t")
8 | # Expect length of 2
9 | if len(sepped) != 2:
10 | bad_data_counter += 1
11 | else:
12 | dct[sepped[0].strip()] = sepped[1] # Create dictionary of input-output pairs to later look up and match
13 |
14 | print("Weird parses: " + str(bad_data_counter))
15 |
16 | lines = []
17 | not_in_dct = 0
18 | with open(sys.argv[2], "r") as fp:
19 | for each in fp:
20 | cleaned = each.strip()
21 | # Expect all data to be in the dictionary
22 | if cleaned not in dct:
23 | not_in_dct += 1
24 | else:
25 | lines += [cleaned + '\t' + dct[cleaned] + "\n"] # Write input \t output
26 |
27 | print("Not in Dict: " + str(not_in_dct))
28 |
29 | with open(sys.argv[3], "w") as op:
30 | op.writelines(lines)
31 |
32 |
33 |
--------------------------------------------------------------------------------
/experiments/bm25_scripts/make_csv.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import sys
3 |
4 | with open(sys.argv[1], "r") as fp, open(sys.argv[2], "w", newline='') as op:
5 | csv_fp = fp
6 | csv_op = csv.writer(op, delimiter=',')
7 | csv_op.writerow(["text","id"]) # Give column names
8 | total = 0
9 | for i, each in enumerate(csv_fp):
10 | csv_op.writerow([each.strip()] + [i]) # Write value and ID in CSV format
11 | total += 1
12 |
--------------------------------------------------------------------------------
/experiments/checks/negative_consistency.py:
--------------------------------------------------------------------------------
1 | from GIO import GIOKL
2 | import numpy as np
3 | import jax.numpy as jnp
4 | import matplotlib.pyplot as plt
5 |
6 | # Create some data
7 | def getX():
8 | mean = [3,4]
9 | cov = [[0.5,0],[0,0.5]]
10 | np.random.seed(1)
11 | x, y = np.random.multivariate_normal(mean, cov, 100).T
12 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
13 |
14 | def getXTest():
15 | mean = [300,400]
16 | cov = [[0.5,0],[0,0.5]]
17 | np.random.seed(5)
18 | x, y = np.random.multivariate_normal(mean, cov, 100).T
19 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
20 |
21 | X = getX()
22 | X_test = getXTest()
23 |
24 | # Initialize class
25 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2)
26 |
27 | # Perform the Algorithm
28 | W, kl_divs, _ = gio_kl.fit(X_test, X, normalize=False)
29 | W = W[100:] # Remove the uniform start
30 |
31 | # Plot results
32 | plt.plot(kl_divs)
33 | plt.title("KL Divergence vs. Iterations")
34 | plt.xlabel("Iterations")
35 | plt.ylabel("KL Divergence")
36 | plt.show()
37 | plt.clf()
38 | plt.scatter([each[0] for each in W], [each[1] for each in W], label='Selected Data')
39 | plt.scatter([each[0] for each in X], [each[1] for each in X], label='Target Data')
40 | plt.title("Target Data and Selected Data")
41 | plt.xlabel("Dimension 1")
42 | plt.ylabel("Dimension 2")
43 | plt.legend()
44 | plt.show()
45 |
--------------------------------------------------------------------------------
/experiments/checks/quantization_consistency.py:
--------------------------------------------------------------------------------
1 | from GIO import GIOKL
2 | import numpy as np
3 | import jax.numpy as jnp
4 | import matplotlib.pyplot as plt
5 | from pyspark.sql.types import *
6 | import pyspark.sql.functions as F
7 |
8 | # Create some data
9 | def getX():
10 | mean = [3,4]
11 | cov = [[0.5,0],[0,0.5]]
12 | np.random.seed(1)
13 | x, y = np.random.multivariate_normal(mean, cov, 100).T
14 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
15 |
16 | X = getX()
17 |
18 | new_schema = ArrayType(DoubleType(), containsNull=False)
19 | udf_no_null = F.udf(lambda x: x, new_schema)
20 |
21 | # Initialize class
22 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2)
23 | X_df = gio_kl.spark.createDataFrame(data=[(i, each.tolist()) for i, each in enumerate(X)], schema=["id", "features"]).withColumn("features", udf_no_null(F.col("features")))
24 |
25 | # Quantize data
26 | model_train, model_X, transformed_train, transformed_X = gio_kl.quantize(X_df, X_df)
27 | quantized_X = jnp.array(model_X.clusterCenters())
28 |
29 | # Calculate KL
30 | kl = gio_kl.calculate_statistical_distance(X, quantized_X)
31 |
32 | print("KL Divergence: " + str(kl))
33 |
--------------------------------------------------------------------------------
/experiments/checks/self_consistency.py:
--------------------------------------------------------------------------------
1 | from GIO import GIOKL
2 | import numpy as np
3 | import jax.numpy as jnp
4 | import matplotlib.pyplot as plt
5 |
6 | # Create some data
7 | def getX():
8 | mean = [3,4]
9 | cov = [[0.5,0],[0,0.5]]
10 | np.random.seed(1)
11 | x, y = np.random.multivariate_normal(mean, cov, 100).T
12 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
13 |
14 | def getXTest():
15 | mean = [3,4]
16 | cov = [[0.5,0],[0,0.5]]
17 | np.random.seed(5)
18 | x, y = np.random.multivariate_normal(mean, cov, 100).T
19 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
20 |
21 | X = getX()
22 | X_test = getXTest()
23 |
24 | # Initialize class
25 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2)
26 |
27 | # Perform the Algorithm
28 | W, kl_divs, _ = gio_kl.fit(X_test, X, normalize=False)
29 | W = W[100:] # Remove the uniform start
30 |
31 | # Plot results
32 | plt.plot(kl_divs)
33 | plt.title("KL Divergence vs. Iterations")
34 | plt.xlabel("Iterations")
35 | plt.ylabel("KL Divergence")
36 | plt.show()
37 | plt.clf()
38 | plt.scatter([each[0] for each in W], [each[1] for each in W], label='Selected Data')
39 | plt.scatter([each[0] for each in X], [each[1] for each in X], label='Target Data')
40 | plt.title("Target Data and Selected Data")
41 | plt.xlabel("Dimension 1")
42 | plt.ylabel("Dimension 2")
43 | plt.legend()
44 | plt.show()
45 |
--------------------------------------------------------------------------------
/experiments/fashion_mnist/csv_to_image.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | to_write = ""
4 |
5 | with open(sys.argv[1], "r") as fp, open(sys.argv[2], "w") as op:
6 | for i, each in enumerate(fp):
7 | # Skip header
8 | if i == 0:
9 | continue
10 | sepped = each.strip().split(",") # Split into label and image pixels
11 | to_write += "[" + ", ".join(sepped[1:]) + "]\t" + sepped[0] + "\n" # Make into array of image \t label format
12 | op.write(to_write)
13 |
--------------------------------------------------------------------------------
/experiments/fashion_mnist/image_to_csv.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | to_write = "label,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8,pixel9,pixel10,pixel11,pixel12,pixel13,pixel14,pixel15,pixel16,pixel17,pixel18,pixel19,pixel20,pixel21,pixel22,pixel23,pixel24,pixel25,pixel26,pixel27,pixel28,pixel29,pixel30,pixel31,pixel32,pixel33,pixel34,pixel35,pixel36,pixel37,pixel38,pixel39,pixel40,pixel41,pixel42,pixel43,pixel44,pixel45,pixel46,pixel47,pixel48,pixel49,pixel50,pixel51,pixel52,pixel53,pixel54,pixel55,pixel56,pixel57,pixel58,pixel59,pixel60,pixel61,pixel62,pixel63,pixel64,pixel65,pixel66,pixel67,pixel68,pixel69,pixel70,pixel71,pixel72,pixel73,pixel74,pixel75,pixel76,pixel77,pixel78,pixel79,pixel80,pixel81,pixel82,pixel83,pixel84,pixel85,pixel86,pixel87,pixel88,pixel89,pixel90,pixel91,pixel92,pixel93,pixel94,pixel95,pixel96,pixel97,pixel98,pixel99,pixel100,pixel101,pixel102,pixel103,pixel104,pixel105,pixel106,pixel107,pixel108,pixel109,pixel110,pixel111,pixel112,pixel113,pixel114,pixel115,pixel116,pixel117,pixel118,pixel119,pixel120,pixel121,pixel122,pixel123,pixel124,pixel125,pixel126,pixel127,pixel128,pixel129,pixel130,pixel131,pixel132,pixel133,pixel134,pixel135,pixel136,pixel137,pixel138,pixel139,pixel140,pixel141,pixel142,pixel143,pixel144,pixel145,pixel146,pixel147,pixel148,pixel149,pixel150,pixel151,pixel152,pixel153,pixel154,pixel155,pixel156,pixel157,pixel158,pixel159,pixel160,pixel161,pixel162,pixel163,pixel164,pixel165,pixel166,pixel167,pixel168,pixel169,pixel170,pixel171,pixel172,pixel173,pixel174,pixel175,pixel176,pixel177,pixel178,pixel179,pixel180,pixel181,pixel182,pixel183,pixel184,pixel185,pixel186,pixel187,pixel188,pixel189,pixel190,pixel191,pixel192,pixel193,pixel194,pixel195,pixel196,pixel197,pixel198,pixel199,pixel200,pixel201,pixel202,pixel203,pixel204,pixel205,pixel206,pixel207,pixel208,pixel209,pixel210,pixel211,pixel212,pixel213,pixel214,pixel215,pixel216,pixel217,pixel218,pixel219,pixel220,pixel221,pixel222,pixel223,pixel224,pixel225,pixel226,pixel227,pixel228,pixel229,pixel230,pixel231,pixel232,pixel233,pixel234,pixel235,pixel236,pixel237,pixel238,pixel239,pixel240,pixel241,pixel242,pixel243,pixel244,pixel245,pixel246,pixel247,pixel248,pixel249,pixel250,pixel251,pixel252,pixel253,pixel254,pixel255,pixel256,pixel257,pixel258,pixel259,pixel260,pixel261,pixel262,pixel263,pixel264,pixel265,pixel266,pixel267,pixel268,pixel269,pixel270,pixel271,pixel272,pixel273,pixel274,pixel275,pixel276,pixel277,pixel278,pixel279,pixel280,pixel281,pixel282,pixel283,pixel284,pixel285,pixel286,pixel287,pixel288,pixel289,pixel290,pixel291,pixel292,pixel293,pixel294,pixel295,pixel296,pixel297,pixel298,pixel299,pixel300,pixel301,pixel302,pixel303,pixel304,pixel305,pixel306,pixel307,pixel308,pixel309,pixel310,pixel311,pixel312,pixel313,pixel314,pixel315,pixel316,pixel317,pixel318,pixel319,pixel320,pixel321,pixel322,pixel323,pixel324,pixel325,pixel326,pixel327,pixel328,pixel329,pixel330,pixel331,pixel332,pixel333,pixel334,pixel335,pixel336,pixel337,pixel338,pixel339,pixel340,pixel341,pixel342,pixel343,pixel344,pixel345,pixel346,pixel347,pixel348,pixel349,pixel350,pixel351,pixel352,pixel353,pixel354,pixel355,pixel356,pixel357,pixel358,pixel359,pixel360,pixel361,pixel362,pixel363,pixel364,pixel365,pixel366,pixel367,pixel368,pixel369,pixel370,pixel371,pixel372,pixel373,pixel374,pixel375,pixel376,pixel377,pixel378,pixel379,pixel380,pixel381,pixel382,pixel383,pixel384,pixel385,pixel386,pixel387,pixel388,pixel389,pixel390,pixel391,pixel392,pixel393,pixel394,pixel395,pixel396,pixel397,pixel398,pixel399,pixel400,pixel401,pixel402,pixel403,pixel404,pixel405,pixel406,pixel407,pixel408,pixel409,pixel410,pixel411,pixel412,pixel413,pixel414,pixel415,pixel416,pixel417,pixel418,pixel419,pixel420,pixel421,pixel422,pixel423,pixel424,pixel425,pixel426,pixel427,pixel428,pixel429,pixel430,pixel431,pixel432,pixel433,pixel434,pixel435,pixel436,pixel437,pixel438,pixel439,pixel440,pixel441,pixel442,pixel443,pixel444,pixel445,pixel446,pixel447,pixel448,pixel449,pixel450,pixel451,pixel452,pixel453,pixel454,pixel455,pixel456,pixel457,pixel458,pixel459,pixel460,pixel461,pixel462,pixel463,pixel464,pixel465,pixel466,pixel467,pixel468,pixel469,pixel470,pixel471,pixel472,pixel473,pixel474,pixel475,pixel476,pixel477,pixel478,pixel479,pixel480,pixel481,pixel482,pixel483,pixel484,pixel485,pixel486,pixel487,pixel488,pixel489,pixel490,pixel491,pixel492,pixel493,pixel494,pixel495,pixel496,pixel497,pixel498,pixel499,pixel500,pixel501,pixel502,pixel503,pixel504,pixel505,pixel506,pixel507,pixel508,pixel509,pixel510,pixel511,pixel512,pixel513,pixel514,pixel515,pixel516,pixel517,pixel518,pixel519,pixel520,pixel521,pixel522,pixel523,pixel524,pixel525,pixel526,pixel527,pixel528,pixel529,pixel530,pixel531,pixel532,pixel533,pixel534,pixel535,pixel536,pixel537,pixel538,pixel539,pixel540,pixel541,pixel542,pixel543,pixel544,pixel545,pixel546,pixel547,pixel548,pixel549,pixel550,pixel551,pixel552,pixel553,pixel554,pixel555,pixel556,pixel557,pixel558,pixel559,pixel560,pixel561,pixel562,pixel563,pixel564,pixel565,pixel566,pixel567,pixel568,pixel569,pixel570,pixel571,pixel572,pixel573,pixel574,pixel575,pixel576,pixel577,pixel578,pixel579,pixel580,pixel581,pixel582,pixel583,pixel584,pixel585,pixel586,pixel587,pixel588,pixel589,pixel590,pixel591,pixel592,pixel593,pixel594,pixel595,pixel596,pixel597,pixel598,pixel599,pixel600,pixel601,pixel602,pixel603,pixel604,pixel605,pixel606,pixel607,pixel608,pixel609,pixel610,pixel611,pixel612,pixel613,pixel614,pixel615,pixel616,pixel617,pixel618,pixel619,pixel620,pixel621,pixel622,pixel623,pixel624,pixel625,pixel626,pixel627,pixel628,pixel629,pixel630,pixel631,pixel632,pixel633,pixel634,pixel635,pixel636,pixel637,pixel638,pixel639,pixel640,pixel641,pixel642,pixel643,pixel644,pixel645,pixel646,pixel647,pixel648,pixel649,pixel650,pixel651,pixel652,pixel653,pixel654,pixel655,pixel656,pixel657,pixel658,pixel659,pixel660,pixel661,pixel662,pixel663,pixel664,pixel665,pixel666,pixel667,pixel668,pixel669,pixel670,pixel671,pixel672,pixel673,pixel674,pixel675,pixel676,pixel677,pixel678,pixel679,pixel680,pixel681,pixel682,pixel683,pixel684,pixel685,pixel686,pixel687,pixel688,pixel689,pixel690,pixel691,pixel692,pixel693,pixel694,pixel695,pixel696,pixel697,pixel698,pixel699,pixel700,pixel701,pixel702,pixel703,pixel704,pixel705,pixel706,pixel707,pixel708,pixel709,pixel710,pixel711,pixel712,pixel713,pixel714,pixel715,pixel716,pixel717,pixel718,pixel719,pixel720,pixel721,pixel722,pixel723,pixel724,pixel725,pixel726,pixel727,pixel728,pixel729,pixel730,pixel731,pixel732,pixel733,pixel734,pixel735,pixel736,pixel737,pixel738,pixel739,pixel740,pixel741,pixel742,pixel743,pixel744,pixel745,pixel746,pixel747,pixel748,pixel749,pixel750,pixel751,pixel752,pixel753,pixel754,pixel755,pixel756,pixel757,pixel758,pixel759,pixel760,pixel761,pixel762,pixel763,pixel764,pixel765,pixel766,pixel767,pixel768,pixel769,pixel770,pixel771,pixel772,pixel773,pixel774,pixel775,pixel776,pixel777,pixel778,pixel779,pixel780,pixel781,pixel782,pixel783,pixel784\n"
4 |
5 | with open(sys.argv[1], "r") as fp, open(sys.argv[2], "w") as op:
6 | for each in fp:
7 | sepped = each.strip().split("\t") # Split into label, image
8 | to_write += sepped[1].strip() + "," # Label
9 | to_write += sepped[0].strip().replace("[", "").replace("]", "").replace(" ", "").replace(".0", "").strip() + "\n" # Remove parentheses and decimals
10 | op.write(to_write)
11 |
--------------------------------------------------------------------------------
/experiments/fashion_mnist/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import sys
4 | from tqdm import tqdm
5 | import torch
6 | import torch.nn as nn
7 | from skimage.transform import resize
8 | import torchvision.transforms as transforms
9 | from torch.utils.data import Dataset, DataLoader
10 | from torchvision import models
11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12 | import time
13 | from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
14 |
15 | train_csv = pd.read_csv(sys.argv[1])
16 | test_csv = pd.read_csv(sys.argv[2])
17 | valid_csv = pd.read_csv(sys.argv[3])
18 |
19 | class FashionDataset(Dataset):
20 | """Class to build a dataset from FashionMNIST using Pytorch class Dataset."""
21 |
22 | def __init__(self, data, transform=None):
23 | self.fashion_MNIST = list(data.values)
24 | self.transform = transform
25 |
26 | label = []
27 | image = []
28 |
29 | for i in self.fashion_MNIST:
30 | # first column is of labels.
31 | label.append(i[0])
32 | image.append(i[1:])
33 | self.labels = np.asarray(label)
34 | m = np.mean(image)
35 | std = np.std(image)
36 | # Dimension of Images = 28 * 28 * 1. where height = width = 28 and color_channels = 1.
37 | self.images = np.asarray((np.array(image)-m)/std).reshape(-1, 28, 28, 1).astype('float32')
38 |
39 | def __getitem__(self, index):
40 | label = self.labels[index]
41 | image = self.images[index]
42 | image = self.transform(resize(image, (224,224))/255)
43 |
44 | return image, label
45 |
46 | def __len__(self):
47 | return len(self.images)
48 |
49 |
50 | class MnistResNet(nn.Module):
51 | """Class that edits Resnet50 to do FashionMNIST classification."""
52 |
53 | def __init__(self):
54 | super(MnistResNet, self).__init__()
55 |
56 | # Load a pretrained resnet model from torchvision.models in Pytorch
57 | self.model = models.resnet50(pretrained=True)
58 |
59 | # Change the input layer to take Grayscale image, instead of RGB images.
60 | # Hence in_channels is set as 1
61 | self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
62 |
63 | # Change the output layer to output 10 classes instead of 1000 classes
64 | num_ftrs = self.model.fc.in_features
65 | self.model.fc = nn.Linear(num_ftrs, 10)
66 |
67 | def forward(self, x):
68 | return self.model(x)
69 |
70 |
71 | my_resnet = MnistResNet()
72 |
73 |
74 | def calculate_metric(metric_fn, true_y, pred_y):
75 | """Calculate a metric on the outputs and ground truth
76 | :param metric_fn: metric function to calculate
77 | :param true_y: ground truth labels
78 | :param pred_y: predicted labels
79 | :return: the output of the applie dmetric_fn
80 | """
81 | if metric_fn == accuracy_score:
82 | return metric_fn(true_y, pred_y)
83 | else:
84 | return metric_fn(true_y, pred_y, average="macro")
85 |
86 |
87 | def print_scores(p, r, f1, a, batch_size):
88 | """Print the P/R/F1 and Accuracy in a readable format
89 | :param p: precision
90 | :param r: recall
91 | :param f1: F1 between precision and recall
92 | :param a: accuracy
93 | :param batch_size: batch size used
94 | """
95 | for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
96 | print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")
97 |
98 |
99 | # Transform data into Tensor that has a range from 0 to 1
100 | train_set = FashionDataset(train_csv, transform=transforms.Compose([transforms.ToTensor()]))
101 | test_set = FashionDataset(test_csv, transform=transforms.Compose([transforms.ToTensor()]))
102 | valid_set = FashionDataset(valid_csv, transform=transforms.Compose([transforms.ToTensor()]))
103 |
104 | train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
105 | test_loader = DataLoader(test_set, batch_size=100, shuffle=True)
106 | valid_loader = DataLoader(valid_set, batch_size=100, shuffle=True)
107 |
108 |
109 | # model
110 | model = MnistResNet().to(device)
111 |
112 | # params
113 | epochs = 5
114 | batch_size = 100
115 |
116 |
117 | # loss function and optimizer
118 | loss_function = nn.CrossEntropyLoss()
119 |
120 | # optimizer
121 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
122 |
123 | start_ts = time.time()
124 |
125 | losses = []
126 | batches = len(train_loader)
127 | val_batches = len(valid_loader)
128 | test_batches = len(test_loader)
129 |
130 | # loop for every epoch (training + evaluation)
131 | for epoch in range(epochs):
132 | total_loss = 0
133 |
134 | # progress bar (works in Jupyter notebook too!)
135 | progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
136 |
137 | # ----------------- TRAINING --------------------
138 | # set model to training
139 | model.train()
140 |
141 | for i, data in progress:
142 | X, y = data[0].to(device), data[1].to(device)
143 |
144 | # training step for single batch
145 | model.zero_grad()
146 | outputs = model(X)
147 | loss = loss_function(outputs, y)
148 | loss.backward()
149 | optimizer.step()
150 |
151 | # getting training quality data
152 | current_loss = loss.item()
153 | total_loss += current_loss
154 |
155 | # updating progress bar
156 | progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
157 |
158 | # releasing unnecessary memory in GPU
159 | if torch.cuda.is_available():
160 | torch.cuda.empty_cache()
161 |
162 | # ----------------- VALIDATION -----------------
163 | val_losses = 0
164 | precision, recall, f1, accuracy = [], [], [], []
165 |
166 | # set model to evaluating (testing)
167 | model.eval()
168 | with torch.no_grad():
169 | for i, data in enumerate(valid_loader):
170 | X, y = data[0].to(device), data[1].to(device)
171 |
172 | outputs = model(X) # this gets the prediction from the network
173 |
174 | val_losses += loss_function(outputs, y)
175 |
176 | predicted_classes = torch.max(outputs, 1)[1] # get class from network's prediction
177 |
178 | # calculate P/R/F1/A metrics for batch
179 | for acc, metric in zip((precision, recall, f1, accuracy),
180 | (precision_score, recall_score, f1_score, accuracy_score)):
181 | acc.append(
182 | calculate_metric(metric, y.cpu(), predicted_classes.cpu())
183 | )
184 | test_losses = 0
185 | precision2, recall2, f12, accuracy2 = [], [], [], []
186 | for i, data in enumerate(test_loader):
187 | X, y = data[0].to(device), data[1].to(device)
188 |
189 | outputs = model(X) # this get's the prediction from the network
190 |
191 | test_losses += loss_function(outputs, y)
192 |
193 | predicted_classes = torch.max(outputs, 1)[1] # get class from network's prediction
194 |
195 | # calculate P/R/F1/A metrics for batch
196 | for acc, metric in zip((precision2, recall2, f12, accuracy2),
197 | (precision_score, recall_score, f1_score, accuracy_score)):
198 | acc.append(
199 | calculate_metric(metric, y.cpu(), predicted_classes.cpu())
200 | )
201 |
202 | print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}, test loss: {test_losses/test_batches}")
203 | print_scores(precision, recall, f1, accuracy, val_batches)
204 | print_scores(precision2, recall2, f12, accuracy2, test_batches)
205 | losses.append(total_loss/batches) # for plotting learning curve
206 | print(f"Training time: {time.time()-start_ts}s")
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
--------------------------------------------------------------------------------
/experiments/speller/calculate_scores.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | # Preprocess hypotheses and split into words
4 | with open(sys.argv[1], 'r') as fp:
5 | hypotheses = [each.strip().split(" ") for each in fp.readlines()]
6 |
7 | # Preprocess targets and split into words
8 | with open(sys.argv[2], 'r') as fp:
9 | targets = [each.strip().split(" ") for each in fp.readlines()]
10 |
11 | # Preprocess sources and split into words
12 | with open(sys.argv[3], 'r') as fp:
13 | sources = [each.strip().split(" ") for each in fp.readlines()]
14 |
15 | acc_numerator = 0
16 | acc_denominator = 0
17 | corr_numerator = 0
18 | corr_denominator = 0
19 | len_uneven = 0
20 | for i in range(len(hypotheses)):
21 | hypothesis = hypotheses[i]
22 | target = targets[i]
23 | source = sources[i]
24 |
25 | # Skip mismatches
26 | if len(hypothesis) != len(target):
27 | len_uneven += 1
28 | continue
29 |
30 | # Compute word-level correction rate
31 | for x in range(len(target)):
32 | if len(target) == len(source) and target[x] != source[x]:
33 | corr_denominator += 1
34 | if source[x] != hypothesis[x]:
35 | corr_numerator += 1
36 |
37 | # Compute word-level accuracy
38 | for x in range(len(target)):
39 | if target[x] == hypothesis[x]:
40 | acc_numerator += 1
41 | acc_denominator += 1
42 |
43 | print("Mismatch: " + str(len_uneven))
44 | print("Word-Level Accuracy: " + str(acc_numerator / acc_denominator))
45 | print("Correction Rate: " + str(corr_numerator / corr_denominator))
46 |
47 |
48 |
--------------------------------------------------------------------------------
/images/process.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/images/process.gif
--------------------------------------------------------------------------------
/images/process_once.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/images/process_once.gif
--------------------------------------------------------------------------------
/images/readme_ex1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/images/readme_ex1.png
--------------------------------------------------------------------------------
/images/readme_ex2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/images/readme_ex2.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="grad-info-opt",
8 | version="0.1.2",
9 | author="Dante Everaert",
10 | author_email="dante.everaert@berkeley.edu",
11 | description="Implementation of Gradient Information Optimization for efficient and scalable training data selection",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/daeveraert/gradient-information-optimization",
15 | project_urls={
16 | "Bug Tracker": "https://github.com/daeveraert/gradient-information-optimization/issues",
17 | },
18 | classifiers=[
19 | "Programming Language :: Python :: 3",
20 | "License :: OSI Approved :: Apache Software License",
21 | "Operating System :: OS Independent",
22 | ],
23 | package_dir={"": "src"},
24 | packages=setuptools.find_packages(where="src"),
25 | python_requires=">=3.6",
26 | install_requires=[
27 | 'jax>=0.3.25',
28 | 'pyspark>=2.4.8',
29 | 'numpy>=1.21.6',
30 | 'sentence_transformers>=2.2.2',
31 | 'jaxlib>=0.3.2',
32 | 'pandas>=1.0.5']
33 | )
--------------------------------------------------------------------------------
/src/GIO/GIOKL.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import grad
3 | import random
4 | import jax
5 |
6 | import numpy as np
7 | import pyspark.sql.functions as F
8 |
9 | from pyspark.ml.clustering import KMeans
10 | from pyspark.sql import SparkSession
11 | from pyspark.sql.types import *
12 | from .GIO_super import GIO_super
13 |
14 |
15 | class GIOKL(GIO_super):
16 | def __init__(self, uniform_low=-1, uniform_high=1, uniform_start_size=20, dim=768):
17 | super().__init__()
18 | self.spark = SparkSession.builder.getOrCreate()
19 | self.uniform_low = uniform_low
20 | self.uniform_high = uniform_high
21 | self.uniform_start_size = uniform_start_size
22 | self.dim = dim
23 | self.random_init = False
24 |
25 | def _get_nearest(self, sample, point):
26 | """Euclidean distance from point to it's nearest point in sample.
27 | :param sample: a set of points to compute the nearest distance to
28 | :param point: point to retrieve the nearest point in sample from
29 | :return: the index of the nearest point
30 | """
31 | norms = jnp.linalg.norm(sample - point, axis=1)
32 | return jnp.argsort(norms)[0]
33 |
34 | def _knn(self, x, y, k, last_only, discard_nearest, avg):
35 | """Find k_neighbors-nearest neighbor distances from y for each example in a minibatch x.
36 | :param x: tensor of shape [N_1, D]
37 | :param y: tensor of shape [N_2, D]
38 | :param k: the (k_neighbors+1):th nearest neighbor
39 | :param last_only: use only the last knn vs. all of them
40 | :param discard_nearest:
41 | :return: knn distances of shape [N, k_neighbors] or [N, 1] if last_only
42 | """
43 |
44 | dist_x = jnp.sum((x ** 2), axis=-1)[:, jnp.newaxis]
45 | dist_y = jnp.sum((y ** 2), axis=-1)[:, jnp.newaxis].T
46 | cross = - 2 * jnp.matmul(x, y.T)
47 | distmat = dist_x + cross + dist_y
48 | distmat = jnp.clip(distmat, 1e-10, 1e+20)
49 |
50 | if discard_nearest:
51 | if not avg:
52 | knn, _ = jax.lax.top_k(-distmat, k + 1)
53 | else:
54 | knn = -jnp.sort(distmat)
55 | knn = knn[:, 1:]
56 | else:
57 | knn = -distmat
58 |
59 | if last_only:
60 | knn = knn[:, -1:]
61 |
62 | return jnp.sqrt(-knn)
63 |
64 | def _kl_divergence_knn(self, x, y, k, eps, discard_nearest_for_xy):
65 | """KL divergence estimator for D(x~p || y~q).
66 | :param x: x~p
67 | :param y: y~q
68 | :param k: kth nearest neighbor
69 | :param discard_nearest_for_xy: discard nearest in the xy calculation
70 | :param eps: small epsilon to pass to log
71 | :return: scalar
72 | """
73 | n, d = x.shape
74 | m, _ = y.shape
75 | nns_xx = self._knn(x, x, k=k, last_only=True, discard_nearest=True, avg=False)
76 | nns_xy = self._knn(x, y, k=m, last_only=False, discard_nearest=discard_nearest_for_xy, avg=discard_nearest_for_xy)
77 |
78 | divergence = jnp.mean(d*jnp.log(nns_xy + eps) - d*jnp.log(nns_xx + eps)) + jnp.mean(jnp.log((k*m)/(jnp.arange(1, m+1) * (n-1))))
79 |
80 | return divergence
81 |
82 | def calculate_statistical_distance(self, x, y, k=5, eps=1e-8, discard_nearest_for_xy=False):
83 | """Calculate statistical distance d(p,q) based on x~p and y~q.
84 | :param x: x~p
85 | :param y: y~q
86 | :param k: kth nearest neighbor
87 | :param eps: small epsilon to pass to log
88 | :return: scalar
89 | """
90 | return self._kl_divergence_knn(x, y, k, eps, discard_nearest_for_xy)
91 |
92 | def gradient_descend(self, X, W, v, scaling_factor, max_iterations, lr=0.01, k=5, discard_nearest_for_xy=False):
93 | """Perform gradient descent on the statistical distance bwteen X and W+v
94 | :param X: target data
95 | :param W: current selected data
96 | :param v: initial v
97 | :param scaling_factor: scale the gradient
98 | :param max_iterations: iterations in the gradient descent
99 | :param lr: learning rate
100 | :param k: kth nearest neighbor
101 | :param discard_nearest_for_xy: discard nearest in the xy calculation
102 | :return: vector v opt
103 | """
104 | i = 0
105 | while i < max_iterations:
106 | gradient = grad(lambda v: self.calculate_statistical_distance(X, jnp.concatenate((W, v[jnp.newaxis, :])), k, discard_nearest_for_xy=discard_nearest_for_xy))(v)
107 | v = v - lr * scaling_factor * gradient
108 | i += 1
109 | return v
110 |
111 | def _get_uniform_start(self, do_normalize):
112 | """Get a uniform start for D.
113 | :return: jnp array of uniform points
114 | """
115 | def normalize(v):
116 | norm = np.linalg.norm(v)
117 | if norm == 0:
118 | return v
119 | return v / norm
120 | if do_normalize:
121 | return jnp.array([normalize(each) for each in np.random.uniform(low=self.uniform_low,high=self.uniform_high,size=(self.uniform_start_size,self.dim))])
122 | else:
123 | return jnp.array([each for each in np.random.uniform(low=self.uniform_low,high=self.uniform_high,size=(self.uniform_start_size,self.dim))])
124 |
125 | def fit(self, train, X, D=None, k=5, max_iter=100, stop_criterion="increase", min_difference=0, resets_allowed=False, max_resets=2, max_data_size=1, min_kl=0, max_sequential_increases=3, random_init_pct=0, random_restart_prob=0, scale_factor="auto", v_init='mean', grad_desc_iter=50, discard_nearest_for_xy=False, normalize=True, lr=0.01):
126 | """Perform GIO
127 | :param train: training data
128 | :param X: target data
129 | :param D: initial data
130 | :param k: kth nearest neighbor
131 | :param max_iter: max iterations for the algorithm
132 | :param stop_criterion: a string for the stopping criterion, one of the following: 'increase', 'max_resets', 'min_difference', 'sequential_increase_tolerance', 'min_kl', 'data_size'
133 | :param min_difference: the minimum difference between prior and current KL divergence for 'min_difference' stop criterion
134 | :param resets_allowed: whether if KL divergence increase, resetting G to the full train is allowed (allows the algorithm to pick duplicates). Must be set to true if the stop criterion is 'max_resets'
135 | :param max_resets: the number of resets allowed for the 'max_resets' stop criterion
136 | :param max_data_size: the maximum size of data for the 'data_size' stop criterion, as a percentage
137 | :param min_kl: the minimum kl divergence for the 'min_kl' stop criterion
138 | :param max_sequential_increases: the maximum number of sequential KL divergence increases for the 'sequential_increase_tolerance' stop criterion
139 | :param random_init_pct: the percent of training data to initialize the algorithm from
140 | :param random_restart_prob: probability to extend the gradient descent iterations by 3x to find potentially better extrema. Higher values come at the cost of efficiency
141 | :param scale_factor: factor to scale the gradient by or 'auto'
142 | :param v_init: how to initialize v in gradients descent, one of the following: 'mean', 'prev_opt', 'jump'
143 | :param grad_desc_iter: the number of iterations in gradient descent
144 | :param discard_nearest_for_xy: discard nearest in the xy calculation
145 | :param lr: Learning rate for gradient descent
146 | :return: selected data, kl divergences, (v, scale_factor, just_reset, num_resets, increases, adder, kl_divs)
147 | """
148 | if not random_init_pct and D is None:
149 | W = self._get_uniform_start(normalize)
150 | self.random_init = True
151 | elif D is None:
152 | amount = int(random_init_pct * len(train))
153 | W = jnp.array(random.sample(train.tolist(), amount))
154 | else:
155 | W = D[:]
156 |
157 | kl_dist_prev = self.calculate_statistical_distance(X, W, k, discard_nearest_for_xy=discard_nearest_for_xy)
158 |
159 | print("Starting KL: " + str(kl_dist_prev))
160 | if v_init == 'mean' or v_init == 'prev_opt':
161 | v = jnp.mean(X, axis=0)
162 | elif v_init == 'jump':
163 | v = jnp.array(random.sample(X.tolist(), 1)).squeeze()
164 | adder = train[:]
165 | kl_divs = []
166 |
167 | scale_factor = jnp.linalg.norm(v)/jnp.linalg.norm(grad(lambda v: self.calculate_statistical_distance(X, jnp.concatenate((W, v[jnp.newaxis, :])), k, discard_nearest_for_xy=discard_nearest_for_xy))(v)) if scale_factor == "auto" else scale_factor
168 |
169 | i = 0
170 | just_reset = False
171 | num_resets = 0
172 | total_iter = 0
173 | increases = 0
174 | while True:
175 | # Warmup, reset or random restart
176 | if i == 0 or just_reset or random.random() < random_restart_prob:
177 | v = self.gradient_descend(X, W, v, scale_factor, grad_desc_iter * 3, lr=lr, k=k, discard_nearest_for_xy=discard_nearest_for_xy)
178 | else:
179 | v = self.gradient_descend(X, W, v, scale_factor, grad_desc_iter, lr=lr, k=k, discard_nearest_for_xy=discard_nearest_for_xy)
180 | idx = self._get_nearest(v, adder)
181 | minvals = adder[idx]
182 | adder = jnp.delete(adder, idx, axis=0)
183 |
184 | W_tmp = jnp.concatenate((W, jnp.array(minvals)[jnp.newaxis, :]))
185 |
186 | kl_dist = self.calculate_statistical_distance(X, W_tmp, k, discard_nearest_for_xy=discard_nearest_for_xy)
187 | print("KL Divergence at iteration " + str(i) + ": " + str(kl_dist))
188 |
189 | # STOPPING CRITERIA
190 | if total_iter > max_iter:
191 | break
192 |
193 | if v_init == 'mean':
194 | v = jnp.mean(X, axis=0)
195 | elif v_init == 'jump':
196 | v = jnp.array(random.sample(X.tolist(), 1)).squeeze()
197 |
198 | adder, i, just_reset, stop, v, increases, num_resets = self._test_stop_criterion(v_init, stop_criterion, kl_dist, kl_dist_prev, num_resets, max_resets, min_difference, increases, max_sequential_increases, min_kl, max_data_size, train, X, i, v, just_reset, resets_allowed, adder)
199 |
200 | if stop:
201 | break
202 | if not just_reset:
203 | W = W_tmp
204 | kl_divs += [kl_dist]
205 | kl_dist_prev = kl_dist
206 | i += 1
207 | total_iter += 1
208 | return W, kl_divs, (v, scale_factor, just_reset, num_resets, increases, adder, kl_divs)
209 |
210 | def _test_stop_criterion(self, v_init, stop_criterion, kl_dist, kl_dist_prev, num_resets, max_resets, min_difference, increases, max_sequential_increases, min_kl, max_data_size, train, X, i, v, just_reset, resets_allowed, adder):
211 | stop = False
212 | if stop_criterion == "increase" and kl_dist - kl_dist_prev > 0:
213 | stop = True
214 | elif stop_criterion == "max_resets" and kl_dist - kl_dist_prev > 0 and num_resets == max_resets:
215 | stop = True
216 | elif stop_criterion == "min_difference" and kl_dist_prev - kl_dist < min_difference:
217 | stop = True
218 | elif stop_criterion == 'sequential_increase_tolerance' and kl_dist - kl_dist_prev > 0 and increases == max_sequential_increases:
219 | stop = True
220 | elif stop_criterion == 'min_kl' and kl_dist < min_kl:
221 | stop = True
222 | elif stop_criterion == 'data_size' and i > int(max_data_size * len(train)):
223 | stop = True
224 | if stop:
225 | if just_reset:
226 | increases += 1
227 | if resets_allowed and num_resets < max_resets:
228 | num_resets += 1
229 | if v_init == 'prev_opt':
230 | v = jnp.mean(X, axis=0)
231 | print("KL Div Increase, Resetting G")
232 | adder = train[:]
233 | i -= 1
234 | stop = False
235 | just_reset = True
236 | else:
237 | just_reset = False
238 | increases = 0
239 | return adder, i, just_reset, stop, v, increases, num_resets
240 |
241 | def _return_kmeans(self, df, k, rseed):
242 | """Use Spark to perform K-Means
243 | :param df: dataframe to perform K-Means with
244 | :param k: number of clusters to compute
245 | :param rseed: random seed
246 | :return: k-means model, transformed df
247 | """
248 | kmeans = KMeans().setK(k).setSeed(rseed)
249 | model = kmeans.fit(df.select("features"))
250 | transformed_df = model.transform(df)
251 | return model, transformed_df
252 |
253 | def quantize(self, df_train, df_x, k=1500, rseed='auto', rseed1=234, rseed2=456):
254 | """Use Spark to perform K-Means
255 | :param df_train: train dataframe to quantize
256 | :param df_x: target dataframe to quantize
257 | :param k: number of clusters to compute
258 | :param rseed: 'auto' or 'manual'
259 | :param rseed1: first random seed
260 | :param rseed2: second random seed
261 | :return: k-means model, transformed df
262 | """
263 | if rseed == 'auto':
264 | rseed1 = random.randint(-1000,1000)
265 | rseed2 = random.randint(-1000,1000)
266 | model_train, transformed_train = self._return_kmeans(df_train, k, rseed1)
267 | model_X, transformed_X = self._return_kmeans(df_x, k, rseed2)
268 | return model_train, model_X, transformed_train, transformed_X
269 |
270 | def read_data_from_csv(self, path, path_X, delim="\t"):
271 | """Read in and process data stored in a csv. Data must be of the format: _c0, _c1, _c2 where _c2 contains the
272 | string representation of the vector, like "[0.1, 0.23, 0.45 ...]"
273 | :param path: path to training data
274 | :param path_X: path to target data
275 | :param delim: delimiter for csv file
276 | :return: train df, target df
277 | """
278 | new_schema = ArrayType(DoubleType(), containsNull=False)
279 | udf_json_to_arr = F.udf(lambda x: x, new_schema)
280 |
281 | df_read = self.spark.read.option("delimiter", delim).csv(path)
282 | df_with_embeddings = df_read.withColumn("features", udf_json_to_arr(F.from_json(F.col("_c2"), "array")))
283 |
284 | df_X_read = self.spark.read.option("delimiter", delim).csv(path_X)
285 | df_X_with_embeddings = df_X_read.withColumn("features", udf_json_to_arr(F.from_json(F.col("_c2"), "array")))
286 |
287 | return df_with_embeddings, df_X_with_embeddings
288 |
289 | def read_data_from_parquet(self, path, path_X):
290 | """Read in and process data stored in a parquet format. Data must contain a column "features" that stores an array
291 | of the vectors and be non-nullable.
292 | :param path: path to training data
293 | :param path_X: path to target data
294 | :return: train df, target df
295 | """
296 | new_schema = ArrayType(DoubleType(), containsNull=False)
297 | udf_no_null = F.udf(lambda x: x, new_schema)
298 |
299 | df_with_embeddings = self.spark.read.parquet(path).withColumn("features", udf_no_null(F.col("features")))
300 | df_X_with_embeddings = self.spark.read.parquet(path_X).withColumn("features", udf_no_null(F.col("features")))
301 | return df_with_embeddings, df_X_with_embeddings
302 |
303 | def explode(self, chosen_centroids, kmeans_transformed_df, kmeans_centroids_df):
304 | """Read in and process data stored in a parquet format. Data must contain a column "features" that stores an array
305 | of the vectors.
306 | :param path: path to training data
307 | :param path_X: path to target data
308 | :return: train df, target df
309 | """
310 | pre_existing_centroids = jnp.array([f[1] for f in sorted([[each[0], each[1]] for each in kmeans_centroids_df.collect()], key=lambda x: x[0])])
311 | paired = []
312 | for each in chosen_centroids:
313 | for i, x in enumerate(pre_existing_centroids.tolist()):
314 | if each.tolist() in [x]:
315 | paired += [i]
316 | print("Found " + str(len(paired)) + " centroids out of " + str(len(chosen_centroids)) + " selected centroids")
317 | full_selections_df = self.spark.createDataFrame(data=[(i, each) for i, each in enumerate(paired)], schema=["i", "id"]).join(kmeans_transformed_df, F.col("id") == F.col("prediction"))
318 | return full_selections_df
319 |
320 |
--------------------------------------------------------------------------------
/src/GIO/GIO_super.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import grad
3 | import random
4 | import jax
5 |
6 | import numpy as np
7 | import pyspark.sql.functions as F
8 |
9 | from pyspark.ml.clustering import KMeans
10 | from pyspark.sql import SparkSession
11 | from pyspark.sql.types import *
12 |
13 |
14 | class GIO_super:
15 | def __init__(self):
16 | pass
17 |
18 | def calculate_statistical_distance(self, x, y):
19 | pass
20 |
21 | def gradient_descend(self, X, W, v, factor, max_iterations, lr, *arg):
22 | pass
23 |
24 | def fit(self, train, X, *arg):
25 | pass
26 |
27 | def quantize(self, df_train, df_x, quantize_into):
28 | pass
29 |
30 | def _get_nearest(self, sample, point):
31 | pass
32 |
33 | def explode(self, chosen_centroids, kmeans_transformed_df, kmeans_centroids_df):
34 | pass
35 |
--------------------------------------------------------------------------------
/src/GIO/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daeveraert/gradient-information-optimization/2424d7a50a2aa8575b181f757df25826cdf24dc7/src/GIO/__init__.py
--------------------------------------------------------------------------------
/src/GIO/generate_text_embeddings.py:
--------------------------------------------------------------------------------
1 | from sentence_transformers import SentenceTransformer
2 | import time
3 |
4 |
5 |
6 | class GenerateEmbeddings:
7 | def __init__(self, model_name, device='cuda'):
8 | self.model = SentenceTransformer(model_name)
9 | self.device = device
10 |
11 | def generate_embeddings(self, input_file_path, output_file_path):
12 | """Generate Embeddings from a text file
13 | :param input_file_path: path to input text, one sentence per line
14 | :param output_file_path: path to desired output file
15 | """
16 | print('Reading File...')
17 | with open(input_file_path, 'r') as fp:
18 | sentences = fp.readlines()
19 | print('Generating Embeddings... This May Take a While')
20 | start = time.time()
21 | embeddings = self.model.encode(sentences, device=self.device)
22 | end = time.time()
23 |
24 | print("Time Taken (s): " + str(end - start))
25 |
26 | print("Writing Embeddings.. This May Take a While")
27 | with open(output_file_path, 'w') as op:
28 | for i, each in enumerate(embeddings):
29 | op.write(str(each.tolist()).strip() + "\n")
30 |
--------------------------------------------------------------------------------
/src/grad_info_opt.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: grad-info-opt
3 | Version: 0.1.2
4 | Summary: Implementation of Gradient Information Optimization for efficient and scalable training data selection
5 | Home-page: https://github.com/daeveraert/gradient-information-optimization
6 | Author: Dante Everaert
7 | Author-email: dante.everaert@berkeley.edu
8 | License: UNKNOWN
9 | Project-URL: Bug Tracker, https://github.com/daeveraert/gradient-information-optimization/issues
10 | Platform: UNKNOWN
11 | Classifier: Programming Language :: Python :: 3
12 | Classifier: License :: OSI Approved :: Apache Software License
13 | Classifier: Operating System :: OS Independent
14 | Requires-Python: >=3.6
15 | Description-Content-Type: text/markdown
16 | License-File: LICENSE
17 |
18 | # GIO: Gradient Information Optimization
19 |
20 |
21 |
22 |
23 | GIO is a library that implements Gradient Information Optimization (GIO) at scale, from the paper GIO: Gradient Information Optimization for Training Dataset Selection. GIO is a data selection technique that can
24 | be used to select a subset of training data that gives similar or superior performance to a model trained on full data.
25 |
26 | **Paper Abstract**
27 |
28 | It is often advantageous to train models on a subset of the available train examples, because the examples are of variable quality or because one would like to train with fewer examples, without sacrificing performance. We present Gradient Information Optimization (GIO), a scalable, task-agnostic approach to this data selection problem that requires only a small set of (unlabeled) examples representing a target distribution. GIO begins from a natural, information-theoretic objective that is intractable in practice. Our contribution is in showing that it can be made highly scalable through a simple relaxation of the objective and a highly efficient implementation. In experiments with machine translation, spelling correction, and image recognition, we show that GIO delivers outstanding results with very small train sets. These findings are robust to different representation models and hyperparameters for GIO itself. GIO is task- and domain-agnostic and can be applied out-of-the-box to new datasets and domains.
29 |
30 |
31 | **Features**:
32 | - GIO with quantization using K-means.
33 | - Sentence embedding script to generate embeddings from data to use in GIO
34 |
35 |
36 | ## Installation
37 |
38 | Installable via pip:
39 | ```bash
40 | pip install grad-info-opt
41 | ```
42 | Or install directly form the repository:
43 |
44 | ```bash
45 | git clone git@github.com:daeveraert/gradient-information-optimization.git
46 | cd gradient-information-optimization
47 | pip install -e .
48 | ```
49 |
50 | Direct installation will require you to install additional dependencies listed below. We welcome contributions to GIO.
51 |
52 | ## Requirements
53 | - `numpy>=1.21.6`
54 | - `jax>=0.3.25`
55 | - `pyspark>=2.4.8`
56 | - `sentence_transformers>=2.2.2`
57 | - `jaxlib>=0.3.2`
58 | - `pandas>=1.0.5`
59 |
60 |
61 |
62 | ## Quick Start
63 | **Note:** GIO uses a Spark context, or if it can't find one, it will create a local one. You may encounter a Spark error before the algorithm runs complaining it cannot find a free port. In this case, executing ```export SPARK_LOCAL_IP="127.0.0.1"``` should resolve the issue.
64 |
65 | Here is a simple 2D demonstration of how to use GIO with visualization:
66 | ```python
67 | from GIO import GIOKL
68 | import numpy as np
69 | import jax.numpy as jnp
70 | import matplotlib.pyplot as plt
71 |
72 | # Create some data
73 | def getX():
74 | mean = [3,4]
75 | cov = [[0.5,0],[0,0.5]]
76 | np.random.seed(1)
77 | x, y = np.random.multivariate_normal(mean, cov, 100).T
78 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
79 |
80 | def getXTest():
81 | mean = [3,4]
82 | cov = [[0.5,0],[0,0.5]]
83 | np.random.seed(5)
84 | x, y = np.random.multivariate_normal(mean, cov, 100).T
85 | return jnp.array([[x[i],y[i]] for i in range(len(x))])
86 |
87 | X = getX()
88 | X_test = getXTest()
89 |
90 | # Initialize class
91 | gio_kl = GIOKL.GIOKL(uniform_low=0, uniform_high=8, uniform_start_size=100, dim=2)
92 |
93 | # Perform the Algorithm
94 | W, kl_divs, _ = gio_kl.fit(X_test, X, normalize=False)
95 | W = W[100:] # Remove the uniform start
96 |
97 | # Plot results
98 | plt.plot(kl_divs)
99 | plt.title("KL Divergence vs. Iterations")
100 | plt.xlabel("Iterations")
101 | plt.ylabel("KL Divergence")
102 | plt.show()
103 | plt.clf()
104 | plt.scatter([each[0] for each in W], [each[1] for each in W], label='Selected Data')
105 | plt.scatter([each[0] for each in X], [each[1] for each in X], label='Target Data')
106 | plt.title("Target Data and Selected Data")
107 | plt.xlabel("Dimension 1")
108 | plt.ylabel("Dimension 2")
109 | plt.legend()
110 | plt.show()
111 | ```
112 |
113 |
114 |
115 |
116 |
117 | Here is a more complex example for scale applications, reading and using a CSV that stores embeddings and data, using quantization-explosion, and Spark:
118 | ```python
119 | from GIO import GIOKL
120 | import jax.numpy as jnp
121 | import matplotlib.pyplot as plt
122 | import pyspark.sql.functions as F
123 |
124 | # Initialize class
125 | gio_kl = GIOKL.GIOKL(uniform_low=-1, uniform_high=1, uniform_start_size=20, dim=768)
126 |
127 | # Read data
128 | train_df, target_df = gio_kl.read_data_from_csv(PATH_TO_TRAIN, PATH_TO_TARGET)
129 |
130 | # Quantize data
131 | model_train, model_X, transformed_train, transformed_X = gio_kl.quantize(train_df, target_df)
132 |
133 | X = jnp.array(model_X.clusterCenters())
134 | train = jnp.array(model_train.clusterCenters())
135 | centroids_df = gio_kl.spark.createDataFrame(data=[(i, each.tolist()) for i, each in enumerate(model_train.clusterCenters())], schema=["id", "centroid"])
136 |
137 | # Perform the Algorithm
138 | W, kl_divs, _ = gio_kl.fit(train, X, max_iter=300, stopping_criterion='sequential_increase_tolerance', v_init='jump')
139 | W = W[20:] # Remove the uniform start
140 |
141 | # Explode back to original data and write resulting data
142 | full_selections_df = gio_kl.explode(W, transformed_train, centroids_df)
143 | full_selections_df.select(F.col("_c0"), F.col("_c1")).write.option("delimiter", "\t").csv(OUTPUT_PATH)
144 |
145 |
146 | # Plot results
147 | plt.plot(kl_divs)
148 | plt.title("KL Divergence vs. Iterations")
149 | plt.xlabel("Iterations")
150 | plt.ylabel("KL Divergence")
151 | plt.show()
152 | ```
153 | **Note:** For quantization, Spark requires a large rpc message size. It is recommended to place ```gio_kl.spark.conf.set("spark.rpc.message.maxSize", "500")``` (or any large number) in the code before calling quantize, if the defaults haven't already been increased.
154 |
155 | ## Available Options
156 | `GIOKL.fit` takes the following arguments:
157 | - `train`: training data as a jnp array (jnp is almost identical to numpy) [M, D] shape
158 | - `X`: target data as a jnp array [N, D] shape
159 | - `D`: initial data as a jnp array, default None. Use None to initialize from 0 (uniform) or a subset of training data
160 | - `k`: kth nearest neighbor to use in the KL divergence estimation, default 5
161 | - `max_iter`: maximum iterations for the algorithm. One iteration adds one point (cluster)
162 | - `stop_criterion`: a string for the stopping criterion, one of the following: 'increase', 'max_resets', 'min_difference', 'sequential_increase_tolerance', 'min_kl', 'data_size'. Default is 'increase'
163 | - `min_difference`: the minimum difference between prior and current KL divergence for 'min_difference' stop criterion only. Default is 0
164 | - `resets_allowed`: whether if KL divergence increases, resetting G to the full train is allowed (allows the algorithm to pick duplicates). Must be set to true if the stop criterion is 'max_resets'. Default is False
165 | - `max_resets`: the number of resets allowed for the 'max_resets' stop criterion only (a reset resets G to the full train set and allows the algorithm to pick duplicates). Default is 2
166 | - `max_data_size`: the maximum size of data to be selected for the 'data_size' stop criterion only, as a percentage (of total data) between 0 and 1. Default is 1
167 | - `min_kl`: the minimum kl divergence for the 'min_kl' stop criterion only. Default is 0
168 | - `max_sequential_increases`: the maximum number of sequential KL divergence increases for the 'sequential_increase_tolerance' stop criterion only. Default is 3
169 | - `random_init_pct`: the percent of training data to initialize the algorithm from. Default is 0
170 | - `random_restart_prob`: probability at any given iteration to extend the gradient descent iterations by 3x, to find potentially better extrema. Higher values come at the cost of efficiency. Default is 0
171 | - `scale_factor`: factor to scale the gradient by, or 'auto'. Default is 'auto', which is recommended
172 | - `v_init`: how to initialize v in gradients descent, one of the following: 'mean', 'prev_opt', 'jump'. Default is 'mean'
173 | - `grad_desc_iter`: the number of iterations to use in gradient descent. Default is 50
174 | - `discard_nearest_for_xy`: discard nearest in the xy calculation of KL divergence, for use when X and the train set are the same, comes at the cost of efficiency. Default is False
175 | - `lr`: Learning rate for gradient descent. Default is 0.01
176 |
177 | ## Citing GIO
178 | If you use GIO in a publication, blog or software project, please cite the paper:
179 | ```
180 | @misc{everaert2023gio,
181 | title={GIO: Gradient Information Optimization for Training Dataset Selection},
182 | author={Dante Everaert and Christopher Potts},
183 | year={2023},
184 | eprint={2306.11670},
185 | archivePrefix={arXiv},
186 | primaryClass={cs.LG}
187 | }
188 | ```
189 |
190 |
191 |
--------------------------------------------------------------------------------
/src/grad_info_opt.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | README.md
3 | setup.py
4 | src/GIO/GIOKL.py
5 | src/GIO/GIO_super.py
6 | src/GIO/__init__.py
7 | src/GIO/generate_text_embeddings.py
8 | src/grad_info_opt.egg-info/PKG-INFO
9 | src/grad_info_opt.egg-info/SOURCES.txt
10 | src/grad_info_opt.egg-info/dependency_links.txt
11 | src/grad_info_opt.egg-info/requires.txt
12 | src/grad_info_opt.egg-info/top_level.txt
--------------------------------------------------------------------------------
/src/grad_info_opt.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/grad_info_opt.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | jax>=0.3.25
2 | pyspark>=2.4.8
3 | numpy>=1.21.6
4 | sentence_transformers>=2.2.2
5 | jaxlib>=0.3.2
6 | pandas>=1.0.5
7 |
--------------------------------------------------------------------------------
/src/grad_info_opt.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | GIO
2 |
--------------------------------------------------------------------------------