├── project_utils
├── __init__.py
├── k_means_utils.py
├── pos_embed.py
├── slurm_out_parser.py
├── cluster_and_log_utils.py
├── schedulers.py
├── cluster_utils.py
└── general_utils.py
├── assets
└── overview.png
├── data
├── ssb_splits
│ ├── cub_osr_splits.pkl
│ ├── scars_osr_splits.pkl
│ ├── aircraft_osr_splits.pkl
│ └── herbarium_19_class_splits.pkl
├── data_utils.py
├── augmentations
│ ├── cut_out.py
│ ├── __init__.py
│ └── randaugment.py
├── stanford_cars.py
├── imagenet.py
├── cifar.py
├── get_datasets.py
├── cub.py
├── food.py
├── pets.py
├── flower.py
└── fgvc_aircraft.py
├── requirements.txt
├── bash_scripts
├── extract_features.sh
├── estimate_k.sh
├── ssk_means.sh
├── get_subset_len.sh
├── get_kmeans_subset.sh
└── representation_learning.sh
├── config.py
├── methods
├── clustering
│ ├── feature_vector_dataset.py
│ ├── ssk_means.py
│ ├── extract_features.py
│ └── faster_mix_k_means_pytorch.py
├── partitioning
│ ├── subset_len.py
│ └── kmeans_subset.py
└── estimate_k
│ └── estimate_k.py
├── README.md
└── model
└── vision_transformer.py
/project_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiXXin/XCon/HEAD/assets/overview.png
--------------------------------------------------------------------------------
/data/ssb_splits/cub_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiXXin/XCon/HEAD/data/ssb_splits/cub_osr_splits.pkl
--------------------------------------------------------------------------------
/data/ssb_splits/scars_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiXXin/XCon/HEAD/data/ssb_splits/scars_osr_splits.pkl
--------------------------------------------------------------------------------
/data/ssb_splits/aircraft_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiXXin/XCon/HEAD/data/ssb_splits/aircraft_osr_splits.pkl
--------------------------------------------------------------------------------
/data/ssb_splits/herbarium_19_class_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YiXXin/XCon/HEAD/data/ssb_splits/herbarium_19_class_splits.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | easydict==1.9
2 | matplotlib==3.5.2
3 | numpy==1.17.3
4 | pandas==1.0.3
5 | path.py==12.5.0
6 | Pillow==9.2.0
7 | scikit_learn==1.1.1
8 | scipy==1.7.3
9 | six==1.12.0
10 | tensorboard==2.4.0
11 | timm==0.4.12
12 | torch==1.10.0
13 | torchvision==0.11.1
14 | tqdm==4.36.1
15 |
--------------------------------------------------------------------------------
/bash_scripts/extract_features.sh:
--------------------------------------------------------------------------------
1 | PYTHON='python'
2 |
3 | hostname
4 | nvidia-smi
5 |
6 | export CUDA_VISIBLE_DEVICES=0
7 |
8 | ${PYTHON} -m methods.clustering.extract_features --dataset cub --use_best_model 'True' \
9 | --warmup_model_dir './XCon_outputs/metric_learn_gcd/log/(02.08.2022_|_25.076)/checkpoints/model.pt'
--------------------------------------------------------------------------------
/bash_scripts/estimate_k.sh:
--------------------------------------------------------------------------------
1 | PYTHON='python'
2 |
3 | hostname
4 | export CUDA_VISIBLE_DEVICES=0
5 |
6 | # Get unique log file
7 | SAVE_DIR=./XCon_outputs/outputs/
8 |
9 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l)
10 | EXP_NUM=$((${EXP_NUM}+1))
11 | echo $EXP_NUM
12 |
13 | ${PYTHON} -m methods.estimate_k.estimate_k --max_classes 1000 --dataset_name cub --search_mode other \
14 | --warmup_model_exp_id '(02.08.2022_|_25.076)' --use_best_model 'True' \
15 | > ${SAVE_DIR}logfile_${EXP_NUM}.out
--------------------------------------------------------------------------------
/bash_scripts/ssk_means.sh:
--------------------------------------------------------------------------------
1 | PYTHON='python'
2 |
3 | hostname
4 | nvidia-smi
5 |
6 | export CUDA_VISIBLE_DEVICES=0
7 |
8 | # Get unique log file
9 | SAVE_DIR=./XCon_outputs/outputs/
10 |
11 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l)
12 | EXP_NUM=$((${EXP_NUM}+1))
13 | echo $EXP_NUM
14 |
15 | ${PYTHON} -m methods.clustering.ssk_means --dataset 'cub' --semi_sup 'True' --use_ssb_splits 'True' \
16 | --use_best_model 'True' --max_kmeans_iter 200 --k_means_init 100 --warmup_model_exp_id '(02.08.2022_|_25.076)'\
17 | > ${SAVE_DIR}logfile_${EXP_NUM}.out
--------------------------------------------------------------------------------
/bash_scripts/get_subset_len.sh:
--------------------------------------------------------------------------------
1 | PYTHON='python'
2 |
3 | hostname
4 | nvidia-smi
5 |
6 | export CUDA_VISIBLE_DEVICES=0
7 |
8 | # Get unique log file,
9 | SAVE_DIR=./XCon_outputs/outputs/
10 |
11 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l)
12 | EXP_NUM=$((${EXP_NUM}+1))
13 | echo $EXP_NUM
14 |
15 | ${PYTHON} -m methods.partitioning.subset_len \
16 | --dataset_name 'cub' \
17 | --batch_size 256 \
18 | --num_workers 4 \
19 | --use_ssb_splits 'True' \
20 | --transform 'imagenet' \
21 | --eval_funcs 'v2' \
22 | --experts_num 8 \
--------------------------------------------------------------------------------
/bash_scripts/get_kmeans_subset.sh:
--------------------------------------------------------------------------------
1 | PYTHON='python'
2 |
3 | hostname
4 | nvidia-smi
5 |
6 | export CUDA_VISIBLE_DEVICES=0
7 |
8 | # Get unique log file,
9 | SAVE_DIR=./XCon_outputs/outputs/
10 |
11 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l)
12 | EXP_NUM=$((${EXP_NUM}+1))
13 | echo $EXP_NUM
14 |
15 | ${PYTHON} -m methods.partitioning.kmeans_subset \
16 | --dataset_name 'cub' \
17 | --batch_size 256 \
18 | --grad_from_block 11 \
19 | --epochs 100 \
20 | --base_model vit_dino \
21 | --num_workers 4 \
22 | --use_ssb_splits 'True' \
23 | --weight_decay 5e-5 \
24 | --transform 'imagenet' \
25 | --lr 0.1 \
26 | --eval_funcs 'v2' \
27 | --pretrain_model 'dino' \
28 | --experts_num 8 \
--------------------------------------------------------------------------------
/bash_scripts/representation_learning.sh:
--------------------------------------------------------------------------------
1 | PYTHON='python'
2 |
3 | hostname
4 | nvidia-smi
5 |
6 | export CUDA_VISIBLE_DEVICES=0
7 |
8 | # Get unique log file,
9 | SAVE_DIR=./XCon_outputs/outputs/
10 |
11 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l)
12 | EXP_NUM=$((${EXP_NUM}+1))
13 | echo $EXP_NUM
14 |
15 | ${PYTHON} -m methods.representation_learning.representation_learning \
16 | --dataset_name 'cub' \
17 | --batch_size 256 \
18 | --grad_from_block 11 \
19 | --epochs 200 \
20 | --base_model vit_dino \
21 | --num_workers 4 \
22 | --use_ssb_splits 'True' \
23 | --sup_con_weight 0.35 \
24 | --weight_decay 5e-5 \
25 | --contrast_unlabel_only 'False' \
26 | --transform 'imagenet' \
27 | --lr 0.1 \
28 | --eval_funcs 'v2' \
29 | --val_epoch_size 10 \
30 | --use_best_model 'True' \
31 | --use_global_con 'True' \
32 | --expert_weight 0.1 \
33 | --max_kmeans_iter 200 \
34 | --k_means_init 100 \
35 | --best_new 'False' \
36 | --pretrain_model 'dino' \
37 | > ${SAVE_DIR}logfile_${EXP_NUM}.out
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # -----------------
2 | # DATASET ROOTS
3 | # -----------------
4 | cifar_10_root = '/data/dataset/cifar10'
5 | cifar_100_root = '/data/dataset/cifar100'
6 | cub_root = '/data/dataset/cub200'
7 | aircraft_root = '/data/user-data/fgvc/aircraft/fgvc-aircraft-2013b'
8 | imagenet_root = '/data/user-data/imagenet'
9 | cars_root = '/data/user-data/fgvc/cars'
10 |
11 | pets_root = '/data/user-data/fgvc/pets'
12 | flower_root = '/data/user-data/fgvc/flower102'
13 | food_root = '/data/user-data/fgvc/food-101'
14 |
15 | # OSR Split dir
16 | osr_split_dir = './data/ssb_splits'
17 |
18 | # -----------------
19 | # PRETRAIN PATHS
20 | # -----------------
21 | dino_pretrain_path = './pretrained_models/dino/dino_vitbase16_pretrain.pth'
22 | moco_pretrain_path = './pretrained_models/moco/vit-b-300ep.pth.tar'
23 | mae_pretrain_path = './pretrained_models/mae/mae_pretrain_vit_base.pth'
24 |
25 | # Dataset partitioning paths
26 | km_label_path = './partition_out/km_labels'
27 | subset_len_path = './partition_out/subset_len'
28 |
29 | # -----------------
30 | # OTHER PATHS
31 | # -----------------
32 | feature_extract_dir = './XCon_outputs/extracted_features' # Extract features to this directory
33 | exp_root = './XCon_outputs' # All logs and checkpoints will be saved here
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 | def subsample_instances(dataset, prop_indices_to_subsample=0.8):
6 |
7 | np.random.seed(0)
8 | subsample_indices = np.random.choice(range(len(dataset)), replace=False,
9 | size=(int(prop_indices_to_subsample * len(dataset)),))
10 |
11 | return subsample_indices
12 |
13 | class MergedDataset(Dataset):
14 |
15 | """
16 | Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them
17 | Allows you to iterate over them in parallel
18 | """
19 |
20 | def __init__(self, labelled_dataset, unlabelled_dataset):
21 |
22 | self.labelled_dataset = labelled_dataset
23 | self.unlabelled_dataset = unlabelled_dataset
24 | self.target_transform = None
25 |
26 | def __getitem__(self, item):
27 | if item < len(self.labelled_dataset):
28 | img, label, uq_idx = self.labelled_dataset[item]
29 | labeled_or_not = 1
30 |
31 | else:
32 | img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)]
33 | labeled_or_not = 0
34 |
35 |
36 | return img, label, uq_idx, np.array([labeled_or_not])
37 |
38 | def __len__(self):
39 | return len(self.unlabelled_dataset) + len(self.labelled_dataset)
40 |
41 |
--------------------------------------------------------------------------------
/data/augmentations/cut_out.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/hysts/pytorch_cutout
3 | """
4 |
5 | import torch
6 | import numpy as np
7 |
8 | def cutout(mask_size, p, cutout_inside, mask_color=(0, 0, 0)):
9 | mask_size_half = mask_size // 2
10 | offset = 1 if mask_size % 2 == 0 else 0
11 |
12 | def _cutout(image):
13 | image = np.asarray(image).copy()
14 |
15 | if np.random.random() > p:
16 | return image
17 |
18 | h, w = image.shape[:2]
19 |
20 | if cutout_inside:
21 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half
22 | cymin, cymax = mask_size_half, h + offset - mask_size_half
23 | else:
24 | cxmin, cxmax = 0, w + offset
25 | cymin, cymax = 0, h + offset
26 |
27 | cx = np.random.randint(cxmin, cxmax)
28 | cy = np.random.randint(cymin, cymax)
29 | xmin = cx - mask_size_half
30 | ymin = cy - mask_size_half
31 | xmax = xmin + mask_size
32 | ymax = ymin + mask_size
33 | xmin = max(0, xmin)
34 | ymin = max(0, ymin)
35 | xmax = min(w, xmax)
36 | ymax = min(h, ymax)
37 | image[ymin:ymax, xmin:xmax] = mask_color
38 | return image
39 |
40 | return _cutout
41 |
42 | def to_tensor():
43 | def _to_tensor(image):
44 | if len(image.shape) == 3:
45 | return torch.from_numpy(
46 | image.transpose(2, 0, 1).astype(float))
47 | else:
48 | return torch.from_numpy(image[None, :, :].astype(float))
49 |
50 | return _to_tensor
51 |
52 | def normalize(mean, std):
53 |
54 | mean = np.array(mean)
55 | std = np.array(std)
56 |
57 | def _normalize(image):
58 | image = np.asarray(image).astype(float) / 255.
59 | image = (image - mean) / std
60 | return image
61 |
62 | return _normalize
--------------------------------------------------------------------------------
/methods/clustering/feature_vector_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 | import os
5 | from copy import deepcopy
6 |
7 | from data.data_utils import MergedDataset
8 |
9 |
10 | class FeatureVectorDataset(Dataset):
11 |
12 | def __init__(self, base_dataset, feature_root):
13 |
14 | """
15 | Dataset loads feature vectors instead of images
16 | :param base_dataset: Dataset from which images would come
17 | :param feature_root: Root directory of features
18 |
19 | feature_root should be structured as:
20 | feature_root/class_label/uq_idx.pt (torch files)
21 | """
22 |
23 | self.base_dataset = base_dataset
24 | self.target_transform = deepcopy(base_dataset.target_transform)
25 |
26 | self.base_dataset.target_transform = None
27 | self.base_dataset.transform = None
28 |
29 | self.feature_root = feature_root
30 |
31 | # if uncomment, cannot find the corresponding feature path(only when using expert models)
32 | if isinstance(self.base_dataset, MergedDataset):
33 | self.base_dataset.labelled_dataset.target_transform = None
34 | self.base_dataset.unlabelled_dataset.target_transform = None
35 |
36 | def __getitem__(self, item):
37 |
38 | if isinstance(self.base_dataset, MergedDataset):
39 |
40 | # Get meta info for this instance
41 | _, label, uq_idx, mask_lab = self.base_dataset[item]
42 |
43 | # Load feature vector
44 | feat_path = os.path.join(self.feature_root, f'{label}', f'{uq_idx}.npy')
45 | feature_vector = torch.load(feat_path)
46 |
47 | if self.target_transform is not None:
48 | label = self.target_transform(label)
49 |
50 | return feature_vector, label, uq_idx, mask_lab[0]
51 |
52 | else:
53 |
54 | # Get meta info for this instance
55 | _, label, uq_idx = self.base_dataset[item]
56 |
57 | # Load feature vector
58 | feat_path = os.path.join(self.feature_root, f'{label}', f'{uq_idx}.npy')
59 | feature_vector = torch.load(feat_path)
60 |
61 | if self.target_transform is not None:
62 | label = self.target_transform(label)
63 |
64 | return feature_vector, label, uq_idx
65 |
66 |
67 | def __len__(self):
68 | return len(self.base_dataset)
--------------------------------------------------------------------------------
/project_utils/k_means_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | from torch.utils.data import DataLoader
5 | import numpy as np
6 | from sklearn.cluster import KMeans
7 | import torch
8 | from project_utils.cluster_and_log_utils import log_accs_from_preds
9 |
10 | from methods.clustering.faster_mix_k_means_pytorch import K_Means as SemiSupKMeans
11 |
12 | from tqdm import tqdm
13 |
14 | # TODO: Debug
15 | import warnings
16 | warnings.filterwarnings("ignore", category=DeprecationWarning)
17 |
18 |
19 | def test_kmeans_semi_sup(model, test_loader, epoch, save_name, args, K=None):
20 |
21 | """
22 | In this case, the test loader needs to have the labelled and unlabelled subsets of the training data
23 | """
24 |
25 | if K is None:
26 | K = args.num_labeled_classes + args.num_unlabeled_classes
27 |
28 | model.eval()
29 | device = torch.device('cuda:0')
30 |
31 | all_feats = []
32 | targets = np.array([])
33 | mask_lab = np.array([]) # From all the data, which instances belong to the labelled set
34 | mask_cls = np.array([]) # From all the data, which instances belong to Old classes
35 |
36 | print('Collating features...')
37 | # First extract all features
38 | for batch_idx, (images, label, _, mask_lab_) in enumerate(tqdm(test_loader)):
39 |
40 | images = images.to(device)
41 |
42 | feats = model(images)
43 |
44 | feats = torch.nn.functional.normalize(feats, dim=-1)
45 |
46 | all_feats.append(feats.cpu().numpy())
47 | targets = np.append(targets, label.cpu().numpy())
48 | mask_cls = np.append(mask_cls, np.array([True if x.item() in range(len(args.train_classes))
49 | else False for x in label]))
50 | mask_lab = np.append(mask_lab, mask_lab_.cpu().bool().numpy())
51 |
52 | # -----------------------
53 | # K-MEANS
54 | # -----------------------
55 | mask_lab = mask_lab.astype(bool)
56 | mask_cls = mask_cls.astype(bool)
57 |
58 | all_feats = np.concatenate(all_feats)
59 |
60 | l_feats = all_feats[mask_lab] # Get labelled set
61 | u_feats = all_feats[~mask_lab] # Get unlabelled set
62 | l_targets = targets[mask_lab] # Get labelled targets
63 | u_targets = targets[~mask_lab] # Get unlabelled targets
64 |
65 | print('Fitting Semi-Supervised K-Means...')
66 | kmeans = SemiSupKMeans(k=K, tolerance=1e-4, max_iterations=args.max_kmeans_iter, init='k-means++',
67 | n_init=args.k_means_init, random_state=None, n_jobs=None, pairwise_batch_size=1024, mode=None)
68 |
69 | l_feats, u_feats, l_targets, u_targets = (torch.from_numpy(x).to(device) for
70 | x in (l_feats, u_feats, l_targets, u_targets))
71 |
72 | kmeans.fit_mix(u_feats, l_feats, l_targets)
73 | all_preds = kmeans.labels_.cpu().numpy()
74 | u_targets = u_targets.cpu().numpy()
75 |
76 | # -----------------------
77 | # EVALUATE
78 | # -----------------------
79 | # Get preds corresponding to unlabelled set
80 | preds = all_preds[~mask_lab]
81 |
82 | # Get portion of mask_cls which corresponds to the unlabelled set
83 | mask = mask_cls[~mask_lab]
84 | mask = mask.astype(bool)
85 |
86 | # -----------------------
87 | # EVALUATE
88 | # -----------------------
89 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=u_targets, y_pred=preds, mask=mask, eval_funcs=args.eval_funcs,
90 | save_name='SS-K-Means Train ACC Unlabelled', T=epoch, print_output=True)
91 |
92 | return all_acc, old_acc, new_acc, kmeans
--------------------------------------------------------------------------------
/project_utils/pos_embed.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | # --------------------------------------------------------
6 | # 2D sine-cosine position embedding
7 | # References:
8 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
9 | # MoCo v3: https://github.com/facebookresearch/moco-v3
10 | # --------------------------------------------------------
11 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
12 | """
13 | grid_size: int of the grid height and width
14 | return:
15 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
16 | """
17 | grid_h = np.arange(grid_size, dtype=np.float32)
18 | grid_w = np.arange(grid_size, dtype=np.float32)
19 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
20 | grid = np.stack(grid, axis=0)
21 |
22 | grid = grid.reshape([2, 1, grid_size, grid_size])
23 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
24 | if cls_token:
25 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
26 | return pos_embed
27 |
28 |
29 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
30 | assert embed_dim % 2 == 0
31 |
32 | # use half of dimensions to encode grid_h
33 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
34 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
35 |
36 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
37 | return emb
38 |
39 |
40 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
41 | """
42 | embed_dim: output dimension for each position
43 | pos: a list of positions to be encoded: size (M,)
44 | out: (M, D)
45 | """
46 | assert embed_dim % 2 == 0
47 | omega = np.arange(embed_dim // 2, dtype=np.float)
48 | omega /= embed_dim / 2.
49 | omega = 1. / 10000**omega # (D/2,)
50 |
51 | pos = pos.reshape(-1) # (M,)
52 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
53 |
54 | emb_sin = np.sin(out) # (M, D/2)
55 | emb_cos = np.cos(out) # (M, D/2)
56 |
57 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
58 | return emb
59 |
60 |
61 | # --------------------------------------------------------
62 | # Interpolate position embeddings for high-resolution
63 | # References:
64 | # DeiT: https://github.com/facebookresearch/deit
65 | # --------------------------------------------------------
66 | def interpolate_pos_embed(model, checkpoint_model):
67 | if 'pos_embed' in checkpoint_model:
68 | pos_embed_checkpoint = checkpoint_model['pos_embed']
69 | embedding_size = pos_embed_checkpoint.shape[-1]
70 | num_patches = model.patch_embed.num_patches
71 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
72 | # height (== width) for the checkpoint position embedding
73 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
74 | # height (== width) for the new position embedding
75 | new_size = int(num_patches ** 0.5)
76 | # class_token and dist_token are kept unchanged
77 | if orig_size != new_size:
78 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
79 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
80 | # only the position tokens are interpolated
81 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
82 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
83 | pos_tokens = torch.nn.functional.interpolate(
84 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
85 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
86 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
87 | checkpoint_model['pos_embed'] = new_pos_embed
--------------------------------------------------------------------------------
/project_utils/slurm_out_parser.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pandas as pd
3 | import os
4 |
5 | pd.options.display.width = 0
6 |
7 | rx_dict = {
8 | 'model_dir': re.compile(r'model_dir=\'(.*?)\''),
9 | 'dataset': re.compile(r'dataset_name=\'(.*?)\''),
10 | 'm': re.compile(r'rand_aug_m=([-+]?\d*)'),
11 | 'n': re.compile(r'rand_aug_n=([-+]?\d*)'),
12 | 'wd': re.compile("weight_decay=(.*?),"),
13 | 'sc_weight': re.compile("sup_con_weight=(.*?),"),
14 | 'split_idx': re.compile(r'split_idx=(\d)'),
15 | 'epochs': re.compile(r'Train Epoch: (\d)'),
16 | 'part_loss_mode': re.compile(r'part_loss_mode=(\d)'),
17 | 'consistency_weight': re.compile(r'consistency_weight=(.*?),'),
18 | 'lr': re.compile("lr=([-+]?\d*\.\d+|\d+)"),
19 | 'Train Accs': re.compile("Train Accuracies: ([-+]?\d*\.\d+|\d+)")
20 | }
21 |
22 | save_root_dir = '/work/sagar/open_set_recognition/sweep_summary_files/ensemble_pkls'
23 |
24 |
25 | def get_file(path):
26 |
27 | file = []
28 | with open(path, 'rt') as myfile:
29 | for myline in myfile: # For each line, read to a string,
30 | file.append(myline)
31 |
32 | return file
33 |
34 |
35 | def parse_out_file(path, rx_dict, root_dir=save_root_dir, save_name='test.pkl', save=True, verbose=True):
36 |
37 | file = get_file(path=path)
38 | for i, line in enumerate(file):
39 | if line.find('Namespace') != -1:
40 |
41 | model = {}
42 | s = rx_dict['model_dir'].search(line).group(1)
43 | exp_id = s[s.find("("):s.find(")") + 1]
44 | model['exp_id'] = exp_id
45 | model['dataset'] = rx_dict['dataset'].search(line).group(1)
46 | model['lr'] = rx_dict['lr'].search(line).group(1)
47 |
48 | break
49 |
50 | reverse_file = file[::-1]
51 | for i, line in enumerate(reverse_file):
52 | if line.find('Train Accuracies') != -1:
53 |
54 | model['Train Mean'] = re.findall("\d+\.\d+", line)[0]
55 | model['Train Old'] = re.findall("\d+\.\d+", line)[1]
56 | model['Train New'] = re.findall("\d+\.\d+", line)[2]
57 |
58 | for i_, line_ in enumerate(reverse_file[i:]):
59 | if 'Train Epoch' in line_:
60 | model['Last Epoch'] = re.findall('Train Epoch: (\d+)', line_)[0]
61 | break
62 |
63 | break
64 |
65 | for i, line in enumerate(reverse_file):
66 | if line.find('Best Train Accuracies') != -1:
67 |
68 | model['Best Train Mean'] = re.findall("\d+\.\d+", line)[0]
69 | model['Best Train Old'] = re.findall("\d+\.\d+", line)[1]
70 | model['Best Train New'] = re.findall("\d+\.\d+", line)[2]
71 |
72 | for i_, line_ in enumerate(reverse_file[i:]):
73 | if 'Train Epoch' in line_:
74 | model['Best Epoch'] = re.findall('Train Epoch: (\d+)', line_)[0]
75 | break
76 |
77 | break
78 |
79 |
80 | data = pd.DataFrame([model])
81 |
82 | if verbose:
83 | print(data)
84 |
85 | if save:
86 |
87 | save_path = os.path.join(root_dir, save_name)
88 | data.to_pickle(save_path)
89 |
90 | else:
91 |
92 | return data
93 |
94 |
95 | def parse_multiple_files(all_paths, rx_dict, root_dir=save_root_dir, save_name='test.pkl', verbose=True, save=False):
96 |
97 | all_data = []
98 | for path in all_paths:
99 |
100 | data = parse_out_file(path, rx_dict, save=False, verbose=False)
101 | all_data.append(data)
102 |
103 | all_data = pd.concat(all_data)
104 | save_path = os.path.join(root_dir, save_name)
105 |
106 | if save:
107 | all_data.to_pickle(save_path)
108 |
109 | if verbose:
110 | print(all_data)
111 |
112 | return all_data
113 |
114 |
115 | save_dir = '/work/sagar/open_set_recognition/sweep_summary_files/ensemble_pkls'
116 | base_path = '/work/sagar/osr_novel_categories/slurm_outputs/myLog-{}.out'
117 |
118 | all_paths = [base_path.format(i) for i in ['407814_{}'.format(j) for j in range(11)]]
119 |
120 | data = parse_multiple_files(all_paths, rx_dict, verbose=True, save=False, save_name='test.pkl')
--------------------------------------------------------------------------------
/project_utils/cluster_and_log_utils.py:
--------------------------------------------------------------------------------
1 | from project_utils.cluster_utils import cluster_acc, np, linear_assignment
2 | from torch.utils.tensorboard import SummaryWriter
3 | from typing import List
4 |
5 |
6 | def split_cluster_acc_v1(y_true, y_pred, mask):
7 |
8 | """
9 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask'
10 | (Mask usually corresponding to `Old' and `New' classes in GCD setting)
11 | :param targets: All ground truth labels
12 | :param preds: All predictions
13 | :param mask: Mask defining two subsets
14 | :return:
15 | """
16 |
17 | mask = mask.astype(bool)
18 | y_true = y_true.astype(int)
19 | y_pred = y_pred.astype(int)
20 | weight = mask.mean()
21 |
22 | old_acc = cluster_acc(y_true[mask], y_pred[mask])
23 | new_acc = cluster_acc(y_true[~mask], y_pred[~mask])
24 | total_acc = weight * old_acc + (1 - weight) * new_acc
25 |
26 | return total_acc, old_acc, new_acc
27 |
28 |
29 | def split_cluster_acc_v2(y_true, y_pred, mask):
30 | """
31 | Calculate clustering accuracy. Require scikit-learn installed
32 | First compute linear assignment on all data, then look at how good the accuracy is on subsets
33 |
34 | # Arguments
35 | mask: Which instances come from old classes (True) and which ones come from new classes (False)
36 | y: true labels, numpy.array with shape `(n_samples,)`
37 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
38 |
39 | # Return
40 | accuracy, in [0,1]
41 | """
42 | y_true = y_true.astype(int)
43 |
44 | old_classes_gt = set(y_true[mask])
45 | new_classes_gt = set(y_true[~mask])
46 |
47 | assert y_pred.size == y_true.size
48 | D = max(y_pred.max(), y_true.max()) + 1
49 | w = np.zeros((D, D), dtype=int)
50 | for i in range(y_pred.size):
51 | w[y_pred[i], y_true[i]] += 1
52 |
53 | ind = linear_assignment(w.max() - w)
54 | ind = np.vstack(ind).T
55 |
56 | ind_map = {j: i for i, j in ind}
57 | total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
58 |
59 | old_acc = 0
60 | total_old_instances = 0
61 | for i in old_classes_gt:
62 | old_acc += w[ind_map[i], i]
63 | total_old_instances += sum(w[:, i])
64 | old_acc /= total_old_instances
65 |
66 | new_acc = 0
67 | total_new_instances = 0
68 | for i in new_classes_gt:
69 | new_acc += w[ind_map[i], i]
70 | total_new_instances += sum(w[:, i])
71 | new_acc /= total_new_instances
72 |
73 | return total_acc, old_acc, new_acc
74 |
75 |
76 | EVAL_FUNCS = {
77 | 'v1': split_cluster_acc_v1,
78 | 'v2': split_cluster_acc_v2,
79 | }
80 |
81 | def log_accs_from_preds(y_true, y_pred, mask, eval_funcs: List[str], save_name: str, T: int=None, writer: SummaryWriter=None,
82 | print_output=False):
83 |
84 | """
85 | Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results
86 |
87 | :param y_true: GT labels
88 | :param y_pred: Predicted indices
89 | :param mask: Which instances belong to Old and New classes
90 | :param T: Epoch
91 | :param eval_funcs: Which evaluation functions to use
92 | :param save_name: What are we evaluating ACC on
93 | :param writer: Tensorboard logger
94 | :return:
95 | """
96 |
97 | mask = mask.astype(bool)
98 | y_true = y_true.astype(int)
99 | y_pred = y_pred.astype(int)
100 |
101 | for i, f_name in enumerate(eval_funcs):
102 |
103 | acc_f = EVAL_FUNCS[f_name]
104 | all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask)
105 | log_name = f'{save_name}_{f_name}'
106 |
107 | if writer is not None:
108 | writer.add_scalars(log_name,
109 | {'Old': old_acc, 'New': new_acc,
110 | 'All': all_acc}, T)
111 |
112 | if i == 0:
113 | to_return = (all_acc, old_acc, new_acc)
114 |
115 | if print_output:
116 | print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}'
117 | print(print_str)
118 |
119 | return to_return
--------------------------------------------------------------------------------
/project_utils/schedulers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 | def get_scheduler(optimizer, args):
6 |
7 | if args.scheduler == 'step':
8 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
9 | gamma=0.1,
10 | step_size=150)
11 |
12 | elif args.scheduler == 'plateau':
13 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=50)
14 |
15 | elif args.scheduler == 'cosine':
16 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
17 | T_max=args.epochs, eta_min=args.lr * 1e-3)
18 |
19 | elif args.scheduler == 'cosine_warm_restarts':
20 |
21 | try: num_restarts = args.num_restarts
22 | except: print('Warning: Num restarts not specified...using 2'); num_restarts = 2
23 |
24 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
25 | T_0=int(args.epochs / (num_restarts + 1)),
26 | eta_min=args.lr * 1e-3)
27 |
28 | elif args.scheduler == 'cosine_warm_restarts_warmup':
29 |
30 | try: num_restarts = args.num_restarts
31 | except: print('Warning: Num restarts not specified...using 2'); num_restarts = 2
32 |
33 | scheduler = CosineAnnealingWarmupRestarts_New(warmup_epochs=10, optimizer=optimizer,
34 | T_0=int(args.epochs / (num_restarts + 1)),
35 | eta_min=args.lr * 1e-3)
36 |
37 | elif args.scheduler == 'warm_restarts_plateau':
38 | scheduler = WarmRestartPlateau(T_restart=120, optimizer=optimizer, threshold_mode='abs', threshold=0.5,
39 | mode='min', patience=100)
40 |
41 | elif args.scheduler == 'multi_step':
42 |
43 | try:
44 |
45 | steps = args.steps
46 |
47 | except:
48 |
49 | print('Warning: No step list for Multi-Step Scheduler, using constant step of 30 epochs')
50 | steps = [30 * i for i in range(1, 5)]
51 |
52 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps)
53 |
54 | else:
55 |
56 | raise NotImplementedError
57 |
58 | return scheduler
59 |
60 |
61 | class WarmRestartPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
62 |
63 | """
64 | Reduce learning rate on plateau and reset every T_restart epochs
65 | """
66 |
67 | def __init__(self, T_restart, *args, ** kwargs):
68 |
69 | super().__init__(*args, **kwargs)
70 |
71 | self.T_restart = T_restart
72 | self.base_lrs = [group['lr'] for group in self.optimizer.param_groups]
73 |
74 | def step(self, *args, **kwargs):
75 |
76 | super().step(*args, **kwargs)
77 |
78 | if self.last_epoch > 0 and self.last_epoch % self.T_restart == 0:
79 |
80 | for group, lr in zip(self.optimizer.param_groups, self.base_lrs):
81 | group['lr'] = lr
82 |
83 | self._reset()
84 |
85 |
86 | class CosineAnnealingWarmupRestarts_New(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):
87 |
88 | def __init__(self, warmup_epochs, *args, **kwargs):
89 |
90 | super(CosineAnnealingWarmupRestarts_New, self).__init__(*args, **kwargs)
91 |
92 | # Init optimizer with low learning rate
93 | for param_group in self.optimizer.param_groups:
94 | param_group['lr'] = self.eta_min
95 |
96 | self.warmup_epochs = warmup_epochs
97 |
98 | # Get target LR after warmup is complete
99 | target_lr = self.eta_min + (self.base_lrs[0] - self.eta_min) * (1 + math.cos(math.pi * warmup_epochs / self.T_i)) / 2
100 |
101 | # Linearly interpolate between minimum lr and target_lr
102 | linear_step = (target_lr - self.eta_min) / self.warmup_epochs
103 | self.warmup_lrs = [self.eta_min + linear_step * (n + 1) for n in range(warmup_epochs)]
104 |
105 | def step(self, epoch=None):
106 |
107 | # Called on super class init
108 | if epoch is None:
109 | super(CosineAnnealingWarmupRestarts_New, self).step(epoch=epoch)
110 |
111 | else:
112 | if epoch < self.warmup_epochs:
113 | lr = self.warmup_lrs[epoch]
114 | for param_group in self.optimizer.param_groups:
115 | param_group['lr'] = lr
116 |
117 | # Fulfill misc super() funcs
118 | self.last_epoch = math.floor(epoch)
119 | self.T_cur = epoch
120 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
121 |
122 | else:
123 |
124 | super(CosineAnnealingWarmupRestarts_New, self).step(epoch=epoch)
--------------------------------------------------------------------------------
/data/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from data.augmentations.cut_out import *
3 | from data.augmentations.randaugment import RandAugment
4 |
5 | def get_transform(transform_type='default', image_size=32, args=None):
6 |
7 | if transform_type == 'imagenet':
8 |
9 | mean = (0.485, 0.456, 0.406)
10 | std = (0.229, 0.224, 0.225)
11 | interpolation = args.interpolation
12 | crop_pct = args.crop_pct
13 |
14 | train_transform = transforms.Compose([
15 | transforms.Resize(int(image_size / crop_pct), interpolation),
16 | transforms.RandomCrop(image_size),
17 | transforms.RandomHorizontalFlip(p=0.5),
18 | transforms.ColorJitter(),
19 | transforms.ToTensor(),
20 | transforms.Normalize(
21 | mean=torch.tensor(mean),
22 | std=torch.tensor(std))
23 | ])
24 |
25 | test_transform = transforms.Compose([
26 | transforms.Resize(int(image_size / crop_pct), interpolation),
27 | transforms.CenterCrop(image_size),
28 | transforms.ToTensor(),
29 | transforms.Normalize(
30 | mean=torch.tensor(mean),
31 | std=torch.tensor(std))
32 | ])
33 |
34 | elif transform_type == 'pytorch-cifar':
35 |
36 | mean = (0.4914, 0.4822, 0.4465)
37 | std = (0.2023, 0.1994, 0.2010)
38 |
39 | train_transform = transforms.Compose([
40 | transforms.RandomCrop(image_size, padding=4),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=mean, std=std),
44 | ])
45 |
46 | test_transform = transforms.Compose([
47 | transforms.Resize((image_size, image_size)),
48 | transforms.ToTensor(),
49 | transforms.Normalize(mean=mean, std=std),
50 | ])
51 |
52 | elif transform_type == 'herbarium_default':
53 |
54 | train_transform = transforms.Compose([
55 | transforms.Resize((image_size, image_size)),
56 | transforms.RandomResizedCrop(image_size, scale=(args.resize_lower_bound, 1)),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | ])
60 |
61 | test_transform = transforms.Compose([
62 | transforms.Resize((image_size, image_size)),
63 | transforms.ToTensor(),
64 | ])
65 |
66 | elif transform_type == 'cutout':
67 |
68 | mean = np.array([0.4914, 0.4822, 0.4465])
69 | std = np.array([0.2470, 0.2435, 0.2616])
70 |
71 | train_transform = transforms.Compose([
72 | transforms.RandomCrop(image_size, padding=4),
73 | transforms.RandomHorizontalFlip(),
74 | normalize(mean, std),
75 | cutout(mask_size=int(image_size / 2),
76 | p=1,
77 | cutout_inside=False),
78 | to_tensor(),
79 | ])
80 | test_transform = transforms.Compose([
81 | transforms.Resize((image_size, image_size)),
82 | transforms.ToTensor(),
83 | transforms.Normalize(mean, std),
84 | ])
85 |
86 | elif transform_type == 'rand-augment':
87 |
88 | mean = (0.485, 0.456, 0.406)
89 | std = (0.229, 0.224, 0.225)
90 |
91 | train_transform = transforms.Compose([
92 | transforms.Resize((image_size, image_size)),
93 | transforms.RandomCrop(image_size, padding=4),
94 | transforms.RandomHorizontalFlip(),
95 | transforms.ToTensor(),
96 | transforms.Normalize(mean=mean, std=std),
97 | ])
98 |
99 | train_transform.transforms.insert(0, RandAugment(args.rand_aug_n, args.rand_aug_m, args=None))
100 |
101 | test_transform = transforms.Compose([
102 | transforms.Resize((image_size, image_size)),
103 | transforms.ToTensor(),
104 | transforms.Normalize(mean=mean, std=std),
105 | ])
106 |
107 | elif transform_type == 'random_affine':
108 |
109 | mean = (0.485, 0.456, 0.406)
110 | std = (0.229, 0.224, 0.225)
111 | interpolation = args.interpolation
112 | crop_pct = args.crop_pct
113 |
114 | train_transform = transforms.Compose([
115 | transforms.Resize((image_size, image_size), interpolation),
116 | transforms.RandomAffine(degrees=(-45, 45),
117 | translate=(0.1, 0.1), shear=(-15, 15), scale=(0.7, args.crop_pct)),
118 | transforms.ColorJitter(),
119 | transforms.ToTensor(),
120 | transforms.Normalize(
121 | mean=torch.tensor(mean),
122 | std=torch.tensor(std))
123 | ])
124 |
125 | test_transform = transforms.Compose([
126 | transforms.Resize(int(image_size / crop_pct), interpolation),
127 | transforms.CenterCrop(image_size),
128 | transforms.ToTensor(),
129 | transforms.Normalize(
130 | mean=torch.tensor(mean),
131 | std=torch.tensor(std))
132 | ])
133 |
134 | else:
135 |
136 | raise NotImplementedError
137 |
138 | return (train_transform, test_transform)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # XCon: Learning with Experts for Fine-grained Category Discovery
2 | This repo contains the implementation of our paper: "XCon: Learning with Experts for Fine-grained Category Discovery". ([arXiv](https://arxiv.org/abs/2208.01898))
3 | ## Abstract
4 | We address the problem of generalized category discovery (GCD) in this paper,
5 | i.e. clustering the unlabeled images leveraging the information from a set of
6 | seen classes, where the unlabeled images could contain both seen classes and
7 | unseen classes. The seen classes can be seen as an implicit criterion of
8 | classes, which makes this setting different from unsupervised clustering where
9 | the cluster criteria may be ambiguous. We mainly concern the problem of
10 | discovering categories within a fine-grained dataset since it is one of the
11 | most direct applications of category discovery, i.e. helping experts discover
12 | novel concepts within an unlabeled dataset using the implicit criterion set
13 | forth by the seen classes. State-of-the-art methods for generalized category
14 | discovery leverage contrastive learning to learn the representations, but the
15 | large inter-class similarity and intra-class variance pose a challenge for the
16 | methods because the negative examples may contain irrelevant cues for
17 | recognizing a category so the algorithms may converge to a local-minima. We
18 | present a novel method called Expert-Contrastive Learning (XCon) to help the
19 | model to mine useful information from the images by first partitioning the
20 | dataset into sub-datasets using k-means clustering and then performing
21 | contrastive learning on each of the sub-datasets to learn fine-grained
22 | discriminative features. Experiments on fine-grained datasets show a clear
23 | improved performance over the previous best methods, indicating the
24 | effectiveness of our method.
25 |
26 | 
27 |
28 | ## Requirements
29 | - Python 3.8
30 | - Pytorch 1.10.0
31 | - torchvision 0.11.1
32 | ```
33 | pip install -r requirements.txt
34 | ```
35 |
36 | ## Datasets
37 | In our experiments, we use generic image classification datasets including [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html) and [ImageNet](https://image-net.org/download.php).
38 |
39 | We also use fine-grained image classification datasets including [CUB-200](https://www.kaggle.com/datasets/coolerextreme/cub-200-2011/versions/1), [Stanford-Cars](http://ai.stanford.edu/~jkrause/cars/car_dataset.html), [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) and [Oxford-Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/).
40 |
41 | ## Pretrained Checkpoints
42 | Our model is initialized with the parameters pretrained by DINO on ImageNet.
43 | The DINO checkpoint of ViT-B-16 is available at [here](https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth).
44 |
45 | ## Training and Evaluation Instructions
46 | ### Step 1. Set config
47 | Set the path of datasets and the directory for saving outputs in ```config.py```.
48 | ### Step 2. Dataset partitioning
49 | - Get the k-means labels for partitioning the dataset.
50 | ```
51 | bash bash_scripts/get_kmeans_subset.sh
52 | ```
53 | - Get the length of k expert sub-datasets.
54 | ```
55 | bash bash_scripts/get_subset_len.sh
56 | ```
57 | ### Step 3. Representation learning
58 | Fine-tune the model with the evaluation of semi-supervised k-means.
59 | ```
60 | bash bash_scripts/representation_learning.sh
61 | ```
62 | ### Step 4. Semi-supervised k-means
63 | To run the semi-supervised k-means alone by first running
64 | ```
65 | bash bash_scripts/extract_features.sh
66 | ```
67 | and then running
68 | ```
69 | bash bash_scripts/ssk_means.sh
70 | ```
71 | ### Step 5. Estimate the number of classes
72 | To estimate the number of classes in the unlabeled dataset by first running
73 | ```
74 | bash bash_scripts/extract_features.sh
75 | ```
76 | and then running
77 | ```
78 | bash bash_scripts/estimate_k.sh
79 | ```
80 |
81 | ## Results
82 | Results of our method are reported as below. You can download our model checkpoint by the link.
83 | | **Datasets** | **All** | **Old** | **New** | **Models** |
84 | |:------------|:--------:|:---------:|:---------:|:------:|
85 | | CIFAR10 | 96.0 | 97.3 | 95.4 | [ckpt](https://pan.baidu.com/s/1XKHioJp002Lm7P1xmM5Htg?pwd=xhwq) |
86 | | CIFAR100 | 74.2 | 81.2 | 60.3 | [ckpt](https://pan.baidu.com/s/1DbUpDpFj-dlO58w6GqhyKw?pwd=rvkd) |
87 | | ImageNet-100 | 77.6 | 93.5 | 69.7 | [ckpt](https://pan.baidu.com/s/1G1mY85up1ji2LLxMNBJjrw?pwd=rc7o) |
88 | | CUB-200 | 52.1 | 54.3 | 51.0 | [ckpt](https://pan.baidu.com/s/1gtuPMF-itQvt9r5kW7Y32Q?pwd=pg9m) |
89 | | Stanford-Cars | 40.5 | 58.8 | 31.7 | [ckpt](https://pan.baidu.com/s/1PDVhatM6qVUZZwjBVwSgTg?pwd=6337) |
90 | | FGVC-Aircraft | 47.7 | 44.4 | 49.4 | [ckpt](https://pan.baidu.com/s/1SwkobAaT8l-TTlYn7IXWyQ?pwd=06u1) |
91 | | Oxford-Pet | 86.7 | 91.5 | 84.1 | [ckpt](https://pan.baidu.com/s/1kCUfebbKmws9EgYrvgF5Aw?pwd=ck3k) |
92 |
93 | ## Citation
94 | If you find this repo useful for your research, please consider citing our paper:
95 | ```
96 | @inproceedings{fei2022xcon,
97 | title = {XCon: Learning with Experts for Fine-grained Category Discovery},
98 | author = {Yixin Fei and Zhongkai Zhao and Siwei Yang and Bingchen Zhao},
99 | booktitle={British Machine Vision Conference (BMVC)},
100 | year = {2022}
101 | }
102 | ```
--------------------------------------------------------------------------------
/methods/partitioning/subset_len.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from easydict import EasyDict
4 |
5 | from torch.utils.data import DataLoader
6 | import numpy as np
7 | from sklearn.cluster import KMeans
8 | import torch
9 | from torch.optim import SGD, lr_scheduler
10 | from project_utils.cluster_utils import mixed_eval, AverageMeter
11 | from model import vision_transformer as vits
12 |
13 | from project_utils.general_utils import init_experiment, get_mean_lr, str2bool, get_dino_head_weights
14 |
15 | from data.augmentations import get_transform
16 | from data.get_datasets import get_datasets, get_class_splits
17 | from data.data_utils import MergedDataset
18 |
19 | from tqdm import tqdm
20 |
21 | from torch.nn import functional as F
22 |
23 | from project_utils.cluster_and_log_utils import log_accs_from_preds
24 | from config import exp_root, km_label_path, subset_len_path, dino_pretrain_path
25 |
26 | from copy import deepcopy
27 | # TODO: Debug
28 | import warnings
29 | warnings.filterwarnings("ignore", category=DeprecationWarning)
30 |
31 | class ContrastiveLearningViewGenerator(object):
32 | """Take two random crops of one image as the query and key."""
33 |
34 | def __init__(self, base_transform, n_views=2):
35 | self.base_transform = base_transform
36 | self.n_views = n_views
37 |
38 | def __call__(self, x):
39 | return [self.base_transform(x) for i in range(self.n_views)]
40 |
41 |
42 | if __name__ == "__main__":
43 |
44 | parser = argparse.ArgumentParser(
45 | description='cluster',
46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
47 | parser.add_argument('--batch_size', default=256, type=int)
48 | parser.add_argument('--num_workers', default=4, type=int)
49 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v1', 'v2'])
50 |
51 | parser.add_argument('--exp_root', type=str, default=exp_root)
52 | parser.add_argument('--transform', type=str, default='imagenet')
53 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, scars')
54 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
55 | parser.add_argument('--use_ssb_splits', type=str2bool, default=False)
56 | parser.add_argument('--n_views', default=2, type=int)
57 |
58 | parser.add_argument('--experts_num', type=int, default=8)
59 |
60 | args = parser.parse_args()
61 | device = torch.device('cuda:0')
62 | args = get_class_splits(args)
63 |
64 | args.num_labeled_classes = len(args.train_classes)
65 | args.num_unlabeled_classes = len(args.unlabeled_classes)
66 |
67 | init_experiment(args, runner_name=['metric_learn_gcd'])
68 | print(f'Using evaluation function {args.eval_funcs[0]} to print results')
69 |
70 | whole_km_labels = np.load(f'{km_label_path}/{args.dataset_name}_km_labels.npy')
71 | experts_num = np.unique(whole_km_labels)
72 | print('subset_num:', experts_num)
73 | if len(experts_num) !=args.experts_num:
74 | raise NotImplementedError
75 |
76 | args.interpolation = 3
77 | args.crop_pct = 0.875
78 | args.image_size = 224
79 | args.feat_dim = 768
80 | args.num_mlp_layers = 3
81 | args.mlp_out_dim = 65536
82 |
83 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
84 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
85 | # --------------------
86 | # DATASETS
87 | # --------------------
88 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples = get_datasets(args.dataset_name,
89 | train_transform,
90 | test_transform,
91 | args)
92 | print('whole dataset:', len(train_dataset))
93 | print('labelled:', len(labelled_train_examples))
94 | print('unlabelled:', len(unlabelled_train_examples_test))
95 |
96 | # ----------------------
97 | # Building sub dataset
98 | # -------------------
99 | sub_data_loaders = []
100 | sub_num = []
101 | for subset_index in experts_num:
102 | # Get subset index
103 | subset_idx = []
104 | for idx, km_label in enumerate(whole_km_labels):
105 | if km_label == subset_index:
106 | subset_idx.append(idx)
107 |
108 | # Get the subset
109 | subset = torch.utils.data.Subset(deepcopy(train_dataset), subset_idx) # len(subset_unlabelled)=, len(datasets['train_unlabelled'])=4496
110 |
111 | # Get the numbers of each subset
112 | subdata_num = np.sum(whole_km_labels == subset_index)
113 | sub_num.append(subdata_num)
114 |
115 | # Dala loader for subset
116 | sub_data_loader = DataLoader(subset, num_workers=args.num_workers,
117 | batch_size=args.batch_size, shuffle=False)
118 | sub_data_loaders.append(sub_data_loader)
119 |
120 | label_nums = []
121 | unlabel_nums = []
122 | for loader in sub_data_loaders:
123 | label_num = 0
124 | unlabel_num = 0
125 | for batch in loader:
126 | images, class_labels, uq_idxs, mask_lab = batch
127 | mask_lab = mask_lab[:, 0].numpy()
128 | label_num += np.sum(mask_lab == 1)
129 | unlabel_num += np.sum(mask_lab == 0)
130 | label_nums.append(label_num)
131 | unlabel_nums.append(unlabel_num)
132 |
133 | print('labeled subset len:', label_nums)
134 | print('unlabeled subset len:', unlabel_nums)
135 | subset_nums = label_nums + unlabel_nums
136 |
137 | with open (f'{subset_len_path}/{args.dataset_name}_subset_len.txt', 'w') as f:
138 | for num in subset_nums:
139 | f.write(str(num)+'\n')
140 |
--------------------------------------------------------------------------------
/data/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 | from scipy import io as mat_io
6 |
7 | from torchvision.datasets.folder import default_loader
8 | from torch.utils.data import Dataset
9 |
10 | from data.data_utils import subsample_instances
11 |
12 | car_root = "/data/user-data/fgvc/cars/cars_{}/"
13 | meta_default_path = "/data/user-data/fgvc/cars/devkit/cars_{}.mat"
14 |
15 | class CarsDataset(Dataset):
16 | """
17 | Cars Dataset
18 | """
19 | def __init__(self, train=True, limit=0, data_dir=car_root, transform=None, metas=meta_default_path):
20 |
21 | data_dir = data_dir.format('train') if train else data_dir.format('test')
22 | metas = metas.format('train_annos') if train else metas.format('test_annos_withlabels')
23 |
24 | self.loader = default_loader
25 | self.data_dir = data_dir
26 | self.data = []
27 | self.target = []
28 | self.train = train
29 |
30 | self.transform = transform
31 |
32 | if not isinstance(metas, str):
33 | raise Exception("Train metas must be string location !")
34 | labels_meta = mat_io.loadmat(metas)
35 |
36 | for idx, img_ in enumerate(labels_meta['annotations'][0]):
37 | if limit:
38 | if idx > limit:
39 | break
40 |
41 | # self.data.append(img_resized)
42 | self.data.append(data_dir + img_[5][0])
43 | # if self.mode == 'train':
44 | self.target.append(img_[4][0][0])
45 |
46 | self.uq_idxs = np.array(range(len(self)))
47 | self.target_transform = None
48 |
49 | def __getitem__(self, idx):
50 |
51 | image = self.loader(self.data[idx])
52 | target = self.target[idx] - 1
53 |
54 | if self.transform is not None:
55 | image = self.transform(image)
56 |
57 | if self.target_transform is not None:
58 | target = self.target_transform(target)
59 |
60 | idx = self.uq_idxs[idx]
61 |
62 | return image, target, idx
63 |
64 | def __len__(self):
65 | return len(self.data)
66 |
67 |
68 | def subsample_dataset(dataset, idxs):
69 |
70 | dataset.data = np.array(dataset.data)[idxs].tolist()
71 | dataset.target = np.array(dataset.target)[idxs].tolist()
72 | dataset.uq_idxs = dataset.uq_idxs[idxs]
73 |
74 | return dataset
75 |
76 |
77 | def subsample_classes(dataset, include_classes=range(160)):
78 |
79 | include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195
80 | cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars]
81 |
82 | target_xform_dict = {}
83 | for i, k in enumerate(include_classes):
84 | target_xform_dict[k] = i
85 |
86 | dataset = subsample_dataset(dataset, cls_idxs)
87 |
88 | # dataset.target_transform = lambda x: target_xform_dict[x]
89 |
90 | return dataset
91 |
92 | def get_train_val_indices(train_dataset, val_split=0.2):
93 |
94 | train_classes = np.unique(train_dataset.target)
95 |
96 | # Get train/test indices
97 | train_idxs = []
98 | val_idxs = []
99 | for cls in train_classes:
100 |
101 | cls_idxs = np.where(train_dataset.target == cls)[0]
102 |
103 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
104 | t_ = [x for x in cls_idxs if x not in v_]
105 |
106 | train_idxs.extend(t_)
107 | val_idxs.extend(v_)
108 |
109 | return train_idxs, val_idxs
110 |
111 |
112 | def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
113 | split_train_val=False, seed=0):
114 |
115 | np.random.seed(seed)
116 |
117 | # Init entire training set
118 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, metas=meta_default_path, train=True)
119 |
120 | # Get labelled training set which has subsampled classes, then subsample some indices from that
121 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
122 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
123 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
124 |
125 | # Split into training and validation sets
126 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
127 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
128 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
129 | val_dataset_labelled_split.transform = test_transform
130 |
131 | # Get unlabelled data
132 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
133 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
134 |
135 | # Get test set for all classes
136 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, metas=meta_default_path, train=False)
137 |
138 | # Either split train into train and val or use test set as val
139 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
140 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
141 |
142 | all_datasets = {
143 | 'train_labelled': train_dataset_labelled,
144 | 'train_unlabelled': train_dataset_unlabelled,
145 | 'val': val_dataset_labelled,
146 | 'test': test_dataset,
147 | }
148 |
149 | return all_datasets
150 |
151 | if __name__ == '__main__':
152 |
153 | x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False)
154 |
155 | print('Printing lens...')
156 | for k, v in x.items():
157 | if v is not None:
158 | print(f'{k}: {len(v)}')
159 |
160 | print('Printing labelled and unlabelled overlap...')
161 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
162 | print('Printing total instances in train...')
163 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
164 |
165 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}')
166 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}')
167 | print(f'Len labelled set: {len(x["train_labelled"])}')
168 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import numpy as np
3 |
4 | import os
5 |
6 | from copy import deepcopy
7 | from data.data_utils import subsample_instances
8 | from config import imagenet_root
9 |
10 |
11 | class ImageNetBase(torchvision.datasets.ImageFolder):
12 |
13 | def __init__(self, root, transform):
14 |
15 | super(ImageNetBase, self).__init__(root, transform)
16 |
17 | self.uq_idxs = np.array(range(len(self)))
18 |
19 | def __getitem__(self, item):
20 |
21 | img, label = super().__getitem__(item)
22 | uq_idx = self.uq_idxs[item]
23 |
24 | return img, label, uq_idx
25 |
26 |
27 | def subsample_dataset(dataset, idxs):
28 |
29 | imgs_ = []
30 | for i in idxs:
31 | imgs_.append(dataset.imgs[i])
32 | dataset.imgs = imgs_
33 |
34 | samples_ = []
35 | for i in idxs:
36 | samples_.append(dataset.samples[i])
37 | dataset.samples = samples_
38 |
39 | # dataset.imgs = [x for i, x in enumerate(dataset.imgs) if i in idxs]
40 | # dataset.samples = [x for i, x in enumerate(dataset.samples) if i in idxs]
41 |
42 | dataset.targets = np.array(dataset.targets)[idxs].tolist()
43 | dataset.uq_idxs = dataset.uq_idxs[idxs]
44 |
45 | return dataset
46 |
47 |
48 | def subsample_classes(dataset, include_classes=list(range(1000))):
49 |
50 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
51 |
52 | target_xform_dict = {}
53 | for i, k in enumerate(include_classes):
54 | target_xform_dict[k] = i
55 |
56 | dataset = subsample_dataset(dataset, cls_idxs)
57 | dataset.target_transform = lambda x: target_xform_dict[x]
58 |
59 | return dataset
60 |
61 |
62 | def get_train_val_indices(train_dataset, val_split=0.2):
63 |
64 | train_classes = list(set(train_dataset.targets))
65 |
66 | # Get train/test indices
67 | train_idxs = []
68 | val_idxs = []
69 | for cls in train_classes:
70 |
71 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
72 |
73 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
74 | t_ = [x for x in cls_idxs if x not in v_]
75 |
76 | train_idxs.extend(t_)
77 | val_idxs.extend(v_)
78 |
79 | return train_idxs, val_idxs
80 |
81 |
82 | def get_equal_len_datasets(dataset1, dataset2):
83 | """
84 | Make two datasets the same length
85 | """
86 |
87 | if len(dataset1) > len(dataset2):
88 |
89 | rand_idxs = np.random.choice(range(len(dataset1)), size=(len(dataset2, )))
90 | subsample_dataset(dataset1, rand_idxs)
91 |
92 | elif len(dataset2) > len(dataset1):
93 |
94 | rand_idxs = np.random.choice(range(len(dataset2)), size=(len(dataset1, )))
95 | subsample_dataset(dataset2, rand_idxs)
96 |
97 | return dataset1, dataset2
98 |
99 |
100 | def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80),
101 | prop_train_labels=0.8, split_train_val=False, seed=0):
102 |
103 | np.random.seed(seed)
104 |
105 | # Subsample imagenet dataset initially to include 100 classes
106 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False)
107 | subsampled_100_classes = np.sort(subsampled_100_classes)
108 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}')
109 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))}
110 |
111 | # Init entire training set
112 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
113 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes)
114 |
115 | # Reset dataset
116 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples]
117 | whole_training_set.targets = [s[1] for s in whole_training_set.samples]
118 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set)))
119 | whole_training_set.target_transform = None
120 |
121 | # Get labelled training set which has subsampled classes, then subsample some indices from that
122 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
123 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
124 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
125 |
126 | # Split into training and validation sets
127 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
128 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
129 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
130 | val_dataset_labelled_split.transform = test_transform
131 |
132 | # Get unlabelled data
133 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
134 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
135 |
136 | # Get test set for all classes
137 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
138 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes)
139 |
140 | # Reset test set
141 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples]
142 | test_dataset.targets = [s[1] for s in test_dataset.samples]
143 | test_dataset.uq_idxs = np.array(range(len(test_dataset)))
144 | test_dataset.target_transform = None
145 |
146 | # Either split train into train and val or use test set as val
147 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
148 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
149 |
150 | all_datasets = {
151 | 'train_labelled': train_dataset_labelled,
152 | 'train_unlabelled': train_dataset_unlabelled,
153 | 'val': val_dataset_labelled,
154 | 'test': test_dataset,
155 | }
156 |
157 | return all_datasets
158 |
159 |
160 | if __name__ == '__main__':
161 |
162 | x = get_imagenet_100_datasets(None, None, split_train_val=False,
163 | train_classes=range(50), prop_train_labels=0.5)
164 |
165 | print('Printing lens...')
166 | for k, v in x.items():
167 | if v is not None:
168 | print(f'{k}: {len(v)}')
169 |
170 | print('Printing labelled and unlabelled overlap...')
171 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
172 | print('Printing total instances in train...')
173 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
174 |
175 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
176 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
177 | print(f'Len labelled set: {len(x["train_labelled"])}')
178 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/project_utils/cluster_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 | import numpy as np
3 | import sklearn.metrics
4 | import torch
5 | import torch.nn as nn
6 | import matplotlib
7 | matplotlib.use('agg')
8 | from scipy.optimize import linear_sum_assignment as linear_assignment
9 | import random
10 | import os
11 | import argparse
12 |
13 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
14 | from sklearn.metrics import adjusted_rand_score as ari_score
15 |
16 | from sklearn import metrics
17 | import time
18 |
19 | # -------------------------------
20 | # Evaluation Criteria
21 | # -------------------------------
22 | def evaluate_clustering(y_true, y_pred):
23 |
24 | start = time.time()
25 | print('Computing metrics...')
26 | if len(set(y_pred)) < 1000:
27 | acc = cluster_acc(y_true.astype(int), y_pred.astype(int))
28 | else:
29 | acc = None
30 |
31 | nmi = nmi_score(y_true, y_pred)
32 | ari = ari_score(y_true, y_pred)
33 | pur = purity_score(y_true, y_pred)
34 | print(f'Finished computing metrics {time.time() - start}...')
35 |
36 | return acc, nmi, ari, pur
37 |
38 |
39 | def cluster_acc(y_true, y_pred, return_ind=False):
40 | """
41 | Calculate clustering accuracy. Require scikit-learn installed
42 |
43 | # Arguments
44 | y: true labels, numpy.array with shape `(n_samples,)`
45 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
46 |
47 | # Return
48 | accuracy, in [0,1]
49 | """
50 | y_true = y_true.astype(int)
51 | assert y_pred.size == y_true.size
52 | D = max(y_pred.max(), y_true.max()) + 1
53 | w = np.zeros((D, D), dtype=int)
54 | for i in range(y_pred.size):
55 | w[y_pred[i], y_true[i]] += 1
56 |
57 | ind = linear_assignment(w.max() - w)
58 | ind = np.vstack(ind).T
59 |
60 | if return_ind:
61 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w
62 | else:
63 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
64 |
65 |
66 | def purity_score(y_true, y_pred):
67 | # compute contingency matrix (also called confusion matrix)
68 | contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred)
69 | # return purity
70 | return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)
71 |
72 | # -------------------------------
73 | # Mixed Eval Function
74 | # -------------------------------
75 | def mixed_eval(targets, preds, mask):
76 |
77 | """
78 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask'
79 | (Mask usually corresponding to `Old' and `New' classes in GCD setting)
80 | :param targets: All ground truth labels
81 | :param preds: All predictions
82 | :param mask: Mask defining two subsets
83 | :return:
84 | """
85 |
86 | mask = mask.astype(bool)
87 |
88 | # Labelled examples
89 | if mask.sum() == 0: # All examples come from unlabelled classes
90 |
91 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int), preds.astype(int)), \
92 | nmi_score(targets, preds), \
93 | ari_score(targets, preds)
94 |
95 | print('Unlabelled Classes Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'
96 | .format(unlabelled_acc, unlabelled_nmi, unlabelled_ari))
97 |
98 | # Also return ratio between labelled and unlabelled examples
99 | return (unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.mean()
100 |
101 | else:
102 |
103 | labelled_acc, labelled_nmi, labelled_ari = cluster_acc(targets.astype(int)[mask],
104 | preds.astype(int)[mask]), \
105 | nmi_score(targets[mask], preds[mask]), \
106 | ari_score(targets[mask], preds[mask])
107 |
108 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int)[~mask],
109 | preds.astype(int)[~mask]), \
110 | nmi_score(targets[~mask], preds[~mask]), \
111 | ari_score(targets[~mask], preds[~mask])
112 |
113 | # Also return ratio between labelled and unlabelled examples
114 | return (labelled_acc, labelled_nmi, labelled_ari), (
115 | unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.mean()
116 |
117 |
118 | class AverageMeter(object):
119 | """Computes and stores the average and current value"""
120 | def __init__(self):
121 | self.reset()
122 |
123 | def reset(self):
124 | self.val = 0
125 | self.avg = 0
126 | self.sum = 0
127 | self.count = 0
128 |
129 | def update(self, val, n=1):
130 | self.val = val
131 | self.sum += val * n
132 | self.count += n
133 | self.avg = self.sum / self.count
134 |
135 |
136 | class Identity(nn.Module):
137 | def __init__(self):
138 | super(Identity, self).__init__()
139 | def forward(self, x):
140 | return x
141 |
142 |
143 | class BCE(nn.Module):
144 | eps = 1e-7 # Avoid calculating log(0). Use the small value of float16.
145 | def forward(self, prob1, prob2, simi):
146 | # simi: 1->similar; -1->dissimilar; 0->unknown(ignore)
147 | assert len(prob1)==len(prob2)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(prob1)),str(len(prob2)),str(len(simi)))
148 | P = prob1.mul_(prob2)
149 | P = P.sum(1)
150 | P.mul_(simi).add_(simi.eq(-1).type_as(P))
151 | neglogP = -P.add_(BCE.eps).log_()
152 | return neglogP.mean()
153 |
154 |
155 | def PairEnum(x,mask=None):
156 |
157 | # Enumerate all pairs of feature in x
158 | assert x.ndimension() == 2, 'Input dimension must be 2'
159 | x1 = x.repeat(x.size(0), 1)
160 | x2 = x.repeat(1, x.size(0)).view(-1, x.size(1))
161 |
162 | if mask is not None:
163 |
164 | xmask = mask.view(-1, 1).repeat(1, x.size(1))
165 | #dim 0: #sample, dim 1:#feature
166 | x1 = x1[xmask].view(-1, x.size(1))
167 | x2 = x2[xmask].view(-1, x.size(1))
168 |
169 | return x1, x2
170 |
171 |
172 | def accuracy(output, target, topk=(1,)):
173 | """Computes the accuracy over the k top predictions for the specified values of k"""
174 | with torch.no_grad():
175 | maxk = max(topk)
176 | batch_size = target.size(0)
177 |
178 | _, pred = output.topk(maxk, 1, True, True)
179 | pred = pred.t()
180 | correct = pred.eq(target.view(1, -1).expand_as(pred))
181 |
182 | res = []
183 | for k in topk:
184 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
185 | res.append(correct_k.mul_(100.0 / batch_size))
186 | return res
187 |
188 |
189 | def seed_torch(seed=1029):
190 | random.seed(seed)
191 | os.environ['PYTHONHASHSEED'] = str(seed)
192 | np.random.seed(seed)
193 | torch.manual_seed(seed)
194 | torch.cuda.manual_seed(seed)
195 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
196 | torch.backends.cudnn.benchmark = False
197 | torch.backends.cudnn.deterministic = True
198 |
199 |
200 | def str2bool(v):
201 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
202 | return True
203 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
204 | return False
205 | else:
206 | raise argparse.ArgumentTypeError('Boolean value expected.')
--------------------------------------------------------------------------------
/data/cifar.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR10, CIFAR100
2 | from copy import deepcopy
3 | import numpy as np
4 |
5 | from data.data_utils import subsample_instances
6 | from config import cifar_10_root, cifar_100_root
7 |
8 |
9 | class CustomCIFAR10(CIFAR10):
10 |
11 | def __init__(self, *args, **kwargs):
12 |
13 | super(CustomCIFAR10, self).__init__(*args, **kwargs)
14 |
15 | self.uq_idxs = np.array(range(len(self)))
16 |
17 | def __getitem__(self, item):
18 |
19 | img, label = super().__getitem__(item)
20 | uq_idx = self.uq_idxs[item]
21 |
22 | return img, label, uq_idx
23 |
24 | def __len__(self):
25 | return len(self.targets)
26 |
27 |
28 | class CustomCIFAR100(CIFAR100):
29 |
30 | def __init__(self, *args, **kwargs):
31 | super(CustomCIFAR100, self).__init__(*args, **kwargs)
32 |
33 | self.uq_idxs = np.array(range(len(self)))
34 |
35 | def __getitem__(self, item):
36 | img, label = super().__getitem__(item)
37 | uq_idx = self.uq_idxs[item]
38 |
39 | return img, label, uq_idx
40 |
41 | def __len__(self):
42 | return len(self.targets)
43 |
44 |
45 | def subsample_dataset(dataset, idxs):
46 |
47 | # Allow for setting in which all empty set of indices is passed
48 |
49 | if len(idxs) > 0:
50 |
51 | dataset.data = dataset.data[idxs]
52 | dataset.targets = np.array(dataset.targets)[idxs].tolist()
53 | dataset.uq_idxs = dataset.uq_idxs[idxs]
54 |
55 | return dataset
56 |
57 | else:
58 |
59 | return None
60 |
61 |
62 | def subsample_classes(dataset, include_classes=(0, 1, 8, 9)):
63 |
64 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
65 |
66 | target_xform_dict = {}
67 | for i, k in enumerate(include_classes):
68 | target_xform_dict[k] = i
69 |
70 | dataset = subsample_dataset(dataset, cls_idxs)
71 |
72 | # dataset.target_transform = lambda x: target_xform_dict[x]
73 |
74 | return dataset
75 |
76 |
77 | def get_train_val_indices(train_dataset, val_split=0.2):
78 |
79 | train_classes = np.unique(train_dataset.targets)
80 |
81 | # Get train/test indices
82 | train_idxs = []
83 | val_idxs = []
84 | for cls in train_classes:
85 |
86 | cls_idxs = np.where(train_dataset.targets == cls)[0]
87 |
88 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
89 | t_ = [x for x in cls_idxs if x not in v_]
90 |
91 | train_idxs.extend(t_)
92 | val_idxs.extend(v_)
93 |
94 | return train_idxs, val_idxs
95 |
96 |
97 | def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9),
98 | prop_train_labels=0.8, split_train_val=False, seed=0):
99 |
100 | np.random.seed(seed)
101 |
102 | # Init entire training set
103 | whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True)
104 |
105 | # Get labelled training set which has subsampled classes, then subsample some indices from that
106 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
107 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
108 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
109 |
110 | # Split into training and validation sets
111 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
112 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
113 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
114 | val_dataset_labelled_split.transform = test_transform
115 |
116 | # Get unlabelled data
117 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
118 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
119 |
120 | # Get test set for all classes
121 | test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False)
122 |
123 | # Either split train into train and val or use test set as val
124 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
125 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
126 |
127 | all_datasets = {
128 | 'train_labelled': train_dataset_labelled,
129 | 'train_unlabelled': train_dataset_unlabelled,
130 | 'val': val_dataset_labelled,
131 | 'test': test_dataset,
132 | }
133 |
134 | return all_datasets
135 |
136 |
137 | def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80),
138 | prop_train_labels=0.8, split_train_val=False, seed=0):
139 |
140 | np.random.seed(seed)
141 |
142 | # Init entire training set
143 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True)
144 |
145 | # Get labelled training set which has subsampled classes, then subsample some indices from that
146 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
147 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
148 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
149 |
150 | # Split into training and validation sets
151 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
152 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
153 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
154 | val_dataset_labelled_split.transform = test_transform
155 |
156 | # Get unlabelled data
157 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
158 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
159 |
160 | # Get test set for all classes
161 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False)
162 |
163 | # Either split train into train and val or use test set as val
164 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
165 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
166 |
167 | all_datasets = {
168 | 'train_labelled': train_dataset_labelled,
169 | 'train_unlabelled': train_dataset_unlabelled,
170 | 'val': val_dataset_labelled,
171 | 'test': test_dataset,
172 | }
173 |
174 | return all_datasets
175 |
176 |
177 | if __name__ == '__main__':
178 |
179 | x = get_cifar_100_datasets(None, None, split_train_val=False,
180 | train_classes=range(80), prop_train_labels=0.5)
181 |
182 | print('Printing lens...')
183 | for k, v in x.items():
184 | if v is not None:
185 | print(f'{k}: {len(v)}')
186 |
187 | print('Printing labelled and unlabelled overlap...')
188 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
189 | print('Printing total instances in train...')
190 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
191 |
192 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
193 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
194 | print(f'Len labelled set: {len(x["train_labelled"])}')
195 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/get_datasets.py:
--------------------------------------------------------------------------------
1 | from matplotlib.pyplot import get
2 | from data.data_utils import MergedDataset
3 |
4 | from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets
5 | from data.stanford_cars import get_scars_datasets
6 | from data.imagenet import get_imagenet_100_datasets
7 | from data.cub import get_cub_datasets
8 | from data.fgvc_aircraft import get_aircraft_datasets
9 | from data.pets import get_pets_datasets
10 | from data.flower import get_flower_datasets
11 | from data.food import get_food_datasets
12 |
13 | from data.cifar import subsample_classes as subsample_dataset_cifar
14 | from data.stanford_cars import subsample_classes as subsample_dataset_scars
15 | from data.imagenet import subsample_classes as subsample_dataset_imagenet
16 | from data.cub import subsample_classes as subsample_dataset_cub
17 | from data.fgvc_aircraft import subsample_classes as subsample_dataset_air
18 | from data.pets import subsample_classes as subsample_dataset_pets
19 | from data.flower import subsample_classes as subsample_dataset_flower
20 | from data.food import subsample_classes as subsample_dataset_food
21 |
22 | from copy import deepcopy
23 | import pickle
24 | import os
25 |
26 | from config import osr_split_dir
27 |
28 | sub_sample_class_funcs = {
29 | 'cifar10': subsample_dataset_cifar,
30 | 'cifar100': subsample_dataset_cifar,
31 | 'imagenet_100': subsample_dataset_imagenet,
32 | 'cub': subsample_dataset_cub,
33 | 'aircraft': subsample_dataset_air,
34 | 'scars': subsample_dataset_scars,
35 | 'pets': subsample_dataset_pets,
36 | 'flower': subsample_dataset_flower,
37 | 'food': subsample_dataset_food
38 | }
39 |
40 | get_dataset_funcs = {
41 | 'cifar10': get_cifar_10_datasets,
42 | 'cifar100': get_cifar_100_datasets,
43 | 'imagenet_100': get_imagenet_100_datasets,
44 | 'cub': get_cub_datasets,
45 | 'aircraft': get_aircraft_datasets,
46 | 'scars': get_scars_datasets,
47 | 'pets': get_pets_datasets,
48 | 'flower': get_flower_datasets,
49 | 'food': get_food_datasets
50 | }
51 |
52 |
53 | def get_datasets(dataset_name, train_transform, test_transform, args):
54 |
55 | """
56 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled
57 | test_dataset,
58 | unlabelled_train_examples_test,
59 | datasets
60 | """
61 |
62 | if dataset_name not in get_dataset_funcs.keys():
63 | raise ValueError
64 |
65 | # Get datasets
66 | get_dataset_f = get_dataset_funcs[dataset_name]
67 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform,
68 | train_classes=args.train_classes,
69 | prop_train_labels=args.prop_train_labels,
70 | split_train_val=False)
71 |
72 | # Set target transforms:
73 | target_transform_dict = {}
74 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)):
75 | target_transform_dict[cls] = i
76 | target_transform = lambda x: target_transform_dict[x]
77 |
78 | # ['train_labelled', 'train_unlabelled', 'val', 'test']
79 | for dataset_name, dataset in datasets.items():
80 | if dataset is not None:
81 | dataset.target_transform = target_transform
82 |
83 | # Train split (labelled and unlabelled classes) for training
84 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']),
85 | unlabelled_dataset=deepcopy(datasets['train_unlabelled']))
86 |
87 | test_dataset = datasets['test']
88 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled'])
89 | unlabelled_train_examples_test.transform = test_transform
90 |
91 | labelled_train_examples = deepcopy(datasets['train_labelled'])
92 | labelled_train_examples.transform = test_transform
93 |
94 | return train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples
95 |
96 |
97 | def get_class_splits(args):
98 |
99 | # For FGVC datasets, optionally return bespoke splits
100 | if args.dataset_name in ('scars', 'cub', 'aircraft'):
101 | if hasattr(args, 'use_ssb_splits'):
102 | use_ssb_splits = args.use_ssb_splits
103 | else:
104 | use_ssb_splits = False
105 |
106 | # -------------
107 | # GET CLASS SPLITS
108 | # -------------
109 | if args.dataset_name == 'cifar10':
110 |
111 | args.image_size = 32
112 | args.train_classes = range(5)
113 | args.unlabeled_classes = range(5, 10)
114 |
115 | elif args.dataset_name == 'cifar100':
116 |
117 | args.image_size = 32
118 | args.train_classes = range(80)
119 | args.unlabeled_classes = range(80, 100)
120 |
121 | elif args.dataset_name == 'tinyimagenet':
122 |
123 | args.image_size = 64
124 | args.train_classes = range(100)
125 | args.unlabeled_classes = range(100, 200)
126 |
127 |
128 | elif args.dataset_name == 'imagenet_100':
129 |
130 | args.image_size = 224
131 | args.train_classes = range(50)
132 | args.unlabeled_classes = range(50, 100)
133 |
134 | elif args.dataset_name == 'scars':
135 |
136 | args.image_size = 224
137 |
138 | if use_ssb_splits:
139 |
140 | split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl')
141 | with open(split_path, 'rb') as handle:
142 | class_info = pickle.load(handle)
143 |
144 | args.train_classes = class_info['known_classes']
145 | open_set_classes = class_info['unknown_classes']
146 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
147 |
148 | else:
149 |
150 | args.train_classes = range(98)
151 | args.unlabeled_classes = range(98, 196)
152 |
153 | elif args.dataset_name == 'aircraft':
154 |
155 | args.image_size = 224
156 | if use_ssb_splits:
157 |
158 | split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl')
159 | with open(split_path, 'rb') as handle:
160 | class_info = pickle.load(handle)
161 |
162 | args.train_classes = class_info['known_classes']
163 | open_set_classes = class_info['unknown_classes']
164 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
165 |
166 | else:
167 |
168 | args.train_classes = range(50)
169 | args.unlabeled_classes = range(50, 100)
170 |
171 | elif args.dataset_name == 'cub':
172 |
173 | args.image_size = 224
174 |
175 | if use_ssb_splits:
176 |
177 | split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl')
178 | with open(split_path, 'rb') as handle:
179 | class_info = pickle.load(handle)
180 |
181 | args.train_classes = class_info['known_classes']
182 | open_set_classes = class_info['unknown_classes']
183 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
184 |
185 | else:
186 |
187 | args.train_classes = range(100)
188 | args.unlabeled_classes = range(100, 200)
189 |
190 |
191 | elif args.dataset_name == 'pets':
192 |
193 | args.image_size = 224
194 | args.train_classes = range(19)
195 | args.unlabeled_classes = range(19, 37)
196 |
197 | elif args.dataset_name == 'flower':
198 |
199 | args.image_size = 224
200 | args.train_classes = range(51)
201 | args.unlabeled_classes = range(51, 102)
202 |
203 | elif args.dataset_name == 'food':
204 |
205 | args.image_size = 224
206 | args.train_classes = range(51)
207 | args.unlabeled_classes = range(51, 101)
208 |
209 | else:
210 |
211 | raise NotImplementedError
212 |
213 | return args
--------------------------------------------------------------------------------
/data/cub.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 |
6 | from torchvision.datasets.folder import default_loader
7 | from torchvision.datasets.utils import download_url
8 | from torch.utils.data import Dataset
9 |
10 | from data.data_utils import subsample_instances
11 | from config import cub_root
12 |
13 |
14 | class CustomCub2011(Dataset):
15 | base_folder = 'CUB_200_2011/images'
16 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
17 | filename = 'CUB_200_2011.tgz'
18 | tgz_md5 = '97eceeb196236b17998738112f37df78'
19 |
20 | def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=True):
21 |
22 | self.root = os.path.expanduser(root)
23 | self.transform = transform
24 | self.target_transform = target_transform
25 |
26 | self.loader = loader
27 | self.train = train
28 |
29 |
30 | if download:
31 | self._download()
32 |
33 | if not self._check_integrity():
34 | raise RuntimeError('Dataset not found or corrupted.' +
35 | ' You can use download=True to download it')
36 |
37 | self.uq_idxs = np.array(range(len(self)))
38 |
39 | def _load_metadata(self):
40 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
41 | names=['img_id', 'filepath'])
42 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
43 | sep=' ', names=['img_id', 'target'])
44 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
45 | sep=' ', names=['img_id', 'is_training_img'])
46 |
47 | data = images.merge(image_class_labels, on='img_id')
48 | self.data = data.merge(train_test_split, on='img_id')
49 |
50 | if self.train:
51 | self.data = self.data[self.data.is_training_img == 1]
52 | else:
53 | self.data = self.data[self.data.is_training_img == 0]
54 |
55 | def _check_integrity(self):
56 | try:
57 | self._load_metadata()
58 | except Exception:
59 | return False
60 |
61 | for index, row in self.data.iterrows():
62 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
63 | if not os.path.isfile(filepath):
64 | print(filepath)
65 | return False
66 | return True
67 |
68 | def _download(self):
69 | import tarfile
70 |
71 | if self._check_integrity():
72 | print('Files already downloaded and verified')
73 | return
74 |
75 | download_url(self.url, self.root, self.filename, self.tgz_md5)
76 |
77 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
78 | tar.extractall(path=self.root)
79 |
80 | def __len__(self):
81 | return len(self.data)
82 |
83 | def __getitem__(self, idx):
84 | sample = self.data.iloc[idx]
85 | path = os.path.join(self.root, self.base_folder, sample.filepath)
86 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0
87 | img = self.loader(path)
88 |
89 | if self.transform is not None:
90 | img = self.transform(img)
91 |
92 | if self.target_transform is not None:
93 | target = self.target_transform(target)
94 |
95 | return img, target, self.uq_idxs[idx]
96 |
97 |
98 | def subsample_dataset(dataset, idxs):
99 |
100 | mask = np.zeros(len(dataset)).astype('bool')
101 | mask[idxs] = True
102 |
103 | dataset.data = dataset.data[mask]
104 | dataset.uq_idxs = dataset.uq_idxs[mask]
105 |
106 | return dataset
107 |
108 |
109 | def subsample_classes(dataset, include_classes=range(160)):
110 |
111 | include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199
112 | cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub]
113 |
114 | # TODO: For now have no target transform
115 | target_xform_dict = {}
116 | for i, k in enumerate(include_classes):
117 | target_xform_dict[k] = i
118 |
119 | dataset = subsample_dataset(dataset, cls_idxs)
120 |
121 | dataset.target_transform = lambda x: target_xform_dict[x]
122 |
123 | return dataset
124 |
125 |
126 | def get_train_val_indices(train_dataset, val_split=0.2):
127 |
128 | train_classes = np.unique(train_dataset.data['target'])
129 |
130 | # Get train/test indices
131 | train_idxs = []
132 | val_idxs = []
133 | for cls in train_classes:
134 |
135 | cls_idxs = np.where(train_dataset.data['target'] == cls)[0]
136 |
137 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
138 | t_ = [x for x in cls_idxs if x not in v_]
139 |
140 | train_idxs.extend(t_)
141 | val_idxs.extend(v_)
142 |
143 | return train_idxs, val_idxs
144 |
145 |
146 | def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
147 | split_train_val=False, seed=0):
148 |
149 | np.random.seed(seed)
150 |
151 | # Init entire training set
152 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True, download=False)
153 |
154 | # Get labelled training set which has subsampled classes, then subsample some indices from that
155 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
156 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
157 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
158 |
159 | # Split into training and validation sets
160 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
161 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
162 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
163 | val_dataset_labelled_split.transform = test_transform
164 |
165 | # Get unlabelled data
166 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
167 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
168 |
169 | # Get test set for all classes
170 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False, download=False)
171 |
172 | # Either split train into train and val or use test set as val
173 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
174 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
175 |
176 | all_datasets = {
177 | 'train_labelled': train_dataset_labelled,
178 | 'train_unlabelled': train_dataset_unlabelled,
179 | 'val': val_dataset_labelled,
180 | 'test': test_dataset,
181 | }
182 |
183 | return all_datasets
184 |
185 | def main():
186 |
187 | x = get_cub_datasets(None, None, split_train_val=False,
188 | train_classes=range(100), prop_train_labels=0.5)
189 |
190 | print('Printing lens...')
191 | for k, v in x.items():
192 | if v is not None:
193 | print(f'{k}: {len(v)}')
194 |
195 | print('Printing labelled and unlabelled overlap...')
196 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
197 | print('Printing total instances in train...')
198 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
199 |
200 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}')
201 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}')
202 | print(f'Len labelled set: {len(x["train_labelled"])}')
203 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
204 |
205 | if __name__ == '__main__':
206 | main()
--------------------------------------------------------------------------------
/methods/clustering/ssk_means.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | from torch.utils.data import DataLoader
5 | import numpy as np
6 | from sklearn.cluster import KMeans
7 | import torch
8 | from project_utils.cluster_utils import str2bool
9 | from project_utils.general_utils import seed_torch
10 | from project_utils.cluster_and_log_utils import log_accs_from_preds
11 |
12 | from methods.clustering.feature_vector_dataset import FeatureVectorDataset
13 | from data.get_datasets import get_datasets, get_class_splits
14 | from methods.clustering.faster_mix_k_means_pytorch import K_Means as SemiSupKMeans
15 |
16 | from tqdm import tqdm
17 | from config import feature_extract_dir
18 |
19 | # TODO: Debug
20 | import warnings
21 | warnings.filterwarnings("ignore", category=DeprecationWarning)
22 |
23 | def kmeans_semi_sup(merge_test_loader, args, K=None):
24 |
25 | """
26 | In this case, the test loader needs to have the labelled and unlabelled subsets of the training data
27 | """
28 |
29 | if K is None:
30 | K = args.num_labeled_classes + args.num_unlabeled_classes
31 |
32 | all_feats = []
33 | targets = np.array([])
34 | mask_lab = np.array([]) # From all the data, which instances belong to the labelled set
35 | mask_cls = np.array([]) # From all the data, which instances belong to Old classes
36 |
37 | print('Collating features...')
38 | # First extract all features
39 | for batch_idx, (feats, label, _, mask_lab_) in enumerate(tqdm(merge_test_loader)):
40 |
41 | feats = feats.to(device)
42 |
43 | feats = torch.nn.functional.normalize(feats, dim=-1)
44 |
45 | all_feats.append(feats.cpu().numpy())
46 | targets = np.append(targets, label.cpu().numpy())
47 | mask_cls = np.append(mask_cls, np.array([True if x.item() in range(len(args.train_classes))
48 | else False for x in label]))
49 | mask_lab = np.append(mask_lab, mask_lab_.cpu().bool().numpy())
50 |
51 | # -----------------------
52 | # K-MEANS
53 | # -----------------------
54 | mask_lab = mask_lab.astype(bool)
55 | mask_cls = mask_cls.astype(bool)
56 |
57 | all_feats = np.concatenate(all_feats) # (5994, 768)
58 |
59 | l_feats = all_feats[mask_lab] # Get labelled set [1498,768]
60 | u_feats = all_feats[~mask_lab] # Get unlabelled set
61 | l_targets = targets[mask_lab] # Get labelled targets
62 | u_targets = targets[~mask_lab] # Get unlabelled targets
63 |
64 | print('Fitting Semi-Supervised K-Means...')
65 | kmeans = SemiSupKMeans(k=K, tolerance=1e-4, max_iterations=args.max_kmeans_iter, init='k-means++',
66 | n_init=args.k_means_init, random_state=None, n_jobs=None, pairwise_batch_size=1024, mode=None)
67 |
68 | l_feats, u_feats, l_targets, u_targets = (torch.from_numpy(x).to(device) for
69 | x in (l_feats, u_feats, l_targets, u_targets))
70 |
71 | kmeans.fit_mix(u_feats, l_feats, l_targets)
72 | all_preds = kmeans.labels_.cpu().numpy()
73 | u_targets = u_targets.cpu().numpy()
74 |
75 | # -----------------------
76 | # EVALUATE
77 | # -----------------------
78 | # Get preds corresponding to unlabelled set
79 | preds = all_preds[~mask_lab]
80 |
81 | # Get portion of mask_cls which corresponds to the unlabelled set
82 | mask = mask_cls[~mask_lab]
83 | mask = mask.astype(bool)
84 |
85 | # -----------------------
86 | # EVALUATE
87 | # -----------------------
88 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=u_targets, y_pred=preds, mask=mask, eval_funcs=args.eval_funcs,
89 | save_name='SS-K-Means Train ACC Unlabelled', print_output=True)
90 |
91 | return all_acc, old_acc, new_acc, kmeans
92 |
93 | if __name__ == "__main__":
94 |
95 | parser = argparse.ArgumentParser(
96 | description='cluster',
97 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
98 | parser.add_argument('--batch_size', default=128, type=int)
99 | parser.add_argument('--num_workers', default=8, type=int)
100 | parser.add_argument('--K', default=None, type=int, help='Set manually to run with custom K')
101 | parser.add_argument('--root_dir', type=str, default=feature_extract_dir)
102 | parser.add_argument('--warmup_model_exp_id', type=str, default=None)
103 | parser.add_argument('--use_best_model', type=str2bool, default=True)
104 | parser.add_argument('--spatial', type=str2bool, default=False)
105 | parser.add_argument('--semi_sup', type=str2bool, default=True)
106 | parser.add_argument('--max_kmeans_iter', type=int, default=10)
107 | parser.add_argument('--k_means_init', type=int, default=10)
108 | parser.add_argument('--model_name', type=str, default='vit_dino', help='Format is {model_name}_{pretrain}')
109 | parser.add_argument('--dataset_name', type=str, default='aircraft', help='options: cifar10, cifar100, scars')
110 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
111 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v1', 'v2'])
112 | parser.add_argument('--use_ssb_splits', type=str2bool, default=True)
113 |
114 | # ----------------------
115 | # INIT
116 | # ----------------------
117 | args = parser.parse_args()
118 | cluster_accs = {}
119 | seed_torch(0)
120 | args.save_dir = os.path.join(args.root_dir, f'{args.model_name}_{args.dataset_name}')
121 |
122 | args = get_class_splits(args)
123 |
124 | args.num_labeled_classes = len(args.train_classes)
125 | args.num_unlabeled_classes = len(args.unlabeled_classes)
126 |
127 | device = torch.device('cuda:0')
128 | args.device = device
129 | print(args)
130 |
131 | if args.warmup_model_exp_id is not None:
132 |
133 | args.save_dir += '_' + args.warmup_model_exp_id
134 |
135 | if args.use_best_model:
136 | args.save_dir += '_best'
137 |
138 | print(f'Using features from experiment: {args.warmup_model_exp_id}')
139 | else:
140 | print(f'Using pretrained {args.model_name} features...')
141 |
142 | print(args.save_dir)
143 |
144 | # --------------------
145 | # DATASETS
146 | # --------------------
147 | print('Building datasets...')
148 | train_transform, test_transform = None, None
149 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples = get_datasets(args.dataset_name,
150 | train_transform, test_transform, args)
151 |
152 | # Set target transforms:
153 | target_transform_dict = {}
154 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)):
155 | target_transform_dict[cls] = i
156 | target_transform = lambda x: target_transform_dict[x]
157 |
158 | # Convert to feature vector dataset
159 | test_dataset = FeatureVectorDataset(base_dataset=test_dataset, feature_root=os.path.join(args.save_dir, 'test'))
160 | unlabelled_train_examples_test = FeatureVectorDataset(base_dataset=unlabelled_train_examples_test,
161 | feature_root=os.path.join(args.save_dir, 'train'))
162 | train_dataset = FeatureVectorDataset(base_dataset=train_dataset, feature_root=os.path.join(args.save_dir, 'train'))
163 | train_dataset.target_transform = target_transform
164 |
165 | unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
166 | batch_size=args.batch_size, shuffle=False)
167 | test_loader = DataLoader(test_dataset, num_workers=args.num_workers,
168 | batch_size=args.batch_size, shuffle=False)
169 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers,
170 | batch_size=args.batch_size, shuffle=False)
171 |
172 | print('Performing SS-K-Means on all in the training data...')
173 | all_acc, old_acc, new_acc, kmeans = kmeans_semi_sup(train_loader, args, K=args.K)
174 | cluster_save_path = os.path.join(args.save_dir, 'ss_kmeans_cluster_centres.pt')
175 | torch.save(kmeans.cluster_centers_, cluster_save_path)
--------------------------------------------------------------------------------
/methods/partitioning/kmeans_subset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from easydict import EasyDict
4 |
5 | from torch.utils.data import DataLoader
6 | import numpy as np
7 | from sklearn.cluster import KMeans
8 | import torch
9 | from torch.optim import SGD, lr_scheduler
10 | from project_utils.cluster_utils import mixed_eval, AverageMeter
11 | from model import vision_transformer as vits
12 |
13 | from project_utils.general_utils import init_experiment, get_mean_lr, str2bool, get_dino_head_weights
14 |
15 | from data.augmentations import get_transform
16 | from data.get_datasets import get_datasets, get_class_splits
17 | from data.data_utils import MergedDataset
18 |
19 | from tqdm import tqdm
20 |
21 | from torch.nn import functional as F
22 |
23 | from project_utils.cluster_and_log_utils import log_accs_from_preds
24 | from config import exp_root, km_label_path, dino_pretrain_path, moco_pretrain_path, mae_pretrain_path
25 |
26 | from copy import deepcopy
27 | # TODO: Debug
28 | import warnings
29 | warnings.filterwarnings("ignore", category=DeprecationWarning)
30 |
31 |
32 | class ContrastiveLearningViewGenerator(object):
33 | """Take two random crops of one image as the query and key."""
34 |
35 | def __init__(self, base_transform, n_views=2):
36 | self.base_transform = base_transform
37 | self.n_views = n_views
38 |
39 | def __call__(self, x):
40 | return [self.base_transform(x) for i in range(self.n_views)]
41 |
42 |
43 | if __name__ == "__main__":
44 |
45 | parser = argparse.ArgumentParser(
46 | description='cluster',
47 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
48 | parser.add_argument('--batch_size', default=256, type=int)
49 | parser.add_argument('--num_workers', default=4, type=int)
50 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v1', 'v2'])
51 |
52 | parser.add_argument('--warmup_model_dir', type=str, default=None)
53 | parser.add_argument('--model_name', type=str, default='vit_dino', help='Format is {model_name}_{pretrain}')
54 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, scars')
55 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
56 | parser.add_argument('--use_ssb_splits', type=str2bool, default=False)
57 |
58 | parser.add_argument('--grad_from_block', type=int, default=11)
59 | parser.add_argument('--lr', type=float, default=0.1)
60 | parser.add_argument('--save_best_thresh', type=float, default=None)
61 | parser.add_argument('--gamma', type=float, default=0.1)
62 | parser.add_argument('--momentum', type=float, default=0.9)
63 | parser.add_argument('--weight_decay', type=float, default=1e-4)
64 | parser.add_argument('--epochs', default=200, type=int)
65 | parser.add_argument('--exp_root', type=str, default=exp_root)
66 | parser.add_argument('--transform', type=str, default='imagenet')
67 | parser.add_argument('--seed', default=1, type=int)
68 |
69 | parser.add_argument('--base_model', type=str, default='vit_dino')
70 | parser.add_argument('--n_views', default=2, type=int)
71 |
72 | parser.add_argument('--experts_num', type=int, default=8)
73 |
74 | parser.add_argument('--pretrain_model', type=str, default='dino')
75 |
76 | args = parser.parse_args()
77 | device = torch.device('cuda:0')
78 | args = get_class_splits(args)
79 |
80 | args.num_labeled_classes = len(args.train_classes)
81 | args.num_unlabeled_classes = len(args.unlabeled_classes)
82 |
83 | init_experiment(args, runner_name=['metric_learn_gcd'])
84 | print(f'Using evaluation function {args.eval_funcs[0]} to print results')
85 |
86 | if args.base_model == 'vit_dino':
87 |
88 | args.interpolation = 3
89 | args.crop_pct = 0.875
90 | if args.pretrain_model == 'dino':
91 | pretrain_path = dino_pretrain_path
92 |
93 | model = vits.__dict__['vit_base']()
94 |
95 | if args.pretrain_model == 'dino':
96 | weight = torch.load(pretrain_path, map_location='cpu')
97 |
98 | msg = model.load_state_dict(weight, strict=False)
99 | print(msg)
100 |
101 | if args.warmup_model_dir is not None:
102 | print(f'Loading weights from {args.warmup_model_dir}')
103 | model.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu'))
104 |
105 | model.to(device)
106 |
107 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model
108 | args.image_size = 224
109 | args.feat_dim = 768
110 | args.num_mlp_layers = 3
111 | args.mlp_out_dim = 65536
112 |
113 | # ----------------------
114 | # HOW MUCH OF BASE MODEL TO FINETUNE
115 | # ----------------------
116 | for m in model.parameters():
117 | m.requires_grad = False
118 |
119 | # Only finetune layers from block 'args.grad_from_block' onwards
120 | for name, m in model.named_parameters():
121 | if 'block' in name:
122 | block_num = int(name.split('.')[1])
123 | if block_num >= args.grad_from_block:
124 | m.requires_grad = True
125 |
126 | else:
127 |
128 | raise NotImplementedError
129 |
130 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
131 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
132 |
133 | # --------------------
134 | # DATASETS
135 | # --------------------
136 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples = get_datasets(args.dataset_name,
137 | train_transform,
138 | test_transform,
139 | args)
140 | whole_train_test_dataset = MergedDataset(deepcopy(labelled_train_examples),deepcopy(unlabelled_train_examples_test))
141 |
142 |
143 | # --------------------
144 | # SAMPLER
145 | # Sampler which balances labelled and unlabelled examples in each batch
146 | # --------------------
147 | label_len = len(train_dataset.labelled_dataset)
148 | unlabelled_len = len(train_dataset.unlabelled_dataset)
149 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
150 | sample_weights = torch.DoubleTensor(sample_weights)
151 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset))
152 |
153 | # --------------------
154 | # DATALOADERS
155 | # --------------------
156 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
157 | sampler=sampler, drop_last=True)
158 | test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
159 | batch_size=args.batch_size, shuffle=False)
160 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers,
161 | batch_size=args.batch_size, shuffle=False)
162 | train_loader_labelled = DataLoader(labelled_train_examples, num_workers=args.num_workers,
163 | batch_size=args.batch_size, shuffle=False)
164 | whole_train_test_loader = DataLoader(whole_train_test_dataset, num_workers=args.num_workers,
165 | batch_size=args.batch_size, shuffle=False)
166 |
167 | model.eval()
168 | all_feats = []
169 | targets = np.array([])
170 | mask = np.array([])
171 | for batch_idx, batch in enumerate(tqdm(whole_train_test_loader)):
172 | images, label, _, _ = batch
173 | images = images.to(device)
174 | feats = model(images)
175 | feats = torch.nn.functional.normalize(feats, dim=-1)
176 | all_feats.append(feats.cpu().detach().numpy())
177 | targets = np.append(targets, label.cpu().detach().numpy())
178 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes))
179 | else False for x in label]))
180 | print('Kmeans...')
181 |
182 | all_feats = np.concatenate(all_feats)
183 | kmeans = KMeans(n_clusters=args.experts_num, random_state=0).fit(all_feats)
184 | preds = kmeans.labels_
185 | print('Done!')
186 | print('feats length:', all_feats.shape)
187 |
188 | np.save(f'{km_label_path}/{args.dataset_name}_km_labels.npy', preds)
189 |
--------------------------------------------------------------------------------
/data/augmentations/randaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
3 | """
4 |
5 | import random
6 |
7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 |
13 | def ShearX(img, v): # [-0.3, 0.3]
14 | assert -0.3 <= v <= 0.3
15 | if random.random() > 0.5:
16 | v = -v
17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
18 |
19 |
20 | def ShearY(img, v): # [-0.3, 0.3]
21 | assert -0.3 <= v <= 0.3
22 | if random.random() > 0.5:
23 | v = -v
24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
25 |
26 |
27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
28 | assert -0.45 <= v <= 0.45
29 | if random.random() > 0.5:
30 | v = -v
31 | v = v * img.size[0]
32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
33 |
34 |
35 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
36 | assert 0 <= v
37 | if random.random() > 0.5:
38 | v = -v
39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
40 |
41 |
42 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
43 | assert -0.45 <= v <= 0.45
44 | if random.random() > 0.5:
45 | v = -v
46 | v = v * img.size[1]
47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
48 |
49 |
50 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
51 | assert 0 <= v
52 | if random.random() > 0.5:
53 | v = -v
54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
55 |
56 |
57 | def Rotate(img, v): # [-30, 30]
58 | assert -30 <= v <= 30
59 | if random.random() > 0.5:
60 | v = -v
61 | return img.rotate(v)
62 |
63 |
64 | def AutoContrast(img, _):
65 | return PIL.ImageOps.autocontrast(img)
66 |
67 |
68 | def Invert(img, _):
69 | return PIL.ImageOps.invert(img)
70 |
71 |
72 | def Equalize(img, _):
73 | return PIL.ImageOps.equalize(img)
74 |
75 |
76 | def Flip(img, _): # not from the paper
77 | return PIL.ImageOps.mirror(img)
78 |
79 |
80 | def Solarize(img, v): # [0, 256]
81 | assert 0 <= v <= 256
82 | return PIL.ImageOps.solarize(img, v)
83 |
84 |
85 | def SolarizeAdd(img, addition=0, threshold=128):
86 | img_np = np.array(img).astype(np.int)
87 | img_np = img_np + addition
88 | img_np = np.clip(img_np, 0, 255)
89 | img_np = img_np.astype(np.uint8)
90 | img = Image.fromarray(img_np)
91 | return PIL.ImageOps.solarize(img, threshold)
92 |
93 |
94 | def Posterize(img, v): # [4, 8]
95 | v = int(v)
96 | v = max(1, v)
97 | return PIL.ImageOps.posterize(img, v)
98 |
99 |
100 | def Contrast(img, v): # [0.1,1.9]
101 | assert 0.1 <= v <= 1.9
102 | return PIL.ImageEnhance.Contrast(img).enhance(v)
103 |
104 |
105 | def Color(img, v): # [0.1,1.9]
106 | assert 0.1 <= v <= 1.9
107 | return PIL.ImageEnhance.Color(img).enhance(v)
108 |
109 |
110 | def Brightness(img, v): # [0.1,1.9]
111 | assert 0.1 <= v <= 1.9
112 | return PIL.ImageEnhance.Brightness(img).enhance(v)
113 |
114 |
115 | def Sharpness(img, v): # [0.1,1.9]
116 | assert 0.1 <= v <= 1.9
117 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
118 |
119 |
120 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
121 | assert 0.0 <= v <= 0.2
122 | if v <= 0.:
123 | return img
124 |
125 | v = v * img.size[0]
126 | return CutoutAbs(img, v)
127 |
128 |
129 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
130 | # assert 0 <= v <= 20
131 | if v < 0:
132 | return img
133 | w, h = img.size
134 | x0 = np.random.uniform(w)
135 | y0 = np.random.uniform(h)
136 |
137 | x0 = int(max(0, x0 - v / 2.))
138 | y0 = int(max(0, y0 - v / 2.))
139 | x1 = min(w, x0 + v)
140 | y1 = min(h, y0 + v)
141 |
142 | xy = (x0, y0, x1, y1)
143 | color = (125, 123, 114)
144 | # color = (0, 0, 0)
145 | img = img.copy()
146 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
147 | return img
148 |
149 |
150 | def SamplePairing(imgs): # [0, 0.4]
151 | def f(img1, v):
152 | i = np.random.choice(len(imgs))
153 | img2 = PIL.Image.fromarray(imgs[i])
154 | return PIL.Image.blend(img1, img2, v)
155 |
156 | return f
157 |
158 |
159 | def Identity(img, v):
160 | return img
161 |
162 |
163 | def augment_list(): # 16 oeprations and their ranges
164 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
165 | # l = [
166 | # (Identity, 0., 1.0),
167 | # (ShearX, 0., 0.3), # 0
168 | # (ShearY, 0., 0.3), # 1
169 | # (TranslateX, 0., 0.33), # 2
170 | # (TranslateY, 0., 0.33), # 3
171 | # (Rotate, 0, 30), # 4
172 | # (AutoContrast, 0, 1), # 5
173 | # (Invert, 0, 1), # 6
174 | # (Equalize, 0, 1), # 7
175 | # (Solarize, 0, 110), # 8
176 | # (Posterize, 4, 8), # 9
177 | # # (Contrast, 0.1, 1.9), # 10
178 | # (Color, 0.1, 1.9), # 11
179 | # (Brightness, 0.1, 1.9), # 12
180 | # (Sharpness, 0.1, 1.9), # 13
181 | # # (Cutout, 0, 0.2), # 14
182 | # # (SamplePairing(imgs), 0, 0.4), # 15
183 | # ]
184 |
185 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
186 | l = [
187 | (AutoContrast, 0, 1),
188 | (Equalize, 0, 1),
189 | (Invert, 0, 1),
190 | (Rotate, 0, 30),
191 | (Posterize, 0, 4),
192 | (Solarize, 0, 256),
193 | (SolarizeAdd, 0, 110),
194 | (Color, 0.1, 1.9),
195 | (Contrast, 0.1, 1.9),
196 | (Brightness, 0.1, 1.9),
197 | (Sharpness, 0.1, 1.9),
198 | (ShearX, 0., 0.3),
199 | (ShearY, 0., 0.3),
200 | (CutoutAbs, 0, 40),
201 | (TranslateXabs, 0., 100),
202 | (TranslateYabs, 0., 100),
203 | ]
204 |
205 | return l
206 |
207 | def augment_list_svhn(): # 16 oeprations and their ranges
208 |
209 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
210 | l = [
211 | (AutoContrast, 0, 1),
212 | (Equalize, 0, 1),
213 | (Invert, 0, 1),
214 | (Posterize, 0, 4),
215 | (Solarize, 0, 256),
216 | (SolarizeAdd, 0, 110),
217 | (Color, 0.1, 1.9),
218 | (Contrast, 0.1, 1.9),
219 | (Brightness, 0.1, 1.9),
220 | (Sharpness, 0.1, 1.9),
221 | (ShearX, 0., 0.3),
222 | (ShearY, 0., 0.3),
223 | (CutoutAbs, 0, 40),
224 | ]
225 |
226 | return l
227 |
228 |
229 | class Lighting(object):
230 | """Lighting noise(AlexNet - style PCA - based noise)"""
231 |
232 | def __init__(self, alphastd, eigval, eigvec):
233 | self.alphastd = alphastd
234 | self.eigval = torch.Tensor(eigval)
235 | self.eigvec = torch.Tensor(eigvec)
236 |
237 | def __call__(self, img):
238 | if self.alphastd == 0:
239 | return img
240 |
241 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
242 | rgb = self.eigvec.type_as(img).clone() \
243 | .mul(alpha.view(1, 3).expand(3, 3)) \
244 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
245 | .sum(1).squeeze()
246 |
247 | return img.add(rgb.view(3, 1, 1).expand_as(img))
248 |
249 |
250 | class CutoutDefault(object):
251 | """
252 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
253 | """
254 | def __init__(self, length):
255 | self.length = length
256 |
257 | def __call__(self, img):
258 | h, w = img.size(1), img.size(2)
259 | mask = np.ones((h, w), float)
260 | y = np.random.randint(h)
261 | x = np.random.randint(w)
262 |
263 | y1 = np.clip(y - self.length // 2, 0, h)
264 | y2 = np.clip(y + self.length // 2, 0, h)
265 | x1 = np.clip(x - self.length // 2, 0, w)
266 | x2 = np.clip(x + self.length // 2, 0, w)
267 |
268 | mask[y1: y2, x1: x2] = 0.
269 | mask = torch.from_numpy(mask)
270 | mask = mask.expand_as(img)
271 | img *= mask
272 | return img
273 |
274 |
275 | class RandAugment:
276 | def __init__(self, n, m, args=None):
277 | self.n = n # [1, 2]
278 | self.m = m # [0...30]
279 |
280 | if args is None:
281 | self.augment_list = augment_list()
282 |
283 | elif args.dataset == 'svhn' or args.dataset == 'mnist':
284 | self.augment_list = augment_list_svhn()
285 |
286 | else:
287 | self.augment_list = augment_list()
288 |
289 | def __call__(self, img):
290 | ops = random.choices(self.augment_list, k=self.n)
291 | for op, minval, maxval in ops:
292 | val = (float(self.m) / 30) * float(maxval - minval) + minval
293 | img = op(img, val)
294 |
295 | return img
--------------------------------------------------------------------------------
/data/food.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import pathlib
4 | import json
5 | from typing import Any, Callable, Optional, Union, Tuple
6 | from typing import Sequence
7 |
8 | import numpy as np
9 | from copy import deepcopy
10 | from scipy import io as mat_io
11 |
12 | from data.data_utils import subsample_instances
13 | from config import food_root
14 |
15 | from torchvision.datasets.folder import default_loader
16 | from torch.utils.data import Dataset
17 |
18 | from data.data_utils import subsample_instances
19 |
20 |
21 | def make_dataset(dir, image_ids, targets):
22 | assert(len(image_ids) == len(targets))
23 | images = []
24 | dir = os.path.expanduser(dir)
25 | for i in range(len(image_ids)):
26 | item = (os.path.join(dir, 'images',
27 | '%s.jpg' % image_ids[i]), targets[i])
28 | images.append(item)
29 | return images
30 |
31 | class Food101(Dataset):
32 | """`The Food-101 Data Set `_.
33 |
34 | The Food-101 is a challenging data set of 101 food categories, with 101'000 images.
35 | For each class, 250 manually reviewed test images are provided as well as 750 training images.
36 | On purpose, the training images were not cleaned, and thus still contain some amount of noise.
37 | This comes mostly in the form of intense colors and sometimes wrong labels. All images were
38 | rescaled to have a maximum side length of 512 pixels.
39 |
40 |
41 | Args:
42 | root (string): Root directory of the dataset.
43 | split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
44 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
45 | version. E.g, ``transforms.RandomCrop``.
46 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
47 | download (bool, optional): If True, downloads the dataset from the internet and
48 | puts it in root directory. If dataset is already downloaded, it is not
49 | downloaded again. Default is False.
50 | """
51 |
52 | _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
53 | _MD5 = "85eeb15f3717b99a5da872d97d918f87"
54 | splits = ('train', 'test')
55 |
56 | def __init__(
57 | self,
58 | root: str,
59 | split: str = "train",
60 | transform: Optional[Callable] = None,
61 | target_transform: Optional[Callable] = None,
62 | download: bool = False,
63 | loader=default_loader
64 | ) -> None:
65 | if split not in self.splits:
66 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
67 | split, ', '.join(self.splits),
68 | ))
69 | self.split = split
70 | self.root = root
71 | self.transform = transform
72 | self.target_transform = target_transform
73 | self.loader = loader
74 |
75 | self._base_folder = pathlib.Path(self.root)
76 | self._meta_folder = self._base_folder / "meta"
77 | self._images_folder = self._base_folder / "images"
78 |
79 | if download:
80 | self._download()
81 |
82 | if not self._check_exists():
83 | raise RuntimeError("Dataset not found. You can use download=True to download it")
84 |
85 | self._labels = []
86 | self._image_files = []
87 | self.image_ids = []
88 | with open(self._meta_folder / f"{split}.json") as f:
89 | metadata = json.loads(f.read())
90 |
91 | self.classes = sorted(metadata.keys())
92 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
93 |
94 | for class_label, im_rel_paths in metadata.items():
95 | self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
96 | self.image_ids += [im_rel_path for im_rel_path in im_rel_paths]
97 | self._image_files += [
98 | self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
99 | ]
100 |
101 | samples = make_dataset(self.root, self.image_ids, self._labels)
102 | self.samples = samples
103 |
104 | self.uq_idxs = np.array(range(len(self)))
105 |
106 | def __len__(self) -> int:
107 | return len(self.samples)
108 |
109 | def __getitem__(self, idx) -> Tuple[Any, Any]:
110 |
111 | path, target = self.samples[idx]
112 | sample = self.loader(path)
113 | if self.transform is not None:
114 | sample = self.transform(sample)
115 | if self.target_transform is not None:
116 | target = self.target_transform(target)
117 |
118 | return sample, target, self.uq_idxs[idx]
119 |
120 | def extra_repr(self) -> str:
121 | return f"split={self._split}"
122 |
123 | def _check_exists(self) -> bool:
124 | return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
125 |
126 | def _download(self) -> None:
127 | if self._check_exists():
128 | return
129 | download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
130 |
131 | def subsample_dataset(dataset, idxs):
132 |
133 | mask = np.zeros(len(dataset)).astype('bool')
134 | mask[idxs] = True
135 |
136 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs]
137 | dataset.uq_idxs = dataset.uq_idxs[mask]
138 |
139 | return dataset
140 |
141 | def subsample_classes(dataset, include_classes=range(60)):
142 |
143 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes]
144 |
145 | # TODO: Don't transform targets for now
146 | target_xform_dict = {}
147 | for i, k in enumerate(include_classes):
148 | target_xform_dict[k] = i
149 |
150 | dataset = subsample_dataset(dataset, cls_idxs)
151 |
152 | dataset.target_transform = lambda x: target_xform_dict[x]
153 |
154 | return dataset
155 |
156 | def get_train_val_indices(train_dataset, val_split=0.2):
157 |
158 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)]
159 | train_classes = np.unique(all_targets)
160 |
161 | # Get train/test indices
162 | train_idxs = []
163 | val_idxs = []
164 | for cls in train_classes:
165 | cls_idxs = np.where(all_targets == cls)[0]
166 |
167 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
168 | t_ = [x for x in cls_idxs if x not in v_]
169 |
170 | train_idxs.extend(t_)
171 | val_idxs.extend(v_)
172 |
173 | return train_idxs, val_idxs
174 |
175 | def get_food_datasets(train_transform, test_transform, train_classes=range(51), prop_train_labels=0.8,
176 | split_train_val=False, seed=0):
177 |
178 | np.random.seed(seed)
179 |
180 | # Init entire training set
181 | whole_training_set = Food101(root=food_root, transform=train_transform, split='train', download=False) # len = 75750
182 |
183 | # Get labelled training set which has subsampled classes, then subsample some indices from that
184 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
185 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
186 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
187 |
188 | # Split into training and validation sets
189 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
190 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
191 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
192 | val_dataset_labelled_split.transform = test_transform
193 |
194 | # Get unlabelled data
195 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
196 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
197 |
198 | # Get test set for all classes
199 | test_dataset = Food101(root=food_root, transform=test_transform, split='test', download=False) # len = 25250
200 |
201 | # Either split train into train and val or use test set as val
202 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
203 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
204 |
205 | all_datasets = {
206 | 'train_labelled': train_dataset_labelled,
207 | 'train_unlabelled': train_dataset_unlabelled,
208 | 'val': val_dataset_labelled,
209 | 'test': test_dataset,
210 | }
211 |
212 | return all_datasets
213 |
214 | def main():
215 | x = get_food_datasets(None, None, split_train_val=False)
216 |
217 | print('Printing lens...')
218 | for k, v in x.items():
219 | if v is not None:
220 | print(f'{k}: {len(v)}')
221 |
222 | print('Printing labelled and unlabelled overlap...')
223 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
224 | print('Printing total instances in train...')
225 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
226 | print('Printing number of labelled classes...')
227 | print(len(set([i[1] for i in x['train_labelled'].samples])))
228 | print('Printing total number of classes...')
229 | print(len(set([i[1] for i in x['train_unlabelled'].samples])))
230 |
231 |
232 | if __name__ == '__main__':
233 | main()
--------------------------------------------------------------------------------
/methods/clustering/extract_features.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | import timm
4 | from torchvision import transforms
5 | import torchvision
6 |
7 | import argparse
8 | import os
9 | from tqdm import tqdm
10 |
11 | from data.stanford_cars import CarsDataset
12 | from data.cifar import CustomCIFAR10, CustomCIFAR100, cifar_10_root, cifar_100_root
13 | from data.augmentations import get_transform
14 | from data.imagenet import get_imagenet_100_datasets
15 | from data.data_utils import MergedDataset
16 | from data.cub import CustomCub2011, cub_root
17 | from data.fgvc_aircraft import FGVCAircraft, aircraft_root
18 | from data.pets import OxfordIIITPet, pets_root
19 |
20 | from project_utils.general_utils import strip_state_dict, str2bool
21 | from copy import deepcopy
22 |
23 | from config import feature_extract_dir, dino_pretrain_path
24 |
25 | def extract_features_dino(model, loader, save_dir):
26 |
27 | model.to(device)
28 | model.eval()
29 |
30 | with torch.no_grad():
31 | for batch_idx, batch in enumerate(tqdm(loader)):
32 |
33 | images, labels, idxs = batch[:3]
34 | images = images.to(device)
35 |
36 | features = model(images) # CLS_Token for ViT, Average pooled vector for R50
37 |
38 | # Save features
39 | for f, t, uq in zip(features, labels, idxs):
40 |
41 | t = t.item()
42 | uq = uq.item()
43 |
44 | save_path = os.path.join(save_dir, f'{t}', f'{uq}.npy')
45 | torch.save(f.detach().cpu().numpy(), save_path)
46 |
47 |
48 | def extract_features_timm(model, loader, save_dir):
49 |
50 | model.to(device)
51 | model.eval()
52 |
53 | with torch.no_grad():
54 | for batch_idx, batch in enumerate(tqdm(loader)):
55 |
56 | images, labels, idxs = batch[:3]
57 | images = images.to(device)
58 |
59 | features = model.forward_features(images) # CLS_Token for ViT, Average pooled vector for R50
60 |
61 | # Save features
62 | for f, t, uq in zip(features, labels, idxs):
63 |
64 | t = t.item()
65 | uq = uq.item()
66 |
67 | save_path = os.path.join(save_dir, f'{t}', f'{uq}.npy')
68 | torch.save(f.detach().cpu().numpy(), save_path)
69 |
70 |
71 | if __name__ == "__main__":
72 |
73 | parser = argparse.ArgumentParser(
74 | description='cluster',
75 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
76 | parser.add_argument('--batch_size', default=128, type=int)
77 | parser.add_argument('--num_workers', default=8, type=int)
78 | parser.add_argument('--root_dir', type=str, default=feature_extract_dir)
79 | parser.add_argument('--warmup_model_dir', type=str,
80 | default=None)
81 | parser.add_argument('--use_best_model', type=str2bool, default=True)
82 | parser.add_argument('--model_name', type=str, default='vit_dino', help='Format is {model_name}_{pretrain}')
83 | parser.add_argument('--dataset', type=str, default='aircraft', help='options: cifar10, cifar100, scars')
84 | parser.add_argument('--pretrain_model', type=str, default='dino')
85 |
86 | # ----------------------
87 | # INIT
88 | # ----------------------
89 | args = parser.parse_args()
90 | device = torch.device('cuda:0')
91 |
92 | args.save_dir = os.path.join(args.root_dir, f'{args.model_name}_{args.dataset}')
93 | print(args)
94 |
95 | print('Loading model...')
96 | # ----------------------
97 | # MODEL
98 | # ----------------------
99 | if args.model_name == 'vit_dino':
100 |
101 | extract_features_func = extract_features_dino
102 | args.interpolation = 3
103 | args.crop_pct = 0.875
104 | if args.pretrain_model == 'dino':
105 | pretrain_path = dino_pretrain_path
106 |
107 | model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16', pretrained=False)
108 |
109 | state_dict = torch.load(pretrain_path, map_location='cpu')
110 | model.load_state_dict(state_dict)
111 |
112 | _, val_transform = get_transform('imagenet', image_size=224, args=args)
113 |
114 | elif args.model_name == 'resnet50_dino':
115 |
116 | extract_features_func = extract_features_dino
117 | args.interpolation = 3
118 | args.crop_pct = 0.875
119 | pretrain_path = '/work/sagar/pretrained_models/dino/dino_resnet50_pretrain.pth'
120 |
121 | model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50', pretrained=False)
122 |
123 | state_dict = torch.load(pretrain_path, map_location='cpu')
124 | model.load_state_dict(state_dict)
125 |
126 | _, val_transform = get_transform('imagenet', image_size=224, args=args)
127 |
128 | else:
129 |
130 | raise NotImplementedError
131 |
132 | if args.warmup_model_dir is not None:
133 |
134 | warmup_id = args.warmup_model_dir.split('(')[1].split(')')[0]
135 |
136 | if args.use_best_model:
137 | args.warmup_model_dir = args.warmup_model_dir[:-3] + '_best.pt'
138 | args.save_dir += '_(' + args.warmup_model_dir.split('(')[1].split(')')[0] + ')_best'
139 | else:
140 | args.save_dir += '_(' + args.warmup_model_dir.split('(')[1].split(')')[0] + ')'
141 |
142 | print(f'Using weights from {args.warmup_model_dir} ...')
143 | state_dict = torch.load(args.warmup_model_dir)
144 | model.load_state_dict(state_dict)
145 |
146 | print(f'Saving to {args.save_dir}')
147 |
148 | print('Loading data...')
149 | # ----------------------
150 | # DATASET
151 | # ----------------------
152 | if args.dataset == 'cifar10':
153 |
154 | train_dataset = CustomCIFAR10(root=cifar_10_root, train=True, transform=val_transform)
155 | test_dataset = CustomCIFAR10(root=cifar_10_root, train=False, transform=val_transform)
156 | targets = list(set(train_dataset.targets))
157 |
158 | elif args.dataset == 'cifar100':
159 |
160 | train_dataset = CustomCIFAR100(root=cifar_100_root, train=True, transform=val_transform)
161 | test_dataset = CustomCIFAR100(root=cifar_100_root, train=False, transform=val_transform)
162 | targets = list(set(train_dataset.targets))
163 |
164 | elif args.dataset == 'scars':
165 |
166 | train_dataset = CarsDataset(train=True, transform=val_transform)
167 | test_dataset = CarsDataset(train=False, transform=val_transform)
168 | targets = list(set(train_dataset.target))
169 | targets = [i - 1 for i in targets] # SCars are labelled 1 - 197. Change to 0 - 196
170 |
171 | elif args.dataset == 'imagenet_100':
172 |
173 | datasets = get_imagenet_100_datasets(train_transform=val_transform, test_transform=val_transform,
174 | train_classes=range(50),
175 | prop_train_labels=0.5)
176 |
177 | datasets['train_labelled'].target_transform = None
178 | datasets['train_unlabelled'].target_transform = None
179 |
180 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']),
181 | unlabelled_dataset=deepcopy(datasets['train_unlabelled']))
182 |
183 | test_dataset = datasets['test']
184 | targets = list(set(test_dataset.targets))
185 |
186 | elif args.dataset == 'cub':
187 |
188 | train_dataset = CustomCub2011(root=cub_root, transform=val_transform, train=True)
189 | test_dataset = CustomCub2011(root=cub_root, transform=val_transform, train=False)
190 | targets = list(set(train_dataset.data.target.values))
191 | targets = [i - 1 for i in targets] # SCars are labelled 1 - 200. Change to 0 - 199
192 |
193 | elif args.dataset == 'aircraft':
194 |
195 | train_dataset = FGVCAircraft(root=aircraft_root, transform=val_transform, split='trainval')
196 | test_dataset = FGVCAircraft(root=aircraft_root, transform=val_transform, split='test')
197 | targets = list(set([s[1] for s in train_dataset.samples]))
198 |
199 | elif args.dataset == 'pets':
200 |
201 | train_dataset = OxfordIIITPet(root=pets_root, transform=val_transform, split='trainval')
202 | test_dataset = OxfordIIITPet(root=pets_root, transform=val_transform, split='test')
203 | targets = list(set([s[1] for s in train_dataset.samples]))
204 |
205 | else:
206 |
207 | raise NotImplementedError
208 |
209 | # ----------------------
210 | # DATALOADER
211 | # ----------------------
212 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
213 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
214 |
215 | print('Creating base directories...')
216 | # ----------------------
217 | # INIT SAVE DIRS
218 | # Create a directory for each class
219 | # ----------------------
220 | if not os.path.exists(args.save_dir):
221 | os.makedirs(args.save_dir)
222 |
223 | for fold in ('train', 'test'):
224 |
225 | fold_dir = os.path.join(args.save_dir, fold)
226 | if not os.path.exists(fold_dir):
227 | os.mkdir(fold_dir)
228 |
229 | for t in targets:
230 | target_dir = os.path.join(fold_dir, f'{t}')
231 | if not os.path.exists(target_dir):
232 | os.mkdir(target_dir)
233 |
234 | # ----------------------
235 | # EXTRACT FEATURES
236 | # ----------------------
237 | # Extract train features
238 | train_save_dir = os.path.join(args.save_dir, 'train')
239 | print('Extracting features from train split...')
240 | extract_features_func(model=model, loader=train_loader, save_dir=train_save_dir)
241 |
242 | # Extract test features
243 | test_save_dir = os.path.join(args.save_dir, 'test')
244 | print('Extracting features from test split...')
245 | extract_features_func(model=model, loader=test_loader, save_dir=test_save_dir)
246 |
247 | print('Done!')
--------------------------------------------------------------------------------
/data/pets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import pathlib
4 | from typing import Any, Callable, Optional, Union, Tuple
5 | from typing import Sequence
6 |
7 | import numpy as np
8 | from copy import deepcopy
9 | from scipy import io as mat_io
10 |
11 | from data.data_utils import subsample_instances
12 | from config import pets_root
13 |
14 | from torchvision.datasets.folder import default_loader
15 | from torch.utils.data import Dataset
16 |
17 | from data.data_utils import subsample_instances
18 |
19 |
20 | def make_dataset(dir, image_ids, targets):
21 | assert(len(image_ids) == len(targets))
22 | images = []
23 | dir = os.path.expanduser(dir)
24 | for i in range(len(image_ids)):
25 | item = (os.path.join(dir, 'images',
26 | '%s.jpg' % image_ids[i]), targets[i])
27 | images.append(item)
28 | return images
29 |
30 | class OxfordIIITPet(Dataset):
31 | """`Oxford-IIIT Pet Dataset `_.
32 |
33 | Args:
34 | root (string): Root directory of the dataset.
35 | split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``.
36 | target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or
37 | ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
38 |
39 | - ``category`` (int): Label for one of the 37 pet categories.
40 | - ``segmentation`` (PIL image): Segmentation trimap of the image.
41 |
42 | If empty, ``None`` will be returned as target.
43 |
44 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
45 | version. E.g, ``transforms.RandomCrop``.
46 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
47 | download (bool, optional): If True, downloads the dataset from the internet and puts it into
48 | ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
49 | """
50 |
51 | _RESOURCES = (
52 | ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
53 | ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
54 | )
55 | valid_target_types = ("category", "segmentation")
56 | splits = ('train', 'val', 'trainval', 'test')
57 |
58 | def __init__(
59 | self,
60 | root: str,
61 | split: str = "trainval",
62 | target_types: Union[Sequence[str], str] = "category",
63 | transforms: Optional[Callable] = None,
64 | transform: Optional[Callable] = None,
65 | target_transform: Optional[Callable] = None,
66 | download: bool = False,
67 | loader=default_loader
68 | ):
69 | if split not in self.splits:
70 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
71 | split, ', '.join(self.splits),
72 | ))
73 | self.split = split
74 |
75 | if isinstance(target_types, str):
76 | target_types = [target_types]
77 |
78 | self.root = root
79 | self.transform = transform
80 | self.target_transform = target_transform
81 | self.loader = loader
82 |
83 | self._base_folder = pathlib.Path(self.root)
84 | self._images_folder = self._base_folder / "images"
85 | self._anns_folder = self._base_folder / "annotations"
86 | self._segs_folder = self._anns_folder / "trimaps"
87 |
88 | if download:
89 | self._download()
90 |
91 | if not self._check_exists():
92 | raise RuntimeError("Dataset not found. You can use download=True to download it")
93 |
94 | image_ids = []
95 | self._labels = []
96 | with open(self._anns_folder / f"{self.split}.txt") as file:
97 | for line in file:
98 | image_id, label, *_ = line.strip().split()
99 | image_ids.append(image_id)
100 | self._labels.append(int(label) - 1)
101 |
102 | self.classes = [
103 | " ".join(part.title() for part in raw_cls.split("_"))
104 | for raw_cls, _ in sorted(
105 | {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)},
106 | key=lambda image_id_and_label: image_id_and_label[1],
107 | )
108 | ]
109 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
110 |
111 | samples = make_dataset(self.root, image_ids, self._labels)
112 | self.samples = samples
113 | self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
114 | self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
115 |
116 | self.uq_idxs = np.array(range(len(self)))
117 |
118 | def __len__(self) -> int:
119 | return len(self.samples)
120 |
121 | def __getitem__(self, idx: int) -> Tuple[Any, Any]:
122 | """
123 | Args:
124 | index (int): Index
125 |
126 | Returns:
127 | tuple: (sample, target) where target is class_index of the target class.
128 | """
129 |
130 | path, target = self.samples[idx]
131 | sample = self.loader(path)
132 | if self.transform is not None:
133 | sample = self.transform(sample)
134 | if self.target_transform is not None:
135 | target = self.target_transform(target)
136 |
137 | return sample, target, self.uq_idxs[idx]
138 |
139 | def _check_exists(self) -> bool:
140 | for folder in (self._images_folder, self._anns_folder):
141 | if not (os.path.exists(folder) and os.path.isdir(folder)):
142 | return False
143 | else:
144 | return True
145 |
146 | def _download(self) -> None:
147 | if self._check_exists():
148 | return
149 |
150 | for url, md5 in self._RESOURCES:
151 | download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)
152 |
153 | def subsample_dataset(dataset, idxs):
154 |
155 | mask = np.zeros(len(dataset)).astype('bool')
156 | mask[idxs] = True
157 |
158 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs]
159 | dataset.uq_idxs = dataset.uq_idxs[mask]
160 |
161 | return dataset
162 |
163 | def subsample_classes(dataset, include_classes=range(60)):
164 |
165 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] # 1885
166 |
167 | # TODO: Don't transform targets for now
168 | target_xform_dict = {}
169 | for i, k in enumerate(include_classes):
170 | target_xform_dict[k] = i
171 |
172 | dataset = subsample_dataset(dataset, cls_idxs)
173 |
174 | dataset.target_transform = lambda x: target_xform_dict[x]
175 |
176 | return dataset
177 |
178 | def get_train_val_indices(train_dataset, val_split=0.2):
179 |
180 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)]
181 | train_classes = np.unique(all_targets)
182 |
183 | # Get train/test indices
184 | train_idxs = []
185 | val_idxs = []
186 | for cls in train_classes:
187 | cls_idxs = np.where(all_targets == cls)[0]
188 |
189 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
190 | t_ = [x for x in cls_idxs if x not in v_]
191 |
192 | train_idxs.extend(t_)
193 | val_idxs.extend(v_)
194 |
195 | return train_idxs, val_idxs
196 |
197 | def get_pets_datasets(train_transform, test_transform, train_classes=range(19), prop_train_labels=0.8,
198 | split_train_val=False, seed=0):
199 |
200 | np.random.seed(seed)
201 |
202 | # Init entire training set
203 | whole_training_set = OxfordIIITPet(root=pets_root, transform=train_transform, split='trainval', download=False) # len = 3680
204 |
205 | # Get labelled training set which has subsampled classes, then subsample some indices from that
206 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
207 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
208 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
209 |
210 | # Split into training and validation sets
211 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
212 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
213 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
214 | val_dataset_labelled_split.transform = test_transform
215 |
216 | # Get unlabelled data
217 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
218 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
219 |
220 | # Get test set for all classes
221 | test_dataset = OxfordIIITPet(root=pets_root, transform=test_transform, split='test', download=False)
222 |
223 | # Either split train into train and val or use test set as val
224 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
225 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
226 |
227 | all_datasets = {
228 | 'train_labelled': train_dataset_labelled,
229 | 'train_unlabelled': train_dataset_unlabelled,
230 | 'val': val_dataset_labelled,
231 | 'test': test_dataset,
232 | }
233 |
234 | return all_datasets
235 |
236 | def main():
237 | x = get_pets_datasets(None, None, split_train_val=False)
238 |
239 | print('Printing lens...')
240 | for k, v in x.items():
241 | if v is not None:
242 | print(f'{k}: {len(v)}')
243 |
244 | print('Printing labelled and unlabelled overlap...')
245 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
246 | print('Printing total instances in train...')
247 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
248 | print('Printing number of labelled classes...')
249 | print(len(set([i[1] for i in x['train_labelled'].samples])))
250 | print('Printing total number of classes...')
251 | print(len(set([i[1] for i in x['train_unlabelled'].samples])))
252 |
253 |
254 | if __name__ == '__main__':
255 | main()
--------------------------------------------------------------------------------
/data/flower.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | from path import Path
4 | import pathlib
5 | from typing import Any, Callable, Optional, Union, Tuple
6 |
7 | import numpy as np
8 | from copy import deepcopy
9 | from scipy import io as mat_io
10 | from scipy.io import loadmat
11 |
12 | from data.data_utils import subsample_instances
13 | from config import flower_root
14 |
15 | from torchvision.datasets.utils import check_integrity
16 | from torchvision.datasets.folder import default_loader
17 | from torch.utils.data import Dataset
18 |
19 | from data.data_utils import subsample_instances
20 |
21 |
22 | def make_dataset(dir, image_ids, targets):
23 | assert(len(image_ids) == len(targets))
24 | images = []
25 | dir = os.path.expanduser(dir)
26 | for i in range(len(image_ids)):
27 | item = (os.path.join(dir, 'jpg',
28 | '%s' % image_ids[i]), targets[i])
29 | images.append(item)
30 | return images
31 |
32 | class Flowers102(Dataset):
33 | """`Oxford 102 Flower `_ Dataset.
34 |
35 | .. warning::
36 |
37 | This class needs `scipy `_ to load target files from `.mat` format.
38 |
39 | Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
40 | flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
41 | between 40 and 258 images.
42 |
43 | The images have large scale, pose and light variations. In addition, there are categories that
44 | have large variations within the category, and several very similar categories.
45 |
46 | Args:
47 | root (string): Root directory of the dataset.
48 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
49 | transform (callable, optional): A function/transform that takes in an PIL image and returns a
50 | transformed version. E.g, ``transforms.RandomCrop``.
51 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
52 | download (bool, optional): If true, downloads the dataset from the internet and
53 | puts it in root directory. If dataset is already downloaded, it is not
54 | downloaded again.
55 | """
56 |
57 | _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
58 | _file_dict = { # filename, md5
59 | "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
60 | "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
61 | "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
62 | }
63 | _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
64 |
65 | def __init__(
66 | self,
67 | root: str,
68 | split: str = "train",
69 | transform: Optional[Callable] = None,
70 | target_transform: Optional[Callable] = None,
71 | download: bool = False,
72 | loader=default_loader
73 | ) -> None:
74 | self.splits = ('train', 'val', 'trainval', 'test')
75 |
76 | if split not in self.splits:
77 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
78 | split, ', '.join(self.splits),
79 | ))
80 | self.split = split
81 | self.root = root
82 |
83 | self.transform = transform
84 | self.target_transform = target_transform
85 | self.loader = loader
86 |
87 | self._base_folder = pathlib.Path(self.root)
88 | self._images_folder = self._base_folder / "jpg"
89 |
90 | if download:
91 | self.download()
92 |
93 | if not self._check_integrity():
94 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
95 |
96 | if self.split == 'trainval':
97 | set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
98 | image_ids_train = set_ids[self._splits_map['train']].tolist()
99 | image_ids_val = set_ids[self._splits_map['val']].tolist()
100 | image_ids = image_ids_train + image_ids_val
101 |
102 | else:
103 | set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
104 | image_ids = set_ids[self._splits_map[self.split]].tolist()
105 |
106 | labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
107 | image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
108 |
109 | self._labels = []
110 | self._image_files = []
111 | for image_id in image_ids:
112 | self._labels.append(image_id_to_label[image_id])
113 | self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
114 |
115 | image_name = [f"image_{image_id:05d}.jpg" for image_id in image_ids]
116 |
117 | samples = make_dataset(self.root, image_name, self._labels) # 2040
118 | self.samples = samples
119 | self.uq_idxs = np.array(range(len(self)))
120 |
121 |
122 | def __len__(self) -> int:
123 | return len(self.samples)
124 |
125 | def __getitem__(self, idx) -> Tuple[Any, Any]:
126 | """
127 | Args:
128 | index (int): Index
129 |
130 | Returns:
131 | tuple: (sample, target) where target is class_index of the target class.
132 | """
133 |
134 | path, target = self.samples[idx]
135 | sample = self.loader(path)
136 | if self.transform is not None:
137 | sample = self.transform(sample)
138 | if self.target_transform is not None:
139 | target = self.target_transform(target)
140 |
141 | return sample, target, self.uq_idxs[idx]
142 |
143 | def extra_repr(self) -> str:
144 | return f"split={self._split}"
145 |
146 | def _check_integrity(self):
147 | if not (self._images_folder.exists() and self._images_folder.is_dir()):
148 | return False
149 |
150 | for id in ["label", "setid"]:
151 | filename, md5 = self._file_dict[id]
152 | if not check_integrity(str(self._base_folder / filename), md5):
153 | return False
154 | return True
155 |
156 | def download(self):
157 | if self._check_integrity():
158 | return
159 | download_and_extract_archive(
160 | f"{self._download_url_prefix}{self._file_dict['image'][0]}",
161 | str(self._base_folder),
162 | md5=self._file_dict["image"][1],
163 | )
164 | for id in ["label", "setid"]:
165 | filename, md5 = self._file_dict[id]
166 | download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
167 |
168 | def subsample_dataset(dataset, idxs):
169 |
170 | mask = np.zeros(len(dataset)).astype('bool')
171 | mask[idxs] = True
172 |
173 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs]
174 | dataset.uq_idxs = dataset.uq_idxs[mask]
175 |
176 | return dataset
177 |
178 | def subsample_classes(dataset, include_classes=range(60)):
179 |
180 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] # 1885
181 |
182 | # TODO: Don't transform targets for now
183 | target_xform_dict = {}
184 | for i, k in enumerate(include_classes):
185 | target_xform_dict[k] = i
186 |
187 | dataset = subsample_dataset(dataset, cls_idxs)
188 |
189 | dataset.target_transform = lambda x: target_xform_dict[x]
190 |
191 | return dataset
192 |
193 | def get_train_val_indices(train_dataset, val_split=0.2):
194 |
195 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)]
196 | train_classes = np.unique(all_targets)
197 |
198 | # Get train/test indices
199 | train_idxs = []
200 | val_idxs = []
201 | for cls in train_classes:
202 | cls_idxs = np.where(all_targets == cls)[0]
203 |
204 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
205 | t_ = [x for x in cls_idxs if x not in v_]
206 |
207 | train_idxs.extend(t_)
208 | val_idxs.extend(v_)
209 |
210 | return train_idxs, val_idxs
211 |
212 | def get_flower_datasets(train_transform, test_transform, train_classes=range(51), prop_train_labels=0.8,
213 | split_train_val=False, seed=0):
214 |
215 | np.random.seed(seed)
216 |
217 | # Init entire training set
218 | whole_training_set = Flowers102(root=flower_root, transform=train_transform, split='test', download=False) # len = 6149
219 |
220 | # Get labelled training set which has subsampled classes, then subsample some indices from that
221 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
222 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
223 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
224 |
225 | # Split into training and validation sets
226 | train_dataset_labelled_split = Flowers102(root=flower_root, transform=train_transform, split='train', download=False)
227 | val_dataset_labelled_split = Flowers102(root=flower_root, transform=train_transform, split='val', download=False)
228 | val_dataset_labelled_split.transform = test_transform
229 |
230 | # Get unlabelled data
231 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
232 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
233 |
234 | # Get test set for all classes
235 | test_dataset = Flowers102(root=flower_root, transform=test_transform, split='trainval', download=False) # len = 2040
236 |
237 | # Either split train into train and val or use test set as val
238 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
239 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
240 |
241 | all_datasets = {
242 | 'train_labelled': train_dataset_labelled,
243 | 'train_unlabelled': train_dataset_unlabelled,
244 | 'val': val_dataset_labelled,
245 | 'test': test_dataset,
246 | }
247 |
248 | return all_datasets
249 |
250 | def main():
251 | x = get_flower_datasets(None, None, split_train_val=False)
252 |
253 | print('Printing lens...')
254 | for k, v in x.items():
255 | if v is not None:
256 | print(f'{k}: {len(v)}')
257 |
258 | print('Printing labelled and unlabelled overlap...')
259 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
260 | print('Printing total instances in train...')
261 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
262 | print('Printing number of labelled classes...')
263 | print(len(set([i[1] for i in x['train_labelled'].samples])))
264 | print('Printing total number of classes...')
265 | print(len(set([i[1] for i in x['train_unlabelled'].samples])))
266 |
267 |
268 | if __name__ == '__main__':
269 | main()
--------------------------------------------------------------------------------
/data/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 |
6 | from torchvision.datasets.folder import default_loader
7 | from torch.utils.data import Dataset
8 |
9 | from data.data_utils import subsample_instances
10 | from config import aircraft_root
11 |
12 | def make_dataset(dir, image_ids, targets):
13 | assert(len(image_ids) == len(targets))
14 | images = []
15 | dir = os.path.expanduser(dir)
16 | for i in range(len(image_ids)):
17 | item = (os.path.join(dir, 'data', 'images',
18 | '%s.jpg' % image_ids[i]), targets[i])
19 | images.append(item)
20 | return images
21 |
22 |
23 | def find_classes(classes_file):
24 |
25 | # read classes file, separating out image IDs and class names
26 | image_ids = []
27 | targets = []
28 | f = open(classes_file, 'r')
29 | for line in f:
30 | split_line = line.split(' ')
31 | image_ids.append(split_line[0])
32 | targets.append(' '.join(split_line[1:]))
33 | f.close()
34 |
35 | # index class names
36 | classes = np.unique(targets)
37 | class_to_idx = {classes[i]: i for i in range(len(classes))}
38 | targets = [class_to_idx[c] for c in targets]
39 |
40 | return (image_ids, targets, classes, class_to_idx)
41 |
42 |
43 | class FGVCAircraft(Dataset):
44 |
45 | """`FGVC-Aircraft `_ Dataset.
46 |
47 | Args:
48 | root (string): Root directory path to dataset.
49 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
50 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
51 | transform (callable, optional): A function/transform that takes in a PIL image
52 | and returns a transformed version. E.g. ``transforms.RandomCrop``
53 | target_transform (callable, optional): A function/transform that takes in the
54 | target and transforms it.
55 | loader (callable, optional): A function to load an image given its path.
56 | download (bool, optional): If true, downloads the dataset from the internet and
57 | puts it in the root directory. If dataset is already downloaded, it is not
58 | downloaded again.
59 | """
60 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
61 | class_types = ('variant', 'family', 'manufacturer')
62 | splits = ('train', 'val', 'trainval', 'test')
63 |
64 | def __init__(self, root, class_type='variant', split='train', transform=None,
65 | target_transform=None, loader=default_loader, download=False):
66 | if split not in self.splits:
67 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
68 | split, ', '.join(self.splits),
69 | ))
70 | if class_type not in self.class_types:
71 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
72 | class_type, ', '.join(self.class_types),
73 | ))
74 | self.root = os.path.expanduser(root)
75 | self.class_type = class_type
76 | self.split = split
77 | self.classes_file = os.path.join(self.root, 'data',
78 | 'images_%s_%s.txt' % (self.class_type, self.split))
79 |
80 | if download:
81 | self.download()
82 |
83 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) # 6667, 6667, 100, 100
84 | samples = make_dataset(self.root, image_ids, targets)
85 | self.transform = transform
86 | self.target_transform = target_transform
87 | self.loader = loader
88 |
89 | self.samples = samples
90 | self.classes = classes
91 | self.class_to_idx = class_to_idx
92 | self.train = True if split == 'train' else False
93 |
94 | self.uq_idxs = np.array(range(len(self)))
95 |
96 | def __getitem__(self, index):
97 | """
98 | Args:
99 | index (int): Index
100 |
101 | Returns:
102 | tuple: (sample, target) where target is class_index of the target class.
103 | """
104 |
105 | path, target = self.samples[index]
106 | sample = self.loader(path)
107 | if self.transform is not None:
108 | sample = self.transform(sample)
109 | if self.target_transform is not None:
110 | target = self.target_transform(target)
111 |
112 | return sample, target, self.uq_idxs[index]
113 |
114 | def __len__(self):
115 | return len(self.samples)
116 |
117 | def __repr__(self):
118 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
119 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
120 | fmt_str += ' Root Location: {}\n'.format(self.root)
121 | tmp = ' Transforms (if any): '
122 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
123 | tmp = ' Target Transforms (if any): '
124 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
125 | return fmt_str
126 |
127 | def _check_exists(self):
128 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
129 | os.path.exists(self.classes_file)
130 |
131 | def download(self):
132 | """Download the FGVC-Aircraft data if it doesn't exist already."""
133 | from six.moves import urllib
134 | import tarfile
135 |
136 | if self._check_exists():
137 | return
138 |
139 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
140 | print('Downloading %s ... (may take a few minutes)' % self.url)
141 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
142 | tar_name = self.url.rpartition('/')[-1]
143 | tar_path = os.path.join(parent_dir, tar_name)
144 | data = urllib.request.urlopen(self.url)
145 |
146 | # download .tar.gz file
147 | with open(tar_path, 'wb') as f:
148 | f.write(data.read())
149 |
150 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
151 | data_folder = tar_path.strip('.tar.gz')
152 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
153 | tar = tarfile.open(tar_path)
154 | tar.extractall(parent_dir)
155 |
156 | # if necessary, rename data folder to self.root
157 | if not os.path.samefile(data_folder, self.root):
158 | print('Renaming %s to %s ...' % (data_folder, self.root))
159 | os.rename(data_folder, self.root)
160 |
161 | # delete .tar.gz file
162 | print('Deleting %s ...' % tar_path)
163 | os.remove(tar_path)
164 |
165 | print('Done!')
166 |
167 |
168 | def subsample_dataset(dataset, idxs):
169 |
170 | mask = np.zeros(len(dataset)).astype('bool')
171 | mask[idxs] = True
172 |
173 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs]
174 | dataset.uq_idxs = dataset.uq_idxs[mask]
175 |
176 | return dataset
177 |
178 |
179 | def subsample_classes(dataset, include_classes=range(60)):
180 |
181 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes]
182 |
183 | # TODO: Don't transform targets for now
184 | target_xform_dict = {}
185 | for i, k in enumerate(include_classes):
186 | target_xform_dict[k] = i
187 |
188 | dataset = subsample_dataset(dataset, cls_idxs)
189 |
190 | dataset.target_transform = lambda x: target_xform_dict[x]
191 |
192 | return dataset
193 |
194 |
195 | def get_train_val_indices(train_dataset, val_split=0.2):
196 |
197 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)]
198 | train_classes = np.unique(all_targets)
199 |
200 | # Get train/test indices
201 | train_idxs = []
202 | val_idxs = []
203 | for cls in train_classes:
204 | cls_idxs = np.where(all_targets == cls)[0]
205 |
206 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
207 | t_ = [x for x in cls_idxs if x not in v_]
208 |
209 | train_idxs.extend(t_)
210 | val_idxs.extend(v_)
211 |
212 | return train_idxs, val_idxs
213 |
214 |
215 | def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8,
216 | split_train_val=False, seed=0):
217 |
218 | np.random.seed(seed)
219 |
220 | # Init entire training set
221 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval') # len = 6667
222 |
223 | # Get labelled training set which has subsampled classes, then subsample some indices from that
224 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
225 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
226 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
227 |
228 | # Split into training and validation sets
229 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
230 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
231 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
232 | val_dataset_labelled_split.transform = test_transform
233 |
234 | # Get unlabelled data
235 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
236 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
237 |
238 | # Get test set for all classes
239 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test')
240 |
241 | # Either split train into train and val or use test set as val
242 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
243 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
244 |
245 | all_datasets = {
246 | 'train_labelled': train_dataset_labelled,
247 | 'train_unlabelled': train_dataset_unlabelled,
248 | 'val': val_dataset_labelled,
249 | 'test': test_dataset,
250 | }
251 |
252 | return all_datasets
253 |
254 |
255 | def main():
256 | x = get_aircraft_datasets(None, None, split_train_val=False)
257 |
258 | print('Printing lens...')
259 | for k, v in x.items():
260 | if v is not None:
261 | print(f'{k}: {len(v)}')
262 |
263 | print('Printing labelled and unlabelled overlap...')
264 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
265 | print('Printing total instances in train...')
266 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
267 | print('Printing number of labelled classes...')
268 | print(len(set([i[1] for i in x['train_labelled'].samples])))
269 | print('Printing total number of classes...')
270 | print(len(set([i[1] for i in x['train_unlabelled'].samples])))
271 |
272 | if __name__ == '__main__':
273 | main()
274 |
--------------------------------------------------------------------------------
/project_utils/general_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 | import inspect
6 |
7 | from torch.utils.tensorboard import SummaryWriter
8 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
9 |
10 | from datetime import datetime
11 |
12 | class AverageMeter(object):
13 | """Computes and stores the average and current value"""
14 | def __init__(self):
15 | self.reset()
16 |
17 | def reset(self):
18 |
19 | self.val = 0
20 | self.avg = 0
21 | self.sum = 0
22 | self.count = 0
23 |
24 | def update(self, val, n=1):
25 |
26 | self.val = val
27 | self.sum += val * n
28 | self.count += n
29 | self.avg = self.sum / self.count
30 |
31 |
32 | def seed_torch(seed=1029):
33 |
34 | random.seed(seed)
35 | os.environ['PYTHONHASHSEED'] = str(seed)
36 | np.random.seed(seed)
37 | torch.manual_seed(seed)
38 | torch.cuda.manual_seed(seed)
39 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
40 | torch.backends.cudnn.benchmark = False
41 | torch.backends.cudnn.deterministic = True
42 |
43 |
44 | def strip_state_dict(state_dict, strip_key='module.'):
45 |
46 | """
47 | Strip 'module' from start of state_dict keys
48 | Useful if model has been trained as DataParallel model
49 | """
50 |
51 | for k in list(state_dict.keys()):
52 | if k.startswith(strip_key):
53 | state_dict[k[len(strip_key):]] = state_dict[k]
54 | del state_dict[k]
55 |
56 | return state_dict
57 |
58 |
59 | def get_dino_head_weights(pretrain_path):
60 |
61 | """
62 | :param pretrain_path: Path to full DINO pretrained checkpoint as in https://github.com/facebookresearch/dino
63 | 'full_ckpt'
64 | :return: weights only for the projection head
65 | """
66 |
67 | all_weights = torch.load(pretrain_path)
68 |
69 | head_state_dict = {}
70 | for k, v in all_weights['teacher'].items():
71 | if 'head' in k and 'last_layer' not in k:
72 | head_state_dict[k] = v
73 |
74 | head_state_dict = strip_state_dict(head_state_dict, strip_key='head.')
75 |
76 | # Deal with weight norm
77 | weight_norm_state_dict = {}
78 | for k, v in all_weights['teacher'].items():
79 | if 'last_layer' in k:
80 | weight_norm_state_dict[k.split('.')[2]] = v
81 |
82 | linear_shape = weight_norm_state_dict['weight'].shape
83 | dummy_linear = torch.nn.Linear(in_features=linear_shape[1], out_features=linear_shape[0], bias=False)
84 | dummy_linear.load_state_dict(weight_norm_state_dict)
85 | dummy_linear = torch.nn.utils.weight_norm(dummy_linear)
86 |
87 | for k, v in dummy_linear.state_dict().items():
88 |
89 | head_state_dict['last_layer.' + k] = v
90 |
91 | return head_state_dict
92 |
93 | def transform_moco_state_dict(obj, num_classes):
94 |
95 | """
96 | :param obj: Moco State Dict
97 | :param args: argsparse object with training classes
98 | :return: State dict compatable with standard resnet architecture
99 | """
100 |
101 | newmodel = {}
102 | for k, v in obj.items():
103 | if not k.startswith("module.encoder_q."):
104 | continue
105 | old_k = k
106 | k = k.replace("module.encoder_q.", "")
107 |
108 | if k.startswith("fc.2"):
109 | continue
110 |
111 | if k.startswith("fc.0"):
112 | k = k.replace("0.", "")
113 | if "weight" in k:
114 | v = torch.randn((num_classes, v.size(1)))
115 | elif "bias" in k:
116 | v = torch.randn((num_classes,))
117 |
118 | newmodel[k] = v
119 |
120 | return newmodel
121 |
122 |
123 | def init_experiment(args, runner_name=None, exp_id=None):
124 |
125 | args.cuda = torch.cuda.is_available()
126 |
127 | # Get filepath of calling script
128 | if runner_name is None:
129 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:]
130 |
131 | root_dir = os.path.join(args.exp_root, *runner_name)
132 |
133 | if not os.path.exists(root_dir):
134 | os.makedirs(root_dir)
135 |
136 | # Either generate a unique experiment ID, or use one which is passed
137 | if exp_id is None:
138 |
139 | # Unique identifier for experiment
140 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \
141 | datetime.now().strftime("%S.%f")[:-3] + ')'
142 |
143 | log_dir = os.path.join(root_dir, 'log', now)
144 | while os.path.exists(log_dir):
145 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \
146 | datetime.now().strftime("%S.%f")[:-3] + ')'
147 |
148 | log_dir = os.path.join(root_dir, 'log', now)
149 |
150 | else:
151 |
152 | log_dir = os.path.join(root_dir, 'log', f'{exp_id}')
153 |
154 | if not os.path.exists(log_dir):
155 | os.makedirs(log_dir)
156 | args.log_dir = log_dir
157 |
158 | # Instantiate directory to save models to
159 | model_root_dir = os.path.join(args.log_dir, 'checkpoints')
160 | if not os.path.exists(model_root_dir):
161 | os.mkdir(model_root_dir)
162 |
163 | args.model_dir = model_root_dir
164 | args.model_path = os.path.join(args.model_dir, 'model.pt')
165 |
166 | print(f'Experiment saved to: {args.log_dir}')
167 |
168 | args.writer = SummaryWriter(log_dir=args.log_dir)
169 |
170 | hparam_dict = {}
171 |
172 | for k, v in vars(args).items():
173 | if isinstance(v, (int, float, str, bool, torch.Tensor)):
174 | hparam_dict[k] = v
175 |
176 | args.writer.add_hparams(hparam_dict=hparam_dict, metric_dict={})
177 |
178 | print(runner_name)
179 | print(args)
180 |
181 | return args
182 |
183 |
184 | def accuracy(output, target, topk=(1,)):
185 | """Computes the accuracy over the k top predictions for the specified values of k"""
186 | with torch.no_grad():
187 | maxk = max(topk)
188 | batch_size = target.size(0)
189 |
190 | _, pred = output.topk(maxk, 1, True, True)
191 | pred = pred.t()
192 | correct = pred.eq(target.view(1, -1).expand_as(pred))
193 |
194 | res = []
195 | for k in topk:
196 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
197 | res.append(correct_k.mul_(100.0 / batch_size))
198 | return res
199 |
200 |
201 | def str2bool(v):
202 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
203 | return True
204 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
205 | return False
206 | else:
207 | raise argparse.ArgumentTypeError('Boolean value expected.')
208 |
209 |
210 | class ClassificationPredSaver(object):
211 |
212 | def __init__(self, length, save_path=None):
213 |
214 | if save_path is not None:
215 |
216 | # Remove filetype from save_path
217 | save_path = save_path.split('.')[0]
218 | self.save_path = save_path
219 |
220 | self.length = length
221 |
222 | self.all_preds = None
223 | self.all_labels = None
224 |
225 | self.running_start_idx = 0
226 |
227 | def update(self, preds, labels=None):
228 |
229 | # Expect preds in shape B x C
230 |
231 | if torch.is_tensor(preds):
232 | preds = preds.detach().cpu().numpy()
233 |
234 | b, c = preds.shape
235 |
236 | if self.all_preds is None:
237 | self.all_preds = np.zeros((self.length, c))
238 |
239 | self.all_preds[self.running_start_idx: self.running_start_idx + b] = preds
240 |
241 | if labels is not None:
242 | if torch.is_tensor(labels):
243 | labels = labels.detach().cpu().numpy()
244 |
245 | if self.all_labels is None:
246 | self.all_labels = np.zeros((self.length,))
247 |
248 | self.all_labels[self.running_start_idx: self.running_start_idx + b] = labels
249 |
250 | # Maintain running index on dataset being evaluated
251 | self.running_start_idx += b
252 |
253 | def save(self):
254 |
255 | # Softmax over preds
256 | preds = torch.from_numpy(self.all_preds)
257 | preds = torch.nn.Softmax(dim=-1)(preds)
258 | self.all_preds = preds.numpy()
259 |
260 | pred_path = self.save_path + '.pth'
261 | print(f'Saving all predictions to {pred_path}')
262 |
263 | torch.save(self.all_preds, pred_path)
264 |
265 | if self.all_labels is not None:
266 |
267 | # Evaluate
268 | self.evaluate()
269 | torch.save(self.all_labels, self.save_path + '_labels.pth')
270 |
271 | def evaluate(self):
272 |
273 | topk = [1, 5, 10]
274 | topk = [k for k in topk if k < self.all_preds.shape[-1]]
275 | acc = accuracy(torch.from_numpy(self.all_preds), torch.from_numpy(self.all_labels), topk=topk)
276 |
277 | for k, a in zip(topk, acc):
278 | print(f'Top{k} Acc: {a.item()}')
279 |
280 |
281 | def get_acc_auroc_curves(logdir):
282 |
283 | """
284 | :param logdir: Path to logs: E.g '/work/sagar/open_set_recognition/methods/ARPL/log/(12.03.2021_|_32.570)/'
285 | :return:
286 | """
287 |
288 | event_acc = EventAccumulator(logdir)
289 | event_acc.Reload()
290 |
291 | # Only gets scalars
292 | log_info = {}
293 | for tag in event_acc.Tags()['scalars']:
294 |
295 | log_info[tag] = np.array([[x.step, x.value] for x in event_acc.scalars._buckets[tag].items])
296 |
297 | return log_info
298 |
299 |
300 | def get_mean_lr(optimizer):
301 | return torch.mean(torch.Tensor([param_group['lr'] for param_group in optimizer.param_groups])).item()
302 |
303 |
304 | class IndicatePlateau(object):
305 |
306 | def __init__(self, threshold=5e-4, patience_epochs=5, mode='min', threshold_mode='rel'):
307 |
308 | self.patience = patience_epochs
309 | self.cooldown_counter = 0
310 | self.mode = mode
311 | self.threshold = threshold
312 | self.threshold_mode = threshold_mode
313 | self.best = None
314 | self.num_bad_epochs = None
315 | self.mode_worse = None # the worse value for the chosen mode
316 | self.last_epoch = 0
317 | self._init_is_better(mode=mode, threshold=threshold,
318 | threshold_mode=threshold_mode)
319 |
320 | self._init_is_better(mode=mode, threshold=threshold,
321 | threshold_mode=threshold_mode)
322 | self._reset()
323 |
324 | def _reset(self):
325 | """Resets num_bad_epochs counter and cooldown counter."""
326 | self.best = self.mode_worse
327 | self.cooldown_counter = 0
328 | self.num_bad_epochs = 0
329 |
330 | def step(self, metrics, epoch=None):
331 | # convert `metrics` to float, in case it's a zero-dim Tensor
332 | current = float(metrics)
333 | self.last_epoch = epoch
334 |
335 | if self.is_better(current, self.best):
336 | self.best = current
337 | self.num_bad_epochs = 0
338 | else:
339 | self.num_bad_epochs += 1
340 |
341 | if self.num_bad_epochs > self.patience:
342 | print('Tracked metric has plateaud')
343 | self._reset()
344 | return True
345 | else:
346 | return False
347 |
348 | def is_better(self, a, best):
349 |
350 | if self.mode == 'min' and self.threshold_mode == 'rel':
351 | rel_epsilon = 1. - self.threshold
352 | return a < best * rel_epsilon
353 |
354 | elif self.mode == 'min' and self.threshold_mode == 'abs':
355 | return a < best - self.threshold
356 |
357 | elif self.mode == 'max' and self.threshold_mode == 'rel':
358 | rel_epsilon = self.threshold + 1.
359 | return a > best * rel_epsilon
360 |
361 | else: # mode == 'max' and epsilon_mode == 'abs':
362 | return a > best + self.threshold
363 |
364 | def _init_is_better(self, mode, threshold, threshold_mode):
365 |
366 | if mode not in {'min', 'max'}:
367 | raise ValueError('mode ' + mode + ' is unknown!')
368 | if threshold_mode not in {'rel', 'abs'}:
369 | raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
370 |
371 | if mode == 'min':
372 | self.mode_worse = float('inf')
373 | else: # mode == 'max':
374 | self.mode_worse = -float('inf')
375 |
376 | self.mode = mode
377 | self.threshold = threshold
378 | self.threshold_mode = threshold_mode
379 |
380 |
381 | if __name__ == '__main__':
382 |
383 | x = IndicatePlateau(threshold=0.0899)
384 | eps = np.arange(0, 2000, 1)
385 | y = np.exp(-0.01 * eps)
386 |
387 | print(y)
388 | for i, y_ in enumerate(y):
389 |
390 | z = x.step(y_)
391 | if z:
392 | print(f'Plateaud at epoch {i} with val {y_}')
--------------------------------------------------------------------------------
/methods/estimate_k/estimate_k.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
5 | from torch.utils.data import DataLoader
6 | from sklearn.metrics import adjusted_rand_score as ari_score
7 | import numpy as np
8 | from sklearn.cluster import KMeans
9 | import torch
10 | from project_utils.cluster_utils import cluster_acc
11 |
12 | from methods.clustering.feature_vector_dataset import FeatureVectorDataset
13 | from data.get_datasets import get_datasets, get_class_splits
14 |
15 | from config import feature_extract_dir
16 | from tqdm import tqdm
17 |
18 | from scipy.optimize import minimize_scalar
19 | from functools import partial
20 |
21 | from project_utils.cluster_utils import str2bool
22 | # TODO: Debug
23 | import warnings
24 | warnings.filterwarnings("ignore", category=DeprecationWarning)
25 |
26 | def test_kmeans(K, merge_test_loader, args=None, verbose=False):
27 |
28 | """
29 | In this case, the test loader needs to have the labelled and unlabelled subsets of the training data
30 | """
31 |
32 | if K is None:
33 | K = args.num_labeled_classes + args.num_unlabeled_classes
34 |
35 | all_feats = []
36 | targets = np.array([])
37 | mask_lab = np.array([]) # From all the data, which instances belong to the labelled set
38 | mask_cls = np.array([]) # From all the data, which instances belong to seen classes
39 |
40 | print('Collating features...')
41 | # First extract all features
42 | for batch_idx, (feats, label, _, mask_lab_) in enumerate(tqdm(merge_test_loader)):
43 |
44 | feats = feats.to(device)
45 |
46 | feats = torch.nn.functional.normalize(feats, dim=-1)
47 |
48 | all_feats.append(feats.cpu().numpy())
49 | targets = np.append(targets, label.cpu().numpy())
50 | mask_cls = np.append(mask_cls, np.array([True if x.item() in range(len(args.train_classes))
51 | else False for x in label]))
52 | mask_lab = np.append(mask_lab, mask_lab_.cpu().bool().numpy())
53 |
54 | # -----------------------
55 | # K-MEANS
56 | # -----------------------
57 | mask_lab = mask_lab.astype(bool)
58 | mask_cls = mask_cls.astype(bool)
59 |
60 | all_feats = np.concatenate(all_feats)
61 |
62 | print('Fitting K-Means...')
63 | kmeans = KMeans(n_clusters=K, random_state=0).fit(all_feats)
64 | preds = kmeans.labels_
65 |
66 | # -----------------------
67 | # EVALUATE
68 | # -----------------------
69 | mask = mask_lab
70 |
71 |
72 | labelled_acc, labelled_nmi, labelled_ari = cluster_acc(targets.astype(int)[mask], preds.astype(int)[mask]), \
73 | nmi_score(targets[mask], preds[mask]), \
74 | ari_score(targets[mask], preds[mask])
75 |
76 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int)[~mask],
77 | preds.astype(int)[~mask]), \
78 | nmi_score(targets[~mask], preds[~mask]), \
79 | ari_score(targets[~mask], preds[~mask])
80 |
81 | if verbose:
82 | print('K')
83 | print('Labelled Instances acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(labelled_acc, labelled_nmi,
84 | labelled_ari))
85 | print('Unlabelled Instances acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(unlabelled_acc, unlabelled_nmi,
86 | unlabelled_ari))
87 |
88 | # labelled_acc = DUMMY_ACCS[K - 1].item()
89 |
90 | return labelled_acc
91 | # return (labelled_acc, labelled_nmi, labelled_ari), (unlabelled_acc, unlabelled_nmi, unlabelled_ari), mask.astype(float).mean()
92 |
93 |
94 | def test_kmeans_for_scipy(K, merge_test_loader, args=None, verbose=False):
95 |
96 | """
97 | In this case, the test loader needs to have the labelled and unlabelled subsets of the training data
98 | """
99 |
100 | K = int(K)
101 |
102 | all_feats = []
103 | targets = np.array([])
104 | mask_lab = np.array([]) # From all the data, which instances belong to the labelled set
105 | mask_cls = np.array([]) # From all the data, which instances belong to seen classes
106 |
107 | print('Collating features...')
108 | # First extract all features
109 | for batch_idx, (feats, label, _, mask_lab_) in enumerate(tqdm(merge_test_loader)):
110 |
111 | feats = feats.to(device)
112 |
113 | feats = torch.nn.functional.normalize(feats, dim=-1)
114 |
115 | all_feats.append(feats.cpu().numpy())
116 | targets = np.append(targets, label.cpu().numpy())
117 | mask_cls = np.append(mask_cls, np.array([True if x.item() in range(len(args.train_classes))
118 | else False for x in label]))
119 | mask_lab = np.append(mask_lab, mask_lab_.cpu().bool().numpy())
120 |
121 | # -----------------------
122 | # K-MEANS
123 | # -----------------------
124 | mask_lab = mask_lab.astype(bool)
125 | mask_cls = mask_cls.astype(bool)
126 |
127 | all_feats = np.concatenate(all_feats)
128 |
129 | print(f'Fitting K-Means for K = {K}...')
130 | kmeans = KMeans(n_clusters=K, random_state=0).fit(all_feats)
131 | preds = kmeans.labels_
132 |
133 | # -----------------------
134 | # EVALUATE
135 | # -----------------------
136 | mask = mask_lab
137 |
138 |
139 | labelled_acc, labelled_nmi, labelled_ari = cluster_acc(targets.astype(int)[mask], preds.astype(int)[mask]), \
140 | nmi_score(targets[mask], preds[mask]), \
141 | ari_score(targets[mask], preds[mask])
142 |
143 | unlabelled_acc, unlabelled_nmi, unlabelled_ari = cluster_acc(targets.astype(int)[~mask],
144 | preds.astype(int)[~mask]), \
145 | nmi_score(targets[~mask], preds[~mask]), \
146 | ari_score(targets[~mask], preds[~mask])
147 |
148 | print(f'K = {K}')
149 | print('Labelled Instances acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(labelled_acc, labelled_nmi,
150 | labelled_ari))
151 | print('Unlabelled Instances acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(unlabelled_acc, unlabelled_nmi,
152 | unlabelled_ari))
153 |
154 | return -labelled_acc
155 |
156 |
157 | def binary_search(merge_test_loader, args):
158 |
159 | min_classes = args.num_labeled_classes
160 |
161 | # Iter 0
162 | big_k = args.max_classes
163 | small_k = min_classes
164 | diff = big_k - small_k
165 | middle_k = int(0.5 * diff + small_k)
166 |
167 | labelled_acc_big = test_kmeans(big_k, merge_test_loader, args)
168 | labelled_acc_small = test_kmeans(small_k, merge_test_loader, args)
169 | labelled_acc_middle = test_kmeans(middle_k, merge_test_loader, args)
170 |
171 | print(f'Iter 0: BigK {big_k}, Acc {labelled_acc_big:.4f} | MiddleK {middle_k}, Acc {labelled_acc_middle:.4f} | SmallK {small_k}, Acc {labelled_acc_small:.4f} ')
172 | all_accs = [labelled_acc_small, labelled_acc_middle, labelled_acc_big]
173 | best_acc_so_far = np.max(all_accs)
174 | best_acc_at_k = np.array([small_k, middle_k, big_k])[np.argmax(all_accs)]
175 | print(f'Best Acc so far {best_acc_so_far:.4f} at K {best_acc_at_k}')
176 |
177 | for i in range(1, int(np.log2(diff)) + 1):
178 |
179 | if labelled_acc_big > labelled_acc_small:
180 |
181 | best_acc = max(labelled_acc_middle, labelled_acc_big)
182 |
183 | small_k = middle_k
184 | labelled_acc_small = labelled_acc_middle
185 | diff = big_k - small_k
186 | middle_k = int(0.5 * diff + small_k)
187 |
188 | else:
189 |
190 | best_acc = max(labelled_acc_middle, labelled_acc_small)
191 | big_k = middle_k
192 |
193 | diff = big_k - small_k
194 | middle_k = int(0.5 * diff + small_k)
195 | labelled_acc_big = labelled_acc_middle
196 |
197 | labelled_acc_middle = test_kmeans(middle_k, merge_test_loader, args)
198 |
199 | print(f'Iter {i}: BigK {big_k}, Acc {labelled_acc_big:.4f} | MiddleK {middle_k}, Acc {labelled_acc_middle:.4f} | SmallK {small_k}, Acc {labelled_acc_small:.4f} ')
200 | all_accs = [labelled_acc_small, labelled_acc_middle, labelled_acc_big]
201 | best_acc_so_far = np.max(all_accs)
202 | best_acc_at_k = np.array([small_k, middle_k, big_k])[np.argmax(all_accs)]
203 | print(f'Best Acc so far {best_acc_so_far:.4f} at K {best_acc_at_k}')
204 |
205 |
206 | def scipy_optimise(merge_test_loader, args):
207 |
208 | small_k = args.num_labeled_classes
209 | big_k = args.max_classes
210 |
211 | test_k_means_partial = partial(test_kmeans_for_scipy, merge_test_loader=merge_test_loader, args=args, verbose=True)
212 | res = minimize_scalar(test_k_means_partial, bounds=(small_k, big_k), method='bounded', options={'disp': True})
213 | print(f'Optimal K is {res.x}')
214 |
215 |
216 | if __name__ == "__main__":
217 |
218 | parser = argparse.ArgumentParser(
219 | description='cluster',
220 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
221 | parser.add_argument('--batch_size', default=128, type=int)
222 | parser.add_argument('--num_workers', default=8, type=int)
223 | parser.add_argument('--max_classes', default=1000, type=int)
224 | parser.add_argument('--root_dir', type=str, default=feature_extract_dir)
225 | parser.add_argument('--warmup_model_exp_id', type=str, default=None)
226 | parser.add_argument('--model_name', type=str, default='vit_dino', help='Format is {model_name}_{pretrain}')
227 | parser.add_argument('--search_mode', type=str, default='brent', help='Mode for black box optimisation')
228 | parser.add_argument('--dataset_name', type=str, default='cifar10', help='options: cifar10, cifar100, scars')
229 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
230 | parser.add_argument('--use_best_model', type=str2bool, default=True)
231 |
232 | # ----------------------
233 | # INIT
234 | # ----------------------
235 | args = parser.parse_args()
236 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
237 |
238 | cluster_accs = {}
239 |
240 | args.save_dir = os.path.join(args.root_dir, f'{args.model_name}_{args.dataset_name}')
241 |
242 | args = get_class_splits(args)
243 |
244 | args.num_labeled_classes = len(args.train_classes)
245 | args.num_unlabeled_classes = len(args.unlabeled_classes)
246 |
247 | print(args)
248 |
249 | if args.warmup_model_exp_id is not None:
250 | args.save_dir += '_' + args.warmup_model_exp_id
251 | if args.use_best_model:
252 | args.save_dir += '_best'
253 | print(f'Using features from experiment: {args.warmup_model_exp_id}')
254 | else:
255 | print(f'Using pretrained {args.model_name} features...')
256 |
257 | # --------------------
258 | # DATASETS
259 | # --------------------
260 | print('Building datasets...')
261 | train_transform, test_transform = None, None
262 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets, labelled_train_examples = get_datasets(args.dataset_name,
263 | train_transform, test_transform, args)
264 |
265 | # Convert to feature vector dataset
266 | test_dataset = FeatureVectorDataset(base_dataset=test_dataset, feature_root=os.path.join(args.save_dir, 'test'))
267 | unlabelled_train_examples_test = FeatureVectorDataset(base_dataset=unlabelled_train_examples_test,
268 | feature_root=os.path.join(args.save_dir, 'train'))
269 | train_dataset = FeatureVectorDataset(base_dataset=train_dataset, feature_root=os.path.join(args.save_dir, 'train'))
270 |
271 | # --------------------
272 | # DATALOADERS
273 | # --------------------
274 | unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
275 | batch_size=args.batch_size, shuffle=False)
276 | test_loader = DataLoader(test_dataset, num_workers=args.num_workers,
277 | batch_size=args.batch_size, shuffle=False)
278 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers,
279 | batch_size=args.batch_size, shuffle=False)
280 |
281 | print('Testing on all in the training data...')
282 | if args.search_mode == 'brent':
283 | print('Optimising with Brents algorithm')
284 | scipy_optimise(merge_test_loader=train_loader, args=args)
285 | else:
286 | binary_search(train_loader, args)
--------------------------------------------------------------------------------
/model/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 | from torch.nn.init import trunc_normal_
25 |
26 | def drop_path(x, drop_prob: float = 0., training: bool = False):
27 | if drop_prob == 0. or not training:
28 | return x
29 | keep_prob = 1 - drop_prob
30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
32 | random_tensor.floor_() # binarize
33 | output = x.div(keep_prob) * random_tensor
34 | return output
35 |
36 |
37 | class DropPath(nn.Module):
38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
39 | """
40 | def __init__(self, drop_prob=None):
41 | super(DropPath, self).__init__()
42 | self.drop_prob = drop_prob
43 |
44 | def forward(self, x):
45 | return drop_path(x, self.drop_prob, self.training)
46 |
47 |
48 | class Mlp(nn.Module):
49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50 | super().__init__()
51 | out_features = out_features or in_features
52 | hidden_features = hidden_features or in_features
53 | self.fc1 = nn.Linear(in_features, hidden_features)
54 | self.act = act_layer()
55 | self.fc2 = nn.Linear(hidden_features, out_features)
56 | self.drop = nn.Dropout(drop)
57 |
58 | def forward(self, x):
59 | x = self.fc1(x)
60 | x = self.act(x)
61 | x = self.drop(x)
62 | x = self.fc2(x)
63 | x = self.drop(x)
64 | return x
65 |
66 |
67 | class Attention(nn.Module):
68 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
69 | super().__init__()
70 | self.num_heads = num_heads
71 | head_dim = dim // num_heads
72 | self.scale = qk_scale or head_dim ** -0.5
73 |
74 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
75 | self.attn_drop = nn.Dropout(attn_drop)
76 | self.proj = nn.Linear(dim, dim)
77 | self.proj_drop = nn.Dropout(proj_drop)
78 |
79 | def forward(self, x):
80 | B, N, C = x.shape
81 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
82 | q, k, v = qkv[0], qkv[1], qkv[2]
83 |
84 | attn = (q @ k.transpose(-2, -1)) * self.scale
85 | attn = attn.softmax(dim=-1)
86 | attn = self.attn_drop(attn)
87 |
88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
89 | x = self.proj(x)
90 | x = self.proj_drop(x)
91 | return x, attn
92 |
93 |
94 | class Block(nn.Module):
95 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
96 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
97 | super().__init__()
98 | self.norm1 = norm_layer(dim)
99 | self.attn = Attention(
100 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
102 | self.norm2 = norm_layer(dim)
103 | mlp_hidden_dim = int(dim * mlp_ratio)
104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
105 |
106 | def forward(self, x, return_attention=False):
107 | y, attn = self.attn(self.norm1(x))
108 | x = x + self.drop_path(y)
109 | x = x + self.drop_path(self.mlp(self.norm2(x)))
110 |
111 | if return_attention:
112 | return x, attn
113 | else:
114 | return x
115 |
116 |
117 | class PatchEmbed(nn.Module):
118 | """ Image to Patch Embedding
119 | """
120 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
121 | super().__init__()
122 | num_patches = (img_size // patch_size) * (img_size // patch_size)
123 | self.img_size = img_size
124 | self.patch_size = patch_size
125 | self.num_patches = num_patches
126 |
127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
128 |
129 | def forward(self, x):
130 | B, C, H, W = x.shape
131 | x = self.proj(x).flatten(2).transpose(1, 2)
132 | return x
133 |
134 |
135 | class VisionTransformer(nn.Module):
136 | """ Vision Transformer """
137 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
138 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
139 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
140 | super().__init__()
141 | self.num_features = self.embed_dim = embed_dim
142 |
143 | self.patch_embed = PatchEmbed(
144 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
145 | num_patches = self.patch_embed.num_patches
146 |
147 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
148 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
149 | self.pos_drop = nn.Dropout(p=drop_rate)
150 |
151 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
152 | self.blocks = nn.ModuleList([
153 | Block(
154 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
155 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
156 | for i in range(depth)])
157 | self.norm = norm_layer(embed_dim)
158 |
159 | # Classifier head
160 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
161 |
162 | trunc_normal_(self.pos_embed, std=.02)
163 | trunc_normal_(self.cls_token, std=.02)
164 | self.apply(self._init_weights)
165 |
166 | def _init_weights(self, m):
167 | if isinstance(m, nn.Linear):
168 | trunc_normal_(m.weight, std=.02)
169 | if isinstance(m, nn.Linear) and m.bias is not None:
170 | nn.init.constant_(m.bias, 0)
171 | elif isinstance(m, nn.LayerNorm):
172 | nn.init.constant_(m.bias, 0)
173 | nn.init.constant_(m.weight, 1.0)
174 |
175 | def interpolate_pos_encoding(self, x, w, h):
176 | npatch = x.shape[1] - 1
177 | N = self.pos_embed.shape[1] - 1
178 | if npatch == N and w == h:
179 | return self.pos_embed
180 | class_pos_embed = self.pos_embed[:, 0]
181 | patch_pos_embed = self.pos_embed[:, 1:]
182 | dim = x.shape[-1]
183 | w0 = w // self.patch_embed.patch_size
184 | h0 = h // self.patch_embed.patch_size
185 | # we add a small number to avoid floating point error in the interpolation
186 | # see discussion at https://github.com/facebookresearch/dino/issues/8
187 | w0, h0 = w0 + 0.1, h0 + 0.1
188 | patch_pos_embed = nn.functional.interpolate(
189 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
190 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
191 | mode='bicubic',
192 | )
193 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
194 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
195 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
196 |
197 | def prepare_tokens(self, x):
198 | B, nc, w, h = x.shape
199 | x = self.patch_embed(x) # patch linear embedding
200 |
201 | # add the [CLS] token to the embed patch tokens
202 | cls_tokens = self.cls_token.expand(B, -1, -1)
203 | x = torch.cat((cls_tokens, x), dim=1)
204 |
205 | # add positional encoding to each token
206 | x = x + self.interpolate_pos_encoding(x, w, h)
207 |
208 | return self.pos_drop(x)
209 |
210 | def forward(self, x, return_all_patches=False):
211 | x = self.prepare_tokens(x)
212 | for blk in self.blocks:
213 | x = blk(x)
214 | x = self.norm(x)
215 |
216 | if return_all_patches:
217 | return x
218 | else:
219 | return x[:, 0]
220 |
221 | def get_last_selfattention(self, x):
222 | x = self.prepare_tokens(x)
223 | for i, blk in enumerate(self.blocks):
224 | if i < len(self.blocks) - 1:
225 | x = blk(x)
226 | else:
227 | # return attention of the last block
228 | x, attn = blk(x, return_attention=True)
229 | x = self.norm(x)
230 | return x, attn
231 |
232 | def get_intermediate_layers(self, x, n=1):
233 | x = self.prepare_tokens(x)
234 | # we return the output tokens from the `n` last blocks
235 | output = []
236 | for i, blk in enumerate(self.blocks):
237 | x = blk(x)
238 | if len(self.blocks) - i <= n:
239 | output.append(self.norm(x))
240 | return output
241 |
242 |
243 | def vit_tiny(patch_size=16, **kwargs):
244 | model = VisionTransformer(
245 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
246 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
247 | return model
248 |
249 |
250 | def vit_small(patch_size=16, **kwargs):
251 | model = VisionTransformer(
252 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
253 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
254 | return model
255 |
256 |
257 | def vit_base(patch_size=16, **kwargs):
258 | model = VisionTransformer(
259 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
260 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
261 | return model
262 |
263 |
264 | class DINOHead(nn.Module):
265 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
266 | super().__init__()
267 | nlayers = max(nlayers, 1)
268 | if nlayers == 1:
269 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
270 | else:
271 | layers = [nn.Linear(in_dim, hidden_dim)]
272 | if use_bn:
273 | layers.append(nn.BatchNorm1d(hidden_dim))
274 | layers.append(nn.GELU())
275 | for _ in range(nlayers - 2):
276 | layers.append(nn.Linear(hidden_dim, hidden_dim))
277 | if use_bn:
278 | layers.append(nn.BatchNorm1d(hidden_dim))
279 | layers.append(nn.GELU())
280 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
281 | self.mlp = nn.Sequential(*layers)
282 | self.apply(self._init_weights)
283 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
284 | self.last_layer.weight_g.data.fill_(1)
285 | if norm_last_layer:
286 | self.last_layer.weight_g.requires_grad = False
287 |
288 | def _init_weights(self, m):
289 | if isinstance(m, nn.Linear):
290 | trunc_normal_(m.weight, std=.02)
291 | if isinstance(m, nn.Linear) and m.bias is not None:
292 | nn.init.constant_(m.bias, 0)
293 |
294 | def forward(self, x):
295 | x = self.mlp(x)
296 | x = nn.functional.normalize(x, dim=-1, p=2)
297 | x = self.last_layer(x)
298 | return x
299 |
300 |
301 | class VisionTransformerWithLinear(nn.Module):
302 |
303 | def __init__(self, base_vit, num_classes=200):
304 |
305 | super().__init__()
306 |
307 | self.base_vit = base_vit
308 | self.fc = nn.Linear(768, num_classes)
309 |
310 | def forward(self, x, return_features=False):
311 |
312 | features = self.base_vit(x)
313 | features = torch.nn.functional.normalize(features, dim=-1)
314 | logits = self.fc(features)
315 |
316 | if return_features:
317 | return logits, features
318 | else:
319 | return logits
320 |
321 | @torch.no_grad()
322 | def normalize_prototypes(self):
323 | w = self.fc.weight.data.clone()
324 | w = torch.nn.functional.normalize(w, dim=1, p=2)
325 | self.fc.weight.copy_(w)
--------------------------------------------------------------------------------
/methods/clustering/faster_mix_k_means_pytorch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import random
4 | from project_utils.cluster_utils import cluster_acc
5 | from sklearn.utils._joblib import Parallel, delayed, effective_n_jobs
6 | from sklearn.utils import check_random_state
7 | import torch
8 |
9 | def pairwise_distance(data1, data2, batch_size=None):
10 | r'''
11 | using broadcast mechanism to calculate pairwise ecludian distance of data
12 | the input data is N*M matrix, where M is the dimension
13 | we first expand the N*M matrix into N*1*M matrix A and 1*N*M matrix B
14 | then a simple elementwise operation of A and B will handle the pairwise operation of points represented by data
15 | '''
16 | #N*1*M
17 | A = data1.unsqueeze(dim=1)
18 |
19 | #1*N*M
20 | B = data2.unsqueeze(dim=0)
21 |
22 | if batch_size == None:
23 | dis = (A-B)**2
24 | #return N*N matrix for pairwise distance
25 | dis = dis.sum(dim=-1)
26 | # torch.cuda.empty_cache()
27 | else:
28 | i = 0
29 | dis = torch.zeros(data1.shape[0], data2.shape[0])
30 | while i < data1.shape[0]:
31 | if(i+batch_size < data1.shape[0]):
32 | dis_batch = (A[i:i+batch_size]-B)**2
33 | dis_batch = dis_batch.sum(dim=-1)
34 | dis[i:i+batch_size] = dis_batch
35 | i = i+batch_size
36 | # torch.cuda.empty_cache()
37 | elif(i+batch_size >= data1.shape[0]):
38 | dis_final = (A[i:] - B)**2
39 | dis_final = dis_final.sum(dim=-1)
40 | dis[i:] = dis_final
41 | # torch.cuda.empty_cache()
42 | break
43 | # torch.cuda.empty_cache()
44 | return dis
45 |
46 |
47 | class K_Means:
48 |
49 | def __init__(self, k=3, tolerance=1e-4, max_iterations=100, init='k-means++',
50 | n_init=10, random_state=None, n_jobs=None, pairwise_batch_size=None, mode=None):
51 | self.k = k
52 | self.tolerance = tolerance
53 | self.max_iterations = max_iterations
54 | self.init = init
55 | self.n_init = n_init
56 | self.random_state = random_state
57 | self.n_jobs = n_jobs
58 | self.pairwise_batch_size = pairwise_batch_size
59 | self.mode = mode
60 |
61 | def split_for_val(self, l_feats, l_targets, val_prop=0.2):
62 |
63 | np.random.seed(0)
64 |
65 | # Reserve some labelled examples for validation
66 | num_val_instances = int(val_prop * len(l_targets))
67 | val_idxs = np.random.choice(range(len(l_targets)), size=(num_val_instances), replace=False)
68 | val_idxs.sort()
69 | remaining_idxs = list(set(range(len(l_targets))) - set(val_idxs.tolist()))
70 | remaining_idxs.sort()
71 | remaining_idxs = np.array(remaining_idxs)
72 |
73 | val_l_targets = l_targets[val_idxs]
74 | val_l_feats = l_feats[val_idxs]
75 |
76 | remaining_l_targets = l_targets[remaining_idxs]
77 | remaining_l_feats = l_feats[remaining_idxs]
78 |
79 | return remaining_l_feats, remaining_l_targets, val_l_feats, val_l_targets
80 |
81 |
82 | def kpp(self, X, pre_centers=None, k=10, random_state=None):
83 | random_state = check_random_state(random_state)
84 |
85 | if pre_centers is not None:
86 |
87 | C = pre_centers
88 |
89 | else:
90 |
91 | C = X[random_state.randint(0, len(X))]
92 |
93 | C = C.view(-1, X.shape[1])
94 |
95 | while C.shape[0] < k:
96 |
97 | dist = pairwise_distance(X, C, self.pairwise_batch_size)
98 | dist = dist.view(-1, C.shape[0])
99 | d2, _ = torch.min(dist, dim=1)
100 | prob = d2/d2.sum()
101 | cum_prob = torch.cumsum(prob, dim=0)
102 | r = random_state.rand()
103 |
104 | if len((cum_prob >= r).nonzero()) == 0:
105 | debug = 0
106 | else:
107 | ind = (cum_prob >= r).nonzero()[0][0]
108 | C = torch.cat((C, X[ind].view(1, -1)), dim=0)
109 |
110 | return C
111 |
112 |
113 | def fit_once(self, X, random_state):
114 |
115 | centers = torch.zeros(self.k, X.shape[1]).type_as(X)
116 | labels = -torch.ones(len(X))
117 | #initialize the centers, the first 'k' elements in the dataset will be our initial centers
118 |
119 | if self.init == 'k-means++':
120 | centers = self.kpp(X, k=self.k, random_state=random_state)
121 |
122 | elif self.init == 'random':
123 |
124 | random_state = check_random_state(self.random_state)
125 | idx = random_state.choice(len(X), self.k, replace=False)
126 | for i in range(self.k):
127 | centers[i] = X[idx[i]]
128 |
129 | else:
130 | for i in range(self.k):
131 | centers[i] = X[i]
132 |
133 | #begin iterations
134 |
135 | best_labels, best_inertia, best_centers = None, None, None
136 | for i in range(self.max_iterations):
137 |
138 | centers_old = centers.clone()
139 | dist = pairwise_distance(X, centers, self.pairwise_batch_size)
140 | mindist, labels = torch.min(dist, dim=1)
141 | inertia = mindist.sum()
142 |
143 | for idx in range(self.k):
144 | selected = torch.nonzero(labels == idx).squeeze()
145 | selected = torch.index_select(X, 0, selected)
146 | centers[idx] = selected.mean(dim=0)
147 |
148 | if best_inertia is None or inertia < best_inertia:
149 | best_labels = labels.clone()
150 | best_centers = centers.clone()
151 | best_inertia = inertia
152 |
153 | center_shift = torch.sum(torch.sqrt(torch.sum((centers - centers_old) ** 2, dim=1)))
154 | if center_shift ** 2 < self.tolerance:
155 | #break out of the main loop if the results are optimal, ie. the centers don't change their positions much(more than our tolerance)
156 | break
157 |
158 | return best_labels, best_inertia, best_centers, i + 1
159 |
160 |
161 | def fit_mix_once(self, u_feats, l_feats, l_targets, random_state):
162 |
163 | def supp_idxs(c):
164 | return l_targets.eq(c).nonzero().squeeze(1)
165 |
166 | l_classes = torch.unique(l_targets)
167 | support_idxs = list(map(supp_idxs, l_classes))
168 | l_centers = torch.stack([l_feats[idx_list].mean(0) for idx_list in support_idxs])
169 | cat_feats = torch.cat((l_feats, u_feats))
170 |
171 | centers = torch.zeros([self.k, cat_feats.shape[1]]).type_as(cat_feats)
172 | centers[:len(l_classes)] = l_centers
173 |
174 | labels = -torch.ones(len(cat_feats)).type_as(cat_feats).long()
175 |
176 | l_classes = l_classes.cpu().long().numpy()
177 | l_targets = l_targets.cpu().long().numpy()
178 | l_num = len(l_targets)
179 | cid2ncid = {cid:ncid for ncid, cid in enumerate(l_classes)} # Create the mapping table for New cid (ncid)
180 | for i in range(l_num):
181 | labels[i] = cid2ncid[l_targets[i]]
182 |
183 | #initialize the centers, the first 'k' elements in the dataset will be our initial centers
184 | centers = self.kpp(u_feats, l_centers, k=self.k, random_state=random_state)
185 |
186 | # Begin iterations
187 | best_labels, best_inertia, best_centers = None, None, None
188 | for it in range(self.max_iterations):
189 | centers_old = centers.clone()
190 |
191 | dist = pairwise_distance(u_feats, centers, self.pairwise_batch_size)
192 | u_mindist, u_labels = torch.min(dist, dim=1)
193 | u_inertia = u_mindist.sum()
194 | l_mindist = torch.sum((l_feats - centers[labels[:l_num]])**2, dim=1)
195 | l_inertia = l_mindist.sum()
196 | inertia = u_inertia + l_inertia
197 | labels[l_num:] = u_labels
198 |
199 | for idx in range(self.k):
200 |
201 | selected = torch.nonzero(labels == idx).squeeze()
202 | selected = torch.index_select(cat_feats, 0, selected)
203 | centers[idx] = selected.mean(dim=0)
204 |
205 | if best_inertia is None or inertia < best_inertia:
206 | best_labels = labels.clone()
207 | best_centers = centers.clone()
208 | best_inertia = inertia
209 |
210 | center_shift = torch.sum(torch.sqrt(torch.sum((centers - centers_old) ** 2, dim=1)))
211 |
212 | if center_shift ** 2 < self.tolerance:
213 | #break out of the main loop if the results are optimal, ie. the centers don't change their positions much(more than our tolerance)
214 | break
215 |
216 | return best_labels, best_inertia, best_centers, i + 1
217 |
218 |
219 | def fit(self, X):
220 | random_state = check_random_state(self.random_state)
221 | best_inertia = None
222 | if effective_n_jobs(self.n_jobs) == 1:
223 | for it in range(self.n_init):
224 | labels, inertia, centers, n_iters = self.fit_once(X, random_state)
225 | if best_inertia is None or inertia < best_inertia:
226 | self.labels_ = labels.clone()
227 | self.cluster_centers_ = centers.clone()
228 | best_inertia = inertia
229 | self.inertia_ = inertia
230 | self.n_iter_ = n_iters
231 | else:
232 | # parallelisation of k-means runs
233 | seeds = random_state.randint(np.iinfo(np.int32).max, size=self.n_init)
234 | results = Parallel(n_jobs=self.n_jobs, verbose=0)(delayed(self.fit_once)(X, seed) for seed in seeds)
235 | # Get results with the lowest inertia
236 | labels, inertia, centers, n_iters = zip(*results)
237 | best = np.argmin(inertia)
238 | self.labels_ = labels[best]
239 | self.inertia_ = inertia[best]
240 | self.cluster_centers_ = centers[best]
241 | self.n_iter_ = n_iters[best]
242 |
243 |
244 | def fit_mix(self, u_feats, l_feats, l_targets):
245 |
246 | random_state = check_random_state(self.random_state)
247 | best_inertia = None
248 | fit_func = self.fit_mix_once
249 |
250 | if effective_n_jobs(self.n_jobs) == 1:
251 | for it in range(self.n_init):
252 |
253 | labels, inertia, centers, n_iters = fit_func(u_feats, l_feats, l_targets, random_state)
254 |
255 | if best_inertia is None or inertia < best_inertia:
256 | self.labels_ = labels.clone()
257 | self.cluster_centers_ = centers.clone()
258 | best_inertia = inertia
259 | self.inertia_ = inertia
260 | self.n_iter_ = n_iters
261 |
262 | else:
263 |
264 | # parallelisation of k-means runs
265 | seeds = random_state.randint(np.iinfo(np.int32).max, size=self.n_init)
266 | results = Parallel(n_jobs=self.n_jobs, verbose=0)(delayed(fit_func)(u_feats, l_feats, l_targets, seed)
267 | for seed in seeds)
268 | # Get results with the lowest inertia
269 |
270 | labels, inertia, centers, n_iters = zip(*results)
271 | best = np.argmin(inertia)
272 | self.labels_ = labels[best]
273 | self.inertia_ = inertia[best]
274 | self.cluster_centers_ = centers[best]
275 | self.n_iter_ = n_iters[best]
276 |
277 |
278 | def main():
279 |
280 | import matplotlib.pyplot as plt
281 | from matplotlib import style
282 | import pandas as pd
283 | style.use('ggplot')
284 | from sklearn.datasets import make_blobs
285 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
286 | X, y = make_blobs(n_samples=500,
287 | n_features=2,
288 | centers=4,
289 | cluster_std=1,
290 | center_box=(-10.0, 10.0),
291 | shuffle=True,
292 | random_state=1) # For reproducibility
293 |
294 | cuda = torch.cuda.is_available()
295 | device = torch.device("cuda" if cuda else "cpu")
296 | # X = torch.from_numpy(X).float().to(device)
297 |
298 |
299 | y = np.array(y)
300 | l_targets = y[y>1]
301 | l_feats = X[y>1]
302 | u_feats = X[y<2]
303 | cat_feats = np.concatenate((l_feats, u_feats))
304 | y = np.concatenate((y[y>1], y[y<2]))
305 | cat_feats = torch.from_numpy(cat_feats).to(device)
306 | u_feats = torch.from_numpy(u_feats).to(device)
307 | l_feats = torch.from_numpy(l_feats).to(device)
308 | l_targets = torch.from_numpy(l_targets).to(device)
309 |
310 | km = K_Means(k=4, init='k-means++', random_state=1, n_jobs=None, pairwise_batch_size=10)
311 |
312 | # km.fit(X)
313 |
314 | km.fit_mix(u_feats, l_feats, l_targets)
315 | # X = X.cpu()
316 | X = cat_feats.cpu()
317 | centers = km.cluster_centers_.cpu()
318 | pred = km.labels_.cpu()
319 | print('nmi', nmi_score(pred, y))
320 |
321 | # Plotting starts here
322 | colors = 10*["g", "c", "b", "k", "r", "m"]
323 |
324 | for i in range(len(X)):
325 | x = X[i]
326 | plt.scatter(x[0], x[1], color = colors[pred[i]],s = 10)
327 |
328 | for i in range(4):
329 | plt.scatter(centers[i][0], centers[i][1], s = 130, marker = "*", color='r')
330 | plt.show()
331 |
332 | if __name__ == "__main__":
333 | main()
--------------------------------------------------------------------------------