├── README.md ├── SimCLR ├── balanced_cluster.py ├── kmeans_gpu.py ├── linear_classify.py ├── loss.py ├── main.py ├── resnet.py ├── sample.py ├── unbalance.py └── utils.py ├── image └── SBCL.jpg ├── moco ├── BaseDareLoader.py ├── balanced_cluster.py ├── imagenet_lt │ ├── test.txt │ └── train.txt ├── imagenet_lt_loader.py ├── kmean_gpu.py ├── linear_classify.py ├── loss.py ├── main.py ├── moco.py ├── sample.py └── utils.py └── sbcl.jpg /README.md: -------------------------------------------------------------------------------- 1 | # [ICCV 2023] Subclass-balancing contrastive learning for long-tailed recognition 2 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)]([https://arxiv.org/abs/2306.15925]) 3 | 4 | This repository provides the code for paper:
5 | **Subclass-balancing Contrastive Learning for Long-tailed Recognition** 6 |

7 |
8 | 9 | 10 | ## Overview 11 | In this paper, we prospose subclass-balancing contrastive learning (SBCL), 12 | a novel supervised contrastive learning defined on subclasses, which are the clusters within each 13 | head class, have comparable size as tail classes, and are adaptively updated during the training. 14 | Instead of sacrificing instance-balance for class-balance, our method achieves both instance- and 15 | subclass-balance by exploring the head-class structure in the learned representation space of the 16 | model-in-training. In particular, we propose a bi-granularity contrastive loss that enforces a sample 17 | (1) to be closer to samples from the same subclass than all the other samples; and (2) to be closer to 18 | samples from a different subclass but the same class than samples from any other subclasses. While 19 | the former learns representations with balanced and compact subclasses, the latter preserves the class 20 | structure on subclass level by encouraging the same class’s subslasses to be closer to each other than 21 | to any different class’s subclasses. Hence, it can learn an accurate classifier distinguishing original 22 | classes while enjoy both the instance- and subclass-balance. 23 | ## Requiremenmts 24 | * ImageNet dataset 25 | * Python ≥ 3.6 26 | * PyTorch ≥ 1.4 27 | * scikit-learn 28 | ## CIFAR dataset 29 | The code will help you download the CIFAR dataset automatically. Only change the `--dataset` and `--imb_factor` can change the CIFAR dataset. 30 | ### First-stage train 31 | To perform SBCL using 2-gpu machines, run: 32 |

python SimCLR/main.py \ 
33 |   --dataset 'cifar100' \ 
34 |   --imb_factor 0.01 \
35 |   --lr 0.5\
36 |   --batch_size 1024 \
37 |   --temperature 0.1 
38 | 
39 | NOTE: 40 | For the CIFAR-10 datasets, we should increase steps for updaing the clusters. 41 | 42 | ### Second-stage train 43 | To evalute the representation learning, our code support [Classifier-Balancing](https://arxiv.org/abs/1910.09217) and [LDAM](https://arxiv.org/abs/1906.07413) to learn the classify. 44 | We report the accuracy of LDAM in this paper. 45 | #### LDAM training 46 |
python SimCLR/linear_classify.py  \
47 |   --dataset 'cifar100' \ 
48 |   --imb_factor 0.01 \
49 |   --train_rule 'DRW' \
50 |   --epochs 200 
51 | 
52 | #### Classifier-Balancing training 53 |
python SimCLR/linear_classify.py  \
54 |   --dataset 'cifar100' \ 
55 |   --imb_factor 0.01 \
56 |   --train_rule 'CB' \
57 |   --epochs 45 
58 | 
59 | 60 | ## ImageNet-LT dataset 61 | You should download [ImageNet-LT](http://image-net.org/download) dataset manually, and place them in your `data_root`. Long-tailed version will be created using train/val splits (.txt files) in corresponding subfolders under `moco/imagenet_lt`. 62 | You should change the `data_root` and `save_folder` in [`moco/sbcl.py`](.moco/sbcl.py) and [`moco/linear_classify.py`](.moco/linear_classif.py) accordingly for ImageNet-LT. 63 | ### First-stage train 64 | To perform SBCL using 8-gpu machines, run: 65 |
python moco/main.py \ 
66 |   -a resnet50 \ 
67 |   --lr 0.1 \
68 |   --batch_size 256 \
69 |   --temperature 0.07\
70 |   --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 
71 | 
72 | 73 | ### Second-stage train 74 | To evalute the representation learning, run: 75 |
python moco/linear_classsify.py 
76 |   --pretrained [your pretrained model] \
77 |   -a resnet50 \ 
78 |   --lr 10 \
79 |   --batch_size 2048 \
80 |   --train_rule 'CB'\
81 |   --epochs 40 --schedule 20 30 --seed 0
82 | 
83 | NOTE: 84 | In this code, we also can use [LDAM](https://arxiv.org/abs/1906.07413) loss to train the linear classifier on top of the representation. Many/medium/minor classes accuracy could change significantly with different learning rate or batch size in the second stage while overall accuracy remains the same. 85 | 86 | ## Citation 87 | If you find SBCL helpful in your research, please consider citing: 88 | ```bibtex 89 | @article{hou2023subclass, 90 | title={Subclass-balancing contrastive learning for long-tailed recognition}, 91 | author={Hou, Chengkai and Zhang, Jieyu and Wang, Haonan and Zhou, Tianyi}, 92 | journal={arXiv preprint arXiv:2306.15925}, 93 | year={2023} 94 | } 95 | ``` 96 | ## Acknowledgement 97 | This code inherits some codes from [MoCo](https://github.com/facebookresearch/moco), [Classifier-Balancing](https://github.com/facebookresearch/classifier-balancing) and [LDAM](https://arxiv.org/abs/1906.07413). 98 | -------------------------------------------------------------------------------- /SimCLR/balanced_cluster.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | from tqdm import trange, tqdm 7 | 8 | 9 | def balanced_kmean( 10 | X, 11 | n_clusters, 12 | init='k-means++', 13 | device=torch.device('cpu'), 14 | tol=1e-4, 15 | iol=100, 16 | distance='cosine' 17 | 18 | ): 19 | ''' 20 | X: the clustered feature 21 | n_clusters: the cluster number 22 | ''' 23 | # convert to float 24 | X = X.float() 25 | 26 | # transfer to device 27 | X = X.to(device) 28 | 29 | if distance == 'euclidean': 30 | pairwise_similarity_function = partial(pairwise_euclidean, device=device) 31 | elif distance == 'cosine': 32 | pairwise_similarity_function = partial(pairwise_cosine, device=device) 33 | else: 34 | raise NotImplementedError 35 | # initialize 36 | if init == 'random': 37 | centroids = initialize(X, n_clusters) 38 | elif init == 'k-means++': 39 | centroids, _ = _kmeans_plusplus(X, 40 | n_clusters, 41 | random_state=0, 42 | pairwise_similarity=pairwise_similarity_function, 43 | n_local_trials=None) 44 | 45 | 46 | else: 47 | raise NotImplementedError 48 | 49 | N = len(X) 50 | n_per_cluster = N // n_clusters 51 | n_left = N % n_clusters 52 | for i in trange(iol): 53 | similarity_matrix = pairwise_similarity_function(centroids, X) 54 | similarity_matrix = similarity_matrix / similarity_matrix.sum(dim=1, keepdim=True) 55 | cluster_assignment = torch.zeros(N, dtype=torch.long) - 1 56 | cluster_size = {c: 0 for c in range(n_clusters)} 57 | 58 | idx = torch.argsort(similarity_matrix.flatten(), descending=True) 59 | #print(idx) 60 | 61 | if n_left == 0: 62 | for labels in idx: 63 | labels = labels.item() 64 | des = labels % N 65 | label = labels // N 66 | if cluster_assignment[des] == -1 and cluster_size[label] < n_per_cluster: 67 | cluster_assignment[des] = label 68 | cluster_size[label] += 1 69 | else: 70 | for labels in idx: 71 | labels = labels.item() 72 | des = labels % N 73 | label = labels // N 74 | if cluster_assignment[des] == -1 and cluster_size[label] < n_per_cluster: 75 | cluster_assignment[des] = label 76 | cluster_size[label] += 1 77 | similarity_matrix[:, des] = -100 78 | for _ in range(n_left): 79 | labels = torch.argmax(similarity_matrix).item() 80 | des = labels % N 81 | label = labels // N 82 | cluster_assignment[des] = label 83 | similarity_matrix[:, des] = -100 84 | cluster_size[label] += 1 85 | if cluster_size[label] >= n_per_cluster + 1: 86 | similarity_matrix[label, :] = -100 87 | 88 | assert torch.all(cluster_assignment != -1) 89 | 90 | last_centroids = centroids.clone() 91 | for index in range(n_clusters): 92 | centroids[index] = X[cluster_assignment == index].mean(dim=0) 93 | 94 | center_shift = torch.sum( 95 | torch.sqrt( 96 | torch.sum((centroids - last_centroids) ** 2, dim=1) 97 | )) 98 | 99 | # update tqdm meter 100 | if center_shift ** 2 < tol: 101 | break 102 | 103 | return cluster_assignment.cpu(), centroids.cpu() 104 | 105 | 106 | # def balanced_kmeans1( 107 | # X, 108 | # n_clusters, 109 | # init='k-means++', 110 | # device=torch.device('cpu'), 111 | # tol=1e-4, 112 | # iol=100, 113 | # distance='cosine' 114 | # 115 | # ): 116 | # ''' 117 | # X: the clustered feature 118 | # n_clusters: the cluster number 119 | # ''' 120 | # # convert to float 121 | # X = X.float() 122 | # 123 | # # transfer to device 124 | # X = X.to(device) 125 | # 126 | # if distance == 'euclidean': 127 | # pairwise_similarity_function = partial(pairwise_euclidean, device=device) 128 | # elif distance == 'cosine': 129 | # pairwise_similarity_function = partial(pairwise_cosine, device=device) 130 | # else: 131 | # raise NotImplementedError 132 | # # initialize 133 | # if init == 'random': 134 | # centroids = initialize(X, n_clusters) 135 | # elif init == 'k-means++': 136 | # centroids, _ = _kmeans_plusplus(X, 137 | # n_clusters, 138 | # random_state=0, 139 | # pairwise_distance=pairwise_similarity_function, 140 | # n_local_trials=None) 141 | # 142 | # 143 | # else: 144 | # raise NotImplementedError 145 | # 146 | # # centroids = KMeans(n_clusters=n_clusters)._init_centroids(X.cpu().numpy(), x_squared_norms=None, 147 | # # init=init, random_state=np.random.RandomState(seed=0)) 148 | # # centroids = torch.from_numpy(centroids).to(device) 149 | # 150 | # N = len(X) 151 | # n_per_cluster = N // n_clusters 152 | # n_left = N % n_clusters 153 | # for i in trange(iol): 154 | # similarity_matrix = pairwise_similarity_function(centroids, X) 155 | # similarity_matrix = similarity_matrix / similarity_matrix.sum(dim=1, keepdim=True) 156 | # cluster_assignment = torch.zeros(N) - 1 157 | # cluster_size = {c: 0 for c in range(n_clusters)} 158 | # 159 | # if n_left == 0: 160 | # for _ in range(len(X)): 161 | # labels = torch.argmax(similarity_matrix).item() 162 | # label = labels // len(X) 163 | # des = labels % len(X) 164 | # cluster_assignment[des] = label 165 | # similarity_matrix[:, des] = -100 166 | # cluster_size[label] += 1 167 | # if cluster_size[label] >= n_per_cluster: 168 | # similarity_matrix[label, :] = -100 169 | # else: 170 | # similarity_matrix_clone = similarity_matrix.clone() 171 | # for _ in range(n_per_cluster * n_clusters): 172 | # labels = torch.argmax(similarity_matrix).item() 173 | # label = labels // len(X) 174 | # des = labels % len(X) 175 | # cluster_assignment[des] = label 176 | # similarity_matrix[:, des] = -100 177 | # similarity_matrix_clone[:, des] = -100 178 | # cluster_size[label] += 1 179 | # if cluster_size[label] >= n_per_cluster: 180 | # similarity_matrix[label, :] = -100 181 | # for _ in range(n_left): 182 | # labels = torch.argmax(similarity_matrix_clone).item() 183 | # label = labels // len(X) 184 | # des = labels % len(X) 185 | # cluster_assignment[des] = label 186 | # similarity_matrix_clone[:, des] = -100 187 | # cluster_size[label] += 1 188 | # if cluster_size[label] >= n_per_cluster + 1: 189 | # similarity_matrix_clone[label, :] = -100 190 | # 191 | # last_centroids = centroids.clone() 192 | # for index in range(n_clusters): 193 | # centroids[index] = X[cluster_assignment == index].mean(dim=0) 194 | # 195 | # center_shift = torch.sum( 196 | # torch.sqrt( 197 | # torch.sum((centroids - last_centroids) ** 2, dim=1) 198 | # )) 199 | # 200 | # # update tqdm meter 201 | # if center_shift ** 2 < tol: 202 | # break 203 | # 204 | # return cluster_assignment.cpu().numpy(), centroids.cpu().numpy() 205 | 206 | 207 | def pairwise_cosine(data1, data2, device=torch.device('cpu')): 208 | data1, data2 = data1.to(device), data2.to(device) 209 | A = data1.unsqueeze(dim=1) 210 | B = data2.unsqueeze(dim=0) 211 | A_normalized = A / A.norm(dim=-1, keepdim=True) 212 | B_normalized = B / B.norm(dim=-1, keepdim=True) 213 | cosine = A_normalized * B_normalized 214 | cosine_dis = cosine.sum(dim=-1).squeeze() 215 | return cosine_dis 216 | 217 | 218 | def pairwise_euclidean(data1, data2, device=torch.device('cpu')): 219 | data1, data2 = data1.to(device), data2.to(device) 220 | A = data1.unsqueeze(dim=1) 221 | B = data2.unsqueeze(dim=0) 222 | dis = (A - B) ** 2.0 223 | dis = 1 / (dis.sum(dim=-1).squeeze() + 1e-4) 224 | return dis 225 | 226 | 227 | def initialize(X, num_clusters): 228 | """ 229 | initialize cluster centers 230 | """ 231 | num_samples = X.shape[1] 232 | bs = X.shape[0] 233 | 234 | indices = torch.empty(X.shape[:-1], device=X.device, dtype=torch.long) 235 | for i in range(bs): 236 | indices[i] = torch.randperm(num_samples, device=X.device) 237 | initial_state = torch.gather(X, 1, indices.unsqueeze(-1).repeat(1, 1, X.shape[-1])).reshape(bs, num_clusters, -1, X.shape[-1]).mean(dim=-2) 238 | return initial_state 239 | 240 | 241 | def stable_cumsum(arr, dim=None, rtol=1e-05, atol=1e-08): 242 | """Use high precision for cumsum and check that final value matches sum. 243 | """ 244 | if dim is None: 245 | arr = arr.flatten() 246 | dim = 0 247 | out = torch.cumsum(arr, dim=dim, dtype=torch.float64) 248 | expected = torch.sum(arr, dim=dim, dtype=torch.float64) 249 | if not torch.all(torch.isclose(out.take(torch.Tensor([-1]).long().to(arr.device)), 250 | expected, rtol=rtol, 251 | atol=atol, equal_nan=True)): 252 | warnings.warn('cumsum was found to be unstable: ' 253 | 'its last element does not correspond to sum', 254 | RuntimeWarning) 255 | return out 256 | 257 | 258 | def _kmeans_plusplus(X, 259 | n_clusters, 260 | random_state, 261 | pairwise_similarity, 262 | n_local_trials=None): 263 | """Computational component for initialization of n_clusters by 264 | k-means++. Prior validation of data is assumed. 265 | """ 266 | n_samples, n_features = X.shape 267 | 268 | generator = torch.Generator(device=str(X.device)) 269 | generator.manual_seed(random_state) 270 | 271 | centers = torch.empty((n_clusters, n_features), dtype=X.dtype, device=X.device) 272 | 273 | # Set the number of local seeding trials if none is given 274 | if n_local_trials is None: 275 | # This is what Arthur/Vassilvitskii tried, but did not report 276 | # specific results for other than mentioning in the conclusion 277 | # that it helped. 278 | n_local_trials = 2 + int(np.log(n_clusters)) 279 | 280 | # Pick first center randomly and track index of point 281 | # center_id = random_state.randint(n_samples) 282 | center_id = torch.randint(n_samples, (1,), generator=generator, device=X.device) 283 | 284 | indices = torch.full((n_clusters,), -1, dtype=torch.int, device=X.device) 285 | centers[0] = X[center_id] 286 | indices[0] = center_id 287 | 288 | # Initialize list of closest distances and calculate current potential 289 | closest_dist_sq = 1/pairwise_similarity( 290 | centers[0, None], X) 291 | current_pot = closest_dist_sq.sum() 292 | 293 | # Pick the remaining n_clusters-1 points 294 | for c in range(1, n_clusters): 295 | # Choose center candidates by sampling with probability proportional 296 | # to the squared distance to the closest existing center 297 | # rand_vals = random_state.random_sample(n_local_trials) * current_pot 298 | rand_vals = torch.rand(n_local_trials, generator=generator, device=X.device) * current_pot 299 | 300 | candidate_ids = torch.searchsorted(stable_cumsum(closest_dist_sq), 301 | rand_vals) 302 | # XXX: numerical imprecision can result in a candidate_id out of range 303 | torch.clip(candidate_ids, None, closest_dist_sq.numel() - 1, 304 | out=candidate_ids) 305 | 306 | # Compute distances to center candidates 307 | distance_to_candidates = 1/pairwise_similarity( 308 | X[candidate_ids], X) 309 | 310 | # update closest distances squared and potential for each candidate 311 | torch.minimum(closest_dist_sq, distance_to_candidates, 312 | out=distance_to_candidates) 313 | candidates_pot = distance_to_candidates.sum(dim=1) 314 | 315 | # Decide which candidate is the best 316 | best_candidate = torch.argmin(candidates_pot) 317 | current_pot = candidates_pot[best_candidate] 318 | closest_dist_sq = distance_to_candidates[best_candidate] 319 | best_candidate = candidate_ids[best_candidate] 320 | 321 | # Permanently add best center candidate found in local tries 322 | centers[c] = X[best_candidate] 323 | indices[c] = best_candidate 324 | 325 | return centers, indices 326 | if __name__ == '__main__': 327 | X = torch.randn(6, 3) 328 | cluster_label,_ = balanced_kmean(X,n_clusters=3,init='k-means++',device=torch.device('cpu'),tol=1e-4,iol=100, 329 | distance='euclidean') 330 | print(cluster_label) 331 | -------------------------------------------------------------------------------- /SimCLR/kmeans_gpu.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | #from .soft_dtw_cuda import SoftDTW 8 | 9 | 10 | def initialize(X, num_clusters, seed): 11 | """ 12 | initialize cluster centers 13 | :param X: (torch.tensor) matrix 14 | :param num_clusters: (int) number of clusters 15 | :param seed: (int) seed for kmeans 16 | :return: (np.array) initial state 17 | """ 18 | num_samples = len(X) 19 | if seed == None: 20 | indices = np.random.choice(num_samples, num_clusters, replace=False) 21 | else: 22 | np.random.seed(seed) ; indices = np.random.choice(num_samples, num_clusters, replace=False) 23 | initial_state = X[indices] 24 | return initial_state 25 | 26 | 27 | def kmeans( 28 | X, 29 | num_clusters, 30 | distance='euclidean', 31 | cluster_centers=[], 32 | tol=1e-4, 33 | tqdm_flag=True, 34 | iter_limit=0, 35 | device=torch.device('cpu'), 36 | gamma_for_soft_dtw=0.001, 37 | seed=None, 38 | ): 39 | """ 40 | perform kmeans 41 | :param X: (torch.tensor) matrix 42 | :param num_clusters: (int) number of clusters 43 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 44 | :param seed: (int) seed for kmeans 45 | :param tol: (float) threshold [default: 0.0001] 46 | :param device: (torch.device) device [default: cpu] 47 | :param tqdm_flag: Allows to turn logs on and off 48 | :param iter_limit: hard limit for max number of iterations 49 | :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 50 | :return: (torch.tensor, torch.tensor) cluster ids, cluster centers 51 | """ 52 | if tqdm_flag: 53 | print(f'running k-means on {device}..') 54 | 55 | if distance == 'euclidean': 56 | pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) 57 | elif distance == 'cosine': 58 | pairwise_distance_function = partial(pairwise_cosine, device=device) 59 | else: 60 | raise NotImplementedError 61 | 62 | # convert to float 63 | X = X.float() 64 | 65 | # transfer to device 66 | X = X.to(device) 67 | 68 | # initialize 69 | if type(cluster_centers) == list: # ToDo: make this less annoyingly weird 70 | initial_state = initialize(X, num_clusters, seed=seed) 71 | else: 72 | if tqdm_flag: 73 | print('resuming') 74 | # find data point closest to the initial cluster center 75 | initial_state = cluster_centers 76 | dis = pairwise_distance_function(X, initial_state) 77 | choice_points = torch.argmin(dis, dim=0) 78 | initial_state = X[choice_points] 79 | initial_state = initial_state.to(device) 80 | 81 | iteration = 0 82 | if tqdm_flag: 83 | tqdm_meter = tqdm(desc='[running kmeans]') 84 | while True: 85 | 86 | dis = pairwise_distance_function(X, initial_state) 87 | 88 | choice_cluster = torch.argmin(dis, dim=1) 89 | 90 | initial_state_pre = initial_state.clone() 91 | 92 | for index in range(num_clusters): 93 | selected = torch.nonzero(choice_cluster == index).squeeze().to(device) 94 | 95 | selected = torch.index_select(X, 0, selected) 96 | 97 | # https://github.com/subhadarship/kmeans_pytorch/issues/16 98 | if selected.shape[0] == 0: 99 | selected = X[torch.randint(len(X), (1,))] 100 | 101 | initial_state[index] = selected.mean(dim=0) 102 | 103 | center_shift = torch.sum( 104 | torch.sqrt( 105 | torch.sum((initial_state - initial_state_pre) ** 2, dim=1) 106 | )) 107 | 108 | # increment iteration 109 | iteration = iteration + 1 110 | 111 | # update tqdm meter 112 | if tqdm_flag: 113 | tqdm_meter.set_postfix( 114 | iteration=f'{iteration}', 115 | center_shift=f'{center_shift ** 2:0.6f}', 116 | tol=f'{tol:0.6f}' 117 | ) 118 | tqdm_meter.update() 119 | if center_shift ** 2 < tol: 120 | break 121 | if iter_limit != 0 and iteration >= iter_limit: 122 | break 123 | 124 | return choice_cluster.cpu(), initial_state.cpu() 125 | 126 | 127 | def kmeans_predict( 128 | X, 129 | cluster_centers, 130 | distance='euclidean', 131 | device=torch.device('cpu'), 132 | gamma_for_soft_dtw=0.001, 133 | tqdm_flag=True 134 | ): 135 | """ 136 | predict using cluster centers 137 | :param X: (torch.tensor) matrix 138 | :param cluster_centers: (torch.tensor) cluster centers 139 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 140 | :param device: (torch.device) device [default: 'cpu'] 141 | :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 142 | :return: (torch.tensor) cluster ids 143 | """ 144 | if tqdm_flag: 145 | print(f'predicting on {device}..') 146 | 147 | if distance == 'euclidean': 148 | pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) 149 | elif distance == 'cosine': 150 | pairwise_distance_function = partial(pairwise_cosine, device=device) 151 | else: 152 | raise NotImplementedError 153 | 154 | # convert to float 155 | X = X.float() 156 | 157 | # transfer to device 158 | X = X.to(device) 159 | 160 | dis = pairwise_distance_function(X, cluster_centers) 161 | choice_cluster = torch.argmin(dis, dim=1) 162 | 163 | return choice_cluster.cpu() 164 | 165 | 166 | def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True): 167 | if tqdm_flag: 168 | print(f'device is :{device}') 169 | 170 | # transfer to device 171 | data1, data2 = data1.to(device), data2.to(device) 172 | 173 | # N*1*M 174 | A = data1.unsqueeze(dim=1) 175 | 176 | # 1*N*M 177 | B = data2.unsqueeze(dim=0) 178 | 179 | dis = (A - B) ** 2.0 180 | # return N*N matrix for pairwise distance 181 | dis = dis.sum(dim=-1).squeeze() 182 | return dis 183 | 184 | 185 | def pairwise_cosine(data1, data2, device=torch.device('cpu')): 186 | # transfer to device 187 | data1, data2 = data1.to(device), data2.to(device) 188 | 189 | # N*1*M 190 | A = data1.unsqueeze(dim=1) 191 | 192 | # 1*N*M 193 | B = data2.unsqueeze(dim=0) 194 | 195 | # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] 196 | A_normalized = A / A.norm(dim=-1, keepdim=True) 197 | B_normalized = B / B.norm(dim=-1, keepdim=True) 198 | 199 | cosine = A_normalized * B_normalized 200 | 201 | # return N*N matrix for pairwise distance 202 | cosine_dis = 1 - cosine.sum(dim=-1).squeeze() 203 | return cosine_dis 204 | 205 | -------------------------------------------------------------------------------- /SimCLR/linear_classify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | import warnings 6 | import sys 7 | import math 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import os, sys 19 | sys.path.append(os.getcwd()) 20 | from resnet import SupConResNet,LinearClassifier,NormedLinear 21 | from sklearn.metrics import confusion_matrix 22 | from unbalance import IMBALANCECIFAR10, IMBALANCECIFAR100 23 | from sample import ClassAwareSampler 24 | from loss import* 25 | from utils import* 26 | 27 | 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch Cifar Training') 30 | parser.add_argument('--dataset', default='cifar100', help='dataset setting') 31 | parser.add_argument('--imb_type', default="exp", type=str, help='imbalance type') 32 | parser.add_argument('--imb_factor', default=0.01, type=float, help='imbalance factor') 33 | parser.add_argument('--feat_dim', default=128, type=int, help='feature dimenmion for model') 34 | parser.add_argument('--train_rule', default='DRW', type=str, help='data sampling strategy for train loader') 35 | parser.add_argument('--rand_number', default=0, type=int, help='fix random number for data sampling') 36 | parser.add_argument('--exp_str', default='0', type=str, help='number to indicate which experiment it is') 37 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 40 | help='number of total epochs to run') 41 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 42 | help='manual epoch number (useful on restarts)') 43 | parser.add_argument('-b', '--batch-size', default=128, type=int, 44 | metavar='N', 45 | help='mini-batch size') 46 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 47 | help='decay rate for learning rate') 48 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 49 | metavar='LR', help='initial learning rate', dest='lr') 50 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 51 | help='momentum') 52 | parser.add_argument('--wd', '--weight-decay', default=0, type=float, 53 | metavar='W', help='weight decay (default: 0)', 54 | dest='weight_decay') 55 | parser.add_argument('--seed', default=None, type=int, 56 | help='seed for initializing training. ') 57 | parser.add_argument('--gpu', default=None, type=int, 58 | help='GPU id to use.') 59 | parser.add_argument('--root_log',type=str, default='log') 60 | parser.add_argument('--root_model', type=str, default='checkpoint') 61 | best_acc1 = 0 62 | 63 | 64 | def main(): 65 | args = parser.parse_args() 66 | 67 | if args.seed is not None: 68 | random.seed(args.seed) 69 | torch.manual_seed(args.seed) 70 | cudnn.deterministic = True 71 | warnings.warn('You have chosen to seed training. ' 72 | 'This will turn on the CUDNN deterministic setting, ' 73 | 'which can slow down your training considerably! ' 74 | 'You may see unexpected behavior when restarting ' 75 | 'from checkpoints.') 76 | 77 | if args.gpu is not None: 78 | warnings.warn('You have chosen a specific GPU. This will completely ' 79 | 'disable data parallelism.') 80 | 81 | ngpus_per_node = torch.cuda.device_count() 82 | main_worker(args.gpu, ngpus_per_node, args) 83 | def main_worker(gpu, ngpus_per_node, args): 84 | global best_acc1 85 | args.gpu = gpu 86 | if args.gpu is not None: 87 | print("Use GPU: {} for training".format(args.gpu)) 88 | 89 | 90 | num_classes = 100 if args.dataset == 'cifar100' else 10 91 | model = SupConResNet(feat_dim=args.feat_dim) 92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 93 | model = torch.nn.DataParallel(model).cuda() 94 | PATH = ''#your idea model path 95 | model.load_state_dict(torch.load(PATH)) 96 | if args.train_rule == 'Reweight' or args.train_rule == 'DRW' : 97 | classify = NormedLinear(64,num_classes) 98 | else: 99 | classify = LinearClassifier(num_classes=num_classes) 100 | classify = torch.nn.DataParallel(classify).cuda() 101 | optimizer = torch.optim.SGD(classify.parameters(), args.lr, 102 | momentum=args.momentum, 103 | weight_decay=args.weight_decay) 104 | cudnn.benchmark = True 105 | 106 | transform_train = transforms.Compose([ 107 | transforms.RandomCrop(32, padding=4), 108 | transforms.RandomHorizontalFlip(), 109 | transforms.ToTensor(), 110 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 111 | ]) 112 | transform_val = transforms.Compose([ 113 | transforms.ToTensor(), 114 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 115 | ]) 116 | if args.dataset == 'cifar10': 117 | train_dataset = IMBALANCECIFAR10(root='./dataset/data', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_train) 118 | val_dataset = datasets.CIFAR10(root='./dataset/data', train=False, download=True, transform=transform_val) 119 | elif args.dataset == 'cifar100': 120 | train_dataset = IMBALANCECIFAR100(root='./dataset/data100', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=transform_train) 121 | val_dataset = datasets.CIFAR100(root='./dataset/data100', train=False, download=True, transform=transform_val) 122 | else: 123 | warnings.warn('Dataset is not listed') 124 | return 125 | cls_num_list = train_dataset.get_cls_num_list() 126 | print('cls num list:') 127 | print(cls_num_list) 128 | args.cls_num_list = cls_num_list 129 | if args.train_rule == 'DRW' or args.train_rule == 'CE' : 130 | train_sampler = None 131 | elif args.train_rule == 'CB': 132 | train_sampler = ClassAwareSampler(train_dataset) 133 | elif args.train_rule == 'Reweight': 134 | train_sampler = ImbalancedDatasetSampler(train_dataset) 135 | 136 | train_loader = torch.utils.data.DataLoader( 137 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 138 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 139 | 140 | val_loader = torch.utils.data.DataLoader( 141 | val_dataset, batch_size=100, shuffle=False, 142 | num_workers=args.workers, pin_memory=True) 143 | 144 | for epoch in range(args.start_epoch, args.epochs): 145 | adjust_learning_rate(optimizer, epoch, args) 146 | if args.train_rule == 'DRW' : 147 | train_sampler = None 148 | idx=1 149 | betas = [0, 0.9999] 150 | effective_num = 1.0 - np.power(betas[idx], cls_num_list) 151 | per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num) 152 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) 153 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu) 154 | criterion = LDAMLoss(cls_num_list=cls_num_list, max_m=0.5, s=30, weight=per_cls_weights).cuda(args.gpu) 155 | elif args.train_rule == 'Reweight' : 156 | criterion = LDAMLoss(cls_num_list=cls_num_list, max_m=0.5, s=30).cuda(args.gpu) 157 | else: 158 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 159 | train(train_loader, model,classify,criterion, optimizer, epoch, args,flag='train') 160 | acc1 =validate(val_loader, model,classify,criterion, epoch, args,flag='val') 161 | best_acc1 = max(acc1, best_acc1) 162 | print('acc/test_top1_best', best_acc1, epoch) 163 | output_best = 'Best Prec@1: %.3f\n' % (best_acc1) 164 | print(output_best) 165 | 166 | 167 | def train(train_loader, model,classify,criterion, optimizer, epoch, args,flag='train'): 168 | batch_time = AverageMeter('Time', ':6.3f') 169 | data_time = AverageMeter('Data', ':6.3f') 170 | losses = AverageMeter('Loss', ':.4e') 171 | top1 = AverageMeter('Acc@1', ':6.2f') 172 | top5 = AverageMeter('Acc@5', ':6.2f') 173 | 174 | # switch to train mode 175 | classify.train() 176 | model.eval() 177 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 178 | end = time.time() 179 | for i, (input,target,cluster_target) in enumerate(train_loader): 180 | # measure data loading time 181 | data_time.update(time.time() - end) 182 | input = input.to(device) 183 | target = target.to(device) 184 | with torch.no_grad(): 185 | features = model.module.encoder(input) 186 | if isinstance(features,tuple): 187 | out,features = features 188 | else: 189 | features = features 190 | output = classify(features.detach()) 191 | loss = criterion(output, target) 192 | 193 | # measure accuracy and record loss 194 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 195 | losses.update(loss.item(), input.size(0)) 196 | top1.update(acc1[0], input.size(0)) 197 | top5.update(acc5[0], input.size(0)) 198 | 199 | # compute gradient and do SGD step 200 | optimizer.zero_grad() 201 | loss.backward() 202 | optimizer.step() 203 | 204 | # measure elapsed time 205 | batch_time.update(time.time() - end) 206 | end = time.time() 207 | torch.cuda.empty_cache() 208 | output = ('{flag} Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 209 | .format(flag=flag, top1=top1, top5=top5, loss=losses)) 210 | print(output) 211 | torch.cuda.empty_cache() 212 | 213 | def validate(val_loader, model,classify,criterion, epoch, args,flag='val'): 214 | batch_time = AverageMeter('Time', ':6.3f') 215 | losses = AverageMeter('Loss', ':.4e') 216 | top1 = AverageMeter('Acc@1', ':6.2f') 217 | top5 = AverageMeter('Acc@5', ':6.2f') 218 | 219 | # switch to evaluate mode 220 | model.eval() 221 | classify.eval() 222 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 223 | all_preds = [] 224 | all_targets = [] 225 | with torch.no_grad(): 226 | end = time.time() 227 | for i, (input, target) in enumerate(val_loader): 228 | input = input.to(device) 229 | target = target.to(device) 230 | # compute output 231 | features = model.module.encoder(input) 232 | if isinstance(features,tuple): 233 | out,features = features 234 | else: 235 | features = features 236 | output = classify(features) 237 | loss = criterion(output,target) 238 | 239 | # measure accuracy and record loss 240 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 241 | losses.update(loss.item(), input.size(0)) 242 | top1.update(acc1[0], input.size(0)) 243 | top5.update(acc5[0], input.size(0)) 244 | 245 | # measure elapsed time 246 | batch_time.update(time.time() - end) 247 | end = time.time() 248 | #print(output) 249 | _, pred = torch.max(output, 1) 250 | all_preds.extend(pred.cpu().numpy()) 251 | all_targets.extend(target.cpu().numpy()) 252 | cf = confusion_matrix(all_targets, all_preds).astype(float) 253 | cls_cnt = cf.sum(axis=1) 254 | cls_hit = np.diag(cf) 255 | cls_acc = cls_hit / cls_cnt 256 | output = ('{flag} Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 257 | .format(flag=flag, top1=top1, top5=top5, loss=losses)) 258 | out_cls_acc = '%s Class Accuracy: %s'%(flag,(np.array2string(cls_acc, separator=',', formatter={'float_kind':lambda x: "%.3f" % x}))) 259 | print(output) 260 | print(out_cls_acc) 261 | 262 | return top1.avg 263 | 264 | def adjust_learning_rate(optimizer, epoch, args): 265 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 266 | epoch = epoch + 1 267 | if args.train_rule == 'CB' or args.train_rule == 'Reweight': 268 | if epoch <= 5: 269 | lr = args.lr * epoch / 5 270 | else: 271 | lr_min = 0 272 | lr_max = args.lr 273 | lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos((epoch-5) / (args.epochs-5) * 3.1415926535)) 274 | else: 275 | if epoch <= 5: 276 | lr = args.lr * epoch / 5 277 | elif epoch > 180: 278 | lr = args.lr * 0.01 279 | elif epoch > 140: 280 | lr = args.lr * 0.1 281 | elif epoch > 190: 282 | lr = args.lr * 0.001 283 | else: 284 | lr = args.lr 285 | for param_group in optimizer.param_groups: 286 | param_group['lr'] = lr 287 | 288 | 289 | 290 | if __name__ == '__main__': 291 | main() 292 | -------------------------------------------------------------------------------- /SimCLR/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class LDAMLoss(nn.Module): 9 | 10 | def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30): 11 | super(LDAMLoss, self).__init__() 12 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) 13 | m_list = m_list * (max_m / np.max(m_list)) 14 | m_list = torch.cuda.FloatTensor(m_list) 15 | self.m_list = m_list 16 | assert s > 0 17 | self.s = s 18 | self.weight = weight 19 | 20 | def forward(self, x, target): 21 | index = torch.zeros_like(x, dtype=torch.uint8) 22 | index.scatter_(1, target.data.view(-1, 1), 1) 23 | 24 | index_float = index.type(torch.cuda.FloatTensor) 25 | batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1)) 26 | batch_m = batch_m.view((-1, 1)) 27 | x_m = x - batch_m 28 | 29 | output = torch.where(index, x_m, x) 30 | return F.cross_entropy(self.s*output, target, weight=self.weight) 31 | 32 | 33 | class SupConLoss(nn.Module): 34 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 35 | It also supports the unsupervised contrastive loss in SimCLR""" 36 | def __init__(self,temperature=0.1, contrast_mode='all', 37 | base_temperature=0.1): 38 | super(SupConLoss, self).__init__() 39 | self.temperature = temperature 40 | self.contrast_mode = contrast_mode 41 | self.base_temperature = base_temperature 42 | def forward(self, features, labels=None, mask=None): 43 | """Compute loss for model. If both `labels` and `mask` are None, 44 | it degenerates to SimCLR unsupervised loss: 45 | https://arxiv.org/pdf/2002.05709.pdf 46 | 47 | Args: 48 | features: hidden vector of shape [bsz, n_views, ...]. 49 | labels: ground truth of shape [bsz]. 50 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 51 | has the same class as sample i. Can be asymmetric. 52 | Returns: 53 | A loss scalar. 54 | """ 55 | device = (torch.device('cuda') 56 | if features.is_cuda 57 | else torch.device('cpu')) 58 | 59 | if len(features.shape) < 3: 60 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 61 | 'at least 3 dimensions are required') 62 | if len(features.shape) > 3: 63 | features = features.view(features.shape[0], features.shape[1], -1) 64 | 65 | batch_size = features.shape[0] 66 | 67 | if labels is not None and mask is not None: 68 | raise ValueError('Cannot define both `labels` and `mask`') 69 | elif labels is None and mask is None: 70 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 71 | elif labels is not None: 72 | labels = labels.contiguous().view(-1, 1) 73 | if labels.shape[0] != batch_size: 74 | raise ValueError('Num of labels does not match num of features') 75 | mask = torch.eq(labels, labels.T).float().to(device) 76 | else: 77 | mask = mask.float().to(device) 78 | 79 | contrast_count = features.shape[1] 80 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 81 | if self.contrast_mode == 'one': 82 | anchor_feature = features[:, 0] 83 | anchor_count = 1 84 | elif self.contrast_mode == 'all': 85 | anchor_feature = contrast_feature 86 | anchor_count = contrast_count 87 | else: 88 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 89 | 90 | # compute logits 91 | anchor_dot_contrast = torch.div( 92 | torch.matmul(anchor_feature, contrast_feature.T), 93 | self.temperature) 94 | # for numerical stability 95 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 96 | logits = anchor_dot_contrast - logits_max.detach() 97 | 98 | # tile mask 99 | mask = mask.repeat(anchor_count, contrast_count) 100 | # mask-out self-contrast cases 101 | logits_mask = torch.scatter( 102 | torch.ones_like(mask), 103 | 1, 104 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 105 | 0 106 | ) 107 | mask = mask * logits_mask 108 | 109 | # compute log_prob 110 | exp_logits = torch.exp(logits) * logits_mask 111 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 112 | 113 | # compute mean of log-likelihood over positive 114 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 115 | 116 | # loss 117 | loss = - (self.temperature/ self.base_temperature) * mean_log_prob_pos 118 | loss = loss.view(anchor_count, batch_size).mean() 119 | 120 | return loss 121 | 122 | 123 | class SupConLoss_ccl(nn.Module): 124 | def __init__(self,temperature=0.1, grama=0.25, 125 | base_temperature=0.07): 126 | super(SupConLoss_ccl, self).__init__() 127 | self.temperature = temperature 128 | self.base_temperature = temperature if base_temperature is None else base_temperature 129 | self.grama = grama 130 | def forward(self,features,labels,cluster_target, mask=None): 131 | device = (torch.device('cuda') 132 | if features.is_cuda 133 | else torch.device('cpu')) 134 | 135 | if len(features.shape) < 3: 136 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 137 | 'at least 3 dimensions are required') 138 | if len(features.shape) > 3: 139 | features = features.view(features.shape[0], features.shape[1], -1) 140 | 141 | batch_size = features.shape[0] 142 | labels = labels.contiguous().view(-1, 1) 143 | if labels.shape[0] != batch_size: 144 | raise ValueError('Num of labels does not match num of features') 145 | mask = torch.eq(labels, labels.T).float().to(device) 146 | 147 | contrast_count = features.shape[1] 148 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 149 | anchor_feature = contrast_feature 150 | anchor_count = contrast_count 151 | # compute logits 152 | anchor_dot_contrast = torch.div( 153 | torch.matmul(anchor_feature, contrast_feature.T), 154 | self.temperature) 155 | # for numerical stability 156 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 157 | logits = anchor_dot_contrast - logits_max.detach() 158 | 159 | # tile mask 160 | mask = mask.repeat(anchor_count, contrast_count) 161 | # mask-out self-contrast cases 162 | logits_mask = torch.scatter( 163 | torch.ones_like(mask), 164 | 1, 165 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 166 | 0 167 | ) 168 | mask = mask * logits_mask 169 | 170 | # compute log_prob 171 | exp_logits = torch.exp(logits) * logits_mask 172 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 173 | 174 | # compute mean of log-likelihood over positive 175 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 176 | 177 | cluster_target = cluster_target.contiguous().view(-1, 1) 178 | cluster_mask = torch.eq(cluster_target, cluster_target.T).float().to(device) 179 | cluster_mask = cluster_mask.repeat(anchor_count, contrast_count) 180 | cluster_mask = cluster_mask * logits_mask 181 | 182 | # for numerical stability 183 | mean_log_label_pos = (cluster_mask * log_prob).sum(1) / cluster_mask.sum(1) 184 | # loss 185 | 186 | loss = - (self.temperature / self.base_temperature) * mean_log_label_pos-self.grama*(mean_log_prob_pos) 187 | loss = loss.mean() 188 | return loss 189 | 190 | 191 | 192 | class SupConLoss_rank(nn.Module): 193 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 194 | It also supports the unsupervised contrastive loss in SimCLR""" 195 | def __init__(self,num_class, temperature=0.1,ranking_temperature=0.2, contrast_mode='all',grama=0.25, 196 | base_temperature=0.07): 197 | super(SupConLoss_rank, self).__init__() 198 | self.temperature = temperature 199 | self.contrast_mode = contrast_mode 200 | self.base_temperature = base_temperature 201 | self.ranking_temperature = ranking_temperature 202 | self.num_class = num_class 203 | self.grama = grama 204 | def forward(self,features,labels,cluster_target, mask=None): 205 | """Compute loss for model. If both `labels` and `mask` are None, 206 | it degenerates to SimCLR unsupervised loss: 207 | https://arxiv.org/pdf/2002.05709.pdf 208 | 209 | Args: 210 | features: hidden vector of shape [bsz, n_views, ...]. 211 | labels: ground truth of shape [bsz]. 212 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 213 | has the same class as sample i. Can be asymmetric. 214 | Returns: 215 | A loss scalar. 216 | """ 217 | device = (torch.device('cuda') 218 | if features.is_cuda 219 | else torch.device('cpu')) 220 | 221 | if len(features.shape) < 3: 222 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 223 | 'at least 3 dimensions are required') 224 | if len(features.shape) > 3: 225 | features = features.view(features.shape[0], features.shape[1], -1) 226 | 227 | batch_size = features.shape[0] 228 | cluster_target = cluster_target.contiguous().view(-1, 1) 229 | if cluster_target.shape[0] != batch_size: 230 | raise ValueError('Num of labels does not match num of features') 231 | mask = torch.eq(cluster_target, cluster_target.T).float().to(device) 232 | 233 | contrast_count = features.shape[1] 234 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 235 | anchor_feature = contrast_feature 236 | anchor_count = contrast_count 237 | 238 | # compute logits 239 | anchor_dot_contrast = torch.div( 240 | torch.matmul(anchor_feature, contrast_feature.T), 241 | self.temperature) 242 | # for numerical stability 243 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 244 | logits = anchor_dot_contrast - logits_max.detach() 245 | one_hot_labels=torch.nn.functional.one_hot(labels, num_classes=self.num_class) 246 | one_hot_labels = one_hot_labels.repeat(anchor_count,1).float().cuda() 247 | ranking_temperature = torch.from_numpy(self.ranking_temperature).float().cuda() 248 | ranking_temperature = torch.matmul(one_hot_labels, ranking_temperature.T) 249 | ranking_temperature = ranking_temperature.unsqueeze(0).T 250 | anchor_rank_contrast = torch.div( 251 | torch.matmul(anchor_feature, contrast_feature.T), 252 | ranking_temperature) 253 | # for numerical stability 254 | logits_rank_max, _ = torch.max(anchor_rank_contrast, dim=1, keepdim=True) 255 | logits_rank = anchor_rank_contrast - logits_rank_max.detach() 256 | 257 | # tile mask 258 | mask = mask.repeat(anchor_count, contrast_count) 259 | # mask-out self-contrast cases 260 | logits_mask = torch.scatter( 261 | torch.ones_like(mask), 262 | 1, 263 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 264 | 0 265 | ) 266 | mask = mask * logits_mask 267 | 268 | # compute log_prob 269 | exp_logits = torch.exp(logits) * logits_mask 270 | 271 | # compute mean of log-likelihood over positive 272 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 273 | 274 | # compute mean of log-likelihood over positive 275 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 276 | 277 | labels = labels.contiguous().view(-1, 1) 278 | 279 | label_mask = torch.eq(labels, labels.T).float().to(device) 280 | label_mask = label_mask.repeat(anchor_count, contrast_count) 281 | Bool = ~mask.bool() 282 | Inverse_cluster_target =Bool.float() 283 | logits_label_mask = Inverse_cluster_target * logits_mask 284 | list=np.arange(batch_size,batch_size* anchor_count).tolist() 285 | list.extend(np.arange(batch_size).tolist()) 286 | ad_mask = torch.scatter( 287 | torch.zeros(mask.size(0),mask.size(1)), 288 | 0, 289 | torch.tensor([list]), 290 | 1 291 | ) 292 | ad_logits_label_mask= ad_mask.cuda()+logits_label_mask 293 | exp_logits_rank = torch.exp(logits_rank) * ad_logits_label_mask 294 | 295 | label_mask=logits_label_mask*label_mask 296 | label_mask=ad_mask.cuda()+label_mask 297 | # compute mean of log-likelihood over positive 298 | log_prob_rank = logits_rank - torch.log(exp_logits_rank.sum(1, keepdim=True)) 299 | 300 | mean_log_label_pos = (label_mask * log_prob_rank).sum(1) /label_mask.sum(1) 301 | ''' 302 | label_mask = torch.eq(labels, labels.T).float().to(device) 303 | label_mask = label_mask.repeat(anchor_count, contrast_count) 304 | label_mask = label_mask * logits_mask 305 | exp_logits_rank = torch.exp(logits_rank) * logits_mask 306 | log_prob_rank = logits_rank - torch.log(exp_logits_rank.sum(1, keepdim=True)) 307 | 308 | mean_log_label_pos = (label_mask * log_prob_rank).sum(1) /label_mask.sum(1) 309 | ''' 310 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos-self.grama*(mean_log_label_pos) 311 | loss = loss.mean() 312 | 313 | return loss 314 | 315 | -------------------------------------------------------------------------------- /SimCLR/main.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import random 5 | import time 6 | import warnings 7 | import sys 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torch.nn.functional as F 20 | import os, sys 21 | sys.path.append(os.getcwd()) 22 | from resnet import SupConResNet 23 | from balanced_cluster import balanced_kmean 24 | from kmeans_gpu import kmeans 25 | from unbalance import IMBALANCECIFAR10, IMBALANCECIFAR100 26 | from loss import SupConLoss_ccl,SupConLoss_rank,SupConLoss 27 | from utils import* 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch Cifar Training') 30 | parser.add_argument('--dataset', default='cifar100', help='dataset setting') 31 | parser.add_argument('--imb_type', default="exp", type=str, help='imbalance type') 32 | parser.add_argument('--imb_factor', default=0.01, type=float, help='imbalance factor') 33 | parser.add_argument('--train_rule', default='Rank', type=str, help='loss function for constrastive learning') 34 | parser.add_argument('--rand_number', default=0, type=int, help='fix random number for data sampling') 35 | parser.add_argument('--feat_dim', default=128, type=int, help='feature dimenmion for model') 36 | parser.add_argument('--exp_str', default='0', type=str, help='number to indicate which experiment it is') 37 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('--epochs', default=1000, type=int, metavar='N', 40 | help='number of total epochs to run') 41 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 42 | help='manual epoch number (useful on restarts)') 43 | parser.add_argument('-b', '--batch-size', default=1024, type=int, 44 | metavar='N', 45 | help='mini-batch size') 46 | parser.add_argument( '--step', default=10, type=int, 47 | metavar='N' ,help='steps for updating cluster') 48 | parser.add_argument('--lr', '--learning-rate', default=0.5, type=float, 49 | metavar='LR', help='initial learning rate', dest='lr') 50 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 51 | help='momentum') 52 | parser.add_argument('--lr_decay_rate', default=0.1,type=float, 53 | help='decay rate for learning rate') 54 | parser.add_argument('--wd', '--weight-decay',default=1e-4, type=float, 55 | metavar='W', help='weight decay (default: 1e-4)', 56 | dest='weight_decay') 57 | parser.add_argument('--cluster_method', default=False, type=str, 58 | help='chose to balance cluster') 59 | parser.add_argument('--cluster', default=10, type=int, 60 | metavar='N', help='the low limit of cluster') 61 | parser.add_argument('--seed', default=None, type=int, 62 | help='seed for initializing training. ') 63 | parser.add_argument('--gpu', default=None, type=int, 64 | help='GPU id to use.') 65 | parser.add_argument('--temperature', default=0.1, type=float, 66 | help='softmax temperature') 67 | parser.add_argument('--cosine', default='True', 68 | help='using cosine annealing') 69 | parser.add_argument('--root_log',type=str, default='log') 70 | parser.add_argument('--root_model', type=str, default='checkpoint') 71 | 72 | 73 | def main(): 74 | args = parser.parse_args() 75 | if args.seed is not None: 76 | random.seed(args.seed) 77 | torch.manual_seed(args.seed) 78 | cudnn.deterministic = True 79 | warnings.warn('You have chosen to seed training. ' 80 | 'This will turn on the CUDNN deterministic setting, ' 81 | 'which can slow down your training considerably! ' 82 | 'You may see unexpected behavior when restarting ' 83 | 'from checkpoints.') 84 | 85 | if args.gpu is not None: 86 | warnings.warn('You have chosen a specific GPU. This will completely ' 87 | 'disable data parallelism.') 88 | 89 | ngpus_per_node = torch.cuda.device_count() 90 | main_worker(args.gpu, ngpus_per_node, args) 91 | def main_worker(gpu, ngpus_per_node, args): 92 | args.gpu = gpu 93 | if args.gpu is not None: 94 | print("Use GPU: {} for training".format(args.gpu)) 95 | if args.batch_size > 256: 96 | args.warm = True 97 | else: 98 | args.warm = False 99 | if args.warm: 100 | args.warmup_from = 0.01 101 | args.warm_epochs = 10 102 | if args.cosine: 103 | eta_min = args.lr * (args.lr_decay_rate ** 3) 104 | args.warmup_to = eta_min + (args.lr - eta_min) * ( 105 | 1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2 106 | else: 107 | args.warmup_to = args.lr 108 | num_classes = 100 if args.dataset == 'cifar100' else 10 109 | args.num_classes = num_classes 110 | model = SupConResNet(feat_dim=args.feat_dim) 111 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 112 | model = torch.nn.DataParallel(model).cuda() 113 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 114 | momentum=args.momentum, 115 | weight_decay=args.weight_decay) 116 | cudnn.benchmark = True 117 | transform_train = transforms.Compose([ 118 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 119 | transforms.RandomHorizontalFlip(), 120 | transforms.RandomApply([ 121 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 122 | ], p=0.8), 123 | transforms.RandomGrayscale(p=0.2), 124 | transforms.ToTensor(), 125 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 126 | ]) 127 | 128 | transform_val = transforms.Compose([ 129 | transforms.ToTensor(), 130 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 131 | ]) 132 | if args.dataset == 'cifar10': 133 | train_dataset = IMBALANCECIFAR10(root='./dataset/data', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=TwoCropTransform(transform_train)) 134 | elif args.dataset == 'cifar100': 135 | train_dataset = IMBALANCECIFAR100(root='./dataset/data100', imb_type=args.imb_type, imb_factor=args.imb_factor, rand_number=args.rand_number, train=True, download=True, transform=TwoCropTransform(transform_train)) 136 | else: 137 | warnings.warn('Dataset is not listed') 138 | return 139 | cls_num_list = train_dataset.get_cls_num_list() 140 | print('cls num list:') 141 | print(cls_num_list) 142 | args.cls_num_list = cls_num_list 143 | print(sum(cls_num_list)) 144 | train_sampler = None 145 | train_loader_cluster = torch.utils.data.DataLoader( 146 | train_dataset, batch_size=args.batch_size, shuffle=False, 147 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 148 | cluster_number= [t//max(min(cls_num_list),args.cluster) for t in cls_num_list] 149 | for index, value in enumerate(cluster_number): 150 | if value==0: 151 | cluster_number[index]=1 152 | print(cluster_number) 153 | train_loader = torch.utils.data.DataLoader( 154 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 155 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 156 | for epoch in range(args.start_epoch, args.epochs): 157 | adjust_learning_rate(args, optimizer, epoch) 158 | if epoch < args.warm_epochs: 159 | criterion =SupConLoss(temperature=args.temperature).cuda() 160 | else: 161 | if args.train_rule != 'Rank': 162 | if epoch % args.step == 0: 163 | targets=cluster(train_loader_cluster,model,cluster_number,args) 164 | train_dataset.new_labels = targets 165 | criterion =SupConLoss_ccl(temperature=args.temperature).cuda() 166 | else: 167 | if epoch % args.step == 0: 168 | targets,density=cluster(train_loader_cluster,model,cluster_number,args) 169 | train_dataset.new_labels = targets 170 | criterion=SupConLoss_rank(num_class=num_classes,ranking_temperature=density).cuda() 171 | train_loss=train(train_loader,model,criterion, optimizer, epoch,args) 172 | if ((epoch+1) % 100 == 0 and 100< epoch < 1000) or (epoch==999): 173 | save_file = ''.format(epoch=epoch) #you should add file 174 | torch.save(model.state_dict(),save_file) 175 | 176 | def cluster (train_loader_cluster,model,cluster_number,args): 177 | model.eval() 178 | features_sum = [] 179 | for i, (input, target,cluster_target) in enumerate(train_loader_cluster): 180 | input = input[0].cuda() 181 | target =target.cuda() 182 | with torch.no_grad(): 183 | features = model(input) 184 | features = features.detach() 185 | features_sum.append(features) 186 | features= torch.cat(features_sum,dim=0) 187 | features = torch.split(features, args.cls_num_list, dim=0) 188 | if args.train_rule == 'Rank': 189 | feature_center = [torch.mean(t, dim=0) for t in features] 190 | feature_center = torch.cat(feature_center,axis = 0) 191 | feature_center=feature_center.reshape(args.num_classes,args.feat_dim) 192 | density = np.zeros(len(cluster_number)) 193 | for i in range(len(cluster_number)): 194 | center_distance = F.pairwise_distance(features[i], feature_center[i], p=2).mean()/np.log(len(features[i])+10) 195 | density[i] = center_distance.cpu().numpy() 196 | density = density.clip(np.percentile(density,20),np.percentile(density,80)) 197 | #density = args.temperature*np.exp(density/density.mean()) 198 | density = args.temperature*(density/density.mean()) 199 | for index, value in enumerate(cluster_number): 200 | if value==1: 201 | density[index] = args.temperature 202 | target = [[] for i in range(len(cluster_number))] 203 | for i in range(len(cluster_number)): 204 | if cluster_number[i] >1: 205 | if args.cluster_method: 206 | cluster_ids_x, _ = balanced_kmean(X=features[i], num_clusters=cluster_number[i], distance='cosine', init='k-means++',iol=50,tol=1e-3,device=torch.device("cuda")) 207 | else: 208 | cluster_ids_x, _ = kmeans(X=features[i], num_clusters=cluster_number[i], distance='cosine', tol=1e-3, iter_limit=35, device=torch.device("cuda")) 209 | #run faster for cluster 210 | target[i]=cluster_ids_x 211 | else: 212 | target[i] = torch.zeros(1,features[i].size()[0], dtype=torch.int).squeeze(0) 213 | cluster_number_sum=[sum(cluster_number[:i]) for i in range(len(cluster_number))] 214 | for i ,k in enumerate(cluster_number_sum): 215 | target[i] = torch.add(target[i], k) 216 | targets=torch.cat(target,dim=0) 217 | targets = targets.numpy().tolist() 218 | if args.train_rule == 'Rank': 219 | return targets,density 220 | else: 221 | return targets 222 | 223 | 224 | def train(train_loader,model,criterion, optimizer, epoch, args,flag='train'): 225 | batch_time = AverageMeter('Time', ':6.3f') 226 | data_time = AverageMeter('Data', ':6.3f') 227 | losses = AverageMeter('Loss', ':.4e') 228 | top1 = AverageMeter('Acc@1', ':6.2f') 229 | top5 = AverageMeter('Acc@5', ':6.2f') 230 | # switch to train mode 231 | model.train() 232 | end = time.time() 233 | for idx, (input, target,cluster_target) in enumerate(train_loader): 234 | # measure data loading time 235 | data_time.update(time.time() - end) 236 | input = torch.cat([input[0], input[1]], dim=0) 237 | input = input.cuda() 238 | target =target.cuda() 239 | cluster_target = cluster_target.cuda() 240 | warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer) 241 | bsz = target.shape[0] 242 | # compute output 243 | features= model(input) 244 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) 245 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 246 | if cluster_target[0]!= -1: 247 | loss = criterion(features,target,cluster_target) 248 | else: 249 | loss = criterion(features,target) 250 | losses.update(loss.item(), bsz) 251 | optimizer.zero_grad() 252 | loss.backward() 253 | optimizer.step() 254 | 255 | # measure elapsed time 256 | batch_time.update(time.time() - end) 257 | end = time.time() 258 | torch.cuda.empty_cache() 259 | output = ('{flag} Results: Loss {loss.avg:.5f}' 260 | .format(flag=flag, loss=losses)) 261 | print(output) 262 | print('epoch',epoch) 263 | return losses.avg 264 | def adjust_learning_rate(args, optimizer, epoch): 265 | lr = args.lr 266 | eta_min = lr * (args.lr_decay_rate ** 4) 267 | lr = eta_min + (lr - eta_min) * ( 268 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 269 | 270 | for param_group in optimizer.param_groups: 271 | param_group['lr'] = lr 272 | 273 | 274 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 275 | if args.warm and epoch <= args.warm_epochs: 276 | p = (batch_id + (epoch - 1) * total_batches) / \ 277 | (args.warm_epochs * total_batches) 278 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 279 | 280 | for param_group in optimizer.param_groups: 281 | param_group['lr'] = lr 282 | 283 | if __name__ == '__main__': 284 | main() 285 | -------------------------------------------------------------------------------- /SimCLR/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.nn import Parameter 6 | 7 | __all__ = ['ResNet_s', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 8 | 9 | def _weights_init(m): 10 | classname = m.__class__.__name__ 11 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 12 | init.kaiming_normal_(m.weight) 13 | 14 | class NormedLinear(nn.Module): 15 | 16 | def __init__(self, in_features, out_features): 17 | super(NormedLinear, self).__init__() 18 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 19 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 20 | 21 | def forward(self, x): 22 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 23 | return out 24 | 25 | class LambdaLayer(nn.Module): 26 | 27 | def __init__(self, lambd): 28 | super(LambdaLayer, self).__init__() 29 | self.lambd = lambd 30 | 31 | def forward(self, x): 32 | return self.lambd(x) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, in_planes, planes, stride=1, option='A'): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 43 | self.bn2 = nn.BatchNorm2d(planes) 44 | 45 | self.shortcut = nn.Sequential() 46 | if stride != 1 or in_planes != planes: 47 | if option == 'A': 48 | """ 49 | For CIFAR10 ResNet paper uses option A. 50 | """ 51 | self.shortcut = LambdaLayer(lambda x: 52 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 53 | elif option == 'B': 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion * planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = self.bn2(self.conv2(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet_s(nn.Module): 68 | 69 | def __init__(self, block, num_blocks, dropout=None, num_classes=10, use_norm=False,s=30): 70 | super(ResNet_s, self).__init__() 71 | self.in_planes = 16 72 | 73 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(16) 75 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 76 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 77 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 78 | if use_norm: 79 | self.linear = NormedLinear(64, num_classes) 80 | else: 81 | s=1 82 | self.linear = nn.Linear(64, num_classes) 83 | self.apply(_weights_init) 84 | self.use_dropout = True if dropout else False 85 | if self.use_dropout: 86 | print('Using dropout.') 87 | self.dropout = nn.Dropout(p=dropout) 88 | self.s=s 89 | def _make_layer(self, block, planes, num_blocks, stride): 90 | strides = [stride] + [1]*(num_blocks-1) 91 | layers = [] 92 | for stride in strides: 93 | layers.append(block(self.in_planes, planes, stride)) 94 | self.in_planes = planes * block.expansion 95 | 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | out = F.relu(self.bn1(self.conv1(x))) 100 | out= self.layer1(out) 101 | out = self.layer2(out) 102 | out = self.layer3(out) 103 | out = F.avg_pool2d(out, out.size()[3]) 104 | out = out.view(out.size(0), -1) 105 | feature = out 106 | return feature 107 | 108 | 109 | 110 | def resnet20(): 111 | return ResNet_s(BasicBlock, [3, 3, 3]) 112 | 113 | 114 | def resnet32(num_classes=10, use_norm=False,dropout=None): 115 | return ResNet_s(BasicBlock, [5, 5, 5], num_classes=num_classes, use_norm=use_norm,dropout=dropout) 116 | 117 | 118 | def resnet44(): 119 | return ResNet_s(BasicBlock, [7, 7, 7]) 120 | 121 | 122 | def resnet56(): 123 | return ResNet_s(BasicBlock, [9, 9, 9]) 124 | 125 | 126 | def resnet110(): 127 | return ResNet_s(BasicBlock, [18, 18, 18]) 128 | 129 | 130 | def resnet1202(): 131 | return ResNet_s(BasicBlock, [200, 200, 200]) 132 | 133 | 134 | def test(net): 135 | import numpy as np 136 | total_params = 0 137 | 138 | for x in filter(lambda p: p.requires_grad, net.parameters()): 139 | total_params += np.prod(x.data.numpy().shape) 140 | print("Total number of params", total_params) 141 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | for net_name in __all__: 147 | if net_name.startswith('resnet'): 148 | print(net_name) 149 | test(globals()[net_name]()) 150 | print() 151 | 152 | 153 | 154 | 155 | class SupConResNet(nn.Module): 156 | """backbone + projection head""" 157 | def __init__(self, head='mlp', feat_dim=32): 158 | super(SupConResNet, self).__init__() 159 | self.encoder = resnet32() 160 | dim_in = 64 161 | if head == 'linear': 162 | self.head = nn.Linear(dim_in, feat_dim) 163 | elif head == 'mlp': 164 | self.head = nn.Sequential( 165 | nn.Linear(dim_in, dim_in), 166 | nn.ReLU(inplace=True), 167 | nn.Linear(dim_in, feat_dim) 168 | ) 169 | else: 170 | raise NotImplementedError( 171 | 'head not supported: {}'.format(head)) 172 | 173 | def forward(self, x): 174 | feat = self.encoder(x) 175 | feat = F.normalize(self.head(feat), dim=1) 176 | return feat 177 | 178 | class LinearClassifier(nn.Module): 179 | """Linear classifier""" 180 | def __init__(self, name='resnet32', num_classes=10): 181 | super(LinearClassifier, self).__init__() 182 | feat_dim=64 183 | self.fc = nn.Linear(feat_dim, num_classes) 184 | 185 | def forward(self, features): 186 | return self.fc(features) 187 | -------------------------------------------------------------------------------- /SimCLR/sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | import torch 5 | 6 | class BalancedDatasetSampler(torch.utils.data.sampler.Sampler): 7 | 8 | def __init__(self, dataset, indices=None, num_samples=None): 9 | 10 | # if indices is not provided, 11 | # all elements in the dataset will be considered 12 | self.indices = list(range(len(dataset))) \ 13 | if indices is None else indices 14 | 15 | # if num_samples is not provided, 16 | # draw `len(indices)` samples in each iteration 17 | self.num_samples = len(self.indices) \ 18 | if num_samples is None else num_samples 19 | 20 | # distribution of classes in the dataset 21 | label_to_count = [0] * len(np.unique(dataset.targets)) 22 | for idx in self.indices: 23 | label = self._get_label(dataset, idx) 24 | label_to_count[label] += 1 25 | 26 | 27 | 28 | 29 | per_cls_weights = 1 / np.array(label_to_count) 30 | 31 | # weight for each sample 32 | weights = [per_cls_weights[self._get_label(dataset, idx)] 33 | for idx in self.indices] 34 | 35 | 36 | self.weights = torch.DoubleTensor(weights) 37 | 38 | def _get_label(self, dataset, idx): 39 | return dataset.targets[idx] 40 | 41 | def __iter__(self): 42 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist()) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | class EffectNumSampler(torch.utils.data.sampler.Sampler): 48 | 49 | def __init__(self, dataset, indices=None, num_samples=None): 50 | 51 | # if indices is not provided, 52 | # all elements in the dataset will be considered 53 | self.indices = list(range(len(dataset))) \ 54 | if indices is None else indices 55 | 56 | # if num_samples is not provided, 57 | # draw `len(indices)` samples in each iteration 58 | self.num_samples = len(self.indices) \ 59 | if num_samples is None else num_samples 60 | 61 | # distribution of classes in the dataset 62 | label_to_count = [0] * len(np.unique(dataset.targets)) 63 | for idx in self.indices: 64 | label = self._get_label(dataset, idx) 65 | label_to_count[label] += 1 66 | 67 | 68 | 69 | beta = 0.9999 70 | effective_num = 1.0 - np.power(beta, label_to_count) 71 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 72 | 73 | # weight for each sample 74 | weights = [per_cls_weights[self._get_label(dataset, idx)] 75 | for idx in self.indices] 76 | 77 | 78 | self.weights = torch.DoubleTensor(weights) 79 | 80 | def _get_label(self, dataset, idx): 81 | return dataset.targets[idx] 82 | 83 | def __iter__(self): 84 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist()) 85 | 86 | def __len__(self): 87 | return self.num_samples 88 | 89 | class RandomCycleIter: 90 | 91 | def __init__ (self, data, test_mode=False): 92 | self.data_list = list(data) 93 | self.length = len(self.data_list) 94 | self.i = self.length - 1 95 | self.test_mode = test_mode 96 | 97 | def __iter__ (self): 98 | return self 99 | 100 | def __next__ (self): 101 | self.i += 1 102 | 103 | if self.i == self.length: 104 | self.i = 0 105 | if not self.test_mode: 106 | random.shuffle(self.data_list) 107 | 108 | return self.data_list[self.i] 109 | 110 | def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1): 111 | 112 | i = 0 113 | j = 0 114 | while i < n: 115 | 116 | # yield next(data_iter_list[next(cls_iter)]) 117 | 118 | if j >= num_samples_cls: 119 | j = 0 120 | 121 | if j == 0: 122 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls)) 123 | yield temp_tuple[j] 124 | else: 125 | yield temp_tuple[j] 126 | 127 | i += 1 128 | j += 1 129 | 130 | class ClassAwareSampler(torch.utils.data.sampler.Sampler): 131 | def __init__(self, data_source, num_samples_cls=4,): 132 | # pdb.set_trace() 133 | num_classes = len(np.unique(data_source.targets)) 134 | self.class_iter = RandomCycleIter(range(num_classes)) 135 | cls_data_list = [list() for _ in range(num_classes)] 136 | for i, label in enumerate(data_source.targets): 137 | cls_data_list[label].append(i) 138 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] 139 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) 140 | self.num_samples_cls = num_samples_cls 141 | 142 | def __iter__ (self): 143 | return class_aware_sample_generator(self.class_iter, self.data_iter_list, 144 | self.num_samples, self.num_samples_cls) 145 | 146 | def __len__ (self): 147 | return self.num_samples 148 | 149 | def get_sampler(): 150 | return ClassAwareSampler 151 | -------------------------------------------------------------------------------- /SimCLR/unbalance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | from PIL import Image 6 | 7 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10): 8 | cls_num = 10 9 | 10 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, num_expert=3, train=True, 11 | transform=None, target_transform=None, 12 | download=False): 13 | super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download) 14 | self.new_labels = [] 15 | np.random.seed(rand_number) 16 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) 17 | self.gen_imbalanced_data(img_num_list,num_expert) 18 | 19 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 20 | img_max = len(self.data) / cls_num 21 | img_num_per_cls = [] 22 | if imb_type == 'exp': 23 | for cls_idx in range(cls_num): 24 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 25 | img_num_per_cls.append(int(num)) 26 | elif imb_type == 'step': 27 | for cls_idx in range(cls_num // 2): 28 | img_num_per_cls.append(int(img_max)) 29 | for cls_idx in range(cls_num // 2): 30 | img_num_per_cls.append(int(img_max * imb_factor)) 31 | else: 32 | img_num_per_cls.extend([int(img_max)] * cls_num) 33 | return img_num_per_cls 34 | 35 | def gen_imbalanced_data(self, img_num_per_cls,num_expert): 36 | new_data = [] 37 | new_targets = [] 38 | targets_np = np.array(self.targets, dtype=np.int64) 39 | classes = np.unique(targets_np) 40 | # np.random.shuffle(classes) 41 | self.num_per_cls_dict = dict() 42 | for the_class, the_img_num in zip(classes, img_num_per_cls): 43 | self.num_per_cls_dict[the_class] = the_img_num 44 | idx = np.where(targets_np == the_class)[0] 45 | np.random.shuffle(idx) 46 | selec_idx = idx[:the_img_num] 47 | new_data.append(self.data[selec_idx, ...]) 48 | new_targets.extend([the_class, ] * the_img_num) 49 | new_data = np.vstack(new_data) 50 | self.data = new_data 51 | self.targets = new_targets 52 | def get_cls_num_list(self): 53 | cls_num_list = [] 54 | for i in range(self.cls_num): 55 | cls_num_list.append(self.num_per_cls_dict[i]) 56 | return cls_num_list 57 | def __getitem__(self, index: int): 58 | """ 59 | Args: 60 | index (int): Index 61 | 62 | Returns: 63 | tuple: (image, target) where target is index of the target class. 64 | """ 65 | img, target = self.data[index], self.targets[index] 66 | if self.new_labels !=[]: 67 | new_labels = self.new_labels[index] 68 | else: 69 | new_labels =-1 70 | # doing this so that it is consistent with all other datasets 71 | # to return a PIL Image 72 | img = Image.fromarray(img) 73 | 74 | if self.transform is not None: 75 | img = self.transform(img) 76 | 77 | if self.target_transform is not None: 78 | target = self.target_transform(target) 79 | 80 | return img, target,new_labels 81 | 82 | 83 | class IMBALANCECIFAR100(IMBALANCECIFAR10): 84 | """`CIFAR100 `_ Dataset. 85 | This is a subclass of the `CIFAR10` Dataset. 86 | """ 87 | base_folder = 'cifar-100-python' 88 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 89 | filename = "cifar-100-python.tar.gz" 90 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 91 | train_list = [ 92 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 93 | ] 94 | 95 | test_list = [ 96 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 97 | ] 98 | meta = { 99 | 'filename': 'meta', 100 | 'key': 'fine_label_names', 101 | 'md5': '7973b15100ade9c7d40fb424638fde48', 102 | } 103 | cls_num = 100 104 | 105 | 106 | if __name__ == '__main__': 107 | transform = transforms.Compose( 108 | [transforms.ToTensor(), 109 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 110 | trainset = IMBALANCECIFAR10(root='./data', train=True, 111 | download=True, transform=transform) 112 | num=0 113 | for i in range(len(trainset)): 114 | if trainset.targets[i]==0: 115 | num=num+1 116 | print(num) 117 | trainloader = iter(trainset) 118 | data, label = next(trainloader) 119 | print(data) 120 | -------------------------------------------------------------------------------- /SimCLR/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import shutil 3 | import os 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | from sklearn.metrics import confusion_matrix 9 | from sklearn.utils.multiclass import unique_labels 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import random 14 | from PIL import ImageFilter 15 | import math 16 | 17 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 18 | 19 | def __init__(self, dataset, indices=None, num_samples=None): 20 | 21 | # if indices is not provided, 22 | # all elements in the dataset will be considered 23 | self.indices = list(range(len(dataset))) \ 24 | if indices is None else indices 25 | 26 | # if num_samples is not provided, 27 | # draw `len(indices)` samples in each iteration 28 | self.num_samples = len(self.indices) \ 29 | if num_samples is None else num_samples 30 | 31 | # distribution of classes in the dataset 32 | label_to_count = [0] * len(np.unique(dataset.targets)) 33 | for idx in self.indices: 34 | label = self._get_label(dataset, idx) 35 | label_to_count[label] += 1 36 | 37 | beta = 0.9999 38 | effective_num = 1.0 - np.power(beta, label_to_count) 39 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 40 | # weight for each sample 41 | weights = [per_cls_weights[self._get_label(dataset, idx)] 42 | for idx in self.indices] 43 | self.weights = torch.DoubleTensor(weights) 44 | 45 | def _get_label(self, dataset, idx): 46 | return dataset.targets[idx] 47 | 48 | def __iter__(self): 49 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist()) 50 | 51 | def __len__(self): 52 | return self.num_samples 53 | 54 | def calc_confusion_mat(val_loader, model, args): 55 | 56 | model.eval() 57 | all_preds = [] 58 | all_targets = [] 59 | with torch.no_grad(): 60 | for i, (input, target) in enumerate(val_loader): 61 | if args.gpu is not None: 62 | input = input.cuda(args.gpu, non_blocking=True) 63 | target = target.cuda(args.gpu, non_blocking=True) 64 | 65 | # compute output 66 | output = model(input) 67 | _, pred = torch.max(output, 1) 68 | all_preds.extend(pred.cpu().numpy()) 69 | all_targets.extend(target.cpu().numpy()) 70 | cf = confusion_matrix(all_targets, all_preds).astype(float) 71 | 72 | cls_cnt = cf.sum(axis=1) 73 | cls_hit = np.diag(cf) 74 | 75 | cls_acc = cls_hit / cls_cnt 76 | 77 | print('Class Accuracy : ') 78 | print(cls_acc) 79 | classes = [str(x) for x in args.cls_num_list] 80 | plot_confusion_matrix(all_targets, all_preds, classes) 81 | plt.savefig(os.path.join(args.root_log, args.store_name, 'confusion_matrix.png')) 82 | 83 | def plot_confusion_matrix(y_true, y_pred, classes, 84 | normalize=False, 85 | title=None, 86 | cmap=plt.cm.Blues): 87 | 88 | if not title: 89 | if normalize: 90 | title = 'Normalized confusion matrix' 91 | else: 92 | title = 'Confusion matrix, without normalization' 93 | 94 | # Compute confusion matrix 95 | cm = confusion_matrix(y_true, y_pred) 96 | 97 | fig, ax = plt.subplots() 98 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 99 | ax.figure.colorbar(im, ax=ax) 100 | # We want to show all ticks... 101 | ax.set(xticks=np.arange(cm.shape[1]), 102 | yticks=np.arange(cm.shape[0]), 103 | # ... and label them with the respective list entries 104 | xticklabels=classes, yticklabels=classes, 105 | title=title, 106 | ylabel='True label', 107 | xlabel='Predicted label') 108 | 109 | # Rotate the tick labels and set their alignment. 110 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 111 | rotation_mode="anchor") 112 | 113 | # Loop over data dimensions and create text annotations. 114 | fmt = '.2f' if normalize else 'd' 115 | thresh = cm.max() / 2. 116 | for i in range(cm.shape[0]): 117 | for j in range(cm.shape[1]): 118 | ax.text(j, i, format(cm[i, j], fmt), 119 | ha="center", va="center", 120 | color="white" if cm[i, j] > thresh else "black") 121 | fig.tight_layout() 122 | return ax 123 | 124 | def prepare_folders(args): 125 | 126 | folders_util = [args.root_log, args.root_model, 127 | os.path.join(args.root_log, args.store_name), 128 | os.path.join(args.root_model, args.store_name)] 129 | for folder in folders_util: 130 | if not os.path.exists(folder): 131 | print('creating folder ' + folder) 132 | os.mkdir(folder) 133 | 134 | def save_checkpoint(args, state, is_best): 135 | 136 | filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name) 137 | torch.save(state, filename) 138 | if is_best: 139 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 140 | 141 | 142 | class AverageMeter(object): 143 | 144 | def __init__(self, name, fmt=':f'): 145 | self.name = name 146 | self.fmt = fmt 147 | self.reset() 148 | 149 | def reset(self): 150 | self.val = 0 151 | self.avg = 0 152 | self.sum = 0 153 | self.count = 0 154 | 155 | def update(self, val, n=1): 156 | self.val = val 157 | self.sum += val * n 158 | self.count += n 159 | self.avg = self.sum / self.count 160 | 161 | def __str__(self): 162 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 163 | return fmtstr.format(**self.__dict__) 164 | 165 | 166 | def accuracy(output, target, topk=(1,)): 167 | 168 | with torch.no_grad(): 169 | maxk = max(topk) 170 | batch_size = target.size(0) 171 | 172 | _, pred = output.topk(maxk, 1, True, True) 173 | pred = pred.t() 174 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 175 | 176 | res = [] 177 | for k in topk: 178 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 179 | res.append(correct_k.mul_(100.0 / batch_size)) 180 | return res 181 | 182 | 183 | class TwoCropTransform: 184 | """Create two crops of the same image""" 185 | def __init__(self, transform): 186 | self.transform = transform 187 | 188 | def __call__(self, x): 189 | return [self.transform(x), self.transform(x)] 190 | 191 | class GaussianBlur(object): 192 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 193 | 194 | def __init__(self, sigma=[.1, 2.]): 195 | self.sigma = sigma 196 | 197 | def __call__(self, x): 198 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 199 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 200 | return x 201 | 202 | -------------------------------------------------------------------------------- /image/SBCL.jpg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /moco/BaseDareLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | 51 | # turn off shuffle option which is mutually exclusive with sampler 52 | self.shuffle = False 53 | self.n_samples = len(train_idx) 54 | 55 | return train_sampler, valid_sampler 56 | 57 | def split_validation(self): 58 | if self.valid_sampler is None: 59 | return None 60 | else: 61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 62 | -------------------------------------------------------------------------------- /moco/balanced_cluster.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | from tqdm import trange, tqdm 7 | 8 | 9 | def balanced_kmean( 10 | X, 11 | n_clusters, 12 | init='k-means++', 13 | device=torch.device('cpu'), 14 | tol=1e-4, 15 | iol=100, 16 | distance='cosine' 17 | 18 | ): 19 | ''' 20 | X: the clustered feature 21 | n_clusters: the cluster number 22 | ''' 23 | # convert to float 24 | X = X.float() 25 | 26 | # transfer to device 27 | X = X.to(device) 28 | 29 | if distance == 'euclidean': 30 | pairwise_similarity_function = partial(pairwise_euclidean, device=device) 31 | elif distance == 'cosine': 32 | pairwise_similarity_function = partial(pairwise_cosine, device=device) 33 | else: 34 | raise NotImplementedError 35 | # initialize 36 | if init == 'random': 37 | centroids = initialize(X, n_clusters) 38 | elif init == 'k-means++': 39 | centroids, _ = _kmeans_plusplus(X, 40 | n_clusters, 41 | random_state=0, 42 | pairwise_similarity=pairwise_similarity_function, 43 | n_local_trials=None) 44 | 45 | 46 | else: 47 | raise NotImplementedError 48 | 49 | N = len(X) 50 | n_per_cluster = N // n_clusters 51 | n_left = N % n_clusters 52 | for i in trange(iol): 53 | similarity_matrix = pairwise_similarity_function(centroids, X) 54 | similarity_matrix = similarity_matrix / similarity_matrix.sum(dim=1, keepdim=True) 55 | cluster_assignment = torch.zeros(N, dtype=torch.long) - 1 56 | cluster_size = {c: 0 for c in range(n_clusters)} 57 | 58 | idx = torch.argsort(similarity_matrix.flatten(), descending=True) 59 | #print(idx) 60 | 61 | if n_left == 0: 62 | for labels in idx: 63 | labels = labels.item() 64 | des = labels % N 65 | label = labels // N 66 | if cluster_assignment[des] == -1 and cluster_size[label] < n_per_cluster: 67 | cluster_assignment[des] = label 68 | cluster_size[label] += 1 69 | else: 70 | for labels in idx: 71 | labels = labels.item() 72 | des = labels % N 73 | label = labels // N 74 | if cluster_assignment[des] == -1 and cluster_size[label] < n_per_cluster: 75 | cluster_assignment[des] = label 76 | cluster_size[label] += 1 77 | similarity_matrix[:, des] = -100 78 | for _ in range(n_left): 79 | labels = torch.argmax(similarity_matrix).item() 80 | des = labels % N 81 | label = labels // N 82 | cluster_assignment[des] = label 83 | similarity_matrix[:, des] = -100 84 | cluster_size[label] += 1 85 | if cluster_size[label] >= n_per_cluster + 1: 86 | similarity_matrix[label, :] = -100 87 | 88 | assert torch.all(cluster_assignment != -1) 89 | 90 | last_centroids = centroids.clone() 91 | for index in range(n_clusters): 92 | centroids[index] = X[cluster_assignment == index].mean(dim=0) 93 | 94 | center_shift = torch.sum( 95 | torch.sqrt( 96 | torch.sum((centroids - last_centroids) ** 2, dim=1) 97 | )) 98 | 99 | # update tqdm meter 100 | if center_shift ** 2 < tol: 101 | break 102 | 103 | return cluster_assignment.cpu(), centroids.cpu() 104 | 105 | 106 | # def balanced_kmeans1( 107 | # X, 108 | # n_clusters, 109 | # init='k-means++', 110 | # device=torch.device('cpu'), 111 | # tol=1e-4, 112 | # iol=100, 113 | # distance='cosine' 114 | # 115 | # ): 116 | # ''' 117 | # X: the clustered feature 118 | # n_clusters: the cluster number 119 | # ''' 120 | # # convert to float 121 | # X = X.float() 122 | # 123 | # # transfer to device 124 | # X = X.to(device) 125 | # 126 | # if distance == 'euclidean': 127 | # pairwise_similarity_function = partial(pairwise_euclidean, device=device) 128 | # elif distance == 'cosine': 129 | # pairwise_similarity_function = partial(pairwise_cosine, device=device) 130 | # else: 131 | # raise NotImplementedError 132 | # # initialize 133 | # if init == 'random': 134 | # centroids = initialize(X, n_clusters) 135 | # elif init == 'k-means++': 136 | # centroids, _ = _kmeans_plusplus(X, 137 | # n_clusters, 138 | # random_state=0, 139 | # pairwise_distance=pairwise_similarity_function, 140 | # n_local_trials=None) 141 | # 142 | # 143 | # else: 144 | # raise NotImplementedError 145 | # 146 | # # centroids = KMeans(n_clusters=n_clusters)._init_centroids(X.cpu().numpy(), x_squared_norms=None, 147 | # # init=init, random_state=np.random.RandomState(seed=0)) 148 | # # centroids = torch.from_numpy(centroids).to(device) 149 | # 150 | # N = len(X) 151 | # n_per_cluster = N // n_clusters 152 | # n_left = N % n_clusters 153 | # for i in trange(iol): 154 | # similarity_matrix = pairwise_similarity_function(centroids, X) 155 | # similarity_matrix = similarity_matrix / similarity_matrix.sum(dim=1, keepdim=True) 156 | # cluster_assignment = torch.zeros(N) - 1 157 | # cluster_size = {c: 0 for c in range(n_clusters)} 158 | # 159 | # if n_left == 0: 160 | # for _ in range(len(X)): 161 | # labels = torch.argmax(similarity_matrix).item() 162 | # label = labels // len(X) 163 | # des = labels % len(X) 164 | # cluster_assignment[des] = label 165 | # similarity_matrix[:, des] = -100 166 | # cluster_size[label] += 1 167 | # if cluster_size[label] >= n_per_cluster: 168 | # similarity_matrix[label, :] = -100 169 | # else: 170 | # similarity_matrix_clone = similarity_matrix.clone() 171 | # for _ in range(n_per_cluster * n_clusters): 172 | # labels = torch.argmax(similarity_matrix).item() 173 | # label = labels // len(X) 174 | # des = labels % len(X) 175 | # cluster_assignment[des] = label 176 | # similarity_matrix[:, des] = -100 177 | # similarity_matrix_clone[:, des] = -100 178 | # cluster_size[label] += 1 179 | # if cluster_size[label] >= n_per_cluster: 180 | # similarity_matrix[label, :] = -100 181 | # for _ in range(n_left): 182 | # labels = torch.argmax(similarity_matrix_clone).item() 183 | # label = labels // len(X) 184 | # des = labels % len(X) 185 | # cluster_assignment[des] = label 186 | # similarity_matrix_clone[:, des] = -100 187 | # cluster_size[label] += 1 188 | # if cluster_size[label] >= n_per_cluster + 1: 189 | # similarity_matrix_clone[label, :] = -100 190 | # 191 | # last_centroids = centroids.clone() 192 | # for index in range(n_clusters): 193 | # centroids[index] = X[cluster_assignment == index].mean(dim=0) 194 | # 195 | # center_shift = torch.sum( 196 | # torch.sqrt( 197 | # torch.sum((centroids - last_centroids) ** 2, dim=1) 198 | # )) 199 | # 200 | # # update tqdm meter 201 | # if center_shift ** 2 < tol: 202 | # break 203 | # 204 | # return cluster_assignment.cpu().numpy(), centroids.cpu().numpy() 205 | 206 | 207 | def pairwise_cosine(data1, data2, device=torch.device('cpu')): 208 | data1, data2 = data1.to(device), data2.to(device) 209 | A = data1.unsqueeze(dim=1) 210 | B = data2.unsqueeze(dim=0) 211 | A_normalized = A / A.norm(dim=-1, keepdim=True) 212 | B_normalized = B / B.norm(dim=-1, keepdim=True) 213 | cosine = A_normalized * B_normalized 214 | cosine_dis = cosine.sum(dim=-1).squeeze() 215 | return cosine_dis 216 | 217 | 218 | def pairwise_euclidean(data1, data2, device=torch.device('cpu')): 219 | data1, data2 = data1.to(device), data2.to(device) 220 | A = data1.unsqueeze(dim=1) 221 | B = data2.unsqueeze(dim=0) 222 | dis = (A - B) ** 2.0 223 | dis = 1 / (dis.sum(dim=-1).squeeze() + 1e-4) 224 | return dis 225 | 226 | 227 | def initialize(X, num_clusters): 228 | """ 229 | initialize cluster centers 230 | """ 231 | num_samples = X.shape[1] 232 | bs = X.shape[0] 233 | 234 | indices = torch.empty(X.shape[:-1], device=X.device, dtype=torch.long) 235 | for i in range(bs): 236 | indices[i] = torch.randperm(num_samples, device=X.device) 237 | initial_state = torch.gather(X, 1, indices.unsqueeze(-1).repeat(1, 1, X.shape[-1])).reshape(bs, num_clusters, -1, X.shape[-1]).mean(dim=-2) 238 | return initial_state 239 | 240 | 241 | def stable_cumsum(arr, dim=None, rtol=1e-05, atol=1e-08): 242 | """Use high precision for cumsum and check that final value matches sum. 243 | """ 244 | if dim is None: 245 | arr = arr.flatten() 246 | dim = 0 247 | out = torch.cumsum(arr, dim=dim, dtype=torch.float64) 248 | expected = torch.sum(arr, dim=dim, dtype=torch.float64) 249 | if not torch.all(torch.isclose(out.take(torch.Tensor([-1]).long().to(arr.device)), 250 | expected, rtol=rtol, 251 | atol=atol, equal_nan=True)): 252 | warnings.warn('cumsum was found to be unstable: ' 253 | 'its last element does not correspond to sum', 254 | RuntimeWarning) 255 | return out 256 | 257 | 258 | def _kmeans_plusplus(X, 259 | n_clusters, 260 | random_state, 261 | pairwise_similarity, 262 | n_local_trials=None): 263 | """Computational component for initialization of n_clusters by 264 | k-means++. Prior validation of data is assumed. 265 | """ 266 | n_samples, n_features = X.shape 267 | 268 | generator = torch.Generator(device=str(X.device)) 269 | generator.manual_seed(random_state) 270 | 271 | centers = torch.empty((n_clusters, n_features), dtype=X.dtype, device=X.device) 272 | 273 | # Set the number of local seeding trials if none is given 274 | if n_local_trials is None: 275 | # This is what Arthur/Vassilvitskii tried, but did not report 276 | # specific results for other than mentioning in the conclusion 277 | # that it helped. 278 | n_local_trials = 2 + int(np.log(n_clusters)) 279 | 280 | # Pick first center randomly and track index of point 281 | # center_id = random_state.randint(n_samples) 282 | center_id = torch.randint(n_samples, (1,), generator=generator, device=X.device) 283 | 284 | indices = torch.full((n_clusters,), -1, dtype=torch.int, device=X.device) 285 | centers[0] = X[center_id] 286 | indices[0] = center_id 287 | 288 | # Initialize list of closest distances and calculate current potential 289 | closest_dist_sq = 1/pairwise_similarity( 290 | centers[0, None], X) 291 | current_pot = closest_dist_sq.sum() 292 | 293 | # Pick the remaining n_clusters-1 points 294 | for c in range(1, n_clusters): 295 | # Choose center candidates by sampling with probability proportional 296 | # to the squared distance to the closest existing center 297 | # rand_vals = random_state.random_sample(n_local_trials) * current_pot 298 | rand_vals = torch.rand(n_local_trials, generator=generator, device=X.device) * current_pot 299 | 300 | candidate_ids = torch.searchsorted(stable_cumsum(closest_dist_sq), 301 | rand_vals) 302 | # XXX: numerical imprecision can result in a candidate_id out of range 303 | torch.clip(candidate_ids, None, closest_dist_sq.numel() - 1, 304 | out=candidate_ids) 305 | 306 | # Compute distances to center candidates 307 | distance_to_candidates = 1/pairwise_similarity( 308 | X[candidate_ids], X) 309 | 310 | # update closest distances squared and potential for each candidate 311 | torch.minimum(closest_dist_sq, distance_to_candidates, 312 | out=distance_to_candidates) 313 | candidates_pot = distance_to_candidates.sum(dim=1) 314 | 315 | # Decide which candidate is the best 316 | best_candidate = torch.argmin(candidates_pot) 317 | current_pot = candidates_pot[best_candidate] 318 | closest_dist_sq = distance_to_candidates[best_candidate] 319 | best_candidate = candidate_ids[best_candidate] 320 | 321 | # Permanently add best center candidate found in local tries 322 | centers[c] = X[best_candidate] 323 | indices[c] = best_candidate 324 | 325 | return centers, indices 326 | if __name__ == '__main__': 327 | X = torch.randn(6, 3) 328 | cluster_label,_ = balanced_kmean(X,n_clusters=3,init='k-means++',device=torch.device('cpu'),tol=1e-4,iol=100, 329 | distance='euclidean') 330 | print(cluster_label) 331 | -------------------------------------------------------------------------------- /moco/imagenet_lt_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import os, sys 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader, Dataset, Sampler 7 | import BaseDataLoader 8 | from PIL import Image 9 | from utils import* 10 | 11 | import torch 12 | import numpy as np 13 | import torchvision 14 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 15 | from torchvision import transforms 16 | import os 17 | from PIL import Image 18 | import random 19 | 20 | class ImageNetLT_moco(Dataset): 21 | num_classes=1000 22 | def __init__(self, root, txt, transform=None, class_balance=False): 23 | self.img_path = [] 24 | self.labels = [] 25 | self.new_labels = [] 26 | self.transform = transform 27 | self.class_balance=class_balance 28 | with open(txt) as f: 29 | for line in f: 30 | self.img_path.append(os.path.join(root, line.split()[0])) 31 | self.labels.append(int(line.split()[1])) 32 | 33 | self.class_data=[[] for i in range(self.num_classes)] 34 | for i in range(len(self.labels)): 35 | y=self.labels[i] 36 | self.class_data[y].append(i) 37 | 38 | self.cls_num_list=[len(self.class_data[i]) for i in range(self.num_classes)] 39 | 40 | 41 | def __len__(self): 42 | return len(self.labels) 43 | 44 | def __getitem__(self, index): 45 | if self.class_balance: 46 | label=random.randint(0,self.num_classes-1) 47 | index=random.choice(self.class_data[label]) 48 | path1 = self.img_path[index] 49 | else: 50 | path1 = self.img_path[index] 51 | label = self.labels[index] 52 | if self.new_labels !=[]: 53 | new_labels = self.new_labels[index] 54 | else: 55 | new_labels =-1 56 | with open(path1, 'rb') as f: 57 | img = Image.open(f).convert('RGB') 58 | if len(self.transform)==2: 59 | sample1 = self.transform[0](img) 60 | sample2 = self.transform[1](img) 61 | return [sample1, sample2], label, new_labels 62 | else: 63 | sample1 = self.transform[0](img) 64 | return sample1, label, new_labels 65 | 66 | 67 | class ImageNetLT_val(Dataset): 68 | num_classes=1000 69 | def __init__(self, root, txt, transform=None, class_balance=False): 70 | self.img_path = [] 71 | self.labels = [] 72 | self.transform = transform 73 | self.class_balance=class_balance 74 | with open(txt) as f: 75 | for line in f: 76 | lst_url = line.split()[0].split('/') 77 | lst_url.pop(-2) 78 | str_url = '/'.join(lst_url) 79 | self.img_path.append(os.path.join(root, str_url)) 80 | self.labels.append(int(line.split()[1])) 81 | 82 | self.class_data=[[] for i in range(self.num_classes)] 83 | for i in range(len(self.labels)): 84 | y=self.labels[i] 85 | self.class_data[y].append(i) 86 | 87 | self.cls_num_list=[len(self.class_data[i]) for i in range(self.num_classes)] 88 | 89 | 90 | def __len__(self): 91 | return len(self.labels) 92 | 93 | def __getitem__(self, index): 94 | if self.class_balance: 95 | label=random.randint(0,self.num_classes-1) 96 | index=random.choice(self.class_data[label]) 97 | path1 = self.img_path[index] 98 | else: 99 | path1 = self.img_path[index] 100 | label = self.labels[index] 101 | 102 | with open(path1, 'rb') as f: 103 | img = Image.open(f).convert('RGB') 104 | if len(self.transform)==2: 105 | sample1 = self.transform[0](img) 106 | sample2 = self.transform[1](img) 107 | return [sample1, sample2], label 108 | else: 109 | sample1 = self.transform[0](img) 110 | return sample1, label 111 | -------------------------------------------------------------------------------- /moco/kmean_gpu.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | #from .soft_dtw_cuda import SoftDTW 8 | 9 | 10 | def initialize(X, num_clusters, seed): 11 | """ 12 | initialize cluster centers 13 | :param X: (torch.tensor) matrix 14 | :param num_clusters: (int) number of clusters 15 | :param seed: (int) seed for kmeans 16 | :return: (np.array) initial state 17 | """ 18 | num_samples = len(X) 19 | if seed == None: 20 | indices = np.random.choice(num_samples, num_clusters, replace=False) 21 | else: 22 | np.random.seed(seed) ; indices = np.random.choice(num_samples, num_clusters, replace=False) 23 | initial_state = X[indices] 24 | return initial_state 25 | 26 | 27 | def kmeans( 28 | X, 29 | num_clusters, 30 | distance='euclidean', 31 | cluster_centers=[], 32 | tol=1e-4, 33 | tqdm_flag=True, 34 | iter_limit=0, 35 | device=torch.device('cpu'), 36 | gamma_for_soft_dtw=0.001, 37 | seed=None, 38 | ): 39 | """ 40 | perform kmeans 41 | :param X: (torch.tensor) matrix 42 | :param num_clusters: (int) number of clusters 43 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 44 | :param seed: (int) seed for kmeans 45 | :param tol: (float) threshold [default: 0.0001] 46 | :param device: (torch.device) device [default: cpu] 47 | :param tqdm_flag: Allows to turn logs on and off 48 | :param iter_limit: hard limit for max number of iterations 49 | :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 50 | :return: (torch.tensor, torch.tensor) cluster ids, cluster centers 51 | """ 52 | if tqdm_flag: 53 | print(f'running k-means on {device}..') 54 | 55 | if distance == 'euclidean': 56 | pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) 57 | elif distance == 'cosine': 58 | pairwise_distance_function = partial(pairwise_cosine, device=device) 59 | #elif distance == 'soft_dtw': 60 | # sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw) 61 | # pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device) 62 | else: 63 | raise NotImplementedError 64 | 65 | # convert to float 66 | X = X.float() 67 | 68 | # transfer to device 69 | X = X.to(device) 70 | 71 | # initialize 72 | if type(cluster_centers) == list: # ToDo: make this less annoyingly weird 73 | initial_state = initialize(X, num_clusters, seed=seed) 74 | else: 75 | if tqdm_flag: 76 | print('resuming') 77 | # find data point closest to the initial cluster center 78 | initial_state = cluster_centers 79 | dis = pairwise_distance_function(X, initial_state) 80 | choice_points = torch.argmin(dis, dim=0) 81 | initial_state = X[choice_points] 82 | initial_state = initial_state.to(device) 83 | 84 | iteration = 0 85 | if tqdm_flag: 86 | tqdm_meter = tqdm(desc='[running kmeans]') 87 | while True: 88 | 89 | dis = pairwise_distance_function(X, initial_state) 90 | 91 | choice_cluster = torch.argmin(dis, dim=1) 92 | 93 | initial_state_pre = initial_state.clone() 94 | 95 | for index in range(num_clusters): 96 | selected = torch.nonzero(choice_cluster == index).squeeze().to(device) 97 | 98 | selected = torch.index_select(X, 0, selected) 99 | 100 | # https://github.com/subhadarship/kmeans_pytorch/issues/16 101 | if selected.shape[0] == 0: 102 | selected = X[torch.randint(len(X), (1,))] 103 | 104 | initial_state[index] = selected.mean(dim=0) 105 | 106 | center_shift = torch.sum( 107 | torch.sqrt( 108 | torch.sum((initial_state - initial_state_pre) ** 2, dim=1) 109 | )) 110 | 111 | # increment iteration 112 | iteration = iteration + 1 113 | 114 | # update tqdm meter 115 | if tqdm_flag: 116 | tqdm_meter.set_postfix( 117 | iteration=f'{iteration}', 118 | center_shift=f'{center_shift ** 2:0.6f}', 119 | tol=f'{tol:0.6f}' 120 | ) 121 | tqdm_meter.update() 122 | if center_shift ** 2 < tol: 123 | break 124 | if iter_limit != 0 and iteration >= iter_limit: 125 | break 126 | 127 | return choice_cluster.cpu(), initial_state.cpu() 128 | 129 | 130 | def kmeans_predict( 131 | X, 132 | cluster_centers, 133 | distance='euclidean', 134 | device=torch.device('cpu'), 135 | gamma_for_soft_dtw=0.001, 136 | tqdm_flag=True 137 | ): 138 | """ 139 | predict using cluster centers 140 | :param X: (torch.tensor) matrix 141 | :param cluster_centers: (torch.tensor) cluster centers 142 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 143 | :param device: (torch.device) device [default: 'cpu'] 144 | :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 145 | :return: (torch.tensor) cluster ids 146 | """ 147 | if tqdm_flag: 148 | print(f'predicting on {device}..') 149 | 150 | if distance == 'euclidean': 151 | pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) 152 | elif distance == 'cosine': 153 | pairwise_distance_function = partial(pairwise_cosine, device=device) 154 | #elif distance == 'soft_dtw': 155 | # sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw) 156 | # pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device) 157 | else: 158 | raise NotImplementedError 159 | 160 | # convert to float 161 | X = X.float() 162 | 163 | # transfer to device 164 | X = X.to(device) 165 | 166 | dis = pairwise_distance_function(X, cluster_centers) 167 | choice_cluster = torch.argmin(dis, dim=1) 168 | 169 | return choice_cluster.cpu() 170 | 171 | 172 | def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True): 173 | if tqdm_flag: 174 | print(f'device is :{device}') 175 | 176 | # transfer to device 177 | data1, data2 = data1.to(device), data2.to(device) 178 | 179 | # N*1*M 180 | A = data1.unsqueeze(dim=1) 181 | 182 | # 1*N*M 183 | B = data2.unsqueeze(dim=0) 184 | 185 | dis = (A - B) ** 2.0 186 | # return N*N matrix for pairwise distance 187 | dis = dis.sum(dim=-1).squeeze() 188 | return dis 189 | 190 | 191 | def pairwise_cosine(data1, data2, device=torch.device('cpu')): 192 | # transfer to device 193 | data1, data2 = data1.to(device), data2.to(device) 194 | 195 | # N*1*M 196 | A = data1.unsqueeze(dim=1) 197 | 198 | # 1*N*M 199 | B = data2.unsqueeze(dim=0) 200 | 201 | # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] 202 | A_normalized = A / A.norm(dim=-1, keepdim=True) 203 | B_normalized = B / B.norm(dim=-1, keepdim=True) 204 | 205 | cosine = A_normalized * B_normalized 206 | 207 | # return N*N matrix for pairwise distance 208 | cosine_dis = 1 - cosine.sum(dim=-1).squeeze() 209 | return cosine_dis 210 | -------------------------------------------------------------------------------- /moco/linear_classify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from sample import ClassAwareSampler 22 | from loss import LDAMLoss 23 | from imagenet_lt_loader import ImageNetLT_moco, ImageNetLT_val 24 | from utils import* 25 | 26 | model_names = sorted(name for name in models.__dict__ 27 | if name.islower() and not name.startswith("__") 28 | and callable(models.__dict__[name])) 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 32 | choices=model_names, 33 | help='model architecture: ' + 34 | ' | '.join(model_names) + 35 | ' (default: resnet50)') 36 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 37 | help='number of data loading workers (default: 32)') 38 | parser.add_argument('--epochs', default=40, type=int, metavar='N', 39 | help='number of total epochs to run') 40 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 41 | help='manual epoch number (useful on restarts)') 42 | parser.add_argument('-b', '--batch-size', default=2048, type=int, 43 | metavar='N', 44 | help='mini-batch size (default: 256), this is the total ' 45 | 'batch size of all GPUs on the current node when ' 46 | 'using Data Parallel or Distributed Data Parallel') 47 | parser.add_argument('--train_rule', default='CB', type=str, help='data sampling strategy for train loader') 48 | parser.add_argument('--lr', '--learning-rate', default=10, type=float, 49 | metavar='LR', help='initial learning rate', dest='lr') 50 | parser.add_argument('--schedule', default=[20, 30], nargs='*', type=int, 51 | help='learning rate schedule (when to drop lr by a ratio)') 52 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 53 | help='momentum') 54 | parser.add_argument('--wd', '--weight-decay', default=0, type=float, 55 | metavar='W', help='weight decay (default: 0.)', 56 | dest='weight_decay') 57 | parser.add_argument('-p', '--print-freq', default=10, type=int, 58 | metavar='N', help='print frequency (default: 10)') 59 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 60 | help='path to latest checkpoint (default: none)') 61 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 62 | help='evaluate model on validation set') 63 | parser.add_argument('--seed', default=None, type=int, 64 | help='seed for initializing training. ') 65 | parser.add_argument('--gpu', default=None, type=int, 66 | help='GPU id to use.') 67 | parser.add_argument('--multiprocessing-distributed', action='store_true', 68 | help='Use multi-processing distributed training to launch ' 69 | 'N processes per node, which has N GPUs. This is the ' 70 | 'fastest way to use PyTorch for either single node or ' 71 | 'multi node data parallel training') 72 | 73 | parser.add_argument('--pretrained', default='', type=str, 74 | help='path to moco pretrained checkpoint') 75 | best_acc1 = 0 76 | 77 | 78 | def main(): 79 | args = parser.parse_args() 80 | 81 | if args.seed is not None: 82 | random.seed(args.seed) 83 | torch.manual_seed(args.seed) 84 | cudnn.deterministic = True 85 | warnings.warn('You have chosen to seed training. ' 86 | 'This will turn on the CUDNN deterministic setting, ' 87 | 'which can slow down your training considerably! ' 88 | 'You may see unexpected behavior when restarting ' 89 | 'from checkpoints.') 90 | 91 | if args.gpu is not None: 92 | warnings.warn('You have chosen a specific GPU. This will completely ' 93 | 'disable data parallelism.') 94 | 95 | global best_acc1 96 | # create model 97 | print("=> creating model '{}'".format(args.arch)) 98 | num_classes = 1000 99 | model = models.__dict__[args.arch](num_classes=num_classes) 100 | 101 | # freeze all layers but the last fc 102 | for name, param in model.named_parameters(): 103 | if name not in ['fc.weight', 'fc.bias']: 104 | param.requires_grad = False 105 | # init the fc layer 106 | if args.train_rule == 'DRW': 107 | model.fc=NormedLinear_Classifier() 108 | else: 109 | model.fc.weight.data.normal_(mean=0.0, std=0.01) 110 | model.fc.bias.data.zero_() 111 | 112 | # load from pre-trained, before DistributedDataParallel constructor 113 | if args.pretrained: 114 | if os.path.isfile(args.pretrained): 115 | print("=> loading checkpoint '{}'".format(args.pretrained)) 116 | checkpoint = torch.load(args.pretrained, map_location="cpu") 117 | 118 | # rename moco pre-trained keys 119 | state_dict = checkpoint['state_dict'] 120 | for k in list(state_dict.keys()): 121 | # retain only encoder_q up to before the embedding layer 122 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 123 | # remove prefix 124 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 125 | # delete renamed or unused k 126 | del state_dict[k] 127 | 128 | args.start_epoch = 0 129 | msg = model.load_state_dict(state_dict, strict=False) 130 | if args.train_rule == 'CB' or args.train_rule == 'CE': 131 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 132 | 133 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 134 | 135 | model = torch.nn.DataParallel(model).cuda() 136 | 137 | 138 | # optimize only the linear classifier 139 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 140 | if args.train_rule == 'CB' or args.train_rule == 'CE': 141 | assert len(parameters) == 2 # fc.weight, fc.bias 142 | optimizer = torch.optim.SGD(parameters, args.lr, 143 | momentum=args.momentum, 144 | weight_decay=args.weight_decay) 145 | 146 | # optionally resume from a checkpoint 147 | if args.resume: 148 | if os.path.isfile(args.resume): 149 | print("=> loading checkpoint '{}'".format(args.resume)) 150 | if args.gpu is None: 151 | checkpoint = torch.load(args.resume) 152 | else: 153 | # Map model to be loaded to specified single gpu. 154 | loc = 'cuda:{}'.format(args.gpu) 155 | checkpoint = torch.load(args.resume, map_location=loc) 156 | args.start_epoch = checkpoint['epoch'] 157 | best_acc1 = checkpoint['best_acc1'] 158 | if args.gpu is not None: 159 | # best_acc1 may be from a checkpoint from a different GPU 160 | best_acc1 = best_acc1.to(args.gpu) 161 | model.load_state_dict(checkpoint['state_dict']) 162 | optimizer.load_state_dict(checkpoint['optimizer']) 163 | print("=> loaded checkpoint '{}' (epoch {})" 164 | .format(args.resume, checkpoint['epoch'])) 165 | else: 166 | print("=> no checkpoint found at '{}'".format(args.resume)) 167 | 168 | cudnn.benchmark = True 169 | 170 | # Data loading code 171 | args.data = 'autodl-tmp/imagenet' 172 | txt_train = f'moco/ImageNet_LT_train.txt' 173 | txt_test = f'moco/ImageNet_LT_test.txt' 174 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 175 | train_transform = [ 176 | transforms.RandomResizedCrop(224), 177 | transforms.RandomHorizontalFlip(), 178 | transforms.ToTensor(), 179 | normalize, 180 | ] 181 | transform_train = [transforms.Compose(train_transform)] 182 | train_dataset = ImageNetLT_moco( 183 | root=args.data, 184 | txt=txt_train, 185 | transform=transform_train) 186 | args.cls_num_list = train_dataset.cls_num_list 187 | 188 | train_sampler = None 189 | 190 | if args.train_rule == 'DRW' or args.train_rule == 'CE' : 191 | train_loader = torch.utils.data.DataLoader( 192 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 193 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 194 | print(len(train_loader)) 195 | else: 196 | balance_sampler = ClassAwareSampler(train_dataset) 197 | train_loader = torch.utils.data.DataLoader( 198 | train_dataset, batch_size=args.batch_size, shuffle=False, 199 | num_workers=args.workers, pin_memory=True, sampler=balance_sampler, drop_last=True) 200 | print('CB',len(train_loader)) 201 | 202 | vail_transform = transforms.Compose([ 203 | transforms.Resize(256), 204 | transforms.CenterCrop(224), 205 | transforms.ToTensor(), 206 | normalize, 207 | ]) 208 | 209 | vail_dataset =ImageNetLT_val( 210 | root=args.data, 211 | txt=txt_test, 212 | transform=[vail_transform]) 213 | 214 | val_loader = torch.utils.data.DataLoader( 215 | vail_dataset, 216 | batch_size=args.batch_size, shuffle=False, 217 | num_workers=10, pin_memory=True) 218 | 219 | if args.evaluate: 220 | validate(val_loader, model, criterion, args) 221 | return 222 | 223 | for epoch in range(args.start_epoch, args.epochs): 224 | adjust_learning_rate(optimizer, epoch, args) 225 | if args.train_rule == 'CB'or args.train_rule == 'CE' : 226 | per_cls_weights = None 227 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 228 | elif args.train_rule == 'DRW': 229 | idx = epoch // 60 230 | betas = [0, 0.9999] 231 | effective_num = 1.0 - np.power(betas[idx], args.cls_num_list) 232 | per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num) 233 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(args.cls_num_list) 234 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu) 235 | criterion = LDAMLoss(cls_num_list=args.cls_num_list, max_m=0.5, s=30, weight=per_cls_weights).cuda(args.gpu) 236 | else: 237 | warnings.warn('Sample rule is not listed') 238 | 239 | 240 | 241 | # train for one epoch 242 | train(train_loader, model, criterion, optimizer, epoch, args) 243 | 244 | # evaluate on validation set 245 | acc1, ece = validate(val_loader, model, criterion, args) 246 | 247 | # remember best acc@1 and save checkpoint 248 | is_best = acc1 > best_acc1 249 | best_acc1 = max(acc1, best_acc1) 250 | if is_best: 251 | its_ece = ece 252 | output_best = 'Best Prec@1: %.3f\n' % (best_acc1) 253 | print(output_best) 254 | print(its_ece) 255 | 256 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 257 | and args.rank % ngpus_per_node == 0) : 258 | save_checkpoint({ 259 | 'epoch': epoch + 1, 260 | 'arch': args.arch, 261 | 'state_dict': model.state_dict(), 262 | 'best_acc1': best_acc1, 263 | 'optimizer' : optimizer.state_dict(), 264 | }, is_best=False,filename='Imagenet/liner_checkpoint.pth.tar') 265 | 266 | def train(train_loader, model, criterion, optimizer, epoch, args): 267 | batch_time = AverageMeter('Time', ':6.3f') 268 | data_time = AverageMeter('Data', ':6.3f') 269 | losses = AverageMeter('Loss', ':.4e') 270 | top1 = AverageMeter('Acc@1', ':6.2f') 271 | top5 = AverageMeter('Acc@5', ':6.2f') 272 | training_data_num = len(train_loader.dataset) 273 | epoch_steps = int(training_data_num / args.batch_size) 274 | 275 | progress = ProgressMeter( 276 | epoch_steps, 277 | [batch_time, data_time, losses, top1, top5], 278 | prefix="Epoch: [{}]".format(epoch)) 279 | 280 | model.train() 281 | 282 | end = time.time() 283 | for i, (images, target,index) in enumerate(train_loader): 284 | if i == epoch_steps: 285 | break 286 | # measure data loading time 287 | data_time.update(time.time() - end) 288 | if args.gpu is not None: 289 | images = images.cuda(args.gpu, non_blocking=True) 290 | target = target.cuda(args.gpu, non_blocking=True) 291 | 292 | # compute output 293 | output = model(images) 294 | loss = criterion(output, target) 295 | 296 | # measure accuracy and record loss 297 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 298 | losses.update(loss.item(), images.size(0)) 299 | top1.update(acc1[0], images.size(0)) 300 | top5.update(acc5[0], images.size(0)) 301 | 302 | # compute gradient and do SGD step 303 | optimizer.zero_grad() 304 | loss.backward() 305 | optimizer.step() 306 | 307 | # measure elapsed time 308 | batch_time.update(time.time() - end) 309 | end = time.time() 310 | 311 | if i % args.print_freq == 0: 312 | progress.display(i) 313 | 314 | 315 | def validate(val_loader, model, criterion, args): 316 | batch_time = AverageMeter('Time', ':6.3f') 317 | losses = AverageMeter('Loss', ':.4e') 318 | top1 = AverageMeter('Acc@1', ':6.2f') 319 | top5 = AverageMeter('Acc@5', ':6.2f') 320 | progress = ProgressMeter( 321 | len(val_loader), 322 | [batch_time, losses, top1, top5], 323 | prefix='Test: ') 324 | 325 | # switch to evaluate mode 326 | model.eval() 327 | num_classes =1000 328 | class_num = torch.zeros(num_classes).cuda() 329 | total_logits = torch.empty((0, num_classes)).cuda() 330 | total_labels = torch.empty(0, dtype=torch.long).cuda() 331 | 332 | 333 | with torch.no_grad(): 334 | end = time.time() 335 | for i, (images, target) in enumerate(val_loader): 336 | if args.gpu is not None: 337 | images = images.cuda(args.gpu, non_blocking=True) 338 | target = target.cuda(args.gpu, non_blocking=True) 339 | # compute output 340 | output = model(images) 341 | loss = criterion(output, target) 342 | 343 | # measure accuracy and record loss 344 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 345 | losses.update(loss.item(), images.size(0)) 346 | top1.update(acc1[0], images.size(0)) 347 | top5.update(acc5[0], images.size(0)) 348 | 349 | total_logits = torch.cat((total_logits, output)) 350 | total_labels = torch.cat((total_labels, target)) 351 | 352 | if i % args.print_freq == 0: 353 | progress.display(i) 354 | _, preds = F.softmax(total_logits.detach(), dim=1).max(dim=1) 355 | ece = shot_acc(preds[total_labels != -1], 356 | total_labels[total_labels != -1], 357 | args.cls_num_list) 358 | 359 | print('acc',top1.avg) 360 | return top1.avg, ece 361 | 362 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 363 | torch.save(state, filename) 364 | if is_best: 365 | shutil.copyfile(filename, 'model_best.pth.tar') 366 | 367 | 368 | def adjust_learning_rate(optimizer, epoch, args): 369 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 370 | lr = args.lr 371 | epoch = epoch + 1 372 | for milestone in args.schedule: 373 | lr *= 0.1 if epoch >= milestone else 1. 374 | for param_group in optimizer.param_groups: 375 | param_group['lr'] = lr 376 | 377 | 378 | 379 | if __name__ == '__main__': 380 | main() 381 | -------------------------------------------------------------------------------- /moco/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class LDAMLoss(nn.Module): 8 | 9 | def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30): 10 | super(LDAMLoss, self).__init__() 11 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) 12 | m_list = m_list * (max_m / np.max(m_list)) 13 | m_list = torch.cuda.FloatTensor(m_list) 14 | self.m_list = m_list 15 | assert s > 0 16 | self.s = s 17 | self.weight = weight 18 | 19 | def forward(self, x, target): 20 | index = torch.zeros_like(x, dtype=torch.uint8) 21 | index.scatter_(1, target.data.view(-1, 1), 1) 22 | 23 | index_float = index.type(torch.cuda.FloatTensor) 24 | batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1)) 25 | batch_m = batch_m.view((-1, 1)) 26 | x_m = x - batch_m 27 | 28 | output = torch.where(index, x_m, x) 29 | return F.cross_entropy(self.s*output, target, weight=self.weight) 30 | 31 | class KCL(nn.Module): 32 | def __init__(self,K,k=6,temperature=0.07): 33 | super(KCL, self).__init__() 34 | self.K =K 35 | self.k = k 36 | self.temperature = temperature 37 | def forward(self, logits,im_labels,queue_label): 38 | logits = logits /self.temperature 39 | im_labels = im_labels.contiguous().view(-1, 1) 40 | mask = torch.eq(im_labels, queue_label).float() 41 | mask_pos_view = torch.zeros_like(mask) 42 | # positive logits from queue 43 | num_positive=self.k 44 | if num_positive > 0: 45 | for i in range(num_positive): 46 | all_pos_idxs = mask.view(-1).nonzero().view(-1) 47 | num_pos_per_anchor = mask.sum(1) 48 | num_pos_cum = num_pos_per_anchor.cumsum(0).roll(1) 49 | num_pos_cum[0] = 0 50 | rand = torch.rand(mask.size(0), device=mask.device) 51 | idxs = ((rand * num_pos_per_anchor).floor() + num_pos_cum).long() 52 | idxs = idxs[num_pos_per_anchor.nonzero().view(-1)] 53 | sampled_pos_idxs = all_pos_idxs[idxs.view(-1)] 54 | mask_pos_view.view(-1)[sampled_pos_idxs] = 1 55 | else: 56 | mask_pos_view = mask.clone() 57 | mask_pos_view_class = mask_pos_view.clone() 58 | #print(queue_label.size(1)) 59 | mask_pos_view_class[:, queue_label.size(1):] = 0 60 | 61 | mask_pos_view = torch.cat([torch.ones([mask_pos_view.shape[0], 1]).cuda(), mask_pos_view], dim=1) 62 | mask_pos_view_class = torch.cat([torch.ones([mask_pos_view_class.shape[0], 1]).cuda(), mask_pos_view_class], dim=1) 63 | log_prob = F.normalize(logits.exp(), dim=1, p=1).log() 64 | loss = - torch.sum((mask_pos_view_class * log_prob).sum(1) / mask_pos_view.sum(1)) / mask_pos_view.shape[0] 65 | return loss 66 | 67 | class SupConLoss_ccl(nn.Module): 68 | def __init__(self,K,gramma=0.2,temperature=0.07): 69 | super(SupConLoss_ccl, self).__init__() 70 | self.K = K 71 | self.gramma = gramma 72 | self.temperature = temperature 73 | def forward(self, logits, im_label,queue_label,im_cluster,queue_cluster): 74 | 75 | logits =logits /self.temperature 76 | log_prob = F.normalize(logits.exp(), dim=1, p=1).log() 77 | 78 | im_label = im_label.contiguous().view(-1, 1) 79 | label_mask = torch.eq(im_label, queue_label).float() 80 | label_mask = label_mask.clone() 81 | im_cluster = im_cluster.contiguous().view(-1, 1) 82 | cluster_mask = torch.eq(im_cluster, queue_cluster).float() 83 | cluster_mask = cluster_mask.clone() 84 | # compute logits 85 | label_mask = torch.cat([torch.ones([label_mask.shape[0], 1]).cuda(), label_mask], dim=1) 86 | cluster_mask = torch.cat([torch.ones([cluster_mask.shape[0], 1]).cuda(), cluster_mask], dim=1) 87 | loss_cluster = - torch.sum((cluster_mask * log_prob).sum(1) / cluster_mask.sum(1)) / cluster_mask.shape[0] 88 | loss_label= - torch.sum((label_mask * log_prob).sum(1) / label_mask.sum(1)) / label_mask.shape[0] 89 | # loss 90 | loss =loss_cluster + self.gramma*loss_label 91 | return loss 92 | 93 | class SupConLoss_rank(nn.Module): 94 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 95 | It also supports the unsupervised contrastive loss in SimCLR""" 96 | def __init__(self,K,temperature=0.07,ranking_temperature=0.12,grama=0.2): 97 | super(SupConLoss_rank, self).__init__() 98 | self.temperature = temperature 99 | self.ranking_temperature = ranking_temperature 100 | self.grama = grama 101 | self.K = K 102 | def forward(self, logits,im_label,queue_label,im_cluster,queue_cluster): 103 | 104 | im_label = im_label.contiguous().view(-1, 1) 105 | label_mask = torch.eq(im_label, queue_label).float() 106 | 107 | im_cluster = im_cluster.contiguous().view(-1, 1) 108 | cluster_mask = torch.eq(im_cluster, queue_cluster).float() 109 | cluster_mask_com = torch.cat([torch.ones([cluster_mask.shape[0], 1]).cuda(), cluster_mask], dim=1) 110 | logits_cluster = logits/self.temperature 111 | log_cluster_prob = F.normalize(logits_cluster.exp(), dim=1, p=1).log() 112 | loss_cluster = - torch.sum((cluster_mask_com * log_cluster_prob).sum(1) / cluster_mask_com.sum(1)) / cluster_mask_com.shape[0] 113 | 114 | Bool = ~cluster_mask.bool() 115 | Inverse_cluster =Bool.float() 116 | label_mask = Inverse_cluster * label_mask 117 | label_mask = torch.cat([torch.ones([label_mask.shape[0], 1]).cuda(), label_mask], dim=1) 118 | Inverse_cluster = torch.cat([torch.ones([Inverse_cluster.shape[0], 1]).cuda(), Inverse_cluster], dim=1) 119 | logits_label = logits/self.ranking_temperature 120 | log_inverse_cluster = torch.exp(logits_label) * Inverse_cluster 121 | log_label_prob = logits_label - torch.log(log_inverse_cluster .sum(1, keepdim=True)) 122 | loss_label = - torch.sum((label_mask* log_label_prob).sum(1) / label_mask.sum(1)) / label_mask.shape[0] 123 | loss =loss_cluster + self.grama*loss_label 124 | 125 | return loss 126 | -------------------------------------------------------------------------------- /moco/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | import moco 23 | from kmeans_gpu import kmeans 24 | from imagenet_lt_loader import ImageNetLT_moco 25 | from utils import* 26 | from loss import* 27 | 28 | model_names = sorted(name for name in models.__dict__ 29 | if name.islower() and not name.startswith("__") 30 | and callable(models.__dict__[name])) 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 33 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: resnet50)') 38 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 39 | help='number of data loading workers (default: 32)') 40 | parser.add_argument('--epochs', default=400, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=256, type=int, 45 | metavar='N', 46 | help='mini-batch size (default: 256), this is the total ' 47 | 'batch size of all GPUs on the current node when ' 48 | 'using Data Parallel or Distributed Data Parallel') 49 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 50 | metavar='LR', help='initial learning rate', dest='lr') 51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 52 | help='momentum of SGD solver') 53 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 54 | metavar='W', help='weight decay (default: 1e-4)', 55 | dest='weight_decay') 56 | parser.add_argument('--cluster', default=20, type=int, 57 | help='control cluster number') 58 | parser.add_argument('--step', default=5, type=int, 59 | help='step for updating cluster') 60 | parser.add_argument('-p', '--print-freq', default=20, type=int, 61 | metavar='N', help='print frequency (default: 10)') 62 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 63 | help='path to latest checkpoint (default: none)') 64 | parser.add_argument('--world-size', default=-1, type=int, 65 | help='number of nodes for distributed training') 66 | parser.add_argument('--rank', default=-1, type=int, 67 | help='node rank for distributed training') 68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 69 | help='url used to set up distributed training') 70 | parser.add_argument('--dist-backend', default='nccl', type=str, 71 | help='distributed backend') 72 | parser.add_argument('--seed', default=None, type=int, 73 | help='seed for initializing training. ') 74 | parser.add_argument('--gpu', default=None, type=int, 75 | help='GPU id to use.') 76 | parser.add_argument('--multiprocessing-distributed', action='store_true', 77 | help='Use multi-processing distributed training to launch ' 78 | 'N processes per node, which has N GPUs. This is the ' 79 | 'fastest way to use PyTorch for either single node or ' 80 | 'multi node data parallel training') 81 | 82 | # moco specific configs: 83 | parser.add_argument('--moco-dim', default=128, type=int, 84 | help='feature dimension (default: 128)') 85 | parser.add_argument('--moco-k', default=65536, type=int, 86 | help='queue size; number of negative keys (default: 65536)') 87 | parser.add_argument('--moco-m', default=0.999, type=float, 88 | help='moco momentum of updating key encoder (default: 0.999)') 89 | parser.add_argument('--moco-t', default=0.07, type=float, 90 | help='softmax temperature (default: 0.07)') 91 | 92 | # options for moco v2 93 | parser.add_argument('--mlp', default='True', 94 | help='use mlp head') 95 | parser.add_argument('--aug-plus', default='True', 96 | help='use moco v2 data augmentation') 97 | parser.add_argument('--cos', default='True', 98 | help='use cosine lr schedule') 99 | 100 | def main(): 101 | args = parser.parse_args() 102 | 103 | if args.seed is not None: 104 | random.seed(args.seed) 105 | torch.manual_seed(args.seed) 106 | cudnn.deterministic = True 107 | warnings.warn('You have chosen to seed training. ' 108 | 'This will turn on the CUDNN deterministic setting, ' 109 | 'which can slow down your training considerably! ' 110 | 'You may see unexpected behavior when restarting ' 111 | 'from checkpoints.') 112 | 113 | if args.gpu is not None: 114 | warnings.warn('You have chosen a specific GPU. This will completely ' 115 | 'disable data parallelism.') 116 | 117 | if args.dist_url == "env://" and args.world_size == -1: 118 | args.world_size = int(os.environ["WORLD_SIZE"]) 119 | 120 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 121 | print('distributed',args.distributed) 122 | ngpus_per_node = torch.cuda.device_count() 123 | if args.multiprocessing_distributed: 124 | # Since we have ngpus_per_node processes per node, the total world_size 125 | # needs to be adjusted accordingly 126 | args.world_size = ngpus_per_node * args.world_size 127 | # Use torch.multiprocessing.spawn to launch distributed processes: the 128 | # main_worker process function 129 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 130 | else: 131 | # Simply call main_worker function 132 | main_worker(args.gpu, ngpus_per_node, args) 133 | 134 | 135 | def main_worker(gpu, ngpus_per_node, args): 136 | args.gpu = gpu 137 | 138 | # suppress printing if not master 139 | if args.multiprocessing_distributed and args.gpu != 0: 140 | def print_pass(*args): 141 | pass 142 | builtins.print = print_pass 143 | 144 | if args.gpu is not None: 145 | print("Use GPU: {} for training".format(args.gpu)) 146 | 147 | if args.distributed: 148 | if args.dist_url == "env://" and args.rank == -1: 149 | args.rank = int(os.environ["RANK"]) 150 | if args.multiprocessing_distributed: 151 | # For multiprocessing distributed training, rank needs to be the 152 | # global rank among all the processes 153 | args.rank = args.rank * ngpus_per_node + gpu 154 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 155 | world_size=args.world_size, rank=args.rank) 156 | 157 | args.data = 'autodl-tmp/imagenet' 158 | traindir = os.path.join(args.data, 'train') 159 | txt_train = f'moco/ImageNet_LT_train.txt' 160 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 161 | 162 | if args.aug_plus: 163 | # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 164 | augmentation = [ 165 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 166 | transforms.RandomApply([ 167 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 168 | ], p=0.8), 169 | transforms.RandomGrayscale(p=0.2), 170 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 171 | transforms.RandomHorizontalFlip(), 172 | transforms.ToTensor(), 173 | normalize 174 | ] 175 | else: 176 | # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978 177 | augmentation = [ 178 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 179 | transforms.RandomGrayscale(p=0.2), 180 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 181 | transforms.RandomHorizontalFlip(), 182 | transforms.ToTensor(), 183 | normalize 184 | ] 185 | transform_train = [transforms.Compose(augmentation), transforms.Compose(augmentation)] 186 | train_dataset = ImageNetLT_moco( 187 | root=args.data, 188 | txt=txt_train, 189 | transform=transform_train) 190 | args.num_class = 1000 191 | args.cls_num_list = train_dataset.cls_num_list 192 | cluster_number= [t//max(min(args.cls_num_list),args.cluster) for t in args.cls_num_list] 193 | for index, value in enumerate(cluster_number): 194 | if value==0: 195 | cluster_number[index]=1 196 | print(cluster_number) 197 | num_cluster = sum(cluster_number) 198 | print(num_cluster) 199 | # create model 200 | print("=> creating model '{}'".format(args.arch)) 201 | model = moco.MoCo( 202 | models.__dict__[args.arch], num_cluster, 203 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp) 204 | print(model) 205 | if args.distributed: 206 | if args.gpu is not None: 207 | torch.cuda.set_device(args.gpu) 208 | model.cuda(args.gpu) 209 | args.batch_size = int(args.batch_size / ngpus_per_node) 210 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 211 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 212 | else: 213 | model.cuda() 214 | model = torch.nn.parallel.DistributedDataParallel(model) 215 | elif args.gpu is not None: 216 | torch.cuda.set_device(args.gpu) 217 | model = model.cuda(args.gpu) 218 | raise NotImplementedError("Only DistributedDataParallel is supported.") 219 | else: 220 | raise NotImplementedError("Only DistributedDataParallel is supported.") 221 | 222 | 223 | 224 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 225 | momentum=args.momentum, 226 | weight_decay=args.weight_decay) 227 | 228 | 229 | # optionally resume from a checkpoint 230 | if args.resume: 231 | if os.path.isfile(args.resume): 232 | print("=> loading checkpoint '{}'".format(args.resume)) 233 | if args.gpu is None: 234 | checkpoint = torch.load(args.resume) 235 | else: 236 | # Map model to be loaded to specified single gpu. 237 | loc = 'cuda:{}'.format(args.gpu) 238 | checkpoint = torch.load(args.resume, map_location=loc) 239 | args.start_epoch = checkpoint['epoch'] 240 | model.load_state_dict(checkpoint['state_dict']) 241 | optimizer.load_state_dict(checkpoint['optimizer']) 242 | print("=> loaded checkpoint '{}' (epoch {})" 243 | .format(args.resume, checkpoint['epoch'])) 244 | else: 245 | print("=> no checkpoint found at '{}'".format(args.resume)) 246 | 247 | cudnn.benchmark = True 248 | 249 | if args.distributed: 250 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 251 | else: 252 | train_sampler = None 253 | train_loader = torch.utils.data.DataLoader( 254 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 255 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 256 | 257 | train_loader_cluster = torch.utils.data.DataLoader( 258 | train_dataset, batch_size=args.batch_size*5, shuffle=False, 259 | num_workers=args.workers, pin_memory=True) 260 | pretrain_epochs = args.epochs // 2 261 | tsc_epochs = args.epochs - pretrain_epochs 262 | 263 | for epoch in range(args.start_epoch, pretrain_epochs): 264 | if args.distributed: 265 | train_sampler.set_epoch(epoch) 266 | adjust_learning_rate(optimizer, epoch, pretrain_epochs, args) 267 | criterion = KCL(K=args.moco_k,k=6,temperature=args.moco_t).cuda() 268 | train(train_loader, model, criterion, optimizer,epoch,args) 269 | 270 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 271 | and args.rank % ngpus_per_node == 0) : 272 | save_checkpoint({ 273 | 'epoch': epoch + 1, 274 | 'arch': args.arch, 275 | 'state_dict': model.state_dict(), 276 | 'optimizer': optimizer.state_dict(), 277 | }, is_best=False, filename='Imagenet/last.pth.tar') 278 | if (epoch + 1) % 50== 0: 279 | save_checkpoint({ 280 | 'epoch': epoch + 1, 281 | 'arch': args.arch, 282 | 'state_dict': model.state_dict(), 283 | 'optimizer': optimizer.state_dict(), 284 | }, is_best=False, filename='Imagenet/checkpoint_{:04d}.pth.tar'.format(epoch)) 285 | 286 | 287 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 288 | momentum=args.momentum, 289 | weight_decay=args.weight_decay) 290 | 291 | for epoch in range(args.start_epoch,tsc_epochs): 292 | if args.distributed: 293 | train_sampler.set_epoch(epoch) 294 | adjust_learning_rate(optimizer, epoch, tsc_epochs, args) 295 | criterion = SupConLoss_rank(K=args.moco_k,temperature=args.moco_t).cuda() 296 | if epoch % args.step == 0: 297 | targets=cluster(train_loader_cluster,model,cluster_number,args) 298 | train_dataset.new_labels = targets 299 | train(train_loader, model, criterion, optimizer,epoch+pretrain_epochs,args) 300 | 301 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 302 | and args.rank % ngpus_per_node == 0): 303 | save_checkpoint({ 304 | 'epoch': epoch + 1, 305 | 'arch': args.arch, 306 | 'state_dict': model.state_dict(), 307 | 'optimizer': optimizer.state_dict(), 308 | }, is_best=False, filename='Imagenet/last.pth.tar') 309 | if (epoch + 1) % 50== 0: 310 | save_checkpoint({ 311 | 'epoch': epoch + 1, 312 | 'arch': args.arch, 313 | 'state_dict': model.state_dict(), 314 | 'optimizer': optimizer.state_dict(), 315 | }, is_best=False, filename='Imagenet/cclcheckpoint_{:04d}.pth.tar'.format(epoch+pretrain_epochs)) 316 | def cluster (train_loader_cluster,model,cluster_number,args): 317 | model.eval() 318 | features_sum = [] 319 | print('cluster_proccess') 320 | for i, (images, target, index) in enumerate(train_loader_cluster): 321 | images = images[0].cuda(args.gpu, non_blocking=True) 322 | target = target.cuda(args.gpu, non_blocking=True) 323 | if i % 100 ==0: 324 | print(target) 325 | with torch.no_grad(): 326 | features = model(im_q=images) 327 | features = features.detach() 328 | features_sum.append(features) 329 | features= torch.cat(features_sum,dim=0) 330 | features = torch.split(features,args.cls_num_list, dim=0) 331 | target = [[] for i in range(len(cluster_number))] 332 | for i in range(len(cluster_number)): 333 | if cluster_number[i] >1: 334 | cluster_ids_x, cluster_centers = kmeans(X=features[i], num_clusters=cluster_number[i], distance='cosine', tol=1e-3, iter_limit=35, device=torch.device("cuda")) 335 | target[i]=cluster_ids_x 336 | else: 337 | target[i] = torch.zeros(1,features[i].size()[0], dtype=torch.int).squeeze(0) 338 | if i% 100 ==0: 339 | print(target[i]) 340 | cluster_number_sum=[sum(cluster_number[:i]) for i in range(len(cluster_number))] 341 | for i ,k in enumerate(cluster_number_sum): 342 | target[i] = torch.add(target[i], k) 343 | targets=torch.cat(target,dim=0) 344 | targets = targets.numpy().tolist() 345 | return targets 346 | 347 | 348 | def train(train_loader, model, criterion, optimizer, epoch, args): 349 | batch_time = AverageMeter('Time', ':6.3f') 350 | data_time = AverageMeter('Data', ':6.3f') 351 | losses = AverageMeter('Loss', ':.4e') 352 | progress = ProgressMeter( 353 | len(train_loader), 354 | [batch_time, data_time, losses], 355 | prefix="Epoch: [{}]".format(epoch)) 356 | 357 | # switch to train mode 358 | model.train() 359 | 360 | end = time.time() 361 | for i, (images,target,cluster_target) in enumerate(train_loader): 362 | # measure data loading time 363 | data_time.update(time.time() - end) 364 | 365 | if args.gpu is not None: 366 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 367 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 368 | target = target.cuda(args.gpu, non_blocking=True) 369 | cluster_target = cluster_target.cuda(args.gpu, non_blocking=True) 370 | 371 | 372 | # compute output 373 | logits,labels,true_labels = model(im_q=images[0], im_k=images[1], labels=cluster_target,true_labels=target) 374 | if epoch < args.epochs//2: 375 | loss = criterion(logits,target,true_labels) 376 | else: 377 | loss = criterion(logits,target,true_labels,cluster_target,labels) 378 | losses.update(loss.item(), images[0].size(0)) 379 | 380 | # compute gradient and do SGD step 381 | optimizer.zero_grad() 382 | loss.backward() 383 | optimizer.step() 384 | 385 | # measure elapsed time 386 | batch_time.update(time.time() - end) 387 | end = time.time() 388 | if i % args.print_freq == 0: 389 | progress.display(i) 390 | 391 | 392 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 393 | torch.save(state, filename) 394 | if is_best: 395 | shutil.copyfile(filename, 'model_best.pth.tar') 396 | 397 | 398 | def adjust_learning_rate(optimizer, epoch, total_epochs, args): 399 | """Decay the learning rate based on schedule""" 400 | lr = args.lr 401 | if args.cos: # cosine lr schedule 402 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / total_epochs)) 403 | else: # stepwise lr schedule 404 | for milestone in args.schedule: 405 | lr *= 0.1 if epoch >= milestone else 1. 406 | for param_group in optimizer.param_groups: 407 | param_group['lr'] = lr 408 | return lr 409 | 410 | 411 | 412 | if __name__ == '__main__': 413 | main() 414 | -------------------------------------------------------------------------------- /moco/moco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | 9 | class MoCo(nn.Module): 10 | """ 11 | Build a MoCo model with: a query encoder, a key encoder, and a queue 12 | https://arxiv.org/abs/1911.05722 13 | """ 14 | def __init__(self, base_encoder,num_class=1000, dim=128, K=65536, m=0.999, T=0.07, mlp=False): 15 | """ 16 | dim: feature dimension (default: 128) 17 | K: queue size; number of negative keys (default: 65536) 18 | m: moco momentum of updating key encoder (default: 0.999) 19 | T: softmax temperature (default: 0.07) 20 | """ 21 | super(MoCo, self).__init__() 22 | 23 | self.K = K 24 | self.m = m 25 | self.T = T 26 | 27 | 28 | # create the encoders 29 | # num_classes is the output fc dimension 30 | self.encoder_q = base_encoder(num_classes=dim) 31 | self.encoder_k = base_encoder(num_classes=dim) 32 | 33 | if mlp: # hack: brute-force replacement 34 | dim_mlp = self.encoder_q.fc.weight.shape[1] 35 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 36 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 37 | 38 | 39 | 40 | 41 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 42 | param_k.data.copy_(param_q.data) # initialize 43 | param_k.requires_grad = False # not update by gradient 44 | 45 | 46 | # create the queue 47 | self.register_buffer("queue", torch.randn(dim,K)) 48 | self.queue = nn.functional.normalize(self.queue, dim=0) 49 | self.register_buffer("queue_l", -torch.ones(1, K).long()) 50 | self.register_buffer("queue_t", -torch.ones(1, K).long()) 51 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 52 | 53 | @torch.no_grad() 54 | def _momentum_update_key_encoder(self): 55 | """ 56 | Momentum update of the key encoder 57 | """ 58 | 59 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 60 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 61 | 62 | @torch.no_grad() 63 | def _dequeue_and_enqueue(self, keys, labels,true_labels): 64 | # gather keys before updating queue 65 | keys = concat_all_gather(keys) 66 | labels = concat_all_gather(labels) 67 | true_labels = concat_all_gather(true_labels) 68 | 69 | batch_size = keys.shape[0] 70 | 71 | ptr = int(self.queue_ptr) 72 | assert self.K % batch_size == 0 73 | 74 | # replace the keys at ptr (dequeue and enqueue) 75 | self.queue[:, ptr:ptr + batch_size] = keys.T 76 | self.queue_l[:,ptr:ptr + batch_size] = labels.T 77 | self.queue_t[:,ptr:ptr + batch_size] = true_labels.T 78 | 79 | ptr = (ptr + batch_size) % self.K # move pointer 80 | 81 | self.queue_ptr[0] = ptr 82 | 83 | @torch.no_grad() 84 | def _batch_shuffle_ddp(self, x): 85 | """ 86 | Batch shuffle, for making use of BatchNorm. 87 | *** Only support DistributedDataParallel (DDP) model. *** 88 | """ 89 | # gather from all gpus 90 | batch_size_this = x.shape[0] 91 | x_gather = concat_all_gather(x) 92 | batch_size_all = x_gather.shape[0] 93 | 94 | num_gpus = batch_size_all // batch_size_this 95 | 96 | # random shuffle index 97 | idx_shuffle = torch.randperm(batch_size_all).cuda() 98 | 99 | # broadcast to all gpus 100 | torch.distributed.broadcast(idx_shuffle, src=0) 101 | 102 | # index for restoring 103 | idx_unshuffle = torch.argsort(idx_shuffle) 104 | 105 | # shuffled index for this gpu 106 | gpu_idx = torch.distributed.get_rank() 107 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 108 | 109 | return x_gather[idx_this], idx_unshuffle 110 | 111 | @torch.no_grad() 112 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 113 | """ 114 | Undo batch shuffle. 115 | *** Only support DistributedDataParallel (DDP) model. *** 116 | """ 117 | # gather from all gpus 118 | batch_size_this = x.shape[0] 119 | x_gather = concat_all_gather(x) 120 | batch_size_all = x_gather.shape[0] 121 | 122 | num_gpus = batch_size_all // batch_size_this 123 | 124 | # restored index for this gpu 125 | gpu_idx = torch.distributed.get_rank() 126 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 127 | 128 | 129 | return x_gather[idx_this] 130 | 131 | def forward(self, im_q, im_k=None, labels=None,true_labels= None): 132 | """ 133 | Input: 134 | im_q: a batch of query images 135 | im_k: a batch of key images 136 | Output: 137 | logits, targets 138 | """ 139 | # compute key features 140 | if im_k!=None: 141 | q = self.encoder_q(im_q) # queries: NxC 142 | q = nn.functional.normalize(q, dim=1) 143 | with torch.no_grad(): # no gradient to keys 144 | self._momentum_update_key_encoder() # update the key encoder 145 | 146 | # shuffle for making use of BN 147 | im_k,idx_unshuffle = self._batch_shuffle_ddp(im_k) 148 | 149 | k = self.encoder_k(im_k) # keys: NxC 150 | k = nn.functional.normalize(k, dim=1) 151 | 152 | # undo shuffle 153 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 154 | 155 | 156 | 157 | # compute logits 158 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 159 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 160 | logits = torch.cat([l_pos, l_neg], dim=1) 161 | queue_label=self.queue_l.clone().detach() 162 | queue_true_label = self.queue_t.clone().detach() 163 | self._dequeue_and_enqueue(k,labels,true_labels) 164 | 165 | return logits,queue_label,queue_true_label 166 | else: 167 | k = self.encoder_k(im_q) 168 | k = nn.functional.normalize(k, dim=1) 169 | return k 170 | 171 | 172 | 173 | 174 | 175 | # utils 176 | @torch.no_grad() 177 | def concat_all_gather(tensor): 178 | """ 179 | Performs all_gather operation on the provided tensors. 180 | *** Warning ***: torch.distributed.all_gather has no gradient. 181 | """ 182 | tensors_gather = [torch.ones_like(tensor) 183 | for _ in range(torch.distributed.get_world_size())] 184 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 185 | 186 | output = torch.cat(tensors_gather, dim=0) 187 | return output 188 | -------------------------------------------------------------------------------- /moco/sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | import torch 5 | 6 | class BalancedDatasetSampler(torch.utils.data.sampler.Sampler): 7 | 8 | def __init__(self, dataset, indices=None, num_samples=None): 9 | 10 | # if indices is not provided, 11 | # all elements in the dataset will be considered 12 | self.indices = list(range(len(dataset))) \ 13 | if indices is None else indices 14 | 15 | # if num_samples is not provided, 16 | # draw `len(indices)` samples in each iteration 17 | self.num_samples = len(self.indices) \ 18 | if num_samples is None else num_samples 19 | 20 | # distribution of classes in the dataset 21 | label_to_count = [0] * len(np.unique(dataset.targets)) 22 | for idx in self.indices: 23 | label = self._get_label(dataset, idx) 24 | label_to_count[label] += 1 25 | 26 | 27 | 28 | 29 | per_cls_weights = 1 / np.array(label_to_count) 30 | 31 | # weight for each sample 32 | weights = [per_cls_weights[self._get_label(dataset, idx)] 33 | for idx in self.indices] 34 | 35 | 36 | self.weights = torch.DoubleTensor(weights) 37 | 38 | def _get_label(self, dataset, idx): 39 | return dataset.targets[idx] 40 | 41 | def __iter__(self): 42 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist()) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | class EffectNumSampler(torch.utils.data.sampler.Sampler): 48 | 49 | def __init__(self, dataset, indices=None, num_samples=None): 50 | 51 | # if indices is not provided, 52 | # all elements in the dataset will be considered 53 | self.indices = list(range(len(dataset))) \ 54 | if indices is None else indices 55 | 56 | # if num_samples is not provided, 57 | # draw `len(indices)` samples in each iteration 58 | self.num_samples = len(self.indices) \ 59 | if num_samples is None else num_samples 60 | 61 | # distribution of classes in the dataset 62 | label_to_count = [0] * len(np.unique(dataset.targets)) 63 | for idx in self.indices: 64 | label = self._get_label(dataset, idx) 65 | label_to_count[label] += 1 66 | 67 | 68 | 69 | beta = 0.9999 70 | effective_num = 1.0 - np.power(beta, label_to_count) 71 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 72 | 73 | # weight for each sample 74 | weights = [per_cls_weights[self._get_label(dataset, idx)] 75 | for idx in self.indices] 76 | 77 | 78 | self.weights = torch.DoubleTensor(weights) 79 | 80 | def _get_label(self, dataset, idx): 81 | return dataset.targets[idx] 82 | 83 | def __iter__(self): 84 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist()) 85 | 86 | def __len__(self): 87 | return self.num_samples 88 | 89 | class RandomCycleIter: 90 | 91 | def __init__ (self, data, test_mode=False): 92 | self.data_list = list(data) 93 | self.length = len(self.data_list) 94 | self.i = self.length - 1 95 | self.test_mode = test_mode 96 | 97 | def __iter__ (self): 98 | return self 99 | 100 | def __next__ (self): 101 | self.i += 1 102 | 103 | if self.i == self.length: 104 | self.i = 0 105 | if not self.test_mode: 106 | random.shuffle(self.data_list) 107 | 108 | return self.data_list[self.i] 109 | 110 | def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1): 111 | 112 | i = 0 113 | j = 0 114 | while i < n: 115 | 116 | # yield next(data_iter_list[next(cls_iter)]) 117 | 118 | if j >= num_samples_cls: 119 | j = 0 120 | 121 | if j == 0: 122 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls)) 123 | yield temp_tuple[j] 124 | else: 125 | yield temp_tuple[j] 126 | 127 | i += 1 128 | j += 1 129 | 130 | class ClassAwareSampler(torch.utils.data.sampler.Sampler): 131 | def __init__(self, data_source, num_samples_cls=4,): 132 | # pdb.set_trace() 133 | num_classes = len(np.unique(data_source.labels)) 134 | self.class_iter = RandomCycleIter(range(num_classes)) 135 | cls_data_list = [list() for _ in range(num_classes)] 136 | for i, label in enumerate(data_source.labels): 137 | cls_data_list[label].append(i) 138 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] 139 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) 140 | self.num_samples_cls = num_samples_cls 141 | 142 | def __iter__ (self): 143 | return class_aware_sample_generator(self.class_iter, self.data_iter_list, 144 | self.num_samples, self.num_samples_cls) 145 | 146 | def __len__ (self): 147 | return self.num_samples 148 | 149 | def get_sampler(): 150 | return ClassAwareSampler 151 | -------------------------------------------------------------------------------- /moco/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import shutil 3 | import os 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | from sklearn.metrics import confusion_matrix 9 | from sklearn.utils.multiclass import unique_labels 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn import Parameter 14 | import random 15 | from PIL import ImageFilter 16 | 17 | 18 | class AverageMeter(object): 19 | 20 | def __init__(self, name, fmt=':f'): 21 | self.name = name 22 | self.fmt = fmt 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def __str__(self): 38 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 39 | return fmtstr.format(**self.__dict__) 40 | 41 | 42 | def accuracy(output, target, topk=(1,)): 43 | 44 | with torch.no_grad(): 45 | maxk = max(topk) 46 | batch_size = target.size(0) 47 | 48 | _, pred = output.topk(maxk, 1, True, True) 49 | pred = pred.t() 50 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 51 | 52 | res = [] 53 | for k in topk: 54 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 55 | res.append(correct_k.mul_(100.0 / batch_size)) 56 | return res 57 | 58 | class AddGaussianNoise(object): 59 | 60 | def __init__(self, mean=0., std=1.): 61 | self.std = std 62 | self.mean = mean 63 | 64 | def __call__(self, tensor): 65 | return tensor + torch.randn(tensor.size()) * self.std + self.mean 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 69 | 70 | 71 | class TwoCropTransform: 72 | """Create two crops of the same image""" 73 | def __init__(self, transform): 74 | self.transform = transform 75 | 76 | def __call__(self, x): 77 | return [self.transform(x), self.transform(x)] 78 | 79 | class GaussianBlur(object): 80 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 81 | 82 | def __init__(self, sigma=[.1, 2.]): 83 | self.sigma = sigma 84 | 85 | def __call__(self, x): 86 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 87 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 88 | return x 89 | 90 | 91 | class ProgressMeter(object): 92 | 93 | def __init__(self, num_batches, meters, prefix=""): 94 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 95 | self.meters = meters 96 | self.prefix = prefix 97 | 98 | def display(self, batch): 99 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 100 | entries += [str(meter) for meter in self.meters] 101 | print('\t'.join(entries)) 102 | 103 | def _get_batch_fmtstr(self, num_batches): 104 | num_digits = len(str(num_batches // 1)) 105 | fmt = '{:' + str(num_digits) + 'd}' 106 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 107 | 108 | 109 | 110 | def shot_acc (preds, labels, cls, many_shot_thr=100, low_shot_thr=20): 111 | if isinstance(preds, torch.Tensor): 112 | preds = preds.detach().cpu().numpy() 113 | labels = labels.detach().cpu().numpy() 114 | elif isinstance(preds, np.ndarray): 115 | pass 116 | else: 117 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 118 | train_class_count = cls 119 | test_class_count = [] 120 | class_correct = [] 121 | for l in np.unique(labels): 122 | test_class_count.append(len(labels[labels == l])) 123 | class_correct.append((preds[labels == l] == labels[labels == l]).sum()) 124 | 125 | many_shot = [] 126 | median_shot = [] 127 | low_shot = [] 128 | for i in range(len(train_class_count)): 129 | if train_class_count[i] > many_shot_thr: 130 | many_shot.append((class_correct[i] / test_class_count[i])) 131 | elif train_class_count[i] < low_shot_thr: 132 | low_shot.append((class_correct[i] / test_class_count[i])) 133 | else: 134 | median_shot.append((class_correct[i] / test_class_count[i])) 135 | 136 | if len(many_shot) == 0: 137 | many_shot.append(0) 138 | if len(median_shot) == 0: 139 | median_shot.append(0) 140 | if len(low_shot) == 0: 141 | low_shot.append(0) 142 | return [np.mean(many_shot), np.mean(median_shot), np.mean(low_shot)] 143 | 144 | class NormedLinear_Classifier(nn.Module): 145 | 146 | def __init__(self, num_classes=100, feat_dim=64): 147 | super(NormedLinear_Classifier, self).__init__() 148 | self.weight = Parameter(torch.Tensor(feat_dim, num_classes)) 149 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 150 | 151 | def forward(self, x, *args): 152 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 153 | return out 154 | 155 | 156 | def flatten(t): 157 | return t.reshape(t.shape[0], -1) 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /sbcl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JackHck/SBCL/3b9047a31d5c54a0ac14cde351ab557d2833611e/sbcl.jpg --------------------------------------------------------------------------------