├── data
├── assets
│ └── overview.png
├── ssb_splits
│ ├── cub_osr_splits.pkl
│ ├── scars_osr_splits.pkl
│ ├── aircraft_osr_splits.pkl
│ └── herbarium_19_class_splits.pkl
├── data_utils.py
├── augmentations
│ ├── cut_out.py
│ ├── __init__.py
│ └── randaugment.py
├── stanford_cars.py
├── herbarium_19.py
├── get_datasets.py
├── cifar.py
├── cub.py
├── imagenet.py
└── fgvc_aircraft.py
├── .gitignore
├── config.py
├── requirements.txt
├── models
├── dino.py
└── vision_transformer.py
├── bash_scripts
├── meanshift_clustering.sh
└── contrastive_meanshift_training.sh
├── project_utils
├── general_utils.py
├── cluster_and_log_utils.py
└── cluster_utils.py
├── methods
├── loss.py
├── contrastive_meanshift_training.py
└── meanshift_clustering.py
└── README.md
/data/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/assets/overview.png
--------------------------------------------------------------------------------
/data/ssb_splits/cub_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/ssb_splits/cub_osr_splits.pkl
--------------------------------------------------------------------------------
/data/ssb_splits/scars_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/ssb_splits/scars_osr_splits.pkl
--------------------------------------------------------------------------------
/data/ssb_splits/aircraft_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/ssb_splits/aircraft_osr_splits.pkl
--------------------------------------------------------------------------------
/data/ssb_splits/herbarium_19_class_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sua-choi/CMS/HEAD/data/ssb_splits/herbarium_19_class_splits.pkl
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | .vscode/
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Logs
11 | log/
12 | wandb/
13 | data/features/
14 | data/memory
15 |
16 | #slurm
17 | *.out
18 |
19 | models/*.pth
20 | *.pt
21 | *.gif
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # -----------------
2 | # DATASET ROOTS
3 | # -----------------
4 | DATASET_ROOT='/home/suachoi/datasets/'
5 | cifar_10_root = DATASET_ROOT+'cifar10'
6 | cifar_100_root = DATASET_ROOT+'cifar100'
7 | cub_root = DATASET_ROOT
8 | aircraft_root = DATASET_ROOT+'fgvc-aircraft-2013b'
9 | herbarium_dataroot = DATASET_ROOT+'herbarium_19'
10 | imagenet_root = DATASET_ROOT+'imagenet'
11 |
12 | # OSR Split dir
13 | osr_split_dir = './data/ssb_splits'
14 |
15 | dino_pretrain_path = './models/dino_vitbase16_pretrain.pth'
16 | exp_root = './log' # All logs and checkpoints will be saved here
17 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | appdirs==1.4.4
3 | certifi==2024.2.2
4 | chardet==5.2.0
5 | charset-normalizer==2.0.4
6 | click==8.1.7
7 | contourpy==1.2.1
8 | cycler==0.12.1
9 | docker-pycreds==0.4.0
10 | fonttools==4.51.0
11 | gitdb==4.0.11
12 | GitPython==3.1.43
13 | grpcio==1.62.1
14 | idna==3.4
15 | joblib==1.3.2
16 | kiwisolver==1.4.5
17 | Markdown==3.6
18 | MarkupSafe==2.1.5
19 | matplotlib==3.8.4
20 | mkl-fft==1.3.8
21 | mkl-random==1.2.4
22 | mkl-service==2.4.0
23 | numpy==1.26.4
24 | packaging==24.0
25 | pandas==2.2.1
26 | pillow==10.2.0
27 | pip==23.3.1
28 | protobuf==4.25.3
29 | psutil==5.9.8
30 | pyparsing==3.1.2
31 | python-dateutil==2.9.0.post0
32 | pytz==2024.1
33 | PyYAML==6.0.1
34 | requests==2.31.0
35 | scikit-learn==1.4.1.post1
36 | scipy==1.13.0
37 | sentry-sdk==1.44.0
38 | setproctitle==1.3.3
39 | setuptools==68.2.2
40 | six==1.16.0
41 | smmap==5.0.1
42 | tensorboard==2.16.2
43 | tensorboard-data-server==0.7.2
44 | threadpoolctl==3.4.0
45 | torch==1.13.1
46 | torchaudio==0.13.1
47 | torchvision==0.14.1
48 | tqdm==4.66.2
49 | typing_extensions==4.9.0
50 | tzdata==2024.1
51 | urllib3==2.1.0
52 | wandb==0.16.5
53 | Werkzeug==3.0.2
54 | wheel==0.41.2
55 |
--------------------------------------------------------------------------------
/models/dino.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from models import vision_transformer as vits
8 |
9 | class DINO(nn.Module):
10 | def __init__(self, args):
11 | super(DINO, self).__init__()
12 | self.k = args.k
13 | self.backbone = self.init_backbone(args.pretrain_path)
14 | self.img_projection_head = vits.__dict__['DINOHead'](in_dim=args.feat_dim, out_dim=args.feat_dim, nlayers=args.num_mlp_layers)
15 |
16 | def forward(self, image):
17 | feat = self.backbone(image)
18 | feat = self.img_projection_head(feat)
19 | feat = F.normalize(feat, dim=-1)
20 |
21 | return feat
22 |
23 | def init_backbone(self, pretrain_path):
24 | model = vits.__dict__['vit_base']()
25 | state_dict = torch.load(os.path.join(pretrain_path, 'dino_vitbase16_pretrain.pth'), map_location='cpu')
26 | model.load_state_dict(state_dict)
27 | for m in model.parameters():
28 | m.requires_grad = False
29 |
30 | for name, m in model.named_parameters():
31 | if 'block' in name:
32 | block_num = int(name.split('.')[1])
33 | if block_num >= 11:
34 | m.requires_grad = True
35 |
36 | return model
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset
3 |
4 | def subsample_instances(dataset, prop_indices_to_subsample=0.8):
5 |
6 | np.random.seed(0)
7 | subsample_indices = np.random.choice(range(len(dataset)), replace=False,
8 | size=(int(prop_indices_to_subsample * len(dataset)),))
9 |
10 | return subsample_indices
11 |
12 | class MergedDataset(Dataset):
13 |
14 | """
15 | Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them
16 | Allows you to iterate over them in parallel
17 | """
18 |
19 | def __init__(self, labelled_dataset, unlabelled_dataset):
20 |
21 | self.labelled_dataset = labelled_dataset
22 | self.unlabelled_dataset = unlabelled_dataset
23 | self.target_transform = None
24 |
25 | def __getitem__(self, item):
26 |
27 | if item < len(self.labelled_dataset):
28 | img, label, uq_idx = self.labelled_dataset[item]
29 | labeled_or_not = 1
30 |
31 | else:
32 |
33 | img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)]
34 | labeled_or_not = 0
35 |
36 |
37 | return img, label, uq_idx, np.array([labeled_or_not])
38 |
39 | def __len__(self):
40 | return len(self.unlabelled_dataset) + len(self.labelled_dataset)
41 |
--------------------------------------------------------------------------------
/data/augmentations/cut_out.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/hysts/pytorch_cutout
3 | """
4 |
5 | import torch
6 | import numpy as np
7 |
8 | def cutout(mask_size, p, cutout_inside, mask_color=(0, 0, 0)):
9 | mask_size_half = mask_size // 2
10 | offset = 1 if mask_size % 2 == 0 else 0
11 |
12 | def _cutout(image):
13 | image = np.asarray(image).copy()
14 |
15 | if np.random.random() > p:
16 | return image
17 |
18 | h, w = image.shape[:2]
19 |
20 | if cutout_inside:
21 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half
22 | cymin, cymax = mask_size_half, h + offset - mask_size_half
23 | else:
24 | cxmin, cxmax = 0, w + offset
25 | cymin, cymax = 0, h + offset
26 |
27 | cx = np.random.randint(cxmin, cxmax)
28 | cy = np.random.randint(cymin, cymax)
29 | xmin = cx - mask_size_half
30 | ymin = cy - mask_size_half
31 | xmax = xmin + mask_size
32 | ymax = ymin + mask_size
33 | xmin = max(0, xmin)
34 | ymin = max(0, ymin)
35 | xmax = min(w, xmax)
36 | ymax = min(h, ymax)
37 | image[ymin:ymax, xmin:xmax] = mask_color
38 | return image
39 |
40 | return _cutout
41 |
42 | def to_tensor():
43 | def _to_tensor(image):
44 | if len(image.shape) == 3:
45 | return torch.from_numpy(
46 | image.transpose(2, 0, 1).astype(float))
47 | else:
48 | return torch.from_numpy(image[None, :, :].astype(float))
49 |
50 | return _to_tensor
51 |
52 | def normalize(mean, std):
53 |
54 | mean = np.array(mean)
55 | std = np.array(std)
56 |
57 | def _normalize(image):
58 | image = np.asarray(image).astype(float) / 255.
59 | image = (image - mean) / std
60 | return image
61 |
62 | return _normalize
--------------------------------------------------------------------------------
/bash_scripts/meanshift_clustering.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | ###################################################
4 | # GCD
5 | ###################################################
6 |
7 | python -m methods.meanshift_clustering \
8 | --dataset_name cifar100 \
9 | --warmup_model_dir 'cifar100_best'
10 |
11 | python -m methods.meanshift_clustering \
12 | --dataset_name imagenet_100 \
13 | --warmup_model_dir 'imagenet_best'
14 |
15 | python -m methods.meanshift_clustering \
16 | --dataset_name cub \
17 | --warmup_model_dir 'cub_best'
18 |
19 | python -m methods.meanshift_clustering \
20 | --dataset_name scars \
21 | --warmup_model_dir 'scars_best'
22 |
23 | python -m methods.meanshift_clustering \
24 | --dataset_name aircraft \
25 | --warmup_model_dir 'aircraft_best'
26 |
27 | python -m methods.meanshift_clustering \
28 | --dataset_name herbarium_19 \
29 | --warmup_model_dir 'herbarium_19_best'
30 |
31 | ###################################################
32 | # INDUCTIVE GCD
33 | ###################################################
34 |
35 | python -m methods.meanshift_clustering \
36 | --dataset_name cifar100 \
37 | --warmup_model_dir 'cifar100_best' \
38 | --inductive
39 |
40 | python -m methods.meanshift_clustering \
41 | --dataset_name imagenet_100 \
42 | --warmup_model_dir 'imagenet_best' \
43 | --inductive
44 |
45 | python -m methods.meanshift_clustering \
46 | --dataset_name cub \
47 | --warmup_model_dir 'cub_best' \
48 | --inductive
49 |
50 | python -m methods.meanshift_clustering \
51 | --dataset_name scars \
52 | --warmup_model_dir 'scars_best' \
53 | --inductive
54 |
55 | python -m methods.meanshift_clustering \
56 | --dataset_name aircraft \
57 | --warmup_model_dir 'aircraft_best' \
58 | --inductive
59 |
60 | python -m methods.meanshift_clustering \
61 | --dataset_name herbarium_19 \
62 | --warmup_model_dir 'herbarium_19_best' \
63 | --inductive
--------------------------------------------------------------------------------
/bash_scripts/contrastive_meanshift_training.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | ###################################################
4 | # GCD
5 | ###################################################
6 |
7 | python -m methods.contrastive_meanshift_training \
8 | --dataset_name 'cifar100' \
9 | --lr 0.01 \
10 | --temperature 0.3 \
11 | --wandb
12 |
13 | python -m methods.contrastive_meanshift_training \
14 | --dataset_name 'imagenet_100' \
15 | --lr 0.01 \
16 | --temperature 0.3 \
17 | --wandb
18 |
19 | python -m methods.contrastive_meanshift_training \
20 | --dataset_name 'cub' \
21 | --lr 0.05 \
22 | --temperature 0.25 \
23 | --eta_min 5e-3 \
24 | --wandb
25 |
26 | python -m methods.contrastive_meanshift_training \
27 | --dataset_name 'scars' \
28 | --lr 0.05 \
29 | --temperature 0.25 \
30 | --eta_min 5e-3 \
31 | --wandb
32 |
33 | python -m methods.contrastive_meanshift_training \
34 | --dataset_name 'aircraft' \
35 | --lr 0.05 \
36 | --temperature 0.25 \
37 | --eta_min 5e-3 \
38 | --wandb
39 |
40 | python -m methods.contrastive_meanshift_training \
41 | --dataset_name 'herbarium_19' \
42 | --lr 0.05 \
43 | --temperature 0.25 \
44 | --eta_min 5e-3 \
45 | --wandb
46 |
47 |
48 | ###################################################
49 | # INDUCTIVE GCD
50 | ###################################################
51 |
52 | python -m methods.contrastive_meanshift_training \
53 | --dataset_name 'cifar100' \
54 | --lr 0.01 \
55 | --temperature 0.3 \
56 | --inductive \
57 | --wandb
58 |
59 | python -m methods.contrastive_meanshift_training \
60 | --dataset_name 'imagenet_100' \
61 | --lr 0.01 \
62 | --temperature 0.3 \
63 | --inductive \
64 | --wandb
65 |
66 | python -m methods.contrastive_meanshift_training \
67 | --dataset_name 'cub' \
68 | --lr 0.05 \
69 | --temperature 0.25 \
70 | --eta_min 5e-3 \
71 | --inductive \
72 | --wandb
73 |
74 | python -m methods.contrastive_meanshift_training \
75 | --dataset_name 'scars' \
76 | --lr 0.05 \
77 | --temperature 0.25 \
78 | --eta_min 5e-3 \
79 | --inductive \
80 | --wandb
81 |
82 | python -m methods.contrastive_meanshift_training \
83 | --dataset_name 'aircraft' \
84 | --lr 0.05 \
85 | --temperature 0.25 \
86 | --eta_min 5e-3 \
87 | --inductive \
88 | --wandb
89 |
90 | python -m methods.contrastive_meanshift_training \
91 | --dataset_name 'herbarium_19' \
92 | --lr 0.05 \
93 | --temperature 0.25 \
94 | --eta_min 5e-3 \
95 | --inductive \
96 | --wandb
--------------------------------------------------------------------------------
/project_utils/general_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 | import inspect
6 |
7 | from datetime import datetime
8 |
9 | def seed_torch(seed=1029):
10 |
11 | random.seed(seed)
12 | os.environ['PYTHONHASHSEED'] = str(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
17 | torch.backends.cudnn.benchmark = False
18 | torch.backends.cudnn.deterministic = True
19 |
20 |
21 | def get_dino_head_weights(pretrain_path):
22 |
23 | """
24 | :param pretrain_path: Path to full DINO pretrained checkpoint as in https://github.com/facebookresearch/dino
25 | 'full_ckpt'
26 | :return: weights only for the projection head
27 | """
28 |
29 | all_weights = torch.load(pretrain_path)
30 |
31 | head_state_dict = {}
32 | for k, v in all_weights['teacher'].items():
33 | if 'head' in k and 'last_layer' not in k:
34 | head_state_dict[k] = v
35 |
36 | head_state_dict = strip_state_dict(head_state_dict, strip_key='head.')
37 |
38 | # Deal with weight norm
39 | weight_norm_state_dict = {}
40 | for k, v in all_weights['teacher'].items():
41 | if 'last_layer' in k:
42 | weight_norm_state_dict[k.split('.')[2]] = v
43 |
44 | linear_shape = weight_norm_state_dict['weight'].shape
45 | dummy_linear = torch.nn.Linear(in_features=linear_shape[1], out_features=linear_shape[0], bias=False)
46 | dummy_linear.load_state_dict(weight_norm_state_dict)
47 | dummy_linear = torch.nn.utils.weight_norm(dummy_linear)
48 |
49 | for k, v in dummy_linear.state_dict().items():
50 |
51 | head_state_dict['last_layer.' + k] = v
52 |
53 | return head_state_dict
54 |
55 |
56 | def init_experiment(args, runner_name=None, exp_id=None):
57 |
58 | args.cuda = torch.cuda.is_available()
59 |
60 | # Get filepath of calling script
61 | if runner_name is None:
62 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:]
63 |
64 | root_dir = os.path.join(args.exp_root, *runner_name)
65 |
66 | if not os.path.exists(root_dir):
67 | os.makedirs(root_dir)
68 |
69 | # Either generate a unique experiment ID, or use one which is passed
70 | if exp_id is None:
71 |
72 | # Unique identifier for experiment
73 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \
74 | datetime.now().strftime("%S.%f")[:-3] + ')'
75 |
76 | log_dir = os.path.join(root_dir, 'log', now)
77 | while os.path.exists(log_dir):
78 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \
79 | datetime.now().strftime("%S.%f")[:-3] + ')'
80 |
81 | log_dir = os.path.join(root_dir, 'log', now)
82 |
83 | else:
84 |
85 | log_dir = os.path.join(root_dir, 'log', f'{exp_id}')
86 |
87 | if not os.path.exists(log_dir):
88 | os.makedirs(log_dir)
89 | args.log_dir = log_dir
90 |
91 | # Instantiate directory to save models to
92 | model_root_dir = os.path.join(args.log_dir, 'checkpoints')
93 | if not os.path.exists(model_root_dir):
94 | os.mkdir(model_root_dir)
95 |
96 | args.model_dir = model_root_dir
97 | args.model_path = os.path.join(args.model_dir, 'model.pt')
98 |
99 | print(f'Experiment saved to: {args.log_dir}')
100 |
101 | print(runner_name)
102 | print(args)
103 |
104 | return args
105 |
106 |
107 | def str2bool(v):
108 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
109 | return True
110 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
111 | return False
112 | else:
113 | raise argparse.ArgumentTypeError('Boolean value expected.')
114 |
115 |
--------------------------------------------------------------------------------
/data/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from data.augmentations.cut_out import *
3 | from data.augmentations.randaugment import RandAugment
4 |
5 | def get_transform(transform_type='default', image_size=32, args=None):
6 |
7 | if transform_type == 'imagenet':
8 |
9 | mean = (0.485, 0.456, 0.406)
10 | std = (0.229, 0.224, 0.225)
11 | interpolation = args.interpolation
12 | crop_pct = args.crop_pct
13 |
14 | train_transform = transforms.Compose([
15 | transforms.Resize(int(image_size / crop_pct), interpolation),
16 | transforms.RandomCrop(image_size),
17 | transforms.RandomHorizontalFlip(p=0.5),
18 | transforms.ColorJitter(),
19 | transforms.ToTensor(),
20 | transforms.Normalize(
21 | mean=torch.tensor(mean),
22 | std=torch.tensor(std))
23 | ])
24 |
25 | test_transform = transforms.Compose([
26 | transforms.Resize(int(image_size / crop_pct), interpolation),
27 | transforms.CenterCrop(image_size),
28 | transforms.ToTensor(),
29 | transforms.Normalize(
30 | mean=torch.tensor(mean),
31 | std=torch.tensor(std))
32 | ])
33 |
34 | elif transform_type == 'pytorch-cifar':
35 |
36 | mean = (0.4914, 0.4822, 0.4465)
37 | std = (0.2023, 0.1994, 0.2010)
38 |
39 | train_transform = transforms.Compose([
40 | transforms.RandomCrop(image_size, padding=4),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=mean, std=std),
44 | ])
45 |
46 | test_transform = transforms.Compose([
47 | transforms.Resize((image_size, image_size)),
48 | transforms.ToTensor(),
49 | transforms.Normalize(mean=mean, std=std),
50 | ])
51 |
52 | elif transform_type == 'herbarium_default':
53 |
54 | train_transform = transforms.Compose([
55 | transforms.Resize((image_size, image_size)),
56 | transforms.RandomResizedCrop(image_size, scale=(args.resize_lower_bound, 1)),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | ])
60 |
61 | test_transform = transforms.Compose([
62 | transforms.Resize((image_size, image_size)),
63 | transforms.ToTensor(),
64 | ])
65 |
66 | elif transform_type == 'cutout':
67 |
68 | mean = np.array([0.4914, 0.4822, 0.4465])
69 | std = np.array([0.2470, 0.2435, 0.2616])
70 |
71 | train_transform = transforms.Compose([
72 | transforms.RandomCrop(image_size, padding=4),
73 | transforms.RandomHorizontalFlip(),
74 | normalize(mean, std),
75 | cutout(mask_size=int(image_size / 2),
76 | p=1,
77 | cutout_inside=False),
78 | to_tensor(),
79 | ])
80 | test_transform = transforms.Compose([
81 | transforms.Resize((image_size, image_size)),
82 | transforms.ToTensor(),
83 | transforms.Normalize(mean, std),
84 | ])
85 |
86 | elif transform_type == 'rand-augment':
87 |
88 | mean = (0.485, 0.456, 0.406)
89 | std = (0.229, 0.224, 0.225)
90 |
91 | train_transform = transforms.Compose([
92 | transforms.Resize((image_size, image_size)),
93 | transforms.RandomCrop(image_size, padding=4),
94 | transforms.RandomHorizontalFlip(),
95 | transforms.ToTensor(),
96 | transforms.Normalize(mean=mean, std=std),
97 | ])
98 |
99 | train_transform.transforms.insert(0, RandAugment(args.rand_aug_n, args.rand_aug_m, args=None))
100 |
101 | test_transform = transforms.Compose([
102 | transforms.Resize((image_size, image_size)),
103 | transforms.ToTensor(),
104 | transforms.Normalize(mean=mean, std=std),
105 | ])
106 |
107 | elif transform_type == 'random_affine':
108 |
109 | mean = (0.485, 0.456, 0.406)
110 | std = (0.229, 0.224, 0.225)
111 | interpolation = args.interpolation
112 | crop_pct = args.crop_pct
113 |
114 | train_transform = transforms.Compose([
115 | transforms.Resize((image_size, image_size), interpolation),
116 | transforms.RandomAffine(degrees=(-45, 45),
117 | translate=(0.1, 0.1), shear=(-15, 15), scale=(0.7, args.crop_pct)),
118 | transforms.ColorJitter(),
119 | transforms.ToTensor(),
120 | transforms.Normalize(
121 | mean=torch.tensor(mean),
122 | std=torch.tensor(std))
123 | ])
124 |
125 | test_transform = transforms.Compose([
126 | transforms.Resize(int(image_size / crop_pct), interpolation),
127 | transforms.CenterCrop(image_size),
128 | transforms.ToTensor(),
129 | transforms.Normalize(
130 | mean=torch.tensor(mean),
131 | std=torch.tensor(std))
132 | ])
133 |
134 | else:
135 |
136 | raise NotImplementedError
137 |
138 | return (train_transform, test_transform)
--------------------------------------------------------------------------------
/methods/loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class SupConLoss(nn.Module):
9 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
10 | It also supports the unsupervised contrastive loss in SimCLR"""
11 | def __init__(self, temperature=0.07, contrast_mode='all',
12 | base_temperature=0.07):
13 | super(SupConLoss, self).__init__()
14 | self.temperature = temperature
15 | self.contrast_mode = contrast_mode
16 | self.base_temperature = base_temperature
17 |
18 | def forward(self, features, labels=None, mask=None):
19 | """Compute loss for model. If both `labels` and `mask` are None,
20 | it degenerates to SimCLR unsupervised loss:
21 | https://arxiv.org/pdf/2002.05709.pdf
22 |
23 | Args:
24 | features: hidden vector of shape [bsz, n_views, ...].
25 | labels: ground truth of shape [bsz].
26 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
27 | has the same class as sample i. Can be asymmetric.
28 | Returns:
29 | A loss scalar.
30 | """
31 | device = (torch.device('cuda')
32 | if features.is_cuda
33 | else torch.device('cpu'))
34 |
35 | if len(features.shape) < 3:
36 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
37 | 'at least 3 dimensions are required')
38 | if len(features.shape) > 3:
39 | features = features.view(features.shape[0], features.shape[1], -1)
40 |
41 | batch_size = features.shape[0]
42 | if labels is not None and mask is not None:
43 | raise ValueError('Cannot define both `labels` and `mask`')
44 | elif labels is None and mask is None:
45 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
46 | elif labels is not None:
47 | labels = labels.contiguous().view(-1, 1)
48 | if labels.shape[0] != batch_size:
49 | raise ValueError('Num of labels does not match num of features')
50 | mask = torch.eq(labels, labels.T).float().to(device)
51 | else:
52 | mask = mask.float().to(device)
53 |
54 | contrast_count = features.shape[1]
55 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
56 | if self.contrast_mode == 'one':
57 | anchor_feature = features[:, 0]
58 | anchor_count = 1
59 | elif self.contrast_mode == 'all':
60 | anchor_feature = contrast_feature
61 | anchor_count = contrast_count # n_views
62 | else:
63 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
64 |
65 | # compute logits
66 | anchor_dot_contrast = torch.div(
67 | torch.matmul(anchor_feature, contrast_feature.T),
68 | self.temperature)
69 | # for numerical stability
70 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
71 | logits = anchor_dot_contrast - logits_max.detach()
72 |
73 | # tile mask
74 | mask = mask.repeat(anchor_count, contrast_count)
75 | # mask-out self-contrast cases
76 | logits_mask = torch.scatter(
77 | torch.ones_like(mask),
78 | 1,
79 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
80 | 0
81 | )
82 | mask = mask * logits_mask
83 |
84 | # compute log_prob
85 | exp_logits = torch.exp(logits) * logits_mask
86 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
87 |
88 | # compute mean of log-likelihood over positive
89 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
90 |
91 | # loss
92 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
93 | loss = loss.view(anchor_count, batch_size).mean()
94 |
95 | return loss
96 |
97 | class ConMeanShiftLoss(nn.Module):
98 | def __init__(self, args):
99 | super(ConMeanShiftLoss, self).__init__()
100 | self.n_views = args.n_views
101 | self.alpha = args.alpha
102 | self.temperature = args.temperature
103 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
104 |
105 | def forward(self, knn_emb, features):
106 | batch_size = int(features.size(0)) // self.n_views
107 | positive_mask = torch.eye(batch_size, dtype=torch.float32).to(self.device)
108 | positive_mask = positive_mask.repeat(self.n_views, self.n_views)
109 | negative_mask = torch.ones_like(positive_mask)
110 | negative_mask[positive_mask>0] = 0
111 | self_mask = torch.eye(batch_size * self.n_views, dtype=torch.float32).to(self.device)
112 | positive_mask[self_mask>0] = 0
113 |
114 | #compute logits
115 | meanshift_feat = (1-self.alpha) * features + self.alpha * knn_emb
116 | norm = torch.sqrt(torch.sum((torch.pow(meanshift_feat, 2)), dim=-1)).unsqueeze(1).detach()
117 | meanshift_feat = meanshift_feat / norm
118 |
119 | anchor_dot_contrast = torch.div(torch.matmul(meanshift_feat, meanshift_feat.T), self.temperature)
120 |
121 | # for numerical stability
122 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
123 | logits = anchor_dot_contrast - logits_max.detach()
124 |
125 | # compute log_prob
126 | exp_logits = torch.exp(logits) * negative_mask
127 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
128 |
129 | # compute mean of log-likelihood over positive
130 | loss = - ((positive_mask * log_prob).sum(1) / positive_mask.sum(1))
131 | loss = loss.view(self.n_views, batch_size).mean()
132 |
133 | return loss
134 |
--------------------------------------------------------------------------------
/project_utils/cluster_and_log_utils.py:
--------------------------------------------------------------------------------
1 | from project_utils.cluster_utils import cluster_acc, np, linear_assignment
2 | from torch.utils.tensorboard import SummaryWriter
3 | from typing import List
4 | from scipy.cluster.hierarchy import linkage, fcluster
5 | from sklearn.metrics import silhouette_score
6 |
7 | def split_cluster_acc_v1(y_true, y_pred, mask):
8 |
9 | """
10 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask'
11 | (Mask usually corresponding to `Old' and `New' classes in GCD setting)
12 | :param targets: All ground truth labels
13 | :param preds: All predictions
14 | :param mask: Mask defining two subsets
15 | :return:
16 | """
17 |
18 | mask = mask.astype(bool)
19 | y_true = y_true.astype(int)
20 | y_pred = y_pred.astype(int)
21 | weight = mask.mean()
22 | old_acc = cluster_acc(y_true[mask], y_pred[mask])
23 | new_acc = cluster_acc(y_true[~mask], y_pred[~mask])
24 | total_acc = weight * old_acc + (1 - weight) * new_acc
25 |
26 | return total_acc, old_acc, new_acc
27 |
28 | def split_cluster_acc_v2(y_true, y_pred, mask):
29 | """
30 | Calculate clustering accuracy. Require scikit-learn installed
31 | First compute linear assignment on all data, then look at how good the accuracy is on subsets
32 |
33 | # Arguments
34 | mask: Which instances come from old classes (True) and which ones come from new classes (False)
35 | y: true labels, numpy.array with shape `(n_samples,)`
36 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
37 |
38 | # Return
39 | accuracy, in [0,1]
40 | """
41 | y_true = y_true.astype(int)
42 |
43 | old_classes_gt = set(y_true[mask])
44 | new_classes_gt = set(y_true[~mask])
45 |
46 | assert y_pred.size == y_true.size
47 | D = max(y_pred.max(), y_true.max()) + 1
48 | w = np.zeros((D, D), dtype=int)
49 | for i in range(y_pred.size):
50 | w[y_pred[i], y_true[i]] += 1
51 | # w: pred x label count
52 | ind = linear_assignment(w.max() - w)
53 | ind = np.vstack(ind).T
54 |
55 | ind_map = {j: i for i, j in ind}
56 | total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
57 |
58 | old_acc = 0
59 | classwise_acc = []
60 | total_old_instances = 0
61 | for i in old_classes_gt:
62 | old_acc += w[ind_map[i], i]
63 | total_old_instances += sum(w[:, i])
64 | classwise_acc.append(w[ind_map[i], i]/sum(w[:, i]))
65 | old_acc /= total_old_instances
66 |
67 | new_acc = 0
68 | total_new_instances = 0
69 | for i in new_classes_gt:
70 | new_acc += w[ind_map[i], i] # w[pred matching with ith label, ith label]
71 | total_new_instances += sum(w[:, i])
72 | classwise_acc.append(w[ind_map[i], i]/sum(w[:, i]))
73 | new_acc /= total_new_instances
74 |
75 | return total_acc, old_acc, new_acc
76 |
77 |
78 | EVAL_FUNCS = {
79 | 'v1': split_cluster_acc_v1,
80 | 'v2': split_cluster_acc_v2,
81 | }
82 |
83 | def log_accs_from_preds(y_true, y_pred, mask, eval_funcs: List[str], save_name: str, T: int=None, writer: SummaryWriter=None,
84 | print_output=False):
85 |
86 | """
87 | Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results
88 |
89 | :param y_true: GT labels
90 | :param y_pred: Predicted indices
91 | :param mask: Which instances belong to Old and New classes
92 | :param T: Epoch
93 | :param eval_funcs: Which evaluation functions to use
94 | :param save_name: What are we evaluating ACC on
95 | :param writer: Tensorboard logger
96 | :return:
97 | """
98 |
99 | mask = mask.astype(bool)
100 | y_true = y_true.astype(int)
101 | y_pred = y_pred.astype(int)
102 |
103 | for i, f_name in enumerate(eval_funcs):
104 |
105 | acc_f = EVAL_FUNCS[f_name]
106 | all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask)
107 | log_name = f'{save_name}_{f_name}'
108 |
109 | if writer is not None:
110 | writer.add_scalars(log_name,
111 | {'Old': old_acc, 'New': new_acc,
112 | 'All': all_acc}, T)
113 |
114 | if i == 0:
115 | to_return = (all_acc, old_acc, new_acc)
116 |
117 | if print_output:
118 | print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}'
119 | print(print_str)
120 |
121 | return to_return
122 |
123 |
124 | def test_agglo(epoch, feats, targets, mask, save_name, args):
125 | best_acc = 0
126 | mask = mask.astype(bool)
127 | feats = np.concatenate(feats)
128 | linked = linkage(feats, method="ward")
129 |
130 | gt_dist = linked[:, 2][-args.num_labeled_classes-args.num_unlabeled_classes]
131 | preds = fcluster(linked, t=gt_dist, criterion='distance')
132 | test_all_acc_test, test_old_acc_test, test_new_acc_test = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask,
133 | T=0, eval_funcs=args.eval_funcs, save_name=save_name)
134 |
135 | dist = linked[:, 2][:-args.num_labeled_classes]
136 | tolerance = 0
137 | for d in reversed(dist):
138 | preds = fcluster(linked, t=d, criterion='distance')
139 | k = max(preds)
140 | all_acc_test, old_acc_test, new_acc_test = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask,
141 | T=0, eval_funcs=args.eval_funcs, save_name=save_name)
142 |
143 | if old_acc_test > best_acc: # save best labeled acc without knowing GT K
144 | best_acc = old_acc_test
145 | best_acc_k = k
146 | best_acc_d = d
147 | tolerance = 0
148 | else:
149 | tolerance += 1
150 |
151 | if tolerance == 50 :
152 | break
153 | return test_all_acc_test, test_old_acc_test, test_new_acc_test, best_acc, best_acc_k
154 |
155 |
--------------------------------------------------------------------------------
/project_utils/cluster_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 | import numpy as np
3 | import sklearn.metrics
4 | import torch
5 | import torch.nn as nn
6 | import matplotlib
7 | matplotlib.use('agg')
8 | from scipy.optimize import linear_sum_assignment as linear_assignment
9 | import random
10 | import os
11 | import argparse
12 |
13 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
14 | from sklearn.metrics import adjusted_rand_score as ari_score
15 |
16 | from sklearn import metrics
17 | import time
18 |
19 | # -------------------------------
20 | # Evaluation Criteria
21 | # -------------------------------
22 | def evaluate_clustering(y_true, y_pred):
23 |
24 | start = time.time()
25 | print('Computing metrics...')
26 | if len(set(y_pred)) < 1000:
27 | acc = cluster_acc(y_true.astype(int), y_pred.astype(int))
28 | else:
29 | acc = None
30 |
31 | nmi = nmi_score(y_true, y_pred)
32 | ari = ari_score(y_true, y_pred)
33 | pur = purity_score(y_true, y_pred)
34 | print(f'Finished computing metrics {time.time() - start}...')
35 |
36 | return acc, nmi, ari, pur
37 |
38 |
39 | def cluster_acc(y_true, y_pred, return_ind=False):
40 | """
41 | Calculate clustering accuracy. Require scikit-learn installed
42 |
43 | # Arguments
44 | y: true labels, numpy.array with shape `(n_samples,)`
45 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
46 |
47 | # Return
48 | accuracy, in [0,1]
49 | """
50 | y_true = y_true.astype(int)
51 | assert y_pred.size == y_true.size
52 | D = max(y_pred.max(), y_true.max()) + 1
53 | w = np.zeros((D, D), dtype=int)
54 | for i in range(y_pred.size):
55 | w[y_pred[i], y_true[i]] += 1
56 |
57 | ind = linear_assignment(w.max() - w)
58 | ind = np.vstack(ind).T
59 |
60 | if return_ind:
61 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w
62 | else:
63 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
64 |
65 |
66 | def purity_score(y_true, y_pred):
67 | # compute contingency matrix (also called confusion matrix)
68 | contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred)
69 | # return purity
70 | return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)
71 |
72 | # -------------------------------
73 | # Mixed Eval Function
74 | # -------------------------------
75 | def mixed_eval(targets, preds, mask):
76 |
77 | """
78 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask'
79 | (Mask usually corresponding to `Old' and `New' classes in GCD setting)
80 | :param targets: All ground truth labels
81 | :param preds: All predictions
82 | :param mask: Mask defining two subsets
83 | :return:
84 | """
85 |
86 | mask = mask.astype(bool)
87 |
88 | # Labelled examples
89 | if mask.sum() == 0: # All examples come from unlabelled classes
90 |
91 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int), preds.astype(int)), \
92 | nmi_score(targets, preds), \
93 | ari_score(targets, preds)
94 |
95 | print('Unlabelled Classes Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'
96 | .format(unlabelled_acc, unlabelled_nmi, unlabelled_ari))
97 |
98 | # Also return ratio between labelled and unlabelled examples
99 | return (unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.mean()
100 |
101 | else:
102 |
103 | labelled_acc, labelled_nmi, labelled_ari = cluster_acc(targets.astype(int)[mask],
104 | preds.astype(int)[mask]), \
105 | nmi_score(targets[mask], preds[mask]), \
106 | ari_score(targets[mask], preds[mask])
107 |
108 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int)[~mask],
109 | preds.astype(int)[~mask]), \
110 | nmi_score(targets[~mask], preds[~mask]), \
111 | ari_score(targets[~mask], preds[~mask])
112 |
113 | # Also return ratio between labelled and unlabelled examples
114 | return (labelled_acc, labelled_nmi, labelled_ari), (
115 | unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.mean()
116 |
117 |
118 | class AverageMeter(object):
119 | """Computes and stores the average and current value"""
120 | def __init__(self):
121 | self.reset()
122 |
123 | def reset(self):
124 | self.val = 0
125 | self.avg = 0
126 | self.sum = 0
127 | self.count = 0
128 |
129 | def update(self, val, n=1):
130 | self.val = val
131 | self.sum += val * n
132 | self.count += n
133 | self.avg = self.sum / self.count
134 |
135 |
136 | class Identity(nn.Module):
137 | def __init__(self):
138 | super(Identity, self).__init__()
139 | def forward(self, x):
140 | return x
141 |
142 |
143 | class BCE(nn.Module):
144 | eps = 1e-7 # Avoid calculating log(0). Use the small value of float16.
145 | def forward(self, prob1, prob2, simi):
146 | # simi: 1->similar; -1->dissimilar; 0->unknown(ignore)
147 | assert len(prob1)==len(prob2)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(prob1)),str(len(prob2)),str(len(simi)))
148 | P = prob1.mul_(prob2)
149 | P = P.sum(1)
150 | P.mul_(simi).add_(simi.eq(-1).type_as(P))
151 | neglogP = -P.add_(BCE.eps).log_()
152 | return neglogP.mean()
153 |
154 |
155 | def PairEnum(x,mask=None):
156 |
157 | # Enumerate all pairs of feature in x
158 | assert x.ndimension() == 2, 'Input dimension must be 2'
159 | x1 = x.repeat(x.size(0), 1)
160 | x2 = x.repeat(1, x.size(0)).view(-1, x.size(1))
161 |
162 | if mask is not None:
163 |
164 | xmask = mask.view(-1, 1).repeat(1, x.size(1))
165 | #dim 0: #sample, dim 1:#feature
166 | x1 = x1[xmask].view(-1, x.size(1))
167 | x2 = x2[xmask].view(-1, x.size(1))
168 |
169 | return x1, x2
170 |
171 |
172 | def accuracy(output, target, topk=(1,)):
173 | """Computes the accuracy over the k top predictions for the specified values of k"""
174 | with torch.no_grad():
175 | maxk = max(topk)
176 | batch_size = target.size(0)
177 |
178 | _, pred = output.topk(maxk, 1, True, True)
179 | pred = pred.t()
180 | correct = pred.eq(target.view(1, -1).expand_as(pred))
181 |
182 | res = []
183 | for k in topk:
184 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
185 | res.append(correct_k.mul_(100.0 / batch_size))
186 | return res
187 |
188 |
189 | def seed_torch(seed=1029):
190 | random.seed(seed)
191 | os.environ['PYTHONHASHSEED'] = str(seed)
192 | np.random.seed(seed)
193 | torch.manual_seed(seed)
194 | torch.cuda.manual_seed(seed)
195 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
196 | torch.backends.cudnn.benchmark = False
197 | torch.backends.cudnn.deterministic = True
198 |
199 |
200 | def str2bool(v):
201 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
202 | return True
203 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
204 | return False
205 | else:
206 | raise argparse.ArgumentTypeError('Boolean value expected.')
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Contrastive Mean-Shift Learning
3 | for Generalized Category Discovery
4 |
5 |
6 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |

18 |
19 |
20 |
24 |
25 |
26 |
27 | ## Environmnet installation
28 | This project is built upon the following environment:
29 | * [Python 3.10](https://pytorch.org)
30 | * [CUDA 11.7](https://developer.nvidia.com/cuda-toolkit)
31 | * [PyTorch 1.13.1](https://pytorch.org)
32 |
33 | The package requirements can be installed via `requirements.txt`,
34 | ```bash
35 | pip install -r requirements.txt
36 | ```
37 |
38 | ## Datasets
39 | We use fine-grained benchmarks in this paper, including:
40 | * [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6)
41 |
42 | We also use generic object recognition datasets, including:
43 | * [CIFAR100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet](https://image-net.org/download.php)
44 |
45 | Please follow [this repo](https://github.com/sgvaze/generalized-category-discovery) to set up the data.
46 |
47 | Download the datasets, ssb splits, and pretrained backbone by following the file structure below and set `DATASET_ROOT={YOUR DIRECTORY}` in `config.py`.
48 |
49 | ```
50 | DATASET_ROOT/
51 | ├── cifar100/
52 | │ ├── cifar-100-python\
53 | │ │ ├── meta/
54 | │ ├── ...
55 | ├── CUB_200_2011/
56 | │ ├── attributes/
57 | │ ├── ...
58 | ├── ...
59 | ```
60 | ```
61 | CMS/
62 | ├── data/
63 | │ ├── ssb_splits/
64 | ├── models/
65 | │ ├── dino_vitbase16_pretrain.pth
66 | ├── ...
67 | ```
68 |
69 |
70 | ## Training
71 | ```bash
72 | bash bash_scripts/contrastive_meanshift_training.sh
73 | ```
74 | Example bash commands for training are as follows:
75 | ```bash
76 | # GCD
77 | python -m methods.contrastive_meanshift_training \
78 | --dataset_name 'cub' \
79 | --lr 0.05 \
80 | --temperature 0.25 \
81 | --wandb
82 |
83 | # Inductive GCD
84 | python -m methods.contrastive_meanshift_training \
85 | --dataset_name 'cub' \
86 | --lr 0.05 \
87 | --temperature 0.25 \
88 | --inductive \
89 | --wandb
90 | ```
91 |
92 | ## Evaluation
93 | ```bash
94 | bash bash_scripts/meanshift_clustering.sh
95 | ```
96 | Example bash command for evaluation is as follows. It will require changing `model_name`.
97 | ```bash
98 | python -m methods.meanshift_clustering \
99 | --dataset_name 'cub' \
100 | --model_name 'cub_best' \
101 | ```
102 |
103 |
104 | ## Results and checkpoints
105 |
106 | ### Experimental results on GCD task.
107 |
108 |
109 |
110 | | |
111 | All |
112 | Old |
113 | Novel |
114 | Checkpoints |
115 |
116 |
117 | | CIFAR100 |
118 | 82.3 |
119 | 85.7 |
120 | 75.5 |
121 | link |
122 |
123 |
124 | | ImageNet100 |
125 | 84.7 |
126 | 95.6 |
127 | 79.2 |
128 | link |
129 |
130 |
131 | | CUB |
132 | 68.2 |
133 | 76.5 |
134 | 64.0 |
135 | link |
136 |
137 |
138 | | Stanford Cars |
139 | 56.9 |
140 | 76.1 |
141 | 47.6 |
142 | link |
143 |
144 |
145 | | FGVC-Aircraft |
146 | 56.0 |
147 | 63.4 |
148 | 52.3 |
149 | link |
150 |
151 |
152 | | Herbarium19 |
153 | 36.4 |
154 | 54.9 |
155 | 26.4 |
156 | link |
157 |
158 |
159 |
160 |
161 | ### Experimental results on inductive GCD task.
162 |
163 |
164 |
165 | | |
166 | All |
167 | Old |
168 | Novel |
169 | Checkpoints |
170 |
171 |
172 | | CIFAR100 |
173 | 80.7 |
174 | 84.4 |
175 | 65.9 |
176 | link |
177 |
178 |
179 | | ImageNet100 |
180 | 85.7 |
181 | 95.7 |
182 | 75.8 |
183 | link |
184 |
185 |
186 | | CUB |
187 | 69.7 |
188 | 76.5 |
189 | 63.0 |
190 | link |
191 |
192 |
193 | | Stanford Cars |
194 | 57.8 |
195 | 75.2 |
196 | 41.0 |
197 | link |
198 |
199 |
200 | | FGVC-Aircraft |
201 | 53.3 |
202 | 62.7 |
203 | 43.8 |
204 | link |
205 |
206 |
207 | | Herbarium19 |
208 | 46.2 |
209 | 53.0 |
210 | 38.9 |
211 | link |
212 |
213 |
214 |
215 |
216 |
217 | ## Citation
218 | If you find our code or paper useful, please consider citing our paper:
219 |
220 | ```BibTeX
221 | @inproceedings{choi2024contrastive,
222 | title={Contrastive Mean-Shift Learning for Generalized Category Discovery},
223 | author={Choi, Sua and Kang, Dahyun and Cho, Minsu},
224 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
225 | year={2024}
226 | }
227 | ```
228 |
229 | ## Related Repos
230 | The codebase is largely built on [Generalized Category Discovery](https://github.com/sgvaze/generalized-category-discovery) and [PromptCAL](https://github.com/sheng-eatamath/PromptCAL).
231 |
232 |
233 | ## Acknowledgements
234 | This work was supported by the NRF grant (NRF-2021R1A2C3012728 (50%)) and the IITP grants (2022-0-00113: Developing a Sustainable Collaborative Multi-modal Lifelong Learning Framework (45%), 2019-0-01906:
235 | AI Graduate School Program at POSTECH (5%)) funded by Ministry of Science and ICT, Korea.
236 |
237 |
--------------------------------------------------------------------------------
/data/augmentations/randaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
3 | """
4 |
5 | import random
6 |
7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 |
13 | def ShearX(img, v): # [-0.3, 0.3]
14 | assert -0.3 <= v <= 0.3
15 | if random.random() > 0.5:
16 | v = -v
17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
18 |
19 |
20 | def ShearY(img, v): # [-0.3, 0.3]
21 | assert -0.3 <= v <= 0.3
22 | if random.random() > 0.5:
23 | v = -v
24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
25 |
26 |
27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
28 | assert -0.45 <= v <= 0.45
29 | if random.random() > 0.5:
30 | v = -v
31 | v = v * img.size[0]
32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
33 |
34 |
35 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
36 | assert 0 <= v
37 | if random.random() > 0.5:
38 | v = -v
39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
40 |
41 |
42 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
43 | assert -0.45 <= v <= 0.45
44 | if random.random() > 0.5:
45 | v = -v
46 | v = v * img.size[1]
47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
48 |
49 |
50 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
51 | assert 0 <= v
52 | if random.random() > 0.5:
53 | v = -v
54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
55 |
56 |
57 | def Rotate(img, v): # [-30, 30]
58 | assert -30 <= v <= 30
59 | if random.random() > 0.5:
60 | v = -v
61 | return img.rotate(v)
62 |
63 |
64 | def AutoContrast(img, _):
65 | return PIL.ImageOps.autocontrast(img)
66 |
67 |
68 | def Invert(img, _):
69 | return PIL.ImageOps.invert(img)
70 |
71 |
72 | def Equalize(img, _):
73 | return PIL.ImageOps.equalize(img)
74 |
75 |
76 | def Flip(img, _): # not from the paper
77 | return PIL.ImageOps.mirror(img)
78 |
79 |
80 | def Solarize(img, v): # [0, 256]
81 | assert 0 <= v <= 256
82 | return PIL.ImageOps.solarize(img, v)
83 |
84 |
85 | def SolarizeAdd(img, addition=0, threshold=128):
86 | img_np = np.array(img).astype(np.int)
87 | img_np = img_np + addition
88 | img_np = np.clip(img_np, 0, 255)
89 | img_np = img_np.astype(np.uint8)
90 | img = Image.fromarray(img_np)
91 | return PIL.ImageOps.solarize(img, threshold)
92 |
93 |
94 | def Posterize(img, v): # [4, 8]
95 | v = int(v)
96 | v = max(1, v)
97 | return PIL.ImageOps.posterize(img, v)
98 |
99 |
100 | def Contrast(img, v): # [0.1,1.9]
101 | assert 0.1 <= v <= 1.9
102 | return PIL.ImageEnhance.Contrast(img).enhance(v)
103 |
104 |
105 | def Color(img, v): # [0.1,1.9]
106 | assert 0.1 <= v <= 1.9
107 | return PIL.ImageEnhance.Color(img).enhance(v)
108 |
109 |
110 | def Brightness(img, v): # [0.1,1.9]
111 | assert 0.1 <= v <= 1.9
112 | return PIL.ImageEnhance.Brightness(img).enhance(v)
113 |
114 |
115 | def Sharpness(img, v): # [0.1,1.9]
116 | assert 0.1 <= v <= 1.9
117 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
118 |
119 |
120 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
121 | assert 0.0 <= v <= 0.2
122 | if v <= 0.:
123 | return img
124 |
125 | v = v * img.size[0]
126 | return CutoutAbs(img, v)
127 |
128 |
129 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
130 | # assert 0 <= v <= 20
131 | if v < 0:
132 | return img
133 | w, h = img.size
134 | x0 = np.random.uniform(w)
135 | y0 = np.random.uniform(h)
136 |
137 | x0 = int(max(0, x0 - v / 2.))
138 | y0 = int(max(0, y0 - v / 2.))
139 | x1 = min(w, x0 + v)
140 | y1 = min(h, y0 + v)
141 |
142 | xy = (x0, y0, x1, y1)
143 | color = (125, 123, 114)
144 | # color = (0, 0, 0)
145 | img = img.copy()
146 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
147 | return img
148 |
149 |
150 | def SamplePairing(imgs): # [0, 0.4]
151 | def f(img1, v):
152 | i = np.random.choice(len(imgs))
153 | img2 = PIL.Image.fromarray(imgs[i])
154 | return PIL.Image.blend(img1, img2, v)
155 |
156 | return f
157 |
158 |
159 | def Identity(img, v):
160 | return img
161 |
162 |
163 | def augment_list(): # 16 oeprations and their ranges
164 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
165 | # l = [
166 | # (Identity, 0., 1.0),
167 | # (ShearX, 0., 0.3), # 0
168 | # (ShearY, 0., 0.3), # 1
169 | # (TranslateX, 0., 0.33), # 2
170 | # (TranslateY, 0., 0.33), # 3
171 | # (Rotate, 0, 30), # 4
172 | # (AutoContrast, 0, 1), # 5
173 | # (Invert, 0, 1), # 6
174 | # (Equalize, 0, 1), # 7
175 | # (Solarize, 0, 110), # 8
176 | # (Posterize, 4, 8), # 9
177 | # # (Contrast, 0.1, 1.9), # 10
178 | # (Color, 0.1, 1.9), # 11
179 | # (Brightness, 0.1, 1.9), # 12
180 | # (Sharpness, 0.1, 1.9), # 13
181 | # # (Cutout, 0, 0.2), # 14
182 | # # (SamplePairing(imgs), 0, 0.4), # 15
183 | # ]
184 |
185 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
186 | l = [
187 | (AutoContrast, 0, 1),
188 | (Equalize, 0, 1),
189 | (Invert, 0, 1),
190 | (Rotate, 0, 30),
191 | (Posterize, 0, 4),
192 | (Solarize, 0, 256),
193 | (SolarizeAdd, 0, 110),
194 | (Color, 0.1, 1.9),
195 | (Contrast, 0.1, 1.9),
196 | (Brightness, 0.1, 1.9),
197 | (Sharpness, 0.1, 1.9),
198 | (ShearX, 0., 0.3),
199 | (ShearY, 0., 0.3),
200 | (CutoutAbs, 0, 40),
201 | (TranslateXabs, 0., 100),
202 | (TranslateYabs, 0., 100),
203 | ]
204 |
205 | return l
206 |
207 | def augment_list_svhn(): # 16 oeprations and their ranges
208 |
209 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
210 | l = [
211 | (AutoContrast, 0, 1),
212 | (Equalize, 0, 1),
213 | (Invert, 0, 1),
214 | (Posterize, 0, 4),
215 | (Solarize, 0, 256),
216 | (SolarizeAdd, 0, 110),
217 | (Color, 0.1, 1.9),
218 | (Contrast, 0.1, 1.9),
219 | (Brightness, 0.1, 1.9),
220 | (Sharpness, 0.1, 1.9),
221 | (ShearX, 0., 0.3),
222 | (ShearY, 0., 0.3),
223 | (CutoutAbs, 0, 40),
224 | ]
225 |
226 | return l
227 |
228 |
229 | class Lighting(object):
230 | """Lighting noise(AlexNet - style PCA - based noise)"""
231 |
232 | def __init__(self, alphastd, eigval, eigvec):
233 | self.alphastd = alphastd
234 | self.eigval = torch.Tensor(eigval)
235 | self.eigvec = torch.Tensor(eigvec)
236 |
237 | def __call__(self, img):
238 | if self.alphastd == 0:
239 | return img
240 |
241 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
242 | rgb = self.eigvec.type_as(img).clone() \
243 | .mul(alpha.view(1, 3).expand(3, 3)) \
244 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
245 | .sum(1).squeeze()
246 |
247 | return img.add(rgb.view(3, 1, 1).expand_as(img))
248 |
249 |
250 | class CutoutDefault(object):
251 | """
252 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
253 | """
254 | def __init__(self, length):
255 | self.length = length
256 |
257 | def __call__(self, img):
258 | h, w = img.size(1), img.size(2)
259 | mask = np.ones((h, w), float)
260 | y = np.random.randint(h)
261 | x = np.random.randint(w)
262 |
263 | y1 = np.clip(y - self.length // 2, 0, h)
264 | y2 = np.clip(y + self.length // 2, 0, h)
265 | x1 = np.clip(x - self.length // 2, 0, w)
266 | x2 = np.clip(x + self.length // 2, 0, w)
267 |
268 | mask[y1: y2, x1: x2] = 0.
269 | mask = torch.from_numpy(mask)
270 | mask = mask.expand_as(img)
271 | img *= mask
272 | return img
273 |
274 |
275 | class RandAugment:
276 | def __init__(self, n, m, args=None):
277 | self.n = n # [1, 2]
278 | self.m = m # [0...30]
279 |
280 | if args is None:
281 | self.augment_list = augment_list()
282 |
283 | elif args.dataset == 'svhn' or args.dataset == 'mnist':
284 | self.augment_list = augment_list_svhn()
285 |
286 | else:
287 | self.augment_list = augment_list()
288 |
289 | def __call__(self, img):
290 | ops = random.choices(self.augment_list, k=self.n)
291 | for op, minval, maxval in ops:
292 | val = (float(self.m) / 30) * float(maxval - minval) + minval
293 | img = op(img, val)
294 |
295 | return img
--------------------------------------------------------------------------------
/data/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 | from scipy import io as mat_io
6 |
7 | from torchvision.datasets.folder import default_loader
8 | from torch.utils.data import Dataset
9 |
10 | from data.data_utils import subsample_instances
11 |
12 | car_root = "/home/suachoi/datasets/stanford_car/cars_{}/"
13 | meta_default_path = "/home/suachoi/datasets/stanford_car/devkit/cars_{}.mat"
14 |
15 | class CarsDataset(Dataset):
16 | """
17 | Cars Dataset
18 | """
19 | def __init__(self, train=True, limit=0, data_dir=car_root, transform=None, metas=meta_default_path):
20 |
21 | data_dir = data_dir.format('train') if train else data_dir.format('test')
22 | metas = metas.format('train_annos') if train else metas.format('test_annos_withlabels')
23 |
24 | self.loader = default_loader
25 | self.data_dir = data_dir
26 | self.data = []
27 | self.target = []
28 | self.train = train
29 |
30 | self.transform = transform
31 |
32 | if not isinstance(metas, str):
33 | raise Exception("Train metas must be string location !")
34 | labels_meta = mat_io.loadmat(metas)
35 |
36 | for idx, img_ in enumerate(labels_meta['annotations'][0]):
37 | if limit:
38 | if idx > limit:
39 | break
40 |
41 | self.data.append(data_dir + img_[5][0])
42 | self.target.append(img_[4][0][0])
43 |
44 | self.uq_idxs = np.array(range(len(self)))
45 | self.target_transform = None
46 |
47 | def __getitem__(self, idx):
48 |
49 | image = self.loader(self.data[idx])
50 | target = self.target[idx] - 1
51 |
52 | if self.transform is not None:
53 | image = self.transform(image)
54 |
55 | if self.target_transform is not None:
56 | target = self.target_transform(target)
57 |
58 | idx = self.uq_idxs[idx]
59 |
60 | return image, target, idx
61 |
62 | def __len__(self):
63 | return len(self.data)
64 |
65 |
66 | def subsample_dataset(dataset, idxs, absolute=True):
67 | mask = np.zeros(len(dataset)).astype('bool')
68 | if absolute==True:
69 | mask[idxs] = True
70 | else:
71 | idxs = set(idxs)
72 | mask = np.array([i in idxs for i in dataset.uq_idxs])
73 |
74 | dataset.data = np.array(dataset.data)[mask].tolist()
75 | dataset.target = np.array(dataset.target)[mask].tolist()
76 | dataset.uq_idxs = dataset.uq_idxs[mask]
77 |
78 | return dataset
79 |
80 |
81 | def subsample_classes(dataset, include_classes=range(160)):
82 |
83 | include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195
84 | cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars]
85 |
86 | target_xform_dict = {}
87 | for i, k in enumerate(include_classes):
88 | target_xform_dict[k] = i
89 |
90 | dataset = subsample_dataset(dataset, cls_idxs)
91 |
92 | return dataset
93 |
94 | def get_train_val_indices(train_dataset, val_split=0.2):
95 |
96 | train_classes = np.unique(train_dataset.target)
97 |
98 | # Get train/test indices
99 | train_idxs = []
100 | val_idxs = []
101 | for cls in train_classes:
102 |
103 | cls_idxs = np.where(train_dataset.target == cls)[0]
104 |
105 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
106 | t_ = [x for x in cls_idxs if x not in v_]
107 |
108 | train_idxs.extend(t_)
109 | val_idxs.extend(v_)
110 |
111 | return train_idxs, val_idxs
112 |
113 |
114 | def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
115 | split_train_val=False, seed=0):
116 |
117 | np.random.seed(seed)
118 |
119 | # Init entire training set
120 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, metas=meta_default_path, train=True)
121 |
122 | # Get labelled training set which has subsampled classes, then subsample some indices from that
123 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
124 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
125 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
126 |
127 | # Split into training and validation sets
128 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
129 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
130 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
131 | val_dataset_labelled_split.transform = test_transform
132 |
133 | # Get unlabelled data
134 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
135 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
136 |
137 | # Get test set for all classes
138 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, metas=meta_default_path, train=False)
139 |
140 | # Either split train into train and val or use test set as val
141 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
142 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
143 |
144 | all_datasets = {
145 | 'train_labelled': train_dataset_labelled,
146 | 'train_unlabelled': train_dataset_unlabelled,
147 | 'val': val_dataset_labelled,
148 | 'test': test_dataset,
149 | }
150 |
151 | return all_datasets
152 |
153 | def get_scars_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80),
154 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1):
155 |
156 | np.random.seed(seed)
157 |
158 | # Init entire training set
159 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, metas=meta_default_path, train=True)
160 |
161 | # Get labelled training set which has subsampled classes, then subsample some indices from that
162 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
163 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
164 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
165 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
166 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False)
167 |
168 | # Split into training and validation sets
169 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split)
170 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
171 | val_dataset_labelled.transform = test_transform
172 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
173 |
174 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split)
175 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs)
176 | val_dataset_unlabelled.transform = test_transform
177 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs)
178 |
179 | val_dataset_unlabelled.transform = test_transform
180 |
181 |
182 | # Get test set for all classes
183 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, metas=meta_default_path, train=False)
184 |
185 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}')
186 |
187 | all_datasets = {
188 | 'train_labelled': train_dataset_labelled,
189 | 'train_unlabelled': train_dataset_unlabelled,
190 | 'val': [val_dataset_labelled, val_dataset_unlabelled],
191 | 'test': test_dataset,
192 | }
193 |
194 | return all_datasets
195 |
196 | if __name__ == '__main__':
197 |
198 | x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False)
199 |
200 | print('Printing lens...')
201 | for k, v in x.items():
202 | if v is not None:
203 | print(f'{k}: {len(v)}')
204 |
205 | print('Printing labelled and unlabelled overlap...')
206 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
207 | print('Printing total instances in train...')
208 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
209 |
210 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}')
211 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}')
212 | print(f'Len labelled set: {len(x["train_labelled"])}')
213 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/herbarium_19.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torchvision
4 | import numpy as np
5 | from copy import deepcopy
6 |
7 | from data.data_utils import subsample_instances
8 | from config import herbarium_dataroot
9 |
10 | class HerbariumDataset19(torchvision.datasets.ImageFolder):
11 |
12 | def __init__(self, *args, **kwargs):
13 |
14 | # Process metadata json for training images into a DataFrame
15 | super().__init__(*args, **kwargs)
16 |
17 | self.uq_idxs = np.array(range(len(self)))
18 |
19 | def __getitem__(self, idx):
20 |
21 | img, label = super().__getitem__(idx)
22 | uq_idx = self.uq_idxs[idx]
23 |
24 | return img, label, uq_idx
25 |
26 |
27 | def subsample_dataset(dataset, idxs, absolute=True):
28 | mask = np.zeros(len(dataset)).astype('bool')
29 | if absolute==True:
30 | mask[idxs] = True
31 | else:
32 | idxs = set(idxs)
33 | mask = np.array([i in idxs for i in dataset.uq_idxs])
34 |
35 | dataset.samples = np.array(dataset.samples)[mask].tolist()
36 | dataset.targets = np.array(dataset.targets)[mask].tolist()
37 |
38 | dataset.uq_idxs = dataset.uq_idxs[mask]
39 |
40 | dataset.samples = [[x[0], int(x[1])] for x in dataset.samples]
41 | dataset.targets = [int(x) for x in dataset.targets]
42 |
43 | return dataset
44 |
45 |
46 | def subsample_classes(dataset, include_classes=range(250)):
47 |
48 | cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes]
49 |
50 | target_xform_dict = {}
51 | for i, k in enumerate(include_classes):
52 | target_xform_dict[k] = i
53 |
54 | dataset = subsample_dataset(dataset, cls_idxs)
55 |
56 | dataset.target_transform = lambda x: target_xform_dict[x]
57 |
58 | return dataset
59 |
60 |
61 | def get_train_val_indices(train_dataset, val_instances_per_class=5):
62 |
63 | train_classes = list(set(train_dataset.targets))
64 |
65 | # Get train/test indices
66 | train_idxs = []
67 | val_idxs = []
68 | for cls in train_classes:
69 |
70 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
71 |
72 | # Have a balanced test set
73 | val_size = min(cls_idxs.shape[0], val_instances_per_class)
74 | v_ = np.random.choice(cls_idxs, replace=False, size=(val_size,))
75 | t_ = [x for x in cls_idxs if x not in v_]
76 |
77 | train_idxs.extend(t_)
78 | val_idxs.extend(v_)
79 |
80 | return train_idxs, val_idxs
81 |
82 |
83 | def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8,
84 | seed=0, split_train_val=False):
85 |
86 | np.random.seed(seed)
87 |
88 | # Init entire training set
89 | train_dataset = HerbariumDataset19(transform=train_transform,
90 | root=os.path.join(herbarium_dataroot, 'small-train'))
91 |
92 | # Get labelled training set which has subsampled classes, then subsample some indices from that
93 | # TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class
94 | train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes)
95 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
96 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
97 |
98 | # Split into training and validation sets
99 | if split_train_val:
100 |
101 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled,
102 | val_instances_per_class=5)
103 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
104 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
105 | val_dataset_labelled_split.transform = test_transform
106 |
107 | else:
108 |
109 | train_dataset_labelled_split, val_dataset_labelled_split = None, None
110 |
111 | # Get unlabelled data
112 | unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs)
113 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices)))
114 |
115 | # Get test dataset
116 | test_dataset = HerbariumDataset19(transform=test_transform,
117 | root=os.path.join(herbarium_dataroot, 'small-validation'))
118 |
119 | # Transform dict
120 | unlabelled_classes = list(set(train_dataset.targets) - set(train_classes))
121 | target_xform_dict = {}
122 | for i, k in enumerate(list(train_classes) + unlabelled_classes):
123 | target_xform_dict[k] = i
124 |
125 | test_dataset.target_transform = lambda x: target_xform_dict[x]
126 | train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x]
127 |
128 | # Either split train into train and val or use test set as val
129 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
130 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
131 |
132 | all_datasets = {
133 | 'train_labelled': train_dataset_labelled,
134 | 'train_unlabelled': train_dataset_unlabelled,
135 | 'val': val_dataset_labelled,
136 | 'test': test_dataset,
137 | }
138 |
139 | return all_datasets
140 |
141 | def get_herbarium_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80),
142 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1):
143 |
144 | np.random.seed(seed)
145 |
146 | # Init entire training set
147 | whole_training_set = HerbariumDataset19(transform=train_transform,
148 | root=os.path.join(herbarium_dataroot, 'small-train'))
149 | # Get labelled training set which has subsampled classes, then subsample some indices from that
150 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
151 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
152 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
153 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
154 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False)
155 |
156 | # Split into training and validation sets
157 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_instances_per_class=5)
158 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
159 | val_dataset_labelled.transform = test_transform
160 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
161 |
162 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_instances_per_class=5)
163 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs)
164 | val_dataset_unlabelled.transform = test_transform
165 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs)
166 |
167 | val_dataset_unlabelled.transform = test_transform
168 |
169 |
170 | # Get test set for all classes
171 | test_dataset = HerbariumDataset19(transform=test_transform,
172 | root=os.path.join(herbarium_dataroot, 'small-validation'))
173 |
174 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}')
175 |
176 | all_datasets = {
177 | 'train_labelled': train_dataset_labelled,
178 | 'train_unlabelled': train_dataset_unlabelled,
179 | 'val': [val_dataset_labelled, val_dataset_unlabelled],
180 | 'test': test_dataset,
181 | }
182 |
183 | return all_datasets
184 |
185 | if __name__ == '__main__':
186 |
187 | np.random.seed(0)
188 | train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False)
189 |
190 | x = get_herbarium_datasets(None, None, train_classes=train_classes,
191 | prop_train_labels=0.5)
192 |
193 | assert set(x['train_unlabelled'].targets) == set(range(683))
194 |
195 | print('Printing lens...')
196 | for k, v in x.items():
197 | if v is not None:
198 | print(f'{k}: {len(v)}')
199 |
200 | print('Printing labelled and unlabelled overlap...')
201 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
202 | print('Printing total instances in train...')
203 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
204 | print('Printing number of labelled classes...')
205 | print(len(set(x['train_labelled'].targets)))
206 | print('Printing total number of classes...')
207 | print(len(set(x['train_unlabelled'].targets)))
208 |
209 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
210 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
211 | print(f'Len labelled set: {len(x["train_labelled"])}')
212 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/get_datasets.py:
--------------------------------------------------------------------------------
1 | from data.data_utils import MergedDataset
2 |
3 | from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets, get_cifar_100_datasets_with_gcdval
4 | from data.herbarium_19 import get_herbarium_datasets, get_herbarium_datasets_with_gcdval
5 | from data.stanford_cars import get_scars_datasets, get_scars_datasets_with_gcdval
6 | from data.imagenet import get_imagenet_100_datasets, get_imagenet_100_gcd_datasets_with_gcdval
7 | from data.cub import get_cub_datasets, get_cub_datasets_with_gcdval
8 | from data.fgvc_aircraft import get_aircraft_datasets, get_aircraft_datasets_with_gcdval
9 |
10 | from data.cifar import subsample_classes as subsample_dataset_cifar
11 | from data.herbarium_19 import subsample_classes as subsample_dataset_herb
12 | from data.stanford_cars import subsample_classes as subsample_dataset_scars
13 | from data.imagenet import subsample_classes as subsample_dataset_imagenet
14 | from data.cub import subsample_classes as subsample_dataset_cub
15 | from data.fgvc_aircraft import subsample_classes as subsample_dataset_air
16 |
17 | from copy import deepcopy
18 | import pickle
19 | import os
20 |
21 | from config import osr_split_dir
22 |
23 | sub_sample_class_funcs = {
24 | 'cifar10': subsample_dataset_cifar,
25 | 'cifar100': subsample_dataset_cifar,
26 | 'imagenet_100': subsample_dataset_imagenet,
27 | 'herbarium_19': subsample_dataset_herb,
28 | 'cub': subsample_dataset_cub,
29 | 'aircraft': subsample_dataset_air,
30 | 'scars': subsample_dataset_scars
31 | }
32 |
33 | get_dataset_funcs = {
34 | 'cifar10': get_cifar_10_datasets,
35 | 'cifar100': get_cifar_100_datasets,
36 | 'imagenet_100': get_imagenet_100_datasets,
37 | 'herbarium_19': get_herbarium_datasets,
38 | 'cub': get_cub_datasets,
39 | 'aircraft': get_aircraft_datasets,
40 | 'scars': get_scars_datasets
41 | }
42 |
43 | get_dataset_funcs_with_gcdval = {
44 | 'cifar100': get_cifar_100_datasets_with_gcdval,
45 | 'cub': get_cub_datasets_with_gcdval,
46 | 'imagenet_100': get_imagenet_100_gcd_datasets_with_gcdval,
47 | 'aircraft': get_aircraft_datasets_with_gcdval,
48 | 'scars': get_scars_datasets_with_gcdval,
49 | 'herbarium_19': get_herbarium_datasets_with_gcdval
50 | }
51 |
52 |
53 | def get_datasets(dataset_name, train_transform, test_transform, args):
54 |
55 | """
56 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled
57 | test_dataset,
58 | unlabelled_train_examples_test,
59 | datasets
60 | """
61 |
62 | #
63 | if dataset_name not in get_dataset_funcs.keys():
64 | raise ValueError
65 |
66 | # Get datasets
67 | get_dataset_f = get_dataset_funcs[dataset_name]
68 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform,
69 | train_classes=args.train_classes,
70 | prop_train_labels=args.prop_train_labels,
71 | split_train_val=False)
72 |
73 | # Set target transforms:
74 | target_transform_dict = {}
75 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)):
76 | target_transform_dict[cls] = i
77 | target_transform = lambda x: target_transform_dict[x]
78 |
79 | for dataset_name, dataset in datasets.items():
80 | if dataset is not None:
81 | dataset.target_transform = target_transform
82 |
83 | # Train split (labelled and unlabelled classes) for training
84 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']),
85 | unlabelled_dataset=deepcopy(datasets['train_unlabelled']))
86 |
87 | test_dataset = datasets['test']
88 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled'])
89 | unlabelled_train_examples_test.transform = test_transform
90 |
91 | return train_dataset, test_dataset, unlabelled_train_examples_test, datasets
92 |
93 |
94 | def get_datasets_with_gcdval(dataset_name, train_transform, test_transform, args):
95 | """
96 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled
97 | test_dataset,
98 | unlabelled_train_examples_test,
99 | datasets
100 | """
101 |
102 | #
103 | if dataset_name not in get_dataset_funcs.keys():
104 | raise ValueError
105 |
106 | # Get datasets
107 | get_dataset_f = get_dataset_funcs_with_gcdval[dataset_name]
108 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform,
109 | train_classes=args.train_classes,
110 | prop_train_labels=args.prop_train_labels,
111 | split_train_val=True, val_split=0.1)
112 |
113 | # Set target transforms:
114 | target_transform_dict = {}
115 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)):
116 | target_transform_dict[cls] = i
117 | target_transform = lambda x: target_transform_dict[x]
118 |
119 | for dataset_name, dataset in datasets.items():
120 | if dataset is not None:
121 | if isinstance(dataset, list):
122 | for d in dataset:
123 | d.target_transform = target_transform
124 | else:
125 | dataset.target_transform = target_transform
126 |
127 | # Train split (labelled and unlabelled classes) for training
128 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']),
129 | unlabelled_dataset=deepcopy(datasets['train_unlabelled']))
130 | val_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['val'][0]),
131 | unlabelled_dataset=deepcopy(datasets['val'][1]))
132 | test_dataset = datasets['test']
133 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled'])
134 | unlabelled_train_examples_test.transform = test_transform
135 |
136 | return train_dataset, test_dataset, unlabelled_train_examples_test, val_dataset, datasets
137 |
138 | def get_class_splits(args):
139 |
140 | # For FGVC datasets, optionally return bespoke splits
141 | if args.dataset_name in ('scars', 'cub', 'aircraft'):
142 | if hasattr(args, 'use_ssb_splits'):
143 | use_ssb_splits = args.use_ssb_splits
144 | else:
145 | use_ssb_splits = False
146 |
147 | # -------------
148 | # GET CLASS SPLITS
149 | # -------------
150 | if args.dataset_name == 'cifar10':
151 |
152 | args.image_size = 32
153 | args.train_classes = range(5)
154 | args.unlabeled_classes = range(5, 10)
155 |
156 | elif args.dataset_name == 'cifar100':
157 |
158 | args.image_size = 32
159 | args.train_classes = range(80)
160 | args.unlabeled_classes = range(80, 100)
161 |
162 | elif args.dataset_name == 'tinyimagenet':
163 |
164 | args.image_size = 64
165 | args.train_classes = range(100)
166 | args.unlabeled_classes = range(100, 200)
167 |
168 | elif args.dataset_name == 'herbarium_19':
169 |
170 | args.image_size = 224
171 | herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl')
172 |
173 | with open(herb_path_splits, 'rb') as handle:
174 | class_splits = pickle.load(handle)
175 |
176 | args.train_classes = class_splits['Old']
177 | args.unlabeled_classes = class_splits['New']
178 |
179 | elif args.dataset_name == 'imagenet_100':
180 |
181 | args.image_size = 224
182 | args.train_classes = range(50)
183 | args.unlabeled_classes = range(50, 100)
184 |
185 | elif args.dataset_name == 'scars':
186 |
187 | args.image_size = 224
188 |
189 | if use_ssb_splits:
190 |
191 | split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl')
192 | with open(split_path, 'rb') as handle:
193 | class_info = pickle.load(handle)
194 |
195 | args.train_classes = class_info['known_classes']
196 | open_set_classes = class_info['unknown_classes']
197 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
198 |
199 | else:
200 |
201 | args.train_classes = range(98)
202 | args.unlabeled_classes = range(98, 196)
203 |
204 | elif args.dataset_name == 'aircraft':
205 |
206 | args.image_size = 224
207 | if use_ssb_splits:
208 |
209 | split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl')
210 | with open(split_path, 'rb') as handle:
211 | class_info = pickle.load(handle)
212 |
213 | args.train_classes = class_info['known_classes']
214 | open_set_classes = class_info['unknown_classes']
215 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
216 |
217 | else:
218 |
219 | args.train_classes = range(50)
220 | args.unlabeled_classes = range(50, 100)
221 |
222 | elif args.dataset_name == 'cub':
223 |
224 | args.image_size = 224
225 |
226 | if use_ssb_splits:
227 |
228 | split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl')
229 | with open(split_path, 'rb') as handle:
230 | class_info = pickle.load(handle)
231 |
232 | args.train_classes = class_info['known_classes']
233 | open_set_classes = class_info['unknown_classes']
234 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
235 |
236 | else:
237 |
238 | args.train_classes = range(100)
239 | args.unlabeled_classes = range(100, 200)
240 |
241 | elif args.dataset_name == 'chinese_traffic_signs':
242 |
243 | args.image_size = 224
244 | args.train_classes = range(28)
245 | args.unlabeled_classes = range(28, 56)
246 |
247 | else:
248 |
249 | raise NotImplementedError
250 |
251 | return args
--------------------------------------------------------------------------------
/data/cifar.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR10, CIFAR100
2 | from copy import deepcopy
3 | import numpy as np
4 |
5 | from data.data_utils import subsample_instances
6 | from config import cifar_10_root, cifar_100_root
7 |
8 |
9 | class CustomCIFAR10(CIFAR10):
10 |
11 | def __init__(self, *args, **kwargs):
12 |
13 | super(CustomCIFAR10, self).__init__(*args, **kwargs)
14 |
15 | self.uq_idxs = np.array(range(len(self)))
16 |
17 | def __getitem__(self, item):
18 |
19 | img, label = super().__getitem__(item)
20 | uq_idx = self.uq_idxs[item]
21 |
22 | return img, label, uq_idx
23 |
24 | def __len__(self):
25 | return len(self.targets)
26 |
27 |
28 | class CustomCIFAR100(CIFAR100):
29 |
30 | def __init__(self, *args, **kwargs):
31 | super(CustomCIFAR100, self).__init__(*args, **kwargs)
32 |
33 | self.uq_idxs = np.array(range(len(self)))
34 |
35 | def __getitem__(self, item):
36 | img, label = super().__getitem__(item)
37 | uq_idx = self.uq_idxs[item]
38 | return img, label, uq_idx
39 |
40 | def __len__(self):
41 | return len(self.targets)
42 |
43 |
44 | def subsample_dataset(dataset, idxs, absolute=True):
45 | mask = np.zeros(len(dataset)).astype('bool')
46 | if absolute==True:
47 | mask[idxs] = True
48 | else:
49 | idxs = set(idxs)
50 | mask = np.array([i in idxs for i in dataset.uq_idxs])
51 |
52 | dataset.data = dataset.data[mask]
53 | dataset.targets = np.array(dataset.targets)[mask].tolist()
54 | dataset.uq_idxs = dataset.uq_idxs[mask]
55 |
56 | return dataset
57 |
58 |
59 | def subsample_classes(dataset, include_classes=(0, 1, 8, 9)):
60 |
61 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
62 |
63 | target_xform_dict = {}
64 | for i, k in enumerate(include_classes):
65 | target_xform_dict[k] = i
66 |
67 | dataset = subsample_dataset(dataset, cls_idxs)
68 |
69 | # dataset.target_transform = lambda x: target_xform_dict[x]
70 |
71 | return dataset
72 |
73 |
74 | def get_train_val_indices(train_dataset, val_split=0.2):
75 |
76 | train_classes = np.unique(train_dataset.targets)
77 |
78 | # Get train/test indices
79 | train_idxs = []
80 | val_idxs = []
81 | for cls in train_classes:
82 |
83 | cls_idxs = np.where(train_dataset.targets == cls)[0]
84 |
85 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
86 | t_ = [x for x in cls_idxs if x not in v_]
87 |
88 | train_idxs.extend(t_)
89 | val_idxs.extend(v_)
90 |
91 | return train_idxs, val_idxs
92 |
93 |
94 | def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9),
95 | prop_train_labels=0.8, split_train_val=False, seed=0):
96 |
97 | np.random.seed(seed)
98 |
99 | # Init entire training set
100 | whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True)
101 |
102 | # Get labelled training set which has subsampled classes, then subsample some indices from that
103 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
104 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
105 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
106 |
107 | # Split into training and validation sets
108 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
109 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
110 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
111 | val_dataset_labelled_split.transform = test_transform
112 |
113 | # Get unlabelled data
114 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
115 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
116 |
117 | # Get test set for all classes
118 | test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False)
119 |
120 | # Either split train into train and val or use test set as val
121 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
122 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
123 |
124 | all_datasets = {
125 | 'train_labelled': train_dataset_labelled,
126 | 'train_unlabelled': train_dataset_unlabelled,
127 | 'val': val_dataset_labelled,
128 | 'test': test_dataset,
129 | }
130 |
131 | return all_datasets
132 |
133 |
134 | def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80),
135 | prop_train_labels=0.8, split_train_val=False, seed=0):
136 |
137 | np.random.seed(seed)
138 |
139 | # Init entire training set
140 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True)
141 |
142 | # Get labelled training set which has subsampled classes, then subsample some indices from that
143 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
144 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
145 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
146 |
147 | # Split into training and validation sets
148 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
149 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
150 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
151 | val_dataset_labelled_split.transform = test_transform
152 |
153 | # Get unlabelled data
154 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
155 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
156 |
157 | # Get test set for all classes
158 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False)
159 |
160 | # Either split train into train and val or use test set as val
161 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
162 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
163 |
164 | all_datasets = {
165 | 'train_labelled': train_dataset_labelled,
166 | 'train_unlabelled': train_dataset_unlabelled,
167 | 'val': val_dataset_labelled,
168 | 'test': test_dataset,
169 | }
170 |
171 | return all_datasets
172 |
173 |
174 | def get_cifar_100_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80),
175 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1):
176 |
177 | np.random.seed(seed)
178 |
179 | # Init entire training set
180 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True)
181 |
182 | # Split whole dataset by train classes
183 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
184 | # Subsample train classes into trainl
185 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
186 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
187 | # All remaining samples into trainu
188 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
189 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False)
190 |
191 | # Split trainl into train/val
192 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split)
193 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
194 | val_dataset_labelled.transform = test_transform
195 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
196 |
197 | # Split trainu into train/val
198 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split)
199 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs)
200 | val_dataset_unlabelled.transform = test_transform
201 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs)
202 |
203 | val_dataset_unlabelled.transform = test_transform
204 |
205 |
206 | # Get test set for all classes
207 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False)
208 |
209 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}')
210 |
211 | all_datasets = {
212 | 'train_labelled': train_dataset_labelled,
213 | 'train_unlabelled': train_dataset_unlabelled,
214 | 'val': [val_dataset_labelled, val_dataset_unlabelled],
215 | 'test': test_dataset,
216 | }
217 |
218 | return all_datasets
219 |
220 | if __name__ == '__main__':
221 |
222 | x = get_cifar_100_datasets(None, None, split_train_val=False,
223 | train_classes=range(80), prop_train_labels=0.5)
224 |
225 | print('Printing lens...')
226 | for k, v in x.items():
227 | if v is not None:
228 | print(f'{k}: {len(v)}')
229 |
230 | print('Printing labelled and unlabelled overlap...')
231 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
232 | print('Printing total instances in train...')
233 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
234 |
235 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
236 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
237 | print(f'Len labelled set: {len(x["train_labelled"])}')
238 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/cub.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 |
6 | from torchvision.datasets.folder import default_loader
7 | from torchvision.datasets.utils import download_url
8 | from torch.utils.data import Dataset
9 |
10 | from data.data_utils import subsample_instances
11 | from config import cub_root
12 |
13 |
14 | class CustomCub2011(Dataset):
15 | base_folder = 'CUB_200_2011/images'
16 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
17 | filename = 'CUB_200_2011.tgz'
18 | tgz_md5 = '97eceeb196236b17998738112f37df78'
19 |
20 | def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=True):
21 |
22 | self.root = os.path.expanduser(root)
23 | self.transform = transform
24 | self.target_transform = target_transform
25 |
26 | self.loader = loader
27 | self.train = train
28 |
29 |
30 | if download:
31 | self._download()
32 |
33 | if not self._check_integrity():
34 | raise RuntimeError('Dataset not found or corrupted.' +
35 | ' You can use download=True to download it')
36 |
37 | self.uq_idxs = np.array(range(len(self)))
38 |
39 | def _load_metadata(self):
40 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
41 | names=['img_id', 'filepath'])
42 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
43 | sep=' ', names=['img_id', 'target'])
44 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
45 | sep=' ', names=['img_id', 'is_training_img'])
46 |
47 | data = images.merge(image_class_labels, on='img_id')
48 | self.data = data.merge(train_test_split, on='img_id')
49 |
50 | if self.train:
51 | self.data = self.data[self.data.is_training_img == 1]
52 | else:
53 | self.data = self.data[self.data.is_training_img == 0]
54 |
55 | def _check_integrity(self):
56 | try:
57 | self._load_metadata()
58 | except Exception:
59 | return False
60 |
61 | for index, row in self.data.iterrows():
62 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
63 | if not os.path.isfile(filepath):
64 | print(filepath)
65 | return False
66 | return True
67 |
68 | def _download(self):
69 | import tarfile
70 |
71 | if self._check_integrity():
72 | print('Files already downloaded and verified')
73 | return
74 |
75 | download_url(self.url, self.root, self.filename, self.tgz_md5)
76 |
77 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
78 | tar.extractall(path=self.root)
79 |
80 | def __len__(self):
81 | return len(self.data)
82 |
83 | def __getitem__(self, idx):
84 | sample = self.data.iloc[idx]
85 | path = os.path.join(self.root, self.base_folder, sample.filepath)
86 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0
87 | img = self.loader(path)
88 |
89 | if self.transform is not None:
90 | img = self.transform(img)
91 |
92 | if self.target_transform is not None:
93 | target = self.target_transform(target)
94 |
95 | return img, target, self.uq_idxs[idx]
96 |
97 |
98 | def subsample_dataset(dataset, idxs, absolute=True):
99 | mask = np.zeros(len(dataset)).astype('bool')
100 | if absolute==True:
101 | mask[idxs] = True
102 | else:
103 | idxs = set(idxs)
104 | mask = np.array([i in idxs for i in dataset.uq_idxs])
105 |
106 | dataset.data = dataset.data[mask]
107 | dataset.uq_idxs = dataset.uq_idxs[mask]
108 |
109 | return dataset
110 |
111 | def subsample_classes(dataset, include_classes=range(160)):
112 |
113 | include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199
114 | cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub]
115 |
116 | target_xform_dict = {}
117 | for i, k in enumerate(include_classes):
118 | target_xform_dict[k] = i
119 |
120 | dataset = subsample_dataset(dataset, cls_idxs)
121 |
122 | dataset.target_transform = lambda x: target_xform_dict[x]
123 |
124 | return dataset
125 |
126 |
127 | def get_train_val_indices(train_dataset, val_split=0.2):
128 |
129 | train_classes = np.unique(train_dataset.data['target'])
130 |
131 | # Get train/test indices
132 | train_idxs = []
133 | val_idxs = []
134 | for cls in train_classes:
135 |
136 | cls_idxs = np.where(train_dataset.data['target'] == cls)[0]
137 |
138 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
139 | t_ = [x for x in cls_idxs if x not in v_]
140 |
141 | train_idxs.extend(t_)
142 | val_idxs.extend(v_)
143 |
144 | return train_idxs, val_idxs
145 |
146 |
147 | def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
148 | split_train_val=False, seed=0):
149 |
150 | np.random.seed(seed)
151 |
152 | # Init entire training set
153 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True)
154 |
155 | # Get labelled training set which has subsampled classes, then subsample some indices from that
156 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
157 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
158 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
159 |
160 | # Split into training and validation sets
161 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
162 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
163 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
164 | val_dataset_labelled_split.transform = test_transform
165 |
166 | # Get unlabelled data
167 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
168 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
169 |
170 | # Get test set for all classes
171 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False)
172 |
173 | # Either split train into train and val or use test set as val
174 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
175 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
176 |
177 | all_datasets = {
178 | 'train_labelled': train_dataset_labelled,
179 | 'train_unlabelled': train_dataset_unlabelled,
180 | 'val': val_dataset_labelled,
181 | 'test': test_dataset,
182 | }
183 |
184 | return all_datasets
185 |
186 | def get_cub_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80),
187 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1):
188 |
189 | np.random.seed(seed)
190 |
191 | # Init entire training set
192 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True)
193 |
194 | # Get labelled training set which has subsampled classes, then subsample some indices from that
195 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
196 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
197 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
198 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
199 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False)
200 |
201 | # Split into training and validation sets
202 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split)
203 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
204 | val_dataset_labelled.transform = test_transform
205 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
206 |
207 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split)
208 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs)
209 | val_dataset_unlabelled.transform = test_transform
210 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs)
211 |
212 | val_dataset_unlabelled.transform = test_transform
213 |
214 |
215 | # Get test set for all classes
216 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False)
217 |
218 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}')
219 |
220 | all_datasets = {
221 | 'train_labelled': train_dataset_labelled,
222 | 'train_unlabelled': train_dataset_unlabelled,
223 | 'val': [val_dataset_labelled, val_dataset_unlabelled],
224 | 'test': test_dataset,
225 | }
226 |
227 | return all_datasets
228 |
229 | if __name__ == '__main__':
230 |
231 | x = get_cub_datasets(None, None, split_train_val=False,
232 | train_classes=range(100), prop_train_labels=0.5)
233 |
234 | print('Printing lens...')
235 | for k, v in x.items():
236 | if v is not None:
237 | print(f'{k}: {len(v)}')
238 |
239 | print('Printing labelled and unlabelled overlap...')
240 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
241 | print('Printing total instances in train...')
242 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
243 |
244 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}')
245 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}')
246 | print(f'Len labelled set: {len(x["train_labelled"])}')
247 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import numpy as np
3 |
4 | import os
5 |
6 | from copy import deepcopy
7 | from data.data_utils import subsample_instances
8 | from config import imagenet_root
9 |
10 |
11 | class ImageNetBase(torchvision.datasets.ImageFolder):
12 |
13 | def __init__(self, root, transform):
14 |
15 | super(ImageNetBase, self).__init__(root, transform)
16 |
17 | self.uq_idxs = np.array(range(len(self)))
18 |
19 | def __getitem__(self, item):
20 |
21 | img, label = super().__getitem__(item)
22 | uq_idx = self.uq_idxs[item]
23 |
24 | return img, label, uq_idx
25 |
26 | def subsample_dataset(dataset, idxs, absolute=True):
27 | mask = np.zeros(len(dataset)).astype('bool')
28 | if absolute==True:
29 | mask[idxs] = True
30 | else:
31 | idxs = set(idxs)
32 | mask = np.array([i in idxs for i in dataset.uq_idxs])
33 |
34 | dataset.samples = [s for m, s in zip(mask, dataset.samples) if m==True]
35 | dataset.targets = [t for m, t in zip(mask, dataset.targets) if m==True]
36 | dataset.uq_idxs = dataset.uq_idxs[mask]
37 |
38 | return dataset
39 |
40 | def subsample_classes(dataset, include_classes=list(range(1000))):
41 |
42 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
43 |
44 | target_xform_dict = {}
45 | for i, k in enumerate(include_classes):
46 | target_xform_dict[k] = i
47 |
48 | dataset = subsample_dataset(dataset, cls_idxs)
49 | dataset.target_transform = lambda x: target_xform_dict[x]
50 |
51 | return dataset
52 |
53 |
54 | def get_train_val_indices(train_dataset, val_split=0.2):
55 |
56 | train_classes = list(set(train_dataset.targets))
57 |
58 | # Get train/test indices
59 | train_idxs = []
60 | val_idxs = []
61 | for cls in train_classes:
62 |
63 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
64 |
65 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
66 | t_ = [x for x in cls_idxs if x not in v_]
67 |
68 | train_idxs.extend(t_)
69 | val_idxs.extend(v_)
70 |
71 | return train_idxs, val_idxs
72 |
73 |
74 | def get_equal_len_datasets(dataset1, dataset2):
75 | """
76 | Make two datasets the same length
77 | """
78 |
79 | if len(dataset1) > len(dataset2):
80 |
81 | rand_idxs = np.random.choice(range(len(dataset1)), size=(len(dataset2, )))
82 | subsample_dataset(dataset1, rand_idxs)
83 |
84 | elif len(dataset2) > len(dataset1):
85 |
86 | rand_idxs = np.random.choice(range(len(dataset2)), size=(len(dataset1, )))
87 | subsample_dataset(dataset2, rand_idxs)
88 |
89 | return dataset1, dataset2
90 |
91 |
92 | def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80),
93 | prop_train_labels=0.8, split_train_val=False, seed=0):
94 |
95 | np.random.seed(seed)
96 |
97 | # Subsample imagenet dataset initially to include 100 classes
98 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False)
99 | subsampled_100_classes = np.sort(subsampled_100_classes)
100 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}')
101 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))}
102 |
103 | # Init entire training set
104 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
105 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes)
106 |
107 | # Reset dataset
108 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples]
109 | whole_training_set.targets = [s[1] for s in whole_training_set.samples]
110 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set)))
111 | whole_training_set.target_transform = None
112 |
113 | # Get labelled training set which has subsampled classes, then subsample some indices from that
114 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
115 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
116 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
117 |
118 | # Split into training and validation sets
119 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
120 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
121 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
122 | val_dataset_labelled_split.transform = test_transform
123 |
124 | # Get unlabelled data
125 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
126 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
127 |
128 | # Get test set for all classes
129 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
130 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes)
131 | # Reset test set
132 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples]
133 | test_dataset.targets = [s[1] for s in test_dataset.samples]
134 | test_dataset.uq_idxs = np.array(range(len(test_dataset)))
135 | test_dataset.target_transform = None
136 |
137 | #test_dataset = subsample_classes(test_dataset, include_classes=train_classes)
138 |
139 | # Either split train into train and val or use test set as val
140 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
141 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
142 |
143 | all_datasets = {
144 | 'train_labelled': train_dataset_labelled,
145 | 'train_unlabelled': train_dataset_unlabelled,
146 | 'val': val_dataset_labelled,
147 | 'test': test_dataset,
148 | }
149 |
150 | return all_datasets
151 |
152 | def get_imagenet_100_gcd_datasets_with_gcdval(train_transform, test_transform, train_classes=range(50),
153 | prop_train_labels=0.5, split_train_val=True, seed=0, val_split=0.1):
154 |
155 | np.random.seed(seed)
156 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False)
157 | subsampled_100_classes = np.sort(subsampled_100_classes)
158 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}')
159 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))}
160 |
161 | # Init entire training set
162 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
163 | # NOTE: use GCD paper IN split
164 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes)
165 | # Reset dataset
166 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples]
167 | whole_training_set.targets = [s[1] for s in whole_training_set.samples]
168 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set)))
169 | whole_training_set.target_transform = None
170 |
171 | # Get labelled training set which has subsampled classes, then subsample some indices from that
172 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
173 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
174 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
175 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
176 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False)
177 |
178 | # Split into training and validation sets
179 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split)
180 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
181 | val_dataset_labelled.transform = test_transform
182 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
183 |
184 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split)
185 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs)
186 | val_dataset_unlabelled.transform = test_transform
187 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs)
188 |
189 | val_dataset_unlabelled.transform = test_transform
190 |
191 |
192 | # Get test set for all classes
193 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
194 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes)
195 |
196 | # Reset test set
197 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples]
198 | test_dataset.targets = [s[1] for s in test_dataset.samples]
199 | test_dataset.uq_idxs = np.array(range(len(test_dataset)))
200 | test_dataset.target_transform = None
201 |
202 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}')
203 |
204 | all_datasets = {
205 | 'train_labelled': train_dataset_labelled,
206 | 'train_unlabelled': train_dataset_unlabelled,
207 | 'val': [val_dataset_labelled, val_dataset_unlabelled],
208 | 'test': test_dataset,
209 | }
210 |
211 | return all_datasets
212 |
213 |
214 | if __name__ == '__main__':
215 |
216 | x = get_imagenet_100_datasets(None, None, split_train_val=False,
217 | train_classes=range(50), prop_train_labels=0.5)
218 |
219 | print('Printing lens...')
220 | for k, v in x.items():
221 | if v is not None:
222 | print(f'{k}: {len(v)}')
223 |
224 | print('Printing labelled and unlabelled overlap...')
225 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
226 | print('Printing total instances in train...')
227 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
228 |
229 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
230 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
231 | print(f'Len labelled set: {len(x["train_labelled"])}')
232 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/methods/contrastive_meanshift_training.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import numpy as np
6 | import torch.nn as nn
7 | import wandb
8 |
9 | from torch.nn import functional as F
10 | from torch.utils.data import DataLoader
11 | from torch.optim import SGD, lr_scheduler
12 | from tqdm import tqdm
13 |
14 | from project_utils.cluster_utils import AverageMeter
15 | from data.augmentations import get_transform
16 | from data.get_datasets import get_datasets, get_datasets_with_gcdval, get_class_splits
17 | from project_utils.cluster_and_log_utils import *
18 | from project_utils.general_utils import init_experiment, str2bool
19 |
20 | from models.dino import *
21 | from methods.loss import *
22 |
23 | import warnings
24 | warnings.filterwarnings("ignore", category=DeprecationWarning)
25 |
26 | class ContrastiveLearningViewGenerator(object):
27 | """Take two random crops of one image as the query and key."""
28 |
29 | def __init__(self, base_transform, n_views=2):
30 | self.base_transform = base_transform
31 | self.n_views = n_views
32 |
33 | def __call__(self, x):
34 | return [self.base_transform(x) for i in range(self.n_views)]
35 |
36 |
37 | def train(model, train_loader, test_loader, train_loader_memory, args):
38 |
39 | optimizer = SGD(list(model.module.parameters()), lr=args.lr, momentum=args.momentum,
40 | weight_decay=args.weight_decay)
41 | exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(
42 | optimizer,
43 | T_max=args.epochs,
44 | eta_min=args.lr * args.eta_min,
45 | )
46 |
47 | sup_con_crit = SupConLoss()
48 | unsup_con_crit = ConMeanShiftLoss(args)
49 |
50 | best_agglo_score = 0
51 | best_agglo_img_score = 0
52 | best_img_old_score = 0
53 |
54 | for epoch in range(args.epochs):
55 | loss_record = AverageMeter()
56 | with torch.no_grad():
57 | model.eval()
58 | all_feats = []
59 | for batch_idx, batch in enumerate(tqdm(train_loader_memory)):
60 | images, class_labels, uq_idxs, mask_lab = batch
61 | images = torch.cat(images, dim=0).to(device)
62 |
63 | features = model(images)
64 | all_feats.append(features.detach().cpu())
65 | all_feats = torch.cat(all_feats)
66 |
67 | model.train()
68 | for batch_idx, batch in enumerate(tqdm(train_loader)):
69 | images, class_labels, uq_idxs, mask_lab = batch
70 | class_labels, mask_lab = class_labels.to(device), mask_lab[:, 0].to(device).bool()
71 | images = torch.cat(images, dim=0).to(device)
72 |
73 | features = model(images)
74 |
75 | classwise_sim = torch.einsum('b d, n d -> b n', features.cpu(), all_feats)
76 |
77 | _, indices = classwise_sim.topk(k=args.k+1, dim=-1, largest=True, sorted=True)
78 | indices = indices[:, 1:]
79 | knn_emb = torch.mean(all_feats[indices, :].view(-1, args.k, args.feat_dim), dim=1).to(device)
80 |
81 | if args.contrast_unlabel_only:
82 | # Contrastive loss only on unlabelled instances
83 | f1, f2 = [f[~mask_lab] for f in features.chunk(2)]
84 | con_feats = torch.cat([f1, f2], dim=0)
85 | f3, f4 = [f[~mask_lab] for f in knn_emb.chunk(2)]
86 | con_knn_emb = torch.cat([f3, f4], dim=0)
87 | con_uq_idxs = uq_idxs[~mask_lab]
88 |
89 | else:
90 | # Contrastive loss for all examples
91 | con_feats = features
92 | con_knn_emb = knn_emb
93 | con_uq_idxs = uq_idxs
94 |
95 | unsup_con_loss = unsup_con_crit(con_knn_emb, con_feats)
96 |
97 | # Supervised contrastive loss
98 | f1, f2 = [f[mask_lab] for f in features.chunk(2)]
99 | sup_con_feats = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
100 | sup_con_labels = class_labels[mask_lab]
101 | sup_con_loss = sup_con_crit(sup_con_feats, labels=sup_con_labels)
102 |
103 | # Total loss
104 | loss = (1 - args.sup_con_weight) * unsup_con_loss + args.sup_con_weight * (sup_con_loss)
105 |
106 | loss_record.update(loss.item(), class_labels.size(0))
107 | optimizer.zero_grad()
108 | loss.backward()
109 | optimizer.step()
110 |
111 | print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
112 |
113 | with torch.no_grad():
114 | model.eval()
115 | all_feats_val = []
116 | targets = np.array([])
117 | mask = np.array([])
118 |
119 | for batch_idx, batch in enumerate(tqdm(test_loader)):
120 | images, label, _ = batch[:3]
121 | images = images.cuda()
122 |
123 | features = model(images)
124 | all_feats_val.append(features.detach().cpu().numpy())
125 | targets = np.append(targets, label.cpu().numpy())
126 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes))
127 | else False for x in label]))
128 |
129 | # -----------------
130 | # Clustering
131 | # -----------------
132 | img_all_acc_test, img_old_acc_test, img_new_acc_test, img_agglo_score, estimated_k = test_agglo(epoch, all_feats_val, targets, mask, "Test/ACC", args)
133 | if args.wandb:
134 | wandb.log({ 'test/all': img_all_acc_test, 'test/base': img_old_acc_test, 'test/novel': img_new_acc_test,
135 | 'score/agglo': img_agglo_score, 'score/estimated_k': estimated_k}, step=epoch)
136 |
137 | print('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(img_all_acc_test, img_old_acc_test, img_new_acc_test))
138 |
139 | # Step schedule
140 | exp_lr_scheduler.step()
141 | torch.save(model.state_dict(), args.model_path)
142 |
143 | if img_agglo_score > best_agglo_img_score:
144 | torch.save({
145 | 'k': estimated_k,
146 | 'model_state_dict': model.state_dict()}
147 | , args.model_path[:-3] + f'_best.pt')
148 | best_agglo_img_score = img_agglo_score
149 |
150 |
151 | if __name__ == "__main__":
152 |
153 | parser = argparse.ArgumentParser(
154 | description='cluster',
155 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
156 | parser.add_argument('--batch_size', default=128, type=int)
157 | parser.add_argument('--num_workers', default=16, type=int)
158 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2'])
159 |
160 | parser.add_argument('--warmup_model_dir', type=str, default=None)
161 | parser.add_argument('--exp_root', type=str, default='/home/suachoi/CMS/log')
162 | parser.add_argument('--pretrain_path', type=str, default='/home/suachoi/CMS/models')
163 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, scars')
164 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
165 | parser.add_argument('--use_ssb_splits', type=bool, default=True)
166 |
167 | parser.add_argument('--lr', type=float, default=0.01)
168 | parser.add_argument('--eta_min', type=float, default=1e-3)
169 | parser.add_argument('--epochs', default=200, type=int)
170 | parser.add_argument('--momentum', type=float, default=0.9)
171 | parser.add_argument('--weight_decay', type=float, default=5e-5)
172 | parser.add_argument('--transform', type=str, default='imagenet')
173 | parser.add_argument('--seed', default=1, type=int)
174 | parser.add_argument('--n_views', default=2, type=int)
175 | parser.add_argument('--contrast_unlabel_only', type=bool, default=False)
176 | parser.add_argument('--sup_con_weight', type=float, default=0.35)
177 | parser.add_argument('--temperature', type=float, default=0.5)
178 |
179 | parser.add_argument('--alpha', type=float, default=0.5)
180 | parser.add_argument('--k', default=8, type=int)
181 | parser.add_argument('--inductive', action='store_true')
182 | parser.add_argument('--wandb', action='store_true', help='Flag to log at wandb')
183 |
184 | # ----------------------
185 | # INIT
186 | # ----------------------
187 | args = parser.parse_args()
188 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
189 | args = get_class_splits(args)
190 |
191 | args.feat_dim = 768
192 | args.interpolation = 3
193 | args.crop_pct = 0.875
194 | args.image_size = 224
195 | args.num_mlp_layers = 3
196 | args.num_labeled_classes = len(args.train_classes)
197 | args.num_unlabeled_classes = len(args.unlabeled_classes)
198 |
199 | init_experiment(args, runner_name=['cms'])
200 | print(f'Using evaluation function {args.eval_funcs[0]} to print results')
201 |
202 | if args.wandb:
203 | wandb.init(project='CMS')
204 | wandb.config.update(args)
205 |
206 |
207 | # --------------------
208 | # MODEL
209 | # --------------------
210 | model = DINO(args)
211 | model = nn.DataParallel(model).to(device)
212 |
213 | # --------------------
214 | # CONTRASTIVE TRANSFORM
215 | # --------------------
216 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
217 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
218 |
219 | # --------------------
220 | # DATASETS
221 | # --------------------
222 | if args.inductive:
223 | train_dataset, test_dataset, unlabelled_train_examples_test, val_datasets, datasets = get_datasets_with_gcdval(args.dataset_name, train_transform, test_transform, args)
224 | else:
225 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name, train_transform, test_transform, args)
226 |
227 |
228 | # --------------------
229 | # SAMPLER
230 | # Sampler which balances labelled and unlabelled examples in each batch
231 | # --------------------
232 | label_len = len(train_dataset.labelled_dataset)
233 | unlabelled_len = len(train_dataset.unlabelled_dataset)
234 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
235 | sample_weights = torch.DoubleTensor(sample_weights)
236 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset))
237 |
238 | # --------------------
239 | # DATALOADERS
240 | # --------------------
241 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, sampler=sampler, drop_last=True)
242 | train_loader_memory = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, drop_last=False)
243 | if args.inductive:
244 | test_loader_labelled = DataLoader(val_datasets, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
245 | else:
246 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
247 |
248 |
249 | # ----------------------
250 | # TRAIN
251 | # ----------------------
252 | train(model, train_loader, test_loader_labelled, train_loader_memory, args)
--------------------------------------------------------------------------------
/models/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 | from torch.nn.init import trunc_normal_
25 |
26 | def drop_path(x, drop_prob: float = 0., training: bool = False):
27 | if drop_prob == 0. or not training:
28 | return x
29 | keep_prob = 1 - drop_prob
30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
32 | random_tensor.floor_() # binarize
33 | output = x.div(keep_prob) * random_tensor
34 | return output
35 |
36 |
37 | class DropPath(nn.Module):
38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
39 | """
40 | def __init__(self, drop_prob=None):
41 | super(DropPath, self).__init__()
42 | self.drop_prob = drop_prob
43 |
44 | def forward(self, x):
45 | return drop_path(x, self.drop_prob, self.training)
46 |
47 |
48 | class Mlp(nn.Module):
49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50 | super().__init__()
51 | out_features = out_features or in_features
52 | hidden_features = hidden_features or in_features
53 | self.fc1 = nn.Linear(in_features, hidden_features)
54 | self.act = act_layer()
55 | self.fc2 = nn.Linear(hidden_features, out_features)
56 | self.drop = nn.Dropout(drop)
57 |
58 | def forward(self, x):
59 | x = self.fc1(x)
60 | x = self.act(x)
61 | x = self.drop(x)
62 | x = self.fc2(x)
63 | x = self.drop(x)
64 | return x
65 |
66 |
67 | class Attention(nn.Module):
68 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
69 | super().__init__()
70 | self.num_heads = num_heads
71 | head_dim = dim // num_heads
72 | self.scale = qk_scale or head_dim ** -0.5
73 |
74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
75 | self.attn_drop = nn.Dropout(attn_drop)
76 | self.proj = nn.Linear(dim, dim)
77 | self.proj_drop = nn.Dropout(proj_drop)
78 |
79 | def forward(self, x):
80 | B, N, C = x.shape
81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
82 | q, k, v = qkv[0], qkv[1], qkv[2]
83 |
84 | attn = (q @ k.transpose(-2, -1)) * self.scale
85 | attn = attn.softmax(dim=-1)
86 | attn = self.attn_drop(attn)
87 |
88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
89 | x = self.proj(x)
90 | x = self.proj_drop(x)
91 | return x, attn
92 |
93 |
94 | class Block(nn.Module):
95 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
96 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
97 | super().__init__()
98 | self.norm1 = norm_layer(dim)
99 | self.attn = Attention(
100 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
102 | self.norm2 = norm_layer(dim)
103 | mlp_hidden_dim = int(dim * mlp_ratio)
104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
105 |
106 | def forward(self, x, return_attention=False):
107 | y, attn = self.attn(self.norm1(x))
108 | x = x + self.drop_path(y)
109 | x = x + self.drop_path(self.mlp(self.norm2(x)))
110 |
111 | if return_attention:
112 | return x, attn
113 | else:
114 | return x
115 |
116 |
117 | class PatchEmbed(nn.Module):
118 | """ Image to Patch Embedding
119 | """
120 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
121 | super().__init__()
122 | num_patches = (img_size // patch_size) * (img_size // patch_size)
123 | self.img_size = img_size
124 | self.patch_size = patch_size
125 | self.num_patches = num_patches
126 |
127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
128 |
129 | def forward(self, x):
130 | B, C, H, W = x.shape
131 | x = self.proj(x).flatten(2).transpose(1, 2)
132 | return x
133 |
134 |
135 | class VisionTransformer(nn.Module):
136 | """ Vision Transformer """
137 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
138 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
139 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
140 | super().__init__()
141 | self.num_features = self.embed_dim = embed_dim
142 |
143 | self.patch_embed = PatchEmbed(
144 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
145 | num_patches = self.patch_embed.num_patches
146 |
147 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
148 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
149 | self.pos_drop = nn.Dropout(p=drop_rate)
150 |
151 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
152 | self.blocks = nn.ModuleList([
153 | Block(
154 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
155 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
156 | for i in range(depth)])
157 | self.norm = norm_layer(embed_dim)
158 |
159 | # Classifier head
160 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
161 |
162 | trunc_normal_(self.pos_embed, std=.02)
163 | trunc_normal_(self.cls_token, std=.02)
164 | self.apply(self._init_weights)
165 |
166 | def _init_weights(self, m):
167 | if isinstance(m, nn.Linear):
168 | trunc_normal_(m.weight, std=.02)
169 | if isinstance(m, nn.Linear) and m.bias is not None:
170 | nn.init.constant_(m.bias, 0)
171 | elif isinstance(m, nn.LayerNorm):
172 | nn.init.constant_(m.bias, 0)
173 | nn.init.constant_(m.weight, 1.0)
174 |
175 | def interpolate_pos_encoding(self, x, w, h):
176 | npatch = x.shape[1] - 1
177 | N = self.pos_embed.shape[1] - 1
178 | if npatch == N and w == h:
179 | return self.pos_embed
180 | class_pos_embed = self.pos_embed[:, 0]
181 | patch_pos_embed = self.pos_embed[:, 1:]
182 | dim = x.shape[-1]
183 | w0 = w // self.patch_embed.patch_size
184 | h0 = h // self.patch_embed.patch_size
185 | # we add a small number to avoid floating point error in the interpolation
186 | # see discussion at https://github.com/facebookresearch/dino/issues/8
187 | w0, h0 = w0 + 0.1, h0 + 0.1
188 | patch_pos_embed = nn.functional.interpolate(
189 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
190 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
191 | mode='bicubic',
192 | )
193 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
194 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
195 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
196 |
197 | def prepare_tokens(self, x):
198 | B, nc, w, h = x.shape
199 | x = self.patch_embed(x) # patch linear embedding
200 |
201 | # add the [CLS] token to the embed patch tokens
202 | cls_tokens = self.cls_token.expand(B, -1, -1)
203 | x = torch.cat((cls_tokens, x), dim=1)
204 |
205 | # add positional encoding to each token
206 | x = x + self.interpolate_pos_encoding(x, w, h)
207 |
208 | return self.pos_drop(x)
209 |
210 | def forward(self, x, return_all_patches=False):
211 | x = self.prepare_tokens(x)
212 | for blk in self.blocks:
213 | x = blk(x)
214 | x = self.norm(x)
215 |
216 | if return_all_patches:
217 | return x
218 | else:
219 | return x[:, 0]
220 |
221 | def get_last_selfattention(self, x):
222 | x = self.prepare_tokens(x)
223 | for i, blk in enumerate(self.blocks):
224 | if i < len(self.blocks) - 1:
225 | x = blk(x)
226 | else:
227 | # return attention of the last block
228 | x, attn = blk(x, return_attention=True)
229 | x = self.norm(x)
230 | return x, attn
231 |
232 | def get_intermediate_layers(self, x, n=1):
233 | x = self.prepare_tokens(x)
234 | # we return the output tokens from the `n` last blocks
235 | output = []
236 | for i, blk in enumerate(self.blocks):
237 | x = blk(x)
238 | if len(self.blocks) - i <= n:
239 | output.append(self.norm(x))
240 | return output
241 |
242 |
243 | def vit_tiny(patch_size=16, **kwargs):
244 | model = VisionTransformer(
245 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
246 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
247 | return model
248 |
249 |
250 | def vit_small(patch_size=16, **kwargs):
251 | model = VisionTransformer(
252 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
253 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
254 | return model
255 |
256 |
257 | def vit_base(patch_size=16, **kwargs):
258 | model = VisionTransformer(
259 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
260 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
261 | return model
262 |
263 |
264 | class DINOHead(nn.Module):
265 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
266 | super().__init__()
267 | nlayers = max(nlayers, 1)
268 | if nlayers == 1:
269 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
270 | else:
271 | layers = [nn.Linear(in_dim, hidden_dim)]
272 | if use_bn:
273 | layers.append(nn.BatchNorm1d(hidden_dim))
274 | layers.append(nn.GELU())
275 | for _ in range(nlayers - 2):
276 | layers.append(nn.Linear(hidden_dim, hidden_dim))
277 | if use_bn:
278 | layers.append(nn.BatchNorm1d(hidden_dim))
279 | layers.append(nn.GELU())
280 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
281 | self.mlp = nn.Sequential(*layers)
282 | self.apply(self._init_weights)
283 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
284 | self.last_layer.weight_g.data.fill_(1)
285 | if norm_last_layer:
286 | self.last_layer.weight_g.requires_grad = False
287 |
288 | def _init_weights(self, m):
289 | if isinstance(m, nn.Linear):
290 | trunc_normal_(m.weight, std=.02)
291 | if isinstance(m, nn.Linear) and m.bias is not None:
292 | nn.init.constant_(m.bias, 0)
293 |
294 | def forward(self, x):
295 | x = self.mlp(x)
296 | x = nn.functional.normalize(x, dim=-1, p=2)
297 | x = self.last_layer(x)
298 | return x
299 |
300 |
301 | class VisionTransformerWithLinear(nn.Module):
302 |
303 | def __init__(self, base_vit, num_classes=200):
304 |
305 | super().__init__()
306 |
307 | self.base_vit = base_vit
308 | self.fc = nn.Linear(768, num_classes)
309 |
310 | def forward(self, x, return_features=False):
311 |
312 | features = self.base_vit(x)
313 | features = torch.nn.functional.normalize(features, dim=-1)
314 | logits = self.fc(features)
315 |
316 | if return_features:
317 | return logits, features
318 | else:
319 | return logits
320 |
321 | @torch.no_grad()
322 | def normalize_prototypes(self):
323 | w = self.fc.weight.data.clone()
324 | w = torch.nn.functional.normalize(w, dim=1, p=2)
325 | self.fc.weight.copy_(w)
--------------------------------------------------------------------------------
/data/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 |
6 | from torchvision.datasets.folder import default_loader
7 | from torch.utils.data import Dataset
8 |
9 | from data.data_utils import subsample_instances
10 | from config import aircraft_root
11 |
12 | def make_dataset(dir, image_ids, targets):
13 | assert(len(image_ids) == len(targets))
14 | images = []
15 | dir = os.path.expanduser(dir)
16 | for i in range(len(image_ids)):
17 | item = (os.path.join(dir, 'data', 'images',
18 | '%s.jpg' % image_ids[i]), targets[i])
19 | images.append(item)
20 | return images
21 |
22 |
23 | def find_classes(classes_file):
24 |
25 | # read classes file, separating out image IDs and class names
26 | image_ids = []
27 | targets = []
28 | f = open(classes_file, 'r')
29 | for line in f:
30 | split_line = line.split(' ')
31 | image_ids.append(split_line[0])
32 | targets.append(' '.join(split_line[1:]))
33 | f.close()
34 |
35 | # index class names
36 | classes = np.unique(targets)
37 | class_to_idx = {classes[i]: i for i in range(len(classes))}
38 | targets = [class_to_idx[c] for c in targets]
39 |
40 | return (image_ids, targets, classes, class_to_idx)
41 |
42 |
43 | class FGVCAircraft(Dataset):
44 |
45 | """`FGVC-Aircraft `_ Dataset.
46 |
47 | Args:
48 | root (string): Root directory path to dataset.
49 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
50 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
51 | transform (callable, optional): A function/transform that takes in a PIL image
52 | and returns a transformed version. E.g. ``transforms.RandomCrop``
53 | target_transform (callable, optional): A function/transform that takes in the
54 | target and transforms it.
55 | loader (callable, optional): A function to load an image given its path.
56 | download (bool, optional): If true, downloads the dataset from the internet and
57 | puts it in the root directory. If dataset is already downloaded, it is not
58 | downloaded again.
59 | """
60 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
61 | class_types = ('variant', 'family', 'manufacturer')
62 | splits = ('train', 'val', 'trainval', 'test')
63 |
64 | def __init__(self, root, class_type='variant', split='train', transform=None,
65 | target_transform=None, loader=default_loader, download=False):
66 | if split not in self.splits:
67 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
68 | split, ', '.join(self.splits),
69 | ))
70 | if class_type not in self.class_types:
71 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
72 | class_type, ', '.join(self.class_types),
73 | ))
74 | self.root = os.path.expanduser(root)
75 | self.class_type = class_type
76 | self.split = split
77 | self.classes_file = os.path.join(self.root, 'data',
78 | 'images_%s_%s.txt' % (self.class_type, self.split))
79 |
80 | if download:
81 | self.download()
82 |
83 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
84 | samples = make_dataset(self.root, image_ids, targets)
85 |
86 | self.transform = transform
87 | self.target_transform = target_transform
88 | self.loader = loader
89 |
90 | self.samples = samples
91 | self.classes = classes
92 | self.class_to_idx = class_to_idx
93 | self.train = True if split == 'train' else False
94 |
95 | self.uq_idxs = np.array(range(len(self)))
96 |
97 | def __getitem__(self, index):
98 | """
99 | Args:
100 | index (int): Index
101 |
102 | Returns:
103 | tuple: (sample, target) where target is class_index of the target class.
104 | """
105 |
106 | path, target = self.samples[index]
107 | sample = self.loader(path)
108 | if self.transform is not None:
109 | sample = self.transform(sample)
110 | if self.target_transform is not None:
111 | target = self.target_transform(target)
112 |
113 | return sample, target, self.uq_idxs[index]
114 |
115 | def __len__(self):
116 | return len(self.samples)
117 |
118 | def __repr__(self):
119 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
120 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
121 | fmt_str += ' Root Location: {}\n'.format(self.root)
122 | tmp = ' Transforms (if any): '
123 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
124 | tmp = ' Target Transforms (if any): '
125 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
126 | return fmt_str
127 |
128 | def _check_exists(self):
129 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
130 | os.path.exists(self.classes_file)
131 |
132 | def download(self):
133 | """Download the FGVC-Aircraft data if it doesn't exist already."""
134 | from six.moves import urllib
135 | import tarfile
136 |
137 | if self._check_exists():
138 | return
139 |
140 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
141 | print('Downloading %s ... (may take a few minutes)' % self.url)
142 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
143 | tar_name = self.url.rpartition('/')[-1]
144 | tar_path = os.path.join(parent_dir, tar_name)
145 | data = urllib.request.urlopen(self.url)
146 |
147 | # download .tar.gz file
148 | with open(tar_path, 'wb') as f:
149 | f.write(data.read())
150 |
151 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
152 | data_folder = tar_path.strip('.tar.gz')
153 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
154 | tar = tarfile.open(tar_path)
155 | tar.extractall(parent_dir)
156 |
157 | # if necessary, rename data folder to self.root
158 | if not os.path.samefile(data_folder, self.root):
159 | print('Renaming %s to %s ...' % (data_folder, self.root))
160 | os.rename(data_folder, self.root)
161 |
162 | # delete .tar.gz file
163 | print('Deleting %s ...' % tar_path)
164 | os.remove(tar_path)
165 |
166 | print('Done!')
167 |
168 |
169 | def subsample_dataset(dataset, idxs, absolute=True):
170 |
171 | mask = np.zeros(len(dataset)).astype('bool')
172 | if absolute==True:
173 | mask[idxs] = True
174 | else:
175 | idxs = set(idxs)
176 | mask = np.array([i in idxs for i in dataset.uq_idxs])
177 |
178 | dataset.samples = [(p, t) for m, (p, t) in zip(mask, dataset.samples) if m==True]
179 | dataset.uq_idxs = dataset.uq_idxs[mask]
180 |
181 | return dataset
182 |
183 |
184 | def subsample_classes(dataset, include_classes=range(60)):
185 |
186 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes]
187 |
188 | # TODO: Don't transform targets for now
189 | target_xform_dict = {}
190 | for i, k in enumerate(include_classes):
191 | target_xform_dict[k] = i
192 |
193 | dataset = subsample_dataset(dataset, cls_idxs)
194 |
195 | dataset.target_transform = lambda x: target_xform_dict[x]
196 |
197 | return dataset
198 |
199 |
200 | def get_train_val_indices(train_dataset, val_split=0.2):
201 |
202 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)]
203 | train_classes = np.unique(all_targets)
204 |
205 | # Get train/test indices
206 | train_idxs = []
207 | val_idxs = []
208 | for cls in train_classes:
209 | cls_idxs = np.where(all_targets == cls)[0]
210 |
211 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
212 | t_ = [x for x in cls_idxs if x not in v_]
213 |
214 | train_idxs.extend(t_)
215 | val_idxs.extend(v_)
216 |
217 | return train_idxs, val_idxs
218 |
219 |
220 | def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8,
221 | split_train_val=False, seed=0):
222 |
223 | np.random.seed(seed)
224 |
225 | # Init entire training set
226 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval')
227 |
228 | # Get labelled training set which has subsampled classes, then subsample some indices from that
229 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
230 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
231 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
232 |
233 | # Split into training and validation sets
234 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
235 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
236 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
237 | val_dataset_labelled_split.transform = test_transform
238 |
239 | # Get unlabelled data
240 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
241 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
242 |
243 | # Get test set for all classes
244 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test')
245 |
246 | # Either split train into train and val or use test set as val
247 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
248 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
249 |
250 | all_datasets = {
251 | 'train_labelled': train_dataset_labelled,
252 | 'train_unlabelled': train_dataset_unlabelled,
253 | 'val': val_dataset_labelled,
254 | 'test': test_dataset,
255 | }
256 |
257 | return all_datasets
258 |
259 | def get_aircraft_datasets_with_gcdval(train_transform, test_transform, train_classes=range(80),
260 | prop_train_labels=0.8, split_train_val=True, seed=0, val_split=0.1):
261 |
262 | np.random.seed(seed)
263 |
264 | # Init entire training set
265 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval')
266 |
267 | # Get labelled training set which has subsampled classes, then subsample some indices from that
268 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
269 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
270 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
271 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
272 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)), absolute=False)
273 |
274 | # Split into training and validation sets
275 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, val_split=val_split)
276 | val_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
277 | val_dataset_labelled.transform = test_transform
278 | train_dataset_labelled = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
279 |
280 | train_idxs, val_idxs = get_train_val_indices(train_dataset_unlabelled, val_split=val_split)
281 | val_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), val_idxs)
282 | val_dataset_unlabelled.transform = test_transform
283 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset_unlabelled), train_idxs)
284 |
285 | val_dataset_unlabelled.transform = test_transform
286 |
287 |
288 | # Get test set for all classes
289 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test')
290 |
291 | print(f'total={len(whole_training_set)} train={len(train_dataset_labelled)} {len(train_dataset_unlabelled)} val={len(val_dataset_labelled)} {len(val_dataset_unlabelled)} test={len(test_dataset)}')
292 |
293 | all_datasets = {
294 | 'train_labelled': train_dataset_labelled,
295 | 'train_unlabelled': train_dataset_unlabelled,
296 | 'val': [val_dataset_labelled, val_dataset_unlabelled],
297 | 'test': test_dataset,
298 | }
299 |
300 | return all_datasets
301 |
302 | if __name__ == '__main__':
303 |
304 | x = get_aircraft_datasets(None, None, split_train_val=False)
305 |
306 | print('Printing lens...')
307 | for k, v in x.items():
308 | if v is not None:
309 | print(f'{k}: {len(v)}')
310 |
311 | print('Printing labelled and unlabelled overlap...')
312 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
313 | print('Printing total instances in train...')
314 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
315 | print('Printing number of labelled classes...')
316 | print(len(set([i[1] for i in x['train_labelled'].samples])))
317 | print('Printing total number of classes...')
318 | print(len(set([i[1] for i in x['train_unlabelled'].samples])))
319 |
--------------------------------------------------------------------------------
/methods/meanshift_clustering.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import timm
4 | import argparse
5 | import numpy as np
6 |
7 |
8 | from tqdm import tqdm
9 | from torch import nn
10 | from torch.nn import functional as F
11 | from torch.utils.data import DataLoader
12 | from torchvision import transforms
13 |
14 | from data.stanford_cars import CarsDataset
15 | from data.cifar import CustomCIFAR10, CustomCIFAR100, cifar_10_root, cifar_100_root
16 | from data.herbarium_19 import HerbariumDataset19, herbarium_dataroot
17 | from data.augmentations import get_transform
18 | from data.imagenet import get_imagenet_100_datasets
19 | from data.data_utils import MergedDataset
20 | from data.cub import CustomCub2011, cub_root, get_cub_datasets
21 | from data.fgvc_aircraft import FGVCAircraft, aircraft_root
22 | from data.get_datasets import get_datasets, get_class_splits, get_datasets_with_gcdval
23 |
24 | from project_utils.general_utils import strip_state_dict, str2bool
25 | from copy import deepcopy
26 | from project_utils.cluster_and_log_utils import *
27 |
28 | from models.dino import *
29 | from models import vision_transformer as vits
30 | from config import dino_pretrain_path
31 |
32 |
33 | def iterative_meanshift(model, loader, args):
34 | """
35 | This function measures clustering accuracies on GCD setup in both GT, predicted number of classes
36 |
37 | Clustering : labeled train dataset, unlabeled train dataset
38 | Stopping condition : labeled train dataset
39 | Acc : unlabeled train dataset
40 | """
41 | num_clusters = [args.num_labeled_classes + args.num_unlabeled_classes, args.num_clusters]
42 | acc = [0, 0]
43 | max_acc = [0, 0]
44 | tolerance = [0, 0]
45 | final_acc = [(0,0,0), (0,0,0)]
46 | print('Predicted number of clusters: ', args.num_clusters)
47 |
48 | all_feats = torch.zeros(size=(len(loader.dataset), args.feat_dim))
49 | new_feats = torch.zeros(size=(len(loader.dataset), args.feat_dim))
50 | targets = torch.zeros(len(loader.dataset), dtype=int)
51 | mask_lab = torch.zeros(len(loader.dataset), dtype=bool)
52 | mask_cls = torch.zeros(len(loader.dataset), dtype=bool)
53 |
54 | with torch.no_grad():
55 | for epoch in range(args.epochs):
56 | # Save embeddings
57 | for batch_idx, batch in enumerate(tqdm(loader)):
58 | images, label, uq_idxs, mask_lab_ = batch
59 | if epoch == 0:
60 | images = torch.Tensor(images).to(device)
61 | all_feats[uq_idxs] = model(images).detach().cpu()
62 | targets[uq_idxs] = label.cpu()
63 | mask_lab[uq_idxs] = mask_lab_.squeeze(1).cpu().bool()
64 | else:
65 | classwise_sim = torch.einsum('b d, n d -> b n', all_feats[uq_idxs], all_feats)
66 | _, indices = classwise_sim.topk(k=args.k+1, dim=-1, largest=True, sorted=True)
67 | indices = indices[:, 1:]
68 | knn_emb = torch.mean(all_feats[indices].view(-1, args.k, args.feat_dim), dim=1)
69 | new_feats[uq_idxs] = (1-args.alpha) * all_feats[uq_idxs] + args.alpha * knn_emb.detach().cpu()
70 |
71 |
72 | if epoch == 0:
73 | mask_cls = np.isin(targets, range(len(args.train_classes))).astype(bool)
74 | mask_lab = mask_lab.numpy().astype(bool)
75 | l_targets = targets[mask_lab].numpy()
76 | u_targets = targets[~mask_lab].numpy()
77 | mask = mask_cls[~mask_lab]
78 | mask = mask.astype(bool)
79 | else:
80 | norm = torch.sqrt(torch.sum((torch.pow(new_feats, 2)), dim=-1)).unsqueeze(1)
81 | new_feats = new_feats / norm
82 | all_feats = new_feats
83 |
84 | # Agglomerative clustering
85 | linked = linkage(all_feats, method="ward")
86 | for i in range(len(num_clusters)):
87 | if num_clusters[i]:
88 | print('num clusters', num_clusters[i])
89 | else:
90 | continue
91 |
92 | threshold = linked[:, 2][-num_clusters[i]]
93 | preds = fcluster(linked, t=threshold, criterion='distance')
94 |
95 | old_acc_train = compute_acc(l_targets, preds[mask_lab])
96 | all_acc_test, old_acc_test, new_acc_test = log_accs_from_preds(y_true=u_targets, y_pred=preds[~mask_lab], mask=mask,
97 | T=epoch, eval_funcs=args.eval_funcs, save_name='IMS unlabeled train ACC', print_output=True)
98 |
99 | # Stopping condition with tolerance
100 | tolerance[i] = 0 if max_acc[i] < old_acc_train else tolerance[i] + 1
101 |
102 | if max_acc[i] <= old_acc_train:
103 | max_acc[i] = old_acc_train
104 | acc[i] = (all_acc_test, old_acc_test, new_acc_test)
105 |
106 | # Stop
107 | if tolerance[i] >= 2:
108 | num_clusters[i] = 0
109 | final_acc[i] = acc[i]
110 |
111 | # If both GT & predicted K stopped, break
112 | if sum(num_clusters) == 0:
113 | break
114 |
115 | print(f'ACC with GT number of clusters: All {final_acc[0][0]:.4f} | Old {final_acc[0][1]:.4f} | New {final_acc[0][2]:.4f}')
116 | print(f'ACC with predicted number of clusters: All {final_acc[1][0]:.4f} | Old {final_acc[1][1]:.4f} | New {final_acc[1][2]:.4f}')
117 |
118 |
119 | def iterative_meanshift_inductive(model, loader, val_loader, args):
120 | """
121 | This function measures clustering accuracies on inductive GCD setup in both GT, predicted number of classes
122 | Clustering : test dataset
123 | Stopping condition : val dataset
124 | Acc : test dataset
125 | """
126 |
127 | num_clusters = [args.num_labeled_classes + args.num_unlabeled_classes, args.num_clusters]
128 | acc = [0, 0]
129 | max_acc = [0, 0]
130 | tolerance = [0, 0]
131 | final_acc = [(0,0,0), (0,0,0)]
132 | print('Predicted number of clusters: ', args.num_clusters)
133 |
134 | all_feats = torch.zeros(size=(len(loader.dataset), args.feat_dim))
135 | new_feats = torch.zeros(size=(len(loader.dataset), args.feat_dim))
136 | targets = torch.zeros(len(loader.dataset), dtype=int)
137 | mask_lab = torch.zeros(len(loader.dataset), dtype=bool)
138 | mask_cls = torch.zeros(len(loader.dataset), dtype=bool)
139 |
140 | all_feats_val = []
141 | new_feats_val = []
142 | targets_val = []
143 | mask_cls_val = []
144 | with torch.no_grad():
145 | for epoch in range(args.epochs):
146 | # Save embeddings (test)
147 | for batch_idx, batch in enumerate(tqdm(loader)):
148 | images, label, uq_idxs = batch
149 | if epoch == 0:
150 | images = torch.Tensor(images).to(device)
151 | all_feats[uq_idxs] = model(images).detach().cpu()
152 | targets[uq_idxs] = label.cpu()
153 | else:
154 | classwise_sim = torch.einsum('b d, n d -> b n', all_feats[uq_idxs], all_feats)
155 | _, indices = classwise_sim.topk(k=args.k+1, dim=-1, largest=True, sorted=True)
156 | indices = indices[:, 1:]
157 | knn_emb = torch.mean(all_feats[indices].view(-1, args.k, args.feat_dim), dim=1)
158 | new_feats[uq_idxs] = (1-args.alpha) * all_feats[uq_idxs] + args.alpha * knn_emb.detach().cpu()
159 |
160 | if epoch == 0:
161 | mask_cls = np.isin(targets, range(len(args.train_classes))).astype(bool)
162 | targets = np.array(targets)
163 | else:
164 | norm = torch.sqrt(torch.sum((torch.pow(new_feats, 2)), dim=-1)).unsqueeze(1)
165 | new_feats = new_feats / norm
166 | all_feats = new_feats
167 |
168 | # Save embeddings (val)
169 | for batch_idx, batch in enumerate(tqdm(val_loader)):
170 | images, label, uq_idxs = batch[:3]
171 | if epoch == 0:
172 | images = torch.Tensor(images).to(device)
173 | all_feats_val.append(model(images).detach().cpu())
174 | targets_val.append(label.cpu())
175 | else:
176 | start_idx = batch_idx*args.batch_size
177 | classwise_sim = torch.einsum('b d, n d -> b n', all_feats_val[start_idx:start_idx+len(uq_idxs)], all_feats_val)
178 | _, indices = classwise_sim.topk(k=args.k+1, dim=-1, largest=True, sorted=True)
179 | indices = indices[:, 1:]
180 | knn_emb_val = torch.mean(all_feats_val[indices].view(-1, args.k, args.feat_dim), dim=1)
181 | new_feats_val[start_idx:start_idx+len(uq_idxs)] = (1-args.alpha) * all_feats_val[start_idx:start_idx+len(uq_idxs)] + args.alpha * knn_emb_val.detach().cpu()
182 |
183 | if epoch == 0:
184 | all_feats_val = torch.cat(all_feats_val)
185 | targets_val = np.array(torch.cat(targets_val))
186 | mask_cls_val = np.isin(targets_val, range(len(args.train_classes))).astype(bool)
187 | new_feats_val = all_feats_val
188 | else:
189 | norm = torch.sqrt(torch.sum((torch.pow(new_feats_val, 2)), dim=-1)).unsqueeze(1)
190 | new_feats_val = new_feats_val / norm
191 | all_feats_val = new_feats_val
192 |
193 | # Agglomerative clustering
194 | linked = linkage(all_feats, method="ward")
195 | linked_val = linkage(all_feats_val, method="ward")
196 | for i in range(len(num_clusters)):
197 | if num_clusters[i]:
198 | print('num clusters', num_clusters[i])
199 | else:
200 | continue
201 |
202 | # acc of validation set
203 | threshold = linked[:, 2][-num_clusters[i]]
204 | preds_val = fcluster(linked_val, t=threshold, criterion='distance')
205 | old_acc_val = compute_acc(targets_val[mask_cls_val], preds_val[mask_cls_val])
206 |
207 | # acc of test set
208 | threshold = linked[:, 2][-num_clusters[i]]
209 | preds = fcluster(linked, t=threshold, criterion='distance')
210 | all_acc_test, old_acc_test, new_acc_test = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask_cls,
211 | T=epoch, eval_funcs=args.eval_funcs, save_name='IMS test ACC', print_output=True)
212 |
213 | # Stopping condition with tolerance
214 | tolerance[i] = 0 if max_acc[i] < old_acc_val else tolerance[i] + 1
215 |
216 | if max_acc[i] <= old_acc_val:
217 | max_acc[i] = old_acc_val
218 | acc[i] = (all_acc_test, old_acc_test, new_acc_test)
219 |
220 | if tolerance[i] >= 2:
221 | num_clusters[i] = 0
222 | final_acc[i] = acc[i]
223 |
224 | # If both GT & predicted K stopped, break
225 | if sum(num_clusters) == 0:
226 | break
227 |
228 | print(f'ACC with GT number of clusters: All {final_acc[0][0]:.4f} | Old {final_acc[0][1]:.4f} | New {final_acc[0][2]:.4f}')
229 | print(f'ACC with predicted number of clusters: All {final_acc[1][0]:.4f} | Old {final_acc[1][1]:.4f} | New {final_acc[1][2]:.4f}')
230 |
231 |
232 | def compute_acc(y_true, y_pred):
233 | y_true = y_true.astype(int)
234 | old_classes_gt = set(y_true)
235 |
236 | assert y_pred.size == y_true.size
237 | D = max(y_pred.max(), y_true.max()) + 1
238 | w = np.zeros((D, D), dtype=int)
239 | for i in range(y_pred.size):
240 | w[y_pred[i], y_true[i]] += 1
241 | # w: pred x label count
242 | ind = linear_assignment(w.max() - w)
243 | ind = np.vstack(ind).T
244 |
245 | ind_map = {j: i for i, j in ind}
246 | total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
247 |
248 | return total_acc
249 |
250 |
251 | if __name__ == "__main__":
252 |
253 | parser = argparse.ArgumentParser(
254 | description='cluster',
255 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
256 | parser.add_argument('--batch_size', default=128, type=int)
257 | parser.add_argument('--num_workers', default=8, type=int)
258 | parser.add_argument('--model_name', type=str, default=None)
259 | parser.add_argument('--pretrain_path', type=str, default='./models')
260 | parser.add_argument('--transform', type=str, default='imagenet')
261 | parser.add_argument('--eval_funcs', type=str, default=['v2'])
262 | parser.add_argument('--use_ssb_splits', type=str2bool, default=True)
263 | parser.add_argument('--model_name', type=str, default='cms', help='Format is {model_name}_{pretrain}')
264 | parser.add_argument('--dataset_name', type=str, default='aircraft', help='options: cifar10, cifar100, scars')
265 | parser.add_argument('--epochs', default=20, type=int)
266 | parser.add_argument('--feat_dim', default=768, type=int)
267 | parser.add_argument('--num_clusters', default=None, type=int)
268 | parser.add_argument('--inductive', action='store_true')
269 | parser.add_argument('--k', default=8, type=int)
270 | parser.add_argument('--alpha', type=float, default=0.5)
271 |
272 | # ----------------------
273 | # INIT
274 | # ----------------------
275 | args = parser.parse_args()
276 | device = torch.device('cuda:0')
277 | args = get_class_splits(args)
278 |
279 | args.num_labeled_classes = len(args.train_classes)
280 | args.num_unlabeled_classes = len(args.unlabeled_classes)
281 | print(args)
282 |
283 | args.model_name = "/home/suachoi/CMS/log/metric_learn_gcd/log/" + args.model_name + "/checkpoints/model_best.pt"
284 | print(f'Using weights from {args.model_name} ...')
285 |
286 | # ----------------------
287 | # MODEL
288 | # ----------------------
289 | args.interpolation = 3
290 | args.crop_pct = 0.875
291 | args.prop_train_labels = 0.5
292 | args.num_mlp_layers = 3
293 | args.feat_dim = 768
294 |
295 | model = DINO(args)
296 | model = nn.DataParallel(model)
297 | model.to(device)
298 | model.eval()
299 |
300 | state_dict = torch.load(args.model_name)
301 | model.load_state_dict(state_dict['model_state_dict'], strict=False)
302 | args.num_clusters = state_dict['k']
303 |
304 | # ----------------------
305 | # DATASET
306 | # ----------------------
307 | train_transform, test_transform = get_transform(args.transform, image_size=224, args=args)
308 | if args.inductive:
309 | _, test_dataset, _, val_dataset, _ = get_datasets_with_gcdval(args.dataset_name, test_transform, test_transform, args)
310 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
311 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
312 | iterative_meanshift_inductive(model, test_loader, val_loader, args)
313 | else:
314 | train_dataset, _, _, _ = get_datasets(args.dataset_name, test_transform, test_transform, args)
315 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
316 | iterative_meanshift(model, train_loader, args)
--------------------------------------------------------------------------------