├── .gitignore ├── assets └── teaser.jpg ├── data ├── ssb_splits │ ├── cub_osr_splits.pkl │ ├── scars_osr_splits.pkl │ ├── aircraft_osr_splits.pkl │ └── herbarium_19_class_splits.pkl ├── augmentations │ └── __init__.py ├── data_utils.py ├── get_datasets.py ├── stanford_cars.py ├── herbarium_19.py ├── cifar.py ├── cub.py ├── imagenet.py └── fgvc_aircraft.py ├── requirements.txt ├── scripts ├── run_cub.sh ├── run_cars.sh ├── run_aircraft.sh ├── run_cifar10.sh ├── run_cifar100.sh ├── run_herb19.sh ├── run_imagenet100.sh └── run_imagenet1k.sh ├── config.py ├── LICENSE ├── util ├── general_utils.py └── cluster_and_log_utils.py ├── README.md ├── model.py ├── train.py └── train_mp.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | dev_outputs -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/assets/teaser.jpg -------------------------------------------------------------------------------- /data/ssb_splits/cub_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/data/ssb_splits/cub_osr_splits.pkl -------------------------------------------------------------------------------- /data/ssb_splits/scars_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/data/ssb_splits/scars_osr_splits.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru 2 | numpy 3 | pandas 4 | scikit_learn 5 | scipy 6 | torch==1.10.0 7 | torchvision==0.11.1 8 | tqdm -------------------------------------------------------------------------------- /data/ssb_splits/aircraft_osr_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/data/ssb_splits/aircraft_osr_splits.pkl -------------------------------------------------------------------------------- /data/ssb_splits/herbarium_19_class_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/data/ssb_splits/herbarium_19_class_splits.pkl -------------------------------------------------------------------------------- /scripts/run_cub.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --dataset_name 'cub' \ 8 | --batch_size 128 \ 9 | --grad_from_block 11 \ 10 | --epochs 200 \ 11 | --num_workers 8 \ 12 | --use_ssb_splits \ 13 | --sup_weight 0.35 \ 14 | --weight_decay 5e-5 \ 15 | --transform 'imagenet' \ 16 | --lr 0.1 \ 17 | --eval_funcs 'v2' \ 18 | --warmup_teacher_temp 0.07 \ 19 | --teacher_temp 0.04 \ 20 | --warmup_teacher_temp_epochs 30 \ 21 | --memax_weight 2 \ 22 | --exp_name cub_simgcd 23 | -------------------------------------------------------------------------------- /scripts/run_cars.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --dataset_name 'scars' \ 8 | --batch_size 128 \ 9 | --grad_from_block 11 \ 10 | --epochs 200 \ 11 | --num_workers 8 \ 12 | --use_ssb_splits \ 13 | --sup_weight 0.35 \ 14 | --weight_decay 5e-5 \ 15 | --transform 'imagenet' \ 16 | --lr 0.1 \ 17 | --eval_funcs 'v2' \ 18 | --warmup_teacher_temp 0.07 \ 19 | --teacher_temp 0.04 \ 20 | --warmup_teacher_temp_epochs 30 \ 21 | --memax_weight 1 \ 22 | --exp_name scars_simgcd 23 | -------------------------------------------------------------------------------- /scripts/run_aircraft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --dataset_name 'aircraft' \ 8 | --batch_size 128 \ 9 | --grad_from_block 11 \ 10 | --epochs 200 \ 11 | --num_workers 8 \ 12 | --use_ssb_splits \ 13 | --sup_weight 0.35 \ 14 | --weight_decay 5e-5 \ 15 | --transform 'imagenet' \ 16 | --lr 0.1 \ 17 | --eval_funcs 'v2' \ 18 | --warmup_teacher_temp 0.07 \ 19 | --teacher_temp 0.04 \ 20 | --warmup_teacher_temp_epochs 30 \ 21 | --memax_weight 1 \ 22 | --exp_name aircraft_simgcd 23 | -------------------------------------------------------------------------------- /scripts/run_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --dataset_name 'cifar10' \ 8 | --batch_size 128 \ 9 | --grad_from_block 11 \ 10 | --epochs 200 \ 11 | --num_workers 8 \ 12 | --use_ssb_splits \ 13 | --sup_weight 0.35 \ 14 | --weight_decay 5e-5 \ 15 | --transform 'imagenet' \ 16 | --lr 0.1 \ 17 | --eval_funcs 'v2' \ 18 | --warmup_teacher_temp 0.07 \ 19 | --teacher_temp 0.04 \ 20 | --warmup_teacher_temp_epochs 30 \ 21 | --memax_weight 1 \ 22 | --exp_name cifar10_simgcd 23 | -------------------------------------------------------------------------------- /scripts/run_cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --dataset_name 'cifar100' \ 8 | --batch_size 128 \ 9 | --grad_from_block 11 \ 10 | --epochs 200 \ 11 | --num_workers 8 \ 12 | --use_ssb_splits \ 13 | --sup_weight 0.35 \ 14 | --weight_decay 5e-5 \ 15 | --transform 'imagenet' \ 16 | --lr 0.1 \ 17 | --eval_funcs 'v2' \ 18 | --warmup_teacher_temp 0.07 \ 19 | --teacher_temp 0.04 \ 20 | --warmup_teacher_temp_epochs 30 \ 21 | --memax_weight 4 \ 22 | --exp_name cifar100_simgcd 23 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ----------------- 2 | # DATASET ROOTS 3 | # ----------------- 4 | cifar_10_root = '${DATASET_DIR}/cifar10' 5 | cifar_100_root = '${DATASET_DIR}/cifar100' 6 | cub_root = '${DATASET_DIR}/cub' 7 | aircraft_root = '${DATASET_DIR}/fgvc-aircraft-2013b' 8 | car_root = '${DATASET_DIR}/cars' 9 | herbarium_dataroot = '${DATASET_DIR}/herbarium_19' 10 | imagenet_root = '${DATASET_DIR}/ImageNet' 11 | 12 | # OSR Split dir 13 | osr_split_dir = 'data/ssb_splits' 14 | 15 | # ----------------- 16 | # OTHER PATHS 17 | # ----------------- 18 | exp_root = 'dev_outputs' # All logs and checkpoints will be saved here -------------------------------------------------------------------------------- /scripts/run_herb19.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --dataset_name 'herbarium_19' \ 8 | --batch_size 128 \ 9 | --grad_from_block 11 \ 10 | --epochs 200 \ 11 | --num_workers 8 \ 12 | --use_ssb_splits \ 13 | --sup_weight 0.35 \ 14 | --weight_decay 5e-5 \ 15 | --transform 'imagenet' \ 16 | --lr 0.1 \ 17 | --eval_funcs 'v2' 'v2b' \ 18 | --warmup_teacher_temp 0.07 \ 19 | --teacher_temp 0.04 \ 20 | --warmup_teacher_temp_epochs 30 \ 21 | --memax_weight 1 \ 22 | --exp_name herb19_simgcd 23 | -------------------------------------------------------------------------------- /scripts/run_imagenet100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | CUDA_VISIBLE_DEVICES=0 python train.py \ 7 | --dataset_name 'imagenet_100' \ 8 | --batch_size 128 \ 9 | --grad_from_block 11 \ 10 | --epochs 200 \ 11 | --num_workers 8 \ 12 | --use_ssb_splits \ 13 | --sup_weight 0.35 \ 14 | --weight_decay 5e-5 \ 15 | --transform 'imagenet' \ 16 | --lr 0.1 \ 17 | --eval_funcs 'v2' \ 18 | --warmup_teacher_temp 0.07 \ 19 | --teacher_temp 0.04 \ 20 | --warmup_teacher_temp_epochs 30 \ 21 | --memax_weight 1 \ 22 | --exp_name imagenet100_simgcd 23 | -------------------------------------------------------------------------------- /scripts/run_imagenet1k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 train_mp.py \ 7 | --dataset_name 'imagenet_1k' \ 8 | --batch_size 128 \ 9 | --grad_from_block 11 \ 10 | --epochs 200 \ 11 | --num_workers 8 \ 12 | --use_ssb_splits \ 13 | --sup_weight 0.35 \ 14 | --weight_decay 5e-5 \ 15 | --transform 'imagenet' \ 16 | --lr 0.1 \ 17 | --eval_funcs 'v2' \ 18 | --warmup_teacher_temp 0.07 \ 19 | --teacher_temp 0.04 \ 20 | --warmup_teacher_temp_epochs 30 \ 21 | --memax_weight 1 \ 22 | --exp_name imagenet1k_simgcd \ 23 | --print_freq 100 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xin Wen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | import torch 4 | 5 | def get_transform(transform_type='imagenet', image_size=32, args=None): 6 | 7 | if transform_type == 'imagenet': 8 | 9 | mean = (0.485, 0.456, 0.406) 10 | std = (0.229, 0.224, 0.225) 11 | interpolation = args.interpolation 12 | crop_pct = args.crop_pct 13 | 14 | train_transform = transforms.Compose([ 15 | transforms.Resize(int(image_size / crop_pct), interpolation), 16 | transforms.RandomCrop(image_size), 17 | transforms.RandomHorizontalFlip(p=0.5), 18 | transforms.ColorJitter(), 19 | transforms.ToTensor(), 20 | transforms.Normalize( 21 | mean=torch.tensor(mean), 22 | std=torch.tensor(std)) 23 | ]) 24 | 25 | test_transform = transforms.Compose([ 26 | transforms.Resize(int(image_size / crop_pct), interpolation), 27 | transforms.CenterCrop(image_size), 28 | transforms.ToTensor(), 29 | transforms.Normalize( 30 | mean=torch.tensor(mean), 31 | std=torch.tensor(std)) 32 | ]) 33 | 34 | else: 35 | 36 | raise NotImplementedError 37 | 38 | return (train_transform, test_transform) -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | def subsample_instances(dataset, prop_indices_to_subsample=0.8): 5 | 6 | np.random.seed(0) 7 | subsample_indices = np.random.choice(range(len(dataset)), replace=False, 8 | size=(int(prop_indices_to_subsample * len(dataset)),)) 9 | 10 | return subsample_indices 11 | 12 | class MergedDataset(Dataset): 13 | 14 | """ 15 | Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them 16 | Allows you to iterate over them in parallel 17 | """ 18 | 19 | def __init__(self, labelled_dataset, unlabelled_dataset): 20 | 21 | self.labelled_dataset = labelled_dataset 22 | self.unlabelled_dataset = unlabelled_dataset 23 | self.target_transform = None 24 | 25 | def __getitem__(self, item): 26 | 27 | if item < len(self.labelled_dataset): 28 | img, label, uq_idx = self.labelled_dataset[item] 29 | labeled_or_not = 1 30 | 31 | else: 32 | 33 | img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)] 34 | labeled_or_not = 0 35 | 36 | 37 | return img, label, uq_idx, np.array([labeled_or_not]) 38 | 39 | def __len__(self): 40 | return len(self.unlabelled_dataset) + len(self.labelled_dataset) 41 | -------------------------------------------------------------------------------- /util/general_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import inspect 4 | 5 | from datetime import datetime 6 | from loguru import logger 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | def init_experiment(args, runner_name=None, exp_id=None): 29 | # Get filepath of calling script 30 | if runner_name is None: 31 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:] 32 | 33 | root_dir = os.path.join(args.exp_root, *runner_name) 34 | 35 | if not os.path.exists(root_dir): 36 | os.makedirs(root_dir) 37 | 38 | # Either generate a unique experiment ID, or use one which is passed 39 | if exp_id is None: 40 | 41 | if args.exp_name is None: 42 | raise ValueError("Need to specify the experiment name") 43 | # Unique identifier for experiment 44 | now = '{}_({:02d}.{:02d}.{}_|_'.format(args.exp_name, datetime.now().day, datetime.now().month, datetime.now().year) + \ 45 | datetime.now().strftime("%S.%f")[:-3] + ')' 46 | 47 | log_dir = os.path.join(root_dir, 'log', now) 48 | while os.path.exists(log_dir): 49 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \ 50 | datetime.now().strftime("%S.%f")[:-3] + ')' 51 | 52 | log_dir = os.path.join(root_dir, 'log', now) 53 | 54 | else: 55 | 56 | log_dir = os.path.join(root_dir, 'log', f'{exp_id}') 57 | 58 | if not os.path.exists(log_dir): 59 | os.makedirs(log_dir) 60 | 61 | 62 | logger.add(os.path.join(log_dir, 'log.txt')) 63 | args.logger = logger 64 | args.log_dir = log_dir 65 | 66 | # Instantiate directory to save models to 67 | model_root_dir = os.path.join(args.log_dir, 'checkpoints') 68 | if not os.path.exists(model_root_dir): 69 | os.mkdir(model_root_dir) 70 | 71 | args.model_dir = model_root_dir 72 | args.model_path = os.path.join(args.model_dir, 'model.pt') 73 | 74 | print(f'Experiment saved to: {args.log_dir}') 75 | 76 | hparam_dict = {} 77 | 78 | for k, v in vars(args).items(): 79 | if isinstance(v, (int, float, str, bool, torch.Tensor)): 80 | hparam_dict[k] = v 81 | 82 | print(runner_name) 83 | print(args) 84 | 85 | return args 86 | 87 | 88 | class DistributedWeightedSampler(torch.utils.data.distributed.DistributedSampler): 89 | 90 | def __init__(self, dataset, weights, num_samples, num_replicas=None, rank=None, 91 | replacement=True, generator=None): 92 | super(DistributedWeightedSampler, self).__init__(dataset, num_replicas, rank) 93 | if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ 94 | num_samples <= 0: 95 | raise ValueError("num_samples should be a positive integer " 96 | "value, but got num_samples={}".format(num_samples)) 97 | if not isinstance(replacement, bool): 98 | raise ValueError("replacement should be a boolean value, but got " 99 | "replacement={}".format(replacement)) 100 | self.weights = torch.as_tensor(weights, dtype=torch.double) 101 | self.num_samples = num_samples 102 | self.replacement = replacement 103 | self.generator = generator 104 | self.weights = self.weights[self.rank::self.num_replicas] 105 | self.num_samples = self.num_samples // self.num_replicas 106 | 107 | def __iter__(self): 108 | rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) 109 | rand_tensor = self.rank + rand_tensor * self.num_replicas 110 | yield from iter(rand_tensor.tolist()) 111 | 112 | def __len__(self): 113 | return self.num_samples 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parametric Classification for Generalized Category Discovery: A Baseline Study 2 | 3 | 4 |

5 | 6 | 7 | 8 | 9 |

10 |

11 | Parametric Classification for Generalized Category Discovery: A Baseline Study (ICCV 2023)
12 | By 13 | Xin Wen*, 14 | Bingchen Zhao*, and 15 | Xiaojuan Qi. 16 |

17 | 18 | ![teaser](assets/teaser.jpg) 19 | 20 | Generalized Category Discovery (GCD) aims to discover novel categories in unlabelled datasets using knowledge learned from labelled samples. 21 | Previous studies argued that parametric classifiers are prone to overfitting to seen categories, and endorsed using a non-parametric classifier formed with semi-supervised $k$-means. 22 | 23 | However, in this study, we investigate the failure of parametric classifiers, verify the effectiveness of previous design choices when high-quality supervision is available, and identify unreliable pseudo-labels as a key problem. We demonstrate that two prediction biases exist: the classifier tends to predict seen classes more often, and produces an imbalanced distribution across seen and novel categories. 24 | Based on these findings, we propose a simple yet effective parametric classification method that benefits from entropy regularisation, achieves state-of-the-art performance on multiple GCD benchmarks and shows strong robustness to unknown class numbers. 25 | We hope the investigation and proposed simple framework can serve as a strong baseline to facilitate future studies in this field. 26 | 27 | ## Running 28 | 29 | ### Dependencies 30 | 31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ### Config 36 | 37 | Set paths to datasets and desired log directories in ```config.py``` 38 | 39 | 40 | ### Datasets 41 | 42 | We use fine-grained benchmarks in this paper, including: 43 | 44 | * [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6) 45 | 46 | We also use generic object recognition datasets, including: 47 | 48 | * [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet-100/1K](https://image-net.org/download.php) 49 | 50 | 51 | ### Scripts 52 | 53 | **Train the model**: 54 | 55 | ``` 56 | bash scripts/run_${DATASET_NAME}.sh 57 | ``` 58 | 59 | We found picking the model according to 'Old' class performance could lead to possible over-fitting, and since 'New' class labels on the held-out validation set should be assumed unavailable, we suggest not to perform model selection, and simply use the last-epoch model. 60 | 61 | ## Results 62 | Our results: 63 | 64 |
SourcePaper (3 runs) Current Github (5 runs)
DatasetAllOldNewAllOldNew
CIFAR1097.1±0.095.1±0.198.1±0.197.0±0.193.9±0.198.5±0.1
CIFAR10080.1±0.981.2±0.477.8±2.079.8±0.681.1±0.577.4±2.5
ImageNet-10083.0±1.293.1±0.277.9±1.983.6±1.492.4±0.179.1±2.2
ImageNet-1K57.1±0.177.3±0.146.9±0.257.0±0.477.1±0.146.9±0.5
CUB60.3±0.165.6±0.957.7±0.461.5±0.565.7±0.559.4±0.8
Stanford Cars53.8±2.271.9±1.745.0±2.453.4±1.671.5±1.644.6±1.7
FGVC-Aircraft54.2±1.959.1±1.251.8±2.354.3±0.759.4±0.451.7±1.2
Herbarium 1944.0±0.458.0±0.436.4±0.844.2±0.257.6±0.637.0±0.4
65 | 66 | ## Citing this work 67 | 68 | If you find this repo useful for your research, please consider citing our paper: 69 | 70 | ``` 71 | @inproceedings{wen2023simgcd, 72 | author = {Wen, Xin and Zhao, Bingchen and Qi, Xiaojuan}, 73 | title = {Parametric Classification for Generalized Category Discovery: A Baseline Study}, 74 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 75 | year = {2023}, 76 | pages = {16590-16600} 77 | } 78 | ``` 79 | 80 | ## Acknowledgements 81 | 82 | The codebase is largely built on this repo: https://github.com/sgvaze/generalized-category-discovery. 83 | 84 | ## License 85 | 86 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 87 | -------------------------------------------------------------------------------- /data/get_datasets.py: -------------------------------------------------------------------------------- 1 | from data.data_utils import MergedDataset 2 | 3 | from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets 4 | from data.herbarium_19 import get_herbarium_datasets 5 | from data.stanford_cars import get_scars_datasets 6 | from data.imagenet import get_imagenet_100_datasets, get_imagenet_1k_datasets 7 | from data.cub import get_cub_datasets 8 | from data.fgvc_aircraft import get_aircraft_datasets 9 | 10 | from copy import deepcopy 11 | import pickle 12 | import os 13 | 14 | from config import osr_split_dir 15 | 16 | 17 | get_dataset_funcs = { 18 | 'cifar10': get_cifar_10_datasets, 19 | 'cifar100': get_cifar_100_datasets, 20 | 'imagenet_100': get_imagenet_100_datasets, 21 | 'imagenet_1k': get_imagenet_1k_datasets, 22 | 'herbarium_19': get_herbarium_datasets, 23 | 'cub': get_cub_datasets, 24 | 'aircraft': get_aircraft_datasets, 25 | 'scars': get_scars_datasets 26 | } 27 | 28 | 29 | def get_datasets(dataset_name, train_transform, test_transform, args): 30 | 31 | """ 32 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled 33 | test_dataset, 34 | unlabelled_train_examples_test, 35 | datasets 36 | """ 37 | 38 | # 39 | if dataset_name not in get_dataset_funcs.keys(): 40 | raise ValueError 41 | 42 | # Get datasets 43 | get_dataset_f = get_dataset_funcs[dataset_name] 44 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform, 45 | train_classes=args.train_classes, 46 | prop_train_labels=args.prop_train_labels, 47 | split_train_val=False) 48 | # Set target transforms: 49 | target_transform_dict = {} 50 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)): 51 | target_transform_dict[cls] = i 52 | target_transform = lambda x: target_transform_dict[x] 53 | 54 | for dataset_name, dataset in datasets.items(): 55 | if dataset is not None: 56 | dataset.target_transform = target_transform 57 | 58 | # Train split (labelled and unlabelled classes) for training 59 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']), 60 | unlabelled_dataset=deepcopy(datasets['train_unlabelled'])) 61 | 62 | test_dataset = datasets['test'] 63 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled']) 64 | unlabelled_train_examples_test.transform = test_transform 65 | 66 | return train_dataset, test_dataset, unlabelled_train_examples_test, datasets 67 | 68 | 69 | def get_class_splits(args): 70 | 71 | # For FGVC datasets, optionally return bespoke splits 72 | if args.dataset_name in ('scars', 'cub', 'aircraft'): 73 | if hasattr(args, 'use_ssb_splits'): 74 | use_ssb_splits = args.use_ssb_splits 75 | else: 76 | use_ssb_splits = False 77 | 78 | # ------------- 79 | # GET CLASS SPLITS 80 | # ------------- 81 | if args.dataset_name == 'cifar10': 82 | 83 | args.image_size = 32 84 | args.train_classes = range(5) 85 | args.unlabeled_classes = range(5, 10) 86 | 87 | elif args.dataset_name == 'cifar100': 88 | 89 | args.image_size = 32 90 | args.train_classes = range(80) 91 | args.unlabeled_classes = range(80, 100) 92 | 93 | elif args.dataset_name == 'herbarium_19': 94 | 95 | args.image_size = 224 96 | herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl') 97 | 98 | with open(herb_path_splits, 'rb') as handle: 99 | class_splits = pickle.load(handle) 100 | 101 | args.train_classes = class_splits['Old'] 102 | args.unlabeled_classes = class_splits['New'] 103 | 104 | elif args.dataset_name == 'imagenet_100': 105 | 106 | args.image_size = 224 107 | args.train_classes = range(50) 108 | args.unlabeled_classes = range(50, 100) 109 | 110 | elif args.dataset_name == 'imagenet_1k': 111 | 112 | args.image_size = 224 113 | args.train_classes = range(500) 114 | args.unlabeled_classes = range(500, 1000) 115 | 116 | elif args.dataset_name == 'scars': 117 | 118 | args.image_size = 224 119 | 120 | if use_ssb_splits: 121 | 122 | split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl') 123 | with open(split_path, 'rb') as handle: 124 | class_info = pickle.load(handle) 125 | 126 | args.train_classes = class_info['known_classes'] 127 | open_set_classes = class_info['unknown_classes'] 128 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 129 | 130 | else: 131 | 132 | args.train_classes = range(98) 133 | args.unlabeled_classes = range(98, 196) 134 | 135 | elif args.dataset_name == 'aircraft': 136 | 137 | args.image_size = 224 138 | if use_ssb_splits: 139 | 140 | split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl') 141 | with open(split_path, 'rb') as handle: 142 | class_info = pickle.load(handle) 143 | 144 | args.train_classes = class_info['known_classes'] 145 | open_set_classes = class_info['unknown_classes'] 146 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 147 | 148 | else: 149 | 150 | args.train_classes = range(50) 151 | args.unlabeled_classes = range(50, 100) 152 | 153 | elif args.dataset_name == 'cub': 154 | 155 | args.image_size = 224 156 | 157 | if use_ssb_splits: 158 | 159 | split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl') 160 | with open(split_path, 'rb') as handle: 161 | class_info = pickle.load(handle) 162 | 163 | args.train_classes = class_info['known_classes'] 164 | open_set_classes = class_info['unknown_classes'] 165 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 166 | 167 | else: 168 | 169 | args.train_classes = range(100) 170 | args.unlabeled_classes = range(100, 200) 171 | 172 | else: 173 | 174 | raise NotImplementedError 175 | 176 | return args 177 | -------------------------------------------------------------------------------- /data/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | from scipy import io as mat_io 6 | 7 | from torchvision.datasets.folder import default_loader 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import subsample_instances 11 | from config import car_root 12 | 13 | class CarsDataset(Dataset): 14 | """ 15 | Cars Dataset 16 | """ 17 | def __init__(self, train=True, limit=0, data_dir=car_root, transform=None): 18 | 19 | metas = os.path.join(data_dir, 'devkit/cars_train_annos.mat') if train else os.path.join(data_dir, 'devkit/cars_test_annos_withlabels.mat') 20 | data_dir = os.path.join(data_dir, 'cars_train/') if train else os.path.join(data_dir, 'cars_test/') 21 | 22 | self.loader = default_loader 23 | self.data_dir = data_dir 24 | self.data = [] 25 | self.target = [] 26 | self.train = train 27 | 28 | self.transform = transform 29 | 30 | if not isinstance(metas, str): 31 | raise Exception("Train metas must be string location !") 32 | labels_meta = mat_io.loadmat(metas) 33 | 34 | for idx, img_ in enumerate(labels_meta['annotations'][0]): 35 | if limit: 36 | if idx > limit: 37 | break 38 | 39 | # self.data.append(img_resized) 40 | self.data.append(data_dir + img_[5][0]) 41 | # if self.mode == 'train': 42 | self.target.append(img_[4][0][0]) 43 | 44 | self.uq_idxs = np.array(range(len(self))) 45 | self.target_transform = None 46 | 47 | def __getitem__(self, idx): 48 | 49 | image = self.loader(self.data[idx]) 50 | target = self.target[idx] - 1 51 | 52 | if self.transform is not None: 53 | image = self.transform(image) 54 | 55 | if self.target_transform is not None: 56 | target = self.target_transform(target) 57 | 58 | idx = self.uq_idxs[idx] 59 | 60 | return image, target, idx 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | 66 | def subsample_dataset(dataset, idxs): 67 | 68 | dataset.data = np.array(dataset.data)[idxs].tolist() 69 | dataset.target = np.array(dataset.target)[idxs].tolist() 70 | dataset.uq_idxs = dataset.uq_idxs[idxs] 71 | 72 | return dataset 73 | 74 | 75 | def subsample_classes(dataset, include_classes=range(160)): 76 | 77 | include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195 78 | cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars] 79 | 80 | target_xform_dict = {} 81 | for i, k in enumerate(include_classes): 82 | target_xform_dict[k] = i 83 | 84 | dataset = subsample_dataset(dataset, cls_idxs) 85 | 86 | # dataset.target_transform = lambda x: target_xform_dict[x] 87 | 88 | return dataset 89 | 90 | def get_train_val_indices(train_dataset, val_split=0.2): 91 | 92 | train_classes = np.unique(train_dataset.target) 93 | 94 | # Get train/test indices 95 | train_idxs = [] 96 | val_idxs = [] 97 | for cls in train_classes: 98 | 99 | cls_idxs = np.where(train_dataset.target == cls)[0] 100 | 101 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 102 | t_ = [x for x in cls_idxs if x not in v_] 103 | 104 | train_idxs.extend(t_) 105 | val_idxs.extend(v_) 106 | 107 | return train_idxs, val_idxs 108 | 109 | 110 | def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, 111 | split_train_val=False, seed=0): 112 | 113 | np.random.seed(seed) 114 | 115 | # Init entire training set 116 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, train=True) 117 | 118 | # Get labelled training set which has subsampled classes, then subsample some indices from that 119 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 120 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 121 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 122 | 123 | # Split into training and validation sets 124 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 125 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 126 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 127 | val_dataset_labelled_split.transform = test_transform 128 | 129 | # Get unlabelled data 130 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 131 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 132 | 133 | # Get test set for all classes 134 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, train=False) 135 | 136 | # Either split train into train and val or use test set as val 137 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 138 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 139 | 140 | all_datasets = { 141 | 'train_labelled': train_dataset_labelled, 142 | 'train_unlabelled': train_dataset_unlabelled, 143 | 'val': val_dataset_labelled, 144 | 'test': test_dataset, 145 | } 146 | 147 | return all_datasets 148 | 149 | if __name__ == '__main__': 150 | 151 | x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False) 152 | 153 | print('Printing lens...') 154 | for k, v in x.items(): 155 | if v is not None: 156 | print(f'{k}: {len(v)}') 157 | 158 | print('Printing labelled and unlabelled overlap...') 159 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 160 | print('Printing total instances in train...') 161 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 162 | 163 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}') 164 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}') 165 | print(f'Len labelled set: {len(x["train_labelled"])}') 166 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/herbarium_19.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision 4 | import numpy as np 5 | from copy import deepcopy 6 | 7 | from data.data_utils import subsample_instances 8 | from config import herbarium_dataroot 9 | 10 | class HerbariumDataset19(torchvision.datasets.ImageFolder): 11 | 12 | def __init__(self, *args, **kwargs): 13 | 14 | # Process metadata json for training images into a DataFrame 15 | super().__init__(*args, **kwargs) 16 | 17 | self.uq_idxs = np.array(range(len(self))) 18 | 19 | def __getitem__(self, idx): 20 | 21 | img, label = super().__getitem__(idx) 22 | uq_idx = self.uq_idxs[idx] 23 | 24 | return img, label, uq_idx 25 | 26 | 27 | def subsample_dataset(dataset, idxs): 28 | 29 | mask = np.zeros(len(dataset)).astype('bool') 30 | mask[idxs] = True 31 | 32 | dataset.samples = np.array(dataset.samples)[mask].tolist() 33 | dataset.targets = np.array(dataset.targets)[mask].tolist() 34 | 35 | dataset.uq_idxs = dataset.uq_idxs[mask] 36 | 37 | dataset.samples = [[x[0], int(x[1])] for x in dataset.samples] 38 | dataset.targets = [int(x) for x in dataset.targets] 39 | 40 | return dataset 41 | 42 | 43 | def subsample_classes(dataset, include_classes=range(250)): 44 | 45 | cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes] 46 | 47 | target_xform_dict = {} 48 | for i, k in enumerate(include_classes): 49 | target_xform_dict[k] = i 50 | 51 | dataset = subsample_dataset(dataset, cls_idxs) 52 | 53 | dataset.target_transform = lambda x: target_xform_dict[x] 54 | 55 | return dataset 56 | 57 | 58 | def get_train_val_indices(train_dataset, val_instances_per_class=5): 59 | 60 | train_classes = list(set(train_dataset.targets)) 61 | 62 | # Get train/test indices 63 | train_idxs = [] 64 | val_idxs = [] 65 | for cls in train_classes: 66 | 67 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] 68 | 69 | # Have a balanced test set 70 | v_ = np.random.choice(cls_idxs, replace=False, size=(val_instances_per_class,)) 71 | t_ = [x for x in cls_idxs if x not in v_] 72 | 73 | train_idxs.extend(t_) 74 | val_idxs.extend(v_) 75 | 76 | return train_idxs, val_idxs 77 | 78 | 79 | def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8, 80 | seed=0, split_train_val=False): 81 | 82 | np.random.seed(seed) 83 | 84 | # Init entire training set 85 | train_dataset = HerbariumDataset19(transform=train_transform, 86 | root=os.path.join(herbarium_dataroot, 'small-train')) 87 | 88 | # Get labelled training set which has subsampled classes, then subsample some indices from that 89 | # TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class 90 | train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes) 91 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 92 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 93 | 94 | # Split into training and validation sets 95 | if split_train_val: 96 | 97 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, 98 | val_instances_per_class=5) 99 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 100 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 101 | val_dataset_labelled_split.transform = test_transform 102 | 103 | else: 104 | 105 | train_dataset_labelled_split, val_dataset_labelled_split = None, None 106 | 107 | # Get unlabelled data 108 | unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs) 109 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices))) 110 | 111 | # Get test dataset 112 | test_dataset = HerbariumDataset19(transform=test_transform, 113 | root=os.path.join(herbarium_dataroot, 'small-validation')) 114 | 115 | # Transform dict 116 | unlabelled_classes = list(set(train_dataset.targets) - set(train_classes)) 117 | target_xform_dict = {} 118 | for i, k in enumerate(list(train_classes) + unlabelled_classes): 119 | target_xform_dict[k] = i 120 | 121 | test_dataset.target_transform = lambda x: target_xform_dict[x] 122 | train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x] 123 | 124 | # Either split train into train and val or use test set as val 125 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 126 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 127 | 128 | all_datasets = { 129 | 'train_labelled': train_dataset_labelled, 130 | 'train_unlabelled': train_dataset_unlabelled, 131 | 'val': val_dataset_labelled, 132 | 'test': test_dataset, 133 | } 134 | 135 | return all_datasets 136 | 137 | if __name__ == '__main__': 138 | 139 | np.random.seed(0) 140 | train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False) 141 | 142 | x = get_herbarium_datasets(None, None, train_classes=train_classes, 143 | prop_train_labels=0.5) 144 | 145 | assert set(x['train_unlabelled'].targets) == set(range(683)) 146 | 147 | print('Printing lens...') 148 | for k, v in x.items(): 149 | if v is not None: 150 | print(f'{k}: {len(v)}') 151 | 152 | print('Printing labelled and unlabelled overlap...') 153 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 154 | print('Printing total instances in train...') 155 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 156 | print('Printing number of labelled classes...') 157 | print(len(set(x['train_labelled'].targets))) 158 | print('Printing total number of classes...') 159 | print(len(set(x['train_unlabelled'].targets))) 160 | 161 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 162 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 163 | print(f'Len labelled set: {len(x["train_labelled"])}') 164 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /util/cluster_and_log_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import numpy as np 4 | from scipy.optimize import linear_sum_assignment as linear_assignment 5 | 6 | 7 | def all_sum_item(item): 8 | item = torch.tensor(item).cuda() 9 | dist.all_reduce(item) 10 | return item.item() 11 | 12 | def split_cluster_acc_v2(y_true, y_pred, mask): 13 | """ 14 | Calculate clustering accuracy. Require scikit-learn installed 15 | First compute linear assignment on all data, then look at how good the accuracy is on subsets 16 | 17 | # Arguments 18 | mask: Which instances come from old classes (True) and which ones come from new classes (False) 19 | y: true labels, numpy.array with shape `(n_samples,)` 20 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 21 | 22 | # Return 23 | accuracy, in [0,1] 24 | """ 25 | y_true = y_true.astype(int) 26 | 27 | old_classes_gt = set(y_true[mask]) 28 | new_classes_gt = set(y_true[~mask]) 29 | 30 | assert y_pred.size == y_true.size 31 | D = max(y_pred.max(), y_true.max()) + 1 32 | w = np.zeros((D, D), dtype=int) 33 | for i in range(y_pred.size): 34 | w[y_pred[i], y_true[i]] += 1 35 | 36 | ind = linear_assignment(w.max() - w) 37 | ind = np.vstack(ind).T 38 | 39 | ind_map = {j: i for i, j in ind} 40 | total_acc = sum([w[i, j] for i, j in ind]) 41 | total_instances = y_pred.size 42 | try: 43 | if dist.get_world_size() > 0: 44 | total_acc = all_sum_item(total_acc) 45 | total_instances = all_sum_item(total_instances) 46 | except: 47 | pass 48 | total_acc /= total_instances 49 | 50 | old_acc = 0 51 | total_old_instances = 0 52 | for i in old_classes_gt: 53 | old_acc += w[ind_map[i], i] 54 | total_old_instances += sum(w[:, i]) 55 | 56 | try: 57 | if dist.get_world_size() > 0: 58 | old_acc = all_sum_item(old_acc) 59 | total_old_instances = all_sum_item(total_old_instances) 60 | except: 61 | pass 62 | old_acc /= total_old_instances 63 | 64 | new_acc = 0 65 | total_new_instances = 0 66 | for i in new_classes_gt: 67 | new_acc += w[ind_map[i], i] 68 | total_new_instances += sum(w[:, i]) 69 | 70 | try: 71 | if dist.get_world_size() > 0: 72 | new_acc = all_sum_item(new_acc) 73 | total_new_instances = all_sum_item(total_new_instances) 74 | except: 75 | pass 76 | new_acc /= total_new_instances 77 | 78 | return total_acc, old_acc, new_acc 79 | 80 | 81 | def split_cluster_acc_v2_balanced(y_true, y_pred, mask): 82 | """ 83 | Calculate clustering accuracy. Require scikit-learn installed 84 | First compute linear assignment on all data, then look at how good the accuracy is on subsets 85 | 86 | # Arguments 87 | mask: Which instances come from old classes (True) and which ones come from new classes (False) 88 | y: true labels, numpy.array with shape `(n_samples,)` 89 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 90 | 91 | # Return 92 | accuracy, in [0,1] 93 | """ 94 | y_true = y_true.astype(int) 95 | 96 | old_classes_gt = set(y_true[mask]) 97 | new_classes_gt = set(y_true[~mask]) 98 | 99 | assert y_pred.size == y_true.size 100 | D = max(y_pred.max(), y_true.max()) + 1 101 | w = np.zeros((D, D), dtype=int) 102 | for i in range(y_pred.size): 103 | w[y_pred[i], y_true[i]] += 1 104 | 105 | ind = linear_assignment(w.max() - w) 106 | ind = np.vstack(ind).T 107 | 108 | ind_map = {j: i for i, j in ind} 109 | 110 | old_acc = np.zeros(len(old_classes_gt)) 111 | total_old_instances = np.zeros(len(old_classes_gt)) 112 | for idx, i in enumerate(old_classes_gt): 113 | old_acc[idx] += w[ind_map[i], i] 114 | total_old_instances[idx] += sum(w[:, i]) 115 | 116 | new_acc = np.zeros(len(new_classes_gt)) 117 | total_new_instances = np.zeros(len(new_classes_gt)) 118 | for idx, i in enumerate(new_classes_gt): 119 | new_acc[idx] += w[ind_map[i], i] 120 | total_new_instances[idx] += sum(w[:, i]) 121 | 122 | try: 123 | if dist.get_world_size() > 0: 124 | old_acc, new_acc = torch.from_numpy(old_acc).cuda(), torch.from_numpy(new_acc).cuda() 125 | dist.all_reduce(old_acc), dist.all_reduce(new_acc) 126 | dist.all_reduce(total_old_instances), dist.all_reduce(total_new_instances) 127 | old_acc, new_acc = old_acc.cpu().numpy(), new_acc.cpu().numpy() 128 | total_old_instances, total_new_instances = total_old_instances.cpu().numpy(), total_new_instances.cpu().numpy() 129 | except: 130 | pass 131 | 132 | total_acc = np.concatenate([old_acc, new_acc]) / np.concatenate([total_old_instances, total_new_instances]) 133 | old_acc /= total_old_instances 134 | new_acc /= total_new_instances 135 | total_acc, old_acc, new_acc = total_acc.mean(), old_acc.mean(), new_acc.mean() 136 | return total_acc, old_acc, new_acc 137 | 138 | 139 | EVAL_FUNCS = { 140 | 'v2': split_cluster_acc_v2, 141 | 'v2b': split_cluster_acc_v2_balanced 142 | } 143 | 144 | def log_accs_from_preds(y_true, y_pred, mask, eval_funcs, save_name, T=None, 145 | print_output=True, args=None): 146 | 147 | """ 148 | Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results 149 | 150 | :param y_true: GT labels 151 | :param y_pred: Predicted indices 152 | :param mask: Which instances belong to Old and New classes 153 | :param T: Epoch 154 | :param eval_funcs: Which evaluation functions to use 155 | :param save_name: What are we evaluating ACC on 156 | :param writer: Tensorboard logger 157 | :return: 158 | """ 159 | 160 | mask = mask.astype(bool) 161 | y_true = y_true.astype(int) 162 | y_pred = y_pred.astype(int) 163 | 164 | for i, f_name in enumerate(eval_funcs): 165 | 166 | acc_f = EVAL_FUNCS[f_name] 167 | all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask) 168 | log_name = f'{save_name}_{f_name}' 169 | 170 | if i == 0: 171 | to_return = (all_acc, old_acc, new_acc) 172 | 173 | if print_output: 174 | print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}' 175 | try: 176 | if dist.get_rank() == 0: 177 | try: 178 | args.logger.info(print_str) 179 | except: 180 | print(print_str) 181 | except: 182 | pass 183 | 184 | return to_return -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10, CIFAR100 2 | from copy import deepcopy 3 | import numpy as np 4 | 5 | from data.data_utils import subsample_instances 6 | from config import cifar_10_root, cifar_100_root 7 | 8 | 9 | class CustomCIFAR10(CIFAR10): 10 | 11 | def __init__(self, *args, **kwargs): 12 | 13 | super(CustomCIFAR10, self).__init__(*args, **kwargs) 14 | 15 | self.uq_idxs = np.array(range(len(self))) 16 | 17 | def __getitem__(self, item): 18 | 19 | img, label = super().__getitem__(item) 20 | uq_idx = self.uq_idxs[item] 21 | 22 | return img, label, uq_idx 23 | 24 | def __len__(self): 25 | return len(self.targets) 26 | 27 | 28 | class CustomCIFAR100(CIFAR100): 29 | 30 | def __init__(self, *args, **kwargs): 31 | super(CustomCIFAR100, self).__init__(*args, **kwargs) 32 | 33 | self.uq_idxs = np.array(range(len(self))) 34 | 35 | def __getitem__(self, item): 36 | img, label = super().__getitem__(item) 37 | uq_idx = self.uq_idxs[item] 38 | 39 | return img, label, uq_idx 40 | 41 | def __len__(self): 42 | return len(self.targets) 43 | 44 | 45 | def subsample_dataset(dataset, idxs): 46 | 47 | # Allow for setting in which all empty set of indices is passed 48 | 49 | if len(idxs) > 0: 50 | 51 | dataset.data = dataset.data[idxs] 52 | dataset.targets = np.array(dataset.targets)[idxs].tolist() 53 | dataset.uq_idxs = dataset.uq_idxs[idxs] 54 | 55 | return dataset 56 | 57 | else: 58 | 59 | return None 60 | 61 | 62 | def subsample_classes(dataset, include_classes=(0, 1, 8, 9)): 63 | 64 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] 65 | 66 | target_xform_dict = {} 67 | for i, k in enumerate(include_classes): 68 | target_xform_dict[k] = i 69 | 70 | dataset = subsample_dataset(dataset, cls_idxs) 71 | 72 | # dataset.target_transform = lambda x: target_xform_dict[x] 73 | 74 | return dataset 75 | 76 | 77 | def get_train_val_indices(train_dataset, val_split=0.2): 78 | 79 | train_classes = np.unique(train_dataset.targets) 80 | 81 | # Get train/test indices 82 | train_idxs = [] 83 | val_idxs = [] 84 | for cls in train_classes: 85 | 86 | cls_idxs = np.where(train_dataset.targets == cls)[0] 87 | 88 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 89 | t_ = [x for x in cls_idxs if x not in v_] 90 | 91 | train_idxs.extend(t_) 92 | val_idxs.extend(v_) 93 | 94 | return train_idxs, val_idxs 95 | 96 | 97 | def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9), 98 | prop_train_labels=0.8, split_train_val=False, seed=0): 99 | 100 | np.random.seed(seed) 101 | 102 | # Init entire training set 103 | whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True) 104 | 105 | # Get labelled training set which has subsampled classes, then subsample some indices from that 106 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 107 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 108 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 109 | 110 | # Split into training and validation sets 111 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 112 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 113 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 114 | val_dataset_labelled_split.transform = test_transform 115 | 116 | # Get unlabelled data 117 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 118 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 119 | 120 | # Get test set for all classes 121 | test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False) 122 | 123 | # Either split train into train and val or use test set as val 124 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 125 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 126 | 127 | all_datasets = { 128 | 'train_labelled': train_dataset_labelled, 129 | 'train_unlabelled': train_dataset_unlabelled, 130 | 'val': val_dataset_labelled, 131 | 'test': test_dataset, 132 | } 133 | 134 | return all_datasets 135 | 136 | 137 | def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80), 138 | prop_train_labels=0.8, split_train_val=False, seed=0): 139 | 140 | np.random.seed(seed) 141 | 142 | # Init entire training set 143 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True) 144 | 145 | # Get labelled training set which has subsampled classes, then subsample some indices from that 146 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 147 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 148 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 149 | 150 | # Split into training and validation sets 151 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 152 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 153 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 154 | val_dataset_labelled_split.transform = test_transform 155 | 156 | # Get unlabelled data 157 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 158 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 159 | 160 | # Get test set for all classes 161 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False) 162 | 163 | # Either split train into train and val or use test set as val 164 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 165 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 166 | 167 | all_datasets = { 168 | 'train_labelled': train_dataset_labelled, 169 | 'train_unlabelled': train_dataset_unlabelled, 170 | 'val': val_dataset_labelled, 171 | 'test': test_dataset, 172 | } 173 | 174 | return all_datasets 175 | 176 | 177 | if __name__ == '__main__': 178 | 179 | x = get_cifar_100_datasets(None, None, split_train_val=False, 180 | train_classes=range(80), prop_train_labels=0.5) 181 | 182 | print('Printing lens...') 183 | for k, v in x.items(): 184 | if v is not None: 185 | print(f'{k}: {len(v)}') 186 | 187 | print('Printing labelled and unlabelled overlap...') 188 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 189 | print('Printing total instances in train...') 190 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 191 | 192 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 193 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 194 | print(f'Len labelled set: {len(x["train_labelled"])}') 195 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from torchvision.datasets.folder import default_loader 7 | from torchvision.datasets.utils import download_url 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import subsample_instances 11 | from config import cub_root 12 | 13 | 14 | class CustomCub2011(Dataset): 15 | base_folder = 'CUB_200_2011/images' 16 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 17 | filename = 'CUB_200_2011.tgz' 18 | tgz_md5 = '97eceeb196236b17998738112f37df78' 19 | 20 | def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=True): 21 | 22 | self.root = os.path.expanduser(root) 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | 26 | self.loader = loader 27 | self.train = train 28 | 29 | 30 | if download: 31 | self._download() 32 | 33 | if not self._check_integrity(): 34 | raise RuntimeError('Dataset not found or corrupted.' + 35 | ' You can use download=True to download it') 36 | 37 | self.uq_idxs = np.array(range(len(self))) 38 | 39 | def _load_metadata(self): 40 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 41 | names=['img_id', 'filepath']) 42 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 43 | sep=' ', names=['img_id', 'target']) 44 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 45 | sep=' ', names=['img_id', 'is_training_img']) 46 | 47 | data = images.merge(image_class_labels, on='img_id') 48 | self.data = data.merge(train_test_split, on='img_id') 49 | 50 | if self.train: 51 | self.data = self.data[self.data.is_training_img == 1] 52 | else: 53 | self.data = self.data[self.data.is_training_img == 0] 54 | 55 | def _check_integrity(self): 56 | try: 57 | self._load_metadata() 58 | except Exception: 59 | return False 60 | 61 | for index, row in self.data.iterrows(): 62 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 63 | if not os.path.isfile(filepath): 64 | print(filepath) 65 | return False 66 | return True 67 | 68 | def _download(self): 69 | import tarfile 70 | 71 | if self._check_integrity(): 72 | print('Files already downloaded and verified') 73 | return 74 | 75 | download_url(self.url, self.root, self.filename, self.tgz_md5) 76 | 77 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 78 | tar.extractall(path=self.root) 79 | 80 | def __len__(self): 81 | return len(self.data) 82 | 83 | def __getitem__(self, idx): 84 | sample = self.data.iloc[idx] 85 | path = os.path.join(self.root, self.base_folder, sample.filepath) 86 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 87 | img = self.loader(path) 88 | 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | 95 | return img, target, self.uq_idxs[idx] 96 | 97 | 98 | def subsample_dataset(dataset, idxs): 99 | 100 | mask = np.zeros(len(dataset)).astype('bool') 101 | mask[idxs] = True 102 | 103 | dataset.data = dataset.data[mask] 104 | dataset.uq_idxs = dataset.uq_idxs[mask] 105 | 106 | return dataset 107 | 108 | 109 | def subsample_classes(dataset, include_classes=range(160)): 110 | 111 | include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199 112 | cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub] 113 | 114 | # TODO: For now have no target transform 115 | target_xform_dict = {} 116 | for i, k in enumerate(include_classes): 117 | target_xform_dict[k] = i 118 | 119 | dataset = subsample_dataset(dataset, cls_idxs) 120 | 121 | dataset.target_transform = lambda x: target_xform_dict[x] 122 | 123 | return dataset 124 | 125 | 126 | def get_train_val_indices(train_dataset, val_split=0.2): 127 | 128 | train_classes = np.unique(train_dataset.data['target']) 129 | 130 | # Get train/test indices 131 | train_idxs = [] 132 | val_idxs = [] 133 | for cls in train_classes: 134 | 135 | cls_idxs = np.where(train_dataset.data['target'] == cls)[0] 136 | 137 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 138 | t_ = [x for x in cls_idxs if x not in v_] 139 | 140 | train_idxs.extend(t_) 141 | val_idxs.extend(v_) 142 | 143 | return train_idxs, val_idxs 144 | 145 | 146 | def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, 147 | split_train_val=False, seed=0, download=False): 148 | 149 | np.random.seed(seed) 150 | 151 | # Init entire training set 152 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True, download=download) 153 | 154 | # Get labelled training set which has subsampled classes, then subsample some indices from that 155 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 156 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 157 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 158 | 159 | # Split into training and validation sets 160 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 161 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 162 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 163 | val_dataset_labelled_split.transform = test_transform 164 | 165 | # Get unlabelled data 166 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 167 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 168 | 169 | # Get test set for all classes 170 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False) 171 | 172 | # Either split train into train and val or use test set as val 173 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 174 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 175 | 176 | all_datasets = { 177 | 'train_labelled': train_dataset_labelled, 178 | 'train_unlabelled': train_dataset_unlabelled, 179 | 'val': val_dataset_labelled, 180 | 'test': test_dataset, 181 | } 182 | 183 | return all_datasets 184 | 185 | if __name__ == '__main__': 186 | 187 | x = get_cub_datasets(None, None, split_train_val=False, 188 | train_classes=range(100), prop_train_labels=0.5) 189 | 190 | print('Printing lens...') 191 | for k, v in x.items(): 192 | if v is not None: 193 | print(f'{k}: {len(v)}') 194 | 195 | print('Printing labelled and unlabelled overlap...') 196 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 197 | print('Printing total instances in train...') 198 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 199 | 200 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}') 201 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}') 202 | print(f'Len labelled set: {len(x["train_labelled"])}') 203 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import numpy as np 3 | 4 | import os 5 | 6 | from copy import deepcopy 7 | from data.data_utils import subsample_instances 8 | from config import imagenet_root 9 | 10 | 11 | class ImageNetBase(torchvision.datasets.ImageFolder): 12 | 13 | def __init__(self, root, transform): 14 | 15 | super(ImageNetBase, self).__init__(root, transform) 16 | 17 | self.uq_idxs = np.array(range(len(self))) 18 | 19 | def __getitem__(self, item): 20 | 21 | img, label = super().__getitem__(item) 22 | uq_idx = self.uq_idxs[item] 23 | 24 | return img, label, uq_idx 25 | 26 | 27 | def subsample_dataset(dataset, idxs): 28 | 29 | imgs_ = [] 30 | for i in idxs: 31 | imgs_.append(dataset.imgs[i]) 32 | dataset.imgs = imgs_ 33 | 34 | samples_ = [] 35 | for i in idxs: 36 | samples_.append(dataset.samples[i]) 37 | dataset.samples = samples_ 38 | 39 | # dataset.imgs = [x for i, x in enumerate(dataset.imgs) if i in idxs] 40 | # dataset.samples = [x for i, x in enumerate(dataset.samples) if i in idxs] 41 | 42 | dataset.targets = np.array(dataset.targets)[idxs].tolist() 43 | dataset.uq_idxs = dataset.uq_idxs[idxs] 44 | 45 | return dataset 46 | 47 | 48 | def subsample_classes(dataset, include_classes=list(range(1000))): 49 | 50 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] 51 | 52 | target_xform_dict = {} 53 | for i, k in enumerate(include_classes): 54 | target_xform_dict[k] = i 55 | 56 | dataset = subsample_dataset(dataset, cls_idxs) 57 | dataset.target_transform = lambda x: target_xform_dict[x] 58 | 59 | return dataset 60 | 61 | 62 | def get_train_val_indices(train_dataset, val_split=0.2): 63 | 64 | train_classes = list(set(train_dataset.targets)) 65 | 66 | # Get train/test indices 67 | train_idxs = [] 68 | val_idxs = [] 69 | for cls in train_classes: 70 | 71 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] 72 | 73 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 74 | t_ = [x for x in cls_idxs if x not in v_] 75 | 76 | train_idxs.extend(t_) 77 | val_idxs.extend(v_) 78 | 79 | return train_idxs, val_idxs 80 | 81 | 82 | def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80), 83 | prop_train_labels=0.8, split_train_val=False, seed=0): 84 | 85 | np.random.seed(seed) 86 | 87 | # Subsample imagenet dataset initially to include 100 classes 88 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False) 89 | subsampled_100_classes = np.sort(subsampled_100_classes) 90 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}') 91 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))} 92 | 93 | # Init entire training set 94 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) 95 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes) 96 | 97 | # Reset dataset 98 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples] 99 | whole_training_set.targets = [s[1] for s in whole_training_set.samples] 100 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set))) 101 | whole_training_set.target_transform = None 102 | 103 | # Get labelled training set which has subsampled classes, then subsample some indices from that 104 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 105 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 106 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 107 | 108 | # Split into training and validation sets 109 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 110 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 111 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 112 | val_dataset_labelled_split.transform = test_transform 113 | 114 | # Get unlabelled data 115 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 116 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 117 | 118 | # Get test set for all classes 119 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) 120 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes) 121 | 122 | # Reset test set 123 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples] 124 | test_dataset.targets = [s[1] for s in test_dataset.samples] 125 | test_dataset.uq_idxs = np.array(range(len(test_dataset))) 126 | test_dataset.target_transform = None 127 | 128 | # Either split train into train and val or use test set as val 129 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 130 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 131 | 132 | all_datasets = { 133 | 'train_labelled': train_dataset_labelled, 134 | 'train_unlabelled': train_dataset_unlabelled, 135 | 'val': val_dataset_labelled, 136 | 'test': test_dataset, 137 | } 138 | 139 | return all_datasets 140 | 141 | 142 | def get_imagenet_1k_datasets(train_transform, test_transform, train_classes=range(500), 143 | prop_train_labels=0.5, split_train_val=False, seed=0): 144 | 145 | np.random.seed(seed) 146 | 147 | # Init entire training set 148 | whole_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) 149 | 150 | # Get labelled training set which has subsampled classes, then subsample some indices from that 151 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 152 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 153 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 154 | 155 | # Split into training and validation sets 156 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 157 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 158 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 159 | val_dataset_labelled_split.transform = test_transform 160 | 161 | # Get unlabelled data 162 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 163 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 164 | 165 | # Get test set for all classes 166 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) 167 | 168 | # Either split train into train and val or use test set as val 169 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 170 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 171 | 172 | all_datasets = { 173 | 'train_labelled': train_dataset_labelled, 174 | 'train_unlabelled': train_dataset_unlabelled, 175 | 'val': val_dataset_labelled, 176 | 'test': test_dataset, 177 | } 178 | 179 | return all_datasets 180 | 181 | 182 | 183 | if __name__ == '__main__': 184 | 185 | x = get_imagenet_100_datasets(None, None, split_train_val=False, 186 | train_classes=range(50), prop_train_labels=0.5) 187 | 188 | print('Printing lens...') 189 | for k, v in x.items(): 190 | if v is not None: 191 | print(f'{k}: {len(v)}') 192 | 193 | print('Printing labelled and unlabelled overlap...') 194 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 195 | print('Printing total instances in train...') 196 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 197 | 198 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 199 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 200 | print(f'Len labelled set: {len(x["train_labelled"])}') 201 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class DINOHead(nn.Module): 7 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, 8 | nlayers=3, hidden_dim=2048, bottleneck_dim=256): 9 | super().__init__() 10 | nlayers = max(nlayers, 1) 11 | if nlayers == 1: 12 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 13 | elif nlayers != 0: 14 | layers = [nn.Linear(in_dim, hidden_dim)] 15 | if use_bn: 16 | layers.append(nn.BatchNorm1d(hidden_dim)) 17 | layers.append(nn.GELU()) 18 | for _ in range(nlayers - 2): 19 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 20 | if use_bn: 21 | layers.append(nn.BatchNorm1d(hidden_dim)) 22 | layers.append(nn.GELU()) 23 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 24 | self.mlp = nn.Sequential(*layers) 25 | self.apply(self._init_weights) 26 | self.last_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False)) 27 | self.last_layer.weight_g.data.fill_(1) 28 | if norm_last_layer: 29 | self.last_layer.weight_g.requires_grad = False 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | torch.nn.init.trunc_normal_(m.weight, std=.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x_proj = self.mlp(x) 39 | x = nn.functional.normalize(x, dim=-1, p=2) 40 | # x = x.detach() 41 | logits = self.last_layer(x) 42 | return x_proj, logits 43 | 44 | 45 | class ContrastiveLearningViewGenerator(object): 46 | """Take two random crops of one image as the query and key.""" 47 | 48 | def __init__(self, base_transform, n_views=2): 49 | self.base_transform = base_transform 50 | self.n_views = n_views 51 | 52 | def __call__(self, x): 53 | if not isinstance(self.base_transform, list): 54 | return [self.base_transform(x) for i in range(self.n_views)] 55 | else: 56 | return [self.base_transform[i](x) for i in range(self.n_views)] 57 | 58 | class SupConLoss(torch.nn.Module): 59 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 60 | It also supports the unsupervised contrastive loss in SimCLR 61 | From: https://github.com/HobbitLong/SupContrast""" 62 | def __init__(self, temperature=0.07, contrast_mode='all', 63 | base_temperature=0.07): 64 | super(SupConLoss, self).__init__() 65 | self.temperature = temperature 66 | self.contrast_mode = contrast_mode 67 | self.base_temperature = base_temperature 68 | 69 | def forward(self, features, labels=None, mask=None): 70 | """Compute loss for model. If both `labels` and `mask` are None, 71 | it degenerates to SimCLR unsupervised loss: 72 | https://arxiv.org/pdf/2002.05709.pdf 73 | Args: 74 | features: hidden vector of shape [bsz, n_views, ...]. 75 | labels: ground truth of shape [bsz]. 76 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 77 | has the same class as sample i. Can be asymmetric. 78 | Returns: 79 | A loss scalar. 80 | """ 81 | 82 | device = (torch.device('cuda') 83 | if features.is_cuda 84 | else torch.device('cpu')) 85 | 86 | if len(features.shape) < 3: 87 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 88 | 'at least 3 dimensions are required') 89 | if len(features.shape) > 3: 90 | features = features.view(features.shape[0], features.shape[1], -1) 91 | 92 | batch_size = features.shape[0] 93 | if labels is not None and mask is not None: 94 | raise ValueError('Cannot define both `labels` and `mask`') 95 | elif labels is None and mask is None: 96 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 97 | elif labels is not None: 98 | labels = labels.contiguous().view(-1, 1) 99 | if labels.shape[0] != batch_size: 100 | raise ValueError('Num of labels does not match num of features') 101 | mask = torch.eq(labels, labels.T).float().to(device) 102 | else: 103 | mask = mask.float().to(device) 104 | 105 | contrast_count = features.shape[1] 106 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 107 | if self.contrast_mode == 'one': 108 | anchor_feature = features[:, 0] 109 | anchor_count = 1 110 | elif self.contrast_mode == 'all': 111 | anchor_feature = contrast_feature 112 | anchor_count = contrast_count 113 | else: 114 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 115 | 116 | # compute logits 117 | anchor_dot_contrast = torch.div( 118 | torch.matmul(anchor_feature, contrast_feature.T), 119 | self.temperature) 120 | 121 | # for numerical stability 122 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 123 | logits = anchor_dot_contrast - logits_max.detach() 124 | 125 | # tile mask 126 | mask = mask.repeat(anchor_count, contrast_count) 127 | # mask-out self-contrast cases 128 | logits_mask = torch.scatter( 129 | torch.ones_like(mask), 130 | 1, 131 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 132 | 0 133 | ) 134 | mask = mask * logits_mask 135 | 136 | # compute log_prob 137 | exp_logits = torch.exp(logits) * logits_mask 138 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 139 | 140 | # compute mean of log-likelihood over positive 141 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 142 | 143 | # loss 144 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 145 | loss = loss.view(anchor_count, batch_size).mean() 146 | 147 | return loss 148 | 149 | 150 | 151 | def info_nce_logits(features, n_views=2, temperature=1.0, device='cuda'): 152 | 153 | b_ = 0.5 * int(features.size(0)) 154 | 155 | labels = torch.cat([torch.arange(b_) for i in range(n_views)], dim=0) 156 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() 157 | labels = labels.to(device) 158 | 159 | features = F.normalize(features, dim=1) 160 | 161 | similarity_matrix = torch.matmul(features, features.T) 162 | 163 | # discard the main diagonal from both: labels and similarities matrix 164 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device) 165 | labels = labels[~mask].view(labels.shape[0], -1) 166 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) 167 | 168 | # select and combine multiple positives 169 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) 170 | 171 | # select only the negatives the negatives 172 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) 173 | 174 | logits = torch.cat([positives, negatives], dim=1) 175 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device) 176 | 177 | logits = logits / temperature 178 | return logits, labels 179 | 180 | 181 | def get_params_groups(model): 182 | regularized = [] 183 | not_regularized = [] 184 | for name, param in model.named_parameters(): 185 | if not param.requires_grad: 186 | continue 187 | # we do not regularize biases nor Norm parameters 188 | if name.endswith(".bias") or len(param.shape) == 1: 189 | not_regularized.append(param) 190 | else: 191 | regularized.append(param) 192 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 193 | 194 | 195 | class DistillLoss(nn.Module): 196 | def __init__(self, warmup_teacher_temp_epochs, nepochs, 197 | ncrops=2, warmup_teacher_temp=0.07, teacher_temp=0.04, 198 | student_temp=0.1): 199 | super().__init__() 200 | self.student_temp = student_temp 201 | self.ncrops = ncrops 202 | self.teacher_temp_schedule = np.concatenate(( 203 | np.linspace(warmup_teacher_temp, 204 | teacher_temp, warmup_teacher_temp_epochs), 205 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp 206 | )) 207 | 208 | def forward(self, student_output, teacher_output, epoch): 209 | """ 210 | Cross-entropy between softmax outputs of the teacher and student networks. 211 | """ 212 | student_out = student_output / self.student_temp 213 | student_out = student_out.chunk(self.ncrops) 214 | 215 | # teacher centering and sharpening 216 | temp = self.teacher_temp_schedule[epoch] 217 | teacher_out = F.softmax(teacher_output / temp, dim=-1) 218 | teacher_out = teacher_out.detach().chunk(2) 219 | 220 | total_loss = 0 221 | n_loss_terms = 0 222 | for iq, q in enumerate(teacher_out): 223 | for v in range(len(student_out)): 224 | if v == iq: 225 | # we skip cases where student and teacher operate on the same view 226 | continue 227 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 228 | total_loss += loss.mean() 229 | n_loss_terms += 1 230 | total_loss /= n_loss_terms 231 | return total_loss 232 | -------------------------------------------------------------------------------- /data/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from copy import deepcopy 4 | 5 | from torchvision.datasets.folder import default_loader 6 | from torch.utils.data import Dataset 7 | 8 | from data.data_utils import subsample_instances 9 | from config import aircraft_root 10 | 11 | def make_dataset(dir, image_ids, targets): 12 | assert(len(image_ids) == len(targets)) 13 | images = [] 14 | dir = os.path.expanduser(dir) 15 | for i in range(len(image_ids)): 16 | item = (os.path.join(dir, 'data', 'images', 17 | '%s.jpg' % image_ids[i]), targets[i]) 18 | images.append(item) 19 | return images 20 | 21 | 22 | def find_classes(classes_file): 23 | 24 | # read classes file, separating out image IDs and class names 25 | image_ids = [] 26 | targets = [] 27 | f = open(classes_file, 'r') 28 | for line in f: 29 | split_line = line.split(' ') 30 | image_ids.append(split_line[0]) 31 | targets.append(' '.join(split_line[1:])) 32 | f.close() 33 | 34 | # index class names 35 | classes = np.unique(targets) 36 | class_to_idx = {classes[i]: i for i in range(len(classes))} 37 | targets = [class_to_idx[c] for c in targets] 38 | 39 | return (image_ids, targets, classes, class_to_idx) 40 | 41 | 42 | class FGVCAircraft(Dataset): 43 | 44 | """`FGVC-Aircraft `_ Dataset. 45 | 46 | Args: 47 | root (string): Root directory path to dataset. 48 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification 49 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). 50 | transform (callable, optional): A function/transform that takes in a PIL image 51 | and returns a transformed version. E.g. ``transforms.RandomCrop`` 52 | target_transform (callable, optional): A function/transform that takes in the 53 | target and transforms it. 54 | loader (callable, optional): A function to load an image given its path. 55 | download (bool, optional): If true, downloads the dataset from the internet and 56 | puts it in the root directory. If dataset is already downloaded, it is not 57 | downloaded again. 58 | """ 59 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' 60 | class_types = ('variant', 'family', 'manufacturer') 61 | splits = ('train', 'val', 'trainval', 'test') 62 | 63 | def __init__(self, root, class_type='variant', split='train', transform=None, 64 | target_transform=None, loader=default_loader, download=False): 65 | if split not in self.splits: 66 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 67 | split, ', '.join(self.splits), 68 | )) 69 | if class_type not in self.class_types: 70 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( 71 | class_type, ', '.join(self.class_types), 72 | )) 73 | self.root = os.path.expanduser(root) 74 | self.class_type = class_type 75 | self.split = split 76 | self.classes_file = os.path.join(self.root, 'data', 77 | 'images_%s_%s.txt' % (self.class_type, self.split)) 78 | 79 | if download: 80 | self.download() 81 | 82 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) 83 | samples = make_dataset(self.root, image_ids, targets) 84 | 85 | self.transform = transform 86 | self.target_transform = target_transform 87 | self.loader = loader 88 | 89 | self.samples = samples 90 | self.classes = classes 91 | self.class_to_idx = class_to_idx 92 | self.train = True if split == 'train' else False 93 | 94 | self.uq_idxs = np.array(range(len(self))) 95 | 96 | def __getitem__(self, index): 97 | """ 98 | Args: 99 | index (int): Index 100 | 101 | Returns: 102 | tuple: (sample, target) where target is class_index of the target class. 103 | """ 104 | 105 | path, target = self.samples[index] 106 | sample = self.loader(path) 107 | if self.transform is not None: 108 | sample = self.transform(sample) 109 | if self.target_transform is not None: 110 | target = self.target_transform(target) 111 | 112 | return sample, target, self.uq_idxs[index] 113 | 114 | def __len__(self): 115 | return len(self.samples) 116 | 117 | def __repr__(self): 118 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 119 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 120 | fmt_str += ' Root Location: {}\n'.format(self.root) 121 | tmp = ' Transforms (if any): ' 122 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 123 | tmp = ' Target Transforms (if any): ' 124 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 125 | return fmt_str 126 | 127 | def _check_exists(self): 128 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ 129 | os.path.exists(self.classes_file) 130 | 131 | def download(self): 132 | """Download the FGVC-Aircraft data if it doesn't exist already.""" 133 | from six.moves import urllib 134 | import tarfile 135 | 136 | if self._check_exists(): 137 | return 138 | 139 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz 140 | print('Downloading %s ... (may take a few minutes)' % self.url) 141 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir)) 142 | tar_name = self.url.rpartition('/')[-1] 143 | tar_path = os.path.join(parent_dir, tar_name) 144 | data = urllib.request.urlopen(self.url) 145 | 146 | # download .tar.gz file 147 | with open(tar_path, 'wb') as f: 148 | f.write(data.read()) 149 | 150 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b 151 | data_folder = tar_path.strip('.tar.gz') 152 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder)) 153 | tar = tarfile.open(tar_path) 154 | tar.extractall(parent_dir) 155 | 156 | # if necessary, rename data folder to self.root 157 | if not os.path.samefile(data_folder, self.root): 158 | print('Renaming %s to %s ...' % (data_folder, self.root)) 159 | os.rename(data_folder, self.root) 160 | 161 | # delete .tar.gz file 162 | print('Deleting %s ...' % tar_path) 163 | os.remove(tar_path) 164 | 165 | print('Done!') 166 | 167 | 168 | def subsample_dataset(dataset, idxs): 169 | 170 | mask = np.zeros(len(dataset)).astype('bool') 171 | mask[idxs] = True 172 | 173 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs] 174 | dataset.uq_idxs = dataset.uq_idxs[mask] 175 | 176 | return dataset 177 | 178 | 179 | def subsample_classes(dataset, include_classes=range(60)): 180 | 181 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] 182 | 183 | # TODO: Don't transform targets for now 184 | target_xform_dict = {} 185 | for i, k in enumerate(include_classes): 186 | target_xform_dict[k] = i 187 | 188 | dataset = subsample_dataset(dataset, cls_idxs) 189 | 190 | dataset.target_transform = lambda x: target_xform_dict[x] 191 | 192 | return dataset 193 | 194 | 195 | def get_train_val_indices(train_dataset, val_split=0.2): 196 | 197 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)] 198 | train_classes = np.unique(all_targets) 199 | 200 | # Get train/test indices 201 | train_idxs = [] 202 | val_idxs = [] 203 | for cls in train_classes: 204 | cls_idxs = np.where(all_targets == cls)[0] 205 | 206 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 207 | t_ = [x for x in cls_idxs if x not in v_] 208 | 209 | train_idxs.extend(t_) 210 | val_idxs.extend(v_) 211 | 212 | return train_idxs, val_idxs 213 | 214 | 215 | def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8, 216 | split_train_val=False, seed=0): 217 | 218 | np.random.seed(seed) 219 | 220 | # Init entire training set 221 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval') 222 | 223 | # Get labelled training set which has subsampled classes, then subsample some indices from that 224 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 225 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 226 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 227 | 228 | # Split into training and validation sets 229 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 230 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 231 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 232 | val_dataset_labelled_split.transform = test_transform 233 | 234 | # Get unlabelled data 235 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 236 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 237 | 238 | # Get test set for all classes 239 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test') 240 | 241 | # Either split train into train and val or use test set as val 242 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 243 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 244 | 245 | all_datasets = { 246 | 'train_labelled': train_dataset_labelled, 247 | 'train_unlabelled': train_dataset_unlabelled, 248 | 'val': val_dataset_labelled, 249 | 'test': test_dataset, 250 | } 251 | 252 | return all_datasets 253 | 254 | if __name__ == '__main__': 255 | 256 | x = get_aircraft_datasets(None, None, split_train_val=False) 257 | 258 | print('Printing lens...') 259 | for k, v in x.items(): 260 | if v is not None: 261 | print(f'{k}: {len(v)}') 262 | 263 | print('Printing labelled and unlabelled overlap...') 264 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 265 | print('Printing total instances in train...') 266 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 267 | print('Printing number of labelled classes...') 268 | print(len(set([i[1] for i in x['train_labelled'].samples]))) 269 | print('Printing total number of classes...') 270 | print(len(set([i[1] for i in x['train_unlabelled'].samples]))) 271 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import SGD, lr_scheduler 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from data.augmentations import get_transform 12 | from data.get_datasets import get_datasets, get_class_splits 13 | 14 | from util.general_utils import AverageMeter, init_experiment 15 | from util.cluster_and_log_utils import log_accs_from_preds 16 | from config import exp_root 17 | from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups 18 | 19 | 20 | def train(student, train_loader, test_loader, unlabelled_train_loader, args): 21 | params_groups = get_params_groups(student) 22 | optimizer = SGD(params_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 23 | fp16_scaler = None 24 | if args.fp16: 25 | fp16_scaler = torch.cuda.amp.GradScaler() 26 | 27 | exp_lr_scheduler = lr_scheduler.CosineAnnealingLR( 28 | optimizer, 29 | T_max=args.epochs, 30 | eta_min=args.lr * 1e-3, 31 | ) 32 | 33 | 34 | cluster_criterion = DistillLoss( 35 | args.warmup_teacher_temp_epochs, 36 | args.epochs, 37 | args.n_views, 38 | args.warmup_teacher_temp, 39 | args.teacher_temp, 40 | ) 41 | 42 | # # inductive 43 | # best_test_acc_lab = 0 44 | # # transductive 45 | # best_train_acc_lab = 0 46 | # best_train_acc_ubl = 0 47 | # best_train_acc_all = 0 48 | 49 | for epoch in range(args.epochs): 50 | loss_record = AverageMeter() 51 | 52 | student.train() 53 | for batch_idx, batch in enumerate(train_loader): 54 | images, class_labels, uq_idxs, mask_lab = batch 55 | mask_lab = mask_lab[:, 0] 56 | 57 | class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool() 58 | images = torch.cat(images, dim=0).cuda(non_blocking=True) 59 | 60 | with torch.cuda.amp.autocast(fp16_scaler is not None): 61 | student_proj, student_out = student(images) 62 | teacher_out = student_out.detach() 63 | 64 | # clustering, sup 65 | sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0) 66 | sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0) 67 | cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels) 68 | 69 | # clustering, unsup 70 | cluster_loss = cluster_criterion(student_out, teacher_out, epoch) 71 | avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0) 72 | me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs))) 73 | cluster_loss += args.memax_weight * me_max_loss 74 | 75 | # represent learning, unsup 76 | contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj) 77 | contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels) 78 | 79 | # representation learning, sup 80 | student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1) 81 | student_proj = torch.nn.functional.normalize(student_proj, dim=-1) 82 | sup_con_labels = class_labels[mask_lab] 83 | sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels) 84 | 85 | pstr = '' 86 | pstr += f'cls_loss: {cls_loss.item():.4f} ' 87 | pstr += f'cluster_loss: {cluster_loss.item():.4f} ' 88 | pstr += f'sup_con_loss: {sup_con_loss.item():.4f} ' 89 | pstr += f'contrastive_loss: {contrastive_loss.item():.4f} ' 90 | 91 | loss = 0 92 | loss += (1 - args.sup_weight) * cluster_loss + args.sup_weight * cls_loss 93 | loss += (1 - args.sup_weight) * contrastive_loss + args.sup_weight * sup_con_loss 94 | 95 | # Train acc 96 | loss_record.update(loss.item(), class_labels.size(0)) 97 | optimizer.zero_grad() 98 | if fp16_scaler is None: 99 | loss.backward() 100 | optimizer.step() 101 | else: 102 | fp16_scaler.scale(loss).backward() 103 | fp16_scaler.step(optimizer) 104 | fp16_scaler.update() 105 | 106 | if batch_idx % args.print_freq == 0: 107 | args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}' 108 | .format(epoch, batch_idx, len(train_loader), loss.item(), pstr)) 109 | 110 | args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg)) 111 | 112 | args.logger.info('Testing on unlabelled examples in the training data...') 113 | all_acc, old_acc, new_acc = test(student, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args) 114 | # args.logger.info('Testing on disjoint test set...') 115 | # all_acc_test, old_acc_test, new_acc_test = test(student, test_loader, epoch=epoch, save_name='Test ACC', args=args) 116 | 117 | 118 | args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) 119 | # args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test)) 120 | 121 | # Step schedule 122 | exp_lr_scheduler.step() 123 | 124 | save_dict = { 125 | 'model': student.state_dict(), 126 | 'optimizer': optimizer.state_dict(), 127 | 'epoch': epoch + 1, 128 | } 129 | 130 | torch.save(save_dict, args.model_path) 131 | args.logger.info("model saved to {}.".format(args.model_path)) 132 | 133 | # if old_acc_test > best_test_acc_lab: 134 | # 135 | # args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...') 136 | # args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) 137 | # 138 | # torch.save(save_dict, args.model_path[:-3] + f'_best.pt') 139 | # args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt')) 140 | # 141 | # # inductive 142 | # best_test_acc_lab = old_acc_test 143 | # # transductive 144 | # best_train_acc_lab = old_acc 145 | # best_train_acc_ubl = new_acc 146 | # best_train_acc_all = all_acc 147 | # 148 | # args.logger.info(f'Exp Name: {args.exp_name}') 149 | # args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}') 150 | 151 | 152 | def test(model, test_loader, epoch, save_name, args): 153 | 154 | model.eval() 155 | 156 | preds, targets = [], [] 157 | mask = np.array([]) 158 | for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)): 159 | images = images.cuda(non_blocking=True) 160 | with torch.no_grad(): 161 | _, logits = model(images) 162 | preds.append(logits.argmax(1).cpu().numpy()) 163 | targets.append(label.cpu().numpy()) 164 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label])) 165 | 166 | preds = np.concatenate(preds) 167 | targets = np.concatenate(targets) 168 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask, 169 | T=epoch, eval_funcs=args.eval_funcs, save_name=save_name, 170 | args=args) 171 | 172 | return all_acc, old_acc, new_acc 173 | 174 | 175 | if __name__ == "__main__": 176 | 177 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 178 | parser.add_argument('--batch_size', default=128, type=int) 179 | parser.add_argument('--num_workers', default=8, type=int) 180 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2p']) 181 | 182 | parser.add_argument('--warmup_model_dir', type=str, default=None) 183 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19') 184 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 185 | parser.add_argument('--use_ssb_splits', action='store_true', default=True) 186 | 187 | parser.add_argument('--grad_from_block', type=int, default=11) 188 | parser.add_argument('--lr', type=float, default=0.1) 189 | parser.add_argument('--gamma', type=float, default=0.1) 190 | parser.add_argument('--momentum', type=float, default=0.9) 191 | parser.add_argument('--weight_decay', type=float, default=1e-4) 192 | parser.add_argument('--epochs', default=200, type=int) 193 | parser.add_argument('--exp_root', type=str, default=exp_root) 194 | parser.add_argument('--transform', type=str, default='imagenet') 195 | parser.add_argument('--sup_weight', type=float, default=0.35) 196 | parser.add_argument('--n_views', default=2, type=int) 197 | 198 | parser.add_argument('--memax_weight', type=float, default=2) 199 | parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.') 200 | parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.') 201 | parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.') 202 | 203 | parser.add_argument('--fp16', action='store_true', default=False) 204 | parser.add_argument('--print_freq', default=10, type=int) 205 | parser.add_argument('--exp_name', default=None, type=str) 206 | 207 | # ---------------------- 208 | # INIT 209 | # ---------------------- 210 | args = parser.parse_args() 211 | device = torch.device('cuda:0') 212 | args = get_class_splits(args) 213 | 214 | args.num_labeled_classes = len(args.train_classes) 215 | args.num_unlabeled_classes = len(args.unlabeled_classes) 216 | 217 | init_experiment(args, runner_name=['simgcd']) 218 | args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results') 219 | 220 | torch.backends.cudnn.benchmark = True 221 | 222 | # ---------------------- 223 | # BASE MODEL 224 | # ---------------------- 225 | args.interpolation = 3 226 | args.crop_pct = 0.875 227 | 228 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 229 | 230 | if args.warmup_model_dir is not None: 231 | args.logger.info(f'Loading weights from {args.warmup_model_dir}') 232 | backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu')) 233 | 234 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model 235 | args.image_size = 224 236 | args.feat_dim = 768 237 | args.num_mlp_layers = 3 238 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes 239 | 240 | # ---------------------- 241 | # HOW MUCH OF BASE MODEL TO FINETUNE 242 | # ---------------------- 243 | for m in backbone.parameters(): 244 | m.requires_grad = False 245 | 246 | # Only finetune layers from block 'args.grad_from_block' onwards 247 | for name, m in backbone.named_parameters(): 248 | if 'block' in name: 249 | block_num = int(name.split('.')[1]) 250 | if block_num >= args.grad_from_block: 251 | m.requires_grad = True 252 | 253 | 254 | args.logger.info('model build') 255 | 256 | # -------------------- 257 | # CONTRASTIVE TRANSFORM 258 | # -------------------- 259 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) 260 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 261 | # -------------------- 262 | # DATASETS 263 | # -------------------- 264 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name, 265 | train_transform, 266 | test_transform, 267 | args) 268 | 269 | # -------------------- 270 | # SAMPLER 271 | # Sampler which balances labelled and unlabelled examples in each batch 272 | # -------------------- 273 | label_len = len(train_dataset.labelled_dataset) 274 | unlabelled_len = len(train_dataset.unlabelled_dataset) 275 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] 276 | sample_weights = torch.DoubleTensor(sample_weights) 277 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset)) 278 | 279 | # -------------------- 280 | # DATALOADERS 281 | # -------------------- 282 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, 283 | sampler=sampler, drop_last=True, pin_memory=True) 284 | test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, 285 | batch_size=256, shuffle=False, pin_memory=False) 286 | # test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers, 287 | # batch_size=256, shuffle=False, pin_memory=False) 288 | 289 | # ---------------------- 290 | # PROJECTION HEAD 291 | # ---------------------- 292 | projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers) 293 | model = nn.Sequential(backbone, projector).to(device) 294 | 295 | # ---------------------- 296 | # TRAIN 297 | # ---------------------- 298 | # train(model, train_loader, test_loader_labelled, test_loader_unlabelled, args) 299 | train(model, train_loader, None, test_loader_unlabelled, args) 300 | -------------------------------------------------------------------------------- /train_mp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.distributed as dist 9 | import torch.backends.cudnn as cudnn 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from data.augmentations import get_transform 14 | from data.get_datasets import get_datasets, get_class_splits 15 | 16 | from util.general_utils import AverageMeter, init_experiment, DistributedWeightedSampler 17 | from util.cluster_and_log_utils import log_accs_from_preds 18 | from config import exp_root 19 | from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups 20 | 21 | 22 | def get_parser(): 23 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | 25 | parser.add_argument('--batch_size', default=128, type=int) 26 | parser.add_argument('--num_workers', default=8, type=int) 27 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2b']) 28 | 29 | parser.add_argument('--warmup_model_dir', type=str, default=None) 30 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19') 31 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 32 | parser.add_argument('--use_ssb_splits', action='store_true', default=True) 33 | 34 | parser.add_argument('--grad_from_block', type=int, default=11) 35 | parser.add_argument('--lr', type=float, default=0.1) 36 | parser.add_argument('--gamma', type=float, default=0.1) 37 | parser.add_argument('--momentum', type=float, default=0.9) 38 | parser.add_argument('--weight_decay', type=float, default=1e-4) 39 | parser.add_argument('--epochs', default=200, type=int) 40 | parser.add_argument('--exp_root', type=str, default=exp_root) 41 | parser.add_argument('--transform', type=str, default='imagenet') 42 | parser.add_argument('--sup_weight', type=float, default=0.35) 43 | parser.add_argument('--n_views', default=2, type=int) 44 | 45 | parser.add_argument('--memax_weight', type=float, default=2) 46 | parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.') 47 | parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.') 48 | parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.') 49 | 50 | parser.add_argument('--fp16', action='store_true', default=False) 51 | parser.add_argument('--print_freq', default=10, type=int) 52 | parser.add_argument('--exp_name', default=None, type=str) 53 | 54 | # ---------------------- 55 | # INIT 56 | # ---------------------- 57 | args = parser.parse_args() 58 | args = get_class_splits(args) 59 | 60 | args.num_labeled_classes = len(args.train_classes) 61 | args.num_unlabeled_classes = len(args.unlabeled_classes) 62 | 63 | if os.environ["LOCAL_RANK"] is not None: 64 | args.local_rank = int(os.environ["LOCAL_RANK"]) 65 | 66 | return args 67 | 68 | 69 | def main(args): 70 | # ---------------------- 71 | # BASE MODEL 72 | # ---------------------- 73 | args.interpolation = 3 74 | args.crop_pct = 0.875 75 | 76 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 77 | 78 | if args.warmup_model_dir is not None: 79 | if dist.get_rank() == 0: 80 | args.logger.info(f'Loading weights from {args.warmup_model_dir}') 81 | backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu')) 82 | 83 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model 84 | args.image_size = 224 85 | args.feat_dim = 768 86 | args.num_mlp_layers = 3 87 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes 88 | 89 | # ---------------------- 90 | # HOW MUCH OF BASE MODEL TO FINETUNE 91 | # ---------------------- 92 | for m in backbone.parameters(): 93 | m.requires_grad = False 94 | 95 | # Only finetune layers from block 'args.grad_from_block' onwards 96 | for name, m in backbone.named_parameters(): 97 | if 'block' in name: 98 | block_num = int(name.split('.')[1]) 99 | if block_num >= args.grad_from_block: 100 | m.requires_grad = True 101 | 102 | if dist.get_rank() == 0: 103 | args.logger.info('model build') 104 | 105 | # -------------------- 106 | # CONTRASTIVE TRANSFORM 107 | # -------------------- 108 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) 109 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 110 | # -------------------- 111 | # DATASETS 112 | # -------------------- 113 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name, 114 | train_transform, 115 | test_transform, 116 | args) 117 | 118 | # -------------------- 119 | # SAMPLER 120 | # Sampler which balances labelled and unlabelled examples in each batch 121 | # -------------------- 122 | label_len = len(train_dataset.labelled_dataset) 123 | unlabelled_len = len(train_dataset.unlabelled_dataset) 124 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] 125 | sample_weights = torch.DoubleTensor(sample_weights) 126 | train_sampler = DistributedWeightedSampler(train_dataset, sample_weights, num_samples=len(train_dataset)) 127 | unlabelled_train_sampler = torch.utils.data.distributed.DistributedSampler(unlabelled_train_examples_test) 128 | # test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) 129 | # -------------------- 130 | # DATALOADERS 131 | # -------------------- 132 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, 133 | sampler=train_sampler, drop_last=True, pin_memory=True) 134 | unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, batch_size=256, 135 | shuffle=False, sampler=unlabelled_train_sampler, pin_memory=False) 136 | # test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=256, 137 | # shuffle=False, sampler=test_sampler, pin_memory=False) 138 | 139 | # ---------------------- 140 | # PROJECTION HEAD 141 | # ---------------------- 142 | projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers) 143 | model = nn.Sequential(backbone, projector).cuda() 144 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 145 | 146 | params_groups = get_params_groups(model) 147 | optimizer = torch.optim.SGD( 148 | params_groups, 149 | lr=args.lr * (args.batch_size * dist.get_world_size() / 128), # linear scaling rule 150 | momentum=args.momentum, 151 | weight_decay=args.weight_decay 152 | ) 153 | 154 | fp16_scaler = None 155 | if args.fp16: 156 | fp16_scaler = torch.cuda.amp.GradScaler() 157 | 158 | exp_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 159 | optimizer, 160 | T_max=args.epochs, 161 | eta_min=args.lr * (args.batch_size * dist.get_world_size() / 128) * 1e-3, 162 | ) 163 | 164 | cluster_criterion = DistillLoss( 165 | args.warmup_teacher_temp_epochs, 166 | args.epochs, 167 | args.n_views, 168 | args.warmup_teacher_temp, 169 | args.teacher_temp, 170 | ) 171 | 172 | # # inductive 173 | # best_test_acc_lab = 0 174 | # # transductive 175 | # best_train_acc_lab = 0 176 | # best_train_acc_ubl = 0 177 | # best_train_acc_all = 0 178 | 179 | for epoch in range(args.epochs): 180 | train_sampler.set_epoch(epoch) 181 | train(model, train_loader, optimizer, fp16_scaler, exp_lr_scheduler, cluster_criterion, epoch, args) 182 | 183 | unlabelled_train_sampler.set_epoch(epoch) 184 | # test_sampler.set_epoch(epoch) 185 | if dist.get_rank() == 0: 186 | args.logger.info('Testing on unlabelled examples in the training data...') 187 | all_acc, old_acc, new_acc = test(model, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args) 188 | # if dist.get_rank() == 0: 189 | # args.logger.info('Testing on disjoint test set...') 190 | # all_acc_test, old_acc_test, new_acc_test = test(model, test_loader, epoch=epoch, save_name='Test ACC', args=args) 191 | 192 | if dist.get_rank() == 0: 193 | args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) 194 | # args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test)) 195 | 196 | save_dict = { 197 | 'model': model.state_dict(), 198 | 'optimizer': optimizer.state_dict(), 199 | 'epoch': epoch + 1, 200 | } 201 | 202 | torch.save(save_dict, args.model_path) 203 | args.logger.info("model saved to {}.".format(args.model_path)) 204 | 205 | # if old_acc_test > best_test_acc_lab and dist.get_rank() == 0: 206 | # args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...') 207 | # args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) 208 | # 209 | # torch.save(save_dict, args.model_path[:-3] + f'_best.pt') 210 | # args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt')) 211 | # 212 | # # inductive 213 | # best_test_acc_lab = old_acc_test 214 | # # transductive 215 | # best_train_acc_lab = old_acc 216 | # best_train_acc_ubl = new_acc 217 | # best_train_acc_all = all_acc 218 | # 219 | # if dist.get_rank() == 0: 220 | # args.logger.info(f'Exp Name: {args.exp_name}') 221 | # args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}') 222 | 223 | 224 | def train(student, train_loader, optimizer, scaler, scheduler, cluster_criterion, epoch, args): 225 | loss_record = AverageMeter() 226 | 227 | student.train() 228 | for batch_idx, batch in enumerate(train_loader): 229 | images, class_labels, uq_idxs, mask_lab = batch 230 | mask_lab = mask_lab[:, 0] 231 | 232 | class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool() 233 | images = torch.cat(images, dim=0).cuda(non_blocking=True) 234 | 235 | with torch.cuda.amp.autocast(scaler is not None): 236 | student_proj, student_out = student(images) 237 | teacher_out = student_out.detach() 238 | 239 | # clustering, sup 240 | sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0) 241 | sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0) 242 | cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels) 243 | 244 | # clustering, unsup 245 | cluster_loss = cluster_criterion(student_out, teacher_out, epoch) 246 | avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0) 247 | me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs))) 248 | cluster_loss += args.memax_weight * me_max_loss 249 | 250 | # represent learning, unsup 251 | contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj) 252 | contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels) 253 | 254 | # representation learning, sup 255 | student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1) 256 | student_proj = torch.nn.functional.normalize(student_proj, dim=-1) 257 | sup_con_labels = class_labels[mask_lab] 258 | sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels) 259 | 260 | pstr = '' 261 | pstr += f'cls_loss: {cls_loss.item():.4f} ' 262 | pstr += f'cluster_loss: {cluster_loss.item():.4f} ' 263 | pstr += f'sup_con_loss: {sup_con_loss.item():.4f} ' 264 | pstr += f'contrastive_loss: {contrastive_loss.item():.4f} ' 265 | 266 | loss = 0 267 | loss += (1 - args.sup_weight) * cluster_loss + args.sup_weight * cls_loss 268 | loss += (1 - args.sup_weight) * contrastive_loss + args.sup_weight * sup_con_loss 269 | 270 | # Train acc 271 | loss_record.update(loss.item(), class_labels.size(0)) 272 | optimizer.zero_grad() 273 | if scaler is None: 274 | loss.backward() 275 | optimizer.step() 276 | else: 277 | scaler.scale(loss).backward() 278 | scaler.step(optimizer) 279 | scaler.update() 280 | 281 | if batch_idx % args.print_freq == 0 and dist.get_rank() == 0: 282 | args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}' 283 | .format(epoch, batch_idx, len(train_loader), loss.item(), pstr)) 284 | # Step schedule 285 | scheduler.step() 286 | 287 | if dist.get_rank() == 0: 288 | args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg)) 289 | 290 | 291 | def test(model, test_loader, epoch, save_name, args): 292 | 293 | model.eval() 294 | 295 | preds, targets = [], [] 296 | mask = np.array([]) 297 | for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)): 298 | images = images.cuda(non_blocking=True) 299 | with torch.no_grad(): 300 | _, logits = model(images) 301 | preds.append(logits.argmax(1).cpu().numpy()) 302 | targets.append(label.cpu().numpy()) 303 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label])) 304 | 305 | preds = np.concatenate(preds) 306 | targets = np.concatenate(targets) 307 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask, 308 | T=epoch, eval_funcs=args.eval_funcs, save_name=save_name, 309 | args=args) 310 | 311 | return all_acc, old_acc, new_acc 312 | 313 | 314 | if __name__ == '__main__': 315 | args = get_parser() 316 | 317 | torch.cuda.set_device(args.local_rank) 318 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 319 | cudnn.benchmark = True 320 | 321 | if dist.get_rank() == 0: 322 | init_experiment(args, runner_name=['simgcd']) 323 | args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results') 324 | 325 | main(args) --------------------------------------------------------------------------------