├── .gitignore ├── README.md ├── data ├── README.md ├── aircraft_image_folder_generator.py ├── cars_image_folder_generator.py ├── celeba_image_folder_generator.py ├── cub_image_folder_generator.py ├── dogs_image_folder_generator.py ├── dtd_image_folder_generator.py ├── flowers_image_folder_generator.py ├── food11_image_folder_generator.py ├── mit67_image_folder_generator.py ├── pets_image_folder_generator.py ├── places_image_folder_generator.py ├── stanford40_image_folder_generator.py ├── webfg_496_check.py ├── webfg_496_image_folder_generator.py └── webvision_image_folder_generator.py ├── downstream ├── README.md ├── active │ ├── methods │ │ ├── __init__.py │ │ ├── badge.py │ │ ├── coreset.py │ │ ├── entropy.py │ │ ├── random.py │ │ └── strategy.py │ ├── models │ ├── run_active.sh │ ├── train.py │ └── util │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── misc.py ├── detection │ ├── dataloaders │ │ ├── __init__.py │ │ └── datasets │ │ │ ├── aircraft_detection.py │ │ │ └── cars_detection.py │ ├── mAP │ │ ├── README.md │ │ ├── main.py │ │ └── scripts │ │ │ └── extra │ │ │ ├── README.md │ │ │ ├── class_list.txt │ │ │ ├── convert_dr_darkflow_json.py │ │ │ ├── convert_dr_yolo.py │ │ │ ├── convert_gt_xml.py │ │ │ ├── convert_gt_yolo.py │ │ │ ├── convert_keras-yolo3.py │ │ │ ├── find_class.py │ │ │ ├── intersect-gt-and-dr.py │ │ │ └── result.txt │ ├── run_detection.sh │ ├── train.py │ └── util │ │ └── misc.py ├── mining │ ├── methods │ │ ├── __init__.py │ │ ├── base.py │ │ ├── explicit_hard_sampling.py │ │ ├── explicit_hard_sampling_memory.py │ │ ├── hard_sampling.py │ │ └── simclr.py │ ├── models │ ├── run_hard_neg_mining.sh │ ├── train.py │ └── util │ │ ├── dataset.py │ │ ├── misc.py │ │ └── transform.py ├── opensemi │ ├── models │ ├── run_openmatch.sh │ ├── run_selftraining.sh │ ├── train.py │ └── util │ │ ├── dataset.py │ │ ├── methods.py │ │ ├── misc.py │ │ └── transform.py ├── segmentation │ ├── dataloaders │ │ ├── __init__.py │ │ ├── custom_transforms.py │ │ └── datasets │ │ │ ├── cub_segmentation.py │ │ │ └── pets_segmentation.py │ ├── modeling │ │ ├── aspp.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── drn.py │ │ │ ├── mobilenet.py │ │ │ ├── resnet.py │ │ │ └── xception.py │ │ ├── decoder.py │ │ ├── deeplab.py │ │ └── sync_batchnorm │ │ │ ├── __init__.py │ │ │ ├── batchnorm.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ └── unittest.py │ ├── run_segmentation.sh │ ├── train.py │ └── util │ │ ├── calculate_weights.py │ │ ├── loss.py │ │ ├── lr_scheduler.py │ │ ├── metrics.py │ │ ├── saver.py │ │ └── summaries.py ├── semisup │ ├── models │ ├── run_semisup.sh │ ├── train.py │ └── util │ │ ├── dataset.py │ │ ├── methods.py │ │ ├── misc.py │ │ └── transform.py └── weblysup │ ├── models │ ├── run_coteaching.sh │ ├── run_dividemix.sh │ ├── train.py │ └── util │ ├── dataset.py │ ├── methods.py │ ├── misc.py │ └── transform.py ├── models ├── __init__.py ├── dino_vit.py ├── efficientnet.py ├── mae_vit.py ├── resnet.py ├── resnext.py └── timm_vit.py ├── requirements.txt ├── run_selfsup.sh ├── run_sup.sh ├── ssl ├── __init__.py ├── base.py ├── byol.py ├── dino.py ├── mae.py ├── moco.py ├── simclr.py ├── simsiam.py └── swav.py ├── train_selfsup.py ├── train_sup.py └── util ├── imagenet_subset.py ├── knn_evaluation.py ├── merge_dataset.py ├── misc.py ├── sampling.py ├── semisup_dataset.py ├── subclass ├── imagenet_sub_aircraft.txt ├── imagenet_sub_cars.txt ├── imagenet_sub_cub.txt ├── imagenet_sub_pets.txt └── imagenet_to_label.txt └── transform.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | data/ 4 | save/ 5 | *.jpg 6 | *.ipynb 7 | *.pth 8 | *.json 9 | *.sh 10 | *.png 11 | .ipynb_checkpoints/ 12 | downstream/segmentation/run/ 13 | downstream/detection/mAP/input/ 14 | downstream/detection/mAP/output/ -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## How to setup datasets 2 | After downloading each dataset from the official links, set the image files as the `torchvision.datasets.ImageFolder` format. 3 | In other words, all images should be constructed as this way: 4 | 5 | ``` 6 | data/aircraft 7 | ├── train 8 | │ ├── Boeing_717 9 | │ │ ├── 1801653.jpg 10 | │ │ ├── 1385089.jpg 11 | │ │ ├── 0181712.jpg 12 | │ │ 13 | │ ├── CRJ-900 14 | │ ├── A380 15 | │ 16 | ├── test 17 | │ ├── Boeing_717 18 | ``` 19 | 20 | We provide the python files for the ease of ImageFolder generation. Please carefully modify source and target paths in each python files, and run `[DATASET_NAME]_image_folder_generator.py` to set the right paths for train and test folders. 21 | 22 | We also provide the full details of each data we used. 23 | 24 | | Dataset name (in paper) | name (code) | # of data | Link | 25 | |-------------------------|----------------|-----------|------| 26 | | Aircraft (FGVC-Aircraft) | aircraft | 6,667 | https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ | 27 | | Cars (Stanford Cars) | cars | 8,144 | https://ai.stanford.edu/~jkrause/cars/car_dataset.html | 28 | | Pet (Oxford-IIIT Pet) | pets | 3,680 | https://www.robots.ox.ac.uk/~vgg/data/pets/ | 29 | | Birds (Caltech-UCSD Birds) | cub | 5,990 | https://www.vision.caltech.edu/datasets/cub_200_2011/ | 30 | | Dogs (Stanford Dogs) | dogs | 12,000 | http://vision.stanford.edu/aditya86/ImageNetDogs/ | 31 | | Flowers (Oxford 102 Flower) | flowers | 2,040 | https://www.robots.ox.ac.uk/~vgg/data/flowers/ | 32 | | Actions (Stanford 40 Actions) | stanford40 | 4,000 | http://vision.stanford.edu/Datasets/40actions.html | 33 | | Indoor (MIT-67 Indoor Scene) | mit67 | 5,360 | https://web.mit.edu/torralba/www/indoor.html | 34 | | Textures (Describable Textures) | dtd | 3,760 | https://www.robots.ox.ac.uk/~vgg/data/dtd/ | 35 | | Faces (CelebAMask-HQ) | celeba | 4,263 | https://github.com/switchablenorms/CelebAMask-HQ | 36 | | Food (Food 11) | food11 | 13,296 | https://www.kaggle.com/datasets/trolukovich/food11-image-dataset | 37 | | | | | | 38 | | ImageNet | imagenet | 1,281,167 | https://www.image-net.org | 39 | | Microsoft COCO | coco | 118,287 | https://cocodataset.org/#home | 40 | | iNaturalist2021-mini | inaturalist | 500,000 | https://github.com/visipedia/inat_comp/tree/master/2021 | 41 | | Places365 | places | 8,026,628 | https://ai.stanford.edu/~jkrause/cars/car_dataset.html | 42 | | ALL | everything | 9,926,082 | - | 43 | | WebVision | webvision | 2,446,037 | https://data.vision.ee.ethz.ch/cvl/webvision/dataset2017.html | 44 | | WebFG-496 | webfg | 53,339 | https://github.com/NUST-Machine-Intelligence-Laboratory/weblyFG-dataset | 45 | -------------------------------------------------------------------------------- /data/aircraft_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | ROOT = "./fgvc-aircraft-2013b/data" 5 | TARGET_ROOT = "./" 6 | 7 | LABEL = 'manufacturer' # family, manufacturer 8 | 9 | with open(os.path.join(ROOT, 'images_{}_trainval.txt'.format(LABEL))) as f: 10 | lines = f.readlines() 11 | 12 | image_to_class = dict() 13 | class_to_image = dict() 14 | for line in lines: 15 | word = line.split(' ') 16 | imagename = word[0] 17 | classname = '_'.join(word[1:]).replace('/', '_').strip() 18 | image_to_class[imagename] = classname 19 | 20 | SOURCE_DIR = os.path.join(ROOT, "./images") 21 | TARGET_DIR = os.path.join(TARGET_ROOT, "./aircraft/{}/train".format(LABEL)) 22 | 23 | print("Copying images...") 24 | for path, cls in image_to_class.items(): 25 | source_path = os.path.join(SOURCE_DIR, path+'.jpg') 26 | target_path = os.path.join(TARGET_DIR, cls, path+'.jpg') 27 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 28 | shutil.copy(source_path, target_path) 29 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 30 | 31 | with open(os.path.join(ROOT, 'images_{}_test.txt'.format(LABEL))) as f: 32 | lines = f.readlines() 33 | 34 | image_to_class = dict() 35 | class_to_image = dict() 36 | for line in lines: 37 | word = line.split(' ') 38 | imagename = word[0] 39 | classname = '_'.join(word[1:]).replace('/', '_').strip() 40 | image_to_class[imagename] = classname 41 | 42 | SOURCE_DIR = os.path.join(ROOT, "./images") 43 | # TARGET_DIR = "./aircraft/test" 44 | TARGET_DIR = os.path.join(TARGET_ROOT, "./aircraft/{}/test".format(LABEL)) 45 | 46 | print("Copying images...") 47 | for path, cls in image_to_class.items(): 48 | source_path = os.path.join(SOURCE_DIR, path+'.jpg') 49 | target_path = os.path.join(TARGET_DIR, cls, path+'.jpg') 50 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 51 | shutil.copy(source_path, target_path) 52 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 53 | -------------------------------------------------------------------------------- /data/cars_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import os 4 | from collections import defaultdict 5 | from tqdm import tqdm 6 | import pandas as pd 7 | import numpy as np 8 | import json 9 | import shutil 10 | import copy 11 | import scipy.io 12 | 13 | ROOT = "./cars" 14 | 15 | LABEL = "type" # brand, type 16 | 17 | metadata = scipy.io.loadmat(os.path.join(ROOT, "raw/cars_annos.mat")) 18 | type_list = ['SUV', 'Sedan', 'Hatchback', 'Convertible', 'Coupe', 'Wagon', 'Cab', 'Van', 'Minivan'] 19 | UNK = {'Type-S': 'Sedan', 'R': 'Coupe', 'GS': 'Sedan', 'ZR1': '', 'Z06': '', 'SS': '', 'SRT-8': '', 'SRT8': '', 'Abarth': '', 'SuperCab': 'Cab', 20 | 'IPL': 'Coupe', 'XKR': 'Coupe', 'Superleggera': 'Coupe'} 21 | 22 | if LABEL in ['brand', 'type']: 23 | if LABEL == 'brand': 24 | cls_to_label = {i+1: cls[0].split(' ')[0] for i, cls in enumerate(metadata["class_names"][0])} 25 | elif LABEL == 'type': 26 | cls_to_label = {} 27 | for i, cls in enumerate(metadata["class_names"][0]): 28 | cartype = cls[0].split(' ')[-2] 29 | if cartype not in type_list: 30 | if UNK[cartype]: 31 | cartype = UNK[cartype] 32 | cls_to_label[i+1] = cartype 33 | # = {i+1: cls[0].split(' ')[-2] for i, cls in enumerate(metadata["class_names"][0])} 34 | 35 | metadata = metadata["annotations"][0] 36 | 37 | train_paths_by_class = defaultdict(list) 38 | test_paths_by_class = defaultdict(list) 39 | 40 | for m in tqdm(metadata): 41 | path, _, _, _, _, cls, is_test = m 42 | cls = cls.item() 43 | path = path.item() 44 | is_test = is_test.item() 45 | 46 | if LABEL == 'original': 47 | key = str(cls) 48 | else: 49 | key = cls_to_label[cls] 50 | if LABEL == 'type' and key not in type_list: 51 | # print(key) 52 | continue 53 | 54 | if is_test: 55 | test_paths_by_class[key].append(path) 56 | else: 57 | train_paths_by_class[key].append(path) 58 | 59 | i = 0 60 | for key in train_paths_by_class: 61 | i += len(train_paths_by_class[key]) 62 | print(i) 63 | 64 | # SOURCE_DIR = os.path.join(ROOT, "car_ims") 65 | # TARGET_DIR = "./train" 66 | SOURCE_DIR = os.path.join(ROOT, "raw/car_ims") 67 | TARGET_DIR = os.path.join(ROOT, "./{}/train".format(LABEL)) 68 | 69 | print("Copying images...") 70 | for cls, paths in train_paths_by_class.items(): 71 | for source_path in paths: 72 | path = source_path.split('/')[1] 73 | target_path = os.path.join(TARGET_DIR, cls, path) 74 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 75 | shutil.copy(os.path.join(SOURCE_DIR, path), target_path) 76 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 77 | 78 | 79 | # SOURCE_DIR = os.path.join(ROOT, "car_ims") 80 | # TARGET_DIR = "./test" 81 | SOURCE_DIR = os.path.join(ROOT, "raw/car_ims") 82 | TARGET_DIR = os.path.join(ROOT, "./{}/test".format(LABEL)) 83 | 84 | print("Copying images...") 85 | for cls, paths in test_paths_by_class.items(): 86 | for source_path in paths: 87 | path = source_path.split('/')[1] 88 | target_path = os.path.join(TARGET_DIR, cls, path) 89 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 90 | shutil.copy(os.path.join(SOURCE_DIR, path), target_path) 91 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 92 | -------------------------------------------------------------------------------- /data/celeba_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copyfile 3 | 4 | ROOT = './celeba_maskhq/' 5 | 6 | orig_identities = {} 7 | with open(os.path.join('./celeba', 'identity_CelebA.txt')) as f: 8 | lines = f.readlines() 9 | for line in lines: 10 | file_name, identity = line.strip().split() 11 | orig_identities[file_name] = identity 12 | 13 | identities = {} 14 | 15 | with open(os.path.join(ROOT, 'CelebA-HQ-to-CelebA-mapping.txt')) as f: 16 | lines = f.readlines() 17 | for idx, line in enumerate(lines): 18 | if idx == 0: continue 19 | file_name, _, orig_file_name = line.strip().split() 20 | identities['{}.jpg'.format(file_name)] = orig_identities[orig_file_name] 21 | 22 | print(f'There are {len(set(identities.values()))} identities.') 23 | print(f'There are {len(identities.keys())} images.') 24 | 25 | source_root = os.path.join(ROOT, 'CelebA-HQ-img') 26 | target_root = os.path.join(ROOT, 'identity') 27 | file_list = os.listdir(source_root) 28 | 29 | for file in file_list: 30 | identity = identities[file] 31 | source = os.path.join(source_root, file) 32 | target = os.path.join(target_root, str(identity), file) 33 | if not os.path.exists(os.path.join(target_root, str(identity))): 34 | os.makedirs(os.path.join(target_root, str(identity))) 35 | copyfile(source, target) 36 | 37 | 38 | # sample the identities with higher than 15 images 39 | 40 | folder_root = os.path.join(ROOT, 'identity') 41 | folder_list = os.listdir(folder_root) 42 | 43 | threshold = 15 44 | identity_cnt = 0 45 | 46 | train_images = 0 47 | test_images = 0 48 | train_ratio = 0.8 49 | 50 | for folder in folder_list: 51 | file_list = os.path.join(folder_root, folder) 52 | file_list = os.listdir(file_list) 53 | if len(file_list) >= threshold: 54 | identity_cnt += 1 55 | num_train = int(train_ratio * len(file_list)) 56 | for file in file_list[:num_train]: 57 | train_images += 1 58 | source = os.path.join(folder_root, folder, file) 59 | target = os.path.join(folder_root, 'train', folder, file) 60 | if not os.path.exists(os.path.join(folder_root, 'train', folder)): 61 | os.makedirs(os.path.join(folder_root, 'train', folder)) 62 | os.rename(source, target) 63 | for file in file_list[num_train:]: 64 | test_images += 1 65 | source = os.path.join(folder_root, folder, file) 66 | target = os.path.join(folder_root, 'test', folder, file) 67 | if not os.path.exists(os.path.join(folder_root, 'test', folder)): 68 | os.makedirs(os.path.join(folder_root, 'test', folder)) 69 | os.rename(source, target) 70 | 71 | print(f'There are {identity_cnt} identities that have more than {threshold} images.') 72 | print(f'There are {train_images} train images.') 73 | print(f'There are {test_images} test images.') 74 | 75 | 76 | for folder in os.listdir(os.path.join(ROOT, 'identity')): 77 | if folder not in ['train', 'test']: 78 | os.system("rm -rf {}".format(os.path.join(ROOT, 'identity', folder))) 79 | 80 | os.system("mv {} {}/".format(os.path.join(ROOT, 'identity', 'train'), ROOT)) 81 | os.system("mv {} {}/".format(os.path.join(ROOT, 'identity', 'test'), ROOT)) 82 | os.system("rm -rf {}".format(os.path.join(ROOT, 'identity'))) 83 | 84 | 85 | # For multi-attribute recognition for CelebA 86 | # only use "Male" (20) and "Smiling" (31) attribute 87 | 88 | multi_att = {} 89 | with open(os.path.join(ROOT, 'CelebAMask-HQ-attribute-anno.txt')) as f: 90 | lines = f.readlines() 91 | for idx, line in enumerate(lines): 92 | if idx == 0: continue 93 | elif idx == 1: 94 | cls_to_att = line.strip().split() 95 | cls_to_att = {k: i for i, k in enumerate(cls_to_att)} 96 | else: 97 | attribute = line.strip().split() 98 | file_name, att_lst = attribute[0], attribute[1:] 99 | att_lst = [int(a.replace('-1', '0')) for a in att_lst] 100 | 101 | multi_att[file_name] = (att_lst[20], att_lst[31]) 102 | 103 | male, smile = 0, 0 104 | for k, v in multi_att.items(): 105 | if v[0] == 1: male +=1 106 | if v[1] == 1: smile += 1 107 | 108 | print('whole CelebAMask-HQ dataset') 109 | 110 | ATT = 'male' 111 | 112 | os.makedirs(os.path.join(ROOT, '{}/train'.format(ATT))) 113 | os.makedirs(os.path.join(ROOT, '{}/test'.format(ATT))) 114 | 115 | for split in ['train', 'test']: 116 | source_root = os.path.join(ROOT, split) 117 | target_root = os.path.join(ROOT, ATT, split) 118 | folder_list = os.listdir(source_root) 119 | 120 | for folder in folder_list: 121 | for file in os.listdir(os.path.join(source_root, folder)): 122 | atts = multi_att[file] 123 | if ATT == 'male': att = atts[0] 124 | if ATT == 'smiling': att = atts[1] 125 | 126 | source = os.path.join(source_root, folder, file) 127 | target = os.path.join(target_root, str(att), file) 128 | 129 | if not os.path.exists(os.path.join(target_root, str(att))): 130 | os.makedirs(os.path.join(target_root, str(att))) 131 | copyfile(source, target) 132 | -------------------------------------------------------------------------------- /data/cub_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | import matplotlib.pyplot as plt 4 | import os 5 | import shutil 6 | from shutil import copyfile 7 | 8 | # Change the path to your dataset folder: 9 | base_folder = './CUB_200_2011/' 10 | 11 | # These path should be fine 12 | images_txt_path = base_folder+ 'images.txt' 13 | train_test_split_path = base_folder+ 'train_test_split.txt' 14 | images_path = base_folder+ 'images/' 15 | 16 | # Here declare where you want to place the train/test folders 17 | # You don't need to create them! 18 | test_folder = './data/cub/test/' 19 | train_folder = './data/cub/train/' 20 | 21 | 22 | def ignore_files(dir,files): return [f for f in files if os.path.isfile(os.path.join(dir,f))] 23 | 24 | shutil.copytree(images_path,test_folder,ignore=ignore_files) 25 | shutil.copytree(images_path,train_folder,ignore=ignore_files) 26 | 27 | with open(images_txt_path) as f: 28 | images_lines = f.readlines() 29 | 30 | with open(train_test_split_path) as f: 31 | split_lines = f.readlines() 32 | 33 | test_images, train_images = 0,0 34 | 35 | for image_line,split_line in zip(images_lines,split_lines): 36 | 37 | image_line = (image_line.strip()).split(' ') 38 | split_line = (split_line.strip()).split(' ') 39 | 40 | image = plt.imread(images_path + image_line[1]) 41 | 42 | # Use only RGB images, avoid grayscale 43 | if len(image.shape) == 3: 44 | 45 | # If test image 46 | if(int(split_line[1]) == 0): 47 | copyfile(images_path+image_line[1],test_folder+image_line[1]) 48 | test_images += 1 49 | else: 50 | # If train image 51 | copyfile(images_path+image_line[1],train_folder+image_line[1]) 52 | train_images += 1 53 | 54 | print(train_images,test_images) 55 | assert train_images == 5990 56 | assert test_images == 5790 57 | 58 | print('Dataset succesfully splitted!') 59 | -------------------------------------------------------------------------------- /data/dogs_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | import matplotlib.pyplot as plt 4 | import os 5 | import shutil 6 | import scipy.io 7 | from shutil import copyfile 8 | 9 | # Change the path to your dataset folder: 10 | base_folder = './Images/' 11 | 12 | train_path = './train_list.mat' 13 | test_path = './test_list.mat' 14 | train_metadata = scipy.io.loadmat(train_path)['file_list'] 15 | test_metadata = scipy.io.loadmat(test_path)['file_list'] 16 | 17 | # Here declare where you want to place the train/test folders 18 | # You don't need to create them! 19 | test_folder = './data/dogs/test/' 20 | train_folder = './data/dogs/train/' 21 | 22 | 23 | def ignore_files(dir,files): return [f for f in files if os.path.isfile(os.path.join(dir,f))] 24 | 25 | shutil.copytree(base_folder,test_folder,ignore=ignore_files) 26 | shutil.copytree(base_folder,train_folder,ignore=ignore_files) 27 | 28 | train_lines, test_lines = [], [] 29 | for i in range(len(train_metadata)): 30 | train_lines.append(train_metadata[i][0][0]) 31 | for i in range(len(test_metadata)): 32 | test_lines.append(test_metadata[i][0][0]) 33 | 34 | test_images, train_images = 0,0 35 | 36 | for train_line in train_lines: 37 | if os.path.exists(base_folder+train_line): 38 | copyfile(base_folder+train_line, train_folder+train_line) 39 | train_images += 1 40 | 41 | for test_line in test_lines: 42 | if os.path.exists(base_folder+test_line): 43 | copyfile(base_folder+test_line, test_folder+test_line) 44 | test_images += 1 45 | 46 | print(train_images, test_images) 47 | assert train_images == 12000 48 | assert test_images == 8580 49 | 50 | print('Dataset succesfully splitted!') 51 | -------------------------------------------------------------------------------- /data/dtd_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io 3 | import shutil 4 | from shutil import copyfile 5 | 6 | 7 | metadata = scipy.io.loadmat("./imdb/imdb.mat") 8 | meta = metadata['images'][0,0] 9 | ROOT = './images' 10 | 11 | traindir = './train' 12 | testdir = './test' 13 | 14 | def ignore_files(dir,files): return [f for f in files if os.path.isfile(os.path.join(dir,f))] 15 | shutil.copytree(ROOT, traindir, ignore=ignore_files) 16 | shutil.copytree(ROOT, testdir, ignore=ignore_files) 17 | 18 | trainlist = list() 19 | vallist = list() 20 | testlist = list() 21 | with open('./labels/train1.txt') as f: 22 | for file in f.readlines(): 23 | trainlist.append(file.strip()) 24 | with open('./labels/val1.txt') as f: 25 | for file in f.readlines(): 26 | vallist.append(file.strip()) 27 | with open('./labels/test1.txt') as f: 28 | for file in f.readlines(): 29 | testlist.append(file.strip()) 30 | 31 | trainlist = trainlist + vallist 32 | 33 | train_images = 0 34 | for img_path in trainlist: 35 | copyfile(os.path.join(ROOT, img_path), os.path.join(traindir, img_path)) 36 | train_images += 1 37 | 38 | test_images = 0 39 | for img_path in testlist: 40 | copyfile(os.path.join(ROOT, img_path), os.path.join(testdir, img_path)) 41 | test_images += 1 42 | 43 | print(train_images,test_images) 44 | assert train_images == 3760 45 | assert test_images == 1880 46 | print('Dataset succesfully splitted!') 47 | -------------------------------------------------------------------------------- /data/flowers_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import os 4 | from collections import defaultdict 5 | from tqdm import tqdm 6 | import pandas as pd 7 | import numpy as np 8 | import json 9 | import shutil 10 | import copy 11 | import scipy.io 12 | 13 | 14 | ROOT = "./data" 15 | metadata = scipy.io.loadmat(os.path.join(ROOT, "imagelabels.mat")) 16 | setid = scipy.io.loadmat(os.path.join(ROOT, "setid.mat")) 17 | 18 | metadata = metadata['labels'][0] 19 | train_setid = setid['trnid'][0] 20 | valid_setid = setid['valid'][0] 21 | test_setid = setid['tstid'][0] 22 | 23 | 24 | train_paths_by_class = defaultdict(list) 25 | valid_paths_by_class = defaultdict(list) 26 | test_paths_by_class = defaultdict(list) 27 | for idx, label in tqdm(enumerate(metadata, start=1)): 28 | path = os.path.join(ROOT, 'flowers', 'image_{:05d}.jpg'.format(idx)) 29 | if idx in train_setid: 30 | train_paths_by_class[str(label)].append(path) 31 | elif idx in valid_setid: 32 | valid_paths_by_class[str(label)].append(path) 33 | elif idx in test_setid: 34 | test_paths_by_class[str(label)].append(path) 35 | else: 36 | raise ValueError(idx) 37 | 38 | TARGET_DIR = './train' 39 | print("Copying images...") 40 | for cls, paths in train_paths_by_class.items(): 41 | for source_path in paths: 42 | path = source_path.split('/')[-1] 43 | target_path = os.path.join(TARGET_DIR, cls, path) 44 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 45 | shutil.copy(source_path, target_path) 46 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 47 | 48 | TARGET_DIR = './val' 49 | print("Copying images...") 50 | for cls, paths in valid_paths_by_class.items(): 51 | for source_path in paths: 52 | path = source_path.split('/')[-1] 53 | target_path = os.path.join(TARGET_DIR, cls, path) 54 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 55 | shutil.copy(source_path, target_path) 56 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 57 | 58 | TARGET_DIR = "./test" 59 | print("Copying images...") 60 | for cls, paths in test_paths_by_class.items(): 61 | for source_path in paths: 62 | path = source_path.split('/')[-1] 63 | target_path = os.path.join(TARGET_DIR, cls, path) 64 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 65 | shutil.copy(source_path, target_path) 66 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 67 | 68 | ### Option: use train + val instead of only train 69 | 70 | SOURCE_DIR = './flowers/val' 71 | TARGET_DIR = './flowers/train' 72 | 73 | for dirlist in os.listdir(SOURCE_DIR): 74 | for file in os.listdir(os.path.join(SOURCE_DIR, dirlist)): 75 | shutil.move(os.path.join(SOURCE_DIR, dirlist, file), os.path.join(TARGET_DIR, dirlist, file)) 76 | -------------------------------------------------------------------------------- /data/food11_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copyfile 3 | 4 | 5 | ROOT = './food11' 6 | 7 | classes = list(os.listdir(os.path.join(ROOT, 'raw', 'train'))) 8 | 9 | if not os.path.isdir(os.path.join(ROOT, 'train')): 10 | os.makedirs(os.path.join(ROOT, 'train')) 11 | 12 | mode = ['train', 'val'] 13 | 14 | for m in mode: 15 | source_root = os.path.join(ROOT, 'raw', m) 16 | target_root = os.path.join(ROOT, 'train') 17 | 18 | for cls in classes: 19 | file_list = os.listdir(os.path.join(source_root, cls)) 20 | 21 | for file in file_list: 22 | source = os.path.join(source_root, cls, file) 23 | target = os.path.join(target_root, cls, '{}_{}'.format(m, file)) 24 | if not os.path.exists(os.path.join(target_root, cls)): 25 | os.makedirs(os.path.join(target_root, cls)) 26 | copyfile(source, target) 27 | -------------------------------------------------------------------------------- /data/mit67_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | import matplotlib.pyplot as plt 4 | import os 5 | import shutil 6 | from shutil import copyfile 7 | 8 | # Change the path to your dataset folder: 9 | base_folder = './Images/' 10 | 11 | train_path = './TrainImages.txt' 12 | test_path = './TestImages.txt' 13 | 14 | # Here declare where you want to place the train/test folders 15 | # You don't need to create them! 16 | test_folder = './data/mit67/test/' 17 | train_folder = './data/mit67/train/' 18 | 19 | 20 | def ignore_files(dir,files): return [f for f in files if os.path.isfile(os.path.join(dir,f))] 21 | 22 | shutil.copytree(base_folder,test_folder,ignore=ignore_files) 23 | shutil.copytree(base_folder,train_folder,ignore=ignore_files) 24 | 25 | 26 | with open(train_path) as f: 27 | train_lines = f.readlines() 28 | 29 | with open(test_path) as f: 30 | test_lines = f.readlines() 31 | 32 | test_images, train_images = 0,0 33 | 34 | for train_line in train_lines: 35 | train_line = train_line.strip() 36 | if os.path.exists(base_folder+train_line): 37 | copyfile(base_folder+train_line, train_folder+train_line) 38 | train_images += 1 39 | 40 | for test_line in test_lines: 41 | test_line = test_line.strip() 42 | if os.path.exists(base_folder+test_line): 43 | copyfile(base_folder+test_line, test_folder+test_line) 44 | test_images += 1 45 | 46 | print(train_images, test_images) 47 | assert train_images == 67*80 48 | assert test_images == 67*20 49 | 50 | print('Dataset succesfully splitted!') 51 | -------------------------------------------------------------------------------- /data/pets_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | with open('./annotations/trainval.txt') as f: 5 | lines = f.readlines() 6 | 7 | image_to_class = dict() 8 | for line in lines: 9 | if line.startswith('#'): 10 | continue 11 | word = line.split(' ') 12 | image_to_class[word[0]] = word[1] 13 | # image_to_class 14 | 15 | SOURCE_DIR = "./images" 16 | TARGET_DIR = "./pets/train" 17 | print("Copying images...") 18 | for path, cls in image_to_class.items(): 19 | source_path = os.path.join(SOURCE_DIR, path+'.jpg') 20 | target_path = os.path.join(TARGET_DIR, cls, path+'.jpg') 21 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 22 | shutil.copy(source_path, target_path) 23 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 24 | 25 | with open('./annotations/test.txt') as f: 26 | lines = f.readlines() 27 | 28 | image_to_class = dict() 29 | for line in lines: 30 | if line.startswith('#'): 31 | continue 32 | word = line.split(' ') 33 | image_to_class[word[0]] = word[1] 34 | 35 | SOURCE_DIR = "./images" 36 | TARGET_DIR = "./pets/test" 37 | print("Copying images...") 38 | for path, cls in image_to_class.items(): 39 | source_path = os.path.join(SOURCE_DIR, path+'.jpg') 40 | target_path = os.path.join(TARGET_DIR, cls, path+'.jpg') 41 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 42 | shutil.copy(source_path, target_path) 43 | print("ImageFolder directory created at {}".format(os.path.abspath(TARGET_DIR))) 44 | 45 | -------------------------------------------------------------------------------- /data/places_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | # this notebook is for places365 challenge 2 | import os 3 | 4 | ROOT = './places/data_256' 5 | f = open(os.path.join('./places', 'categories_places365.txt'), 'r') 6 | 7 | cls_lists = [] 8 | while True: 9 | line = f.readline() 10 | if not line: break 11 | 12 | path = line.split(' ')[0][1:] 13 | tmp = path.split('/') 14 | if len(tmp) == 3: 15 | tmp = '{}_{}'.format(tmp[1], tmp[2]) 16 | elif len(tmp) == 2: 17 | tmp = '{}'.format(tmp[1]) 18 | 19 | cls_lists.append((path, tmp)) 20 | f.close() 21 | 22 | for p, t in cls_lists: 23 | path = os.path.join(ROOT, p) 24 | os.system('mv {}/ {}/'.format(path, os.path.join(ROOT, t))) 25 | 26 | lst = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 27 | 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y', 'z'] 28 | 29 | for l in lst: 30 | os.system('rm -rf {}'.format(os.path.join(ROOT, l))) 31 | 32 | path = os.listdir(os.path.join(ROOT)) 33 | os.system('mkdir {}'.format(os.path.join(ROOT, 'train'))) 34 | 35 | for p in path: 36 | os.system('mv {}/ {}/'.format(os.path.join(ROOT, p), os.path.join(ROOT, 'train'))) 37 | 38 | # rm data_256/train/airport_terminal/00020865.jpg 39 | # rm data_256/train/airfield/00021054.jpg 40 | # rm data_256/train/airport_terminal/00021622.jpg 41 | # rm data_256/train/alley/00017667.jpg 42 | # rm data_256/train/alley/00017729.jpg 43 | # rm data_256/train/alley/00018289.jpg 44 | # rm data_256/train/alley/00018892.jpg 45 | # rm data_256/train/amusement_arcade/00007300.jpg 46 | 47 | -------------------------------------------------------------------------------- /data/stanford40_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | import matplotlib.pyplot as plt 4 | import os 5 | import shutil 6 | from shutil import copyfile 7 | 8 | # Change the path to your dataset folder: 9 | base_folder = './JPEGImages/' 10 | 11 | train_path = './ImageSplits/train.txt' 12 | test_path = './ImageSplits/test.txt' 13 | 14 | # Here declare where you want to place the train/test folders 15 | # You don't need to create them! 16 | test_folder = './data/stanford40/test/' 17 | train_folder = './data/stanford40/train/' 18 | 19 | os.makedirs(test_folder, exist_ok=True) 20 | os.makedirs(train_folder, exist_ok=True) 21 | 22 | with open(train_path) as f: 23 | train_lines = f.readlines() 24 | 25 | with open(test_path) as f: 26 | test_lines = f.readlines() 27 | 28 | test_images, train_images = 0,0 29 | 30 | for train_line in train_lines: 31 | file_name = train_line.strip() 32 | class_name = '_'.join(train_line.strip().split('_')[:-1]) 33 | if os.path.exists(base_folder+file_name): 34 | os.makedirs(train_folder+class_name, exist_ok=True) 35 | copyfile(base_folder+file_name, os.path.join(train_folder,class_name,file_name)) 36 | train_images += 1 37 | 38 | for test_line in test_lines: 39 | file_name = test_line.strip() 40 | class_name = '_'.join(test_line.strip().split('_')[:-1]) 41 | if os.path.exists(base_folder+file_name): 42 | os.makedirs(test_folder+class_name, exist_ok=True) 43 | copyfile(base_folder+file_name, os.path.join(test_folder,class_name,file_name)) 44 | test_images += 1 45 | 46 | print(train_images, test_images) 47 | assert train_images == 4000 48 | assert test_images == 5532 49 | 50 | print('Dataset succesfully splitted!') 51 | -------------------------------------------------------------------------------- /data/webfg_496_check.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | T_DIR = './WebFG-496/web-aircraft/train' 4 | S_DIR = './aircraft/train' 5 | 6 | num = 0 7 | for folder in os.listdir(T_DIR): 8 | if folder.replace(' ', '_') in os.listdir(S_DIR): 9 | num += 1 10 | else: 11 | print(folder) 12 | 13 | print(num) 14 | 15 | os.system('mv {} {}'.format(os.path.join(T_DIR, 'F-16AB'), os.path.join(T_DIR, 'F-16A_B'))) 16 | os.system('mv {} {}'.format(os.path.join(T_DIR, 'FA-18'), os.path.join(T_DIR, 'F_A-18'))) 17 | 18 | 19 | T_DIR = './WebFG-496/web-cub/train' 20 | S_DIR = './cub/train' 21 | 22 | num = 0 23 | for folder in os.listdir(T_DIR): 24 | if folder.replace(' ', '_') in os.listdir(S_DIR): 25 | num += 1 26 | else: 27 | print(folder) 28 | 29 | print(num) 30 | -------------------------------------------------------------------------------- /data/webfg_496_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | dirs = './WebFG-496' 4 | cat = ['web-aircraft', 'web-bird', 'web-car'] 5 | 6 | os.makedirs(os.path.join(dirs, 'train'), exist_ok=True) 7 | 8 | for c in cat: 9 | for split in ['train']: 10 | path = os.listdir(os.path.join(dirs, c, split)) 11 | 12 | for p in path: 13 | p_sp = p.replace(' ', '\ ') 14 | t = p.replace(' ', '_') 15 | source_dir = os.path.join(dirs, c, split, p_sp) 16 | target_dir = os.path.join(dirs, 'train') 17 | 18 | if t not in os.listdir(target_dir): 19 | os.system('mv {} {}'.format(source_dir, target_dir+'/{}'.format(t))) 20 | else: 21 | os.system('mv {} {}'.format(source_dir+'/*', os.path.join(target_dir, t)+'/')) 22 | 23 | 24 | # there are 7 corrupted files 25 | 26 | from PIL import Image 27 | import os 28 | 29 | path = "./WebFG-496/train" 30 | folder = os.listdir(path) 31 | 32 | for fol in folder: 33 | checkdir = os.path.join(path, fol) 34 | files = os.listdir(checkdir) 35 | format = [".jpg", ".jpeg"] 36 | 37 | for(p, dirs, f) in os.walk(checkdir): 38 | for file in f: 39 | if file.endswith(tuple(format)): 40 | try: 41 | image = Image.open(p+"/"+file).load() 42 | # print(image) 43 | except Exception as e: 44 | print("An exception is raised:", e) 45 | print(file) 46 | -------------------------------------------------------------------------------- /data/webvision_image_folder_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | dirs = 'dataset/WebVision' 4 | cat = ['google', 'flickr'] 5 | 6 | os.makedirs(os.path.join(dirs, 'train'), exist_ok=True) 7 | 8 | for c in cat: 9 | path = os.listdir(os.path.join(dirs, c)) 10 | for p in path: 11 | if 'q' in p: 12 | tmp = '{}_{}'.format(c, p) 13 | os.system('mv {} {}'.format(os.path.join(dirs, c, p), os.path.joint(dirs, 'train', tmp))) 14 | -------------------------------------------------------------------------------- /downstream/README.md: -------------------------------------------------------------------------------- 1 | ## Downstream tasks 2 | 3 | We provide the training and evaluation codes for the downstream tasks which we have experimented on. 4 | 5 | ### Summary 6 | 7 | These tasks are outlined as follows. 8 | **Task name: `path`, `reference in the main paper`** 9 | 10 | - object detection: `downstream/detection/`, `Table 7c` 11 | - pixel-wise segmentation: `downstream/segmentation/`, `Table 7c` 12 | - open-set semi-supervised learning: `downstream/opensemi/`, `Table 6` 13 | - webly supervised learning: `downstream/weblysup/`, `Table 6` 14 | - semi-supervised learning: `downstream/semisup/`, `Table 15 in Appx. F.3` 15 | - active learning: `downstream/active/`, `Table 16 in Appx. F.3` 16 | - hard negative mining: `downstream/mining/`, `Table 17 in Appx. G` 17 | 18 | 19 | Note: _Semi-supervised learning here includes MixMatch, ReMixMatch, FixMatch, and FlexMatch, all of which utilize both labeled and unlabeled target data. 20 | On the other hand, semi-supervised learning in [train_sup.py](../train_sup.py) utilizes only partially labeled data (Table 7b in the paper), 21 | which follows the same protocol as other SSL works._ 22 | 23 | ### Run 24 | Every running files are located in each task. For example, to run the OpenSemi experiment with OpenMatch method, run 25 | ```sh 26 | $ cd downstream/opensemi/ 27 | $ bash run_openmatch.sh 28 | ``` 29 | All you have to do is setup the **dataset** and its **path**, and the **pretrained model checkpoint**. 30 | -------------------------------------------------------------------------------- /downstream/active/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .random import RandomSampling 2 | from .entropy import EntropySampling 3 | from .coreset import CoreSet 4 | from .badge import BadgeSampling 5 | 6 | 7 | def query_samples(args, dataset, backbone, classifier, unlabel_idxs, label_idxs=None): 8 | if label_idxs is None: 9 | strategy = RandomSampling(args, dataset, backbone, classifier) 10 | selected_indices = strategy.query(unlabel_idxs, args.n_query) 11 | else: 12 | if args.active_method == 'random': 13 | strategy = RandomSampling(args, dataset, backbone, classifier) 14 | selected_indices = strategy.query(unlabel_idxs, args.n_query) 15 | elif args.active_method == 'entropy': 16 | strategy = EntropySampling(args, dataset, backbone, classifier) 17 | selected_indices = strategy.query(unlabel_idxs, args.n_query) 18 | elif args.active_method == 'coreset': 19 | strategy = CoreSet(args, dataset, backbone, classifier) 20 | selected_indices = strategy.query(label_idxs, unlabel_idxs, args.n_query) 21 | elif args.active_method == 'badge': 22 | strategy = BadgeSampling(args, dataset, backbone, classifier) 23 | selected_indices = strategy.query(unlabel_idxs, args.n_query) 24 | 25 | selected_indices.extend(label_idxs) 26 | 27 | return selected_indices -------------------------------------------------------------------------------- /downstream/active/methods/badge.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | from scipy import stats 4 | from sklearn.metrics import pairwise_distances 5 | 6 | from .strategy import Strategy 7 | 8 | 9 | class BadgeSampling(Strategy): 10 | # kmeans ++ initialization 11 | def init_centers(self, X, K): 12 | ind = np.argmax([np.linalg.norm(s, 2) for s in X]) 13 | mu = [X[ind]] 14 | indsAll = [ind] 15 | centInds = [0.] * len(X) 16 | cent = 0 17 | # print('#Samps\tTotal Distance') 18 | while len(mu) < K: 19 | if len(mu) == 1: 20 | D2 = pairwise_distances(X, mu).ravel().astype(float) 21 | else: 22 | newD = pairwise_distances(X, [mu[-1]]).ravel().astype(float) 23 | for i in range(len(X)): 24 | if D2[i] > newD[i]: 25 | centInds[i] = cent 26 | D2[i] = newD[i] 27 | # print(str(len(mu)) + '\t' + str(sum(D2)), flush=True) 28 | if sum(D2) == 0.0: pdb.set_trace() 29 | D2 = D2.ravel().astype(float) 30 | Ddist = (D2 ** 2)/ sum(D2 ** 2) 31 | customDist = stats.rv_discrete(name='custm', values=(np.arange(len(D2)), Ddist)) 32 | ind = customDist.rvs(size=1)[0] 33 | while ind in indsAll: ind = customDist.rvs(size=1)[0] 34 | mu.append(X[ind]) 35 | indsAll.append(ind) 36 | cent += 1 37 | return indsAll 38 | 39 | def query(self, unlabel_idxs, n_query): 40 | gradEmbedding = self.get_grad_embedding(unlabel_idxs) 41 | gradEmbedding = gradEmbedding.numpy() 42 | 43 | chosen = self.init_centers(gradEmbedding, n_query), 44 | 45 | unlabel_idxs = np.array(unlabel_idxs) 46 | return list(unlabel_idxs[chosen]) -------------------------------------------------------------------------------- /downstream/active/methods/coreset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import pairwise_distances 3 | 4 | from .strategy import Strategy 5 | 6 | 7 | class CoreSet(Strategy): 8 | def furthest_first(self, X, X_set, n): 9 | m = np.shape(X)[0] 10 | if np.shape(X_set)[0] == 0: 11 | min_dist = np.tile(float("inf"), m) 12 | else: 13 | dist_ctr = pairwise_distances(X, X_set) 14 | min_dist = np.amin(dist_ctr, axis=1) 15 | 16 | idxs = [] 17 | 18 | for i in range(n): 19 | idx = min_dist.argmax() 20 | idxs.append(idx) 21 | dist_new_ctr = pairwise_distances(X, X[[idx], :]) 22 | for j in range(m): 23 | min_dist[j] = min(min_dist[j], dist_new_ctr[j, 0]) 24 | 25 | return idxs 26 | 27 | def query(self, label_idxs, unlabel_idxs, n_query): 28 | data_idxs = list(unlabel_idxs) + list(label_idxs) 29 | 30 | unlabel_idxs = np.array(unlabel_idxs) 31 | label_idxs = np.array(label_idxs) 32 | 33 | embedding = self.get_embedding(data_idxs) 34 | embedding = embedding.numpy() 35 | 36 | chosen = self.furthest_first(embedding[:len(unlabel_idxs), :], embedding[len(unlabel_idxs):, :], n_query) 37 | 38 | return list(unlabel_idxs[chosen]) -------------------------------------------------------------------------------- /downstream/active/methods/entropy.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from .strategy import Strategy 7 | 8 | 9 | class EntropySampling(Strategy): 10 | def query(self, unlabel_idxs, n_query): 11 | unlabel_idxs = np.array(unlabel_idxs) 12 | probs = self.predict_prob(unlabel_idxs) 13 | 14 | log_probs = torch.log(probs) 15 | 16 | log_probs[log_probs == float("-inf")] = 0 17 | log_probs[log_probs == float("inf")] = 0 18 | 19 | U = (probs*log_probs).sum(1) 20 | return list(unlabel_idxs[U.sort()[1][:n_query]]) -------------------------------------------------------------------------------- /downstream/active/methods/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from .strategy import Strategy 4 | 5 | 6 | class RandomSampling(Strategy): 7 | def query(self, unlabel_idxs, n_query): 8 | return random.sample(unlabel_idxs, n_query) -------------------------------------------------------------------------------- /downstream/active/methods/strategy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | 11 | class DatasetSplit(Dataset): 12 | def __init__(self, dataset, idxs): 13 | self.dataset = dataset 14 | self.idxs = list(idxs) 15 | 16 | def __len__(self): 17 | return len(self.idxs) 18 | 19 | def __getitem__(self, item): 20 | image, label = self.dataset[self.idxs[item]] 21 | return image, label, item 22 | 23 | 24 | class Strategy: 25 | def __init__(self, args, dataset, backbone, classifier): 26 | self.args = args 27 | self.dataset = dataset 28 | self.backbone = backbone 29 | self.classifier = classifier 30 | self.loss_func = nn.CrossEntropyLoss() 31 | 32 | def query(self, label_idxs, unlabel_idx, n_query): 33 | pass 34 | 35 | # Entropy Sampling 36 | def predict_prob(self, unlabel_idxs): 37 | loader_te = DataLoader(DatasetSplit(self.dataset, unlabel_idxs), shuffle=False) 38 | 39 | self.backbone.eval() 40 | probs = torch.zeros([len(unlabel_idxs), self.args.n_cls]) 41 | with torch.no_grad(): 42 | for x, y, idxs in loader_te: 43 | x, y = Variable(x.cuda(non_blocking=True)), Variable(y.cuda(non_blocking=True)) 44 | output = self.classifier(self.backbone(x)) 45 | probs[idxs] = torch.nn.functional.softmax(output, dim=1).cpu().data 46 | return probs 47 | 48 | # CoreSet 49 | def get_embedding(self, data_idxs): 50 | loader_te = DataLoader(DatasetSplit(self.dataset, data_idxs), shuffle=False) 51 | 52 | self.backbone.eval() 53 | embedding = torch.zeros([len(data_idxs), self.backbone.module.final_feat_dim]) 54 | with torch.no_grad(): 55 | for x, y, idxs in loader_te: 56 | x, y = Variable(x.cuda(non_blocking=True)), Variable(y.cuda(non_blocking=True)) 57 | emb = self.backbone(x) 58 | embedding[idxs] = emb.cpu().data 59 | 60 | return embedding 61 | 62 | # BADGE: gradient embedding (assumes cross-entropy loss) 63 | def get_grad_embedding(self, data_idxs): 64 | embDim = self.backbone.module.final_feat_dim 65 | self.backbone.eval() 66 | 67 | nLab = self.args.n_cls 68 | embedding = np.zeros([len(data_idxs), embDim * nLab]) 69 | loader_te = DataLoader(DatasetSplit(self.dataset, data_idxs), shuffle=False) 70 | 71 | with torch.no_grad(): 72 | for x, y, idxs in loader_te: 73 | x, y = Variable(x.cuda(non_blocking=True)), Variable(y.cuda(non_blocking=True)) 74 | out = self.backbone(x) 75 | cout = self.classifier(out) 76 | 77 | out = out.data.cpu().numpy() 78 | 79 | batchProbs = F.softmax(cout, dim=1).data.cpu().numpy() 80 | maxInds = np.argmax(batchProbs, 1) 81 | 82 | for j in range(len(y)): 83 | for c in range(nLab): 84 | if c == maxInds[j]: 85 | embedding[idxs[j]][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (1 - batchProbs[j][c]) 86 | else: 87 | embedding[idxs[j]][embDim * c : embDim * (c+1)] = deepcopy(out[j]) * (-1 * batchProbs[j][c]) 88 | return torch.Tensor(embedding) -------------------------------------------------------------------------------- /downstream/active/models: -------------------------------------------------------------------------------- 1 | ../../models -------------------------------------------------------------------------------- /downstream/active/run_active.sh: -------------------------------------------------------------------------------- 1 | # Random template -> for other methods, please change active_method 2 | SYM="SimCore" # X / OS 3 | DATA="pets" # cub / aircraft / cars 4 | TAG="bs256_lr3e-2_ep100_ar0_${SYM}" 5 | ALGO="random" 6 | 7 | python train.py --tag $TAG \ 8 | --dataset $DATA \ 9 | --data_folder /path/to/data \ 10 | --model resnet50 \ 11 | --method simclr \ 12 | --optimizer sgd \ 13 | --learning_rate 3e-2 \ 14 | --weight_decay 1e-4 \ 15 | --epochs 100 \ 16 | --label_ratio 0.1 \ 17 | --batch_size 256 \ 18 | --active_method $ALGO \ 19 | --active_round 0 \ 20 | --pretrained_ckpt /path/to/ckpt/last.pth 21 | 22 | LABEL="0.2 0.3 0.4" 23 | cnt=0 24 | for l in $LABEL 25 | do 26 | TAG="bs256_lr3e-2_ep100_ar${cnt}_${SYM}" 27 | CKPT="./save/${DATA}_resnet50_${ALGO}_active_${TAG}/best.pth" 28 | 29 | ((cnt++)) 30 | TAG="bs256_lr3e-2_ep100_ar${cnt}_${SYM}" 31 | 32 | python train.py --tag $TAG \ 33 | --dataset $DATA \ 34 | --data_folder /path/to/data \ 35 | --model resnet50 \ 36 | --method simclr \ 37 | --optimizer sgd \ 38 | --learning_rate 3e-2 \ 39 | --weight_decay 1e-4 \ 40 | --epochs 100 \ 41 | --label_ratio $l \ 42 | --batch_size 256 \ 43 | --active_method $ALGO \ 44 | --active_round $cnt \ 45 | --pretrained_ckpt ${CKPT} 46 | done -------------------------------------------------------------------------------- /downstream/active/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/openssl-simcore/97b4c9e6492a24d833d6391f203dad1793a08aa6/downstream/active/util/__init__.py -------------------------------------------------------------------------------- /downstream/active/util/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data as data 7 | import torchvision.datasets as datasets 8 | 9 | 10 | class ImageFolderSemiSup(data.Dataset): 11 | def __init__(self, root='', transform=None, p=1.0, index=None, return_idx=False): 12 | super(ImageFolderSemiSup, self).__init__() 13 | 14 | self.root = root 15 | self.transform = transform 16 | self.p = p 17 | self.return_idx = return_idx 18 | 19 | self.train_dataset = datasets.ImageFolder(root=self.root, transform=self.transform) 20 | if index is None: 21 | # randomly select samples 22 | random.seed(0) 23 | self.dataset_len = int(self.p * len(self.train_dataset)) 24 | self.sampled_index = random.sample(range(len(self.train_dataset)), self.dataset_len) 25 | else: 26 | # samples selected by x_u_split function 27 | self.sampled_index = index 28 | self.dataset_len = len(self.sampled_index) 29 | 30 | def __len__(self): 31 | return self.dataset_len 32 | 33 | def __getitem__(self, index): 34 | if self.return_idx: 35 | return index, self.train_dataset[self.sampled_index[index]] 36 | else: 37 | return self.train_dataset[self.sampled_index[index]] -------------------------------------------------------------------------------- /downstream/active/util/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import json 5 | import math 6 | import copy 7 | import numpy as np 8 | 9 | import torch 10 | import torch.optim as optim 11 | 12 | __all__ = ['AverageMeter', 'AverageClassMeter', 'adjust_lr_wd', 'warmup_learning_rate', 'accuracy', 'update_metric', 'get_best_acc', 'save_model', 'update_json'] 13 | 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | 33 | class AverageClassMeter(object): 34 | def __init__(self, n_cls): 35 | self.meter = [] 36 | for _ in range(n_cls): self.meter.append(AverageMeter()) 37 | self.n_cls = n_cls 38 | 39 | def update(self, cls, val, n=1): 40 | self.meter[cls].update(val, n) 41 | 42 | self.val = self.meter[-1].val 43 | self.avg = torch.tensor(sum([m.avg for m in self.meter]) / self.n_cls) 44 | 45 | 46 | def adjust_lr_wd(args, optimizer, epoch): 47 | lr = args.learning_rate 48 | if args.cosine: 49 | eta_min = lr * (args.lr_decay_rate ** 3) 50 | lr = eta_min + (lr - eta_min) * ( 51 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 52 | else: 53 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 54 | if steps > 0: 55 | lr = lr * (args.lr_decay_rate ** steps) 56 | 57 | wd = args.weight_decay 58 | if args.wd_scheduler: 59 | wd_min = args.weight_decay_end 60 | wd = wd_min + (wd - wd_min) * ( 61 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 62 | 63 | for i, param_group in enumerate(optimizer.param_groups): 64 | param_group['lr'] = lr 65 | if i == 0: # in case of DINO-ViT and MAE-ViT, only wd for regularized params 66 | param_group['weight_decay'] = wd 67 | 68 | 69 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 70 | if args.warm and epoch <= args.warm_epochs: 71 | p = (batch_id + (epoch - 1) * total_batches) / \ 72 | (args.warm_epochs * total_batches) 73 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 74 | 75 | for param_group in optimizer.param_groups: 76 | param_group['lr'] = lr 77 | 78 | 79 | def accuracy(output, target, topk=(1,)): 80 | """Computes the accuracy over the k top predictions for the specified values of k""" 81 | with torch.no_grad(): 82 | n_cls = output.shape[1] 83 | valid_topk = [k for k in topk if k <= n_cls] 84 | 85 | maxk = max(valid_topk) 86 | bsz = target.size(0) 87 | 88 | _, pred = output.topk(maxk, 1, True, True) 89 | pred = pred.t() 90 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 91 | 92 | res = [] 93 | for k in topk: 94 | if k in valid_topk: 95 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 96 | res.append(correct_k.mul_(100.0 / bsz)) 97 | else: res.append(torch.tensor([0.])) 98 | return res, bsz 99 | 100 | 101 | def update_metric(output, labels, top1, top5, args): 102 | if top1.__class__.__name__ == 'AverageMeter': 103 | [acc1, acc5], bsz = accuracy(output, labels, topk=(1, 5)) 104 | top1.update(acc1[0], bsz) 105 | top5.update(acc5[0], bsz) 106 | else: # mean per-class accuracy 107 | for cls in range(args.n_cls): 108 | if not (labels==cls).sum(): continue 109 | [acc1, acc5], bsz = accuracy(output[labels==cls], labels[labels==cls], topk=(1, 5)) 110 | top1.update(cls, acc1[0], bsz) 111 | top5.update(cls, acc5[0], bsz) 112 | return top1, top5 113 | 114 | 115 | def get_best_acc(val_acc1, val_acc5, best_acc): 116 | best = False 117 | if val_acc1.item() > best_acc[0]: 118 | best_acc[0] = val_acc1.item() 119 | best_acc[1] = val_acc5.item() 120 | best = True 121 | return best_acc, best 122 | 123 | 124 | def save_model(model, optimizer, args, epoch, save_file, indices=None, classifier=None): 125 | print('==> Saving...') 126 | state = { 127 | 'args': args, 128 | 'model': model.state_dict(), 129 | 'optimizer': optimizer.state_dict(), 130 | 'epoch': epoch, 131 | } 132 | if indices is not None: 133 | state['indices'] = indices 134 | if classifier is not None: # for active learning fine-tuning and open-set semi 135 | state['classifier'] = classifier.state_dict() 136 | torch.save(state, save_file) 137 | del state 138 | 139 | 140 | def update_json(exp_name, acc=[], path='./save/results.json'): 141 | acc = [round(a, 2) for a in acc] 142 | if not os.path.exists(path): 143 | with open(path, 'w') as f: 144 | json.dump({}, f) 145 | 146 | with open(path, 'r', encoding="UTF-8") as f: 147 | result_dict = json.load(f) 148 | result_dict[exp_name] = acc 149 | 150 | with open(path, 'w') as f: 151 | json.dump(result_dict, f) 152 | 153 | print('best accuracy: {} (Acc@1, Acc@5, Train Acc)'.format(acc)) 154 | print('results updated to %s' % path) 155 | -------------------------------------------------------------------------------- /downstream/detection/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | from albumentations.pytorch.transforms import ToTensorV2 3 | from torch.utils.data import DataLoader 4 | from dataloaders.datasets import aircraft_detection, cars_detection 5 | 6 | 7 | def make_data_loader(args, **kwargs): 8 | def collate_fn(batch): 9 | return tuple(zip(*batch)) 10 | 11 | train_transform = A.Compose([ 12 | # A.Resize(height=224, width=224, always_apply=True), 13 | A.BBoxSafeRandomCrop(), 14 | A.HorizontalFlip(p=0.5), 15 | ToTensorV2() 16 | ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']}) 17 | 18 | test_transform = A.Compose([ 19 | ToTensorV2() 20 | ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']}) 21 | 22 | if args.dataset == 'aircraft': 23 | # root = '~~~/aircraft/' 24 | train_set = aircraft_detection.AircraftDetectionDataset(args.root, 'trainval', transforms=train_transform, class_label=args.class_label) 25 | test_set = aircraft_detection.AircraftDetectionDataset(args.root, 'test', transforms=test_transform, class_label=args.class_label) 26 | 27 | # 1 class (aircraft) + background 28 | num_class = 2 if not args.class_label else 101 29 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8, **kwargs) 30 | # test_loader = DataLoader(test_set, batch_size=1, shuffle=False, collate_fn=collate_fn, **kwargs) 31 | 32 | return train_loader, test_set, num_class 33 | 34 | elif args.dataset == 'cars': 35 | # root = '~~~/cars/' 36 | train_set = cars_detection.CarDetectionDataset(args.root, split='train', transforms=train_transform, class_label=args.class_label) 37 | test_set = cars_detection.CarDetectionDataset(args.root, split='test', transforms=test_transform, class_label=args.class_label) 38 | 39 | # class: cars (localization) 40 | num_class = 2 if not args.class_label else 197 41 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8, **kwargs) 42 | # test_loader = DataLoader(test_set, batch_size=1, shuffle=False, collate_fn=collate_fn, **kwargs) 43 | 44 | return train_loader, test_set, num_class 45 | 46 | else: 47 | raise NotImplementedError 48 | 49 | -------------------------------------------------------------------------------- /downstream/detection/dataloaders/datasets/aircraft_detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | class AircraftDetectionDataset(object): 9 | def __init__(self, root, split, transforms=None, class_label=False): 10 | self.root = root 11 | self.split = split # 'trainval' or 'test' 12 | self.transforms = transforms 13 | self.class_label = class_label 14 | 15 | self.cls_to_idx = dict() 16 | self.img_to_cls = dict() 17 | num = 0 18 | with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'images_variant_{}.txt'.format(self.split)), 'r') as f: 19 | while True: 20 | line = f.readline() 21 | if not line: break 22 | line = line.strip().split(' ') 23 | img_id, cls_name = line[0]+'.jpg', '_'.join(line[1:]).replace('/', '_') 24 | self.img_to_cls[img_id] = cls_name 25 | if cls_name not in self.cls_to_idx: 26 | self.cls_to_idx[cls_name] = num 27 | num += 1 28 | f.close() 29 | 30 | if not os.path.exists(os.path.join(self.root, '{}_bbox.csv'.format(self.split))): 31 | print('bbox csv file not exists > make...') 32 | self.df = self._make_bbox() 33 | else: 34 | self.df = pd.read_csv(os.path.join(self.root, '{}_bbox.csv'.format(self.split))) 35 | 36 | self.image_ids = self.df['img_path'].unique().tolist() 37 | 38 | def _make_bbox(self): 39 | paths = self._split_data(self.split) 40 | 41 | images = [] 42 | with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'images_box.txt'), 'r') as f: 43 | while True: 44 | line = f.readline() 45 | if not line: break 46 | 47 | line = line.replace('\n', '').split(' ') 48 | line[0] += '.jpg' 49 | 50 | if line[0] in paths: 51 | line.append(int(self.cls_to_idx[self.img_to_cls[line[0]]])) 52 | images.append(line) 53 | 54 | df = pd.DataFrame(np.array(images), columns=['img_path', 'bbox_x1', 'bbox_y1', 'bbox_x2', 'bbox_y2', 'labels']) 55 | 56 | # bbox type should be np.int64 57 | for i, col in enumerate(df.columns): 58 | if i != 0: 59 | df[col] = df[col].astype(np.int64) 60 | 61 | # save dataframe for ground_truth bbox later 62 | df.to_csv(os.path.join(self.root, '{}_bbox.csv'.format(self.split)), index=False) 63 | 64 | return df 65 | 66 | def _split_data(self, split): 67 | # split dataset to 'trainval' or 'test' 68 | paths = [] 69 | with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'images_variant_{}.txt'.format(split)), 'r') as f: 70 | while True: 71 | line = f.readline() 72 | if not line: break 73 | 74 | line = line.split(' ')[0] + '.jpg' 75 | paths.append(line) 76 | f.close() 77 | return paths 78 | 79 | def __len__(self): 80 | return len(self.image_ids) 81 | 82 | def __getitem__(self, idx): 83 | image_id = self.image_ids[idx] 84 | records = self.df[self.df['img_path'] == image_id] 85 | image = cv2.imread(os.path.join(self.root, 'fgvc-aircraft-2013b/data/images', image_id), cv2.IMREAD_COLOR) 86 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) 87 | image /= 255.0 88 | 89 | boxes = records[['bbox_x1', 'bbox_y1', 'bbox_x2', 'bbox_y2']].to_numpy() 90 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 91 | if self.class_label: 92 | labels = records['labels'].to_numpy() 93 | labels = torch.tensor([labels[0]]).long() 94 | else: 95 | labels = torch.ones((records.shape[0],), dtype=torch.int64) 96 | 97 | # some images have wrong bbox labels, such as rotated bbox 98 | if not (boxes[:, 2] <= image.shape[1] and boxes[:, 0] >= 0 and boxes[:, 3] <= image.shape[0] and boxes[:, 1] >= 0): 99 | y1, x1, y2, x2 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3] 100 | boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3] = x1, y1, x2, y2 101 | assert boxes[:, 2] <= image.shape[1] and boxes[:, 0] >= 0 and boxes[:, 3] <= image.shape[0] and boxes[:, 1] >= 0 102 | 103 | target = {} 104 | target['boxes'] = boxes 105 | target['labels'] = labels 106 | target['image_id'] = torch.tensor([idx]) 107 | target['area'] = torch.as_tensor(area, dtype=torch.float32) 108 | target['iscrowd'] = torch.zeros((records.shape[0],), dtype=torch.int64) 109 | 110 | if self.transforms: 111 | sample = { 112 | 'image': image, 113 | 'bboxes': target['boxes'], 114 | 'labels': labels 115 | } 116 | sample = self.transforms(**sample) 117 | image = sample['image'] 118 | 119 | target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0) 120 | 121 | return image.clone().detach(), target, image_id 122 | -------------------------------------------------------------------------------- /downstream/detection/dataloaders/datasets/cars_detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | import scipy.io 7 | 8 | 9 | class CarDetectionDataset(object): 10 | def __init__(self, root, split, transforms=None, class_label=False): 11 | self.root = root 12 | self.split = split 13 | self.transforms = transforms 14 | self.class_label = class_label 15 | if not os.path.exists(os.path.join(self.root, '{}_bbox.csv'.format(self.split))): 16 | print('bbox csv file not exists > make...') 17 | self.df = self._make_bbox() 18 | else: 19 | self.df = pd.read_csv(os.path.join(self.root, '{}_bbox.csv'.format(self.split))) 20 | 21 | self.image_ids = self.df['img_path'].unique().tolist() 22 | 23 | def _make_bbox(self): 24 | train_bbox_path = os.path.join(self.root, 'train_bbox.csv') 25 | test_bbox_path = os.path.join(self.root, 'test_bbox.csv') 26 | 27 | metadata = scipy.io.loadmat(os.path.join(self.root, 'cars_annos.mat')) 28 | metadata = metadata['annotations'][0] 29 | train_bbox_list, test_bbox_list = list(), list() 30 | for meta in metadata: 31 | img_path = meta[0][0] 32 | bbox_x1, bbox_y1, bbox_x2, bbox_y2, class_num, test = [meta[i][0,0] for i in range(1,7)] 33 | if not test: 34 | img_path = os.path.join(self.root, 'train', str(class_num), img_path.split('/')[1]) 35 | train_bbox_list.append([img_path, bbox_x1, bbox_y1, bbox_x2, bbox_y2, class_num - 1]) 36 | else: 37 | img_path = os.path.join(self.root, 'test', str(class_num), img_path.split('/')[1]) 38 | test_bbox_list.append([img_path, bbox_x1, bbox_y1, bbox_x2, bbox_y2, class_num - 1]) 39 | 40 | train_df = pd.DataFrame(train_bbox_list, columns=['img_path', 'bbox_x1', 'bbox_y1', 'bbox_x2', 'bbox_y2', 'labels']) 41 | test_df = pd.DataFrame(test_bbox_list, columns=['img_path', 'bbox_x1', 'bbox_y1', 'bbox_x2', 'bbox_y2', 'labels']) 42 | train_df.to_csv(train_bbox_path, index=False) 43 | test_df.to_csv(test_bbox_path, index=False) 44 | 45 | if self.split == 'train': 46 | return train_df 47 | elif self.split == 'test': 48 | return test_df 49 | 50 | def __len__(self): 51 | return len(self.image_ids) 52 | 53 | def __getitem__(self, idx): 54 | image_id = self.image_ids[idx] 55 | records = self.df[self.df['img_path'] == image_id] 56 | image = cv2.imread(image_id, cv2.IMREAD_COLOR) 57 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) 58 | image /= 255.0 59 | 60 | boxes = records[['bbox_x1', 'bbox_y1', 'bbox_x2', 'bbox_y2']].to_numpy() 61 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 62 | if self.class_label: 63 | labels = records['labels'].to_numpy() 64 | labels = torch.tensor([labels[0]]).long() 65 | else: 66 | labels = torch.ones((records.shape[0],), dtype=torch.int64) 67 | 68 | # some images have wrong bbox labels, such as flipped bbox 69 | if not (boxes[:, 2] <= image.shape[1] and boxes[:, 0] >= 0 and boxes[:, 3] <= image.shape[0] and boxes[:, 1] >= 0): 70 | y1, x1, y2, x2 = boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3] 71 | boxes[:,0], boxes[:,1], boxes[:,2], boxes[:,3] = x1, y1, x2, y2 72 | assert boxes[:, 2] <= image.shape[1] and boxes[:, 0] >= 0 and boxes[:, 3] <= image.shape[0] and boxes[:, 1] >= 0 73 | 74 | target = {} 75 | target['boxes'] = boxes 76 | target['labels'] = labels 77 | target['image_id'] = torch.tensor([idx]) 78 | target['area'] = torch.as_tensor(area, dtype=torch.float32) 79 | target['iscrowd'] = torch.zeros((records.shape[0],), dtype=torch.int64) 80 | 81 | if self.transforms: 82 | sample = { 83 | 'image': image, 84 | 'bboxes': target['boxes'], 85 | 'labels': labels 86 | } 87 | sample = self.transforms(**sample) 88 | image = sample['image'] 89 | 90 | target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0) 91 | 92 | return image.clone().detach(), target, image_id 93 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/README.md: -------------------------------------------------------------------------------- 1 | # Extra 2 | 3 | ## ground-truth: 4 | - ### convert `xml` to our format: 5 | 6 | 1) Insert ground-truth xml files into **ground-truth/** 7 | 2) Run the python script: `python convert_gt_xml.py` 8 | 9 | - ### convert YOLO to our format: 10 | 11 | 1) Add class list to the file `class_list.txt` 12 | 2) Insert ground-truth files into **ground-truth/** 13 | 3) Insert images into **images/** 14 | 4) Run the python script: `python convert_gt_yolo.py` 15 | 16 | - ### convert keras-yolo3 to our format: 17 | 18 | 1) Add or update the class list to the file `class_list.txt` 19 | 2) Use the parameter `--gt` to set the **ground-truth** source. 20 | 3) Run the python script: `python3 convert_keras-yolo3.py --gt ` 21 | 1) Supports only python 3. 22 | 2) This code can handle recursive annotation structure. Just use the `-r` parameter. 23 | 3) The converted annotation is placed by default in a new from_kerasyolo3 folder. You can change that with the parameter `-o`. 24 | 4) The format is defined according with github.com/qqwweee/keras-yolo3 25 | 26 | ## detection-results: 27 | - ### convert darkflow `json` to our format: 28 | 29 | 1) Insert result json files into **detection-results/** 30 | 2) Run the python script: `python convert_dr_darkflow_json.py` 31 | 32 | - ### convert YOLO to our format: 33 | 34 | After runnuning darknet on a list of images, e.g.: `darknet.exe detector test data/voc.data yolo-voc.cfg yolo-voc.weights -dont_show -ext_output < data/test.txt > result.txt` 35 | 36 | 1) Copy the file `result.txt` to the folder `extra/` 37 | 2) Run the python script: `python convert_dr_yolo.py` 38 | 39 | - ### convert keras-yolo3 to our format: 40 | 41 | 1) Add or update the class list to the file `class_list.txt` 42 | 2) Use the parameter `--dr` to set the **detection-results** source. 43 | 3) Run the python script: `python3 convert_keras-yolo3.py --dr ` 44 | 1) Supports only python 3. 45 | 2) This code can handle recursive annotation structure. Just use the `-r` parameter. 46 | 3) The converted annotation is placed by default in a new from_kerasyolo3 folder. You can change that with the parameter `-o`. 47 | 4) The format is defined according with github.com/gustavovaliati/keras-yolo3 48 | 49 | ## Find the files that contain a specific class of objects 50 | 51 | 1) Run the `find_class.py` script and specify the **class** as argument, e.g. 52 | `python find_class.py chair` 53 | 54 | ## Intersect ground-truth and detection-results files 55 | This script ensures same number of files in ground-truth and detection-results folder. 56 | When you encounter file not found error, it's usually because you have 57 | mismatched numbers of ground-truth and detection-results files. 58 | You can use this script to move ground-truth and detection-results files that are 59 | not in the intersection into a backup folder (backup_no_matches_found). 60 | This will retain only files that have the same name in both folders. 61 | 62 | 1) Prepare `.txt` files in your `ground-truth` and `detection-results` folders. 63 | 2) Run the `intersect-gt-and-dr.py` script to move non-intersected files into a backup folder (default: `backup_no_matches_found`). 64 | 65 | `python intersect-gt-and-dr.py` 66 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/class_list.txt: -------------------------------------------------------------------------------- 1 | bed 2 | person 3 | pictureframe 4 | shirt 5 | lamp 6 | nightstand 7 | clock 8 | heater 9 | windowblind 10 | pillow 11 | robot 12 | cabinetry 13 | door 14 | doorhandle 15 | shelf 16 | pottedplant 17 | chair 18 | diningtable 19 | backpack 20 | whiteboard 21 | cup 22 | tvmonitor 23 | pen 24 | pencil 25 | wardrobe 26 | apple 27 | orange 28 | countertop 29 | tap 30 | banana 31 | bicyclehelmet 32 | book 33 | bookcase 34 | refrigerator 35 | wastecontainer 36 | tincan 37 | handbag 38 | sofa 39 | glasses 40 | vase 41 | coffeetable 42 | bowl 43 | remote 44 | candle 45 | bottle 46 | sink 47 | envelope 48 | doll 49 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/convert_dr_darkflow_json.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import json 5 | 6 | # make sure that the cwd() in the beginning is the location of the python script (so that every path makes sense) 7 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | # change directory to the one with the files to be changed 10 | parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) 11 | parent_path = os.path.abspath(os.path.join(parent_path, os.pardir)) 12 | DR_PATH = os.path.join(parent_path, 'input','detection-results') 13 | #print(DR_PATH) 14 | os.chdir(DR_PATH) 15 | 16 | # old files (darkflow json format) will be moved to a "backup" folder 17 | ## create the backup dir if it doesn't exist already 18 | if not os.path.exists("backup"): 19 | os.makedirs("backup") 20 | 21 | # create VOC format files 22 | json_list = glob.glob('*.json') 23 | if len(json_list) == 0: 24 | print("Error: no .json files found in detection-results") 25 | sys.exit() 26 | for tmp_file in json_list: 27 | #print(tmp_file) 28 | # 1. create new file (VOC format) 29 | with open(tmp_file.replace(".json", ".txt"), "a") as new_f: 30 | data = json.load(open(tmp_file)) 31 | for obj in data: 32 | obj_name = obj['label'] 33 | conf = obj['confidence'] 34 | left = obj['topleft']['x'] 35 | top = obj['topleft']['y'] 36 | right = obj['bottomright']['x'] 37 | bottom = obj['bottomright']['y'] 38 | new_f.write(obj_name + " " + str(conf) + " " + str(left) + " " + str(top) + " " + str(right) + " " + str(bottom) + '\n') 39 | # 2. move old file (darkflow format) to backup 40 | os.rename(tmp_file, "backup/" + tmp_file) 41 | print("Conversion completed!") 42 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/convert_dr_yolo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | # make sure that the cwd() in the beginning is the location of the python script (so that every path makes sense) 5 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 6 | 7 | IN_FILE = 'result.txt' 8 | 9 | # change directory to the one with the files to be changed 10 | parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) 11 | parent_path = os.path.abspath(os.path.join(parent_path, os.pardir)) 12 | DR_PATH = os.path.join(parent_path, 'input','detection-results') 13 | #print(DR_PATH) 14 | os.chdir(DR_PATH) 15 | 16 | SEPARATOR_KEY = 'Enter Image Path:' 17 | IMG_FORMAT = '.jpg' 18 | 19 | outfile = None 20 | with open(IN_FILE) as infile: 21 | for line in infile: 22 | if SEPARATOR_KEY in line: 23 | if IMG_FORMAT not in line: 24 | break 25 | # get text between two substrings (SEPARATOR_KEY and IMG_FORMAT) 26 | image_path = re.search(SEPARATOR_KEY + '(.*)' + IMG_FORMAT, line) 27 | # get the image name (the final component of a image_path) 28 | # e.g., from 'data/horses_1' to 'horses_1' 29 | image_name = os.path.basename(image_path.group(1)) 30 | # close the previous file 31 | if outfile is not None: 32 | outfile.close() 33 | # open a new file 34 | outfile = open(os.path.join(DR_PATH, image_name + '.txt'), 'w') 35 | elif outfile is not None: 36 | # split line on first occurrence of the character ':' and '%' 37 | class_name, info = line.split(':', 1) 38 | confidence, bbox = info.split('%', 1) 39 | # get all the coordinates of the bounding box 40 | bbox = bbox.replace(')','') # remove the character ')' 41 | # go through each of the parts of the string and check if it is a digit 42 | left, top, width, height = [int(s) for s in bbox.split() if s.lstrip('-').isdigit()] 43 | right = left + width 44 | bottom = top + height 45 | outfile.write("{} {} {} {} {} {}\n".format(class_name, float(confidence)/100, left, top, right, bottom)) 46 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/convert_gt_xml.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import xml.etree.ElementTree as ET 5 | 6 | # make sure that the cwd() in the beginning is the location of the python script (so that every path makes sense) 7 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | # change directory to the one with the files to be changed 10 | parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) 11 | parent_path = os.path.abspath(os.path.join(parent_path, os.pardir)) 12 | GT_PATH = os.path.join(parent_path, 'input','ground-truth') 13 | #print(GT_PATH) 14 | os.chdir(GT_PATH) 15 | 16 | # old files (xml format) will be moved to a "backup" folder 17 | ## create the backup dir if it doesn't exist already 18 | if not os.path.exists("backup"): 19 | os.makedirs("backup") 20 | 21 | # create VOC format files 22 | xml_list = glob.glob('*.xml') 23 | if len(xml_list) == 0: 24 | print("Error: no .xml files found in ground-truth") 25 | sys.exit() 26 | for tmp_file in xml_list: 27 | #print(tmp_file) 28 | # 1. create new file (VOC format) 29 | with open(tmp_file.replace(".xml", ".txt"), "a") as new_f: 30 | root = ET.parse(tmp_file).getroot() 31 | for obj in root.findall('object'): 32 | obj_name = obj.find('name').text 33 | bndbox = obj.find('bndbox') 34 | left = bndbox.find('xmin').text 35 | top = bndbox.find('ymin').text 36 | right = bndbox.find('xmax').text 37 | bottom = bndbox.find('ymax').text 38 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 39 | # 2. move old file (xml format) to backup 40 | os.rename(tmp_file, os.path.join("backup", tmp_file)) 41 | print("Conversion completed!") 42 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/convert_gt_yolo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import cv2 5 | 6 | def convert_yolo_coordinates_to_voc(x_c_n, y_c_n, width_n, height_n, img_width, img_height): 7 | ## remove normalization given the size of the image 8 | x_c = float(x_c_n) * img_width 9 | y_c = float(y_c_n) * img_height 10 | width = float(width_n) * img_width 11 | height = float(height_n) * img_height 12 | ## compute half width and half height 13 | half_width = width / 2 14 | half_height = height / 2 15 | ## compute left, top, right, bottom 16 | ## in the official VOC challenge the top-left pixel in the image has coordinates (1;1) 17 | left = int(x_c - half_width) + 1 18 | top = int(y_c - half_height) + 1 19 | right = int(x_c + half_width) + 1 20 | bottom = int(y_c + half_height) + 1 21 | return left, top, right, bottom 22 | 23 | # make sure that the cwd() in the beginning is the location of the python script (so that every path makes sense) 24 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 25 | 26 | # read the class_list.txt to a list 27 | with open("class_list.txt") as f: 28 | obj_list = f.readlines() 29 | ## remove whitespace characters like `\n` at the end of each line 30 | obj_list = [x.strip() for x in obj_list] 31 | ## e.g. first object in the list 32 | #print(obj_list[0]) 33 | 34 | # change directory to the one with the files to be changed 35 | parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) 36 | parent_path = os.path.abspath(os.path.join(parent_path, os.pardir)) 37 | GT_PATH = os.path.join(parent_path, 'input','ground-truth') 38 | #print(GT_PATH) 39 | os.chdir(GT_PATH) 40 | 41 | # old files (YOLO format) will be moved to a new folder (backup/) 42 | ## create the backup dir if it doesn't exist already 43 | if not os.path.exists("backup"): 44 | os.makedirs("backup") 45 | 46 | # create VOC format files 47 | txt_list = glob.glob('*.txt') 48 | if len(txt_list) == 0: 49 | print("Error: no .txt files found in ground-truth") 50 | sys.exit() 51 | for tmp_file in txt_list: 52 | #print(tmp_file) 53 | # 1. check that there is an image with that name 54 | ## get name before ".txt" 55 | image_name = tmp_file.split(".txt",1)[0] 56 | #print(image_name) 57 | ## check if image exists 58 | for fname in os.listdir('../images'): 59 | if fname.startswith(image_name): 60 | ## image found 61 | #print(fname) 62 | img = cv2.imread('../images/' + fname) 63 | ## get image width and height 64 | img_height, img_width = img.shape[:2] 65 | break 66 | else: 67 | ## image not found 68 | print("Error: image not found, corresponding to " + tmp_file) 69 | sys.exit() 70 | # 2. open txt file lines to a list 71 | with open(tmp_file) as f: 72 | content = f.readlines() 73 | ## remove whitespace characters like `\n` at the end of each line 74 | content = [x.strip() for x in content] 75 | # 3. move old file (YOLO format) to backup 76 | os.rename(tmp_file, "backup/" + tmp_file) 77 | # 4. create new file (VOC format) 78 | with open(tmp_file, "a") as new_f: 79 | for line in content: 80 | ## split a line by spaces. 81 | ## "c" stands for center and "n" stands for normalized 82 | obj_id, x_c_n, y_c_n, width_n, height_n = line.split() 83 | obj_name = obj_list[int(obj_id)] 84 | left, top, right, bottom = convert_yolo_coordinates_to_voc(x_c_n, y_c_n, width_n, height_n, img_width, img_height) 85 | ## add new line to file 86 | #print(obj_name + " " + str(left) + " " + str(top) + " " + str(right) + " " + str(bottom)) 87 | new_f.write(obj_name + " " + str(left) + " " + str(top) + " " + str(right) + " " + str(bottom) + '\n') 88 | print("Conversion completed!") 89 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/convert_keras-yolo3.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ABOUT THIS SCRIPT: 3 | Converts ground-truth from the annotation files 4 | according to the https://github.com/qqwweee/keras-yolo3 5 | or https://github.com/gustavovaliati/keras-yolo3 format. 6 | 7 | And converts the detection-results from the annotation files 8 | according to the https://github.com/gustavovaliati/keras-yolo3 format. 9 | ''' 10 | 11 | import argparse 12 | import datetime 13 | import os 14 | 15 | ''' 16 | Each time this script runs, it saves the output in a different path 17 | controlled by the following folder suffix: annotation_version. 18 | ''' 19 | annotation_version = datetime.datetime.now().strftime('%Y%m%d%H%M%S') 20 | 21 | ap = argparse.ArgumentParser() 22 | 23 | ap.add_argument("-o", "--output_path", 24 | required=False, 25 | default='from_kerasyolo3/version_{}'.format(annotation_version), 26 | type=str, 27 | help="The dataset root path location.") 28 | ap.add_argument("-r", "--gen_recursive", 29 | required=False, 30 | default=False, 31 | action="store_true", 32 | help="Define if the output txt files will be placed in a \ 33 | recursive folder tree or to direct txt files.") 34 | group = ap.add_mutually_exclusive_group(required=True) 35 | group.add_argument('--gt', 36 | type=str, 37 | default=None, 38 | help="The annotation file that refers to ground-truth in (keras-yolo3 format)") 39 | group.add_argument('--dr', 40 | type=str, 41 | default=None, 42 | help="The annotation file that refers to detection-results in (keras-yolo3 format)") 43 | 44 | ARGS = ap.parse_args() 45 | 46 | with open('class_list.txt', 'r') as class_file: 47 | class_map = class_file.readlines() 48 | print(class_map) 49 | annotation_file = ARGS.gt if ARGS.gt else ARGS.dr 50 | 51 | os.makedirs(ARGS.output_path, exist_ok=True) 52 | 53 | with open(annotation_file, 'r') as annot_f: 54 | for annot in annot_f: 55 | annot = annot.split(' ') 56 | img_path = annot[0].strip() 57 | if ARGS.gen_recursive: 58 | annotation_dir_name = os.path.dirname(img_path) 59 | # remove the root path to enable to path.join. 60 | if annotation_dir_name.startswith('/'): 61 | annotation_dir_name = annotation_dir_name.replace('/', '', 1) 62 | destination_dir = os.path.join(ARGS.output_path, annotation_dir_name) 63 | os.makedirs(destination_dir, exist_ok=True) 64 | # replace .jpg with your image format. 65 | file_name = os.path.basename(img_path).replace('.jpg', '.txt') 66 | output_file_path = os.path.join(destination_dir, file_name) 67 | else: 68 | file_name = img_path.replace('.jpg', '.txt').replace('/', '__') 69 | output_file_path = os.path.join(ARGS.output_path, file_name) 70 | os.path.dirname(output_file_path) 71 | 72 | with open(output_file_path, 'w') as out_f: 73 | for bbox in annot[1:]: 74 | if ARGS.gt: 75 | # Here we are dealing with ground-truth annotations 76 | # [] 77 | # todo: handle difficulty 78 | x_min, y_min, x_max, y_max, class_id = list(map(float, bbox.split(','))) 79 | out_box = '{} {} {} {} {}'.format( 80 | class_map[int(class_id)].strip(), x_min, y_min, x_max, y_max) 81 | else: 82 | # Here we are dealing with detection-results annotations 83 | # 84 | x_min, y_min, x_max, y_max, class_id, score = list(map(float, bbox.split(','))) 85 | out_box = '{} {} {} {} {} {}'.format( 86 | class_map[int(class_id)].strip(), score, x_min, y_min, x_max, y_max) 87 | 88 | out_f.write(out_box + "\n") 89 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/find_class.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | 5 | # make sure that the cwd() in the beginning is the location of the python script (so that every path makes sense) 6 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | if len(sys.argv) != 2: 9 | print("Error: wrong format.\nUsage: python find_class.py [class_name]") 10 | sys.exit(0) 11 | 12 | searching_class_name = sys.argv[1] 13 | 14 | def find_class(class_name): 15 | file_list = glob.glob('*.txt') 16 | file_list.sort() 17 | # iterate through the text files 18 | file_found = False 19 | for txt_file in file_list: 20 | # open txt file lines to a list 21 | with open(txt_file) as f: 22 | content = f.readlines() 23 | # remove whitespace characters like `\n` at the end of each line 24 | content = [x.strip() for x in content] 25 | # go through each line of eache file 26 | for line in content: 27 | class_name = line.split()[0] 28 | if class_name == searching_class_name: 29 | print(" " + txt_file) 30 | file_found = True 31 | break 32 | if not file_found: 33 | print(" No file found with that class") 34 | 35 | parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) 36 | parent_path = os.path.abspath(os.path.join(parent_path, os.pardir)) 37 | GT_PATH = os.path.join(parent_path, 'input','ground-truth') 38 | DR_PATH = os.path.join(parent_path, 'input','detection-results') 39 | 40 | print("ground-truth folder:") 41 | os.chdir(GT_PATH) 42 | find_class(searching_class_name) 43 | print("detection-results folder:") 44 | os.chdir(DR_PATH) 45 | find_class(searching_class_name) 46 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/intersect-gt-and-dr.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | 5 | ## This script ensures same number of files in ground-truth and detection-results folder. 6 | ## When you encounter file not found error, it's usually because you have 7 | ## mismatched numbers of ground-truth and detection-results files. 8 | ## You can use this script to move ground-truth and detection-results files that are 9 | ## not in the intersection into a backup folder (backup_no_matches_found). 10 | ## This will retain only files that have the same name in both folders. 11 | 12 | 13 | # make sure that the cwd() in the beginning is the location of the python script (so that every path makes sense) 14 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 15 | 16 | parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) 17 | parent_path = os.path.abspath(os.path.join(parent_path, os.pardir)) 18 | GT_PATH = os.path.join(parent_path, 'input','ground-truth') 19 | DR_PATH = os.path.join(parent_path, 'input','detection-results') 20 | 21 | backup_folder = 'backup_no_matches_found' # must end without slash 22 | 23 | os.chdir(GT_PATH) 24 | gt_files = glob.glob('*.txt') 25 | if len(gt_files) == 0: 26 | print("Error: no .txt files found in", GT_PATH) 27 | sys.exit() 28 | os.chdir(DR_PATH) 29 | dr_files = glob.glob('*.txt') 30 | if len(dr_files) == 0: 31 | print("Error: no .txt files found in", DR_PATH) 32 | sys.exit() 33 | 34 | gt_files = set(gt_files) 35 | dr_files = set(dr_files) 36 | print('total ground-truth files:', len(gt_files)) 37 | print('total detection-results files:', len(dr_files)) 38 | print() 39 | 40 | gt_backup = gt_files - dr_files 41 | dr_backup = dr_files - gt_files 42 | 43 | def backup(src_folder, backup_files, backup_folder): 44 | # non-intersection files (txt format) will be moved to a backup folder 45 | if not backup_files: 46 | print('No backup required for', src_folder) 47 | return 48 | os.chdir(src_folder) 49 | ## create the backup dir if it doesn't exist already 50 | if not os.path.exists(backup_folder): 51 | os.makedirs(backup_folder) 52 | for file in backup_files: 53 | os.rename(file, backup_folder + '/' + file) 54 | 55 | backup(GT_PATH, gt_backup, backup_folder) 56 | backup(DR_PATH, dr_backup, backup_folder) 57 | if gt_backup: 58 | print('total ground-truth backup files:', len(gt_backup)) 59 | if dr_backup: 60 | print('total detection-results backup files:', len(dr_backup)) 61 | 62 | intersection = gt_files & dr_files 63 | print('total intersected files:', len(intersection)) 64 | print("Intersection completed!") 65 | -------------------------------------------------------------------------------- /downstream/detection/mAP/scripts/extra/result.txt: -------------------------------------------------------------------------------- 1 | Total BFLOPS 65.864 2 | 3 | seen 64 4 | Enter Image Path: data/horses.jpg: Predicted in 42.076185 seconds. 5 | horse: 88% (left_x: 3 top_y: 185 width: 150 height: 167) 6 | horse: 99% (left_x: 5 top_y: 198 width: 307 height: 214) 7 | horse: 96% (left_x: 236 top_y: 180 width: 215 height: 169) 8 | horse: 99% (left_x: 440 top_y: 209 width: 156 height: 142) 9 | Enter Image Path: data/person.jpg: Predicted in 41.767213 seconds. 10 | dog: 99% (left_x: 58 top_y: 262 width: 147 height: 89) 11 | person: 100% (left_x: 190 top_y: 95 width: 86 height: 284) 12 | horse: 100% (left_x: 394 top_y: 137 width: 215 height: 206) 13 | Enter Image Path: -------------------------------------------------------------------------------- /downstream/detection/run_detection.sh: -------------------------------------------------------------------------------- 1 | TAG="" 2 | DATA="" 3 | 4 | python train.py --tag $TAG \ 5 | --root /data/${DATA}/ \ 6 | --dataset $DATA \ 7 | --backbone resnet50 \ 8 | --pretrained_backbone /path/to/ckpt/last.pth \ 9 | --epochs 100 \ 10 | --precision \ 11 | --optimizer adam \ 12 | --learning_rate 0.0001 \ 13 | --cosine \ 14 | --iou_thresh 0.2 \ 15 | --class_label \ 16 | --predict_path ./mAP/input/${TAG}/detection-results \ 17 | --gt_bbox_path ./mAP/input/${TAG}/ground-truth 18 | 19 | # only for getting prediction bbox 20 | # --pretrained_ckpt /path/to/ckpt/last.pth 21 | 22 | IoU_start="0.5" 23 | IoU_end="0.95" 24 | Step="0.05" 25 | 26 | for i in $(seq $IoU_start $Step $IoU_end) 27 | do 28 | echo "IoU threshold is $i" 29 | python ./mAP/main.py --tag $TAG --dataset $DATA --iou_thresh $i 30 | done 31 | -------------------------------------------------------------------------------- /downstream/detection/util/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | __all__ = ['adjust_learning_rate', 'save_model'] 6 | 7 | 8 | def adjust_learning_rate(args, optimizer, epoch): 9 | lr = args.learning_rate 10 | if args.cosine: 11 | eta_min = lr * (args.lr_decay_rate ** 3) 12 | lr = eta_min + (lr - eta_min) * ( 13 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 14 | else: 15 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 16 | if steps > 0: 17 | lr = lr * (args.lr_decay_rate ** steps) 18 | 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = lr 21 | 22 | 23 | def save_model(model, optimizer, args, epoch, save_file): 24 | print('==> Saving...') 25 | state = { 26 | 'args': args, 27 | 'model': model.state_dict(), 28 | 'optimizer': optimizer.state_dict(), 29 | 'epoch': epoch, 30 | } 31 | torch.save(state, save_file) 32 | del state -------------------------------------------------------------------------------- /downstream/mining/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .hard_sampling import HardSampling 2 | from .explicit_hard_sampling import ExpHardSampling 3 | from .explicit_hard_sampling_memory import ExpHardSamplingMemory 4 | 5 | _method_class_map = { 6 | 'sampling': HardSampling, 7 | 'explicit_sampling': ExpHardSampling, 8 | 'explicit_sampling_memory': ExpHardSamplingMemory 9 | } 10 | 11 | 12 | def get_method_class(key): 13 | if key in _method_class_map: 14 | return _method_class_map[key] 15 | else: 16 | raise ValueError('Invalid method: {}'.format(key)) 17 | -------------------------------------------------------------------------------- /downstream/mining/methods/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from argparse import Namespace 3 | from typing import Tuple 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class BaseModel(nn.Module): 10 | """ 11 | BaseModel subclasses self-contain all modules and losses required for pre-training. 12 | 13 | - supported_feature_selectors: Feature selectors (see `forward_features()`) are used during fine-tuning 14 | to select which features (from which layer the features should be extracted) should be used for downstream 15 | tasks. This class attribute should be set for subclasses to prevent mistakes regarding the feature_selector 16 | argument (see `params.ft_features`). 17 | """ 18 | supported_feature_selectors = [] 19 | 20 | def __init__(self, backbone: nn.Module, params: Namespace): 21 | super().__init__() 22 | self.backbone = backbone 23 | self.params = params 24 | self.classifier = nn.Linear(backbone.final_feat_dim, params.n_cls) 25 | self.classifier.bias.data.fill_(0) 26 | self.cls_loss_function = nn.CrossEntropyLoss() 27 | self.final_feat_dim = backbone.final_feat_dim 28 | 29 | def forward_features(self, x): 30 | """ 31 | You'll likely need to override this method for SSL models. 32 | """ 33 | return self.backbone(x) 34 | 35 | def forward(self, x): 36 | x = self.backbone(x) 37 | x = self.classifier(x) 38 | return x 39 | 40 | def compute_cls_loss_and_accuracy(self, x, y, return_predictions=False) -> Tuple: 41 | scores = self.forward(x) 42 | _, predicted = torch.max(scores.data, 1) 43 | accuracy = predicted.eq(y.data).cpu().sum() / x.shape[0] 44 | if return_predictions: 45 | return self.cls_loss_function(scores, y), accuracy, predicted 46 | else: 47 | return self.cls_loss_function(scores, y), accuracy 48 | 49 | def on_step_start(self): 50 | pass 51 | 52 | def on_step_end(self): 53 | pass 54 | 55 | def on_epoch_start(self, train_loader): 56 | pass 57 | 58 | def on_epoch_end(self): 59 | pass 60 | 61 | 62 | class BaseSelfSupervisedModel(BaseModel): 63 | @abstractmethod 64 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 65 | """ 66 | If SSL is based on paired input: 67 | By default: x1, x2 represent the input pair. 68 | If x2=None: x1 alone contains the full concatenated input pair. 69 | Else: 70 | x1 contains the input. 71 | """ 72 | raise NotImplementedError() 73 | 74 | @abstractmethod 75 | def _data_parallel(self): 76 | raise NotImplementedError() 77 | -------------------------------------------------------------------------------- /downstream/mining/methods/explicit_hard_sampling.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from .simclr import SimCLR, NTXentLoss 10 | 11 | 12 | class ExpHardNTXentLoss(NTXentLoss): 13 | def __init__(self, temperature, use_cosine_similarity, sampling_ratio): 14 | super(ExpHardNTXentLoss, self).__init__(temperature, use_cosine_similarity) 15 | self.sampling_ratio = sampling_ratio 16 | self.criterion = nn.CrossEntropyLoss(reduction='sum') 17 | 18 | def forward(self, zis, zjs): 19 | batch_size = zis.shape[0] 20 | representations = torch.cat([zjs, zis], dim=0) 21 | representations = torch.nn.functional.normalize(representations, dim=-1) 22 | device = representations.device 23 | 24 | similarity_matrix = self.similarity_function( 25 | representations, representations) 26 | 27 | # filter out the scores from the positive samples 28 | l_pos = torch.diag(similarity_matrix, batch_size) 29 | r_pos = torch.diag(similarity_matrix, -batch_size) 30 | positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) 31 | 32 | mask = self._get_correlated_mask(batch_size).to(device) 33 | negatives = similarity_matrix[mask].view(2 * batch_size, -1) 34 | # compute logits (Einstein sum is more intuitive) 35 | # representations = torch.cat([zjs, zis], dim=0) 36 | # representations_ = torch.cat([zis, zjs], dim=0) 37 | # l_pos = torch.einsum('nc,nc->n', [representations, representations_]).unsqueeze(-1) # positive logits: 2Nx1 38 | # l_neg = torch.einsum('nc,ck->nk', [representations, queue]) # negative logits: 2NxK 39 | 40 | sampling_num = int(2 * batch_size * self.sampling_ratio) 41 | negatives = negatives.sort(descending=True, dim=1)[0][:, :sampling_num] # 2Nx(sample_num) 42 | logits = torch.cat((positives, negatives), dim=1) 43 | logits /= self.temperature 44 | 45 | labels = torch.zeros(2 * batch_size).to(device).long() 46 | loss = self.criterion(logits, labels) 47 | 48 | return (self.temperature / 0.07) * loss / (2 * batch_size) 49 | 50 | 51 | class ExpHardSampling(SimCLR): 52 | def __init__(self, backbone: nn.Module, params: Namespace): 53 | super().__init__(backbone, params) 54 | simclr_projection_dim = 128 55 | simclr_temperature = params.simclr_temperature 56 | sampling_ratio = params.sampling_ratio 57 | 58 | self.ssl_loss_fn = ExpHardNTXentLoss(temperature=simclr_temperature, use_cosine_similarity=False, 59 | sampling_ratio=sampling_ratio) 60 | 61 | @torch.no_grad() 62 | def on_step_start(self, train_loader): 63 | pass 64 | 65 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 66 | if x2 is None: 67 | x = x1 68 | else: 69 | x = torch.cat([x1, x2]) 70 | batch_size = int(x.shape[0] / 2) 71 | 72 | f = self.backbone(x) 73 | f1, f2 = f[:batch_size], f[batch_size:] 74 | p1 = self.head(f1) 75 | p2 = self.head(f2) 76 | loss = self.ssl_loss_fn(p1, p2) 77 | 78 | if return_features: 79 | if x2 is None: 80 | return loss, f 81 | else: 82 | return loss, f1, f2 83 | else: 84 | return loss 85 | -------------------------------------------------------------------------------- /downstream/mining/methods/explicit_hard_sampling_memory.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from .simclr import SimCLR, NTXentLoss 10 | 11 | 12 | class ExpHardNTXentLoss(NTXentLoss): 13 | def __init__(self, temperature, use_cosine_similarity, sampling_ratio): 14 | super(ExpHardNTXentLoss, self).__init__(temperature, use_cosine_similarity) 15 | self.sampling_ratio = sampling_ratio 16 | self.criterion = nn.CrossEntropyLoss(reduction='sum') 17 | 18 | def forward(self, zis, zjs, queue): 19 | batch_size = zis.shape[0] 20 | K = queue.shape[1] 21 | 22 | representations = torch.cat([zjs, zis], dim=0) 23 | representations = torch.nn.functional.normalize(representations, dim=-1) 24 | device = representations.device 25 | 26 | similarity_matrix = self.similarity_function( 27 | representations, representations) 28 | 29 | similarity_matrix_queue = self.similarity_function( 30 | representations, queue.T) # (2N, K) 31 | 32 | # filter out the scores from the positive samples 33 | l_pos = torch.diag(similarity_matrix, batch_size) 34 | r_pos = torch.diag(similarity_matrix, -batch_size) 35 | positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) 36 | 37 | mask = self._get_correlated_mask(batch_size).to(device) 38 | negatives = similarity_matrix[mask].view(2 * batch_size, -1) 39 | 40 | sampling_num = int(K * self.sampling_ratio) 41 | negatives_queue = similarity_matrix_queue.sort(descending=True, dim=1)[0][:, :sampling_num] 42 | neg = torch.cat([negatives, negatives_queue], dim=1) # (2N, 2N-2+sampling_num) 43 | logits = torch.cat((positives, neg), dim=1) 44 | logits /= self.temperature 45 | 46 | labels = torch.zeros(2 * batch_size).to(representations.device).long() 47 | loss = self.criterion(logits, labels) 48 | 49 | return loss / (2 * batch_size) 50 | 51 | 52 | class ExpHardSamplingMemory(SimCLR): 53 | def __init__(self, backbone: nn.Module, params: Namespace): 54 | super().__init__(backbone, params) 55 | simclr_projection_dim = 128 56 | simclr_temperature = 0.07 57 | sampling_ratio = params.sampling_ratio 58 | self.K = 65536 59 | # self.K = params.K 60 | 61 | self.ssl_loss_fn = ExpHardNTXentLoss(temperature=simclr_temperature, use_cosine_similarity=False, 62 | sampling_ratio=sampling_ratio) 63 | 64 | self.register_buffer("queue", torch.randn(simclr_projection_dim, self.K)) 65 | self.queue = F.normalize(self.queue, dim=0) 66 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 67 | 68 | @torch.no_grad() 69 | def _dequeue_and_enqueue(self, keys): 70 | batch_size = keys.shape[0] 71 | ptr = int(self.queue_ptr) 72 | assert self.K % batch_size == 0 # for simplicity 73 | 74 | # replace the keys at ptr (dequeue and enqueue) 75 | self.queue[:, ptr:ptr + batch_size] = keys.T 76 | ptr = (ptr + batch_size) % self.K # move pointer 77 | 78 | self.queue_ptr[0] = ptr 79 | 80 | @torch.no_grad() 81 | def on_step_start(self, train_loader): 82 | global train_iter 83 | try: 84 | images, _ = next(train_iter) 85 | except: 86 | print('Iterator restart!') 87 | train_iter = iter(train_loader) 88 | images, _ = next(train_iter) 89 | img = images.cuda(non_blocking=True) 90 | 91 | with torch.cuda.amp.autocast(): 92 | p1 = self.head(self.backbone(img)) 93 | p1 = F.normalize(p1, dim=1) 94 | self._dequeue_and_enqueue(p1.clone().detach()) 95 | 96 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 97 | if x2 is None: 98 | x = x1 99 | else: 100 | x = torch.cat([x1, x2]) 101 | batch_size = int(x.shape[0] / 2) 102 | 103 | f = self.backbone(x) 104 | f1, f2 = f[:batch_size], f[batch_size:] 105 | p1 = self.head(f1) 106 | p2 = self.head(f2) 107 | loss = self.ssl_loss_fn(p1, p2, self.queue.clone().detach()) 108 | 109 | if return_features: 110 | if x2 is None: 111 | return loss, f 112 | else: 113 | return loss, f1, f2 114 | else: 115 | return loss 116 | -------------------------------------------------------------------------------- /downstream/mining/methods/hard_sampling.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from .simclr import SimCLR, NTXentLoss 10 | 11 | 12 | class HardNTXentLoss(NTXentLoss): 13 | def __init__(self, temperature, use_cosine_similarity, beta, tau_plus): 14 | super(HardNTXentLoss, self).__init__(temperature, use_cosine_similarity) 15 | self.beta = beta 16 | self.tau_plus = tau_plus 17 | 18 | def forward(self, zis, zjs): 19 | batch_size = zis.shape[0] 20 | representations = torch.cat([zjs, zis], dim=0) 21 | representations = torch.nn.functional.normalize(representations, dim=-1) 22 | device = representations.device 23 | 24 | similarity_matrix = self.similarity_function( 25 | representations, representations) 26 | 27 | # filter out the scores from the positive samples 28 | l_pos = torch.diag(similarity_matrix, batch_size) 29 | r_pos = torch.diag(similarity_matrix, -batch_size) 30 | positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) 31 | 32 | mask = self._get_correlated_mask(batch_size).to(device) 33 | negatives = similarity_matrix[mask].view(2 * batch_size, -1) 34 | 35 | neg = (negatives / self.temperature).exp() # (2N, 2N-2) 36 | pos = (positives / self.temperature).exp() 37 | 38 | N = batch_size * 2 - 2 39 | imp = (self.beta* neg.log()).exp() 40 | reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1) 41 | Ng = (-self.tau_plus * N * pos + reweight_neg) / (1 - self.tau_plus) 42 | # constrain (optional) 43 | Ng = torch.clamp(Ng, min = N * np.e**(-1 / self.temperature)) 44 | 45 | loss = (- torch.log(pos / (pos + Ng) )).mean() 46 | 47 | return loss 48 | 49 | 50 | class HardSampling(SimCLR): 51 | def __init__(self, backbone: nn.Module, params: Namespace): 52 | super().__init__(backbone, params) 53 | simclr_projection_dim = 128 54 | simclr_temperature = 0.07 55 | beta = params.beta 56 | tau_plus = params.tau_plus 57 | 58 | self.ssl_loss_fn = HardNTXentLoss(temperature=simclr_temperature, use_cosine_similarity=False, 59 | beta=beta, tau_plus=tau_plus) 60 | 61 | @torch.no_grad() 62 | def on_step_start(self, train_loader): 63 | pass 64 | 65 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 66 | if x2 is None: 67 | x = x1 68 | else: 69 | x = torch.cat([x1, x2]) 70 | batch_size = int(x.shape[0] / 2) 71 | 72 | f = self.backbone(x) 73 | f1, f2 = f[:batch_size], f[batch_size:] 74 | p1 = self.head(f1) 75 | p2 = self.head(f2) 76 | loss = self.ssl_loss_fn(p1, p2) 77 | 78 | if return_features: 79 | if x2 is None: 80 | return loss, f 81 | else: 82 | return loss, f1, f2 83 | else: 84 | return loss -------------------------------------------------------------------------------- /downstream/mining/methods/simclr.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | from .base import BaseSelfSupervisedModel 9 | 10 | 11 | class ProjectionHead(nn.Module): 12 | def __init__(self, in_dim, out_dim): 13 | super(ProjectionHead, self).__init__() 14 | self.in_dim = in_dim 15 | self.out_dim = out_dim 16 | 17 | self.fc1 = nn.Linear(in_dim, in_dim) 18 | self.relu = nn.ReLU() 19 | self.fc2 = nn.Linear(in_dim, out_dim) 20 | 21 | def forward(self, x): 22 | return self.fc2(self.relu(self.fc1(x))) 23 | 24 | 25 | class NTXentLoss(nn.Module): 26 | def __init__(self, temperature, use_cosine_similarity): 27 | super(NTXentLoss, self).__init__() 28 | self.temperature = temperature 29 | self.softmax = torch.nn.Softmax(dim=-1) 30 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 31 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 32 | 33 | def _get_similarity_function(self, use_cosine_similarity): 34 | if use_cosine_similarity: 35 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 36 | return self._cosine_simililarity 37 | else: 38 | return self._dot_simililarity 39 | 40 | @lru_cache(maxsize=4) 41 | def _get_correlated_mask(self, batch_size): 42 | diag = np.eye(2 * batch_size) 43 | l1 = np.eye((2 * batch_size), 2 * 44 | batch_size, k=-batch_size) 45 | l2 = np.eye((2 * batch_size), 2 * 46 | batch_size, k=batch_size) 47 | mask = torch.from_numpy((diag + l1 + l2)) 48 | mask = (1 - mask).type(torch.bool) 49 | return mask 50 | 51 | @staticmethod 52 | def _dot_simililarity(x, y): 53 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 54 | # x shape: (N, 1, C) 55 | # y shape: (1, C, 2N) 56 | # v shape: (N, 2N) 57 | return v 58 | 59 | def _cosine_simililarity(self, x, y): 60 | # x shape: (N, 1, C) 61 | # y shape: (1, 2N, C) 62 | # v shape: (N, 2N) 63 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 64 | return v 65 | 66 | def forward(self, zis, zjs): 67 | batch_size = zis.shape[0] 68 | representations = torch.cat([zjs, zis], dim=0) 69 | representations = torch.nn.functional.normalize(representations, dim=-1) 70 | device = representations.device 71 | 72 | similarity_matrix = self.similarity_function( 73 | representations, representations) 74 | 75 | # filter out the scores from the positive samples 76 | l_pos = torch.diag(similarity_matrix, batch_size) 77 | r_pos = torch.diag(similarity_matrix, -batch_size) 78 | positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) 79 | 80 | mask = self._get_correlated_mask(batch_size).to(device) 81 | negatives = similarity_matrix[mask].view(2 * batch_size, -1) 82 | 83 | logits = torch.cat((positives, negatives), dim=1) 84 | logits /= self.temperature 85 | 86 | labels = torch.zeros(2 * batch_size).to(device).long() 87 | loss = self.criterion(logits, labels) 88 | 89 | return loss / (2 * batch_size) 90 | 91 | 92 | class SimCLR(BaseSelfSupervisedModel): 93 | def __init__(self, backbone: nn.Module, params: Namespace): 94 | super().__init__(backbone, params) 95 | simclr_projection_dim = 128 96 | simclr_temperature = 0.07 97 | 98 | self.head = ProjectionHead(backbone.final_feat_dim, out_dim=simclr_projection_dim) 99 | self.ssl_loss_fn = NTXentLoss(temperature=simclr_temperature, use_cosine_similarity=False) 100 | self.final_feat_dim = self.backbone.final_feat_dim 101 | 102 | def _data_parallel(self): 103 | self.backbone = nn.DataParallel(self.backbone) 104 | self.head = nn.DataParallel(self.head) 105 | 106 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 107 | if x2 is None: 108 | x = x1 109 | else: 110 | x = torch.cat([x1, x2]) 111 | batch_size = int(x.shape[0] / 2) 112 | 113 | f = self.backbone(x) 114 | f1, f2 = f[:batch_size], f[batch_size:] 115 | p1 = self.head(f1) 116 | p2 = self.head(f2) 117 | loss = self.ssl_loss_fn(p1, p2) 118 | 119 | if return_features: 120 | if x2 is None: 121 | return loss, f 122 | else: 123 | return loss, f1, f2 124 | else: 125 | return loss 126 | -------------------------------------------------------------------------------- /downstream/mining/models: -------------------------------------------------------------------------------- 1 | ../../models -------------------------------------------------------------------------------- /downstream/mining/run_hard_neg_mining.sh: -------------------------------------------------------------------------------- 1 | # HardSampling template -> for other methods, please modify method and related arguments to default values 2 | TAG="" 3 | DATA="" 4 | 5 | python train.py --tag $TAG \ 6 | --dataset1 $DATA \ 7 | --dataset2 imagenet \ 8 | --data_folder1 /path/to/data \ 9 | --data_folder2 /path/to/data/ILSVRC2015/ILSVRC2015/Data/CLS-LOC/ \ 10 | --model resnet18 \ 11 | --cosine \ 12 | --precision \ 13 | --method sampling \ 14 | --epochs 5000 \ 15 | --beta 1.0 \ 16 | --tau_plus 0.01 -------------------------------------------------------------------------------- /downstream/mining/util/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | 5 | class MergeDataset(data.Dataset): 6 | def __init__(self, dataset1, dataset2): 7 | super(MergeDataset, self).__init__() 8 | assert isinstance(dataset1, data.Dataset) and isinstance(dataset2, data.Dataset) 9 | self.dataset1 = dataset1 10 | self.dataset2 = dataset2 11 | self.n1 = len(self.dataset1) 12 | self.n2 = len(self.dataset2) 13 | 14 | def __len__(self): 15 | return self.n1 + self.n2 16 | 17 | def __getitem__(self, idx): 18 | if idx < self.n1: 19 | return self.dataset1[idx] 20 | else: 21 | return self.dataset2[idx - self.n1] -------------------------------------------------------------------------------- /downstream/mining/util/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim as optim 9 | 10 | __all__ = ['AverageMeter', 'adjust_lr_wd', 'warmup_learning_rate', 'save_model'] 11 | 12 | 13 | class AverageMeter(object): 14 | """Computes and stores the average and current value""" 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | def adjust_lr_wd(args, optimizer, epoch): 32 | lr = args.learning_rate 33 | if args.cosine: 34 | eta_min = lr * (args.lr_decay_rate ** 3) 35 | lr = eta_min + (lr - eta_min) * ( 36 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 37 | else: 38 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 39 | if steps > 0: 40 | lr = lr * (args.lr_decay_rate ** steps) 41 | 42 | wd = args.weight_decay 43 | if args.wd_scheduler: 44 | wd_min = args.weight_decay_end 45 | wd = wd_min + (wd - wd_min) * ( 46 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 47 | 48 | for i, param_group in enumerate(optimizer.param_groups): 49 | param_group['lr'] = lr 50 | if i == 0: # in case of DINO-ViT and MAE-ViT, only wd for regularized params 51 | param_group['weight_decay'] = wd 52 | 53 | 54 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 55 | if args.warm and epoch <= args.warm_epochs: 56 | p = (batch_id + (epoch - 1) * total_batches) / \ 57 | (args.warm_epochs * total_batches) 58 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 59 | 60 | for param_group in optimizer.param_groups: 61 | param_group['lr'] = lr 62 | 63 | 64 | def save_model(model, optimizer, args, epoch, save_file, indices=None, classifier=None): 65 | print('==> Saving...') 66 | state = { 67 | 'args': args, 68 | 'model': model.state_dict(), 69 | 'optimizer': optimizer.state_dict(), 70 | 'epoch': epoch, 71 | } 72 | if indices is not None: 73 | state['indices'] = indices 74 | if classifier is not None: # for active learning fine-tuning and open-set semi 75 | state['classifier'] = classifier.state_dict() 76 | torch.save(state, save_file) 77 | del state 78 | -------------------------------------------------------------------------------- /downstream/mining/util/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TwoCropTransform: 5 | """ Create two crops of the same image """ 6 | def __init__(self, transform): 7 | self.transform = transform 8 | 9 | def __call__(self, x): 10 | return [self.transform(x), self.transform(x)] -------------------------------------------------------------------------------- /downstream/opensemi/models: -------------------------------------------------------------------------------- 1 | ../../models -------------------------------------------------------------------------------- /downstream/opensemi/run_openmatch.sh: -------------------------------------------------------------------------------- 1 | TAG="" 2 | DATA="" 3 | 4 | # SimCore pretrained checkpoint 5 | python train.py --tag $TAG \ 6 | --dataset $DATA \ 7 | --data_folder /path/to/data \ 8 | --data_folder2 /path/to/data/WebFG-496 \ 9 | --model resnet50 \ 10 | --method simclr \ 11 | --semisup_method openmatch \ 12 | --optimizer sgd \ 13 | --learning_rate 0.03 \ 14 | --weight_decay 1e-4 \ 15 | --cosine \ 16 | --start_fix 5 \ 17 | --total_step 16384 \ 18 | --eval_step 512 \ 19 | --batch_size 64 \ 20 | --label_ratio 0.5 \ 21 | --mu 2 \ 22 | --T 1 \ 23 | --lambda_oem 0.1 \ 24 | --lambda_socr 0.5 \ 25 | --threshold 0.0 \ 26 | --save_freq 10 \ 27 | --pretrained \ 28 | --pretrained_ckpt /path/to/ckpt/last.pth 29 | 30 | # train from scratch 31 | python train.py --tag $TAG \ 32 | --dataset $DATA \ 33 | --data_folder /path/to/data \ 34 | --data_folder2 /path/to/data/WebFG-496 \ 35 | --model resnet50 \ 36 | --method simclr \ 37 | --semisup_method openmatch \ 38 | --optimizer sgd \ 39 | --learning_rate 0.03 \ 40 | --weight_decay 1e-4 \ 41 | --cosine \ 42 | --start_fix 20 \ 43 | --total_step 65536 \ 44 | --eval_step 512 \ 45 | --batch_size 64 \ 46 | --label_ratio 0.5 \ 47 | --mu 2 \ 48 | --T 1 \ 49 | --lambda_oem 0.1 \ 50 | --lambda_socr 0.5 \ 51 | --threshold 0.0 \ 52 | --save_freq 30 53 | -------------------------------------------------------------------------------- /downstream/opensemi/run_selftraining.sh: -------------------------------------------------------------------------------- 1 | TAG="" 2 | DATA="" 3 | 4 | # SimCore pretrained checkpoint 5 | python train.py --tag $TAG \ 6 | --dataset $DATA \ 7 | --data_folder /path/to/data \ 8 | --data_folder2 /path/to/data/WebFG-496 \ 9 | --model resnet50 \ 10 | --method simclr \ 11 | --semisup_method self_training \ 12 | --optimizer sgd \ 13 | --learning_rate 0.03 \ 14 | --weight_decay 1e-4 \ 15 | --cosine \ 16 | --total_steps 16384 \ 17 | --eval_step 512 \ 18 | --batch_size 256 \ 19 | --label_ratio 0.5 \ 20 | --mu 1 \ 21 | --teacher_epochs 100 \ 22 | --teacher_batch_size 256 \ 23 | --T 1 \ 24 | --lambda_u 0.5 \ 25 | --save_freq 10 \ 26 | --pretrained \ 27 | --pretrained_ckpt /path/to/ckpt/last.pth 28 | 29 | # train from scratch 30 | python train.py --tag $TAG \ 31 | --dataset $DATA \ 32 | --data_folder /path/to/data \ 33 | --data_folder2 /path/to/data/WebFG-496 \ 34 | --model resnet50 \ 35 | --method simclr \ 36 | --semisup_method self_training \ 37 | --optimizer sgd \ 38 | --learning_rate 0.03 \ 39 | --weight_decay 1e-4 \ 40 | --cosine \ 41 | --total_steps 65536 \ 42 | --eval_step 512 \ 43 | --batch_size 64 \ 44 | --label_ratio 0.5 \ 45 | --mu 1 \ 46 | --teacher_epochs 500 \ 47 | --teacher_batch_size 64 \ 48 | --T 1 \ 49 | --lambda_u 0.5 \ 50 | --save_freq 30 -------------------------------------------------------------------------------- /downstream/opensemi/util/dataset.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import os 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | 10 | 11 | class ImageFolderSemiSup(data.Dataset): 12 | def __init__(self, root='', transform=None, p=1.0, index=None): 13 | super(ImageFolderSemiSup, self).__init__() 14 | 15 | self.root = root 16 | self.transform = transform 17 | self.p = p 18 | 19 | self.train_dataset = datasets.ImageFolder(root=self.root, transform=self.transform) 20 | if index is None: 21 | # randomly select samples 22 | random.seed(0) 23 | self.dataset_len = int(self.p * len(self.train_dataset)) 24 | self.sampled_index = random.sample(range(len(self.train_dataset)), self.dataset_len) 25 | else: 26 | # samples selected by x_u_split function 27 | self.sampled_index = index 28 | self.dataset_len = len(self.sampled_index) 29 | 30 | def __len__(self): 31 | return self.dataset_len 32 | 33 | def __getitem__(self, index): 34 | return self.train_dataset[self.sampled_index[index]] 35 | 36 | 37 | def x_u_split(args, include_lb_to_ulb=False): 38 | """ in TorchSSL, include_lb_to_ulb=True 39 | """ 40 | labeled_idx = random.sample(range(args.n_data), args.num_labeled) 41 | if include_lb_to_ulb: 42 | unlabeled_idx = list(range(args.n_data)) 43 | else: 44 | unlabeled_idx = list(set(range(args.n_data)) - set(labeled_idx)) 45 | 46 | return labeled_idx, unlabeled_idx 47 | 48 | 49 | class MergeDataset(data.Dataset): 50 | def __init__(self, dataset1, dataset2): 51 | super(MergeDataset, self).__init__() 52 | assert isinstance(dataset1, data.Dataset) and isinstance(dataset2, data.Dataset) 53 | self.dataset1 = dataset1 54 | self.dataset2 = dataset2 55 | self.n1 = len(self.dataset1) 56 | self.n2 = len(self.dataset2) 57 | 58 | def __len__(self): 59 | return self.n1 + self.n2 60 | 61 | def __getitem__(self, idx): 62 | if idx < self.n1: 63 | return self.dataset1[idx] 64 | else: 65 | return self.dataset2[idx - self.n1] -------------------------------------------------------------------------------- /downstream/opensemi/util/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | np.random.seed(0) 4 | 5 | import torch 6 | import torchvision 7 | from torch import nn 8 | from torchvision.transforms import transforms 9 | 10 | 11 | class SelfTrainingTransform: 12 | def __init__(self, size, mean, std): 13 | self.transform = transforms.Compose([ 14 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 15 | transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean=mean, std=std) 18 | ]) 19 | 20 | def __call__(self, x): 21 | return self.transform(x) 22 | 23 | 24 | class OpenMatchTransform: 25 | def __init__(self, size, mean, std): 26 | self.weak = transforms.Compose([ 27 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 28 | transforms.RandomHorizontalFlip()]) 29 | 30 | self.weak2 = transforms.Compose([ 31 | transforms.Resize(256), 32 | transforms.CenterCrop(size), 33 | transforms.RandomHorizontalFlip()]) 34 | 35 | self.strong = transforms.Compose([ 36 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 37 | transforms.RandomHorizontalFlip(), 38 | torchvision.transforms.RandAugment(num_ops=2, magnitude=10)]) 39 | # RandAugmentMC(n=2, m=10)]) # refer to original code in https://github.com/kekmodel/FixMatch-pytorch/blob/master/dataset/randaugment.py 40 | 41 | self.normalize = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=mean, std=std)]) 44 | 45 | def __call__(self, x): 46 | weak = self.weak(x) 47 | weak2 = self.weak2(x) 48 | strong = self.strong(x) 49 | return self.normalize(weak), self.normalize(strong), self.normalize(weak2) 50 | 51 | 52 | _transform_class_map = { 53 | 'self_training': SelfTrainingTransform, 54 | 'openmatch': OpenMatchTransform, 55 | } 56 | 57 | def get_semisup_transform_class(key): 58 | if key in _transform_class_map: 59 | return _transform_class_map[key] 60 | else: 61 | raise ValueError('Invalid method: {}'.format(key)) 62 | -------------------------------------------------------------------------------- /downstream/segmentation/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | # from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd 2 | from dataloaders.datasets import pets_segmentation, cub_segmentation 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def make_data_loader(args, **kwargs): 7 | if args.dataset == 'pets': 8 | # root = '~~~/pets/' 9 | train_set = pets_segmentation.PetsSegmentationDataset(args.root, test=False, class_label=args.class_label) 10 | val_set = pets_segmentation.PetsSegmentationDataset(args.root, test=True, class_label=args.class_label) 11 | 12 | # class: foreground, background, unknown (boundary is not considered) 13 | num_class = 2 if not args.class_label else 38 14 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 15 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 16 | test_loader = None 17 | 18 | return train_loader, val_loader, test_loader, num_class 19 | 20 | elif args.dataset == 'cub': 21 | # root = '~~~/cub/' 22 | train_set = cub_segmentation.CUBSegmentationDataset(args.root, test=False, class_label=args.class_label) 23 | val_set = cub_segmentation.CUBSegmentationDataset(args.root, test=True, class_label=args.class_label) 24 | 25 | # class: background, foreground (5/5 experts coincide), unknown (below 4/5 experts coincide) 26 | num_class = 2 if not args.class_label else 201 27 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 28 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 29 | test_loader = None 30 | 31 | return train_loader, val_loader, test_loader, num_class 32 | 33 | # if args.dataset == 'pascal': 34 | # train_set = pascal.VOCSegmentation(args, split='train') 35 | # val_set = pascal.VOCSegmentation(args, split='val') 36 | # if args.use_sbd: 37 | # sbd_train = sbd.SBDSegmentation(args, split=['train', 'val']) 38 | # train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set]) 39 | 40 | # num_class = train_set.NUM_CLASSES 41 | # train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 42 | # val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 43 | # test_loader = None 44 | 45 | # return train_loader, val_loader, test_loader, num_class 46 | 47 | # elif args.dataset == 'cityscapes': 48 | # train_set = cityscapes.CityscapesSegmentation(args, split='train') 49 | # val_set = cityscapes.CityscapesSegmentation(args, split='val') 50 | # test_set = cityscapes.CityscapesSegmentation(args, split='val') 51 | # num_class = train_set.NUM_CLASSES 52 | # train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 53 | # val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 54 | # test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 55 | 56 | # return train_loader, val_loader, test_loader, num_class 57 | 58 | # elif args.dataset == 'coco': 59 | # train_set = coco.COCOSegmentation(args, split='train') 60 | # val_set = coco.COCOSegmentation(args, split='val') 61 | # test_set = coco.COCOSegmentation(args, split='val') 62 | # num_class = train_set.NUM_CLASSES 63 | # train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 64 | # val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 65 | # test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 66 | # return train_loader, val_loader, test_loader, num_class 67 | 68 | else: 69 | raise NotImplementedError 70 | 71 | -------------------------------------------------------------------------------- /downstream/segmentation/dataloaders/datasets/cub_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import copy 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from dataloaders import custom_transforms as tr 8 | 9 | 10 | class CUBSegmentationDataset(Dataset): 11 | def __init__(self, root, test=False, class_label=False): 12 | ANNOT_PATH = 'segmentations' 13 | self.image_paths = list() 14 | self.label_paths = list() 15 | self.test = test 16 | self.class_label = class_label 17 | if self.test: 18 | split = 'test' 19 | else: 20 | split = 'train' 21 | datapath = os.path.join(root, split) 22 | for i, cls in enumerate(os.listdir(datapath)): 23 | for img_name in os.listdir(os.path.join(datapath, cls)): 24 | self.image_paths.append((os.path.join(datapath, cls, img_name), i+1)) 25 | self.label_paths.append(os.path.join(root, ANNOT_PATH, cls, img_name.replace('.jpg', '.png'))) 26 | 27 | def __len__(self): 28 | return len(self.image_paths) 29 | 30 | def __getitem__(self, idx): 31 | path, cls = self.image_paths[idx] 32 | image = Image.open(path).convert('RGB') 33 | label = Image.open(self.label_paths[idx]) 34 | sample = {'image': image, 'label': label} 35 | 36 | if not self.test: 37 | ret = self.transform_tr(sample) 38 | else: 39 | ret = self.transform_val(sample) 40 | 41 | new_label = copy.deepcopy(ret['label']) 42 | new_label[ret['label'] > 204.] = 1. 43 | new_label[new_label >= 51.] = 255. 44 | if len(new_label.shape) != 2: 45 | new_label = new_label[:,:,0] 46 | if self.class_label: 47 | new_label[new_label == 1.] *= cls 48 | ret['label'] = new_label 49 | 50 | return ret 51 | 52 | def transform_tr(self, sample): 53 | composed_transforms = transforms.Compose([ 54 | tr.RandomHorizontalFlip(), 55 | tr.RandomScaleCrop(base_size=256, crop_size=224), 56 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 57 | tr.ToTensor()]) 58 | 59 | return composed_transforms(sample) 60 | 61 | def transform_val(self, sample): 62 | composed_transforms = transforms.Compose([ 63 | tr.FixScaleCrop(crop_size=224), 64 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 65 | tr.ToTensor()]) 66 | 67 | return composed_transforms(sample) 68 | 69 | -------------------------------------------------------------------------------- /downstream/segmentation/dataloaders/datasets/pets_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import copy 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from dataloaders import custom_transforms as tr 8 | 9 | 10 | class PetsSegmentationDataset(Dataset): 11 | def __init__(self, root, test=False, class_label=False): 12 | ANNOT_PATH = 'annotations/trimaps' 13 | self.image_paths = list() 14 | self.label_paths = list() 15 | self.test = test 16 | self.class_label = class_label 17 | if self.test: 18 | split = 'test' 19 | else: 20 | split = 'train' 21 | datapath = os.path.join(root, split) 22 | for i, cls in enumerate(os.listdir(datapath)): 23 | for img_name in os.listdir(os.path.join(datapath, cls)): 24 | self.image_paths.append((os.path.join(datapath, cls, img_name), i+1)) 25 | self.label_paths.append(os.path.join(root, ANNOT_PATH, img_name.replace('.jpg', '.png'))) 26 | 27 | def __len__(self): 28 | return len(self.image_paths) 29 | 30 | def __getitem__(self, idx): 31 | path, cls = self.image_paths[idx] 32 | image = Image.open(path).convert('RGB') 33 | label = Image.open(self.label_paths[idx]) 34 | sample = {'image': image, 'label': label} 35 | 36 | # convert trimap to rgb array 37 | # label = np.array(label) 38 | # new_label = np.zeros(label.shape[0], label.shape[1], 3) 39 | # new_label[label==1, 0] = 1 40 | # new_label[label==2, 1] = 1 41 | # new_label[label==3, 2] = 1 42 | # mask = Image.fromarray(np.uint8(new_label*255)) 43 | 44 | if not self.test: 45 | ret = self.transform_tr(sample) 46 | else: 47 | ret = self.transform_val(sample) 48 | 49 | new_label = copy.deepcopy(ret['label']) 50 | new_label[ret['label']==0.] = 255. 51 | new_label[ret['label']==3.] = 255. 52 | new_label[ret['label']==2.] = 0. 53 | # class labels 54 | if self.class_label: 55 | new_label[ret['label']==1.] *= cls 56 | 57 | return ret 58 | 59 | def transform_tr(self, sample): 60 | composed_transforms = transforms.Compose([ 61 | tr.RandomHorizontalFlip(), 62 | tr.RandomScaleCrop(base_size=256, crop_size=224), 63 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 64 | tr.ToTensor()]) 65 | 66 | return composed_transforms(sample) 67 | 68 | def transform_val(self, sample): 69 | composed_transforms = transforms.Compose([ 70 | tr.FixScaleCrop(crop_size=224), 71 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 72 | tr.ToTensor()]) 73 | 74 | return composed_transforms(sample) 75 | 76 | -------------------------------------------------------------------------------- /downstream/segmentation/modeling/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class _ASPPModule(nn.Module): 9 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 10 | super(_ASPPModule, self).__init__() 11 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 12 | stride=1, padding=padding, dilation=dilation, bias=False) 13 | self.bn = BatchNorm(planes) 14 | self.relu = nn.ReLU() 15 | 16 | self._init_weight() 17 | 18 | def forward(self, x): 19 | x = self.atrous_conv(x) 20 | x = self.bn(x) 21 | 22 | return self.relu(x) 23 | 24 | def _init_weight(self): 25 | for m in self.modules(): 26 | if isinstance(m, nn.Conv2d): 27 | torch.nn.init.kaiming_normal_(m.weight) 28 | elif isinstance(m, SynchronizedBatchNorm2d): 29 | m.weight.data.fill_(1) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.BatchNorm2d): 32 | m.weight.data.fill_(1) 33 | m.bias.data.zero_() 34 | 35 | 36 | class ASPP(nn.Module): 37 | def __init__(self, backbone, output_stride, BatchNorm): 38 | super(ASPP, self).__init__() 39 | if backbone == 'drn': 40 | inplanes = 512 41 | elif backbone == 'mobilenet': 42 | inplanes = 320 43 | elif backbone == 'resnet34': 44 | inplanes = 512 45 | else: 46 | inplanes = 2048 47 | if output_stride == 16: 48 | dilations = [1, 6, 12, 18] 49 | elif output_stride == 8: 50 | dilations = [1, 12, 24, 36] 51 | else: 52 | raise NotImplementedError 53 | 54 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 55 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 56 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 57 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 58 | 59 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 60 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 61 | BatchNorm(256), 62 | nn.ReLU()) 63 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 64 | self.bn1 = BatchNorm(256) 65 | self.relu = nn.ReLU() 66 | self.dropout = nn.Dropout(0.5) 67 | self._init_weight() 68 | 69 | def forward(self, x): 70 | x1 = self.aspp1(x) 71 | x2 = self.aspp2(x) 72 | x3 = self.aspp3(x) 73 | x4 = self.aspp4(x) 74 | x5 = self.global_avg_pool(x) 75 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 76 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 77 | 78 | x = self.conv1(x) 79 | x = self.bn1(x) 80 | x = self.relu(x) 81 | 82 | return self.dropout(x) 83 | 84 | def _init_weight(self): 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | torch.nn.init.kaiming_normal_(m.weight) 90 | elif isinstance(m, SynchronizedBatchNorm2d): 91 | m.weight.data.fill_(1) 92 | m.bias.data.zero_() 93 | elif isinstance(m, nn.BatchNorm2d): 94 | m.weight.data.fill_(1) 95 | m.bias.data.zero_() 96 | 97 | 98 | def build_aspp(backbone, output_stride, BatchNorm): 99 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /downstream/segmentation/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from modeling.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm, pretrained=''): 4 | if backbone == 'resnet34': 5 | # return resnet.ResNet101(output_stride, BatchNorm) 6 | return resnet.ResNet34(pretrained=pretrained) 7 | elif backbone == 'resnet50': 8 | # return resnet.ResNet101(output_stride, BatchNorm) 9 | return resnet.ResNet50(pretrained=pretrained) 10 | elif backbone == 'xception': 11 | return xception.AlignedXception(output_stride, BatchNorm) 12 | elif backbone == 'drn': 13 | return drn.drn_d_54(BatchNorm) 14 | elif backbone == 'mobilenet': 15 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 16 | else: 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /downstream/segmentation/modeling/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, num_classes, backbone, BatchNorm): 10 | super(Decoder, self).__init__() 11 | if backbone in ['resnet34', 'drn']: 12 | low_level_inplanes = 64 13 | elif backbone == 'xception': 14 | low_level_inplanes = 128 15 | elif backbone == 'mobilenet': 16 | low_level_inplanes = 24 17 | elif backbone == 'resnet50': 18 | low_level_inplanes = 256 19 | else: 20 | raise NotImplementedError 21 | 22 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 23 | self.bn1 = BatchNorm(48) 24 | self.relu = nn.ReLU() 25 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 26 | BatchNorm(256), 27 | nn.ReLU(), 28 | nn.Dropout(0.5), 29 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 30 | BatchNorm(256), 31 | nn.ReLU(), 32 | nn.Dropout(0.1), 33 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 34 | self._init_weight() 35 | 36 | def forward(self, x, low_level_feat): 37 | low_level_feat = self.conv1(low_level_feat) 38 | low_level_feat = self.bn1(low_level_feat) 39 | low_level_feat = self.relu(low_level_feat) 40 | 41 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 42 | x = torch.cat((x, low_level_feat), dim=1) 43 | x = self.last_conv(x) 44 | 45 | return x 46 | 47 | def _init_weight(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | torch.nn.init.kaiming_normal_(m.weight) 51 | elif isinstance(m, SynchronizedBatchNorm2d): 52 | m.weight.data.fill_(1) 53 | m.bias.data.zero_() 54 | elif isinstance(m, nn.BatchNorm2d): 55 | m.weight.data.fill_(1) 56 | m.bias.data.zero_() 57 | 58 | 59 | def build_decoder(num_classes, backbone, BatchNorm): 60 | return Decoder(num_classes, backbone, BatchNorm) 61 | -------------------------------------------------------------------------------- /downstream/segmentation/modeling/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from modeling.aspp import build_aspp 6 | from modeling.decoder import build_decoder 7 | from modeling.backbone import build_backbone 8 | 9 | 10 | class DeepLab(nn.Module): 11 | def __init__(self, backbone='resnet50', output_stride=16, num_classes=21, 12 | sync_bn=True, freeze_bn=False, pretrained_backbone=''): 13 | super(DeepLab, self).__init__() 14 | if backbone == 'drn': 15 | output_stride = 8 16 | 17 | if sync_bn == True: 18 | BatchNorm = SynchronizedBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm, pretrained_backbone) 23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 24 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 25 | 26 | self.freeze_bn = freeze_bn 27 | 28 | def forward(self, input): 29 | x, low_level_feat = self.backbone(input) 30 | x = self.aspp(x) 31 | x = self.decoder(x, low_level_feat) 32 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 33 | 34 | return x 35 | 36 | def freeze_bn(self): 37 | for m in self.modules(): 38 | if isinstance(m, SynchronizedBatchNorm2d): 39 | m.eval() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.eval() 42 | 43 | def get_1x_lr_params(self): 44 | modules = [self.backbone] 45 | for i in range(len(modules)): 46 | for m in modules[i].named_modules(): 47 | if self.freeze_bn: 48 | if isinstance(m[1], nn.Conv2d): 49 | for p in m[1].parameters(): 50 | if p.requires_grad: 51 | yield p 52 | else: 53 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 54 | or isinstance(m[1], nn.BatchNorm2d): 55 | for p in m[1].parameters(): 56 | if p.requires_grad: 57 | yield p 58 | 59 | def get_10x_lr_params(self): 60 | modules = [self.aspp, self.decoder] 61 | for i in range(len(modules)): 62 | for m in modules[i].named_modules(): 63 | if self.freeze_bn: 64 | if isinstance(m[1], nn.Conv2d): 65 | for p in m[1].parameters(): 66 | if p.requires_grad: 67 | yield p 68 | else: 69 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 70 | or isinstance(m[1], nn.BatchNorm2d): 71 | for p in m[1].parameters(): 72 | if p.requires_grad: 73 | yield p 74 | 75 | if __name__ == "__main__": 76 | model = DeepLab(backbone='mobilenet', output_stride=16) 77 | model.eval() 78 | input = torch.rand(1, 3, 513, 513) 79 | output = model(input) 80 | print(output.size()) 81 | 82 | 83 | -------------------------------------------------------------------------------- /downstream/segmentation/modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /downstream/segmentation/modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /downstream/segmentation/modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /downstream/segmentation/modeling/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /downstream/segmentation/run_segmentation.sh: -------------------------------------------------------------------------------- 1 | # on pets or cub dataset 2 | python train.py \ 3 | --backbone resnet50 \ 4 | --lr 1e-1 \ 5 | --weight-decay 1e-4 \ 6 | --nesterov \ 7 | --epochs 30 \ 8 | --workers 16 \ 9 | --gpu-ids 0,1,2,3 \ 10 | --checkname deeplab-resnet50 \ 11 | --eval-interval 30 \ 12 | --dataset pets \ 13 | --root /path/to/data \ 14 | --pretrained-backbone /path/to/ckpt/last.pth 15 | -------------------------------------------------------------------------------- /downstream/segmentation/util/calculate_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from mypath import Path 5 | 6 | 7 | def calculate_weigths_labels(dataset, dataloader, num_classes): 8 | # Create an instance from the data loader 9 | z = np.zeros((num_classes,)) 10 | # Initialize tqdm 11 | tqdm_batch = tqdm(dataloader) 12 | print('Calculating classes weights') 13 | for sample in tqdm_batch: 14 | y = sample['label'] 15 | y = y.detach().cpu().numpy() 16 | mask = (y >= 0) & (y < num_classes) 17 | labels = y[mask].astype(np.uint8) 18 | count_l = np.bincount(labels, minlength=num_classes) 19 | z += count_l 20 | tqdm_batch.close() 21 | total_frequency = np.sum(z) 22 | class_weights = [] 23 | for frequency in z: 24 | class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) 25 | class_weights.append(class_weight) 26 | ret = np.array(class_weights) 27 | classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy') 28 | np.save(classes_weights_path, ret) 29 | 30 | return ret -------------------------------------------------------------------------------- /downstream/segmentation/util/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SegmentationLosses(object): 6 | def __init__(self, weight=None, size_average=True, batch_average=True, ignore_index=255, cuda=False): 7 | self.ignore_index = ignore_index 8 | self.weight = weight 9 | self.size_average = size_average 10 | self.batch_average = batch_average 11 | self.cuda = cuda 12 | 13 | def build_loss(self, mode='ce'): 14 | """Choices: ['ce' or 'focal']""" 15 | if mode == 'ce': 16 | return self.CrossEntropyLoss 17 | elif mode == 'focal': 18 | return self.FocalLoss 19 | else: 20 | raise NotImplementedError 21 | 22 | def CrossEntropyLoss(self, logit, target): 23 | n, c, h, w = logit.size() 24 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 25 | reduction='mean') 26 | if self.cuda: 27 | criterion = criterion.cuda() 28 | 29 | loss = criterion(logit, target.long()) 30 | 31 | # if self.batch_average: 32 | # loss /= n 33 | 34 | return loss 35 | 36 | def FocalLoss(self, logit, target, gamma=2, alpha=0.5): 37 | n, c, h, w = logit.size() 38 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 39 | size_average=self.size_average) 40 | if self.cuda: 41 | criterion = criterion.cuda() 42 | 43 | logpt = -criterion(logit, target.long()) 44 | pt = torch.exp(logpt) 45 | if alpha is not None: 46 | logpt *= alpha 47 | loss = -((1 - pt) ** gamma) * logpt 48 | 49 | if self.batch_average: 50 | loss /= n 51 | 52 | return loss 53 | 54 | if __name__ == "__main__": 55 | loss = SegmentationLosses(cuda=True) 56 | a = torch.rand(1, 3, 7, 7).cuda() 57 | b = torch.rand(1, 7, 7).cuda() 58 | print(loss.CrossEntropyLoss(a, b).item()) 59 | print(loss.FocalLoss(a, b, gamma=0, alpha=None).item()) 60 | print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item()) 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /downstream/segmentation/util/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | 14 | class LR_Scheduler(object): 15 | """Learning Rate Scheduler 16 | 17 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 18 | 19 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 20 | 21 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 22 | 23 | Args: 24 | args: 25 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 26 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 27 | :attr:`args.lr_step` 28 | 29 | iters_per_epoch: number of iterations per epoch 30 | """ 31 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 32 | lr_step=0, warmup_epochs=0): 33 | self.mode = mode 34 | print('Using {} LR Scheduler!'.format(self.mode)) 35 | self.lr = base_lr 36 | if mode == 'step': 37 | assert lr_step 38 | self.lr_step = lr_step 39 | self.iters_per_epoch = iters_per_epoch 40 | self.N = num_epochs * iters_per_epoch 41 | self.epoch = -1 42 | self.warmup_iters = warmup_epochs * iters_per_epoch 43 | 44 | def __call__(self, optimizer, i, epoch, best_pred, best_classIoU): 45 | T = epoch * self.iters_per_epoch + i 46 | if self.mode == 'cos': 47 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 48 | elif self.mode == 'poly': 49 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 50 | elif self.mode == 'step': 51 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 52 | else: 53 | raise NotImplemented 54 | # warm up lr schedule 55 | if self.warmup_iters > 0 and T < self.warmup_iters: 56 | lr = lr * 1.0 * T / self.warmup_iters 57 | if epoch > self.epoch: 58 | print('\n=>Epoches %i, learning rate = %.4f, \ 59 | previous best = %.4f, (class IoU) %s' % (epoch, lr, best_pred, best_classIoU)) 60 | self.epoch = epoch 61 | assert lr >= 0 62 | self._adjust_learning_rate(optimizer, lr) 63 | 64 | def _adjust_learning_rate(self, optimizer, lr): 65 | if len(optimizer.param_groups) == 1: 66 | optimizer.param_groups[0]['lr'] = lr 67 | else: 68 | # enlarge the lr at the head 69 | optimizer.param_groups[0]['lr'] = lr 70 | for i in range(1, len(optimizer.param_groups)): 71 | optimizer.param_groups[i]['lr'] = lr * 10 72 | -------------------------------------------------------------------------------- /downstream/segmentation/util/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = np.zeros((self.num_class,)*2) 8 | 9 | def Pixel_Accuracy(self): 10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 11 | return Acc 12 | 13 | def Pixel_Accuracy_Class(self): 14 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 15 | Acc = np.nanmean(Acc) 16 | return Acc 17 | 18 | def Precision(self): 19 | prec = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=0) 20 | return prec 21 | 22 | def Recall(self): 23 | rec = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 24 | return rec 25 | 26 | def Intersection_over_Union(self): 27 | IoU = np.diag(self.confusion_matrix) / ( 28 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 29 | np.diag(self.confusion_matrix)) 30 | return IoU 31 | 32 | def Mean_Intersection_over_Union(self): 33 | MIoU = np.diag(self.confusion_matrix) / ( 34 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 35 | np.diag(self.confusion_matrix)) 36 | MIoU = np.nanmean(MIoU) 37 | return MIoU 38 | 39 | def Frequency_Weighted_Intersection_over_Union(self): 40 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 41 | iu = np.diag(self.confusion_matrix) / ( 42 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 43 | np.diag(self.confusion_matrix)) 44 | 45 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 46 | return FWIoU 47 | 48 | def _generate_matrix(self, gt_image, pre_image): 49 | mask = (gt_image >= 0) & (gt_image < self.num_class) 50 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 51 | count = np.bincount(label, minlength=self.num_class**2) 52 | confusion_matrix = count.reshape(self.num_class, self.num_class) 53 | return confusion_matrix 54 | 55 | def add_batch(self, gt_image, pre_image): 56 | assert gt_image.shape == pre_image.shape 57 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 58 | 59 | def reset(self): 60 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /downstream/segmentation/util/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import glob 5 | from collections import OrderedDict 6 | 7 | 8 | class Saver(object): 9 | def __init__(self, args): 10 | self.args = args 11 | self.directory = os.path.join('run', args.dataset, args.checkname) 12 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 13 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 14 | 15 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 16 | if not os.path.exists(self.experiment_dir): 17 | os.makedirs(self.experiment_dir) 18 | 19 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 20 | """Saves checkpoint to disk""" 21 | filename = os.path.join(self.experiment_dir, filename) 22 | torch.save(state, filename) 23 | if is_best: 24 | best_pred = state['best_pred'] 25 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 26 | f.write(str(best_pred)) 27 | if self.runs: 28 | previous_miou = [0.0] 29 | for run in self.runs: 30 | run_id = run.split('_')[-1] 31 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 32 | if os.path.exists(path): 33 | with open(path, 'r') as f: 34 | miou = float(f.readline()) 35 | previous_miou.append(miou) 36 | else: 37 | continue 38 | max_miou = max(previous_miou) 39 | if best_pred > max_miou: 40 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 41 | else: 42 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 43 | 44 | def save_experiment_config(self): 45 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 46 | log_file = open(logfile, 'w') 47 | p = OrderedDict() 48 | p['datset'] = self.args.dataset 49 | p['backbone'] = self.args.backbone 50 | p['out_stride'] = self.args.out_stride 51 | p['lr'] = self.args.lr 52 | p['lr_scheduler'] = self.args.lr_scheduler 53 | p['loss_type'] = self.args.loss_type 54 | p['epoch'] = self.args.epochs 55 | p['base_size'] = self.args.base_size 56 | p['crop_size'] = self.args.crop_size 57 | 58 | for key, val in p.items(): 59 | log_file.write(key + ':' + str(val) + '\n') 60 | log_file.close() -------------------------------------------------------------------------------- /downstream/segmentation/util/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import make_grid 4 | from tensorboardX import SummaryWriter 5 | # from dataloaders.utils import decode_seg_map_sequence 6 | 7 | 8 | class TensorboardSummary(object): 9 | def __init__(self, directory): 10 | self.directory = directory 11 | 12 | def create_summary(self): 13 | writer = SummaryWriter(log_dir=os.path.join(self.directory)) 14 | return writer 15 | 16 | # def visualize_image(self, writer, dataset, image, target, output, global_step): 17 | # grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) 18 | # writer.add_image('Image', grid_image, global_step) 19 | # grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), 20 | # dataset=dataset), 3, normalize=False, range=(0, 255)) 21 | # writer.add_image('Predicted label', grid_image, global_step) 22 | # grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), 23 | # dataset=dataset), 3, normalize=False, range=(0, 255)) 24 | # writer.add_image('Groundtruth label', grid_image, global_step) 25 | -------------------------------------------------------------------------------- /downstream/semisup/models: -------------------------------------------------------------------------------- 1 | ../../models -------------------------------------------------------------------------------- /downstream/semisup/run_semisup.sh: -------------------------------------------------------------------------------- 1 | # MixMatch template -> for other methods, please change semisup_method and modify related arguments to default values 2 | TAG="" 3 | DATA="" 4 | 5 | python train.py --tag $TAG \ 6 | --cosine \ 7 | --total_step 16384 \ 8 | --eval_step 512 \ 9 | --dataset $DATA \ 10 | --data_folder /path/to/data \ 11 | --method mixmatch \ 12 | --pretrained \ 13 | --pretrained_ckpt /path/to/ckpt \ 14 | --learning_rate 3e-2 \ 15 | --weight_decay 1e-4 \ 16 | --T 0.5 \ 17 | --mixup_beta 0.75 \ 18 | --lambda_u 75 -------------------------------------------------------------------------------- /downstream/semisup/util/dataset.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import os 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | 10 | 11 | class ImageFolderSemiSup(data.Dataset): 12 | def __init__(self, root='', transform=None, p=1.0, index=None, return_idx=False): 13 | super(ImageFolderSemiSup, self).__init__() 14 | 15 | self.root = root 16 | self.transform = transform 17 | self.p = p 18 | self.return_idx = return_idx 19 | 20 | self.train_dataset = datasets.ImageFolder(root=self.root, transform=self.transform) 21 | if index is None: 22 | # randomly select samples 23 | random.seed(0) 24 | self.dataset_len = int(self.p * len(self.train_dataset)) 25 | self.sampled_index = random.sample(range(len(self.train_dataset)), self.dataset_len) 26 | else: 27 | # samples selected by x_u_split function 28 | self.sampled_index = index 29 | self.dataset_len = len(self.sampled_index) 30 | 31 | def __len__(self): 32 | return self.dataset_len 33 | 34 | def __getitem__(self, index): 35 | if self.return_idx: 36 | return index, self.train_dataset[self.sampled_index[index]] 37 | else: 38 | return self.train_dataset[self.sampled_index[index]] 39 | 40 | 41 | def x_u_split(args, include_lb_to_ulb=False): 42 | """ in TorchSSL, include_lb_to_ulb=True 43 | """ 44 | labeled_idx = random.sample(range(args.n_data), args.num_labeled) 45 | if include_lb_to_ulb: 46 | unlabeled_idx = list(range(args.n_data)) 47 | else: 48 | unlabeled_idx = list(set(range(args.n_data)) - set(labeled_idx)) 49 | 50 | return labeled_idx, unlabeled_idx -------------------------------------------------------------------------------- /downstream/semisup/util/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | np.random.seed(0) 4 | 5 | import torch 6 | import torchvision 7 | from torch import nn 8 | from torchvision.transforms import transforms 9 | 10 | 11 | # for semi-supervised fine-tuning 12 | class MixMatchTransform: 13 | def __init__(self, size, mean, std): 14 | self.transform = transforms.Compose([ 15 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 16 | transforms.RandomHorizontalFlip()]) 17 | self.normalize = transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=mean, std=std)]) 20 | 21 | def __call__(self, x): 22 | x1 = self.transform(x) 23 | x2 = self.transform(x) 24 | return self.normalize(x1), self.normalize(x2) 25 | 26 | 27 | class ReMixMatchTransform: 28 | def __init__(self, size, mean, std): 29 | self.weak = transforms.Compose([ 30 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 31 | transforms.RandomHorizontalFlip()]) 32 | self.strong = transforms.Compose([ 33 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 34 | transforms.RandomHorizontalFlip(), 35 | torchvision.transforms.RandAugment(num_ops=2, magnitude=10)]) # TorchSSL used 3, 5 36 | self.normalize = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize(mean=mean, std=std)]) 39 | 40 | def __call__(self, x): 41 | weak = self.normalize(self.weak(x)) 42 | strong_1 = self.normalize(self.strong(x)) 43 | strong_2 = self.normalize(self.strong(x)) 44 | 45 | rotate_v_list = [0, 90, 180, 270] 46 | rotate_v1 = np.random.choice(rotate_v_list, 1).item() 47 | strong_1_rot = torchvision.transforms.functional.rotate(strong_1, rotate_v1) 48 | return weak, strong_1, strong_2, strong_1_rot, rotate_v_list.index(rotate_v1) 49 | 50 | 51 | class FixMatchTransform: 52 | def __init__(self, size, mean, std): 53 | self.weak = transforms.Compose([ 54 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 55 | transforms.RandomHorizontalFlip()]) 56 | self.strong = transforms.Compose([ 57 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 58 | transforms.RandomHorizontalFlip(), 59 | torchvision.transforms.RandAugment(num_ops=2, magnitude=10)]) 60 | # RandAugmentMC(n=2, m=10)]) # refer to original code in https://github.com/kekmodel/FixMatch-pytorch/blob/master/dataset/randaugment.py 61 | self.normalize = transforms.Compose([ 62 | transforms.ToTensor(), 63 | transforms.Normalize(mean=mean, std=std)]) 64 | 65 | def __call__(self, x): 66 | weak = self.weak(x) 67 | strong = self.strong(x) 68 | return self.normalize(weak), self.normalize(strong) 69 | 70 | 71 | class FlexMatchTransform: 72 | def __init__(self, size, mean, std): 73 | self.weak = transforms.Compose([ 74 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 75 | transforms.RandomHorizontalFlip()]) 76 | self.strong = transforms.Compose([ 77 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 78 | transforms.RandomHorizontalFlip(), 79 | torchvision.transforms.RandAugment(num_ops=2, magnitude=10)]) 80 | self.normalize = transforms.Compose([ 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean=mean, std=std)]) 83 | 84 | def __call__(self, x): 85 | weak = self.weak(x) 86 | strong = self.strong(x) 87 | return self.normalize(weak), self.normalize(strong) 88 | 89 | 90 | _transform_class_map = { 91 | 'mixmatch': MixMatchTransform, 92 | 'remixmatch': ReMixMatchTransform, 93 | 'fixmatch': FixMatchTransform, 94 | 'flexmatch': FlexMatchTransform 95 | } 96 | 97 | def get_semisup_transform_class(key): 98 | if key in _transform_class_map: 99 | return _transform_class_map[key] 100 | else: 101 | raise ValueError('Invalid method: {}'.format(key)) 102 | -------------------------------------------------------------------------------- /downstream/weblysup/models: -------------------------------------------------------------------------------- 1 | ../../models/ -------------------------------------------------------------------------------- /downstream/weblysup/run_coteaching.sh: -------------------------------------------------------------------------------- 1 | TAG="" 2 | DATA="" 3 | 4 | # SimCore pretrained checkpoint 5 | python train.py --tag $TAG \ 6 | --dataset $DATA \ 7 | --data_folder /path/to/data \ 8 | --data_folder2 /path/to/data/WebFG-496 \ 9 | --model resnet50 \ 10 | --method simclr \ 11 | --optimizer sgd \ 12 | --learning_rate 0.03 \ 13 | --weight_decay 1e-4 \ 14 | --cosine \ 15 | --epochs 200 \ 16 | --batch_size 256 \ 17 | --noise_method co_teaching \ 18 | --forget_rate 0.2 \ 19 | --num_gradual 10 \ 20 | --exponent 1 \ 21 | --pretrained \ 22 | --pretrained_ckpt /path/to/ckpt/last.pth 23 | 24 | # train from scratch 25 | python train.py --tag $TAG \ 26 | --dataset $DATA \ 27 | --data_folder /path/to/data \ 28 | --data_folder2 /path/to/data/WebFG-496 \ 29 | --model resnet50 \ 30 | --method simclr \ 31 | --optimizer sgd \ 32 | --learning_rate 0.03 \ 33 | --weight_decay 1e-4 \ 34 | --cosine \ 35 | --epochs 1000 \ 36 | --batch_size 256 \ 37 | --noise_method co_teaching \ 38 | --forget_rate 0.2 \ 39 | --num_gradual 10 \ 40 | --exponent 1 41 | -------------------------------------------------------------------------------- /downstream/weblysup/run_dividemix.sh: -------------------------------------------------------------------------------- 1 | TAG="" 2 | DATA="" 3 | 4 | # SimCore pretrained checkpoint 5 | python train.py --tag $TAG \ 6 | --dataset $DATA \ 7 | --data_folder /path/to/data \ 8 | --data_folder2 /path/to/data/WebFG-496 \ 9 | --model resnet50 \ 10 | --method simclr \ 11 | --optimizer sgd \ 12 | --learning_rate 0.03 \ 13 | --weight_decay 1e-4 \ 14 | --cosine \ 15 | --epochs 400 \ 16 | --batch_size 128 \ 17 | --noise_method dividemix \ 18 | --warmup 30 \ 19 | --alpha 0.75 \ 20 | --lambda_u 75 \ 21 | --p_threshold 0.2 \ 22 | --T 0.5 \ 23 | --pretrained \ 24 | --pretrained_ckpt /path/to/ckpt/last.pth 25 | 26 | # train from scratch 27 | python train.py --tag $TAG \ 28 | --dataset $DATA \ 29 | --data_folder /path/to/data \ 30 | --data_folder2 /path/to/data/WebFG-496 \ 31 | --model resnet50 \ 32 | --method simclr \ 33 | --optimizer sgd \ 34 | --learning_rate 0.03 \ 35 | --weight_decay 1e-4 \ 36 | --cosine \ 37 | --epochs 400 \ 38 | --batch_size 128 \ 39 | --noise_method dividemix \ 40 | --warmup 30 \ 41 | --alpha 0.75 \ 42 | --lambda_u 75 \ 43 | --p_threshold 0.2 \ 44 | --T 0.5 45 | -------------------------------------------------------------------------------- /downstream/weblysup/util/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | import torch 7 | import torch.utils.data as data 8 | from .transform import TwoCropTransform 9 | 10 | 11 | class IndexDataset(data.Dataset): 12 | def __init__(self, train_dataset): 13 | super(IndexDataset, self).__init__() 14 | 15 | self.train_dataset = train_dataset 16 | self.dataset_len = len(self.train_dataset) 17 | 18 | def __len__(self): 19 | return self.dataset_len 20 | 21 | def __getitem__(self, index): 22 | image, label = self.train_dataset[index] 23 | return image, label, index 24 | 25 | 26 | class MergeDataset(data.Dataset): 27 | def __init__(self, dataset1, dataset2): 28 | super(MergeDataset, self).__init__() 29 | assert isinstance(dataset1, data.Dataset) and isinstance(dataset2, data.Dataset) 30 | self.dataset1 = dataset1 31 | self.dataset2 = dataset2 32 | self.n1 = len(self.dataset1) 33 | self.n2 = len(self.dataset2) 34 | 35 | def __len__(self): 36 | return self.n1 + self.n2 37 | 38 | def __getitem__(self, idx): 39 | if idx < self.n1: 40 | return self.dataset1[idx] 41 | else: 42 | return self.dataset2[idx - self.n1] 43 | 44 | 45 | class NoisyMergeDataset(data.Dataset): 46 | def __init__(self, dataset1, dataset2, probability, index=None, return_prob=False): 47 | super(NoisyMergeDataset, self).__init__() 48 | self.dataset1 = deepcopy(dataset1) 49 | self.dataset2 = deepcopy(dataset2) 50 | self.n1 = len(self.dataset1) 51 | self.n2 = len(self.dataset2) 52 | 53 | self.dataset1.transform = TwoCropTransform(self.dataset1.transform) 54 | self.dataset2.transform = TwoCropTransform(self.dataset2.transform) 55 | 56 | self.sampled_index = index 57 | self.dataset_len = len(self.sampled_index) 58 | self.probability = probability 59 | self.return_prob = return_prob 60 | 61 | def __len__(self): 62 | return self.dataset_len 63 | 64 | def __getitem__(self, index): 65 | idx = self.sampled_index[index] 66 | if idx < self.n1: 67 | [image1, image2], label = self.dataset1[idx] 68 | else: 69 | [image1, image2], label = self.dataset2[idx - self.n1] 70 | 71 | if self.return_prob: 72 | return image1, image2, label, self.probability[idx] 73 | else: 74 | return image1, image2 -------------------------------------------------------------------------------- /downstream/weblysup/util/methods.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.mixture import GaussianMixture 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | # Loss functions 10 | class CoTeaching(nn.Module): 11 | def forward(self, y_1, y_2, t, forget_rate): 12 | loss_1 = F.cross_entropy(y_1, t, reduce = False) 13 | ind_1_sorted = torch.argsort(loss_1.data) 14 | loss_1_sorted = loss_1[ind_1_sorted] 15 | 16 | loss_2 = F.cross_entropy(y_2, t, reduce = False) 17 | ind_2_sorted = torch.argsort(loss_2.data) 18 | loss_2_sorted = loss_2[ind_2_sorted] 19 | 20 | remember_rate = 1 - forget_rate 21 | num_remember = int(remember_rate * len(loss_1_sorted)) 22 | 23 | ind_1_update=ind_1_sorted[:num_remember] 24 | ind_2_update=ind_2_sorted[:num_remember] 25 | # exchange 26 | loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update]) 27 | loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update]) 28 | 29 | return loss_1_update, loss_2_update 30 | 31 | 32 | class DivideMix(nn.Module): 33 | """ we assume that WebFG dataset has moderate noise_ratio, based on SimCore sampling ratio, and assym noise_mode. 34 | """ 35 | def __init__(self, lambda_u): 36 | super(DivideMix, self).__init__() 37 | 38 | self.all_loss = [[],[]] # save the history of losses from two networks 39 | self.ce = nn.CrossEntropyLoss(reduction='none') 40 | self.lambda_u = lambda_u 41 | 42 | def fit_gmm(self, train_loader, model, model_idx): 43 | model.eval() 44 | 45 | losses = torch.zeros(len(train_loader.dataset)) 46 | with torch.no_grad(): 47 | for batch_idx, (inputs, targets, index) in enumerate(train_loader): 48 | inputs, targets = inputs.cuda(), targets.cuda() 49 | outputs = model(inputs) 50 | loss = self.ce(outputs, targets) 51 | for b in range(inputs.size(0)): 52 | losses[index[b]] = loss[b] 53 | losses = (losses - losses.min()) / (losses.max() - losses.min()) 54 | self.all_loss[model_idx].append(losses) 55 | 56 | history = torch.stack(self.all_loss[model_idx]) 57 | input_loss = history[-5:].mean(0) 58 | input_loss = input_loss.reshape(-1,1) 59 | # input_loss = losses.reshape(-1,1) 60 | 61 | # fit a two-component GMM to the loss 62 | gmm = GaussianMixture(n_components=2, max_iter=10, tol=1e-2, reg_covar=5e-4) 63 | gmm.fit(input_loss) 64 | prob = gmm.predict_proba(input_loss) 65 | prob = prob[:, gmm.means_.argmin()] 66 | 67 | return prob 68 | 69 | def linear_rampup(self, current, warm_up, rampup_length=16): 70 | current = np.clip((current - warm_up) / rampup_length, 0.0, 1.0) 71 | return self.lambda_u * float(current) 72 | 73 | def forward(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up): 74 | probs_u = torch.softmax(outputs_u, dim=1) 75 | 76 | Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1)) 77 | Lu = torch.mean((probs_u - targets_u)**2) 78 | 79 | return Lx, Lu, self.linear_rampup(epoch, warm_up) -------------------------------------------------------------------------------- /downstream/weblysup/util/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import json 5 | import math 6 | import copy 7 | import numpy as np 8 | 9 | import torch 10 | import torch.optim as optim 11 | 12 | __all__ = ['AverageMeter', 'AverageClassMeter', 'adjust_lr_wd', 'warmup_learning_rate', 'accuracy', 'update_metric', 'get_best_acc', 'save_model', 'update_json'] 13 | 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | 33 | class AverageClassMeter(object): 34 | def __init__(self, n_cls): 35 | self.meter = [] 36 | for _ in range(n_cls): self.meter.append(AverageMeter()) 37 | self.n_cls = n_cls 38 | 39 | def update(self, cls, val, n=1): 40 | self.meter[cls].update(val, n) 41 | 42 | self.val = self.meter[-1].val 43 | self.avg = torch.tensor(sum([m.avg for m in self.meter]) / self.n_cls) 44 | 45 | 46 | def adjust_lr_wd(args, optimizer, epoch): 47 | lr = args.learning_rate 48 | if args.cosine: 49 | eta_min = lr * (args.lr_decay_rate ** 3) 50 | lr = eta_min + (lr - eta_min) * ( 51 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 52 | else: 53 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 54 | if steps > 0: 55 | lr = lr * (args.lr_decay_rate ** steps) 56 | 57 | wd = args.weight_decay 58 | if args.wd_scheduler: 59 | wd_min = args.weight_decay_end 60 | wd = wd_min + (wd - wd_min) * ( 61 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 62 | 63 | for i, param_group in enumerate(optimizer.param_groups): 64 | param_group['lr'] = lr 65 | if i == 0: # in case of DINO-ViT and MAE-ViT, only wd for regularized params 66 | param_group['weight_decay'] = wd 67 | 68 | 69 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 70 | if args.warm and epoch <= args.warm_epochs: 71 | p = (batch_id + (epoch - 1) * total_batches) / \ 72 | (args.warm_epochs * total_batches) 73 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 74 | 75 | for param_group in optimizer.param_groups: 76 | param_group['lr'] = lr 77 | 78 | 79 | def accuracy(output, target, topk=(1,)): 80 | """Computes the accuracy over the k top predictions for the specified values of k""" 81 | with torch.no_grad(): 82 | n_cls = output.shape[1] 83 | valid_topk = [k for k in topk if k <= n_cls] 84 | 85 | maxk = max(valid_topk) 86 | bsz = target.size(0) 87 | 88 | _, pred = output.topk(maxk, 1, True, True) 89 | pred = pred.t() 90 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 91 | 92 | res = [] 93 | for k in topk: 94 | if k in valid_topk: 95 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 96 | res.append(correct_k.mul_(100.0 / bsz)) 97 | else: res.append(torch.tensor([0.])) 98 | return res, bsz 99 | 100 | 101 | def update_metric(output, labels, top1, top5, args): 102 | if top1.__class__.__name__ == 'AverageMeter': 103 | [acc1, acc5], bsz = accuracy(output, labels, topk=(1, 5)) 104 | top1.update(acc1[0], bsz) 105 | top5.update(acc5[0], bsz) 106 | else: # mean per-class accuracy 107 | for cls in range(args.n_cls): 108 | if not (labels==cls).sum(): continue 109 | [acc1, acc5], bsz = accuracy(output[labels==cls], labels[labels==cls], topk=(1, 5)) 110 | top1.update(cls, acc1[0], bsz) 111 | top5.update(cls, acc5[0], bsz) 112 | return top1, top5 113 | 114 | 115 | def get_best_acc(val_acc1, val_acc5, best_acc): 116 | best = False 117 | if val_acc1.item() > best_acc[0]: 118 | best_acc[0] = val_acc1.item() 119 | best_acc[1] = val_acc5.item() 120 | best = True 121 | return best_acc, best 122 | 123 | 124 | def save_model(model, optimizer, args, epoch, save_file, indices=None, classifier=None): 125 | print('==> Saving...') 126 | state = { 127 | 'args': args, 128 | 'model': model.state_dict(), 129 | 'optimizer': optimizer.state_dict(), 130 | 'epoch': epoch, 131 | } 132 | if indices is not None: 133 | state['indices'] = indices 134 | if classifier is not None: # for active learning fine-tuning and open-set semi 135 | state['classifier'] = classifier.state_dict() 136 | torch.save(state, save_file) 137 | del state 138 | 139 | 140 | def update_json(exp_name, acc=[], path='./save/results.json'): 141 | acc = [round(a, 2) for a in acc] 142 | if not os.path.exists(path): 143 | with open(path, 'w') as f: 144 | json.dump({}, f) 145 | 146 | with open(path, 'r', encoding="UTF-8") as f: 147 | result_dict = json.load(f) 148 | result_dict[exp_name] = acc 149 | 150 | with open(path, 'w') as f: 151 | json.dump(result_dict, f) 152 | 153 | print('best accuracy: {} (Acc@1, Acc@5, Train Acc)'.format(acc)) 154 | print('results updated to %s' % path) 155 | -------------------------------------------------------------------------------- /downstream/weblysup/util/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TwoCropTransform: 5 | """ Create two crops of the same image """ 6 | def __init__(self, transform): 7 | self.transform = transform 8 | 9 | def __call__(self, x): 10 | return [self.transform(x), self.transform(x)] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet10, ResNet18, ResNet50, ResNet101 2 | from .resnext import resnext50, resnext101 3 | from .efficientnet import EfficientNet_B0, EfficientNet_B1, EfficientNet_B2 4 | from .dino_vit import dino_vit_tiny, dino_vit_small, dino_vit_base 5 | from .mae_vit import mae_vit_base_patch16_dec512d8b, mae_vit_large_patch16_dec512d8b, mae_vit_huge_patch14_dec512d8b 6 | # from .timm_vit import vit_small_patch16_224, vit_small_patch32_224, vit_base_patch8_224, vit_base_patch16_224, vit_base_patch32_224, vit_large_patch16_224, vit_large_patch32_224 7 | # from .timm_vit import vit_small_patch8_224_dino, vit_small_patch16_224_dino, vit_base_patch8_224_dino, vit_base_patch16_224_dino 8 | 9 | _backbone_class_map = { 10 | 'resnet10': ResNet10, 11 | 'resnet18': ResNet18, 12 | 'resnet50': ResNet50, 13 | 'resnet101': ResNet101, 14 | 'resnext50': resnext50, 15 | 'resnext101': resnext101, 16 | 'efficientnet_b0': EfficientNet_B0, 17 | 'efficientnet_b1': EfficientNet_B1, 18 | 'efficientnet_b2': EfficientNet_B2, 19 | 'dino_vit_t16': dino_vit_tiny, 20 | 'dino_vit_s16': dino_vit_small, 21 | 'dino_vit_b16': dino_vit_base, 22 | 'mae_vit_base': mae_vit_base_patch16_dec512d8b, 23 | 'mae_vit_large': mae_vit_large_patch16_dec512d8b, 24 | 'mae_vit_huge': mae_vit_huge_patch14_dec512d8b 25 | } 26 | 27 | 28 | def get_backbone_class(key): 29 | if key in _backbone_class_map: 30 | return _backbone_class_map[key] 31 | else: 32 | raise ValueError('Invalid backbone: {}'.format(key)) -------------------------------------------------------------------------------- /models/efficientnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import Tensor 4 | import torch.nn as nn 5 | import math 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from torch.hub import load_state_dict_from_url 9 | from torchvision.models.efficientnet import _efficientnet_conf 10 | 11 | 12 | # torchvision.__version__ == '0.11.0+cu113' (arguments for _efficientnet_conf are different in latest version) 13 | class EfficientNet_B0(torchvision.models.efficientnet.EfficientNet): 14 | def __init__(self): 15 | inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.0) 16 | super().__init__(inverted_residual_setting, 0.2) 17 | 18 | del self.classifier 19 | self.final_feat_dim = 1280 20 | 21 | def load_sl_official_weights(self, progress=True): 22 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", 23 | progress=progress) 24 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 25 | if len(missing) > 0: 26 | raise AssertionError('Model code may be incorrect') 27 | 28 | def load_ssl_official_weights(self, progress=True): 29 | raise NotImplemented 30 | 31 | def _forward_impl(self, x: Tensor) -> Tensor: 32 | # See note [TorchScript super()] 33 | x = self.features(x) 34 | 35 | x = self.avgpool(x) 36 | x = torch.flatten(x, 1) 37 | 38 | # x = self.classifier(x) 39 | 40 | return x 41 | 42 | 43 | class EfficientNet_B1(torchvision.models.efficientnet.EfficientNet): 44 | def __init__(self): 45 | inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.1) 46 | super().__init__(inverted_residual_setting, 0.2) 47 | 48 | del self.classifier 49 | self.final_feat_dim = 1280 50 | 51 | def load_sl_official_weights(self, progress=True): 52 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", 53 | progress=progress) 54 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 55 | if len(missing) > 0: 56 | raise AssertionError('Model code may be incorrect') 57 | 58 | def load_ssl_official_weights(self, progress=True): 59 | raise NotImplemented 60 | 61 | def _forward_impl(self, x: Tensor) -> Tensor: 62 | # See note [TorchScript super()] 63 | x = self.features(x) 64 | 65 | x = self.avgpool(x) 66 | x = torch.flatten(x, 1) 67 | 68 | # x = self.classifier(x) 69 | 70 | return x 71 | 72 | 73 | class EfficientNet_B2(torchvision.models.efficientnet.EfficientNet): 74 | def __init__(self): 75 | inverted_residual_setting = _efficientnet_conf(width_mult=1.1, depth_mult=1.2) 76 | super().__init__(inverted_residual_setting, 0.3) 77 | 78 | del self.classifier 79 | self.final_feat_dim = 1408 80 | 81 | def load_sl_official_weights(self, progress=True): 82 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", 83 | progress=progress) 84 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 85 | if len(missing) > 0: 86 | raise AssertionError('Model code may be incorrect') 87 | 88 | def load_ssl_official_weights(self, progress=True): 89 | raise NotImplemented 90 | 91 | def _forward_impl(self, x: Tensor) -> Tensor: 92 | # See note [TorchScript super()] 93 | x = self.features(x) 94 | 95 | x = self.avgpool(x) 96 | x = torch.flatten(x, 1) 97 | 98 | # x = self.classifier(x) 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # image recognition 2 | ipdb==0.13.9 3 | numpy==1.22.3 4 | Pillow==9.2.0 5 | scikit-learn 6 | timm 7 | pycocotools 8 | # if there is error msg of ImportError: cannot import name '_registerMatType' from 'cv2.cv2' 9 | opencv-python-headless<4.3 10 | 11 | # segmentation 12 | tensorboardX 13 | tqdm 14 | 15 | # detection 16 | pandas 17 | albumentations 18 | 19 | # for from_ssl_official install 20 | lightning-bolts 21 | setuptools==59.5.0 22 | 23 | # Pytorch 24 | # torch==1.10.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html 25 | # torchvision==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html -------------------------------------------------------------------------------- /run_selfsup.sh: -------------------------------------------------------------------------------- 1 | # self-supervised learning with coreset sampling 2 | TAG="" 3 | DATA="" 4 | 5 | python train_selfsup.py --tag $TAG \ 6 | --merge_dataset \ 7 | --model resnet50 \ 8 | --batch_size 512 \ 9 | --precision \ 10 | --dataset1 $DATA \ 11 | --dataset2 imagenet \ 12 | --data_folder1 /path/to/data \ 13 | --data_folder2 /path/to/data/ILSVRC2015/ILSVRC2015/Data/CLS-LOC/ \ 14 | --method simclr \ 15 | --epochs 5000 \ 16 | --cosine \ 17 | --optimizer sgd \ 18 | --learning_rate 1e-1 \ 19 | --weight_decay 1e-4 \ 20 | --sampling_method simcore \ 21 | --retrieval_ckpt /path/to/retrieval_ckpt/last.pth \ 22 | --cluster_num 100 \ 23 | --stop \ 24 | --stop_thresh 0.95 25 | 26 | 27 | # vanilla self-supervised learning without sampling 28 | TAG="" 29 | DATA="" 30 | 31 | python train_selfsup.py --tag $TAG \ 32 | --no_sampling \ 33 | --model resnet50 \ 34 | --batch_size 512 \ 35 | --precision \ 36 | --dataset $DATA \ 37 | --data_folder /path/to/data \ 38 | --method simclr \ 39 | --epochs 5000 \ 40 | --cosine \ 41 | --learning_rate 1e-1 \ 42 | --weight_decay 1e-4 43 | -------------------------------------------------------------------------------- /run_sup.sh: -------------------------------------------------------------------------------- 1 | TAG="" 2 | DATA="" 3 | 4 | python train_sup.py --tag $TAG \ 5 | --dataset $DATA \ 6 | --model resnet50 \ 7 | --data_folder /path/to/data \ 8 | --pretrained \ 9 | --pretrained_ckpt /path/to/ckpt/last.pth 10 | --method simclr \ 11 | --epochs 100 \ 12 | --learning_rate 10 \ 13 | --weight_decay 0 14 | 15 | # knn (Table 7a) 16 | # --knn \ 17 | # --topk 20 200 18 | 19 | # semisup (Table 7b) 20 | # --label_ratio 0.1 \ 21 | # --e2e -------------------------------------------------------------------------------- /ssl/__init__.py: -------------------------------------------------------------------------------- 1 | from ssl.base import BaseModel 2 | from ssl.simclr import SimCLR 3 | from ssl.moco import MoCo 4 | from ssl.byol import BYOL 5 | from ssl.simsiam import SimSiam 6 | from ssl.swav import SwAV 7 | from ssl.dino import DINO 8 | from ssl.mae import MAE 9 | 10 | _method_class_map = { 11 | 'base': BaseModel, 12 | 'simclr': SimCLR, 13 | 'moco': MoCo, 14 | 'byol': BYOL, 15 | 'simsiam': SimSiam, 16 | 'swav': SwAV, 17 | 'dino': DINO, 18 | 'mae': MAE 19 | } 20 | 21 | 22 | def get_method_class(key): 23 | if key in _method_class_map: 24 | return _method_class_map[key] 25 | else: 26 | raise ValueError('Invalid method: {}'.format(key)) 27 | -------------------------------------------------------------------------------- /ssl/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from argparse import Namespace 3 | from typing import Tuple 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class BaseModel(nn.Module): 10 | """ 11 | BaseModel subclasses self-contain all modules and losses required for pre-training. 12 | 13 | - supported_feature_selectors: Feature selectors (see `forward_features()`) are used during fine-tuning 14 | to select which features (from which layer the features should be extracted) should be used for downstream 15 | tasks. This class attribute should be set for subclasses to prevent mistakes regarding the feature_selector 16 | argument (see `params.ft_features`). 17 | """ 18 | supported_feature_selectors = [] 19 | 20 | def __init__(self, backbone: nn.Module, params: Namespace): 21 | super().__init__() 22 | self.backbone = backbone 23 | self.params = params 24 | self.classifier = nn.Linear(backbone.final_feat_dim, params.n_cls) 25 | self.classifier.bias.data.fill_(0) 26 | self.cls_loss_function = nn.CrossEntropyLoss() 27 | self.final_feat_dim = backbone.final_feat_dim 28 | 29 | def forward_features(self, x): 30 | """ 31 | You'll likely need to override this method for SSL models. 32 | """ 33 | return self.backbone(x) 34 | 35 | def forward(self, x): 36 | x = self.backbone(x) 37 | x = self.classifier(x) 38 | return x 39 | 40 | def compute_cls_loss_and_accuracy(self, x, y, return_predictions=False) -> Tuple: 41 | scores = self.forward(x) 42 | _, predicted = torch.max(scores.data, 1) 43 | accuracy = predicted.eq(y.data).cpu().sum() / x.shape[0] 44 | if return_predictions: 45 | return self.cls_loss_function(scores, y), accuracy, predicted 46 | else: 47 | return self.cls_loss_function(scores, y), accuracy 48 | 49 | def on_step_start(self): 50 | pass 51 | 52 | def on_step_end(self): 53 | pass 54 | 55 | def on_epoch_start(self, cur_epoch): 56 | pass 57 | 58 | def on_epoch_end(self, cur_epoch, tot_epoch): 59 | pass 60 | 61 | 62 | class BaseSelfSupervisedModel(BaseModel): 63 | @abstractmethod 64 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 65 | """ 66 | If SSL is based on paired input: 67 | By default: x1, x2 represent the input pair. 68 | If x2=None: x1 alone contains the full concatenated input pair. 69 | Else: 70 | x1 contains the input. 71 | """ 72 | raise NotImplementedError() 73 | 74 | @abstractmethod 75 | def _data_parallel(self): 76 | raise NotImplementedError() 77 | -------------------------------------------------------------------------------- /ssl/mae.py: -------------------------------------------------------------------------------- 1 | ### refer to https://github.com/facebookresearch/mae ### 2 | 3 | import copy 4 | import math 5 | import random 6 | import warnings 7 | from argparse import Namespace 8 | from functools import wraps 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | from ssl.base import BaseSelfSupervisedModel 15 | 16 | 17 | def _get_module_device(module): 18 | return next(module.parameters()).device 19 | 20 | 21 | class MAE(BaseSelfSupervisedModel): 22 | def __init__(self, backbone: nn.Module, params: Namespace): 23 | super().__init__(backbone, params) 24 | 25 | self.norm_pix_loss = params.norm_pix_loss 26 | 27 | backbone.set_mask_ratio(mask_ratio=params.mask_ratio) 28 | self.online_encoder = backbone # for consistency 29 | 30 | # get device of network and make wrapper same device 31 | device = _get_module_device(self.online_encoder) 32 | self.to(device) 33 | 34 | def _data_parallel(self): 35 | self.online_encoder = nn.DataParallel(self.online_encoder) 36 | 37 | def compute_ssl_loss(self, x, _, return_features=False): 38 | pred, target, mask = self.online_encoder(x, pretrain=True) # pred: [N, L, p*p*3] 39 | 40 | if self.norm_pix_loss: 41 | mean = target.mean(dim=-1, keepdim=True) 42 | var = target.var(dim=-1, keepdim=True) 43 | target = (target - mean) / (var + 1.e-6)**.5 44 | 45 | # MSE loss 46 | loss = (pred - target) ** 2 47 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 48 | 49 | # mask: [N, L], 0 is keep, 1 is remove 50 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 51 | 52 | return loss 53 | 54 | def forward_features(self, x): 55 | """ Only used in train_selfsup_sampling.py 56 | """ 57 | output = self.backbone(x, global_pool=True) 58 | return output -------------------------------------------------------------------------------- /ssl/moco.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from argparse import Namespace 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from ssl.base import BaseSelfSupervisedModel 9 | 10 | 11 | class MoCo(BaseSelfSupervisedModel): 12 | def __init__(self, backbone: nn.Module, params: Namespace): 13 | super().__init__(backbone, params) 14 | 15 | dim = 128 16 | mlp = False 17 | self.K = 1024 18 | self.m = 0.999 19 | self.T = 1.0 20 | 21 | self.encoder_q = self.backbone 22 | self.encoder_k = copy.deepcopy(self.backbone) 23 | 24 | if not mlp: 25 | self.projector_q = nn.Linear(self.encoder_q.final_feat_dim, dim) 26 | self.projector_k = nn.Linear(self.encoder_k.final_feat_dim, dim) 27 | else: 28 | mlp_dim = self.encoder_q.feature.final_feat_dim 29 | self.projector_q = nn.Sequential(nn.Linear(mlp_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim)) 30 | self.projector_k = nn.Sequential(nn.Linear(mlp_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim)) 31 | 32 | self.encoder_k.requires_grad_(False) 33 | self.projector_k.requires_grad_(False) 34 | # Just in case (copied from old code) 35 | for param_k in self.encoder_k.parameters(): 36 | param_k.requires_grad = False 37 | for param_k in self.projector_k.parameters(): 38 | param_k.requires_grad = False 39 | 40 | self.register_buffer("queue", torch.randn(dim, self.K)) 41 | self.queue = F.normalize(self.queue, dim=0) 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | self.ce_loss = nn.CrossEntropyLoss() 45 | 46 | @torch.no_grad() 47 | def _momentum_update_key_encoder(self): 48 | """ 49 | Momentum update of the key encoder 50 | """ 51 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 52 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 53 | for param_q_, param_k_ in zip(self.projector_q.parameters(), self.projector_k.parameters()): 54 | param_k_.data = param_k_.data * self.m + param_q_.data * (1. - self.m) 55 | 56 | @torch.no_grad() 57 | def _dequeue_and_enqueue(self, keys): 58 | batch_size = keys.shape[0] 59 | ptr = int(self.queue_ptr) 60 | assert self.K % batch_size == 0 # for simplicity 61 | 62 | # replace the keys at ptr (dequeue and enqueue) 63 | self.queue[:, ptr:ptr + batch_size] = keys.T 64 | ptr = (ptr + batch_size) % self.K # move pointer 65 | 66 | self.queue_ptr[0] = ptr 67 | 68 | def _data_parallel(self): 69 | self.encoder_q = nn.DataParallel(self.encoder_q) 70 | self.encoder_k = nn.DataParallel(self.encoder_k) 71 | self.projector_q = nn.DataParallel(self.projector_q) 72 | self.projector_k = nn.DataParallel(self.projector_k) 73 | 74 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 75 | if x2 is None: 76 | x = x1 77 | batch_size = int(x.shape[0] / 2) 78 | im_q = x[:batch_size] 79 | im_k = x[batch_size:] 80 | else: 81 | im_q = x1 82 | im_k = x2 83 | 84 | q_features = self.encoder_q(im_q) 85 | q = self.projector_q(q_features) # queries: NxC 86 | q = F.normalize(q, dim=1) 87 | 88 | # compute key features 89 | with torch.no_grad(): # no gradient to keys 90 | self._momentum_update_key_encoder() # update the key encoder 91 | 92 | k_features = self.encoder_k(im_k) 93 | k = self.projector_k(k_features) # keys: NxC 94 | k = F.normalize(k, dim=1) 95 | 96 | # compute logits (Einstein sum is more intuitive) 97 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # positive logits: Nx1 98 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # negative logits: NxK 99 | 100 | logits = torch.cat([l_pos, l_neg], dim=1) # logits: Nx(1+K) 101 | logits /= self.T # apply temperature 102 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # labels: positive key indicators 103 | 104 | self._dequeue_and_enqueue(k) 105 | 106 | loss = self.ce_loss(logits, labels) 107 | 108 | if return_features: 109 | if x2 is None: 110 | return loss, torch.cat([q_features, k_features]) 111 | else: 112 | return loss, q_features, k_features 113 | else: 114 | return loss 115 | -------------------------------------------------------------------------------- /ssl/simclr.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | from ssl.base import BaseSelfSupervisedModel 9 | 10 | 11 | class ProjectionHead(nn.Module): 12 | def __init__(self, in_dim, out_dim): 13 | super(ProjectionHead, self).__init__() 14 | self.in_dim = in_dim 15 | self.out_dim = out_dim 16 | 17 | self.fc1 = nn.Linear(in_dim, in_dim) 18 | self.relu = nn.ReLU() 19 | self.fc2 = nn.Linear(in_dim, out_dim) 20 | 21 | def forward(self, x): 22 | return self.fc2(self.relu(self.fc1(x))) 23 | 24 | 25 | class NTXentLoss(nn.Module): 26 | def __init__(self, temperature, use_cosine_similarity): 27 | super(NTXentLoss, self).__init__() 28 | self.temperature = temperature 29 | self.softmax = torch.nn.Softmax(dim=-1) 30 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 31 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 32 | 33 | def _get_similarity_function(self, use_cosine_similarity): 34 | if use_cosine_similarity: 35 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 36 | return self._cosine_simililarity 37 | else: 38 | return self._dot_simililarity 39 | 40 | @lru_cache(maxsize=4) 41 | def _get_correlated_mask(self, batch_size): 42 | diag = np.eye(2 * batch_size) 43 | l1 = np.eye((2 * batch_size), 2 * 44 | batch_size, k=-batch_size) 45 | l2 = np.eye((2 * batch_size), 2 * 46 | batch_size, k=batch_size) 47 | mask = torch.from_numpy((diag + l1 + l2)) 48 | mask = (1 - mask).type(torch.bool) 49 | return mask 50 | 51 | @staticmethod 52 | def _dot_simililarity(x, y): 53 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 54 | # x shape: (N, 1, C) 55 | # y shape: (1, C, 2N) 56 | # v shape: (N, 2N) 57 | return v 58 | 59 | def _cosine_simililarity(self, x, y): 60 | # x shape: (N, 1, C) 61 | # y shape: (1, 2N, C) 62 | # v shape: (N, 2N) 63 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 64 | return v 65 | 66 | def forward(self, zis, zjs): 67 | batch_size = zis.shape[0] 68 | representations = torch.cat([zjs, zis], dim=0) 69 | representations = torch.nn.functional.normalize(representations, dim=-1) 70 | device = representations.device 71 | 72 | similarity_matrix = self.similarity_function( 73 | representations, representations) 74 | 75 | # filter out the scores from the positive samples 76 | l_pos = torch.diag(similarity_matrix, batch_size) 77 | r_pos = torch.diag(similarity_matrix, -batch_size) 78 | positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) 79 | 80 | mask = self._get_correlated_mask(batch_size).to(device) 81 | negatives = similarity_matrix[mask].view(2 * batch_size, -1) 82 | 83 | logits = torch.cat((positives, negatives), dim=1) 84 | logits /= self.temperature 85 | 86 | labels = torch.zeros(2 * batch_size).to(device).long() 87 | loss = self.criterion(logits, labels) 88 | 89 | return loss / (2 * batch_size) 90 | 91 | 92 | class SimCLR(BaseSelfSupervisedModel): 93 | def __init__(self, backbone: nn.Module, params: Namespace): 94 | super().__init__(backbone, params) 95 | simclr_projection_dim = 128 96 | simclr_temperature = 0.07 97 | 98 | self.head = ProjectionHead(backbone.final_feat_dim, out_dim=simclr_projection_dim) 99 | self.ssl_loss_fn = NTXentLoss(temperature=simclr_temperature, use_cosine_similarity=False) 100 | self.final_feat_dim = self.backbone.final_feat_dim 101 | 102 | def _data_parallel(self): 103 | self.backbone = nn.DataParallel(self.backbone) 104 | self.head = nn.DataParallel(self.head) 105 | 106 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 107 | if x2 is None: 108 | x = x1 109 | else: 110 | x = torch.cat([x1, x2]) 111 | batch_size = int(x.shape[0] / 2) 112 | 113 | f = self.backbone(x) 114 | f1, f2 = f[:batch_size], f[batch_size:] 115 | p1 = self.head(f1) 116 | p2 = self.head(f2) 117 | loss = self.ssl_loss_fn(p1, p2) 118 | 119 | if return_features: 120 | if x2 is None: 121 | return loss, f 122 | else: 123 | return loss, f1, f2 124 | else: 125 | return loss 126 | -------------------------------------------------------------------------------- /ssl/simsiam.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | from torch import nn 4 | 5 | from ssl import BYOL 6 | 7 | 8 | class SimSiam(BYOL): 9 | def __init__(self, backbone: nn.Module, params: Namespace): 10 | super().__init__(backbone, params, use_momentum=False) 11 | -------------------------------------------------------------------------------- /util/imagenet_subset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cited from: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import torch 7 | import torch.utils.data as data 8 | from PIL import Image 9 | from torchvision import transforms as tf 10 | from glob import glob 11 | 12 | 13 | class ImageNetSubset(data.Dataset): 14 | def __init__(self, subset_file, root='', split='train', 15 | transform=None): 16 | super(ImageNetSubset, self).__init__() 17 | 18 | self.root = root 19 | self.transform = transform 20 | self.split = split 21 | 22 | # Read the subset of classes to include (sorted) 23 | with open(subset_file, 'r') as f: 24 | result = f.read().splitlines() 25 | subdirs = [] 26 | for line in result: 27 | subdirs.append(line) 28 | 29 | # Gather the files (sorted) 30 | imgs = [] 31 | for i, subdir in enumerate(subdirs): 32 | subdir_path = os.path.join(self.root, subdir) 33 | files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG'))) 34 | for f in files: 35 | imgs.append((f, i)) 36 | self.imgs = imgs 37 | 38 | # Resize 39 | self.resize = tf.Resize(256) 40 | 41 | def get_image(self, index): 42 | path, target = self.imgs[index] 43 | with open(path, 'rb') as f: 44 | img = Image.open(f).convert('RGB') 45 | img = self.resize(img) 46 | return img 47 | 48 | def __len__(self): 49 | return len(self.imgs) 50 | 51 | def __getitem__(self, index): 52 | path, target = self.imgs[index] 53 | with open(path, 'rb') as f: 54 | img = Image.open(f).convert('RGB') 55 | im_size = img.size 56 | img = self.resize(img) 57 | 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | 61 | return img, target 62 | -------------------------------------------------------------------------------- /util/knn_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import transforms, datasets 6 | 7 | 8 | # refer to https://github.com/Reza-Safdari/SimSiam-91.9-top1-acc-on-CIFAR10/blob/f7365684399b1e6895f81ff059eaa4e2b8c73608/simsiam/validation.py#L9 9 | class KNNValidation(object): 10 | DATASET_CONFIG = {'cars': 196, 'flowers': 102, 'pets': 37, 'aircraft': 100, 'cub': 200, 'dogs': 120, 11 | 'mit67': 67, 'stanford40': 40, 'dtd': 47, 'imagenet100': 100, 'imagenet': 1000} 12 | def __init__(self, args): 13 | self.args = args 14 | 15 | normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 16 | base_transform = transforms.Compose([transforms.Resize(256), 17 | transforms.CenterCrop(224), 18 | transforms.ToTensor(), 19 | normalize]) 20 | 21 | if args.dataset in self.DATASET_CONFIG: 22 | if args.dataset == 'imagenet': 23 | traindir = os.path.join(args.data_folder, 'train') # under ~~/Data/CLS-LOC 24 | valdir = os.path.join(args.data_folder, 'val') 25 | else: 26 | traindir = os.path.join(args.data_folder, args.dataset, 'train') 27 | valdir = os.path.join(args.data_folder, args.dataset, 'test') 28 | train_dataset = datasets.ImageFolder(root=traindir, transform=base_transform) 29 | val_dataset = datasets.ImageFolder(root=valdir, transform=base_transform) 30 | else: 31 | raise NotImplementedError 32 | 33 | # shuffle must be False 34 | self.train_loader = torch.utils.data.DataLoader( 35 | train_dataset, batch_size=args.batch_size, shuffle=False, 36 | num_workers=args.num_workers, pin_memory=True) 37 | self.val_loader = torch.utils.data.DataLoader( 38 | val_dataset, batch_size=args.batch_size, shuffle=False, 39 | num_workers=args.num_workers, pin_memory=True) 40 | 41 | def topk_retrieval(self, model): 42 | model.eval() 43 | 44 | n_data = len(self.train_loader.dataset) 45 | train_features = torch.zeros([model.module.final_feat_dim, n_data]) 46 | 47 | top1_list = [] 48 | for topk in self.args.topk: 49 | with torch.no_grad(): 50 | for idx, (images, _) in enumerate(self.train_loader): 51 | images = images.cuda(non_blocking=True) 52 | bsz = images.shape[0] 53 | 54 | # forward 55 | features = model(images) 56 | features = nn.functional.normalize(features) 57 | train_features[:, bsz*idx : bsz*idx+bsz] = features.data.t() 58 | 59 | train_labels = torch.LongTensor(self.train_loader.dataset.targets) 60 | 61 | total = 0 62 | correct = 0 63 | with torch.no_grad(): 64 | for idx, (images, labels) in enumerate(self.val_loader): 65 | images = images.cuda(non_blocking=True) 66 | # labels = labels.cuda(non_blocking=True) 67 | bsz = images.shape[0] 68 | 69 | features = model(images) 70 | features = nn.functional.normalize(features) 71 | dist = torch.mm(features.cpu(), train_features) 72 | 73 | # top-k 74 | yd, yi = dist.topk(topk, dim=1, largest=True, sorted=True) 75 | candidates = train_labels.view(1, -1).expand(bsz, -1) 76 | retrieval = torch.gather(candidates, 1, yi) 77 | # retrieval = retrieval.narrow(1, 0, 1).clone().view(-1) 78 | 79 | weight = torch.exp(yd / 0.07) # use temperature 0.07 80 | 81 | preds = [] 82 | for i, ret in enumerate(retrieval): 83 | unique = {cls.item(): 0 for cls in ret.unique()} 84 | for r, w in zip(ret, weight[i]): 85 | unique[r.item()] += w.item() 86 | pred, v = sorted(unique.items(), key=lambda item: item[1], reverse=True)[0] 87 | preds.append(pred) 88 | preds = torch.tensor(preds) 89 | 90 | total += labels.size(0) 91 | correct += preds.eq(labels.data).sum().item() 92 | top1 = round(correct / total * 100, 2) 93 | 94 | print(' * {topk}-NN Acc@1 {top1:.2f}'.format(topk=topk, top1=top1)) 95 | top1_list.append(top1) 96 | 97 | return top1_list 98 | -------------------------------------------------------------------------------- /util/merge_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torchvision import datasets 3 | 4 | 5 | class MergeDataset(data.Dataset): 6 | def __init__(self, dataset1, dataset2): 7 | super(MergeDataset, self).__init__() 8 | assert isinstance(dataset1, data.Dataset) and isinstance(dataset2, data.Dataset) 9 | self.dataset1 = dataset1 10 | self.dataset2 = dataset2 11 | self.n1 = len(self.dataset1) 12 | self.n2 = len(self.dataset2) 13 | 14 | def __len__(self): 15 | return self.n1 + self.n2 16 | 17 | def __getitem__(self, idx): 18 | if idx < self.n1: 19 | return self.dataset1[idx] 20 | else: 21 | return self.dataset2[idx - self.n1] 22 | 23 | 24 | class MergeAllDataset(data.Dataset): 25 | def __init__(self, transform): 26 | super(MergeAllDataset, self).__init__() 27 | self.transform = transform 28 | 29 | print('Merging COCO, iNaturalist, ImageNet, and Places365...') 30 | self.dataset1 = datasets.ImageFolder(root='/path/to/coco/train', 31 | transform=self.transform) 32 | self.dataset2 = datasets.INaturalist(root='/path/to/iNaturalist', 33 | version='2021_train_mini', 34 | transform=self.transform) 35 | self.dataset3 = datasets.ImageFolder(root='/path/to/ILSVRC2015/ILSVRC2015/Data/CLS-LOC/', 36 | transform=self.transform) 37 | self.dataset4 = datasets.ImageFolder(root='/path/to/places/data_256/train', 38 | transform=self.transform) 39 | 40 | self.n1 = len(self.dataset1) 41 | self.n2 = len(self.dataset2) + self.n1 42 | self.n3 = len(self.dataset3) + self.n2 43 | self.n4 = len(self.dataset4) + self.n3 44 | 45 | def __len__(self): 46 | return self.n4 47 | 48 | def __getitem__(self, idx): 49 | if idx < self.n1: 50 | return self.dataset1[idx] 51 | elif idx >= self.n1 and idx < self.n2: 52 | return self.dataset2[idx - self.n1] 53 | elif idx >= self.n2 and idx < self.n3: 54 | return self.dataset3[idx - self.n2] 55 | else: 56 | return self.dataset4[idx - self.n3] 57 | -------------------------------------------------------------------------------- /util/semisup_dataset.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import os 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | 10 | 11 | class ImageFolderSemiSup(data.Dataset): 12 | def __init__(self, root='', transform=None, p=1.0): 13 | super(ImageFolderSemiSup, self).__init__() 14 | 15 | self.root = root 16 | self.transform = transform 17 | self.p = p 18 | 19 | self.train_dataset = datasets.ImageFolder(root=self.root, transform=self.transform) 20 | 21 | # randomly select samples 22 | random.seed(0) 23 | self.dataset_len = int(self.p * len(self.train_dataset)) 24 | self.sampled_index = random.sample(range(len(self.train_dataset)), self.dataset_len) 25 | 26 | def __len__(self): 27 | return self.dataset_len 28 | 29 | def __getitem__(self, index): 30 | return self.train_dataset[self.sampled_index[index]] -------------------------------------------------------------------------------- /util/subclass/imagenet_sub_aircraft.txt: -------------------------------------------------------------------------------- 1 | n02690373 2 | n04552348 3 | n02692877 4 | n04266014 5 | n02009912 6 | n02002556 7 | n02012849 8 | n02006656 9 | -------------------------------------------------------------------------------- /util/subclass/imagenet_sub_cars.txt: -------------------------------------------------------------------------------- 1 | n02701002 2 | n02814533 3 | n02930766 4 | n03100240 5 | n03594945 6 | n03670208 7 | n03770679 8 | n03777568 9 | n04037443 10 | n04285008 11 | -------------------------------------------------------------------------------- /util/subclass/imagenet_sub_cub.txt: -------------------------------------------------------------------------------- 1 | n01855032 2 | n02027492 3 | n02028035 4 | n02058221 5 | n02051845 6 | n01537544 7 | n01530575 8 | n01531178 9 | n01532829 10 | n01534433 11 | n01560419 12 | n01558993 13 | n01580077 14 | n01582220 15 | n01592084 16 | n01608432 17 | n01833805 18 | n01843065 19 | n01828970 20 | n01806567 21 | -------------------------------------------------------------------------------- /util/subclass/imagenet_sub_pets.txt: -------------------------------------------------------------------------------- 1 | n02124075 2 | n02123394 3 | n02123597 4 | n02123159 5 | n02108089 6 | n02112350 7 | n02088238 8 | n02100735 9 | n02107312 10 | n02085620 11 | n02111500 12 | n02100236 13 | n02088364 14 | n02093256 15 | n02102318 16 | n02111277 17 | n02112018 18 | n02111129 19 | n02098105 20 | n02111889 21 | n02097298 22 | n02110958 23 | n02109525 24 | n02094433 25 | -------------------------------------------------------------------------------- /util/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | np.random.seed(0) 4 | 5 | import torch 6 | import torchvision 7 | from torch import nn 8 | from torchvision.transforms import transforms 9 | 10 | 11 | # for self-supervised pretraining 12 | class TwoCropTransform: 13 | """ Create two crops of the same image """ 14 | def __init__(self, transform): 15 | self.transform = transform 16 | 17 | def __call__(self, x): 18 | return [self.transform(x), self.transform(x)] 19 | 20 | 21 | class MultiCropTransform: 22 | """ Create multi crops of the same image """ 23 | def __init__(self, local_crops_number): 24 | # slightly modified transforms (e.g., Crop scale, ColorJitter, Blur, Solarize) 25 | global_crops_scale = (0.2, 1.) 26 | local_crops_scale = (0.05, 0.2) 27 | self.local_crops_number = local_crops_number 28 | 29 | self.global_transform = transforms.Compose([ 30 | transforms.RandomResizedCrop(size=224, scale=global_crops_scale), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), 33 | transforms.RandomGrayscale(p=0.2), 34 | # transforms.GaussianBlur(kernel_size=(3,3)), # require too much cost 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 37 | ]) 38 | self.local_transform = transforms.Compose([ 39 | transforms.RandomResizedCrop(size=96, scale=local_crops_scale), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), 42 | transforms.RandomGrayscale(p=0.2), 43 | # transforms.GaussianBlur(kernel_size=(3,3)), # require too much cost 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 46 | ]) 47 | 48 | def __call__(self, x): 49 | crops = [] 50 | crops.append(self.global_transform(x)) 51 | crops.append(self.global_transform(x)) 52 | for _ in range(self.local_crops_number): 53 | crops.append(self.local_transform(x)) 54 | return crops 55 | 56 | 57 | class MAETransform: 58 | """ Single-view transform for MAE-ViT """ 59 | def __init__(self, img_size): 60 | self.transform = transforms.Compose([ 61 | transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 65 | ]) 66 | 67 | def __call__(self, x): 68 | return self.transform(x) --------------------------------------------------------------------------------