├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── benchmarks.ipynb
├── poetry.lock
├── pyproject.toml
├── pytest.ini
└── pytorch_resample
├── __init__.py
├── hybrid.py
├── over.py
├── under.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb filter=nbstripout
2 | *.ipynb diff=ipynb
3 | *.ipynb linguist-detectable=false
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | .ipynb_checkpoints
3 | *.pyc
4 | .vscode
5 | dist/
6 | .DS_Store
7 | *.csv
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Max Halford
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Iterable dataset resampling in PyTorch
2 |
3 | - [Motivation](#motivation)
4 | - [Installation](#installation)
5 | - [Usage](#usage)
6 | - [Under-sampling](#under-sampling)
7 | - [Over-sampling](#over-sampling)
8 | - [Hybrid method](#hybrid-method)
9 | - [Expected number of samples](#expected-number-of-samples)
10 | - [Performance tip](#performance-tip)
11 | - [Benchmarks](#benchmarks)
12 | - [How does it work?](#how-does-it-work)
13 | - [Development](#development)
14 | - [License](#license)
15 |
16 | ## Motivation
17 |
18 | [Imbalanced learning](https://www.jeremyjordan.me/imbalanced-data/) is a machine learning paradigm whereby a classifier has to learn from a dataset that has a skewed class distribution. An imbalanced dataset may have a detrimental impact on the classifier's performance.
19 |
20 | Rebalancing a dataset is one way to deal with class imbalance. This can be done by:
21 |
22 | 1. under-sampling common classes.
23 | 2. over-sampling rare classes.
24 | 3. doing a mix of both.
25 |
26 | PyTorch provides [some utilities](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler) for rebalancing a dataset, but they are limited to batch datasets of known length (i.e., they require a dataset to have a `__len__` method). Community contributions such as [ufoym/imbalanced-dataset-sampler](https://github.com/ufoym/imbalanced-dataset-sampler) are cute, but they also only work with batch datasets (also called *map-style* datasets in PyTorch jargon). There's also a [GitHub issue](https://github.com/pytorch/pytorch/issues/28743) opened on the PyTorch GitHub repository, but it doesn't seem very active.
27 |
28 | This repository implements data resamplers that wrap an [`IterableDataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset). Each data resampler also inherits from `IterableDataset`. The latter was added to PyTorch in [this pull request](https://github.com/pytorch/pytorch/pull/19228). In particular, the provided methods do not require you to have to know the size of your dataset in advance. Each methods works for both binary and multi-class classification.
29 |
30 | ☝️ If you're looking to sample your data completely at random, without taking into consideration the class distribution, then we recommend that you do it yourself in your `IterableDataset` implementation. Indeed, you just have to generate a random number between 0 and 1 and keep a sample if the sampled number is under a given threshold. This library is meant to be used when you want to use resampling to balance your class distribution.
31 |
32 | ## Installation
33 |
34 | ```sh
35 | $ pip install pytorch_resample
36 | ```
37 |
38 | ## Usage
39 |
40 | As a running example, we'll define an `IterableDataset` that iterates over the output of scikit-learn's [`make_classification`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html) function.
41 |
42 | ```py
43 | >>> from sklearn import datasets
44 | >>> import torch
45 |
46 | >>> class MakeClassificationStream(torch.utils.data.IterableDataset):
47 | ...
48 | ... def __init__(self, *args, **kwargs):
49 | ... self.X, self.y = datasets.make_classification(*args, **kwargs)
50 | ...
51 | ... def __iter__(self):
52 | ... yield from iter(zip(self.X, self.y))
53 |
54 | ```
55 |
56 | The above dataset can be provided to a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) in order to iterate over [`Tensor`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor) batches. For the sake of example, we'll generate 10.000 samples, with 50% of 0s, 40% of 1s, and 10% of 2s. We can use a [`collections.Counter`](https://docs.python.org/3/library/collections.html#collections.Counter) to measure the effective class distribution.
57 |
58 | ```py
59 | >>> import collections
60 |
61 | >>> dataset = MakeClassificationStream(
62 | ... n_samples=10_000,
63 | ... n_classes=3,
64 | ... n_informative=6,
65 | ... weights=[.5, .4, .1],
66 | ... random_state=42
67 | ... )
68 |
69 | >>> y_dist = collections.Counter()
70 |
71 | >>> batches = torch.utils.data.DataLoader(dataset, batch_size=16)
72 | >>> for xb, yb in batches:
73 | ... y_dist.update(yb.numpy())
74 |
75 | >>> for label in sorted(y_dist):
76 | ... frequency = y_dist[label] / sum(y_dist.values())
77 | ... print(f'• {label}: {frequency:.2%} ({y_dist[label]})')
78 | • 0: 49.95% (4995)
79 | • 1: 39.88% (3988)
80 | • 2: 10.17% (1017)
81 |
82 | ```
83 |
84 | ### Under-sampling
85 |
86 | The data stream can be under-sampled with the `pytorch_resample.UnderSampler` class. The latter is a wrapper that has to be provided with an `IterableDataset` and a desired class distribution. It inherits from `IterableDataset`, and may thus be used instead of the wrapped dataset. As an example, let's make it so that the classes are equally represented.
87 |
88 | ```py
89 | >>> import pytorch_resample
90 | >>> import torch
91 |
92 | >>> sample = pytorch_resample.UnderSampler(
93 | ... dataset=dataset,
94 | ... desired_dist={0: .33, 1: .33, 2: .33},
95 | ... seed=42
96 | ... )
97 |
98 | >>> isinstance(sample, torch.utils.data.IterableDataset)
99 | True
100 |
101 | >>> y_dist = collections.Counter()
102 |
103 | >>> batches = torch.utils.data.DataLoader(sample, batch_size=16)
104 | >>> for xb, yb in batches:
105 | ... y_dist.update(yb.numpy())
106 |
107 | >>> for label in sorted(y_dist):
108 | ... frequency = y_dist[label] / sum(y_dist.values())
109 | ... print(f'• {label}: {frequency:.2%} ({y_dist[label]})')
110 | • 0: 33.30% (1007)
111 | • 1: 33.10% (1001)
112 | • 2: 33.60% (1016)
113 |
114 | ```
115 |
116 | As shown, the observed class distribution is close to the specified distribution. Indeed, there are less 0s and 1s than above. Note that the values of the `desired_dist` parameter are not required to sum up to 1. Indeed, the distribution is normalized automatically.
117 |
118 | ### Over-sampling
119 |
120 | You may use `pytorch_resample.OverSampler` to instead oversample the data. It has the same signature as `pytorch_resample.UnderSampler`, and can thus be used in the exact same manner.
121 |
122 | ```py
123 | >>> sample = pytorch_resample.OverSampler(
124 | ... dataset=dataset,
125 | ... desired_dist={0: .33, 1: .33, 2: .33},
126 | ... seed=42
127 | ... )
128 |
129 | >>> isinstance(sample, torch.utils.data.IterableDataset)
130 | True
131 |
132 | >>> y_dist = collections.Counter()
133 |
134 | >>> batches = torch.utils.data.DataLoader(sample, batch_size=16)
135 | >>> for xb, yb in batches:
136 | ... y_dist.update(yb.numpy())
137 |
138 | >>> for label in sorted(y_dist):
139 | ... frequency = y_dist[label] / sum(y_dist.values())
140 | ... print(f'• {label}: {frequency:.2%} ({y_dist[label]})')
141 | • 0: 33.21% (4995)
142 | • 1: 33.01% (4965)
143 | • 2: 33.78% (5080)
144 |
145 | ```
146 |
147 | In this case, the 1s and 2s have been oversampled.
148 |
149 | ### Hybrid method
150 |
151 | The `pytorch_resample.HybridSampler` class can be used to compromise between under-sampling and over-sampling. It accepts an extra parameter called `sampling_rate`, which determines the percentage of data to use. This allows to control how much data is used for training, whilst ensuring that the class distribution follows the desired distribution.
152 |
153 | ```py
154 | >>> sample = pytorch_resample.HybridSampler(
155 | ... dataset=dataset,
156 | ... desired_dist={0: .33, 1: .33, 2: .33},
157 | ... sampling_rate=.5, # use 50% of the dataset
158 | ... seed=42
159 | ... )
160 |
161 | >>> isinstance(sample, torch.utils.data.IterableDataset)
162 | True
163 |
164 | >>> y_dist = collections.Counter()
165 |
166 | >>> batches = torch.utils.data.DataLoader(sample, batch_size=16)
167 | >>> for xb, yb in batches:
168 | ... y_dist.update(yb.numpy())
169 |
170 | >>> for label in sorted(y_dist):
171 | ... frequency = y_dist[label] / sum(y_dist.values())
172 | ... print(f'• {label}: {frequency:.2%} ({y_dist[label]})')
173 | • 0: 33.01% (1672)
174 | • 1: 32.91% (1667)
175 | • 2: 34.08% (1726)
176 |
177 | ```
178 |
179 | As can be observed, the amount of streamed samples is close to 5000, which is half the size of the dataset.
180 |
181 | ### Expected number of samples
182 |
183 | It's possible to determine the exact number of samples each resampler will stream back in advance, provided the class distribution of the data is known.
184 |
185 | ```py
186 | >>> n = 10_000
187 | >>> desired = {'cat': 1 / 3, 'mouse': 1 / 3, 'dog': 1 / 3}
188 | >>> actual = {'cat': .5, 'mouse': .4, 'dog': .1}
189 |
190 | >>> pytorch_resample.UnderSampler.expected_size(n, desired, actual)
191 | 3000
192 |
193 | >>> pytorch_resample.OverSampler.expected_size(n, desired, actual)
194 | 15000
195 |
196 | >>> pytorch_resample.HybridSampler.expected_size(n, .5)
197 | 5000
198 |
199 | ```
200 |
201 | ### Performance tip
202 |
203 | By design `OverSampler` and `HybridSampler` yield repeated samples one after the other. This might not be ideal, as it is usually desirable to diversify the samples within each batch. We therefore recommend that you use a [shuffling buffer](https://www.moderndescartes.com/essays/shuffle_viz/), such as the `ShuffleDataset` class proposed [here](https://discuss.pytorch.org/t/how-to-shuffle-an-iterable-dataset/64130/6).
204 |
205 | ## Benchmarks
206 |
207 | I've written a [simple benchmark](benchmarks.ipynb) to verify that resampling brings a performance boost and can reduce computation time. It works, but take it with a grain of salt, as it is far from being exhaustive. Feel free to contribute more sophisticated benchmarks.
208 |
209 | ## How does it work?
210 |
211 | As far as I know, the methods that are implemented in this package do not exist in the litterature per se. I first [stumbled](https://maxhalford.github.io/blog/undersampling-ratios/) on the under-sampling method by myself, which turned out to be equivalent to [rejection sampling](https://www.wikiwand.com/en/Rejection_sampling). I then worked out the necessary formulas for over-sampling and the hybrid method. Both of the latter are based on the idea of sampling from a Poisson distribution, which I took from the [*Online Bagging and Boosting* paper](https://ti.arc.nasa.gov/m/profile/oza/files/ozru01a.pdf) by Nikunj Oza and Stuart Russell. The innovation lies in the determination of the rate that satisfies the desired class distribution.
212 |
213 | ## Development
214 |
215 | ```sh
216 | $ git clone https://github.com/MaxHalford/pytorch-resample
217 | $ cd pytorch-resample
218 | $ python -m venv .env
219 | $ source .env/bin/activate
220 | $ pip install poetry
221 | $ poetry install
222 | $ poetry shell
223 | $ pytest
224 | ```
225 |
226 | ## License
227 |
228 | The MIT License (MIT). Please see the [license file](LICENSE) for more information.
229 |
--------------------------------------------------------------------------------
/benchmarks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Benchmarks"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "## Setup"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "We'll use the [credit card dataset](https://datahub.io/machine-learning/creditcard) that is commonly used as an imbalanced binary classification task."
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {},
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "--2020-08-12 18:20:10-- https://datahub.io/machine-learning/creditcard/r/creditcard.csv\n",
34 | "Resolving datahub.io (datahub.io)... 104.18.48.253, 104.18.49.253, 172.67.157.38\n",
35 | "Connecting to datahub.io (datahub.io)|104.18.48.253|:443... connected.\n",
36 | "HTTP request sent, awaiting response... 302 Found\n",
37 | "Location: https://pkgstore.datahub.io/machine-learning/creditcard/creditcard_csv/data/ebdc64b6837b3026238f3fcad3402337/creditcard_csv.csv [following]\n",
38 | "--2020-08-12 18:20:11-- https://pkgstore.datahub.io/machine-learning/creditcard/creditcard_csv/data/ebdc64b6837b3026238f3fcad3402337/creditcard_csv.csv\n",
39 | "Resolving pkgstore.datahub.io (pkgstore.datahub.io)... 104.18.49.253, 172.67.157.38, 104.18.48.253\n",
40 | "Connecting to pkgstore.datahub.io (pkgstore.datahub.io)|104.18.49.253|:443... connected.\n",
41 | "HTTP request sent, awaiting response... 200 OK\n",
42 | "Length: 151114991 (144M) [text/csv]\n",
43 | "Saving to: ‘creditcard.csv’\n",
44 | "\n",
45 | "creditcard.csv 100%[===================>] 144.11M 17.7MB/s in 8.6s \n",
46 | "\n",
47 | "2020-08-12 18:20:21 (16.7 MB/s) - ‘creditcard.csv’ saved [151114991/151114991]\n",
48 | "\n"
49 | ]
50 | }
51 | ],
52 | "source": [
53 | "!wget https://datahub.io/machine-learning/creditcard/r/creditcard.csv"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "Do a train/test split."
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 72,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "import pandas as pd\n",
70 | "from sklearn import preprocessing\n",
71 | "\n",
72 | "credit = pd.read_csv('creditcard.csv')\n",
73 | "credit = credit.drop(columns=['Time'])\n",
74 | "features = credit.columns.drop('Class')\n",
75 | "credit.loc[:, features] = preprocessing.scale(credit.loc[:, features])\n",
76 | "credit['Class'] = (credit['Class'] == \"'1'\").astype(int)\n",
77 | "\n",
78 | "n_test = 40_000\n",
79 | "credit[:-n_test].to_csv('train.csv', index=False)\n",
80 | "credit[-n_test:].to_csv('test.csv', index=False)"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "While we're at it, let's look at the class distribution."
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 94,
93 | "metadata": {},
94 | "outputs": [
95 | {
96 | "data": {
97 | "text/plain": [
98 | "0 0.998273\n",
99 | "1 0.001727\n",
100 | "Name: Class, dtype: float64"
101 | ]
102 | },
103 | "execution_count": 94,
104 | "metadata": {},
105 | "output_type": "execute_result"
106 | }
107 | ],
108 | "source": [
109 | "credit['Class'].value_counts(normalize=True)"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "Define a helper function to train a PyTorch model on an `IterableDataset`."
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 73,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "import torch\n",
126 | "\n",
127 | "def train(net, optimizer, criterion, train_batches):\n",
128 | "\n",
129 | " for x_batch, y_batch in train_batches:\n",
130 | "\n",
131 | " optimizer.zero_grad()\n",
132 | " y_pred = net(x_batch)\n",
133 | " loss = criterion(y_pred[:, 0], y_batch.float())\n",
134 | " loss.backward()\n",
135 | " optimizer.step()\n",
136 | " \n",
137 | " return net"
138 | ]
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {},
143 | "source": [
144 | "Define a helper function to score a PyTorch model on a test set."
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 84,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "import numpy as np\n",
154 | "\n",
155 | "def score(net, test_batches, metric):\n",
156 | " \n",
157 | " y_true = []\n",
158 | " y_pred = []\n",
159 | " \n",
160 | " for x_batch, y_batch in test_batches:\n",
161 | " y_true.extend(y_batch.detach().numpy())\n",
162 | " y_pred.extend(net(x_batch).detach().numpy()[:, 0])\n",
163 | " \n",
164 | " y_true = np.array(y_true)\n",
165 | " y_pred = np.array(y_pred)\n",
166 | " \n",
167 | " return metric(y_true, y_pred)"
168 | ]
169 | },
170 | {
171 | "cell_type": "markdown",
172 | "metadata": {},
173 | "source": [
174 | "Let's also create an `IterableDataset` that reads from a CSV file. The following implementation is not very generic nor flexible, but it will do for this notebook."
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 78,
180 | "metadata": {},
181 | "outputs": [],
182 | "source": [
183 | "import csv\n",
184 | "\n",
185 | "class IterableCSV(torch.utils.data.IterableDataset):\n",
186 | " \n",
187 | " def __init__(self, path):\n",
188 | " self.path = path\n",
189 | " \n",
190 | " def __iter__(self):\n",
191 | " \n",
192 | " with open(self.path) as file:\n",
193 | " reader = csv.reader(file)\n",
194 | " header = next(reader)\n",
195 | " for row in reader:\n",
196 | " x = [float(el) for el in row[:-1]]\n",
197 | " y = int(row[-1])\n",
198 | " yield torch.tensor(x), y"
199 | ]
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "metadata": {},
204 | "source": [
205 | "We can now define the training and test set loaders."
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 79,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "train_set = IterableCSV(path='train.csv')\n",
215 | "test_set = IterableCSV(path='test.csv')"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {},
221 | "source": [
222 | "## Vanilla"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 117,
228 | "metadata": {},
229 | "outputs": [],
230 | "source": [
231 | "def make_net(n_features):\n",
232 | " \n",
233 | " torch.manual_seed(0)\n",
234 | " \n",
235 | " return torch.nn.Sequential(\n",
236 | " torch.nn.Linear(n_features, 30),\n",
237 | " torch.nn.Linear(30, 1)\n",
238 | " )"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 118,
244 | "metadata": {},
245 | "outputs": [
246 | {
247 | "name": "stdout",
248 | "output_type": "stream",
249 | "text": [
250 | "CPU times: user 10.6 s, sys: 240 ms, total: 10.8 s\n",
251 | "Wall time: 10.8 s\n"
252 | ]
253 | }
254 | ],
255 | "source": [
256 | "%%time\n",
257 | "\n",
258 | "net = make_net(len(features))\n",
259 | "\n",
260 | "net = train(\n",
261 | " net,\n",
262 | " optimizer=torch.optim.SGD(net.parameters(), lr=1e-2),\n",
263 | " criterion=torch.nn.BCEWithLogitsLoss(),\n",
264 | " train_batches=torch.utils.data.DataLoader(train_set, batch_size=16)\n",
265 | ")"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": 119,
271 | "metadata": {},
272 | "outputs": [
273 | {
274 | "data": {
275 | "text/plain": [
276 | "0.9557763595155662"
277 | ]
278 | },
279 | "execution_count": 119,
280 | "metadata": {},
281 | "output_type": "execute_result"
282 | }
283 | ],
284 | "source": [
285 | "from sklearn import metrics\n",
286 | "\n",
287 | "score(\n",
288 | " net,\n",
289 | " test_batches=torch.utils.data.DataLoader(test_set, batch_size=16),\n",
290 | " metric=metrics.roc_auc_score\n",
291 | ")"
292 | ]
293 | },
294 | {
295 | "cell_type": "markdown",
296 | "metadata": {},
297 | "source": [
298 | "## Under-sampling"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 120,
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "import pytorch_resample\n",
308 | "\n",
309 | "train_sample = pytorch_resample.UnderSampler(\n",
310 | " train_set,\n",
311 | " desired_dist={0: .8, 1: .2},\n",
312 | " seed=42\n",
313 | ")"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 121,
319 | "metadata": {},
320 | "outputs": [
321 | {
322 | "name": "stdout",
323 | "output_type": "stream",
324 | "text": [
325 | "CPU times: user 5.98 s, sys: 54.9 ms, total: 6.04 s\n",
326 | "Wall time: 6.05 s\n"
327 | ]
328 | }
329 | ],
330 | "source": [
331 | "%%time\n",
332 | "\n",
333 | "net = make_net(len(features))\n",
334 | "\n",
335 | "net = train(\n",
336 | " net,\n",
337 | " optimizer=torch.optim.SGD(net.parameters(), lr=1e-2),\n",
338 | " criterion=torch.nn.BCEWithLogitsLoss(),\n",
339 | " train_batches=torch.utils.data.DataLoader(train_sample, batch_size=16)\n",
340 | ")"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": 122,
346 | "metadata": {},
347 | "outputs": [
348 | {
349 | "data": {
350 | "text/plain": [
351 | "0.9162552315799719"
352 | ]
353 | },
354 | "execution_count": 122,
355 | "metadata": {},
356 | "output_type": "execute_result"
357 | }
358 | ],
359 | "source": [
360 | "score(\n",
361 | " net,\n",
362 | " test_batches=torch.utils.data.DataLoader(test_set, batch_size=16),\n",
363 | " metric=metrics.roc_auc_score\n",
364 | ")"
365 | ]
366 | },
367 | {
368 | "cell_type": "markdown",
369 | "metadata": {},
370 | "source": [
371 | "## Over-sampling"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": 123,
377 | "metadata": {},
378 | "outputs": [],
379 | "source": [
380 | "train_sample = pytorch_resample.OverSampler(\n",
381 | " train_set,\n",
382 | " desired_dist={0: .8, 1: .2},\n",
383 | " seed=42\n",
384 | ")"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "execution_count": 124,
390 | "metadata": {},
391 | "outputs": [
392 | {
393 | "name": "stdout",
394 | "output_type": "stream",
395 | "text": [
396 | "CPU times: user 12.1 s, sys: 287 ms, total: 12.4 s\n",
397 | "Wall time: 12.4 s\n"
398 | ]
399 | }
400 | ],
401 | "source": [
402 | "%%time\n",
403 | "\n",
404 | "net = make_net(len(features))\n",
405 | "\n",
406 | "net = train(\n",
407 | " net,\n",
408 | " optimizer=torch.optim.SGD(net.parameters(), lr=1e-2),\n",
409 | " criterion=torch.nn.BCEWithLogitsLoss(),\n",
410 | " train_batches=torch.utils.data.DataLoader(train_sample, batch_size=16)\n",
411 | ")"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": 125,
417 | "metadata": {},
418 | "outputs": [
419 | {
420 | "data": {
421 | "text/plain": [
422 | "0.9642164101280608"
423 | ]
424 | },
425 | "execution_count": 125,
426 | "metadata": {},
427 | "output_type": "execute_result"
428 | }
429 | ],
430 | "source": [
431 | "score(\n",
432 | " net,\n",
433 | " test_batches=torch.utils.data.DataLoader(test_set, batch_size=16),\n",
434 | " metric=metrics.roc_auc_score\n",
435 | ")"
436 | ]
437 | },
438 | {
439 | "cell_type": "markdown",
440 | "metadata": {},
441 | "source": [
442 | "## Hybrid method"
443 | ]
444 | },
445 | {
446 | "cell_type": "code",
447 | "execution_count": 126,
448 | "metadata": {},
449 | "outputs": [],
450 | "source": [
451 | "train_sample = pytorch_resample.HybridSampler(\n",
452 | " train_set,\n",
453 | " desired_dist={0: .8, 1: .2},\n",
454 | " sampling_rate=.5,\n",
455 | " seed=42\n",
456 | ")"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 127,
462 | "metadata": {},
463 | "outputs": [
464 | {
465 | "name": "stdout",
466 | "output_type": "stream",
467 | "text": [
468 | "CPU times: user 8.95 s, sys: 166 ms, total: 9.11 s\n",
469 | "Wall time: 9.14 s\n"
470 | ]
471 | }
472 | ],
473 | "source": [
474 | "%%time\n",
475 | "\n",
476 | "net = make_net(len(features))\n",
477 | "\n",
478 | "net = train(\n",
479 | " net,\n",
480 | " optimizer=torch.optim.SGD(net.parameters(), lr=1e-2),\n",
481 | " criterion=torch.nn.BCEWithLogitsLoss(),\n",
482 | " train_batches=torch.utils.data.DataLoader(train_sample, batch_size=16)\n",
483 | ")"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": 128,
489 | "metadata": {},
490 | "outputs": [
491 | {
492 | "data": {
493 | "text/plain": [
494 | "0.9687554053866155"
495 | ]
496 | },
497 | "execution_count": 128,
498 | "metadata": {},
499 | "output_type": "execute_result"
500 | }
501 | ],
502 | "source": [
503 | "score(\n",
504 | " net,\n",
505 | " test_batches=torch.utils.data.DataLoader(test_set, batch_size=16),\n",
506 | " metric=metrics.roc_auc_score\n",
507 | ")"
508 | ]
509 | }
510 | ],
511 | "metadata": {
512 | "kernelspec": {
513 | "display_name": "Python 3",
514 | "language": "python",
515 | "name": "python3"
516 | },
517 | "language_info": {
518 | "codemirror_mode": {
519 | "name": "ipython",
520 | "version": 3
521 | },
522 | "file_extension": ".py",
523 | "mimetype": "text/x-python",
524 | "name": "python",
525 | "nbconvert_exporter": "python",
526 | "pygments_lexer": "ipython3",
527 | "version": "3.7.4"
528 | }
529 | },
530 | "nbformat": 4,
531 | "nbformat_minor": 4
532 | }
533 |
--------------------------------------------------------------------------------
/poetry.lock:
--------------------------------------------------------------------------------
1 | [[package]]
2 | category = "dev"
3 | description = "Atomic file writes."
4 | name = "atomicwrites"
5 | optional = false
6 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
7 | version = "1.4.0"
8 |
9 | [[package]]
10 | category = "dev"
11 | description = "Classes Without Boilerplate"
12 | name = "attrs"
13 | optional = false
14 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
15 | version = "19.3.0"
16 |
17 | [package.extras]
18 | azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "pytest-azurepipelines"]
19 | dev = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "sphinx", "pre-commit"]
20 | docs = ["sphinx", "zope.interface"]
21 | tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"]
22 |
23 | [[package]]
24 | category = "dev"
25 | description = "Cross-platform colored terminal text."
26 | marker = "sys_platform == \"win32\""
27 | name = "colorama"
28 | optional = false
29 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
30 | version = "0.4.3"
31 |
32 | [[package]]
33 | category = "main"
34 | description = "Clean single-source support for Python 3 and 2"
35 | name = "future"
36 | optional = false
37 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
38 | version = "0.18.2"
39 |
40 | [[package]]
41 | category = "dev"
42 | description = "Read metadata from Python packages"
43 | marker = "python_version < \"3.8\""
44 | name = "importlib-metadata"
45 | optional = false
46 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
47 | version = "1.7.0"
48 |
49 | [package.dependencies]
50 | zipp = ">=0.5"
51 |
52 | [package.extras]
53 | docs = ["sphinx", "rst.linker"]
54 | testing = ["packaging", "pep517", "importlib-resources (>=1.3)"]
55 |
56 | [[package]]
57 | category = "dev"
58 | description = "Lightweight pipelining: using Python functions as pipeline jobs."
59 | name = "joblib"
60 | optional = false
61 | python-versions = ">=3.6"
62 | version = "0.16.0"
63 |
64 | [[package]]
65 | category = "dev"
66 | description = "More routines for operating on iterables, beyond itertools"
67 | name = "more-itertools"
68 | optional = false
69 | python-versions = ">=3.5"
70 | version = "8.4.0"
71 |
72 | [[package]]
73 | category = "main"
74 | description = "NumPy is the fundamental package for array computing with Python."
75 | name = "numpy"
76 | optional = false
77 | python-versions = ">=3.6"
78 | version = "1.19.1"
79 |
80 | [[package]]
81 | category = "dev"
82 | description = "plugin and hook calling mechanisms for python"
83 | name = "pluggy"
84 | optional = false
85 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
86 | version = "0.13.1"
87 |
88 | [package.dependencies]
89 | [package.dependencies.importlib-metadata]
90 | python = "<3.8"
91 | version = ">=0.12"
92 |
93 | [package.extras]
94 | dev = ["pre-commit", "tox"]
95 |
96 | [[package]]
97 | category = "dev"
98 | description = "library with cross-python path, ini-parsing, io, code, log facilities"
99 | name = "py"
100 | optional = false
101 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
102 | version = "1.9.0"
103 |
104 | [[package]]
105 | category = "dev"
106 | description = "pytest: simple powerful testing with Python"
107 | name = "pytest"
108 | optional = false
109 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
110 | version = "3.10.1"
111 |
112 | [package.dependencies]
113 | atomicwrites = ">=1.0"
114 | attrs = ">=17.4.0"
115 | colorama = "*"
116 | more-itertools = ">=4.0.0"
117 | pluggy = ">=0.7"
118 | py = ">=1.5.0"
119 | setuptools = "*"
120 | six = ">=1.10.0"
121 |
122 | [[package]]
123 | category = "dev"
124 | description = "A set of python modules for machine learning and data mining"
125 | name = "scikit-learn"
126 | optional = false
127 | python-versions = ">=3.6"
128 | version = "0.23.2"
129 |
130 | [package.dependencies]
131 | joblib = ">=0.11"
132 | numpy = ">=1.13.3"
133 | scipy = ">=0.19.1"
134 | threadpoolctl = ">=2.0.0"
135 |
136 | [package.extras]
137 | alldeps = ["numpy (>=1.13.3)", "scipy (>=0.19.1)"]
138 |
139 | [[package]]
140 | category = "dev"
141 | description = "SciPy: Scientific Library for Python"
142 | name = "scipy"
143 | optional = false
144 | python-versions = ">=3.6"
145 | version = "1.5.2"
146 |
147 | [package.dependencies]
148 | numpy = ">=1.14.5"
149 |
150 | [[package]]
151 | category = "dev"
152 | description = "Python 2 and 3 compatibility utilities"
153 | name = "six"
154 | optional = false
155 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
156 | version = "1.15.0"
157 |
158 | [[package]]
159 | category = "dev"
160 | description = "threadpoolctl"
161 | name = "threadpoolctl"
162 | optional = false
163 | python-versions = ">=3.5"
164 | version = "2.1.0"
165 |
166 | [[package]]
167 | category = "main"
168 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
169 | name = "torch"
170 | optional = false
171 | python-versions = ">=3.6.1"
172 | version = "1.6.0"
173 |
174 | [package.dependencies]
175 | future = "*"
176 | numpy = "*"
177 |
178 | [[package]]
179 | category = "dev"
180 | description = "Backport of pathlib-compatible object wrapper for zip files"
181 | marker = "python_version < \"3.8\""
182 | name = "zipp"
183 | optional = false
184 | python-versions = ">=3.6"
185 | version = "3.1.0"
186 |
187 | [package.extras]
188 | docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"]
189 | testing = ["jaraco.itertools", "func-timeout"]
190 |
191 | [metadata]
192 | content-hash = "7891adb5afde230411c466fb2348d211b354aba148dc2e411a9593e33401aaf5"
193 | lock-version = "1.0"
194 | python-versions = "^3.6.1"
195 |
196 | [metadata.files]
197 | atomicwrites = [
198 | {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"},
199 | {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"},
200 | ]
201 | attrs = [
202 | {file = "attrs-19.3.0-py2.py3-none-any.whl", hash = "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c"},
203 | {file = "attrs-19.3.0.tar.gz", hash = "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"},
204 | ]
205 | colorama = [
206 | {file = "colorama-0.4.3-py2.py3-none-any.whl", hash = "sha256:7d73d2a99753107a36ac6b455ee49046802e59d9d076ef8e47b61499fa29afff"},
207 | {file = "colorama-0.4.3.tar.gz", hash = "sha256:e96da0d330793e2cb9485e9ddfd918d456036c7149416295932478192f4436a1"},
208 | ]
209 | future = [
210 | {file = "future-0.18.2.tar.gz", hash = "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"},
211 | ]
212 | importlib-metadata = [
213 | {file = "importlib_metadata-1.7.0-py2.py3-none-any.whl", hash = "sha256:dc15b2969b4ce36305c51eebe62d418ac7791e9a157911d58bfb1f9ccd8e2070"},
214 | {file = "importlib_metadata-1.7.0.tar.gz", hash = "sha256:90bb658cdbbf6d1735b6341ce708fc7024a3e14e99ffdc5783edea9f9b077f83"},
215 | ]
216 | joblib = [
217 | {file = "joblib-0.16.0-py3-none-any.whl", hash = "sha256:d348c5d4ae31496b2aa060d6d9b787864dd204f9480baaa52d18850cb43e9f49"},
218 | {file = "joblib-0.16.0.tar.gz", hash = "sha256:8f52bf24c64b608bf0b2563e0e47d6fcf516abc8cfafe10cfd98ad66d94f92d6"},
219 | ]
220 | more-itertools = [
221 | {file = "more-itertools-8.4.0.tar.gz", hash = "sha256:68c70cc7167bdf5c7c9d8f6954a7837089c6a36bf565383919bb595efb8a17e5"},
222 | {file = "more_itertools-8.4.0-py3-none-any.whl", hash = "sha256:b78134b2063dd214000685165d81c154522c3ee0a1c0d4d113c80361c234c5a2"},
223 | ]
224 | numpy = [
225 | {file = "numpy-1.19.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b1cca51512299841bf69add3b75361779962f9cee7d9ee3bb446d5982e925b69"},
226 | {file = "numpy-1.19.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:c9591886fc9cbe5532d5df85cb8e0cc3b44ba8ce4367bd4cf1b93dc19713da72"},
227 | {file = "numpy-1.19.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:cf1347450c0b7644ea142712619533553f02ef23f92f781312f6a3553d031fc7"},
228 | {file = "numpy-1.19.1-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:ed8a311493cf5480a2ebc597d1e177231984c818a86875126cfd004241a73c3e"},
229 | {file = "numpy-1.19.1-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3673c8b2b29077f1b7b3a848794f8e11f401ba0b71c49fbd26fb40b71788b132"},
230 | {file = "numpy-1.19.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:56ef7f56470c24bb67fb43dae442e946a6ce172f97c69f8d067ff8550cf782ff"},
231 | {file = "numpy-1.19.1-cp36-cp36m-win32.whl", hash = "sha256:aaf42a04b472d12515debc621c31cf16c215e332242e7a9f56403d814c744624"},
232 | {file = "numpy-1.19.1-cp36-cp36m-win_amd64.whl", hash = "sha256:082f8d4dd69b6b688f64f509b91d482362124986d98dc7dc5f5e9f9b9c3bb983"},
233 | {file = "numpy-1.19.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e4f6d3c53911a9d103d8ec9518190e52a8b945bab021745af4939cfc7c0d4a9e"},
234 | {file = "numpy-1.19.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:5b6885c12784a27e957294b60f97e8b5b4174c7504665333c5e94fbf41ae5d6a"},
235 | {file = "numpy-1.19.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:1bc0145999e8cb8aed9d4e65dd8b139adf1919e521177f198529687dbf613065"},
236 | {file = "numpy-1.19.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:5a936fd51049541d86ccdeef2833cc89a18e4d3808fe58a8abeb802665c5af93"},
237 | {file = "numpy-1.19.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:ef71a1d4fd4858596ae80ad1ec76404ad29701f8ca7cdcebc50300178db14dfc"},
238 | {file = "numpy-1.19.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b9792b0ac0130b277536ab8944e7b754c69560dac0415dd4b2dbd16b902c8954"},
239 | {file = "numpy-1.19.1-cp37-cp37m-win32.whl", hash = "sha256:b12e639378c741add21fbffd16ba5ad25c0a1a17cf2b6fe4288feeb65144f35b"},
240 | {file = "numpy-1.19.1-cp37-cp37m-win_amd64.whl", hash = "sha256:8343bf67c72e09cfabfab55ad4a43ce3f6bf6e6ced7acf70f45ded9ebb425055"},
241 | {file = "numpy-1.19.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e45f8e981a0ab47103181773cc0a54e650b2aef8c7b6cd07405d0fa8d869444a"},
242 | {file = "numpy-1.19.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:667c07063940e934287993366ad5f56766bc009017b4a0fe91dbd07960d0aba7"},
243 | {file = "numpy-1.19.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:480fdd4dbda4dd6b638d3863da3be82873bba6d32d1fc12ea1b8486ac7b8d129"},
244 | {file = "numpy-1.19.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:935c27ae2760c21cd7354402546f6be21d3d0c806fffe967f745d5f2de5005a7"},
245 | {file = "numpy-1.19.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:309cbcfaa103fc9a33ec16d2d62569d541b79f828c382556ff072442226d1968"},
246 | {file = "numpy-1.19.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:7ed448ff4eaffeb01094959b19cbaf998ecdee9ef9932381420d514e446601cd"},
247 | {file = "numpy-1.19.1-cp38-cp38-win32.whl", hash = "sha256:de8b4a9b56255797cbddb93281ed92acbc510fb7b15df3f01bd28f46ebc4edae"},
248 | {file = "numpy-1.19.1-cp38-cp38-win_amd64.whl", hash = "sha256:92feb989b47f83ebef246adabc7ff3b9a59ac30601c3f6819f8913458610bdcc"},
249 | {file = "numpy-1.19.1-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:e1b1dc0372f530f26a03578ac75d5e51b3868b9b76cd2facba4c9ee0eb252ab1"},
250 | {file = "numpy-1.19.1.zip", hash = "sha256:b8456987b637232602ceb4d663cb34106f7eb780e247d51a260b84760fd8f491"},
251 | ]
252 | pluggy = [
253 | {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"},
254 | {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"},
255 | ]
256 | py = [
257 | {file = "py-1.9.0-py2.py3-none-any.whl", hash = "sha256:366389d1db726cd2fcfc79732e75410e5fe4d31db13692115529d34069a043c2"},
258 | {file = "py-1.9.0.tar.gz", hash = "sha256:9ca6883ce56b4e8da7e79ac18787889fa5206c79dcc67fb065376cd2fe03f342"},
259 | ]
260 | pytest = [
261 | {file = "pytest-3.10.1-py2.py3-none-any.whl", hash = "sha256:3f193df1cfe1d1609d4c583838bea3d532b18d6160fd3f55c9447fdca30848ec"},
262 | {file = "pytest-3.10.1.tar.gz", hash = "sha256:e246cf173c01169b9617fc07264b7b1316e78d7a650055235d6d897bc80d9660"},
263 | ]
264 | scikit-learn = [
265 | {file = "scikit-learn-0.23.2.tar.gz", hash = "sha256:20766f515e6cd6f954554387dfae705d93c7b544ec0e6c6a5d8e006f6f7ef480"},
266 | {file = "scikit_learn-0.23.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:98508723f44c61896a4e15894b2016762a55555fbf09365a0bb1870ecbd442de"},
267 | {file = "scikit_learn-0.23.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a64817b050efd50f9abcfd311870073e500ae11b299683a519fbb52d85e08d25"},
268 | {file = "scikit_learn-0.23.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:daf276c465c38ef736a79bd79fc80a249f746bcbcae50c40945428f7ece074f8"},
269 | {file = "scikit_learn-0.23.2-cp36-cp36m-win32.whl", hash = "sha256:cb3e76380312e1f86abd20340ab1d5b3cc46a26f6593d3c33c9ea3e4c7134028"},
270 | {file = "scikit_learn-0.23.2-cp36-cp36m-win_amd64.whl", hash = "sha256:0a127cc70990d4c15b1019680bfedc7fec6c23d14d3719fdf9b64b22d37cdeca"},
271 | {file = "scikit_learn-0.23.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2aa95c2f17d2f80534156215c87bee72b6aa314a7f8b8fe92a2d71f47280570d"},
272 | {file = "scikit_learn-0.23.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:6c28a1d00aae7c3c9568f61aafeaad813f0f01c729bee4fd9479e2132b215c1d"},
273 | {file = "scikit_learn-0.23.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:da8e7c302003dd765d92a5616678e591f347460ac7b53e53d667be7dfe6d1b10"},
274 | {file = "scikit_learn-0.23.2-cp37-cp37m-win32.whl", hash = "sha256:d9a1ce5f099f29c7c33181cc4386660e0ba891b21a60dc036bf369e3a3ee3aec"},
275 | {file = "scikit_learn-0.23.2-cp37-cp37m-win_amd64.whl", hash = "sha256:914ac2b45a058d3f1338d7736200f7f3b094857758895f8667be8a81ff443b5b"},
276 | {file = "scikit_learn-0.23.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7671bbeddd7f4f9a6968f3b5442dac5f22bf1ba06709ef888cc9132ad354a9ab"},
277 | {file = "scikit_learn-0.23.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:d0dcaa54263307075cb93d0bee3ceb02821093b1b3d25f66021987d305d01dce"},
278 | {file = "scikit_learn-0.23.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5ce7a8021c9defc2b75620571b350acc4a7d9763c25b7593621ef50f3bd019a2"},
279 | {file = "scikit_learn-0.23.2-cp38-cp38-win32.whl", hash = "sha256:0d39748e7c9669ba648acf40fb3ce96b8a07b240db6888563a7cb76e05e0d9cc"},
280 | {file = "scikit_learn-0.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea"},
281 | ]
282 | scipy = [
283 | {file = "scipy-1.5.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cca9fce15109a36a0a9f9cfc64f870f1c140cb235ddf27fe0328e6afb44dfed0"},
284 | {file = "scipy-1.5.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:1c7564a4810c1cd77fcdee7fa726d7d39d4e2695ad252d7c86c3ea9d85b7fb8f"},
285 | {file = "scipy-1.5.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:07e52b316b40a4f001667d1ad4eb5f2318738de34597bd91537851365b6c61f1"},
286 | {file = "scipy-1.5.2-cp36-cp36m-win32.whl", hash = "sha256:d56b10d8ed72ec1be76bf10508446df60954f08a41c2d40778bc29a3a9ad9bce"},
287 | {file = "scipy-1.5.2-cp36-cp36m-win_amd64.whl", hash = "sha256:8e28e74b97fc8d6aa0454989db3b5d36fc27e69cef39a7ee5eaf8174ca1123cb"},
288 | {file = "scipy-1.5.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6e86c873fe1335d88b7a4bfa09d021f27a9e753758fd75f3f92d714aa4093768"},
289 | {file = "scipy-1.5.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a0afbb967fd2c98efad5f4c24439a640d39463282040a88e8e928db647d8ac3d"},
290 | {file = "scipy-1.5.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:eecf40fa87eeda53e8e11d265ff2254729d04000cd40bae648e76ff268885d66"},
291 | {file = "scipy-1.5.2-cp37-cp37m-win32.whl", hash = "sha256:315aa2165aca31375f4e26c230188db192ed901761390be908c9b21d8b07df62"},
292 | {file = "scipy-1.5.2-cp37-cp37m-win_amd64.whl", hash = "sha256:ec5fe57e46828d034775b00cd625c4a7b5c7d2e354c3b258d820c6c72212a6ec"},
293 | {file = "scipy-1.5.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fc98f3eac993b9bfdd392e675dfe19850cc8c7246a8fd2b42443e506344be7d9"},
294 | {file = "scipy-1.5.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:a785409c0fa51764766840185a34f96a0a93527a0ff0230484d33a8ed085c8f8"},
295 | {file = "scipy-1.5.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0a0e9a4e58a4734c2eba917f834b25b7e3b6dc333901ce7784fd31aefbd37b2f"},
296 | {file = "scipy-1.5.2-cp38-cp38-win32.whl", hash = "sha256:dac09281a0eacd59974e24525a3bc90fa39b4e95177e638a31b14db60d3fa806"},
297 | {file = "scipy-1.5.2-cp38-cp38-win_amd64.whl", hash = "sha256:92eb04041d371fea828858e4fff182453c25ae3eaa8782d9b6c32b25857d23bc"},
298 | {file = "scipy-1.5.2.tar.gz", hash = "sha256:066c513d90eb3fd7567a9e150828d39111ebd88d3e924cdfc9f8ce19ab6f90c9"},
299 | ]
300 | six = [
301 | {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"},
302 | {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"},
303 | ]
304 | threadpoolctl = [
305 | {file = "threadpoolctl-2.1.0-py3-none-any.whl", hash = "sha256:38b74ca20ff3bb42caca8b00055111d74159ee95c4370882bbff2b93d24da725"},
306 | {file = "threadpoolctl-2.1.0.tar.gz", hash = "sha256:ddc57c96a38beb63db45d6c159b5ab07b6bced12c45a1f07b2b92f272aebfa6b"},
307 | ]
308 | torch = [
309 | {file = "torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:7669f4d923b5758e28b521ea749c795ed67ff24b45ba20296bc8cff706d08df8"},
310 | {file = "torch-1.6.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:728facb972a5952323c6d790c2c5922b2b35c44b0bc7bdfa02f8639727671a0c"},
311 | {file = "torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:87d65c01d1b70bb46070824f28bfd93c86d3c5c56b90cbbe836a3f2491d91c76"},
312 | {file = "torch-1.6.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:3838bd01af7dfb1f78573973f6842ce75b17e8e4f22be99c891dcb7c94bc13f5"},
313 | {file = "torch-1.6.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5357873e243bcfa804c32dc341f564e9a4c12addfc9baae4ee857fcc09a0a216"},
314 | {file = "torch-1.6.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:4f9a4ad7947cef566afb0a323d99009fe8524f0b0f2ca1fb7ad5de0400381a5b"},
315 | ]
316 | zipp = [
317 | {file = "zipp-3.1.0-py3-none-any.whl", hash = "sha256:aa36550ff0c0b7ef7fa639055d797116ee891440eac1a56f378e2d3179e0320b"},
318 | {file = "zipp-3.1.0.tar.gz", hash = "sha256:c599e4d75c98f6798c509911d08a22e6c021d074469042177c8c86fb92eefd96"},
319 | ]
320 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "pytorch_resample"
3 | version = "0.1.0"
4 | description = "Resampling methods for iterable datasets in PyTorch"
5 | authors = ["Max Halford "]
6 |
7 | [tool.poetry.dependencies]
8 | python = "^3.6.1"
9 | torch = "^1.6.0"
10 |
11 | [tool.poetry.dev-dependencies]
12 | pytest = "^3.4"
13 | scikit-learn = "^0.23.2"
14 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | addopts = --doctest-modules --doctest-glob=README.md --verbose
3 |
--------------------------------------------------------------------------------
/pytorch_resample/__init__.py:
--------------------------------------------------------------------------------
1 | from .hybrid import HybridSampler
2 | from .over import OverSampler
3 | from .under import UnderSampler
4 |
--------------------------------------------------------------------------------
/pytorch_resample/hybrid.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import random
3 |
4 | import torch
5 |
6 | from . import utils
7 |
8 |
9 | class HybridSampler(torch.utils.data.IterableDataset):
10 | """Dataset wrapper that uses both under-sampling and over-sampling.
11 |
12 | Parameters:
13 | dataset
14 | desired_dist: The desired class distribution. The keys are the classes whilst the
15 | values are the desired class percentages. The values are normalised so that sum up
16 | to 1.
17 | sampling_rate: The fraction of data to use.
18 | seed: Random seed for reproducibility.
19 |
20 | Attributes:
21 | actual_dist: The counts of the observed sample labels.
22 | rng: A random number generator instance.
23 |
24 | """
25 |
26 | def __init__(self, dataset: torch.utils.data.IterableDataset, desired_dist: dict,
27 | sampling_rate: float, seed: int = None):
28 |
29 | self.dataset = dataset
30 | self.desired_dist = {c: p / sum(desired_dist.values()) for c, p in desired_dist.items()}
31 | self.sampling_rate = min(max(sampling_rate, 0), 1)
32 | self.seed = seed
33 |
34 | self.actual_dist = collections.Counter()
35 | self.rng = random.Random(seed)
36 | self._n = 0
37 |
38 | def __iter__(self):
39 |
40 | for x, y in self.dataset:
41 |
42 | self.actual_dist[y] += 1
43 | self._n += 1
44 |
45 | f = self.desired_dist
46 | g = self.actual_dist
47 |
48 | rate = self.sampling_rate * f[y] / (g[y] / self._n)
49 |
50 | for _ in range(utils.random_poisson(rate, rng=self.rng)):
51 | yield x, y
52 |
53 | @classmethod
54 | def expected_size(cls, n, sampling_rate):
55 | return int(sampling_rate * n)
56 |
--------------------------------------------------------------------------------
/pytorch_resample/over.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import random
3 |
4 | import torch
5 |
6 | from . import utils
7 |
8 |
9 | class OverSampler(torch.utils.data.IterableDataset):
10 | """Dataset wrapper for over-sampling.
11 |
12 | Parameters:
13 | dataset
14 | desired_dist: The desired class distribution. The keys are the classes whilst the
15 | values are the desired class percentages. The values are normalised so that sum up
16 | to 1.
17 | seed: Random seed for reproducibility.
18 |
19 | Attributes:
20 | actual_dist: The counts of the observed sample labels.
21 | rng: A random number generator instance.
22 |
23 | """
24 |
25 | def __init__(self, dataset: torch.utils.data.IterableDataset, desired_dist: dict,
26 | seed: int = None):
27 |
28 | self.dataset = dataset
29 | self.desired_dist = {c: p / sum(desired_dist.values()) for c, p in desired_dist.items()}
30 | self.seed = seed
31 |
32 | self.actual_dist = collections.Counter()
33 | self.rng = random.Random(seed)
34 | self._pivot = None
35 |
36 | def __iter__(self):
37 |
38 | for x, y in self.dataset:
39 |
40 | self.actual_dist[y] += 1
41 |
42 | # To ease notation
43 | f = self.desired_dist
44 | g = self.actual_dist
45 |
46 | # Check if the pivot needs to be changed
47 | if y != self._pivot:
48 | self._pivot = max(g.keys(), key=lambda y: g[y] / f[y])
49 | else:
50 | yield x, y
51 | continue
52 |
53 | # Determine the sampling ratio if the observed label is not the pivot
54 | M = g[self._pivot] / f[self._pivot]
55 | rate = M * f[y] / g[y]
56 |
57 | for _ in range(utils.random_poisson(rate, rng=self.rng)):
58 | yield x, y
59 |
60 | @classmethod
61 | def expected_size(cls, n, desired_dist, actual_dist):
62 | M = max(
63 | actual_dist.get(k) / desired_dist.get(k)
64 | for k in set(desired_dist) | set(actual_dist)
65 | )
66 | return int(n * M)
67 |
--------------------------------------------------------------------------------
/pytorch_resample/under.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import random
3 |
4 | import torch
5 |
6 |
7 | class UnderSampler(torch.utils.data.IterableDataset):
8 | """Dataset wrapper for under-sampling.
9 |
10 | This method is based on rejection sampling.
11 |
12 | Parameters:
13 | dataset
14 | desired_dist: The desired class distribution. The keys are the classes whilst the
15 | values are the desired class percentages. The values are normalised so that sum up
16 | to 1.
17 | seed: Random seed for reproducibility.
18 |
19 | Attributes:
20 | actual_dist: The counts of the observed sample labels.
21 | rng: A random number generator instance.
22 |
23 | References:
24 | - https://www.wikiwand.com/en/Rejection_sampling
25 |
26 | """
27 |
28 | def __init__(self, dataset: torch.utils.data.IterableDataset, desired_dist: dict,
29 | seed: int = None):
30 |
31 | self.dataset = dataset
32 | self.desired_dist = {c: p / sum(desired_dist.values()) for c, p in desired_dist.items()}
33 | self.seed = seed
34 |
35 | self.actual_dist = collections.Counter()
36 | self.rng = random.Random(seed)
37 | self._pivot = None
38 |
39 | def __iter__(self):
40 |
41 | for x, y in self.dataset:
42 |
43 | self.actual_dist[y] += 1
44 |
45 | # To ease notation
46 | f = self.desired_dist
47 | g = self.actual_dist
48 |
49 | # Check if the pivot needs to be changed
50 | if y != self._pivot:
51 | self._pivot = max(g.keys(), key=lambda y: f[y] / g[y])
52 | else:
53 | yield x, y
54 | continue
55 |
56 | # Determine the sampling ratio if the observed label is not the pivot
57 | M = f[self._pivot] / g[self._pivot]
58 | ratio = f[y] / (M * g[y])
59 |
60 | if ratio < 1 and self.rng.random() < ratio:
61 | yield x, y
62 |
63 | @classmethod
64 | def expected_size(cls, n, desired_dist, actual_dist):
65 | M = max(
66 | desired_dist.get(k) / actual_dist.get(k)
67 | for k in set(desired_dist) | set(actual_dist)
68 | )
69 | return int(n / M)
70 |
--------------------------------------------------------------------------------
/pytorch_resample/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 |
5 | __all__ = ['random_poisson']
6 |
7 |
8 | def random_poisson(rate, rng=random):
9 | """Sample a random value from a Poisson distribution.
10 |
11 | This implementation is done in pure Python. Using PyTorch would be much slower.
12 |
13 | References:
14 | - https://www.wikiwand.com/en/Poisson_distribution#/Generating_Poisson-distributed_random_variables
15 |
16 | """
17 |
18 | L = math.exp(-rate)
19 | k = 0
20 | p = 1
21 |
22 | while p > L:
23 | k += 1
24 | p *= rng.random()
25 |
26 | return k - 1
27 |
--------------------------------------------------------------------------------