├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── benchmarks.ipynb ├── poetry.lock ├── pyproject.toml ├── pytest.ini └── pytorch_resample ├── __init__.py ├── hybrid.py ├── over.py ├── under.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb filter=nbstripout 2 | *.ipynb diff=ipynb 3 | *.ipynb linguist-detectable=false 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .ipynb_checkpoints 3 | *.pyc 4 | .vscode 5 | dist/ 6 | .DS_Store 7 | *.csv 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Max Halford 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Iterable dataset resampling in PyTorch

2 | 3 | - [Motivation](#motivation) 4 | - [Installation](#installation) 5 | - [Usage](#usage) 6 | - [Under-sampling](#under-sampling) 7 | - [Over-sampling](#over-sampling) 8 | - [Hybrid method](#hybrid-method) 9 | - [Expected number of samples](#expected-number-of-samples) 10 | - [Performance tip](#performance-tip) 11 | - [Benchmarks](#benchmarks) 12 | - [How does it work?](#how-does-it-work) 13 | - [Development](#development) 14 | - [License](#license) 15 | 16 | ## Motivation 17 | 18 | [Imbalanced learning](https://www.jeremyjordan.me/imbalanced-data/) is a machine learning paradigm whereby a classifier has to learn from a dataset that has a skewed class distribution. An imbalanced dataset may have a detrimental impact on the classifier's performance. 19 | 20 | Rebalancing a dataset is one way to deal with class imbalance. This can be done by: 21 | 22 | 1. under-sampling common classes. 23 | 2. over-sampling rare classes. 24 | 3. doing a mix of both. 25 | 26 | PyTorch provides [some utilities](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler) for rebalancing a dataset, but they are limited to batch datasets of known length (i.e., they require a dataset to have a `__len__` method). Community contributions such as [ufoym/imbalanced-dataset-sampler](https://github.com/ufoym/imbalanced-dataset-sampler) are cute, but they also only work with batch datasets (also called *map-style* datasets in PyTorch jargon). There's also a [GitHub issue](https://github.com/pytorch/pytorch/issues/28743) opened on the PyTorch GitHub repository, but it doesn't seem very active. 27 | 28 | This repository implements data resamplers that wrap an [`IterableDataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset). Each data resampler also inherits from `IterableDataset`. The latter was added to PyTorch in [this pull request](https://github.com/pytorch/pytorch/pull/19228). In particular, the provided methods do not require you to have to know the size of your dataset in advance. Each methods works for both binary and multi-class classification. 29 | 30 | ☝️ If you're looking to sample your data completely at random, without taking into consideration the class distribution, then we recommend that you do it yourself in your `IterableDataset` implementation. Indeed, you just have to generate a random number between 0 and 1 and keep a sample if the sampled number is under a given threshold. This library is meant to be used when you want to use resampling to balance your class distribution. 31 | 32 | ## Installation 33 | 34 | ```sh 35 | $ pip install pytorch_resample 36 | ``` 37 | 38 | ## Usage 39 | 40 | As a running example, we'll define an `IterableDataset` that iterates over the output of scikit-learn's [`make_classification`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html) function. 41 | 42 | ```py 43 | >>> from sklearn import datasets 44 | >>> import torch 45 | 46 | >>> class MakeClassificationStream(torch.utils.data.IterableDataset): 47 | ... 48 | ... def __init__(self, *args, **kwargs): 49 | ... self.X, self.y = datasets.make_classification(*args, **kwargs) 50 | ... 51 | ... def __iter__(self): 52 | ... yield from iter(zip(self.X, self.y)) 53 | 54 | ``` 55 | 56 | The above dataset can be provided to a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) in order to iterate over [`Tensor`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor) batches. For the sake of example, we'll generate 10.000 samples, with 50% of 0s, 40% of 1s, and 10% of 2s. We can use a [`collections.Counter`](https://docs.python.org/3/library/collections.html#collections.Counter) to measure the effective class distribution. 57 | 58 | ```py 59 | >>> import collections 60 | 61 | >>> dataset = MakeClassificationStream( 62 | ... n_samples=10_000, 63 | ... n_classes=3, 64 | ... n_informative=6, 65 | ... weights=[.5, .4, .1], 66 | ... random_state=42 67 | ... ) 68 | 69 | >>> y_dist = collections.Counter() 70 | 71 | >>> batches = torch.utils.data.DataLoader(dataset, batch_size=16) 72 | >>> for xb, yb in batches: 73 | ... y_dist.update(yb.numpy()) 74 | 75 | >>> for label in sorted(y_dist): 76 | ... frequency = y_dist[label] / sum(y_dist.values()) 77 | ... print(f'• {label}: {frequency:.2%} ({y_dist[label]})') 78 | • 0: 49.95% (4995) 79 | • 1: 39.88% (3988) 80 | • 2: 10.17% (1017) 81 | 82 | ``` 83 | 84 | ### Under-sampling 85 | 86 | The data stream can be under-sampled with the `pytorch_resample.UnderSampler` class. The latter is a wrapper that has to be provided with an `IterableDataset` and a desired class distribution. It inherits from `IterableDataset`, and may thus be used instead of the wrapped dataset. As an example, let's make it so that the classes are equally represented. 87 | 88 | ```py 89 | >>> import pytorch_resample 90 | >>> import torch 91 | 92 | >>> sample = pytorch_resample.UnderSampler( 93 | ... dataset=dataset, 94 | ... desired_dist={0: .33, 1: .33, 2: .33}, 95 | ... seed=42 96 | ... ) 97 | 98 | >>> isinstance(sample, torch.utils.data.IterableDataset) 99 | True 100 | 101 | >>> y_dist = collections.Counter() 102 | 103 | >>> batches = torch.utils.data.DataLoader(sample, batch_size=16) 104 | >>> for xb, yb in batches: 105 | ... y_dist.update(yb.numpy()) 106 | 107 | >>> for label in sorted(y_dist): 108 | ... frequency = y_dist[label] / sum(y_dist.values()) 109 | ... print(f'• {label}: {frequency:.2%} ({y_dist[label]})') 110 | • 0: 33.30% (1007) 111 | • 1: 33.10% (1001) 112 | • 2: 33.60% (1016) 113 | 114 | ``` 115 | 116 | As shown, the observed class distribution is close to the specified distribution. Indeed, there are less 0s and 1s than above. Note that the values of the `desired_dist` parameter are not required to sum up to 1. Indeed, the distribution is normalized automatically. 117 | 118 | ### Over-sampling 119 | 120 | You may use `pytorch_resample.OverSampler` to instead oversample the data. It has the same signature as `pytorch_resample.UnderSampler`, and can thus be used in the exact same manner. 121 | 122 | ```py 123 | >>> sample = pytorch_resample.OverSampler( 124 | ... dataset=dataset, 125 | ... desired_dist={0: .33, 1: .33, 2: .33}, 126 | ... seed=42 127 | ... ) 128 | 129 | >>> isinstance(sample, torch.utils.data.IterableDataset) 130 | True 131 | 132 | >>> y_dist = collections.Counter() 133 | 134 | >>> batches = torch.utils.data.DataLoader(sample, batch_size=16) 135 | >>> for xb, yb in batches: 136 | ... y_dist.update(yb.numpy()) 137 | 138 | >>> for label in sorted(y_dist): 139 | ... frequency = y_dist[label] / sum(y_dist.values()) 140 | ... print(f'• {label}: {frequency:.2%} ({y_dist[label]})') 141 | • 0: 33.21% (4995) 142 | • 1: 33.01% (4965) 143 | • 2: 33.78% (5080) 144 | 145 | ``` 146 | 147 | In this case, the 1s and 2s have been oversampled. 148 | 149 | ### Hybrid method 150 | 151 | The `pytorch_resample.HybridSampler` class can be used to compromise between under-sampling and over-sampling. It accepts an extra parameter called `sampling_rate`, which determines the percentage of data to use. This allows to control how much data is used for training, whilst ensuring that the class distribution follows the desired distribution. 152 | 153 | ```py 154 | >>> sample = pytorch_resample.HybridSampler( 155 | ... dataset=dataset, 156 | ... desired_dist={0: .33, 1: .33, 2: .33}, 157 | ... sampling_rate=.5, # use 50% of the dataset 158 | ... seed=42 159 | ... ) 160 | 161 | >>> isinstance(sample, torch.utils.data.IterableDataset) 162 | True 163 | 164 | >>> y_dist = collections.Counter() 165 | 166 | >>> batches = torch.utils.data.DataLoader(sample, batch_size=16) 167 | >>> for xb, yb in batches: 168 | ... y_dist.update(yb.numpy()) 169 | 170 | >>> for label in sorted(y_dist): 171 | ... frequency = y_dist[label] / sum(y_dist.values()) 172 | ... print(f'• {label}: {frequency:.2%} ({y_dist[label]})') 173 | • 0: 33.01% (1672) 174 | • 1: 32.91% (1667) 175 | • 2: 34.08% (1726) 176 | 177 | ``` 178 | 179 | As can be observed, the amount of streamed samples is close to 5000, which is half the size of the dataset. 180 | 181 | ### Expected number of samples 182 | 183 | It's possible to determine the exact number of samples each resampler will stream back in advance, provided the class distribution of the data is known. 184 | 185 | ```py 186 | >>> n = 10_000 187 | >>> desired = {'cat': 1 / 3, 'mouse': 1 / 3, 'dog': 1 / 3} 188 | >>> actual = {'cat': .5, 'mouse': .4, 'dog': .1} 189 | 190 | >>> pytorch_resample.UnderSampler.expected_size(n, desired, actual) 191 | 3000 192 | 193 | >>> pytorch_resample.OverSampler.expected_size(n, desired, actual) 194 | 15000 195 | 196 | >>> pytorch_resample.HybridSampler.expected_size(n, .5) 197 | 5000 198 | 199 | ``` 200 | 201 | ### Performance tip 202 | 203 | By design `OverSampler` and `HybridSampler` yield repeated samples one after the other. This might not be ideal, as it is usually desirable to diversify the samples within each batch. We therefore recommend that you use a [shuffling buffer](https://www.moderndescartes.com/essays/shuffle_viz/), such as the `ShuffleDataset` class proposed [here](https://discuss.pytorch.org/t/how-to-shuffle-an-iterable-dataset/64130/6). 204 | 205 | ## Benchmarks 206 | 207 | I've written a [simple benchmark](benchmarks.ipynb) to verify that resampling brings a performance boost and can reduce computation time. It works, but take it with a grain of salt, as it is far from being exhaustive. Feel free to contribute more sophisticated benchmarks. 208 | 209 | ## How does it work? 210 | 211 | As far as I know, the methods that are implemented in this package do not exist in the litterature per se. I first [stumbled](https://maxhalford.github.io/blog/undersampling-ratios/) on the under-sampling method by myself, which turned out to be equivalent to [rejection sampling](https://www.wikiwand.com/en/Rejection_sampling). I then worked out the necessary formulas for over-sampling and the hybrid method. Both of the latter are based on the idea of sampling from a Poisson distribution, which I took from the [*Online Bagging and Boosting* paper](https://ti.arc.nasa.gov/m/profile/oza/files/ozru01a.pdf) by Nikunj Oza and Stuart Russell. The innovation lies in the determination of the rate that satisfies the desired class distribution. 212 | 213 | ## Development 214 | 215 | ```sh 216 | $ git clone https://github.com/MaxHalford/pytorch-resample 217 | $ cd pytorch-resample 218 | $ python -m venv .env 219 | $ source .env/bin/activate 220 | $ pip install poetry 221 | $ poetry install 222 | $ poetry shell 223 | $ pytest 224 | ``` 225 | 226 | ## License 227 | 228 | The MIT License (MIT). Please see the [license file](LICENSE) for more information. 229 | -------------------------------------------------------------------------------- /benchmarks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Benchmarks" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Setup" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "We'll use the [credit card dataset](https://datahub.io/machine-learning/creditcard) that is commonly used as an imbalanced binary classification task." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "--2020-08-12 18:20:10-- https://datahub.io/machine-learning/creditcard/r/creditcard.csv\n", 34 | "Resolving datahub.io (datahub.io)... 104.18.48.253, 104.18.49.253, 172.67.157.38\n", 35 | "Connecting to datahub.io (datahub.io)|104.18.48.253|:443... connected.\n", 36 | "HTTP request sent, awaiting response... 302 Found\n", 37 | "Location: https://pkgstore.datahub.io/machine-learning/creditcard/creditcard_csv/data/ebdc64b6837b3026238f3fcad3402337/creditcard_csv.csv [following]\n", 38 | "--2020-08-12 18:20:11-- https://pkgstore.datahub.io/machine-learning/creditcard/creditcard_csv/data/ebdc64b6837b3026238f3fcad3402337/creditcard_csv.csv\n", 39 | "Resolving pkgstore.datahub.io (pkgstore.datahub.io)... 104.18.49.253, 172.67.157.38, 104.18.48.253\n", 40 | "Connecting to pkgstore.datahub.io (pkgstore.datahub.io)|104.18.49.253|:443... connected.\n", 41 | "HTTP request sent, awaiting response... 200 OK\n", 42 | "Length: 151114991 (144M) [text/csv]\n", 43 | "Saving to: ‘creditcard.csv’\n", 44 | "\n", 45 | "creditcard.csv 100%[===================>] 144.11M 17.7MB/s in 8.6s \n", 46 | "\n", 47 | "2020-08-12 18:20:21 (16.7 MB/s) - ‘creditcard.csv’ saved [151114991/151114991]\n", 48 | "\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "!wget https://datahub.io/machine-learning/creditcard/r/creditcard.csv" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "Do a train/test split." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 72, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "import pandas as pd\n", 70 | "from sklearn import preprocessing\n", 71 | "\n", 72 | "credit = pd.read_csv('creditcard.csv')\n", 73 | "credit = credit.drop(columns=['Time'])\n", 74 | "features = credit.columns.drop('Class')\n", 75 | "credit.loc[:, features] = preprocessing.scale(credit.loc[:, features])\n", 76 | "credit['Class'] = (credit['Class'] == \"'1'\").astype(int)\n", 77 | "\n", 78 | "n_test = 40_000\n", 79 | "credit[:-n_test].to_csv('train.csv', index=False)\n", 80 | "credit[-n_test:].to_csv('test.csv', index=False)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "While we're at it, let's look at the class distribution." 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 94, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "0 0.998273\n", 99 | "1 0.001727\n", 100 | "Name: Class, dtype: float64" 101 | ] 102 | }, 103 | "execution_count": 94, 104 | "metadata": {}, 105 | "output_type": "execute_result" 106 | } 107 | ], 108 | "source": [ 109 | "credit['Class'].value_counts(normalize=True)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "Define a helper function to train a PyTorch model on an `IterableDataset`." 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 73, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "import torch\n", 126 | "\n", 127 | "def train(net, optimizer, criterion, train_batches):\n", 128 | "\n", 129 | " for x_batch, y_batch in train_batches:\n", 130 | "\n", 131 | " optimizer.zero_grad()\n", 132 | " y_pred = net(x_batch)\n", 133 | " loss = criterion(y_pred[:, 0], y_batch.float())\n", 134 | " loss.backward()\n", 135 | " optimizer.step()\n", 136 | " \n", 137 | " return net" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "Define a helper function to score a PyTorch model on a test set." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 84, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "import numpy as np\n", 154 | "\n", 155 | "def score(net, test_batches, metric):\n", 156 | " \n", 157 | " y_true = []\n", 158 | " y_pred = []\n", 159 | " \n", 160 | " for x_batch, y_batch in test_batches:\n", 161 | " y_true.extend(y_batch.detach().numpy())\n", 162 | " y_pred.extend(net(x_batch).detach().numpy()[:, 0])\n", 163 | " \n", 164 | " y_true = np.array(y_true)\n", 165 | " y_pred = np.array(y_pred)\n", 166 | " \n", 167 | " return metric(y_true, y_pred)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "Let's also create an `IterableDataset` that reads from a CSV file. The following implementation is not very generic nor flexible, but it will do for this notebook." 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 78, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "import csv\n", 184 | "\n", 185 | "class IterableCSV(torch.utils.data.IterableDataset):\n", 186 | " \n", 187 | " def __init__(self, path):\n", 188 | " self.path = path\n", 189 | " \n", 190 | " def __iter__(self):\n", 191 | " \n", 192 | " with open(self.path) as file:\n", 193 | " reader = csv.reader(file)\n", 194 | " header = next(reader)\n", 195 | " for row in reader:\n", 196 | " x = [float(el) for el in row[:-1]]\n", 197 | " y = int(row[-1])\n", 198 | " yield torch.tensor(x), y" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "We can now define the training and test set loaders." 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 79, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "train_set = IterableCSV(path='train.csv')\n", 215 | "test_set = IterableCSV(path='test.csv')" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "## Vanilla" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 117, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "def make_net(n_features):\n", 232 | " \n", 233 | " torch.manual_seed(0)\n", 234 | " \n", 235 | " return torch.nn.Sequential(\n", 236 | " torch.nn.Linear(n_features, 30),\n", 237 | " torch.nn.Linear(30, 1)\n", 238 | " )" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 118, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "CPU times: user 10.6 s, sys: 240 ms, total: 10.8 s\n", 251 | "Wall time: 10.8 s\n" 252 | ] 253 | } 254 | ], 255 | "source": [ 256 | "%%time\n", 257 | "\n", 258 | "net = make_net(len(features))\n", 259 | "\n", 260 | "net = train(\n", 261 | " net,\n", 262 | " optimizer=torch.optim.SGD(net.parameters(), lr=1e-2),\n", 263 | " criterion=torch.nn.BCEWithLogitsLoss(),\n", 264 | " train_batches=torch.utils.data.DataLoader(train_set, batch_size=16)\n", 265 | ")" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 119, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "text/plain": [ 276 | "0.9557763595155662" 277 | ] 278 | }, 279 | "execution_count": 119, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "from sklearn import metrics\n", 286 | "\n", 287 | "score(\n", 288 | " net,\n", 289 | " test_batches=torch.utils.data.DataLoader(test_set, batch_size=16),\n", 290 | " metric=metrics.roc_auc_score\n", 291 | ")" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "## Under-sampling" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 120, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "import pytorch_resample\n", 308 | "\n", 309 | "train_sample = pytorch_resample.UnderSampler(\n", 310 | " train_set,\n", 311 | " desired_dist={0: .8, 1: .2},\n", 312 | " seed=42\n", 313 | ")" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 121, 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "CPU times: user 5.98 s, sys: 54.9 ms, total: 6.04 s\n", 326 | "Wall time: 6.05 s\n" 327 | ] 328 | } 329 | ], 330 | "source": [ 331 | "%%time\n", 332 | "\n", 333 | "net = make_net(len(features))\n", 334 | "\n", 335 | "net = train(\n", 336 | " net,\n", 337 | " optimizer=torch.optim.SGD(net.parameters(), lr=1e-2),\n", 338 | " criterion=torch.nn.BCEWithLogitsLoss(),\n", 339 | " train_batches=torch.utils.data.DataLoader(train_sample, batch_size=16)\n", 340 | ")" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 122, 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "data": { 350 | "text/plain": [ 351 | "0.9162552315799719" 352 | ] 353 | }, 354 | "execution_count": 122, 355 | "metadata": {}, 356 | "output_type": "execute_result" 357 | } 358 | ], 359 | "source": [ 360 | "score(\n", 361 | " net,\n", 362 | " test_batches=torch.utils.data.DataLoader(test_set, batch_size=16),\n", 363 | " metric=metrics.roc_auc_score\n", 364 | ")" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "## Over-sampling" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 123, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "train_sample = pytorch_resample.OverSampler(\n", 381 | " train_set,\n", 382 | " desired_dist={0: .8, 1: .2},\n", 383 | " seed=42\n", 384 | ")" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 124, 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "name": "stdout", 394 | "output_type": "stream", 395 | "text": [ 396 | "CPU times: user 12.1 s, sys: 287 ms, total: 12.4 s\n", 397 | "Wall time: 12.4 s\n" 398 | ] 399 | } 400 | ], 401 | "source": [ 402 | "%%time\n", 403 | "\n", 404 | "net = make_net(len(features))\n", 405 | "\n", 406 | "net = train(\n", 407 | " net,\n", 408 | " optimizer=torch.optim.SGD(net.parameters(), lr=1e-2),\n", 409 | " criterion=torch.nn.BCEWithLogitsLoss(),\n", 410 | " train_batches=torch.utils.data.DataLoader(train_sample, batch_size=16)\n", 411 | ")" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 125, 417 | "metadata": {}, 418 | "outputs": [ 419 | { 420 | "data": { 421 | "text/plain": [ 422 | "0.9642164101280608" 423 | ] 424 | }, 425 | "execution_count": 125, 426 | "metadata": {}, 427 | "output_type": "execute_result" 428 | } 429 | ], 430 | "source": [ 431 | "score(\n", 432 | " net,\n", 433 | " test_batches=torch.utils.data.DataLoader(test_set, batch_size=16),\n", 434 | " metric=metrics.roc_auc_score\n", 435 | ")" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | "## Hybrid method" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 126, 448 | "metadata": {}, 449 | "outputs": [], 450 | "source": [ 451 | "train_sample = pytorch_resample.HybridSampler(\n", 452 | " train_set,\n", 453 | " desired_dist={0: .8, 1: .2},\n", 454 | " sampling_rate=.5,\n", 455 | " seed=42\n", 456 | ")" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 127, 462 | "metadata": {}, 463 | "outputs": [ 464 | { 465 | "name": "stdout", 466 | "output_type": "stream", 467 | "text": [ 468 | "CPU times: user 8.95 s, sys: 166 ms, total: 9.11 s\n", 469 | "Wall time: 9.14 s\n" 470 | ] 471 | } 472 | ], 473 | "source": [ 474 | "%%time\n", 475 | "\n", 476 | "net = make_net(len(features))\n", 477 | "\n", 478 | "net = train(\n", 479 | " net,\n", 480 | " optimizer=torch.optim.SGD(net.parameters(), lr=1e-2),\n", 481 | " criterion=torch.nn.BCEWithLogitsLoss(),\n", 482 | " train_batches=torch.utils.data.DataLoader(train_sample, batch_size=16)\n", 483 | ")" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 128, 489 | "metadata": {}, 490 | "outputs": [ 491 | { 492 | "data": { 493 | "text/plain": [ 494 | "0.9687554053866155" 495 | ] 496 | }, 497 | "execution_count": 128, 498 | "metadata": {}, 499 | "output_type": "execute_result" 500 | } 501 | ], 502 | "source": [ 503 | "score(\n", 504 | " net,\n", 505 | " test_batches=torch.utils.data.DataLoader(test_set, batch_size=16),\n", 506 | " metric=metrics.roc_auc_score\n", 507 | ")" 508 | ] 509 | } 510 | ], 511 | "metadata": { 512 | "kernelspec": { 513 | "display_name": "Python 3", 514 | "language": "python", 515 | "name": "python3" 516 | }, 517 | "language_info": { 518 | "codemirror_mode": { 519 | "name": "ipython", 520 | "version": 3 521 | }, 522 | "file_extension": ".py", 523 | "mimetype": "text/x-python", 524 | "name": "python", 525 | "nbconvert_exporter": "python", 526 | "pygments_lexer": "ipython3", 527 | "version": "3.7.4" 528 | } 529 | }, 530 | "nbformat": 4, 531 | "nbformat_minor": 4 532 | } 533 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | category = "dev" 3 | description = "Atomic file writes." 4 | name = "atomicwrites" 5 | optional = false 6 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 7 | version = "1.4.0" 8 | 9 | [[package]] 10 | category = "dev" 11 | description = "Classes Without Boilerplate" 12 | name = "attrs" 13 | optional = false 14 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 15 | version = "19.3.0" 16 | 17 | [package.extras] 18 | azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "pytest-azurepipelines"] 19 | dev = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "sphinx", "pre-commit"] 20 | docs = ["sphinx", "zope.interface"] 21 | tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] 22 | 23 | [[package]] 24 | category = "dev" 25 | description = "Cross-platform colored terminal text." 26 | marker = "sys_platform == \"win32\"" 27 | name = "colorama" 28 | optional = false 29 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 30 | version = "0.4.3" 31 | 32 | [[package]] 33 | category = "main" 34 | description = "Clean single-source support for Python 3 and 2" 35 | name = "future" 36 | optional = false 37 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 38 | version = "0.18.2" 39 | 40 | [[package]] 41 | category = "dev" 42 | description = "Read metadata from Python packages" 43 | marker = "python_version < \"3.8\"" 44 | name = "importlib-metadata" 45 | optional = false 46 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" 47 | version = "1.7.0" 48 | 49 | [package.dependencies] 50 | zipp = ">=0.5" 51 | 52 | [package.extras] 53 | docs = ["sphinx", "rst.linker"] 54 | testing = ["packaging", "pep517", "importlib-resources (>=1.3)"] 55 | 56 | [[package]] 57 | category = "dev" 58 | description = "Lightweight pipelining: using Python functions as pipeline jobs." 59 | name = "joblib" 60 | optional = false 61 | python-versions = ">=3.6" 62 | version = "0.16.0" 63 | 64 | [[package]] 65 | category = "dev" 66 | description = "More routines for operating on iterables, beyond itertools" 67 | name = "more-itertools" 68 | optional = false 69 | python-versions = ">=3.5" 70 | version = "8.4.0" 71 | 72 | [[package]] 73 | category = "main" 74 | description = "NumPy is the fundamental package for array computing with Python." 75 | name = "numpy" 76 | optional = false 77 | python-versions = ">=3.6" 78 | version = "1.19.1" 79 | 80 | [[package]] 81 | category = "dev" 82 | description = "plugin and hook calling mechanisms for python" 83 | name = "pluggy" 84 | optional = false 85 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 86 | version = "0.13.1" 87 | 88 | [package.dependencies] 89 | [package.dependencies.importlib-metadata] 90 | python = "<3.8" 91 | version = ">=0.12" 92 | 93 | [package.extras] 94 | dev = ["pre-commit", "tox"] 95 | 96 | [[package]] 97 | category = "dev" 98 | description = "library with cross-python path, ini-parsing, io, code, log facilities" 99 | name = "py" 100 | optional = false 101 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 102 | version = "1.9.0" 103 | 104 | [[package]] 105 | category = "dev" 106 | description = "pytest: simple powerful testing with Python" 107 | name = "pytest" 108 | optional = false 109 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 110 | version = "3.10.1" 111 | 112 | [package.dependencies] 113 | atomicwrites = ">=1.0" 114 | attrs = ">=17.4.0" 115 | colorama = "*" 116 | more-itertools = ">=4.0.0" 117 | pluggy = ">=0.7" 118 | py = ">=1.5.0" 119 | setuptools = "*" 120 | six = ">=1.10.0" 121 | 122 | [[package]] 123 | category = "dev" 124 | description = "A set of python modules for machine learning and data mining" 125 | name = "scikit-learn" 126 | optional = false 127 | python-versions = ">=3.6" 128 | version = "0.23.2" 129 | 130 | [package.dependencies] 131 | joblib = ">=0.11" 132 | numpy = ">=1.13.3" 133 | scipy = ">=0.19.1" 134 | threadpoolctl = ">=2.0.0" 135 | 136 | [package.extras] 137 | alldeps = ["numpy (>=1.13.3)", "scipy (>=0.19.1)"] 138 | 139 | [[package]] 140 | category = "dev" 141 | description = "SciPy: Scientific Library for Python" 142 | name = "scipy" 143 | optional = false 144 | python-versions = ">=3.6" 145 | version = "1.5.2" 146 | 147 | [package.dependencies] 148 | numpy = ">=1.14.5" 149 | 150 | [[package]] 151 | category = "dev" 152 | description = "Python 2 and 3 compatibility utilities" 153 | name = "six" 154 | optional = false 155 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 156 | version = "1.15.0" 157 | 158 | [[package]] 159 | category = "dev" 160 | description = "threadpoolctl" 161 | name = "threadpoolctl" 162 | optional = false 163 | python-versions = ">=3.5" 164 | version = "2.1.0" 165 | 166 | [[package]] 167 | category = "main" 168 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 169 | name = "torch" 170 | optional = false 171 | python-versions = ">=3.6.1" 172 | version = "1.6.0" 173 | 174 | [package.dependencies] 175 | future = "*" 176 | numpy = "*" 177 | 178 | [[package]] 179 | category = "dev" 180 | description = "Backport of pathlib-compatible object wrapper for zip files" 181 | marker = "python_version < \"3.8\"" 182 | name = "zipp" 183 | optional = false 184 | python-versions = ">=3.6" 185 | version = "3.1.0" 186 | 187 | [package.extras] 188 | docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] 189 | testing = ["jaraco.itertools", "func-timeout"] 190 | 191 | [metadata] 192 | content-hash = "7891adb5afde230411c466fb2348d211b354aba148dc2e411a9593e33401aaf5" 193 | lock-version = "1.0" 194 | python-versions = "^3.6.1" 195 | 196 | [metadata.files] 197 | atomicwrites = [ 198 | {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, 199 | {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, 200 | ] 201 | attrs = [ 202 | {file = "attrs-19.3.0-py2.py3-none-any.whl", hash = "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c"}, 203 | {file = "attrs-19.3.0.tar.gz", hash = "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72"}, 204 | ] 205 | colorama = [ 206 | {file = "colorama-0.4.3-py2.py3-none-any.whl", hash = "sha256:7d73d2a99753107a36ac6b455ee49046802e59d9d076ef8e47b61499fa29afff"}, 207 | {file = "colorama-0.4.3.tar.gz", hash = "sha256:e96da0d330793e2cb9485e9ddfd918d456036c7149416295932478192f4436a1"}, 208 | ] 209 | future = [ 210 | {file = "future-0.18.2.tar.gz", hash = "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"}, 211 | ] 212 | importlib-metadata = [ 213 | {file = "importlib_metadata-1.7.0-py2.py3-none-any.whl", hash = "sha256:dc15b2969b4ce36305c51eebe62d418ac7791e9a157911d58bfb1f9ccd8e2070"}, 214 | {file = "importlib_metadata-1.7.0.tar.gz", hash = "sha256:90bb658cdbbf6d1735b6341ce708fc7024a3e14e99ffdc5783edea9f9b077f83"}, 215 | ] 216 | joblib = [ 217 | {file = "joblib-0.16.0-py3-none-any.whl", hash = "sha256:d348c5d4ae31496b2aa060d6d9b787864dd204f9480baaa52d18850cb43e9f49"}, 218 | {file = "joblib-0.16.0.tar.gz", hash = "sha256:8f52bf24c64b608bf0b2563e0e47d6fcf516abc8cfafe10cfd98ad66d94f92d6"}, 219 | ] 220 | more-itertools = [ 221 | {file = "more-itertools-8.4.0.tar.gz", hash = "sha256:68c70cc7167bdf5c7c9d8f6954a7837089c6a36bf565383919bb595efb8a17e5"}, 222 | {file = "more_itertools-8.4.0-py3-none-any.whl", hash = "sha256:b78134b2063dd214000685165d81c154522c3ee0a1c0d4d113c80361c234c5a2"}, 223 | ] 224 | numpy = [ 225 | {file = "numpy-1.19.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b1cca51512299841bf69add3b75361779962f9cee7d9ee3bb446d5982e925b69"}, 226 | {file = "numpy-1.19.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:c9591886fc9cbe5532d5df85cb8e0cc3b44ba8ce4367bd4cf1b93dc19713da72"}, 227 | {file = "numpy-1.19.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:cf1347450c0b7644ea142712619533553f02ef23f92f781312f6a3553d031fc7"}, 228 | {file = "numpy-1.19.1-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:ed8a311493cf5480a2ebc597d1e177231984c818a86875126cfd004241a73c3e"}, 229 | {file = "numpy-1.19.1-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:3673c8b2b29077f1b7b3a848794f8e11f401ba0b71c49fbd26fb40b71788b132"}, 230 | {file = "numpy-1.19.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:56ef7f56470c24bb67fb43dae442e946a6ce172f97c69f8d067ff8550cf782ff"}, 231 | {file = "numpy-1.19.1-cp36-cp36m-win32.whl", hash = "sha256:aaf42a04b472d12515debc621c31cf16c215e332242e7a9f56403d814c744624"}, 232 | {file = "numpy-1.19.1-cp36-cp36m-win_amd64.whl", hash = "sha256:082f8d4dd69b6b688f64f509b91d482362124986d98dc7dc5f5e9f9b9c3bb983"}, 233 | {file = "numpy-1.19.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e4f6d3c53911a9d103d8ec9518190e52a8b945bab021745af4939cfc7c0d4a9e"}, 234 | {file = "numpy-1.19.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:5b6885c12784a27e957294b60f97e8b5b4174c7504665333c5e94fbf41ae5d6a"}, 235 | {file = "numpy-1.19.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:1bc0145999e8cb8aed9d4e65dd8b139adf1919e521177f198529687dbf613065"}, 236 | {file = "numpy-1.19.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:5a936fd51049541d86ccdeef2833cc89a18e4d3808fe58a8abeb802665c5af93"}, 237 | {file = "numpy-1.19.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:ef71a1d4fd4858596ae80ad1ec76404ad29701f8ca7cdcebc50300178db14dfc"}, 238 | {file = "numpy-1.19.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b9792b0ac0130b277536ab8944e7b754c69560dac0415dd4b2dbd16b902c8954"}, 239 | {file = "numpy-1.19.1-cp37-cp37m-win32.whl", hash = "sha256:b12e639378c741add21fbffd16ba5ad25c0a1a17cf2b6fe4288feeb65144f35b"}, 240 | {file = "numpy-1.19.1-cp37-cp37m-win_amd64.whl", hash = "sha256:8343bf67c72e09cfabfab55ad4a43ce3f6bf6e6ced7acf70f45ded9ebb425055"}, 241 | {file = "numpy-1.19.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e45f8e981a0ab47103181773cc0a54e650b2aef8c7b6cd07405d0fa8d869444a"}, 242 | {file = "numpy-1.19.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:667c07063940e934287993366ad5f56766bc009017b4a0fe91dbd07960d0aba7"}, 243 | {file = "numpy-1.19.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:480fdd4dbda4dd6b638d3863da3be82873bba6d32d1fc12ea1b8486ac7b8d129"}, 244 | {file = "numpy-1.19.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:935c27ae2760c21cd7354402546f6be21d3d0c806fffe967f745d5f2de5005a7"}, 245 | {file = "numpy-1.19.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:309cbcfaa103fc9a33ec16d2d62569d541b79f828c382556ff072442226d1968"}, 246 | {file = "numpy-1.19.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:7ed448ff4eaffeb01094959b19cbaf998ecdee9ef9932381420d514e446601cd"}, 247 | {file = "numpy-1.19.1-cp38-cp38-win32.whl", hash = "sha256:de8b4a9b56255797cbddb93281ed92acbc510fb7b15df3f01bd28f46ebc4edae"}, 248 | {file = "numpy-1.19.1-cp38-cp38-win_amd64.whl", hash = "sha256:92feb989b47f83ebef246adabc7ff3b9a59ac30601c3f6819f8913458610bdcc"}, 249 | {file = "numpy-1.19.1-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:e1b1dc0372f530f26a03578ac75d5e51b3868b9b76cd2facba4c9ee0eb252ab1"}, 250 | {file = "numpy-1.19.1.zip", hash = "sha256:b8456987b637232602ceb4d663cb34106f7eb780e247d51a260b84760fd8f491"}, 251 | ] 252 | pluggy = [ 253 | {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, 254 | {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, 255 | ] 256 | py = [ 257 | {file = "py-1.9.0-py2.py3-none-any.whl", hash = "sha256:366389d1db726cd2fcfc79732e75410e5fe4d31db13692115529d34069a043c2"}, 258 | {file = "py-1.9.0.tar.gz", hash = "sha256:9ca6883ce56b4e8da7e79ac18787889fa5206c79dcc67fb065376cd2fe03f342"}, 259 | ] 260 | pytest = [ 261 | {file = "pytest-3.10.1-py2.py3-none-any.whl", hash = "sha256:3f193df1cfe1d1609d4c583838bea3d532b18d6160fd3f55c9447fdca30848ec"}, 262 | {file = "pytest-3.10.1.tar.gz", hash = "sha256:e246cf173c01169b9617fc07264b7b1316e78d7a650055235d6d897bc80d9660"}, 263 | ] 264 | scikit-learn = [ 265 | {file = "scikit-learn-0.23.2.tar.gz", hash = "sha256:20766f515e6cd6f954554387dfae705d93c7b544ec0e6c6a5d8e006f6f7ef480"}, 266 | {file = "scikit_learn-0.23.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:98508723f44c61896a4e15894b2016762a55555fbf09365a0bb1870ecbd442de"}, 267 | {file = "scikit_learn-0.23.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a64817b050efd50f9abcfd311870073e500ae11b299683a519fbb52d85e08d25"}, 268 | {file = "scikit_learn-0.23.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:daf276c465c38ef736a79bd79fc80a249f746bcbcae50c40945428f7ece074f8"}, 269 | {file = "scikit_learn-0.23.2-cp36-cp36m-win32.whl", hash = "sha256:cb3e76380312e1f86abd20340ab1d5b3cc46a26f6593d3c33c9ea3e4c7134028"}, 270 | {file = "scikit_learn-0.23.2-cp36-cp36m-win_amd64.whl", hash = "sha256:0a127cc70990d4c15b1019680bfedc7fec6c23d14d3719fdf9b64b22d37cdeca"}, 271 | {file = "scikit_learn-0.23.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2aa95c2f17d2f80534156215c87bee72b6aa314a7f8b8fe92a2d71f47280570d"}, 272 | {file = "scikit_learn-0.23.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:6c28a1d00aae7c3c9568f61aafeaad813f0f01c729bee4fd9479e2132b215c1d"}, 273 | {file = "scikit_learn-0.23.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:da8e7c302003dd765d92a5616678e591f347460ac7b53e53d667be7dfe6d1b10"}, 274 | {file = "scikit_learn-0.23.2-cp37-cp37m-win32.whl", hash = "sha256:d9a1ce5f099f29c7c33181cc4386660e0ba891b21a60dc036bf369e3a3ee3aec"}, 275 | {file = "scikit_learn-0.23.2-cp37-cp37m-win_amd64.whl", hash = "sha256:914ac2b45a058d3f1338d7736200f7f3b094857758895f8667be8a81ff443b5b"}, 276 | {file = "scikit_learn-0.23.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7671bbeddd7f4f9a6968f3b5442dac5f22bf1ba06709ef888cc9132ad354a9ab"}, 277 | {file = "scikit_learn-0.23.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:d0dcaa54263307075cb93d0bee3ceb02821093b1b3d25f66021987d305d01dce"}, 278 | {file = "scikit_learn-0.23.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5ce7a8021c9defc2b75620571b350acc4a7d9763c25b7593621ef50f3bd019a2"}, 279 | {file = "scikit_learn-0.23.2-cp38-cp38-win32.whl", hash = "sha256:0d39748e7c9669ba648acf40fb3ce96b8a07b240db6888563a7cb76e05e0d9cc"}, 280 | {file = "scikit_learn-0.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:1b8a391de95f6285a2f9adffb7db0892718950954b7149a70c783dc848f104ea"}, 281 | ] 282 | scipy = [ 283 | {file = "scipy-1.5.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cca9fce15109a36a0a9f9cfc64f870f1c140cb235ddf27fe0328e6afb44dfed0"}, 284 | {file = "scipy-1.5.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:1c7564a4810c1cd77fcdee7fa726d7d39d4e2695ad252d7c86c3ea9d85b7fb8f"}, 285 | {file = "scipy-1.5.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:07e52b316b40a4f001667d1ad4eb5f2318738de34597bd91537851365b6c61f1"}, 286 | {file = "scipy-1.5.2-cp36-cp36m-win32.whl", hash = "sha256:d56b10d8ed72ec1be76bf10508446df60954f08a41c2d40778bc29a3a9ad9bce"}, 287 | {file = "scipy-1.5.2-cp36-cp36m-win_amd64.whl", hash = "sha256:8e28e74b97fc8d6aa0454989db3b5d36fc27e69cef39a7ee5eaf8174ca1123cb"}, 288 | {file = "scipy-1.5.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6e86c873fe1335d88b7a4bfa09d021f27a9e753758fd75f3f92d714aa4093768"}, 289 | {file = "scipy-1.5.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a0afbb967fd2c98efad5f4c24439a640d39463282040a88e8e928db647d8ac3d"}, 290 | {file = "scipy-1.5.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:eecf40fa87eeda53e8e11d265ff2254729d04000cd40bae648e76ff268885d66"}, 291 | {file = "scipy-1.5.2-cp37-cp37m-win32.whl", hash = "sha256:315aa2165aca31375f4e26c230188db192ed901761390be908c9b21d8b07df62"}, 292 | {file = "scipy-1.5.2-cp37-cp37m-win_amd64.whl", hash = "sha256:ec5fe57e46828d034775b00cd625c4a7b5c7d2e354c3b258d820c6c72212a6ec"}, 293 | {file = "scipy-1.5.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fc98f3eac993b9bfdd392e675dfe19850cc8c7246a8fd2b42443e506344be7d9"}, 294 | {file = "scipy-1.5.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:a785409c0fa51764766840185a34f96a0a93527a0ff0230484d33a8ed085c8f8"}, 295 | {file = "scipy-1.5.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0a0e9a4e58a4734c2eba917f834b25b7e3b6dc333901ce7784fd31aefbd37b2f"}, 296 | {file = "scipy-1.5.2-cp38-cp38-win32.whl", hash = "sha256:dac09281a0eacd59974e24525a3bc90fa39b4e95177e638a31b14db60d3fa806"}, 297 | {file = "scipy-1.5.2-cp38-cp38-win_amd64.whl", hash = "sha256:92eb04041d371fea828858e4fff182453c25ae3eaa8782d9b6c32b25857d23bc"}, 298 | {file = "scipy-1.5.2.tar.gz", hash = "sha256:066c513d90eb3fd7567a9e150828d39111ebd88d3e924cdfc9f8ce19ab6f90c9"}, 299 | ] 300 | six = [ 301 | {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"}, 302 | {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"}, 303 | ] 304 | threadpoolctl = [ 305 | {file = "threadpoolctl-2.1.0-py3-none-any.whl", hash = "sha256:38b74ca20ff3bb42caca8b00055111d74159ee95c4370882bbff2b93d24da725"}, 306 | {file = "threadpoolctl-2.1.0.tar.gz", hash = "sha256:ddc57c96a38beb63db45d6c159b5ab07b6bced12c45a1f07b2b92f272aebfa6b"}, 307 | ] 308 | torch = [ 309 | {file = "torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:7669f4d923b5758e28b521ea749c795ed67ff24b45ba20296bc8cff706d08df8"}, 310 | {file = "torch-1.6.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:728facb972a5952323c6d790c2c5922b2b35c44b0bc7bdfa02f8639727671a0c"}, 311 | {file = "torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:87d65c01d1b70bb46070824f28bfd93c86d3c5c56b90cbbe836a3f2491d91c76"}, 312 | {file = "torch-1.6.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:3838bd01af7dfb1f78573973f6842ce75b17e8e4f22be99c891dcb7c94bc13f5"}, 313 | {file = "torch-1.6.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5357873e243bcfa804c32dc341f564e9a4c12addfc9baae4ee857fcc09a0a216"}, 314 | {file = "torch-1.6.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:4f9a4ad7947cef566afb0a323d99009fe8524f0b0f2ca1fb7ad5de0400381a5b"}, 315 | ] 316 | zipp = [ 317 | {file = "zipp-3.1.0-py3-none-any.whl", hash = "sha256:aa36550ff0c0b7ef7fa639055d797116ee891440eac1a56f378e2d3179e0320b"}, 318 | {file = "zipp-3.1.0.tar.gz", hash = "sha256:c599e4d75c98f6798c509911d08a22e6c021d074469042177c8c86fb92eefd96"}, 319 | ] 320 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pytorch_resample" 3 | version = "0.1.0" 4 | description = "Resampling methods for iterable datasets in PyTorch" 5 | authors = ["Max Halford "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.6.1" 9 | torch = "^1.6.0" 10 | 11 | [tool.poetry.dev-dependencies] 12 | pytest = "^3.4" 13 | scikit-learn = "^0.23.2" 14 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --doctest-modules --doctest-glob=README.md --verbose 3 | -------------------------------------------------------------------------------- /pytorch_resample/__init__.py: -------------------------------------------------------------------------------- 1 | from .hybrid import HybridSampler 2 | from .over import OverSampler 3 | from .under import UnderSampler 4 | -------------------------------------------------------------------------------- /pytorch_resample/hybrid.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import random 3 | 4 | import torch 5 | 6 | from . import utils 7 | 8 | 9 | class HybridSampler(torch.utils.data.IterableDataset): 10 | """Dataset wrapper that uses both under-sampling and over-sampling. 11 | 12 | Parameters: 13 | dataset 14 | desired_dist: The desired class distribution. The keys are the classes whilst the 15 | values are the desired class percentages. The values are normalised so that sum up 16 | to 1. 17 | sampling_rate: The fraction of data to use. 18 | seed: Random seed for reproducibility. 19 | 20 | Attributes: 21 | actual_dist: The counts of the observed sample labels. 22 | rng: A random number generator instance. 23 | 24 | """ 25 | 26 | def __init__(self, dataset: torch.utils.data.IterableDataset, desired_dist: dict, 27 | sampling_rate: float, seed: int = None): 28 | 29 | self.dataset = dataset 30 | self.desired_dist = {c: p / sum(desired_dist.values()) for c, p in desired_dist.items()} 31 | self.sampling_rate = min(max(sampling_rate, 0), 1) 32 | self.seed = seed 33 | 34 | self.actual_dist = collections.Counter() 35 | self.rng = random.Random(seed) 36 | self._n = 0 37 | 38 | def __iter__(self): 39 | 40 | for x, y in self.dataset: 41 | 42 | self.actual_dist[y] += 1 43 | self._n += 1 44 | 45 | f = self.desired_dist 46 | g = self.actual_dist 47 | 48 | rate = self.sampling_rate * f[y] / (g[y] / self._n) 49 | 50 | for _ in range(utils.random_poisson(rate, rng=self.rng)): 51 | yield x, y 52 | 53 | @classmethod 54 | def expected_size(cls, n, sampling_rate): 55 | return int(sampling_rate * n) 56 | -------------------------------------------------------------------------------- /pytorch_resample/over.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import random 3 | 4 | import torch 5 | 6 | from . import utils 7 | 8 | 9 | class OverSampler(torch.utils.data.IterableDataset): 10 | """Dataset wrapper for over-sampling. 11 | 12 | Parameters: 13 | dataset 14 | desired_dist: The desired class distribution. The keys are the classes whilst the 15 | values are the desired class percentages. The values are normalised so that sum up 16 | to 1. 17 | seed: Random seed for reproducibility. 18 | 19 | Attributes: 20 | actual_dist: The counts of the observed sample labels. 21 | rng: A random number generator instance. 22 | 23 | """ 24 | 25 | def __init__(self, dataset: torch.utils.data.IterableDataset, desired_dist: dict, 26 | seed: int = None): 27 | 28 | self.dataset = dataset 29 | self.desired_dist = {c: p / sum(desired_dist.values()) for c, p in desired_dist.items()} 30 | self.seed = seed 31 | 32 | self.actual_dist = collections.Counter() 33 | self.rng = random.Random(seed) 34 | self._pivot = None 35 | 36 | def __iter__(self): 37 | 38 | for x, y in self.dataset: 39 | 40 | self.actual_dist[y] += 1 41 | 42 | # To ease notation 43 | f = self.desired_dist 44 | g = self.actual_dist 45 | 46 | # Check if the pivot needs to be changed 47 | if y != self._pivot: 48 | self._pivot = max(g.keys(), key=lambda y: g[y] / f[y]) 49 | else: 50 | yield x, y 51 | continue 52 | 53 | # Determine the sampling ratio if the observed label is not the pivot 54 | M = g[self._pivot] / f[self._pivot] 55 | rate = M * f[y] / g[y] 56 | 57 | for _ in range(utils.random_poisson(rate, rng=self.rng)): 58 | yield x, y 59 | 60 | @classmethod 61 | def expected_size(cls, n, desired_dist, actual_dist): 62 | M = max( 63 | actual_dist.get(k) / desired_dist.get(k) 64 | for k in set(desired_dist) | set(actual_dist) 65 | ) 66 | return int(n * M) 67 | -------------------------------------------------------------------------------- /pytorch_resample/under.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import random 3 | 4 | import torch 5 | 6 | 7 | class UnderSampler(torch.utils.data.IterableDataset): 8 | """Dataset wrapper for under-sampling. 9 | 10 | This method is based on rejection sampling. 11 | 12 | Parameters: 13 | dataset 14 | desired_dist: The desired class distribution. The keys are the classes whilst the 15 | values are the desired class percentages. The values are normalised so that sum up 16 | to 1. 17 | seed: Random seed for reproducibility. 18 | 19 | Attributes: 20 | actual_dist: The counts of the observed sample labels. 21 | rng: A random number generator instance. 22 | 23 | References: 24 | - https://www.wikiwand.com/en/Rejection_sampling 25 | 26 | """ 27 | 28 | def __init__(self, dataset: torch.utils.data.IterableDataset, desired_dist: dict, 29 | seed: int = None): 30 | 31 | self.dataset = dataset 32 | self.desired_dist = {c: p / sum(desired_dist.values()) for c, p in desired_dist.items()} 33 | self.seed = seed 34 | 35 | self.actual_dist = collections.Counter() 36 | self.rng = random.Random(seed) 37 | self._pivot = None 38 | 39 | def __iter__(self): 40 | 41 | for x, y in self.dataset: 42 | 43 | self.actual_dist[y] += 1 44 | 45 | # To ease notation 46 | f = self.desired_dist 47 | g = self.actual_dist 48 | 49 | # Check if the pivot needs to be changed 50 | if y != self._pivot: 51 | self._pivot = max(g.keys(), key=lambda y: f[y] / g[y]) 52 | else: 53 | yield x, y 54 | continue 55 | 56 | # Determine the sampling ratio if the observed label is not the pivot 57 | M = f[self._pivot] / g[self._pivot] 58 | ratio = f[y] / (M * g[y]) 59 | 60 | if ratio < 1 and self.rng.random() < ratio: 61 | yield x, y 62 | 63 | @classmethod 64 | def expected_size(cls, n, desired_dist, actual_dist): 65 | M = max( 66 | desired_dist.get(k) / actual_dist.get(k) 67 | for k in set(desired_dist) | set(actual_dist) 68 | ) 69 | return int(n / M) 70 | -------------------------------------------------------------------------------- /pytorch_resample/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | 5 | __all__ = ['random_poisson'] 6 | 7 | 8 | def random_poisson(rate, rng=random): 9 | """Sample a random value from a Poisson distribution. 10 | 11 | This implementation is done in pure Python. Using PyTorch would be much slower. 12 | 13 | References: 14 | - https://www.wikiwand.com/en/Poisson_distribution#/Generating_Poisson-distributed_random_variables 15 | 16 | """ 17 | 18 | L = math.exp(-rate) 19 | k = 0 20 | p = 1 21 | 22 | while p > L: 23 | k += 1 24 | p *= rng.random() 25 | 26 | return k - 1 27 | --------------------------------------------------------------------------------