├── data ├── assets │ └── overview.png ├── ssb_splits │ ├── cub_osr_splits.pkl │ ├── scars_osr_splits.pkl │ ├── aircraft_osr_splits.pkl │ └── herbarium_19_class_splits.pkl ├── data_utils.py ├── augmentations │ ├── cut_out.py │ ├── __init__.py │ └── randaugment.py ├── stanford_cars.py ├── herbarium_19.py ├── get_datasets.py ├── cifar.py ├── cub.py ├── imagenet.py └── fgvc_aircraft.py ├── .gitignore ├── config.py ├── requirements.txt ├── models ├── dino.py └── vision_transformer.py ├── bash_scripts ├── meanshift_clustering.sh └── contrastive_meanshift_training.sh ├── project_utils ├── general_utils.py ├── cluster_and_log_utils.py └── cluster_utils.py ├── methods ├── loss.py ├── contrastive_meanshift_training.py └── meanshift_clustering.py └── README.md /data/assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/assets/overview.png -------------------------------------------------------------------------------- /data/ssb_splits/cub_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/ssb_splits/cub_osr_splits.pkl -------------------------------------------------------------------------------- /data/ssb_splits/scars_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/ssb_splits/scars_osr_splits.pkl -------------------------------------------------------------------------------- /data/ssb_splits/aircraft_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/ssb_splits/aircraft_osr_splits.pkl -------------------------------------------------------------------------------- /data/ssb_splits/herbarium_19_class_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/ssb_splits/herbarium_19_class_splits.pkl -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | .vscode/ 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Logs 11 | log/ 12 | wandb/ 13 | data/features/ 14 | data/memory 15 | 16 | #slurm 17 | *.out 18 | 19 | models/*.pth 20 | *.pt 21 | *.gif -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ----------------- 2 | # DATASET ROOTS 3 | # ----------------- 4 | DATASET_ROOT='/home/suachoi/datasets/' 5 | cifar_10_root = DATASET_ROOT+'cifar10' 6 | cifar_100_root = DATASET_ROOT+'cifar100' 7 | cub_root = DATASET_ROOT 8 | aircraft_root = DATASET_ROOT+'fgvc-aircraft-2013b' 9 | herbarium_dataroot = DATASET_ROOT+'herbarium_19' 10 | imagenet_root = DATASET_ROOT+'imagenet' 11 | 12 | # OSR Split dir 13 | osr_split_dir = './data/ssb_splits' 14 | 15 | dino_pretrain_path = './models/dino_vitbase16_pretrain.pth' 16 | exp_root = './log' # All logs and checkpoints will be saved here 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | appdirs==1.4.4 3 | certifi==2024.2.2 4 | chardet==5.2.0 5 | charset-normalizer==2.0.4 6 | click==8.1.7 7 | contourpy==1.2.1 8 | cycler==0.12.1 9 | docker-pycreds==0.4.0 10 | fonttools==4.51.0 11 | gitdb==4.0.11 12 | GitPython==3.1.43 13 | grpcio==1.62.1 14 | idna==3.4 15 | joblib==1.3.2 16 | kiwisolver==1.4.5 17 | Markdown==3.6 18 | MarkupSafe==2.1.5 19 | matplotlib==3.8.4 20 | mkl-fft==1.3.8 21 | mkl-random==1.2.4 22 | mkl-service==2.4.0 23 | numpy==1.26.4 24 | packaging==24.0 25 | pandas==2.2.1 26 | pillow==10.2.0 27 | pip==23.3.1 28 | protobuf==4.25.3 29 | psutil==5.9.8 30 | pyparsing==3.1.2 31 | python-dateutil==2.9.0.post0 32 | pytz==2024.1 33 | PyYAML==6.0.1 34 | requests==2.31.0 35 | scikit-learn==1.4.1.post1 36 | scipy==1.13.0 37 | sentry-sdk==1.44.0 38 | setproctitle==1.3.3 39 | setuptools==68.2.2 40 | six==1.16.0 41 | smmap==5.0.1 42 | tensorboard==2.16.2 43 | tensorboard-data-server==0.7.2 44 | threadpoolctl==3.4.0 45 | torch==1.13.1 46 | torchaudio==0.13.1 47 | torchvision==0.14.1 48 | tqdm==4.66.2 49 | typing_extensions==4.9.0 50 | tzdata==2024.1 51 | urllib3==2.1.0 52 | wandb==0.16.5 53 | Werkzeug==3.0.2 54 | wheel==0.41.2 55 | -------------------------------------------------------------------------------- /models/dino.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from models import vision_transformer as vits 8 | 9 | class DINO(nn.Module): 10 | def __init__(self, args): 11 | super(DINO, self).__init__() 12 | self.k = args.k 13 | self.backbone = self.init_backbone(args.pretrain_path) 14 | self.img_projection_head = vits.__dict__['DINOHead'](in_dim=args.feat_dim, out_dim=args.feat_dim, nlayers=args.num_mlp_layers) 15 | 16 | def forward(self, image): 17 | feat = self.backbone(image) 18 | feat = self.img_projection_head(feat) 19 | feat = F.normalize(feat, dim=-1) 20 | 21 | return feat 22 | 23 | def init_backbone(self, pretrain_path): 24 | model = vits.__dict__['vit_base']() 25 | state_dict = torch.load(os.path.join(pretrain_path, 'dino_vitbase16_pretrain.pth'), map_location='cpu') 26 | model.load_state_dict(state_dict) 27 | for m in model.parameters(): 28 | m.requires_grad = False 29 | 30 | for name, m in model.named_parameters(): 31 | if 'block' in name: 32 | block_num = int(name.split('.')[1]) 33 | if block_num >= 11: 34 | m.requires_grad = True 35 | 36 | return model -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | def subsample_instances(dataset, prop_indices_to_subsample=0.8): 5 | 6 | np.random.seed(0) 7 | subsample_indices = np.random.choice(range(len(dataset)), replace=False, 8 | size=(int(prop_indices_to_subsample * len(dataset)),)) 9 | 10 | return subsample_indices 11 | 12 | class MergedDataset(Dataset): 13 | 14 | """ 15 | Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them 16 | Allows you to iterate over them in parallel 17 | """ 18 | 19 | def __init__(self, labelled_dataset, unlabelled_dataset): 20 | 21 | self.labelled_dataset = labelled_dataset 22 | self.unlabelled_dataset = unlabelled_dataset 23 | self.target_transform = None 24 | 25 | def __getitem__(self, item): 26 | 27 | if item < len(self.labelled_dataset): 28 | img, label, uq_idx = self.labelled_dataset[item] 29 | labeled_or_not = 1 30 | 31 | else: 32 | 33 | img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)] 34 | labeled_or_not = 0 35 | 36 | 37 | return img, label, uq_idx, np.array([labeled_or_not]) 38 | 39 | def __len__(self): 40 | return len(self.unlabelled_dataset) + len(self.labelled_dataset) 41 | -------------------------------------------------------------------------------- /data/augmentations/cut_out.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/hysts/pytorch_cutout 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | def cutout(mask_size, p, cutout_inside, mask_color=(0, 0, 0)): 9 | mask_size_half = mask_size // 2 10 | offset = 1 if mask_size % 2 == 0 else 0 11 | 12 | def _cutout(image): 13 | image = np.asarray(image).copy() 14 | 15 | if np.random.random() > p: 16 | return image 17 | 18 | h, w = image.shape[:2] 19 | 20 | if cutout_inside: 21 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half 22 | cymin, cymax = mask_size_half, h + offset - mask_size_half 23 | else: 24 | cxmin, cxmax = 0, w + offset 25 | cymin, cymax = 0, h + offset 26 | 27 | cx = np.random.randint(cxmin, cxmax) 28 | cy = np.random.randint(cymin, cymax) 29 | xmin = cx - mask_size_half 30 | ymin = cy - mask_size_half 31 | xmax = xmin + mask_size 32 | ymax = ymin + mask_size 33 | xmin = max(0, xmin) 34 | ymin = max(0, ymin) 35 | xmax = min(w, xmax) 36 | ymax = min(h, ymax) 37 | image[ymin:ymax, xmin:xmax] = mask_color 38 | return image 39 | 40 | return _cutout 41 | 42 | def to_tensor(): 43 | def _to_tensor(image): 44 | if len(image.shape) == 3: 45 | return torch.from_numpy( 46 | image.transpose(2, 0, 1).astype(float)) 47 | else: 48 | return torch.from_numpy(image[None, :, :].astype(float)) 49 | 50 | return _to_tensor 51 | 52 | def normalize(mean, std): 53 | 54 | mean = np.array(mean) 55 | std = np.array(std) 56 | 57 | def _normalize(image): 58 | image = np.asarray(image).astype(float) / 255. 59 | image = (image - mean) / std 60 | return image 61 | 62 | return _normalize -------------------------------------------------------------------------------- /bash_scripts/meanshift_clustering.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | ################################################### 4 | # GCD 5 | ################################################### 6 | 7 | python -m methods.meanshift_clustering \ 8 | --dataset_name cifar100 \ 9 | --warmup_model_dir 'cifar100_best' 10 | 11 | python -m methods.meanshift_clustering \ 12 | --dataset_name imagenet_100 \ 13 | --warmup_model_dir 'imagenet_best' 14 | 15 | python -m methods.meanshift_clustering \ 16 | --dataset_name cub \ 17 | --warmup_model_dir 'cub_best' 18 | 19 | python -m methods.meanshift_clustering \ 20 | --dataset_name scars \ 21 | --warmup_model_dir 'scars_best' 22 | 23 | python -m methods.meanshift_clustering \ 24 | --dataset_name aircraft \ 25 | --warmup_model_dir 'aircraft_best' 26 | 27 | python -m methods.meanshift_clustering \ 28 | --dataset_name herbarium_19 \ 29 | --warmup_model_dir 'herbarium_19_best' 30 | 31 | ################################################### 32 | # INDUCTIVE GCD 33 | ################################################### 34 | 35 | python -m methods.meanshift_clustering \ 36 | --dataset_name cifar100 \ 37 | --warmup_model_dir 'cifar100_best' \ 38 | --inductive 39 | 40 | python -m methods.meanshift_clustering \ 41 | --dataset_name imagenet_100 \ 42 | --warmup_model_dir 'imagenet_best' \ 43 | --inductive 44 | 45 | python -m methods.meanshift_clustering \ 46 | --dataset_name cub \ 47 | --warmup_model_dir 'cub_best' \ 48 | --inductive 49 | 50 | python -m methods.meanshift_clustering \ 51 | --dataset_name scars \ 52 | --warmup_model_dir 'scars_best' \ 53 | --inductive 54 | 55 | python -m methods.meanshift_clustering \ 56 | --dataset_name aircraft \ 57 | --warmup_model_dir 'aircraft_best' \ 58 | --inductive 59 | 60 | python -m methods.meanshift_clustering \ 61 | --dataset_name herbarium_19 \ 62 | --warmup_model_dir 'herbarium_19_best' \ 63 | --inductive -------------------------------------------------------------------------------- /bash_scripts/contrastive_meanshift_training.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | ################################################### 4 | # GCD 5 | ################################################### 6 | 7 | python -m methods.contrastive_meanshift_training \ 8 | --dataset_name 'cifar100' \ 9 | --lr 0.01 \ 10 | --temperature 0.3 \ 11 | --wandb 12 | 13 | python -m methods.contrastive_meanshift_training \ 14 | --dataset_name 'imagenet_100' \ 15 | --lr 0.01 \ 16 | --temperature 0.3 \ 17 | --wandb 18 | 19 | python -m methods.contrastive_meanshift_training \ 20 | --dataset_name 'cub' \ 21 | --lr 0.05 \ 22 | --temperature 0.25 \ 23 | --eta_min 5e-3 \ 24 | --wandb 25 | 26 | python -m methods.contrastive_meanshift_training \ 27 | --dataset_name 'scars' \ 28 | --lr 0.05 \ 29 | --temperature 0.25 \ 30 | --eta_min 5e-3 \ 31 | --wandb 32 | 33 | python -m methods.contrastive_meanshift_training \ 34 | --dataset_name 'aircraft' \ 35 | --lr 0.05 \ 36 | --temperature 0.25 \ 37 | --eta_min 5e-3 \ 38 | --wandb 39 | 40 | python -m methods.contrastive_meanshift_training \ 41 | --dataset_name 'herbarium_19' \ 42 | --lr 0.05 \ 43 | --temperature 0.25 \ 44 | --eta_min 5e-3 \ 45 | --wandb 46 | 47 | 48 | ################################################### 49 | # INDUCTIVE GCD 50 | ################################################### 51 | 52 | python -m methods.contrastive_meanshift_training \ 53 | --dataset_name 'cifar100' \ 54 | --lr 0.01 \ 55 | --temperature 0.3 \ 56 | --inductive \ 57 | --wandb 58 | 59 | python -m methods.contrastive_meanshift_training \ 60 | --dataset_name 'imagenet_100' \ 61 | --lr 0.01 \ 62 | --temperature 0.3 \ 63 | --inductive \ 64 | --wandb 65 | 66 | python -m methods.contrastive_meanshift_training \ 67 | --dataset_name 'cub' \ 68 | --lr 0.05 \ 69 | --temperature 0.25 \ 70 | --eta_min 5e-3 \ 71 | --inductive \ 72 | --wandb 73 | 74 | python -m methods.contrastive_meanshift_training \ 75 | --dataset_name 'scars' \ 76 | --lr 0.05 \ 77 | --temperature 0.25 \ 78 | --eta_min 5e-3 \ 79 | --inductive \ 80 | --wandb 81 | 82 | python -m methods.contrastive_meanshift_training \ 83 | --dataset_name 'aircraft' \ 84 | --lr 0.05 \ 85 | --temperature 0.25 \ 86 | --eta_min 5e-3 \ 87 | --inductive \ 88 | --wandb 89 | 90 | python -m methods.contrastive_meanshift_training \ 91 | --dataset_name 'herbarium_19' \ 92 | --lr 0.05 \ 93 | --temperature 0.25 \ 94 | --eta_min 5e-3 \ 95 | --inductive \ 96 | --wandb -------------------------------------------------------------------------------- /project_utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | import inspect 6 | 7 | from datetime import datetime 8 | 9 | def seed_torch(seed=1029): 10 | 11 | random.seed(seed) 12 | os.environ['PYTHONHASHSEED'] = str(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 17 | torch.backends.cudnn.benchmark = False 18 | torch.backends.cudnn.deterministic = True 19 | 20 | 21 | def get_dino_head_weights(pretrain_path): 22 | 23 | """ 24 | :param pretrain_path: Path to full DINO pretrained checkpoint as in https://github.com/facebookresearch/dino 25 | 'full_ckpt' 26 | :return: weights only for the projection head 27 | """ 28 | 29 | all_weights = torch.load(pretrain_path) 30 | 31 | head_state_dict = {} 32 | for k, v in all_weights['teacher'].items(): 33 | if 'head' in k and 'last_layer' not in k: 34 | head_state_dict[k] = v 35 | 36 | head_state_dict = strip_state_dict(head_state_dict, strip_key='head.') 37 | 38 | # Deal with weight norm 39 | weight_norm_state_dict = {} 40 | for k, v in all_weights['teacher'].items(): 41 | if 'last_layer' in k: 42 | weight_norm_state_dict[k.split('.')[2]] = v 43 | 44 | linear_shape = weight_norm_state_dict['weight'].shape 45 | dummy_linear = torch.nn.Linear(in_features=linear_shape[1], out_features=linear_shape[0], bias=False) 46 | dummy_linear.load_state_dict(weight_norm_state_dict) 47 | dummy_linear = torch.nn.utils.weight_norm(dummy_linear) 48 | 49 | for k, v in dummy_linear.state_dict().items(): 50 | 51 | head_state_dict['last_layer.' + k] = v 52 | 53 | return head_state_dict 54 | 55 | 56 | def init_experiment(args, runner_name=None, exp_id=None): 57 | 58 | args.cuda = torch.cuda.is_available() 59 | 60 | # Get filepath of calling script 61 | if runner_name is None: 62 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:] 63 | 64 | root_dir = os.path.join(args.exp_root, *runner_name) 65 | 66 | if not os.path.exists(root_dir): 67 | os.makedirs(root_dir) 68 | 69 | # Either generate a unique experiment ID, or use one which is passed 70 | if exp_id is None: 71 | 72 | # Unique identifier for experiment 73 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \ 74 | datetime.now().strftime("%S.%f")[:-3] + ')' 75 | 76 | log_dir = os.path.join(root_dir, 'log', now) 77 | while os.path.exists(log_dir): 78 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \ 79 | datetime.now().strftime("%S.%f")[:-3] + ')' 80 | 81 | log_dir = os.path.join(root_dir, 'log', now) 82 | 83 | else: 84 | 85 | log_dir = os.path.join(root_dir, 'log', f'{exp_id}') 86 | 87 | if not os.path.exists(log_dir): 88 | os.makedirs(log_dir) 89 | args.log_dir = log_dir 90 | 91 | # Instantiate directory to save models to 92 | model_root_dir = os.path.join(args.log_dir, 'checkpoints') 93 | if not os.path.exists(model_root_dir): 94 | os.mkdir(model_root_dir) 95 | 96 | args.model_dir = model_root_dir 97 | args.model_path = os.path.join(args.model_dir, 'model.pt') 98 | 99 | print(f'Experiment saved to: {args.log_dir}') 100 | 101 | print(runner_name) 102 | print(args) 103 | 104 | return args 105 | 106 | 107 | def str2bool(v): 108 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 109 | return True 110 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 111 | return False 112 | else: 113 | raise argparse.ArgumentTypeError('Boolean value expected.') 114 | 115 | -------------------------------------------------------------------------------- /data/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from data.augmentations.cut_out import * 3 | from data.augmentations.randaugment import RandAugment 4 | 5 | def get_transform(transform_type='default', image_size=32, args=None): 6 | 7 | if transform_type == 'imagenet': 8 | 9 | mean = (0.485, 0.456, 0.406) 10 | std = (0.229, 0.224, 0.225) 11 | interpolation = args.interpolation 12 | crop_pct = args.crop_pct 13 | 14 | train_transform = transforms.Compose([ 15 | transforms.Resize(int(image_size / crop_pct), interpolation), 16 | transforms.RandomCrop(image_size), 17 | transforms.RandomHorizontalFlip(p=0.5), 18 | transforms.ColorJitter(), 19 | transforms.ToTensor(), 20 | transforms.Normalize( 21 | mean=torch.tensor(mean), 22 | std=torch.tensor(std)) 23 | ]) 24 | 25 | test_transform = transforms.Compose([ 26 | transforms.Resize(int(image_size / crop_pct), interpolation), 27 | transforms.CenterCrop(image_size), 28 | transforms.ToTensor(), 29 | transforms.Normalize( 30 | mean=torch.tensor(mean), 31 | std=torch.tensor(std)) 32 | ]) 33 | 34 | elif transform_type == 'pytorch-cifar': 35 | 36 | mean = (0.4914, 0.4822, 0.4465) 37 | std = (0.2023, 0.1994, 0.2010) 38 | 39 | train_transform = transforms.Compose([ 40 | transforms.RandomCrop(image_size, padding=4), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=mean, std=std), 44 | ]) 45 | 46 | test_transform = transforms.Compose([ 47 | transforms.Resize((image_size, image_size)), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=mean, std=std), 50 | ]) 51 | 52 | elif transform_type == 'herbarium_default': 53 | 54 | train_transform = transforms.Compose([ 55 | transforms.Resize((image_size, image_size)), 56 | transforms.RandomResizedCrop(image_size, scale=(args.resize_lower_bound, 1)), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | ]) 60 | 61 | test_transform = transforms.Compose([ 62 | transforms.Resize((image_size, image_size)), 63 | transforms.ToTensor(), 64 | ]) 65 | 66 | elif transform_type == 'cutout': 67 | 68 | mean = np.array([0.4914, 0.4822, 0.4465]) 69 | std = np.array([0.2470, 0.2435, 0.2616]) 70 | 71 | train_transform = transforms.Compose([ 72 | transforms.RandomCrop(image_size, padding=4), 73 | transforms.RandomHorizontalFlip(), 74 | normalize(mean, std), 75 | cutout(mask_size=int(image_size / 2), 76 | p=1, 77 | cutout_inside=False), 78 | to_tensor(), 79 | ]) 80 | test_transform = transforms.Compose([ 81 | transforms.Resize((image_size, image_size)), 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean, std), 84 | ]) 85 | 86 | elif transform_type == 'rand-augment': 87 | 88 | mean = (0.485, 0.456, 0.406) 89 | std = (0.229, 0.224, 0.225) 90 | 91 | train_transform = transforms.Compose([ 92 | transforms.Resize((image_size, image_size)), 93 | transforms.RandomCrop(image_size, padding=4), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | transforms.Normalize(mean=mean, std=std), 97 | ]) 98 | 99 | train_transform.transforms.insert(0, RandAugment(args.rand_aug_n, args.rand_aug_m, args=None)) 100 | 101 | test_transform = transforms.Compose([ 102 | transforms.Resize((image_size, image_size)), 103 | transforms.ToTensor(), 104 | transforms.Normalize(mean=mean, std=std), 105 | ]) 106 | 107 | elif transform_type == 'random_affine': 108 | 109 | mean = (0.485, 0.456, 0.406) 110 | std = (0.229, 0.224, 0.225) 111 | interpolation = args.interpolation 112 | crop_pct = args.crop_pct 113 | 114 | train_transform = transforms.Compose([ 115 | transforms.Resize((image_size, image_size), interpolation), 116 | transforms.RandomAffine(degrees=(-45, 45), 117 | translate=(0.1, 0.1), shear=(-15, 15), scale=(0.7, args.crop_pct)), 118 | transforms.ColorJitter(), 119 | transforms.ToTensor(), 120 | transforms.Normalize( 121 | mean=torch.tensor(mean), 122 | std=torch.tensor(std)) 123 | ]) 124 | 125 | test_transform = transforms.Compose([ 126 | transforms.Resize(int(image_size / crop_pct), interpolation), 127 | transforms.CenterCrop(image_size), 128 | transforms.ToTensor(), 129 | transforms.Normalize( 130 | mean=torch.tensor(mean), 131 | std=torch.tensor(std)) 132 | ]) 133 | 134 | else: 135 | 136 | raise NotImplementedError 137 | 138 | return (train_transform, test_transform) -------------------------------------------------------------------------------- /methods/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class SupConLoss(nn.Module): 9 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 10 | It also supports the unsupervised contrastive loss in SimCLR""" 11 | def __init__(self, temperature=0.07, contrast_mode='all', 12 | base_temperature=0.07): 13 | super(SupConLoss, self).__init__() 14 | self.temperature = temperature 15 | self.contrast_mode = contrast_mode 16 | self.base_temperature = base_temperature 17 | 18 | def forward(self, features, labels=None, mask=None): 19 | """Compute loss for model. If both `labels` and `mask` are None, 20 | it degenerates to SimCLR unsupervised loss: 21 | https://arxiv.org/pdf/2002.05709.pdf 22 | 23 | Args: 24 | features: hidden vector of shape [bsz, n_views, ...]. 25 | labels: ground truth of shape [bsz]. 26 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 27 | has the same class as sample i. Can be asymmetric. 28 | Returns: 29 | A loss scalar. 30 | """ 31 | device = (torch.device('cuda') 32 | if features.is_cuda 33 | else torch.device('cpu')) 34 | 35 | if len(features.shape) < 3: 36 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 37 | 'at least 3 dimensions are required') 38 | if len(features.shape) > 3: 39 | features = features.view(features.shape[0], features.shape[1], -1) 40 | 41 | batch_size = features.shape[0] 42 | if labels is not None and mask is not None: 43 | raise ValueError('Cannot define both `labels` and `mask`') 44 | elif labels is None and mask is None: 45 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 46 | elif labels is not None: 47 | labels = labels.contiguous().view(-1, 1) 48 | if labels.shape[0] != batch_size: 49 | raise ValueError('Num of labels does not match num of features') 50 | mask = torch.eq(labels, labels.T).float().to(device) 51 | else: 52 | mask = mask.float().to(device) 53 | 54 | contrast_count = features.shape[1] 55 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 56 | if self.contrast_mode == 'one': 57 | anchor_feature = features[:, 0] 58 | anchor_count = 1 59 | elif self.contrast_mode == 'all': 60 | anchor_feature = contrast_feature 61 | anchor_count = contrast_count # n_views 62 | else: 63 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 64 | 65 | # compute logits 66 | anchor_dot_contrast = torch.div( 67 | torch.matmul(anchor_feature, contrast_feature.T), 68 | self.temperature) 69 | # for numerical stability 70 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 71 | logits = anchor_dot_contrast - logits_max.detach() 72 | 73 | # tile mask 74 | mask = mask.repeat(anchor_count, contrast_count) 75 | # mask-out self-contrast cases 76 | logits_mask = torch.scatter( 77 | torch.ones_like(mask), 78 | 1, 79 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 80 | 0 81 | ) 82 | mask = mask * logits_mask 83 | 84 | # compute log_prob 85 | exp_logits = torch.exp(logits) * logits_mask 86 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 87 | 88 | # compute mean of log-likelihood over positive 89 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 90 | 91 | # loss 92 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 93 | loss = loss.view(anchor_count, batch_size).mean() 94 | 95 | return loss 96 | 97 | class ConMeanShiftLoss(nn.Module): 98 | def __init__(self, args): 99 | super(ConMeanShiftLoss, self).__init__() 100 | self.n_views = args.n_views 101 | self.alpha = args.alpha 102 | self.temperature = args.temperature 103 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 104 | 105 | def forward(self, knn_emb, features): 106 | batch_size = int(features.size(0)) // self.n_views 107 | positive_mask = torch.eye(batch_size, dtype=torch.float32).to(self.device) 108 | positive_mask = positive_mask.repeat(self.n_views, self.n_views) 109 | negative_mask = torch.ones_like(positive_mask) 110 | negative_mask[positive_mask>0] = 0 111 | self_mask = torch.eye(batch_size * self.n_views, dtype=torch.float32).to(self.device) 112 | positive_mask[self_mask>0] = 0 113 | 114 | #compute logits 115 | meanshift_feat = (1-self.alpha) * features + self.alpha * knn_emb 116 | norm = torch.sqrt(torch.sum((torch.pow(meanshift_feat, 2)), dim=-1)).unsqueeze(1).detach() 117 | meanshift_feat = meanshift_feat / norm 118 | 119 | anchor_dot_contrast = torch.div(torch.matmul(meanshift_feat, meanshift_feat.T), self.temperature) 120 | 121 | # for numerical stability 122 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 123 | logits = anchor_dot_contrast - logits_max.detach() 124 | 125 | # compute log_prob 126 | exp_logits = torch.exp(logits) * negative_mask 127 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 128 | 129 | # compute mean of log-likelihood over positive 130 | loss = - ((positive_mask * log_prob).sum(1) / positive_mask.sum(1)) 131 | loss = loss.view(self.n_views, batch_size).mean() 132 | 133 | return loss 134 | -------------------------------------------------------------------------------- /project_utils/cluster_and_log_utils.py: -------------------------------------------------------------------------------- 1 | from project_utils.cluster_utils import cluster_acc, np, linear_assignment 2 | from torch.utils.tensorboard import SummaryWriter 3 | from typing import List 4 | from scipy.cluster.hierarchy import linkage, fcluster 5 | from sklearn.metrics import silhouette_score 6 | 7 | def split_cluster_acc_v1(y_true, y_pred, mask): 8 | 9 | """ 10 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask' 11 | (Mask usually corresponding to `Old' and `New' classes in GCD setting) 12 | :param targets: All ground truth labels 13 | :param preds: All predictions 14 | :param mask: Mask defining two subsets 15 | :return: 16 | """ 17 | 18 | mask = mask.astype(bool) 19 | y_true = y_true.astype(int) 20 | y_pred = y_pred.astype(int) 21 | weight = mask.mean() 22 | old_acc = cluster_acc(y_true[mask], y_pred[mask]) 23 | new_acc = cluster_acc(y_true[~mask], y_pred[~mask]) 24 | total_acc = weight * old_acc + (1 - weight) * new_acc 25 | 26 | return total_acc, old_acc, new_acc 27 | 28 | def split_cluster_acc_v2(y_true, y_pred, mask): 29 | """ 30 | Calculate clustering accuracy. Require scikit-learn installed 31 | First compute linear assignment on all data, then look at how good the accuracy is on subsets 32 | 33 | # Arguments 34 | mask: Which instances come from old classes (True) and which ones come from new classes (False) 35 | y: true labels, numpy.array with shape `(n_samples,)` 36 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 37 | 38 | # Return 39 | accuracy, in [0,1] 40 | """ 41 | y_true = y_true.astype(int) 42 | 43 | old_classes_gt = set(y_true[mask]) 44 | new_classes_gt = set(y_true[~mask]) 45 | 46 | assert y_pred.size == y_true.size 47 | D = max(y_pred.max(), y_true.max()) + 1 48 | w = np.zeros((D, D), dtype=int) 49 | for i in range(y_pred.size): 50 | w[y_pred[i], y_true[i]] += 1 51 | # w: pred x label count 52 | ind = linear_assignment(w.max() - w) 53 | ind = np.vstack(ind).T 54 | 55 | ind_map = {j: i for i, j in ind} 56 | total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 57 | 58 | old_acc = 0 59 | classwise_acc = [] 60 | total_old_instances = 0 61 | for i in old_classes_gt: 62 | old_acc += w[ind_map[i], i] 63 | total_old_instances += sum(w[:, i]) 64 | classwise_acc.append(w[ind_map[i], i]/sum(w[:, i])) 65 | old_acc /= total_old_instances 66 | 67 | new_acc = 0 68 | total_new_instances = 0 69 | for i in new_classes_gt: 70 | new_acc += w[ind_map[i], i] # w[pred matching with ith label, ith label] 71 | total_new_instances += sum(w[:, i]) 72 | classwise_acc.append(w[ind_map[i], i]/sum(w[:, i])) 73 | new_acc /= total_new_instances 74 | 75 | return total_acc, old_acc, new_acc 76 | 77 | 78 | EVAL_FUNCS = { 79 | 'v1': split_cluster_acc_v1, 80 | 'v2': split_cluster_acc_v2, 81 | } 82 | 83 | def log_accs_from_preds(y_true, y_pred, mask, eval_funcs: List[str], save_name: str, T: int=None, writer: SummaryWriter=None, 84 | print_output=False): 85 | 86 | """ 87 | Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results 88 | 89 | :param y_true: GT labels 90 | :param y_pred: Predicted indices 91 | :param mask: Which instances belong to Old and New classes 92 | :param T: Epoch 93 | :param eval_funcs: Which evaluation functions to use 94 | :param save_name: What are we evaluating ACC on 95 | :param writer: Tensorboard logger 96 | :return: 97 | """ 98 | 99 | mask = mask.astype(bool) 100 | y_true = y_true.astype(int) 101 | y_pred = y_pred.astype(int) 102 | 103 | for i, f_name in enumerate(eval_funcs): 104 | 105 | acc_f = EVAL_FUNCS[f_name] 106 | all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask) 107 | log_name = f'{save_name}_{f_name}' 108 | 109 | if writer is not None: 110 | writer.add_scalars(log_name, 111 | {'Old': old_acc, 'New': new_acc, 112 | 'All': all_acc}, T) 113 | 114 | if i == 0: 115 | to_return = (all_acc, old_acc, new_acc) 116 | 117 | if print_output: 118 | print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}' 119 | print(print_str) 120 | 121 | return to_return 122 | 123 | 124 | def test_agglo(epoch, feats, targets, mask, save_name, args): 125 | best_acc = 0 126 | mask = mask.astype(bool) 127 | feats = np.concatenate(feats) 128 | linked = linkage(feats, method="ward") 129 | 130 | gt_dist = linked[:, 2][-args.num_labeled_classes-args.num_unlabeled_classes] 131 | preds = fcluster(linked, t=gt_dist, criterion='distance') 132 | test_all_acc_test, test_old_acc_test, test_new_acc_test = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask, 133 | T=0, eval_funcs=args.eval_funcs, save_name=save_name) 134 | 135 | dist = linked[:, 2][:-args.num_labeled_classes] 136 | tolerance = 0 137 | for d in reversed(dist): 138 | preds = fcluster(linked, t=d, criterion='distance') 139 | k = max(preds) 140 | all_acc_test, old_acc_test, new_acc_test = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask, 141 | T=0, eval_funcs=args.eval_funcs, save_name=save_name) 142 | 143 | if old_acc_test > best_acc: # save best labeled acc without knowing GT K 144 | best_acc = old_acc_test 145 | best_acc_k = k 146 | best_acc_d = d 147 | tolerance = 0 148 | else: 149 | tolerance += 1 150 | 151 | if tolerance == 50 : 152 | break 153 | return test_all_acc_test, test_old_acc_test, test_new_acc_test, best_acc, best_acc_k 154 | 155 | -------------------------------------------------------------------------------- /project_utils/cluster_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import numpy as np 3 | import sklearn.metrics 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib 7 | matplotlib.use('agg') 8 | from scipy.optimize import linear_sum_assignment as linear_assignment 9 | import random 10 | import os 11 | import argparse 12 | 13 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 14 | from sklearn.metrics import adjusted_rand_score as ari_score 15 | 16 | from sklearn import metrics 17 | import time 18 | 19 | # ------------------------------- 20 | # Evaluation Criteria 21 | # ------------------------------- 22 | def evaluate_clustering(y_true, y_pred): 23 | 24 | start = time.time() 25 | print('Computing metrics...') 26 | if len(set(y_pred)) < 1000: 27 | acc = cluster_acc(y_true.astype(int), y_pred.astype(int)) 28 | else: 29 | acc = None 30 | 31 | nmi = nmi_score(y_true, y_pred) 32 | ari = ari_score(y_true, y_pred) 33 | pur = purity_score(y_true, y_pred) 34 | print(f'Finished computing metrics {time.time() - start}...') 35 | 36 | return acc, nmi, ari, pur 37 | 38 | 39 | def cluster_acc(y_true, y_pred, return_ind=False): 40 | """ 41 | Calculate clustering accuracy. Require scikit-learn installed 42 | 43 | # Arguments 44 | y: true labels, numpy.array with shape `(n_samples,)` 45 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 46 | 47 | # Return 48 | accuracy, in [0,1] 49 | """ 50 | y_true = y_true.astype(int) 51 | assert y_pred.size == y_true.size 52 | D = max(y_pred.max(), y_true.max()) + 1 53 | w = np.zeros((D, D), dtype=int) 54 | for i in range(y_pred.size): 55 | w[y_pred[i], y_true[i]] += 1 56 | 57 | ind = linear_assignment(w.max() - w) 58 | ind = np.vstack(ind).T 59 | 60 | if return_ind: 61 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w 62 | else: 63 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 64 | 65 | 66 | def purity_score(y_true, y_pred): 67 | # compute contingency matrix (also called confusion matrix) 68 | contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred) 69 | # return purity 70 | return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix) 71 | 72 | # ------------------------------- 73 | # Mixed Eval Function 74 | # ------------------------------- 75 | def mixed_eval(targets, preds, mask): 76 | 77 | """ 78 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask' 79 | (Mask usually corresponding to `Old' and `New' classes in GCD setting) 80 | :param targets: All ground truth labels 81 | :param preds: All predictions 82 | :param mask: Mask defining two subsets 83 | :return: 84 | """ 85 | 86 | mask = mask.astype(bool) 87 | 88 | # Labelled examples 89 | if mask.sum() == 0: # All examples come from unlabelled classes 90 | 91 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int), preds.astype(int)), \ 92 | nmi_score(targets, preds), \ 93 | ari_score(targets, preds) 94 | 95 | print('Unlabelled Classes Test acc {:.4f}, nmi {:.4f}, ari {:.4f}' 96 | .format(unlabelled_acc, unlabelled_nmi, unlabelled_ari)) 97 | 98 | # Also return ratio between labelled and unlabelled examples 99 | return (unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.mean() 100 | 101 | else: 102 | 103 | labelled_acc, labelled_nmi, labelled_ari = cluster_acc(targets.astype(int)[mask], 104 | preds.astype(int)[mask]), \ 105 | nmi_score(targets[mask], preds[mask]), \ 106 | ari_score(targets[mask], preds[mask]) 107 | 108 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int)[~mask], 109 | preds.astype(int)[~mask]), \ 110 | nmi_score(targets[~mask], preds[~mask]), \ 111 | ari_score(targets[~mask], preds[~mask]) 112 | 113 | # Also return ratio between labelled and unlabelled examples 114 | return (labelled_acc, labelled_nmi, labelled_ari), ( 115 | unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.mean() 116 | 117 | 118 | class AverageMeter(object): 119 | """Computes and stores the average and current value""" 120 | def __init__(self): 121 | self.reset() 122 | 123 | def reset(self): 124 | self.val = 0 125 | self.avg = 0 126 | self.sum = 0 127 | self.count = 0 128 | 129 | def update(self, val, n=1): 130 | self.val = val 131 | self.sum += val * n 132 | self.count += n 133 | self.avg = self.sum / self.count 134 | 135 | 136 | class Identity(nn.Module): 137 | def __init__(self): 138 | super(Identity, self).__init__() 139 | def forward(self, x): 140 | return x 141 | 142 | 143 | class BCE(nn.Module): 144 | eps = 1e-7 # Avoid calculating log(0). Use the small value of float16. 145 | def forward(self, prob1, prob2, simi): 146 | # simi: 1->similar; -1->dissimilar; 0->unknown(ignore) 147 | assert len(prob1)==len(prob2)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(prob1)),str(len(prob2)),str(len(simi))) 148 | P = prob1.mul_(prob2) 149 | P = P.sum(1) 150 | P.mul_(simi).add_(simi.eq(-1).type_as(P)) 151 | neglogP = -P.add_(BCE.eps).log_() 152 | return neglogP.mean() 153 | 154 | 155 | def PairEnum(x,mask=None): 156 | 157 | # Enumerate all pairs of feature in x 158 | assert x.ndimension() == 2, 'Input dimension must be 2' 159 | x1 = x.repeat(x.size(0), 1) 160 | x2 = x.repeat(1, x.size(0)).view(-1, x.size(1)) 161 | 162 | if mask is not None: 163 | 164 | xmask = mask.view(-1, 1).repeat(1, x.size(1)) 165 | #dim 0: #sample, dim 1:#feature 166 | x1 = x1[xmask].view(-1, x.size(1)) 167 | x2 = x2[xmask].view(-1, x.size(1)) 168 | 169 | return x1, x2 170 | 171 | 172 | def accuracy(output, target, topk=(1,)): 173 | """Computes the accuracy over the k top predictions for the specified values of k""" 174 | with torch.no_grad(): 175 | maxk = max(topk) 176 | batch_size = target.size(0) 177 | 178 | _, pred = output.topk(maxk, 1, True, True) 179 | pred = pred.t() 180 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 181 | 182 | res = [] 183 | for k in topk: 184 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 185 | res.append(correct_k.mul_(100.0 / batch_size)) 186 | return res 187 | 188 | 189 | def seed_torch(seed=1029): 190 | random.seed(seed) 191 | os.environ['PYTHONHASHSEED'] = str(seed) 192 | np.random.seed(seed) 193 | torch.manual_seed(seed) 194 | torch.cuda.manual_seed(seed) 195 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 196 | torch.backends.cudnn.benchmark = False 197 | torch.backends.cudnn.deterministic = True 198 | 199 | 200 | def str2bool(v): 201 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 202 | return True 203 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 204 | return False 205 | else: 206 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Contrastive Mean-Shift Learning 3 | for Generalized Category Discovery

4 |
5 | 6 |
7 |

Sua Choi    Dahyun Kang    Minsu Cho 8 |

Pohang University of Science and Technology (POSTECH) 9 |

[Paper] [Project page] 10 |

11 |
12 | 13 | 14 | 15 |
16 |
17 | result 18 |
19 | 20 | 24 | 25 | 26 | 27 | ## Environmnet installation 28 | This project is built upon the following environment: 29 | * [Python 3.10](https://pytorch.org) 30 | * [CUDA 11.7](https://developer.nvidia.com/cuda-toolkit) 31 | * [PyTorch 1.13.1](https://pytorch.org) 32 | 33 | The package requirements can be installed via `requirements.txt`, 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ## Datasets 39 | We use fine-grained benchmarks in this paper, including: 40 | * [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6) 41 | 42 | We also use generic object recognition datasets, including: 43 | * [CIFAR100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet](https://image-net.org/download.php) 44 | 45 | Please follow [this repo](https://github.com/sgvaze/generalized-category-discovery) to set up the data. 46 | 47 | Download the datasets, ssb splits, and pretrained backbone by following the file structure below and set `DATASET_ROOT={YOUR DIRECTORY}` in `config.py`. 48 | 49 | ``` 50 | DATASET_ROOT/ 51 | ├── cifar100/ 52 | │ ├── cifar-100-python\ 53 | │ │ ├── meta/ 54 | │ ├── ... 55 | ├── CUB_200_2011/ 56 | │ ├── attributes/ 57 | │ ├── ... 58 | ├── ... 59 | ``` 60 | ``` 61 | CMS/ 62 | ├── data/ 63 | │ ├── ssb_splits/ 64 | ├── models/ 65 | │ ├── dino_vitbase16_pretrain.pth 66 | ├── ... 67 | ``` 68 | 69 | 70 | ## Training 71 | ```bash 72 | bash bash_scripts/contrastive_meanshift_training.sh 73 | ``` 74 | Example bash commands for training are as follows: 75 | ```bash 76 | # GCD 77 | python -m methods.contrastive_meanshift_training \ 78 | --dataset_name 'cub' \ 79 | --lr 0.05 \ 80 | --temperature 0.25 \ 81 | --wandb 82 | 83 | # Inductive GCD 84 | python -m methods.contrastive_meanshift_training \ 85 | --dataset_name 'cub' \ 86 | --lr 0.05 \ 87 | --temperature 0.25 \ 88 | --inductive \ 89 | --wandb 90 | ``` 91 | 92 | ## Evaluation 93 | ```bash 94 | bash bash_scripts/meanshift_clustering.sh 95 | ``` 96 | Example bash command for evaluation is as follows. It will require changing `model_name`. 97 | ```bash 98 | python -m methods.meanshift_clustering \ 99 | --dataset_name 'cub' \ 100 | --model_name 'cub_best' \ 101 | ``` 102 | 103 | 104 | ## Results and checkpoints 105 | 106 | ### Experimental results on GCD task. 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 |
AllOldNovelCheckpoints
CIFAR10082.385.775.5link
ImageNet10084.795.679.2link
CUB68.276.564.0link
Stanford Cars56.976.147.6link
FGVC-Aircraft56.063.452.3link
Herbarium1936.454.926.4link
159 | 160 | 161 | ### Experimental results on inductive GCD task. 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 |
AllOldNovelCheckpoints
CIFAR10080.784.465.9link
ImageNet10085.795.775.8link
CUB69.776.563.0link
Stanford Cars57.875.241.0link
FGVC-Aircraft53.362.743.8link
Herbarium1946.253.038.9link
214 | 215 | 216 | 217 | ## Citation 218 | If you find our code or paper useful, please consider citing our paper: 219 | 220 | ```BibTeX 221 | @inproceedings{choi2024contrastive, 222 | title={Contrastive Mean-Shift Learning for Generalized Category Discovery}, 223 | author={Choi, Sua and Kang, Dahyun and Cho, Minsu}, 224 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 225 | year={2024} 226 | } 227 | ``` 228 | 229 | ## Related Repos 230 | The codebase is largely built on [Generalized Category Discovery](https://github.com/sgvaze/generalized-category-discovery) and [PromptCAL](https://github.com/sheng-eatamath/PromptCAL). 231 | 232 | 233 | ## Acknowledgements 234 | This work was supported by the NRF grant (NRF-2021R1A2C3012728 (50%)) and the IITP grants (2022-0-00113: Developing a Sustainable Collaborative Multi-modal Lifelong Learning Framework (45%), 2019-0-01906: 235 | AI Graduate School Program at POSTECH (5%)) funded by Ministry of Science and ICT, Korea. 236 | 237 | -------------------------------------------------------------------------------- /data/augmentations/randaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | """ 4 | 5 | import random 6 | 7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | 13 | def ShearX(img, v): # [-0.3, 0.3] 14 | assert -0.3 <= v <= 0.3 15 | if random.random() > 0.5: 16 | v = -v 17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 18 | 19 | 20 | def ShearY(img, v): # [-0.3, 0.3] 21 | assert -0.3 <= v <= 0.3 22 | if random.random() > 0.5: 23 | v = -v 24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 25 | 26 | 27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 28 | assert -0.45 <= v <= 0.45 29 | if random.random() > 0.5: 30 | v = -v 31 | v = v * img.size[0] 32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 33 | 34 | 35 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 36 | assert 0 <= v 37 | if random.random() > 0.5: 38 | v = -v 39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 40 | 41 | 42 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 43 | assert -0.45 <= v <= 0.45 44 | if random.random() > 0.5: 45 | v = -v 46 | v = v * img.size[1] 47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 48 | 49 | 50 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 51 | assert 0 <= v 52 | if random.random() > 0.5: 53 | v = -v 54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 55 | 56 | 57 | def Rotate(img, v): # [-30, 30] 58 | assert -30 <= v <= 30 59 | if random.random() > 0.5: 60 | v = -v 61 | return img.rotate(v) 62 | 63 | 64 | def AutoContrast(img, _): 65 | return PIL.ImageOps.autocontrast(img) 66 | 67 | 68 | def Invert(img, _): 69 | return PIL.ImageOps.invert(img) 70 | 71 | 72 | def Equalize(img, _): 73 | return PIL.ImageOps.equalize(img) 74 | 75 | 76 | def Flip(img, _): # not from the paper 77 | return PIL.ImageOps.mirror(img) 78 | 79 | 80 | def Solarize(img, v): # [0, 256] 81 | assert 0 <= v <= 256 82 | return PIL.ImageOps.solarize(img, v) 83 | 84 | 85 | def SolarizeAdd(img, addition=0, threshold=128): 86 | img_np = np.array(img).astype(np.int) 87 | img_np = img_np + addition 88 | img_np = np.clip(img_np, 0, 255) 89 | img_np = img_np.astype(np.uint8) 90 | img = Image.fromarray(img_np) 91 | return PIL.ImageOps.solarize(img, threshold) 92 | 93 | 94 | def Posterize(img, v): # [4, 8] 95 | v = int(v) 96 | v = max(1, v) 97 | return PIL.ImageOps.posterize(img, v) 98 | 99 | 100 | def Contrast(img, v): # [0.1,1.9] 101 | assert 0.1 <= v <= 1.9 102 | return PIL.ImageEnhance.Contrast(img).enhance(v) 103 | 104 | 105 | def Color(img, v): # [0.1,1.9] 106 | assert 0.1 <= v <= 1.9 107 | return PIL.ImageEnhance.Color(img).enhance(v) 108 | 109 | 110 | def Brightness(img, v): # [0.1,1.9] 111 | assert 0.1 <= v <= 1.9 112 | return PIL.ImageEnhance.Brightness(img).enhance(v) 113 | 114 | 115 | def Sharpness(img, v): # [0.1,1.9] 116 | assert 0.1 <= v <= 1.9 117 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 118 | 119 | 120 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 121 | assert 0.0 <= v <= 0.2 122 | if v <= 0.: 123 | return img 124 | 125 | v = v * img.size[0] 126 | return CutoutAbs(img, v) 127 | 128 | 129 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 130 | # assert 0 <= v <= 20 131 | if v < 0: 132 | return img 133 | w, h = img.size 134 | x0 = np.random.uniform(w) 135 | y0 = np.random.uniform(h) 136 | 137 | x0 = int(max(0, x0 - v / 2.)) 138 | y0 = int(max(0, y0 - v / 2.)) 139 | x1 = min(w, x0 + v) 140 | y1 = min(h, y0 + v) 141 | 142 | xy = (x0, y0, x1, y1) 143 | color = (125, 123, 114) 144 | # color = (0, 0, 0) 145 | img = img.copy() 146 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 147 | return img 148 | 149 | 150 | def SamplePairing(imgs): # [0, 0.4] 151 | def f(img1, v): 152 | i = np.random.choice(len(imgs)) 153 | img2 = PIL.Image.fromarray(imgs[i]) 154 | return PIL.Image.blend(img1, img2, v) 155 | 156 | return f 157 | 158 | 159 | def Identity(img, v): 160 | return img 161 | 162 | 163 | def augment_list(): # 16 oeprations and their ranges 164 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 165 | # l = [ 166 | # (Identity, 0., 1.0), 167 | # (ShearX, 0., 0.3), # 0 168 | # (ShearY, 0., 0.3), # 1 169 | # (TranslateX, 0., 0.33), # 2 170 | # (TranslateY, 0., 0.33), # 3 171 | # (Rotate, 0, 30), # 4 172 | # (AutoContrast, 0, 1), # 5 173 | # (Invert, 0, 1), # 6 174 | # (Equalize, 0, 1), # 7 175 | # (Solarize, 0, 110), # 8 176 | # (Posterize, 4, 8), # 9 177 | # # (Contrast, 0.1, 1.9), # 10 178 | # (Color, 0.1, 1.9), # 11 179 | # (Brightness, 0.1, 1.9), # 12 180 | # (Sharpness, 0.1, 1.9), # 13 181 | # # (Cutout, 0, 0.2), # 14 182 | # # (SamplePairing(imgs), 0, 0.4), # 15 183 | # ] 184 | 185 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 186 | l = [ 187 | (AutoContrast, 0, 1), 188 | (Equalize, 0, 1), 189 | (Invert, 0, 1), 190 | (Rotate, 0, 30), 191 | (Posterize, 0, 4), 192 | (Solarize, 0, 256), 193 | (SolarizeAdd, 0, 110), 194 | (Color, 0.1, 1.9), 195 | (Contrast, 0.1, 1.9), 196 | (Brightness, 0.1, 1.9), 197 | (Sharpness, 0.1, 1.9), 198 | (ShearX, 0., 0.3), 199 | (ShearY, 0., 0.3), 200 | (CutoutAbs, 0, 40), 201 | (TranslateXabs, 0., 100), 202 | (TranslateYabs, 0., 100), 203 | ] 204 | 205 | return l 206 | 207 | def augment_list_svhn(): # 16 oeprations and their ranges 208 | 209 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 210 | l = [ 211 | (AutoContrast, 0, 1), 212 | (Equalize, 0, 1), 213 | (Invert, 0, 1), 214 | (Posterize, 0, 4), 215 | (Solarize, 0, 256), 216 | (SolarizeAdd, 0, 110), 217 | (Color, 0.1, 1.9), 218 | (Contrast, 0.1, 1.9), 219 | (Brightness, 0.1, 1.9), 220 | (Sharpness, 0.1, 1.9), 221 | (ShearX, 0., 0.3), 222 | (ShearY, 0., 0.3), 223 | (CutoutAbs, 0, 40), 224 | ] 225 | 226 | return l 227 | 228 | 229 | class Lighting(object): 230 | """Lighting noise(AlexNet - style PCA - based noise)""" 231 | 232 | def __init__(self, alphastd, eigval, eigvec): 233 | self.alphastd = alphastd 234 | self.eigval = torch.Tensor(eigval) 235 | self.eigvec = torch.Tensor(eigvec) 236 | 237 | def __call__(self, img): 238 | if self.alphastd == 0: 239 | return img 240 | 241 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 242 | rgb = self.eigvec.type_as(img).clone() \ 243 | .mul(alpha.view(1, 3).expand(3, 3)) \ 244 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 245 | .sum(1).squeeze() 246 | 247 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 248 | 249 | 250 | class CutoutDefault(object): 251 | """ 252 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 253 | """ 254 | def __init__(self, length): 255 | self.length = length 256 | 257 | def __call__(self, img): 258 | h, w = img.size(1), img.size(2) 259 | mask = np.ones((h, w), float) 260 | y = np.random.randint(h) 261 | x = np.random.randint(w) 262 | 263 | y1 = np.clip(y - self.length // 2, 0, h) 264 | y2 = np.clip(y + self.length // 2, 0, h) 265 | x1 = np.clip(x - self.length // 2, 0, w) 266 | x2 = np.clip(x + self.length // 2, 0, w) 267 | 268 | mask[y1: y2, x1: x2] = 0. 269 | mask = torch.from_numpy(mask) 270 | mask = mask.expand_as(img) 271 | img *= mask 272 | return img 273 | 274 | 275 | class RandAugment: 276 | def __init__(self, n, m, args=None): 277 | self.n = n # [1, 2] 278 | self.m = m # [0...30] 279 | 280 | if args is None: 281 | self.augment_list = augment_list() 282 | 283 | elif args.dataset == 'svhn' or args.dataset == 'mnist': 284 | self.augment_list = augment_list_svhn() 285 | 286 | else: 287 | self.augment_list = augment_list() 288 | 289 | def __call__(self, img): 290 | ops = random.choices(self.augment_list, k=self.n) 291 | for op, minval, maxval in ops: 292 | val = (float(self.m) / 30) * float(maxval - minval) + minval 293 | img = op(img, val) 294 | 295 | return img -------------------------------------------------------------------------------- /data/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | from scipy import io as mat_io 6 | 7 | from torchvision.datasets.folder import default_loader 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import subsample_instances 11 | 12 | car_root = "/home/suachoi/datasets/stanford_car/cars_{}/" 13 | meta_default_path = "/home/suachoi/datasets/stanford_car/devkit/cars_{}.mat" 14 | 15 | class CarsDataset(Dataset): 16 | """ 17 | Cars Dataset 18 | """ 19 | def __init__(self, train=True, limit=0, data_dir=car_root, transform=None, metas=meta_default_path): 20 | 21 | data_dir = data_dir.format('train') if train else data_dir.format('test') 22 | metas = metas.format('train_annos') if train else metas.format('test_annos_withlabels') 23 | 24 | self.loader = default_loader 25 | self.data_dir = data_dir 26 | self.data = [] 27 | self.target = [] 28 | self.train = train 29 | 30 | self.transform = transform 31 | 32 | if not isinstance(metas, str): 33 | raise Exception("Train metas must be string location !") 34 | labels_meta = mat_io.loadmat(metas) 35 | 36 | for idx, img_ in enumerate(labels_meta['annotations'][0]): 37 | if limit: 38 | if idx > limit: 39 | break 40 | 41 | self.data.append(data_dir + img_[5][0]) 42 | self.target.append(img_[4][0][0]) 43 | 44 | self.uq_idxs = np.array(range(len(self))) 45 | self.target_transform = None 46 | 47 | def __getitem__(self, idx): 48 | 49 | image = self.loader(self.data[idx]) 50 | target = self.target[idx] - 1 51 | 52 | if self.transform is not None: 53 | image = self.transform(image) 54 | 55 | if self.target_transform is not None: 56 | target = self.target_transform(target) 57 | 58 | idx = self.uq_idxs[idx] 59 | 60 | return image, target, idx 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | 66 | def subsample_dataset(dataset, idxs, absolute=True): 67 | mask = np.zeros(len(dataset)).astype('bool') 68 | if absolute==True: 69 | mask[idxs] = True 70 | else: 71 | idxs = set(idxs) 72 | mask = np.array([i in idxs for i in dataset.uq_idxs]) 73 | 74 | dataset.data = np.array(dataset.data)[mask].tolist() 75 | dataset.target = np.array(dataset.target)[mask].tolist() 76 | dataset.uq_idxs = dataset.uq_idxs[mask] 77 | 78 | return dataset 79 | 80 | 81 | def subsample_classes(dataset, include_classes=range(160)): 82 | 83 | include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195 84 | cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars] 85 | 86 | target_xform_dict = {} 87 | for i, k in enumerate(include_classes): 88 | target_xform_dict[k] = i 89 | 90 | dataset = subsample_dataset(dataset, cls_idxs) 91 | 92 | return dataset 93 | 94 | def get_train_val_indices(train_dataset, val_split=0.2): 95 | 96 | train_classes = np.unique(train_dataset.target) 97 | 98 | # Get train/test indices 99 | train_idxs = [] 100 | val_idxs = [] 101 | for cls in train_classes: 102 | 103 | cls_idxs = np.where(train_dataset.target == cls)[0] 104 | 105 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 106 | t_ = [x for x in cls_idxs if x not in v_] 107 | 108 | train_idxs.extend(t_) 109 | val_idxs.extend(v_) 110 | 111 | return train_idxs, val_idxs 112 | 113 | 114 | def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, 115 | split_train_val=False, seed=0): 116 | 117 | np.random.seed(seed) 118 | 119 | # Init entire training set 120 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, metas=meta_default_path, train=True) 121 | 122 | # Get labelled training set which has subsampled classes, then subsample some indices from that 123 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 124 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 125 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 126 | 127 | # Split into training and validation sets 128 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 129 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 130 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 131 | val_dataset_labelled_split.transform = test_transform 132 | 133 | # Get unlabelled data 134 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 135 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 136 | 137 | # Get test set for all classes 138 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, metas=meta_default_path, train=False) 139 | 140 | # Either split train into train and val or use test set as val 141 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 142 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 143 | 144 | all_datasets = { 145 | 'train_labelled': train_dataset_labelled, 146 | 'train_unlabelled': train_dataset_unlabelled, 147 | 'val': val_dataset_labelled, 148 | 'test': test_dataset, 149 | } 150 | 151 | return all_datasets 152 | 153 | def get_scars_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80), 154 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1): 155 | 156 | np.random.seed(seed) 157 | 158 | # Init entire training set 159 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, metas=meta_default_path, train=True) 160 | 161 | # Get labelled training set which has subsampled classes, then subsample some indices from that 162 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 163 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 164 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 165 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 166 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False) 167 | 168 | # Split into training and validation sets 169 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split) 170 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 171 | val_dataset_labelled.transform = test_transform 172 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 173 | 174 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split) 175 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs) 176 | val_dataset_unlabelled.transform = test_transform 177 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs) 178 | 179 | val_dataset_unlabelled.transform = test_transform 180 | 181 | 182 | # Get test set for all classes 183 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, metas=meta_default_path, train=False) 184 | 185 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}') 186 | 187 | all_datasets = { 188 | 'train_labelled': train_dataset_labelled, 189 | 'train_unlabelled': train_dataset_unlabelled, 190 | 'val': [val_dataset_labelled, val_dataset_unlabelled], 191 | 'test': test_dataset, 192 | } 193 | 194 | return all_datasets 195 | 196 | if __name__ == '__main__': 197 | 198 | x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False) 199 | 200 | print('Printing lens...') 201 | for k, v in x.items(): 202 | if v is not None: 203 | print(f'{k}: {len(v)}') 204 | 205 | print('Printing labelled and unlabelled overlap...') 206 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 207 | print('Printing total instances in train...') 208 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 209 | 210 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}') 211 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}') 212 | print(f'Len labelled set: {len(x["train_labelled"])}') 213 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/herbarium_19.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision 4 | import numpy as np 5 | from copy import deepcopy 6 | 7 | from data.data_utils import subsample_instances 8 | from config import herbarium_dataroot 9 | 10 | class HerbariumDataset19(torchvision.datasets.ImageFolder): 11 | 12 | def __init__(self, *args, **kwargs): 13 | 14 | # Process metadata json for training images into a DataFrame 15 | super().__init__(*args, **kwargs) 16 | 17 | self.uq_idxs = np.array(range(len(self))) 18 | 19 | def __getitem__(self, idx): 20 | 21 | img, label = super().__getitem__(idx) 22 | uq_idx = self.uq_idxs[idx] 23 | 24 | return img, label, uq_idx 25 | 26 | 27 | def subsample_dataset(dataset, idxs, absolute=True): 28 | mask = np.zeros(len(dataset)).astype('bool') 29 | if absolute==True: 30 | mask[idxs] = True 31 | else: 32 | idxs = set(idxs) 33 | mask = np.array([i in idxs for i in dataset.uq_idxs]) 34 | 35 | dataset.samples = np.array(dataset.samples)[mask].tolist() 36 | dataset.targets = np.array(dataset.targets)[mask].tolist() 37 | 38 | dataset.uq_idxs = dataset.uq_idxs[mask] 39 | 40 | dataset.samples = [[x[0], int(x[1])] for x in dataset.samples] 41 | dataset.targets = [int(x) for x in dataset.targets] 42 | 43 | return dataset 44 | 45 | 46 | def subsample_classes(dataset, include_classes=range(250)): 47 | 48 | cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes] 49 | 50 | target_xform_dict = {} 51 | for i, k in enumerate(include_classes): 52 | target_xform_dict[k] = i 53 | 54 | dataset = subsample_dataset(dataset, cls_idxs) 55 | 56 | dataset.target_transform = lambda x: target_xform_dict[x] 57 | 58 | return dataset 59 | 60 | 61 | def get_train_val_indices(train_dataset, val_instances_per_class=5): 62 | 63 | train_classes = list(set(train_dataset.targets)) 64 | 65 | # Get train/test indices 66 | train_idxs = [] 67 | val_idxs = [] 68 | for cls in train_classes: 69 | 70 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] 71 | 72 | # Have a balanced test set 73 | val_size = min(cls_idxs.shape[0], val_instances_per_class) 74 | v_ = np.random.choice(cls_idxs, replace=False, size=(val_size,)) 75 | t_ = [x for x in cls_idxs if x not in v_] 76 | 77 | train_idxs.extend(t_) 78 | val_idxs.extend(v_) 79 | 80 | return train_idxs, val_idxs 81 | 82 | 83 | def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8, 84 | seed=0, split_train_val=False): 85 | 86 | np.random.seed(seed) 87 | 88 | # Init entire training set 89 | train_dataset = HerbariumDataset19(transform=train_transform, 90 | root=os.path.join(herbarium_dataroot, 'small-train')) 91 | 92 | # Get labelled training set which has subsampled classes, then subsample some indices from that 93 | # TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class 94 | train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes) 95 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 96 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 97 | 98 | # Split into training and validation sets 99 | if split_train_val: 100 | 101 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, 102 | val_instances_per_class=5) 103 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 104 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 105 | val_dataset_labelled_split.transform = test_transform 106 | 107 | else: 108 | 109 | train_dataset_labelled_split, val_dataset_labelled_split = None, None 110 | 111 | # Get unlabelled data 112 | unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs) 113 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices))) 114 | 115 | # Get test dataset 116 | test_dataset = HerbariumDataset19(transform=test_transform, 117 | root=os.path.join(herbarium_dataroot, 'small-validation')) 118 | 119 | # Transform dict 120 | unlabelled_classes = list(set(train_dataset.targets) - set(train_classes)) 121 | target_xform_dict = {} 122 | for i, k in enumerate(list(train_classes) + unlabelled_classes): 123 | target_xform_dict[k] = i 124 | 125 | test_dataset.target_transform = lambda x: target_xform_dict[x] 126 | train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x] 127 | 128 | # Either split train into train and val or use test set as val 129 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 130 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 131 | 132 | all_datasets = { 133 | 'train_labelled': train_dataset_labelled, 134 | 'train_unlabelled': train_dataset_unlabelled, 135 | 'val': val_dataset_labelled, 136 | 'test': test_dataset, 137 | } 138 | 139 | return all_datasets 140 | 141 | def get_herbarium_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80), 142 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1): 143 | 144 | np.random.seed(seed) 145 | 146 | # Init entire training set 147 | whole_training_set = HerbariumDataset19(transform=train_transform, 148 | root=os.path.join(herbarium_dataroot, 'small-train')) 149 | # Get labelled training set which has subsampled classes, then subsample some indices from that 150 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 151 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 152 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 153 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 154 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False) 155 | 156 | # Split into training and validation sets 157 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_instances_per_class=5) 158 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 159 | val_dataset_labelled.transform = test_transform 160 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 161 | 162 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_instances_per_class=5) 163 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs) 164 | val_dataset_unlabelled.transform = test_transform 165 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs) 166 | 167 | val_dataset_unlabelled.transform = test_transform 168 | 169 | 170 | # Get test set for all classes 171 | test_dataset = HerbariumDataset19(transform=test_transform, 172 | root=os.path.join(herbarium_dataroot, 'small-validation')) 173 | 174 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}') 175 | 176 | all_datasets = { 177 | 'train_labelled': train_dataset_labelled, 178 | 'train_unlabelled': train_dataset_unlabelled, 179 | 'val': [val_dataset_labelled, val_dataset_unlabelled], 180 | 'test': test_dataset, 181 | } 182 | 183 | return all_datasets 184 | 185 | if __name__ == '__main__': 186 | 187 | np.random.seed(0) 188 | train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False) 189 | 190 | x = get_herbarium_datasets(None, None, train_classes=train_classes, 191 | prop_train_labels=0.5) 192 | 193 | assert set(x['train_unlabelled'].targets) == set(range(683)) 194 | 195 | print('Printing lens...') 196 | for k, v in x.items(): 197 | if v is not None: 198 | print(f'{k}: {len(v)}') 199 | 200 | print('Printing labelled and unlabelled overlap...') 201 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 202 | print('Printing total instances in train...') 203 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 204 | print('Printing number of labelled classes...') 205 | print(len(set(x['train_labelled'].targets))) 206 | print('Printing total number of classes...') 207 | print(len(set(x['train_unlabelled'].targets))) 208 | 209 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 210 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 211 | print(f'Len labelled set: {len(x["train_labelled"])}') 212 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/get_datasets.py: -------------------------------------------------------------------------------- 1 | from data.data_utils import MergedDataset 2 | 3 | from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets, get_cifar_100_datasets_with_gcdval 4 | from data.herbarium_19 import get_herbarium_datasets, get_herbarium_datasets_with_gcdval 5 | from data.stanford_cars import get_scars_datasets, get_scars_datasets_with_gcdval 6 | from data.imagenet import get_imagenet_100_datasets, get_imagenet_100_gcd_datasets_with_gcdval 7 | from data.cub import get_cub_datasets, get_cub_datasets_with_gcdval 8 | from data.fgvc_aircraft import get_aircraft_datasets, get_aircraft_datasets_with_gcdval 9 | 10 | from data.cifar import subsample_classes as subsample_dataset_cifar 11 | from data.herbarium_19 import subsample_classes as subsample_dataset_herb 12 | from data.stanford_cars import subsample_classes as subsample_dataset_scars 13 | from data.imagenet import subsample_classes as subsample_dataset_imagenet 14 | from data.cub import subsample_classes as subsample_dataset_cub 15 | from data.fgvc_aircraft import subsample_classes as subsample_dataset_air 16 | 17 | from copy import deepcopy 18 | import pickle 19 | import os 20 | 21 | from config import osr_split_dir 22 | 23 | sub_sample_class_funcs = { 24 | 'cifar10': subsample_dataset_cifar, 25 | 'cifar100': subsample_dataset_cifar, 26 | 'imagenet_100': subsample_dataset_imagenet, 27 | 'herbarium_19': subsample_dataset_herb, 28 | 'cub': subsample_dataset_cub, 29 | 'aircraft': subsample_dataset_air, 30 | 'scars': subsample_dataset_scars 31 | } 32 | 33 | get_dataset_funcs = { 34 | 'cifar10': get_cifar_10_datasets, 35 | 'cifar100': get_cifar_100_datasets, 36 | 'imagenet_100': get_imagenet_100_datasets, 37 | 'herbarium_19': get_herbarium_datasets, 38 | 'cub': get_cub_datasets, 39 | 'aircraft': get_aircraft_datasets, 40 | 'scars': get_scars_datasets 41 | } 42 | 43 | get_dataset_funcs_with_gcdval = { 44 | 'cifar100': get_cifar_100_datasets_with_gcdval, 45 | 'cub': get_cub_datasets_with_gcdval, 46 | 'imagenet_100': get_imagenet_100_gcd_datasets_with_gcdval, 47 | 'aircraft': get_aircraft_datasets_with_gcdval, 48 | 'scars': get_scars_datasets_with_gcdval, 49 | 'herbarium_19': get_herbarium_datasets_with_gcdval 50 | } 51 | 52 | 53 | def get_datasets(dataset_name, train_transform, test_transform, args): 54 | 55 | """ 56 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled 57 | test_dataset, 58 | unlabelled_train_examples_test, 59 | datasets 60 | """ 61 | 62 | # 63 | if dataset_name not in get_dataset_funcs.keys(): 64 | raise ValueError 65 | 66 | # Get datasets 67 | get_dataset_f = get_dataset_funcs[dataset_name] 68 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform, 69 | train_classes=args.train_classes, 70 | prop_train_labels=args.prop_train_labels, 71 | split_train_val=False) 72 | 73 | # Set target transforms: 74 | target_transform_dict = {} 75 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)): 76 | target_transform_dict[cls] = i 77 | target_transform = lambda x: target_transform_dict[x] 78 | 79 | for dataset_name, dataset in datasets.items(): 80 | if dataset is not None: 81 | dataset.target_transform = target_transform 82 | 83 | # Train split (labelled and unlabelled classes) for training 84 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']), 85 | unlabelled_dataset=deepcopy(datasets['train_unlabelled'])) 86 | 87 | test_dataset = datasets['test'] 88 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled']) 89 | unlabelled_train_examples_test.transform = test_transform 90 | 91 | return train_dataset, test_dataset, unlabelled_train_examples_test, datasets 92 | 93 | 94 | def get_datasets_with_gcdval(dataset_name, train_transform, test_transform, args): 95 | """ 96 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled 97 | test_dataset, 98 | unlabelled_train_examples_test, 99 | datasets 100 | """ 101 | 102 | # 103 | if dataset_name not in get_dataset_funcs.keys(): 104 | raise ValueError 105 | 106 | # Get datasets 107 | get_dataset_f = get_dataset_funcs_with_gcdval[dataset_name] 108 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform, 109 | train_classes=args.train_classes, 110 | prop_train_labels=args.prop_train_labels, 111 | split_train_val=True, val_split=0.1) 112 | 113 | # Set target transforms: 114 | target_transform_dict = {} 115 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)): 116 | target_transform_dict[cls] = i 117 | target_transform = lambda x: target_transform_dict[x] 118 | 119 | for dataset_name, dataset in datasets.items(): 120 | if dataset is not None: 121 | if isinstance(dataset, list): 122 | for d in dataset: 123 | d.target_transform = target_transform 124 | else: 125 | dataset.target_transform = target_transform 126 | 127 | # Train split (labelled and unlabelled classes) for training 128 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']), 129 | unlabelled_dataset=deepcopy(datasets['train_unlabelled'])) 130 | val_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['val'][0]), 131 | unlabelled_dataset=deepcopy(datasets['val'][1])) 132 | test_dataset = datasets['test'] 133 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled']) 134 | unlabelled_train_examples_test.transform = test_transform 135 | 136 | return train_dataset, test_dataset, unlabelled_train_examples_test, val_dataset, datasets 137 | 138 | def get_class_splits(args): 139 | 140 | # For FGVC datasets, optionally return bespoke splits 141 | if args.dataset_name in ('scars', 'cub', 'aircraft'): 142 | if hasattr(args, 'use_ssb_splits'): 143 | use_ssb_splits = args.use_ssb_splits 144 | else: 145 | use_ssb_splits = False 146 | 147 | # ------------- 148 | # GET CLASS SPLITS 149 | # ------------- 150 | if args.dataset_name == 'cifar10': 151 | 152 | args.image_size = 32 153 | args.train_classes = range(5) 154 | args.unlabeled_classes = range(5, 10) 155 | 156 | elif args.dataset_name == 'cifar100': 157 | 158 | args.image_size = 32 159 | args.train_classes = range(80) 160 | args.unlabeled_classes = range(80, 100) 161 | 162 | elif args.dataset_name == 'tinyimagenet': 163 | 164 | args.image_size = 64 165 | args.train_classes = range(100) 166 | args.unlabeled_classes = range(100, 200) 167 | 168 | elif args.dataset_name == 'herbarium_19': 169 | 170 | args.image_size = 224 171 | herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl') 172 | 173 | with open(herb_path_splits, 'rb') as handle: 174 | class_splits = pickle.load(handle) 175 | 176 | args.train_classes = class_splits['Old'] 177 | args.unlabeled_classes = class_splits['New'] 178 | 179 | elif args.dataset_name == 'imagenet_100': 180 | 181 | args.image_size = 224 182 | args.train_classes = range(50) 183 | args.unlabeled_classes = range(50, 100) 184 | 185 | elif args.dataset_name == 'scars': 186 | 187 | args.image_size = 224 188 | 189 | if use_ssb_splits: 190 | 191 | split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl') 192 | with open(split_path, 'rb') as handle: 193 | class_info = pickle.load(handle) 194 | 195 | args.train_classes = class_info['known_classes'] 196 | open_set_classes = class_info['unknown_classes'] 197 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 198 | 199 | else: 200 | 201 | args.train_classes = range(98) 202 | args.unlabeled_classes = range(98, 196) 203 | 204 | elif args.dataset_name == 'aircraft': 205 | 206 | args.image_size = 224 207 | if use_ssb_splits: 208 | 209 | split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl') 210 | with open(split_path, 'rb') as handle: 211 | class_info = pickle.load(handle) 212 | 213 | args.train_classes = class_info['known_classes'] 214 | open_set_classes = class_info['unknown_classes'] 215 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 216 | 217 | else: 218 | 219 | args.train_classes = range(50) 220 | args.unlabeled_classes = range(50, 100) 221 | 222 | elif args.dataset_name == 'cub': 223 | 224 | args.image_size = 224 225 | 226 | if use_ssb_splits: 227 | 228 | split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl') 229 | with open(split_path, 'rb') as handle: 230 | class_info = pickle.load(handle) 231 | 232 | args.train_classes = class_info['known_classes'] 233 | open_set_classes = class_info['unknown_classes'] 234 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 235 | 236 | else: 237 | 238 | args.train_classes = range(100) 239 | args.unlabeled_classes = range(100, 200) 240 | 241 | elif args.dataset_name == 'chinese_traffic_signs': 242 | 243 | args.image_size = 224 244 | args.train_classes = range(28) 245 | args.unlabeled_classes = range(28, 56) 246 | 247 | else: 248 | 249 | raise NotImplementedError 250 | 251 | return args -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10, CIFAR100 2 | from copy import deepcopy 3 | import numpy as np 4 | 5 | from data.data_utils import subsample_instances 6 | from config import cifar_10_root, cifar_100_root 7 | 8 | 9 | class CustomCIFAR10(CIFAR10): 10 | 11 | def __init__(self, *args, **kwargs): 12 | 13 | super(CustomCIFAR10, self).__init__(*args, **kwargs) 14 | 15 | self.uq_idxs = np.array(range(len(self))) 16 | 17 | def __getitem__(self, item): 18 | 19 | img, label = super().__getitem__(item) 20 | uq_idx = self.uq_idxs[item] 21 | 22 | return img, label, uq_idx 23 | 24 | def __len__(self): 25 | return len(self.targets) 26 | 27 | 28 | class CustomCIFAR100(CIFAR100): 29 | 30 | def __init__(self, *args, **kwargs): 31 | super(CustomCIFAR100, self).__init__(*args, **kwargs) 32 | 33 | self.uq_idxs = np.array(range(len(self))) 34 | 35 | def __getitem__(self, item): 36 | img, label = super().__getitem__(item) 37 | uq_idx = self.uq_idxs[item] 38 | return img, label, uq_idx 39 | 40 | def __len__(self): 41 | return len(self.targets) 42 | 43 | 44 | def subsample_dataset(dataset, idxs, absolute=True): 45 | mask = np.zeros(len(dataset)).astype('bool') 46 | if absolute==True: 47 | mask[idxs] = True 48 | else: 49 | idxs = set(idxs) 50 | mask = np.array([i in idxs for i in dataset.uq_idxs]) 51 | 52 | dataset.data = dataset.data[mask] 53 | dataset.targets = np.array(dataset.targets)[mask].tolist() 54 | dataset.uq_idxs = dataset.uq_idxs[mask] 55 | 56 | return dataset 57 | 58 | 59 | def subsample_classes(dataset, include_classes=(0, 1, 8, 9)): 60 | 61 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] 62 | 63 | target_xform_dict = {} 64 | for i, k in enumerate(include_classes): 65 | target_xform_dict[k] = i 66 | 67 | dataset = subsample_dataset(dataset, cls_idxs) 68 | 69 | # dataset.target_transform = lambda x: target_xform_dict[x] 70 | 71 | return dataset 72 | 73 | 74 | def get_train_val_indices(train_dataset, val_split=0.2): 75 | 76 | train_classes = np.unique(train_dataset.targets) 77 | 78 | # Get train/test indices 79 | train_idxs = [] 80 | val_idxs = [] 81 | for cls in train_classes: 82 | 83 | cls_idxs = np.where(train_dataset.targets == cls)[0] 84 | 85 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 86 | t_ = [x for x in cls_idxs if x not in v_] 87 | 88 | train_idxs.extend(t_) 89 | val_idxs.extend(v_) 90 | 91 | return train_idxs, val_idxs 92 | 93 | 94 | def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9), 95 | prop_train_labels=0.8, split_train_val=False, seed=0): 96 | 97 | np.random.seed(seed) 98 | 99 | # Init entire training set 100 | whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True) 101 | 102 | # Get labelled training set which has subsampled classes, then subsample some indices from that 103 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 104 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 105 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 106 | 107 | # Split into training and validation sets 108 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 109 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 110 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 111 | val_dataset_labelled_split.transform = test_transform 112 | 113 | # Get unlabelled data 114 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 115 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 116 | 117 | # Get test set for all classes 118 | test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False) 119 | 120 | # Either split train into train and val or use test set as val 121 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 122 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 123 | 124 | all_datasets = { 125 | 'train_labelled': train_dataset_labelled, 126 | 'train_unlabelled': train_dataset_unlabelled, 127 | 'val': val_dataset_labelled, 128 | 'test': test_dataset, 129 | } 130 | 131 | return all_datasets 132 | 133 | 134 | def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80), 135 | prop_train_labels=0.8, split_train_val=False, seed=0): 136 | 137 | np.random.seed(seed) 138 | 139 | # Init entire training set 140 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True) 141 | 142 | # Get labelled training set which has subsampled classes, then subsample some indices from that 143 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 144 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 145 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 146 | 147 | # Split into training and validation sets 148 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 149 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 150 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 151 | val_dataset_labelled_split.transform = test_transform 152 | 153 | # Get unlabelled data 154 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 155 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 156 | 157 | # Get test set for all classes 158 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False) 159 | 160 | # Either split train into train and val or use test set as val 161 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 162 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 163 | 164 | all_datasets = { 165 | 'train_labelled': train_dataset_labelled, 166 | 'train_unlabelled': train_dataset_unlabelled, 167 | 'val': val_dataset_labelled, 168 | 'test': test_dataset, 169 | } 170 | 171 | return all_datasets 172 | 173 | 174 | def get_cifar_100_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80), 175 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1): 176 | 177 | np.random.seed(seed) 178 | 179 | # Init entire training set 180 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True) 181 | 182 | # Split whole dataset by train classes 183 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 184 | # Subsample train classes into trainl 185 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 186 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 187 | # All remaining samples into trainu 188 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 189 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False) 190 | 191 | # Split trainl into train/val 192 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split) 193 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 194 | val_dataset_labelled.transform = test_transform 195 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 196 | 197 | # Split trainu into train/val 198 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split) 199 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs) 200 | val_dataset_unlabelled.transform = test_transform 201 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs) 202 | 203 | val_dataset_unlabelled.transform = test_transform 204 | 205 | 206 | # Get test set for all classes 207 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False) 208 | 209 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}') 210 | 211 | all_datasets = { 212 | 'train_labelled': train_dataset_labelled, 213 | 'train_unlabelled': train_dataset_unlabelled, 214 | 'val': [val_dataset_labelled, val_dataset_unlabelled], 215 | 'test': test_dataset, 216 | } 217 | 218 | return all_datasets 219 | 220 | if __name__ == '__main__': 221 | 222 | x = get_cifar_100_datasets(None, None, split_train_val=False, 223 | train_classes=range(80), prop_train_labels=0.5) 224 | 225 | print('Printing lens...') 226 | for k, v in x.items(): 227 | if v is not None: 228 | print(f'{k}: {len(v)}') 229 | 230 | print('Printing labelled and unlabelled overlap...') 231 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 232 | print('Printing total instances in train...') 233 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 234 | 235 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 236 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 237 | print(f'Len labelled set: {len(x["train_labelled"])}') 238 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from torchvision.datasets.folder import default_loader 7 | from torchvision.datasets.utils import download_url 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import subsample_instances 11 | from config import cub_root 12 | 13 | 14 | class CustomCub2011(Dataset): 15 | base_folder = 'CUB_200_2011/images' 16 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 17 | filename = 'CUB_200_2011.tgz' 18 | tgz_md5 = '97eceeb196236b17998738112f37df78' 19 | 20 | def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=True): 21 | 22 | self.root = os.path.expanduser(root) 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | 26 | self.loader = loader 27 | self.train = train 28 | 29 | 30 | if download: 31 | self._download() 32 | 33 | if not self._check_integrity(): 34 | raise RuntimeError('Dataset not found or corrupted.' + 35 | ' You can use download=True to download it') 36 | 37 | self.uq_idxs = np.array(range(len(self))) 38 | 39 | def _load_metadata(self): 40 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 41 | names=['img_id', 'filepath']) 42 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 43 | sep=' ', names=['img_id', 'target']) 44 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 45 | sep=' ', names=['img_id', 'is_training_img']) 46 | 47 | data = images.merge(image_class_labels, on='img_id') 48 | self.data = data.merge(train_test_split, on='img_id') 49 | 50 | if self.train: 51 | self.data = self.data[self.data.is_training_img == 1] 52 | else: 53 | self.data = self.data[self.data.is_training_img == 0] 54 | 55 | def _check_integrity(self): 56 | try: 57 | self._load_metadata() 58 | except Exception: 59 | return False 60 | 61 | for index, row in self.data.iterrows(): 62 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 63 | if not os.path.isfile(filepath): 64 | print(filepath) 65 | return False 66 | return True 67 | 68 | def _download(self): 69 | import tarfile 70 | 71 | if self._check_integrity(): 72 | print('Files already downloaded and verified') 73 | return 74 | 75 | download_url(self.url, self.root, self.filename, self.tgz_md5) 76 | 77 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 78 | tar.extractall(path=self.root) 79 | 80 | def __len__(self): 81 | return len(self.data) 82 | 83 | def __getitem__(self, idx): 84 | sample = self.data.iloc[idx] 85 | path = os.path.join(self.root, self.base_folder, sample.filepath) 86 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 87 | img = self.loader(path) 88 | 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | 95 | return img, target, self.uq_idxs[idx] 96 | 97 | 98 | def subsample_dataset(dataset, idxs, absolute=True): 99 | mask = np.zeros(len(dataset)).astype('bool') 100 | if absolute==True: 101 | mask[idxs] = True 102 | else: 103 | idxs = set(idxs) 104 | mask = np.array([i in idxs for i in dataset.uq_idxs]) 105 | 106 | dataset.data = dataset.data[mask] 107 | dataset.uq_idxs = dataset.uq_idxs[mask] 108 | 109 | return dataset 110 | 111 | def subsample_classes(dataset, include_classes=range(160)): 112 | 113 | include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199 114 | cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub] 115 | 116 | target_xform_dict = {} 117 | for i, k in enumerate(include_classes): 118 | target_xform_dict[k] = i 119 | 120 | dataset = subsample_dataset(dataset, cls_idxs) 121 | 122 | dataset.target_transform = lambda x: target_xform_dict[x] 123 | 124 | return dataset 125 | 126 | 127 | def get_train_val_indices(train_dataset, val_split=0.2): 128 | 129 | train_classes = np.unique(train_dataset.data['target']) 130 | 131 | # Get train/test indices 132 | train_idxs = [] 133 | val_idxs = [] 134 | for cls in train_classes: 135 | 136 | cls_idxs = np.where(train_dataset.data['target'] == cls)[0] 137 | 138 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 139 | t_ = [x for x in cls_idxs if x not in v_] 140 | 141 | train_idxs.extend(t_) 142 | val_idxs.extend(v_) 143 | 144 | return train_idxs, val_idxs 145 | 146 | 147 | def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, 148 | split_train_val=False, seed=0): 149 | 150 | np.random.seed(seed) 151 | 152 | # Init entire training set 153 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True) 154 | 155 | # Get labelled training set which has subsampled classes, then subsample some indices from that 156 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 157 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 158 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 159 | 160 | # Split into training and validation sets 161 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 162 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 163 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 164 | val_dataset_labelled_split.transform = test_transform 165 | 166 | # Get unlabelled data 167 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 168 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 169 | 170 | # Get test set for all classes 171 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False) 172 | 173 | # Either split train into train and val or use test set as val 174 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 175 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 176 | 177 | all_datasets = { 178 | 'train_labelled': train_dataset_labelled, 179 | 'train_unlabelled': train_dataset_unlabelled, 180 | 'val': val_dataset_labelled, 181 | 'test': test_dataset, 182 | } 183 | 184 | return all_datasets 185 | 186 | def get_cub_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80), 187 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1): 188 | 189 | np.random.seed(seed) 190 | 191 | # Init entire training set 192 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True) 193 | 194 | # Get labelled training set which has subsampled classes, then subsample some indices from that 195 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 196 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 197 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 198 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 199 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False) 200 | 201 | # Split into training and validation sets 202 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split) 203 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 204 | val_dataset_labelled.transform = test_transform 205 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 206 | 207 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split) 208 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs) 209 | val_dataset_unlabelled.transform = test_transform 210 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs) 211 | 212 | val_dataset_unlabelled.transform = test_transform 213 | 214 | 215 | # Get test set for all classes 216 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False) 217 | 218 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}') 219 | 220 | all_datasets = { 221 | 'train_labelled': train_dataset_labelled, 222 | 'train_unlabelled': train_dataset_unlabelled, 223 | 'val': [val_dataset_labelled, val_dataset_unlabelled], 224 | 'test': test_dataset, 225 | } 226 | 227 | return all_datasets 228 | 229 | if __name__ == '__main__': 230 | 231 | x = get_cub_datasets(None, None, split_train_val=False, 232 | train_classes=range(100), prop_train_labels=0.5) 233 | 234 | print('Printing lens...') 235 | for k, v in x.items(): 236 | if v is not None: 237 | print(f'{k}: {len(v)}') 238 | 239 | print('Printing labelled and unlabelled overlap...') 240 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 241 | print('Printing total instances in train...') 242 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 243 | 244 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}') 245 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}') 246 | print(f'Len labelled set: {len(x["train_labelled"])}') 247 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import numpy as np 3 | 4 | import os 5 | 6 | from copy import deepcopy 7 | from data.data_utils import subsample_instances 8 | from config import imagenet_root 9 | 10 | 11 | class ImageNetBase(torchvision.datasets.ImageFolder): 12 | 13 | def __init__(self, root, transform): 14 | 15 | super(ImageNetBase, self).__init__(root, transform) 16 | 17 | self.uq_idxs = np.array(range(len(self))) 18 | 19 | def __getitem__(self, item): 20 | 21 | img, label = super().__getitem__(item) 22 | uq_idx = self.uq_idxs[item] 23 | 24 | return img, label, uq_idx 25 | 26 | def subsample_dataset(dataset, idxs, absolute=True): 27 | mask = np.zeros(len(dataset)).astype('bool') 28 | if absolute==True: 29 | mask[idxs] = True 30 | else: 31 | idxs = set(idxs) 32 | mask = np.array([i in idxs for i in dataset.uq_idxs]) 33 | 34 | dataset.samples = [s for m, s in zip(mask, dataset.samples) if m==True] 35 | dataset.targets = [t for m, t in zip(mask, dataset.targets) if m==True] 36 | dataset.uq_idxs = dataset.uq_idxs[mask] 37 | 38 | return dataset 39 | 40 | def subsample_classes(dataset, include_classes=list(range(1000))): 41 | 42 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] 43 | 44 | target_xform_dict = {} 45 | for i, k in enumerate(include_classes): 46 | target_xform_dict[k] = i 47 | 48 | dataset = subsample_dataset(dataset, cls_idxs) 49 | dataset.target_transform = lambda x: target_xform_dict[x] 50 | 51 | return dataset 52 | 53 | 54 | def get_train_val_indices(train_dataset, val_split=0.2): 55 | 56 | train_classes = list(set(train_dataset.targets)) 57 | 58 | # Get train/test indices 59 | train_idxs = [] 60 | val_idxs = [] 61 | for cls in train_classes: 62 | 63 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] 64 | 65 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 66 | t_ = [x for x in cls_idxs if x not in v_] 67 | 68 | train_idxs.extend(t_) 69 | val_idxs.extend(v_) 70 | 71 | return train_idxs, val_idxs 72 | 73 | 74 | def get_equal_len_datasets(dataset1, dataset2): 75 | """ 76 | Make two datasets the same length 77 | """ 78 | 79 | if len(dataset1) > len(dataset2): 80 | 81 | rand_idxs = np.random.choice(range(len(dataset1)), size=(len(dataset2, ))) 82 | subsample_dataset(dataset1, rand_idxs) 83 | 84 | elif len(dataset2) > len(dataset1): 85 | 86 | rand_idxs = np.random.choice(range(len(dataset2)), size=(len(dataset1, ))) 87 | subsample_dataset(dataset2, rand_idxs) 88 | 89 | return dataset1, dataset2 90 | 91 | 92 | def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80), 93 | prop_train_labels=0.8, split_train_val=False, seed=0): 94 | 95 | np.random.seed(seed) 96 | 97 | # Subsample imagenet dataset initially to include 100 classes 98 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False) 99 | subsampled_100_classes = np.sort(subsampled_100_classes) 100 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}') 101 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))} 102 | 103 | # Init entire training set 104 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) 105 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes) 106 | 107 | # Reset dataset 108 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples] 109 | whole_training_set.targets = [s[1] for s in whole_training_set.samples] 110 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set))) 111 | whole_training_set.target_transform = None 112 | 113 | # Get labelled training set which has subsampled classes, then subsample some indices from that 114 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 115 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 116 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 117 | 118 | # Split into training and validation sets 119 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 120 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 121 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 122 | val_dataset_labelled_split.transform = test_transform 123 | 124 | # Get unlabelled data 125 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 126 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 127 | 128 | # Get test set for all classes 129 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) 130 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes) 131 | # Reset test set 132 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples] 133 | test_dataset.targets = [s[1] for s in test_dataset.samples] 134 | test_dataset.uq_idxs = np.array(range(len(test_dataset))) 135 | test_dataset.target_transform = None 136 | 137 | #test_dataset = subsample_classes(test_dataset, include_classes=train_classes) 138 | 139 | # Either split train into train and val or use test set as val 140 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 141 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 142 | 143 | all_datasets = { 144 | 'train_labelled': train_dataset_labelled, 145 | 'train_unlabelled': train_dataset_unlabelled, 146 | 'val': val_dataset_labelled, 147 | 'test': test_dataset, 148 | } 149 | 150 | return all_datasets 151 | 152 | def get_imagenet_100_gcd_datasets_with_gcdval(train_transform, test_transform, train_classes=range(50), 153 | prop_train_labels=0.5, split_train_val=True, seed=0, val_split=0.1): 154 | 155 | np.random.seed(seed) 156 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False) 157 | subsampled_100_classes = np.sort(subsampled_100_classes) 158 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}') 159 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))} 160 | 161 | # Init entire training set 162 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) 163 | # NOTE: use GCD paper IN split 164 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes) 165 | # Reset dataset 166 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples] 167 | whole_training_set.targets = [s[1] for s in whole_training_set.samples] 168 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set))) 169 | whole_training_set.target_transform = None 170 | 171 | # Get labelled training set which has subsampled classes, then subsample some indices from that 172 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 173 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 174 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 175 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 176 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False) 177 | 178 | # Split into training and validation sets 179 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split) 180 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 181 | val_dataset_labelled.transform = test_transform 182 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 183 | 184 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split) 185 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs) 186 | val_dataset_unlabelled.transform = test_transform 187 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs) 188 | 189 | val_dataset_unlabelled.transform = test_transform 190 | 191 | 192 | # Get test set for all classes 193 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) 194 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes) 195 | 196 | # Reset test set 197 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples] 198 | test_dataset.targets = [s[1] for s in test_dataset.samples] 199 | test_dataset.uq_idxs = np.array(range(len(test_dataset))) 200 | test_dataset.target_transform = None 201 | 202 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}') 203 | 204 | all_datasets = { 205 | 'train_labelled': train_dataset_labelled, 206 | 'train_unlabelled': train_dataset_unlabelled, 207 | 'val': [val_dataset_labelled, val_dataset_unlabelled], 208 | 'test': test_dataset, 209 | } 210 | 211 | return all_datasets 212 | 213 | 214 | if __name__ == '__main__': 215 | 216 | x = get_imagenet_100_datasets(None, None, split_train_val=False, 217 | train_classes=range(50), prop_train_labels=0.5) 218 | 219 | print('Printing lens...') 220 | for k, v in x.items(): 221 | if v is not None: 222 | print(f'{k}: {len(v)}') 223 | 224 | print('Printing labelled and unlabelled overlap...') 225 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 226 | print('Printing total instances in train...') 227 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 228 | 229 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 230 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 231 | print(f'Len labelled set: {len(x["train_labelled"])}') 232 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /methods/contrastive_meanshift_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | import wandb 8 | 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader 11 | from torch.optim import SGD, lr_scheduler 12 | from tqdm import tqdm 13 | 14 | from project_utils.cluster_utils import AverageMeter 15 | from data.augmentations import get_transform 16 | from data.get_datasets import get_datasets, get_datasets_with_gcdval, get_class_splits 17 | from project_utils.cluster_and_log_utils import * 18 | from project_utils.general_utils import init_experiment, str2bool 19 | 20 | from models.dino import * 21 | from methods.loss import * 22 | 23 | import warnings 24 | warnings.filterwarnings("ignore", category=DeprecationWarning) 25 | 26 | class ContrastiveLearningViewGenerator(object): 27 | """Take two random crops of one image as the query and key.""" 28 | 29 | def __init__(self, base_transform, n_views=2): 30 | self.base_transform = base_transform 31 | self.n_views = n_views 32 | 33 | def __call__(self, x): 34 | return [self.base_transform(x) for i in range(self.n_views)] 35 | 36 | 37 | def train(model, train_loader, test_loader, train_loader_memory, args): 38 | 39 | optimizer = SGD(list(model.module.parameters()), lr=args.lr, momentum=args.momentum, 40 | weight_decay=args.weight_decay) 41 | exp_lr_scheduler = lr_scheduler.CosineAnnealingLR( 42 | optimizer, 43 | T_max=args.epochs, 44 | eta_min=args.lr * args.eta_min, 45 | ) 46 | 47 | sup_con_crit = SupConLoss() 48 | unsup_con_crit = ConMeanShiftLoss(args) 49 | 50 | best_agglo_score = 0 51 | best_agglo_img_score = 0 52 | best_img_old_score = 0 53 | 54 | for epoch in range(args.epochs): 55 | loss_record = AverageMeter() 56 | with torch.no_grad(): 57 | model.eval() 58 | all_feats = [] 59 | for batch_idx, batch in enumerate(tqdm(train_loader_memory)): 60 | images, class_labels, uq_idxs, mask_lab = batch 61 | images = torch.cat(images, dim=0).to(device) 62 | 63 | features = model(images) 64 | all_feats.append(features.detach().cpu()) 65 | all_feats = torch.cat(all_feats) 66 | 67 | model.train() 68 | for batch_idx, batch in enumerate(tqdm(train_loader)): 69 | images, class_labels, uq_idxs, mask_lab = batch 70 | class_labels, mask_lab = class_labels.to(device), mask_lab[:, 0].to(device).bool() 71 | images = torch.cat(images, dim=0).to(device) 72 | 73 | features = model(images) 74 | 75 | classwise_sim = torch.einsum('b d, n d -> b n', features.cpu(), all_feats) 76 | 77 | _, indices = classwise_sim.topk(k=args.k+1, dim=-1, largest=True, sorted=True) 78 | indices = indices[:, 1:] 79 | knn_emb = torch.mean(all_feats[indices, :].view(-1, args.k, args.feat_dim), dim=1).to(device) 80 | 81 | if args.contrast_unlabel_only: 82 | # Contrastive loss only on unlabelled instances 83 | f1, f2 = [f[~mask_lab] for f in features.chunk(2)] 84 | con_feats = torch.cat([f1, f2], dim=0) 85 | f3, f4 = [f[~mask_lab] for f in knn_emb.chunk(2)] 86 | con_knn_emb = torch.cat([f3, f4], dim=0) 87 | con_uq_idxs = uq_idxs[~mask_lab] 88 | 89 | else: 90 | # Contrastive loss for all examples 91 | con_feats = features 92 | con_knn_emb = knn_emb 93 | con_uq_idxs = uq_idxs 94 | 95 | unsup_con_loss = unsup_con_crit(con_knn_emb, con_feats) 96 | 97 | # Supervised contrastive loss 98 | f1, f2 = [f[mask_lab] for f in features.chunk(2)] 99 | sup_con_feats = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 100 | sup_con_labels = class_labels[mask_lab] 101 | sup_con_loss = sup_con_crit(sup_con_feats, labels=sup_con_labels) 102 | 103 | # Total loss 104 | loss = (1 - args.sup_con_weight) * unsup_con_loss + args.sup_con_weight * (sup_con_loss) 105 | 106 | loss_record.update(loss.item(), class_labels.size(0)) 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | 111 | print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) 112 | 113 | with torch.no_grad(): 114 | model.eval() 115 | all_feats_val = [] 116 | targets = np.array([]) 117 | mask = np.array([]) 118 | 119 | for batch_idx, batch in enumerate(tqdm(test_loader)): 120 | images, label, _ = batch[:3] 121 | images = images.cuda() 122 | 123 | features = model(images) 124 | all_feats_val.append(features.detach().cpu().numpy()) 125 | targets = np.append(targets, label.cpu().numpy()) 126 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) 127 | else False for x in label])) 128 | 129 | # ----------------- 130 | # Clustering 131 | # ----------------- 132 | img_all_acc_test, img_old_acc_test, img_new_acc_test, img_agglo_score, estimated_k = test_agglo(epoch, all_feats_val, targets, mask, "Test/ACC", args) 133 | if args.wandb: 134 | wandb.log({ 'test/all': img_all_acc_test, 'test/base': img_old_acc_test, 'test/novel': img_new_acc_test, 135 | 'score/agglo': img_agglo_score, 'score/estimated_k': estimated_k}, step=epoch) 136 | 137 | print('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(img_all_acc_test, img_old_acc_test, img_new_acc_test)) 138 | 139 | # Step schedule 140 | exp_lr_scheduler.step() 141 | torch.save(model.state_dict(), args.model_path) 142 | 143 | if img_agglo_score > best_agglo_img_score: 144 | torch.save({ 145 | 'k': estimated_k, 146 | 'model_state_dict': model.state_dict()} 147 | , args.model_path[:-3] + f'_best.pt') 148 | best_agglo_img_score = img_agglo_score 149 | 150 | 151 | if __name__ == "__main__": 152 | 153 | parser = argparse.ArgumentParser( 154 | description='cluster', 155 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 156 | parser.add_argument('--batch_size', default=128, type=int) 157 | parser.add_argument('--num_workers', default=16, type=int) 158 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2']) 159 | 160 | parser.add_argument('--warmup_model_dir', type=str, default=None) 161 | parser.add_argument('--exp_root', type=str, default='/home/suachoi/CMS/log') 162 | parser.add_argument('--pretrain_path', type=str, default='/home/suachoi/CMS/models') 163 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, scars') 164 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 165 | parser.add_argument('--use_ssb_splits', type=bool, default=True) 166 | 167 | parser.add_argument('--lr', type=float, default=0.01) 168 | parser.add_argument('--eta_min', type=float, default=1e-3) 169 | parser.add_argument('--epochs', default=200, type=int) 170 | parser.add_argument('--momentum', type=float, default=0.9) 171 | parser.add_argument('--weight_decay', type=float, default=5e-5) 172 | parser.add_argument('--transform', type=str, default='imagenet') 173 | parser.add_argument('--seed', default=1, type=int) 174 | parser.add_argument('--n_views', default=2, type=int) 175 | parser.add_argument('--contrast_unlabel_only', type=bool, default=False) 176 | parser.add_argument('--sup_con_weight', type=float, default=0.35) 177 | parser.add_argument('--temperature', type=float, default=0.5) 178 | 179 | parser.add_argument('--alpha', type=float, default=0.5) 180 | parser.add_argument('--k', default=8, type=int) 181 | parser.add_argument('--inductive', action='store_true') 182 | parser.add_argument('--wandb', action='store_true', help='Flag to log at wandb') 183 | 184 | # ---------------------- 185 | # INIT 186 | # ---------------------- 187 | args = parser.parse_args() 188 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 189 | args = get_class_splits(args) 190 | 191 | args.feat_dim = 768 192 | args.interpolation = 3 193 | args.crop_pct = 0.875 194 | args.image_size = 224 195 | args.num_mlp_layers = 3 196 | args.num_labeled_classes = len(args.train_classes) 197 | args.num_unlabeled_classes = len(args.unlabeled_classes) 198 | 199 | init_experiment(args, runner_name=['cms']) 200 | print(f'Using evaluation function {args.eval_funcs[0]} to print results') 201 | 202 | if args.wandb: 203 | wandb.init(project='CMS') 204 | wandb.config.update(args) 205 | 206 | 207 | # -------------------- 208 | # MODEL 209 | # -------------------- 210 | model = DINO(args) 211 | model = nn.DataParallel(model).to(device) 212 | 213 | # -------------------- 214 | # CONTRASTIVE TRANSFORM 215 | # -------------------- 216 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) 217 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 218 | 219 | # -------------------- 220 | # DATASETS 221 | # -------------------- 222 | if args.inductive: 223 | train_dataset, test_dataset, unlabelled_train_examples_test, val_datasets, datasets = get_datasets_with_gcdval(args.dataset_name, train_transform, test_transform, args) 224 | else: 225 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name, train_transform, test_transform, args) 226 | 227 | 228 | # -------------------- 229 | # SAMPLER 230 | # Sampler which balances labelled and unlabelled examples in each batch 231 | # -------------------- 232 | label_len = len(train_dataset.labelled_dataset) 233 | unlabelled_len = len(train_dataset.unlabelled_dataset) 234 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] 235 | sample_weights = torch.DoubleTensor(sample_weights) 236 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset)) 237 | 238 | # -------------------- 239 | # DATALOADERS 240 | # -------------------- 241 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, sampler=sampler, drop_last=True) 242 | train_loader_memory = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, drop_last=False) 243 | if args.inductive: 244 | test_loader_labelled = DataLoader(val_datasets, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 245 | else: 246 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 247 | 248 | 249 | # ---------------------- 250 | # TRAIN 251 | # ---------------------- 252 | train(model, train_loader, test_loader_labelled, train_loader_memory, args) -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from torch.nn.init import trunc_normal_ 25 | 26 | def drop_path(x, drop_prob: float = 0., training: bool = False): 27 | if drop_prob == 0. or not training: 28 | return x 29 | keep_prob = 1 - drop_prob 30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 32 | random_tensor.floor_() # binarize 33 | output = x.div(keep_prob) * random_tensor 34 | return output 35 | 36 | 37 | class DropPath(nn.Module): 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 39 | """ 40 | def __init__(self, drop_prob=None): 41 | super(DropPath, self).__init__() 42 | self.drop_prob = drop_prob 43 | 44 | def forward(self, x): 45 | return drop_path(x, self.drop_prob, self.training) 46 | 47 | 48 | class Mlp(nn.Module): 49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 50 | super().__init__() 51 | out_features = out_features or in_features 52 | hidden_features = hidden_features or in_features 53 | self.fc1 = nn.Linear(in_features, hidden_features) 54 | self.act = act_layer() 55 | self.fc2 = nn.Linear(hidden_features, out_features) 56 | self.drop = nn.Dropout(drop) 57 | 58 | def forward(self, x): 59 | x = self.fc1(x) 60 | x = self.act(x) 61 | x = self.drop(x) 62 | x = self.fc2(x) 63 | x = self.drop(x) 64 | return x 65 | 66 | 67 | class Attention(nn.Module): 68 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 69 | super().__init__() 70 | self.num_heads = num_heads 71 | head_dim = dim // num_heads 72 | self.scale = qk_scale or head_dim ** -0.5 73 | 74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 75 | self.attn_drop = nn.Dropout(attn_drop) 76 | self.proj = nn.Linear(dim, dim) 77 | self.proj_drop = nn.Dropout(proj_drop) 78 | 79 | def forward(self, x): 80 | B, N, C = x.shape 81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 82 | q, k, v = qkv[0], qkv[1], qkv[2] 83 | 84 | attn = (q @ k.transpose(-2, -1)) * self.scale 85 | attn = attn.softmax(dim=-1) 86 | attn = self.attn_drop(attn) 87 | 88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 89 | x = self.proj(x) 90 | x = self.proj_drop(x) 91 | return x, attn 92 | 93 | 94 | class Block(nn.Module): 95 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 96 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 97 | super().__init__() 98 | self.norm1 = norm_layer(dim) 99 | self.attn = Attention( 100 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 102 | self.norm2 = norm_layer(dim) 103 | mlp_hidden_dim = int(dim * mlp_ratio) 104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 105 | 106 | def forward(self, x, return_attention=False): 107 | y, attn = self.attn(self.norm1(x)) 108 | x = x + self.drop_path(y) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | 111 | if return_attention: 112 | return x, attn 113 | else: 114 | return x 115 | 116 | 117 | class PatchEmbed(nn.Module): 118 | """ Image to Patch Embedding 119 | """ 120 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 121 | super().__init__() 122 | num_patches = (img_size // patch_size) * (img_size // patch_size) 123 | self.img_size = img_size 124 | self.patch_size = patch_size 125 | self.num_patches = num_patches 126 | 127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 128 | 129 | def forward(self, x): 130 | B, C, H, W = x.shape 131 | x = self.proj(x).flatten(2).transpose(1, 2) 132 | return x 133 | 134 | 135 | class VisionTransformer(nn.Module): 136 | """ Vision Transformer """ 137 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 138 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 139 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim 142 | 143 | self.patch_embed = PatchEmbed( 144 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 145 | num_patches = self.patch_embed.num_patches 146 | 147 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 148 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 149 | self.pos_drop = nn.Dropout(p=drop_rate) 150 | 151 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 152 | self.blocks = nn.ModuleList([ 153 | Block( 154 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 155 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 156 | for i in range(depth)]) 157 | self.norm = norm_layer(embed_dim) 158 | 159 | # Classifier head 160 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 161 | 162 | trunc_normal_(self.pos_embed, std=.02) 163 | trunc_normal_(self.cls_token, std=.02) 164 | self.apply(self._init_weights) 165 | 166 | def _init_weights(self, m): 167 | if isinstance(m, nn.Linear): 168 | trunc_normal_(m.weight, std=.02) 169 | if isinstance(m, nn.Linear) and m.bias is not None: 170 | nn.init.constant_(m.bias, 0) 171 | elif isinstance(m, nn.LayerNorm): 172 | nn.init.constant_(m.bias, 0) 173 | nn.init.constant_(m.weight, 1.0) 174 | 175 | def interpolate_pos_encoding(self, x, w, h): 176 | npatch = x.shape[1] - 1 177 | N = self.pos_embed.shape[1] - 1 178 | if npatch == N and w == h: 179 | return self.pos_embed 180 | class_pos_embed = self.pos_embed[:, 0] 181 | patch_pos_embed = self.pos_embed[:, 1:] 182 | dim = x.shape[-1] 183 | w0 = w // self.patch_embed.patch_size 184 | h0 = h // self.patch_embed.patch_size 185 | # we add a small number to avoid floating point error in the interpolation 186 | # see discussion at https://github.com/facebookresearch/dino/issues/8 187 | w0, h0 = w0 + 0.1, h0 + 0.1 188 | patch_pos_embed = nn.functional.interpolate( 189 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 190 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 191 | mode='bicubic', 192 | ) 193 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 194 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 195 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 196 | 197 | def prepare_tokens(self, x): 198 | B, nc, w, h = x.shape 199 | x = self.patch_embed(x) # patch linear embedding 200 | 201 | # add the [CLS] token to the embed patch tokens 202 | cls_tokens = self.cls_token.expand(B, -1, -1) 203 | x = torch.cat((cls_tokens, x), dim=1) 204 | 205 | # add positional encoding to each token 206 | x = x + self.interpolate_pos_encoding(x, w, h) 207 | 208 | return self.pos_drop(x) 209 | 210 | def forward(self, x, return_all_patches=False): 211 | x = self.prepare_tokens(x) 212 | for blk in self.blocks: 213 | x = blk(x) 214 | x = self.norm(x) 215 | 216 | if return_all_patches: 217 | return x 218 | else: 219 | return x[:, 0] 220 | 221 | def get_last_selfattention(self, x): 222 | x = self.prepare_tokens(x) 223 | for i, blk in enumerate(self.blocks): 224 | if i < len(self.blocks) - 1: 225 | x = blk(x) 226 | else: 227 | # return attention of the last block 228 | x, attn = blk(x, return_attention=True) 229 | x = self.norm(x) 230 | return x, attn 231 | 232 | def get_intermediate_layers(self, x, n=1): 233 | x = self.prepare_tokens(x) 234 | # we return the output tokens from the `n` last blocks 235 | output = [] 236 | for i, blk in enumerate(self.blocks): 237 | x = blk(x) 238 | if len(self.blocks) - i <= n: 239 | output.append(self.norm(x)) 240 | return output 241 | 242 | 243 | def vit_tiny(patch_size=16, **kwargs): 244 | model = VisionTransformer( 245 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 246 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 247 | return model 248 | 249 | 250 | def vit_small(patch_size=16, **kwargs): 251 | model = VisionTransformer( 252 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 253 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 254 | return model 255 | 256 | 257 | def vit_base(patch_size=16, **kwargs): 258 | model = VisionTransformer( 259 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 260 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 261 | return model 262 | 263 | 264 | class DINOHead(nn.Module): 265 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 266 | super().__init__() 267 | nlayers = max(nlayers, 1) 268 | if nlayers == 1: 269 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 270 | else: 271 | layers = [nn.Linear(in_dim, hidden_dim)] 272 | if use_bn: 273 | layers.append(nn.BatchNorm1d(hidden_dim)) 274 | layers.append(nn.GELU()) 275 | for _ in range(nlayers - 2): 276 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 277 | if use_bn: 278 | layers.append(nn.BatchNorm1d(hidden_dim)) 279 | layers.append(nn.GELU()) 280 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 281 | self.mlp = nn.Sequential(*layers) 282 | self.apply(self._init_weights) 283 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 284 | self.last_layer.weight_g.data.fill_(1) 285 | if norm_last_layer: 286 | self.last_layer.weight_g.requires_grad = False 287 | 288 | def _init_weights(self, m): 289 | if isinstance(m, nn.Linear): 290 | trunc_normal_(m.weight, std=.02) 291 | if isinstance(m, nn.Linear) and m.bias is not None: 292 | nn.init.constant_(m.bias, 0) 293 | 294 | def forward(self, x): 295 | x = self.mlp(x) 296 | x = nn.functional.normalize(x, dim=-1, p=2) 297 | x = self.last_layer(x) 298 | return x 299 | 300 | 301 | class VisionTransformerWithLinear(nn.Module): 302 | 303 | def __init__(self, base_vit, num_classes=200): 304 | 305 | super().__init__() 306 | 307 | self.base_vit = base_vit 308 | self.fc = nn.Linear(768, num_classes) 309 | 310 | def forward(self, x, return_features=False): 311 | 312 | features = self.base_vit(x) 313 | features = torch.nn.functional.normalize(features, dim=-1) 314 | logits = self.fc(features) 315 | 316 | if return_features: 317 | return logits, features 318 | else: 319 | return logits 320 | 321 | @torch.no_grad() 322 | def normalize_prototypes(self): 323 | w = self.fc.weight.data.clone() 324 | w = torch.nn.functional.normalize(w, dim=1, p=2) 325 | self.fc.weight.copy_(w) -------------------------------------------------------------------------------- /data/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from torchvision.datasets.folder import default_loader 7 | from torch.utils.data import Dataset 8 | 9 | from data.data_utils import subsample_instances 10 | from config import aircraft_root 11 | 12 | def make_dataset(dir, image_ids, targets): 13 | assert(len(image_ids) == len(targets)) 14 | images = [] 15 | dir = os.path.expanduser(dir) 16 | for i in range(len(image_ids)): 17 | item = (os.path.join(dir, 'data', 'images', 18 | '%s.jpg' % image_ids[i]), targets[i]) 19 | images.append(item) 20 | return images 21 | 22 | 23 | def find_classes(classes_file): 24 | 25 | # read classes file, separating out image IDs and class names 26 | image_ids = [] 27 | targets = [] 28 | f = open(classes_file, 'r') 29 | for line in f: 30 | split_line = line.split(' ') 31 | image_ids.append(split_line[0]) 32 | targets.append(' '.join(split_line[1:])) 33 | f.close() 34 | 35 | # index class names 36 | classes = np.unique(targets) 37 | class_to_idx = {classes[i]: i for i in range(len(classes))} 38 | targets = [class_to_idx[c] for c in targets] 39 | 40 | return (image_ids, targets, classes, class_to_idx) 41 | 42 | 43 | class FGVCAircraft(Dataset): 44 | 45 | """`FGVC-Aircraft `_ Dataset. 46 | 47 | Args: 48 | root (string): Root directory path to dataset. 49 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification 50 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). 51 | transform (callable, optional): A function/transform that takes in a PIL image 52 | and returns a transformed version. E.g. ``transforms.RandomCrop`` 53 | target_transform (callable, optional): A function/transform that takes in the 54 | target and transforms it. 55 | loader (callable, optional): A function to load an image given its path. 56 | download (bool, optional): If true, downloads the dataset from the internet and 57 | puts it in the root directory. If dataset is already downloaded, it is not 58 | downloaded again. 59 | """ 60 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' 61 | class_types = ('variant', 'family', 'manufacturer') 62 | splits = ('train', 'val', 'trainval', 'test') 63 | 64 | def __init__(self, root, class_type='variant', split='train', transform=None, 65 | target_transform=None, loader=default_loader, download=False): 66 | if split not in self.splits: 67 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 68 | split, ', '.join(self.splits), 69 | )) 70 | if class_type not in self.class_types: 71 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( 72 | class_type, ', '.join(self.class_types), 73 | )) 74 | self.root = os.path.expanduser(root) 75 | self.class_type = class_type 76 | self.split = split 77 | self.classes_file = os.path.join(self.root, 'data', 78 | 'images_%s_%s.txt' % (self.class_type, self.split)) 79 | 80 | if download: 81 | self.download() 82 | 83 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) 84 | samples = make_dataset(self.root, image_ids, targets) 85 | 86 | self.transform = transform 87 | self.target_transform = target_transform 88 | self.loader = loader 89 | 90 | self.samples = samples 91 | self.classes = classes 92 | self.class_to_idx = class_to_idx 93 | self.train = True if split == 'train' else False 94 | 95 | self.uq_idxs = np.array(range(len(self))) 96 | 97 | def __getitem__(self, index): 98 | """ 99 | Args: 100 | index (int): Index 101 | 102 | Returns: 103 | tuple: (sample, target) where target is class_index of the target class. 104 | """ 105 | 106 | path, target = self.samples[index] 107 | sample = self.loader(path) 108 | if self.transform is not None: 109 | sample = self.transform(sample) 110 | if self.target_transform is not None: 111 | target = self.target_transform(target) 112 | 113 | return sample, target, self.uq_idxs[index] 114 | 115 | def __len__(self): 116 | return len(self.samples) 117 | 118 | def __repr__(self): 119 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 120 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 121 | fmt_str += ' Root Location: {}\n'.format(self.root) 122 | tmp = ' Transforms (if any): ' 123 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 124 | tmp = ' Target Transforms (if any): ' 125 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 126 | return fmt_str 127 | 128 | def _check_exists(self): 129 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ 130 | os.path.exists(self.classes_file) 131 | 132 | def download(self): 133 | """Download the FGVC-Aircraft data if it doesn't exist already.""" 134 | from six.moves import urllib 135 | import tarfile 136 | 137 | if self._check_exists(): 138 | return 139 | 140 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz 141 | print('Downloading %s ... (may take a few minutes)' % self.url) 142 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir)) 143 | tar_name = self.url.rpartition('/')[-1] 144 | tar_path = os.path.join(parent_dir, tar_name) 145 | data = urllib.request.urlopen(self.url) 146 | 147 | # download .tar.gz file 148 | with open(tar_path, 'wb') as f: 149 | f.write(data.read()) 150 | 151 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b 152 | data_folder = tar_path.strip('.tar.gz') 153 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder)) 154 | tar = tarfile.open(tar_path) 155 | tar.extractall(parent_dir) 156 | 157 | # if necessary, rename data folder to self.root 158 | if not os.path.samefile(data_folder, self.root): 159 | print('Renaming %s to %s ...' % (data_folder, self.root)) 160 | os.rename(data_folder, self.root) 161 | 162 | # delete .tar.gz file 163 | print('Deleting %s ...' % tar_path) 164 | os.remove(tar_path) 165 | 166 | print('Done!') 167 | 168 | 169 | def subsample_dataset(dataset, idxs, absolute=True): 170 | 171 | mask = np.zeros(len(dataset)).astype('bool') 172 | if absolute==True: 173 | mask[idxs] = True 174 | else: 175 | idxs = set(idxs) 176 | mask = np.array([i in idxs for i in dataset.uq_idxs]) 177 | 178 | dataset.samples = [(p, t) for m, (p, t) in zip(mask, dataset.samples) if m==True] 179 | dataset.uq_idxs = dataset.uq_idxs[mask] 180 | 181 | return dataset 182 | 183 | 184 | def subsample_classes(dataset, include_classes=range(60)): 185 | 186 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] 187 | 188 | # TODO: Don't transform targets for now 189 | target_xform_dict = {} 190 | for i, k in enumerate(include_classes): 191 | target_xform_dict[k] = i 192 | 193 | dataset = subsample_dataset(dataset, cls_idxs) 194 | 195 | dataset.target_transform = lambda x: target_xform_dict[x] 196 | 197 | return dataset 198 | 199 | 200 | def get_train_val_indices(train_dataset, val_split=0.2): 201 | 202 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)] 203 | train_classes = np.unique(all_targets) 204 | 205 | # Get train/test indices 206 | train_idxs = [] 207 | val_idxs = [] 208 | for cls in train_classes: 209 | cls_idxs = np.where(all_targets == cls)[0] 210 | 211 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 212 | t_ = [x for x in cls_idxs if x not in v_] 213 | 214 | train_idxs.extend(t_) 215 | val_idxs.extend(v_) 216 | 217 | return train_idxs, val_idxs 218 | 219 | 220 | def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8, 221 | split_train_val=False, seed=0): 222 | 223 | np.random.seed(seed) 224 | 225 | # Init entire training set 226 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval') 227 | 228 | # Get labelled training set which has subsampled classes, then subsample some indices from that 229 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 230 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 231 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 232 | 233 | # Split into training and validation sets 234 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 235 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 236 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 237 | val_dataset_labelled_split.transform = test_transform 238 | 239 | # Get unlabelled data 240 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 241 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 242 | 243 | # Get test set for all classes 244 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test') 245 | 246 | # Either split train into train and val or use test set as val 247 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 248 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 249 | 250 | all_datasets = { 251 | 'train_labelled': train_dataset_labelled, 252 | 'train_unlabelled': train_dataset_unlabelled, 253 | 'val': val_dataset_labelled, 254 | 'test': test_dataset, 255 | } 256 | 257 | return all_datasets 258 | 259 | def get_aircraft_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80), 260 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1): 261 | 262 | np.random.seed(seed) 263 | 264 | # Init entire training set 265 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval') 266 | 267 | # Get labelled training set which has subsampled classes, then subsample some indices from that 268 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 269 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 270 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 271 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 272 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False) 273 | 274 | # Split into training and validation sets 275 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split) 276 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 277 | val_dataset_labelled.transform = test_transform 278 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 279 | 280 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split) 281 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs) 282 | val_dataset_unlabelled.transform = test_transform 283 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs) 284 | 285 | val_dataset_unlabelled.transform = test_transform 286 | 287 | 288 | # Get test set for all classes 289 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test') 290 | 291 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}') 292 | 293 | all_datasets = { 294 | 'train_labelled': train_dataset_labelled, 295 | 'train_unlabelled': train_dataset_unlabelled, 296 | 'val': [val_dataset_labelled, val_dataset_unlabelled], 297 | 'test': test_dataset, 298 | } 299 | 300 | return all_datasets 301 | 302 | if __name__ == '__main__': 303 | 304 | x = get_aircraft_datasets(None, None, split_train_val=False) 305 | 306 | print('Printing lens...') 307 | for k, v in x.items(): 308 | if v is not None: 309 | print(f'{k}: {len(v)}') 310 | 311 | print('Printing labelled and unlabelled overlap...') 312 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 313 | print('Printing total instances in train...') 314 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 315 | print('Printing number of labelled classes...') 316 | print(len(set([i[1] for i in x['train_labelled'].samples]))) 317 | print('Printing total number of classes...') 318 | print(len(set([i[1] for i in x['train_unlabelled'].samples]))) 319 | -------------------------------------------------------------------------------- /methods/meanshift_clustering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import timm 4 | import argparse 5 | import numpy as np 6 | 7 | 8 | from tqdm import tqdm 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader 12 | from torchvision import transforms 13 | 14 | from data.stanford_cars import CarsDataset 15 | from data.cifar import CustomCIFAR10, CustomCIFAR100, cifar_10_root, cifar_100_root 16 | from data.herbarium_19 import HerbariumDataset19, herbarium_dataroot 17 | from data.augmentations import get_transform 18 | from data.imagenet import get_imagenet_100_datasets 19 | from data.data_utils import MergedDataset 20 | from data.cub import CustomCub2011, cub_root, get_cub_datasets 21 | from data.fgvc_aircraft import FGVCAircraft, aircraft_root 22 | from data.get_datasets import get_datasets, get_class_splits, get_datasets_with_gcdval 23 | 24 | from project_utils.general_utils import strip_state_dict, str2bool 25 | from copy import deepcopy 26 | from project_utils.cluster_and_log_utils import * 27 | 28 | from models.dino import * 29 | from models import vision_transformer as vits 30 | from config import dino_pretrain_path 31 | 32 | 33 | def iterative_meanshift(model, loader, args): 34 | """ 35 | This function measures clustering accuracies on GCD setup in both GT, predicted number of classes 36 | 37 | Clustering : labeled train dataset, unlabeled train dataset 38 | Stopping condition : labeled train dataset 39 | Acc : unlabeled train dataset 40 | """ 41 | num_clusters = [args.num_labeled_classes + args.num_unlabeled_classes, args.num_clusters] 42 | acc = [0, 0] 43 | max_acc = [0, 0] 44 | tolerance = [0, 0] 45 | final_acc = [(0,0,0), (0,0,0)] 46 | print('Predicted number of clusters: ', args.num_clusters) 47 | 48 | all_feats = torch.zeros(size=(len(loader.dataset), args.feat_dim)) 49 | new_feats = torch.zeros(size=(len(loader.dataset), args.feat_dim)) 50 | targets = torch.zeros(len(loader.dataset), dtype=int) 51 | mask_lab = torch.zeros(len(loader.dataset), dtype=bool) 52 | mask_cls = torch.zeros(len(loader.dataset), dtype=bool) 53 | 54 | with torch.no_grad(): 55 | for epoch in range(args.epochs): 56 | # Save embeddings 57 | for batch_idx, batch in enumerate(tqdm(loader)): 58 | images, label, uq_idxs, mask_lab_ = batch 59 | if epoch == 0: 60 | images = torch.Tensor(images).to(device) 61 | all_feats[uq_idxs] = model(images).detach().cpu() 62 | targets[uq_idxs] = label.cpu() 63 | mask_lab[uq_idxs] = mask_lab_.squeeze(1).cpu().bool() 64 | else: 65 | classwise_sim = torch.einsum('b d, n d -> b n', all_feats[uq_idxs], all_feats) 66 | _, indices = classwise_sim.topk(k=args.k+1, dim=-1, largest=True, sorted=True) 67 | indices = indices[:, 1:] 68 | knn_emb = torch.mean(all_feats[indices].view(-1, args.k, args.feat_dim), dim=1) 69 | new_feats[uq_idxs] = (1-args.alpha) * all_feats[uq_idxs] + args.alpha * knn_emb.detach().cpu() 70 | 71 | 72 | if epoch == 0: 73 | mask_cls = np.isin(targets, range(len(args.train_classes))).astype(bool) 74 | mask_lab = mask_lab.numpy().astype(bool) 75 | l_targets = targets[mask_lab].numpy() 76 | u_targets = targets[~mask_lab].numpy() 77 | mask = mask_cls[~mask_lab] 78 | mask = mask.astype(bool) 79 | else: 80 | norm = torch.sqrt(torch.sum((torch.pow(new_feats, 2)), dim=-1)).unsqueeze(1) 81 | new_feats = new_feats / norm 82 | all_feats = new_feats 83 | 84 | # Agglomerative clustering 85 | linked = linkage(all_feats, method="ward") 86 | for i in range(len(num_clusters)): 87 | if num_clusters[i]: 88 | print('num clusters', num_clusters[i]) 89 | else: 90 | continue 91 | 92 | threshold = linked[:, 2][-num_clusters[i]] 93 | preds = fcluster(linked, t=threshold, criterion='distance') 94 | 95 | old_acc_train = compute_acc(l_targets, preds[mask_lab]) 96 | all_acc_test, old_acc_test, new_acc_test = log_accs_from_preds(y_true=u_targets, y_pred=preds[~mask_lab], mask=mask, 97 | T=epoch, eval_funcs=args.eval_funcs, save_name='IMS unlabeled train ACC', print_output=True) 98 | 99 | # Stopping condition with tolerance 100 | tolerance[i] = 0 if max_acc[i] < old_acc_train else tolerance[i] + 1 101 | 102 | if max_acc[i] <= old_acc_train: 103 | max_acc[i] = old_acc_train 104 | acc[i] = (all_acc_test, old_acc_test, new_acc_test) 105 | 106 | # Stop 107 | if tolerance[i] >= 2: 108 | num_clusters[i] = 0 109 | final_acc[i] = acc[i] 110 | 111 | # If both GT & predicted K stopped, break 112 | if sum(num_clusters) == 0: 113 | break 114 | 115 | print(f'ACC with GT number of clusters: All {final_acc[0][0]:.4f} | Old {final_acc[0][1]:.4f} | New {final_acc[0][2]:.4f}') 116 | print(f'ACC with predicted number of clusters: All {final_acc[1][0]:.4f} | Old {final_acc[1][1]:.4f} | New {final_acc[1][2]:.4f}') 117 | 118 | 119 | def iterative_meanshift_inductive(model, loader, val_loader, args): 120 | """ 121 | This function measures clustering accuracies on inductive GCD setup in both GT, predicted number of classes 122 | Clustering : test dataset 123 | Stopping condition : val dataset 124 | Acc : test dataset 125 | """ 126 | 127 | num_clusters = [args.num_labeled_classes + args.num_unlabeled_classes, args.num_clusters] 128 | acc = [0, 0] 129 | max_acc = [0, 0] 130 | tolerance = [0, 0] 131 | final_acc = [(0,0,0), (0,0,0)] 132 | print('Predicted number of clusters: ', args.num_clusters) 133 | 134 | all_feats = torch.zeros(size=(len(loader.dataset), args.feat_dim)) 135 | new_feats = torch.zeros(size=(len(loader.dataset), args.feat_dim)) 136 | targets = torch.zeros(len(loader.dataset), dtype=int) 137 | mask_lab = torch.zeros(len(loader.dataset), dtype=bool) 138 | mask_cls = torch.zeros(len(loader.dataset), dtype=bool) 139 | 140 | all_feats_val = [] 141 | new_feats_val = [] 142 | targets_val = [] 143 | mask_cls_val = [] 144 | with torch.no_grad(): 145 | for epoch in range(args.epochs): 146 | # Save embeddings (test) 147 | for batch_idx, batch in enumerate(tqdm(loader)): 148 | images, label, uq_idxs = batch 149 | if epoch == 0: 150 | images = torch.Tensor(images).to(device) 151 | all_feats[uq_idxs] = model(images).detach().cpu() 152 | targets[uq_idxs] = label.cpu() 153 | else: 154 | classwise_sim = torch.einsum('b d, n d -> b n', all_feats[uq_idxs], all_feats) 155 | _, indices = classwise_sim.topk(k=args.k+1, dim=-1, largest=True, sorted=True) 156 | indices = indices[:, 1:] 157 | knn_emb = torch.mean(all_feats[indices].view(-1, args.k, args.feat_dim), dim=1) 158 | new_feats[uq_idxs] = (1-args.alpha) * all_feats[uq_idxs] + args.alpha * knn_emb.detach().cpu() 159 | 160 | if epoch == 0: 161 | mask_cls = np.isin(targets, range(len(args.train_classes))).astype(bool) 162 | targets = np.array(targets) 163 | else: 164 | norm = torch.sqrt(torch.sum((torch.pow(new_feats, 2)), dim=-1)).unsqueeze(1) 165 | new_feats = new_feats / norm 166 | all_feats = new_feats 167 | 168 | # Save embeddings (val) 169 | for batch_idx, batch in enumerate(tqdm(val_loader)): 170 | images, label, uq_idxs = batch[:3] 171 | if epoch == 0: 172 | images = torch.Tensor(images).to(device) 173 | all_feats_val.append(model(images).detach().cpu()) 174 | targets_val.append(label.cpu()) 175 | else: 176 | start_idx = batch_idx*args.batch_size 177 | classwise_sim = torch.einsum('b d, n d -> b n', all_feats_val[start_idx:start_idx+len(uq_idxs)], all_feats_val) 178 | _, indices = classwise_sim.topk(k=args.k+1, dim=-1, largest=True, sorted=True) 179 | indices = indices[:, 1:] 180 | knn_emb_val = torch.mean(all_feats_val[indices].view(-1, args.k, args.feat_dim), dim=1) 181 | new_feats_val[start_idx:start_idx+len(uq_idxs)] = (1-args.alpha) * all_feats_val[start_idx:start_idx+len(uq_idxs)] + args.alpha * knn_emb_val.detach().cpu() 182 | 183 | if epoch == 0: 184 | all_feats_val = torch.cat(all_feats_val) 185 | targets_val = np.array(torch.cat(targets_val)) 186 | mask_cls_val = np.isin(targets_val, range(len(args.train_classes))).astype(bool) 187 | new_feats_val = all_feats_val 188 | else: 189 | norm = torch.sqrt(torch.sum((torch.pow(new_feats_val, 2)), dim=-1)).unsqueeze(1) 190 | new_feats_val = new_feats_val / norm 191 | all_feats_val = new_feats_val 192 | 193 | # Agglomerative clustering 194 | linked = linkage(all_feats, method="ward") 195 | linked_val = linkage(all_feats_val, method="ward") 196 | for i in range(len(num_clusters)): 197 | if num_clusters[i]: 198 | print('num clusters', num_clusters[i]) 199 | else: 200 | continue 201 | 202 | # acc of validation set 203 | threshold = linked[:, 2][-num_clusters[i]] 204 | preds_val = fcluster(linked_val, t=threshold, criterion='distance') 205 | old_acc_val = compute_acc(targets_val[mask_cls_val], preds_val[mask_cls_val]) 206 | 207 | # acc of test set 208 | threshold = linked[:, 2][-num_clusters[i]] 209 | preds = fcluster(linked, t=threshold, criterion='distance') 210 | all_acc_test, old_acc_test, new_acc_test = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask_cls, 211 | T=epoch, eval_funcs=args.eval_funcs, save_name='IMS test ACC', print_output=True) 212 | 213 | # Stopping condition with tolerance 214 | tolerance[i] = 0 if max_acc[i] < old_acc_val else tolerance[i] + 1 215 | 216 | if max_acc[i] <= old_acc_val: 217 | max_acc[i] = old_acc_val 218 | acc[i] = (all_acc_test, old_acc_test, new_acc_test) 219 | 220 | if tolerance[i] >= 2: 221 | num_clusters[i] = 0 222 | final_acc[i] = acc[i] 223 | 224 | # If both GT & predicted K stopped, break 225 | if sum(num_clusters) == 0: 226 | break 227 | 228 | print(f'ACC with GT number of clusters: All {final_acc[0][0]:.4f} | Old {final_acc[0][1]:.4f} | New {final_acc[0][2]:.4f}') 229 | print(f'ACC with predicted number of clusters: All {final_acc[1][0]:.4f} | Old {final_acc[1][1]:.4f} | New {final_acc[1][2]:.4f}') 230 | 231 | 232 | def compute_acc(y_true, y_pred): 233 | y_true = y_true.astype(int) 234 | old_classes_gt = set(y_true) 235 | 236 | assert y_pred.size == y_true.size 237 | D = max(y_pred.max(), y_true.max()) + 1 238 | w = np.zeros((D, D), dtype=int) 239 | for i in range(y_pred.size): 240 | w[y_pred[i], y_true[i]] += 1 241 | # w: pred x label count 242 | ind = linear_assignment(w.max() - w) 243 | ind = np.vstack(ind).T 244 | 245 | ind_map = {j: i for i, j in ind} 246 | total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 247 | 248 | return total_acc 249 | 250 | 251 | if __name__ == "__main__": 252 | 253 | parser = argparse.ArgumentParser( 254 | description='cluster', 255 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 256 | parser.add_argument('--batch_size', default=128, type=int) 257 | parser.add_argument('--num_workers', default=8, type=int) 258 | parser.add_argument('--model_name', type=str, default=None) 259 | parser.add_argument('--pretrain_path', type=str, default='./models') 260 | parser.add_argument('--transform', type=str, default='imagenet') 261 | parser.add_argument('--eval_funcs', type=str, default=['v2']) 262 | parser.add_argument('--use_ssb_splits', type=str2bool, default=True) 263 | parser.add_argument('--model_name', type=str, default='cms', help='Format is {model_name}_{pretrain}') 264 | parser.add_argument('--dataset_name', type=str, default='aircraft', help='options: cifar10, cifar100, scars') 265 | parser.add_argument('--epochs', default=20, type=int) 266 | parser.add_argument('--feat_dim', default=768, type=int) 267 | parser.add_argument('--num_clusters', default=None, type=int) 268 | parser.add_argument('--inductive', action='store_true') 269 | parser.add_argument('--k', default=8, type=int) 270 | parser.add_argument('--alpha', type=float, default=0.5) 271 | 272 | # ---------------------- 273 | # INIT 274 | # ---------------------- 275 | args = parser.parse_args() 276 | device = torch.device('cuda:0') 277 | args = get_class_splits(args) 278 | 279 | args.num_labeled_classes = len(args.train_classes) 280 | args.num_unlabeled_classes = len(args.unlabeled_classes) 281 | print(args) 282 | 283 | args.model_name = "/home/suachoi/CMS/log/metric_learn_gcd/log/" + args.model_name + "/checkpoints/model_best.pt" 284 | print(f'Using weights from {args.model_name} ...') 285 | 286 | # ---------------------- 287 | # MODEL 288 | # ---------------------- 289 | args.interpolation = 3 290 | args.crop_pct = 0.875 291 | args.prop_train_labels = 0.5 292 | args.num_mlp_layers = 3 293 | args.feat_dim = 768 294 | 295 | model = DINO(args) 296 | model = nn.DataParallel(model) 297 | model.to(device) 298 | model.eval() 299 | 300 | state_dict = torch.load(args.model_name) 301 | model.load_state_dict(state_dict['model_state_dict'], strict=False) 302 | args.num_clusters = state_dict['k'] 303 | 304 | # ---------------------- 305 | # DATASET 306 | # ---------------------- 307 | train_transform, test_transform = get_transform(args.transform, image_size=224, args=args) 308 | if args.inductive: 309 | _, test_dataset, _, val_dataset, _ = get_datasets_with_gcdval(args.dataset_name, test_transform, test_transform, args) 310 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) 311 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) 312 | iterative_meanshift_inductive(model, test_loader, val_loader, args) 313 | else: 314 | train_dataset, _, _, _ = get_datasets(args.dataset_name, test_transform, test_transform, args) 315 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) 316 | iterative_meanshift(model, train_loader, args) --------------------------------------------------------------------------------