├── .gitignore ├── README.md ├── mixmatch_pytorch ├── __init__.py ├── get_mixmatch_loss.py ├── get_unlabeled_loader.py ├── guess_targets.py ├── k_batch_sampler.py ├── mixmatch_batch.py ├── mixmatch_loader.py ├── mixup_samples.py ├── sharpen.py └── tile_adjacent.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | tests.py 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mixmatch-pytorch 2 | An implementation of MixMatch (https://arxiv.org/pdf/1905.02249.pdf) with PyTorch 3 | 4 | 5 | ## Installation 6 | `pip install git+https://github.com/FelixAbrahamsson/mixmatch-pytorch` 7 | 8 | ## Instructions 9 | The package provides a class `mixmatch_pytorch.MixMatchLoader` that works like a normal PyTorch DataLoader, as well as a loss function that is constructed from `mixmatch_pytorch.get_mixmatch_loss`. For example uses, see below. 10 | 11 | You must provide a data loader that functions as an iterable yielding dictionaries with keys `'features'` and `'targets'` that hold augmented (!) features and targets for the labeled dataset. A dataset must also be provided for the unlabeled data, that can be wrapped in a PyTorch DataLoader. The dataset must return dictionaries with key `'features'` that hold augmented features. 12 | 13 | A model used for guessing targets for unlabeled data must be provided, as well as an output transform that converts the logits to probabilities. 14 | 15 | Your targets may be single class or multiclass, though for a multiclass task take care to use one-hot encoding with a float datatype for your targets. If you want to use this package for a regression task, it should work out of the box with a simple change of input hyperparameters (losses etc.). You would also need to set T=1 to remove sharpening. 16 | 17 | For a description of the hyperparameters, please refer to the author's article. 18 | 19 | ## Example use 20 | ```python 21 | from mixmatch_pytorch import MixMatchLoader, get_mixmatch_loss 22 | 23 | 24 | loader_mixmatch = MixMatchLoader( 25 | loader_labeled, 26 | dataset_unlabeled, 27 | model, 28 | output_transform=torch.sigmoid, 29 | K=2, 30 | T=0.5, 31 | alpha=0.75 32 | ) 33 | 34 | criterion = get_mixmatch_loss( 35 | criterion_labeled=nn.BCEWithLogitsLoss(), 36 | output_transform=torch.sigmoid, 37 | K=2, 38 | weight_unlabeled=100., 39 | criterion_unlabeled=nn.MSELoss() 40 | ) 41 | 42 | for batch in loader_mixmatch: 43 | logits = model(batch['features'].to(device)) 44 | loss = criterion(logits, batch['targets']) 45 | ``` 46 | -------------------------------------------------------------------------------- /mixmatch_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixmatch_loader import MixMatchLoader 2 | from .get_mixmatch_loss import get_mixmatch_loss 3 | -------------------------------------------------------------------------------- /mixmatch_pytorch/get_mixmatch_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_mixmatch_loss( 5 | criterion_labeled, output_transform, K=2, weight_unlabeled=100., 6 | criterion_unlabeled=nn.MSELoss() 7 | ): 8 | 9 | def loss_function(logits, targets): 10 | 11 | batch_size = len(logits) // (K + 1) 12 | loss_labeled = criterion_labeled( 13 | logits[:batch_size], targets[:batch_size] 14 | ) 15 | loss_unlabeled = criterion_unlabeled( 16 | output_transform(logits[batch_size:]), targets[batch_size:] 17 | ) 18 | return loss_labeled + weight_unlabeled * loss_unlabeled 19 | 20 | return loss_function 21 | -------------------------------------------------------------------------------- /mixmatch_pytorch/get_unlabeled_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from .k_batch_sampler import KBatchSampler 3 | 4 | 5 | def get_unlabeled_loader(dataset, batch_size, K, num_workers=1): 6 | 7 | return DataLoader( 8 | dataset, 9 | batch_sampler=KBatchSampler(dataset, batch_size, K), 10 | num_workers=num_workers, 11 | ) 12 | -------------------------------------------------------------------------------- /mixmatch_pytorch/guess_targets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .sharpen import sharpen 3 | from .tile_adjacent import tile_adjacent 4 | 5 | 6 | def guess_targets(features, model, output_transform, K, T): 7 | 8 | original_device = features.device 9 | with torch.no_grad(): 10 | features = features.to(next(model.parameters()).device) 11 | probabilities = output_transform(model(features)) 12 | probabilities = ( 13 | probabilities 14 | .view(-1, K, *probabilities.shape[1:]) 15 | .mean(dim=1) 16 | ) 17 | probabilities = sharpen(probabilities, T).to(original_device) 18 | 19 | return tile_adjacent(probabilities, K) 20 | -------------------------------------------------------------------------------- /mixmatch_pytorch/k_batch_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Sampler 3 | from .tile_adjacent import tile_adjacent 4 | 5 | 6 | class KBatchSampler(Sampler): 7 | 8 | def __init__(self, dataset, batch_size, K): 9 | 10 | self.num_samples = len(dataset) 11 | self.batch_size = batch_size 12 | self.K = K 13 | 14 | def __iter__(self): 15 | 16 | while True: 17 | indices = torch.randint(0, self.num_samples, (self.batch_size,)) 18 | yield tile_adjacent(indices, self.K) 19 | -------------------------------------------------------------------------------- /mixmatch_pytorch/mixmatch_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .guess_targets import guess_targets 3 | from .mixup_samples import mixup_samples 4 | 5 | 6 | def mixmatch_batch( 7 | batch, batch_unlabeled, model, output_transform, K, T, beta 8 | ): 9 | 10 | features_labeled = batch['features'] 11 | targets_labeled = batch['targets'] 12 | features_unlabeled = batch_unlabeled['features'] 13 | targets_unlabeled = guess_targets( 14 | features_unlabeled, model, output_transform, K, T 15 | ) 16 | 17 | indices = torch.randperm(len(features_labeled) + len(features_unlabeled)) 18 | features_W = torch.cat((features_labeled, features_unlabeled), dim=0)[indices] 19 | targets_W = torch.cat((targets_labeled, targets_unlabeled), dim=0)[indices] 20 | 21 | features_X, targets_X = mixup_samples( 22 | features_labeled, 23 | targets_labeled, 24 | features_W[:len(features_labeled)], 25 | targets_W[:len(features_labeled)], 26 | beta 27 | ) 28 | features_U, targets_U = mixup_samples( 29 | features_unlabeled, 30 | targets_unlabeled, 31 | features_W[len(features_labeled):], 32 | targets_W[len(features_labeled):], 33 | beta 34 | ) 35 | 36 | return dict( 37 | features=torch.cat((features_X, features_U), dim=0), 38 | targets=torch.cat((targets_X, targets_U), dim=0), 39 | ) 40 | -------------------------------------------------------------------------------- /mixmatch_pytorch/mixmatch_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .mixmatch_batch import mixmatch_batch 3 | from .get_unlabeled_loader import get_unlabeled_loader 4 | 5 | 6 | class MixMatchLoader: 7 | 8 | def __init__( 9 | self, loader_labeled, dataset_unlabeled, model, output_transform, 10 | K=2, T=0.5, alpha=0.75 11 | ): 12 | 13 | self.loader_labeled = loader_labeled 14 | self.loader_unlabeled = get_unlabeled_loader( 15 | dataset_unlabeled, 16 | loader_labeled.batch_size, 17 | K, 18 | loader_labeled.num_workers, 19 | ) 20 | self.model = model 21 | self.output_transform = output_transform 22 | self.K = K 23 | self.T = T 24 | alpha = torch.tensor(alpha, dtype=torch.float64) 25 | self.beta = torch.distributions.beta.Beta(alpha, alpha) 26 | 27 | def __iter__(self): 28 | 29 | zipped_loaders = zip(self.loader_labeled, self.loader_unlabeled) 30 | for batch_labeled, batch_unlabeled in zipped_loaders: 31 | yield mixmatch_batch( 32 | batch_labeled, batch_unlabeled, self.model, 33 | self.output_transform, self.K, self.T, self.beta 34 | ) 35 | 36 | def __len__(self): 37 | 38 | return min(len(self.loader_labeled), len(self.loader_unlabeled)) 39 | -------------------------------------------------------------------------------- /mixmatch_pytorch/mixup_samples.py: -------------------------------------------------------------------------------- 1 | def mixup_samples(x1, y1, x2, y2, beta): 2 | 3 | weight = beta.sample().item() 4 | weight = max(weight, 1 - weight) 5 | 6 | x = x1 * weight + (1 - weight) * x2 7 | y = y1 * weight + (1 - weight) * y2 8 | 9 | return x, y 10 | -------------------------------------------------------------------------------- /mixmatch_pytorch/sharpen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sharpen(probabilities, T): 5 | 6 | if probabilities.ndim == 1: 7 | tempered = torch.pow(probabilities, 1 / T) 8 | tempered = ( 9 | tempered 10 | / (torch.pow((1 - probabilities), 1 / T) + tempered) 11 | ) 12 | 13 | else: 14 | tempered = torch.pow(probabilities, 1 / T) 15 | tempered = tempered / tempered.sum(dim=-1, keepdim=True) 16 | 17 | return tempered 18 | -------------------------------------------------------------------------------- /mixmatch_pytorch/tile_adjacent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def tile_adjacent(tensor, K): 5 | ''' Examples: 6 | in: tensor=[0, 1, 2, 3], K=2 7 | out: [0, 0, 1, 1, 2, 2, 3, 3] 8 | 9 | in: tensor=[ 10 | [1, 2], 11 | [3, 4] 12 | ], K=2 13 | out: [ 14 | [1, 2], 15 | [1, 2], 16 | [3, 4], 17 | [3, 4] 18 | ] 19 | ''' 20 | 21 | return ( 22 | torch.stack( 23 | tensor 24 | .repeat(K, *tuple(1 for i in range(tensor.ndim - 1))) 25 | .split(tensor.size(0), dim=0) 26 | ) 27 | .transpose(1, 0) 28 | .contiguous() 29 | .view(-1, *tensor.shape[1:]) 30 | ) 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch<=1.3 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name='mixmatch-pytorch', 6 | version='0.2.1', 7 | description='MixMatch for PyTorch', 8 | author='Felix Abrahamsson', 9 | author_email='FelixAbrahamsson@github.com', 10 | keywords='mixmatch holistic approach pytorch torch', 11 | packages=['mixmatch_pytorch'], 12 | install_requires=[ 13 | 'torch<=1.3', 14 | ], 15 | ) 16 | --------------------------------------------------------------------------------