├── CONTRIBUTING-ARCHIVED.md ├── DataLoader ├── Places_dataset.py ├── VOC_dataset.py ├── __pycache__ │ ├── Places_dataset.cpython-36.pyc │ ├── VOC_dataset.cpython-36.pyc │ ├── dataloader.cpython-36.pyc │ └── dataloader_balance.cpython-36.pyc ├── dataloader.py └── dataloader_balance.py ├── README.md ├── classifier_retrain.py ├── detection ├── README.md ├── configs │ ├── Base-RCNN-FPN-BN.yaml │ └── coco_R_50_FPN_1x.yaml ├── convert-pretrain-to-detectron2.py └── train_net.py ├── finetune_imagenet.py ├── img └── blog.png ├── lowshot_svm.py ├── model.py ├── noise_cleaning.py ├── resnet.py └── train.py /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /DataLoader/Places_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision.datasets as datasets 3 | import torchvision.transforms as transforms 4 | import csv 5 | from PIL import Image 6 | import os 7 | import random 8 | 9 | 10 | class Places205(data.Dataset): 11 | def __init__(self, root, split, transform=None, target_transform=None): 12 | self.root = os.path.expanduser(root) 13 | self.data_folder = os.path.join(self.root, 'data', 'vision', 'torralba', 'deeplearning', 'images256') 14 | self.split_folder = os.path.join(self.root, 'trainvalsplit_places205') 15 | assert(split=='train' or split=='val') 16 | split_csv_file = os.path.join(self.split_folder, split+'_places205.csv') 17 | 18 | self.low_shot = False 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | with open(split_csv_file, 'r') as f: 22 | reader = csv.reader(f, delimiter=' ') 23 | self.img_files = [] 24 | self.labels = [] 25 | for row in reader: 26 | self.img_files.append(row[0]) 27 | self.labels.append(int(row[1])) 28 | 29 | def convert_low_shot(self, k): 30 | 31 | label2img = {c:[] for c in range(205)} 32 | 33 | for n in range(len(self.labels)): 34 | label2img[self.labels[n]].append(self.img_files[n]) 35 | 36 | self.img_files_lowshot = [] 37 | self.labels_lowshot = [] 38 | 39 | for c,imlist in label2img.items(): 40 | random.shuffle(imlist) 41 | self.labels_lowshot += [c]*k 42 | self.img_files_lowshot += imlist[:k] 43 | self.low_shot = True 44 | 45 | def __getitem__(self, index): 46 | """ 47 | Args: 48 | index (int): Index 49 | Returns: 50 | tuple: (image, target) where target is index of the target class. 51 | """ 52 | if self.low_shot: 53 | image_path = os.path.join(self.data_folder, self.img_files_lowshot[index]) 54 | img = Image.open(image_path).convert('RGB') 55 | target = self.labels_lowshot[index] 56 | else: 57 | image_path = os.path.join(self.data_folder, self.img_files[index]) 58 | img = Image.open(image_path).convert('RGB') 59 | target = self.labels[index] 60 | 61 | if self.transform is not None: 62 | img = self.transform(img) 63 | if self.target_transform is not None: 64 | target = self.target_transform(target) 65 | return img, target 66 | 67 | def __len__(self): 68 | if self.low_shot: 69 | return len(self.labels_lowshot) 70 | else: 71 | return len(self.labels) -------------------------------------------------------------------------------- /DataLoader/VOC_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | import csv 3 | import os 4 | import os.path 5 | import tarfile 6 | from six.moves.urllib.parse import urlparse 7 | 8 | import numpy as np 9 | import torch 10 | import torch.utils.data as data 11 | from PIL import Image 12 | import random 13 | 14 | from tqdm import tqdm 15 | from six.moves.urllib.request import urlretrieve 16 | 17 | object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 18 | 'bottle', 'bus', 'car', 'cat', 'chair', 19 | 'cow', 'diningtable', 'dog', 'horse', 20 | 'motorbike', 'person', 'pottedplant', 21 | 'sheep', 'sofa', 'train', 'tvmonitor'] 22 | 23 | urls = { 24 | 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar', 25 | 'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 26 | 'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 27 | 'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar', 28 | } 29 | 30 | 31 | def read_image_label(file): 32 | print('[dataset] read ' + file) 33 | data = dict() 34 | with open(file, 'r') as f: 35 | for line in f: 36 | tmp = line.split(' ') 37 | name = tmp[0] 38 | label = int(tmp[-1]) 39 | data[name] = label 40 | # data.append([name, label]) 41 | # print('%s %d' % (name, label)) 42 | return data 43 | 44 | 45 | def read_object_labels(root, dataset, set): 46 | path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') 47 | labeled_data = dict() 48 | num_classes = len(object_categories) 49 | 50 | for i in range(num_classes): 51 | file = os.path.join(path_labels, object_categories[i] + '_' + set + '.txt') 52 | data = read_image_label(file) 53 | 54 | if i == 0: 55 | for (name, label) in data.items(): 56 | labels = np.zeros(num_classes) 57 | labels[i] = label 58 | labeled_data[name] = labels 59 | else: 60 | for (name, label) in data.items(): 61 | labeled_data[name][i] = label 62 | 63 | return labeled_data 64 | 65 | 66 | def write_object_labels_csv(file, labeled_data): 67 | # write a csv file 68 | print('[dataset] write file %s' % file) 69 | with open(file, 'w') as csvfile: 70 | fieldnames = ['name'] 71 | fieldnames.extend(object_categories) 72 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 73 | 74 | writer.writeheader() 75 | for (name, labels) in labeled_data.items(): 76 | example = {'name': name} 77 | for i in range(20): 78 | example[fieldnames[i + 1]] = int(labels[i]) 79 | writer.writerow(example) 80 | 81 | csvfile.close() 82 | 83 | 84 | def read_object_labels_csv(file, header=True): 85 | images = [] 86 | num_categories = 0 87 | print('[dataset] read', file) 88 | with open(file, 'r') as f: 89 | reader = csv.reader(f) 90 | rownum = 0 91 | for row in reader: 92 | if header and rownum == 0: 93 | header = row 94 | else: 95 | if num_categories == 0: 96 | num_categories = len(row) - 1 97 | name = row[0] 98 | labels = (np.asarray(row[1:num_categories + 1])).astype(np.float32) 99 | labels = torch.from_numpy(labels) 100 | item = (name, labels) 101 | images.append(item) 102 | rownum += 1 103 | return images 104 | 105 | 106 | def find_images_classification(root, dataset, set): 107 | path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') 108 | images = [] 109 | file = os.path.join(path_labels, set + '.txt') 110 | with open(file, 'r') as f: 111 | for line in f: 112 | images.append(line) 113 | return images 114 | 115 | 116 | def download_voc2007(root): 117 | path_devkit = os.path.join(root, 'VOCdevkit') 118 | path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') 119 | tmpdir = os.path.join(root, 'tmp') 120 | 121 | # create directory 122 | if not os.path.exists(root): 123 | os.makedirs(root) 124 | 125 | if not os.path.exists(path_devkit): 126 | 127 | if not os.path.exists(tmpdir): 128 | os.makedirs(tmpdir) 129 | 130 | parts = urlparse(urls['devkit']) 131 | filename = os.path.basename(parts.path) 132 | cached_file = os.path.join(tmpdir, filename) 133 | 134 | if not os.path.exists(cached_file): 135 | print('Downloading: "{}" to {}\n'.format(urls['devkit'], cached_file)) 136 | download_url(urls['devkit'], cached_file) 137 | 138 | # extract file 139 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 140 | cwd = os.getcwd() 141 | tar = tarfile.open(cached_file, "r") 142 | os.chdir(root) 143 | tar.extractall() 144 | tar.close() 145 | os.chdir(cwd) 146 | print('[dataset] Done!') 147 | 148 | # train/val images/annotations 149 | if not os.path.exists(path_images): 150 | 151 | # download train/val images/annotations 152 | parts = urlparse(urls['trainval_2007']) 153 | filename = os.path.basename(parts.path) 154 | cached_file = os.path.join(tmpdir, filename) 155 | 156 | if not os.path.exists(cached_file): 157 | print('Downloading: "{}" to {}\n'.format(urls['trainval_2007'], cached_file)) 158 | download_url(urls['trainval_2007'], cached_file) 159 | 160 | # extract file 161 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 162 | cwd = os.getcwd() 163 | tar = tarfile.open(cached_file, "r") 164 | os.chdir(root) 165 | tar.extractall() 166 | tar.close() 167 | os.chdir(cwd) 168 | print('[dataset] Done!') 169 | 170 | # test annotations 171 | test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt') 172 | if not os.path.exists(test_anno): 173 | 174 | # download test annotations 175 | parts = urlparse(urls['test_images_2007']) 176 | filename = os.path.basename(parts.path) 177 | cached_file = os.path.join(tmpdir, filename) 178 | 179 | if not os.path.exists(cached_file): 180 | print('Downloading: "{}" to {}\n'.format(urls['test_images_2007'], cached_file)) 181 | download_url(urls['test_images_2007'], cached_file) 182 | 183 | # extract file 184 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 185 | cwd = os.getcwd() 186 | tar = tarfile.open(cached_file, "r") 187 | os.chdir(root) 188 | tar.extractall() 189 | tar.close() 190 | os.chdir(cwd) 191 | print('[dataset] Done!') 192 | 193 | # test images 194 | test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg') 195 | if not os.path.exists(test_image): 196 | 197 | # download test images 198 | parts = urlparse(urls['test_anno_2007']) 199 | filename = os.path.basename(parts.path) 200 | cached_file = os.path.join(tmpdir, filename) 201 | 202 | if not os.path.exists(cached_file): 203 | print('Downloading: "{}" to {}\n'.format(urls['test_anno_2007'], cached_file)) 204 | download_url(urls['test_anno_2007'], cached_file) 205 | 206 | # extract file 207 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 208 | cwd = os.getcwd() 209 | tar = tarfile.open(cached_file, "r") 210 | os.chdir(root) 211 | tar.extractall() 212 | tar.close() 213 | os.chdir(cwd) 214 | print('[dataset] Done!') 215 | 216 | def download_url(url, destination=None, progress_bar=True): 217 | """Download a URL to a local file. 218 | Parameters 219 | ---------- 220 | url : str 221 | The URL to download. 222 | destination : str, None 223 | The destination of the file. If None is given the file is saved to a temporary directory. 224 | progress_bar : bool 225 | Whether to show a command-line progress bar while downloading. 226 | Returns 227 | ------- 228 | filename : str 229 | The location of the downloaded file. 230 | Notes 231 | ----- 232 | Progress bar use/example adapted from tqdm documentation: https://github.com/tqdm/tqdm 233 | """ 234 | 235 | def my_hook(t): 236 | last_b = [0] 237 | 238 | def inner(b=1, bsize=1, tsize=None): 239 | if tsize is not None: 240 | t.total = tsize 241 | if b > 0: 242 | t.update((b - last_b[0]) * bsize) 243 | last_b[0] = b 244 | 245 | return inner 246 | 247 | if progress_bar: 248 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t: 249 | filename, _ = urlretrieve(url, filename=destination, reporthook=my_hook(t)) 250 | else: 251 | filename, _ = urlretrieve(url, filename=destination) 252 | 253 | class Voc2007Classification(data.Dataset): 254 | 255 | def __init__(self, root, set, transform=None, target_transform=None): 256 | self.root = root 257 | self.path_devkit = os.path.join(root, 'VOCdevkit') 258 | self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') 259 | self.set = set 260 | self.transform = transform 261 | self.target_transform = target_transform 262 | self.low_shot = False 263 | 264 | # download dataset 265 | download_voc2007(self.root) 266 | 267 | # define path of csv file 268 | path_csv = os.path.join(self.root, 'files', 'VOC2007') 269 | # define filename of csv file 270 | file_csv = os.path.join(path_csv, 'classification_' + set + '.csv') 271 | 272 | # create the csv file if necessary 273 | if not os.path.exists(file_csv): 274 | if not os.path.exists(path_csv): # create dir if necessary 275 | os.makedirs(path_csv) 276 | # generate csv file 277 | labeled_data = read_object_labels(self.root, 'VOC2007', self.set) 278 | # write csv file 279 | write_object_labels_csv(file_csv, labeled_data) 280 | 281 | self.classes = object_categories 282 | self.images = read_object_labels_csv(file_csv) 283 | 284 | print('[dataset] VOC 2007 classification set=%s number of classes=%d number of images=%d' % ( 285 | set, len(self.classes), len(self.images))) 286 | 287 | def __getitem__(self, index): 288 | if self.low_shot: 289 | path, target = self.images_lowshot[index] 290 | else: 291 | path, target = self.images[index] 292 | img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB') 293 | if self.transform is not None: 294 | img = self.transform(img) 295 | if self.target_transform is not None: 296 | target = self.target_transform(target) 297 | return img, target 298 | 299 | def __len__(self): 300 | if self.low_shot: 301 | return len(self.images_lowshot) 302 | else: 303 | return len(self.images) 304 | 305 | def get_number_classes(self): 306 | return len(self.classes) 307 | 308 | 309 | def convert_low_shot(self, k): 310 | label2img = {c:[] for c in range(len(self.classes))} 311 | for img in self.images: 312 | label = img[1] 313 | label_classes = torch.where(label>0)[0] 314 | for c in label_classes: 315 | label2img[c.item()].append(img) 316 | 317 | self.images_lowshot = [] 318 | for c,imlist in label2img.items(): 319 | random.shuffle(imlist) 320 | self.images_lowshot += imlist[:k] 321 | self.low_shot = True -------------------------------------------------------------------------------- /DataLoader/__pycache__/Places_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/MoPro/c02afed3640ac89dadaa368b2ea5dc3180a31d7e/DataLoader/__pycache__/Places_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /DataLoader/__pycache__/VOC_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/MoPro/c02afed3640ac89dadaa368b2ea5dc3180a31d7e/DataLoader/__pycache__/VOC_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /DataLoader/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/MoPro/c02afed3640ac89dadaa368b2ea5dc3180a31d7e/DataLoader/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /DataLoader/__pycache__/dataloader_balance.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/MoPro/c02afed3640ac89dadaa368b2ea5dc3180a31d7e/DataLoader/__pycache__/dataloader_balance.cpython-36.pyc -------------------------------------------------------------------------------- /DataLoader/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torchvision.transforms as transforms 3 | import random 4 | import numpy as np 5 | from PIL import Image 6 | import torch 7 | import os 8 | from PIL import ImageFile 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | 13 | class webvision_dataset(Dataset): 14 | def __init__(self, root_dir, transform, mode, num_class, transform_strong=None): 15 | self.root = root_dir 16 | self.transform = transform 17 | self.mode = mode 18 | 19 | if self.mode=='test': 20 | self.val_imgs = [] 21 | self.val_labels = {} 22 | with open(self.root+'info/val_filelist.txt') as f: 23 | lines=f.readlines() 24 | for line in lines: 25 | img, target = line.split() 26 | target = int(target) 27 | if target repeating dataset') 39 | labels = np.array(self.labels) 40 | uniq,freq = np.unique(labels,return_counts=True) 41 | inv = (1/freq)**0.5 42 | p = inv/inv.sum() 43 | weight = 10*p/p.min() 44 | weight = weight.astype(int) 45 | weight = {u:w for u,w in zip(uniq,weight)} 46 | 47 | train_imgs=[] 48 | train_labels=[] 49 | for im,lab in zip(self.images,self.labels): 50 | train_imgs += [im]*weight[lab] 51 | train_labels += [lab]*weight[lab] 52 | 53 | self.train_imgs=[] 54 | self.train_labels=[] 55 | index_shuf = list(range(len(train_imgs))) 56 | random.shuffle(index_shuf) 57 | for i in index_shuf[:len(self.labels)]: 58 | self.train_imgs.append(train_imgs[i]) 59 | self.train_labels.append(train_labels[i]) 60 | print('=> done') 61 | 62 | def __getitem__(self, index): 63 | if self.mode=='train': 64 | target = self.train_labels[index] 65 | image = Image.open(self.root+self.train_imgs[index]).convert('RGB') 66 | image = self.transform(image) 67 | return image, target 68 | 69 | elif self.mode=='test': 70 | img_path = self.val_imgs[index] 71 | target = self.val_labels[img_path] 72 | image = Image.open(self.root+'val_images_256/'+img_path).convert('RGB') 73 | img = self.transform(image) 74 | return img, target 75 | 76 | def __len__(self): 77 | if self.mode!='test': 78 | return len(self.images) 79 | else: 80 | return len(self.val_imgs) 81 | 82 | 83 | class webvision_dataloader(): 84 | def __init__(self, batch_size, annotation, num_workers, root_dir, imagenet_dir, distributed, crop_size=0.2): 85 | 86 | self.batch_size = batch_size 87 | self.annotation = annotation 88 | self.num_workers = num_workers 89 | self.root_dir = root_dir 90 | self.imagenet_dir = imagenet_dir 91 | self.distributed = distributed 92 | 93 | 94 | self.transform_train = transforms.Compose([ 95 | transforms.RandomResizedCrop(224, scale=(crop_size, 1.0)), 96 | transforms.RandomHorizontalFlip(), 97 | transforms.ToTensor(), 98 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 99 | ]) 100 | 101 | self.transform_test = transforms.Compose([ 102 | transforms.Resize(256), 103 | transforms.CenterCrop(224), 104 | transforms.ToTensor(), 105 | transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 106 | ]) 107 | 108 | def run(self): 109 | 110 | train_dataset = webvision_dataset(root_dir=self.root_dir, transform=self.transform_train, mode="train", annotation = self.annotation) 111 | test_dataset = webvision_dataset(root_dir=self.root_dir, transform=self.transform_test, mode='test', annotation = self.annotation) 112 | imagenet_val = datasets.ImageFolder(os.path.join(self.imagenet_dir, 'val'), self.transform_test) 113 | 114 | if self.distributed: 115 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 116 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset,shuffle=False) 117 | imagenet_sampler = torch.utils.data.distributed.DistributedSampler(imagenet_val,shuffle=False) 118 | else: 119 | self.train_sampler = None 120 | eval_sampler = None 121 | test_sampler = None 122 | imagenet_sampler = None 123 | 124 | train_loader = DataLoader( 125 | dataset=train_dataset, 126 | batch_size=self.batch_size, 127 | shuffle=(self.train_sampler is None), 128 | num_workers=self.num_workers, 129 | pin_memory=True, 130 | sampler=self.train_sampler, 131 | drop_last=True) 132 | 133 | test_loader = DataLoader( 134 | dataset=test_dataset, 135 | batch_size=self.batch_size, 136 | shuffle=False, 137 | num_workers=self.num_workers, 138 | pin_memory=True, 139 | sampler=test_sampler) 140 | 141 | imagenet_loader = DataLoader( 142 | dataset=imagenet_val, 143 | batch_size=self.batch_size, 144 | shuffle=False, 145 | num_workers=self.num_workers, 146 | pin_memory=True, 147 | sampler=imagenet_sampler) 148 | 149 | 150 | return train_loader,test_loader,imagenet_loader 151 | 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MoPro: Webly Supervised Learning with Momentum Prototypes (Salesforce Research) 2 | 3 | 4 | This is a PyTorch implementation of the MoPro paper (Blog post): 5 |
 6 | @article{MoPro,
 7 | 	title={MoPro: Webly Supervised Learning with Momentum Prototypes},
 8 | 	author={Junnan Li and Caiming Xiong and Steven C.H. Hoi},
 9 | 	journal={ICLR},
10 | 	year={2021}
11 | }
12 | 13 | 14 | ### Requirements: 15 | * WebVision dataset 16 | * ImageNet dataset (for evaluation) 17 | * Python ≥ 3.6 18 | * PyTorch ≥ 1.4 19 | 20 | 21 | ### Training 22 | This implementation currently only supports multi-gpu, DistributedDataParallel training, which is faster and simpler. 23 | 24 | To perform webly-supervised training of a ResNet-50 model on WebVision V1.0 using a 4-gpu or 8-gpu machine, run: 25 |
python train.py \ 
26 |   --data [WebVision folder] \ 
27 |   --exp-dir experiment/MoPro\
28 |   --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0
29 | 
30 | 31 | 32 | ### Download MoPro Pre-trained ResNet-50 Models 33 | WebVision V1| WebVision v2 34 | ------ | ------ 35 | 36 | 37 | ### Noise Cleaning 38 |
python noise_cleaning.py --data [WebVision folder] --resume [pre-trained model path] --annotation pseudo_label.json
39 | 
40 | 41 | 42 | ### Classifier Retraining on WebVision 43 |
python classifier_retrain.py --data [WebVision folder] --imagenet [ImageNet folder]\ 
44 |   --resume [pre-trained model path] --annotation pseudo_label.json --exp-dir experiment/cRT\
45 |   --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 
46 | 
47 | 48 | ### Fine-tuning on ImageNet (1% of labeled data) 49 |
python finetune_imagenet.py \
50 |   --data [ImageNet path] \
51 |   --model-path [pre-trained model path] \
52 |   --exp-dir experiment/Finetune \
53 |   --low-resource 0.01 \
54 |   --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 
55 | 
56 | 57 | Result for WebVision-V1 pre-trained model: 58 | Percentage | 1% | 10% 59 | ------ | ------ | ------ 60 | Accuracy | 71.2 | 74.8 61 | 62 | 63 | ### Linear SVM Evaluation on VOC or Places 64 |
python lowshot_svm.py --model_path [your pretrained model] --dataset VOC --voc-path [VOC data path]
65 | 
66 | 67 | Result for WebVision-V1 pre-trained model: 68 | VOC| k=1 | k=2 | k=4 | k=8 | k=16 69 | --- | --- | --- | --- | --- | --- 70 | mAP| 59.5| 71.3| 76.5| 81.4| 83.7 71 | 72 | Places| k=1 | k=2 | k=4 | k=8 | k=16 73 | --- | --- | --- | --- | --- | --- 74 | Acc| 16.9| 23.2| 29.2| 34.5| 38.7 75 | -------------------------------------------------------------------------------- /classifier_retrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import builtins 4 | import math 5 | import os 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | import torch.optim 17 | import torch.multiprocessing as mp 18 | import torch.utils.data 19 | import torch.utils.data.distributed 20 | import torchvision.transforms as transforms 21 | import torchvision.datasets as datasets 22 | import torch.nn.functional as F 23 | 24 | from resnet import * 25 | import DataLoader.dataloader_balance as dataloader 26 | 27 | import tensorboard_logger as tb_logger 28 | 29 | import numpy as np 30 | 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch WebVision Classifier Retraining') 33 | parser.add_argument('--data', default='../WebVision/dataset/', 34 | help='path to WebVision dataset') 35 | parser.add_argument('--imagenet', default='', 36 | help='path to ImageNet validation set') 37 | 38 | parser.add_argument('--annotation', default='./pseudo_label.json', 39 | help='path to pseudo-label annotation') 40 | 41 | parser.add_argument('--exp-dir', default='experiment/cRT', type=str, 42 | help='experiment directory') 43 | 44 | parser.add_argument('-j', '--workers', default=32, type=int, 45 | help='number of data loading workers (default: 32)') 46 | parser.add_argument('--epochs', default=15, type=int, 47 | help='number of total epochs to run') 48 | parser.add_argument('--start-epoch', default=0, type=int, 49 | help='manual epoch number (useful on restarts)') 50 | parser.add_argument('-b', '--batch-size', default=256, type=int, 51 | help='mini-batch size (default: 256), this is the total ' 52 | 'batch size of all GPUs on the current node when ' 53 | 'using Data Parallel or Distributed Data Parallel') 54 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 55 | metavar='LR', help='initial learning rate', dest='lr') 56 | parser.add_argument('--schedule', default=[5, 10], nargs='*', type=int, 57 | help='learning rate schedule (when to drop lr by 10x)') 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 59 | help='momentum of SGD solver') 60 | parser.add_argument('--wd', '--weight-decay', default=0, type=float, 61 | metavar='W', help='weight decay (default: 1e-4)', 62 | dest='weight_decay') 63 | parser.add_argument('-p', '--print-freq', default=50, type=int, 64 | help='print frequency (default: 10)') 65 | parser.add_argument('--resume', default='', type=str, 66 | help='path to latest checkpoint (default: none)') 67 | parser.add_argument('--world-size', default=-1, type=int, 68 | help='number of nodes for distributed training') 69 | parser.add_argument('--rank', default=-1, type=int, 70 | help='node rank for distributed training') 71 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 72 | help='url used to set up distributed training') 73 | parser.add_argument('--dist-backend', default='nccl', type=str, 74 | help='distributed backend') 75 | parser.add_argument('--seed', default=None, type=int, 76 | help='seed for initializing training. ') 77 | parser.add_argument('--gpu', default=None, type=int, 78 | help='GPU id to use.') 79 | parser.add_argument('--multiprocessing-distributed', action='store_true', 80 | help='Use multi-processing distributed training to launch ' 81 | 'N processes per node, which has N GPUs. This is the ' 82 | 'fastest way to use PyTorch for either single node or ' 83 | 'multi node data parallel training') 84 | 85 | parser.add_argument('--num-class', default=1000, type=int) 86 | parser.add_argument('--cos', action='store_true', default=False, 87 | help='use cosine lr schedule') 88 | parser.add_argument('--finetune', action='store_true', default=False, 89 | help='finetune encoder') 90 | 91 | 92 | def main(): 93 | args = parser.parse_args() 94 | if args.seed is not None: 95 | random.seed(args.seed) 96 | torch.manual_seed(args.seed) 97 | cudnn.deterministic = True 98 | warnings.warn('You have chosen to seed training. ' 99 | 'This will turn on the CUDNN deterministic setting, ' 100 | 'which can slow down your training considerably! ' 101 | 'You may see unexpected behavior when restarting ' 102 | 'from checkpoints.') 103 | 104 | if args.gpu is not None: 105 | warnings.warn('You have chosen a specific GPU. This will completely ' 106 | 'disable data parallelism.') 107 | 108 | if args.dist_url == "env://" and args.world_size == -1: 109 | args.world_size = int(os.environ["WORLD_SIZE"]) 110 | 111 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 112 | 113 | if not os.path.exists(args.exp_dir): 114 | os.makedirs(args.exp_dir) 115 | 116 | ngpus_per_node = torch.cuda.device_count() 117 | if args.multiprocessing_distributed: 118 | # Since we have ngpus_per_node processes per node, the total world_size 119 | # needs to be adjusted accordingly 120 | args.world_size = ngpus_per_node * args.world_size 121 | # Use torch.multiprocessing.spawn to launch distributed processes: the 122 | # main_worker process function 123 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 124 | else: 125 | # Simply call main_worker function 126 | main_worker(args.gpu, ngpus_per_node, args) 127 | 128 | 129 | def main_worker(gpu, ngpus_per_node, args): 130 | args.gpu = gpu 131 | if args.gpu is not None: 132 | print("Use GPU: {} for training".format(args.gpu)) 133 | 134 | # suppress printing if not master 135 | 136 | if args.multiprocessing_distributed and args.gpu != 0: 137 | def print_pass(*args): 138 | pass 139 | builtins.print = print_pass 140 | 141 | if args.distributed: 142 | if args.dist_url == "env://" and args.rank == -1: 143 | args.rank = int(os.environ["RANK"]) 144 | if args.multiprocessing_distributed: 145 | # For multiprocessing distributed training, rank needs to be the 146 | # global rank among all the processes 147 | args.rank = args.rank * ngpus_per_node + gpu 148 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 149 | world_size=args.world_size, rank=args.rank) 150 | 151 | # create model 152 | print("=> creating resnet model") 153 | encoder = resnet50(encoder=True) 154 | classifier = nn.Linear(2048,1000) 155 | classifier.weight.data.normal_(mean=0.0, std=0.01) 156 | classifier.bias.data.zero_() 157 | 158 | if args.distributed: 159 | # For multiprocessing distributed, DistributedDataParallel constructor 160 | # should always set the single device scope, otherwise, 161 | # DistributedDataParallel will use all available devices. 162 | if args.gpu is not None: 163 | torch.cuda.set_device(args.gpu) 164 | encoder.cuda(args.gpu) 165 | classifier.cuda(args.gpu) 166 | # When using a single GPU per process and per 167 | # DistributedDataParallel, we need to divide the batch size 168 | # ourselves based on the total number of GPUs we have 169 | args.batch_size = int(args.batch_size / ngpus_per_node) 170 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 171 | encoder = torch.nn.parallel.DistributedDataParallel(encoder, device_ids=[args.gpu]) 172 | classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[args.gpu]) 173 | 174 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 175 | 176 | # resume from a checkpoint 177 | if args.resume: 178 | if os.path.isfile(args.resume): 179 | print("=> loading checkpoint '{}'".format(args.resume)) 180 | if args.gpu is None: 181 | checkpoint = torch.load(args.resume) 182 | else: 183 | # Map model to be loaded to specified single gpu. 184 | loc = 'cuda:{}'.format(args.gpu) 185 | checkpoint = torch.load(args.resume, map_location=loc) 186 | state_dict = checkpoint['state_dict'] 187 | 188 | for k in list(state_dict.keys()): 189 | if k.startswith('module.encoder_q'): 190 | # remove prefix 191 | state_dict[k.replace('.encoder_q','')] = state_dict[k] 192 | # delete renamed or unused k 193 | del state_dict[k] 194 | for k in list(state_dict.keys()): 195 | if k.startswith('module.classifier'): 196 | # remove prefix 197 | state_dict[k.replace('.classifier','')] = state_dict[k] 198 | # delete renamed k 199 | del state_dict[k] 200 | print("=> loaded checkpoint '{}' (epoch {})" 201 | .format(args.resume, checkpoint['epoch'])) 202 | classifier.load_state_dict(state_dict,strict=False) 203 | encoder.load_state_dict(state_dict,strict=False) 204 | else: 205 | print("=> no checkpoint found at '{}'".format(args.resume)) 206 | 207 | 208 | cudnn.benchmark = True 209 | 210 | if args.finetune: 211 | optimizer_encoder = torch.optim.SGD(encoder.parameters(), args.lr, 212 | momentum=args.momentum, 213 | weight_decay=args.weight_decay) 214 | 215 | optimizer = torch.optim.SGD(classifier.parameters(), args.lr, 216 | momentum=args.momentum, 217 | weight_decay=args.weight_decay) 218 | 219 | # Data loading code 220 | loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_workers=args.workers,root_dir=args.data,\ 221 | imagenet_dir=args.imagenet,distributed=args.distributed,annotation=args.annotation) 222 | train_loader,test_loader,imagenet_loader = loader.run() 223 | 224 | if args.gpu==0: 225 | logger = tb_logger.Logger(logdir=os.path.join(args.exp_dir,'tensorboard'), flush_secs=2) 226 | else: 227 | logger = None 228 | 229 | for epoch in range(args.start_epoch, args.epochs+1): 230 | if args.distributed: 231 | loader.train_sampler.set_epoch(epoch) 232 | adjust_learning_rate(optimizer, epoch, args) 233 | 234 | train(train_loader, encoder, classifier, criterion, optimizer, epoch, args, logger) 235 | 236 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 237 | and args.rank % ngpus_per_node == 0): 238 | save_checkpoint({ 239 | 'epoch': epoch + 1, 240 | 'state_dict': classifier.state_dict(), 241 | 'optimizer' : optimizer.state_dict(), 242 | }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.exp_dir,epoch)) 243 | test(encoder, classifier, test_loader, imagenet_loader, args, epoch, logger) 244 | 245 | def train(train_loader, encoder, classifier, criterion, optimizer, epoch, args, tb_logger): 246 | train_loader.dataset.repeat() 247 | 248 | batch_time = AverageMeter('Time', ':1.2f') 249 | data_time = AverageMeter('Data', ':1.2f') 250 | losses_cls = AverageMeter('Loss@Cls', ':2.2f') 251 | acc_cls = AverageMeter('Acc@Cls', ':4.2f') 252 | 253 | 254 | progress = ProgressMeter( 255 | len(train_loader), 256 | [batch_time, data_time, losses_cls, acc_cls], 257 | prefix="Epoch: [{}]".format(epoch)) 258 | 259 | if args.finetune: 260 | # finetune encoder backbone 261 | encoder.train() 262 | else: 263 | encoder.eval() 264 | classifier.train() 265 | 266 | end = time.time() 267 | for i, (img,target) in enumerate(train_loader): 268 | # measure data loading time 269 | data_time.update(time.time() - end) 270 | img = img.cuda(args.gpu, non_blocking=True) 271 | target = target.cuda(args.gpu, non_blocking=True) 272 | if args.finetune: 273 | feature = encoder(img) 274 | else: 275 | with torch.no_grad(): 276 | feature = encoder(img) 277 | output = classifier(feature) 278 | 279 | loss = criterion(output, target) 280 | 281 | losses_cls.update(loss.item()) 282 | 283 | acc = accuracy(output, target)[0] 284 | acc_cls.update(acc[0]) 285 | 286 | # compute gradient and do SGD step 287 | if args.finetune: 288 | optimizer_encoder.zero_grad() 289 | optimizer.zero_grad() 290 | loss.backward() 291 | optimizer.step() 292 | if args.finetune: 293 | optimizer_encoder.step() 294 | 295 | # measure elapsed time 296 | batch_time.update(time.time() - end) 297 | end = time.time() 298 | if i % args.print_freq == 0: 299 | progress.display(i) 300 | 301 | if args.gpu == 0: 302 | tb_logger.log_value('Train Acc', acc_cls.avg, epoch) 303 | 304 | 305 | def test(encoder, classifier, test_loader, imagenet_loader, args, epoch, tb_logger): 306 | with torch.no_grad(): 307 | encoder.eval() 308 | classifier.eval() 309 | top1_webvision = AverageMeter('Top1@webvision', ':4.2f') 310 | top5_webvision = AverageMeter('Top5@webvision', ':4.2f') 311 | top1_imagenet = AverageMeter('Top1@imagenet', ':4.2f') 312 | top5_imagenet = AverageMeter('Top5@imagenet', ':4.2f') 313 | print('==> Evaluation...') 314 | 315 | # evaluate on webvision val set 316 | for batch_idx, (img,target) in enumerate(test_loader): 317 | img = img.cuda(args.gpu, non_blocking=True) 318 | target = target.cuda(args.gpu, non_blocking=True) 319 | feature = encoder(img) 320 | outputs = classifier(feature) 321 | acc1, acc5 = accuracy(outputs, target, topk=(1, 5)) 322 | top1_webvision.update(acc1[0]) 323 | top5_webvision.update(acc5[0]) 324 | 325 | # evaluate on imagenet val set 326 | for batch_idx, (img,target) in enumerate(imagenet_loader): 327 | img = img.cuda(args.gpu, non_blocking=True) 328 | target = target.cuda(args.gpu, non_blocking=True) 329 | feature = encoder(img) 330 | outputs = classifier(feature) 331 | acc1, acc5 = accuracy(outputs, target, topk=(1, 5)) 332 | top1_imagenet.update(acc1[0]) 333 | top5_imagenet.update(acc5[0]) 334 | 335 | acc_tensors = torch.Tensor([top1_webvision.avg,top5_webvision.avg,top1_imagenet.avg,top5_imagenet.avg]).cuda(args.gpu) 336 | dist.all_reduce(acc_tensors) 337 | 338 | acc_tensors /= args.world_size 339 | 340 | print('Webvision Accuracy is %.2f%% (%.2f%%)'%(acc_tensors[0],acc_tensors[1])) 341 | print('ImageNet Accuracy is %.2f%% (%.2f%%)'%(acc_tensors[2],acc_tensors[3])) 342 | if args.gpu ==0: 343 | tb_logger.log_value('WebVision top1 Acc', acc_tensors[0], epoch) 344 | tb_logger.log_value('WebVision top5 Acc', acc_tensors[1], epoch) 345 | tb_logger.log_value('ImageNet top1 Acc', acc_tensors[2], epoch) 346 | tb_logger.log_value('ImageNet top5 Acc', acc_tensors[3], epoch) 347 | return 348 | 349 | 350 | 351 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 352 | torch.save(state, filename) 353 | if is_best: 354 | shutil.copyfile(filename, 'model_best.pth.tar') 355 | 356 | 357 | class AverageMeter(object): 358 | """Computes and stores the average and current value""" 359 | def __init__(self, name, fmt=':f'): 360 | self.name = name 361 | self.fmt = fmt 362 | self.reset() 363 | 364 | def reset(self): 365 | self.val = 0 366 | self.avg = 0 367 | self.sum = 0 368 | self.count = 0 369 | 370 | def update(self, val, n=1): 371 | self.val = val 372 | self.sum += val * n 373 | self.count += n 374 | self.avg = self.sum / self.count 375 | 376 | def __str__(self): 377 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 378 | return fmtstr.format(**self.__dict__) 379 | 380 | 381 | class ProgressMeter(object): 382 | def __init__(self, num_batches, meters, prefix=""): 383 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 384 | self.meters = meters 385 | self.prefix = prefix 386 | 387 | def display(self, batch): 388 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 389 | entries += [str(meter) for meter in self.meters] 390 | print('\t'.join(entries)) 391 | 392 | def _get_batch_fmtstr(self, num_batches): 393 | num_digits = len(str(num_batches // 1)) 394 | fmt = '{:' + str(num_digits) + 'd}' 395 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 396 | 397 | 398 | def adjust_learning_rate(optimizer, epoch, args): 399 | """Decay the learning rate based on schedule""" 400 | lr = args.lr 401 | if args.cos: # cosine lr schedule 402 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 403 | else: # stepwise lr schedule 404 | for milestone in args.schedule: 405 | lr *= 0.1 if epoch >= milestone else 1. 406 | for param_group in optimizer.param_groups: 407 | param_group['lr'] = lr 408 | 409 | 410 | def accuracy(output, target, topk=(1,)): 411 | """Computes the accuracy over the k top predictions for the specified values of k""" 412 | with torch.no_grad(): 413 | maxk = max(topk) 414 | batch_size = target.size(0) 415 | 416 | _, pred = output.topk(maxk, 1, True, True) 417 | pred = pred.t() 418 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 419 | 420 | res = [] 421 | for k in topk: 422 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 423 | res.append(correct_k.mul_(100.0 / batch_size)) 424 | return res 425 | 426 | 427 | if __name__ == '__main__': 428 | main() 429 | -------------------------------------------------------------------------------- /detection/README.md: -------------------------------------------------------------------------------- 1 | ## Finetuning MoPro pre-trained model for object detection on COCO 2 | 3 | 1. Install detectron2. 4 | 2. Convert model into detectron 2 format 5 |
 6 | python convert-pretrain-to-detectron2.py [pre-trained model path] mopro.pkl
 7 | 
8 | 3. Put dataset under "./datasets" directory, following the directory structure requried by detectron2. 9 | 4. Run training: 10 |
11 | python train_net.py --config-file configs/coco_R_50_FPN_1x.yaml \
12 |  --num-gpus 8 MODEL.WEIGHTS ./mopro.pkl
13 |  
14 | 15 | ## Results using Faster R-CNN with a R50-FPN backbone, 1x schedule: 16 | Pre-train dataset| AP | AP50 | AP75 17 | --- | --- | --- | --- 18 | WebVision V1 | 39.7 | 60.9 | 43.1 19 | WebVision V2 | 40.1 | 61.7 | 44.5 20 | -------------------------------------------------------------------------------- /detection/configs/Base-RCNN-FPN-BN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | FREEZE_AT: 0 5 | NAME: "build_resnet_fpn_backbone" 6 | RESNETS: 7 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 8 | NORM: "SyncBN" 9 | FPN: 10 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 11 | NORM: "SyncBN" 12 | ANCHOR_GENERATOR: 13 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 14 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 15 | RPN: 16 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 17 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 18 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 19 | # Detectron1 uses 2000 proposals per-batch, 20 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 21 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 22 | POST_NMS_TOPK_TRAIN: 1000 23 | POST_NMS_TOPK_TEST: 1000 24 | ROI_HEADS: 25 | NAME: "StandardROIHeads" 26 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 27 | ROI_BOX_HEAD: 28 | NAME: "FastRCNNConvFCHead" 29 | NUM_FC: 2 30 | POOLER_RESOLUTION: 7 31 | ROI_MASK_HEAD: 32 | NAME: "MaskRCNNConvUpsampleHead" 33 | NUM_CONV: 4 34 | POOLER_RESOLUTION: 14 35 | TEST: 36 | PRECISE_BN: 37 | ENABLED: True 38 | EVAL_PERIOD: 10000 39 | SOLVER: 40 | IMS_PER_BATCH: 16 41 | BASE_LR: 0.02 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /detection/configs/coco_R_50_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-FPN-BN.yaml" 2 | MODEL: 3 | MASK_ON: True 4 | WEIGHTS: "See Instructions" 5 | PIXEL_MEAN: [123.675, 116.280, 103.530] 6 | PIXEL_STD: [58.395, 57.120, 57.375] 7 | RESNETS: 8 | STRIDE_IN_1X1: False 9 | DEPTH: 50 10 | INPUT: 11 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 12 | MIN_SIZE_TEST: 800 13 | FORMAT: "RGB" 14 | DATASETS: 15 | TRAIN: ("coco_2017_train",) 16 | TEST: ("coco_2017_val",) 17 | SOLVER: 18 | STEPS: (60000, 80000) 19 | MAX_ITER: 90000 20 | BASE_LR: 0.02 21 | OUTPUT_DIR: "./coco_R50_FPN_mopro" -------------------------------------------------------------------------------- /detection/convert-pretrain-to-detectron2.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import sys 3 | import torch 4 | 5 | if __name__ == "__main__": 6 | input = sys.argv[1] 7 | 8 | obj = torch.load(input, map_location="cpu") 9 | obj = obj["state_dict"] 10 | 11 | newmodel = {} 12 | for k, v in obj.items(): 13 | 14 | if not k.startswith("module.encoder_q.") or k.startswith("module.encoder_q.fc") or k.startswith("module.encoder_q.classifier"): 15 | continue 16 | 17 | old_k = k 18 | k = k.replace("module.encoder_q.", "") 19 | 20 | if "layer" not in k: 21 | k = "stem." + k 22 | for t in [1, 2, 3, 4]: 23 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 24 | for t in [1, 2, 3]: 25 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 26 | k = k.replace("downsample.0", "shortcut") 27 | k = k.replace("downsample.1", "shortcut.norm") 28 | print(old_k, "->", k) 29 | newmodel[k] = v.numpy() 30 | 31 | res = {"model": newmodel, "__author__": 'mopro', "matching_heuristics": True} 32 | 33 | with open(sys.argv[2], "wb") as f: 34 | pkl.dump(res, f) -------------------------------------------------------------------------------- /detection/train_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from detectron2.checkpoint import DetectionCheckpointer 4 | from detectron2.config import get_cfg 5 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch 6 | from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator 7 | from detectron2.layers import get_norm 8 | from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads 9 | 10 | 11 | @ROI_HEADS_REGISTRY.register() 12 | class Res5ROIHeadsExtraNorm(Res5ROIHeads): 13 | """ 14 | As described in the MOCO paper, there is an extra BN layer 15 | following the res5 stage. 16 | """ 17 | def _build_res5_block(self, cfg): 18 | seq, out_channels = super()._build_res5_block(cfg) 19 | norm = cfg.MODEL.RESNETS.NORM 20 | norm = get_norm(norm, out_channels) 21 | seq.add_module("norm", norm) 22 | return seq, out_channels 23 | 24 | 25 | class Trainer(DefaultTrainer): 26 | @classmethod 27 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 28 | if output_folder is None: 29 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 30 | if "coco" in dataset_name: 31 | return COCOEvaluator(dataset_name, cfg, True, output_folder) 32 | else: 33 | assert "voc" in dataset_name 34 | return PascalVOCDetectionEvaluator(dataset_name) 35 | 36 | 37 | def setup(args): 38 | cfg = get_cfg() 39 | cfg.merge_from_file(args.config_file) 40 | cfg.merge_from_list(args.opts) 41 | cfg.freeze() 42 | default_setup(cfg, args) 43 | return cfg 44 | 45 | 46 | def main(args): 47 | cfg = setup(args) 48 | 49 | if args.eval_only: 50 | model = Trainer.build_model(cfg) 51 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 52 | cfg.MODEL.WEIGHTS, resume=args.resume 53 | ) 54 | res = Trainer.test(cfg, model) 55 | return res 56 | 57 | trainer = Trainer(cfg) 58 | trainer.resume_or_load(resume=args.resume) 59 | return trainer.train() 60 | 61 | 62 | if __name__ == "__main__": 63 | args = default_argument_parser().parse_args() 64 | print("Command Line Args:", args) 65 | launch( 66 | main, 67 | args.num_gpus, 68 | num_machines=args.num_machines, 69 | machine_rank=args.machine_rank, 70 | dist_url=args.dist_url, 71 | args=(args,), 72 | ) 73 | -------------------------------------------------------------------------------- /finetune_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import builtins 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | import tensorboard_logger as tb_logger 22 | 23 | model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 28 | parser.add_argument('--data', metavar='dir', default='/export/share/datasets/vision/imagenet', 29 | help='path to dataset') 30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 31 | choices=model_names, help='model architecture') 32 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=40, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('-b', '--batch-size', default=256, type=int, 39 | help='mini-batch size (default: 256), this is the total ' 40 | 'batch size of all GPUs on the current node when ' 41 | 'using Data Parallel or Distributed Data Parallel') 42 | parser.add_argument('--lr', '--learning-rate', default=0.005, type=float, 43 | metavar='LR', help='initial learning rate', dest='lr') 44 | parser.add_argument('--decay-rate', default=0.2, type=float) 45 | parser.add_argument('--schedule', default=[20, 30], nargs='*', type=int, 46 | help='learning rate schedule') 47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 48 | help='momentum') 49 | parser.add_argument('--wd', '--weight-decay', default=0, type=float) 50 | parser.add_argument('-p', '--print-freq', default=10, type=int, 51 | metavar='N', help='print frequency (default: 10)') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--world-size', default=-1, type=int, 55 | help='number of nodes for distributed training') 56 | parser.add_argument('--rank', default=-1, type=int, 57 | help='node rank for distributed training') 58 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 59 | help='url used to set up distributed training') 60 | parser.add_argument('--dist-backend', default='nccl', type=str, 61 | help='distributed backend') 62 | parser.add_argument('--seed', default=None, type=int, 63 | help='seed for initializing training. ') 64 | parser.add_argument('--gpu', default=None, type=int, 65 | help='GPU id to use.') 66 | parser.add_argument('--multiprocessing-distributed', action='store_true', 67 | help='Use multi-processing distributed training to launch ' 68 | 'N processes per node, which has N GPUs. This is the ' 69 | 'fastest way to use PyTorch for either single node or ' 70 | 'multi node data parallel training') 71 | 72 | parser.add_argument('--exp-dir', default='./Finetune_1percent', type=str, 73 | help='experiment directory') 74 | parser.add_argument('--model-path', type=str, default='', 75 | help='the model to test') 76 | parser.add_argument('--low-resource', default=0.01, type=float, 77 | help='percentage of training data') 78 | best_acc1 = 0 79 | 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | 84 | if args.seed is not None: 85 | random.seed(args.seed) 86 | torch.manual_seed(args.seed) 87 | 88 | if args.dist_url == "env://" and args.world_size == -1: 89 | args.world_size = int(os.environ["WORLD_SIZE"]) 90 | 91 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 92 | 93 | ngpus_per_node = torch.cuda.device_count() 94 | if args.multiprocessing_distributed: 95 | # Since we have ngpus_per_node processes per node, the total world_size 96 | # needs to be adjusted accordingly 97 | args.world_size = ngpus_per_node * args.world_size 98 | # Use torch.multiprocessing.spawn to launch distributed processes: the 99 | # main_worker process function 100 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 101 | else: 102 | # Simply call main_worker function 103 | main_worker(args.gpu, ngpus_per_node, args) 104 | 105 | 106 | def main_worker(gpu, ngpus_per_node, args): 107 | global best_acc1 108 | args.gpu = gpu 109 | 110 | if args.gpu is not None: 111 | print("Use GPU: {} for training".format(args.gpu)) 112 | if args.multiprocessing_distributed and args.gpu != 0: 113 | def print_pass(*args): 114 | pass 115 | builtins.print = print_pass 116 | 117 | if args.distributed: 118 | if args.dist_url == "env://" and args.rank == -1: 119 | args.rank = int(os.environ["RANK"]) 120 | if args.multiprocessing_distributed: 121 | # For multiprocessing distributed training, rank needs to be the 122 | # global rank among all the processes 123 | args.rank = args.rank * ngpus_per_node + gpu 124 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 125 | world_size=args.world_size, rank=args.rank) 126 | # create model 127 | if args.pretrained: 128 | print("=> using pre-trained model '{}'".format(args.arch)) 129 | model = models.__dict__[args.arch](pretrained=True) 130 | else: 131 | print("=> creating model '{}'".format(args.arch)) 132 | model = models.__dict__[args.arch]() 133 | 134 | if args.distributed: 135 | # For multiprocessing distributed, DistributedDataParallel constructor 136 | # should always set the single device scope, otherwise, 137 | # DistributedDataParallel will use all available devices. 138 | if args.gpu is not None: 139 | torch.cuda.set_device(args.gpu) 140 | model.cuda(args.gpu) 141 | # When using a single GPU per process and per 142 | # DistributedDataParallel, we need to divide the batch size 143 | # ourselves based on the total number of GPUs we have 144 | args.batch_size = int(args.batch_size / ngpus_per_node) 145 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 146 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 147 | else: 148 | model.cuda() 149 | # DistributedDataParallel will divide and allocate batch_size to all 150 | # available GPUs if device_ids are not set 151 | model = torch.nn.parallel.DistributedDataParallel(model) 152 | elif args.gpu is not None: 153 | torch.cuda.set_device(args.gpu) 154 | model = model.cuda(args.gpu) 155 | else: 156 | # DataParallel will divide and allocate batch_size to all available GPUs 157 | model = torch.nn.DataParallel(model).cuda() 158 | 159 | # define loss function (criterion) and optimizer 160 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 161 | 162 | 163 | if not args.pretrained: 164 | print('==> loading pre-trained model') 165 | ckpt = torch.load(args.model_path) 166 | state_dict = ckpt['state_dict'] 167 | # rename pre-trained keys 168 | for k in list(state_dict.keys()): 169 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 170 | # remove prefix 171 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 172 | # delete renamed or unused k 173 | del state_dict[k] 174 | state_dict['fc.weight'] = state_dict['classifier.weight'] 175 | state_dict['fc.bias'] = state_dict['classifier.bias'] 176 | del state_dict['classifier.weight'] 177 | del state_dict['classifier.bias'] 178 | 179 | model.module.load_state_dict(state_dict) 180 | print("==> loaded checkpoint '{}' (epoch {})".format(args.model_path, ckpt['epoch'])) 181 | else: 182 | print("==> use supervised pre-trained model") 183 | 184 | 185 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 186 | momentum=args.momentum, 187 | weight_decay=args.weight_decay) 188 | 189 | cudnn.benchmark = True 190 | 191 | # Data loading code 192 | traindir = os.path.join(args.data, 'train') 193 | valdir = os.path.join(args.data, 'val') 194 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 195 | std=[0.229, 0.224, 0.225]) 196 | 197 | train_dataset = datasets.ImageFolder( 198 | traindir, 199 | transforms.Compose([ 200 | transforms.RandomResizedCrop(224), 201 | transforms.RandomHorizontalFlip(), 202 | transforms.ToTensor(), 203 | normalize, 204 | ])) 205 | 206 | if args.low_resource is not None: 207 | # randomly sample training data 208 | random.shuffle(train_dataset.samples) 209 | train_dataset.samples = train_dataset.samples[:int(args.low_resource*len(train_dataset))] 210 | print('Training dataset has %d samples.'%len(train_dataset)) 211 | 212 | if args.distributed: 213 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 214 | else: 215 | train_sampler = None 216 | 217 | train_loader = torch.utils.data.DataLoader( 218 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 219 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 220 | 221 | val_loader = torch.utils.data.DataLoader( 222 | datasets.ImageFolder(valdir, transforms.Compose([ 223 | transforms.Resize(256), 224 | transforms.CenterCrop(224), 225 | transforms.ToTensor(), 226 | normalize, 227 | ])), 228 | batch_size=args.batch_size, shuffle=False, 229 | num_workers=args.workers, pin_memory=True) 230 | 231 | if args.rank % ngpus_per_node == 0: 232 | logger = tb_logger.Logger(logdir=args.exp_dir, flush_secs=2) 233 | else: 234 | logger = None 235 | 236 | for epoch in range(args.start_epoch, args.epochs): 237 | if args.distributed: 238 | train_sampler.set_epoch(epoch) 239 | adjust_learning_rate(optimizer, epoch, args) 240 | 241 | # train for one epoch 242 | train(train_loader, model, criterion, optimizer, epoch, args) 243 | 244 | # evaluate on validation set 245 | acc1 = validate(val_loader, model, criterion, args, epoch, logger) 246 | 247 | # remember best acc@1 and save checkpoint 248 | is_best = acc1 > best_acc1 249 | best_acc1 = max(acc1, best_acc1) 250 | 251 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 252 | and args.rank % ngpus_per_node == 0): 253 | save_checkpoint({ 254 | 'epoch': epoch + 1, 255 | 'arch': args.arch, 256 | 'state_dict': model.state_dict(), 257 | 'best_acc1': best_acc1, 258 | 'optimizer' : optimizer.state_dict(), 259 | }, is_best, epoch, args.exp_dir) 260 | 261 | 262 | def train(train_loader, model, criterion, optimizer, epoch, args): 263 | batch_time = AverageMeter('Time', ':6.3f') 264 | data_time = AverageMeter('Data', ':6.3f') 265 | losses = AverageMeter('Loss', ':.4e') 266 | top1 = AverageMeter('Acc@1', ':6.2f') 267 | top5 = AverageMeter('Acc@5', ':6.2f') 268 | progress = ProgressMeter( 269 | len(train_loader), 270 | [batch_time, data_time, losses, top1, top5], 271 | prefix="Epoch: [{}]".format(epoch)) 272 | 273 | # switch to train mode 274 | model.train() 275 | 276 | end = time.time() 277 | for i, (images, target) in enumerate(train_loader): 278 | # measure data loading time 279 | data_time.update(time.time() - end) 280 | 281 | if args.gpu is not None: 282 | images = images.cuda(args.gpu, non_blocking=True) 283 | target = target.cuda(args.gpu, non_blocking=True) 284 | 285 | # compute output 286 | output = model(images) 287 | loss = criterion(output, target) 288 | 289 | # measure accuracy and record loss 290 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 291 | losses.update(loss.item(), images.size(0)) 292 | top1.update(acc1[0], images.size(0)) 293 | top5.update(acc5[0], images.size(0)) 294 | 295 | # compute gradient and do SGD step 296 | optimizer.zero_grad() 297 | loss.backward() 298 | optimizer.step() 299 | 300 | # measure elapsed time 301 | batch_time.update(time.time() - end) 302 | end = time.time() 303 | 304 | if i % args.print_freq == 0: 305 | progress.display(i) 306 | 307 | 308 | def validate(val_loader, model, criterion, args, epoch, tb_logger): 309 | batch_time = AverageMeter('Time', ':6.3f') 310 | losses = AverageMeter('Loss', ':.4e') 311 | top1 = AverageMeter('Acc@1', ':6.2f') 312 | top5 = AverageMeter('Acc@5', ':6.2f') 313 | progress = ProgressMeter( 314 | len(val_loader), 315 | [batch_time, losses, top1, top5], 316 | prefix='Test: ') 317 | 318 | # switch to evaluate mode 319 | model.eval() 320 | 321 | with torch.no_grad(): 322 | end = time.time() 323 | for i, (images, target) in enumerate(val_loader): 324 | if args.gpu is not None: 325 | images = images.cuda(args.gpu, non_blocking=True) 326 | target = target.cuda(args.gpu, non_blocking=True) 327 | 328 | # compute output 329 | output = model(images) 330 | loss = criterion(output, target) 331 | 332 | # measure accuracy and record loss 333 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 334 | losses.update(loss.item(), images.size(0)) 335 | top1.update(acc1[0], images.size(0)) 336 | top5.update(acc5[0], images.size(0)) 337 | 338 | # measure elapsed time 339 | batch_time.update(time.time() - end) 340 | end = time.time() 341 | 342 | if i % args.print_freq == 0: 343 | progress.display(i) 344 | 345 | # TODO: this should also be done with the ProgressMeter 346 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 347 | .format(top1=top1, top5=top5)) 348 | 349 | if args.gpu == 0: 350 | tb_logger.log_value('Top1 Acc', top1.avg, epoch) 351 | tb_logger.log_value('Top5 Acc', top5.avg, epoch) 352 | return top1.avg 353 | 354 | 355 | def save_checkpoint(state, is_best, epoch, path='./'): 356 | filename = os.path.join(path,'checkpoint_%d.pth.tar'%epoch) 357 | torch.save(state, filename) 358 | if is_best: 359 | shutil.copyfile(filename, os.path.join(path,'model_best.pth.tar')) 360 | 361 | 362 | class AverageMeter(object): 363 | """Computes and stores the average and current value""" 364 | def __init__(self, name, fmt=':f'): 365 | self.name = name 366 | self.fmt = fmt 367 | self.reset() 368 | 369 | def reset(self): 370 | self.val = 0 371 | self.avg = 0 372 | self.sum = 0 373 | self.count = 0 374 | 375 | def update(self, val, n=1): 376 | self.val = val 377 | self.sum += val * n 378 | self.count += n 379 | self.avg = self.sum / self.count 380 | 381 | def __str__(self): 382 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 383 | return fmtstr.format(**self.__dict__) 384 | 385 | 386 | class ProgressMeter(object): 387 | def __init__(self, num_batches, meters, prefix=""): 388 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 389 | self.meters = meters 390 | self.prefix = prefix 391 | 392 | def display(self, batch): 393 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 394 | entries += [str(meter) for meter in self.meters] 395 | print('\t'.join(entries)) 396 | 397 | def _get_batch_fmtstr(self, num_batches): 398 | num_digits = len(str(num_batches // 1)) 399 | fmt = '{:' + str(num_digits) + 'd}' 400 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 401 | 402 | 403 | def adjust_learning_rate(optimizer, epoch, args): 404 | """Decay the learning rate based on schedule""" 405 | lr = args.lr 406 | for milestone in args.schedule: 407 | lr *= args.decay_rate if epoch >= milestone else 1. 408 | optimizer.param_groups[0]['lr'] = lr 409 | 410 | 411 | def accuracy(output, target, topk=(1,)): 412 | """Computes the accuracy over the k top predictions for the specified values of k""" 413 | with torch.no_grad(): 414 | maxk = max(topk) 415 | batch_size = target.size(0) 416 | 417 | _, pred = output.topk(maxk, 1, True, True) 418 | pred = pred.t() 419 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 420 | 421 | res = [] 422 | for k in topk: 423 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 424 | res.append(correct_k.mul_(100.0 / batch_size)) 425 | return res 426 | 427 | 428 | if __name__ == '__main__': 429 | main() -------------------------------------------------------------------------------- /img/blog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/MoPro/c02afed3640ac89dadaa368b2ea5dc3180a31d7e/img/blog.png -------------------------------------------------------------------------------- /lowshot_svm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import time 6 | import argparse 7 | import torch 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn.functional as F 11 | import torch.nn as nn 12 | from torchvision import transforms, datasets 13 | 14 | from resnet import * 15 | 16 | from DataLoader.Places_dataset import Places205 17 | from DataLoader.VOC_dataset import Voc2007Classification 18 | 19 | from sklearn.svm import LinearSVC 20 | import numpy as np 21 | import random 22 | 23 | 24 | parser = argparse.ArgumentParser('argument for training') 25 | 26 | parser.add_argument('--batch_size', type=int, default=128, help='batch_size') 27 | parser.add_argument('--num_workers', type=int, default=20, help='num of workers to use') 28 | 29 | # model definition 30 | parser.add_argument('--model_path', type=str, default='', help='the model to test') 31 | parser.add_argument('--supervise', default=False, action='store_true', help='whether to use supervised pretrained model') 32 | parser.add_argument('--cost', type=float, default=0.5, help='cost parameter for SVM') 33 | 34 | # dataset 35 | parser.add_argument('--dataset', type=str, default='voc', choices=['places', 'voc']) 36 | parser.add_argument('--voc-path', type=str, default='') 37 | parser.add_argument('--places-path', type=str, default='') 38 | 39 | # seed 40 | parser.add_argument('--seed', default=0, type=int) 41 | 42 | 43 | 44 | def main(): 45 | 46 | args = parser.parse_args() 47 | random.seed(args.seed) 48 | np.random.seed(args.seed) 49 | 50 | mean = [0.485, 0.456, 0.406] 51 | std = [0.229, 0.224, 0.225] 52 | normalize = transforms.Normalize(mean=mean, std=std) 53 | transform = transforms.Compose([ 54 | transforms.Resize(256), 55 | transforms.CenterCrop(224), 56 | transforms.ToTensor(), 57 | normalize, 58 | ]) 59 | 60 | if args.dataset=='voc': 61 | train_dataset = Voc2007Classification(args.voc_path,set='trainval',transform = transform) 62 | val_dataset = Voc2007Classification(args.voc_path,set='test',transform = transform) 63 | 64 | elif args.dataset=='places': 65 | train_dataset = Places205(args.places_path, 'train', transform = transform) 66 | val_dataset = Places205(args.places_path, 'val', transform = transform) 67 | 68 | val_loader = torch.utils.data.DataLoader( 69 | val_dataset, batch_size=args.batch_size, shuffle=False, 70 | num_workers=args.num_workers, pin_memory=True) 71 | 72 | model = resnet50(encoder=True) 73 | 74 | if not args.supervise: 75 | print('==> loading pre-trained model') 76 | ckpt = torch.load(args.model_path) 77 | state_dict = ckpt['state_dict'] 78 | 79 | # rename pre-trained keys 80 | for k in list(state_dict.keys()): 81 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc') and not k.startswith('module.encoder_q.classifier'): 82 | # remove prefix 83 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 84 | # delete renamed or unused k 85 | del state_dict[k] 86 | 87 | model.load_state_dict(state_dict,strict=False) 88 | print("==> loaded checkpoint '{}' (epoch {})".format(args.model_path, ckpt['epoch'])) 89 | else: 90 | print("==> use supervised pre-trained model") 91 | 92 | model = model.cuda() 93 | model.eval() 94 | 95 | test_feats = [] 96 | test_labels = [] 97 | print('==> calculate test features') 98 | for idx, (inputs, target) in enumerate(val_loader): 99 | inputs = inputs.cuda() 100 | feat = model(inputs) 101 | feat = feat.detach().cpu() 102 | test_feats.append(feat) 103 | test_labels.append(target) 104 | 105 | test_feats = torch.cat(test_feats,0).numpy() 106 | test_labels = torch.cat(test_labels,0).numpy() 107 | 108 | test_feats_norm = np.linalg.norm(test_feats, axis=1) 109 | test_feats = test_feats / (test_feats_norm + 1e-5)[:, np.newaxis] 110 | 111 | result={} 112 | 113 | for k in [1,2,4,8,16]: #number of samples per-class 114 | 115 | avg_map = [] 116 | for run in range(5): # 5 runs 117 | print('==> re-sampling training data') 118 | train_dataset.convert_low_shot(k) 119 | print(len(train_dataset)) 120 | 121 | train_loader = torch.utils.data.DataLoader( 122 | train_dataset, batch_size=args.batch_size, shuffle=False, 123 | num_workers=args.num_workers, pin_memory=True) 124 | 125 | train_feats = [] 126 | train_labels = [] 127 | print('==> calculate train features') 128 | for idx, (inputs, target) in enumerate(train_loader): 129 | inputs = inputs.cuda() 130 | feat = model(inputs) 131 | feat = feat.detach() 132 | 133 | train_feats.append(feat) 134 | train_labels.append(target) 135 | 136 | train_feats = torch.cat(train_feats,0).cpu().numpy() 137 | train_labels = torch.cat(train_labels,0).cpu().numpy() 138 | 139 | train_feats_norm = np.linalg.norm(train_feats, axis=1) 140 | train_feats = train_feats / (train_feats_norm + 1e-5)[:, np.newaxis] 141 | 142 | print('==> training SVM Classifier') 143 | if args.dataset=='places': 144 | clf = LinearSVC(random_state=0, tol=1e-4, C=args.cost, dual=True, max_iter=2000) 145 | clf.fit(train_feats, train_labels) 146 | 147 | prediction = clf.predict(test_feats) 148 | print('==> testing SVM Classifier') 149 | accuracy = 100.0*(prediction==test_labels).sum()/len(test_labels) 150 | avg_map.append(accuracy) 151 | print('==> Run%d accuracy is %.2f: '%(run,accuracy)) 152 | 153 | elif args.dataset=='voc': 154 | cls_ap = np.zeros((20, 1)) 155 | test_labels[test_labels==0] = -1 156 | train_labels[train_labels==0] = -1 157 | for cls in range(20): 158 | clf = LinearSVC( 159 | C=args.cost, class_weight={1: 2, -1: 1}, intercept_scaling=1.0, 160 | penalty='l2', loss='squared_hinge', tol=1e-4, 161 | dual=True, max_iter=2000,random_state=0) 162 | clf.fit(train_feats, train_labels[:,cls]) 163 | 164 | prediction = clf.decision_function(test_feats) 165 | P, R, score, ap = get_precision_recall(test_labels[:,cls], prediction) 166 | cls_ap[cls][0] = ap*100 167 | mean_ap = np.mean(cls_ap, axis=0) 168 | 169 | print('==> Run%d mAP is %.2f: '%(run,mean_ap)) 170 | avg_map.append(mean_ap) 171 | 172 | avg_map = np.asarray(avg_map) 173 | print('Average ap is: %.2f' %(avg_map.mean())) 174 | print('Std is: %.2f' %(avg_map.std())) 175 | 176 | result[k] = avg_map.mean() 177 | print(result) 178 | 179 | 180 | def calculate_ap(rec, prec): 181 | """ 182 | Computes the AP under the precision recall curve. 183 | """ 184 | rec, prec = rec.reshape(rec.size, 1), prec.reshape(prec.size, 1) 185 | z, o = np.zeros((1, 1)), np.ones((1, 1)) 186 | mrec, mpre = np.vstack((z, rec, o)), np.vstack((z, prec, z)) 187 | for i in range(len(mpre) - 2, -1, -1): 188 | mpre[i] = max(mpre[i], mpre[i + 1]) 189 | 190 | indices = np.where(mrec[1:] != mrec[0:-1])[0] + 1 191 | ap = 0 192 | for i in indices: 193 | ap = ap + (mrec[i] - mrec[i - 1]) * mpre[i] 194 | return ap 195 | 196 | 197 | def get_precision_recall(targets, preds): 198 | """ 199 | [P, R, score, ap] = get_precision_recall(targets, preds) 200 | Input : 201 | targets : number of occurrences of this class in the ith image 202 | preds : score for this image 203 | Output : 204 | P, R : precision and recall 205 | score : score which corresponds to the particular precision and recall 206 | ap : average precision 207 | """ 208 | # binarize targets 209 | targets = np.array(targets > 0, dtype=np.float32) 210 | tog = np.hstack(( 211 | targets[:, np.newaxis].astype(np.float64), 212 | preds[:, np.newaxis].astype(np.float64) 213 | )) 214 | ind = np.argsort(preds) 215 | ind = ind[::-1] 216 | score = np.array([tog[i, 1] for i in ind]) 217 | sortcounts = np.array([tog[i, 0] for i in ind]) 218 | 219 | tp = sortcounts 220 | fp = sortcounts.copy() 221 | for i in range(sortcounts.shape[0]): 222 | if sortcounts[i] >= 1: 223 | fp[i] = 0. 224 | elif sortcounts[i] < 1: 225 | fp[i] = 1. 226 | P = np.cumsum(tp) / (np.cumsum(tp) + np.cumsum(fp)) 227 | numinst = np.sum(targets) 228 | R = np.cumsum(tp) / numinst 229 | ap = calculate_ap(R, P) 230 | return P, R, score, ap 231 | 232 | 233 | if __name__ == '__main__': 234 | main() 235 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from random import sample 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | class MoPro(nn.Module): 8 | 9 | def __init__(self, base_encoder, args, width): 10 | super(MoPro, self).__init__() 11 | 12 | #encoder 13 | self.encoder_q = base_encoder(num_class=args.num_class,low_dim=args.low_dim,width=width) 14 | #momentum encoder 15 | self.encoder_k = base_encoder(num_class=args.num_class,low_dim=args.low_dim,width=width) 16 | 17 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 18 | param_k.data.copy_(param_q.data) # initialize 19 | param_k.requires_grad = False # not update by gradient 20 | 21 | # create the queue 22 | self.register_buffer("queue", torch.randn(args.low_dim, args.moco_queue)) 23 | self.queue = F.normalize(self.queue, dim=0) 24 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 25 | self.register_buffer("prototypes", torch.zeros(args.num_class,args.low_dim)) 26 | 27 | @torch.no_grad() 28 | def _momentum_update_key_encoder(self, args): 29 | """ 30 | update momentum encoder 31 | """ 32 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 33 | param_k.data = param_k.data * args.moco_m + param_q.data * (1. - args.moco_m) 34 | 35 | @torch.no_grad() 36 | def _dequeue_and_enqueue(self, keys, args): 37 | # gather keys before updating queue 38 | keys = concat_all_gather(keys) 39 | 40 | batch_size = keys.shape[0] 41 | 42 | ptr = int(self.queue_ptr) 43 | assert args.moco_queue % batch_size == 0 # for simplicity 44 | 45 | # replace the keys at ptr (dequeue and enqueue) 46 | self.queue[:, ptr:ptr + batch_size] = keys.T 47 | ptr = (ptr + batch_size) % args.moco_queue # move pointer 48 | 49 | self.queue_ptr[0] = ptr 50 | 51 | @torch.no_grad() 52 | def _batch_shuffle_ddp(self, x): 53 | """ 54 | Batch shuffle, for making use of BatchNorm. 55 | *** Only support DistributedDataParallel (DDP) model. *** 56 | """ 57 | # gather from all gpus 58 | batch_size_this = x.shape[0] 59 | x_gather = concat_all_gather(x) 60 | batch_size_all = x_gather.shape[0] 61 | 62 | num_gpus = batch_size_all // batch_size_this 63 | 64 | # random shuffle index 65 | idx_shuffle = torch.randperm(batch_size_all).cuda() 66 | 67 | # broadcast to all gpus 68 | torch.distributed.broadcast(idx_shuffle, src=0) 69 | 70 | # index for restoring 71 | idx_unshuffle = torch.argsort(idx_shuffle) 72 | 73 | # shuffled index for this gpu 74 | gpu_idx = torch.distributed.get_rank() 75 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 76 | 77 | return x_gather[idx_this], idx_unshuffle 78 | 79 | @torch.no_grad() 80 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 81 | """ 82 | Undo batch shuffle. 83 | *** Only support DistributedDataParallel (DDP) model. *** 84 | """ 85 | # gather from all gpus 86 | batch_size_this = x.shape[0] 87 | x_gather = concat_all_gather(x) 88 | batch_size_all = x_gather.shape[0] 89 | 90 | num_gpus = batch_size_all // batch_size_this 91 | 92 | # restored index for this gpu 93 | gpu_idx = torch.distributed.get_rank() 94 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 95 | 96 | return x_gather[idx_this] 97 | 98 | 99 | def forward(self, batch, args, is_eval=False, is_proto=False, is_clean=False): 100 | 101 | img = batch[0].cuda(args.gpu, non_blocking=True) 102 | target = batch[1].cuda(args.gpu, non_blocking=True) 103 | 104 | output,q = self.encoder_q(img) 105 | if is_eval: 106 | return output, q, target 107 | 108 | img_aug = batch[2].cuda(args.gpu, non_blocking=True) 109 | # compute augmented features 110 | with torch.no_grad(): # no gradient 111 | self._momentum_update_key_encoder(args) # update the momentum encoder 112 | # shuffle for making use of BN 113 | img_aug, idx_unshuffle = self._batch_shuffle_ddp(img_aug) 114 | _, k = self.encoder_k(img_aug) 115 | # undo shuffle 116 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 117 | 118 | # compute instance logits 119 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 120 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 121 | logits = torch.cat([l_pos, l_neg], dim=1) 122 | # apply temperature 123 | logits /= args.temperature 124 | inst_labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 125 | 126 | # dequeue and enqueue 127 | self._dequeue_and_enqueue(k, args) 128 | 129 | if is_proto: 130 | # compute protoypical logits 131 | prototypes = self.prototypes.clone().detach() 132 | logits_proto = torch.mm(q,prototypes.t())/args.temperature 133 | else: 134 | logits_proto = 0 135 | 136 | if is_clean: 137 | # noise cleaning 138 | soft_label = args.alpha*F.softmax(output, dim=1) + (1-args.alpha)*F.softmax(logits_proto, dim=1) 139 | 140 | # keep ground truth label 141 | gt_score = soft_label[target>=0,target] 142 | clean_idx = gt_score>(1/args.num_class) 143 | 144 | # assign a new pseudo label 145 | max_score, hard_label = soft_label.max(1) 146 | correct_idx = max_score>args.pseudo_th 147 | target[correct_idx] = hard_label[correct_idx] 148 | 149 | # confident sample index 150 | clean_idx = clean_idx | correct_idx 151 | clean_idx_all = concat_all_gather(clean_idx.long()) 152 | 153 | # aggregate features and (pseudo) labels across all gpus 154 | targets = concat_all_gather(target) 155 | features = concat_all_gather(q) 156 | 157 | if is_clean: 158 | clean_idx_all = clean_idx_all.bool() 159 | # update momentum prototypes with pseudo-labels 160 | for feat,label in zip(features[clean_idx_all],targets[clean_idx_all]): 161 | self.prototypes[label] = self.prototypes[label]*args.proto_m + (1-args.proto_m)*feat 162 | # select only the confident samples to return 163 | q = q[clean_idx] 164 | target = target[clean_idx] 165 | logits_proto = logits_proto[clean_idx] 166 | output = output[clean_idx] 167 | else: 168 | # update momentum prototypes with original labels 169 | for feat,label in zip(features,targets): 170 | self.prototypes[label] = self.prototypes[label]*args.proto_m + (1-args.proto_m)*feat 171 | 172 | # normalize prototypes 173 | self.prototypes = F.normalize(self.prototypes, p=2, dim=1) 174 | 175 | return output, target, logits, inst_labels, logits_proto 176 | 177 | 178 | # utils 179 | @torch.no_grad() 180 | def concat_all_gather(tensor): 181 | """ 182 | Performs all_gather operation on the provided tensors. 183 | *** Warning ***: torch.distributed.all_gather has no gradient. 184 | """ 185 | tensors_gather = [torch.ones_like(tensor) 186 | for _ in range(torch.distributed.get_world_size())] 187 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 188 | 189 | output = torch.cat(tensors_gather, dim=0) 190 | return output 191 | -------------------------------------------------------------------------------- /noise_cleaning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | import random 5 | import tensorboard_logger as tb_logger 6 | import json 7 | import numpy as np 8 | from PIL import Image, ImageFile 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.utils.data 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import torch.nn.functional as F 19 | from torch.utils.data import Dataset, DataLoader 20 | 21 | from model import MoPro 22 | from resnet import * 23 | 24 | ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch WebVision Noise Correction') 28 | parser.add_argument('--data', metavar='dir', default='../WebVision/dataset/', 29 | help='path to webvision dataset') 30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',choices=['resnet50',]) 31 | parser.add_argument('-j', '--workers', default=16, type=int, 32 | help='number of data loading workers (default: 32)') 33 | parser.add_argument('-b', '--batch-size', default=256, type=int, 34 | help='mini-batch size (default: 256)') 35 | parser.add_argument('--resume', default='', type=str, 36 | help='path to latest checkpoint (default: none)') 37 | 38 | parser.add_argument('--gpu', default=0, type=int, 39 | help='GPU id to use.') 40 | 41 | parser.add_argument('--num-class', default=1000, type=int) 42 | parser.add_argument('--low-dim', default=128, type=int, 43 | help='embedding dimension') 44 | parser.add_argument('--moco_queue', default=8192, type=int, 45 | help='queue size; number of negative samples') 46 | 47 | parser.add_argument('--pseudo_th', default=0.8, type=float, 48 | help='threshold for pseudo labels') 49 | parser.add_argument('--alpha', default=0.5, type=float, 50 | help='weight to combine model prediction and prototype prediction') 51 | parser.add_argument('--temperature', default=0.1, type=float, 52 | help='contrastive temperature') 53 | parser.add_argument('--annotation', default='./pseudo_label.json', 54 | help='path to pseudo-label annotation') 55 | 56 | class webvision_dataset(Dataset): 57 | def __init__(self, root_dir): 58 | self.root = root_dir 59 | self.transform = transforms.Compose([ 60 | transforms.Resize(256), 61 | transforms.CenterCrop(224), 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)), 64 | ]) 65 | self.train_imgs = [] 66 | self.train_labels = {} 67 | with open(self.root+'info/train_filelist_google.txt') as f: 68 | lines=f.readlines() 69 | for line in lines: 70 | img, target = line.split() 71 | target = int(target) 72 | self.train_imgs.append(img) 73 | self.train_labels[img]=target 74 | 75 | with open(self.root+'info/train_filelist_flickr.txt') as f: 76 | lines=f.readlines() 77 | for line in lines: 78 | img, target = line.split() 79 | target = int(target) 80 | self.train_imgs.append(img) 81 | self.train_labels[img]=target 82 | 83 | def __getitem__(self, index): 84 | img_path = self.train_imgs[index] 85 | target = self.train_labels[img_path] 86 | image = Image.open(self.root+img_path).convert('RGB') 87 | img = self.transform(image) 88 | return img, target, img_path 89 | def __len__(self): 90 | return len(self.train_imgs) 91 | 92 | 93 | def main(): 94 | args = parser.parse_args() 95 | 96 | if args.gpu is not None: 97 | print("Use GPU: {} for training".format(args.gpu)) 98 | 99 | print("=> creating model '{}'".format(args.arch)) 100 | if args.arch == 'resnet50': 101 | model = MoPro(resnet50,args,width=1) 102 | elif args.arch == 'resnet50x2': 103 | model = MoPro(resnet50,args,width=2) 104 | elif args.arch == 'resnet50x4': 105 | model = MoPro(resnet50,args,width=4) 106 | else: 107 | raise NotImplementedError('model not supported {}'.format(args.arch)) 108 | 109 | model = model.cuda(args.gpu) 110 | model.eval() 111 | 112 | # resume from a checkpoint 113 | if args.resume: 114 | if os.path.isfile(args.resume): 115 | print("=> loading checkpoint '{}'".format(args.resume)) 116 | checkpoint = torch.load(args.resume) 117 | state_dict = checkpoint['state_dict'] 118 | for k in list(state_dict.keys()): 119 | if k.startswith('module'): 120 | # remove prefix 121 | state_dict[k[len("module."):]] = state_dict[k] 122 | # delete renamed or unused k 123 | del state_dict[k] 124 | model.load_state_dict(state_dict) 125 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 126 | else: 127 | print("=> no checkpoint found at '{}'".format(args.resume)) 128 | 129 | # Data loading code 130 | dataset = webvision_dataset(root_dir=args.data) 131 | loader = DataLoader(dataset=dataset, 132 | batch_size=args.batch_size, 133 | shuffle=False, 134 | num_workers=args.workers, 135 | pin_memory=True) 136 | 137 | images = [] 138 | labels = [] 139 | print("=> performing noise cleaning on the training data") 140 | with torch.no_grad(): 141 | for (img, target, img_path) in tqdm(loader): 142 | img = img.cuda(args.gpu, non_blocking=True) 143 | target = target.cuda(args.gpu, non_blocking=True) 144 | output,feat = model.encoder_q(img) 145 | 146 | logits = torch.mm(feat,model.prototypes.t())/args.temperature 147 | soft_label = (F.softmax(output, dim=1)+F.softmax(logits,dim=1))/2 148 | 149 | gt_score = soft_label[target>=0,target] 150 | clean_idx = gt_score>(1/args.num_class) 151 | 152 | max_score, hard_label = soft_label.max(1) 153 | correct_idx = max_score>args.pseudo_th 154 | target[correct_idx] = hard_label[correct_idx] 155 | clean_idx = clean_idx | correct_idx 156 | for clean,label,path in zip(clean_idx,target.cpu(),img_path): 157 | if clean: 158 | images.append(path) 159 | labels.append(label.item()) 160 | 161 | json.dump({'images':images,'labels':labels},open(args.annotation,'w')) 162 | print("=> pseudo-label annotation saved to {}".format(args.annotation)) 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | import torch.utils.model_zoo as model_zoo 6 | import torch.nn.functional as F 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class Normalize(nn.Module): 25 | 26 | def __init__(self, power=2): 27 | super(Normalize, self).__init__() 28 | self.power = power 29 | 30 | def forward(self, x): 31 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 32 | out = x.div(norm) 33 | return out 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(BasicBlock, self).__init__() 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None): 72 | super(Bottleneck, self).__init__() 73 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 76 | padding=1, bias=False) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(planes * 4) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.downsample = downsample 82 | self.stride = stride 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class ResNet(nn.Module): 108 | 109 | def __init__(self, block, layers, low_dim=128, in_channel=3, width=1, num_class=1000): 110 | self.inplanes = 64 111 | super(ResNet, self).__init__() 112 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, 113 | bias=False) 114 | self.bn1 = nn.BatchNorm2d(64) 115 | self.relu = nn.ReLU(inplace=True) 116 | 117 | self.base = int(64 * width) 118 | 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | self.layer1 = self._make_layer(block, self.base, layers[0]) 121 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 122 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 123 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 124 | self.avgpool = nn.AvgPool2d(7, stride=1) 125 | 126 | self.classifier = nn.Linear(self.base * 8 * block.expansion, num_class) 127 | self.l2norm = Normalize(2) 128 | 129 | #projection MLP 130 | self.fc1 = nn.Linear(self.base * 8 * block.expansion, 2048) 131 | self.fc2 = nn.Linear(2048, low_dim) 132 | 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 136 | m.weight.data.normal_(0, math.sqrt(2. / n)) 137 | elif isinstance(m, nn.BatchNorm2d): 138 | m.weight.data.fill_(1) 139 | m.bias.data.zero_() 140 | 141 | def _make_layer(self, block, planes, blocks, stride=1): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.Conv2d(self.inplanes, planes * block.expansion, 146 | kernel_size=1, stride=stride, bias=False), 147 | nn.BatchNorm2d(planes * block.expansion), 148 | ) 149 | 150 | layers = [] 151 | layers.append(block(self.inplanes, planes, stride, downsample)) 152 | self.inplanes = planes * block.expansion 153 | for i in range(1, blocks): 154 | layers.append(block(self.inplanes, planes)) 155 | 156 | return nn.Sequential(*layers) 157 | 158 | def forward(self, x): 159 | x = self.conv1(x) 160 | x = self.bn1(x) 161 | x = self.relu(x) 162 | x = self.maxpool(x) 163 | 164 | x = self.layer1(x) 165 | x = self.layer2(x) 166 | x = self.layer3(x) 167 | x = self.layer4(x) 168 | 169 | x = self.avgpool(x) 170 | feat = x.view(x.size(0), -1) 171 | 172 | out = self.classifier(feat) 173 | 174 | feat = F.relu(self.fc1(feat)) 175 | feat = self.fc2(feat) 176 | feat = self.l2norm(feat) 177 | return out,feat 178 | 179 | 180 | class ResNet_Encoder(nn.Module): 181 | 182 | def __init__(self, block, layers, in_channel=3, width=1, num_class=1000): 183 | self.inplanes = 64 184 | super(ResNet_Encoder, self).__init__() 185 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, 186 | bias=False) 187 | self.bn1 = nn.BatchNorm2d(64) 188 | self.relu = nn.ReLU(inplace=True) 189 | 190 | self.base = int(64 * width) 191 | 192 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 193 | self.layer1 = self._make_layer(block, self.base, layers[0]) 194 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 195 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 196 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 197 | self.avgpool = nn.AvgPool2d(7, stride=1) 198 | 199 | for m in self.modules(): 200 | if isinstance(m, nn.Conv2d): 201 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 202 | m.weight.data.normal_(0, math.sqrt(2. / n)) 203 | elif isinstance(m, nn.BatchNorm2d): 204 | m.weight.data.fill_(1) 205 | m.bias.data.zero_() 206 | 207 | def _make_layer(self, block, planes, blocks, stride=1): 208 | downsample = None 209 | if stride != 1 or self.inplanes != planes * block.expansion: 210 | downsample = nn.Sequential( 211 | nn.Conv2d(self.inplanes, planes * block.expansion, 212 | kernel_size=1, stride=stride, bias=False), 213 | nn.BatchNorm2d(planes * block.expansion), 214 | ) 215 | 216 | layers = [] 217 | layers.append(block(self.inplanes, planes, stride, downsample)) 218 | self.inplanes = planes * block.expansion 219 | for i in range(1, blocks): 220 | layers.append(block(self.inplanes, planes)) 221 | 222 | return nn.Sequential(*layers) 223 | 224 | def forward(self, x, is_train=False): 225 | 226 | x = self.conv1(x) 227 | x = self.bn1(x) 228 | x = self.relu(x) 229 | x = self.maxpool(x) 230 | 231 | x = self.layer1(x) 232 | x = self.layer2(x) 233 | x = self.layer3(x) 234 | x = self.layer4(x) 235 | 236 | x = self.avgpool(x) 237 | feat = x.view(x.size(0), -1) 238 | 239 | return feat 240 | 241 | def resnet18(pretrained=False, encoder=False, **kwargs): 242 | """Constructs a ResNet-18 model. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | if encoder: 247 | model = ResNet_Encoder(BasicBlock, [2, 2, 2, 2], **kwargs) 248 | else: 249 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 250 | if pretrained: 251 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 252 | return model 253 | 254 | 255 | def resnet34(pretrained=False, encoder=False, **kwargs): 256 | """Constructs a ResNet-34 model. 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | """ 260 | if encoder: 261 | model = ResNet_Encoder(BasicBlock, [3, 4, 6, 3], **kwargs) 262 | else: 263 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 264 | if pretrained: 265 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 266 | return model 267 | 268 | 269 | def resnet50(pretrained=False, encoder=False, **kwargs): 270 | """Constructs a ResNet-50 model. 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | """ 274 | if encoder: 275 | model = ResNet_Encoder(Bottleneck, [3, 4, 6, 3], **kwargs) 276 | else: 277 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 278 | if pretrained: 279 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']),strict=False) 280 | return model 281 | 282 | 283 | def resnet101(pretrained=False, encoder=False, **kwargs): 284 | """Constructs a ResNet-101 model. 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | """ 288 | if encoder: 289 | model = ResNet_Encoder(Bottleneck, [3, 4, 23, 3], **kwargs) 290 | else: 291 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 292 | if pretrained: 293 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 294 | return model 295 | 296 | 297 | 298 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import builtins 3 | import math 4 | import os 5 | import random 6 | import shutil 7 | import time 8 | import warnings 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torch.nn.functional as F 22 | 23 | from model import MoPro 24 | from resnet import * 25 | import DataLoader.dataloader as dataloader 26 | 27 | import tensorboard_logger as tb_logger 28 | 29 | import numpy as np 30 | 31 | parser = argparse.ArgumentParser(description='PyTorch WebVision Training') 32 | parser.add_argument('--data', default='../WebVision/dataset/', 33 | help='path to WebVision dataset') 34 | 35 | parser.add_argument('--exp-dir', default='experiment/MoPro_V1', type=str, 36 | help='experiment directory') 37 | 38 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',choices=['resnet50',]) 39 | parser.add_argument('-j', '--workers', default=32, type=int, 40 | help='number of data loading workers (default: 32)') 41 | parser.add_argument('--epochs', default=90, type=int, 42 | help='number of total epochs to run') 43 | parser.add_argument('--start-epoch', default=0, type=int, 44 | help='manual epoch number (useful on restarts)') 45 | parser.add_argument('-b', '--batch-size', default=256, type=int, 46 | help='mini-batch size (default: 256), this is the total ' 47 | 'batch size of all GPUs on the current node when ' 48 | 'using Data Parallel or Distributed Data Parallel') 49 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 50 | metavar='LR', help='initial learning rate', dest='lr') 51 | parser.add_argument('--schedule', default=[40, 80], nargs='*', type=int, 52 | help='learning rate schedule (when to drop lr by 10x)') 53 | parser.add_argument('--cos', action='store_true', default=False, 54 | help='use cosine lr schedule') 55 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 56 | help='momentum of SGD solver') 57 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 58 | metavar='W', help='weight decay (default: 1e-4)', 59 | dest='weight_decay') 60 | parser.add_argument('-p', '--print-freq', default=50, type=int, 61 | help='print frequency (default: 10)') 62 | parser.add_argument('--resume', default='', type=str, 63 | help='path to latest checkpoint (default: none)') 64 | parser.add_argument('--world-size', default=-1, type=int, 65 | help='number of nodes for distributed training') 66 | parser.add_argument('--rank', default=-1, type=int, 67 | help='node rank for distributed training') 68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 69 | help='url used to set up distributed training') 70 | parser.add_argument('--dist-backend', default='nccl', type=str, 71 | help='distributed backend') 72 | parser.add_argument('--seed', default=None, type=int, 73 | help='seed for initializing training. ') 74 | parser.add_argument('--gpu', default=None, type=int, 75 | help='GPU id to use.') 76 | parser.add_argument('--multiprocessing-distributed', action='store_true', 77 | help='Use multi-processing distributed training to launch ' 78 | 'N processes per node, which has N GPUs. This is the ' 79 | 'fastest way to use PyTorch for either single node or ' 80 | 'multi node data parallel training') 81 | 82 | parser.add_argument('--num-class', default=1000, type=int) 83 | parser.add_argument('--low-dim', default=128, type=int, 84 | help='embedding dimension') 85 | parser.add_argument('--moco_queue', default=8192, type=int, 86 | help='queue size; number of negative samples') 87 | parser.add_argument('--moco_m', default=0.999, type=float, 88 | help='momentum for updating momentum encoder') 89 | parser.add_argument('--proto_m', default=0.999, type=float, 90 | help='momentum for computing the momving average of prototypes') 91 | 92 | parser.add_argument('--temperature', default=0.1, type=float, 93 | help='contrastive temperature') 94 | 95 | parser.add_argument('--w-inst', default=1, type=float, 96 | help='weight for instance contrastive loss') 97 | parser.add_argument('--w-proto', default=1, type=float, 98 | help='weight for prototype contrastive loss') 99 | 100 | parser.add_argument('--start_clean_epoch', default=11, type=int, 101 | help='epoch to start noise cleaning') 102 | parser.add_argument('--pseudo_th', default=0.8, type=float, 103 | help='threshold for pseudo labels') 104 | parser.add_argument('--alpha', default=0.5, type=float, 105 | help='weight to combine model prediction and prototype prediction') 106 | 107 | 108 | def main(): 109 | args = parser.parse_args() 110 | if args.seed is not None: 111 | random.seed(args.seed) 112 | torch.manual_seed(args.seed) 113 | cudnn.deterministic = True 114 | warnings.warn('You have chosen to seed training. ' 115 | 'This will turn on the CUDNN deterministic setting, ' 116 | 'which can slow down your training considerably! ' 117 | 'You may see unexpected behavior when restarting ' 118 | 'from checkpoints.') 119 | 120 | if args.gpu is not None: 121 | warnings.warn('You have chosen a specific GPU. This will completely ' 122 | 'disable data parallelism.') 123 | 124 | if args.dist_url == "env://" and args.world_size == -1: 125 | args.world_size = int(os.environ["WORLD_SIZE"]) 126 | 127 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 128 | 129 | if not os.path.exists(args.exp_dir): 130 | os.makedirs(args.exp_dir) 131 | 132 | ngpus_per_node = torch.cuda.device_count() 133 | if args.multiprocessing_distributed: 134 | # Since we have ngpus_per_node processes per node, the total world_size 135 | # needs to be adjusted accordingly 136 | args.world_size = ngpus_per_node * args.world_size 137 | # Use torch.multiprocessing.spawn to launch distributed processes: the 138 | # main_worker process function 139 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 140 | else: 141 | # Simply call main_worker function 142 | main_worker(args.gpu, ngpus_per_node, args) 143 | 144 | 145 | def main_worker(gpu, ngpus_per_node, args): 146 | args.gpu = gpu 147 | 148 | if args.gpu is not None: 149 | print("Use GPU: {} for training".format(args.gpu)) 150 | 151 | # suppress printing if not master 152 | if args.multiprocessing_distributed and args.gpu != 0: 153 | def print_pass(*args): 154 | pass 155 | builtins.print = print_pass 156 | 157 | if args.distributed: 158 | if args.dist_url == "env://" and args.rank == -1: 159 | args.rank = int(os.environ["RANK"]) 160 | if args.multiprocessing_distributed: 161 | # For multiprocessing distributed training, rank needs to be the 162 | # global rank among all the processes 163 | args.rank = args.rank * ngpus_per_node + gpu 164 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 165 | world_size=args.world_size, rank=args.rank) 166 | # create model 167 | print("=> creating model '{}'".format(args.arch)) 168 | if args.arch == 'resnet50': 169 | model = MoPro(resnet50,args,width=1) 170 | elif args.arch == 'resnet50x2': 171 | model = MoPro(resnet50,args,width=2) 172 | elif args.arch == 'resnet50x4': 173 | model = MoPro(resnet50,args,width=4) 174 | else: 175 | raise NotImplementedError('model not supported {}'.format(args.arch)) 176 | 177 | if args.distributed: 178 | # For multiprocessing distributed, DistributedDataParallel constructor 179 | # should always set the single device scope, otherwise, 180 | # DistributedDataParallel will use all available devices. 181 | if args.gpu is not None: 182 | torch.cuda.set_device(args.gpu) 183 | model.cuda(args.gpu) 184 | # When using a single GPU per process and per 185 | # DistributedDataParallel, we need to divide the batch size 186 | # ourselves based on the total number of GPUs we have 187 | args.batch_size = int(args.batch_size / ngpus_per_node) 188 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 189 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 190 | else: 191 | model.cuda() 192 | # DistributedDataParallel will divide and allocate batch_size to all 193 | # available GPUs if device_ids are not set 194 | model = torch.nn.parallel.DistributedDataParallel(model) 195 | elif args.gpu is not None: 196 | torch.cuda.set_device(args.gpu) 197 | model = model.cuda(args.gpu) 198 | # comment out the following line for debugging 199 | raise NotImplementedError("Only DistributedDataParallel is supported.") 200 | else: 201 | # AllGather implementation (batch shuffle, queue update, etc.) in 202 | # this code only supports DistributedDataParallel. 203 | raise NotImplementedError("Only DistributedDataParallel is supported.") 204 | 205 | # define loss function (criterion) and optimizer 206 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 207 | 208 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 209 | momentum=args.momentum, 210 | weight_decay=args.weight_decay) 211 | 212 | # optionally resume from a checkpoint 213 | if args.resume: 214 | if os.path.isfile(args.resume): 215 | print("=> loading checkpoint '{}'".format(args.resume)) 216 | if args.gpu is None: 217 | checkpoint = torch.load(args.resume) 218 | else: 219 | # Map model to be loaded to specified single gpu. 220 | loc = 'cuda:{}'.format(args.gpu) 221 | checkpoint = torch.load(args.resume, map_location=loc) 222 | args.start_epoch = checkpoint['epoch'] 223 | model.load_state_dict(checkpoint['state_dict']) 224 | optimizer.load_state_dict(checkpoint['optimizer']) 225 | print("=> loaded checkpoint '{}' (epoch {})" 226 | .format(args.resume, checkpoint['epoch'])) 227 | else: 228 | print("=> no checkpoint found at '{}'".format(args.resume)) 229 | cudnn.benchmark = True 230 | 231 | # Data loading code 232 | loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_workers=args.workers,\ 233 | root_dir=args.data,num_class=args.num_class,distributed=args.distributed) 234 | train_loader,test_loader = loader.run() 235 | 236 | if args.gpu==0: 237 | logger = tb_logger.Logger(logdir=os.path.join(args.exp_dir,'tensorboard'), flush_secs=2) 238 | else: 239 | logger = None 240 | 241 | for epoch in range(args.start_epoch, args.epochs): 242 | 243 | if args.distributed: 244 | loader.train_sampler.set_epoch(epoch) 245 | adjust_learning_rate(optimizer, epoch, args) 246 | 247 | train(train_loader, model, criterion, optimizer, epoch, args, logger) 248 | 249 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 250 | and args.rank % ngpus_per_node == 0): 251 | save_checkpoint({ 252 | 'epoch': epoch + 1, 253 | 'arch': args.arch, 254 | 'state_dict': model.state_dict(), 255 | 'optimizer' : optimizer.state_dict(), 256 | }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.exp_dir,epoch)) 257 | test(model, test_loader, args, epoch, logger) 258 | 259 | 260 | def train(train_loader, model, criterion, optimizer, epoch, args, tb_logger): 261 | batch_time = AverageMeter('Time', ':1.2f') 262 | data_time = AverageMeter('Data', ':1.2f') 263 | acc_cls = AverageMeter('Acc@Cls', ':2.2f') 264 | acc_proto = AverageMeter('Acc@Proto', ':2.2f') 265 | acc_inst = AverageMeter('Acc@Inst', ':2.2f') 266 | 267 | progress = ProgressMeter( 268 | len(train_loader), 269 | [batch_time, data_time, acc_cls, acc_inst, acc_proto], 270 | prefix="Epoch: [{}]".format(epoch)) 271 | 272 | # switch to train mode 273 | model.train() 274 | 275 | end = time.time() 276 | for i, batch in enumerate(train_loader): 277 | # measure data loading time 278 | data_time.update(time.time() - end) 279 | 280 | loss = 0 281 | 282 | # compute model output 283 | cls_out, target, logits, inst_labels, logits_proto = \ 284 | model(batch,args,is_proto=(epoch>0),is_clean=(epoch>=args.start_clean_epoch)) 285 | 286 | if epoch>0: 287 | # prototypical contrastive loss 288 | loss_proto = criterion(logits_proto, target) 289 | loss += args.w_proto*loss_proto 290 | acc = accuracy(logits_proto, target)[0] 291 | acc_proto.update(acc[0]) 292 | 293 | # classification loss 294 | loss_cls = criterion(cls_out, target) 295 | # instance contrastive loss 296 | loss_inst = criterion(logits, inst_labels) 297 | 298 | loss += (loss_cls+args.w_inst*loss_inst) 299 | 300 | # log accuracy 301 | acc = accuracy(cls_out, target)[0] 302 | acc_cls.update(acc[0]) 303 | acc = accuracy(logits, inst_labels)[0] 304 | acc_inst.update(acc[0]) 305 | 306 | # compute gradient and do SGD step 307 | optimizer.zero_grad() 308 | loss.backward() 309 | optimizer.step() 310 | # measure elapsed time 311 | batch_time.update(time.time() - end) 312 | end = time.time() 313 | if i % args.print_freq == 0: 314 | progress.display(i) 315 | 316 | if args.gpu == 0: 317 | tb_logger.log_value('Train Acc', acc_cls.avg, epoch) 318 | tb_logger.log_value('Instance Acc', acc_inst.avg, epoch) 319 | tb_logger.log_value('Prototype Acc', acc_proto.avg, epoch) 320 | 321 | 322 | def test(model, test_loader, args, epoch, tb_logger): 323 | with torch.no_grad(): 324 | print('==> Evaluation...') 325 | model.eval() 326 | top1_acc = AverageMeter("Top1") 327 | top5_acc = AverageMeter("Top5") 328 | 329 | # evaluate on webvision val set 330 | for batch_idx, batch in enumerate(test_loader): 331 | outputs,_,target = model(batch, args, is_eval=True) 332 | acc1, acc5 = accuracy(outputs, target, topk=(1, 5)) 333 | top1_acc.update(acc1[0]) 334 | top5_acc.update(acc5[0]) 335 | 336 | # average across all processes 337 | acc_tensors = torch.Tensor([top1_acc.avg,top5_acc.avg]).cuda(args.gpu) 338 | dist.all_reduce(acc_tensors) 339 | acc_tensors /= args.world_size 340 | 341 | print('Webvision Accuracy is %.2f%% (%.2f%%)'%(acc_tensors[0],acc_tensors[1])) 342 | if args.gpu ==0: 343 | tb_logger.log_value('WebVision top1 Acc', acc_tensors[0], epoch) 344 | tb_logger.log_value('WebVision top5 Acc', acc_tensors[1], epoch) 345 | return 346 | 347 | 348 | 349 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 350 | torch.save(state, filename) 351 | if is_best: 352 | shutil.copyfile(filename, 'model_best.pth.tar') 353 | 354 | 355 | class AverageMeter(object): 356 | """Computes and stores the average and current value""" 357 | def __init__(self, name, fmt=':f'): 358 | self.name = name 359 | self.fmt = fmt 360 | self.reset() 361 | 362 | def reset(self): 363 | self.val = 0 364 | self.avg = 0 365 | self.sum = 0 366 | self.count = 0 367 | 368 | def update(self, val, n=1): 369 | self.val = val 370 | self.sum += val * n 371 | self.count += n 372 | self.avg = self.sum / self.count 373 | 374 | def __str__(self): 375 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 376 | return fmtstr.format(**self.__dict__) 377 | 378 | 379 | class ProgressMeter(object): 380 | def __init__(self, num_batches, meters, prefix=""): 381 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 382 | self.meters = meters 383 | self.prefix = prefix 384 | 385 | def display(self, batch): 386 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 387 | entries += [str(meter) for meter in self.meters] 388 | print('\t'.join(entries)) 389 | 390 | def _get_batch_fmtstr(self, num_batches): 391 | num_digits = len(str(num_batches // 1)) 392 | fmt = '{:' + str(num_digits) + 'd}' 393 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 394 | 395 | 396 | def adjust_learning_rate(optimizer, epoch, args): 397 | """Decay the learning rate based on schedule""" 398 | lr = args.lr 399 | if args.cos: # cosine lr schedule 400 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 401 | else: # stepwise lr schedule 402 | for milestone in args.schedule: 403 | lr *= 0.1 if epoch >= milestone else 1. 404 | for param_group in optimizer.param_groups: 405 | param_group['lr'] = lr 406 | 407 | 408 | def accuracy(output, target, topk=(1,)): 409 | """Computes the accuracy over the k top predictions for the specified values of k""" 410 | with torch.no_grad(): 411 | maxk = max(topk) 412 | batch_size = target.size(0) 413 | 414 | _, pred = output.topk(maxk, 1, True, True) 415 | pred = pred.t() 416 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 417 | 418 | res = [] 419 | for k in topk: 420 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 421 | res.append(correct_k.mul_(100.0 / batch_size)) 422 | return res 423 | 424 | 425 | if __name__ == '__main__': 426 | main() 427 | --------------------------------------------------------------------------------