├── README.md ├── eval.sh ├── img └── fig.png ├── linear.py ├── main.py ├── model.py ├── results ├── 128_0.5_200_256_1000_statistics.csv └── linear_statistics.csv ├── train.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |
4 |

5 | 6 |
7 | 8 |

BatchSampler: Sampling Mini-Batches for Contrastive Learning in Vision, Language, and Graphs

9 | 10 | 11 | 12 | Source code for KDD'23 paper: [BatchSampler: Sampling Mini-Batches for Contrastive Learning in Vision, Language, and Graphs](https://arxiv.org/abs/2306.03355). 13 | 14 | 15 | BatchSampler is a simple and general generative method to sample mini-batches of hard-to-distinguish (i.e., hard and true negatives to each other) instances, which can be directly plugged into in-batch contrastive models in vision, language, and graphs. 16 | 17 |

Dependencies

18 | 19 | * Python >= 3.7 20 | * [Pytorch](https://pytorch.org/) >= 1.9.0 21 | 22 |

Quick Start

23 | 24 | Take vision modality as an example, you can run the code on STL10. 25 | 26 | ```bash 27 | sh train.sh 28 | ``` 29 | 30 | 31 | 32 |

Datasets

33 | 34 | We conduct experiments on five datasets across three modalities. For vision modality, we use a large-scale dataset [ImageNet](https://www.image-net.org/), two medium-sacle datasets: [STL10](https://cs.stanford.edu/~acoates/stl10/) and [ImageNet-100](https://www.kaggle.com/datasets/ambityga/imagenet100), and two small-scale datasets: [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) and [CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html). For language modality, we use 7 semantic textual similarity (STS) tasks. For graphs modality, we conduct graph-level classification experiments on 7 benchmark datasets: IMDB-B, IMDB- M, COLLAB, REDDIT-B, PROTEINS, MUTAG, and NCI1. 35 | 36 |

Experimental Results

37 | Vision Modality 38 | 39 | | Method | 100ep | 400ep | 800ep | 40 | | ------------------ | ------------ | ------------ | ------------ | 41 | | SimCLR | 64.0 | 68.1 | 68.7 | 42 | | **w/BatchSampler** | **64.7** | **68.6** | **69.2** | 43 | | Moco v3 | 68.9 | 73.3 | 73.8 | 44 | | **w/BatchSampler** | **69.5** | **73.7** | **74.2** | 45 | 46 | language Modality 47 | 48 | | Method | STS12 | STS13 | STS14 | STS15 | STS16 | STS-B | SICK-R | Avg. | 49 | | ------------------ | ------------ | ------------ | ------------ | -------------- | -------------- | -------------- | -------------- | -------------- | 50 | | SimCSE-BERT{BASE} | 68.62 | 80.89 | 73.74 | 80.88 | 77.66 | **77.79** | **69.64** | 75.60 | 51 | | w/kNN Sampler | 63.62 | 74.86 | 69.79 | 79.17 | 76.24 | 74.73 | 67.74 | 72.31 | 52 | | **w/BatchSampler** | **72.37** | **82.08** | **75.24** | **83.10** | **78.43** | 77.54 | 68.05 | **76.69** | 53 | | DCL-BERT{BASE} | 65.22 | 77.89 | 68.94 | 79.88 | **76.72** |73.89 | **69.54** | 73.15 | 54 | | w/kNN Sampler | 66.34 | 76.66 | 72.60 | 78.30 | 74.86 | 73.65 | 67.92 | 72.90 | 55 | | **w/BatchSampler** | **69.55** | **82.66** | **73.37** | **80.40** | 75.37 | **75.43** | 66.76 | **74.79** | 56 | | HCL-BERT{BASE} | 62.57 | 79.12 | 69.70 | 78.00 | 75.11 | 73.38 | 69.74 | 72.52| 57 | | w/kNN Sampler | 61.12 | 75.73 | 68.43 | 76.64 | 74.78 | 71.22 | 68.04 | 70.85 | 58 | | **w/BatchSampler** | **66.87** | **81.38** | **72.96** | **80.11** | **77.99** | **75.95** | 70.89 | **75.16** | 59 | 60 | 61 | 62 | Graphs Modality 63 | 64 | | Method | IMDB-B | IMDB-M | COLLAB | REDDIT-B | PROTEINS | MUTAG | NCI1 | 65 | | ------------------ | ------------ | ------------ | ------------ | -------------- | -------------- | -------------- | -------------- | 66 | | GraphCL | 70.90±0.53 | 48.48±0.38 | 70.62±0.23 | 90.54±0.25 | 74.39±0.45 | 86.80±1.34 | 77.87±0.41 | 67 | | w/kNN Sampler | 70.72±0.35 | 47.97±0.97 | 70.59±0.14 | 90.21±0.74 | 74.17±0.41 | 86.46±0.82 | 77.27±0.37 | 68 | | **w/BatchSampler** | **71.90±0.46** | **48.93±0.28** | **71.48±0.28** | **90.88±0.16** | **75.04±0.67** | **87.78±0.93** | **78.93±0.38** | 69 | | DCL| 71.07±0.36 | 48.93±0.32 | **71.06±0.51** | 90.66±0.29 | 74.64±0.48 | 88.09±0.93 | 78.49±0.48 | 70 | | w/kNN Sampler | 70.94±0.19 | 48.47±0.35 | 70.49±0.37 | 90.26±1.03 | 74.28±0.17 | 87.13±1.40 | 78.13±0.52 | 71 | | **w/BatchSampler** | **71.32±0.17** | **48.96±0.25** | 70.44±0.35 | **90.73±0.34** | **75.02±0.61** | **89.47±1.43** | **79.03±0.32** | 72 | | HCL| **71.24±0.36** | 48.54±0.51 | 71.03±0.45 | 90.40±0.42 | 74.69±0.42 | 87.79±1.10 | 78.83±0.67 | 73 | | w/kNN Sampler | 71.14±0.44 | 48.36±0.93 | 70.86±0.74 | 90.64±0.51 | 74.06±0.44 | 87.53±1.37 | 78.66±0.48 | 74 | | **w/BatchSampler** | 71.20±0.38 | **48.76±0.39** | **71.70±0.35** | **91.25±0.25** | **75.11±0.63** | **88.31±1.29** | **79.17±0.27** | 75 | | MVGRL| 74.20±0.70 | 51.20±0.50 |- | 84.50±0.60 | - | 89.70±1.10 | - | 76 | | w/kNN Sampler | 73.30±0.34 | 50.70±0.36 | - | 82.70±0.67 | - | 85.08±0.66 | - | 77 | | **w/BatchSampler** | **76.70±0.35** | **52.40±0.39** | - | **87.47±0.79** | - | **91.13±0.81** | - | 78 | 79 |

Citing

80 | If you find our work is helpful to your research, please consider citing our paper: 81 | 82 | ``` 83 | @article{yang2023batchsampler, 84 | title={BatchSampler: Sampling Mini-Batches for Contrastive Learning in Vision, Language, and Graphs}, 85 | author={Yang, Zhen and Huang, Tinglin and Ding, Ming and Dong, Yuxiao and Ying, Rex and Cen, Yukuo and Geng, Yangliao and Tang, Jie}, 86 | journal={arXiv preprint arXiv:2306.03355}, 87 | year={2023} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=7 2 | python linear.py --dataset stl10 --batch_size 512 --epochs 100 --save_path test_stl10 --model_path ./test_stl10/128_0.5_200_256_1000_model.pth 3 | -------------------------------------------------------------------------------- /img/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/BatchSampler/28017ec4b6f2c8eebcde16ea97cfa3bb217cfb4a/img/fig.png -------------------------------------------------------------------------------- /linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pandas as pd 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from thop import profile, clever_format 8 | from torch.utils.data import DataLoader 9 | from torchvision.datasets import CIFAR10, CIFAR100, STL10 10 | from tqdm import tqdm 11 | 12 | import utils 13 | from model import Model 14 | 15 | 16 | class Net(nn.Module): 17 | def __init__(self, num_class, pretrained_path): 18 | super(Net, self).__init__() 19 | 20 | # encoder 21 | self.f = Model().f 22 | # classifier 23 | self.fc = nn.Linear(2048, num_class, bias=True) 24 | self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False) 25 | 26 | def forward(self, x): 27 | x = self.f(x) 28 | feature = torch.flatten(x, start_dim=1) 29 | out = self.fc(feature) 30 | return out 31 | 32 | 33 | # train or test for one epoch 34 | def train_val(net, data_loader, train_optimizer): 35 | is_train = train_optimizer is not None 36 | net.train() if is_train else net.eval() 37 | 38 | total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader) 39 | with (torch.enable_grad() if is_train else torch.no_grad()): 40 | for data, target in data_bar: 41 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 42 | out = net(data) 43 | loss = loss_criterion(out, target) 44 | 45 | if is_train: 46 | train_optimizer.zero_grad() 47 | loss.backward() 48 | train_optimizer.step() 49 | 50 | total_num += data.size(0) 51 | total_loss += loss.item() * data.size(0) 52 | prediction = torch.argsort(out, dim=-1, descending=True) 53 | total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 54 | total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 55 | 56 | data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%' 57 | .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num, 58 | total_correct_1 / total_num * 100, total_correct_5 / total_num * 100)) 59 | 60 | return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser(description='Linear Evaluation') 65 | parser.add_argument('--model_path', type=str, default='results/128_0.5_200_512_500_model.pth', 66 | help='The pretrained model path') 67 | parser.add_argument('--batch_size', type=int, default=512, help='Number of images in each mini-batch') 68 | parser.add_argument('--epochs', type=int, default=100, help='Number of sweeps over the dataset to train') 69 | parser.add_argument('--save_path', default='results', type=str, help='save path') 70 | parser.add_argument('--dataset', default='cifar10', type=str, help='train dataset') 71 | 72 | 73 | args = parser.parse_args() 74 | model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs 75 | if args.dataset == 'cifar10': 76 | train_data = CIFAR10(root='data', train=True, transform=utils.train_transform_cifar10, download=True) 77 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True) 78 | test_data = CIFAR10(root='data', train=False, transform=utils.test_transform_cifar10, download=True) 79 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) 80 | elif args.dataset == 'cifar100': 81 | train_data = CIFAR100(root='data', train=True, transform=utils.train_transform_cifar100, download=True) 82 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True) 83 | test_data = CIFAR100(root='data', train=False, transform=utils.test_transform_cifar100, download=True) 84 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) 85 | elif args.dataset == 'stl10': 86 | train_data = STL10(root='data', split='train', transform=utils.train_transform_stl10, download=True) 87 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True) 88 | test_data = STL10(root='data', split='test', transform=utils.test_transform_stl10, download=True) 89 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) 90 | else: 91 | raise Exception("unvalid dataset") 92 | 93 | model = Net(num_class=len(train_data.classes), pretrained_path=model_path).cuda() 94 | for param in model.f.parameters(): 95 | param.requires_grad = False 96 | 97 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),)) 98 | flops, params = clever_format([flops, params]) 99 | print('# Model Params: {} FLOPs: {}'.format(params, flops)) 100 | optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6) 101 | loss_criterion = nn.CrossEntropyLoss() 102 | results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [], 103 | 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []} 104 | 105 | best_acc = 0.0 106 | for epoch in range(1, epochs + 1): 107 | train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer) 108 | results['train_loss'].append(train_loss) 109 | results['train_acc@1'].append(train_acc_1) 110 | results['train_acc@5'].append(train_acc_5) 111 | test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None) 112 | results['test_loss'].append(test_loss) 113 | results['test_acc@1'].append(test_acc_1) 114 | results['test_acc@5'].append(test_acc_5) 115 | # save statistics 116 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 117 | data_frame.to_csv(args.save_path+'/linear_statistics.csv', index_label='epoch') 118 | if test_acc_1 > best_acc: 119 | best_acc = test_acc_1 120 | torch.save(model.state_dict(), args.save_path+'/linear_model.pth') 121 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pandas as pd 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from thop import profile, clever_format 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | import time 13 | import random 14 | import numpy as np 15 | 16 | import utils 17 | from model import Model 18 | 19 | import json 20 | 21 | 22 | def get_negative_mask(batch_size): 23 | negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool) 24 | for i in range(batch_size): 25 | negative_mask[i, i] = 0 26 | negative_mask[i, i + batch_size] = 0 27 | 28 | negative_mask = torch.cat((negative_mask, negative_mask), 0) 29 | return negative_mask 30 | 31 | # train for one epoch to learn unique features 32 | def train(net, data_loader, train_optimizer, batch_save): 33 | net.train() 34 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) 35 | batch_num = 0 36 | for idx, pos_1, pos_2, target in train_bar: 37 | batch_size = pos_1.shape[0] 38 | pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True) 39 | feature_1, out_1 = net(pos_1) 40 | feature_2, out_2 = net(pos_2) 41 | # [2*B, D] 42 | out = torch.cat([out_1, out_2], dim=0) 43 | # [2*B, 2*B] 44 | sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature) 45 | # neg 46 | mask = get_negative_mask(batch_size).cuda() 47 | # [2*B, 2*B-2] 48 | 49 | 50 | neg = sim_matrix.masked_select(mask).view(2 * batch_size, -1) 51 | 52 | # compute loss 53 | pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) 54 | # [2*B] 55 | pos_sim = torch.cat([pos_sim, pos_sim], dim=0) 56 | 57 | if estimator=='hard': 58 | N = batch_size * 2 - 2 59 | imp = (beta* neg.log()).exp() 60 | reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1) 61 | Ng = (-tau_plus * N * pos_sim + reweight_neg) / (1 - tau_plus) 62 | # constrain (optional) 63 | Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature)) 64 | elif estimator=='easy': 65 | Ng = neg.sum(dim=-1) 66 | elif estimator=='debias': 67 | N = batch_size * 2 - 2 68 | Ng = (-tau_plus * N * pos_sim + neg.sum(dim = -1)) / (1 - tau_plus) 69 | # constrain (optional) 70 | Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature)) 71 | else: 72 | raise Exception('Invalid estimator selected. Please use any of [hard, easy]') 73 | 74 | loss = (- torch.log(pos_sim / (pos_sim + Ng))).mean() 75 | 76 | 77 | # get all image embeds 78 | if batch_num == 0: 79 | embeds_all = torch.cat((out_1, out_2), dim=1).detach() 80 | index_all = idx.view(batch_size, -1).detach() 81 | else: 82 | embeds_all = torch.cat((embeds_all, torch.cat((out_1, out_2), dim=1)), dim=0).detach() 83 | index_all = torch.cat((index_all, idx.view(batch_size, -1)), dim=0).detach() 84 | 85 | train_optimizer.zero_grad() 86 | loss.backward() 87 | train_optimizer.step() 88 | 89 | total_num += batch_size 90 | total_loss += loss.item() * batch_size 91 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num)) 92 | batch_num += 1 93 | batch_save += 1 94 | 95 | if batch_save != 0 and batch_save % args.save_batches == 0: 96 | torch.save(model.state_dict(), args.save_path+'/model_batch_{}.pth'.format(batch_save)) 97 | if batch_save > batch_nums_all: 98 | return total_loss / total_num, embeds_all, index_all, batch_save 99 | return total_loss / total_num, embeds_all, index_all, batch_save 100 | 101 | 102 | # test for one epoch, use weighted knn to find the most similar images' label to assign the test image 103 | def test(net, memory_data_loader, test_data_loader): 104 | net.eval() 105 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 106 | with torch.no_grad(): 107 | # generate feature bank 108 | for _, data, _, target in tqdm(memory_data_loader, desc='Feature extracting'): 109 | feature, out = net(data.cuda(non_blocking=True)) 110 | feature_bank.append(feature) 111 | # [D, N] 112 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 113 | # [N] 114 | feature_labels = torch.tensor(memory_data_loader.dataset.labels, device=feature_bank.device) 115 | # loop test data to predict the label by weighted knn search 116 | test_bar = tqdm(test_data_loader) 117 | for _, data, _, target in test_bar: 118 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 119 | feature, out = net(data) 120 | 121 | total_num += data.size(0) 122 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 123 | sim_matrix = torch.mm(feature, feature_bank) 124 | # [B, K] 125 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1) 126 | # [B, K] 127 | sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices) 128 | sim_weight = (sim_weight / temperature).exp() 129 | 130 | # counts for each class 131 | one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device) 132 | # [B*K, C] 133 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1).long(), value=1.0) 134 | # weighted score ---> [B, C] 135 | pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1) 136 | 137 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 138 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 139 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 140 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%' 141 | .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100)) 142 | 143 | return total_top1 / total_num * 100, total_top5 / total_num * 100 144 | 145 | 146 | # random -> topk 147 | def build_graph(images_embeds): 148 | split_num = 20 149 | fold_len = images_embeds.size()[0] // split_num 150 | 151 | images_embeds = F.normalize(images_embeds, dim=-1) 152 | 153 | for i in range(split_num): 154 | start = i * fold_len 155 | if i == split_num -1: 156 | end = images_embeds.size()[0] 157 | else: 158 | end = (i + 1) * fold_len 159 | 160 | # random 161 | neg_indices = torch.randint(0, images_embeds.shape[0], (end-start, args.M)).cuda() 162 | random_select_idx = neg_indices.view(1, -1)[0] 163 | 164 | anchor = images_embeds[start:end].unsqueeze(1) 165 | 166 | negs = images_embeds[random_select_idx].view(end-start, args.M, -1) 167 | 168 | B_score = torch.matmul(anchor, negs.permute(0,2,1)).squeeze() 169 | B_score = torch.exp(B_score / temperature) 170 | neg = B_score 171 | 172 | # top 173 | B_score, B_idx = neg.topk(k=edge_nums, dim=-1) 174 | B_indices = torch.gather(neg_indices, 1, B_idx) 175 | 176 | 177 | if i == 0: 178 | edge_score = B_score 179 | adj_index = B_indices 180 | else: 181 | edge_score = torch.cat([edge_score, B_score], dim=0) 182 | adj_index = torch.cat([adj_index, B_indices], dim=0) 183 | 184 | edge_score = edge_score.detach().cpu() 185 | adj_index = adj_index.detach().cpu() 186 | 187 | return map_index, adj_index, edge_score 188 | 189 | 190 | 191 | def sampling(start_node, adj_index, map_index, visit, restart_p): 192 | if args.sampling == 'RWR': 193 | walks, visit = random_walk(start_node, adj_index, map_index, visit, restart_p) 194 | else: 195 | raise Exception("unvalid sampling method") 196 | return walks, visit 197 | 198 | 199 | def random_walk(start_node, adj_index, map_index, visit, restart_p): 200 | # random walk + restart 201 | walks = set() 202 | walks.add(start_node) 203 | paths = [start_node] 204 | while len(walks) < batch_size: 205 | cur = paths[-1] 206 | nodes = adj_index[cur].tolist() 207 | if len(nodes) > 0: 208 | if torch.rand(1).item() < restart_p: 209 | paths.append(paths[0]) 210 | else: 211 | next_node = random.choice(nodes) 212 | paths.append(next_node) 213 | if next_node not in visit: 214 | walks.add(next_node) 215 | else: 216 | break 217 | visit = visit | walks 218 | assert len(set(walks)) == batch_size 219 | return walks, visit 220 | 221 | 222 | 223 | def generate_indices(graph, walks): 224 | map_index, adj_index, edge_weights = graph[0], graph[1].numpy(), graph[2] 225 | indices = [] 226 | for _ in range(args.update_batchs): 227 | visit = set() 228 | remain_images = list(set(map_index.keys()) - set(walks)) 229 | if len(remain_images) == 0: 230 | return indices, walks 231 | start_node = random.choice(remain_images) 232 | batch_idx, visit = sampling(start_node, adj_index, map_index, visit, restart_p=restart_p) 233 | walks.extend(batch_idx) 234 | indices.extend(batch_idx) 235 | return indices, walks 236 | 237 | def update_graph(update_embeds, update_index, origin_embeds): 238 | map_index = {k:i for i, k in enumerate(update_index.view(1,-1)[0].numpy().tolist())} 239 | index = torch.tensor(list(map_index.keys())) 240 | embeds_images = update_embeds[torch.tensor(list(map_index.values()))] 241 | origin_embeds[index, :] = embeds_images 242 | images_embeds = origin_embeds 243 | 244 | graph = build_graph(images_embeds) 245 | return graph, images_embeds 246 | 247 | def get_embeds(embeds_all, index_all): 248 | embeds = torch.ones((len(train_data), embeds_all.shape[1]), dtype=torch.float).cuda() 249 | map_index = {k:i for i, k in enumerate(index_all.view(1,-1)[0].numpy().tolist())} 250 | 251 | index = torch.tensor(list(map_index.keys())) 252 | embeds_images = embeds_all[torch.tensor(list(map_index.values()))] 253 | embeds[index, :] = embeds_images 254 | images_embeds = embeds 255 | return map_index, images_embeds 256 | 257 | if __name__ == '__main__': 258 | parser = argparse.ArgumentParser(description='Train SimCLR') 259 | parser.add_argument('--dataset', default='stl10', type=str, help='train dataset') 260 | parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for latent vector') 261 | parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax') 262 | parser.add_argument('--k', default=200, type=int, help='Top k most similar images used to predict the label') 263 | parser.add_argument('--batch_size', default=512, type=int, help='Number of images in each mini-batch') 264 | parser.add_argument('--epochs', default=1000, type=int, help='Number of sweeps over the dataset to train') 265 | parser.add_argument('--save_batches', default=1000, type=int, help='Number of batches to save checkpoint') 266 | parser.add_argument('--edge_nums', default=100, type=int, help='the number of edges for build graph') 267 | parser.add_argument('--restart_p', default=0.1, type=float, help='random walk with restart probability') 268 | parser.add_argument('--random_epochs', default=10, type=int, help='random shuflle epochs at beginning') 269 | parser.add_argument('--save_path', default='results', type=str, help='save path') 270 | parser.add_argument('--sampling', default='RW', type=str, help='sampling method') 271 | parser.add_argument('--M', default=1000, type=int, help='Number of edges in first selection') 272 | parser.add_argument('--update_batchs', default=100, type=int, help='graph update batchs') 273 | parser.add_argument('--estimator', default='hard', type=str, help='Choose loss function') 274 | parser.add_argument('--tau_plus', default=0.1, type=float, help='Positive class priorx') 275 | parser.add_argument('--beta', default=1.0, type=float, help='Choose loss function') 276 | 277 | # args parse 278 | args = parser.parse_args() 279 | print(args) 280 | if not os.path.exists(args.save_path): 281 | os.mkdir(args.save_path) 282 | with open(args.save_path + '/args.txt', 'w') as f: 283 | json.dump(args.__dict__, f, indent=2) 284 | 285 | feature_dim, temperature, k, restart_p, edge_nums= args.feature_dim, args.temperature, args.k, args.restart_p, args.edge_nums 286 | batch_size, epochs = args.batch_size, args.epochs 287 | estimator = args.estimator 288 | tau_plus = args.tau_plus 289 | beta = args.beta 290 | 291 | # data prepare 292 | if args.dataset == 'stl10': 293 | train_data = utils.STL10Pair(root='data', split='train+unlabeled', transform=utils.train_transform_stl10, download=True) 294 | memory_data = utils.STL10Pair(root='data', split='train', transform=utils.test_transform_stl10, download=True) 295 | memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) 296 | test_data = utils.STL10Pair(root='data', split='test', transform=utils.test_transform_stl10, download=True) 297 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True) 298 | else: 299 | raise Exception("unvalid dataset") 300 | 301 | args.save_batches = int(len(train_data)//batch_size+1) * 100 302 | 303 | # model setup and optimizer config 304 | model = Model(feature_dim).cuda() 305 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),)) 306 | flops, params = clever_format([flops, params]) 307 | print('# Model Params: {} FLOPs: {}'.format(params, flops)) 308 | optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6) 309 | c = len(memory_data.classes) 310 | 311 | # training loop 312 | results = {'train_loss': [], 'test_acc@1': [], 'test_acc@5': []} 313 | save_name_pre = '{}_{}_{}_{}_{}'.format(feature_dim, temperature, k, batch_size, epochs) 314 | 315 | 316 | batch_nums_all = int(len(train_data)//batch_size+1) * args.epochs 317 | batch_save = 0 318 | best_acc = 0.0 319 | for epoch in range(1, epochs + 1): 320 | walks = [] 321 | 322 | if epoch <= args.random_epochs: 323 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, 324 | drop_last=False) 325 | train_loss, embeds_all, index_all, batch_save = train(model, train_loader, optimizer, batch_save) 326 | embeds_all, index_all = embeds_all.detach(), index_all.detach() 327 | # build graph 328 | if epoch == args.random_epochs: 329 | map_index, images_embeds = get_embeds(embeds_all, index_all) 330 | graph = build_graph(images_embeds) 331 | else: 332 | t_generate_indices = time.time() 333 | batch_num = len(train_data) // batch_size 334 | cur_batch = 0 335 | # restart_p 336 | restart_p = args.restart_p - (epoch/5)*0.1 337 | if restart_p <= 0.05: 338 | restart_p = 0.05 339 | 340 | while len(set(walks)) < batch_num * batch_size: 341 | if cur_batch != 0 and cur_batch % args.update_batchs == 0: 342 | # update graph 343 | graph, images_embeds = update_graph(embeds_all, index_all, images_embeds) 344 | indices, walks = generate_indices(graph, walks) 345 | sampler = utils.Sampler(indices) 346 | train_loader = DataLoader(train_data, 347 | batch_size=batch_size, 348 | shuffle=False, num_workers=1, pin_memory=True, 349 | sampler=sampler, 350 | ) 351 | train_loss, embeds_all, index_all, batch_save = train(model, train_loader, optimizer, batch_save) 352 | embeds_all, index_all = embeds_all.detach(), index_all.detach() 353 | cur_batch += len(train_loader) 354 | if batch_save > batch_nums_all: 355 | break 356 | 357 | if batch_save > batch_nums_all: 358 | break 359 | remain_images = list(set(map_index.keys()) - set(walks)) 360 | walks.extend(remain_images) 361 | assert len(set(walks)) == len(train_data) 362 | if len(remain_images) > 1: 363 | sampler = utils.Sampler(remain_images) 364 | train_loader = DataLoader(train_data, 365 | batch_size=batch_size, 366 | shuffle=False, num_workers=1, pin_memory=True, 367 | sampler=sampler, 368 | ) 369 | train_loss, embeds_all, index_all, batch_save = train(model, train_loader, optimizer, batch_save) 370 | embeds_all, index_all = embeds_all.detach(), index_all.detach() 371 | print("t_generate_indices=========", time.time() - t_generate_indices) 372 | 373 | 374 | results['train_loss'].append(train_loss) 375 | test_acc_1, test_acc_5 = test(model, memory_loader, test_loader) 376 | results['test_acc@1'].append(test_acc_1) 377 | results['test_acc@5'].append(test_acc_5) 378 | # save statistics 379 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 380 | data_frame.to_csv(args.save_path+'/{}_statistics.csv'.format(save_name_pre), index_label='epoch') 381 | if test_acc_1 > best_acc: 382 | best_acc = test_acc_1 383 | torch.save(model.state_dict(), args.save_path+'/{}_model.pth'.format(save_name_pre)) 384 | 385 | 386 | 387 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models.resnet import resnet50 5 | 6 | 7 | class Model(nn.Module): 8 | def __init__(self, feature_dim=128): 9 | super(Model, self).__init__() 10 | 11 | self.f = [] 12 | for name, module in resnet50().named_children(): 13 | if name == 'conv1': 14 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 15 | if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d): 16 | self.f.append(module) 17 | # encoder 18 | self.f = nn.Sequential(*self.f) 19 | # projection head 20 | self.g = nn.Sequential(nn.Linear(2048, 512, bias=False), nn.BatchNorm1d(512), 21 | nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True)) 22 | 23 | def forward(self, x): 24 | x = self.f(x) 25 | feature = torch.flatten(x, start_dim=1) 26 | out = self.g(feature) 27 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) 28 | -------------------------------------------------------------------------------- /results/128_0.5_200_256_1000_statistics.csv: -------------------------------------------------------------------------------- 1 | epoch,train_loss,test_acc@1,test_acc@5 2 | 1,5.293757577696301,42.225,93.675 3 | 2,4.826159853684275,58.6875,97.2625 4 | 3,4.771193599700927,61.762499999999996,97.6625 5 | 4,4.749537806913077,65.1125,98.0125 6 | 5,4.731312639183468,66.0875,98.0875 7 | 6,4.709558345833603,67.83749999999999,98.2625 8 | 7,4.7036038080851235,68.7,98.41250000000001 9 | 8,4.682034615107945,68.77499999999999,98.5 10 | 9,4.678532238006592,69.4625,98.41250000000001 11 | 10,4.673528847694397,70.46249999999999,98.5375 12 | 11,4.6616191946226975,71.0125,98.6625 13 | 12,4.6555983283303,71.45,98.7625 14 | 13,4.65258057364102,72.225,98.6875 15 | 14,4.643300998501662,73.825,98.9 16 | 15,4.64274603465818,71.75,98.7125 17 | 16,4.643919054667155,72.85000000000001,98.6 18 | 17,4.633812708932846,72.3625,98.6375 19 | 18,4.633907571752021,72.1375,98.8125 20 | 19,4.631996188365238,72.725,98.675 21 | 20,4.627749775940517,73.575,99.0125 22 | 21,4.623903426859114,73.7125,98.925 23 | 22,4.616939772730288,73.6125,98.875 24 | 23,4.617843597166,74.4,98.91250000000001 25 | 24,4.618095388412476,73.2,98.675 26 | 25,4.612618161671198,74.2625,98.88749999999999 27 | 26,4.610560584068298,74.6375,99.175 28 | 27,4.613450231223271,74.5125,98.85000000000001 29 | 28,4.616752988836739,74.9875,99.0125 30 | 29,4.607374534606934,74.5625,99.0 31 | 30,4.610040483474731,74.8,98.95 32 | 31,4.600977091789246,75.175,98.875 33 | 32,4.6121180125645225,75.1375,99.0375 34 | 33,4.6064458751678465,74.8375,98.7875 35 | 34,4.607028651237488,74.425,99.0375 36 | 35,4.605106512705485,75.775,99.125 37 | 36,4.600796276872808,75.6375,99.05000000000001 38 | 37,4.601950973272324,75.775,99.1375 39 | 38,4.600315594673157,75.5125,99.1 40 | 39,4.595778841972351,74.0125,99.075 41 | 40,4.592310833930969,75.0625,99.1375 42 | 41,4.599948711395264,75.35,99.0625 43 | 42,4.59720899105072,75.3375,98.97500000000001 44 | 43,4.599029782452161,76.03750000000001,98.96249999999999 45 | 44,4.5941962683072655,76.71249999999999,99.1125 46 | 45,4.592749371528625,76.5625,99.1375 47 | 46,4.591093096733093,75.44999999999999,99.0125 48 | 47,4.593859170612536,75.8875,99.1625 49 | 48,4.590588365282331,76.05,99.225 50 | 49,4.59339542388916,75.775,99.15 51 | 50,4.593822264671326,77.325,99.2 52 | 51,4.589333243370056,76.47500000000001,99.1625 53 | 52,4.588132557868957,75.3,99.05000000000001 54 | 53,4.586618657112122,76.6,99.2625 55 | 54,4.590204153060913,76.11250000000001,99.25 56 | 55,4.592039896517384,77.03750000000001,99.1625 57 | 56,4.584720191955566,75.7625,99.1125 58 | 57,4.589341929084377,76.625,99.25 59 | 58,4.5870132189924995,78.025,99.3625 60 | 59,4.59016788482666,75.6375,99.1 61 | 60,4.583746252059936,77.275,99.2125 62 | 61,4.5849903781762285,77.4625,99.25 63 | 62,4.579810733795166,77.525,99.1875 64 | 63,4.589144492149353,77.1375,99.3 65 | 64,4.585684504508972,78.4875,99.225 66 | 65,4.588148999214172,77.325,99.1625 67 | 66,4.580712671081225,78.4875,99.275 68 | 67,4.585197448730469,77.225,99.2125 69 | 68,4.582306609153748,76.3875,99.125 70 | 69,4.580277142524719,77.7875,99.41250000000001 71 | 70,4.585395255088806,76.7875,99.1125 72 | 71,4.579945621036348,77.5625,99.375 73 | 72,4.579310126304627,77.03750000000001,99.3 74 | 73,4.5810620307922365,77.0,99.3375 75 | 74,4.579955153465271,76.9375,99.2375 76 | 75,4.5884348335912675,76.9,99.2625 77 | 76,4.578743467330932,78.325,99.3 78 | 77,4.5799387121200565,78.0625,99.2125 79 | 78,4.573168625831604,77.9625,99.3125 80 | 79,4.574429011344909,78.73750000000001,99.3 81 | 80,4.577625751495361,79.0875,99.3625 82 | 81,4.576289272308349,77.9,99.2875 83 | 82,4.578504064348009,77.7,99.4 84 | 83,4.570481319816745,77.47500000000001,99.3625 85 | 84,4.57708838780721,77.875,99.3125 86 | 85,4.578318653106689,78.73750000000001,99.3125 87 | 86,4.576002221107483,77.8375,99.3 88 | 87,4.572186132957196,77.3875,99.2 89 | 88,4.568873965221902,78.5,99.2875 90 | 89,4.572124977111816,79.0875,99.3 91 | 90,4.571539549827576,78.03750000000001,99.275 92 | 91,4.578847978022192,78.5125,99.3625 93 | 92,4.577983069419861,78.1625,99.2 94 | 93,4.5709609790724155,78.9875,99.25 95 | 94,4.572255902290344,77.8625,99.2875 96 | 95,4.57702686018863,77.225,99.275 97 | 96,4.574975714683533,78.8125,99.25 98 | 97,4.57355366230011,78.025,99.3 99 | 98,4.573449125289917,78.6875,99.2125 100 | 99,4.5740386009216305,77.325,99.2125 101 | 100,4.5707975243622405,76.775,99.2375 102 | 101,4.568804507162056,78.9625,99.2375 103 | 102,4.570234065055847,77.7,99.3 104 | 103,4.570616459846496,77.8625,99.2625 105 | 104,4.571278429031372,77.425,99.2 106 | 105,4.5763217780901035,78.125,99.2625 107 | 106,4.571589671648466,75.9875,99.3125 108 | 107,4.572205200195312,78.35,99.325 109 | 108,4.566824765205383,77.14999999999999,99.1625 110 | 109,4.5730495187971325,77.3625,99.3125 111 | 110,4.566690254211426,78.3625,99.3 112 | 111,4.570611038208008,79.03750000000001,99.35000000000001 113 | 112,4.56855793697078,77.85,99.2375 114 | 113,4.569841089702788,78.625,99.2875 115 | 114,4.565202702175487,78.5875,99.3125 116 | 115,4.570380935668945,77.725,99.2125 117 | 116,4.56500020980835,79.25,99.4 118 | 117,4.567118740081787,78.075,99.46249999999999 119 | 118,4.565334788235751,78.825,99.25 120 | 119,4.56559255917867,80.4375,99.4375 121 | 120,4.571868437987107,78.1875,99.35000000000001 122 | 121,4.568742669146994,78.625,99.3625 123 | 122,4.572119749509371,77.25,99.2875 124 | 123,4.566438215119498,77.375,99.3 125 | 124,4.57454890953867,78.3625,99.2125 126 | 125,4.568522095680237,78.4875,99.325 127 | 126,4.5665170844172085,79.0125,99.35000000000001 128 | 127,4.564128122329712,79.5125,99.425 129 | 128,4.568067855834961,79.475,99.3625 130 | 129,4.566403017044068,77.725,99.3 131 | 130,4.562279386520386,78.2875,99.45 132 | 131,4.567161998748779,78.60000000000001,99.2625 133 | 132,4.566066150007577,78.8,99.47500000000001 134 | 133,4.56695170879364,79.4875,99.38749999999999 135 | 134,4.56440342040289,78.66250000000001,99.375 136 | 135,4.567775155368604,79.21249999999999,99.4 137 | 136,4.564255863428116,79.07499999999999,99.3125 138 | 137,4.557527716566876,77.8875,99.325 139 | 138,4.565986478328705,78.525,99.375 140 | 139,4.562467209994793,79.6875,99.3375 141 | 140,4.5640857672389545,79.95,99.4375 142 | 141,4.565324122647205,79.5875,99.3 143 | 142,4.556167395218559,79.5,99.3625 144 | 143,4.565654691060384,79.80000000000001,99.41250000000001 145 | 144,4.563221107210432,78.1125,99.1875 146 | 145,4.562008252014985,79.625,99.3125 147 | 146,4.563423084762861,80.0875,99.3375 148 | 147,4.561200676737605,78.4625,99.275 149 | 148,4.566325235366821,77.4375,99.3 150 | 149,4.565679378509522,79.16250000000001,99.3125 151 | 150,4.561311265112648,79.9875,99.46249999999999 152 | 151,4.564657626152038,78.45,99.2875 153 | 152,4.56643837802815,78.3375,99.38749999999999 154 | 153,4.565955352783203,79.4125,99.3375 155 | 154,4.565937538536227,77.925,99.3625 156 | 155,4.561128410926232,79.7375,99.3375 157 | 156,4.5652302503585815,80.7625,99.46249999999999 158 | 157,4.56415134522973,79.23750000000001,99.41250000000001 159 | 158,4.561256321383194,78.6875,99.325 160 | 159,4.562075548905593,81.075,99.45 161 | 160,4.565318050838652,78.7625,99.3375 162 | 161,4.56672986092106,78.875,99.3375 163 | 162,4.559463489346388,78.66250000000001,99.325 164 | 163,4.557186906987971,78.9,99.35000000000001 165 | 164,4.56389398097992,79.27499999999999,99.4 166 | 165,4.560159504413605,78.825,99.325 167 | -------------------------------------------------------------------------------- /results/linear_statistics.csv: -------------------------------------------------------------------------------- 1 | epoch,train_loss,train_acc@1,train_acc@5,test_loss,test_acc@1,test_acc@5 2 | 1,2.23867652015686,47.32,85.56,2.2394468173980715,62.4875,98.1625 3 | 2,2.0928818286895754,76.32,98.52,2.175532444000244,75.75,99.1625 4 | 3,1.955794093132019,78.7,99.1,2.1144030265808107,78.95,99.2875 5 | 4,1.8307094875335694,80.2,99.14,2.0573978023529054,79.7375,99.325 6 | 5,1.712966128540039,80.74,99.16,2.0034180030822752,79.475,99.325 7 | 6,1.6091594816207886,80.82000000000001,99.18,1.9523463973999022,79.57499999999999,99.3 8 | 7,1.5115610223770142,80.94,99.26,1.9046716756820679,79.3875,99.2875 9 | 8,1.4153708534240723,81.26,99.32,1.8572627630233765,79.9375,99.2625 10 | 9,1.336933927154541,81.04,99.26,1.815097327232361,79.7,99.225 11 | 10,1.2672981122970581,81.69999999999999,99.18,1.7770221939086914,79.5875,99.175 12 | 11,1.2031101558685302,82.44,99.38,1.7408404541015625,79.675,99.2 13 | 12,1.1470788984298705,81.39999999999999,99.24,1.706960210800171,79.95,99.2 14 | 13,1.0925416803359986,82.24000000000001,99.32,1.6728482570648193,80.16250000000001,99.2 15 | 14,1.04535007686615,81.82000000000001,99.14,1.6424105081558227,80.22500000000001,99.2125 16 | 15,1.0020886152267456,82.44,99.36,1.6137052478790284,80.30000000000001,99.2 17 | 16,0.9616574898719787,82.58,99.48,1.5877059564590454,80.2875,99.2125 18 | 17,0.9207497631072998,82.72,99.36,1.5602429485321045,80.45,99.2125 19 | 18,0.8999906936645508,82.36,99.3,1.5393530073165893,80.4125,99.2 20 | 19,0.8730637706756592,83.0,99.32,1.5179570751190186,80.425,99.25 21 | 20,0.8397495670318603,82.98,99.24,1.4954745197296142,80.525,99.225 22 | 21,0.8153184818267822,83.44,99.2,1.475080030441284,80.60000000000001,99.2625 23 | 22,0.7863519223213196,83.34,99.42,1.45409801197052,80.875,99.325 24 | 23,0.7656220963478089,84.28,99.38,1.4374995250701905,80.9375,99.3375 25 | 24,0.745424748802185,84.2,99.42,1.4190076990127563,81.16250000000001,99.35000000000001 26 | 25,0.7366760729789734,83.52000000000001,99.33999999999999,1.4025026140213013,81.0875,99.3625 27 | 26,0.7170994510650635,83.39999999999999,99.24,1.3857117347717285,81.2625,99.3625 28 | 27,0.6940866159439087,84.7,99.5,1.370622797012329,81.2875,99.375 29 | 28,0.6919629066467285,84.06,99.28,1.357133222579956,81.6,99.38749999999999 30 | 29,0.6734722037315368,84.64,99.28,1.3448712224960326,81.3125,99.4 31 | 30,0.6642050458908081,84.17999999999999,99.28,1.332264588356018,81.5125,99.41250000000001 32 | 31,0.6622811601638794,83.52000000000001,99.36,1.3229667139053345,81.6,99.4 33 | 32,0.6436707684516907,84.2,99.38,1.307395610809326,81.75,99.425 34 | 33,0.6284469531059265,84.3,99.56,1.2958546113967895,81.7625,99.38749999999999 35 | 34,0.625880845451355,84.34,99.46000000000001,1.2874197111129762,81.625,99.4 36 | 35,0.6101801568031311,85.04,99.44,1.2744026861190796,81.875,99.375 37 | 36,0.594359105682373,85.02,99.53999999999999,1.2607554054260255,82.0375,99.38749999999999 38 | 37,0.5938352709770203,84.66,99.52,1.2540022830963136,82.0625,99.4 39 | 38,0.5917745595932007,84.86,99.4,1.246719612121582,82.025,99.38749999999999 40 | 39,0.5821160996437073,84.84,99.56,1.2324062976837158,81.9375,99.41250000000001 41 | 40,0.5748880784034729,84.82,99.32,1.222454221725464,82.175,99.38749999999999 42 | 41,0.5710795659065246,84.64,99.53999999999999,1.2201661710739136,82.15,99.3625 43 | 42,0.567002702999115,84.66,99.42,1.2068362731933593,82.5375,99.38749999999999 44 | 43,0.5608441215515136,84.78,99.32,1.201515947341919,82.375,99.3625 45 | 44,0.5359812728881836,85.82,99.66000000000001,1.1899815711975097,82.4125,99.375 46 | 45,0.5436428760528564,85.68,99.44,1.1851844396591187,82.4125,99.4 47 | 46,0.5403897270202637,85.08,99.3,1.1780760755538942,82.575,99.375 48 | 47,0.5392217435359955,84.86,99.4,1.1705868644714355,82.5,99.38749999999999 49 | 48,0.5359720094203949,84.58,99.58,1.1634352951049804,82.5625,99.41250000000001 50 | 49,0.5211987493515015,85.61999999999999,99.6,1.1571491575241089,82.575,99.4 51 | 50,0.5280253508090973,85.0,99.38,1.1525636119842528,82.78750000000001,99.425 52 | 51,0.5165058417320252,85.56,99.6,1.1425102405548095,82.7375,99.4375 53 | 52,0.5036860698699951,85.61999999999999,99.5,1.1345144500732423,82.8875,99.45 54 | 53,0.4980165102958679,85.74000000000001,99.46000000000001,1.1247358837127686,82.8125,99.45 55 | 54,0.503675221157074,85.11999999999999,99.56,1.122702067375183,82.825,99.425 56 | 55,0.5034103675842285,85.39999999999999,99.44,1.1186348485946656,82.8625,99.4375 57 | 56,0.5004815286159515,85.24000000000001,99.44,1.1121335020065308,82.9875,99.41250000000001 58 | 57,0.49233290033340454,85.64,99.52,1.106219783782959,83.05,99.41250000000001 59 | 58,0.4831447655677795,86.33999999999999,99.52,1.0997568349838256,82.9375,99.45 60 | 59,0.4949026760101318,85.82,99.6,1.1006921415328978,83.0,99.425 61 | 60,0.4855158269882202,86.0,99.58,1.0893584241867065,83.1125,99.4375 62 | 61,0.47568662967681885,86.58,99.5,1.0839851894378663,83.175,99.4375 63 | 62,0.4676999207496643,86.04,99.68,1.0798140392303466,83.0625,99.45 64 | 63,0.4779873286247253,85.74000000000001,99.66000000000001,1.0776025590896607,82.9375,99.45 65 | 64,0.4662654571056366,86.26,99.62,1.0716015605926514,83.125,99.4375 66 | 65,0.46651281657218935,86.14,99.62,1.0696470012664796,83.1625,99.46249999999999 67 | 66,0.4583287220001221,86.04,99.66000000000001,1.0656101865768433,83.22500000000001,99.45 68 | 67,0.4584460023403168,86.6,99.44,1.0567087631225587,83.25,99.45 69 | 68,0.47010446615219115,85.8,99.5,1.0547093677520751,83.26249999999999,99.47500000000001 70 | 69,0.4553971031188965,86.58,99.53999999999999,1.052298864364624,83.25,99.46249999999999 71 | 70,0.45542179169654845,86.02,99.52,1.042236050605774,83.6125,99.47500000000001 72 | 71,0.4689282648563385,85.98,99.48,1.045232045173645,83.42500000000001,99.425 73 | 72,0.44606041898727417,86.5,99.62,1.0355810298919679,83.42500000000001,99.4375 74 | 73,0.4590295832157135,85.48,99.53999999999999,1.0351566762924194,83.4375,99.45 75 | 74,0.44059539999961855,86.46000000000001,99.68,1.0291480960845947,83.45,99.4375 76 | 75,0.44589597697257993,85.96000000000001,99.76,1.0236962718963623,83.42500000000001,99.4375 77 | 76,0.4569112882614136,85.39999999999999,99.52,1.0250700330734253,83.7,99.47500000000001 78 | 77,0.44317653717994687,86.9,99.64,1.0191460037231446,83.76249999999999,99.47500000000001 79 | 78,0.4444390136241913,86.46000000000001,99.68,1.0160138721466065,83.825,99.45 80 | 79,0.43353584990501404,86.61999999999999,99.33999999999999,1.0110815677642822,83.75,99.47500000000001 81 | 80,0.43107402420043944,86.44,99.76,1.0067088937759399,83.875,99.5125 82 | 81,0.44140214128494265,86.0,99.46000000000001,1.0060942106246948,83.8,99.5125 83 | 82,0.42756440744400026,86.58,99.62,0.9985377082824707,83.92500000000001,99.5125 84 | 83,0.42294873113632203,87.12,99.53999999999999,1.0020516023635864,83.78750000000001,99.47500000000001 85 | 84,0.42216853313446046,86.98,99.52,0.9964222612380982,84.075,99.4875 86 | 85,0.4254521691799164,86.92,99.66000000000001,0.9886314897537232,84.1125,99.4875 87 | 86,0.4218989317417145,86.98,99.64,0.9874826307296753,84.1,99.4875 88 | 87,0.42631903719902037,86.6,99.53999999999999,0.9843551521301269,84.15,99.46249999999999 89 | 88,0.42168343453407287,86.72,99.68,0.981222113609314,83.975,99.5 90 | 89,0.4164175959587097,87.12,99.58,0.975565571308136,84.1625,99.5 91 | 90,0.41582111711502073,86.78,99.78,0.9738754243850708,84.2125,99.47500000000001 92 | 91,0.41423747925758364,86.68,99.68,0.9700007371902466,84.3375,99.4875 93 | 92,0.41603217372894286,86.94,99.68,0.9711467900276184,84.25,99.4875 94 | 93,0.4159442962169647,86.68,99.64,0.9729300622940064,84.25,99.4875 95 | 94,0.3976638189315796,87.64,99.72,0.9641343412399292,84.3125,99.4875 96 | 95,0.40926754364967344,87.0,99.58,0.9621022901535035,84.1375,99.45 97 | 96,0.40822780995368957,87.4,99.58,0.9597155032157898,84.275,99.4375 98 | 97,0.41371343169212343,86.58,99.56,0.9554791607856751,84.2,99.45 99 | 98,0.41597759351730346,86.58,99.6,0.9617002787590027,84.1375,99.4375 100 | 99,0.4186860776424408,86.64,99.58,0.9551908307075501,84.26249999999999,99.45 101 | 100,0.41007653741836547,87.36,99.64,0.9536493349075318,84.375,99.46249999999999 102 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python main.py --dataset stl10 --estimator easy --sampling RWR --random_epochs 1 --edge_nums 100 --restart_p 0.2 --batch_size 256 --epochs 1000 --update_batchs 100 --save_path test_stl10 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | import numpy as np 4 | from torchvision.datasets import STL10 5 | 6 | 7 | class Sampler(object): 8 | def __init__(self, indices): 9 | self.indices = indices 10 | 11 | def __iter__(self): 12 | return iter(self.indices) 13 | 14 | def __len__(self): 15 | return len(self.indices) 16 | 17 | class STL10Pair(STL10): 18 | def __getitem__(self, index): 19 | img, target = self.data[index], self.labels[index] 20 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 21 | 22 | if self.transform is not None: 23 | pos_1 = self.transform(img) 24 | pos_2 = self.transform(img) 25 | 26 | return index, pos_1, pos_2, target 27 | 28 | 29 | train_transform_stl10 = transforms.Compose([ 30 | transforms.RandomResizedCrop(32), 31 | transforms.RandomHorizontalFlip(p=0.5), 32 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 33 | transforms.RandomGrayscale(p=0.2), 34 | transforms.ToTensor(), 35 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 36 | 37 | test_transform_stl10 = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 40 | 41 | --------------------------------------------------------------------------------