├── README.md
├── eval.sh
├── img
└── fig.png
├── linear.py
├── main.py
├── model.py
├── results
├── 128_0.5_200_256_1000_statistics.csv
└── linear_statistics.csv
├── train.sh
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | BatchSampler: Sampling Mini-Batches for Contrastive Learning in Vision, Language, and Graphs
9 |
10 |
11 |
12 | Source code for KDD'23 paper: [BatchSampler: Sampling Mini-Batches for Contrastive Learning in Vision, Language, and Graphs](https://arxiv.org/abs/2306.03355).
13 |
14 |
15 | BatchSampler is a simple and general generative method to sample mini-batches of hard-to-distinguish (i.e., hard and true negatives to each other) instances, which can be directly plugged into in-batch contrastive models in vision, language, and graphs.
16 |
17 | Dependencies
18 |
19 | * Python >= 3.7
20 | * [Pytorch](https://pytorch.org/) >= 1.9.0
21 |
22 | Quick Start
23 |
24 | Take vision modality as an example, you can run the code on STL10.
25 |
26 | ```bash
27 | sh train.sh
28 | ```
29 |
30 |
31 |
32 | Datasets
33 |
34 | We conduct experiments on five datasets across three modalities. For vision modality, we use a large-scale dataset [ImageNet](https://www.image-net.org/), two medium-sacle datasets: [STL10](https://cs.stanford.edu/~acoates/stl10/) and [ImageNet-100](https://www.kaggle.com/datasets/ambityga/imagenet100), and two small-scale datasets: [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) and [CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html). For language modality, we use 7 semantic textual similarity (STS) tasks. For graphs modality, we conduct graph-level classification experiments on 7 benchmark datasets: IMDB-B, IMDB- M, COLLAB, REDDIT-B, PROTEINS, MUTAG, and NCI1.
35 |
36 | Experimental Results
37 | Vision Modality
38 |
39 | | Method | 100ep | 400ep | 800ep |
40 | | ------------------ | ------------ | ------------ | ------------ |
41 | | SimCLR | 64.0 | 68.1 | 68.7 |
42 | | **w/BatchSampler** | **64.7** | **68.6** | **69.2** |
43 | | Moco v3 | 68.9 | 73.3 | 73.8 |
44 | | **w/BatchSampler** | **69.5** | **73.7** | **74.2** |
45 |
46 | language Modality
47 |
48 | | Method | STS12 | STS13 | STS14 | STS15 | STS16 | STS-B | SICK-R | Avg. |
49 | | ------------------ | ------------ | ------------ | ------------ | -------------- | -------------- | -------------- | -------------- | -------------- |
50 | | SimCSE-BERT{BASE} | 68.62 | 80.89 | 73.74 | 80.88 | 77.66 | **77.79** | **69.64** | 75.60 |
51 | | w/kNN Sampler | 63.62 | 74.86 | 69.79 | 79.17 | 76.24 | 74.73 | 67.74 | 72.31 |
52 | | **w/BatchSampler** | **72.37** | **82.08** | **75.24** | **83.10** | **78.43** | 77.54 | 68.05 | **76.69** |
53 | | DCL-BERT{BASE} | 65.22 | 77.89 | 68.94 | 79.88 | **76.72** |73.89 | **69.54** | 73.15 |
54 | | w/kNN Sampler | 66.34 | 76.66 | 72.60 | 78.30 | 74.86 | 73.65 | 67.92 | 72.90 |
55 | | **w/BatchSampler** | **69.55** | **82.66** | **73.37** | **80.40** | 75.37 | **75.43** | 66.76 | **74.79** |
56 | | HCL-BERT{BASE} | 62.57 | 79.12 | 69.70 | 78.00 | 75.11 | 73.38 | 69.74 | 72.52|
57 | | w/kNN Sampler | 61.12 | 75.73 | 68.43 | 76.64 | 74.78 | 71.22 | 68.04 | 70.85 |
58 | | **w/BatchSampler** | **66.87** | **81.38** | **72.96** | **80.11** | **77.99** | **75.95** | 70.89 | **75.16** |
59 |
60 |
61 |
62 | Graphs Modality
63 |
64 | | Method | IMDB-B | IMDB-M | COLLAB | REDDIT-B | PROTEINS | MUTAG | NCI1 |
65 | | ------------------ | ------------ | ------------ | ------------ | -------------- | -------------- | -------------- | -------------- |
66 | | GraphCL | 70.90±0.53 | 48.48±0.38 | 70.62±0.23 | 90.54±0.25 | 74.39±0.45 | 86.80±1.34 | 77.87±0.41 |
67 | | w/kNN Sampler | 70.72±0.35 | 47.97±0.97 | 70.59±0.14 | 90.21±0.74 | 74.17±0.41 | 86.46±0.82 | 77.27±0.37 |
68 | | **w/BatchSampler** | **71.90±0.46** | **48.93±0.28** | **71.48±0.28** | **90.88±0.16** | **75.04±0.67** | **87.78±0.93** | **78.93±0.38** |
69 | | DCL| 71.07±0.36 | 48.93±0.32 | **71.06±0.51** | 90.66±0.29 | 74.64±0.48 | 88.09±0.93 | 78.49±0.48 |
70 | | w/kNN Sampler | 70.94±0.19 | 48.47±0.35 | 70.49±0.37 | 90.26±1.03 | 74.28±0.17 | 87.13±1.40 | 78.13±0.52 |
71 | | **w/BatchSampler** | **71.32±0.17** | **48.96±0.25** | 70.44±0.35 | **90.73±0.34** | **75.02±0.61** | **89.47±1.43** | **79.03±0.32** |
72 | | HCL| **71.24±0.36** | 48.54±0.51 | 71.03±0.45 | 90.40±0.42 | 74.69±0.42 | 87.79±1.10 | 78.83±0.67 |
73 | | w/kNN Sampler | 71.14±0.44 | 48.36±0.93 | 70.86±0.74 | 90.64±0.51 | 74.06±0.44 | 87.53±1.37 | 78.66±0.48 |
74 | | **w/BatchSampler** | 71.20±0.38 | **48.76±0.39** | **71.70±0.35** | **91.25±0.25** | **75.11±0.63** | **88.31±1.29** | **79.17±0.27** |
75 | | MVGRL| 74.20±0.70 | 51.20±0.50 |- | 84.50±0.60 | - | 89.70±1.10 | - |
76 | | w/kNN Sampler | 73.30±0.34 | 50.70±0.36 | - | 82.70±0.67 | - | 85.08±0.66 | - |
77 | | **w/BatchSampler** | **76.70±0.35** | **52.40±0.39** | - | **87.47±0.79** | - | **91.13±0.81** | - |
78 |
79 | Citing
80 | If you find our work is helpful to your research, please consider citing our paper:
81 |
82 | ```
83 | @article{yang2023batchsampler,
84 | title={BatchSampler: Sampling Mini-Batches for Contrastive Learning in Vision, Language, and Graphs},
85 | author={Yang, Zhen and Huang, Tinglin and Ding, Ming and Dong, Yuxiao and Ying, Rex and Cen, Yukuo and Geng, Yangliao and Tang, Jie},
86 | journal={arXiv preprint arXiv:2306.03355},
87 | year={2023}
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=7
2 | python linear.py --dataset stl10 --batch_size 512 --epochs 100 --save_path test_stl10 --model_path ./test_stl10/128_0.5_200_256_1000_model.pth
3 |
--------------------------------------------------------------------------------
/img/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/BatchSampler/28017ec4b6f2c8eebcde16ea97cfa3bb217cfb4a/img/fig.png
--------------------------------------------------------------------------------
/linear.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import pandas as pd
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from thop import profile, clever_format
8 | from torch.utils.data import DataLoader
9 | from torchvision.datasets import CIFAR10, CIFAR100, STL10
10 | from tqdm import tqdm
11 |
12 | import utils
13 | from model import Model
14 |
15 |
16 | class Net(nn.Module):
17 | def __init__(self, num_class, pretrained_path):
18 | super(Net, self).__init__()
19 |
20 | # encoder
21 | self.f = Model().f
22 | # classifier
23 | self.fc = nn.Linear(2048, num_class, bias=True)
24 | self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)
25 |
26 | def forward(self, x):
27 | x = self.f(x)
28 | feature = torch.flatten(x, start_dim=1)
29 | out = self.fc(feature)
30 | return out
31 |
32 |
33 | # train or test for one epoch
34 | def train_val(net, data_loader, train_optimizer):
35 | is_train = train_optimizer is not None
36 | net.train() if is_train else net.eval()
37 |
38 | total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
39 | with (torch.enable_grad() if is_train else torch.no_grad()):
40 | for data, target in data_bar:
41 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
42 | out = net(data)
43 | loss = loss_criterion(out, target)
44 |
45 | if is_train:
46 | train_optimizer.zero_grad()
47 | loss.backward()
48 | train_optimizer.step()
49 |
50 | total_num += data.size(0)
51 | total_loss += loss.item() * data.size(0)
52 | prediction = torch.argsort(out, dim=-1, descending=True)
53 | total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
54 | total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
55 |
56 | data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%'
57 | .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num,
58 | total_correct_1 / total_num * 100, total_correct_5 / total_num * 100))
59 |
60 | return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100
61 |
62 |
63 | if __name__ == '__main__':
64 | parser = argparse.ArgumentParser(description='Linear Evaluation')
65 | parser.add_argument('--model_path', type=str, default='results/128_0.5_200_512_500_model.pth',
66 | help='The pretrained model path')
67 | parser.add_argument('--batch_size', type=int, default=512, help='Number of images in each mini-batch')
68 | parser.add_argument('--epochs', type=int, default=100, help='Number of sweeps over the dataset to train')
69 | parser.add_argument('--save_path', default='results', type=str, help='save path')
70 | parser.add_argument('--dataset', default='cifar10', type=str, help='train dataset')
71 |
72 |
73 | args = parser.parse_args()
74 | model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs
75 | if args.dataset == 'cifar10':
76 | train_data = CIFAR10(root='data', train=True, transform=utils.train_transform_cifar10, download=True)
77 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
78 | test_data = CIFAR10(root='data', train=False, transform=utils.test_transform_cifar10, download=True)
79 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
80 | elif args.dataset == 'cifar100':
81 | train_data = CIFAR100(root='data', train=True, transform=utils.train_transform_cifar100, download=True)
82 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
83 | test_data = CIFAR100(root='data', train=False, transform=utils.test_transform_cifar100, download=True)
84 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
85 | elif args.dataset == 'stl10':
86 | train_data = STL10(root='data', split='train', transform=utils.train_transform_stl10, download=True)
87 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
88 | test_data = STL10(root='data', split='test', transform=utils.test_transform_stl10, download=True)
89 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
90 | else:
91 | raise Exception("unvalid dataset")
92 |
93 | model = Net(num_class=len(train_data.classes), pretrained_path=model_path).cuda()
94 | for param in model.f.parameters():
95 | param.requires_grad = False
96 |
97 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),))
98 | flops, params = clever_format([flops, params])
99 | print('# Model Params: {} FLOPs: {}'.format(params, flops))
100 | optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)
101 | loss_criterion = nn.CrossEntropyLoss()
102 | results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [],
103 | 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []}
104 |
105 | best_acc = 0.0
106 | for epoch in range(1, epochs + 1):
107 | train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer)
108 | results['train_loss'].append(train_loss)
109 | results['train_acc@1'].append(train_acc_1)
110 | results['train_acc@5'].append(train_acc_5)
111 | test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None)
112 | results['test_loss'].append(test_loss)
113 | results['test_acc@1'].append(test_acc_1)
114 | results['test_acc@5'].append(test_acc_5)
115 | # save statistics
116 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
117 | data_frame.to_csv(args.save_path+'/linear_statistics.csv', index_label='epoch')
118 | if test_acc_1 > best_acc:
119 | best_acc = test_acc_1
120 | torch.save(model.state_dict(), args.save_path+'/linear_model.pth')
121 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import pandas as pd
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from thop import profile, clever_format
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | import time
13 | import random
14 | import numpy as np
15 |
16 | import utils
17 | from model import Model
18 |
19 | import json
20 |
21 |
22 | def get_negative_mask(batch_size):
23 | negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
24 | for i in range(batch_size):
25 | negative_mask[i, i] = 0
26 | negative_mask[i, i + batch_size] = 0
27 |
28 | negative_mask = torch.cat((negative_mask, negative_mask), 0)
29 | return negative_mask
30 |
31 | # train for one epoch to learn unique features
32 | def train(net, data_loader, train_optimizer, batch_save):
33 | net.train()
34 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
35 | batch_num = 0
36 | for idx, pos_1, pos_2, target in train_bar:
37 | batch_size = pos_1.shape[0]
38 | pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True)
39 | feature_1, out_1 = net(pos_1)
40 | feature_2, out_2 = net(pos_2)
41 | # [2*B, D]
42 | out = torch.cat([out_1, out_2], dim=0)
43 | # [2*B, 2*B]
44 | sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
45 | # neg
46 | mask = get_negative_mask(batch_size).cuda()
47 | # [2*B, 2*B-2]
48 |
49 |
50 | neg = sim_matrix.masked_select(mask).view(2 * batch_size, -1)
51 |
52 | # compute loss
53 | pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
54 | # [2*B]
55 | pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
56 |
57 | if estimator=='hard':
58 | N = batch_size * 2 - 2
59 | imp = (beta* neg.log()).exp()
60 | reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)
61 | Ng = (-tau_plus * N * pos_sim + reweight_neg) / (1 - tau_plus)
62 | # constrain (optional)
63 | Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))
64 | elif estimator=='easy':
65 | Ng = neg.sum(dim=-1)
66 | elif estimator=='debias':
67 | N = batch_size * 2 - 2
68 | Ng = (-tau_plus * N * pos_sim + neg.sum(dim = -1)) / (1 - tau_plus)
69 | # constrain (optional)
70 | Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))
71 | else:
72 | raise Exception('Invalid estimator selected. Please use any of [hard, easy]')
73 |
74 | loss = (- torch.log(pos_sim / (pos_sim + Ng))).mean()
75 |
76 |
77 | # get all image embeds
78 | if batch_num == 0:
79 | embeds_all = torch.cat((out_1, out_2), dim=1).detach()
80 | index_all = idx.view(batch_size, -1).detach()
81 | else:
82 | embeds_all = torch.cat((embeds_all, torch.cat((out_1, out_2), dim=1)), dim=0).detach()
83 | index_all = torch.cat((index_all, idx.view(batch_size, -1)), dim=0).detach()
84 |
85 | train_optimizer.zero_grad()
86 | loss.backward()
87 | train_optimizer.step()
88 |
89 | total_num += batch_size
90 | total_loss += loss.item() * batch_size
91 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))
92 | batch_num += 1
93 | batch_save += 1
94 |
95 | if batch_save != 0 and batch_save % args.save_batches == 0:
96 | torch.save(model.state_dict(), args.save_path+'/model_batch_{}.pth'.format(batch_save))
97 | if batch_save > batch_nums_all:
98 | return total_loss / total_num, embeds_all, index_all, batch_save
99 | return total_loss / total_num, embeds_all, index_all, batch_save
100 |
101 |
102 | # test for one epoch, use weighted knn to find the most similar images' label to assign the test image
103 | def test(net, memory_data_loader, test_data_loader):
104 | net.eval()
105 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
106 | with torch.no_grad():
107 | # generate feature bank
108 | for _, data, _, target in tqdm(memory_data_loader, desc='Feature extracting'):
109 | feature, out = net(data.cuda(non_blocking=True))
110 | feature_bank.append(feature)
111 | # [D, N]
112 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
113 | # [N]
114 | feature_labels = torch.tensor(memory_data_loader.dataset.labels, device=feature_bank.device)
115 | # loop test data to predict the label by weighted knn search
116 | test_bar = tqdm(test_data_loader)
117 | for _, data, _, target in test_bar:
118 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
119 | feature, out = net(data)
120 |
121 | total_num += data.size(0)
122 | # compute cos similarity between each feature vector and feature bank ---> [B, N]
123 | sim_matrix = torch.mm(feature, feature_bank)
124 | # [B, K]
125 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)
126 | # [B, K]
127 | sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)
128 | sim_weight = (sim_weight / temperature).exp()
129 |
130 | # counts for each class
131 | one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)
132 | # [B*K, C]
133 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1).long(), value=1.0)
134 | # weighted score ---> [B, C]
135 | pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)
136 |
137 | pred_labels = pred_scores.argsort(dim=-1, descending=True)
138 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
139 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
140 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%'
141 | .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100))
142 |
143 | return total_top1 / total_num * 100, total_top5 / total_num * 100
144 |
145 |
146 | # random -> topk
147 | def build_graph(images_embeds):
148 | split_num = 20
149 | fold_len = images_embeds.size()[0] // split_num
150 |
151 | images_embeds = F.normalize(images_embeds, dim=-1)
152 |
153 | for i in range(split_num):
154 | start = i * fold_len
155 | if i == split_num -1:
156 | end = images_embeds.size()[0]
157 | else:
158 | end = (i + 1) * fold_len
159 |
160 | # random
161 | neg_indices = torch.randint(0, images_embeds.shape[0], (end-start, args.M)).cuda()
162 | random_select_idx = neg_indices.view(1, -1)[0]
163 |
164 | anchor = images_embeds[start:end].unsqueeze(1)
165 |
166 | negs = images_embeds[random_select_idx].view(end-start, args.M, -1)
167 |
168 | B_score = torch.matmul(anchor, negs.permute(0,2,1)).squeeze()
169 | B_score = torch.exp(B_score / temperature)
170 | neg = B_score
171 |
172 | # top
173 | B_score, B_idx = neg.topk(k=edge_nums, dim=-1)
174 | B_indices = torch.gather(neg_indices, 1, B_idx)
175 |
176 |
177 | if i == 0:
178 | edge_score = B_score
179 | adj_index = B_indices
180 | else:
181 | edge_score = torch.cat([edge_score, B_score], dim=0)
182 | adj_index = torch.cat([adj_index, B_indices], dim=0)
183 |
184 | edge_score = edge_score.detach().cpu()
185 | adj_index = adj_index.detach().cpu()
186 |
187 | return map_index, adj_index, edge_score
188 |
189 |
190 |
191 | def sampling(start_node, adj_index, map_index, visit, restart_p):
192 | if args.sampling == 'RWR':
193 | walks, visit = random_walk(start_node, adj_index, map_index, visit, restart_p)
194 | else:
195 | raise Exception("unvalid sampling method")
196 | return walks, visit
197 |
198 |
199 | def random_walk(start_node, adj_index, map_index, visit, restart_p):
200 | # random walk + restart
201 | walks = set()
202 | walks.add(start_node)
203 | paths = [start_node]
204 | while len(walks) < batch_size:
205 | cur = paths[-1]
206 | nodes = adj_index[cur].tolist()
207 | if len(nodes) > 0:
208 | if torch.rand(1).item() < restart_p:
209 | paths.append(paths[0])
210 | else:
211 | next_node = random.choice(nodes)
212 | paths.append(next_node)
213 | if next_node not in visit:
214 | walks.add(next_node)
215 | else:
216 | break
217 | visit = visit | walks
218 | assert len(set(walks)) == batch_size
219 | return walks, visit
220 |
221 |
222 |
223 | def generate_indices(graph, walks):
224 | map_index, adj_index, edge_weights = graph[0], graph[1].numpy(), graph[2]
225 | indices = []
226 | for _ in range(args.update_batchs):
227 | visit = set()
228 | remain_images = list(set(map_index.keys()) - set(walks))
229 | if len(remain_images) == 0:
230 | return indices, walks
231 | start_node = random.choice(remain_images)
232 | batch_idx, visit = sampling(start_node, adj_index, map_index, visit, restart_p=restart_p)
233 | walks.extend(batch_idx)
234 | indices.extend(batch_idx)
235 | return indices, walks
236 |
237 | def update_graph(update_embeds, update_index, origin_embeds):
238 | map_index = {k:i for i, k in enumerate(update_index.view(1,-1)[0].numpy().tolist())}
239 | index = torch.tensor(list(map_index.keys()))
240 | embeds_images = update_embeds[torch.tensor(list(map_index.values()))]
241 | origin_embeds[index, :] = embeds_images
242 | images_embeds = origin_embeds
243 |
244 | graph = build_graph(images_embeds)
245 | return graph, images_embeds
246 |
247 | def get_embeds(embeds_all, index_all):
248 | embeds = torch.ones((len(train_data), embeds_all.shape[1]), dtype=torch.float).cuda()
249 | map_index = {k:i for i, k in enumerate(index_all.view(1,-1)[0].numpy().tolist())}
250 |
251 | index = torch.tensor(list(map_index.keys()))
252 | embeds_images = embeds_all[torch.tensor(list(map_index.values()))]
253 | embeds[index, :] = embeds_images
254 | images_embeds = embeds
255 | return map_index, images_embeds
256 |
257 | if __name__ == '__main__':
258 | parser = argparse.ArgumentParser(description='Train SimCLR')
259 | parser.add_argument('--dataset', default='stl10', type=str, help='train dataset')
260 | parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for latent vector')
261 | parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax')
262 | parser.add_argument('--k', default=200, type=int, help='Top k most similar images used to predict the label')
263 | parser.add_argument('--batch_size', default=512, type=int, help='Number of images in each mini-batch')
264 | parser.add_argument('--epochs', default=1000, type=int, help='Number of sweeps over the dataset to train')
265 | parser.add_argument('--save_batches', default=1000, type=int, help='Number of batches to save checkpoint')
266 | parser.add_argument('--edge_nums', default=100, type=int, help='the number of edges for build graph')
267 | parser.add_argument('--restart_p', default=0.1, type=float, help='random walk with restart probability')
268 | parser.add_argument('--random_epochs', default=10, type=int, help='random shuflle epochs at beginning')
269 | parser.add_argument('--save_path', default='results', type=str, help='save path')
270 | parser.add_argument('--sampling', default='RW', type=str, help='sampling method')
271 | parser.add_argument('--M', default=1000, type=int, help='Number of edges in first selection')
272 | parser.add_argument('--update_batchs', default=100, type=int, help='graph update batchs')
273 | parser.add_argument('--estimator', default='hard', type=str, help='Choose loss function')
274 | parser.add_argument('--tau_plus', default=0.1, type=float, help='Positive class priorx')
275 | parser.add_argument('--beta', default=1.0, type=float, help='Choose loss function')
276 |
277 | # args parse
278 | args = parser.parse_args()
279 | print(args)
280 | if not os.path.exists(args.save_path):
281 | os.mkdir(args.save_path)
282 | with open(args.save_path + '/args.txt', 'w') as f:
283 | json.dump(args.__dict__, f, indent=2)
284 |
285 | feature_dim, temperature, k, restart_p, edge_nums= args.feature_dim, args.temperature, args.k, args.restart_p, args.edge_nums
286 | batch_size, epochs = args.batch_size, args.epochs
287 | estimator = args.estimator
288 | tau_plus = args.tau_plus
289 | beta = args.beta
290 |
291 | # data prepare
292 | if args.dataset == 'stl10':
293 | train_data = utils.STL10Pair(root='data', split='train+unlabeled', transform=utils.train_transform_stl10, download=True)
294 | memory_data = utils.STL10Pair(root='data', split='train', transform=utils.test_transform_stl10, download=True)
295 | memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
296 | test_data = utils.STL10Pair(root='data', split='test', transform=utils.test_transform_stl10, download=True)
297 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
298 | else:
299 | raise Exception("unvalid dataset")
300 |
301 | args.save_batches = int(len(train_data)//batch_size+1) * 100
302 |
303 | # model setup and optimizer config
304 | model = Model(feature_dim).cuda()
305 | flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),))
306 | flops, params = clever_format([flops, params])
307 | print('# Model Params: {} FLOPs: {}'.format(params, flops))
308 | optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
309 | c = len(memory_data.classes)
310 |
311 | # training loop
312 | results = {'train_loss': [], 'test_acc@1': [], 'test_acc@5': []}
313 | save_name_pre = '{}_{}_{}_{}_{}'.format(feature_dim, temperature, k, batch_size, epochs)
314 |
315 |
316 | batch_nums_all = int(len(train_data)//batch_size+1) * args.epochs
317 | batch_save = 0
318 | best_acc = 0.0
319 | for epoch in range(1, epochs + 1):
320 | walks = []
321 |
322 | if epoch <= args.random_epochs:
323 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True,
324 | drop_last=False)
325 | train_loss, embeds_all, index_all, batch_save = train(model, train_loader, optimizer, batch_save)
326 | embeds_all, index_all = embeds_all.detach(), index_all.detach()
327 | # build graph
328 | if epoch == args.random_epochs:
329 | map_index, images_embeds = get_embeds(embeds_all, index_all)
330 | graph = build_graph(images_embeds)
331 | else:
332 | t_generate_indices = time.time()
333 | batch_num = len(train_data) // batch_size
334 | cur_batch = 0
335 | # restart_p
336 | restart_p = args.restart_p - (epoch/5)*0.1
337 | if restart_p <= 0.05:
338 | restart_p = 0.05
339 |
340 | while len(set(walks)) < batch_num * batch_size:
341 | if cur_batch != 0 and cur_batch % args.update_batchs == 0:
342 | # update graph
343 | graph, images_embeds = update_graph(embeds_all, index_all, images_embeds)
344 | indices, walks = generate_indices(graph, walks)
345 | sampler = utils.Sampler(indices)
346 | train_loader = DataLoader(train_data,
347 | batch_size=batch_size,
348 | shuffle=False, num_workers=1, pin_memory=True,
349 | sampler=sampler,
350 | )
351 | train_loss, embeds_all, index_all, batch_save = train(model, train_loader, optimizer, batch_save)
352 | embeds_all, index_all = embeds_all.detach(), index_all.detach()
353 | cur_batch += len(train_loader)
354 | if batch_save > batch_nums_all:
355 | break
356 |
357 | if batch_save > batch_nums_all:
358 | break
359 | remain_images = list(set(map_index.keys()) - set(walks))
360 | walks.extend(remain_images)
361 | assert len(set(walks)) == len(train_data)
362 | if len(remain_images) > 1:
363 | sampler = utils.Sampler(remain_images)
364 | train_loader = DataLoader(train_data,
365 | batch_size=batch_size,
366 | shuffle=False, num_workers=1, pin_memory=True,
367 | sampler=sampler,
368 | )
369 | train_loss, embeds_all, index_all, batch_save = train(model, train_loader, optimizer, batch_save)
370 | embeds_all, index_all = embeds_all.detach(), index_all.detach()
371 | print("t_generate_indices=========", time.time() - t_generate_indices)
372 |
373 |
374 | results['train_loss'].append(train_loss)
375 | test_acc_1, test_acc_5 = test(model, memory_loader, test_loader)
376 | results['test_acc@1'].append(test_acc_1)
377 | results['test_acc@5'].append(test_acc_5)
378 | # save statistics
379 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
380 | data_frame.to_csv(args.save_path+'/{}_statistics.csv'.format(save_name_pre), index_label='epoch')
381 | if test_acc_1 > best_acc:
382 | best_acc = test_acc_1
383 | torch.save(model.state_dict(), args.save_path+'/{}_model.pth'.format(save_name_pre))
384 |
385 |
386 |
387 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.models.resnet import resnet50
5 |
6 |
7 | class Model(nn.Module):
8 | def __init__(self, feature_dim=128):
9 | super(Model, self).__init__()
10 |
11 | self.f = []
12 | for name, module in resnet50().named_children():
13 | if name == 'conv1':
14 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
15 | if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
16 | self.f.append(module)
17 | # encoder
18 | self.f = nn.Sequential(*self.f)
19 | # projection head
20 | self.g = nn.Sequential(nn.Linear(2048, 512, bias=False), nn.BatchNorm1d(512),
21 | nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))
22 |
23 | def forward(self, x):
24 | x = self.f(x)
25 | feature = torch.flatten(x, start_dim=1)
26 | out = self.g(feature)
27 | return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)
28 |
--------------------------------------------------------------------------------
/results/128_0.5_200_256_1000_statistics.csv:
--------------------------------------------------------------------------------
1 | epoch,train_loss,test_acc@1,test_acc@5
2 | 1,5.293757577696301,42.225,93.675
3 | 2,4.826159853684275,58.6875,97.2625
4 | 3,4.771193599700927,61.762499999999996,97.6625
5 | 4,4.749537806913077,65.1125,98.0125
6 | 5,4.731312639183468,66.0875,98.0875
7 | 6,4.709558345833603,67.83749999999999,98.2625
8 | 7,4.7036038080851235,68.7,98.41250000000001
9 | 8,4.682034615107945,68.77499999999999,98.5
10 | 9,4.678532238006592,69.4625,98.41250000000001
11 | 10,4.673528847694397,70.46249999999999,98.5375
12 | 11,4.6616191946226975,71.0125,98.6625
13 | 12,4.6555983283303,71.45,98.7625
14 | 13,4.65258057364102,72.225,98.6875
15 | 14,4.643300998501662,73.825,98.9
16 | 15,4.64274603465818,71.75,98.7125
17 | 16,4.643919054667155,72.85000000000001,98.6
18 | 17,4.633812708932846,72.3625,98.6375
19 | 18,4.633907571752021,72.1375,98.8125
20 | 19,4.631996188365238,72.725,98.675
21 | 20,4.627749775940517,73.575,99.0125
22 | 21,4.623903426859114,73.7125,98.925
23 | 22,4.616939772730288,73.6125,98.875
24 | 23,4.617843597166,74.4,98.91250000000001
25 | 24,4.618095388412476,73.2,98.675
26 | 25,4.612618161671198,74.2625,98.88749999999999
27 | 26,4.610560584068298,74.6375,99.175
28 | 27,4.613450231223271,74.5125,98.85000000000001
29 | 28,4.616752988836739,74.9875,99.0125
30 | 29,4.607374534606934,74.5625,99.0
31 | 30,4.610040483474731,74.8,98.95
32 | 31,4.600977091789246,75.175,98.875
33 | 32,4.6121180125645225,75.1375,99.0375
34 | 33,4.6064458751678465,74.8375,98.7875
35 | 34,4.607028651237488,74.425,99.0375
36 | 35,4.605106512705485,75.775,99.125
37 | 36,4.600796276872808,75.6375,99.05000000000001
38 | 37,4.601950973272324,75.775,99.1375
39 | 38,4.600315594673157,75.5125,99.1
40 | 39,4.595778841972351,74.0125,99.075
41 | 40,4.592310833930969,75.0625,99.1375
42 | 41,4.599948711395264,75.35,99.0625
43 | 42,4.59720899105072,75.3375,98.97500000000001
44 | 43,4.599029782452161,76.03750000000001,98.96249999999999
45 | 44,4.5941962683072655,76.71249999999999,99.1125
46 | 45,4.592749371528625,76.5625,99.1375
47 | 46,4.591093096733093,75.44999999999999,99.0125
48 | 47,4.593859170612536,75.8875,99.1625
49 | 48,4.590588365282331,76.05,99.225
50 | 49,4.59339542388916,75.775,99.15
51 | 50,4.593822264671326,77.325,99.2
52 | 51,4.589333243370056,76.47500000000001,99.1625
53 | 52,4.588132557868957,75.3,99.05000000000001
54 | 53,4.586618657112122,76.6,99.2625
55 | 54,4.590204153060913,76.11250000000001,99.25
56 | 55,4.592039896517384,77.03750000000001,99.1625
57 | 56,4.584720191955566,75.7625,99.1125
58 | 57,4.589341929084377,76.625,99.25
59 | 58,4.5870132189924995,78.025,99.3625
60 | 59,4.59016788482666,75.6375,99.1
61 | 60,4.583746252059936,77.275,99.2125
62 | 61,4.5849903781762285,77.4625,99.25
63 | 62,4.579810733795166,77.525,99.1875
64 | 63,4.589144492149353,77.1375,99.3
65 | 64,4.585684504508972,78.4875,99.225
66 | 65,4.588148999214172,77.325,99.1625
67 | 66,4.580712671081225,78.4875,99.275
68 | 67,4.585197448730469,77.225,99.2125
69 | 68,4.582306609153748,76.3875,99.125
70 | 69,4.580277142524719,77.7875,99.41250000000001
71 | 70,4.585395255088806,76.7875,99.1125
72 | 71,4.579945621036348,77.5625,99.375
73 | 72,4.579310126304627,77.03750000000001,99.3
74 | 73,4.5810620307922365,77.0,99.3375
75 | 74,4.579955153465271,76.9375,99.2375
76 | 75,4.5884348335912675,76.9,99.2625
77 | 76,4.578743467330932,78.325,99.3
78 | 77,4.5799387121200565,78.0625,99.2125
79 | 78,4.573168625831604,77.9625,99.3125
80 | 79,4.574429011344909,78.73750000000001,99.3
81 | 80,4.577625751495361,79.0875,99.3625
82 | 81,4.576289272308349,77.9,99.2875
83 | 82,4.578504064348009,77.7,99.4
84 | 83,4.570481319816745,77.47500000000001,99.3625
85 | 84,4.57708838780721,77.875,99.3125
86 | 85,4.578318653106689,78.73750000000001,99.3125
87 | 86,4.576002221107483,77.8375,99.3
88 | 87,4.572186132957196,77.3875,99.2
89 | 88,4.568873965221902,78.5,99.2875
90 | 89,4.572124977111816,79.0875,99.3
91 | 90,4.571539549827576,78.03750000000001,99.275
92 | 91,4.578847978022192,78.5125,99.3625
93 | 92,4.577983069419861,78.1625,99.2
94 | 93,4.5709609790724155,78.9875,99.25
95 | 94,4.572255902290344,77.8625,99.2875
96 | 95,4.57702686018863,77.225,99.275
97 | 96,4.574975714683533,78.8125,99.25
98 | 97,4.57355366230011,78.025,99.3
99 | 98,4.573449125289917,78.6875,99.2125
100 | 99,4.5740386009216305,77.325,99.2125
101 | 100,4.5707975243622405,76.775,99.2375
102 | 101,4.568804507162056,78.9625,99.2375
103 | 102,4.570234065055847,77.7,99.3
104 | 103,4.570616459846496,77.8625,99.2625
105 | 104,4.571278429031372,77.425,99.2
106 | 105,4.5763217780901035,78.125,99.2625
107 | 106,4.571589671648466,75.9875,99.3125
108 | 107,4.572205200195312,78.35,99.325
109 | 108,4.566824765205383,77.14999999999999,99.1625
110 | 109,4.5730495187971325,77.3625,99.3125
111 | 110,4.566690254211426,78.3625,99.3
112 | 111,4.570611038208008,79.03750000000001,99.35000000000001
113 | 112,4.56855793697078,77.85,99.2375
114 | 113,4.569841089702788,78.625,99.2875
115 | 114,4.565202702175487,78.5875,99.3125
116 | 115,4.570380935668945,77.725,99.2125
117 | 116,4.56500020980835,79.25,99.4
118 | 117,4.567118740081787,78.075,99.46249999999999
119 | 118,4.565334788235751,78.825,99.25
120 | 119,4.56559255917867,80.4375,99.4375
121 | 120,4.571868437987107,78.1875,99.35000000000001
122 | 121,4.568742669146994,78.625,99.3625
123 | 122,4.572119749509371,77.25,99.2875
124 | 123,4.566438215119498,77.375,99.3
125 | 124,4.57454890953867,78.3625,99.2125
126 | 125,4.568522095680237,78.4875,99.325
127 | 126,4.5665170844172085,79.0125,99.35000000000001
128 | 127,4.564128122329712,79.5125,99.425
129 | 128,4.568067855834961,79.475,99.3625
130 | 129,4.566403017044068,77.725,99.3
131 | 130,4.562279386520386,78.2875,99.45
132 | 131,4.567161998748779,78.60000000000001,99.2625
133 | 132,4.566066150007577,78.8,99.47500000000001
134 | 133,4.56695170879364,79.4875,99.38749999999999
135 | 134,4.56440342040289,78.66250000000001,99.375
136 | 135,4.567775155368604,79.21249999999999,99.4
137 | 136,4.564255863428116,79.07499999999999,99.3125
138 | 137,4.557527716566876,77.8875,99.325
139 | 138,4.565986478328705,78.525,99.375
140 | 139,4.562467209994793,79.6875,99.3375
141 | 140,4.5640857672389545,79.95,99.4375
142 | 141,4.565324122647205,79.5875,99.3
143 | 142,4.556167395218559,79.5,99.3625
144 | 143,4.565654691060384,79.80000000000001,99.41250000000001
145 | 144,4.563221107210432,78.1125,99.1875
146 | 145,4.562008252014985,79.625,99.3125
147 | 146,4.563423084762861,80.0875,99.3375
148 | 147,4.561200676737605,78.4625,99.275
149 | 148,4.566325235366821,77.4375,99.3
150 | 149,4.565679378509522,79.16250000000001,99.3125
151 | 150,4.561311265112648,79.9875,99.46249999999999
152 | 151,4.564657626152038,78.45,99.2875
153 | 152,4.56643837802815,78.3375,99.38749999999999
154 | 153,4.565955352783203,79.4125,99.3375
155 | 154,4.565937538536227,77.925,99.3625
156 | 155,4.561128410926232,79.7375,99.3375
157 | 156,4.5652302503585815,80.7625,99.46249999999999
158 | 157,4.56415134522973,79.23750000000001,99.41250000000001
159 | 158,4.561256321383194,78.6875,99.325
160 | 159,4.562075548905593,81.075,99.45
161 | 160,4.565318050838652,78.7625,99.3375
162 | 161,4.56672986092106,78.875,99.3375
163 | 162,4.559463489346388,78.66250000000001,99.325
164 | 163,4.557186906987971,78.9,99.35000000000001
165 | 164,4.56389398097992,79.27499999999999,99.4
166 | 165,4.560159504413605,78.825,99.325
167 |
--------------------------------------------------------------------------------
/results/linear_statistics.csv:
--------------------------------------------------------------------------------
1 | epoch,train_loss,train_acc@1,train_acc@5,test_loss,test_acc@1,test_acc@5
2 | 1,2.23867652015686,47.32,85.56,2.2394468173980715,62.4875,98.1625
3 | 2,2.0928818286895754,76.32,98.52,2.175532444000244,75.75,99.1625
4 | 3,1.955794093132019,78.7,99.1,2.1144030265808107,78.95,99.2875
5 | 4,1.8307094875335694,80.2,99.14,2.0573978023529054,79.7375,99.325
6 | 5,1.712966128540039,80.74,99.16,2.0034180030822752,79.475,99.325
7 | 6,1.6091594816207886,80.82000000000001,99.18,1.9523463973999022,79.57499999999999,99.3
8 | 7,1.5115610223770142,80.94,99.26,1.9046716756820679,79.3875,99.2875
9 | 8,1.4153708534240723,81.26,99.32,1.8572627630233765,79.9375,99.2625
10 | 9,1.336933927154541,81.04,99.26,1.815097327232361,79.7,99.225
11 | 10,1.2672981122970581,81.69999999999999,99.18,1.7770221939086914,79.5875,99.175
12 | 11,1.2031101558685302,82.44,99.38,1.7408404541015625,79.675,99.2
13 | 12,1.1470788984298705,81.39999999999999,99.24,1.706960210800171,79.95,99.2
14 | 13,1.0925416803359986,82.24000000000001,99.32,1.6728482570648193,80.16250000000001,99.2
15 | 14,1.04535007686615,81.82000000000001,99.14,1.6424105081558227,80.22500000000001,99.2125
16 | 15,1.0020886152267456,82.44,99.36,1.6137052478790284,80.30000000000001,99.2
17 | 16,0.9616574898719787,82.58,99.48,1.5877059564590454,80.2875,99.2125
18 | 17,0.9207497631072998,82.72,99.36,1.5602429485321045,80.45,99.2125
19 | 18,0.8999906936645508,82.36,99.3,1.5393530073165893,80.4125,99.2
20 | 19,0.8730637706756592,83.0,99.32,1.5179570751190186,80.425,99.25
21 | 20,0.8397495670318603,82.98,99.24,1.4954745197296142,80.525,99.225
22 | 21,0.8153184818267822,83.44,99.2,1.475080030441284,80.60000000000001,99.2625
23 | 22,0.7863519223213196,83.34,99.42,1.45409801197052,80.875,99.325
24 | 23,0.7656220963478089,84.28,99.38,1.4374995250701905,80.9375,99.3375
25 | 24,0.745424748802185,84.2,99.42,1.4190076990127563,81.16250000000001,99.35000000000001
26 | 25,0.7366760729789734,83.52000000000001,99.33999999999999,1.4025026140213013,81.0875,99.3625
27 | 26,0.7170994510650635,83.39999999999999,99.24,1.3857117347717285,81.2625,99.3625
28 | 27,0.6940866159439087,84.7,99.5,1.370622797012329,81.2875,99.375
29 | 28,0.6919629066467285,84.06,99.28,1.357133222579956,81.6,99.38749999999999
30 | 29,0.6734722037315368,84.64,99.28,1.3448712224960326,81.3125,99.4
31 | 30,0.6642050458908081,84.17999999999999,99.28,1.332264588356018,81.5125,99.41250000000001
32 | 31,0.6622811601638794,83.52000000000001,99.36,1.3229667139053345,81.6,99.4
33 | 32,0.6436707684516907,84.2,99.38,1.307395610809326,81.75,99.425
34 | 33,0.6284469531059265,84.3,99.56,1.2958546113967895,81.7625,99.38749999999999
35 | 34,0.625880845451355,84.34,99.46000000000001,1.2874197111129762,81.625,99.4
36 | 35,0.6101801568031311,85.04,99.44,1.2744026861190796,81.875,99.375
37 | 36,0.594359105682373,85.02,99.53999999999999,1.2607554054260255,82.0375,99.38749999999999
38 | 37,0.5938352709770203,84.66,99.52,1.2540022830963136,82.0625,99.4
39 | 38,0.5917745595932007,84.86,99.4,1.246719612121582,82.025,99.38749999999999
40 | 39,0.5821160996437073,84.84,99.56,1.2324062976837158,81.9375,99.41250000000001
41 | 40,0.5748880784034729,84.82,99.32,1.222454221725464,82.175,99.38749999999999
42 | 41,0.5710795659065246,84.64,99.53999999999999,1.2201661710739136,82.15,99.3625
43 | 42,0.567002702999115,84.66,99.42,1.2068362731933593,82.5375,99.38749999999999
44 | 43,0.5608441215515136,84.78,99.32,1.201515947341919,82.375,99.3625
45 | 44,0.5359812728881836,85.82,99.66000000000001,1.1899815711975097,82.4125,99.375
46 | 45,0.5436428760528564,85.68,99.44,1.1851844396591187,82.4125,99.4
47 | 46,0.5403897270202637,85.08,99.3,1.1780760755538942,82.575,99.375
48 | 47,0.5392217435359955,84.86,99.4,1.1705868644714355,82.5,99.38749999999999
49 | 48,0.5359720094203949,84.58,99.58,1.1634352951049804,82.5625,99.41250000000001
50 | 49,0.5211987493515015,85.61999999999999,99.6,1.1571491575241089,82.575,99.4
51 | 50,0.5280253508090973,85.0,99.38,1.1525636119842528,82.78750000000001,99.425
52 | 51,0.5165058417320252,85.56,99.6,1.1425102405548095,82.7375,99.4375
53 | 52,0.5036860698699951,85.61999999999999,99.5,1.1345144500732423,82.8875,99.45
54 | 53,0.4980165102958679,85.74000000000001,99.46000000000001,1.1247358837127686,82.8125,99.45
55 | 54,0.503675221157074,85.11999999999999,99.56,1.122702067375183,82.825,99.425
56 | 55,0.5034103675842285,85.39999999999999,99.44,1.1186348485946656,82.8625,99.4375
57 | 56,0.5004815286159515,85.24000000000001,99.44,1.1121335020065308,82.9875,99.41250000000001
58 | 57,0.49233290033340454,85.64,99.52,1.106219783782959,83.05,99.41250000000001
59 | 58,0.4831447655677795,86.33999999999999,99.52,1.0997568349838256,82.9375,99.45
60 | 59,0.4949026760101318,85.82,99.6,1.1006921415328978,83.0,99.425
61 | 60,0.4855158269882202,86.0,99.58,1.0893584241867065,83.1125,99.4375
62 | 61,0.47568662967681885,86.58,99.5,1.0839851894378663,83.175,99.4375
63 | 62,0.4676999207496643,86.04,99.68,1.0798140392303466,83.0625,99.45
64 | 63,0.4779873286247253,85.74000000000001,99.66000000000001,1.0776025590896607,82.9375,99.45
65 | 64,0.4662654571056366,86.26,99.62,1.0716015605926514,83.125,99.4375
66 | 65,0.46651281657218935,86.14,99.62,1.0696470012664796,83.1625,99.46249999999999
67 | 66,0.4583287220001221,86.04,99.66000000000001,1.0656101865768433,83.22500000000001,99.45
68 | 67,0.4584460023403168,86.6,99.44,1.0567087631225587,83.25,99.45
69 | 68,0.47010446615219115,85.8,99.5,1.0547093677520751,83.26249999999999,99.47500000000001
70 | 69,0.4553971031188965,86.58,99.53999999999999,1.052298864364624,83.25,99.46249999999999
71 | 70,0.45542179169654845,86.02,99.52,1.042236050605774,83.6125,99.47500000000001
72 | 71,0.4689282648563385,85.98,99.48,1.045232045173645,83.42500000000001,99.425
73 | 72,0.44606041898727417,86.5,99.62,1.0355810298919679,83.42500000000001,99.4375
74 | 73,0.4590295832157135,85.48,99.53999999999999,1.0351566762924194,83.4375,99.45
75 | 74,0.44059539999961855,86.46000000000001,99.68,1.0291480960845947,83.45,99.4375
76 | 75,0.44589597697257993,85.96000000000001,99.76,1.0236962718963623,83.42500000000001,99.4375
77 | 76,0.4569112882614136,85.39999999999999,99.52,1.0250700330734253,83.7,99.47500000000001
78 | 77,0.44317653717994687,86.9,99.64,1.0191460037231446,83.76249999999999,99.47500000000001
79 | 78,0.4444390136241913,86.46000000000001,99.68,1.0160138721466065,83.825,99.45
80 | 79,0.43353584990501404,86.61999999999999,99.33999999999999,1.0110815677642822,83.75,99.47500000000001
81 | 80,0.43107402420043944,86.44,99.76,1.0067088937759399,83.875,99.5125
82 | 81,0.44140214128494265,86.0,99.46000000000001,1.0060942106246948,83.8,99.5125
83 | 82,0.42756440744400026,86.58,99.62,0.9985377082824707,83.92500000000001,99.5125
84 | 83,0.42294873113632203,87.12,99.53999999999999,1.0020516023635864,83.78750000000001,99.47500000000001
85 | 84,0.42216853313446046,86.98,99.52,0.9964222612380982,84.075,99.4875
86 | 85,0.4254521691799164,86.92,99.66000000000001,0.9886314897537232,84.1125,99.4875
87 | 86,0.4218989317417145,86.98,99.64,0.9874826307296753,84.1,99.4875
88 | 87,0.42631903719902037,86.6,99.53999999999999,0.9843551521301269,84.15,99.46249999999999
89 | 88,0.42168343453407287,86.72,99.68,0.981222113609314,83.975,99.5
90 | 89,0.4164175959587097,87.12,99.58,0.975565571308136,84.1625,99.5
91 | 90,0.41582111711502073,86.78,99.78,0.9738754243850708,84.2125,99.47500000000001
92 | 91,0.41423747925758364,86.68,99.68,0.9700007371902466,84.3375,99.4875
93 | 92,0.41603217372894286,86.94,99.68,0.9711467900276184,84.25,99.4875
94 | 93,0.4159442962169647,86.68,99.64,0.9729300622940064,84.25,99.4875
95 | 94,0.3976638189315796,87.64,99.72,0.9641343412399292,84.3125,99.4875
96 | 95,0.40926754364967344,87.0,99.58,0.9621022901535035,84.1375,99.45
97 | 96,0.40822780995368957,87.4,99.58,0.9597155032157898,84.275,99.4375
98 | 97,0.41371343169212343,86.58,99.56,0.9554791607856751,84.2,99.45
99 | 98,0.41597759351730346,86.58,99.6,0.9617002787590027,84.1375,99.4375
100 | 99,0.4186860776424408,86.64,99.58,0.9551908307075501,84.26249999999999,99.45
101 | 100,0.41007653741836547,87.36,99.64,0.9536493349075318,84.375,99.46249999999999
102 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python main.py --dataset stl10 --estimator easy --sampling RWR --random_epochs 1 --edge_nums 100 --restart_p 0.2 --batch_size 256 --epochs 1000 --update_batchs 100 --save_path test_stl10
2 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from torchvision import transforms
3 | import numpy as np
4 | from torchvision.datasets import STL10
5 |
6 |
7 | class Sampler(object):
8 | def __init__(self, indices):
9 | self.indices = indices
10 |
11 | def __iter__(self):
12 | return iter(self.indices)
13 |
14 | def __len__(self):
15 | return len(self.indices)
16 |
17 | class STL10Pair(STL10):
18 | def __getitem__(self, index):
19 | img, target = self.data[index], self.labels[index]
20 | img = Image.fromarray(np.transpose(img, (1, 2, 0)))
21 |
22 | if self.transform is not None:
23 | pos_1 = self.transform(img)
24 | pos_2 = self.transform(img)
25 |
26 | return index, pos_1, pos_2, target
27 |
28 |
29 | train_transform_stl10 = transforms.Compose([
30 | transforms.RandomResizedCrop(32),
31 | transforms.RandomHorizontalFlip(p=0.5),
32 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
33 | transforms.RandomGrayscale(p=0.2),
34 | transforms.ToTensor(),
35 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
36 |
37 | test_transform_stl10 = transforms.Compose([
38 | transforms.ToTensor(),
39 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
40 |
41 |
--------------------------------------------------------------------------------