├── LICENSE ├── README.md ├── linear.py ├── main.py ├── misc ├── fig1.png └── neurips_fig.png ├── model.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ching-Yao Chuang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Debiased Contrastive Learning 2 | 3 |

4 | 5 |

6 | 7 | A prominent technique for self-supervised representation learning has been to contrast semantically similar and dissimilar pairs of samples. Without access to labels, dissimilar (negative) points are typically taken to be randomly sampled datapoints, implicitly accepting that these points may, in reality, actually have the same label. Perhaps unsurprisingly, we observe that sampling negative examples from truly different labels improves performance, in a synthetic setting where labels are available. Motivated by this observation, we develop a debiased contrastive objective that corrects for the sampling of same-label datapoints, even without knowledge of the true labels. 8 | 9 | 10 | **Debiased Contrastive Learning** NeurIPS 2020 [[paper]](https://arxiv.org/abs/2007.00224) 11 |
12 | [Ching-Yao Chuang](https://chingyaoc.github.io/), 13 | [Joshua Robinson](https://joshrobinson.mit.edu/), 14 | [Lin Yen-Chen](https://yenchenlin.me/), 15 | [Antonio Torralba](http://web.mit.edu/torralba/www/), and 16 | [Stefanie Jegelka](https://people.csail.mit.edu/stefje/) 17 |
18 | 19 | 20 | ## Prerequisites 21 | - Python 3.7 22 | - PyTorch 1.3.1 23 | - PIL 24 | - OpenCV 25 | 26 | ## Contrastive Representation Learning 27 | We can train standard (biased) or debiased version (M=1) of [SimCLR](https://arxiv.org/abs/2002.05709) with `main.py` on STL10 dataset. 28 | 29 | flags: 30 | - `--debiased`: use debiased objective (True) or standard objective (False) 31 | - `--tau_plus`: specify class probability 32 | - `--batch_size`: batch size for SimCLR 33 | 34 | For instance, run the following command to train a debiased encoder. 35 | ``` 36 | python main.py --tau_plus = 0.1 37 | ``` 38 | 39 | #### *Due to the implementation of ```nn.DataParallel()```, training with at most 2 GPUs gives the best result. 40 | 41 | ## Linear evaluation 42 | The model is evaluated by training a linear classifier after fixing the learned embedding. 43 | 44 | path flags: 45 | - `--model_path`: specify the path to saved model 46 | ``` 47 | python linear.py --model_path results/model_400.pth 48 | ``` 49 | 50 | #### Pretrained Models 51 | | | tau_plus | Arch | Latent Dim | Batch Size | Accuracy(%) | Download | 52 | |----------|:---:|:----:|:---:|:---:|:---:|:---:| 53 | | Biased | tau_plus = 0.0 | ResNet50 | 128 | 256 | 80.15 | [model](https://drive.google.com/file/d/1qQE03ztnQCK4dtG-GPwCvF66nq_Mk_mo/view?usp=sharing)| 54 | | Debiased |tau_plus = 0.05 | ResNet50 | 128 | 256 | 81.85 | [model](https://drive.google.com/file/d/1pA4Hpcug8tbgH9O6PCu-447vJzxbbR5I/view?usp=sharing)| 55 | | Debiased |tau_plus = 0.1 | ResNet50 | 128 | 256 | 84.26 | [model](https://drive.google.com/file/d/1d8nfGHsHIuJYjU7mHtCtSXf98IbWMFAa/view?usp=sharing)| 56 | 57 | ## Citation 58 | 59 | If you find this repo useful for your research, please consider citing the paper 60 | 61 | ``` 62 | @article{chuang2020debiased, 63 | title={Debiased contrastive learning}, 64 | author={Chuang, Ching-Yao and Robinson, Joshua and Lin, Yen-Chen and Torralba, Antonio and Jegelka, Stefanie}, 65 | journal={Advances in Neural Information Processing Systems}, 66 | volume={33}, 67 | year={2020} 68 | } 69 | ``` 70 | For any questions, please contact Ching-Yao Chuang (cychuang@mit.edu). 71 | 72 | ## Acknowledgements 73 | 74 | Part of this code is inspired by [leftthomas/SimCLR](https://github.com/leftthomas/SimCLR). 75 | -------------------------------------------------------------------------------- /linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | from torchvision.datasets import STL10 8 | from tqdm import tqdm 9 | 10 | import utils 11 | from model import Model 12 | 13 | class Net(nn.Module): 14 | def __init__(self, num_class, pretrained_path): 15 | super(Net, self).__init__() 16 | 17 | # encoder 18 | model = Model().cuda() 19 | model = nn.DataParallel(model) 20 | model.load_state_dict(torch.load(pretrained_path)) 21 | 22 | self.f = model.module.f 23 | # classifier 24 | self.fc = nn.Linear(2048, num_class, bias=True) 25 | 26 | def forward(self, x): 27 | x = self.f(x) 28 | feature = torch.flatten(x, start_dim=1) 29 | out = self.fc(feature) 30 | return out 31 | 32 | 33 | # train or test for one epoch 34 | def train_val(net, data_loader, train_optimizer): 35 | is_train = train_optimizer is not None 36 | net.train() if is_train else net.eval() 37 | 38 | total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader) 39 | with (torch.enable_grad() if is_train else torch.no_grad()): 40 | for data, target in data_bar: 41 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 42 | out = net(data) 43 | loss = loss_criterion(out, target) 44 | 45 | if is_train: 46 | train_optimizer.zero_grad() 47 | loss.backward() 48 | train_optimizer.step() 49 | 50 | total_num += data.size(0) 51 | total_loss += loss.item() * data.size(0) 52 | prediction = torch.argsort(out, dim=-1, descending=True) 53 | total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 54 | total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 55 | 56 | data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%' 57 | .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num, 58 | total_correct_1 / total_num * 100, total_correct_5 / total_num * 100)) 59 | 60 | return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser(description='Linear Evaluation') 65 | parser.add_argument('--model_path', type=str, default='results/model_400.pth', 66 | help='The pretrained model path') 67 | parser.add_argument('--batch_size', type=int, default=512, help='Number of images in each mini-batch') 68 | parser.add_argument('--epochs', type=int, default=100, help='Number of sweeps over the dataset to train') 69 | 70 | args = parser.parse_args() 71 | model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs 72 | train_data = STL10(root='data', split='train', transform=utils.train_transform) 73 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) 74 | test_data = STL10(root='data', split='test', transform=utils.test_transform) 75 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) 76 | 77 | model = Net(num_class=len(train_data.classes), pretrained_path=model_path).cuda() 78 | for param in model.f.parameters(): 79 | param.requires_grad = False 80 | model = nn.DataParallel(model) 81 | 82 | optimizer = optim.Adam(model.module.fc.parameters(), lr=1e-3, weight_decay=1e-6) 83 | loss_criterion = nn.CrossEntropyLoss() 84 | results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [], 85 | 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []} 86 | 87 | for epoch in range(1, epochs + 1): 88 | train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer) 89 | test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None) 90 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | import utils 13 | from model import Model 14 | 15 | 16 | def get_negative_mask(batch_size): 17 | negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool) 18 | for i in range(batch_size): 19 | negative_mask[i, i] = 0 20 | negative_mask[i, i + batch_size] = 0 21 | 22 | negative_mask = torch.cat((negative_mask, negative_mask), 0) 23 | return negative_mask 24 | 25 | 26 | def train(net, data_loader, train_optimizer, temperature, debiased, tau_plus): 27 | net.train() 28 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) 29 | for pos_1, pos_2, target in train_bar: 30 | pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True) 31 | feature_1, out_1 = net(pos_1) 32 | feature_2, out_2 = net(pos_2) 33 | 34 | # neg score 35 | out = torch.cat([out_1, out_2], dim=0) 36 | neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature) 37 | mask = get_negative_mask(batch_size).cuda() 38 | neg = neg.masked_select(mask).view(2 * batch_size, -1) 39 | 40 | # pos score 41 | pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) 42 | pos = torch.cat([pos, pos], dim=0) 43 | 44 | # estimator g() 45 | if debiased: 46 | N = batch_size * 2 - 2 47 | Ng = (-tau_plus * N * pos + neg.sum(dim = -1)) / (1 - tau_plus) 48 | # constrain (optional) 49 | Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature)) 50 | else: 51 | Ng = neg.sum(dim=-1) 52 | 53 | # contrastive loss 54 | loss = (- torch.log(pos / (pos + Ng) )).mean() 55 | 56 | train_optimizer.zero_grad() 57 | loss.backward() 58 | train_optimizer.step() 59 | 60 | total_num += batch_size 61 | total_loss += loss.item() * batch_size 62 | 63 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num)) 64 | 65 | return total_loss / total_num 66 | 67 | 68 | # test for one epoch, use weighted knn to find the most similar images' label to assign the test image 69 | def test(net, memory_data_loader, test_data_loader): 70 | net.eval() 71 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 72 | with torch.no_grad(): 73 | # generate feature bank 74 | for data, _, target in tqdm(memory_data_loader, desc='Feature extracting'): 75 | feature, out = net(data.cuda(non_blocking=True)) 76 | feature_bank.append(feature) 77 | # [D, N] 78 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 79 | # [N] 80 | feature_labels = torch.tensor(memory_data_loader.dataset.labels, device=feature_bank.device) 81 | # loop test data to predict the label by weighted knn search 82 | test_bar = tqdm(test_data_loader) 83 | for data, _, target in test_bar: 84 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 85 | feature, out = net(data) 86 | 87 | total_num += data.size(0) 88 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 89 | sim_matrix = torch.mm(feature, feature_bank) 90 | # [B, K] 91 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1) 92 | # [B, K] 93 | sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices) 94 | sim_weight = (sim_weight / temperature).exp() 95 | 96 | # counts for each class 97 | one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device) 98 | # [B*K, C] 99 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1).long(), value=1.0) 100 | # weighted score ---> [B, C] 101 | pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1) 102 | 103 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 104 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 105 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 106 | test_bar.set_description('KNN Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%' 107 | .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100)) 108 | 109 | return total_top1 / total_num * 100, total_top5 / total_num * 100 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = argparse.ArgumentParser(description='Train SimCLR') 114 | parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for latent vector') 115 | parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax') 116 | parser.add_argument('--tau_plus', default=0.1, type=float, help='Positive class priorx') 117 | parser.add_argument('--k', default=200, type=int, help='Top k most similar images used to predict the label') 118 | parser.add_argument('--batch_size', default=256, type=int, help='Number of images in each mini-batch') 119 | parser.add_argument('--epochs', default=500, type=int, help='Number of sweeps over the dataset to train') 120 | parser.add_argument('--debiased', default=True, type=bool, help='Debiased contrastive loss or standard loss') 121 | 122 | # args parse 123 | args = parser.parse_args() 124 | feature_dim, temperature, tau_plus, k = args.feature_dim, args.temperature, args.tau_plus, args.k 125 | batch_size, epochs, debiased = args.batch_size, args.epochs, args.debiased 126 | 127 | # data prepare 128 | train_data = utils.STL10Pair(root='data', split='train+unlabeled', transform=utils.train_transform) 129 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, 130 | drop_last=True) 131 | memory_data = utils.STL10Pair(root='data', split='train', transform=utils.test_transform) 132 | memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 133 | test_data = utils.STL10Pair(root='data', split='test', transform=utils.test_transform) 134 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 135 | 136 | # model setup and optimizer config 137 | model = Model(feature_dim).cuda() 138 | model = nn.DataParallel(model) 139 | 140 | optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6) 141 | c = len(memory_data.classes) 142 | print('# Classes: {}'.format(c)) 143 | 144 | # training loop 145 | if not os.path.exists('results'): 146 | os.mkdir('results') 147 | for epoch in range(1, epochs + 1): 148 | train_loss = train(model, train_loader, optimizer, temperature, debiased, tau_plus) 149 | if epoch % 25 == 0: 150 | test_acc_1, test_acc_5 = test(model, memory_loader, test_loader) 151 | torch.save(model.state_dict(), 'results/model_{}.pth'.format(epoch)) 152 | -------------------------------------------------------------------------------- /misc/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chingyaoc/DCL/7d6345f6a2e4e2615e63f39966ad0dad3f54f954/misc/fig1.png -------------------------------------------------------------------------------- /misc/neurips_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chingyaoc/DCL/7d6345f6a2e4e2615e63f39966ad0dad3f54f954/misc/neurips_fig.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models.resnet import resnet50 5 | 6 | 7 | class Model(nn.Module): 8 | def __init__(self, feature_dim=128): 9 | super(Model, self).__init__() 10 | 11 | self.f = [] 12 | for name, module in resnet50().named_children(): 13 | if name == 'conv1': 14 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 15 | if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d): 16 | self.f.append(module) 17 | # encoder 18 | self.f = nn.Sequential(*self.f) 19 | # projection head 20 | self.g = nn.Sequential(nn.Linear(2048, 512, bias=False), nn.BatchNorm1d(512), 21 | nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True)) 22 | 23 | def forward(self, x): 24 | x = self.f(x) 25 | feature = torch.flatten(x, start_dim=1) 26 | out = self.g(feature) 27 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 28 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | from torchvision.datasets import STL10 4 | import cv2 5 | import numpy as np 6 | 7 | np.random.seed(0) 8 | 9 | 10 | class STL10Pair(STL10): 11 | def __getitem__(self, index): 12 | img, target = self.data[index], self.labels[index] 13 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 14 | 15 | if self.transform is not None: 16 | pos_1 = self.transform(img) 17 | pos_2 = self.transform(img) 18 | 19 | return pos_1, pos_2, target 20 | 21 | 22 | class GaussianBlur(object): 23 | # Implements Gaussian blur as described in the SimCLR paper 24 | def __init__(self, kernel_size, min=0.1, max=2.0): 25 | self.min = min 26 | self.max = max 27 | # kernel size is set to be 10% of the image height/width 28 | self.kernel_size = kernel_size 29 | 30 | def __call__(self, sample): 31 | sample = np.array(sample) 32 | 33 | # blur the image with a 50% chance 34 | prob = np.random.random_sample() 35 | 36 | if prob < 0.5: 37 | sigma = (self.max - self.min) * np.random.random_sample() + self.min 38 | sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) 39 | 40 | return sample 41 | 42 | train_transform = transforms.Compose([ 43 | transforms.RandomResizedCrop(32), 44 | transforms.RandomHorizontalFlip(p=0.5), 45 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 46 | transforms.RandomGrayscale(p=0.2), 47 | GaussianBlur(kernel_size=int(0.1 * 32)), 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 50 | 51 | test_transform = transforms.Compose([ 52 | transforms.ToTensor(), 53 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 54 | --------------------------------------------------------------------------------