├── poster └── poster.pdf ├── data └── part_2_kmeans.pth.tar ├── nets ├── .combclassifier.py.swp ├── __pycache__ │ └── combclassifier.cpython-37.pyc ├── parallel.py ├── l2norm.py ├── weight_predictor.py └── combclassifier.py ├── utils ├── __pycache__ │ ├── eval.cpython-37.pyc │ ├── misc.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ └── visualize.cpython-37.pyc ├── __init__.py ├── eval.py ├── early_stopping.py ├── visualize.py ├── logger.py └── misc.py ├── configs ├── __pycache__ │ └── config.cpython-37.pyc └── config.py ├── datasets ├── __pycache__ │ ├── cub200.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ └── cub200_intranoisy.cpython-37.pyc ├── __init__.py ├── cub200_intranoisy.py └── cub200.py ├── README.md ├── train_flat_classifier.py ├── train_combinatorial_classifiers.py └── partition_generation.ipynb /poster/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/poster/poster.pdf -------------------------------------------------------------------------------- /data/part_2_kmeans.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/data/part_2_kmeans.pth.tar -------------------------------------------------------------------------------- /nets/.combclassifier.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/nets/.combclassifier.py.swp -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/utils/__pycache__/eval.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /configs/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/configs/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cub200.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/datasets/__pycache__/cub200.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/utils/__pycache__/visualize.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/combclassifier.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/nets/__pycache__/combclassifier.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cub200_intranoisy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geehokim/Combinatorial-Inference/HEAD/datasets/__pycache__/cub200_intranoisy.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cub200 import CUB200 as CUB200 2 | from .cub200_intranoisy import CUB200 as IntraNoisyCUB200 3 | 4 | 5 | __all__ = ('CUB200', 'IntraNoisyCUB200', 6 | ) 7 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os 10 | import sys 11 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /nets/parallel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Parellel(nn.Module): 8 | def __init__(self, fcs): 9 | super(Parellel, self).__init__() 10 | self.fcs = nn.ModuleList(fcs) 11 | 12 | def forward(self, input): 13 | output = torch.cat([fc(input) for fc in self.fcs], dim=1) 14 | return output -------------------------------------------------------------------------------- /nets/l2norm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class L2NormLayer(nn.Module): 8 | def __init__(self, scale=True): 9 | super(L2NormLayer, self).__init__() 10 | self.scale = scale 11 | self.s = nn.Parameter(torch.FloatTensor([10])) 12 | 13 | def forward(self, input): 14 | output = self.l2_norm(input) 15 | if self.scale: 16 | output = output * self.s 17 | return output 18 | 19 | def l2_norm(self, input): 20 | input_size = input.size() 21 | buffer = torch.pow(input, 2) 22 | 23 | normp = torch.sum(buffer, 1).add_(1e-10) 24 | norm = torch.sqrt(normp) 25 | 26 | _output = torch.div(input, norm.view(-1, 1).expand_as(input)) 27 | 28 | output = _output.view(input_size) 29 | 30 | return output -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | mean = { 2 | 'cifar10': (0.4914, 0.4822, 0.4465), 3 | 'cifar100': (0.5071, 0.4867, 0.4408), 4 | 'cub200': (0.4707, 0.4601, 0.4549), 5 | 'car196': (0.4460, 0.4311, 0.4319), 6 | 'inshop': (0.7575, 0.7162, 0.7072), 7 | 'sop': (0.5461, 0.4972, 0.4565), 8 | } 9 | 10 | std = { 11 | 'cifar10': (0.2023, 0.1994, 0.2010), 12 | 'cifar100': (0.2675, 0.2565, 0.2761), 13 | 'cub200': (0.2767, 0.2760, 0.2850), 14 | 'car196': (0.2903, 0.2884, 0.2956), 15 | 'inshop': (0.2810, 0.2955, 0.2961), 16 | 'sop': (0.2867, 0.2894, 0.2994), 17 | } 18 | 19 | num_classes = { 20 | 'cifar10': 10, 21 | 'cifar100': 100, 22 | 'cub200': 200, 23 | 'car196': 196, 24 | 'inshop': 2173, 25 | 'sop': 5036, 26 | } 27 | 28 | imsize = { 29 | 'resnet': 224, 30 | 'vgg': 224, 31 | 'densenet': 224, 32 | 'inception': 299 33 | } 34 | 35 | imresize = { 36 | 'resnet': 256, 37 | 'vgg': 256, 38 | 'densenet': 256, 39 | 'inception': 330 40 | } -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy', 'accuracy2'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res 19 | 20 | def accuracy2(output, target, topk=(1,)): 21 | """Computes the precision@k for the specified values of k""" 22 | maxk = max(topk) 23 | batch_size = target.size(0) 24 | 25 | correct = output.new_zeros(1) 26 | for i in range(batch_size): 27 | correct += 1 if output[i, target[i]] == output[i].max() else 0 28 | 29 | res = [] 30 | for k in topk: 31 | res.append(correct.mul_(100.0 / batch_size)) 32 | return res -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Combinatorial Inference against Label Noise 2 | === 3 | PyTorch implementation of Combinatorial inference framework in the presence of label noise. 4 | 5 | [Combinatorial Inference against Label Noise](https://papers.nips.cc/paper/2019/hash/0cb929eae7a499e50248a3a78f7acfc7-Abstract.html) 6 | 7 | [Paul Hongsuck Seo](https://phseo.github.io/), [Geeho Kim](https://geehokim.github.io./), [Bohyung Han](https://cv.snu.ac.kr/index.php/bhhan/) 8 | 9 | ### Dependencies 10 | This repository is implemented based on [PyTorch](http://pytorch.org/) with Anaconda.
11 | 12 | 13 | ### Training models 14 | 15 | ```bash 16 | python train_combinatorial_classifier.py 17 | ``` 18 | 19 | ### Citation 20 | If you use this code in a publication, please cite our paper. 21 | 22 | ``` 23 | @inproceedings{seo2019combinatorial, 24 | title={Combinatorial inference against label noise}, 25 | author={Seo, Paul Hongsuck and Kim, Geeho and Han, Bohyung}, 26 | booktitle={NeurIPS}, 27 | year={2019} 28 | } 29 | ``` 30 | 31 | ### Acknowledgements 32 | This work is partly supported by Google AI Focused Research Award and Korean ICT R&D programs 33 | of the MSIP/IITP grant. 34 | -------------------------------------------------------------------------------- /utils/early_stopping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class EarlyStopping(object): 5 | def __init__(self, mode='min', min_delta=0, patience=10): 6 | self.mode = mode 7 | self.min_delta = min_delta 8 | self.patience = patience 9 | self.best = None 10 | self.num_bad_epochs = 0 11 | self.is_better = None 12 | self._init_is_better(mode, min_delta) 13 | 14 | if patience == 0: 15 | self.is_better = lambda a, b: True 16 | self.step = lambda a: False 17 | 18 | def step(self, metrics): 19 | if self.best is None: 20 | self.best = metrics 21 | return False 22 | 23 | if np.isnan(metrics): 24 | return True 25 | 26 | if self.is_better(metrics, self.best): 27 | self.num_bad_epochs = 0 28 | self.best = metrics 29 | else: 30 | self.num_bad_epochs += 1 31 | 32 | if self.num_bad_epochs > self.patience: 33 | return True 34 | 35 | return False 36 | 37 | def _init_is_better(self, mode, min_delta): 38 | if mode not in {'min', 'max'}: 39 | raise ValueError('mode ' + mode + ' is unknown!') 40 | if mode == 'min': 41 | self.is_better = lambda a, best: a < best - min_delta 42 | if mode == 'max': 43 | self.is_better = lambda a, best: a > best + min_delta 44 | -------------------------------------------------------------------------------- /nets/weight_predictor.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class WeightPredictor(nn.Module): 8 | def __init__(self, num_classes, num_partitions, hidden_dim=1024): 9 | super(WeightPredictor, self).__init__() 10 | self.h = nn.Parameter(torch.zeros(1, hidden_dim)) 11 | self.predictor = nn.Linear(hidden_dim, num_classes*num_partitions) 12 | 13 | self.num_classes = num_classes 14 | self.num_partitions = num_partitions 15 | 16 | def forward(self, input): 17 | normalized_weight = torch.sigmoid(self.predictor(self.h).view(1, self.num_partitions, self.num_classes)) 18 | 19 | output = (input * normalized_weight).sum(1) 20 | norm = torch.log(torch.exp(output).sum(1, keepdim=True)+1e-20) 21 | 22 | output = output - norm 23 | return output 24 | 25 | 26 | class WeightPredictorFromFeature(nn.Module): 27 | def __init__(self, num_classes, num_partitions, hidden_dim=2048): 28 | super(WeightPredictorFromFeature, self).__init__() 29 | self.predictor = nn.Linear(hidden_dim, num_classes*num_partitions) 30 | 31 | self.num_classes = num_classes 32 | self.num_partitions = num_partitions 33 | 34 | def forward(self, input): 35 | feature = input[0] 36 | probs = input[1] 37 | 38 | normalized_weight = torch.sigmoid(self.predictor(feature).view(-1, self.num_partitions, self.num_classes)) 39 | 40 | output = (probs * normalized_weight).sum(1) 41 | norm = torch.log(torch.exp(output).sum(1, keepdim=True)+1e-20) 42 | 43 | output = output - norm 44 | return output -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /nets/combclassifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from math import * 8 | import numpy as np 9 | import copy 10 | 11 | 12 | class CombinatorialClassifier(nn.Module): 13 | partition_weight = None 14 | 15 | def __init__(self, num_classes, num_partitionings, num_partitions, feature_dim, additive=False): 16 | super(CombinatorialClassifier, self).__init__() 17 | self.classifiers = nn.Linear(feature_dim, num_partitions * num_partitionings) 18 | self.num_classes = num_classes 19 | self.num_partitionings = num_partitionings 20 | self.num_partitions = num_partitions 21 | #Adds a persistent buffer to the module. 22 | #This is typically used to register a buffer that should not to be considered a model parameter. 23 | #For example, BatchNorm’s running_mean is not a parameter, but is part of the persistent state. 24 | 25 | self.register_buffer('partitionings', -torch.ones(num_partitionings, num_classes).long()) 26 | 27 | self.additive = additive 28 | 29 | def set_partitionings(self, partitionings_map): 30 | self.partitionings.copy_(torch.LongTensor(partitionings_map).t()) 31 | arange = torch.arange(self.num_partitionings).view(-1, 1).type_as(self.partitionings) 32 | #arange를 더해준다.? -> 01110, 23332 33 | self.partitionings.add_(arange * self.num_partitions) 34 | 35 | def rescale_grad(self): 36 | for params in self.classifiers.parameters(): 37 | if self.partition_weight is None: 38 | params.grad.mul_(self.num_partitionings) 39 | else: 40 | params.grad.mul_(self.partition_weight.sum()) 41 | 42 | def forward(self, input, weight=None, output_sum=True, return_meta_dist=False, with_feat=False): 43 | assert self.partitionings.sum() > 0, 'Partitionings is never given to the module.' 44 | 45 | all_output = self.classifiers(input) 46 | all_output = all_output.view(-1, self.num_partitionings, self.num_partitions) 47 | all_output = F.log_softmax(all_output, dim=2) 48 | if return_meta_dist: 49 | return all_output 50 | all_output = all_output.view(-1, self.num_partitionings * self.num_partitions) 51 | output = all_output.index_select(1, self.partitionings.view(-1)) 52 | output = output.view(-1, self.num_partitionings, self.num_classes) 53 | 54 | if weight is not None: 55 | weight = weight.view(1, -1, 1) 56 | output = output * weight 57 | self.partition_weight = weight 58 | 59 | if output_sum: 60 | output = output.sum(1) 61 | 62 | if with_feat: 63 | return input, output 64 | return output 65 | 66 | 67 | class EnsembleClassifier(nn.Module): 68 | def __init__(self, num_classes, num_ensembles, feature_dim, additive=False): 69 | super(EnsembleClassifier, self).__init__() 70 | self.classifiers = nn.Linear(feature_dim, num_classes * num_ensembles) 71 | self.num_classes = num_classes 72 | self.num_ensembles = num_ensembles 73 | 74 | self.additive = additive 75 | 76 | def rescale_grad(self): 77 | for params in self.classifiers.parameters(): 78 | params.grad.mul_(self.num_ensembles) 79 | 80 | def forward(self, input, weight=None): 81 | all_output = self.classifiers(input) 82 | if self.additive and False: 83 | raise NotImplementedError 84 | 85 | all_output = all_output.view(-1, self.num_partitionings, self.num_partitions) 86 | all_output = F.softmax(all_output, dim=2).view(-1, self.num_partitionings * self.num_partitions) 87 | output = all_output.index_select(1, self.partitionings.view(-1)) 88 | 89 | output = output.view(-1, self.num_partitionings, self.num_classes) 90 | _sum = output.sum(dim=2, keepdim=True) 91 | output /= _sum.detach() 92 | 93 | # output = all_output.index_select(1, self.partitionings.view(-1)) 94 | # output = F.softmax(output.view(-1, self.num_partitionings, self.num_classes), dim=236_comb_fromZeroNoise) 95 | if weight is None: 96 | output = output.sum(1) / self.num_partitionings 97 | else: 98 | weight = weight.view(1, -1, 1) 99 | output = output * weight 100 | output = output.sum(1) 101 | 102 | output = torch.log(output) 103 | else: 104 | all_output = all_output.view(-1, self.num_ensembles, self.num_classes) 105 | output = F.log_softmax(all_output, dim=2).sum(1) 106 | 107 | return output 108 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | 24 | class Logger(object): 25 | '''Save training process to log file with simple plot function.''' 26 | def __init__(self, fpath, title=None, resume=False): 27 | self.file = None 28 | self.resume = resume 29 | self.title = '' if title == None else title 30 | if fpath is not None: 31 | if resume: 32 | self.file = open(fpath, 'r') 33 | name = self.file.readline() 34 | self.names = name.rstrip().split('\t') 35 | self.numbers = {} 36 | for _, name in enumerate(self.names): 37 | self.numbers[name] = [] 38 | 39 | for numbers in self.file: 40 | numbers = numbers.rstrip().split('\t') 41 | for i in range(0, len(numbers)): 42 | self.numbers[self.names[i]].append(float(numbers[i])) 43 | self.file.close() 44 | self.file = open(fpath, 'a') 45 | else: 46 | self.file = open(fpath, 'w') 47 | 48 | def set_names(self, names): 49 | if self.resume: 50 | pass 51 | # initialize numbers as empty list 52 | self.numbers = {} 53 | self.names = names 54 | for _, name in enumerate(self.names): 55 | self.file.write(name) 56 | self.file.write('\t') 57 | self.numbers[name] = [] 58 | self.file.write('\n') 59 | self.file.flush() 60 | 61 | 62 | def append(self, numbers): 63 | assert len(self.names) == len(numbers), 'Numbers do not match names' 64 | for index, num in enumerate(numbers): 65 | if num < 1e-4: 66 | self.file.write("{0:e}".format(num)) 67 | else: 68 | self.file.write("{0:.6f}".format(num)) 69 | self.file.write('\t') 70 | self.numbers[self.names[index]].append(num) 71 | self.file.write('\n') 72 | self.file.flush() 73 | 74 | def plot(self, names=None): 75 | plt.figure() 76 | names = self.names if names == None else names 77 | numbers = self.numbers 78 | for _, name in enumerate(names): 79 | x = np.arange(len(numbers[name])) 80 | plt.plot(x, np.asarray(numbers[name])) 81 | plt.legend([self.title + '(' + name + ')' for name in names]) 82 | plt.grid(True) 83 | 84 | def close(self): 85 | if self.file is not None: 86 | self.file.close() 87 | 88 | class LoggerMonitor(object): 89 | '''Load and visualize multiple logs.''' 90 | def __init__ (self, paths): 91 | '''paths is a distionary with {name:filepath} pair''' 92 | self.loggers = [] 93 | for title, path in paths.items(): 94 | logger = Logger(path, title=title, resume=True) 95 | self.loggers.append(logger) 96 | 97 | def plot(self, names=None): 98 | plt.figure() 99 | plt.subplot(121) 100 | legend_text = [] 101 | for logger in self.loggers: 102 | legend_text += plot_overlap(logger, names) 103 | plt.legend(legend_text) #, bbox_to_anchor=(1.05, 1), loc=236_comb_fromZeroNoise, borderaxespad=0.) 104 | plt.grid(True) 105 | 106 | if __name__ == '__main__': 107 | # # Example 108 | # logger = Logger('test.txt') 109 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 110 | 111 | # length = 100 112 | # t = np.arange(length) 113 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 116 | 117 | # for i in range(0, length): 118 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 119 | # logger.plot() 120 | 121 | # Example: logger monitor 122 | paths = { 123 | 'resadvnet20':'/home/paul/ssd_data/pytorch-classification/checkpoint/log.txt', 124 | } 125 | 126 | field = ['Valid Acc.'] 127 | 128 | monitor = LoggerMonitor(paths) 129 | monitor.plot(names=field) 130 | savefig('test.pdf') -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | 16 | import numpy as np 17 | 18 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter', 'ConfusionMeter'] 19 | 20 | 21 | def get_mean_and_std(dataset): 22 | '''Compute the mean and std value of dataset.''' 23 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 24 | 25 | mean = torch.zeros(3) 26 | std = torch.zeros(3) 27 | print('==> Computing mean and std..') 28 | for inputs, targets in dataloader: 29 | for i in range(3): 30 | mean[i] += inputs[:,i,:,:].mean() 31 | std[i] += inputs[:,i,:,:].std() 32 | mean.div_(len(dataset)) 33 | std.div_(len(dataset)) 34 | return mean, std 35 | 36 | def init_params(net): 37 | '''Init layer parameters.''' 38 | for m in net.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | init.kaiming_normal(m.weight, mode='fan_out') 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | init.constant(m.weight, 1) 45 | init.constant(m.bias, 0) 46 | elif isinstance(m, nn.Linear): 47 | init.normal(m.weight, std=1e-3) 48 | if m.bias: 49 | init.constant(m.bias, 0) 50 | 51 | def mkdir_p(path): 52 | '''make dir if not exist''' 53 | try: 54 | os.makedirs(path) 55 | except OSError as exc: # Python >236_comb_fromZeroNoise.5 56 | if exc.errno == errno.EEXIST and os.path.isdir(path): 57 | pass 58 | else: 59 | raise 60 | 61 | class AverageMeter(object): 62 | """Computes and stores the average and current value 63 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 64 | """ 65 | def __init__(self): 66 | self.reset() 67 | 68 | def reset(self): 69 | self.val = 0 70 | self.avg = 0 71 | self.sum = 0 72 | self.count = 0 73 | 74 | def update(self, val, n=1): 75 | self.val = val 76 | self.sum += val * n 77 | self.count += n 78 | self.avg = self.sum / self.count 79 | 80 | 81 | class ConfusionMeter(object): 82 | """Maintains a confusion matrix for a given calssification problem. 83 | The ConfusionMeter constructs a confusion matrix for a multi-class 84 | classification problems. It does not support multi-label, multi-class problems: 85 | for such problems, please use MultiLabelConfusionMeter. 86 | Args: 87 | k (int): number of classes in the classification problem 88 | normalized (boolean): Determines whether or not the confusion matrix 89 | is normalized or not 90 | """ 91 | 92 | def __init__(self, k, normalized=False): 93 | super(ConfusionMeter, self).__init__() 94 | self.conf = np.ndarray((k, k), dtype=np.int32) 95 | self.normalized = normalized 96 | self.k = k 97 | self.reset() 98 | 99 | def reset(self): 100 | self.conf.fill(0) 101 | 102 | def update(self, predicted, target): 103 | """Computes the confusion matrix of K x K size where K is no of classes 104 | Args: 105 | predicted (tensor): Can be an N x K tensor of predicted scores obtained from 106 | the model for N examples and K classes or an N-tensor of 107 | integer values between 0 and K-1. 108 | target (tensor): Can be a N-tensor of integer values assumed to be integer 109 | values between 0 and K-1 or N x K tensor, where targets are 110 | assumed to be provided as one-hot vectors 111 | """ 112 | predicted = predicted.cpu().numpy() 113 | target = target.cpu().numpy() 114 | 115 | assert predicted.shape[0] == target.shape[0], \ 116 | 'number of targets and predicted outputs do not match' 117 | 118 | if np.ndim(predicted) != 1: 119 | assert predicted.shape[1] == self.k, \ 120 | 'number of predictions does not match size of confusion matrix' 121 | predicted = np.argmax(predicted, 1) 122 | else: 123 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \ 124 | 'predicted values are not between 1 and k' 125 | 126 | onehot_target = np.ndim(target) != 1 127 | if onehot_target: 128 | assert target.shape[1] == self.k, \ 129 | 'Onehot target does not match size of confusion matrix' 130 | assert (target >= 0).all() and (target <= 1).all(), \ 131 | 'in one-hot encoding, target values should be 0 or 1' 132 | assert (target.sum(1) == 1).all(), \ 133 | 'multi-label setting is not supported' 134 | target = np.argmax(target, 1) 135 | else: 136 | assert (predicted.max() < self.k) and (predicted.min() >= 0), \ 137 | 'predicted values are not between 0 and k-1' 138 | 139 | # hack for bincounting 236_comb_fromZeroNoise arrays together 140 | x = predicted + self.k * target 141 | bincount_2d = np.bincount(x.astype(np.int32), 142 | minlength=self.k ** 2) 143 | assert bincount_2d.size == self.k ** 2 144 | conf = bincount_2d.reshape((self.k, self.k)) 145 | 146 | self.conf += conf 147 | 148 | def value(self): 149 | """ 150 | Returns: 151 | Confustion matrix of K rows and K columns, where rows corresponds 152 | to ground-truth targets and columns corresponds to predicted 153 | targets. 154 | """ 155 | if self.normalized: 156 | conf = self.conf.astype(np.float32) 157 | return conf / conf.sum(1).clip(min=1e-12)[:, None] 158 | else: 159 | return self.conf 160 | -------------------------------------------------------------------------------- /datasets/cub200_intranoisy.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import errno 4 | import numpy as np 5 | from PIL import Image 6 | import torch.utils.data as data 7 | import contextlib 8 | 9 | 10 | @contextlib.contextmanager 11 | def temp_seed(seed): 12 | state = np.random.get_state() 13 | np.random.seed(seed) 14 | try: 15 | yield 16 | finally: 17 | np.random.set_state(state) 18 | 19 | 20 | def pil_loader(path): 21 | with open(path, 'rb') as f: 22 | with Image.open(f) as img: 23 | return img.convert('RGB') 24 | 25 | 26 | def accimage_loader(path): 27 | import torchvision.datasets.accimage as accimage 28 | try: 29 | return accimage.Image(path) 30 | except IOError: 31 | return pil_loader(path) 32 | 33 | 34 | def default_loader(path): 35 | from torchvision import get_image_backend 36 | if get_image_backend() == 'accimage': 37 | return accimage_loader(path) 38 | else: 39 | return pil_loader(path) 40 | 41 | 42 | def build_set(root, year, train, noise_type='pairflip', noise_rate=0.5): 43 | """ 44 | Function to return the lists of paths with the corresponding labels for the images 45 | Args: 46 | root (string): Root directory of dataset 47 | year (int): Year/version of the dataset. Available options are 2010 and 2011 48 | train (bool, optional): If true, returns the list pertaining to training images and labels, else otherwise 49 | Returns: 50 | return_list: list of 236_comb_fromZeroNoise-tuples with 1st location specifying path and 2nd location specifying the class 51 | """ 52 | if year == 2010: 53 | images_file_path = os.path.join(root, 'images/') 54 | 55 | if train: 56 | lists_path = os.path.join(root, 'lists/train.txt') 57 | else: 58 | lists_path = os.path.join(root, 'lists/test.txt') 59 | 60 | files = np.genfromtxt(lists_path, dtype=str) 61 | 62 | imgs = [] 63 | classes = [] 64 | class_to_idx = [] 65 | 66 | for fname in files: 67 | full_path = os.path.join(images_file_path, fname) 68 | imgs.append((full_path, int(fname[0:3]) - 1)) 69 | if os.path.split(fname)[0][4:] not in classes: 70 | classes.append(os.path.split(fname)[0][4:]) 71 | class_to_idx.append(int(fname[0:3]) - 1) 72 | 73 | return imgs, classes, class_to_idx 74 | 75 | elif year == 2011: 76 | images_file_path = os.path.join(root, 'CUB_200_2011/images/') 77 | 78 | all_images_list_path = os.path.join(root, 'CUB_200_2011/images.txt') 79 | all_images_list = np.genfromtxt(all_images_list_path, dtype=str) 80 | train_test_list_path = os.path.join(root, 'CUB_200_2011/train_test_split.txt') 81 | train_test_list = np.genfromtxt(train_test_list_path, dtype=int) 82 | 83 | imgs = [] 84 | classes = [] 85 | class_to_idx = [] 86 | 87 | for i in range(0, len(all_images_list)): 88 | fname = all_images_list[i, 1] 89 | full_path = os.path.join(images_file_path, fname) 90 | if train_test_list[i, 1] == 1 and train: 91 | imgs.append((full_path, int(fname[0:3]) - 1)) 92 | elif train_test_list[i, 1] == 0 and not train: 93 | imgs.append((full_path, int(fname[0:3]) - 1)) 94 | if os.path.split(fname)[0][4:] not in classes: 95 | classes.append(os.path.split(fname)[0][4:]) 96 | class_to_idx.append(int(fname[0:3]) - 1) 97 | 98 | if train: 99 | with temp_seed(0): 100 | if noise_type == 'pairflip': 101 | label_noises = np.random.multinomial(1, [0.55, 0.45], len(imgs)).argmax(1) 102 | else: 103 | noise_rate_per_class = noise_rate / (200-1) 104 | noise_dist = np.ones(200) * noise_rate_per_class 105 | noise_dist[0] = 1 - noise_rate 106 | label_noises = np.random.multinomial(1, noise_dist, len(imgs)).argmax(1) 107 | 108 | imgs = [(x, (y + label_noises[i]) % 200) for i, (x, y) in enumerate(imgs)] 109 | 110 | return imgs, classes, class_to_idx 111 | 112 | 113 | class CUB200(data.Dataset): 114 | """`CUB200 `_ Dataset. 115 | `CUB200 `_ Dataset. 116 | Args: 117 | root (string): Root directory of dataset the images and corresponding lists exist 118 | inside raw folder 119 | train (bool, optional): If True, creates dataset from ``training.pt``, 120 | otherwise from ``test.pt``. 121 | download (bool, optional): If true, downloads the dataset from the internet and 122 | puts it in root directory. If dataset is already downloaded, it is not 123 | downloaded again. 124 | transform (callable, optional): A function/transform that takes in an PIL image 125 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 126 | target_transform (callable, optional): A function/transform that takes in the 127 | target and transforms it. 128 | year (int): Year/version of the dataset. Available options are 2010 and 2011 129 | """ 130 | urls = [] 131 | raw_folder = 'raw' 132 | 133 | def __init__(self, root, year, train=True, transform=None, target_transform=None, download=False, 134 | loader=default_loader, noise_type='pairflip', noise_rate=0.45): 135 | self.root = os.path.expanduser(root) 136 | self.transform = transform 137 | self.target_transform = target_transform 138 | self.train = train 139 | self.year = year 140 | self.loader = loader 141 | 142 | assert year == 2010 or year == 2011, "Invalid version of CUB200 dataset" 143 | if year == 2010: 144 | self.urls = ['http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz', 145 | 'http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz'] 146 | 147 | elif year == 2011: 148 | self.urls = ['http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'] 149 | 150 | if download: 151 | self.download() 152 | 153 | if not self._check_exists(): 154 | raise RuntimeError('Dataset not found. You can use download=True to download it') 155 | 156 | self.imgs, self.classes, self.class_to_idx = build_set(os.path.join(self.root, self.raw_folder), 157 | self.year, self.train, noise_type, noise_rate) 158 | 159 | def __getitem__(self, index): 160 | """ 161 | Args: 162 | index (int): Index 163 | Returns: 164 | tuple: (image, target) where target is index of the target class. 165 | """ 166 | # path, target = self.imgs[index] 167 | # img = self.loader(path) 168 | # 169 | # if self.transform is not None: 170 | # img = self.transform(img) 171 | # 172 | # if self.target_transform is not None: 173 | # img = self.target_transform(img) 174 | # 175 | # return img, target 176 | 177 | path = self.imgs[index][0] 178 | img = self.loader(path) 179 | 180 | if self.transform is not None: 181 | img = self.transform(img) 182 | 183 | if self.target_transform is not None: 184 | img = self.target_transform(img) 185 | 186 | return (img, *self.imgs[index][1:]) 187 | 188 | 189 | def _check_exists(self): 190 | pth = os.path.join(self.root, self.raw_folder) 191 | if self.year == 2010: 192 | return os.path.exists(os.path.join(pth, 'images/')) and os.path.exists(os.path.join(pth, 'lists/')) 193 | elif self.year == 2011: 194 | return os.path.exists(os.path.join(pth, 'CUB_200_2011/')) 195 | 196 | def __len__(self): 197 | return len(self.imgs) 198 | 199 | def download(self): 200 | from six.moves import urllib 201 | import tarfile 202 | 203 | if self._check_exists(): 204 | return 205 | 206 | try: 207 | os.makedirs(os.path.join(self.root, self.raw_folder)) 208 | except OSError as e: 209 | if e.errno == errno.EEXIST: 210 | pass 211 | else: 212 | raise 213 | 214 | for url in self.urls: 215 | print('Downloading ' + url) 216 | data = urllib.request.urlopen(url) 217 | filename = url.rpartition('/')[2] 218 | file_path = os.path.join(self.root, self.raw_folder, filename) 219 | with open(file_path, 'wb') as f: 220 | f.write(data.read()) 221 | tar = tarfile.open(file_path, 'r') 222 | for item in tar: 223 | tar.extract(item, file_path.replace(filename, '')) 224 | os.unlink(file_path) 225 | 226 | print('Done!') 227 | -------------------------------------------------------------------------------- /datasets/cub200.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import errno 4 | import numpy as np 5 | from PIL import Image 6 | import torch.utils.data as data 7 | 8 | 9 | def pil_loader(path): 10 | with open(path, 'rb') as f: 11 | with Image.open(f) as img: 12 | return img.convert('RGB') 13 | 14 | 15 | def accimage_loader(path): 16 | import torchvision.datasets.accimage as accimage 17 | try: 18 | return accimage.Image(path) 19 | except IOError: 20 | return pil_loader(path) 21 | 22 | 23 | def default_loader(path): 24 | from torchvision import get_image_backend 25 | if get_image_backend() == 'accimage': 26 | return accimage_loader(path) 27 | else: 28 | return pil_loader(path) 29 | 30 | 31 | def build_set(root, year, train, split_by, subsample, return_unsup=False, return_flags=False): 32 | """ 33 | Function to return the lists of paths with the corresponding labels for the images 34 | Args: 35 | root (string): Root directory of dataset 36 | year (int): Year/version of the dataset. Available options are 2010 and 2011 37 | train (bool, optional): If true, returns the list pertaining to training images and labels, else otherwise 38 | Returns: 39 | return_list: list of 236_comb_fromZeroNoise-tuples with 1st location specifying path and 2nd location specifying the class 40 | """ 41 | if year == 2010: 42 | images_file_path = os.path.join(root, 'images/') 43 | 44 | if train: 45 | lists_path = os.path.join(root, 'lists/train.txt') 46 | else: 47 | lists_path = os.path.join(root, 'lists/test.txt') 48 | 49 | files = np.genfromtxt(lists_path, dtype=str) 50 | 51 | imgs = [] 52 | classes = [] 53 | class_to_idx = [] 54 | 55 | for fname in files: 56 | full_path = os.path.join(images_file_path, fname) 57 | imgs.append((full_path, int(fname[0:3]) - 1)) 58 | if os.path.split(fname)[0][4:] not in classes: 59 | classes.append(os.path.split(fname)[0][4:]) 60 | class_to_idx.append(int(fname[0:3]) - 1) 61 | 62 | return imgs, classes, class_to_idx 63 | 64 | elif year == 2011: 65 | images_file_path = os.path.join(root, 'CUB_200_2011/images/') 66 | 67 | all_images_list_path = os.path.join(root, 'CUB_200_2011/images.txt') 68 | all_images_list = np.genfromtxt(all_images_list_path, dtype=str) 69 | train_test_list_path = os.path.join(root, 'CUB_200_2011/train_test_split.txt') 70 | train_test_list = np.genfromtxt(train_test_list_path, dtype=int) 71 | 72 | imgs = [] 73 | classes = [] 74 | class_to_idx = [] 75 | 76 | for i in range(0, len(all_images_list)): 77 | fname = all_images_list[i, 1] 78 | full_path = os.path.join(images_file_path, fname) 79 | if train_test_list[i, 1] == 1 and train: 80 | imgs.append((full_path, int(fname[0:3]) - 1)) 81 | elif train_test_list[i, 1] == 0 and not train: 82 | imgs.append((full_path, int(fname[0:3]) - 1)) 83 | if os.path.split(fname)[0][4:] not in classes: 84 | classes.append(os.path.split(fname)[0][4:]) 85 | class_to_idx.append(int(fname[0:3]) - 1) 86 | 87 | if split_by > 1: 88 | n_classes = len(classes) // split_by 89 | 90 | imgs = [(x, y, True if y < n_classes else False) for x, y in imgs] 91 | #imgs = [(x, y // 236_comb_fromZeroNoise, True if y % 236_comb_fromZeroNoise == 0 else False) for x, y in imgs] 92 | 93 | if subsample > 1: 94 | for i in range(len(imgs)): 95 | if i % subsample != 0: 96 | imgs[i] = (imgs[i][0], imgs[i][1], False) 97 | 98 | if not return_flags and (split_by > 1 or subsample > 1): 99 | if return_unsup: 100 | imgs = [(y[0], y[1]) for y in filter(lambda x: not x[2], imgs)] 101 | else: 102 | imgs = [(y[0], y[1]) for y in filter(lambda x: x[2], imgs)] 103 | 104 | return imgs, classes, class_to_idx 105 | 106 | 107 | class CUB200(data.Dataset): 108 | """`CUB200 `_ Dataset. 109 | `CUB200 `_ Dataset. 110 | Args: 111 | root (string): Root directory of dataset the images and corresponding lists exist 112 | inside raw folder 113 | train (bool, optional): If True, creates dataset from ``training.pt``, 114 | otherwise from ``test.pt``. 115 | download (bool, optional): If true, downloads the dataset from the internet and 116 | puts it in root directory. If dataset is already downloaded, it is not 117 | downloaded again. 118 | transform (callable, optional): A function/transform that takes in an PIL image 119 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 120 | target_transform (callable, optional): A function/transform that takes in the 121 | target and transforms it. 122 | year (int): Year/version of the dataset. Available options are 2010 and 2011 123 | """ 124 | urls = [] 125 | raw_folder = 'raw' 126 | 127 | def __init__(self, root, year, train=True, transform=None, target_transform=None, download=False, 128 | loader=default_loader, split_by=1, subsample=1, return_unsup=False, return_flags=False): 129 | self.root = os.path.expanduser(root) 130 | self.transform = transform 131 | self.target_transform = target_transform 132 | self.train = train 133 | self.year = year 134 | self.loader = loader 135 | self.split_by = split_by 136 | self.subsample = subsample 137 | 138 | assert year == 2010 or year == 2011, "Invalid version of CUB200 dataset" 139 | if year == 2010: 140 | self.urls = ['http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz', 141 | 'http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz'] 142 | 143 | elif year == 2011: 144 | self.urls = ['http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'] 145 | 146 | if download: 147 | self.download() 148 | 149 | if not self._check_exists(): 150 | raise RuntimeError('Dataset not found. You can use download=True to download it') 151 | 152 | self.imgs, self.classes, self.class_to_idx = build_set(os.path.join(self.root, self.raw_folder), 153 | self.year, self.train, split_by, subsample, 154 | return_unsup, return_flags) 155 | 156 | def __getitem__(self, index): 157 | """ 158 | Args: 159 | index (int): Index 160 | Returns: 161 | tuple: (image, target) where target is index of the target class. 162 | """ 163 | # path, target = self.imgs[index] 164 | # img = self.loader(path) 165 | # 166 | # if self.transform is not None: 167 | # img = self.transform(img) 168 | # 169 | # if self.target_transform is not None: 170 | # img = self.target_transform(img) 171 | # 172 | # return img, target 173 | 174 | path = self.imgs[index][0] 175 | img = self.loader(path) 176 | 177 | if self.transform is not None: 178 | img = self.transform(img) 179 | 180 | if self.target_transform is not None: 181 | img = self.target_transform(img) 182 | 183 | return (img, *self.imgs[index][1:]) 184 | 185 | 186 | def _check_exists(self): 187 | pth = os.path.join(self.root, self.raw_folder) 188 | if self.year == 2010: 189 | return os.path.exists(os.path.join(pth, 'images/')) and os.path.exists(os.path.join(pth, 'lists/')) 190 | elif self.year == 2011: 191 | return os.path.exists(os.path.join(pth, 'CUB_200_2011/')) 192 | 193 | def __len__(self): 194 | return len(self.imgs) 195 | 196 | def download(self): 197 | from six.moves import urllib 198 | import tarfile 199 | 200 | if self._check_exists(): 201 | return 202 | 203 | try: 204 | os.makedirs(os.path.join(self.root, self.raw_folder)) 205 | except OSError as e: 206 | if e.errno == errno.EEXIST: 207 | pass 208 | else: 209 | raise 210 | 211 | for url in self.urls: 212 | print('Downloading ' + url) 213 | data = urllib.request.urlopen(url) 214 | filename = url.rpartition('/')[2] 215 | file_path = os.path.join(self.root, self.raw_folder, filename) 216 | with open(file_path, 'wb') as f: 217 | f.write(data.read()) 218 | tar = tarfile.open(file_path, 'r') 219 | for item in tar: 220 | tar.extract(item, file_path.replace(filename, '')) 221 | os.unlink(file_path) 222 | 223 | print('Done!') 224 | -------------------------------------------------------------------------------- /train_flat_classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import time 5 | import random 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim as optim 14 | import torch.utils.data as data 15 | import torchvision.transforms as transforms 16 | import torchvision.models as models 17 | import datasets 18 | 19 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig 20 | 21 | import configs.config as cf 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch CIFAR-100 Training') 24 | # Datasets 25 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 26 | help='number of data loading workers (default: 4)') 27 | 28 | # Optimization options 29 | parser.add_argument('--epochs', default=40, type=int, metavar='N', 30 | help='number of total epochs to run') 31 | parser.add_argument('--train-batch', default=32, type=int, metavar='N', 32 | help='train batchsize') 33 | parser.add_argument('--test-batch', default=32, type=int, metavar='N', 34 | help='test batchsize') 35 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 36 | metavar='LR', help='initial learning rate') 37 | parser.add_argument('--schedule', type=int, nargs='+', default=[20, 30], 38 | help='Decrease learning rate at these epochs.') 39 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 41 | help='momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 43 | metavar='W', help='weight decay (default: 1e-4)') 44 | 45 | # Network architecture 46 | parser.add_argument('--net-type', default='resnet', type=str, help='model') 47 | 48 | # Experiment options 49 | parser.add_argument('--dataset', default='cub200', type=str) 50 | parser.add_argument('--method', default='intranoisyset/baseline', type=str) 51 | parser.add_argument('--seed', type=int, default=0, help='manual seed') 52 | parser.add_argument('--pretrained', action='store_true', help='use pretrained model') 53 | parser.add_argument('--noise-rate', type=float, default=0.5, help='') 54 | parser.add_argument('--noise-type', type=str, default='symmetric', help='') 55 | 56 | # Miscs 57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 58 | help='evaluate model on validation set') 59 | parser.add_argument('--gpu-id', default='0', type=str, 60 | help='id(s) for CUDA_VISIBLE_DEVICES') 61 | 62 | args = parser.parse_args() 63 | 64 | state = {k: v for k, v in args._get_kwargs()} 65 | 66 | # Validate dataset 67 | valid_datasets = ['cub200'] 68 | assert args.dataset in valid_datasets, 'Invalid dataset' 69 | assert args.noise_type in ['pairflip', 'symmetric'] 70 | 71 | # Use CUDA 72 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 73 | use_cuda = torch.cuda.is_available() 74 | 75 | random.seed(args.seed) 76 | torch.manual_seed(args.seed) 77 | if use_cuda: 78 | torch.cuda.manual_seed_all(args.seed) 79 | cudnn.benchmark = True 80 | 81 | 82 | def main(): 83 | exp_desc = 'pretrained' if args.pretrained else 'scratch' 84 | model_path = 'results/%s/%s/%s/%s/%s/noise_rate_%.2f/seed_%d' % (args.dataset, args.method, exp_desc, args.net_type, 85 | args.noise_type, args.noise_rate, args.seed) 86 | if not os.path.isdir(os.path.join(model_path, 'graphs')): 87 | mkdir_p(os.path.join(model_path, 'graphs')) 88 | 89 | # Data 90 | print('==> Preparing %s dataset' % args.dataset) 91 | transform_train = transforms.Compose([ 92 | transforms.Resize(int(cf.imresize[args.net_type])), 93 | transforms.RandomRotation(10), 94 | transforms.RandomCrop(cf.imsize[args.net_type]), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.ToTensor(), 97 | transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]), 98 | ]) 99 | 100 | transform_test = transforms.Compose([ 101 | transforms.Resize(cf.imresize[args.net_type]), 102 | transforms.CenterCrop(cf.imsize[args.net_type]), 103 | transforms.ToTensor(), 104 | transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]), 105 | ]) 106 | 107 | if args.dataset == 'cub200': 108 | trainset = datasets.IntraNoisyCUB200(root='data/'+args.dataset, year=2011, train=True, download=True, 109 | transform=transform_train, noise_type=args.noise_type, 110 | noise_rate=args.noise_rate) 111 | testset = datasets.IntraNoisyCUB200(root='data/'+args.dataset, year=2011, train=False, download=False, 112 | transform=transform_test) 113 | else: 114 | assert False 115 | 116 | # Build dataloader 117 | train_loader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, 118 | num_workers=args.workers, pin_memory=use_cuda, drop_last=True) 119 | test_loader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, 120 | num_workers=args.workers, pin_memory=use_cuda) 121 | 122 | # Construct the model 123 | print("==> creating model %s" % args.net_type) 124 | if args.net_type == 'resnet': 125 | model = models.resnet50(pretrained=args.pretrained) 126 | model.fc = nn.Linear(model.fc.in_features, cf.num_classes[args.dataset]) 127 | elif args.net_type == 'inception': 128 | model = models.inception_v3(pretrained=args.pretrained) 129 | model.aux_logits = False 130 | model.fc = nn.Linear(model.fc.in_features, cf.num_classes[args.dataset]) 131 | elif args.net_type == 'densenet': 132 | model = models.densenet161(pretrained=args.pretrained) 133 | model.classifier = nn.Linear(model.classifier.in_features, cf.num_classes[args.dataset]) 134 | elif args.net_type == 'vgg': 135 | model = models.vgg16(pretrained=args.pretrained) 136 | model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 137 | cf.num_classes[args.dataset]) 138 | else: 139 | assert False 140 | 141 | model = model.cuda() 142 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 143 | 144 | criterion = nn.CrossEntropyLoss() 145 | 146 | # Evaluation Only 147 | if args.evaluate: 148 | print('\nEvaluation only') 149 | checkpoint_path = os.path.join(model_path, 'checkpoint.pth.tar') 150 | assert os.path.isfile(checkpoint_path), 'Error: no checkpoint found!' 151 | checkpoint = torch.load(checkpoint_path) 152 | model.load_state_dict(checkpoint['state_dict']) 153 | test_loss, test_acc = test(test_loader, model, criterion, use_cuda) 154 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) 155 | 156 | return 157 | 158 | # Train model 159 | best_acc = 0 160 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 161 | 162 | # Open logger 163 | logger_path = os.path.join(model_path, 'log.txt') 164 | logger = Logger(logger_path) 165 | logger.set_names(['Learning Rate', 'Train Loss', 'Test Loss', 'Train Acc.', 'Test Acc.']) 166 | 167 | # Train and test 168 | for epoch in range(args.epochs): 169 | adjust_learning_rate(optimizer, epoch) 170 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 171 | 172 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, use_cuda) 173 | test_loss, test_acc = test(test_loader, model, criterion, use_cuda) 174 | 175 | # Logging for current iteration 176 | logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) 177 | 178 | # save model 179 | if test_acc > best_acc: 180 | best_acc = test_acc 181 | save_checkpoint({ 182 | 'epoch': epoch + 1, 183 | 'state_dict': model.state_dict(), 184 | 'acc': test_acc, 185 | 'best_acc': best_acc, 186 | 'optimizer': optimizer.state_dict(), 187 | }, dir=model_path, filename='checkpoint.pth.tar') 188 | 189 | logger.append([0, 0, 0, 0, best_acc]) 190 | logger.close() 191 | print('Best acc:', best_acc) 192 | 193 | # Draw plots 194 | logger.plot(names=['Train Loss', 'Test Loss']) 195 | savefig(os.path.join(model_path, 'graphs/loss.png')) 196 | logger.plot(names=['Train Acc.', 'Test Acc.']) 197 | savefig(os.path.join(model_path, 'graphs/acc.png')) 198 | 199 | 200 | def train(train_loader, model, criterion, optimizer, use_cuda): 201 | # switch to train mode 202 | model.train() 203 | 204 | batch_time = AverageMeter() 205 | data_time = AverageMeter() 206 | losses = AverageMeter() 207 | top1 = AverageMeter() 208 | top5 = AverageMeter() 209 | end = time.time() 210 | 211 | bar = Bar('Processing', max=len(train_loader)) 212 | for batch_idx, (inputs, targets) in enumerate(train_loader): 213 | # measure data loading time 214 | data_time.update(time.time() - end) 215 | 216 | if use_cuda: 217 | inputs, targets = inputs.cuda(), targets.cuda() 218 | 219 | # compute output 220 | outputs = model(inputs) 221 | loss = criterion(outputs, targets) 222 | 223 | # measure accuracy and record loss 224 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 225 | losses.update(loss.item(), inputs.size(0)) 226 | top1.update(prec1.item(), inputs.size(0)) 227 | top5.update(prec5.item(), inputs.size(0)) 228 | 229 | # compute gradient and do SGD step 230 | optimizer.zero_grad() 231 | loss.backward() 232 | optimizer.step() 233 | 234 | # measure elapsed time 235 | batch_time.update(time.time() - end) 236 | end = time.time() 237 | 238 | # plot progress 239 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 240 | batch=batch_idx + 1, 241 | size=len(train_loader), 242 | data=data_time.avg, 243 | bt=batch_time.avg, 244 | total=bar.elapsed_td, 245 | eta=bar.eta_td, 246 | loss=losses.avg, 247 | top1=top1.avg, 248 | top5=top5.avg, 249 | ) 250 | bar.next() 251 | bar.finish() 252 | return losses.avg, top1.avg 253 | 254 | 255 | def test(test_loader, model, criterion, use_cuda): 256 | # switch to evaluate mode 257 | model.eval() 258 | 259 | batch_time = AverageMeter() 260 | data_time = AverageMeter() 261 | losses = AverageMeter() 262 | top1 = AverageMeter() 263 | top5 = AverageMeter() 264 | 265 | end = time.time() 266 | bar = Bar('Processing', max=len(test_loader)) 267 | for batch_idx, (inputs, targets) in enumerate(test_loader): 268 | # measure data loading time 269 | data_time.update(time.time() - end) 270 | 271 | if use_cuda: 272 | inputs, targets = inputs.cuda(), targets.cuda() 273 | 274 | with torch.no_grad(): 275 | # compute output 276 | outputs = model(inputs) 277 | loss = criterion(outputs, targets) 278 | 279 | # measure accuracy and record loss 280 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 281 | losses.update(loss.data.item(), inputs.size(0)) 282 | top1.update(prec1.item(), inputs.size(0)) 283 | top5.update(prec5.item(), inputs.size(0)) 284 | 285 | # measure elapsed time 286 | batch_time.update(time.time() - end) 287 | end = time.time() 288 | 289 | # plot progress 290 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 291 | batch=batch_idx + 1, 292 | size=len(test_loader), 293 | data=data_time.avg, 294 | bt=batch_time.avg, 295 | total=bar.elapsed_td, 296 | eta=bar.eta_td, 297 | loss=losses.avg, 298 | top1=top1.avg, 299 | top5=top5.avg, 300 | ) 301 | bar.next() 302 | bar.finish() 303 | 304 | return losses.avg, top1.avg 305 | 306 | 307 | def save_checkpoint(state, dir, filename): 308 | filepath = os.path.join(dir, filename) 309 | torch.save(state, filepath) 310 | 311 | 312 | def adjust_learning_rate(optimizer, epoch): 313 | global state 314 | if epoch in args.schedule: 315 | state['lr'] *= args.gamma 316 | for param_group in optimizer.param_groups: 317 | param_group['lr'] = state['lr'] 318 | 319 | 320 | if __name__ == '__main__': 321 | main() -------------------------------------------------------------------------------- /train_combinatorial_classifiers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import time 5 | import random 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim as optim 14 | import torch.utils.data as data 15 | import torchvision.transforms as transforms 16 | import torchvision.models as models 17 | import datasets 18 | 19 | from nets.combclassifier import CombinatorialClassifier 20 | 21 | from utils import Logger, AverageMeter, accuracy, mkdir_p, savefig 22 | from progress.bar import Bar 23 | 24 | import configs.config as cf 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch CombLern Training') 27 | # Datasets 28 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 29 | help='number of data loading workers (default: 4)') 30 | 31 | # Optimization options 32 | parser.add_argument('--epochs', default=40, type=int, metavar='N', 33 | help='number of total epochs to run') 34 | parser.add_argument('--train-batch', default=32, type=int, metavar='N', 35 | help='train batchsize') 36 | parser.add_argument('--test-batch', default=32, type=int, metavar='N', 37 | help='test batchsize') 38 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 39 | metavar='LR', help='initial learning rate') 40 | 41 | # Classifier-specific optimization options (SGD) 42 | parser.add_argument('--schedule', type=int, nargs='+', default=[20, 30], 43 | help='Decrease learning rate at these epochs.') 44 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 46 | help='momentum') 47 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 48 | metavar='W', help='weight decay (default: 1e-4)') 49 | 50 | parser.add_argument('--num-partitionings', default=100, type=int, metavar='N', 51 | help='number of partitionings') 52 | parser.add_argument('--num-partitions', default=2, type=int, metavar='N', 53 | help='number of partitions') 54 | 55 | # Network architecture 56 | parser.add_argument('--net-type', default='resnet', type=str, help='model') 57 | parser.add_argument('--depth', default=50, type=int, help='depth of model') 58 | 59 | # Experiment options 60 | parser.add_argument('--dataset', default='cub200', type=str) 61 | parser.add_argument('--method', default='intranoisyset/comblearn', type=str) 62 | parser.add_argument('--seed', type=int, default=0, help='manual seed') 63 | parser.add_argument('--pretrained', action='store_true', help='use pretrained model') 64 | parser.add_argument('--exp-name', default='kmeans', type=str) 65 | parser.add_argument('--noise-rate', type=float, default=0.0, help='') 66 | parser.add_argument('--noise-type', type= str, default='symmetric', help='') 67 | 68 | # Miscs 69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 70 | help='evaluate model on validation set') 71 | parser.add_argument('--gpu-id', default='0', type=str, 72 | help='id(s) for CUDA_VISIBLE_DEVICES') 73 | 74 | args = parser.parse_args() 75 | 76 | assert args.exp_name in ['kmeans'] 77 | args.partitionings_path = './data/part_%d_%s.pth.tar' % \ 78 | ( args.num_partitions, 79 | args.exp_name) 80 | 81 | state = {k: v for k, v in args._get_kwargs()} 82 | 83 | # Validate dataset 84 | valid_datasets = ['cub200'] 85 | assert args.dataset in valid_datasets, 'Invalid dataset' 86 | assert args.noise_type in ['pairflip', 'symmetric'] 87 | 88 | # Use CUDA 89 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 90 | use_cuda = torch.cuda.is_available() 91 | 92 | np.random.seed(args.seed) 93 | random.seed(args.seed) 94 | torch.manual_seed(args.seed) 95 | if use_cuda: 96 | torch.cuda.manual_seed_all(args.seed) 97 | cudnn.benchmark = True 98 | 99 | def main(): 100 | exp_desc = 'pretrained' if args.pretrained else 'scratch' 101 | model_path = 'results/%s/%s/%s/%s/%s/part_%d_%d/%s/noise_rate_%.2f/seed_%d' % \ 102 | (args.dataset, args.method, exp_desc, args.exp_name, args.net_type, args.num_partitionings, 103 | args.num_partitions, args.noise_type, args.noise_rate, args.seed) 104 | if not os.path.isdir(os.path.join(model_path, 'graphs')): 105 | mkdir_p(os.path.join(model_path, 'graphs')) 106 | 107 | # Data 108 | print('==> Preparing %s dataset' % args.dataset) 109 | transform_train = transforms.Compose([ 110 | transforms.Resize(int(cf.imresize[args.net_type])), 111 | transforms.RandomRotation(10), 112 | transforms.RandomCrop(cf.imsize[args.net_type]), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.ToTensor(), 115 | transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]), 116 | ]) 117 | 118 | transform_test = transforms.Compose([ 119 | transforms.Resize(cf.imresize[args.net_type]), 120 | transforms.CenterCrop(cf.imsize[args.net_type]), 121 | transforms.ToTensor(), 122 | transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]), 123 | ]) 124 | 125 | if args.dataset == 'cub200': 126 | trainset = datasets.IntraNoisyCUB200(root='data/'+args.dataset, year=2011, train=True, download=True, 127 | transform=transform_train, noise_type=args.noise_type, 128 | noise_rate=args.noise_rate) 129 | testset = datasets.IntraNoisyCUB200(root='data/'+args.dataset, year=2011, train=False, download=False, 130 | transform=transform_test) 131 | else: 132 | assert False 133 | 134 | train_loader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, 135 | num_workers=args.workers, pin_memory=use_cuda, drop_last=True) 136 | test_loader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, 137 | num_workers=args.workers, pin_memory=use_cuda) 138 | 139 | # Construct the model 140 | print("==> creating model %s" % args.net_type) 141 | if args.net_type == 'resnet': 142 | _model = models.resnet50(pretrained=args.pretrained) 143 | feat_dim = _model.fc.in_features 144 | _model.fc = nn.Sequential() 145 | else: 146 | assert False 147 | 148 | print(args.partitionings_path) 149 | assert os.path.isfile(args.partitionings_path), 'Error: no partitionings found!' 150 | partitionings = torch.load(args.partitionings_path)[args.seed * args.num_partitionings 151 | :(args.seed+1)*args.num_partitionings].t() 152 | 153 | model = nn.Sequential() 154 | model.add_module('feature_extractor', _model) 155 | 156 | comb_classifier = CombinatorialClassifier(cf.num_classes[args.dataset], args.num_partitionings, 157 | args.num_partitions, feat_dim) 158 | comb_classifier.set_partitionings(partitionings) 159 | model.add_module('combinatorial_classifier', comb_classifier) 160 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 161 | 162 | if use_cuda: 163 | model.cuda() 164 | 165 | criterion = nn.NLLLoss() 166 | 167 | # Evaluation Only 168 | if args.evaluate: 169 | print('\nEvaluation only') 170 | checkpoint_path = os.path.join(model_path, 'checkpoint.pth.tar') 171 | assert os.path.isfile(checkpoint_path), 'Error: no checkpoint found!' 172 | checkpoint = torch.load(checkpoint_path) 173 | model.load_state_dict(checkpoint['state_dict']) 174 | test_loss, test_acc = test(test_loader, model, criterion, use_cuda) 175 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) 176 | return 177 | 178 | state['lr'] = args.lr 179 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 180 | 181 | # Open logger 182 | logger_path = os.path.join(model_path, 'log.txt') 183 | logger = Logger(logger_path) 184 | logger.set_names(['Learning Rate', 'Train Loss', 'Test Loss', 'Train Acc.', 'Test Acc.']) 185 | 186 | # Train and test classifier 187 | best_acc = 0 188 | for epoch in range(args.epochs): 189 | adjust_learning_rate([optimizer], epoch) 190 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 191 | 192 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, use_cuda) 193 | test_loss, test_acc = test(test_loader, model, criterion, use_cuda) 194 | 195 | # Logging for current iteration 196 | logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) 197 | 198 | # save model 199 | if test_acc > best_acc: 200 | best_acc = test_acc 201 | save_checkpoint({ 202 | 'epoch': epoch + 1, 203 | 'state_dict': model.state_dict(), 204 | 'acc': test_acc, 205 | 'best_acc': best_acc, 206 | 'partitionings': partitionings, 207 | 'optimizer': optimizer.state_dict(), 208 | }, dir=model_path, filename='checkpoint.pth.tar') 209 | 210 | logger.append([0, 0, 0, 0, best_acc]) 211 | logger.close() 212 | print('Best acc:', best_acc) 213 | 214 | # Draw plots 215 | logger.plot(names=['Train Loss', 'Test Loss']) 216 | savefig(os.path.join(model_path, 'graphs/loss.png')) 217 | logger.plot(names=['Train Acc.', 'Test Acc.']) 218 | savefig(os.path.join(model_path, 'graphs/acc.png')) 219 | 220 | 221 | # Check if each class is uniquely partitioned 222 | def check_combinations(partitionings, num_cls): 223 | dict = ['' for _ in range(num_cls)] 224 | 225 | for partitioning in partitionings: 226 | for i, p in enumerate(partitioning): 227 | for j in range(num_cls): 228 | if p[j].item() == 1: 229 | dict[j] += str(i) 230 | 231 | return num_cls - len(set(dict)) 232 | 233 | 234 | def train(train_loader, model, criterion, optimizer, use_cuda): 235 | # switch to train mode 236 | model.train() 237 | 238 | batch_time = AverageMeter() 239 | data_time = AverageMeter() 240 | losses = AverageMeter() 241 | top1 = AverageMeter() 242 | top5 = AverageMeter() 243 | end = time.time() 244 | 245 | bar = Bar('Processing', max=len(train_loader)) 246 | for batch_idx, (inputs, targets) in enumerate(train_loader): 247 | # measure data loading time 248 | data_time.update(time.time() - end) 249 | 250 | if use_cuda: 251 | inputs, targets = inputs.cuda(), targets.cuda() 252 | 253 | # compute output 254 | outputs = model(inputs) 255 | loss = criterion(outputs, targets) / args.num_partitionings 256 | 257 | # measure accuracy and record loss 258 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 259 | losses.update(loss.item(), inputs.size(0)) 260 | top1.update(prec1.item(), inputs.size(0)) 261 | top5.update(prec5.item(), inputs.size(0)) 262 | 263 | # compute gradient and do SGD step 264 | optimizer.zero_grad() 265 | loss.backward() 266 | model._modules['combinatorial_classifier'].rescale_grad() 267 | optimizer.step() 268 | 269 | # measure elapsed time 270 | batch_time.update(time.time() - end) 271 | end = time.time() 272 | 273 | # plot progress 274 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 275 | batch=batch_idx + 1, 276 | size=len(train_loader), 277 | data=data_time.avg, 278 | bt=batch_time.avg, 279 | total=bar.elapsed_td, 280 | eta=bar.eta_td, 281 | loss=losses.avg, 282 | top1=top1.avg, 283 | top5=top5.avg, 284 | ) 285 | bar.next() 286 | bar.finish() 287 | return losses.avg, top1.avg 288 | 289 | 290 | def test(test_loader, model, criterion, use_cuda): 291 | # switch to evaluate mode 292 | model.eval() 293 | 294 | batch_time = AverageMeter() 295 | data_time = AverageMeter() 296 | losses = AverageMeter() 297 | top1 = AverageMeter() 298 | top5 = AverageMeter() 299 | 300 | end = time.time() 301 | bar = Bar('Processing', max=len(test_loader)) 302 | for batch_idx, (inputs, targets) in enumerate(test_loader): 303 | # measure data loading time 304 | data_time.update(time.time() - end) 305 | 306 | if use_cuda: 307 | inputs, targets = inputs.cuda(), targets.cuda() 308 | 309 | with torch.no_grad(): 310 | # compute output 311 | outputs = model(inputs) 312 | loss = criterion(outputs, targets) / args.num_partitionings 313 | 314 | # measure accuracy and record loss 315 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 316 | losses.update(loss.data.item(), inputs.size(0)) 317 | top1.update(prec1.item(), inputs.size(0)) 318 | top5.update(prec5.item(), inputs.size(0)) 319 | 320 | # measure elapsed time 321 | batch_time.update(time.time() - end) 322 | end = time.time() 323 | 324 | # plot progress 325 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 326 | batch=batch_idx + 1, 327 | size=len(test_loader), 328 | data=data_time.avg, 329 | bt=batch_time.avg, 330 | total=bar.elapsed_td, 331 | eta=bar.eta_td, 332 | loss=losses.avg, 333 | top1=top1.avg, 334 | top5=top5.avg, 335 | ) 336 | bar.next() 337 | bar.finish() 338 | 339 | return losses.avg, top1.avg 340 | 341 | 342 | def save_checkpoint(state, dir, filename): 343 | filepath = os.path.join(dir, filename) 344 | torch.save(state, filepath) 345 | 346 | 347 | def adjust_learning_rate(optimizers, epoch): 348 | global state 349 | if epoch in args.schedule: 350 | state['lr'] *= args.gamma 351 | for optimizer in optimizers: 352 | for param_group in optimizer.param_groups: 353 | param_group['lr'] *= args.gamma 354 | 355 | 356 | def detach_partitionings(partitionings): 357 | new_partitionings = [] 358 | for partitioning in partitionings: 359 | new_partitioning = [] 360 | for partition in partitioning: 361 | new_partitioning.append(partition.detach()) 362 | new_partitionings.append(new_partitioning) 363 | 364 | 365 | if __name__ == '__main__': 366 | main() 367 | -------------------------------------------------------------------------------- /partition_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "pycharm": {} 8 | }, 9 | "outputs": [], 10 | "source": "%matplotlib inline\nimport matplotlib\nimport matplotlib.pyplot as plt\n\nimport numpy as np\nfrom sklearn.manifold import TSNE\nfrom sklearn.cluster import KMeans\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.data as data\nimport torch.distributions as dist\nimport torchvision.models as models\n\n\nimport os\n\nfrom datasets import NoisyCUB200\nos.environ[\u0027CUDA_VISIBLE_DEVICES\u0027] \u003d \u00270\u0027" 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": { 16 | "pycharm": {} 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "dataset \u003d \u0027cub200\u0027\n", 21 | "n_cls \u003d 200\n", 22 | "net_type \u003d \u0027resnet\u0027\n", 23 | "noise_type \u003d \u0027symmetric\u0027#\u0027pairflip\u0027 #\u0027symmetric\u0027\n", 24 | "noise_rate \u003d 0#0.45 #0.5\n", 25 | "checkpoint_path \u003d \"../results/%s/intranoisyset/baseline/pretrained/%s/%s/noise_rate_%.2f/seed_0/checkpoint.pth.tar\" % (dataset, net_type, noise_type, noise_rate)" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "metadata": { 32 | "pycharm": {} 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "n_ex_per_cls \u003d [0] * n_cls\n", 37 | "cls_feats \u003d None" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "metadata": { 44 | "pycharm": {} 45 | }, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Training set size: 3000\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "if dataset \u003d\u003d \u0027cub200\u0027:\n", 57 | " trainset\u003d NoisyCUB200(root\u003d\u0027../data/\u0027+dataset, year\u003d2011, train\u003dTrue, noise_rate\u003dnoise_rate)\n", 58 | "elif dataset \u003d\u003d \u0027sop\u0027:\n", 59 | " trainset \u003d datasets.StanfordOnlineProduct(root\u003d\u0027../data/\u0027+dataset, train\u003dTrue)\n", 60 | "print(\u0027Training set size:\u0027, len(trainset))" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": { 67 | "pycharm": {} 68 | }, 69 | "outputs": [], 70 | "source": [] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": { 76 | "pycharm": {} 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "if net_type \u003d\u003d \u0027resnet\u0027:\n", 81 | " model \u003d models.resnet50()\n", 82 | " model.fc \u003d nn.Linear(model.fc.in_features, n_cls)\n", 83 | "elif net_type \u003d\u003d \u0027densenet\u0027:\n", 84 | " model \u003d models.densenet161()\n", 85 | " model.classifier \u003d nn.Linear(model.classifier.in_features, n_cls // num_splits)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 6, 91 | "metadata": { 92 | "pycharm": {} 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "checkpoint \u003d torch.load(checkpoint_path)\n", 97 | "model.load_state_dict(checkpoint[\u0027state_dict\u0027])" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": { 103 | "pycharm": {} 104 | }, 105 | "source": [ 106 | "# partitioning generation" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 7, 112 | "metadata": { 113 | "pycharm": {} 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "# settings\n", 118 | "n_parts \u003d 2\n", 119 | "n_cls_per_part_threshold \u003d 10\n", 120 | "\n", 121 | "pool_size \u003d 1000\n", 122 | "\n", 123 | "save_path \u003d \u0027../data/%s_intranoisy_partitions/%s/%s/noise_rate_%.2f\u0027 % (dataset, net_type, noise_type, noise_rate)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 8, 129 | "metadata": { 130 | "pycharm": {} 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "seed\u003d0\n", 135 | "\n", 136 | "np.random.seed(seed)\n", 137 | "torch.manual_seed(seed)\n", 138 | "torch.cuda.manual_seed_all(seed)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 9, 144 | "metadata": { 145 | "pycharm": {} 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "if net_type \u003d\u003d \u0027resnet\u0027:\n", 150 | " cls_feats \u003d model.fc.weight.data.cpu().numpy()\n", 151 | "elif net_type \u003d\u003d \u0027densenet\u0027:\n", 152 | " cls_feats \u003d model.classifier.weight.data.cpu().numpy()" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 10, 158 | "metadata": { 159 | "pycharm": {} 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJztnW2MnVd17//LzlwYQ5UBZQjJhKndKhgF0tjKKPeD1arxRXVaqsYJSoBKLVJRXaTmA1FkaVIqNRdaYZHm5l6hXlpzL4IPBWwVYtKGNhASNZJV1I5lp4khVkMSaMYWMcTTl3hIZsarH+ac5PjM8/7sl7Wf5/+TLNtnzpxnPy/nv9f+77XXFlUFIYSQ7rMpdgMIIYSEgYJPCCE9gYJPCCE9gYJPCCE9gYJPCCE9gYJPCCE9gYJPCCE9gYJPCCE9gYJPCCE94ZLYDRjlsssu061bt8ZuBiGEJMWxY8d+rKrTZe8zJfhbt27FwsJC7GYQQkhSiMgPqryPlg4hhPQECj4hhPQECj4hhPQECj4hhPQECj4hhPQEU1k6hNThyPFF3PvwKZxeWsaVU5PYv2c79u6cid0sQsxCwSdJcuT4Iu7+2pNYXlkDACwuLePurz0JABR9QnKgpUOS5N6HT70m9kOWV9Zw78OnIrWIEPtQ8EmSnF5arvU6IYSCTxLlyqnJzNc3ieDI8cVGn3nk+CJ2HXgU2+Yfwq4Djzb+HEKsQsEnSbJ/z3ZMTmze8PqaKu7+2pO1xXo4J7C4tAzF63MCFH3SJSj4JEn27pzBp269FptFNvysiZfPOQHSByj4JAg+7JK9O2dwQTXzZ3W9fM4JkD7gJC1TRD4P4NcBvKiq7xm8dg+A3wVwdvC2P1DVb7g4HkkLVymUWXn3V05NYjFDlPM8/jxcfQ4hlnEV4X8BwE0Zr9+vqjsGfyj2PcWFXZLnsd/4rukNXv7kxGbs37O9Vhuz5gSafA4hlnEi+Kr6OICXXHwW6R4u7JK8TuOxp8/iU7dei5mpSQiAmalJfOrWa2svvhrOCbT9HEIs43ul7R0i8tsAFgDcparnxt8gIvsA7AOA2dlZz80hMXBhl+R1DotLy9i7c8aJMLv6HEKs4nPS9rMAfh7ADgBnANyX9SZVPaiqc6o6Nz1dukMXSRAXdkle5yAAUycJqYg3wVfVH6nqmqpeAPA5ADf4OhaxjQu7ZP+e7diYgAkoYCJ1kou2SAp4s3RE5ApVPTP47y0AnvJ1LGKftnbJ3p0z+NihE5k/y7N7QlXTZCE3kgpOInwR+TKAfwCwXUReEJGPAPi0iDwpIv8M4EYAd7o4FukvMzm2TpbdE3LlLBdtkVRwlaXzIVW9QlUnVPUqVf3/qvpbqnqtqv6Cqv7GSLRPSCPqzAWEFOGiCWVf1g4tJNIE1sMnyTC0R6rYNCFXzuZlIQFwZu2M2lNTWybwnz9dxcqF9VXGtJBIVSj4xCR5/nvVuYCQK2f379l+kYc/ynBU0UaIx+cIzp1f8XIc0n1YS4e0xrW94MJ/D7lydpiFlEfbUUWWPeXjOKT7UPBJK3xMjrrw30OvnN27c6bWpHIdqgo56/6QMmjpkFYUiXNTcXXlv4deOZtl7bgYVRTNEbg8Duk+jPBJK3xMjuZFqtYjWF+jihvflb0CfXJiE+v+kFowwietmNoykTmJOLVlovFn+oqUQ+BjVPHY02czX3/rm96Ao/O7nR6LdBsKPmlFzv4jua9XoU76ZR/g5izEFRR80op/W94Y3Re9XhVWrnwdbs5CXEEPn7QiVb89BYbprotLyxsKx6VicRFbUPBJK7hTlB9G012B9aqgQ9HnJG19WIpiHVo6pBVW/fZQlTJ9kZXuqlgXe07U1oPVTF+Hgk9aY81vj/kFd9XRcKLWHT7WiqQKLR3SOWKVK3a56rjPcyOu7Rd2nq9DwSed4sjxxdxVqb6/4C47mr7Ojfgo1dHnznMcCj5phaXJsKFY5OH7C+4ykgxdC8gKPkZnfe08s6CHTxpjbTKsqKpkiC+463x5a3MjIfBhv1hNLIgBBZ80Ji8au+vwEwDCi36RKISIjlMuCeGDJhPYvhaZ9bHzzIKWDmlMnsCuqXrbP7aIPFGYmZoM8mXvqw2TRVMvnvaLXxjhk8YUle2NkfZmIcJmJLlO01RI2i9+oeCTxhRt7QeET3ujWNihjRfPTtMfTgRfRD4P4NcBvKiq7xm89lYAhwBsBfA8gNtV9ZyL4xEbDL+Udx1+AmsZ5THHLZYQq18pFjZgwTebuPLwvwDgprHX5gF8W1WvBvDtwf9Jx9i7cwb33X5dqe/qI7+a2IVevE2cCL6qPg7gpbGXbwbwxcG/vwhgr4tjEXtUmayMtfqVxCHvmQBgZt1GH/Hp4V+uqmcAQFXPiMjbPB6LRKbMSuHy9v4x/ky0XbeRekE8C0RPyxSRfSKyICILZ89mb+VG0ofL20mbUR4tQTf4FPwficgVADD4+8WsN6nqQVWdU9W56enszZpJ+lj3dC2ViOgqbUZ5tATd4NPSeRDAhwEcGPz9dY/HIsapkzLpa+ie97nWSkRYwuW9yMvcUaz7+kWfTUvQDa7SMr8M4JcBXCYiLwD4I6wL/WER+QiAHwK4zcWxSLpUSZn0Jb5Fn8t66dm4vhdF6zbKPptpnm5wIviq+qGcH/0PF59PbOMyCvQlvkWfayV6tDYp6fpejI7yssS76LPbrKK2dl1jEn3SlqSN68k0X+Jb9LkWJpQtTkr6qlx5dH73hk3Zyz67aZ0ii9c1JhR80grXk2m+xLfocy1MKFuclMy7ZpdOTnj77KL7POwsnjvwPhyd311J7O86/ITz65ryBD8Fn7TCdRToS3yLPtdClUsrttIo+/dsx8SmjbH4y6+uthY5353sMLLPKvkBNL+uqY8YWDyNtMLHph+A+wJoZZ8buwaPxUnJvTtn8D//+iTOnV+56PWVNW09p+K70F3RZjhA8+ua+gQ/BZ+0wkdJYl/iG1vUi7BQ2jmLpTGxH+Ji5OHzfhS1r811tTgSqwMtHdIKC3ZIF7B6HS1MaDchr32bRVpd11SvxxDRHI8rBnNzc7qwsBC7GcQgTK2Lw3guPrAeIVvojIrw1W6r10NEjqnqXNn7aOl0gK6LIVfCxiPVTWVizQVZhxF+4liNOFyy68CjmROaM1OTODq/O0KL0qLrAQFhhN8bLGQN+BaU1CfKYsLRUTl96hAp+IkTWwxDCIrFlMVUsBAQWKbo+QXStW7yYJZO4sTOGgixQtTCSthUiR0QhKDNyte85/eeB08mvcAqDwp+4uzfsx0Tmy9eDTmxWYKJYZ5wLC4tO1t2bjVlMQViBwS+abvyNe/5XVpeMVfqwgW0dLrA+Lx7wHn4PLsFcGvvWF40ZRmrC7pc0dayKnp+s0h9ZMQIP2GGxaFWLlys8CsXNFgkkmW3jJIXFbUtQJVyAauQpDo6qnp/21pWeXbhW7ZkF4hLfWTECD9RfBWHqktZjfOstrjYzLrO7/cpCyOLrNGR5WtS5/62ndDPy6sH0MmRESP8RPFVHKoJw7K1MxX94rYTvXV+P/Xqhj6wfk3q3N+2E/p5HV/VkVFqI01G+IniqzhUG6r6xW2H4XV+v6rHGzrijRlhW0/VrHN/26x8LRtJlM0bpbjGgYKfKHlD2bbFodpQ9cvXdhhe5/eriEfoL25sobCeqln3+Wg6od+247PecWZBSydR8oay991+XdSHrcquRG2H4XV+v0paYujdpmLvblXlmsS0KkKtuwg50rQCBT9RUs2+ANq3vc7vVxGP0F/c2EJRdk1ie/yhnu22axRSXONASydhUs5Nb9v2qr9fxWYKXbohdqmIsmtiwaoI8Wy3XaOQ4hoH74IvIs8D+A8AawBWq1R0I8QlZeIR+otrQSiKrknsEUgIhpPmyytr2CyCNVXM1Jw8T7FUcqgI/0ZV/XGgYxFSi9BfXOtCEXsE4pvxSfM11Ys2tC/73fH7llKJbu/18AcR/lwVwWc9fELq4SO9M/QeC6FTVJvur2B574mq9fBDTNoqgG+KyDER2RfgeIT0Al+TqyETAmJMEDe1rGJnV7kghKWzS1VPi8jbAHxLRJ5W1ceHPxx0AvsAYHZ2NkBzCOkGPidXQyUExJggbmpZdWFuw3uEr6qnB3+/COABADeM/fygqs6p6tz09LTv5hDSGcpKU6ew3D+GiDbN808xDXMcr4IvIm8SkZ8Z/hvArwB4yucxCekLeUIjgNk6OePEENGmllUXNuLxbelcDuABERke60uq+neej0lIL8hK7xRs3A5h3CKxVCkzVopqE8vKenZVFbwKvqo+C+A6n8cgpK9kCVBZierYdXzGKRNRS53TsL0pCfw43tMy68C0TELaUZZy2DQlMQaW0yCtUTUtk6UVyGtYi6ZIfcoskpQyTfIyeO46/ATuPHTC+zMa6vsQ8ntHwe8xow/apZMTePnVVaysrY/4Yg/1STPKLJKUVtHmdULDXd58PqOhrK/QFhsFv6eMP2hLyysb3mO9tndMLI+GinzmqpOkFs6vygbjvp7RUOsDQq9DoOD3lLItEoeUDfUtCENorE181qFKpomV88vqnLLwYUeFsr5CW2wU/J5S9YEqGupbEYYsfHZEFsoHt6Es08TK+Y13TpsGVS3H8WFHhbK+Qlts3AClp1R5oMryoa3WFvFdnyWlic8mWDq/0R3U7rv9umALn7IWWQ0XtLlcvRx6MRcF3wGp7VwPZD9oE5sEb9kyUXn1YWxhyLvuvjuiLiyxL8Lq+YUs6jZ6LODiBW0uA4jQO9fR0mlJW1sjlgfuYtVgzIyPouvuuyOysIGJTyyfX8iFT8NjZa1dcGlxhTwnCn5LyvzOIkGP7YGXPWhlnVFoYRhtT5afO7zuvjuiLiyxL6Lr51eX2CNZl1DwW1L0MJQJupXJsSyqdEYhhSFrl6IsTi8t4/4P7PDeEaW+xL6M4fkNO9k7D53AvQ+fcnZ/U8ruSmntQhkU/JYUPQxlgm45cqjaGcWsm57FlVOTnYxQYwikrxFo7JFtXXyMZGN1eJy0zaHqRGzRLHuZoFudHAPsDWOrHHf0Szia3XF0frdJIalKjF2hAH+T31azu/JwPbGadT8/dugEdn7im97vKSP8DOpEIEXR5L0PnyocClqeHLM2jM1rz2YRXFDtRBSfRyzrz1enby2YqILLkWzeaPXc+RXvIx0KfgZ1v2B5D0OZoFu2Hqx1Rnnt6UPlxFgCmdfJTm2Z8PK5Fka2ISi6b747cgp+Bq6+YFUE3erkn7XOyFp7QjD0efMKmPsWyP17tmP/Xz3xWkG9If/501UcOb7Y+NpbCyZCU1YjyGdHTsHPwGUEYlXQq2Ct7dba45OsWvCjhNoV6p4HT24orLdyQVtFoX3svEcpqxHksyOn4GfQ9wiExKcoK2kmoED+W0YVVaB9FNqnznuc4Xlndaa+dYaCn0HfI5C2pJRj7QIf55snqAIE3Zmq7367L8bXOYT6rlDwc+hzBNKG1HKs2+LrfK0ILUe7fgmtM8zDJ04py7FOsdBcEb5yykNXUcwjdHEv4hdG+MQpbUpNpEje+Q7L6DYdotexFX3bAhztdgfvgi8iNwH4PwA2A/h/qnrA9zFJPNqUmkiRohS7th1aFaHtYidK/OHV0hGRzQD+DMCvArgGwIdE5BqfxyTV8WGvtCk1kSJZ5zuK75IBqZUp8EXXrEJf+I7wbwDwjKo+CwAi8hUANwP4rufjJkfo2XpfkWGbUhNV2mwt+2f0fPMifZ8dWhc70bpwlFMd34I/A+BfR/7/AoD/7vmYyRHjgfVprzQtNVFE2TWK2RkUbZQB+M2ssZLNE5MuWoW+8J2lIxmvXbROW0T2iciCiCycPXvWc3NsEmNYHiMybJPxUXSNYlWTHCdGZo2VbJ6YcJRTHd8R/gsA3jHy/6sAnB59g6oeBHAQAObm5vLKhnQa3w9sVvQbKzJsmvFRdI2sRHgxFuxxkWD+KOfSyQnsOvBob69LFr4F/58AXC0i2wAsAvgggN/0fMzk8Cm+eVbI+6+fwVePLSazoKboGlmK8GKkMMZMm7Qwr5JlFU5sErz86uprpQvo66/j1dJR1VUAdwB4GMD3ABxW1ZM+j5kiPofledHvY0+fTWpBTdE1sryRTB5dyCqxYqVlWYVvfuMlG6p89jF7aRzvefiq+g0A3/B9nJTxOSwvin5TWlBTdo1SWv7vY5I+RqRtxUoDNo5yts0/lPm+vvv6nV9pa2HIWQVf4tulLI68a5Saj+1aKGOlJVqy0sYJ+dynojFAxwWf+bn9KX6V0mjFtVDmdSB3HX4CgL9nPU9UFcDOT3wTquvllWOIYKjnPjWN6XTxNK5CZPEri7iec8jrKNZUvXrqRauMz51fwdLySjRvP9Rzn5rGdDrCtzzkHBJiODge/Q4nDFMYgnYR19FnUT0fn556lVXGIdqRh8tRX973NAWNGaXTgm/dv44xHMw65p2HTuBjh04E3Ulp2BYL3mfoduTNOQBo1BGXbZnnezHd3p0z2Db/UO7euyHa4ZOi76l1jRmn05aO9VWIMYaDWcccflFDDr2zUvruPHQCf3jkSe/HLmtHiGuwd+cMjs7vxnMH3vfaDlZN2zG0LzZL1sL2MOJT5RhWRbCMou9pHY2xkIrbacG37l+X1VL38UCURVmh/Me8jucvv/PDoF8EKx5s23bs3TmD+26/LlqAU1Y11FKgVZey1OYqGmNlzUKnLR3AdvaGz1rqTY45JMTQO+8YCgT1eqt6sL5tHxdecMz01PFjT22ZiJql45Iy26aKxuR16Pc8eJJ72vaFMu/Vx0RX2TGBcBZAjHLCVdsxeg1CzLW48oJjBjiWg6sq5HXqLibZ857ppeWVoOUfOm3pWGd0OJiHa/EbP+a46xvSAsh2nMN6vVU82BC2j/X5pq5TZLm4sIarPtO+7URG+BmEzNqIUUt9NBKLlSmzd+cMFn7wEv7yOz+8KLsjtMhVsUFCpN6ltlq4a5Stfm47eqkysh7ic4RLwR8j1sq5WCtiYw7D/3jvtZj72bdGF7myaxAq9S51SyRlfHfqWR36+VdXce78yob3+hzhUvDHiFUQqi8RXtaIYpiWaJW+lKfoMyE69awFkKGfq04IvktbIubKua5HeKnVHRnSl864z8To1GM8V8kLvmsRSW3lXEpYKqdbl5CdsZUVyH0iVqceOshLXvBdi0jXhu+WxCO1uiMxSHUU1AW6PsIGOiD4ZatVqwrdqDBObZnAGy7ZlPyiEWviwdFTOSmPgoh9ks/DzxMLASovYx7PwT13fgWvrF7A/R/YgaPzu5P9ouWJx8cOnYhSy4O55uVwFER8krzgZ4mIABsq9xUtaLBST6UJRQWZikSiyzXKY9K2QFaK+/OSdEje0smabKm7ZD/VqKrMsimrm5N6jXJrNLHQxudYbnzXNL56bLEzc0jEFslH+MDGUrN5pQrqRk/Wo6qykUlZBUPAfqfmGp8lauuOFLOW83/12CLef/1Mp0dBJB7JR/hZ1M20STUzp2xkUmVHIledmqVsoDx8T2LXHSnmdRCPPX3W/GI0kibeInwRuUdEFkXkxODPr/k61jh1veJUveUqI5Ph6Od/f2CHtwlTK7W+y/A9V5N3PzaJ1JpjSW3U1XbUZGFjkL7gO8K/X1X/1PMxMqnrFafoLdcZmfhcWJJKKqFvgc0rkLWm6ykEVedYrFuJo7QdNVlLHe46nbR0+kJdEffVqbkSUt+2kG+BHb8fm0ReE/shox1hqlbiKG07+1SCha7gW/DvEJHfBrAA4C5VPTf+BhHZB2AfAMzOznpuTvewMDJxIaQhIr0QAjt6P7bNP5T5nqw5FgtzH0063LadfVdsrVRoJfgi8giAt2f86OMAPgvgk1hPif8kgPsA/M74G1X1IICDADA3N1e28T0xiAshDRHphRbYKh2hhQ4baN7htu3su2BrpUQrwVfV91Z5n4h8DsDftDkWsYsLIQ0V6YUU2JQsm6YdbttzTOkadQFvlo6IXKGqZwb/vQXAU76O1ZYUUgqt01ZIuxjpWbNsimja4bY9x5SuURfw6eF/WkR2YN3SeR7A73k8VmOYJWCDrkZ6ViybMtp0uG3Psc3vM1irhzfBV9Xf8vXZLulzloClLwsjvbik1OEOn9vFpeWL6mYxWCun92mZfc0ScFH3ZSgGrkQ6lWi4i6TS4Y4/t3lFEq212wq9F/wuesdVqDuyyeog9v/VE4ACKxeyFxaRtEihw816bsfperDWhk4UT2tDX2u0u6j7srKmr4n9kFTKSpM4tC2jUEXMux6staH3Eb6VoWxoP73uyKZO1MQIi2ThIkGirOR3H4K1NvRe8IH4Q9kYmUJ1J+nKvmjj7yVkHBcJElnP7XDidrPIRSNM6/ZUDCj4kRiN6Mtqrvig7sgm64s2sVku8vCB/kVYljKdrOMiQSLruR3fNIZzSflQ8CMwHtGPi/0Q39ZInZFNXgeR9VrsL1koEeYajnq4SpAYf253HXi0t6nVdaHgR6BKpgHgxxppI4Z5HYSrL5ULoQ4pwn1dw9H0PvnK9e9ranUTep+lE4MqD6IPa8TyRiWu2hZyQ/o+Ck2b++Rro6FUtyiNASP8COQNbTeL4IKqNxvCckTqqm0hRTjUGg5L8wRt75OPBImUVgnHhoIfgbwHtE20U0UULEekeW1YXFrGtvmHKgtdVRF2IaIhhKauReW7cwj1DNU5j6L5pV0HHjXRUVqBgp+B7y+N69z/qqJgeVVxUdrnqHUAFHvxVUTYlc8fYg1HnYg6xPxFiGeoyXmMjxw4oZ4NPfwxQvnce3euby7+3IH34ej87lYPYVXf2vKq4qy2jVPFi6/iE7v0+V3exyyqRtRHji/irsNPeJ+/CPEMubg/IedyUoIR/hiWfe48qoqClVXFWYy3LW/rsyrWQZ5PPFplselnh6ZKRD0MUkKk94Z4hlzYRpbty5hQ8MdI8UGpM8yOvaq4iNG27TrwqFPrYHyIn4UFa2ucKhZVWZqv6/Py/Qy5sI0s25cxoaUzRlmKV9viTz6wbNU0xfU5lYmi1etVxaIqCkasnlcReff+xndNV/7udfE74QJG+GMURVRWJ4IsWzVl5E2Quz6nIlGcMX69yiLqojRfF3nubWiSAOGifELK3wmfiOb4fjGYm5vThYWF2M3IfUjzbIaZqUkcnd8doaVpk2WztE1PzaPL9y7kdYzVri7fPxeIyDFVnSt7HyP8DLIiqiPHF5Oa7MvC0gIeIOwEeZcX51iNZl3e3xTn1ixCwa/AMFLJI4WJIIt2VMgvsVVRdIXFyfiixXR14SSsGyj4FSia8EslSrSYbhr6S2xRFLtM3v0VrAcgde5Fl0doIWGWTgWKIs7YPmlVLA6JmUnRbfbv2Q7JeF2B2gugxrOVpiYn8MaJTbjz0Akz2XIp0ErwReQ2ETkpIhdEZG7sZ3eLyDMickpE9rRrZlwunZzIfH1majIJsQdsVhT0VT2R2GDvzplWC+iyPu/o/G7c/4EdeGX1As6dXzFX9dU6bS2dpwDcCuAvRl8UkWsAfBDAuwFcCeAREXmnqpYXgTfGkeOLePnV1Q2vT2ySpCJRq0Ni2izdZsaDbWfRnkyFVhG+qn5PVbPGZjcD+IqqvqKqzwF4BsANbY4Vi3sfPoWVtY1xypvfeElSD1fIaNri4jQSBx+2nUV7MhV8TdrOAPjOyP9fGLy2ARHZB2AfAMzOznpqTnPyHqKl8yuBW1KdosVMvjspi9lAJB4+sqOYsdOcUsEXkUcAvD3jRx9X1a/n/VrGa5l2nqoeBHAQWF94Vdae0KT2cMUW3K4Nt62tXUgR14GGVXsyBUoFX1Xf2+BzXwDwjpH/XwXgdIPPiU5qD1dMwXW5OM2C0MbuPEk2XV9T4RNfls6DAL4kIv8L65O2VwP4R0/H8kpqD1csf9Pl4jQrQtu10UqX4GR/M1oJvojcAuAzAKYBPCQiJ1R1j6qeFJHDAL4LYBXA76eYoTMk5sNVN9KNZUG5XJxmRWg5OUi6RtssnQdU9SpVfYOqXq6qe0Z+9ieq+vOqul1V/7Z9U/tHk923Yi1mcrk4zYrQWly7QEgbuNLWME22aYu1mClPBJssTqsqtL7TP7kSmHQN1tIxSNut+GJYUC4nt0NuRF5Ek/kbC5PNhORBwTdGqlvxuZzcrvJZoXz+Op2nlclmQvKg4Bsj1a34ALcji7LPsuLzj2JlspmQPCj4xmi6FV9VK6ErloPFBXEWOyFCRqHgGyNPyIq2cqtqJXTJcnC9IM5FR2ixEyJkFGbpGKNJZkjVbJ4mWT9WcZmN1CT9NQtm9RDrMMI3RpPJz6pWQgqWQ1GknfUzFxtYu/LeQ63K7ootR8JDwTdI3cnPqlaCdcuhyHIC4M2OctkR+k6J7ZItR8JDS6cDVLUSmlgOIWvbF0XaPu2olFbUdsmWI+FhhN8BqloJdS2HKtGkS3uhSaTtwo5KqSJq0TWi1UPKEFU7Jejn5uZ0YWEhdjPIgF0HHi3MGMpaJDY5sbnx5GnR8QDUzl6qQypimXeN3rJlAj9duXDRvZjYJHjzGy/B0vkV0+dE2iMix1R1rux9jPBJLmURt+uFRmWRts8oPJVyu3nXSBUb7sXKBcW5wc5s9PoJQMEnBZRN8rrO+qlaUsF6FO6a8dHH+6+fwWNPn73oOtx56ETp53DVL6Hgk1zKIm4fWT9FkXYqUbhLsuZRvnpscYNtVlRsbxRLKbgkPMzS6Rgus2rKFjdxoZF/qmblZN2LLCxmHpFwMMLvEE1ztIsmLMsibqCfNksoqtpm4/fi0skJvPzqKlbWXk/KYGdMKPgdoskkatuFPH20WUJSxzYbvxepZB6RcFDwO0STSdSulPTtqri1WSPAzpiMQ8HvEE0mUVOor1NGl8sN0DYjLqHgd4gm0aD1+jpV6MooJQ9G6sQVrbJ0ROQ2ETkpIhdEZG7k9a0isiwiJwZ//rx9U0kZTUoGW860qZpx1IVRCiEhaBvhPwXgVgB/kfGz76vqjpafT2pSNxpsUl8nhL1Qx6bpwiiFkBC0EnxV/R4AiIib1pAoVO0kQnrldWyalIqfERITnwuvtonIcRH5exH5RY/HIYEIWZq3jk3jcveLvr8TAAAEsUlEQVQrQrpMaYQvIo8AeHvGjz6uql/P+bUzAGZV9Scicj2AIyLyblX994zP3wdgHwDMzs5WbzkJTkivvK5Nw4lNQsopjfBV9b2q+p6MP3liD1V9RVV/Mvj3MQDfB/DOnPceVNU5VZ2bnp5ueh4kACE3CrE8mUxIqnixdERkWkQ2D/79cwCuBvCsj2ORcFQVYRf1fGjTEOKeVpO2InILgM8AmAbwkIicUNU9AH4JwCdEZBXAGoCPqupLrVtLolIlo8flxC5tGkLcwh2viFPKdskihLin6o5XLI9MnMJFUITYhYJPnBJyYpcQUg8KPnEKs2sIsQuLpxGnsLojIXah4BPnMLuGEJvQ0iGEkJ5AwSeEkJ5AwSeEkJ5AwSeEkJ5AwSeEkJ5gqrSCiJwF8AMAlwH4ceTmhITn2214vt0n9jn/rKqWlhs2JfhDRGShSl2IrsDz7TY83+6TyjnT0iGEkJ5AwSeEkJ5gVfAPxm5AYHi+3Ybn232SOGeTHj4hhBD3WI3wCSGEOMaU4IvIbSJyUkQuiMjcyOtbRWRZRE4M/vx5zHa6Iu98Bz+7W0SeEZFTIrInVht9ISL3iMjiyD39tdht8oGI3DS4h8+IyHzs9vhGRJ4XkScH97Rz29eJyOdF5EUReWrktbeKyLdE5F8Gf78lZhuLMCX4AJ4CcCuAxzN+9n1V3TH489HA7fJF5vmKyDUAPgjg3QBuAvB/h5vCd4z7R+7pN2I3xjWDe/ZnAH4VwDUAPjS4t13nxsE9NZ+m2IAvYP07Oco8gG+r6tUAvj34v0lMCb6qfk9VT8VuRygKzvdmAF9R1VdU9TkAzwC4IWzriANuAPCMqj6rqq8C+ArW7y1JFFV9HMBLYy/fDOCLg39/EcDeoI2qgSnBL2GbiBwXkb8XkV+M3RjPzAD415H/vzB4rWvcISL/PBgmmx0Gt6Av93EUBfBNETkmIvtiNyYQl6vqGQAY/P22yO3JJfgGKCLyCIC3Z/zo46r69ZxfOwNgVlV/IiLXAzgiIu9W1X/31lBHNDxfyXgtuXSqonMH8FkAn8T6eX0SwH0Afidc64LQiftYk12qelpE3gbgWyLy9CAqJgYILviq+t4Gv/MKgFcG/z4mIt8H8E4A5ieFmpwv1iPBd4z8/yoAp920KBxVz11EPgfgbzw3JwaduI91UNXTg79fFJEHsG5rdV3wfyQiV6jqGRG5AsCLsRuURxKWjohMDyctReTnAFwN4Nm4rfLKgwA+KCJvEJFtWD/ff4zcJqcMvhhDbsH6BHbX+CcAV4vINhH5b1ifiH8wcpu8ISJvEpGfGf4bwK+gm/d1nAcBfHjw7w8DyBu5R8fUnrYicguAzwCYBvCQiJxQ1T0AfgnAJ0RkFcAagI+q6vjESXLkna+qnhSRwwC+C2AVwO+r6lrMtnrg0yKyA+sWx/MAfi9uc9yjqqsicgeAhwFsBvB5VT0ZuVk+uRzAAyICrGvLl1T17+I2yS0i8mUAvwzgMhF5AcAfATgA4LCIfATADwHcFq+FxXClLSGE9IQkLB1CCCHtoeATQkhPoOATQkhPoOATQkhPoOATQkhPoOATQkhPoOATQkhPoOATQkhP+C90pUUy1BVeGwAAAABJRU5ErkJggg\u003d\u003d\n", 165 | "text/plain": [ 166 | "\u003cFigure size 432x288 with 1 Axes\u003e" 167 | ] 168 | }, 169 | "metadata": { 170 | "needs_background": "light" 171 | }, 172 | "output_type": "display_data" 173 | } 174 | ], 175 | "source": [ 176 | "cls_reduced_feats \u003d TSNE(learning_rate\u003d100, metric\u003d\u0027cosine\u0027).fit_transform(cls_feats)\n", 177 | "plt.scatter(cls_reduced_feats[:, 0], cls_reduced_feats[:, 1])\n", 178 | "plt.show()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 11, 184 | "metadata": { 185 | "pycharm": {} 186 | }, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "500\n", 193 | "1000\n" 194 | ] 195 | } 196 | ], 197 | "source": [ 198 | "clusters \u003d []\n", 199 | "kmeans \u003d KMeans(n_clusters\u003dn_parts, n_init\u003d1, init\u003d\u0027random\u0027)\n", 200 | "for i in range(pool_size):\n", 201 | " if (i+1) % 500 \u003d\u003d 0:\n", 202 | " print(i+1)\n", 203 | " kmeans.fit(cls_feats[:, np.random.permutation(2048)[:50]])\n", 204 | " clusters.append(kmeans.labels_)\n", 205 | "\n", 206 | "if not os.path.isdir(save_path):\n", 207 | " os.makedirs(save_path)\n", 208 | "partitions \u003d torch.from_numpy(np.stack(clusters)).long()\n", 209 | "torch.save(partitions, os.path.join(save_path, \u0027part_%d_kmeans.pth.tar\u0027 % n_parts))" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": { 216 | "pycharm": {} 217 | }, 218 | "outputs": [], 219 | "source": [] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": { 225 | "pycharm": {} 226 | }, 227 | "outputs": [], 228 | "source": [] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": { 234 | "pycharm": {} 235 | }, 236 | "outputs": [], 237 | "source": [] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": { 243 | "pycharm": {} 244 | }, 245 | "outputs": [], 246 | "source": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": { 252 | "pycharm": {} 253 | }, 254 | "outputs": [], 255 | "source": [] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": { 261 | "pycharm": {} 262 | }, 263 | "outputs": [], 264 | "source": [] 265 | } 266 | ], 267 | "metadata": { 268 | "kernelspec": { 269 | "display_name": "Python 3", 270 | "language": "python", 271 | "name": "python3" 272 | }, 273 | "language_info": { 274 | "codemirror_mode": { 275 | "name": "ipython", 276 | "version": 3 277 | }, 278 | "file_extension": ".py", 279 | "mimetype": "text/x-python", 280 | "name": "python", 281 | "nbconvert_exporter": "python", 282 | "pygments_lexer": "ipython3", 283 | "version": "3.7.0" 284 | } 285 | }, 286 | "nbformat": 4, 287 | "nbformat_minor": 1 288 | } --------------------------------------------------------------------------------