├── project_utils ├── __init__.py ├── k_means_utils.py ├── pos_embed.py ├── slurm_out_parser.py ├── cluster_and_log_utils.py ├── schedulers.py ├── cluster_utils.py └── general_utils.py ├── assets └── overview.png ├── data ├── ssb_splits │ ├── cub_osr_splits.pkl │ ├── scars_osr_splits.pkl │ ├── aircraft_osr_splits.pkl │ └── herbarium_19_class_splits.pkl ├── data_utils.py ├── augmentations │ ├── cut_out.py │ ├── __init__.py │ └── randaugment.py ├── stanford_cars.py ├── imagenet.py ├── cifar.py ├── get_datasets.py ├── cub.py ├── food.py ├── pets.py ├── flower.py └── fgvc_aircraft.py ├── requirements.txt ├── bash_scripts ├── extract_features.sh ├── estimate_k.sh ├── ssk_means.sh ├── get_subset_len.sh ├── get_kmeans_subset.sh └── representation_learning.sh ├── config.py ├── methods ├── clustering │ ├── feature_vector_dataset.py │ ├── ssk_means.py │ ├── extract_features.py │ └── faster_mix_k_means_pytorch.py ├── partitioning │ ├── subset_len.py │ └── kmeans_subset.py └── estimate_k │ └── estimate_k.py ├── README.md └── model └── vision_transformer.py /project_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiXXin/XCon/HEAD/assets/overview.png -------------------------------------------------------------------------------- /data/ssb_splits/cub_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiXXin/XCon/HEAD/data/ssb_splits/cub_osr_splits.pkl -------------------------------------------------------------------------------- /data/ssb_splits/scars_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiXXin/XCon/HEAD/data/ssb_splits/scars_osr_splits.pkl -------------------------------------------------------------------------------- /data/ssb_splits/aircraft_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiXXin/XCon/HEAD/data/ssb_splits/aircraft_osr_splits.pkl -------------------------------------------------------------------------------- /data/ssb_splits/herbarium_19_class_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YiXXin/XCon/HEAD/data/ssb_splits/herbarium_19_class_splits.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | matplotlib==3.5.2 3 | numpy==1.17.3 4 | pandas==1.0.3 5 | path.py==12.5.0 6 | Pillow==9.2.0 7 | scikit_learn==1.1.1 8 | scipy==1.7.3 9 | six==1.12.0 10 | tensorboard==2.4.0 11 | timm==0.4.12 12 | torch==1.10.0 13 | torchvision==0.11.1 14 | tqdm==4.36.1 15 | -------------------------------------------------------------------------------- /bash_scripts/extract_features.sh: -------------------------------------------------------------------------------- 1 | PYTHON='python' 2 | 3 | hostname 4 | nvidia-smi 5 | 6 | export CUDA_VISIBLE_DEVICES=0 7 | 8 | ${PYTHON} -m methods.clustering.extract_features --dataset cub --use_best_model 'True' \ 9 | --warmup_model_dir './XCon_outputs/metric_learn_gcd/log/(02.08.2022_|_25.076)/checkpoints/model.pt' -------------------------------------------------------------------------------- /bash_scripts/estimate_k.sh: -------------------------------------------------------------------------------- 1 | PYTHON='python' 2 | 3 | hostname 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | # Get unique log file 7 | SAVE_DIR=./XCon_outputs/outputs/ 8 | 9 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 10 | EXP_NUM=$((${EXP_NUM}+1)) 11 | echo $EXP_NUM 12 | 13 | ${PYTHON} -m methods.estimate_k.estimate_k --max_classes 1000 --dataset_name cub --search_mode other \ 14 | --warmup_model_exp_id '(02.08.2022_|_25.076)' --use_best_model 'True' \ 15 | > ${SAVE_DIR}logfile_${EXP_NUM}.out -------------------------------------------------------------------------------- /bash_scripts/ssk_means.sh: -------------------------------------------------------------------------------- 1 | PYTHON='python' 2 | 3 | hostname 4 | nvidia-smi 5 | 6 | export CUDA_VISIBLE_DEVICES=0 7 | 8 | # Get unique log file 9 | SAVE_DIR=./XCon_outputs/outputs/ 10 | 11 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 12 | EXP_NUM=$((${EXP_NUM}+1)) 13 | echo $EXP_NUM 14 | 15 | ${PYTHON} -m methods.clustering.ssk_means --dataset 'cub' --semi_sup 'True' --use_ssb_splits 'True' \ 16 | --use_best_model 'True' --max_kmeans_iter 200 --k_means_init 100 --warmup_model_exp_id '(02.08.2022_|_25.076)'\ 17 | > ${SAVE_DIR}logfile_${EXP_NUM}.out -------------------------------------------------------------------------------- /bash_scripts/get_subset_len.sh: -------------------------------------------------------------------------------- 1 | PYTHON='python' 2 | 3 | hostname 4 | nvidia-smi 5 | 6 | export CUDA_VISIBLE_DEVICES=0 7 | 8 | # Get unique log file, 9 | SAVE_DIR=./XCon_outputs/outputs/ 10 | 11 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 12 | EXP_NUM=$((${EXP_NUM}+1)) 13 | echo $EXP_NUM 14 | 15 | ${PYTHON} -m methods.partitioning.subset_len \ 16 | --dataset_name 'cub' \ 17 | --batch_size 256 \ 18 | --num_workers 4 \ 19 | --use_ssb_splits 'True' \ 20 | --transform 'imagenet' \ 21 | --eval_funcs 'v2' \ 22 | --experts_num 8 \ -------------------------------------------------------------------------------- /bash_scripts/get_kmeans_subset.sh: -------------------------------------------------------------------------------- 1 | PYTHON='python' 2 | 3 | hostname 4 | nvidia-smi 5 | 6 | export CUDA_VISIBLE_DEVICES=0 7 | 8 | # Get unique log file, 9 | SAVE_DIR=./XCon_outputs/outputs/ 10 | 11 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 12 | EXP_NUM=$((${EXP_NUM}+1)) 13 | echo $EXP_NUM 14 | 15 | ${PYTHON} -m methods.partitioning.kmeans_subset \ 16 | --dataset_name 'cub' \ 17 | --batch_size 256 \ 18 | --grad_from_block 11 \ 19 | --epochs 100 \ 20 | --base_model vit_dino \ 21 | --num_workers 4 \ 22 | --use_ssb_splits 'True' \ 23 | --weight_decay 5e-5 \ 24 | --transform 'imagenet' \ 25 | --lr 0.1 \ 26 | --eval_funcs 'v2' \ 27 | --pretrain_model 'dino' \ 28 | --experts_num 8 \ -------------------------------------------------------------------------------- /bash_scripts/representation_learning.sh: -------------------------------------------------------------------------------- 1 | PYTHON='python' 2 | 3 | hostname 4 | nvidia-smi 5 | 6 | export CUDA_VISIBLE_DEVICES=0 7 | 8 | # Get unique log file, 9 | SAVE_DIR=./XCon_outputs/outputs/ 10 | 11 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 12 | EXP_NUM=$((${EXP_NUM}+1)) 13 | echo $EXP_NUM 14 | 15 | ${PYTHON} -m methods.representation_learning.representation_learning \ 16 | --dataset_name 'cub' \ 17 | --batch_size 256 \ 18 | --grad_from_block 11 \ 19 | --epochs 200 \ 20 | --base_model vit_dino \ 21 | --num_workers 4 \ 22 | --use_ssb_splits 'True' \ 23 | --sup_con_weight 0.35 \ 24 | --weight_decay 5e-5 \ 25 | --contrast_unlabel_only 'False' \ 26 | --transform 'imagenet' \ 27 | --lr 0.1 \ 28 | --eval_funcs 'v2' \ 29 | --val_epoch_size 10 \ 30 | --use_best_model 'True' \ 31 | --use_global_con 'True' \ 32 | --expert_weight 0.1 \ 33 | --max_kmeans_iter 200 \ 34 | --k_means_init 100 \ 35 | --best_new 'False' \ 36 | --pretrain_model 'dino' \ 37 | > ${SAVE_DIR}logfile_${EXP_NUM}.out -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ----------------- 2 | # DATASET ROOTS 3 | # ----------------- 4 | cifar_10_root = '/data/dataset/cifar10' 5 | cifar_100_root = '/data/dataset/cifar100' 6 | cub_root = '/data/dataset/cub200' 7 | aircraft_root = '/data/user-data/fgvc/aircraft/fgvc-aircraft-2013b' 8 | imagenet_root = '/data/user-data/imagenet' 9 | cars_root = '/data/user-data/fgvc/cars' 10 | 11 | pets_root = '/data/user-data/fgvc/pets' 12 | flower_root = '/data/user-data/fgvc/flower102' 13 | food_root = '/data/user-data/fgvc/food-101' 14 | 15 | # OSR Split dir 16 | osr_split_dir = './data/ssb_splits' 17 | 18 | # ----------------- 19 | # PRETRAIN PATHS 20 | # ----------------- 21 | dino_pretrain_path = './pretrained_models/dino/dino_vitbase16_pretrain.pth' 22 | moco_pretrain_path = './pretrained_models/moco/vit-b-300ep.pth.tar' 23 | mae_pretrain_path = './pretrained_models/mae/mae_pretrain_vit_base.pth' 24 | 25 | # Dataset partitioning paths 26 | km_label_path = './partition_out/km_labels' 27 | subset_len_path = './partition_out/subset_len' 28 | 29 | # ----------------- 30 | # OTHER PATHS 31 | # ----------------- 32 | feature_extract_dir = './XCon_outputs/extracted_features' # Extract features to this directory 33 | exp_root = './XCon_outputs' # All logs and checkpoints will be saved here -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | def subsample_instances(dataset, prop_indices_to_subsample=0.8): 6 | 7 | np.random.seed(0) 8 | subsample_indices = np.random.choice(range(len(dataset)), replace=False, 9 | size=(int(prop_indices_to_subsample * len(dataset)),)) 10 | 11 | return subsample_indices 12 | 13 | class MergedDataset(Dataset): 14 | 15 | """ 16 | Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them 17 | Allows you to iterate over them in parallel 18 | """ 19 | 20 | def __init__(self, labelled_dataset, unlabelled_dataset): 21 | 22 | self.labelled_dataset = labelled_dataset 23 | self.unlabelled_dataset = unlabelled_dataset 24 | self.target_transform = None 25 | 26 | def __getitem__(self, item): 27 | if item < len(self.labelled_dataset): 28 | img, label, uq_idx = self.labelled_dataset[item] 29 | labeled_or_not = 1 30 | 31 | else: 32 | img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)] 33 | labeled_or_not = 0 34 | 35 | 36 | return img, label, uq_idx, np.array([labeled_or_not]) 37 | 38 | def __len__(self): 39 | return len(self.unlabelled_dataset) + len(self.labelled_dataset) 40 | 41 | -------------------------------------------------------------------------------- /data/augmentations/cut_out.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/hysts/pytorch_cutout 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | def cutout(mask_size, p, cutout_inside, mask_color=(0, 0, 0)): 9 | mask_size_half = mask_size // 2 10 | offset = 1 if mask_size % 2 == 0 else 0 11 | 12 | def _cutout(image): 13 | image = np.asarray(image).copy() 14 | 15 | if np.random.random() > p: 16 | return image 17 | 18 | h, w = image.shape[:2] 19 | 20 | if cutout_inside: 21 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half 22 | cymin, cymax = mask_size_half, h + offset - mask_size_half 23 | else: 24 | cxmin, cxmax = 0, w + offset 25 | cymin, cymax = 0, h + offset 26 | 27 | cx = np.random.randint(cxmin, cxmax) 28 | cy = np.random.randint(cymin, cymax) 29 | xmin = cx - mask_size_half 30 | ymin = cy - mask_size_half 31 | xmax = xmin + mask_size 32 | ymax = ymin + mask_size 33 | xmin = max(0, xmin) 34 | ymin = max(0, ymin) 35 | xmax = min(w, xmax) 36 | ymax = min(h, ymax) 37 | image[ymin:ymax, xmin:xmax] = mask_color 38 | return image 39 | 40 | return _cutout 41 | 42 | def to_tensor(): 43 | def _to_tensor(image): 44 | if len(image.shape) == 3: 45 | return torch.from_numpy( 46 | image.transpose(2, 0, 1).astype(float)) 47 | else: 48 | return torch.from_numpy(image[None, :, :].astype(float)) 49 | 50 | return _to_tensor 51 | 52 | def normalize(mean, std): 53 | 54 | mean = np.array(mean) 55 | std = np.array(std) 56 | 57 | def _normalize(image): 58 | image = np.asarray(image).astype(float) / 255. 59 | image = (image - mean) / std 60 | return image 61 | 62 | return _normalize -------------------------------------------------------------------------------- /methods/clustering/feature_vector_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | import os 5 | from copy import deepcopy 6 | 7 | from data.data_utils import MergedDataset 8 | 9 | 10 | class FeatureVectorDataset(Dataset): 11 | 12 | def __init__(self, base_dataset, feature_root): 13 | 14 | """ 15 | Dataset loads feature vectors instead of images 16 | :param base_dataset: Dataset from which images would come 17 | :param feature_root: Root directory of features 18 | 19 | feature_root should be structured as: 20 | feature_root/class_label/uq_idx.pt (torch files) 21 | """ 22 | 23 | self.base_dataset = base_dataset 24 | self.target_transform = deepcopy(base_dataset.target_transform) 25 | 26 | self.base_dataset.target_transform = None 27 | self.base_dataset.transform = None 28 | 29 | self.feature_root = feature_root 30 | 31 | # if uncomment, cannot find the corresponding feature path(only when using expert models) 32 | if isinstance(self.base_dataset, MergedDataset): 33 | self.base_dataset.labelled_dataset.target_transform = None 34 | self.base_dataset.unlabelled_dataset.target_transform = None 35 | 36 | def __getitem__(self, item): 37 | 38 | if isinstance(self.base_dataset, MergedDataset): 39 | 40 | # Get meta info for this instance 41 | _, label, uq_idx, mask_lab = self.base_dataset[item] 42 | 43 | # Load feature vector 44 | feat_path = os.path.join(self.feature_root, f'{label}', f'{uq_idx}.npy') 45 | feature_vector = torch.load(feat_path) 46 | 47 | if self.target_transform is not None: 48 | label = self.target_transform(label) 49 | 50 | return feature_vector, label, uq_idx, mask_lab[0] 51 | 52 | else: 53 | 54 | # Get meta info for this instance 55 | _, label, uq_idx = self.base_dataset[item] 56 | 57 | # Load feature vector 58 | feat_path = os.path.join(self.feature_root, f'{label}', f'{uq_idx}.npy') 59 | feature_vector = torch.load(feat_path) 60 | 61 | if self.target_transform is not None: 62 | label = self.target_transform(label) 63 | 64 | return feature_vector, label, uq_idx 65 | 66 | 67 | def __len__(self): 68 | return len(self.base_dataset) -------------------------------------------------------------------------------- /project_utils/k_means_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | from sklearn.cluster import KMeans 7 | import torch 8 | from project_utils.cluster_and_log_utils import log_accs_from_preds 9 | 10 | from methods.clustering.faster_mix_k_means_pytorch import K_Means as SemiSupKMeans 11 | 12 | from tqdm import tqdm 13 | 14 | # TODO: Debug 15 | import warnings 16 | warnings.filterwarnings("ignore", category=DeprecationWarning) 17 | 18 | 19 | def test_kmeans_semi_sup(model, test_loader, epoch, save_name, args, K=None): 20 | 21 | """ 22 | In this case, the test loader needs to have the labelled and unlabelled subsets of the training data 23 | """ 24 | 25 | if K is None: 26 | K = args.num_labeled_classes + args.num_unlabeled_classes 27 | 28 | model.eval() 29 | device = torch.device('cuda:0') 30 | 31 | all_feats = [] 32 | targets = np.array([]) 33 | mask_lab = np.array([]) # From all the data, which instances belong to the labelled set 34 | mask_cls = np.array([]) # From all the data, which instances belong to Old classes 35 | 36 | print('Collating features...') 37 | # First extract all features 38 | for batch_idx, (images, label, _, mask_lab_) in enumerate(tqdm(test_loader)): 39 | 40 | images = images.to(device) 41 | 42 | feats = model(images) 43 | 44 | feats = torch.nn.functional.normalize(feats, dim=-1) 45 | 46 | all_feats.append(feats.cpu().numpy()) 47 | targets = np.append(targets, label.cpu().numpy()) 48 | mask_cls = np.append(mask_cls, np.array([True if x.item() in range(len(args.train_classes)) 49 | else False for x in label])) 50 | mask_lab = np.append(mask_lab, mask_lab_.cpu().bool().numpy()) 51 | 52 | # ----------------------- 53 | # K-MEANS 54 | # ----------------------- 55 | mask_lab = mask_lab.astype(bool) 56 | mask_cls = mask_cls.astype(bool) 57 | 58 | all_feats = np.concatenate(all_feats) 59 | 60 | l_feats = all_feats[mask_lab] # Get labelled set 61 | u_feats = all_feats[~mask_lab] # Get unlabelled set 62 | l_targets = targets[mask_lab] # Get labelled targets 63 | u_targets = targets[~mask_lab] # Get unlabelled targets 64 | 65 | print('Fitting Semi-Supervised K-Means...') 66 | kmeans = SemiSupKMeans(k=K, tolerance=1e-4, max_iterations=args.max_kmeans_iter, init='k-means++', 67 | n_init=args.k_means_init, random_state=None, n_jobs=None, pairwise_batch_size=1024, mode=None) 68 | 69 | l_feats, u_feats, l_targets, u_targets = (torch.from_numpy(x).to(device) for 70 | x in (l_feats, u_feats, l_targets, u_targets)) 71 | 72 | kmeans.fit_mix(u_feats, l_feats, l_targets) 73 | all_preds = kmeans.labels_.cpu().numpy() 74 | u_targets = u_targets.cpu().numpy() 75 | 76 | # ----------------------- 77 | # EVALUATE 78 | # ----------------------- 79 | # Get preds corresponding to unlabelled set 80 | preds = all_preds[~mask_lab] 81 | 82 | # Get portion of mask_cls which corresponds to the unlabelled set 83 | mask = mask_cls[~mask_lab] 84 | mask = mask.astype(bool) 85 | 86 | # ----------------------- 87 | # EVALUATE 88 | # ----------------------- 89 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=u_targets, y_pred=preds, mask=mask, eval_funcs=args.eval_funcs, 90 | save_name='SS-K-Means Train ACC Unlabelled', T=epoch, print_output=True) 91 | 92 | return all_acc, old_acc, new_acc, kmeans -------------------------------------------------------------------------------- /project_utils/pos_embed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | # -------------------------------------------------------- 6 | # 2D sine-cosine position embedding 7 | # References: 8 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 9 | # MoCo v3: https://github.com/facebookresearch/moco-v3 10 | # -------------------------------------------------------- 11 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 12 | """ 13 | grid_size: int of the grid height and width 14 | return: 15 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 16 | """ 17 | grid_h = np.arange(grid_size, dtype=np.float32) 18 | grid_w = np.arange(grid_size, dtype=np.float32) 19 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 20 | grid = np.stack(grid, axis=0) 21 | 22 | grid = grid.reshape([2, 1, grid_size, grid_size]) 23 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 24 | if cls_token: 25 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 26 | return pos_embed 27 | 28 | 29 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 30 | assert embed_dim % 2 == 0 31 | 32 | # use half of dimensions to encode grid_h 33 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 34 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 35 | 36 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 37 | return emb 38 | 39 | 40 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 41 | """ 42 | embed_dim: output dimension for each position 43 | pos: a list of positions to be encoded: size (M,) 44 | out: (M, D) 45 | """ 46 | assert embed_dim % 2 == 0 47 | omega = np.arange(embed_dim // 2, dtype=np.float) 48 | omega /= embed_dim / 2. 49 | omega = 1. / 10000**omega # (D/2,) 50 | 51 | pos = pos.reshape(-1) # (M,) 52 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 53 | 54 | emb_sin = np.sin(out) # (M, D/2) 55 | emb_cos = np.cos(out) # (M, D/2) 56 | 57 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 58 | return emb 59 | 60 | 61 | # -------------------------------------------------------- 62 | # Interpolate position embeddings for high-resolution 63 | # References: 64 | # DeiT: https://github.com/facebookresearch/deit 65 | # -------------------------------------------------------- 66 | def interpolate_pos_embed(model, checkpoint_model): 67 | if 'pos_embed' in checkpoint_model: 68 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 69 | embedding_size = pos_embed_checkpoint.shape[-1] 70 | num_patches = model.patch_embed.num_patches 71 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 72 | # height (== width) for the checkpoint position embedding 73 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 74 | # height (== width) for the new position embedding 75 | new_size = int(num_patches ** 0.5) 76 | # class_token and dist_token are kept unchanged 77 | if orig_size != new_size: 78 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 79 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 80 | # only the position tokens are interpolated 81 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 82 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 83 | pos_tokens = torch.nn.functional.interpolate( 84 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 85 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 86 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 87 | checkpoint_model['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /project_utils/slurm_out_parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pandas as pd 3 | import os 4 | 5 | pd.options.display.width = 0 6 | 7 | rx_dict = { 8 | 'model_dir': re.compile(r'model_dir=\'(.*?)\''), 9 | 'dataset': re.compile(r'dataset_name=\'(.*?)\''), 10 | 'm': re.compile(r'rand_aug_m=([-+]?\d*)'), 11 | 'n': re.compile(r'rand_aug_n=([-+]?\d*)'), 12 | 'wd': re.compile("weight_decay=(.*?),"), 13 | 'sc_weight': re.compile("sup_con_weight=(.*?),"), 14 | 'split_idx': re.compile(r'split_idx=(\d)'), 15 | 'epochs': re.compile(r'Train Epoch: (\d)'), 16 | 'part_loss_mode': re.compile(r'part_loss_mode=(\d)'), 17 | 'consistency_weight': re.compile(r'consistency_weight=(.*?),'), 18 | 'lr': re.compile("lr=([-+]?\d*\.\d+|\d+)"), 19 | 'Train Accs': re.compile("Train Accuracies: ([-+]?\d*\.\d+|\d+)") 20 | } 21 | 22 | save_root_dir = '/work/sagar/open_set_recognition/sweep_summary_files/ensemble_pkls' 23 | 24 | 25 | def get_file(path): 26 | 27 | file = [] 28 | with open(path, 'rt') as myfile: 29 | for myline in myfile: # For each line, read to a string, 30 | file.append(myline) 31 | 32 | return file 33 | 34 | 35 | def parse_out_file(path, rx_dict, root_dir=save_root_dir, save_name='test.pkl', save=True, verbose=True): 36 | 37 | file = get_file(path=path) 38 | for i, line in enumerate(file): 39 | if line.find('Namespace') != -1: 40 | 41 | model = {} 42 | s = rx_dict['model_dir'].search(line).group(1) 43 | exp_id = s[s.find("("):s.find(")") + 1] 44 | model['exp_id'] = exp_id 45 | model['dataset'] = rx_dict['dataset'].search(line).group(1) 46 | model['lr'] = rx_dict['lr'].search(line).group(1) 47 | 48 | break 49 | 50 | reverse_file = file[::-1] 51 | for i, line in enumerate(reverse_file): 52 | if line.find('Train Accuracies') != -1: 53 | 54 | model['Train Mean'] = re.findall("\d+\.\d+", line)[0] 55 | model['Train Old'] = re.findall("\d+\.\d+", line)[1] 56 | model['Train New'] = re.findall("\d+\.\d+", line)[2] 57 | 58 | for i_, line_ in enumerate(reverse_file[i:]): 59 | if 'Train Epoch' in line_: 60 | model['Last Epoch'] = re.findall('Train Epoch: (\d+)', line_)[0] 61 | break 62 | 63 | break 64 | 65 | for i, line in enumerate(reverse_file): 66 | if line.find('Best Train Accuracies') != -1: 67 | 68 | model['Best Train Mean'] = re.findall("\d+\.\d+", line)[0] 69 | model['Best Train Old'] = re.findall("\d+\.\d+", line)[1] 70 | model['Best Train New'] = re.findall("\d+\.\d+", line)[2] 71 | 72 | for i_, line_ in enumerate(reverse_file[i:]): 73 | if 'Train Epoch' in line_: 74 | model['Best Epoch'] = re.findall('Train Epoch: (\d+)', line_)[0] 75 | break 76 | 77 | break 78 | 79 | 80 | data = pd.DataFrame([model]) 81 | 82 | if verbose: 83 | print(data) 84 | 85 | if save: 86 | 87 | save_path = os.path.join(root_dir, save_name) 88 | data.to_pickle(save_path) 89 | 90 | else: 91 | 92 | return data 93 | 94 | 95 | def parse_multiple_files(all_paths, rx_dict, root_dir=save_root_dir, save_name='test.pkl', verbose=True, save=False): 96 | 97 | all_data = [] 98 | for path in all_paths: 99 | 100 | data = parse_out_file(path, rx_dict, save=False, verbose=False) 101 | all_data.append(data) 102 | 103 | all_data = pd.concat(all_data) 104 | save_path = os.path.join(root_dir, save_name) 105 | 106 | if save: 107 | all_data.to_pickle(save_path) 108 | 109 | if verbose: 110 | print(all_data) 111 | 112 | return all_data 113 | 114 | 115 | save_dir = '/work/sagar/open_set_recognition/sweep_summary_files/ensemble_pkls' 116 | base_path = '/work/sagar/osr_novel_categories/slurm_outputs/myLog-{}.out' 117 | 118 | all_paths = [base_path.format(i) for i in ['407814_{}'.format(j) for j in range(11)]] 119 | 120 | data = parse_multiple_files(all_paths, rx_dict, verbose=True, save=False, save_name='test.pkl') -------------------------------------------------------------------------------- /project_utils/cluster_and_log_utils.py: -------------------------------------------------------------------------------- 1 | from project_utils.cluster_utils import cluster_acc, np, linear_assignment 2 | from torch.utils.tensorboard import SummaryWriter 3 | from typing import List 4 | 5 | 6 | def split_cluster_acc_v1(y_true, y_pred, mask): 7 | 8 | """ 9 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask' 10 | (Mask usually corresponding to `Old' and `New' classes in GCD setting) 11 | :param targets: All ground truth labels 12 | :param preds: All predictions 13 | :param mask: Mask defining two subsets 14 | :return: 15 | """ 16 | 17 | mask = mask.astype(bool) 18 | y_true = y_true.astype(int) 19 | y_pred = y_pred.astype(int) 20 | weight = mask.mean() 21 | 22 | old_acc = cluster_acc(y_true[mask], y_pred[mask]) 23 | new_acc = cluster_acc(y_true[~mask], y_pred[~mask]) 24 | total_acc = weight * old_acc + (1 - weight) * new_acc 25 | 26 | return total_acc, old_acc, new_acc 27 | 28 | 29 | def split_cluster_acc_v2(y_true, y_pred, mask): 30 | """ 31 | Calculate clustering accuracy. Require scikit-learn installed 32 | First compute linear assignment on all data, then look at how good the accuracy is on subsets 33 | 34 | # Arguments 35 | mask: Which instances come from old classes (True) and which ones come from new classes (False) 36 | y: true labels, numpy.array with shape `(n_samples,)` 37 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 38 | 39 | # Return 40 | accuracy, in [0,1] 41 | """ 42 | y_true = y_true.astype(int) 43 | 44 | old_classes_gt = set(y_true[mask]) 45 | new_classes_gt = set(y_true[~mask]) 46 | 47 | assert y_pred.size == y_true.size 48 | D = max(y_pred.max(), y_true.max()) + 1 49 | w = np.zeros((D, D), dtype=int) 50 | for i in range(y_pred.size): 51 | w[y_pred[i], y_true[i]] += 1 52 | 53 | ind = linear_assignment(w.max() - w) 54 | ind = np.vstack(ind).T 55 | 56 | ind_map = {j: i for i, j in ind} 57 | total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 58 | 59 | old_acc = 0 60 | total_old_instances = 0 61 | for i in old_classes_gt: 62 | old_acc += w[ind_map[i], i] 63 | total_old_instances += sum(w[:, i]) 64 | old_acc /= total_old_instances 65 | 66 | new_acc = 0 67 | total_new_instances = 0 68 | for i in new_classes_gt: 69 | new_acc += w[ind_map[i], i] 70 | total_new_instances += sum(w[:, i]) 71 | new_acc /= total_new_instances 72 | 73 | return total_acc, old_acc, new_acc 74 | 75 | 76 | EVAL_FUNCS = { 77 | 'v1': split_cluster_acc_v1, 78 | 'v2': split_cluster_acc_v2, 79 | } 80 | 81 | def log_accs_from_preds(y_true, y_pred, mask, eval_funcs: List[str], save_name: str, T: int=None, writer: SummaryWriter=None, 82 | print_output=False): 83 | 84 | """ 85 | Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results 86 | 87 | :param y_true: GT labels 88 | :param y_pred: Predicted indices 89 | :param mask: Which instances belong to Old and New classes 90 | :param T: Epoch 91 | :param eval_funcs: Which evaluation functions to use 92 | :param save_name: What are we evaluating ACC on 93 | :param writer: Tensorboard logger 94 | :return: 95 | """ 96 | 97 | mask = mask.astype(bool) 98 | y_true = y_true.astype(int) 99 | y_pred = y_pred.astype(int) 100 | 101 | for i, f_name in enumerate(eval_funcs): 102 | 103 | acc_f = EVAL_FUNCS[f_name] 104 | all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask) 105 | log_name = f'{save_name}_{f_name}' 106 | 107 | if writer is not None: 108 | writer.add_scalars(log_name, 109 | {'Old': old_acc, 'New': new_acc, 110 | 'All': all_acc}, T) 111 | 112 | if i == 0: 113 | to_return = (all_acc, old_acc, new_acc) 114 | 115 | if print_output: 116 | print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}' 117 | print(print_str) 118 | 119 | return to_return -------------------------------------------------------------------------------- /project_utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | def get_scheduler(optimizer, args): 6 | 7 | if args.scheduler == 'step': 8 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 9 | gamma=0.1, 10 | step_size=150) 11 | 12 | elif args.scheduler == 'plateau': 13 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=50) 14 | 15 | elif args.scheduler == 'cosine': 16 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 17 | T_max=args.epochs, eta_min=args.lr * 1e-3) 18 | 19 | elif args.scheduler == 'cosine_warm_restarts': 20 | 21 | try: num_restarts = args.num_restarts 22 | except: print('Warning: Num restarts not specified...using 2'); num_restarts = 2 23 | 24 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 25 | T_0=int(args.epochs / (num_restarts + 1)), 26 | eta_min=args.lr * 1e-3) 27 | 28 | elif args.scheduler == 'cosine_warm_restarts_warmup': 29 | 30 | try: num_restarts = args.num_restarts 31 | except: print('Warning: Num restarts not specified...using 2'); num_restarts = 2 32 | 33 | scheduler = CosineAnnealingWarmupRestarts_New(warmup_epochs=10, optimizer=optimizer, 34 | T_0=int(args.epochs / (num_restarts + 1)), 35 | eta_min=args.lr * 1e-3) 36 | 37 | elif args.scheduler == 'warm_restarts_plateau': 38 | scheduler = WarmRestartPlateau(T_restart=120, optimizer=optimizer, threshold_mode='abs', threshold=0.5, 39 | mode='min', patience=100) 40 | 41 | elif args.scheduler == 'multi_step': 42 | 43 | try: 44 | 45 | steps = args.steps 46 | 47 | except: 48 | 49 | print('Warning: No step list for Multi-Step Scheduler, using constant step of 30 epochs') 50 | steps = [30 * i for i in range(1, 5)] 51 | 52 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps) 53 | 54 | else: 55 | 56 | raise NotImplementedError 57 | 58 | return scheduler 59 | 60 | 61 | class WarmRestartPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): 62 | 63 | """ 64 | Reduce learning rate on plateau and reset every T_restart epochs 65 | """ 66 | 67 | def __init__(self, T_restart, *args, ** kwargs): 68 | 69 | super().__init__(*args, **kwargs) 70 | 71 | self.T_restart = T_restart 72 | self.base_lrs = [group['lr'] for group in self.optimizer.param_groups] 73 | 74 | def step(self, *args, **kwargs): 75 | 76 | super().step(*args, **kwargs) 77 | 78 | if self.last_epoch > 0 and self.last_epoch % self.T_restart == 0: 79 | 80 | for group, lr in zip(self.optimizer.param_groups, self.base_lrs): 81 | group['lr'] = lr 82 | 83 | self._reset() 84 | 85 | 86 | class CosineAnnealingWarmupRestarts_New(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts): 87 | 88 | def __init__(self, warmup_epochs, *args, **kwargs): 89 | 90 | super(CosineAnnealingWarmupRestarts_New, self).__init__(*args, **kwargs) 91 | 92 | # Init optimizer with low learning rate 93 | for param_group in self.optimizer.param_groups: 94 | param_group['lr'] = self.eta_min 95 | 96 | self.warmup_epochs = warmup_epochs 97 | 98 | # Get target LR after warmup is complete 99 | target_lr = self.eta_min + (self.base_lrs[0] - self.eta_min) * (1 + math.cos(math.pi * warmup_epochs / self.T_i)) / 2 100 | 101 | # Linearly interpolate between minimum lr and target_lr 102 | linear_step = (target_lr - self.eta_min) / self.warmup_epochs 103 | self.warmup_lrs = [self.eta_min + linear_step * (n + 1) for n in range(warmup_epochs)] 104 | 105 | def step(self, epoch=None): 106 | 107 | # Called on super class init 108 | if epoch is None: 109 | super(CosineAnnealingWarmupRestarts_New, self).step(epoch=epoch) 110 | 111 | else: 112 | if epoch < self.warmup_epochs: 113 | lr = self.warmup_lrs[epoch] 114 | for param_group in self.optimizer.param_groups: 115 | param_group['lr'] = lr 116 | 117 | # Fulfill misc super() funcs 118 | self.last_epoch = math.floor(epoch) 119 | self.T_cur = epoch 120 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 121 | 122 | else: 123 | 124 | super(CosineAnnealingWarmupRestarts_New, self).step(epoch=epoch) -------------------------------------------------------------------------------- /data/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from data.augmentations.cut_out import * 3 | from data.augmentations.randaugment import RandAugment 4 | 5 | def get_transform(transform_type='default', image_size=32, args=None): 6 | 7 | if transform_type == 'imagenet': 8 | 9 | mean = (0.485, 0.456, 0.406) 10 | std = (0.229, 0.224, 0.225) 11 | interpolation = args.interpolation 12 | crop_pct = args.crop_pct 13 | 14 | train_transform = transforms.Compose([ 15 | transforms.Resize(int(image_size / crop_pct), interpolation), 16 | transforms.RandomCrop(image_size), 17 | transforms.RandomHorizontalFlip(p=0.5), 18 | transforms.ColorJitter(), 19 | transforms.ToTensor(), 20 | transforms.Normalize( 21 | mean=torch.tensor(mean), 22 | std=torch.tensor(std)) 23 | ]) 24 | 25 | test_transform = transforms.Compose([ 26 | transforms.Resize(int(image_size / crop_pct), interpolation), 27 | transforms.CenterCrop(image_size), 28 | transforms.ToTensor(), 29 | transforms.Normalize( 30 | mean=torch.tensor(mean), 31 | std=torch.tensor(std)) 32 | ]) 33 | 34 | elif transform_type == 'pytorch-cifar': 35 | 36 | mean = (0.4914, 0.4822, 0.4465) 37 | std = (0.2023, 0.1994, 0.2010) 38 | 39 | train_transform = transforms.Compose([ 40 | transforms.RandomCrop(image_size, padding=4), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=mean, std=std), 44 | ]) 45 | 46 | test_transform = transforms.Compose([ 47 | transforms.Resize((image_size, image_size)), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=mean, std=std), 50 | ]) 51 | 52 | elif transform_type == 'herbarium_default': 53 | 54 | train_transform = transforms.Compose([ 55 | transforms.Resize((image_size, image_size)), 56 | transforms.RandomResizedCrop(image_size, scale=(args.resize_lower_bound, 1)), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | ]) 60 | 61 | test_transform = transforms.Compose([ 62 | transforms.Resize((image_size, image_size)), 63 | transforms.ToTensor(), 64 | ]) 65 | 66 | elif transform_type == 'cutout': 67 | 68 | mean = np.array([0.4914, 0.4822, 0.4465]) 69 | std = np.array([0.2470, 0.2435, 0.2616]) 70 | 71 | train_transform = transforms.Compose([ 72 | transforms.RandomCrop(image_size, padding=4), 73 | transforms.RandomHorizontalFlip(), 74 | normalize(mean, std), 75 | cutout(mask_size=int(image_size / 2), 76 | p=1, 77 | cutout_inside=False), 78 | to_tensor(), 79 | ]) 80 | test_transform = transforms.Compose([ 81 | transforms.Resize((image_size, image_size)), 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean, std), 84 | ]) 85 | 86 | elif transform_type == 'rand-augment': 87 | 88 | mean = (0.485, 0.456, 0.406) 89 | std = (0.229, 0.224, 0.225) 90 | 91 | train_transform = transforms.Compose([ 92 | transforms.Resize((image_size, image_size)), 93 | transforms.RandomCrop(image_size, padding=4), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | transforms.Normalize(mean=mean, std=std), 97 | ]) 98 | 99 | train_transform.transforms.insert(0, RandAugment(args.rand_aug_n, args.rand_aug_m, args=None)) 100 | 101 | test_transform = transforms.Compose([ 102 | transforms.Resize((image_size, image_size)), 103 | transforms.ToTensor(), 104 | transforms.Normalize(mean=mean, std=std), 105 | ]) 106 | 107 | elif transform_type == 'random_affine': 108 | 109 | mean = (0.485, 0.456, 0.406) 110 | std = (0.229, 0.224, 0.225) 111 | interpolation = args.interpolation 112 | crop_pct = args.crop_pct 113 | 114 | train_transform = transforms.Compose([ 115 | transforms.Resize((image_size, image_size), interpolation), 116 | transforms.RandomAffine(degrees=(-45, 45), 117 | translate=(0.1, 0.1), shear=(-15, 15), scale=(0.7, args.crop_pct)), 118 | transforms.ColorJitter(), 119 | transforms.ToTensor(), 120 | transforms.Normalize( 121 | mean=torch.tensor(mean), 122 | std=torch.tensor(std)) 123 | ]) 124 | 125 | test_transform = transforms.Compose([ 126 | transforms.Resize(int(image_size / crop_pct), interpolation), 127 | transforms.CenterCrop(image_size), 128 | transforms.ToTensor(), 129 | transforms.Normalize( 130 | mean=torch.tensor(mean), 131 | std=torch.tensor(std)) 132 | ]) 133 | 134 | else: 135 | 136 | raise NotImplementedError 137 | 138 | return (train_transform, test_transform) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XCon: Learning with Experts for Fine-grained Category Discovery 2 | This repo contains the implementation of our paper: "XCon: Learning with Experts for Fine-grained Category Discovery". ([arXiv](https://arxiv.org/abs/2208.01898)) 3 | ## Abstract 4 | We address the problem of generalized category discovery (GCD) in this paper, 5 | i.e. clustering the unlabeled images leveraging the information from a set of 6 | seen classes, where the unlabeled images could contain both seen classes and 7 | unseen classes. The seen classes can be seen as an implicit criterion of 8 | classes, which makes this setting different from unsupervised clustering where 9 | the cluster criteria may be ambiguous. We mainly concern the problem of 10 | discovering categories within a fine-grained dataset since it is one of the 11 | most direct applications of category discovery, i.e. helping experts discover 12 | novel concepts within an unlabeled dataset using the implicit criterion set 13 | forth by the seen classes. State-of-the-art methods for generalized category 14 | discovery leverage contrastive learning to learn the representations, but the 15 | large inter-class similarity and intra-class variance pose a challenge for the 16 | methods because the negative examples may contain irrelevant cues for 17 | recognizing a category so the algorithms may converge to a local-minima. We 18 | present a novel method called Expert-Contrastive Learning (XCon) to help the 19 | model to mine useful information from the images by first partitioning the 20 | dataset into sub-datasets using k-means clustering and then performing 21 | contrastive learning on each of the sub-datasets to learn fine-grained 22 | discriminative features. Experiments on fine-grained datasets show a clear 23 | improved performance over the previous best methods, indicating the 24 | effectiveness of our method. 25 | 26 | ![image](https://github.com/YiXXin/XCon/blob/master/assets/overview.png) 27 | 28 | ## Requirements 29 | - Python 3.8 30 | - Pytorch 1.10.0 31 | - torchvision 0.11.1 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Datasets 37 | In our experiments, we use generic image classification datasets including [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html) and [ImageNet](https://image-net.org/download.php). 38 | 39 | We also use fine-grained image classification datasets including [CUB-200](https://www.kaggle.com/datasets/coolerextreme/cub-200-2011/versions/1), [Stanford-Cars](http://ai.stanford.edu/~jkrause/cars/car_dataset.html), [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) and [Oxford-Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/). 40 | 41 | ## Pretrained Checkpoints 42 | Our model is initialized with the parameters pretrained by DINO on ImageNet. 43 | The DINO checkpoint of ViT-B-16 is available at [here](https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth). 44 | 45 | ## Training and Evaluation Instructions 46 | ### Step 1. Set config 47 | Set the path of datasets and the directory for saving outputs in ```config.py```. 48 | ### Step 2. Dataset partitioning 49 | - Get the k-means labels for partitioning the dataset. 50 | ``` 51 | bash bash_scripts/get_kmeans_subset.sh 52 | ``` 53 | - Get the length of k expert sub-datasets. 54 | ``` 55 | bash bash_scripts/get_subset_len.sh 56 | ``` 57 | ### Step 3. Representation learning 58 | Fine-tune the model with the evaluation of semi-supervised k-means. 59 | ``` 60 | bash bash_scripts/representation_learning.sh 61 | ``` 62 | ### Step 4. Semi-supervised k-means 63 | To run the semi-supervised k-means alone by first running 64 | ``` 65 | bash bash_scripts/extract_features.sh 66 | ``` 67 | and then running 68 | ``` 69 | bash bash_scripts/ssk_means.sh 70 | ``` 71 | ### Step 5. Estimate the number of classes 72 | To estimate the number of classes in the unlabeled dataset by first running 73 | ``` 74 | bash bash_scripts/extract_features.sh 75 | ``` 76 | and then running 77 | ``` 78 | bash bash_scripts/estimate_k.sh 79 | ``` 80 | 81 | ## Results 82 | Results of our method are reported as below. You can download our model checkpoint by the link. 83 | | **Datasets** | **All** | **Old** | **New** | **Models** | 84 | |:------------|:--------:|:---------:|:---------:|:------:| 85 | | CIFAR10 | 96.0 | 97.3 | 95.4 | [ckpt](https://pan.baidu.com/s/1XKHioJp002Lm7P1xmM5Htg?pwd=xhwq) | 86 | | CIFAR100 | 74.2 | 81.2 | 60.3 | [ckpt](https://pan.baidu.com/s/1DbUpDpFj-dlO58w6GqhyKw?pwd=rvkd) | 87 | | ImageNet-100 | 77.6 | 93.5 | 69.7 | [ckpt](https://pan.baidu.com/s/1G1mY85up1ji2LLxMNBJjrw?pwd=rc7o) | 88 | | CUB-200 | 52.1 | 54.3 | 51.0 | [ckpt](https://pan.baidu.com/s/1gtuPMF-itQvt9r5kW7Y32Q?pwd=pg9m) | 89 | | Stanford-Cars | 40.5 | 58.8 | 31.7 | [ckpt](https://pan.baidu.com/s/1PDVhatM6qVUZZwjBVwSgTg?pwd=6337) | 90 | | FGVC-Aircraft | 47.7 | 44.4 | 49.4 | [ckpt](https://pan.baidu.com/s/1SwkobAaT8l-TTlYn7IXWyQ?pwd=06u1) | 91 | | Oxford-Pet | 86.7 | 91.5 | 84.1 | [ckpt](https://pan.baidu.com/s/1kCUfebbKmws9EgYrvgF5Aw?pwd=ck3k) | 92 | 93 | ## Citation 94 | If you find this repo useful for your research, please consider citing our paper: 95 | ``` 96 | @inproceedings{fei2022xcon, 97 | title = {XCon: Learning with Experts for Fine-grained Category Discovery}, 98 | author = {Yixin Fei and Zhongkai Zhao and Siwei Yang and Bingchen Zhao}, 99 | booktitle={British Machine Vision Conference (BMVC)}, 100 | year = {2022} 101 | } 102 | ``` -------------------------------------------------------------------------------- /methods/partitioning/subset_len.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from easydict import EasyDict 4 | 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | from sklearn.cluster import KMeans 8 | import torch 9 | from torch.optim import SGD, lr_scheduler 10 | from project_utils.cluster_utils import mixed_eval, AverageMeter 11 | from model import vision_transformer as vits 12 | 13 | from project_utils.general_utils import init_experiment, get_mean_lr, str2bool, get_dino_head_weights 14 | 15 | from data.augmentations import get_transform 16 | from data.get_datasets import get_datasets, get_class_splits 17 | from data.data_utils import MergedDataset 18 | 19 | from tqdm import tqdm 20 | 21 | from torch.nn import functional as F 22 | 23 | from project_utils.cluster_and_log_utils import log_accs_from_preds 24 | from config import exp_root, km_label_path, subset_len_path, dino_pretrain_path 25 | 26 | from copy import deepcopy 27 | # TODO: Debug 28 | import warnings 29 | warnings.filterwarnings("ignore", category=DeprecationWarning) 30 | 31 | class ContrastiveLearningViewGenerator(object): 32 | """Take two random crops of one image as the query and key.""" 33 | 34 | def __init__(self, base_transform, n_views=2): 35 | self.base_transform = base_transform 36 | self.n_views = n_views 37 | 38 | def __call__(self, x): 39 | return [self.base_transform(x) for i in range(self.n_views)] 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | parser = argparse.ArgumentParser( 45 | description='cluster', 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 47 | parser.add_argument('--batch_size', default=256, type=int) 48 | parser.add_argument('--num_workers', default=4, type=int) 49 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v1', 'v2']) 50 | 51 | parser.add_argument('--exp_root', type=str, default=exp_root) 52 | parser.add_argument('--transform', type=str, default='imagenet') 53 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, scars') 54 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 55 | parser.add_argument('--use_ssb_splits', type=str2bool, default=False) 56 | parser.add_argument('--n_views', default=2, type=int) 57 | 58 | parser.add_argument('--experts_num', type=int, default=8) 59 | 60 | args = parser.parse_args() 61 | device = torch.device('cuda:0') 62 | args = get_class_splits(args) 63 | 64 | args.num_labeled_classes = len(args.train_classes) 65 | args.num_unlabeled_classes = len(args.unlabeled_classes) 66 | 67 | init_experiment(args, runner_name=['metric_learn_gcd']) 68 | print(f'Using evaluation function {args.eval_funcs[0]} to print results') 69 | 70 | whole_km_labels = np.load(f'{km_label_path}/{args.dataset_name}_km_labels.npy') 71 | experts_num = np.unique(whole_km_labels) 72 | print('subset_num:', experts_num) 73 | if len(experts_num) !=args.experts_num: 74 | raise NotImplementedError 75 | 76 | args.interpolation = 3 77 | args.crop_pct = 0.875 78 | args.image_size = 224 79 | args.feat_dim = 768 80 | args.num_mlp_layers = 3 81 | args.mlp_out_dim = 65536 82 | 83 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) 84 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 85 | # -------------------- 86 | # DATASETS 87 | # -------------------- 88 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples = get_datasets(args.dataset_name, 89 | train_transform, 90 | test_transform, 91 | args) 92 | print('whole dataset:', len(train_dataset)) 93 | print('labelled:', len(labelled_train_examples)) 94 | print('unlabelled:', len(unlabelled_train_examples_test)) 95 | 96 | # ---------------------- 97 | # Building sub dataset 98 | # ------------------- 99 | sub_data_loaders = [] 100 | sub_num = [] 101 | for subset_index in experts_num: 102 | # Get subset index 103 | subset_idx = [] 104 | for idx, km_label in enumerate(whole_km_labels): 105 | if km_label == subset_index: 106 | subset_idx.append(idx) 107 | 108 | # Get the subset 109 | subset = torch.utils.data.Subset(deepcopy(train_dataset), subset_idx) # len(subset_unlabelled)=, len(datasets['train_unlabelled'])=4496 110 | 111 | # Get the numbers of each subset 112 | subdata_num = np.sum(whole_km_labels == subset_index) 113 | sub_num.append(subdata_num) 114 | 115 | # Dala loader for subset 116 | sub_data_loader = DataLoader(subset, num_workers=args.num_workers, 117 | batch_size=args.batch_size, shuffle=False) 118 | sub_data_loaders.append(sub_data_loader) 119 | 120 | label_nums = [] 121 | unlabel_nums = [] 122 | for loader in sub_data_loaders: 123 | label_num = 0 124 | unlabel_num = 0 125 | for batch in loader: 126 | images, class_labels, uq_idxs, mask_lab = batch 127 | mask_lab = mask_lab[:, 0].numpy() 128 | label_num += np.sum(mask_lab == 1) 129 | unlabel_num += np.sum(mask_lab == 0) 130 | label_nums.append(label_num) 131 | unlabel_nums.append(unlabel_num) 132 | 133 | print('labeled subset len:', label_nums) 134 | print('unlabeled subset len:', unlabel_nums) 135 | subset_nums = label_nums + unlabel_nums 136 | 137 | with open (f'{subset_len_path}/{args.dataset_name}_subset_len.txt', 'w') as f: 138 | for num in subset_nums: 139 | f.write(str(num)+'\n') 140 | -------------------------------------------------------------------------------- /data/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | from scipy import io as mat_io 6 | 7 | from torchvision.datasets.folder import default_loader 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import subsample_instances 11 | 12 | car_root = "/data/user-data/fgvc/cars/cars_{}/" 13 | meta_default_path = "/data/user-data/fgvc/cars/devkit/cars_{}.mat" 14 | 15 | class CarsDataset(Dataset): 16 | """ 17 | Cars Dataset 18 | """ 19 | def __init__(self, train=True, limit=0, data_dir=car_root, transform=None, metas=meta_default_path): 20 | 21 | data_dir = data_dir.format('train') if train else data_dir.format('test') 22 | metas = metas.format('train_annos') if train else metas.format('test_annos_withlabels') 23 | 24 | self.loader = default_loader 25 | self.data_dir = data_dir 26 | self.data = [] 27 | self.target = [] 28 | self.train = train 29 | 30 | self.transform = transform 31 | 32 | if not isinstance(metas, str): 33 | raise Exception("Train metas must be string location !") 34 | labels_meta = mat_io.loadmat(metas) 35 | 36 | for idx, img_ in enumerate(labels_meta['annotations'][0]): 37 | if limit: 38 | if idx > limit: 39 | break 40 | 41 | # self.data.append(img_resized) 42 | self.data.append(data_dir + img_[5][0]) 43 | # if self.mode == 'train': 44 | self.target.append(img_[4][0][0]) 45 | 46 | self.uq_idxs = np.array(range(len(self))) 47 | self.target_transform = None 48 | 49 | def __getitem__(self, idx): 50 | 51 | image = self.loader(self.data[idx]) 52 | target = self.target[idx] - 1 53 | 54 | if self.transform is not None: 55 | image = self.transform(image) 56 | 57 | if self.target_transform is not None: 58 | target = self.target_transform(target) 59 | 60 | idx = self.uq_idxs[idx] 61 | 62 | return image, target, idx 63 | 64 | def __len__(self): 65 | return len(self.data) 66 | 67 | 68 | def subsample_dataset(dataset, idxs): 69 | 70 | dataset.data = np.array(dataset.data)[idxs].tolist() 71 | dataset.target = np.array(dataset.target)[idxs].tolist() 72 | dataset.uq_idxs = dataset.uq_idxs[idxs] 73 | 74 | return dataset 75 | 76 | 77 | def subsample_classes(dataset, include_classes=range(160)): 78 | 79 | include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195 80 | cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars] 81 | 82 | target_xform_dict = {} 83 | for i, k in enumerate(include_classes): 84 | target_xform_dict[k] = i 85 | 86 | dataset = subsample_dataset(dataset, cls_idxs) 87 | 88 | # dataset.target_transform = lambda x: target_xform_dict[x] 89 | 90 | return dataset 91 | 92 | def get_train_val_indices(train_dataset, val_split=0.2): 93 | 94 | train_classes = np.unique(train_dataset.target) 95 | 96 | # Get train/test indices 97 | train_idxs = [] 98 | val_idxs = [] 99 | for cls in train_classes: 100 | 101 | cls_idxs = np.where(train_dataset.target == cls)[0] 102 | 103 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 104 | t_ = [x for x in cls_idxs if x not in v_] 105 | 106 | train_idxs.extend(t_) 107 | val_idxs.extend(v_) 108 | 109 | return train_idxs, val_idxs 110 | 111 | 112 | def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, 113 | split_train_val=False, seed=0): 114 | 115 | np.random.seed(seed) 116 | 117 | # Init entire training set 118 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, metas=meta_default_path, train=True) 119 | 120 | # Get labelled training set which has subsampled classes, then subsample some indices from that 121 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 122 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 123 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 124 | 125 | # Split into training and validation sets 126 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 127 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 128 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 129 | val_dataset_labelled_split.transform = test_transform 130 | 131 | # Get unlabelled data 132 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 133 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 134 | 135 | # Get test set for all classes 136 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, metas=meta_default_path, train=False) 137 | 138 | # Either split train into train and val or use test set as val 139 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 140 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 141 | 142 | all_datasets = { 143 | 'train_labelled': train_dataset_labelled, 144 | 'train_unlabelled': train_dataset_unlabelled, 145 | 'val': val_dataset_labelled, 146 | 'test': test_dataset, 147 | } 148 | 149 | return all_datasets 150 | 151 | if __name__ == '__main__': 152 | 153 | x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False) 154 | 155 | print('Printing lens...') 156 | for k, v in x.items(): 157 | if v is not None: 158 | print(f'{k}: {len(v)}') 159 | 160 | print('Printing labelled and unlabelled overlap...') 161 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 162 | print('Printing total instances in train...') 163 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 164 | 165 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}') 166 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}') 167 | print(f'Len labelled set: {len(x["train_labelled"])}') 168 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import numpy as np 3 | 4 | import os 5 | 6 | from copy import deepcopy 7 | from data.data_utils import subsample_instances 8 | from config import imagenet_root 9 | 10 | 11 | class ImageNetBase(torchvision.datasets.ImageFolder): 12 | 13 | def __init__(self, root, transform): 14 | 15 | super(ImageNetBase, self).__init__(root, transform) 16 | 17 | self.uq_idxs = np.array(range(len(self))) 18 | 19 | def __getitem__(self, item): 20 | 21 | img, label = super().__getitem__(item) 22 | uq_idx = self.uq_idxs[item] 23 | 24 | return img, label, uq_idx 25 | 26 | 27 | def subsample_dataset(dataset, idxs): 28 | 29 | imgs_ = [] 30 | for i in idxs: 31 | imgs_.append(dataset.imgs[i]) 32 | dataset.imgs = imgs_ 33 | 34 | samples_ = [] 35 | for i in idxs: 36 | samples_.append(dataset.samples[i]) 37 | dataset.samples = samples_ 38 | 39 | # dataset.imgs = [x for i, x in enumerate(dataset.imgs) if i in idxs] 40 | # dataset.samples = [x for i, x in enumerate(dataset.samples) if i in idxs] 41 | 42 | dataset.targets = np.array(dataset.targets)[idxs].tolist() 43 | dataset.uq_idxs = dataset.uq_idxs[idxs] 44 | 45 | return dataset 46 | 47 | 48 | def subsample_classes(dataset, include_classes=list(range(1000))): 49 | 50 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] 51 | 52 | target_xform_dict = {} 53 | for i, k in enumerate(include_classes): 54 | target_xform_dict[k] = i 55 | 56 | dataset = subsample_dataset(dataset, cls_idxs) 57 | dataset.target_transform = lambda x: target_xform_dict[x] 58 | 59 | return dataset 60 | 61 | 62 | def get_train_val_indices(train_dataset, val_split=0.2): 63 | 64 | train_classes = list(set(train_dataset.targets)) 65 | 66 | # Get train/test indices 67 | train_idxs = [] 68 | val_idxs = [] 69 | for cls in train_classes: 70 | 71 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] 72 | 73 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 74 | t_ = [x for x in cls_idxs if x not in v_] 75 | 76 | train_idxs.extend(t_) 77 | val_idxs.extend(v_) 78 | 79 | return train_idxs, val_idxs 80 | 81 | 82 | def get_equal_len_datasets(dataset1, dataset2): 83 | """ 84 | Make two datasets the same length 85 | """ 86 | 87 | if len(dataset1) > len(dataset2): 88 | 89 | rand_idxs = np.random.choice(range(len(dataset1)), size=(len(dataset2, ))) 90 | subsample_dataset(dataset1, rand_idxs) 91 | 92 | elif len(dataset2) > len(dataset1): 93 | 94 | rand_idxs = np.random.choice(range(len(dataset2)), size=(len(dataset1, ))) 95 | subsample_dataset(dataset2, rand_idxs) 96 | 97 | return dataset1, dataset2 98 | 99 | 100 | def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80), 101 | prop_train_labels=0.8, split_train_val=False, seed=0): 102 | 103 | np.random.seed(seed) 104 | 105 | # Subsample imagenet dataset initially to include 100 classes 106 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False) 107 | subsampled_100_classes = np.sort(subsampled_100_classes) 108 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}') 109 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))} 110 | 111 | # Init entire training set 112 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) 113 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes) 114 | 115 | # Reset dataset 116 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples] 117 | whole_training_set.targets = [s[1] for s in whole_training_set.samples] 118 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set))) 119 | whole_training_set.target_transform = None 120 | 121 | # Get labelled training set which has subsampled classes, then subsample some indices from that 122 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 123 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 124 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 125 | 126 | # Split into training and validation sets 127 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 128 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 129 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 130 | val_dataset_labelled_split.transform = test_transform 131 | 132 | # Get unlabelled data 133 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 134 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 135 | 136 | # Get test set for all classes 137 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) 138 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes) 139 | 140 | # Reset test set 141 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples] 142 | test_dataset.targets = [s[1] for s in test_dataset.samples] 143 | test_dataset.uq_idxs = np.array(range(len(test_dataset))) 144 | test_dataset.target_transform = None 145 | 146 | # Either split train into train and val or use test set as val 147 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 148 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 149 | 150 | all_datasets = { 151 | 'train_labelled': train_dataset_labelled, 152 | 'train_unlabelled': train_dataset_unlabelled, 153 | 'val': val_dataset_labelled, 154 | 'test': test_dataset, 155 | } 156 | 157 | return all_datasets 158 | 159 | 160 | if __name__ == '__main__': 161 | 162 | x = get_imagenet_100_datasets(None, None, split_train_val=False, 163 | train_classes=range(50), prop_train_labels=0.5) 164 | 165 | print('Printing lens...') 166 | for k, v in x.items(): 167 | if v is not None: 168 | print(f'{k}: {len(v)}') 169 | 170 | print('Printing labelled and unlabelled overlap...') 171 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 172 | print('Printing total instances in train...') 173 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 174 | 175 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 176 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 177 | print(f'Len labelled set: {len(x["train_labelled"])}') 178 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /project_utils/cluster_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import numpy as np 3 | import sklearn.metrics 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib 7 | matplotlib.use('agg') 8 | from scipy.optimize import linear_sum_assignment as linear_assignment 9 | import random 10 | import os 11 | import argparse 12 | 13 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 14 | from sklearn.metrics import adjusted_rand_score as ari_score 15 | 16 | from sklearn import metrics 17 | import time 18 | 19 | # ------------------------------- 20 | # Evaluation Criteria 21 | # ------------------------------- 22 | def evaluate_clustering(y_true, y_pred): 23 | 24 | start = time.time() 25 | print('Computing metrics...') 26 | if len(set(y_pred)) < 1000: 27 | acc = cluster_acc(y_true.astype(int), y_pred.astype(int)) 28 | else: 29 | acc = None 30 | 31 | nmi = nmi_score(y_true, y_pred) 32 | ari = ari_score(y_true, y_pred) 33 | pur = purity_score(y_true, y_pred) 34 | print(f'Finished computing metrics {time.time() - start}...') 35 | 36 | return acc, nmi, ari, pur 37 | 38 | 39 | def cluster_acc(y_true, y_pred, return_ind=False): 40 | """ 41 | Calculate clustering accuracy. Require scikit-learn installed 42 | 43 | # Arguments 44 | y: true labels, numpy.array with shape `(n_samples,)` 45 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 46 | 47 | # Return 48 | accuracy, in [0,1] 49 | """ 50 | y_true = y_true.astype(int) 51 | assert y_pred.size == y_true.size 52 | D = max(y_pred.max(), y_true.max()) + 1 53 | w = np.zeros((D, D), dtype=int) 54 | for i in range(y_pred.size): 55 | w[y_pred[i], y_true[i]] += 1 56 | 57 | ind = linear_assignment(w.max() - w) 58 | ind = np.vstack(ind).T 59 | 60 | if return_ind: 61 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w 62 | else: 63 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 64 | 65 | 66 | def purity_score(y_true, y_pred): 67 | # compute contingency matrix (also called confusion matrix) 68 | contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred) 69 | # return purity 70 | return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix) 71 | 72 | # ------------------------------- 73 | # Mixed Eval Function 74 | # ------------------------------- 75 | def mixed_eval(targets, preds, mask): 76 | 77 | """ 78 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask' 79 | (Mask usually corresponding to `Old' and `New' classes in GCD setting) 80 | :param targets: All ground truth labels 81 | :param preds: All predictions 82 | :param mask: Mask defining two subsets 83 | :return: 84 | """ 85 | 86 | mask = mask.astype(bool) 87 | 88 | # Labelled examples 89 | if mask.sum() == 0: # All examples come from unlabelled classes 90 | 91 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int), preds.astype(int)), \ 92 | nmi_score(targets, preds), \ 93 | ari_score(targets, preds) 94 | 95 | print('Unlabelled Classes Test acc {:.4f}, nmi {:.4f}, ari {:.4f}' 96 | .format(unlabelled_acc, unlabelled_nmi, unlabelled_ari)) 97 | 98 | # Also return ratio between labelled and unlabelled examples 99 | return (unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.mean() 100 | 101 | else: 102 | 103 | labelled_acc, labelled_nmi, labelled_ari = cluster_acc(targets.astype(int)[mask], 104 | preds.astype(int)[mask]), \ 105 | nmi_score(targets[mask], preds[mask]), \ 106 | ari_score(targets[mask], preds[mask]) 107 | 108 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int)[~mask], 109 | preds.astype(int)[~mask]), \ 110 | nmi_score(targets[~mask], preds[~mask]), \ 111 | ari_score(targets[~mask], preds[~mask]) 112 | 113 | # Also return ratio between labelled and unlabelled examples 114 | return (labelled_acc, labelled_nmi, labelled_ari), ( 115 | unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.mean() 116 | 117 | 118 | class AverageMeter(object): 119 | """Computes and stores the average and current value""" 120 | def __init__(self): 121 | self.reset() 122 | 123 | def reset(self): 124 | self.val = 0 125 | self.avg = 0 126 | self.sum = 0 127 | self.count = 0 128 | 129 | def update(self, val, n=1): 130 | self.val = val 131 | self.sum += val * n 132 | self.count += n 133 | self.avg = self.sum / self.count 134 | 135 | 136 | class Identity(nn.Module): 137 | def __init__(self): 138 | super(Identity, self).__init__() 139 | def forward(self, x): 140 | return x 141 | 142 | 143 | class BCE(nn.Module): 144 | eps = 1e-7 # Avoid calculating log(0). Use the small value of float16. 145 | def forward(self, prob1, prob2, simi): 146 | # simi: 1->similar; -1->dissimilar; 0->unknown(ignore) 147 | assert len(prob1)==len(prob2)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(prob1)),str(len(prob2)),str(len(simi))) 148 | P = prob1.mul_(prob2) 149 | P = P.sum(1) 150 | P.mul_(simi).add_(simi.eq(-1).type_as(P)) 151 | neglogP = -P.add_(BCE.eps).log_() 152 | return neglogP.mean() 153 | 154 | 155 | def PairEnum(x,mask=None): 156 | 157 | # Enumerate all pairs of feature in x 158 | assert x.ndimension() == 2, 'Input dimension must be 2' 159 | x1 = x.repeat(x.size(0), 1) 160 | x2 = x.repeat(1, x.size(0)).view(-1, x.size(1)) 161 | 162 | if mask is not None: 163 | 164 | xmask = mask.view(-1, 1).repeat(1, x.size(1)) 165 | #dim 0: #sample, dim 1:#feature 166 | x1 = x1[xmask].view(-1, x.size(1)) 167 | x2 = x2[xmask].view(-1, x.size(1)) 168 | 169 | return x1, x2 170 | 171 | 172 | def accuracy(output, target, topk=(1,)): 173 | """Computes the accuracy over the k top predictions for the specified values of k""" 174 | with torch.no_grad(): 175 | maxk = max(topk) 176 | batch_size = target.size(0) 177 | 178 | _, pred = output.topk(maxk, 1, True, True) 179 | pred = pred.t() 180 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 181 | 182 | res = [] 183 | for k in topk: 184 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 185 | res.append(correct_k.mul_(100.0 / batch_size)) 186 | return res 187 | 188 | 189 | def seed_torch(seed=1029): 190 | random.seed(seed) 191 | os.environ['PYTHONHASHSEED'] = str(seed) 192 | np.random.seed(seed) 193 | torch.manual_seed(seed) 194 | torch.cuda.manual_seed(seed) 195 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 196 | torch.backends.cudnn.benchmark = False 197 | torch.backends.cudnn.deterministic = True 198 | 199 | 200 | def str2bool(v): 201 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 202 | return True 203 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 204 | return False 205 | else: 206 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10, CIFAR100 2 | from copy import deepcopy 3 | import numpy as np 4 | 5 | from data.data_utils import subsample_instances 6 | from config import cifar_10_root, cifar_100_root 7 | 8 | 9 | class CustomCIFAR10(CIFAR10): 10 | 11 | def __init__(self, *args, **kwargs): 12 | 13 | super(CustomCIFAR10, self).__init__(*args, **kwargs) 14 | 15 | self.uq_idxs = np.array(range(len(self))) 16 | 17 | def __getitem__(self, item): 18 | 19 | img, label = super().__getitem__(item) 20 | uq_idx = self.uq_idxs[item] 21 | 22 | return img, label, uq_idx 23 | 24 | def __len__(self): 25 | return len(self.targets) 26 | 27 | 28 | class CustomCIFAR100(CIFAR100): 29 | 30 | def __init__(self, *args, **kwargs): 31 | super(CustomCIFAR100, self).__init__(*args, **kwargs) 32 | 33 | self.uq_idxs = np.array(range(len(self))) 34 | 35 | def __getitem__(self, item): 36 | img, label = super().__getitem__(item) 37 | uq_idx = self.uq_idxs[item] 38 | 39 | return img, label, uq_idx 40 | 41 | def __len__(self): 42 | return len(self.targets) 43 | 44 | 45 | def subsample_dataset(dataset, idxs): 46 | 47 | # Allow for setting in which all empty set of indices is passed 48 | 49 | if len(idxs) > 0: 50 | 51 | dataset.data = dataset.data[idxs] 52 | dataset.targets = np.array(dataset.targets)[idxs].tolist() 53 | dataset.uq_idxs = dataset.uq_idxs[idxs] 54 | 55 | return dataset 56 | 57 | else: 58 | 59 | return None 60 | 61 | 62 | def subsample_classes(dataset, include_classes=(0, 1, 8, 9)): 63 | 64 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] 65 | 66 | target_xform_dict = {} 67 | for i, k in enumerate(include_classes): 68 | target_xform_dict[k] = i 69 | 70 | dataset = subsample_dataset(dataset, cls_idxs) 71 | 72 | # dataset.target_transform = lambda x: target_xform_dict[x] 73 | 74 | return dataset 75 | 76 | 77 | def get_train_val_indices(train_dataset, val_split=0.2): 78 | 79 | train_classes = np.unique(train_dataset.targets) 80 | 81 | # Get train/test indices 82 | train_idxs = [] 83 | val_idxs = [] 84 | for cls in train_classes: 85 | 86 | cls_idxs = np.where(train_dataset.targets == cls)[0] 87 | 88 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 89 | t_ = [x for x in cls_idxs if x not in v_] 90 | 91 | train_idxs.extend(t_) 92 | val_idxs.extend(v_) 93 | 94 | return train_idxs, val_idxs 95 | 96 | 97 | def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9), 98 | prop_train_labels=0.8, split_train_val=False, seed=0): 99 | 100 | np.random.seed(seed) 101 | 102 | # Init entire training set 103 | whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True) 104 | 105 | # Get labelled training set which has subsampled classes, then subsample some indices from that 106 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 107 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 108 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 109 | 110 | # Split into training and validation sets 111 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 112 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 113 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 114 | val_dataset_labelled_split.transform = test_transform 115 | 116 | # Get unlabelled data 117 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 118 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 119 | 120 | # Get test set for all classes 121 | test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False) 122 | 123 | # Either split train into train and val or use test set as val 124 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 125 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 126 | 127 | all_datasets = { 128 | 'train_labelled': train_dataset_labelled, 129 | 'train_unlabelled': train_dataset_unlabelled, 130 | 'val': val_dataset_labelled, 131 | 'test': test_dataset, 132 | } 133 | 134 | return all_datasets 135 | 136 | 137 | def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80), 138 | prop_train_labels=0.8, split_train_val=False, seed=0): 139 | 140 | np.random.seed(seed) 141 | 142 | # Init entire training set 143 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True) 144 | 145 | # Get labelled training set which has subsampled classes, then subsample some indices from that 146 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 147 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 148 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 149 | 150 | # Split into training and validation sets 151 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 152 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 153 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 154 | val_dataset_labelled_split.transform = test_transform 155 | 156 | # Get unlabelled data 157 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 158 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 159 | 160 | # Get test set for all classes 161 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False) 162 | 163 | # Either split train into train and val or use test set as val 164 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 165 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 166 | 167 | all_datasets = { 168 | 'train_labelled': train_dataset_labelled, 169 | 'train_unlabelled': train_dataset_unlabelled, 170 | 'val': val_dataset_labelled, 171 | 'test': test_dataset, 172 | } 173 | 174 | return all_datasets 175 | 176 | 177 | if __name__ == '__main__': 178 | 179 | x = get_cifar_100_datasets(None, None, split_train_val=False, 180 | train_classes=range(80), prop_train_labels=0.5) 181 | 182 | print('Printing lens...') 183 | for k, v in x.items(): 184 | if v is not None: 185 | print(f'{k}: {len(v)}') 186 | 187 | print('Printing labelled and unlabelled overlap...') 188 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 189 | print('Printing total instances in train...') 190 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 191 | 192 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 193 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 194 | print(f'Len labelled set: {len(x["train_labelled"])}') 195 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/get_datasets.py: -------------------------------------------------------------------------------- 1 | from matplotlib.pyplot import get 2 | from data.data_utils import MergedDataset 3 | 4 | from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets 5 | from data.stanford_cars import get_scars_datasets 6 | from data.imagenet import get_imagenet_100_datasets 7 | from data.cub import get_cub_datasets 8 | from data.fgvc_aircraft import get_aircraft_datasets 9 | from data.pets import get_pets_datasets 10 | from data.flower import get_flower_datasets 11 | from data.food import get_food_datasets 12 | 13 | from data.cifar import subsample_classes as subsample_dataset_cifar 14 | from data.stanford_cars import subsample_classes as subsample_dataset_scars 15 | from data.imagenet import subsample_classes as subsample_dataset_imagenet 16 | from data.cub import subsample_classes as subsample_dataset_cub 17 | from data.fgvc_aircraft import subsample_classes as subsample_dataset_air 18 | from data.pets import subsample_classes as subsample_dataset_pets 19 | from data.flower import subsample_classes as subsample_dataset_flower 20 | from data.food import subsample_classes as subsample_dataset_food 21 | 22 | from copy import deepcopy 23 | import pickle 24 | import os 25 | 26 | from config import osr_split_dir 27 | 28 | sub_sample_class_funcs = { 29 | 'cifar10': subsample_dataset_cifar, 30 | 'cifar100': subsample_dataset_cifar, 31 | 'imagenet_100': subsample_dataset_imagenet, 32 | 'cub': subsample_dataset_cub, 33 | 'aircraft': subsample_dataset_air, 34 | 'scars': subsample_dataset_scars, 35 | 'pets': subsample_dataset_pets, 36 | 'flower': subsample_dataset_flower, 37 | 'food': subsample_dataset_food 38 | } 39 | 40 | get_dataset_funcs = { 41 | 'cifar10': get_cifar_10_datasets, 42 | 'cifar100': get_cifar_100_datasets, 43 | 'imagenet_100': get_imagenet_100_datasets, 44 | 'cub': get_cub_datasets, 45 | 'aircraft': get_aircraft_datasets, 46 | 'scars': get_scars_datasets, 47 | 'pets': get_pets_datasets, 48 | 'flower': get_flower_datasets, 49 | 'food': get_food_datasets 50 | } 51 | 52 | 53 | def get_datasets(dataset_name, train_transform, test_transform, args): 54 | 55 | """ 56 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled 57 | test_dataset, 58 | unlabelled_train_examples_test, 59 | datasets 60 | """ 61 | 62 | if dataset_name not in get_dataset_funcs.keys(): 63 | raise ValueError 64 | 65 | # Get datasets 66 | get_dataset_f = get_dataset_funcs[dataset_name] 67 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform, 68 | train_classes=args.train_classes, 69 | prop_train_labels=args.prop_train_labels, 70 | split_train_val=False) 71 | 72 | # Set target transforms: 73 | target_transform_dict = {} 74 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)): 75 | target_transform_dict[cls] = i 76 | target_transform = lambda x: target_transform_dict[x] 77 | 78 | # ['train_labelled', 'train_unlabelled', 'val', 'test'] 79 | for dataset_name, dataset in datasets.items(): 80 | if dataset is not None: 81 | dataset.target_transform = target_transform 82 | 83 | # Train split (labelled and unlabelled classes) for training 84 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']), 85 | unlabelled_dataset=deepcopy(datasets['train_unlabelled'])) 86 | 87 | test_dataset = datasets['test'] 88 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled']) 89 | unlabelled_train_examples_test.transform = test_transform 90 | 91 | labelled_train_examples = deepcopy(datasets['train_labelled']) 92 | labelled_train_examples.transform = test_transform 93 | 94 | return train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples 95 | 96 | 97 | def get_class_splits(args): 98 | 99 | # For FGVC datasets, optionally return bespoke splits 100 | if args.dataset_name in ('scars', 'cub', 'aircraft'): 101 | if hasattr(args, 'use_ssb_splits'): 102 | use_ssb_splits = args.use_ssb_splits 103 | else: 104 | use_ssb_splits = False 105 | 106 | # ------------- 107 | # GET CLASS SPLITS 108 | # ------------- 109 | if args.dataset_name == 'cifar10': 110 | 111 | args.image_size = 32 112 | args.train_classes = range(5) 113 | args.unlabeled_classes = range(5, 10) 114 | 115 | elif args.dataset_name == 'cifar100': 116 | 117 | args.image_size = 32 118 | args.train_classes = range(80) 119 | args.unlabeled_classes = range(80, 100) 120 | 121 | elif args.dataset_name == 'tinyimagenet': 122 | 123 | args.image_size = 64 124 | args.train_classes = range(100) 125 | args.unlabeled_classes = range(100, 200) 126 | 127 | 128 | elif args.dataset_name == 'imagenet_100': 129 | 130 | args.image_size = 224 131 | args.train_classes = range(50) 132 | args.unlabeled_classes = range(50, 100) 133 | 134 | elif args.dataset_name == 'scars': 135 | 136 | args.image_size = 224 137 | 138 | if use_ssb_splits: 139 | 140 | split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl') 141 | with open(split_path, 'rb') as handle: 142 | class_info = pickle.load(handle) 143 | 144 | args.train_classes = class_info['known_classes'] 145 | open_set_classes = class_info['unknown_classes'] 146 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 147 | 148 | else: 149 | 150 | args.train_classes = range(98) 151 | args.unlabeled_classes = range(98, 196) 152 | 153 | elif args.dataset_name == 'aircraft': 154 | 155 | args.image_size = 224 156 | if use_ssb_splits: 157 | 158 | split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl') 159 | with open(split_path, 'rb') as handle: 160 | class_info = pickle.load(handle) 161 | 162 | args.train_classes = class_info['known_classes'] 163 | open_set_classes = class_info['unknown_classes'] 164 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 165 | 166 | else: 167 | 168 | args.train_classes = range(50) 169 | args.unlabeled_classes = range(50, 100) 170 | 171 | elif args.dataset_name == 'cub': 172 | 173 | args.image_size = 224 174 | 175 | if use_ssb_splits: 176 | 177 | split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl') 178 | with open(split_path, 'rb') as handle: 179 | class_info = pickle.load(handle) 180 | 181 | args.train_classes = class_info['known_classes'] 182 | open_set_classes = class_info['unknown_classes'] 183 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 184 | 185 | else: 186 | 187 | args.train_classes = range(100) 188 | args.unlabeled_classes = range(100, 200) 189 | 190 | 191 | elif args.dataset_name == 'pets': 192 | 193 | args.image_size = 224 194 | args.train_classes = range(19) 195 | args.unlabeled_classes = range(19, 37) 196 | 197 | elif args.dataset_name == 'flower': 198 | 199 | args.image_size = 224 200 | args.train_classes = range(51) 201 | args.unlabeled_classes = range(51, 102) 202 | 203 | elif args.dataset_name == 'food': 204 | 205 | args.image_size = 224 206 | args.train_classes = range(51) 207 | args.unlabeled_classes = range(51, 101) 208 | 209 | else: 210 | 211 | raise NotImplementedError 212 | 213 | return args -------------------------------------------------------------------------------- /data/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from torchvision.datasets.folder import default_loader 7 | from torchvision.datasets.utils import download_url 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import subsample_instances 11 | from config import cub_root 12 | 13 | 14 | class CustomCub2011(Dataset): 15 | base_folder = 'CUB_200_2011/images' 16 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 17 | filename = 'CUB_200_2011.tgz' 18 | tgz_md5 = '97eceeb196236b17998738112f37df78' 19 | 20 | def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=True): 21 | 22 | self.root = os.path.expanduser(root) 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | 26 | self.loader = loader 27 | self.train = train 28 | 29 | 30 | if download: 31 | self._download() 32 | 33 | if not self._check_integrity(): 34 | raise RuntimeError('Dataset not found or corrupted.' + 35 | ' You can use download=True to download it') 36 | 37 | self.uq_idxs = np.array(range(len(self))) 38 | 39 | def _load_metadata(self): 40 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 41 | names=['img_id', 'filepath']) 42 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 43 | sep=' ', names=['img_id', 'target']) 44 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 45 | sep=' ', names=['img_id', 'is_training_img']) 46 | 47 | data = images.merge(image_class_labels, on='img_id') 48 | self.data = data.merge(train_test_split, on='img_id') 49 | 50 | if self.train: 51 | self.data = self.data[self.data.is_training_img == 1] 52 | else: 53 | self.data = self.data[self.data.is_training_img == 0] 54 | 55 | def _check_integrity(self): 56 | try: 57 | self._load_metadata() 58 | except Exception: 59 | return False 60 | 61 | for index, row in self.data.iterrows(): 62 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 63 | if not os.path.isfile(filepath): 64 | print(filepath) 65 | return False 66 | return True 67 | 68 | def _download(self): 69 | import tarfile 70 | 71 | if self._check_integrity(): 72 | print('Files already downloaded and verified') 73 | return 74 | 75 | download_url(self.url, self.root, self.filename, self.tgz_md5) 76 | 77 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 78 | tar.extractall(path=self.root) 79 | 80 | def __len__(self): 81 | return len(self.data) 82 | 83 | def __getitem__(self, idx): 84 | sample = self.data.iloc[idx] 85 | path = os.path.join(self.root, self.base_folder, sample.filepath) 86 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 87 | img = self.loader(path) 88 | 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | 95 | return img, target, self.uq_idxs[idx] 96 | 97 | 98 | def subsample_dataset(dataset, idxs): 99 | 100 | mask = np.zeros(len(dataset)).astype('bool') 101 | mask[idxs] = True 102 | 103 | dataset.data = dataset.data[mask] 104 | dataset.uq_idxs = dataset.uq_idxs[mask] 105 | 106 | return dataset 107 | 108 | 109 | def subsample_classes(dataset, include_classes=range(160)): 110 | 111 | include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199 112 | cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub] 113 | 114 | # TODO: For now have no target transform 115 | target_xform_dict = {} 116 | for i, k in enumerate(include_classes): 117 | target_xform_dict[k] = i 118 | 119 | dataset = subsample_dataset(dataset, cls_idxs) 120 | 121 | dataset.target_transform = lambda x: target_xform_dict[x] 122 | 123 | return dataset 124 | 125 | 126 | def get_train_val_indices(train_dataset, val_split=0.2): 127 | 128 | train_classes = np.unique(train_dataset.data['target']) 129 | 130 | # Get train/test indices 131 | train_idxs = [] 132 | val_idxs = [] 133 | for cls in train_classes: 134 | 135 | cls_idxs = np.where(train_dataset.data['target'] == cls)[0] 136 | 137 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 138 | t_ = [x for x in cls_idxs if x not in v_] 139 | 140 | train_idxs.extend(t_) 141 | val_idxs.extend(v_) 142 | 143 | return train_idxs, val_idxs 144 | 145 | 146 | def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, 147 | split_train_val=False, seed=0): 148 | 149 | np.random.seed(seed) 150 | 151 | # Init entire training set 152 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True, download=False) 153 | 154 | # Get labelled training set which has subsampled classes, then subsample some indices from that 155 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 156 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 157 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 158 | 159 | # Split into training and validation sets 160 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 161 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 162 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 163 | val_dataset_labelled_split.transform = test_transform 164 | 165 | # Get unlabelled data 166 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 167 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 168 | 169 | # Get test set for all classes 170 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False, download=False) 171 | 172 | # Either split train into train and val or use test set as val 173 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 174 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 175 | 176 | all_datasets = { 177 | 'train_labelled': train_dataset_labelled, 178 | 'train_unlabelled': train_dataset_unlabelled, 179 | 'val': val_dataset_labelled, 180 | 'test': test_dataset, 181 | } 182 | 183 | return all_datasets 184 | 185 | def main(): 186 | 187 | x = get_cub_datasets(None, None, split_train_val=False, 188 | train_classes=range(100), prop_train_labels=0.5) 189 | 190 | print('Printing lens...') 191 | for k, v in x.items(): 192 | if v is not None: 193 | print(f'{k}: {len(v)}') 194 | 195 | print('Printing labelled and unlabelled overlap...') 196 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 197 | print('Printing total instances in train...') 198 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 199 | 200 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}') 201 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}') 202 | print(f'Len labelled set: {len(x["train_labelled"])}') 203 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') 204 | 205 | if __name__ == '__main__': 206 | main() -------------------------------------------------------------------------------- /methods/clustering/ssk_means.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | from sklearn.cluster import KMeans 7 | import torch 8 | from project_utils.cluster_utils import str2bool 9 | from project_utils.general_utils import seed_torch 10 | from project_utils.cluster_and_log_utils import log_accs_from_preds 11 | 12 | from methods.clustering.feature_vector_dataset import FeatureVectorDataset 13 | from data.get_datasets import get_datasets, get_class_splits 14 | from methods.clustering.faster_mix_k_means_pytorch import K_Means as SemiSupKMeans 15 | 16 | from tqdm import tqdm 17 | from config import feature_extract_dir 18 | 19 | # TODO: Debug 20 | import warnings 21 | warnings.filterwarnings("ignore", category=DeprecationWarning) 22 | 23 | def kmeans_semi_sup(merge_test_loader, args, K=None): 24 | 25 | """ 26 | In this case, the test loader needs to have the labelled and unlabelled subsets of the training data 27 | """ 28 | 29 | if K is None: 30 | K = args.num_labeled_classes + args.num_unlabeled_classes 31 | 32 | all_feats = [] 33 | targets = np.array([]) 34 | mask_lab = np.array([]) # From all the data, which instances belong to the labelled set 35 | mask_cls = np.array([]) # From all the data, which instances belong to Old classes 36 | 37 | print('Collating features...') 38 | # First extract all features 39 | for batch_idx, (feats, label, _, mask_lab_) in enumerate(tqdm(merge_test_loader)): 40 | 41 | feats = feats.to(device) 42 | 43 | feats = torch.nn.functional.normalize(feats, dim=-1) 44 | 45 | all_feats.append(feats.cpu().numpy()) 46 | targets = np.append(targets, label.cpu().numpy()) 47 | mask_cls = np.append(mask_cls, np.array([True if x.item() in range(len(args.train_classes)) 48 | else False for x in label])) 49 | mask_lab = np.append(mask_lab, mask_lab_.cpu().bool().numpy()) 50 | 51 | # ----------------------- 52 | # K-MEANS 53 | # ----------------------- 54 | mask_lab = mask_lab.astype(bool) 55 | mask_cls = mask_cls.astype(bool) 56 | 57 | all_feats = np.concatenate(all_feats) # (5994, 768) 58 | 59 | l_feats = all_feats[mask_lab] # Get labelled set [1498,768] 60 | u_feats = all_feats[~mask_lab] # Get unlabelled set 61 | l_targets = targets[mask_lab] # Get labelled targets 62 | u_targets = targets[~mask_lab] # Get unlabelled targets 63 | 64 | print('Fitting Semi-Supervised K-Means...') 65 | kmeans = SemiSupKMeans(k=K, tolerance=1e-4, max_iterations=args.max_kmeans_iter, init='k-means++', 66 | n_init=args.k_means_init, random_state=None, n_jobs=None, pairwise_batch_size=1024, mode=None) 67 | 68 | l_feats, u_feats, l_targets, u_targets = (torch.from_numpy(x).to(device) for 69 | x in (l_feats, u_feats, l_targets, u_targets)) 70 | 71 | kmeans.fit_mix(u_feats, l_feats, l_targets) 72 | all_preds = kmeans.labels_.cpu().numpy() 73 | u_targets = u_targets.cpu().numpy() 74 | 75 | # ----------------------- 76 | # EVALUATE 77 | # ----------------------- 78 | # Get preds corresponding to unlabelled set 79 | preds = all_preds[~mask_lab] 80 | 81 | # Get portion of mask_cls which corresponds to the unlabelled set 82 | mask = mask_cls[~mask_lab] 83 | mask = mask.astype(bool) 84 | 85 | # ----------------------- 86 | # EVALUATE 87 | # ----------------------- 88 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=u_targets, y_pred=preds, mask=mask, eval_funcs=args.eval_funcs, 89 | save_name='SS-K-Means Train ACC Unlabelled', print_output=True) 90 | 91 | return all_acc, old_acc, new_acc, kmeans 92 | 93 | if __name__ == "__main__": 94 | 95 | parser = argparse.ArgumentParser( 96 | description='cluster', 97 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 98 | parser.add_argument('--batch_size', default=128, type=int) 99 | parser.add_argument('--num_workers', default=8, type=int) 100 | parser.add_argument('--K', default=None, type=int, help='Set manually to run with custom K') 101 | parser.add_argument('--root_dir', type=str, default=feature_extract_dir) 102 | parser.add_argument('--warmup_model_exp_id', type=str, default=None) 103 | parser.add_argument('--use_best_model', type=str2bool, default=True) 104 | parser.add_argument('--spatial', type=str2bool, default=False) 105 | parser.add_argument('--semi_sup', type=str2bool, default=True) 106 | parser.add_argument('--max_kmeans_iter', type=int, default=10) 107 | parser.add_argument('--k_means_init', type=int, default=10) 108 | parser.add_argument('--model_name', type=str, default='vit_dino', help='Format is {model_name}_{pretrain}') 109 | parser.add_argument('--dataset_name', type=str, default='aircraft', help='options: cifar10, cifar100, scars') 110 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 111 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v1', 'v2']) 112 | parser.add_argument('--use_ssb_splits', type=str2bool, default=True) 113 | 114 | # ---------------------- 115 | # INIT 116 | # ---------------------- 117 | args = parser.parse_args() 118 | cluster_accs = {} 119 | seed_torch(0) 120 | args.save_dir = os.path.join(args.root_dir, f'{args.model_name}_{args.dataset_name}') 121 | 122 | args = get_class_splits(args) 123 | 124 | args.num_labeled_classes = len(args.train_classes) 125 | args.num_unlabeled_classes = len(args.unlabeled_classes) 126 | 127 | device = torch.device('cuda:0') 128 | args.device = device 129 | print(args) 130 | 131 | if args.warmup_model_exp_id is not None: 132 | 133 | args.save_dir += '_' + args.warmup_model_exp_id 134 | 135 | if args.use_best_model: 136 | args.save_dir += '_best' 137 | 138 | print(f'Using features from experiment: {args.warmup_model_exp_id}') 139 | else: 140 | print(f'Using pretrained {args.model_name} features...') 141 | 142 | print(args.save_dir) 143 | 144 | # -------------------- 145 | # DATASETS 146 | # -------------------- 147 | print('Building datasets...') 148 | train_transform, test_transform = None, None 149 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples = get_datasets(args.dataset_name, 150 | train_transform, test_transform, args) 151 | 152 | # Set target transforms: 153 | target_transform_dict = {} 154 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)): 155 | target_transform_dict[cls] = i 156 | target_transform = lambda x: target_transform_dict[x] 157 | 158 | # Convert to feature vector dataset 159 | test_dataset = FeatureVectorDataset(base_dataset=test_dataset, feature_root=os.path.join(args.save_dir, 'test')) 160 | unlabelled_train_examples_test = FeatureVectorDataset(base_dataset=unlabelled_train_examples_test, 161 | feature_root=os.path.join(args.save_dir, 'train')) 162 | train_dataset = FeatureVectorDataset(base_dataset=train_dataset, feature_root=os.path.join(args.save_dir, 'train')) 163 | train_dataset.target_transform = target_transform 164 | 165 | unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, 166 | batch_size=args.batch_size, shuffle=False) 167 | test_loader = DataLoader(test_dataset, num_workers=args.num_workers, 168 | batch_size=args.batch_size, shuffle=False) 169 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, 170 | batch_size=args.batch_size, shuffle=False) 171 | 172 | print('Performing SS-K-Means on all in the training data...') 173 | all_acc, old_acc, new_acc, kmeans = kmeans_semi_sup(train_loader, args, K=args.K) 174 | cluster_save_path = os.path.join(args.save_dir, 'ss_kmeans_cluster_centres.pt') 175 | torch.save(kmeans.cluster_centers_, cluster_save_path) -------------------------------------------------------------------------------- /methods/partitioning/kmeans_subset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from easydict import EasyDict 4 | 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | from sklearn.cluster import KMeans 8 | import torch 9 | from torch.optim import SGD, lr_scheduler 10 | from project_utils.cluster_utils import mixed_eval, AverageMeter 11 | from model import vision_transformer as vits 12 | 13 | from project_utils.general_utils import init_experiment, get_mean_lr, str2bool, get_dino_head_weights 14 | 15 | from data.augmentations import get_transform 16 | from data.get_datasets import get_datasets, get_class_splits 17 | from data.data_utils import MergedDataset 18 | 19 | from tqdm import tqdm 20 | 21 | from torch.nn import functional as F 22 | 23 | from project_utils.cluster_and_log_utils import log_accs_from_preds 24 | from config import exp_root, km_label_path, dino_pretrain_path, moco_pretrain_path, mae_pretrain_path 25 | 26 | from copy import deepcopy 27 | # TODO: Debug 28 | import warnings 29 | warnings.filterwarnings("ignore", category=DeprecationWarning) 30 | 31 | 32 | class ContrastiveLearningViewGenerator(object): 33 | """Take two random crops of one image as the query and key.""" 34 | 35 | def __init__(self, base_transform, n_views=2): 36 | self.base_transform = base_transform 37 | self.n_views = n_views 38 | 39 | def __call__(self, x): 40 | return [self.base_transform(x) for i in range(self.n_views)] 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | parser = argparse.ArgumentParser( 46 | description='cluster', 47 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 48 | parser.add_argument('--batch_size', default=256, type=int) 49 | parser.add_argument('--num_workers', default=4, type=int) 50 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v1', 'v2']) 51 | 52 | parser.add_argument('--warmup_model_dir', type=str, default=None) 53 | parser.add_argument('--model_name', type=str, default='vit_dino', help='Format is {model_name}_{pretrain}') 54 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, scars') 55 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 56 | parser.add_argument('--use_ssb_splits', type=str2bool, default=False) 57 | 58 | parser.add_argument('--grad_from_block', type=int, default=11) 59 | parser.add_argument('--lr', type=float, default=0.1) 60 | parser.add_argument('--save_best_thresh', type=float, default=None) 61 | parser.add_argument('--gamma', type=float, default=0.1) 62 | parser.add_argument('--momentum', type=float, default=0.9) 63 | parser.add_argument('--weight_decay', type=float, default=1e-4) 64 | parser.add_argument('--epochs', default=200, type=int) 65 | parser.add_argument('--exp_root', type=str, default=exp_root) 66 | parser.add_argument('--transform', type=str, default='imagenet') 67 | parser.add_argument('--seed', default=1, type=int) 68 | 69 | parser.add_argument('--base_model', type=str, default='vit_dino') 70 | parser.add_argument('--n_views', default=2, type=int) 71 | 72 | parser.add_argument('--experts_num', type=int, default=8) 73 | 74 | parser.add_argument('--pretrain_model', type=str, default='dino') 75 | 76 | args = parser.parse_args() 77 | device = torch.device('cuda:0') 78 | args = get_class_splits(args) 79 | 80 | args.num_labeled_classes = len(args.train_classes) 81 | args.num_unlabeled_classes = len(args.unlabeled_classes) 82 | 83 | init_experiment(args, runner_name=['metric_learn_gcd']) 84 | print(f'Using evaluation function {args.eval_funcs[0]} to print results') 85 | 86 | if args.base_model == 'vit_dino': 87 | 88 | args.interpolation = 3 89 | args.crop_pct = 0.875 90 | if args.pretrain_model == 'dino': 91 | pretrain_path = dino_pretrain_path 92 | 93 | model = vits.__dict__['vit_base']() 94 | 95 | if args.pretrain_model == 'dino': 96 | weight = torch.load(pretrain_path, map_location='cpu') 97 | 98 | msg = model.load_state_dict(weight, strict=False) 99 | print(msg) 100 | 101 | if args.warmup_model_dir is not None: 102 | print(f'Loading weights from {args.warmup_model_dir}') 103 | model.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu')) 104 | 105 | model.to(device) 106 | 107 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model 108 | args.image_size = 224 109 | args.feat_dim = 768 110 | args.num_mlp_layers = 3 111 | args.mlp_out_dim = 65536 112 | 113 | # ---------------------- 114 | # HOW MUCH OF BASE MODEL TO FINETUNE 115 | # ---------------------- 116 | for m in model.parameters(): 117 | m.requires_grad = False 118 | 119 | # Only finetune layers from block 'args.grad_from_block' onwards 120 | for name, m in model.named_parameters(): 121 | if 'block' in name: 122 | block_num = int(name.split('.')[1]) 123 | if block_num >= args.grad_from_block: 124 | m.requires_grad = True 125 | 126 | else: 127 | 128 | raise NotImplementedError 129 | 130 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) 131 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 132 | 133 | # -------------------- 134 | # DATASETS 135 | # -------------------- 136 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples = get_datasets(args.dataset_name, 137 | train_transform, 138 | test_transform, 139 | args) 140 | whole_train_test_dataset = MergedDataset(deepcopy(labelled_train_examples),deepcopy(unlabelled_train_examples_test)) 141 | 142 | 143 | # -------------------- 144 | # SAMPLER 145 | # Sampler which balances labelled and unlabelled examples in each batch 146 | # -------------------- 147 | label_len = len(train_dataset.labelled_dataset) 148 | unlabelled_len = len(train_dataset.unlabelled_dataset) 149 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] 150 | sample_weights = torch.DoubleTensor(sample_weights) 151 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset)) 152 | 153 | # -------------------- 154 | # DATALOADERS 155 | # -------------------- 156 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, 157 | sampler=sampler, drop_last=True) 158 | test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, 159 | batch_size=args.batch_size, shuffle=False) 160 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers, 161 | batch_size=args.batch_size, shuffle=False) 162 | train_loader_labelled = DataLoader(labelled_train_examples, num_workers=args.num_workers, 163 | batch_size=args.batch_size, shuffle=False) 164 | whole_train_test_loader = DataLoader(whole_train_test_dataset, num_workers=args.num_workers, 165 | batch_size=args.batch_size, shuffle=False) 166 | 167 | model.eval() 168 | all_feats = [] 169 | targets = np.array([]) 170 | mask = np.array([]) 171 | for batch_idx, batch in enumerate(tqdm(whole_train_test_loader)): 172 | images, label, _, _ = batch 173 | images = images.to(device) 174 | feats = model(images) 175 | feats = torch.nn.functional.normalize(feats, dim=-1) 176 | all_feats.append(feats.cpu().detach().numpy()) 177 | targets = np.append(targets, label.cpu().detach().numpy()) 178 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) 179 | else False for x in label])) 180 | print('Kmeans...') 181 | 182 | all_feats = np.concatenate(all_feats) 183 | kmeans = KMeans(n_clusters=args.experts_num, random_state=0).fit(all_feats) 184 | preds = kmeans.labels_ 185 | print('Done!') 186 | print('feats length:', all_feats.shape) 187 | 188 | np.save(f'{km_label_path}/{args.dataset_name}_km_labels.npy', preds) 189 | -------------------------------------------------------------------------------- /data/augmentations/randaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | """ 4 | 5 | import random 6 | 7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | 13 | def ShearX(img, v): # [-0.3, 0.3] 14 | assert -0.3 <= v <= 0.3 15 | if random.random() > 0.5: 16 | v = -v 17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 18 | 19 | 20 | def ShearY(img, v): # [-0.3, 0.3] 21 | assert -0.3 <= v <= 0.3 22 | if random.random() > 0.5: 23 | v = -v 24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 25 | 26 | 27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 28 | assert -0.45 <= v <= 0.45 29 | if random.random() > 0.5: 30 | v = -v 31 | v = v * img.size[0] 32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 33 | 34 | 35 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 36 | assert 0 <= v 37 | if random.random() > 0.5: 38 | v = -v 39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 40 | 41 | 42 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 43 | assert -0.45 <= v <= 0.45 44 | if random.random() > 0.5: 45 | v = -v 46 | v = v * img.size[1] 47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 48 | 49 | 50 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 51 | assert 0 <= v 52 | if random.random() > 0.5: 53 | v = -v 54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 55 | 56 | 57 | def Rotate(img, v): # [-30, 30] 58 | assert -30 <= v <= 30 59 | if random.random() > 0.5: 60 | v = -v 61 | return img.rotate(v) 62 | 63 | 64 | def AutoContrast(img, _): 65 | return PIL.ImageOps.autocontrast(img) 66 | 67 | 68 | def Invert(img, _): 69 | return PIL.ImageOps.invert(img) 70 | 71 | 72 | def Equalize(img, _): 73 | return PIL.ImageOps.equalize(img) 74 | 75 | 76 | def Flip(img, _): # not from the paper 77 | return PIL.ImageOps.mirror(img) 78 | 79 | 80 | def Solarize(img, v): # [0, 256] 81 | assert 0 <= v <= 256 82 | return PIL.ImageOps.solarize(img, v) 83 | 84 | 85 | def SolarizeAdd(img, addition=0, threshold=128): 86 | img_np = np.array(img).astype(np.int) 87 | img_np = img_np + addition 88 | img_np = np.clip(img_np, 0, 255) 89 | img_np = img_np.astype(np.uint8) 90 | img = Image.fromarray(img_np) 91 | return PIL.ImageOps.solarize(img, threshold) 92 | 93 | 94 | def Posterize(img, v): # [4, 8] 95 | v = int(v) 96 | v = max(1, v) 97 | return PIL.ImageOps.posterize(img, v) 98 | 99 | 100 | def Contrast(img, v): # [0.1,1.9] 101 | assert 0.1 <= v <= 1.9 102 | return PIL.ImageEnhance.Contrast(img).enhance(v) 103 | 104 | 105 | def Color(img, v): # [0.1,1.9] 106 | assert 0.1 <= v <= 1.9 107 | return PIL.ImageEnhance.Color(img).enhance(v) 108 | 109 | 110 | def Brightness(img, v): # [0.1,1.9] 111 | assert 0.1 <= v <= 1.9 112 | return PIL.ImageEnhance.Brightness(img).enhance(v) 113 | 114 | 115 | def Sharpness(img, v): # [0.1,1.9] 116 | assert 0.1 <= v <= 1.9 117 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 118 | 119 | 120 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 121 | assert 0.0 <= v <= 0.2 122 | if v <= 0.: 123 | return img 124 | 125 | v = v * img.size[0] 126 | return CutoutAbs(img, v) 127 | 128 | 129 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 130 | # assert 0 <= v <= 20 131 | if v < 0: 132 | return img 133 | w, h = img.size 134 | x0 = np.random.uniform(w) 135 | y0 = np.random.uniform(h) 136 | 137 | x0 = int(max(0, x0 - v / 2.)) 138 | y0 = int(max(0, y0 - v / 2.)) 139 | x1 = min(w, x0 + v) 140 | y1 = min(h, y0 + v) 141 | 142 | xy = (x0, y0, x1, y1) 143 | color = (125, 123, 114) 144 | # color = (0, 0, 0) 145 | img = img.copy() 146 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 147 | return img 148 | 149 | 150 | def SamplePairing(imgs): # [0, 0.4] 151 | def f(img1, v): 152 | i = np.random.choice(len(imgs)) 153 | img2 = PIL.Image.fromarray(imgs[i]) 154 | return PIL.Image.blend(img1, img2, v) 155 | 156 | return f 157 | 158 | 159 | def Identity(img, v): 160 | return img 161 | 162 | 163 | def augment_list(): # 16 oeprations and their ranges 164 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 165 | # l = [ 166 | # (Identity, 0., 1.0), 167 | # (ShearX, 0., 0.3), # 0 168 | # (ShearY, 0., 0.3), # 1 169 | # (TranslateX, 0., 0.33), # 2 170 | # (TranslateY, 0., 0.33), # 3 171 | # (Rotate, 0, 30), # 4 172 | # (AutoContrast, 0, 1), # 5 173 | # (Invert, 0, 1), # 6 174 | # (Equalize, 0, 1), # 7 175 | # (Solarize, 0, 110), # 8 176 | # (Posterize, 4, 8), # 9 177 | # # (Contrast, 0.1, 1.9), # 10 178 | # (Color, 0.1, 1.9), # 11 179 | # (Brightness, 0.1, 1.9), # 12 180 | # (Sharpness, 0.1, 1.9), # 13 181 | # # (Cutout, 0, 0.2), # 14 182 | # # (SamplePairing(imgs), 0, 0.4), # 15 183 | # ] 184 | 185 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 186 | l = [ 187 | (AutoContrast, 0, 1), 188 | (Equalize, 0, 1), 189 | (Invert, 0, 1), 190 | (Rotate, 0, 30), 191 | (Posterize, 0, 4), 192 | (Solarize, 0, 256), 193 | (SolarizeAdd, 0, 110), 194 | (Color, 0.1, 1.9), 195 | (Contrast, 0.1, 1.9), 196 | (Brightness, 0.1, 1.9), 197 | (Sharpness, 0.1, 1.9), 198 | (ShearX, 0., 0.3), 199 | (ShearY, 0., 0.3), 200 | (CutoutAbs, 0, 40), 201 | (TranslateXabs, 0., 100), 202 | (TranslateYabs, 0., 100), 203 | ] 204 | 205 | return l 206 | 207 | def augment_list_svhn(): # 16 oeprations and their ranges 208 | 209 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 210 | l = [ 211 | (AutoContrast, 0, 1), 212 | (Equalize, 0, 1), 213 | (Invert, 0, 1), 214 | (Posterize, 0, 4), 215 | (Solarize, 0, 256), 216 | (SolarizeAdd, 0, 110), 217 | (Color, 0.1, 1.9), 218 | (Contrast, 0.1, 1.9), 219 | (Brightness, 0.1, 1.9), 220 | (Sharpness, 0.1, 1.9), 221 | (ShearX, 0., 0.3), 222 | (ShearY, 0., 0.3), 223 | (CutoutAbs, 0, 40), 224 | ] 225 | 226 | return l 227 | 228 | 229 | class Lighting(object): 230 | """Lighting noise(AlexNet - style PCA - based noise)""" 231 | 232 | def __init__(self, alphastd, eigval, eigvec): 233 | self.alphastd = alphastd 234 | self.eigval = torch.Tensor(eigval) 235 | self.eigvec = torch.Tensor(eigvec) 236 | 237 | def __call__(self, img): 238 | if self.alphastd == 0: 239 | return img 240 | 241 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 242 | rgb = self.eigvec.type_as(img).clone() \ 243 | .mul(alpha.view(1, 3).expand(3, 3)) \ 244 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 245 | .sum(1).squeeze() 246 | 247 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 248 | 249 | 250 | class CutoutDefault(object): 251 | """ 252 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 253 | """ 254 | def __init__(self, length): 255 | self.length = length 256 | 257 | def __call__(self, img): 258 | h, w = img.size(1), img.size(2) 259 | mask = np.ones((h, w), float) 260 | y = np.random.randint(h) 261 | x = np.random.randint(w) 262 | 263 | y1 = np.clip(y - self.length // 2, 0, h) 264 | y2 = np.clip(y + self.length // 2, 0, h) 265 | x1 = np.clip(x - self.length // 2, 0, w) 266 | x2 = np.clip(x + self.length // 2, 0, w) 267 | 268 | mask[y1: y2, x1: x2] = 0. 269 | mask = torch.from_numpy(mask) 270 | mask = mask.expand_as(img) 271 | img *= mask 272 | return img 273 | 274 | 275 | class RandAugment: 276 | def __init__(self, n, m, args=None): 277 | self.n = n # [1, 2] 278 | self.m = m # [0...30] 279 | 280 | if args is None: 281 | self.augment_list = augment_list() 282 | 283 | elif args.dataset == 'svhn' or args.dataset == 'mnist': 284 | self.augment_list = augment_list_svhn() 285 | 286 | else: 287 | self.augment_list = augment_list() 288 | 289 | def __call__(self, img): 290 | ops = random.choices(self.augment_list, k=self.n) 291 | for op, minval, maxval in ops: 292 | val = (float(self.m) / 30) * float(maxval - minval) + minval 293 | img = op(img, val) 294 | 295 | return img -------------------------------------------------------------------------------- /data/food.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import pathlib 4 | import json 5 | from typing import Any, Callable, Optional, Union, Tuple 6 | from typing import Sequence 7 | 8 | import numpy as np 9 | from copy import deepcopy 10 | from scipy import io as mat_io 11 | 12 | from data.data_utils import subsample_instances 13 | from config import food_root 14 | 15 | from torchvision.datasets.folder import default_loader 16 | from torch.utils.data import Dataset 17 | 18 | from data.data_utils import subsample_instances 19 | 20 | 21 | def make_dataset(dir, image_ids, targets): 22 | assert(len(image_ids) == len(targets)) 23 | images = [] 24 | dir = os.path.expanduser(dir) 25 | for i in range(len(image_ids)): 26 | item = (os.path.join(dir, 'images', 27 | '%s.jpg' % image_ids[i]), targets[i]) 28 | images.append(item) 29 | return images 30 | 31 | class Food101(Dataset): 32 | """`The Food-101 Data Set `_. 33 | 34 | The Food-101 is a challenging data set of 101 food categories, with 101'000 images. 35 | For each class, 250 manually reviewed test images are provided as well as 750 training images. 36 | On purpose, the training images were not cleaned, and thus still contain some amount of noise. 37 | This comes mostly in the form of intense colors and sometimes wrong labels. All images were 38 | rescaled to have a maximum side length of 512 pixels. 39 | 40 | 41 | Args: 42 | root (string): Root directory of the dataset. 43 | split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. 44 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 45 | version. E.g, ``transforms.RandomCrop``. 46 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 47 | download (bool, optional): If True, downloads the dataset from the internet and 48 | puts it in root directory. If dataset is already downloaded, it is not 49 | downloaded again. Default is False. 50 | """ 51 | 52 | _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" 53 | _MD5 = "85eeb15f3717b99a5da872d97d918f87" 54 | splits = ('train', 'test') 55 | 56 | def __init__( 57 | self, 58 | root: str, 59 | split: str = "train", 60 | transform: Optional[Callable] = None, 61 | target_transform: Optional[Callable] = None, 62 | download: bool = False, 63 | loader=default_loader 64 | ) -> None: 65 | if split not in self.splits: 66 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 67 | split, ', '.join(self.splits), 68 | )) 69 | self.split = split 70 | self.root = root 71 | self.transform = transform 72 | self.target_transform = target_transform 73 | self.loader = loader 74 | 75 | self._base_folder = pathlib.Path(self.root) 76 | self._meta_folder = self._base_folder / "meta" 77 | self._images_folder = self._base_folder / "images" 78 | 79 | if download: 80 | self._download() 81 | 82 | if not self._check_exists(): 83 | raise RuntimeError("Dataset not found. You can use download=True to download it") 84 | 85 | self._labels = [] 86 | self._image_files = [] 87 | self.image_ids = [] 88 | with open(self._meta_folder / f"{split}.json") as f: 89 | metadata = json.loads(f.read()) 90 | 91 | self.classes = sorted(metadata.keys()) 92 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 93 | 94 | for class_label, im_rel_paths in metadata.items(): 95 | self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths) 96 | self.image_ids += [im_rel_path for im_rel_path in im_rel_paths] 97 | self._image_files += [ 98 | self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths 99 | ] 100 | 101 | samples = make_dataset(self.root, self.image_ids, self._labels) 102 | self.samples = samples 103 | 104 | self.uq_idxs = np.array(range(len(self))) 105 | 106 | def __len__(self) -> int: 107 | return len(self.samples) 108 | 109 | def __getitem__(self, idx) -> Tuple[Any, Any]: 110 | 111 | path, target = self.samples[idx] 112 | sample = self.loader(path) 113 | if self.transform is not None: 114 | sample = self.transform(sample) 115 | if self.target_transform is not None: 116 | target = self.target_transform(target) 117 | 118 | return sample, target, self.uq_idxs[idx] 119 | 120 | def extra_repr(self) -> str: 121 | return f"split={self._split}" 122 | 123 | def _check_exists(self) -> bool: 124 | return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder)) 125 | 126 | def _download(self) -> None: 127 | if self._check_exists(): 128 | return 129 | download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) 130 | 131 | def subsample_dataset(dataset, idxs): 132 | 133 | mask = np.zeros(len(dataset)).astype('bool') 134 | mask[idxs] = True 135 | 136 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs] 137 | dataset.uq_idxs = dataset.uq_idxs[mask] 138 | 139 | return dataset 140 | 141 | def subsample_classes(dataset, include_classes=range(60)): 142 | 143 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] 144 | 145 | # TODO: Don't transform targets for now 146 | target_xform_dict = {} 147 | for i, k in enumerate(include_classes): 148 | target_xform_dict[k] = i 149 | 150 | dataset = subsample_dataset(dataset, cls_idxs) 151 | 152 | dataset.target_transform = lambda x: target_xform_dict[x] 153 | 154 | return dataset 155 | 156 | def get_train_val_indices(train_dataset, val_split=0.2): 157 | 158 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)] 159 | train_classes = np.unique(all_targets) 160 | 161 | # Get train/test indices 162 | train_idxs = [] 163 | val_idxs = [] 164 | for cls in train_classes: 165 | cls_idxs = np.where(all_targets == cls)[0] 166 | 167 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 168 | t_ = [x for x in cls_idxs if x not in v_] 169 | 170 | train_idxs.extend(t_) 171 | val_idxs.extend(v_) 172 | 173 | return train_idxs, val_idxs 174 | 175 | def get_food_datasets(train_transform, test_transform, train_classes=range(51), prop_train_labels=0.8, 176 | split_train_val=False, seed=0): 177 | 178 | np.random.seed(seed) 179 | 180 | # Init entire training set 181 | whole_training_set = Food101(root=food_root, transform=train_transform, split='train', download=False) # len = 75750 182 | 183 | # Get labelled training set which has subsampled classes, then subsample some indices from that 184 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 185 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 186 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 187 | 188 | # Split into training and validation sets 189 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 190 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 191 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 192 | val_dataset_labelled_split.transform = test_transform 193 | 194 | # Get unlabelled data 195 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 196 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 197 | 198 | # Get test set for all classes 199 | test_dataset = Food101(root=food_root, transform=test_transform, split='test', download=False) # len = 25250 200 | 201 | # Either split train into train and val or use test set as val 202 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 203 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 204 | 205 | all_datasets = { 206 | 'train_labelled': train_dataset_labelled, 207 | 'train_unlabelled': train_dataset_unlabelled, 208 | 'val': val_dataset_labelled, 209 | 'test': test_dataset, 210 | } 211 | 212 | return all_datasets 213 | 214 | def main(): 215 | x = get_food_datasets(None, None, split_train_val=False) 216 | 217 | print('Printing lens...') 218 | for k, v in x.items(): 219 | if v is not None: 220 | print(f'{k}: {len(v)}') 221 | 222 | print('Printing labelled and unlabelled overlap...') 223 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 224 | print('Printing total instances in train...') 225 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 226 | print('Printing number of labelled classes...') 227 | print(len(set([i[1] for i in x['train_labelled'].samples]))) 228 | print('Printing total number of classes...') 229 | print(len(set([i[1] for i in x['train_unlabelled'].samples]))) 230 | 231 | 232 | if __name__ == '__main__': 233 | main() -------------------------------------------------------------------------------- /methods/clustering/extract_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import timm 4 | from torchvision import transforms 5 | import torchvision 6 | 7 | import argparse 8 | import os 9 | from tqdm import tqdm 10 | 11 | from data.stanford_cars import CarsDataset 12 | from data.cifar import CustomCIFAR10, CustomCIFAR100, cifar_10_root, cifar_100_root 13 | from data.augmentations import get_transform 14 | from data.imagenet import get_imagenet_100_datasets 15 | from data.data_utils import MergedDataset 16 | from data.cub import CustomCub2011, cub_root 17 | from data.fgvc_aircraft import FGVCAircraft, aircraft_root 18 | from data.pets import OxfordIIITPet, pets_root 19 | 20 | from project_utils.general_utils import strip_state_dict, str2bool 21 | from copy import deepcopy 22 | 23 | from config import feature_extract_dir, dino_pretrain_path 24 | 25 | def extract_features_dino(model, loader, save_dir): 26 | 27 | model.to(device) 28 | model.eval() 29 | 30 | with torch.no_grad(): 31 | for batch_idx, batch in enumerate(tqdm(loader)): 32 | 33 | images, labels, idxs = batch[:3] 34 | images = images.to(device) 35 | 36 | features = model(images) # CLS_Token for ViT, Average pooled vector for R50 37 | 38 | # Save features 39 | for f, t, uq in zip(features, labels, idxs): 40 | 41 | t = t.item() 42 | uq = uq.item() 43 | 44 | save_path = os.path.join(save_dir, f'{t}', f'{uq}.npy') 45 | torch.save(f.detach().cpu().numpy(), save_path) 46 | 47 | 48 | def extract_features_timm(model, loader, save_dir): 49 | 50 | model.to(device) 51 | model.eval() 52 | 53 | with torch.no_grad(): 54 | for batch_idx, batch in enumerate(tqdm(loader)): 55 | 56 | images, labels, idxs = batch[:3] 57 | images = images.to(device) 58 | 59 | features = model.forward_features(images) # CLS_Token for ViT, Average pooled vector for R50 60 | 61 | # Save features 62 | for f, t, uq in zip(features, labels, idxs): 63 | 64 | t = t.item() 65 | uq = uq.item() 66 | 67 | save_path = os.path.join(save_dir, f'{t}', f'{uq}.npy') 68 | torch.save(f.detach().cpu().numpy(), save_path) 69 | 70 | 71 | if __name__ == "__main__": 72 | 73 | parser = argparse.ArgumentParser( 74 | description='cluster', 75 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 76 | parser.add_argument('--batch_size', default=128, type=int) 77 | parser.add_argument('--num_workers', default=8, type=int) 78 | parser.add_argument('--root_dir', type=str, default=feature_extract_dir) 79 | parser.add_argument('--warmup_model_dir', type=str, 80 | default=None) 81 | parser.add_argument('--use_best_model', type=str2bool, default=True) 82 | parser.add_argument('--model_name', type=str, default='vit_dino', help='Format is {model_name}_{pretrain}') 83 | parser.add_argument('--dataset', type=str, default='aircraft', help='options: cifar10, cifar100, scars') 84 | parser.add_argument('--pretrain_model', type=str, default='dino') 85 | 86 | # ---------------------- 87 | # INIT 88 | # ---------------------- 89 | args = parser.parse_args() 90 | device = torch.device('cuda:0') 91 | 92 | args.save_dir = os.path.join(args.root_dir, f'{args.model_name}_{args.dataset}') 93 | print(args) 94 | 95 | print('Loading model...') 96 | # ---------------------- 97 | # MODEL 98 | # ---------------------- 99 | if args.model_name == 'vit_dino': 100 | 101 | extract_features_func = extract_features_dino 102 | args.interpolation = 3 103 | args.crop_pct = 0.875 104 | if args.pretrain_model == 'dino': 105 | pretrain_path = dino_pretrain_path 106 | 107 | model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16', pretrained=False) 108 | 109 | state_dict = torch.load(pretrain_path, map_location='cpu') 110 | model.load_state_dict(state_dict) 111 | 112 | _, val_transform = get_transform('imagenet', image_size=224, args=args) 113 | 114 | elif args.model_name == 'resnet50_dino': 115 | 116 | extract_features_func = extract_features_dino 117 | args.interpolation = 3 118 | args.crop_pct = 0.875 119 | pretrain_path = '/work/sagar/pretrained_models/dino/dino_resnet50_pretrain.pth' 120 | 121 | model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50', pretrained=False) 122 | 123 | state_dict = torch.load(pretrain_path, map_location='cpu') 124 | model.load_state_dict(state_dict) 125 | 126 | _, val_transform = get_transform('imagenet', image_size=224, args=args) 127 | 128 | else: 129 | 130 | raise NotImplementedError 131 | 132 | if args.warmup_model_dir is not None: 133 | 134 | warmup_id = args.warmup_model_dir.split('(')[1].split(')')[0] 135 | 136 | if args.use_best_model: 137 | args.warmup_model_dir = args.warmup_model_dir[:-3] + '_best.pt' 138 | args.save_dir += '_(' + args.warmup_model_dir.split('(')[1].split(')')[0] + ')_best' 139 | else: 140 | args.save_dir += '_(' + args.warmup_model_dir.split('(')[1].split(')')[0] + ')' 141 | 142 | print(f'Using weights from {args.warmup_model_dir} ...') 143 | state_dict = torch.load(args.warmup_model_dir) 144 | model.load_state_dict(state_dict) 145 | 146 | print(f'Saving to {args.save_dir}') 147 | 148 | print('Loading data...') 149 | # ---------------------- 150 | # DATASET 151 | # ---------------------- 152 | if args.dataset == 'cifar10': 153 | 154 | train_dataset = CustomCIFAR10(root=cifar_10_root, train=True, transform=val_transform) 155 | test_dataset = CustomCIFAR10(root=cifar_10_root, train=False, transform=val_transform) 156 | targets = list(set(train_dataset.targets)) 157 | 158 | elif args.dataset == 'cifar100': 159 | 160 | train_dataset = CustomCIFAR100(root=cifar_100_root, train=True, transform=val_transform) 161 | test_dataset = CustomCIFAR100(root=cifar_100_root, train=False, transform=val_transform) 162 | targets = list(set(train_dataset.targets)) 163 | 164 | elif args.dataset == 'scars': 165 | 166 | train_dataset = CarsDataset(train=True, transform=val_transform) 167 | test_dataset = CarsDataset(train=False, transform=val_transform) 168 | targets = list(set(train_dataset.target)) 169 | targets = [i - 1 for i in targets] # SCars are labelled 1 - 197. Change to 0 - 196 170 | 171 | elif args.dataset == 'imagenet_100': 172 | 173 | datasets = get_imagenet_100_datasets(train_transform=val_transform, test_transform=val_transform, 174 | train_classes=range(50), 175 | prop_train_labels=0.5) 176 | 177 | datasets['train_labelled'].target_transform = None 178 | datasets['train_unlabelled'].target_transform = None 179 | 180 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']), 181 | unlabelled_dataset=deepcopy(datasets['train_unlabelled'])) 182 | 183 | test_dataset = datasets['test'] 184 | targets = list(set(test_dataset.targets)) 185 | 186 | elif args.dataset == 'cub': 187 | 188 | train_dataset = CustomCub2011(root=cub_root, transform=val_transform, train=True) 189 | test_dataset = CustomCub2011(root=cub_root, transform=val_transform, train=False) 190 | targets = list(set(train_dataset.data.target.values)) 191 | targets = [i - 1 for i in targets] # SCars are labelled 1 - 200. Change to 0 - 199 192 | 193 | elif args.dataset == 'aircraft': 194 | 195 | train_dataset = FGVCAircraft(root=aircraft_root, transform=val_transform, split='trainval') 196 | test_dataset = FGVCAircraft(root=aircraft_root, transform=val_transform, split='test') 197 | targets = list(set([s[1] for s in train_dataset.samples])) 198 | 199 | elif args.dataset == 'pets': 200 | 201 | train_dataset = OxfordIIITPet(root=pets_root, transform=val_transform, split='trainval') 202 | test_dataset = OxfordIIITPet(root=pets_root, transform=val_transform, split='test') 203 | targets = list(set([s[1] for s in train_dataset.samples])) 204 | 205 | else: 206 | 207 | raise NotImplementedError 208 | 209 | # ---------------------- 210 | # DATALOADER 211 | # ---------------------- 212 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 213 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 214 | 215 | print('Creating base directories...') 216 | # ---------------------- 217 | # INIT SAVE DIRS 218 | # Create a directory for each class 219 | # ---------------------- 220 | if not os.path.exists(args.save_dir): 221 | os.makedirs(args.save_dir) 222 | 223 | for fold in ('train', 'test'): 224 | 225 | fold_dir = os.path.join(args.save_dir, fold) 226 | if not os.path.exists(fold_dir): 227 | os.mkdir(fold_dir) 228 | 229 | for t in targets: 230 | target_dir = os.path.join(fold_dir, f'{t}') 231 | if not os.path.exists(target_dir): 232 | os.mkdir(target_dir) 233 | 234 | # ---------------------- 235 | # EXTRACT FEATURES 236 | # ---------------------- 237 | # Extract train features 238 | train_save_dir = os.path.join(args.save_dir, 'train') 239 | print('Extracting features from train split...') 240 | extract_features_func(model=model, loader=train_loader, save_dir=train_save_dir) 241 | 242 | # Extract test features 243 | test_save_dir = os.path.join(args.save_dir, 'test') 244 | print('Extracting features from test split...') 245 | extract_features_func(model=model, loader=test_loader, save_dir=test_save_dir) 246 | 247 | print('Done!') -------------------------------------------------------------------------------- /data/pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import pathlib 4 | from typing import Any, Callable, Optional, Union, Tuple 5 | from typing import Sequence 6 | 7 | import numpy as np 8 | from copy import deepcopy 9 | from scipy import io as mat_io 10 | 11 | from data.data_utils import subsample_instances 12 | from config import pets_root 13 | 14 | from torchvision.datasets.folder import default_loader 15 | from torch.utils.data import Dataset 16 | 17 | from data.data_utils import subsample_instances 18 | 19 | 20 | def make_dataset(dir, image_ids, targets): 21 | assert(len(image_ids) == len(targets)) 22 | images = [] 23 | dir = os.path.expanduser(dir) 24 | for i in range(len(image_ids)): 25 | item = (os.path.join(dir, 'images', 26 | '%s.jpg' % image_ids[i]), targets[i]) 27 | images.append(item) 28 | return images 29 | 30 | class OxfordIIITPet(Dataset): 31 | """`Oxford-IIIT Pet Dataset `_. 32 | 33 | Args: 34 | root (string): Root directory of the dataset. 35 | split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``. 36 | target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or 37 | ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent: 38 | 39 | - ``category`` (int): Label for one of the 37 pet categories. 40 | - ``segmentation`` (PIL image): Segmentation trimap of the image. 41 | 42 | If empty, ``None`` will be returned as target. 43 | 44 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed 45 | version. E.g, ``transforms.RandomCrop``. 46 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 47 | download (bool, optional): If True, downloads the dataset from the internet and puts it into 48 | ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again. 49 | """ 50 | 51 | _RESOURCES = ( 52 | ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"), 53 | ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"), 54 | ) 55 | valid_target_types = ("category", "segmentation") 56 | splits = ('train', 'val', 'trainval', 'test') 57 | 58 | def __init__( 59 | self, 60 | root: str, 61 | split: str = "trainval", 62 | target_types: Union[Sequence[str], str] = "category", 63 | transforms: Optional[Callable] = None, 64 | transform: Optional[Callable] = None, 65 | target_transform: Optional[Callable] = None, 66 | download: bool = False, 67 | loader=default_loader 68 | ): 69 | if split not in self.splits: 70 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 71 | split, ', '.join(self.splits), 72 | )) 73 | self.split = split 74 | 75 | if isinstance(target_types, str): 76 | target_types = [target_types] 77 | 78 | self.root = root 79 | self.transform = transform 80 | self.target_transform = target_transform 81 | self.loader = loader 82 | 83 | self._base_folder = pathlib.Path(self.root) 84 | self._images_folder = self._base_folder / "images" 85 | self._anns_folder = self._base_folder / "annotations" 86 | self._segs_folder = self._anns_folder / "trimaps" 87 | 88 | if download: 89 | self._download() 90 | 91 | if not self._check_exists(): 92 | raise RuntimeError("Dataset not found. You can use download=True to download it") 93 | 94 | image_ids = [] 95 | self._labels = [] 96 | with open(self._anns_folder / f"{self.split}.txt") as file: 97 | for line in file: 98 | image_id, label, *_ = line.strip().split() 99 | image_ids.append(image_id) 100 | self._labels.append(int(label) - 1) 101 | 102 | self.classes = [ 103 | " ".join(part.title() for part in raw_cls.split("_")) 104 | for raw_cls, _ in sorted( 105 | {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)}, 106 | key=lambda image_id_and_label: image_id_and_label[1], 107 | ) 108 | ] 109 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 110 | 111 | samples = make_dataset(self.root, image_ids, self._labels) 112 | self.samples = samples 113 | self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids] 114 | self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids] 115 | 116 | self.uq_idxs = np.array(range(len(self))) 117 | 118 | def __len__(self) -> int: 119 | return len(self.samples) 120 | 121 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 122 | """ 123 | Args: 124 | index (int): Index 125 | 126 | Returns: 127 | tuple: (sample, target) where target is class_index of the target class. 128 | """ 129 | 130 | path, target = self.samples[idx] 131 | sample = self.loader(path) 132 | if self.transform is not None: 133 | sample = self.transform(sample) 134 | if self.target_transform is not None: 135 | target = self.target_transform(target) 136 | 137 | return sample, target, self.uq_idxs[idx] 138 | 139 | def _check_exists(self) -> bool: 140 | for folder in (self._images_folder, self._anns_folder): 141 | if not (os.path.exists(folder) and os.path.isdir(folder)): 142 | return False 143 | else: 144 | return True 145 | 146 | def _download(self) -> None: 147 | if self._check_exists(): 148 | return 149 | 150 | for url, md5 in self._RESOURCES: 151 | download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5) 152 | 153 | def subsample_dataset(dataset, idxs): 154 | 155 | mask = np.zeros(len(dataset)).astype('bool') 156 | mask[idxs] = True 157 | 158 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs] 159 | dataset.uq_idxs = dataset.uq_idxs[mask] 160 | 161 | return dataset 162 | 163 | def subsample_classes(dataset, include_classes=range(60)): 164 | 165 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] # 1885 166 | 167 | # TODO: Don't transform targets for now 168 | target_xform_dict = {} 169 | for i, k in enumerate(include_classes): 170 | target_xform_dict[k] = i 171 | 172 | dataset = subsample_dataset(dataset, cls_idxs) 173 | 174 | dataset.target_transform = lambda x: target_xform_dict[x] 175 | 176 | return dataset 177 | 178 | def get_train_val_indices(train_dataset, val_split=0.2): 179 | 180 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)] 181 | train_classes = np.unique(all_targets) 182 | 183 | # Get train/test indices 184 | train_idxs = [] 185 | val_idxs = [] 186 | for cls in train_classes: 187 | cls_idxs = np.where(all_targets == cls)[0] 188 | 189 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 190 | t_ = [x for x in cls_idxs if x not in v_] 191 | 192 | train_idxs.extend(t_) 193 | val_idxs.extend(v_) 194 | 195 | return train_idxs, val_idxs 196 | 197 | def get_pets_datasets(train_transform, test_transform, train_classes=range(19), prop_train_labels=0.8, 198 | split_train_val=False, seed=0): 199 | 200 | np.random.seed(seed) 201 | 202 | # Init entire training set 203 | whole_training_set = OxfordIIITPet(root=pets_root, transform=train_transform, split='trainval', download=False) # len = 3680 204 | 205 | # Get labelled training set which has subsampled classes, then subsample some indices from that 206 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 207 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 208 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 209 | 210 | # Split into training and validation sets 211 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 212 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 213 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 214 | val_dataset_labelled_split.transform = test_transform 215 | 216 | # Get unlabelled data 217 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 218 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 219 | 220 | # Get test set for all classes 221 | test_dataset = OxfordIIITPet(root=pets_root, transform=test_transform, split='test', download=False) 222 | 223 | # Either split train into train and val or use test set as val 224 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 225 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 226 | 227 | all_datasets = { 228 | 'train_labelled': train_dataset_labelled, 229 | 'train_unlabelled': train_dataset_unlabelled, 230 | 'val': val_dataset_labelled, 231 | 'test': test_dataset, 232 | } 233 | 234 | return all_datasets 235 | 236 | def main(): 237 | x = get_pets_datasets(None, None, split_train_val=False) 238 | 239 | print('Printing lens...') 240 | for k, v in x.items(): 241 | if v is not None: 242 | print(f'{k}: {len(v)}') 243 | 244 | print('Printing labelled and unlabelled overlap...') 245 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 246 | print('Printing total instances in train...') 247 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 248 | print('Printing number of labelled classes...') 249 | print(len(set([i[1] for i in x['train_labelled'].samples]))) 250 | print('Printing total number of classes...') 251 | print(len(set([i[1] for i in x['train_unlabelled'].samples]))) 252 | 253 | 254 | if __name__ == '__main__': 255 | main() -------------------------------------------------------------------------------- /data/flower.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from path import Path 4 | import pathlib 5 | from typing import Any, Callable, Optional, Union, Tuple 6 | 7 | import numpy as np 8 | from copy import deepcopy 9 | from scipy import io as mat_io 10 | from scipy.io import loadmat 11 | 12 | from data.data_utils import subsample_instances 13 | from config import flower_root 14 | 15 | from torchvision.datasets.utils import check_integrity 16 | from torchvision.datasets.folder import default_loader 17 | from torch.utils.data import Dataset 18 | 19 | from data.data_utils import subsample_instances 20 | 21 | 22 | def make_dataset(dir, image_ids, targets): 23 | assert(len(image_ids) == len(targets)) 24 | images = [] 25 | dir = os.path.expanduser(dir) 26 | for i in range(len(image_ids)): 27 | item = (os.path.join(dir, 'jpg', 28 | '%s' % image_ids[i]), targets[i]) 29 | images.append(item) 30 | return images 31 | 32 | class Flowers102(Dataset): 33 | """`Oxford 102 Flower `_ Dataset. 34 | 35 | .. warning:: 36 | 37 | This class needs `scipy `_ to load target files from `.mat` format. 38 | 39 | Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The 40 | flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of 41 | between 40 and 258 images. 42 | 43 | The images have large scale, pose and light variations. In addition, there are categories that 44 | have large variations within the category, and several very similar categories. 45 | 46 | Args: 47 | root (string): Root directory of the dataset. 48 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. 49 | transform (callable, optional): A function/transform that takes in an PIL image and returns a 50 | transformed version. E.g, ``transforms.RandomCrop``. 51 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 52 | download (bool, optional): If true, downloads the dataset from the internet and 53 | puts it in root directory. If dataset is already downloaded, it is not 54 | downloaded again. 55 | """ 56 | 57 | _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/" 58 | _file_dict = { # filename, md5 59 | "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"), 60 | "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"), 61 | "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"), 62 | } 63 | _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"} 64 | 65 | def __init__( 66 | self, 67 | root: str, 68 | split: str = "train", 69 | transform: Optional[Callable] = None, 70 | target_transform: Optional[Callable] = None, 71 | download: bool = False, 72 | loader=default_loader 73 | ) -> None: 74 | self.splits = ('train', 'val', 'trainval', 'test') 75 | 76 | if split not in self.splits: 77 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 78 | split, ', '.join(self.splits), 79 | )) 80 | self.split = split 81 | self.root = root 82 | 83 | self.transform = transform 84 | self.target_transform = target_transform 85 | self.loader = loader 86 | 87 | self._base_folder = pathlib.Path(self.root) 88 | self._images_folder = self._base_folder / "jpg" 89 | 90 | if download: 91 | self.download() 92 | 93 | if not self._check_integrity(): 94 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 95 | 96 | if self.split == 'trainval': 97 | set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True) 98 | image_ids_train = set_ids[self._splits_map['train']].tolist() 99 | image_ids_val = set_ids[self._splits_map['val']].tolist() 100 | image_ids = image_ids_train + image_ids_val 101 | 102 | else: 103 | set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True) 104 | image_ids = set_ids[self._splits_map[self.split]].tolist() 105 | 106 | labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True) 107 | image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1)) 108 | 109 | self._labels = [] 110 | self._image_files = [] 111 | for image_id in image_ids: 112 | self._labels.append(image_id_to_label[image_id]) 113 | self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg") 114 | 115 | image_name = [f"image_{image_id:05d}.jpg" for image_id in image_ids] 116 | 117 | samples = make_dataset(self.root, image_name, self._labels) # 2040 118 | self.samples = samples 119 | self.uq_idxs = np.array(range(len(self))) 120 | 121 | 122 | def __len__(self) -> int: 123 | return len(self.samples) 124 | 125 | def __getitem__(self, idx) -> Tuple[Any, Any]: 126 | """ 127 | Args: 128 | index (int): Index 129 | 130 | Returns: 131 | tuple: (sample, target) where target is class_index of the target class. 132 | """ 133 | 134 | path, target = self.samples[idx] 135 | sample = self.loader(path) 136 | if self.transform is not None: 137 | sample = self.transform(sample) 138 | if self.target_transform is not None: 139 | target = self.target_transform(target) 140 | 141 | return sample, target, self.uq_idxs[idx] 142 | 143 | def extra_repr(self) -> str: 144 | return f"split={self._split}" 145 | 146 | def _check_integrity(self): 147 | if not (self._images_folder.exists() and self._images_folder.is_dir()): 148 | return False 149 | 150 | for id in ["label", "setid"]: 151 | filename, md5 = self._file_dict[id] 152 | if not check_integrity(str(self._base_folder / filename), md5): 153 | return False 154 | return True 155 | 156 | def download(self): 157 | if self._check_integrity(): 158 | return 159 | download_and_extract_archive( 160 | f"{self._download_url_prefix}{self._file_dict['image'][0]}", 161 | str(self._base_folder), 162 | md5=self._file_dict["image"][1], 163 | ) 164 | for id in ["label", "setid"]: 165 | filename, md5 = self._file_dict[id] 166 | download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5) 167 | 168 | def subsample_dataset(dataset, idxs): 169 | 170 | mask = np.zeros(len(dataset)).astype('bool') 171 | mask[idxs] = True 172 | 173 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs] 174 | dataset.uq_idxs = dataset.uq_idxs[mask] 175 | 176 | return dataset 177 | 178 | def subsample_classes(dataset, include_classes=range(60)): 179 | 180 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] # 1885 181 | 182 | # TODO: Don't transform targets for now 183 | target_xform_dict = {} 184 | for i, k in enumerate(include_classes): 185 | target_xform_dict[k] = i 186 | 187 | dataset = subsample_dataset(dataset, cls_idxs) 188 | 189 | dataset.target_transform = lambda x: target_xform_dict[x] 190 | 191 | return dataset 192 | 193 | def get_train_val_indices(train_dataset, val_split=0.2): 194 | 195 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)] 196 | train_classes = np.unique(all_targets) 197 | 198 | # Get train/test indices 199 | train_idxs = [] 200 | val_idxs = [] 201 | for cls in train_classes: 202 | cls_idxs = np.where(all_targets == cls)[0] 203 | 204 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 205 | t_ = [x for x in cls_idxs if x not in v_] 206 | 207 | train_idxs.extend(t_) 208 | val_idxs.extend(v_) 209 | 210 | return train_idxs, val_idxs 211 | 212 | def get_flower_datasets(train_transform, test_transform, train_classes=range(51), prop_train_labels=0.8, 213 | split_train_val=False, seed=0): 214 | 215 | np.random.seed(seed) 216 | 217 | # Init entire training set 218 | whole_training_set = Flowers102(root=flower_root, transform=train_transform, split='test', download=False) # len = 6149 219 | 220 | # Get labelled training set which has subsampled classes, then subsample some indices from that 221 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 222 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 223 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 224 | 225 | # Split into training and validation sets 226 | train_dataset_labelled_split = Flowers102(root=flower_root, transform=train_transform, split='train', download=False) 227 | val_dataset_labelled_split = Flowers102(root=flower_root, transform=train_transform, split='val', download=False) 228 | val_dataset_labelled_split.transform = test_transform 229 | 230 | # Get unlabelled data 231 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 232 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 233 | 234 | # Get test set for all classes 235 | test_dataset = Flowers102(root=flower_root, transform=test_transform, split='trainval', download=False) # len = 2040 236 | 237 | # Either split train into train and val or use test set as val 238 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 239 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 240 | 241 | all_datasets = { 242 | 'train_labelled': train_dataset_labelled, 243 | 'train_unlabelled': train_dataset_unlabelled, 244 | 'val': val_dataset_labelled, 245 | 'test': test_dataset, 246 | } 247 | 248 | return all_datasets 249 | 250 | def main(): 251 | x = get_flower_datasets(None, None, split_train_val=False) 252 | 253 | print('Printing lens...') 254 | for k, v in x.items(): 255 | if v is not None: 256 | print(f'{k}: {len(v)}') 257 | 258 | print('Printing labelled and unlabelled overlap...') 259 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 260 | print('Printing total instances in train...') 261 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 262 | print('Printing number of labelled classes...') 263 | print(len(set([i[1] for i in x['train_labelled'].samples]))) 264 | print('Printing total number of classes...') 265 | print(len(set([i[1] for i in x['train_unlabelled'].samples]))) 266 | 267 | 268 | if __name__ == '__main__': 269 | main() -------------------------------------------------------------------------------- /data/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from torchvision.datasets.folder import default_loader 7 | from torch.utils.data import Dataset 8 | 9 | from data.data_utils import subsample_instances 10 | from config import aircraft_root 11 | 12 | def make_dataset(dir, image_ids, targets): 13 | assert(len(image_ids) == len(targets)) 14 | images = [] 15 | dir = os.path.expanduser(dir) 16 | for i in range(len(image_ids)): 17 | item = (os.path.join(dir, 'data', 'images', 18 | '%s.jpg' % image_ids[i]), targets[i]) 19 | images.append(item) 20 | return images 21 | 22 | 23 | def find_classes(classes_file): 24 | 25 | # read classes file, separating out image IDs and class names 26 | image_ids = [] 27 | targets = [] 28 | f = open(classes_file, 'r') 29 | for line in f: 30 | split_line = line.split(' ') 31 | image_ids.append(split_line[0]) 32 | targets.append(' '.join(split_line[1:])) 33 | f.close() 34 | 35 | # index class names 36 | classes = np.unique(targets) 37 | class_to_idx = {classes[i]: i for i in range(len(classes))} 38 | targets = [class_to_idx[c] for c in targets] 39 | 40 | return (image_ids, targets, classes, class_to_idx) 41 | 42 | 43 | class FGVCAircraft(Dataset): 44 | 45 | """`FGVC-Aircraft `_ Dataset. 46 | 47 | Args: 48 | root (string): Root directory path to dataset. 49 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification 50 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). 51 | transform (callable, optional): A function/transform that takes in a PIL image 52 | and returns a transformed version. E.g. ``transforms.RandomCrop`` 53 | target_transform (callable, optional): A function/transform that takes in the 54 | target and transforms it. 55 | loader (callable, optional): A function to load an image given its path. 56 | download (bool, optional): If true, downloads the dataset from the internet and 57 | puts it in the root directory. If dataset is already downloaded, it is not 58 | downloaded again. 59 | """ 60 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' 61 | class_types = ('variant', 'family', 'manufacturer') 62 | splits = ('train', 'val', 'trainval', 'test') 63 | 64 | def __init__(self, root, class_type='variant', split='train', transform=None, 65 | target_transform=None, loader=default_loader, download=False): 66 | if split not in self.splits: 67 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 68 | split, ', '.join(self.splits), 69 | )) 70 | if class_type not in self.class_types: 71 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( 72 | class_type, ', '.join(self.class_types), 73 | )) 74 | self.root = os.path.expanduser(root) 75 | self.class_type = class_type 76 | self.split = split 77 | self.classes_file = os.path.join(self.root, 'data', 78 | 'images_%s_%s.txt' % (self.class_type, self.split)) 79 | 80 | if download: 81 | self.download() 82 | 83 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) # 6667, 6667, 100, 100 84 | samples = make_dataset(self.root, image_ids, targets) 85 | self.transform = transform 86 | self.target_transform = target_transform 87 | self.loader = loader 88 | 89 | self.samples = samples 90 | self.classes = classes 91 | self.class_to_idx = class_to_idx 92 | self.train = True if split == 'train' else False 93 | 94 | self.uq_idxs = np.array(range(len(self))) 95 | 96 | def __getitem__(self, index): 97 | """ 98 | Args: 99 | index (int): Index 100 | 101 | Returns: 102 | tuple: (sample, target) where target is class_index of the target class. 103 | """ 104 | 105 | path, target = self.samples[index] 106 | sample = self.loader(path) 107 | if self.transform is not None: 108 | sample = self.transform(sample) 109 | if self.target_transform is not None: 110 | target = self.target_transform(target) 111 | 112 | return sample, target, self.uq_idxs[index] 113 | 114 | def __len__(self): 115 | return len(self.samples) 116 | 117 | def __repr__(self): 118 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 119 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 120 | fmt_str += ' Root Location: {}\n'.format(self.root) 121 | tmp = ' Transforms (if any): ' 122 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 123 | tmp = ' Target Transforms (if any): ' 124 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 125 | return fmt_str 126 | 127 | def _check_exists(self): 128 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ 129 | os.path.exists(self.classes_file) 130 | 131 | def download(self): 132 | """Download the FGVC-Aircraft data if it doesn't exist already.""" 133 | from six.moves import urllib 134 | import tarfile 135 | 136 | if self._check_exists(): 137 | return 138 | 139 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz 140 | print('Downloading %s ... (may take a few minutes)' % self.url) 141 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir)) 142 | tar_name = self.url.rpartition('/')[-1] 143 | tar_path = os.path.join(parent_dir, tar_name) 144 | data = urllib.request.urlopen(self.url) 145 | 146 | # download .tar.gz file 147 | with open(tar_path, 'wb') as f: 148 | f.write(data.read()) 149 | 150 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b 151 | data_folder = tar_path.strip('.tar.gz') 152 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder)) 153 | tar = tarfile.open(tar_path) 154 | tar.extractall(parent_dir) 155 | 156 | # if necessary, rename data folder to self.root 157 | if not os.path.samefile(data_folder, self.root): 158 | print('Renaming %s to %s ...' % (data_folder, self.root)) 159 | os.rename(data_folder, self.root) 160 | 161 | # delete .tar.gz file 162 | print('Deleting %s ...' % tar_path) 163 | os.remove(tar_path) 164 | 165 | print('Done!') 166 | 167 | 168 | def subsample_dataset(dataset, idxs): 169 | 170 | mask = np.zeros(len(dataset)).astype('bool') 171 | mask[idxs] = True 172 | 173 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs] 174 | dataset.uq_idxs = dataset.uq_idxs[mask] 175 | 176 | return dataset 177 | 178 | 179 | def subsample_classes(dataset, include_classes=range(60)): 180 | 181 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] 182 | 183 | # TODO: Don't transform targets for now 184 | target_xform_dict = {} 185 | for i, k in enumerate(include_classes): 186 | target_xform_dict[k] = i 187 | 188 | dataset = subsample_dataset(dataset, cls_idxs) 189 | 190 | dataset.target_transform = lambda x: target_xform_dict[x] 191 | 192 | return dataset 193 | 194 | 195 | def get_train_val_indices(train_dataset, val_split=0.2): 196 | 197 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)] 198 | train_classes = np.unique(all_targets) 199 | 200 | # Get train/test indices 201 | train_idxs = [] 202 | val_idxs = [] 203 | for cls in train_classes: 204 | cls_idxs = np.where(all_targets == cls)[0] 205 | 206 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 207 | t_ = [x for x in cls_idxs if x not in v_] 208 | 209 | train_idxs.extend(t_) 210 | val_idxs.extend(v_) 211 | 212 | return train_idxs, val_idxs 213 | 214 | 215 | def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8, 216 | split_train_val=False, seed=0): 217 | 218 | np.random.seed(seed) 219 | 220 | # Init entire training set 221 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval') # len = 6667 222 | 223 | # Get labelled training set which has subsampled classes, then subsample some indices from that 224 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 225 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 226 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 227 | 228 | # Split into training and validation sets 229 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 230 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 231 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 232 | val_dataset_labelled_split.transform = test_transform 233 | 234 | # Get unlabelled data 235 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 236 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 237 | 238 | # Get test set for all classes 239 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test') 240 | 241 | # Either split train into train and val or use test set as val 242 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 243 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 244 | 245 | all_datasets = { 246 | 'train_labelled': train_dataset_labelled, 247 | 'train_unlabelled': train_dataset_unlabelled, 248 | 'val': val_dataset_labelled, 249 | 'test': test_dataset, 250 | } 251 | 252 | return all_datasets 253 | 254 | 255 | def main(): 256 | x = get_aircraft_datasets(None, None, split_train_val=False) 257 | 258 | print('Printing lens...') 259 | for k, v in x.items(): 260 | if v is not None: 261 | print(f'{k}: {len(v)}') 262 | 263 | print('Printing labelled and unlabelled overlap...') 264 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 265 | print('Printing total instances in train...') 266 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 267 | print('Printing number of labelled classes...') 268 | print(len(set([i[1] for i in x['train_labelled'].samples]))) 269 | print('Printing total number of classes...') 270 | print(len(set([i[1] for i in x['train_unlabelled'].samples]))) 271 | 272 | if __name__ == '__main__': 273 | main() 274 | -------------------------------------------------------------------------------- /project_utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | import inspect 6 | 7 | from torch.utils.tensorboard import SummaryWriter 8 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 9 | 10 | from datetime import datetime 11 | 12 | class AverageMeter(object): 13 | """Computes and stores the average and current value""" 14 | def __init__(self): 15 | self.reset() 16 | 17 | def reset(self): 18 | 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | def seed_torch(seed=1029): 33 | 34 | random.seed(seed) 35 | os.environ['PYTHONHASHSEED'] = str(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | torch.cuda.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 40 | torch.backends.cudnn.benchmark = False 41 | torch.backends.cudnn.deterministic = True 42 | 43 | 44 | def strip_state_dict(state_dict, strip_key='module.'): 45 | 46 | """ 47 | Strip 'module' from start of state_dict keys 48 | Useful if model has been trained as DataParallel model 49 | """ 50 | 51 | for k in list(state_dict.keys()): 52 | if k.startswith(strip_key): 53 | state_dict[k[len(strip_key):]] = state_dict[k] 54 | del state_dict[k] 55 | 56 | return state_dict 57 | 58 | 59 | def get_dino_head_weights(pretrain_path): 60 | 61 | """ 62 | :param pretrain_path: Path to full DINO pretrained checkpoint as in https://github.com/facebookresearch/dino 63 | 'full_ckpt' 64 | :return: weights only for the projection head 65 | """ 66 | 67 | all_weights = torch.load(pretrain_path) 68 | 69 | head_state_dict = {} 70 | for k, v in all_weights['teacher'].items(): 71 | if 'head' in k and 'last_layer' not in k: 72 | head_state_dict[k] = v 73 | 74 | head_state_dict = strip_state_dict(head_state_dict, strip_key='head.') 75 | 76 | # Deal with weight norm 77 | weight_norm_state_dict = {} 78 | for k, v in all_weights['teacher'].items(): 79 | if 'last_layer' in k: 80 | weight_norm_state_dict[k.split('.')[2]] = v 81 | 82 | linear_shape = weight_norm_state_dict['weight'].shape 83 | dummy_linear = torch.nn.Linear(in_features=linear_shape[1], out_features=linear_shape[0], bias=False) 84 | dummy_linear.load_state_dict(weight_norm_state_dict) 85 | dummy_linear = torch.nn.utils.weight_norm(dummy_linear) 86 | 87 | for k, v in dummy_linear.state_dict().items(): 88 | 89 | head_state_dict['last_layer.' + k] = v 90 | 91 | return head_state_dict 92 | 93 | def transform_moco_state_dict(obj, num_classes): 94 | 95 | """ 96 | :param obj: Moco State Dict 97 | :param args: argsparse object with training classes 98 | :return: State dict compatable with standard resnet architecture 99 | """ 100 | 101 | newmodel = {} 102 | for k, v in obj.items(): 103 | if not k.startswith("module.encoder_q."): 104 | continue 105 | old_k = k 106 | k = k.replace("module.encoder_q.", "") 107 | 108 | if k.startswith("fc.2"): 109 | continue 110 | 111 | if k.startswith("fc.0"): 112 | k = k.replace("0.", "") 113 | if "weight" in k: 114 | v = torch.randn((num_classes, v.size(1))) 115 | elif "bias" in k: 116 | v = torch.randn((num_classes,)) 117 | 118 | newmodel[k] = v 119 | 120 | return newmodel 121 | 122 | 123 | def init_experiment(args, runner_name=None, exp_id=None): 124 | 125 | args.cuda = torch.cuda.is_available() 126 | 127 | # Get filepath of calling script 128 | if runner_name is None: 129 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:] 130 | 131 | root_dir = os.path.join(args.exp_root, *runner_name) 132 | 133 | if not os.path.exists(root_dir): 134 | os.makedirs(root_dir) 135 | 136 | # Either generate a unique experiment ID, or use one which is passed 137 | if exp_id is None: 138 | 139 | # Unique identifier for experiment 140 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \ 141 | datetime.now().strftime("%S.%f")[:-3] + ')' 142 | 143 | log_dir = os.path.join(root_dir, 'log', now) 144 | while os.path.exists(log_dir): 145 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \ 146 | datetime.now().strftime("%S.%f")[:-3] + ')' 147 | 148 | log_dir = os.path.join(root_dir, 'log', now) 149 | 150 | else: 151 | 152 | log_dir = os.path.join(root_dir, 'log', f'{exp_id}') 153 | 154 | if not os.path.exists(log_dir): 155 | os.makedirs(log_dir) 156 | args.log_dir = log_dir 157 | 158 | # Instantiate directory to save models to 159 | model_root_dir = os.path.join(args.log_dir, 'checkpoints') 160 | if not os.path.exists(model_root_dir): 161 | os.mkdir(model_root_dir) 162 | 163 | args.model_dir = model_root_dir 164 | args.model_path = os.path.join(args.model_dir, 'model.pt') 165 | 166 | print(f'Experiment saved to: {args.log_dir}') 167 | 168 | args.writer = SummaryWriter(log_dir=args.log_dir) 169 | 170 | hparam_dict = {} 171 | 172 | for k, v in vars(args).items(): 173 | if isinstance(v, (int, float, str, bool, torch.Tensor)): 174 | hparam_dict[k] = v 175 | 176 | args.writer.add_hparams(hparam_dict=hparam_dict, metric_dict={}) 177 | 178 | print(runner_name) 179 | print(args) 180 | 181 | return args 182 | 183 | 184 | def accuracy(output, target, topk=(1,)): 185 | """Computes the accuracy over the k top predictions for the specified values of k""" 186 | with torch.no_grad(): 187 | maxk = max(topk) 188 | batch_size = target.size(0) 189 | 190 | _, pred = output.topk(maxk, 1, True, True) 191 | pred = pred.t() 192 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 193 | 194 | res = [] 195 | for k in topk: 196 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 197 | res.append(correct_k.mul_(100.0 / batch_size)) 198 | return res 199 | 200 | 201 | def str2bool(v): 202 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 203 | return True 204 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 205 | return False 206 | else: 207 | raise argparse.ArgumentTypeError('Boolean value expected.') 208 | 209 | 210 | class ClassificationPredSaver(object): 211 | 212 | def __init__(self, length, save_path=None): 213 | 214 | if save_path is not None: 215 | 216 | # Remove filetype from save_path 217 | save_path = save_path.split('.')[0] 218 | self.save_path = save_path 219 | 220 | self.length = length 221 | 222 | self.all_preds = None 223 | self.all_labels = None 224 | 225 | self.running_start_idx = 0 226 | 227 | def update(self, preds, labels=None): 228 | 229 | # Expect preds in shape B x C 230 | 231 | if torch.is_tensor(preds): 232 | preds = preds.detach().cpu().numpy() 233 | 234 | b, c = preds.shape 235 | 236 | if self.all_preds is None: 237 | self.all_preds = np.zeros((self.length, c)) 238 | 239 | self.all_preds[self.running_start_idx: self.running_start_idx + b] = preds 240 | 241 | if labels is not None: 242 | if torch.is_tensor(labels): 243 | labels = labels.detach().cpu().numpy() 244 | 245 | if self.all_labels is None: 246 | self.all_labels = np.zeros((self.length,)) 247 | 248 | self.all_labels[self.running_start_idx: self.running_start_idx + b] = labels 249 | 250 | # Maintain running index on dataset being evaluated 251 | self.running_start_idx += b 252 | 253 | def save(self): 254 | 255 | # Softmax over preds 256 | preds = torch.from_numpy(self.all_preds) 257 | preds = torch.nn.Softmax(dim=-1)(preds) 258 | self.all_preds = preds.numpy() 259 | 260 | pred_path = self.save_path + '.pth' 261 | print(f'Saving all predictions to {pred_path}') 262 | 263 | torch.save(self.all_preds, pred_path) 264 | 265 | if self.all_labels is not None: 266 | 267 | # Evaluate 268 | self.evaluate() 269 | torch.save(self.all_labels, self.save_path + '_labels.pth') 270 | 271 | def evaluate(self): 272 | 273 | topk = [1, 5, 10] 274 | topk = [k for k in topk if k < self.all_preds.shape[-1]] 275 | acc = accuracy(torch.from_numpy(self.all_preds), torch.from_numpy(self.all_labels), topk=topk) 276 | 277 | for k, a in zip(topk, acc): 278 | print(f'Top{k} Acc: {a.item()}') 279 | 280 | 281 | def get_acc_auroc_curves(logdir): 282 | 283 | """ 284 | :param logdir: Path to logs: E.g '/work/sagar/open_set_recognition/methods/ARPL/log/(12.03.2021_|_32.570)/' 285 | :return: 286 | """ 287 | 288 | event_acc = EventAccumulator(logdir) 289 | event_acc.Reload() 290 | 291 | # Only gets scalars 292 | log_info = {} 293 | for tag in event_acc.Tags()['scalars']: 294 | 295 | log_info[tag] = np.array([[x.step, x.value] for x in event_acc.scalars._buckets[tag].items]) 296 | 297 | return log_info 298 | 299 | 300 | def get_mean_lr(optimizer): 301 | return torch.mean(torch.Tensor([param_group['lr'] for param_group in optimizer.param_groups])).item() 302 | 303 | 304 | class IndicatePlateau(object): 305 | 306 | def __init__(self, threshold=5e-4, patience_epochs=5, mode='min', threshold_mode='rel'): 307 | 308 | self.patience = patience_epochs 309 | self.cooldown_counter = 0 310 | self.mode = mode 311 | self.threshold = threshold 312 | self.threshold_mode = threshold_mode 313 | self.best = None 314 | self.num_bad_epochs = None 315 | self.mode_worse = None # the worse value for the chosen mode 316 | self.last_epoch = 0 317 | self._init_is_better(mode=mode, threshold=threshold, 318 | threshold_mode=threshold_mode) 319 | 320 | self._init_is_better(mode=mode, threshold=threshold, 321 | threshold_mode=threshold_mode) 322 | self._reset() 323 | 324 | def _reset(self): 325 | """Resets num_bad_epochs counter and cooldown counter.""" 326 | self.best = self.mode_worse 327 | self.cooldown_counter = 0 328 | self.num_bad_epochs = 0 329 | 330 | def step(self, metrics, epoch=None): 331 | # convert `metrics` to float, in case it's a zero-dim Tensor 332 | current = float(metrics) 333 | self.last_epoch = epoch 334 | 335 | if self.is_better(current, self.best): 336 | self.best = current 337 | self.num_bad_epochs = 0 338 | else: 339 | self.num_bad_epochs += 1 340 | 341 | if self.num_bad_epochs > self.patience: 342 | print('Tracked metric has plateaud') 343 | self._reset() 344 | return True 345 | else: 346 | return False 347 | 348 | def is_better(self, a, best): 349 | 350 | if self.mode == 'min' and self.threshold_mode == 'rel': 351 | rel_epsilon = 1. - self.threshold 352 | return a < best * rel_epsilon 353 | 354 | elif self.mode == 'min' and self.threshold_mode == 'abs': 355 | return a < best - self.threshold 356 | 357 | elif self.mode == 'max' and self.threshold_mode == 'rel': 358 | rel_epsilon = self.threshold + 1. 359 | return a > best * rel_epsilon 360 | 361 | else: # mode == 'max' and epsilon_mode == 'abs': 362 | return a > best + self.threshold 363 | 364 | def _init_is_better(self, mode, threshold, threshold_mode): 365 | 366 | if mode not in {'min', 'max'}: 367 | raise ValueError('mode ' + mode + ' is unknown!') 368 | if threshold_mode not in {'rel', 'abs'}: 369 | raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') 370 | 371 | if mode == 'min': 372 | self.mode_worse = float('inf') 373 | else: # mode == 'max': 374 | self.mode_worse = -float('inf') 375 | 376 | self.mode = mode 377 | self.threshold = threshold 378 | self.threshold_mode = threshold_mode 379 | 380 | 381 | if __name__ == '__main__': 382 | 383 | x = IndicatePlateau(threshold=0.0899) 384 | eps = np.arange(0, 2000, 1) 385 | y = np.exp(-0.01 * eps) 386 | 387 | print(y) 388 | for i, y_ in enumerate(y): 389 | 390 | z = x.step(y_) 391 | if z: 392 | print(f'Plateaud at epoch {i} with val {y_}') -------------------------------------------------------------------------------- /methods/estimate_k/estimate_k.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 5 | from torch.utils.data import DataLoader 6 | from sklearn.metrics import adjusted_rand_score as ari_score 7 | import numpy as np 8 | from sklearn.cluster import KMeans 9 | import torch 10 | from project_utils.cluster_utils import cluster_acc 11 | 12 | from methods.clustering.feature_vector_dataset import FeatureVectorDataset 13 | from data.get_datasets import get_datasets, get_class_splits 14 | 15 | from config import feature_extract_dir 16 | from tqdm import tqdm 17 | 18 | from scipy.optimize import minimize_scalar 19 | from functools import partial 20 | 21 | from project_utils.cluster_utils import str2bool 22 | # TODO: Debug 23 | import warnings 24 | warnings.filterwarnings("ignore", category=DeprecationWarning) 25 | 26 | def test_kmeans(K, merge_test_loader, args=None, verbose=False): 27 | 28 | """ 29 | In this case, the test loader needs to have the labelled and unlabelled subsets of the training data 30 | """ 31 | 32 | if K is None: 33 | K = args.num_labeled_classes + args.num_unlabeled_classes 34 | 35 | all_feats = [] 36 | targets = np.array([]) 37 | mask_lab = np.array([]) # From all the data, which instances belong to the labelled set 38 | mask_cls = np.array([]) # From all the data, which instances belong to seen classes 39 | 40 | print('Collating features...') 41 | # First extract all features 42 | for batch_idx, (feats, label, _, mask_lab_) in enumerate(tqdm(merge_test_loader)): 43 | 44 | feats = feats.to(device) 45 | 46 | feats = torch.nn.functional.normalize(feats, dim=-1) 47 | 48 | all_feats.append(feats.cpu().numpy()) 49 | targets = np.append(targets, label.cpu().numpy()) 50 | mask_cls = np.append(mask_cls, np.array([True if x.item() in range(len(args.train_classes)) 51 | else False for x in label])) 52 | mask_lab = np.append(mask_lab, mask_lab_.cpu().bool().numpy()) 53 | 54 | # ----------------------- 55 | # K-MEANS 56 | # ----------------------- 57 | mask_lab = mask_lab.astype(bool) 58 | mask_cls = mask_cls.astype(bool) 59 | 60 | all_feats = np.concatenate(all_feats) 61 | 62 | print('Fitting K-Means...') 63 | kmeans = KMeans(n_clusters=K, random_state=0).fit(all_feats) 64 | preds = kmeans.labels_ 65 | 66 | # ----------------------- 67 | # EVALUATE 68 | # ----------------------- 69 | mask = mask_lab 70 | 71 | 72 | labelled_acc, labelled_nmi, labelled_ari = cluster_acc(targets.astype(int)[mask], preds.astype(int)[mask]), \ 73 | nmi_score(targets[mask], preds[mask]), \ 74 | ari_score(targets[mask], preds[mask]) 75 | 76 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int)[~mask], 77 | preds.astype(int)[~mask]), \ 78 | nmi_score(targets[~mask], preds[~mask]), \ 79 | ari_score(targets[~mask], preds[~mask]) 80 | 81 | if verbose: 82 | print('K') 83 | print('Labelled Instances acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(labelled_acc, labelled_nmi, 84 | labelled_ari)) 85 | print('Unlabelled Instances acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(unlabelled_acc, unlabelled_nmi, 86 | unlabelled_ari)) 87 | 88 | # labelled_acc = DUMMY_ACCS[K - 1].item() 89 | 90 | return labelled_acc 91 | # return (labelled_acc, labelled_nmi, labelled_ari), (unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.astype(float).mean() 92 | 93 | 94 | def test_kmeans_for_scipy(K, merge_test_loader, args=None, verbose=False): 95 | 96 | """ 97 | In this case, the test loader needs to have the labelled and unlabelled subsets of the training data 98 | """ 99 | 100 | K = int(K) 101 | 102 | all_feats = [] 103 | targets = np.array([]) 104 | mask_lab = np.array([]) # From all the data, which instances belong to the labelled set 105 | mask_cls = np.array([]) # From all the data, which instances belong to seen classes 106 | 107 | print('Collating features...') 108 | # First extract all features 109 | for batch_idx, (feats, label, _, mask_lab_) in enumerate(tqdm(merge_test_loader)): 110 | 111 | feats = feats.to(device) 112 | 113 | feats = torch.nn.functional.normalize(feats, dim=-1) 114 | 115 | all_feats.append(feats.cpu().numpy()) 116 | targets = np.append(targets, label.cpu().numpy()) 117 | mask_cls = np.append(mask_cls, np.array([True if x.item() in range(len(args.train_classes)) 118 | else False for x in label])) 119 | mask_lab = np.append(mask_lab, mask_lab_.cpu().bool().numpy()) 120 | 121 | # ----------------------- 122 | # K-MEANS 123 | # ----------------------- 124 | mask_lab = mask_lab.astype(bool) 125 | mask_cls = mask_cls.astype(bool) 126 | 127 | all_feats = np.concatenate(all_feats) 128 | 129 | print(f'Fitting K-Means for K = {K}...') 130 | kmeans = KMeans(n_clusters=K, random_state=0).fit(all_feats) 131 | preds = kmeans.labels_ 132 | 133 | # ----------------------- 134 | # EVALUATE 135 | # ----------------------- 136 | mask = mask_lab 137 | 138 | 139 | labelled_acc, labelled_nmi, labelled_ari = cluster_acc(targets.astype(int)[mask], preds.astype(int)[mask]), \ 140 | nmi_score(targets[mask], preds[mask]), \ 141 | ari_score(targets[mask], preds[mask]) 142 | 143 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int)[~mask], 144 | preds.astype(int)[~mask]), \ 145 | nmi_score(targets[~mask], preds[~mask]), \ 146 | ari_score(targets[~mask], preds[~mask]) 147 | 148 | print(f'K = {K}') 149 | print('Labelled Instances acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(labelled_acc, labelled_nmi, 150 | labelled_ari)) 151 | print('Unlabelled Instances acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(unlabelled_acc, unlabelled_nmi, 152 | unlabelled_ari)) 153 | 154 | return -labelled_acc 155 | 156 | 157 | def binary_search(merge_test_loader, args): 158 | 159 | min_classes = args.num_labeled_classes 160 | 161 | # Iter 0 162 | big_k = args.max_classes 163 | small_k = min_classes 164 | diff = big_k - small_k 165 | middle_k = int(0.5 * diff + small_k) 166 | 167 | labelled_acc_big = test_kmeans(big_k, merge_test_loader, args) 168 | labelled_acc_small = test_kmeans(small_k, merge_test_loader, args) 169 | labelled_acc_middle = test_kmeans(middle_k, merge_test_loader, args) 170 | 171 | print(f'Iter 0: BigK {big_k}, Acc {labelled_acc_big:.4f} | MiddleK {middle_k}, Acc {labelled_acc_middle:.4f} | SmallK {small_k}, Acc {labelled_acc_small:.4f} ') 172 | all_accs = [labelled_acc_small, labelled_acc_middle, labelled_acc_big] 173 | best_acc_so_far = np.max(all_accs) 174 | best_acc_at_k = np.array([small_k, middle_k, big_k])[np.argmax(all_accs)] 175 | print(f'Best Acc so far {best_acc_so_far:.4f} at K {best_acc_at_k}') 176 | 177 | for i in range(1, int(np.log2(diff)) + 1): 178 | 179 | if labelled_acc_big > labelled_acc_small: 180 | 181 | best_acc = max(labelled_acc_middle, labelled_acc_big) 182 | 183 | small_k = middle_k 184 | labelled_acc_small = labelled_acc_middle 185 | diff = big_k - small_k 186 | middle_k = int(0.5 * diff + small_k) 187 | 188 | else: 189 | 190 | best_acc = max(labelled_acc_middle, labelled_acc_small) 191 | big_k = middle_k 192 | 193 | diff = big_k - small_k 194 | middle_k = int(0.5 * diff + small_k) 195 | labelled_acc_big = labelled_acc_middle 196 | 197 | labelled_acc_middle = test_kmeans(middle_k, merge_test_loader, args) 198 | 199 | print(f'Iter {i}: BigK {big_k}, Acc {labelled_acc_big:.4f} | MiddleK {middle_k}, Acc {labelled_acc_middle:.4f} | SmallK {small_k}, Acc {labelled_acc_small:.4f} ') 200 | all_accs = [labelled_acc_small, labelled_acc_middle, labelled_acc_big] 201 | best_acc_so_far = np.max(all_accs) 202 | best_acc_at_k = np.array([small_k, middle_k, big_k])[np.argmax(all_accs)] 203 | print(f'Best Acc so far {best_acc_so_far:.4f} at K {best_acc_at_k}') 204 | 205 | 206 | def scipy_optimise(merge_test_loader, args): 207 | 208 | small_k = args.num_labeled_classes 209 | big_k = args.max_classes 210 | 211 | test_k_means_partial = partial(test_kmeans_for_scipy, merge_test_loader=merge_test_loader, args=args, verbose=True) 212 | res = minimize_scalar(test_k_means_partial, bounds=(small_k, big_k), method='bounded', options={'disp': True}) 213 | print(f'Optimal K is {res.x}') 214 | 215 | 216 | if __name__ == "__main__": 217 | 218 | parser = argparse.ArgumentParser( 219 | description='cluster', 220 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 221 | parser.add_argument('--batch_size', default=128, type=int) 222 | parser.add_argument('--num_workers', default=8, type=int) 223 | parser.add_argument('--max_classes', default=1000, type=int) 224 | parser.add_argument('--root_dir', type=str, default=feature_extract_dir) 225 | parser.add_argument('--warmup_model_exp_id', type=str, default=None) 226 | parser.add_argument('--model_name', type=str, default='vit_dino', help='Format is {model_name}_{pretrain}') 227 | parser.add_argument('--search_mode', type=str, default='brent', help='Mode for black box optimisation') 228 | parser.add_argument('--dataset_name', type=str, default='cifar10', help='options: cifar10, cifar100, scars') 229 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 230 | parser.add_argument('--use_best_model', type=str2bool, default=True) 231 | 232 | # ---------------------- 233 | # INIT 234 | # ---------------------- 235 | args = parser.parse_args() 236 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 237 | 238 | cluster_accs = {} 239 | 240 | args.save_dir = os.path.join(args.root_dir, f'{args.model_name}_{args.dataset_name}') 241 | 242 | args = get_class_splits(args) 243 | 244 | args.num_labeled_classes = len(args.train_classes) 245 | args.num_unlabeled_classes = len(args.unlabeled_classes) 246 | 247 | print(args) 248 | 249 | if args.warmup_model_exp_id is not None: 250 | args.save_dir += '_' + args.warmup_model_exp_id 251 | if args.use_best_model: 252 | args.save_dir += '_best' 253 | print(f'Using features from experiment: {args.warmup_model_exp_id}') 254 | else: 255 | print(f'Using pretrained {args.model_name} features...') 256 | 257 | # -------------------- 258 | # DATASETS 259 | # -------------------- 260 | print('Building datasets...') 261 | train_transform, test_transform = None, None 262 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples = get_datasets(args.dataset_name, 263 | train_transform, test_transform, args) 264 | 265 | # Convert to feature vector dataset 266 | test_dataset = FeatureVectorDataset(base_dataset=test_dataset, feature_root=os.path.join(args.save_dir, 'test')) 267 | unlabelled_train_examples_test = FeatureVectorDataset(base_dataset=unlabelled_train_examples_test, 268 | feature_root=os.path.join(args.save_dir, 'train')) 269 | train_dataset = FeatureVectorDataset(base_dataset=train_dataset, feature_root=os.path.join(args.save_dir, 'train')) 270 | 271 | # -------------------- 272 | # DATALOADERS 273 | # -------------------- 274 | unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, 275 | batch_size=args.batch_size, shuffle=False) 276 | test_loader = DataLoader(test_dataset, num_workers=args.num_workers, 277 | batch_size=args.batch_size, shuffle=False) 278 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, 279 | batch_size=args.batch_size, shuffle=False) 280 | 281 | print('Testing on all in the training data...') 282 | if args.search_mode == 'brent': 283 | print('Optimising with Brents algorithm') 284 | scipy_optimise(merge_test_loader=train_loader, args=args) 285 | else: 286 | binary_search(train_loader, args) -------------------------------------------------------------------------------- /model/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from torch.nn.init import trunc_normal_ 25 | 26 | def drop_path(x, drop_prob: float = 0., training: bool = False): 27 | if drop_prob == 0. or not training: 28 | return x 29 | keep_prob = 1 - drop_prob 30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 32 | random_tensor.floor_() # binarize 33 | output = x.div(keep_prob) * random_tensor 34 | return output 35 | 36 | 37 | class DropPath(nn.Module): 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 39 | """ 40 | def __init__(self, drop_prob=None): 41 | super(DropPath, self).__init__() 42 | self.drop_prob = drop_prob 43 | 44 | def forward(self, x): 45 | return drop_path(x, self.drop_prob, self.training) 46 | 47 | 48 | class Mlp(nn.Module): 49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 50 | super().__init__() 51 | out_features = out_features or in_features 52 | hidden_features = hidden_features or in_features 53 | self.fc1 = nn.Linear(in_features, hidden_features) 54 | self.act = act_layer() 55 | self.fc2 = nn.Linear(hidden_features, out_features) 56 | self.drop = nn.Dropout(drop) 57 | 58 | def forward(self, x): 59 | x = self.fc1(x) 60 | x = self.act(x) 61 | x = self.drop(x) 62 | x = self.fc2(x) 63 | x = self.drop(x) 64 | return x 65 | 66 | 67 | class Attention(nn.Module): 68 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 69 | super().__init__() 70 | self.num_heads = num_heads 71 | head_dim = dim // num_heads 72 | self.scale = qk_scale or head_dim ** -0.5 73 | 74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 75 | self.attn_drop = nn.Dropout(attn_drop) 76 | self.proj = nn.Linear(dim, dim) 77 | self.proj_drop = nn.Dropout(proj_drop) 78 | 79 | def forward(self, x): 80 | B, N, C = x.shape 81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 82 | q, k, v = qkv[0], qkv[1], qkv[2] 83 | 84 | attn = (q @ k.transpose(-2, -1)) * self.scale 85 | attn = attn.softmax(dim=-1) 86 | attn = self.attn_drop(attn) 87 | 88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 89 | x = self.proj(x) 90 | x = self.proj_drop(x) 91 | return x, attn 92 | 93 | 94 | class Block(nn.Module): 95 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 96 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 97 | super().__init__() 98 | self.norm1 = norm_layer(dim) 99 | self.attn = Attention( 100 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 102 | self.norm2 = norm_layer(dim) 103 | mlp_hidden_dim = int(dim * mlp_ratio) 104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 105 | 106 | def forward(self, x, return_attention=False): 107 | y, attn = self.attn(self.norm1(x)) 108 | x = x + self.drop_path(y) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | 111 | if return_attention: 112 | return x, attn 113 | else: 114 | return x 115 | 116 | 117 | class PatchEmbed(nn.Module): 118 | """ Image to Patch Embedding 119 | """ 120 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 121 | super().__init__() 122 | num_patches = (img_size // patch_size) * (img_size // patch_size) 123 | self.img_size = img_size 124 | self.patch_size = patch_size 125 | self.num_patches = num_patches 126 | 127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 128 | 129 | def forward(self, x): 130 | B, C, H, W = x.shape 131 | x = self.proj(x).flatten(2).transpose(1, 2) 132 | return x 133 | 134 | 135 | class VisionTransformer(nn.Module): 136 | """ Vision Transformer """ 137 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 138 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 139 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim 142 | 143 | self.patch_embed = PatchEmbed( 144 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 145 | num_patches = self.patch_embed.num_patches 146 | 147 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 148 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 149 | self.pos_drop = nn.Dropout(p=drop_rate) 150 | 151 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 152 | self.blocks = nn.ModuleList([ 153 | Block( 154 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 155 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 156 | for i in range(depth)]) 157 | self.norm = norm_layer(embed_dim) 158 | 159 | # Classifier head 160 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 161 | 162 | trunc_normal_(self.pos_embed, std=.02) 163 | trunc_normal_(self.cls_token, std=.02) 164 | self.apply(self._init_weights) 165 | 166 | def _init_weights(self, m): 167 | if isinstance(m, nn.Linear): 168 | trunc_normal_(m.weight, std=.02) 169 | if isinstance(m, nn.Linear) and m.bias is not None: 170 | nn.init.constant_(m.bias, 0) 171 | elif isinstance(m, nn.LayerNorm): 172 | nn.init.constant_(m.bias, 0) 173 | nn.init.constant_(m.weight, 1.0) 174 | 175 | def interpolate_pos_encoding(self, x, w, h): 176 | npatch = x.shape[1] - 1 177 | N = self.pos_embed.shape[1] - 1 178 | if npatch == N and w == h: 179 | return self.pos_embed 180 | class_pos_embed = self.pos_embed[:, 0] 181 | patch_pos_embed = self.pos_embed[:, 1:] 182 | dim = x.shape[-1] 183 | w0 = w // self.patch_embed.patch_size 184 | h0 = h // self.patch_embed.patch_size 185 | # we add a small number to avoid floating point error in the interpolation 186 | # see discussion at https://github.com/facebookresearch/dino/issues/8 187 | w0, h0 = w0 + 0.1, h0 + 0.1 188 | patch_pos_embed = nn.functional.interpolate( 189 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 190 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 191 | mode='bicubic', 192 | ) 193 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 194 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 195 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 196 | 197 | def prepare_tokens(self, x): 198 | B, nc, w, h = x.shape 199 | x = self.patch_embed(x) # patch linear embedding 200 | 201 | # add the [CLS] token to the embed patch tokens 202 | cls_tokens = self.cls_token.expand(B, -1, -1) 203 | x = torch.cat((cls_tokens, x), dim=1) 204 | 205 | # add positional encoding to each token 206 | x = x + self.interpolate_pos_encoding(x, w, h) 207 | 208 | return self.pos_drop(x) 209 | 210 | def forward(self, x, return_all_patches=False): 211 | x = self.prepare_tokens(x) 212 | for blk in self.blocks: 213 | x = blk(x) 214 | x = self.norm(x) 215 | 216 | if return_all_patches: 217 | return x 218 | else: 219 | return x[:, 0] 220 | 221 | def get_last_selfattention(self, x): 222 | x = self.prepare_tokens(x) 223 | for i, blk in enumerate(self.blocks): 224 | if i < len(self.blocks) - 1: 225 | x = blk(x) 226 | else: 227 | # return attention of the last block 228 | x, attn = blk(x, return_attention=True) 229 | x = self.norm(x) 230 | return x, attn 231 | 232 | def get_intermediate_layers(self, x, n=1): 233 | x = self.prepare_tokens(x) 234 | # we return the output tokens from the `n` last blocks 235 | output = [] 236 | for i, blk in enumerate(self.blocks): 237 | x = blk(x) 238 | if len(self.blocks) - i <= n: 239 | output.append(self.norm(x)) 240 | return output 241 | 242 | 243 | def vit_tiny(patch_size=16, **kwargs): 244 | model = VisionTransformer( 245 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 246 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 247 | return model 248 | 249 | 250 | def vit_small(patch_size=16, **kwargs): 251 | model = VisionTransformer( 252 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 253 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 254 | return model 255 | 256 | 257 | def vit_base(patch_size=16, **kwargs): 258 | model = VisionTransformer( 259 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 260 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 261 | return model 262 | 263 | 264 | class DINOHead(nn.Module): 265 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 266 | super().__init__() 267 | nlayers = max(nlayers, 1) 268 | if nlayers == 1: 269 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 270 | else: 271 | layers = [nn.Linear(in_dim, hidden_dim)] 272 | if use_bn: 273 | layers.append(nn.BatchNorm1d(hidden_dim)) 274 | layers.append(nn.GELU()) 275 | for _ in range(nlayers - 2): 276 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 277 | if use_bn: 278 | layers.append(nn.BatchNorm1d(hidden_dim)) 279 | layers.append(nn.GELU()) 280 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 281 | self.mlp = nn.Sequential(*layers) 282 | self.apply(self._init_weights) 283 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 284 | self.last_layer.weight_g.data.fill_(1) 285 | if norm_last_layer: 286 | self.last_layer.weight_g.requires_grad = False 287 | 288 | def _init_weights(self, m): 289 | if isinstance(m, nn.Linear): 290 | trunc_normal_(m.weight, std=.02) 291 | if isinstance(m, nn.Linear) and m.bias is not None: 292 | nn.init.constant_(m.bias, 0) 293 | 294 | def forward(self, x): 295 | x = self.mlp(x) 296 | x = nn.functional.normalize(x, dim=-1, p=2) 297 | x = self.last_layer(x) 298 | return x 299 | 300 | 301 | class VisionTransformerWithLinear(nn.Module): 302 | 303 | def __init__(self, base_vit, num_classes=200): 304 | 305 | super().__init__() 306 | 307 | self.base_vit = base_vit 308 | self.fc = nn.Linear(768, num_classes) 309 | 310 | def forward(self, x, return_features=False): 311 | 312 | features = self.base_vit(x) 313 | features = torch.nn.functional.normalize(features, dim=-1) 314 | logits = self.fc(features) 315 | 316 | if return_features: 317 | return logits, features 318 | else: 319 | return logits 320 | 321 | @torch.no_grad() 322 | def normalize_prototypes(self): 323 | w = self.fc.weight.data.clone() 324 | w = torch.nn.functional.normalize(w, dim=1, p=2) 325 | self.fc.weight.copy_(w) -------------------------------------------------------------------------------- /methods/clustering/faster_mix_k_means_pytorch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import random 4 | from project_utils.cluster_utils import cluster_acc 5 | from sklearn.utils._joblib import Parallel, delayed, effective_n_jobs 6 | from sklearn.utils import check_random_state 7 | import torch 8 | 9 | def pairwise_distance(data1, data2, batch_size=None): 10 | r''' 11 | using broadcast mechanism to calculate pairwise ecludian distance of data 12 | the input data is N*M matrix, where M is the dimension 13 | we first expand the N*M matrix into N*1*M matrix A and 1*N*M matrix B 14 | then a simple elementwise operation of A and B will handle the pairwise operation of points represented by data 15 | ''' 16 | #N*1*M 17 | A = data1.unsqueeze(dim=1) 18 | 19 | #1*N*M 20 | B = data2.unsqueeze(dim=0) 21 | 22 | if batch_size == None: 23 | dis = (A-B)**2 24 | #return N*N matrix for pairwise distance 25 | dis = dis.sum(dim=-1) 26 | # torch.cuda.empty_cache() 27 | else: 28 | i = 0 29 | dis = torch.zeros(data1.shape[0], data2.shape[0]) 30 | while i < data1.shape[0]: 31 | if(i+batch_size < data1.shape[0]): 32 | dis_batch = (A[i:i+batch_size]-B)**2 33 | dis_batch = dis_batch.sum(dim=-1) 34 | dis[i:i+batch_size] = dis_batch 35 | i = i+batch_size 36 | # torch.cuda.empty_cache() 37 | elif(i+batch_size >= data1.shape[0]): 38 | dis_final = (A[i:] - B)**2 39 | dis_final = dis_final.sum(dim=-1) 40 | dis[i:] = dis_final 41 | # torch.cuda.empty_cache() 42 | break 43 | # torch.cuda.empty_cache() 44 | return dis 45 | 46 | 47 | class K_Means: 48 | 49 | def __init__(self, k=3, tolerance=1e-4, max_iterations=100, init='k-means++', 50 | n_init=10, random_state=None, n_jobs=None, pairwise_batch_size=None, mode=None): 51 | self.k = k 52 | self.tolerance = tolerance 53 | self.max_iterations = max_iterations 54 | self.init = init 55 | self.n_init = n_init 56 | self.random_state = random_state 57 | self.n_jobs = n_jobs 58 | self.pairwise_batch_size = pairwise_batch_size 59 | self.mode = mode 60 | 61 | def split_for_val(self, l_feats, l_targets, val_prop=0.2): 62 | 63 | np.random.seed(0) 64 | 65 | # Reserve some labelled examples for validation 66 | num_val_instances = int(val_prop * len(l_targets)) 67 | val_idxs = np.random.choice(range(len(l_targets)), size=(num_val_instances), replace=False) 68 | val_idxs.sort() 69 | remaining_idxs = list(set(range(len(l_targets))) - set(val_idxs.tolist())) 70 | remaining_idxs.sort() 71 | remaining_idxs = np.array(remaining_idxs) 72 | 73 | val_l_targets = l_targets[val_idxs] 74 | val_l_feats = l_feats[val_idxs] 75 | 76 | remaining_l_targets = l_targets[remaining_idxs] 77 | remaining_l_feats = l_feats[remaining_idxs] 78 | 79 | return remaining_l_feats, remaining_l_targets, val_l_feats, val_l_targets 80 | 81 | 82 | def kpp(self, X, pre_centers=None, k=10, random_state=None): 83 | random_state = check_random_state(random_state) 84 | 85 | if pre_centers is not None: 86 | 87 | C = pre_centers 88 | 89 | else: 90 | 91 | C = X[random_state.randint(0, len(X))] 92 | 93 | C = C.view(-1, X.shape[1]) 94 | 95 | while C.shape[0] < k: 96 | 97 | dist = pairwise_distance(X, C, self.pairwise_batch_size) 98 | dist = dist.view(-1, C.shape[0]) 99 | d2, _ = torch.min(dist, dim=1) 100 | prob = d2/d2.sum() 101 | cum_prob = torch.cumsum(prob, dim=0) 102 | r = random_state.rand() 103 | 104 | if len((cum_prob >= r).nonzero()) == 0: 105 | debug = 0 106 | else: 107 | ind = (cum_prob >= r).nonzero()[0][0] 108 | C = torch.cat((C, X[ind].view(1, -1)), dim=0) 109 | 110 | return C 111 | 112 | 113 | def fit_once(self, X, random_state): 114 | 115 | centers = torch.zeros(self.k, X.shape[1]).type_as(X) 116 | labels = -torch.ones(len(X)) 117 | #initialize the centers, the first 'k' elements in the dataset will be our initial centers 118 | 119 | if self.init == 'k-means++': 120 | centers = self.kpp(X, k=self.k, random_state=random_state) 121 | 122 | elif self.init == 'random': 123 | 124 | random_state = check_random_state(self.random_state) 125 | idx = random_state.choice(len(X), self.k, replace=False) 126 | for i in range(self.k): 127 | centers[i] = X[idx[i]] 128 | 129 | else: 130 | for i in range(self.k): 131 | centers[i] = X[i] 132 | 133 | #begin iterations 134 | 135 | best_labels, best_inertia, best_centers = None, None, None 136 | for i in range(self.max_iterations): 137 | 138 | centers_old = centers.clone() 139 | dist = pairwise_distance(X, centers, self.pairwise_batch_size) 140 | mindist, labels = torch.min(dist, dim=1) 141 | inertia = mindist.sum() 142 | 143 | for idx in range(self.k): 144 | selected = torch.nonzero(labels == idx).squeeze() 145 | selected = torch.index_select(X, 0, selected) 146 | centers[idx] = selected.mean(dim=0) 147 | 148 | if best_inertia is None or inertia < best_inertia: 149 | best_labels = labels.clone() 150 | best_centers = centers.clone() 151 | best_inertia = inertia 152 | 153 | center_shift = torch.sum(torch.sqrt(torch.sum((centers - centers_old) ** 2, dim=1))) 154 | if center_shift ** 2 < self.tolerance: 155 | #break out of the main loop if the results are optimal, ie. the centers don't change their positions much(more than our tolerance) 156 | break 157 | 158 | return best_labels, best_inertia, best_centers, i + 1 159 | 160 | 161 | def fit_mix_once(self, u_feats, l_feats, l_targets, random_state): 162 | 163 | def supp_idxs(c): 164 | return l_targets.eq(c).nonzero().squeeze(1) 165 | 166 | l_classes = torch.unique(l_targets) 167 | support_idxs = list(map(supp_idxs, l_classes)) 168 | l_centers = torch.stack([l_feats[idx_list].mean(0) for idx_list in support_idxs]) 169 | cat_feats = torch.cat((l_feats, u_feats)) 170 | 171 | centers = torch.zeros([self.k, cat_feats.shape[1]]).type_as(cat_feats) 172 | centers[:len(l_classes)] = l_centers 173 | 174 | labels = -torch.ones(len(cat_feats)).type_as(cat_feats).long() 175 | 176 | l_classes = l_classes.cpu().long().numpy() 177 | l_targets = l_targets.cpu().long().numpy() 178 | l_num = len(l_targets) 179 | cid2ncid = {cid:ncid for ncid, cid in enumerate(l_classes)} # Create the mapping table for New cid (ncid) 180 | for i in range(l_num): 181 | labels[i] = cid2ncid[l_targets[i]] 182 | 183 | #initialize the centers, the first 'k' elements in the dataset will be our initial centers 184 | centers = self.kpp(u_feats, l_centers, k=self.k, random_state=random_state) 185 | 186 | # Begin iterations 187 | best_labels, best_inertia, best_centers = None, None, None 188 | for it in range(self.max_iterations): 189 | centers_old = centers.clone() 190 | 191 | dist = pairwise_distance(u_feats, centers, self.pairwise_batch_size) 192 | u_mindist, u_labels = torch.min(dist, dim=1) 193 | u_inertia = u_mindist.sum() 194 | l_mindist = torch.sum((l_feats - centers[labels[:l_num]])**2, dim=1) 195 | l_inertia = l_mindist.sum() 196 | inertia = u_inertia + l_inertia 197 | labels[l_num:] = u_labels 198 | 199 | for idx in range(self.k): 200 | 201 | selected = torch.nonzero(labels == idx).squeeze() 202 | selected = torch.index_select(cat_feats, 0, selected) 203 | centers[idx] = selected.mean(dim=0) 204 | 205 | if best_inertia is None or inertia < best_inertia: 206 | best_labels = labels.clone() 207 | best_centers = centers.clone() 208 | best_inertia = inertia 209 | 210 | center_shift = torch.sum(torch.sqrt(torch.sum((centers - centers_old) ** 2, dim=1))) 211 | 212 | if center_shift ** 2 < self.tolerance: 213 | #break out of the main loop if the results are optimal, ie. the centers don't change their positions much(more than our tolerance) 214 | break 215 | 216 | return best_labels, best_inertia, best_centers, i + 1 217 | 218 | 219 | def fit(self, X): 220 | random_state = check_random_state(self.random_state) 221 | best_inertia = None 222 | if effective_n_jobs(self.n_jobs) == 1: 223 | for it in range(self.n_init): 224 | labels, inertia, centers, n_iters = self.fit_once(X, random_state) 225 | if best_inertia is None or inertia < best_inertia: 226 | self.labels_ = labels.clone() 227 | self.cluster_centers_ = centers.clone() 228 | best_inertia = inertia 229 | self.inertia_ = inertia 230 | self.n_iter_ = n_iters 231 | else: 232 | # parallelisation of k-means runs 233 | seeds = random_state.randint(np.iinfo(np.int32).max, size=self.n_init) 234 | results = Parallel(n_jobs=self.n_jobs, verbose=0)(delayed(self.fit_once)(X, seed) for seed in seeds) 235 | # Get results with the lowest inertia 236 | labels, inertia, centers, n_iters = zip(*results) 237 | best = np.argmin(inertia) 238 | self.labels_ = labels[best] 239 | self.inertia_ = inertia[best] 240 | self.cluster_centers_ = centers[best] 241 | self.n_iter_ = n_iters[best] 242 | 243 | 244 | def fit_mix(self, u_feats, l_feats, l_targets): 245 | 246 | random_state = check_random_state(self.random_state) 247 | best_inertia = None 248 | fit_func = self.fit_mix_once 249 | 250 | if effective_n_jobs(self.n_jobs) == 1: 251 | for it in range(self.n_init): 252 | 253 | labels, inertia, centers, n_iters = fit_func(u_feats, l_feats, l_targets, random_state) 254 | 255 | if best_inertia is None or inertia < best_inertia: 256 | self.labels_ = labels.clone() 257 | self.cluster_centers_ = centers.clone() 258 | best_inertia = inertia 259 | self.inertia_ = inertia 260 | self.n_iter_ = n_iters 261 | 262 | else: 263 | 264 | # parallelisation of k-means runs 265 | seeds = random_state.randint(np.iinfo(np.int32).max, size=self.n_init) 266 | results = Parallel(n_jobs=self.n_jobs, verbose=0)(delayed(fit_func)(u_feats, l_feats, l_targets, seed) 267 | for seed in seeds) 268 | # Get results with the lowest inertia 269 | 270 | labels, inertia, centers, n_iters = zip(*results) 271 | best = np.argmin(inertia) 272 | self.labels_ = labels[best] 273 | self.inertia_ = inertia[best] 274 | self.cluster_centers_ = centers[best] 275 | self.n_iter_ = n_iters[best] 276 | 277 | 278 | def main(): 279 | 280 | import matplotlib.pyplot as plt 281 | from matplotlib import style 282 | import pandas as pd 283 | style.use('ggplot') 284 | from sklearn.datasets import make_blobs 285 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 286 | X, y = make_blobs(n_samples=500, 287 | n_features=2, 288 | centers=4, 289 | cluster_std=1, 290 | center_box=(-10.0, 10.0), 291 | shuffle=True, 292 | random_state=1) # For reproducibility 293 | 294 | cuda = torch.cuda.is_available() 295 | device = torch.device("cuda" if cuda else "cpu") 296 | # X = torch.from_numpy(X).float().to(device) 297 | 298 | 299 | y = np.array(y) 300 | l_targets = y[y>1] 301 | l_feats = X[y>1] 302 | u_feats = X[y<2] 303 | cat_feats = np.concatenate((l_feats, u_feats)) 304 | y = np.concatenate((y[y>1], y[y<2])) 305 | cat_feats = torch.from_numpy(cat_feats).to(device) 306 | u_feats = torch.from_numpy(u_feats).to(device) 307 | l_feats = torch.from_numpy(l_feats).to(device) 308 | l_targets = torch.from_numpy(l_targets).to(device) 309 | 310 | km = K_Means(k=4, init='k-means++', random_state=1, n_jobs=None, pairwise_batch_size=10) 311 | 312 | # km.fit(X) 313 | 314 | km.fit_mix(u_feats, l_feats, l_targets) 315 | # X = X.cpu() 316 | X = cat_feats.cpu() 317 | centers = km.cluster_centers_.cpu() 318 | pred = km.labels_.cpu() 319 | print('nmi', nmi_score(pred, y)) 320 | 321 | # Plotting starts here 322 | colors = 10*["g", "c", "b", "k", "r", "m"] 323 | 324 | for i in range(len(X)): 325 | x = X[i] 326 | plt.scatter(x[0], x[1], color = colors[pred[i]],s = 10) 327 | 328 | for i in range(4): 329 | plt.scatter(centers[i][0], centers[i][1], s = 130, marker = "*", color='r') 330 | plt.show() 331 | 332 | if __name__ == "__main__": 333 | main() --------------------------------------------------------------------------------