├── .gitignore ├── README.md ├── cfg.py ├── data └── .gitkeep ├── datasets ├── __init__.py ├── base.py ├── cifar10.py ├── cifar100.py ├── imagenet.py ├── stl10.py ├── tiny_in.py └── transforms.py ├── docker └── Dockerfile ├── eval ├── get_data.py ├── knn.py ├── lbfgs.py └── sgd.py ├── methods ├── __init__.py ├── base.py ├── byol.py ├── contrastive.py ├── norm_mse.py ├── w_mse.py └── whitening.py ├── model.py ├── test.py ├── tf2 ├── README.md └── whitening.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.swp 2 | **/__pycache__ 3 | data/** 4 | !data/.gitkeep 5 | output/** 6 | wandb/ 7 | **/wandb/** 8 | docker/env.list 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Representation Learning 2 | 3 | Official repository of the paper **Whitening for Self-Supervised Representation Learning** 4 | 5 | ICML 2021 | [arXiv:2007.06346](https://arxiv.org/abs/2007.06346) 6 | 7 | It includes 3 types of losses: 8 | - W-MSE [arXiv](https://arxiv.org/abs/2007.06346) 9 | - Contrastive [SimCLR arXiv](https://arxiv.org/abs/2002.05709) 10 | - BYOL [arXiv](https://arxiv.org/abs/2006.07733) 11 | 12 | And 5 datasets: 13 | - CIFAR-10 and CIFAR-100 14 | - STL-10 15 | - Tiny ImageNet 16 | - ImageNet-100 17 | 18 | Checkpoints are stored in `data` each 100 epochs during training. 19 | 20 | The implementation is optimized for a single GPU, although multiple are also supported. It includes fast evaluation: we pre-compute embeddings for the entire dataset and then train a classifier on top. The evaluation of the ResNet-18 encoder takes about one minute. 21 | 22 | ## Installation 23 | 24 | The implementation is based on PyTorch. Logging works on [wandb.ai](https://wandb.ai/). See `docker/Dockerfile`. 25 | 26 | #### ImageNet-100 27 | To get this dataset, take the original ImageNet and filter out [this subset of classes](https://github.com/HobbitLong/CMC/blob/master/imagenet100.txt). We do not use augmentations during testing, and loading big images with resizing on the fly is slow, so we can preprocess classifier train and test images. We recommend [mogrify](https://imagemagick.org/script/mogrify.php) for it. First, you need to resize to 256 (just like `torchvision.transforms.Resize(256)`) and then crop to 224 (like `torchvision.transforms.CenterCrop(224)`). Finally, put the original images to `train`, and resized to `clf` and `test`. 28 | 29 | ## Usage 30 | 31 | Detailed settings are good by default, to see all options: 32 | ``` 33 | python -m train --help 34 | python -m test --help 35 | ``` 36 | 37 | To reproduce the results from [table 1](https://arxiv.org/abs/2007.06346): 38 | #### W-MSE 4 39 | ``` 40 | python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --num_samples 4 --bs 256 --emb 64 --w_size 128 41 | python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --num_samples 4 --bs 256 --emb 64 --w_size 128 42 | python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256 43 | python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256 44 | ``` 45 | 46 | #### W-MSE 2 47 | ``` 48 | python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --w_size 128 49 | python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --w_size 128 50 | python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --w_size 256 --w_iter 4 51 | python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --w_size 256 --w_iter 4 52 | ``` 53 | 54 | #### Contrastive 55 | ``` 56 | python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --method contrastive 57 | python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --method contrastive 58 | python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --method contrastive 59 | python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --method contrastive 60 | ``` 61 | 62 | #### BYOL 63 | ``` 64 | python -m train --dataset cifar10 --epoch 1000 --lr 3e-3 --emb 64 --method byol 65 | python -m train --dataset cifar100 --epoch 1000 --lr 3e-3 --emb 64 --method byol 66 | python -m train --dataset stl10 --epoch 2000 --lr 2e-3 --emb 128 --method byol 67 | python -m train --dataset tiny_in --epoch 1000 --lr 2e-3 --emb 128 --method byol 68 | ``` 69 | 70 | #### ImageNet-100 71 | ``` 72 | python -m train --dataset imagenet --epoch 240 --lr 2e-3 --emb 128 --w_size 256 --crop_s0 0.08 --cj0 0.8 --cj1 0.8 --cj2 0.8 --cj3 0.2 --gs_p 0.2 73 | python -m train --dataset imagenet --epoch 240 --lr 2e-3 --num_samples 4 --bs 256 --emb 128 --w_size 256 --crop_s0 0.08 --cj0 0.8 --cj1 0.8 --cj2 0.8 --cj3 0.2 --gs_p 0.2 74 | ``` 75 | 76 | Use `--no_norm` to disable normalization (for Euclidean distance). 77 | 78 | ## Citation 79 | ``` 80 | @inproceedings{ermolov2021whitening, 81 | title={Whitening for self-supervised representation learning}, 82 | author={Ermolov, Aleksandr and Siarohin, Aliaksandr and Sangineto, Enver and Sebe, Nicu}, 83 | booktitle={International Conference on Machine Learning}, 84 | pages={3015--3024}, 85 | year={2021}, 86 | organization={PMLR} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import argparse 3 | from torchvision import models 4 | import multiprocessing 5 | from datasets import DS_LIST 6 | from methods import METHOD_LIST 7 | 8 | 9 | def get_cfg(): 10 | """ generates configuration from user input in console """ 11 | parser = argparse.ArgumentParser(description="") 12 | parser.add_argument( 13 | "--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type", 14 | ) 15 | parser.add_argument( 16 | "--wandb", 17 | type=str, 18 | default="self_supervised", 19 | help="name of the project for logging at https://wandb.ai", 20 | ) 21 | parser.add_argument( 22 | "--byol_tau", type=float, default=0.99, help="starting tau for byol loss" 23 | ) 24 | parser.add_argument( 25 | "--num_samples", 26 | type=int, 27 | default=2, 28 | help="number of samples (d) generated from each image", 29 | ) 30 | 31 | addf = partial(parser.add_argument, type=float) 32 | addf("--cj0", default=0.4, help="color jitter brightness") 33 | addf("--cj1", default=0.4, help="color jitter contrast") 34 | addf("--cj2", default=0.4, help="color jitter saturation") 35 | addf("--cj3", default=0.1, help="color jitter hue") 36 | addf("--cj_p", default=0.8, help="color jitter probability") 37 | addf("--gs_p", default=0.1, help="grayscale probability") 38 | addf("--crop_s0", default=0.2, help="crop size from") 39 | addf("--crop_s1", default=1.0, help="crop size to") 40 | addf("--crop_r0", default=0.75, help="crop ratio from") 41 | addf("--crop_r1", default=(4 / 3), help="crop ratio to") 42 | addf("--hf_p", default=0.5, help="horizontal flip probability") 43 | 44 | parser.add_argument( 45 | "--no_lr_warmup", 46 | dest="lr_warmup", 47 | action="store_false", 48 | help="do not use learning rate warmup", 49 | ) 50 | parser.add_argument( 51 | "--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head" 52 | ) 53 | parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier") 54 | parser.add_argument("--fname", type=str, help="load model from file") 55 | parser.add_argument( 56 | "--lr_step", 57 | type=str, 58 | choices=["cos", "step", "none"], 59 | default="step", 60 | help="learning rate schedule type", 61 | ) 62 | parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") 63 | parser.add_argument( 64 | "--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)" 65 | ) 66 | parser.add_argument( 67 | "--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)" 68 | ) 69 | parser.add_argument("--T0", type=int, help="period (for --lr_step cos)") 70 | parser.add_argument( 71 | "--Tmult", type=int, default=1, help="period factor (for --lr_step cos)" 72 | ) 73 | parser.add_argument( 74 | "--w_eps", type=float, default=0, help="eps for stability for whitening" 75 | ) 76 | parser.add_argument( 77 | "--head_layers", type=int, default=2, help="number of FC layers in head" 78 | ) 79 | parser.add_argument( 80 | "--head_size", type=int, default=1024, help="size of FC layers in head" 81 | ) 82 | 83 | parser.add_argument( 84 | "--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss" 85 | ) 86 | parser.add_argument( 87 | "--w_iter", 88 | type=int, 89 | default=1, 90 | help="iterations for whitening matrix estimation", 91 | ) 92 | 93 | parser.add_argument( 94 | "--no_norm", dest="norm", action="store_false", help="don't normalize latents", 95 | ) 96 | parser.add_argument( 97 | "--tau", type=float, default=0.5, help="contrastive loss temperature" 98 | ) 99 | 100 | parser.add_argument("--epoch", type=int, default=200, help="total epoch number") 101 | parser.add_argument( 102 | "--eval_every_drop", 103 | type=int, 104 | default=5, 105 | help="how often to evaluate after learning rate drop", 106 | ) 107 | parser.add_argument( 108 | "--eval_every", type=int, default=20, help="how often to evaluate" 109 | ) 110 | parser.add_argument("--emb", type=int, default=64, help="embedding size") 111 | parser.add_argument( 112 | "--bs", type=int, default=512, help="number of original images in batch N", 113 | ) 114 | parser.add_argument( 115 | "--drop", 116 | type=int, 117 | nargs="*", 118 | default=[50, 25], 119 | help="milestones for learning rate decay (0 = last epoch)", 120 | ) 121 | parser.add_argument( 122 | "--drop_gamma", 123 | type=float, 124 | default=0.2, 125 | help="multiplicative factor of learning rate decay", 126 | ) 127 | parser.add_argument( 128 | "--arch", 129 | type=str, 130 | choices=[x for x in dir(models) if "resn" in x], 131 | default="resnet18", 132 | help="encoder architecture", 133 | ) 134 | parser.add_argument("--dataset", type=str, choices=DS_LIST, default="cifar10") 135 | parser.add_argument( 136 | "--num_workers", 137 | type=int, 138 | default=multiprocessing.cpu_count(), 139 | help="dataset workers number", 140 | ) 141 | parser.add_argument( 142 | "--clf", 143 | type=str, 144 | default="sgd", 145 | choices=["sgd", "knn", "lbfgs"], 146 | help="classifier for test.py", 147 | ) 148 | parser.add_argument( 149 | "--eval_head", action="store_true", help="eval head output instead of model", 150 | ) 151 | parser.add_argument("--imagenet_path", type=str, default="~/IN100/") 152 | return parser.parse_args() 153 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htdt/self-supervised/d9662c8d07dafd194a9045375f4f6aa09f5b03e9/data/.gitkeep -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import CIFAR10 2 | from .cifar100 import CIFAR100 3 | from .stl10 import STL10 4 | from .tiny_in import TinyImageNet 5 | from .imagenet import ImageNet 6 | 7 | 8 | DS_LIST = ["cifar10", "cifar100", "stl10", "tiny_in", "imagenet"] 9 | 10 | 11 | def get_ds(name): 12 | assert name in DS_LIST 13 | if name == "cifar10": 14 | return CIFAR10 15 | elif name == "cifar100": 16 | return CIFAR100 17 | elif name == "stl10": 18 | return STL10 19 | elif name == "tiny_in": 20 | return TinyImageNet 21 | elif name == "imagenet": 22 | return ImageNet 23 | -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from functools import lru_cache 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | class BaseDataset(metaclass=ABCMeta): 7 | """ 8 | base class for datasets, it includes 3 types: 9 | - for self-supervised training, 10 | - for classifier training for evaluation, 11 | - for testing 12 | """ 13 | 14 | def __init__( 15 | self, bs_train, aug_cfg, num_workers, bs_clf=1000, bs_test=1000, 16 | ): 17 | self.aug_cfg = aug_cfg 18 | self.bs_train, self.bs_clf, self.bs_test = bs_train, bs_clf, bs_test 19 | self.num_workers = num_workers 20 | 21 | @abstractmethod 22 | def ds_train(self): 23 | raise NotImplementedError 24 | 25 | @abstractmethod 26 | def ds_clf(self): 27 | raise NotImplementedError 28 | 29 | @abstractmethod 30 | def ds_test(self): 31 | raise NotImplementedError 32 | 33 | @property 34 | @lru_cache() 35 | def train(self): 36 | return DataLoader( 37 | dataset=self.ds_train(), 38 | batch_size=self.bs_train, 39 | shuffle=True, 40 | num_workers=self.num_workers, 41 | pin_memory=True, 42 | drop_last=True, 43 | ) 44 | 45 | @property 46 | @lru_cache() 47 | def clf(self): 48 | return DataLoader( 49 | dataset=self.ds_clf(), 50 | batch_size=self.bs_clf, 51 | shuffle=True, 52 | num_workers=self.num_workers, 53 | pin_memory=True, 54 | drop_last=True, 55 | ) 56 | 57 | @property 58 | @lru_cache() 59 | def test(self): 60 | return DataLoader( 61 | dataset=self.ds_test(), 62 | batch_size=self.bs_test, 63 | shuffle=False, 64 | num_workers=self.num_workers, 65 | pin_memory=True, 66 | drop_last=False, 67 | ) 68 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 as C10 2 | import torchvision.transforms as T 3 | from .transforms import MultiSample, aug_transform 4 | from .base import BaseDataset 5 | 6 | 7 | def base_transform(): 8 | return T.Compose( 9 | [T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] 10 | ) 11 | 12 | 13 | class CIFAR10(BaseDataset): 14 | def ds_train(self): 15 | t = MultiSample( 16 | aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples 17 | ) 18 | return C10(root="./data", train=True, download=True, transform=t) 19 | 20 | def ds_clf(self): 21 | t = base_transform() 22 | return C10(root="./data", train=True, download=True, transform=t) 23 | 24 | def ds_test(self): 25 | t = base_transform() 26 | return C10(root="./data", train=False, download=True, transform=t) 27 | -------------------------------------------------------------------------------- /datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR100 as C100 2 | import torchvision.transforms as T 3 | from .transforms import MultiSample, aug_transform 4 | from .base import BaseDataset 5 | 6 | 7 | def base_transform(): 8 | return T.Compose( 9 | [T.ToTensor(), T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))] 10 | ) 11 | 12 | 13 | class CIFAR100(BaseDataset): 14 | def ds_train(self): 15 | t = MultiSample( 16 | aug_transform(32, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples 17 | ) 18 | return C100(root="./data", train=True, download=True, transform=t,) 19 | 20 | def ds_clf(self): 21 | t = base_transform() 22 | return C100(root="./data", train=True, download=True, transform=t) 23 | 24 | def ds_test(self): 25 | t = base_transform() 26 | return C100(root="./data", train=False, download=True, transform=t) 27 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torchvision.datasets import ImageFolder 3 | import torchvision.transforms as T 4 | from PIL import ImageFilter 5 | from .transforms import MultiSample, aug_transform 6 | from .base import BaseDataset 7 | 8 | 9 | class RandomBlur: 10 | def __init__(self, r0, r1): 11 | self.r0, self.r1 = r0, r1 12 | 13 | def __call__(self, image): 14 | r = random.uniform(self.r0, self.r1) 15 | return image.filter(ImageFilter.GaussianBlur(radius=r)) 16 | 17 | 18 | def base_transform(): 19 | return T.Compose( 20 | [T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 21 | ) 22 | 23 | 24 | class ImageNet(BaseDataset): 25 | def ds_train(self): 26 | aug_with_blur = aug_transform( 27 | 224, 28 | base_transform, 29 | self.aug_cfg, 30 | extra_t=[T.RandomApply([RandomBlur(0.1, 2.0)], p=0.5)], 31 | ) 32 | t = MultiSample(aug_with_blur, n=self.aug_cfg.num_samples) 33 | return ImageFolder(root=self.aug_cfg.imagenet_path + "train", transform=t) 34 | 35 | def ds_clf(self): 36 | t = base_transform() 37 | return ImageFolder(root=self.aug_cfg.imagenet_path + "clf", transform=t) 38 | 39 | def ds_test(self): 40 | t = base_transform() 41 | return ImageFolder(root=self.aug_cfg.imagenet_path + "test", transform=t) 42 | -------------------------------------------------------------------------------- /datasets/stl10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import STL10 as S10 2 | import torchvision.transforms as T 3 | from .transforms import MultiSample, aug_transform 4 | from .base import BaseDataset 5 | 6 | 7 | def base_transform(): 8 | return T.Compose( 9 | [T.ToTensor(), T.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27))] 10 | ) 11 | 12 | 13 | def test_transform(): 14 | return T.Compose( 15 | [T.Resize(70, interpolation=3), T.CenterCrop(64), base_transform()] 16 | ) 17 | 18 | 19 | class STL10(BaseDataset): 20 | def ds_train(self): 21 | t = MultiSample( 22 | aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples 23 | ) 24 | return S10(root="./data", split="train+unlabeled", download=True, transform=t) 25 | 26 | def ds_clf(self): 27 | t = test_transform() 28 | return S10(root="./data", split="train", download=True, transform=t) 29 | 30 | def ds_test(self): 31 | t = test_transform() 32 | return S10(root="./data", split="test", download=True, transform=t) 33 | -------------------------------------------------------------------------------- /datasets/tiny_in.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | import torchvision.transforms as T 3 | from .transforms import MultiSample, aug_transform 4 | from .base import BaseDataset 5 | 6 | 7 | def base_transform(): 8 | return T.Compose( 9 | [T.ToTensor(), T.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282))] 10 | ) 11 | 12 | 13 | class TinyImageNet(BaseDataset): 14 | def ds_train(self): 15 | t = MultiSample( 16 | aug_transform(64, base_transform, self.aug_cfg), n=self.aug_cfg.num_samples 17 | ) 18 | return ImageFolder(root="data/tiny-imagenet-200/train", transform=t) 19 | 20 | def ds_clf(self): 21 | t = base_transform() 22 | return ImageFolder(root="data/tiny-imagenet-200/train", transform=t) 23 | 24 | def ds_test(self): 25 | t = base_transform() 26 | return ImageFolder(root="data/tiny-imagenet-200/test", transform=t) 27 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | 4 | def aug_transform(crop, base_transform, cfg, extra_t=[]): 5 | """ augmentation transform generated from config """ 6 | return T.Compose( 7 | [ 8 | T.RandomApply( 9 | [T.ColorJitter(cfg.cj0, cfg.cj1, cfg.cj2, cfg.cj3)], p=cfg.cj_p 10 | ), 11 | T.RandomGrayscale(p=cfg.gs_p), 12 | T.RandomResizedCrop( 13 | crop, 14 | scale=(cfg.crop_s0, cfg.crop_s1), 15 | ratio=(cfg.crop_r0, cfg.crop_r1), 16 | interpolation=3, 17 | ), 18 | T.RandomHorizontalFlip(p=cfg.hf_p), 19 | *extra_t, 20 | base_transform(), 21 | ] 22 | ) 23 | 24 | 25 | class MultiSample: 26 | """ generates n samples with augmentation """ 27 | 28 | def __init__(self, transform, n=2): 29 | self.transform = transform 30 | self.num = n 31 | 32 | def __call__(self, x): 33 | return tuple(self.transform(x) for _ in range(self.num)) 34 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime 2 | RUN pip install sklearn opencv-python 3 | RUN pip install matplotlib 4 | RUN pip install wandb 5 | RUN pip install ipdb 6 | ENTRYPOINT wandb login $WANDB_KEY && /bin/bash 7 | -------------------------------------------------------------------------------- /eval/get_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_data(model, loader, output_size, device): 5 | """ encodes whole dataset into embeddings """ 6 | xs = torch.empty( 7 | len(loader), loader.batch_size, output_size, dtype=torch.float32, device=device 8 | ) 9 | ys = torch.empty(len(loader), loader.batch_size, dtype=torch.long, device=device) 10 | with torch.no_grad(): 11 | for i, (x, y) in enumerate(loader): 12 | x = x.cuda() 13 | xs[i] = model(x).to(device) 14 | ys[i] = y.to(device) 15 | xs = xs.view(-1, output_size) 16 | ys = ys.view(-1) 17 | return xs, ys 18 | -------------------------------------------------------------------------------- /eval/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def eval_knn(x_train, y_train, x_test, y_test, k=5): 5 | """ k-nearest neighbors classifier accuracy """ 6 | d = torch.cdist(x_test, x_train) 7 | topk = torch.topk(d, k=k, dim=1, largest=False) 8 | labels = y_train[topk.indices] 9 | pred = torch.empty_like(y_test) 10 | for i in range(len(labels)): 11 | x = labels[i].unique(return_counts=True) 12 | pred[i] = x[0][x[1].argmax()] 13 | 14 | acc = (pred == y_test).float().mean().cpu().item() 15 | del d, topk, labels, pred 16 | return acc 17 | -------------------------------------------------------------------------------- /eval/lbfgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.linear_model import LogisticRegression 3 | 4 | 5 | def eval_lbfgs(x_train, y_train, x_test, y_test): 6 | """ linear classifier accuracy (lbfgs method) """ 7 | clf = LogisticRegression( 8 | random_state=1337, solver="lbfgs", max_iter=1000, n_jobs=-1 9 | ) 10 | clf.fit(x_train, y_train) 11 | pred = clf.predict(x_test) 12 | return (torch.tensor(pred) == y_test).float().mean() 13 | -------------------------------------------------------------------------------- /eval/sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | 6 | def eval_sgd(x_train, y_train, x_test, y_test, topk=[1, 5], epoch=500): 7 | """ linear classifier accuracy (sgd) """ 8 | lr_start, lr_end = 1e-2, 1e-6 9 | gamma = (lr_end / lr_start) ** (1 / epoch) 10 | output_size = x_train.shape[1] 11 | num_class = y_train.max().item() + 1 12 | clf = nn.Linear(output_size, num_class) 13 | clf.cuda() 14 | clf.train() 15 | optimizer = optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6) 16 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) 17 | criterion = nn.CrossEntropyLoss() 18 | 19 | for ep in range(epoch): 20 | perm = torch.randperm(len(x_train)).view(-1, 1000) 21 | for idx in perm: 22 | optimizer.zero_grad() 23 | criterion(clf(x_train[idx]), y_train[idx]).backward() 24 | optimizer.step() 25 | scheduler.step() 26 | 27 | clf.eval() 28 | with torch.no_grad(): 29 | y_pred = clf(x_test) 30 | pred_top = y_pred.topk(max(topk), 1, largest=True, sorted=True).indices 31 | acc = { 32 | t: (pred_top[:, :t] == y_test[..., None]).float().sum(1).mean().cpu().item() 33 | for t in topk 34 | } 35 | del clf 36 | return acc 37 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .contrastive import Contrastive 2 | from .w_mse import WMSE 3 | from .byol import BYOL 4 | 5 | 6 | METHOD_LIST = ["contrastive", "w_mse", "byol"] 7 | 8 | 9 | def get_method(name): 10 | assert name in METHOD_LIST 11 | if name == "contrastive": 12 | return Contrastive 13 | elif name == "w_mse": 14 | return WMSE 15 | elif name == "byol": 16 | return BYOL 17 | -------------------------------------------------------------------------------- /methods/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from model import get_model, get_head 3 | from eval.sgd import eval_sgd 4 | from eval.knn import eval_knn 5 | from eval.get_data import get_data 6 | 7 | 8 | class BaseMethod(nn.Module): 9 | """ 10 | Base class for self-supervised loss implementation. 11 | It includes encoder and head for training, evaluation function. 12 | """ 13 | 14 | def __init__(self, cfg): 15 | super().__init__() 16 | self.model, self.out_size = get_model(cfg.arch, cfg.dataset) 17 | self.head = get_head(self.out_size, cfg) 18 | self.knn = cfg.knn 19 | self.num_pairs = cfg.num_samples * (cfg.num_samples - 1) // 2 20 | self.eval_head = cfg.eval_head 21 | self.emb_size = cfg.emb 22 | 23 | def forward(self, samples): 24 | raise NotImplementedError 25 | 26 | def get_acc(self, ds_clf, ds_test): 27 | self.eval() 28 | if self.eval_head: 29 | model = lambda x: self.head(self.model(x)) 30 | out_size = self.emb_size 31 | else: 32 | model, out_size = self.model, self.out_size 33 | # torch.cuda.empty_cache() 34 | x_train, y_train = get_data(model, ds_clf, out_size, "cuda") 35 | x_test, y_test = get_data(model, ds_test, out_size, "cuda") 36 | 37 | acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn) 38 | acc_linear = eval_sgd(x_train, y_train, x_test, y_test) 39 | del x_train, y_train, x_test, y_test 40 | self.train() 41 | return acc_knn, acc_linear 42 | 43 | def step(self, progress): 44 | pass 45 | -------------------------------------------------------------------------------- /methods/byol.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from model import get_model, get_head 7 | from .base import BaseMethod 8 | from .norm_mse import norm_mse_loss 9 | 10 | 11 | class BYOL(BaseMethod): 12 | """ implements BYOL loss https://arxiv.org/abs/2006.07733 """ 13 | 14 | def __init__(self, cfg): 15 | """ init additional target and predictor networks """ 16 | super().__init__(cfg) 17 | self.pred = nn.Sequential( 18 | nn.Linear(cfg.emb, cfg.head_size), 19 | nn.BatchNorm1d(cfg.head_size), 20 | nn.ReLU(), 21 | nn.Linear(cfg.head_size, cfg.emb), 22 | ) 23 | self.model_t, _ = get_model(cfg.arch, cfg.dataset) 24 | self.head_t = get_head(self.out_size, cfg) 25 | for param in chain(self.model_t.parameters(), self.head_t.parameters()): 26 | param.requires_grad = False 27 | self.update_target(0) 28 | self.byol_tau = cfg.byol_tau 29 | self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss 30 | 31 | def update_target(self, tau): 32 | """ copy parameters from main network to target """ 33 | for t, s in zip(self.model_t.parameters(), self.model.parameters()): 34 | t.data.copy_(t.data * tau + s.data * (1.0 - tau)) 35 | for t, s in zip(self.head_t.parameters(), self.head.parameters()): 36 | t.data.copy_(t.data * tau + s.data * (1.0 - tau)) 37 | 38 | def forward(self, samples): 39 | z = [self.pred(self.head(self.model(x))) for x in samples] 40 | with torch.no_grad(): 41 | zt = [self.head_t(self.model_t(x)) for x in samples] 42 | 43 | loss = 0 44 | for i in range(len(samples) - 1): 45 | for j in range(i + 1, len(samples)): 46 | loss += self.loss_f(z[i], zt[j]) + self.loss_f(z[j], zt[i]) 47 | loss /= self.num_pairs 48 | return loss 49 | 50 | def step(self, progress): 51 | """ update target network with cosine increasing schedule """ 52 | tau = 1 - (1 - self.byol_tau) * (math.cos(math.pi * progress) + 1) / 2 53 | self.update_target(tau) 54 | -------------------------------------------------------------------------------- /methods/contrastive.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .base import BaseMethod 6 | 7 | 8 | def contrastive_loss(x0, x1, tau, norm): 9 | # https://github.com/google-research/simclr/blob/master/objective.py 10 | bsize = x0.shape[0] 11 | target = torch.arange(bsize).cuda() 12 | eye_mask = torch.eye(bsize).cuda() * 1e9 13 | if norm: 14 | x0 = F.normalize(x0, p=2, dim=1) 15 | x1 = F.normalize(x1, p=2, dim=1) 16 | logits00 = x0 @ x0.t() / tau - eye_mask 17 | logits11 = x1 @ x1.t() / tau - eye_mask 18 | logits01 = x0 @ x1.t() / tau 19 | logits10 = x1 @ x0.t() / tau 20 | return ( 21 | F.cross_entropy(torch.cat([logits01, logits00], dim=1), target) 22 | + F.cross_entropy(torch.cat([logits10, logits11], dim=1), target) 23 | ) / 2 24 | 25 | 26 | class Contrastive(BaseMethod): 27 | """ implements contrastive loss https://arxiv.org/abs/2002.05709 """ 28 | 29 | def __init__(self, cfg): 30 | """ init additional BN used after head """ 31 | super().__init__(cfg) 32 | self.bn_last = nn.BatchNorm1d(cfg.emb) 33 | self.loss_f = partial(contrastive_loss, tau=cfg.tau, norm=cfg.norm) 34 | 35 | def forward(self, samples): 36 | bs = len(samples[0]) 37 | h = [self.model(x.cuda(non_blocking=True)) for x in samples] 38 | h = self.bn_last(self.head(torch.cat(h))) 39 | loss = 0 40 | for i in range(len(samples) - 1): 41 | for j in range(i + 1, len(samples)): 42 | x0 = h[i * bs : (i + 1) * bs] 43 | x1 = h[j * bs : (j + 1) * bs] 44 | loss += self.loss_f(x0, x1) 45 | loss /= self.num_pairs 46 | return loss 47 | -------------------------------------------------------------------------------- /methods/norm_mse.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def norm_mse_loss(x0, x1): 5 | x0 = F.normalize(x0) 6 | x1 = F.normalize(x1) 7 | return 2 - 2 * (x0 * x1).sum(dim=-1).mean() 8 | -------------------------------------------------------------------------------- /methods/w_mse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .whitening import Whitening2d 4 | from .base import BaseMethod 5 | from .norm_mse import norm_mse_loss 6 | 7 | 8 | class WMSE(BaseMethod): 9 | """ implements W-MSE loss """ 10 | 11 | def __init__(self, cfg): 12 | """ init whitening transform """ 13 | super().__init__(cfg) 14 | self.whitening = Whitening2d(cfg.emb, eps=cfg.w_eps, track_running_stats=False) 15 | self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss 16 | self.w_iter = cfg.w_iter 17 | self.w_size = cfg.bs if cfg.w_size is None else cfg.w_size 18 | 19 | def forward(self, samples): 20 | bs = len(samples[0]) 21 | h = [self.model(x.cuda(non_blocking=True)) for x in samples] 22 | h = self.head(torch.cat(h)) 23 | loss = 0 24 | for _ in range(self.w_iter): 25 | z = torch.empty_like(h) 26 | perm = torch.randperm(bs).view(-1, self.w_size) 27 | for idx in perm: 28 | for i in range(len(samples)): 29 | z[idx + i * bs] = self.whitening(h[idx + i * bs]) 30 | for i in range(len(samples) - 1): 31 | for j in range(i + 1, len(samples)): 32 | x0 = z[i * bs : (i + 1) * bs] 33 | x1 = z[j * bs : (j + 1) * bs] 34 | loss += self.loss_f(x0, x1) 35 | loss /= self.w_iter * self.num_pairs 36 | return loss 37 | -------------------------------------------------------------------------------- /methods/whitening.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import conv2d 4 | 5 | 6 | class Whitening2d(nn.Module): 7 | def __init__(self, num_features, momentum=0.01, track_running_stats=True, eps=0): 8 | super(Whitening2d, self).__init__() 9 | self.num_features = num_features 10 | self.momentum = momentum 11 | self.track_running_stats = track_running_stats 12 | self.eps = eps 13 | 14 | if self.track_running_stats: 15 | self.register_buffer( 16 | "running_mean", torch.zeros([1, self.num_features, 1, 1]) 17 | ) 18 | self.register_buffer("running_variance", torch.eye(self.num_features)) 19 | 20 | def forward(self, x): 21 | x = x.unsqueeze(2).unsqueeze(3) 22 | m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) 23 | if not self.training and self.track_running_stats: # for inference 24 | m = self.running_mean 25 | xn = x - m 26 | 27 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1) 28 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) 29 | 30 | eye = torch.eye(self.num_features).type(f_cov.type()) 31 | 32 | if not self.training and self.track_running_stats: # for inference 33 | f_cov = self.running_variance 34 | 35 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye 36 | 37 | inv_sqrt = torch.linalg.solve_triangular( 38 | torch.linalg.cholesky(f_cov_shrinked), 39 | eye, 40 | upper=False 41 | ) 42 | 43 | inv_sqrt = inv_sqrt.contiguous().view( 44 | self.num_features, self.num_features, 1, 1 45 | ) 46 | 47 | decorrelated = conv2d(xn, inv_sqrt) 48 | 49 | if self.training and self.track_running_stats: 50 | self.running_mean = torch.add( 51 | self.momentum * m.detach(), 52 | (1 - self.momentum) * self.running_mean, 53 | out=self.running_mean, 54 | ) 55 | self.running_variance = torch.add( 56 | self.momentum * f_cov.detach(), 57 | (1 - self.momentum) * self.running_variance, 58 | out=self.running_variance, 59 | ) 60 | 61 | return decorrelated.squeeze(2).squeeze(2) 62 | 63 | def extra_repr(self): 64 | return "features={}, eps={}, momentum={}".format( 65 | self.num_features, self.eps, self.momentum 66 | ) 67 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | def get_head(out_size, cfg): 6 | """ creates projection head g() from config """ 7 | x = [] 8 | in_size = out_size 9 | for _ in range(cfg.head_layers - 1): 10 | x.append(nn.Linear(in_size, cfg.head_size)) 11 | if cfg.add_bn: 12 | x.append(nn.BatchNorm1d(cfg.head_size)) 13 | x.append(nn.ReLU()) 14 | in_size = cfg.head_size 15 | x.append(nn.Linear(in_size, cfg.emb)) 16 | return nn.Sequential(*x) 17 | 18 | 19 | def get_model(arch, dataset): 20 | """ creates encoder E() by name and modifies it for dataset """ 21 | model = getattr(models, arch)(pretrained=False) 22 | if dataset != "imagenet": 23 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 24 | if dataset == "cifar10" or dataset == "cifar100": 25 | model.maxpool = nn.Identity() 26 | out_size = model.fc.in_features 27 | model.fc = nn.Identity() 28 | 29 | return nn.DataParallel(model), out_size 30 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import get_ds 3 | from cfg import get_cfg 4 | from methods import get_method 5 | 6 | from eval.sgd import eval_sgd 7 | from eval.knn import eval_knn 8 | from eval.lbfgs import eval_lbfgs 9 | from eval.get_data import get_data 10 | 11 | 12 | if __name__ == "__main__": 13 | cfg = get_cfg() 14 | 15 | model_full = get_method(cfg.method)(cfg) 16 | model_full.cuda().eval() 17 | if cfg.fname is None: 18 | print("evaluating random model") 19 | else: 20 | model_full.load_state_dict(torch.load(cfg.fname)) 21 | 22 | ds = get_ds(cfg.dataset)(None, cfg, cfg.num_workers) 23 | device = "cpu" if cfg.clf == "lbfgs" else "cuda" 24 | if cfg.eval_head: 25 | model = lambda x: model_full.head(model_full.model(x)) 26 | out_size = cfg.emb 27 | else: 28 | model = model_full.model 29 | out_size = model_full.out_size 30 | x_train, y_train = get_data(model, ds.clf, out_size, device) 31 | x_test, y_test = get_data(model, ds.test, out_size, device) 32 | 33 | if cfg.clf == "sgd": 34 | acc = eval_sgd(x_train, y_train, x_test, y_test) 35 | if cfg.clf == "knn": 36 | acc = eval_knn(x_train, y_train, x_test, y_test) 37 | elif cfg.clf == "lbfgs": 38 | acc = eval_lbfgs(x_train, y_train, x_test, y_test) 39 | print(acc) 40 | -------------------------------------------------------------------------------- /tf2/README.md: -------------------------------------------------------------------------------- 1 | `w_mse_loss()` from `whitening.py` is W-MSE loss implementation for TensorFlow 2, 2 | it can be used with other popular implementations, e.g. [SimCLRv2](https://github.com/google-research/simclr/tree/master/tf2). 3 | 4 | 5 | Method uses global flags mechanism as in SimCLRv2: 6 | - `FLAGS.num_samples` - number of samples (d) generated from each image 7 | - `FLAGS.train_batch_size` 8 | - `FLAGS.proj_out_dim` 9 | -------------------------------------------------------------------------------- /tf2/whitening.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v2 as tf 2 | from absl import flags 3 | 4 | FLAGS = flags.FLAGS 5 | 6 | 7 | class Whitening1D(tf.keras.layers.Layer): 8 | def __init__(self, eps=0, **kwargs): 9 | super(Whitening1D, self).__init__(**kwargs) 10 | self.eps = eps 11 | 12 | def call(self, x): 13 | bs, c = x.shape 14 | x_t = tf.transpose(x, (1, 0)) 15 | m = tf.reduce_mean(x_t, axis=1, keepdims=True) 16 | f = x_t - m 17 | ff_apr = tf.matmul(f, f, transpose_b=True) / (tf.cast(bs, tf.float32) - 1.0) 18 | ff_apr_shrinked = (1 - self.eps) * ff_apr + tf.eye(c) * self.eps 19 | sqrt = tf.linalg.cholesky(ff_apr_shrinked) 20 | inv_sqrt = tf.linalg.triangular_solve(sqrt, tf.eye(c)) 21 | f_hat = tf.matmul(inv_sqrt, f) 22 | decorelated = tf.transpose(f_hat, (1, 0)) 23 | return decorelated 24 | 25 | 26 | def w_mse_loss(x): 27 | """ input x shape = (batch size * num_samples, proj_out_dim) """ 28 | 29 | w = Whitening1D() 30 | num_samples = FLAGS.num_samples 31 | num_slice = num_samples * FLAGS.train_batch_size // (2 * FLAGS.proj_out_dim) 32 | x_split = tf.split(x, num_slice, 0) 33 | for i in range(num_slice): 34 | x_split[i] = w(x_split[i]) 35 | x = tf.concat(x_split, 0) 36 | x = tf.math.l2_normalize(x, -1) 37 | 38 | x_split = tf.split(x, num_samples, 0) 39 | loss = 0 40 | for i in range(num_samples - 1): 41 | for j in range(i + 1, num_samples): 42 | v = x_split[i] * x_split[j] 43 | loss += 2 - 2 * tf.reduce_mean(tf.reduce_sum(v, -1)) 44 | loss /= num_samples * (num_samples - 1) // 2 45 | return loss 46 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange, tqdm 2 | import numpy as np 3 | import wandb 4 | import torch 5 | import torch.optim as optim 6 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts 7 | import torch.backends.cudnn as cudnn 8 | 9 | from cfg import get_cfg 10 | from datasets import get_ds 11 | from methods import get_method 12 | 13 | 14 | def get_scheduler(optimizer, cfg): 15 | if cfg.lr_step == "cos": 16 | return CosineAnnealingWarmRestarts( 17 | optimizer, 18 | T_0=cfg.epoch if cfg.T0 is None else cfg.T0, 19 | T_mult=cfg.Tmult, 20 | eta_min=cfg.eta_min, 21 | ) 22 | elif cfg.lr_step == "step": 23 | m = [cfg.epoch - a for a in cfg.drop] 24 | return MultiStepLR(optimizer, milestones=m, gamma=cfg.drop_gamma) 25 | else: 26 | return None 27 | 28 | 29 | if __name__ == "__main__": 30 | cfg = get_cfg() 31 | wandb.init(project=cfg.wandb, config=cfg) 32 | 33 | ds = get_ds(cfg.dataset)(cfg.bs, cfg, cfg.num_workers) 34 | model = get_method(cfg.method)(cfg) 35 | model.cuda().train() 36 | if cfg.fname is not None: 37 | model.load_state_dict(torch.load(cfg.fname)) 38 | 39 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.adam_l2) 40 | scheduler = get_scheduler(optimizer, cfg) 41 | 42 | eval_every = cfg.eval_every 43 | lr_warmup = 0 if cfg.lr_warmup else 500 44 | cudnn.benchmark = True 45 | 46 | for ep in trange(cfg.epoch, position=0): 47 | loss_ep = [] 48 | iters = len(ds.train) 49 | for n_iter, (samples, _) in enumerate(tqdm(ds.train, position=1)): 50 | if lr_warmup < 500: 51 | lr_scale = (lr_warmup + 1) / 500 52 | for pg in optimizer.param_groups: 53 | pg["lr"] = cfg.lr * lr_scale 54 | lr_warmup += 1 55 | 56 | optimizer.zero_grad() 57 | loss = model(samples) 58 | loss.backward() 59 | optimizer.step() 60 | loss_ep.append(loss.item()) 61 | model.step(ep / cfg.epoch) 62 | if cfg.lr_step == "cos" and lr_warmup >= 500: 63 | scheduler.step(ep + n_iter / iters) 64 | 65 | if cfg.lr_step == "step": 66 | scheduler.step() 67 | 68 | if len(cfg.drop) and ep == (cfg.epoch - cfg.drop[0]): 69 | eval_every = cfg.eval_every_drop 70 | 71 | if (ep + 1) % eval_every == 0: 72 | acc_knn, acc = model.get_acc(ds.clf, ds.test) 73 | wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False) 74 | 75 | if (ep + 1) % 100 == 0: 76 | fname = f"data/{cfg.method}_{cfg.dataset}_{ep}.pt" 77 | torch.save(model.state_dict(), fname) 78 | 79 | wandb.log({"loss": np.mean(loss_ep), "ep": ep}) 80 | --------------------------------------------------------------------------------