├── .gitignore
├── assets
└── teaser.jpg
├── data
├── ssb_splits
│ ├── cub_osr_splits.pkl
│ ├── scars_osr_splits.pkl
│ ├── aircraft_osr_splits.pkl
│ └── herbarium_19_class_splits.pkl
├── augmentations
│ └── __init__.py
├── data_utils.py
├── get_datasets.py
├── stanford_cars.py
├── herbarium_19.py
├── cifar.py
├── cub.py
├── imagenet.py
└── fgvc_aircraft.py
├── requirements.txt
├── scripts
├── run_cub.sh
├── run_cars.sh
├── run_aircraft.sh
├── run_cifar10.sh
├── run_cifar100.sh
├── run_herb19.sh
├── run_imagenet100.sh
└── run_imagenet1k.sh
├── config.py
├── LICENSE
├── util
├── general_utils.py
└── cluster_and_log_utils.py
├── README.md
├── model.py
├── train.py
└── train_mp.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | dev_outputs
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/assets/teaser.jpg
--------------------------------------------------------------------------------
/data/ssb_splits/cub_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/data/ssb_splits/cub_osr_splits.pkl
--------------------------------------------------------------------------------
/data/ssb_splits/scars_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/data/ssb_splits/scars_osr_splits.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | loguru
2 | numpy
3 | pandas
4 | scikit_learn
5 | scipy
6 | torch==1.10.0
7 | torchvision==0.11.1
8 | tqdm
--------------------------------------------------------------------------------
/data/ssb_splits/aircraft_osr_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/data/ssb_splits/aircraft_osr_splits.pkl
--------------------------------------------------------------------------------
/data/ssb_splits/herbarium_19_class_splits.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CVMI-Lab/SimGCD/HEAD/data/ssb_splits/herbarium_19_class_splits.pkl
--------------------------------------------------------------------------------
/scripts/run_cub.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --dataset_name 'cub' \
8 | --batch_size 128 \
9 | --grad_from_block 11 \
10 | --epochs 200 \
11 | --num_workers 8 \
12 | --use_ssb_splits \
13 | --sup_weight 0.35 \
14 | --weight_decay 5e-5 \
15 | --transform 'imagenet' \
16 | --lr 0.1 \
17 | --eval_funcs 'v2' \
18 | --warmup_teacher_temp 0.07 \
19 | --teacher_temp 0.04 \
20 | --warmup_teacher_temp_epochs 30 \
21 | --memax_weight 2 \
22 | --exp_name cub_simgcd
23 |
--------------------------------------------------------------------------------
/scripts/run_cars.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --dataset_name 'scars' \
8 | --batch_size 128 \
9 | --grad_from_block 11 \
10 | --epochs 200 \
11 | --num_workers 8 \
12 | --use_ssb_splits \
13 | --sup_weight 0.35 \
14 | --weight_decay 5e-5 \
15 | --transform 'imagenet' \
16 | --lr 0.1 \
17 | --eval_funcs 'v2' \
18 | --warmup_teacher_temp 0.07 \
19 | --teacher_temp 0.04 \
20 | --warmup_teacher_temp_epochs 30 \
21 | --memax_weight 1 \
22 | --exp_name scars_simgcd
23 |
--------------------------------------------------------------------------------
/scripts/run_aircraft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --dataset_name 'aircraft' \
8 | --batch_size 128 \
9 | --grad_from_block 11 \
10 | --epochs 200 \
11 | --num_workers 8 \
12 | --use_ssb_splits \
13 | --sup_weight 0.35 \
14 | --weight_decay 5e-5 \
15 | --transform 'imagenet' \
16 | --lr 0.1 \
17 | --eval_funcs 'v2' \
18 | --warmup_teacher_temp 0.07 \
19 | --teacher_temp 0.04 \
20 | --warmup_teacher_temp_epochs 30 \
21 | --memax_weight 1 \
22 | --exp_name aircraft_simgcd
23 |
--------------------------------------------------------------------------------
/scripts/run_cifar10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --dataset_name 'cifar10' \
8 | --batch_size 128 \
9 | --grad_from_block 11 \
10 | --epochs 200 \
11 | --num_workers 8 \
12 | --use_ssb_splits \
13 | --sup_weight 0.35 \
14 | --weight_decay 5e-5 \
15 | --transform 'imagenet' \
16 | --lr 0.1 \
17 | --eval_funcs 'v2' \
18 | --warmup_teacher_temp 0.07 \
19 | --teacher_temp 0.04 \
20 | --warmup_teacher_temp_epochs 30 \
21 | --memax_weight 1 \
22 | --exp_name cifar10_simgcd
23 |
--------------------------------------------------------------------------------
/scripts/run_cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --dataset_name 'cifar100' \
8 | --batch_size 128 \
9 | --grad_from_block 11 \
10 | --epochs 200 \
11 | --num_workers 8 \
12 | --use_ssb_splits \
13 | --sup_weight 0.35 \
14 | --weight_decay 5e-5 \
15 | --transform 'imagenet' \
16 | --lr 0.1 \
17 | --eval_funcs 'v2' \
18 | --warmup_teacher_temp 0.07 \
19 | --teacher_temp 0.04 \
20 | --warmup_teacher_temp_epochs 30 \
21 | --memax_weight 4 \
22 | --exp_name cifar100_simgcd
23 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # -----------------
2 | # DATASET ROOTS
3 | # -----------------
4 | cifar_10_root = '${DATASET_DIR}/cifar10'
5 | cifar_100_root = '${DATASET_DIR}/cifar100'
6 | cub_root = '${DATASET_DIR}/cub'
7 | aircraft_root = '${DATASET_DIR}/fgvc-aircraft-2013b'
8 | car_root = '${DATASET_DIR}/cars'
9 | herbarium_dataroot = '${DATASET_DIR}/herbarium_19'
10 | imagenet_root = '${DATASET_DIR}/ImageNet'
11 |
12 | # OSR Split dir
13 | osr_split_dir = 'data/ssb_splits'
14 |
15 | # -----------------
16 | # OTHER PATHS
17 | # -----------------
18 | exp_root = 'dev_outputs' # All logs and checkpoints will be saved here
--------------------------------------------------------------------------------
/scripts/run_herb19.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --dataset_name 'herbarium_19' \
8 | --batch_size 128 \
9 | --grad_from_block 11 \
10 | --epochs 200 \
11 | --num_workers 8 \
12 | --use_ssb_splits \
13 | --sup_weight 0.35 \
14 | --weight_decay 5e-5 \
15 | --transform 'imagenet' \
16 | --lr 0.1 \
17 | --eval_funcs 'v2' 'v2b' \
18 | --warmup_teacher_temp 0.07 \
19 | --teacher_temp 0.04 \
20 | --warmup_teacher_temp_epochs 30 \
21 | --memax_weight 1 \
22 | --exp_name herb19_simgcd
23 |
--------------------------------------------------------------------------------
/scripts/run_imagenet100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | CUDA_VISIBLE_DEVICES=0 python train.py \
7 | --dataset_name 'imagenet_100' \
8 | --batch_size 128 \
9 | --grad_from_block 11 \
10 | --epochs 200 \
11 | --num_workers 8 \
12 | --use_ssb_splits \
13 | --sup_weight 0.35 \
14 | --weight_decay 5e-5 \
15 | --transform 'imagenet' \
16 | --lr 0.1 \
17 | --eval_funcs 'v2' \
18 | --warmup_teacher_temp 0.07 \
19 | --teacher_temp 0.04 \
20 | --warmup_teacher_temp_epochs 30 \
21 | --memax_weight 1 \
22 | --exp_name imagenet100_simgcd
23 |
--------------------------------------------------------------------------------
/scripts/run_imagenet1k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | set -x
5 |
6 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 train_mp.py \
7 | --dataset_name 'imagenet_1k' \
8 | --batch_size 128 \
9 | --grad_from_block 11 \
10 | --epochs 200 \
11 | --num_workers 8 \
12 | --use_ssb_splits \
13 | --sup_weight 0.35 \
14 | --weight_decay 5e-5 \
15 | --transform 'imagenet' \
16 | --lr 0.1 \
17 | --eval_funcs 'v2' \
18 | --warmup_teacher_temp 0.07 \
19 | --teacher_temp 0.04 \
20 | --warmup_teacher_temp_epochs 30 \
21 | --memax_weight 1 \
22 | --exp_name imagenet1k_simgcd \
23 | --print_freq 100
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Xin Wen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/data/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 |
3 | import torch
4 |
5 | def get_transform(transform_type='imagenet', image_size=32, args=None):
6 |
7 | if transform_type == 'imagenet':
8 |
9 | mean = (0.485, 0.456, 0.406)
10 | std = (0.229, 0.224, 0.225)
11 | interpolation = args.interpolation
12 | crop_pct = args.crop_pct
13 |
14 | train_transform = transforms.Compose([
15 | transforms.Resize(int(image_size / crop_pct), interpolation),
16 | transforms.RandomCrop(image_size),
17 | transforms.RandomHorizontalFlip(p=0.5),
18 | transforms.ColorJitter(),
19 | transforms.ToTensor(),
20 | transforms.Normalize(
21 | mean=torch.tensor(mean),
22 | std=torch.tensor(std))
23 | ])
24 |
25 | test_transform = transforms.Compose([
26 | transforms.Resize(int(image_size / crop_pct), interpolation),
27 | transforms.CenterCrop(image_size),
28 | transforms.ToTensor(),
29 | transforms.Normalize(
30 | mean=torch.tensor(mean),
31 | std=torch.tensor(std))
32 | ])
33 |
34 | else:
35 |
36 | raise NotImplementedError
37 |
38 | return (train_transform, test_transform)
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset
3 |
4 | def subsample_instances(dataset, prop_indices_to_subsample=0.8):
5 |
6 | np.random.seed(0)
7 | subsample_indices = np.random.choice(range(len(dataset)), replace=False,
8 | size=(int(prop_indices_to_subsample * len(dataset)),))
9 |
10 | return subsample_indices
11 |
12 | class MergedDataset(Dataset):
13 |
14 | """
15 | Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them
16 | Allows you to iterate over them in parallel
17 | """
18 |
19 | def __init__(self, labelled_dataset, unlabelled_dataset):
20 |
21 | self.labelled_dataset = labelled_dataset
22 | self.unlabelled_dataset = unlabelled_dataset
23 | self.target_transform = None
24 |
25 | def __getitem__(self, item):
26 |
27 | if item < len(self.labelled_dataset):
28 | img, label, uq_idx = self.labelled_dataset[item]
29 | labeled_or_not = 1
30 |
31 | else:
32 |
33 | img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)]
34 | labeled_or_not = 0
35 |
36 |
37 | return img, label, uq_idx, np.array([labeled_or_not])
38 |
39 | def __len__(self):
40 | return len(self.unlabelled_dataset) + len(self.labelled_dataset)
41 |
--------------------------------------------------------------------------------
/util/general_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import inspect
4 |
5 | from datetime import datetime
6 | from loguru import logger
7 |
8 | class AverageMeter(object):
9 | """Computes and stores the average and current value"""
10 | def __init__(self):
11 | self.reset()
12 |
13 | def reset(self):
14 |
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 |
22 | self.val = val
23 | self.sum += val * n
24 | self.count += n
25 | self.avg = self.sum / self.count
26 |
27 |
28 | def init_experiment(args, runner_name=None, exp_id=None):
29 | # Get filepath of calling script
30 | if runner_name is None:
31 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:]
32 |
33 | root_dir = os.path.join(args.exp_root, *runner_name)
34 |
35 | if not os.path.exists(root_dir):
36 | os.makedirs(root_dir)
37 |
38 | # Either generate a unique experiment ID, or use one which is passed
39 | if exp_id is None:
40 |
41 | if args.exp_name is None:
42 | raise ValueError("Need to specify the experiment name")
43 | # Unique identifier for experiment
44 | now = '{}_({:02d}.{:02d}.{}_|_'.format(args.exp_name, datetime.now().day, datetime.now().month, datetime.now().year) + \
45 | datetime.now().strftime("%S.%f")[:-3] + ')'
46 |
47 | log_dir = os.path.join(root_dir, 'log', now)
48 | while os.path.exists(log_dir):
49 | now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \
50 | datetime.now().strftime("%S.%f")[:-3] + ')'
51 |
52 | log_dir = os.path.join(root_dir, 'log', now)
53 |
54 | else:
55 |
56 | log_dir = os.path.join(root_dir, 'log', f'{exp_id}')
57 |
58 | if not os.path.exists(log_dir):
59 | os.makedirs(log_dir)
60 |
61 |
62 | logger.add(os.path.join(log_dir, 'log.txt'))
63 | args.logger = logger
64 | args.log_dir = log_dir
65 |
66 | # Instantiate directory to save models to
67 | model_root_dir = os.path.join(args.log_dir, 'checkpoints')
68 | if not os.path.exists(model_root_dir):
69 | os.mkdir(model_root_dir)
70 |
71 | args.model_dir = model_root_dir
72 | args.model_path = os.path.join(args.model_dir, 'model.pt')
73 |
74 | print(f'Experiment saved to: {args.log_dir}')
75 |
76 | hparam_dict = {}
77 |
78 | for k, v in vars(args).items():
79 | if isinstance(v, (int, float, str, bool, torch.Tensor)):
80 | hparam_dict[k] = v
81 |
82 | print(runner_name)
83 | print(args)
84 |
85 | return args
86 |
87 |
88 | class DistributedWeightedSampler(torch.utils.data.distributed.DistributedSampler):
89 |
90 | def __init__(self, dataset, weights, num_samples, num_replicas=None, rank=None,
91 | replacement=True, generator=None):
92 | super(DistributedWeightedSampler, self).__init__(dataset, num_replicas, rank)
93 | if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
94 | num_samples <= 0:
95 | raise ValueError("num_samples should be a positive integer "
96 | "value, but got num_samples={}".format(num_samples))
97 | if not isinstance(replacement, bool):
98 | raise ValueError("replacement should be a boolean value, but got "
99 | "replacement={}".format(replacement))
100 | self.weights = torch.as_tensor(weights, dtype=torch.double)
101 | self.num_samples = num_samples
102 | self.replacement = replacement
103 | self.generator = generator
104 | self.weights = self.weights[self.rank::self.num_replicas]
105 | self.num_samples = self.num_samples // self.num_replicas
106 |
107 | def __iter__(self):
108 | rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
109 | rand_tensor = self.rank + rand_tensor * self.num_replicas
110 | yield from iter(rand_tensor.tolist())
111 |
112 | def __len__(self):
113 | return self.num_samples
114 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Parametric Classification for Generalized Category Discovery: A Baseline Study
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Parametric Classification for Generalized Category Discovery: A Baseline Study (ICCV 2023)
12 | By
13 | Xin Wen*,
14 | Bingchen Zhao*, and
15 | Xiaojuan Qi.
16 |
17 |
18 | 
19 |
20 | Generalized Category Discovery (GCD) aims to discover novel categories in unlabelled datasets using knowledge learned from labelled samples.
21 | Previous studies argued that parametric classifiers are prone to overfitting to seen categories, and endorsed using a non-parametric classifier formed with semi-supervised $k$-means.
22 |
23 | However, in this study, we investigate the failure of parametric classifiers, verify the effectiveness of previous design choices when high-quality supervision is available, and identify unreliable pseudo-labels as a key problem. We demonstrate that two prediction biases exist: the classifier tends to predict seen classes more often, and produces an imbalanced distribution across seen and novel categories.
24 | Based on these findings, we propose a simple yet effective parametric classification method that benefits from entropy regularisation, achieves state-of-the-art performance on multiple GCD benchmarks and shows strong robustness to unknown class numbers.
25 | We hope the investigation and proposed simple framework can serve as a strong baseline to facilitate future studies in this field.
26 |
27 | ## Running
28 |
29 | ### Dependencies
30 |
31 | ```
32 | pip install -r requirements.txt
33 | ```
34 |
35 | ### Config
36 |
37 | Set paths to datasets and desired log directories in ```config.py```
38 |
39 |
40 | ### Datasets
41 |
42 | We use fine-grained benchmarks in this paper, including:
43 |
44 | * [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6)
45 |
46 | We also use generic object recognition datasets, including:
47 |
48 | * [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet-100/1K](https://image-net.org/download.php)
49 |
50 |
51 | ### Scripts
52 |
53 | **Train the model**:
54 |
55 | ```
56 | bash scripts/run_${DATASET_NAME}.sh
57 | ```
58 |
59 | We found picking the model according to 'Old' class performance could lead to possible over-fitting, and since 'New' class labels on the held-out validation set should be assumed unavailable, we suggest not to perform model selection, and simply use the last-epoch model.
60 |
61 | ## Results
62 | Our results:
63 |
64 | | Source | Paper (3 runs) | Current Github (5 runs) |
|---|
| Dataset | All | Old | New | All | Old | New |
| CIFAR10 | 97.1±0.0 | 95.1±0.1 | 98.1±0.1 | 97.0±0.1 | 93.9±0.1 | 98.5±0.1 |
| CIFAR100 | 80.1±0.9 | 81.2±0.4 | 77.8±2.0 | 79.8±0.6 | 81.1±0.5 | 77.4±2.5 |
| ImageNet-100 | 83.0±1.2 | 93.1±0.2 | 77.9±1.9 | 83.6±1.4 | 92.4±0.1 | 79.1±2.2 |
| ImageNet-1K | 57.1±0.1 | 77.3±0.1 | 46.9±0.2 | 57.0±0.4 | 77.1±0.1 | 46.9±0.5 |
| CUB | 60.3±0.1 | 65.6±0.9 | 57.7±0.4 | 61.5±0.5 | 65.7±0.5 | 59.4±0.8 |
| Stanford Cars | 53.8±2.2 | 71.9±1.7 | 45.0±2.4 | 53.4±1.6 | 71.5±1.6 | 44.6±1.7 |
| FGVC-Aircraft | 54.2±1.9 | 59.1±1.2 | 51.8±2.3 | 54.3±0.7 | 59.4±0.4 | 51.7±1.2 |
| Herbarium 19 | 44.0±0.4 | 58.0±0.4 | 36.4±0.8 | 44.2±0.2 | 57.6±0.6 | 37.0±0.4 |
65 |
66 | ## Citing this work
67 |
68 | If you find this repo useful for your research, please consider citing our paper:
69 |
70 | ```
71 | @inproceedings{wen2023simgcd,
72 | author = {Wen, Xin and Zhao, Bingchen and Qi, Xiaojuan},
73 | title = {Parametric Classification for Generalized Category Discovery: A Baseline Study},
74 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
75 | year = {2023},
76 | pages = {16590-16600}
77 | }
78 | ```
79 |
80 | ## Acknowledgements
81 |
82 | The codebase is largely built on this repo: https://github.com/sgvaze/generalized-category-discovery.
83 |
84 | ## License
85 |
86 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
87 |
--------------------------------------------------------------------------------
/data/get_datasets.py:
--------------------------------------------------------------------------------
1 | from data.data_utils import MergedDataset
2 |
3 | from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets
4 | from data.herbarium_19 import get_herbarium_datasets
5 | from data.stanford_cars import get_scars_datasets
6 | from data.imagenet import get_imagenet_100_datasets, get_imagenet_1k_datasets
7 | from data.cub import get_cub_datasets
8 | from data.fgvc_aircraft import get_aircraft_datasets
9 |
10 | from copy import deepcopy
11 | import pickle
12 | import os
13 |
14 | from config import osr_split_dir
15 |
16 |
17 | get_dataset_funcs = {
18 | 'cifar10': get_cifar_10_datasets,
19 | 'cifar100': get_cifar_100_datasets,
20 | 'imagenet_100': get_imagenet_100_datasets,
21 | 'imagenet_1k': get_imagenet_1k_datasets,
22 | 'herbarium_19': get_herbarium_datasets,
23 | 'cub': get_cub_datasets,
24 | 'aircraft': get_aircraft_datasets,
25 | 'scars': get_scars_datasets
26 | }
27 |
28 |
29 | def get_datasets(dataset_name, train_transform, test_transform, args):
30 |
31 | """
32 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled
33 | test_dataset,
34 | unlabelled_train_examples_test,
35 | datasets
36 | """
37 |
38 | #
39 | if dataset_name not in get_dataset_funcs.keys():
40 | raise ValueError
41 |
42 | # Get datasets
43 | get_dataset_f = get_dataset_funcs[dataset_name]
44 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform,
45 | train_classes=args.train_classes,
46 | prop_train_labels=args.prop_train_labels,
47 | split_train_val=False)
48 | # Set target transforms:
49 | target_transform_dict = {}
50 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)):
51 | target_transform_dict[cls] = i
52 | target_transform = lambda x: target_transform_dict[x]
53 |
54 | for dataset_name, dataset in datasets.items():
55 | if dataset is not None:
56 | dataset.target_transform = target_transform
57 |
58 | # Train split (labelled and unlabelled classes) for training
59 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']),
60 | unlabelled_dataset=deepcopy(datasets['train_unlabelled']))
61 |
62 | test_dataset = datasets['test']
63 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled'])
64 | unlabelled_train_examples_test.transform = test_transform
65 |
66 | return train_dataset, test_dataset, unlabelled_train_examples_test, datasets
67 |
68 |
69 | def get_class_splits(args):
70 |
71 | # For FGVC datasets, optionally return bespoke splits
72 | if args.dataset_name in ('scars', 'cub', 'aircraft'):
73 | if hasattr(args, 'use_ssb_splits'):
74 | use_ssb_splits = args.use_ssb_splits
75 | else:
76 | use_ssb_splits = False
77 |
78 | # -------------
79 | # GET CLASS SPLITS
80 | # -------------
81 | if args.dataset_name == 'cifar10':
82 |
83 | args.image_size = 32
84 | args.train_classes = range(5)
85 | args.unlabeled_classes = range(5, 10)
86 |
87 | elif args.dataset_name == 'cifar100':
88 |
89 | args.image_size = 32
90 | args.train_classes = range(80)
91 | args.unlabeled_classes = range(80, 100)
92 |
93 | elif args.dataset_name == 'herbarium_19':
94 |
95 | args.image_size = 224
96 | herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl')
97 |
98 | with open(herb_path_splits, 'rb') as handle:
99 | class_splits = pickle.load(handle)
100 |
101 | args.train_classes = class_splits['Old']
102 | args.unlabeled_classes = class_splits['New']
103 |
104 | elif args.dataset_name == 'imagenet_100':
105 |
106 | args.image_size = 224
107 | args.train_classes = range(50)
108 | args.unlabeled_classes = range(50, 100)
109 |
110 | elif args.dataset_name == 'imagenet_1k':
111 |
112 | args.image_size = 224
113 | args.train_classes = range(500)
114 | args.unlabeled_classes = range(500, 1000)
115 |
116 | elif args.dataset_name == 'scars':
117 |
118 | args.image_size = 224
119 |
120 | if use_ssb_splits:
121 |
122 | split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl')
123 | with open(split_path, 'rb') as handle:
124 | class_info = pickle.load(handle)
125 |
126 | args.train_classes = class_info['known_classes']
127 | open_set_classes = class_info['unknown_classes']
128 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
129 |
130 | else:
131 |
132 | args.train_classes = range(98)
133 | args.unlabeled_classes = range(98, 196)
134 |
135 | elif args.dataset_name == 'aircraft':
136 |
137 | args.image_size = 224
138 | if use_ssb_splits:
139 |
140 | split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl')
141 | with open(split_path, 'rb') as handle:
142 | class_info = pickle.load(handle)
143 |
144 | args.train_classes = class_info['known_classes']
145 | open_set_classes = class_info['unknown_classes']
146 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
147 |
148 | else:
149 |
150 | args.train_classes = range(50)
151 | args.unlabeled_classes = range(50, 100)
152 |
153 | elif args.dataset_name == 'cub':
154 |
155 | args.image_size = 224
156 |
157 | if use_ssb_splits:
158 |
159 | split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl')
160 | with open(split_path, 'rb') as handle:
161 | class_info = pickle.load(handle)
162 |
163 | args.train_classes = class_info['known_classes']
164 | open_set_classes = class_info['unknown_classes']
165 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
166 |
167 | else:
168 |
169 | args.train_classes = range(100)
170 | args.unlabeled_classes = range(100, 200)
171 |
172 | else:
173 |
174 | raise NotImplementedError
175 |
176 | return args
177 |
--------------------------------------------------------------------------------
/data/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 | from scipy import io as mat_io
6 |
7 | from torchvision.datasets.folder import default_loader
8 | from torch.utils.data import Dataset
9 |
10 | from data.data_utils import subsample_instances
11 | from config import car_root
12 |
13 | class CarsDataset(Dataset):
14 | """
15 | Cars Dataset
16 | """
17 | def __init__(self, train=True, limit=0, data_dir=car_root, transform=None):
18 |
19 | metas = os.path.join(data_dir, 'devkit/cars_train_annos.mat') if train else os.path.join(data_dir, 'devkit/cars_test_annos_withlabels.mat')
20 | data_dir = os.path.join(data_dir, 'cars_train/') if train else os.path.join(data_dir, 'cars_test/')
21 |
22 | self.loader = default_loader
23 | self.data_dir = data_dir
24 | self.data = []
25 | self.target = []
26 | self.train = train
27 |
28 | self.transform = transform
29 |
30 | if not isinstance(metas, str):
31 | raise Exception("Train metas must be string location !")
32 | labels_meta = mat_io.loadmat(metas)
33 |
34 | for idx, img_ in enumerate(labels_meta['annotations'][0]):
35 | if limit:
36 | if idx > limit:
37 | break
38 |
39 | # self.data.append(img_resized)
40 | self.data.append(data_dir + img_[5][0])
41 | # if self.mode == 'train':
42 | self.target.append(img_[4][0][0])
43 |
44 | self.uq_idxs = np.array(range(len(self)))
45 | self.target_transform = None
46 |
47 | def __getitem__(self, idx):
48 |
49 | image = self.loader(self.data[idx])
50 | target = self.target[idx] - 1
51 |
52 | if self.transform is not None:
53 | image = self.transform(image)
54 |
55 | if self.target_transform is not None:
56 | target = self.target_transform(target)
57 |
58 | idx = self.uq_idxs[idx]
59 |
60 | return image, target, idx
61 |
62 | def __len__(self):
63 | return len(self.data)
64 |
65 |
66 | def subsample_dataset(dataset, idxs):
67 |
68 | dataset.data = np.array(dataset.data)[idxs].tolist()
69 | dataset.target = np.array(dataset.target)[idxs].tolist()
70 | dataset.uq_idxs = dataset.uq_idxs[idxs]
71 |
72 | return dataset
73 |
74 |
75 | def subsample_classes(dataset, include_classes=range(160)):
76 |
77 | include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195
78 | cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars]
79 |
80 | target_xform_dict = {}
81 | for i, k in enumerate(include_classes):
82 | target_xform_dict[k] = i
83 |
84 | dataset = subsample_dataset(dataset, cls_idxs)
85 |
86 | # dataset.target_transform = lambda x: target_xform_dict[x]
87 |
88 | return dataset
89 |
90 | def get_train_val_indices(train_dataset, val_split=0.2):
91 |
92 | train_classes = np.unique(train_dataset.target)
93 |
94 | # Get train/test indices
95 | train_idxs = []
96 | val_idxs = []
97 | for cls in train_classes:
98 |
99 | cls_idxs = np.where(train_dataset.target == cls)[0]
100 |
101 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
102 | t_ = [x for x in cls_idxs if x not in v_]
103 |
104 | train_idxs.extend(t_)
105 | val_idxs.extend(v_)
106 |
107 | return train_idxs, val_idxs
108 |
109 |
110 | def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
111 | split_train_val=False, seed=0):
112 |
113 | np.random.seed(seed)
114 |
115 | # Init entire training set
116 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, train=True)
117 |
118 | # Get labelled training set which has subsampled classes, then subsample some indices from that
119 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
120 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
121 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
122 |
123 | # Split into training and validation sets
124 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
125 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
126 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
127 | val_dataset_labelled_split.transform = test_transform
128 |
129 | # Get unlabelled data
130 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
131 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
132 |
133 | # Get test set for all classes
134 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, train=False)
135 |
136 | # Either split train into train and val or use test set as val
137 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
138 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
139 |
140 | all_datasets = {
141 | 'train_labelled': train_dataset_labelled,
142 | 'train_unlabelled': train_dataset_unlabelled,
143 | 'val': val_dataset_labelled,
144 | 'test': test_dataset,
145 | }
146 |
147 | return all_datasets
148 |
149 | if __name__ == '__main__':
150 |
151 | x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False)
152 |
153 | print('Printing lens...')
154 | for k, v in x.items():
155 | if v is not None:
156 | print(f'{k}: {len(v)}')
157 |
158 | print('Printing labelled and unlabelled overlap...')
159 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
160 | print('Printing total instances in train...')
161 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
162 |
163 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}')
164 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}')
165 | print(f'Len labelled set: {len(x["train_labelled"])}')
166 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/herbarium_19.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torchvision
4 | import numpy as np
5 | from copy import deepcopy
6 |
7 | from data.data_utils import subsample_instances
8 | from config import herbarium_dataroot
9 |
10 | class HerbariumDataset19(torchvision.datasets.ImageFolder):
11 |
12 | def __init__(self, *args, **kwargs):
13 |
14 | # Process metadata json for training images into a DataFrame
15 | super().__init__(*args, **kwargs)
16 |
17 | self.uq_idxs = np.array(range(len(self)))
18 |
19 | def __getitem__(self, idx):
20 |
21 | img, label = super().__getitem__(idx)
22 | uq_idx = self.uq_idxs[idx]
23 |
24 | return img, label, uq_idx
25 |
26 |
27 | def subsample_dataset(dataset, idxs):
28 |
29 | mask = np.zeros(len(dataset)).astype('bool')
30 | mask[idxs] = True
31 |
32 | dataset.samples = np.array(dataset.samples)[mask].tolist()
33 | dataset.targets = np.array(dataset.targets)[mask].tolist()
34 |
35 | dataset.uq_idxs = dataset.uq_idxs[mask]
36 |
37 | dataset.samples = [[x[0], int(x[1])] for x in dataset.samples]
38 | dataset.targets = [int(x) for x in dataset.targets]
39 |
40 | return dataset
41 |
42 |
43 | def subsample_classes(dataset, include_classes=range(250)):
44 |
45 | cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes]
46 |
47 | target_xform_dict = {}
48 | for i, k in enumerate(include_classes):
49 | target_xform_dict[k] = i
50 |
51 | dataset = subsample_dataset(dataset, cls_idxs)
52 |
53 | dataset.target_transform = lambda x: target_xform_dict[x]
54 |
55 | return dataset
56 |
57 |
58 | def get_train_val_indices(train_dataset, val_instances_per_class=5):
59 |
60 | train_classes = list(set(train_dataset.targets))
61 |
62 | # Get train/test indices
63 | train_idxs = []
64 | val_idxs = []
65 | for cls in train_classes:
66 |
67 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
68 |
69 | # Have a balanced test set
70 | v_ = np.random.choice(cls_idxs, replace=False, size=(val_instances_per_class,))
71 | t_ = [x for x in cls_idxs if x not in v_]
72 |
73 | train_idxs.extend(t_)
74 | val_idxs.extend(v_)
75 |
76 | return train_idxs, val_idxs
77 |
78 |
79 | def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8,
80 | seed=0, split_train_val=False):
81 |
82 | np.random.seed(seed)
83 |
84 | # Init entire training set
85 | train_dataset = HerbariumDataset19(transform=train_transform,
86 | root=os.path.join(herbarium_dataroot, 'small-train'))
87 |
88 | # Get labelled training set which has subsampled classes, then subsample some indices from that
89 | # TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class
90 | train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes)
91 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
92 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
93 |
94 | # Split into training and validation sets
95 | if split_train_val:
96 |
97 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled,
98 | val_instances_per_class=5)
99 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
100 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
101 | val_dataset_labelled_split.transform = test_transform
102 |
103 | else:
104 |
105 | train_dataset_labelled_split, val_dataset_labelled_split = None, None
106 |
107 | # Get unlabelled data
108 | unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs)
109 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices)))
110 |
111 | # Get test dataset
112 | test_dataset = HerbariumDataset19(transform=test_transform,
113 | root=os.path.join(herbarium_dataroot, 'small-validation'))
114 |
115 | # Transform dict
116 | unlabelled_classes = list(set(train_dataset.targets) - set(train_classes))
117 | target_xform_dict = {}
118 | for i, k in enumerate(list(train_classes) + unlabelled_classes):
119 | target_xform_dict[k] = i
120 |
121 | test_dataset.target_transform = lambda x: target_xform_dict[x]
122 | train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x]
123 |
124 | # Either split train into train and val or use test set as val
125 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
126 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
127 |
128 | all_datasets = {
129 | 'train_labelled': train_dataset_labelled,
130 | 'train_unlabelled': train_dataset_unlabelled,
131 | 'val': val_dataset_labelled,
132 | 'test': test_dataset,
133 | }
134 |
135 | return all_datasets
136 |
137 | if __name__ == '__main__':
138 |
139 | np.random.seed(0)
140 | train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False)
141 |
142 | x = get_herbarium_datasets(None, None, train_classes=train_classes,
143 | prop_train_labels=0.5)
144 |
145 | assert set(x['train_unlabelled'].targets) == set(range(683))
146 |
147 | print('Printing lens...')
148 | for k, v in x.items():
149 | if v is not None:
150 | print(f'{k}: {len(v)}')
151 |
152 | print('Printing labelled and unlabelled overlap...')
153 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
154 | print('Printing total instances in train...')
155 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
156 | print('Printing number of labelled classes...')
157 | print(len(set(x['train_labelled'].targets)))
158 | print('Printing total number of classes...')
159 | print(len(set(x['train_unlabelled'].targets)))
160 |
161 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
162 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
163 | print(f'Len labelled set: {len(x["train_labelled"])}')
164 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/util/cluster_and_log_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import numpy as np
4 | from scipy.optimize import linear_sum_assignment as linear_assignment
5 |
6 |
7 | def all_sum_item(item):
8 | item = torch.tensor(item).cuda()
9 | dist.all_reduce(item)
10 | return item.item()
11 |
12 | def split_cluster_acc_v2(y_true, y_pred, mask):
13 | """
14 | Calculate clustering accuracy. Require scikit-learn installed
15 | First compute linear assignment on all data, then look at how good the accuracy is on subsets
16 |
17 | # Arguments
18 | mask: Which instances come from old classes (True) and which ones come from new classes (False)
19 | y: true labels, numpy.array with shape `(n_samples,)`
20 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
21 |
22 | # Return
23 | accuracy, in [0,1]
24 | """
25 | y_true = y_true.astype(int)
26 |
27 | old_classes_gt = set(y_true[mask])
28 | new_classes_gt = set(y_true[~mask])
29 |
30 | assert y_pred.size == y_true.size
31 | D = max(y_pred.max(), y_true.max()) + 1
32 | w = np.zeros((D, D), dtype=int)
33 | for i in range(y_pred.size):
34 | w[y_pred[i], y_true[i]] += 1
35 |
36 | ind = linear_assignment(w.max() - w)
37 | ind = np.vstack(ind).T
38 |
39 | ind_map = {j: i for i, j in ind}
40 | total_acc = sum([w[i, j] for i, j in ind])
41 | total_instances = y_pred.size
42 | try:
43 | if dist.get_world_size() > 0:
44 | total_acc = all_sum_item(total_acc)
45 | total_instances = all_sum_item(total_instances)
46 | except:
47 | pass
48 | total_acc /= total_instances
49 |
50 | old_acc = 0
51 | total_old_instances = 0
52 | for i in old_classes_gt:
53 | old_acc += w[ind_map[i], i]
54 | total_old_instances += sum(w[:, i])
55 |
56 | try:
57 | if dist.get_world_size() > 0:
58 | old_acc = all_sum_item(old_acc)
59 | total_old_instances = all_sum_item(total_old_instances)
60 | except:
61 | pass
62 | old_acc /= total_old_instances
63 |
64 | new_acc = 0
65 | total_new_instances = 0
66 | for i in new_classes_gt:
67 | new_acc += w[ind_map[i], i]
68 | total_new_instances += sum(w[:, i])
69 |
70 | try:
71 | if dist.get_world_size() > 0:
72 | new_acc = all_sum_item(new_acc)
73 | total_new_instances = all_sum_item(total_new_instances)
74 | except:
75 | pass
76 | new_acc /= total_new_instances
77 |
78 | return total_acc, old_acc, new_acc
79 |
80 |
81 | def split_cluster_acc_v2_balanced(y_true, y_pred, mask):
82 | """
83 | Calculate clustering accuracy. Require scikit-learn installed
84 | First compute linear assignment on all data, then look at how good the accuracy is on subsets
85 |
86 | # Arguments
87 | mask: Which instances come from old classes (True) and which ones come from new classes (False)
88 | y: true labels, numpy.array with shape `(n_samples,)`
89 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
90 |
91 | # Return
92 | accuracy, in [0,1]
93 | """
94 | y_true = y_true.astype(int)
95 |
96 | old_classes_gt = set(y_true[mask])
97 | new_classes_gt = set(y_true[~mask])
98 |
99 | assert y_pred.size == y_true.size
100 | D = max(y_pred.max(), y_true.max()) + 1
101 | w = np.zeros((D, D), dtype=int)
102 | for i in range(y_pred.size):
103 | w[y_pred[i], y_true[i]] += 1
104 |
105 | ind = linear_assignment(w.max() - w)
106 | ind = np.vstack(ind).T
107 |
108 | ind_map = {j: i for i, j in ind}
109 |
110 | old_acc = np.zeros(len(old_classes_gt))
111 | total_old_instances = np.zeros(len(old_classes_gt))
112 | for idx, i in enumerate(old_classes_gt):
113 | old_acc[idx] += w[ind_map[i], i]
114 | total_old_instances[idx] += sum(w[:, i])
115 |
116 | new_acc = np.zeros(len(new_classes_gt))
117 | total_new_instances = np.zeros(len(new_classes_gt))
118 | for idx, i in enumerate(new_classes_gt):
119 | new_acc[idx] += w[ind_map[i], i]
120 | total_new_instances[idx] += sum(w[:, i])
121 |
122 | try:
123 | if dist.get_world_size() > 0:
124 | old_acc, new_acc = torch.from_numpy(old_acc).cuda(), torch.from_numpy(new_acc).cuda()
125 | dist.all_reduce(old_acc), dist.all_reduce(new_acc)
126 | dist.all_reduce(total_old_instances), dist.all_reduce(total_new_instances)
127 | old_acc, new_acc = old_acc.cpu().numpy(), new_acc.cpu().numpy()
128 | total_old_instances, total_new_instances = total_old_instances.cpu().numpy(), total_new_instances.cpu().numpy()
129 | except:
130 | pass
131 |
132 | total_acc = np.concatenate([old_acc, new_acc]) / np.concatenate([total_old_instances, total_new_instances])
133 | old_acc /= total_old_instances
134 | new_acc /= total_new_instances
135 | total_acc, old_acc, new_acc = total_acc.mean(), old_acc.mean(), new_acc.mean()
136 | return total_acc, old_acc, new_acc
137 |
138 |
139 | EVAL_FUNCS = {
140 | 'v2': split_cluster_acc_v2,
141 | 'v2b': split_cluster_acc_v2_balanced
142 | }
143 |
144 | def log_accs_from_preds(y_true, y_pred, mask, eval_funcs, save_name, T=None,
145 | print_output=True, args=None):
146 |
147 | """
148 | Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results
149 |
150 | :param y_true: GT labels
151 | :param y_pred: Predicted indices
152 | :param mask: Which instances belong to Old and New classes
153 | :param T: Epoch
154 | :param eval_funcs: Which evaluation functions to use
155 | :param save_name: What are we evaluating ACC on
156 | :param writer: Tensorboard logger
157 | :return:
158 | """
159 |
160 | mask = mask.astype(bool)
161 | y_true = y_true.astype(int)
162 | y_pred = y_pred.astype(int)
163 |
164 | for i, f_name in enumerate(eval_funcs):
165 |
166 | acc_f = EVAL_FUNCS[f_name]
167 | all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask)
168 | log_name = f'{save_name}_{f_name}'
169 |
170 | if i == 0:
171 | to_return = (all_acc, old_acc, new_acc)
172 |
173 | if print_output:
174 | print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}'
175 | try:
176 | if dist.get_rank() == 0:
177 | try:
178 | args.logger.info(print_str)
179 | except:
180 | print(print_str)
181 | except:
182 | pass
183 |
184 | return to_return
--------------------------------------------------------------------------------
/data/cifar.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR10, CIFAR100
2 | from copy import deepcopy
3 | import numpy as np
4 |
5 | from data.data_utils import subsample_instances
6 | from config import cifar_10_root, cifar_100_root
7 |
8 |
9 | class CustomCIFAR10(CIFAR10):
10 |
11 | def __init__(self, *args, **kwargs):
12 |
13 | super(CustomCIFAR10, self).__init__(*args, **kwargs)
14 |
15 | self.uq_idxs = np.array(range(len(self)))
16 |
17 | def __getitem__(self, item):
18 |
19 | img, label = super().__getitem__(item)
20 | uq_idx = self.uq_idxs[item]
21 |
22 | return img, label, uq_idx
23 |
24 | def __len__(self):
25 | return len(self.targets)
26 |
27 |
28 | class CustomCIFAR100(CIFAR100):
29 |
30 | def __init__(self, *args, **kwargs):
31 | super(CustomCIFAR100, self).__init__(*args, **kwargs)
32 |
33 | self.uq_idxs = np.array(range(len(self)))
34 |
35 | def __getitem__(self, item):
36 | img, label = super().__getitem__(item)
37 | uq_idx = self.uq_idxs[item]
38 |
39 | return img, label, uq_idx
40 |
41 | def __len__(self):
42 | return len(self.targets)
43 |
44 |
45 | def subsample_dataset(dataset, idxs):
46 |
47 | # Allow for setting in which all empty set of indices is passed
48 |
49 | if len(idxs) > 0:
50 |
51 | dataset.data = dataset.data[idxs]
52 | dataset.targets = np.array(dataset.targets)[idxs].tolist()
53 | dataset.uq_idxs = dataset.uq_idxs[idxs]
54 |
55 | return dataset
56 |
57 | else:
58 |
59 | return None
60 |
61 |
62 | def subsample_classes(dataset, include_classes=(0, 1, 8, 9)):
63 |
64 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
65 |
66 | target_xform_dict = {}
67 | for i, k in enumerate(include_classes):
68 | target_xform_dict[k] = i
69 |
70 | dataset = subsample_dataset(dataset, cls_idxs)
71 |
72 | # dataset.target_transform = lambda x: target_xform_dict[x]
73 |
74 | return dataset
75 |
76 |
77 | def get_train_val_indices(train_dataset, val_split=0.2):
78 |
79 | train_classes = np.unique(train_dataset.targets)
80 |
81 | # Get train/test indices
82 | train_idxs = []
83 | val_idxs = []
84 | for cls in train_classes:
85 |
86 | cls_idxs = np.where(train_dataset.targets == cls)[0]
87 |
88 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
89 | t_ = [x for x in cls_idxs if x not in v_]
90 |
91 | train_idxs.extend(t_)
92 | val_idxs.extend(v_)
93 |
94 | return train_idxs, val_idxs
95 |
96 |
97 | def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9),
98 | prop_train_labels=0.8, split_train_val=False, seed=0):
99 |
100 | np.random.seed(seed)
101 |
102 | # Init entire training set
103 | whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True)
104 |
105 | # Get labelled training set which has subsampled classes, then subsample some indices from that
106 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
107 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
108 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
109 |
110 | # Split into training and validation sets
111 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
112 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
113 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
114 | val_dataset_labelled_split.transform = test_transform
115 |
116 | # Get unlabelled data
117 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
118 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
119 |
120 | # Get test set for all classes
121 | test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False)
122 |
123 | # Either split train into train and val or use test set as val
124 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
125 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
126 |
127 | all_datasets = {
128 | 'train_labelled': train_dataset_labelled,
129 | 'train_unlabelled': train_dataset_unlabelled,
130 | 'val': val_dataset_labelled,
131 | 'test': test_dataset,
132 | }
133 |
134 | return all_datasets
135 |
136 |
137 | def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80),
138 | prop_train_labels=0.8, split_train_val=False, seed=0):
139 |
140 | np.random.seed(seed)
141 |
142 | # Init entire training set
143 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True)
144 |
145 | # Get labelled training set which has subsampled classes, then subsample some indices from that
146 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
147 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
148 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
149 |
150 | # Split into training and validation sets
151 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
152 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
153 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
154 | val_dataset_labelled_split.transform = test_transform
155 |
156 | # Get unlabelled data
157 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
158 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
159 |
160 | # Get test set for all classes
161 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False)
162 |
163 | # Either split train into train and val or use test set as val
164 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
165 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
166 |
167 | all_datasets = {
168 | 'train_labelled': train_dataset_labelled,
169 | 'train_unlabelled': train_dataset_unlabelled,
170 | 'val': val_dataset_labelled,
171 | 'test': test_dataset,
172 | }
173 |
174 | return all_datasets
175 |
176 |
177 | if __name__ == '__main__':
178 |
179 | x = get_cifar_100_datasets(None, None, split_train_val=False,
180 | train_classes=range(80), prop_train_labels=0.5)
181 |
182 | print('Printing lens...')
183 | for k, v in x.items():
184 | if v is not None:
185 | print(f'{k}: {len(v)}')
186 |
187 | print('Printing labelled and unlabelled overlap...')
188 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
189 | print('Printing total instances in train...')
190 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
191 |
192 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
193 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
194 | print(f'Len labelled set: {len(x["train_labelled"])}')
195 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/cub.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 |
6 | from torchvision.datasets.folder import default_loader
7 | from torchvision.datasets.utils import download_url
8 | from torch.utils.data import Dataset
9 |
10 | from data.data_utils import subsample_instances
11 | from config import cub_root
12 |
13 |
14 | class CustomCub2011(Dataset):
15 | base_folder = 'CUB_200_2011/images'
16 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
17 | filename = 'CUB_200_2011.tgz'
18 | tgz_md5 = '97eceeb196236b17998738112f37df78'
19 |
20 | def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=True):
21 |
22 | self.root = os.path.expanduser(root)
23 | self.transform = transform
24 | self.target_transform = target_transform
25 |
26 | self.loader = loader
27 | self.train = train
28 |
29 |
30 | if download:
31 | self._download()
32 |
33 | if not self._check_integrity():
34 | raise RuntimeError('Dataset not found or corrupted.' +
35 | ' You can use download=True to download it')
36 |
37 | self.uq_idxs = np.array(range(len(self)))
38 |
39 | def _load_metadata(self):
40 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
41 | names=['img_id', 'filepath'])
42 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
43 | sep=' ', names=['img_id', 'target'])
44 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
45 | sep=' ', names=['img_id', 'is_training_img'])
46 |
47 | data = images.merge(image_class_labels, on='img_id')
48 | self.data = data.merge(train_test_split, on='img_id')
49 |
50 | if self.train:
51 | self.data = self.data[self.data.is_training_img == 1]
52 | else:
53 | self.data = self.data[self.data.is_training_img == 0]
54 |
55 | def _check_integrity(self):
56 | try:
57 | self._load_metadata()
58 | except Exception:
59 | return False
60 |
61 | for index, row in self.data.iterrows():
62 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
63 | if not os.path.isfile(filepath):
64 | print(filepath)
65 | return False
66 | return True
67 |
68 | def _download(self):
69 | import tarfile
70 |
71 | if self._check_integrity():
72 | print('Files already downloaded and verified')
73 | return
74 |
75 | download_url(self.url, self.root, self.filename, self.tgz_md5)
76 |
77 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
78 | tar.extractall(path=self.root)
79 |
80 | def __len__(self):
81 | return len(self.data)
82 |
83 | def __getitem__(self, idx):
84 | sample = self.data.iloc[idx]
85 | path = os.path.join(self.root, self.base_folder, sample.filepath)
86 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0
87 | img = self.loader(path)
88 |
89 | if self.transform is not None:
90 | img = self.transform(img)
91 |
92 | if self.target_transform is not None:
93 | target = self.target_transform(target)
94 |
95 | return img, target, self.uq_idxs[idx]
96 |
97 |
98 | def subsample_dataset(dataset, idxs):
99 |
100 | mask = np.zeros(len(dataset)).astype('bool')
101 | mask[idxs] = True
102 |
103 | dataset.data = dataset.data[mask]
104 | dataset.uq_idxs = dataset.uq_idxs[mask]
105 |
106 | return dataset
107 |
108 |
109 | def subsample_classes(dataset, include_classes=range(160)):
110 |
111 | include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199
112 | cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub]
113 |
114 | # TODO: For now have no target transform
115 | target_xform_dict = {}
116 | for i, k in enumerate(include_classes):
117 | target_xform_dict[k] = i
118 |
119 | dataset = subsample_dataset(dataset, cls_idxs)
120 |
121 | dataset.target_transform = lambda x: target_xform_dict[x]
122 |
123 | return dataset
124 |
125 |
126 | def get_train_val_indices(train_dataset, val_split=0.2):
127 |
128 | train_classes = np.unique(train_dataset.data['target'])
129 |
130 | # Get train/test indices
131 | train_idxs = []
132 | val_idxs = []
133 | for cls in train_classes:
134 |
135 | cls_idxs = np.where(train_dataset.data['target'] == cls)[0]
136 |
137 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
138 | t_ = [x for x in cls_idxs if x not in v_]
139 |
140 | train_idxs.extend(t_)
141 | val_idxs.extend(v_)
142 |
143 | return train_idxs, val_idxs
144 |
145 |
146 | def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
147 | split_train_val=False, seed=0, download=False):
148 |
149 | np.random.seed(seed)
150 |
151 | # Init entire training set
152 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True, download=download)
153 |
154 | # Get labelled training set which has subsampled classes, then subsample some indices from that
155 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
156 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
157 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
158 |
159 | # Split into training and validation sets
160 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
161 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
162 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
163 | val_dataset_labelled_split.transform = test_transform
164 |
165 | # Get unlabelled data
166 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
167 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
168 |
169 | # Get test set for all classes
170 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False)
171 |
172 | # Either split train into train and val or use test set as val
173 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
174 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
175 |
176 | all_datasets = {
177 | 'train_labelled': train_dataset_labelled,
178 | 'train_unlabelled': train_dataset_unlabelled,
179 | 'val': val_dataset_labelled,
180 | 'test': test_dataset,
181 | }
182 |
183 | return all_datasets
184 |
185 | if __name__ == '__main__':
186 |
187 | x = get_cub_datasets(None, None, split_train_val=False,
188 | train_classes=range(100), prop_train_labels=0.5)
189 |
190 | print('Printing lens...')
191 | for k, v in x.items():
192 | if v is not None:
193 | print(f'{k}: {len(v)}')
194 |
195 | print('Printing labelled and unlabelled overlap...')
196 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
197 | print('Printing total instances in train...')
198 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
199 |
200 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}')
201 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}')
202 | print(f'Len labelled set: {len(x["train_labelled"])}')
203 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import numpy as np
3 |
4 | import os
5 |
6 | from copy import deepcopy
7 | from data.data_utils import subsample_instances
8 | from config import imagenet_root
9 |
10 |
11 | class ImageNetBase(torchvision.datasets.ImageFolder):
12 |
13 | def __init__(self, root, transform):
14 |
15 | super(ImageNetBase, self).__init__(root, transform)
16 |
17 | self.uq_idxs = np.array(range(len(self)))
18 |
19 | def __getitem__(self, item):
20 |
21 | img, label = super().__getitem__(item)
22 | uq_idx = self.uq_idxs[item]
23 |
24 | return img, label, uq_idx
25 |
26 |
27 | def subsample_dataset(dataset, idxs):
28 |
29 | imgs_ = []
30 | for i in idxs:
31 | imgs_.append(dataset.imgs[i])
32 | dataset.imgs = imgs_
33 |
34 | samples_ = []
35 | for i in idxs:
36 | samples_.append(dataset.samples[i])
37 | dataset.samples = samples_
38 |
39 | # dataset.imgs = [x for i, x in enumerate(dataset.imgs) if i in idxs]
40 | # dataset.samples = [x for i, x in enumerate(dataset.samples) if i in idxs]
41 |
42 | dataset.targets = np.array(dataset.targets)[idxs].tolist()
43 | dataset.uq_idxs = dataset.uq_idxs[idxs]
44 |
45 | return dataset
46 |
47 |
48 | def subsample_classes(dataset, include_classes=list(range(1000))):
49 |
50 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
51 |
52 | target_xform_dict = {}
53 | for i, k in enumerate(include_classes):
54 | target_xform_dict[k] = i
55 |
56 | dataset = subsample_dataset(dataset, cls_idxs)
57 | dataset.target_transform = lambda x: target_xform_dict[x]
58 |
59 | return dataset
60 |
61 |
62 | def get_train_val_indices(train_dataset, val_split=0.2):
63 |
64 | train_classes = list(set(train_dataset.targets))
65 |
66 | # Get train/test indices
67 | train_idxs = []
68 | val_idxs = []
69 | for cls in train_classes:
70 |
71 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
72 |
73 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
74 | t_ = [x for x in cls_idxs if x not in v_]
75 |
76 | train_idxs.extend(t_)
77 | val_idxs.extend(v_)
78 |
79 | return train_idxs, val_idxs
80 |
81 |
82 | def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80),
83 | prop_train_labels=0.8, split_train_val=False, seed=0):
84 |
85 | np.random.seed(seed)
86 |
87 | # Subsample imagenet dataset initially to include 100 classes
88 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False)
89 | subsampled_100_classes = np.sort(subsampled_100_classes)
90 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}')
91 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))}
92 |
93 | # Init entire training set
94 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
95 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes)
96 |
97 | # Reset dataset
98 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples]
99 | whole_training_set.targets = [s[1] for s in whole_training_set.samples]
100 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set)))
101 | whole_training_set.target_transform = None
102 |
103 | # Get labelled training set which has subsampled classes, then subsample some indices from that
104 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
105 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
106 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
107 |
108 | # Split into training and validation sets
109 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
110 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
111 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
112 | val_dataset_labelled_split.transform = test_transform
113 |
114 | # Get unlabelled data
115 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
116 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
117 |
118 | # Get test set for all classes
119 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
120 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes)
121 |
122 | # Reset test set
123 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples]
124 | test_dataset.targets = [s[1] for s in test_dataset.samples]
125 | test_dataset.uq_idxs = np.array(range(len(test_dataset)))
126 | test_dataset.target_transform = None
127 |
128 | # Either split train into train and val or use test set as val
129 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
130 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
131 |
132 | all_datasets = {
133 | 'train_labelled': train_dataset_labelled,
134 | 'train_unlabelled': train_dataset_unlabelled,
135 | 'val': val_dataset_labelled,
136 | 'test': test_dataset,
137 | }
138 |
139 | return all_datasets
140 |
141 |
142 | def get_imagenet_1k_datasets(train_transform, test_transform, train_classes=range(500),
143 | prop_train_labels=0.5, split_train_val=False, seed=0):
144 |
145 | np.random.seed(seed)
146 |
147 | # Init entire training set
148 | whole_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
149 |
150 | # Get labelled training set which has subsampled classes, then subsample some indices from that
151 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
152 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
153 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
154 |
155 | # Split into training and validation sets
156 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
157 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
158 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
159 | val_dataset_labelled_split.transform = test_transform
160 |
161 | # Get unlabelled data
162 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
163 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
164 |
165 | # Get test set for all classes
166 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
167 |
168 | # Either split train into train and val or use test set as val
169 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
170 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
171 |
172 | all_datasets = {
173 | 'train_labelled': train_dataset_labelled,
174 | 'train_unlabelled': train_dataset_unlabelled,
175 | 'val': val_dataset_labelled,
176 | 'test': test_dataset,
177 | }
178 |
179 | return all_datasets
180 |
181 |
182 |
183 | if __name__ == '__main__':
184 |
185 | x = get_imagenet_100_datasets(None, None, split_train_val=False,
186 | train_classes=range(50), prop_train_labels=0.5)
187 |
188 | print('Printing lens...')
189 | for k, v in x.items():
190 | if v is not None:
191 | print(f'{k}: {len(v)}')
192 |
193 | print('Printing labelled and unlabelled overlap...')
194 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
195 | print('Printing total instances in train...')
196 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
197 |
198 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
199 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
200 | print(f'Len labelled set: {len(x["train_labelled"])}')
201 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class DINOHead(nn.Module):
7 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True,
8 | nlayers=3, hidden_dim=2048, bottleneck_dim=256):
9 | super().__init__()
10 | nlayers = max(nlayers, 1)
11 | if nlayers == 1:
12 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
13 | elif nlayers != 0:
14 | layers = [nn.Linear(in_dim, hidden_dim)]
15 | if use_bn:
16 | layers.append(nn.BatchNorm1d(hidden_dim))
17 | layers.append(nn.GELU())
18 | for _ in range(nlayers - 2):
19 | layers.append(nn.Linear(hidden_dim, hidden_dim))
20 | if use_bn:
21 | layers.append(nn.BatchNorm1d(hidden_dim))
22 | layers.append(nn.GELU())
23 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
24 | self.mlp = nn.Sequential(*layers)
25 | self.apply(self._init_weights)
26 | self.last_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False))
27 | self.last_layer.weight_g.data.fill_(1)
28 | if norm_last_layer:
29 | self.last_layer.weight_g.requires_grad = False
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | torch.nn.init.trunc_normal_(m.weight, std=.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 |
37 | def forward(self, x):
38 | x_proj = self.mlp(x)
39 | x = nn.functional.normalize(x, dim=-1, p=2)
40 | # x = x.detach()
41 | logits = self.last_layer(x)
42 | return x_proj, logits
43 |
44 |
45 | class ContrastiveLearningViewGenerator(object):
46 | """Take two random crops of one image as the query and key."""
47 |
48 | def __init__(self, base_transform, n_views=2):
49 | self.base_transform = base_transform
50 | self.n_views = n_views
51 |
52 | def __call__(self, x):
53 | if not isinstance(self.base_transform, list):
54 | return [self.base_transform(x) for i in range(self.n_views)]
55 | else:
56 | return [self.base_transform[i](x) for i in range(self.n_views)]
57 |
58 | class SupConLoss(torch.nn.Module):
59 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
60 | It also supports the unsupervised contrastive loss in SimCLR
61 | From: https://github.com/HobbitLong/SupContrast"""
62 | def __init__(self, temperature=0.07, contrast_mode='all',
63 | base_temperature=0.07):
64 | super(SupConLoss, self).__init__()
65 | self.temperature = temperature
66 | self.contrast_mode = contrast_mode
67 | self.base_temperature = base_temperature
68 |
69 | def forward(self, features, labels=None, mask=None):
70 | """Compute loss for model. If both `labels` and `mask` are None,
71 | it degenerates to SimCLR unsupervised loss:
72 | https://arxiv.org/pdf/2002.05709.pdf
73 | Args:
74 | features: hidden vector of shape [bsz, n_views, ...].
75 | labels: ground truth of shape [bsz].
76 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
77 | has the same class as sample i. Can be asymmetric.
78 | Returns:
79 | A loss scalar.
80 | """
81 |
82 | device = (torch.device('cuda')
83 | if features.is_cuda
84 | else torch.device('cpu'))
85 |
86 | if len(features.shape) < 3:
87 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
88 | 'at least 3 dimensions are required')
89 | if len(features.shape) > 3:
90 | features = features.view(features.shape[0], features.shape[1], -1)
91 |
92 | batch_size = features.shape[0]
93 | if labels is not None and mask is not None:
94 | raise ValueError('Cannot define both `labels` and `mask`')
95 | elif labels is None and mask is None:
96 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
97 | elif labels is not None:
98 | labels = labels.contiguous().view(-1, 1)
99 | if labels.shape[0] != batch_size:
100 | raise ValueError('Num of labels does not match num of features')
101 | mask = torch.eq(labels, labels.T).float().to(device)
102 | else:
103 | mask = mask.float().to(device)
104 |
105 | contrast_count = features.shape[1]
106 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
107 | if self.contrast_mode == 'one':
108 | anchor_feature = features[:, 0]
109 | anchor_count = 1
110 | elif self.contrast_mode == 'all':
111 | anchor_feature = contrast_feature
112 | anchor_count = contrast_count
113 | else:
114 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
115 |
116 | # compute logits
117 | anchor_dot_contrast = torch.div(
118 | torch.matmul(anchor_feature, contrast_feature.T),
119 | self.temperature)
120 |
121 | # for numerical stability
122 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
123 | logits = anchor_dot_contrast - logits_max.detach()
124 |
125 | # tile mask
126 | mask = mask.repeat(anchor_count, contrast_count)
127 | # mask-out self-contrast cases
128 | logits_mask = torch.scatter(
129 | torch.ones_like(mask),
130 | 1,
131 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
132 | 0
133 | )
134 | mask = mask * logits_mask
135 |
136 | # compute log_prob
137 | exp_logits = torch.exp(logits) * logits_mask
138 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
139 |
140 | # compute mean of log-likelihood over positive
141 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
142 |
143 | # loss
144 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
145 | loss = loss.view(anchor_count, batch_size).mean()
146 |
147 | return loss
148 |
149 |
150 |
151 | def info_nce_logits(features, n_views=2, temperature=1.0, device='cuda'):
152 |
153 | b_ = 0.5 * int(features.size(0))
154 |
155 | labels = torch.cat([torch.arange(b_) for i in range(n_views)], dim=0)
156 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
157 | labels = labels.to(device)
158 |
159 | features = F.normalize(features, dim=1)
160 |
161 | similarity_matrix = torch.matmul(features, features.T)
162 |
163 | # discard the main diagonal from both: labels and similarities matrix
164 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
165 | labels = labels[~mask].view(labels.shape[0], -1)
166 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
167 |
168 | # select and combine multiple positives
169 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
170 |
171 | # select only the negatives the negatives
172 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
173 |
174 | logits = torch.cat([positives, negatives], dim=1)
175 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
176 |
177 | logits = logits / temperature
178 | return logits, labels
179 |
180 |
181 | def get_params_groups(model):
182 | regularized = []
183 | not_regularized = []
184 | for name, param in model.named_parameters():
185 | if not param.requires_grad:
186 | continue
187 | # we do not regularize biases nor Norm parameters
188 | if name.endswith(".bias") or len(param.shape) == 1:
189 | not_regularized.append(param)
190 | else:
191 | regularized.append(param)
192 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
193 |
194 |
195 | class DistillLoss(nn.Module):
196 | def __init__(self, warmup_teacher_temp_epochs, nepochs,
197 | ncrops=2, warmup_teacher_temp=0.07, teacher_temp=0.04,
198 | student_temp=0.1):
199 | super().__init__()
200 | self.student_temp = student_temp
201 | self.ncrops = ncrops
202 | self.teacher_temp_schedule = np.concatenate((
203 | np.linspace(warmup_teacher_temp,
204 | teacher_temp, warmup_teacher_temp_epochs),
205 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
206 | ))
207 |
208 | def forward(self, student_output, teacher_output, epoch):
209 | """
210 | Cross-entropy between softmax outputs of the teacher and student networks.
211 | """
212 | student_out = student_output / self.student_temp
213 | student_out = student_out.chunk(self.ncrops)
214 |
215 | # teacher centering and sharpening
216 | temp = self.teacher_temp_schedule[epoch]
217 | teacher_out = F.softmax(teacher_output / temp, dim=-1)
218 | teacher_out = teacher_out.detach().chunk(2)
219 |
220 | total_loss = 0
221 | n_loss_terms = 0
222 | for iq, q in enumerate(teacher_out):
223 | for v in range(len(student_out)):
224 | if v == iq:
225 | # we skip cases where student and teacher operate on the same view
226 | continue
227 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
228 | total_loss += loss.mean()
229 | n_loss_terms += 1
230 | total_loss /= n_loss_terms
231 | return total_loss
232 |
--------------------------------------------------------------------------------
/data/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from copy import deepcopy
4 |
5 | from torchvision.datasets.folder import default_loader
6 | from torch.utils.data import Dataset
7 |
8 | from data.data_utils import subsample_instances
9 | from config import aircraft_root
10 |
11 | def make_dataset(dir, image_ids, targets):
12 | assert(len(image_ids) == len(targets))
13 | images = []
14 | dir = os.path.expanduser(dir)
15 | for i in range(len(image_ids)):
16 | item = (os.path.join(dir, 'data', 'images',
17 | '%s.jpg' % image_ids[i]), targets[i])
18 | images.append(item)
19 | return images
20 |
21 |
22 | def find_classes(classes_file):
23 |
24 | # read classes file, separating out image IDs and class names
25 | image_ids = []
26 | targets = []
27 | f = open(classes_file, 'r')
28 | for line in f:
29 | split_line = line.split(' ')
30 | image_ids.append(split_line[0])
31 | targets.append(' '.join(split_line[1:]))
32 | f.close()
33 |
34 | # index class names
35 | classes = np.unique(targets)
36 | class_to_idx = {classes[i]: i for i in range(len(classes))}
37 | targets = [class_to_idx[c] for c in targets]
38 |
39 | return (image_ids, targets, classes, class_to_idx)
40 |
41 |
42 | class FGVCAircraft(Dataset):
43 |
44 | """`FGVC-Aircraft `_ Dataset.
45 |
46 | Args:
47 | root (string): Root directory path to dataset.
48 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
49 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
50 | transform (callable, optional): A function/transform that takes in a PIL image
51 | and returns a transformed version. E.g. ``transforms.RandomCrop``
52 | target_transform (callable, optional): A function/transform that takes in the
53 | target and transforms it.
54 | loader (callable, optional): A function to load an image given its path.
55 | download (bool, optional): If true, downloads the dataset from the internet and
56 | puts it in the root directory. If dataset is already downloaded, it is not
57 | downloaded again.
58 | """
59 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
60 | class_types = ('variant', 'family', 'manufacturer')
61 | splits = ('train', 'val', 'trainval', 'test')
62 |
63 | def __init__(self, root, class_type='variant', split='train', transform=None,
64 | target_transform=None, loader=default_loader, download=False):
65 | if split not in self.splits:
66 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
67 | split, ', '.join(self.splits),
68 | ))
69 | if class_type not in self.class_types:
70 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
71 | class_type, ', '.join(self.class_types),
72 | ))
73 | self.root = os.path.expanduser(root)
74 | self.class_type = class_type
75 | self.split = split
76 | self.classes_file = os.path.join(self.root, 'data',
77 | 'images_%s_%s.txt' % (self.class_type, self.split))
78 |
79 | if download:
80 | self.download()
81 |
82 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
83 | samples = make_dataset(self.root, image_ids, targets)
84 |
85 | self.transform = transform
86 | self.target_transform = target_transform
87 | self.loader = loader
88 |
89 | self.samples = samples
90 | self.classes = classes
91 | self.class_to_idx = class_to_idx
92 | self.train = True if split == 'train' else False
93 |
94 | self.uq_idxs = np.array(range(len(self)))
95 |
96 | def __getitem__(self, index):
97 | """
98 | Args:
99 | index (int): Index
100 |
101 | Returns:
102 | tuple: (sample, target) where target is class_index of the target class.
103 | """
104 |
105 | path, target = self.samples[index]
106 | sample = self.loader(path)
107 | if self.transform is not None:
108 | sample = self.transform(sample)
109 | if self.target_transform is not None:
110 | target = self.target_transform(target)
111 |
112 | return sample, target, self.uq_idxs[index]
113 |
114 | def __len__(self):
115 | return len(self.samples)
116 |
117 | def __repr__(self):
118 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
119 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
120 | fmt_str += ' Root Location: {}\n'.format(self.root)
121 | tmp = ' Transforms (if any): '
122 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
123 | tmp = ' Target Transforms (if any): '
124 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
125 | return fmt_str
126 |
127 | def _check_exists(self):
128 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
129 | os.path.exists(self.classes_file)
130 |
131 | def download(self):
132 | """Download the FGVC-Aircraft data if it doesn't exist already."""
133 | from six.moves import urllib
134 | import tarfile
135 |
136 | if self._check_exists():
137 | return
138 |
139 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
140 | print('Downloading %s ... (may take a few minutes)' % self.url)
141 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
142 | tar_name = self.url.rpartition('/')[-1]
143 | tar_path = os.path.join(parent_dir, tar_name)
144 | data = urllib.request.urlopen(self.url)
145 |
146 | # download .tar.gz file
147 | with open(tar_path, 'wb') as f:
148 | f.write(data.read())
149 |
150 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
151 | data_folder = tar_path.strip('.tar.gz')
152 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
153 | tar = tarfile.open(tar_path)
154 | tar.extractall(parent_dir)
155 |
156 | # if necessary, rename data folder to self.root
157 | if not os.path.samefile(data_folder, self.root):
158 | print('Renaming %s to %s ...' % (data_folder, self.root))
159 | os.rename(data_folder, self.root)
160 |
161 | # delete .tar.gz file
162 | print('Deleting %s ...' % tar_path)
163 | os.remove(tar_path)
164 |
165 | print('Done!')
166 |
167 |
168 | def subsample_dataset(dataset, idxs):
169 |
170 | mask = np.zeros(len(dataset)).astype('bool')
171 | mask[idxs] = True
172 |
173 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs]
174 | dataset.uq_idxs = dataset.uq_idxs[mask]
175 |
176 | return dataset
177 |
178 |
179 | def subsample_classes(dataset, include_classes=range(60)):
180 |
181 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes]
182 |
183 | # TODO: Don't transform targets for now
184 | target_xform_dict = {}
185 | for i, k in enumerate(include_classes):
186 | target_xform_dict[k] = i
187 |
188 | dataset = subsample_dataset(dataset, cls_idxs)
189 |
190 | dataset.target_transform = lambda x: target_xform_dict[x]
191 |
192 | return dataset
193 |
194 |
195 | def get_train_val_indices(train_dataset, val_split=0.2):
196 |
197 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)]
198 | train_classes = np.unique(all_targets)
199 |
200 | # Get train/test indices
201 | train_idxs = []
202 | val_idxs = []
203 | for cls in train_classes:
204 | cls_idxs = np.where(all_targets == cls)[0]
205 |
206 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
207 | t_ = [x for x in cls_idxs if x not in v_]
208 |
209 | train_idxs.extend(t_)
210 | val_idxs.extend(v_)
211 |
212 | return train_idxs, val_idxs
213 |
214 |
215 | def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8,
216 | split_train_val=False, seed=0):
217 |
218 | np.random.seed(seed)
219 |
220 | # Init entire training set
221 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval')
222 |
223 | # Get labelled training set which has subsampled classes, then subsample some indices from that
224 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
225 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
226 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
227 |
228 | # Split into training and validation sets
229 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
230 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
231 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
232 | val_dataset_labelled_split.transform = test_transform
233 |
234 | # Get unlabelled data
235 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
236 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
237 |
238 | # Get test set for all classes
239 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test')
240 |
241 | # Either split train into train and val or use test set as val
242 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
243 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
244 |
245 | all_datasets = {
246 | 'train_labelled': train_dataset_labelled,
247 | 'train_unlabelled': train_dataset_unlabelled,
248 | 'val': val_dataset_labelled,
249 | 'test': test_dataset,
250 | }
251 |
252 | return all_datasets
253 |
254 | if __name__ == '__main__':
255 |
256 | x = get_aircraft_datasets(None, None, split_train_val=False)
257 |
258 | print('Printing lens...')
259 | for k, v in x.items():
260 | if v is not None:
261 | print(f'{k}: {len(v)}')
262 |
263 | print('Printing labelled and unlabelled overlap...')
264 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
265 | print('Printing total instances in train...')
266 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
267 | print('Printing number of labelled classes...')
268 | print(len(set([i[1] for i in x['train_labelled'].samples])))
269 | print('Printing total number of classes...')
270 | print(len(set([i[1] for i in x['train_unlabelled'].samples])))
271 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from torch.optim import SGD, lr_scheduler
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 |
11 | from data.augmentations import get_transform
12 | from data.get_datasets import get_datasets, get_class_splits
13 |
14 | from util.general_utils import AverageMeter, init_experiment
15 | from util.cluster_and_log_utils import log_accs_from_preds
16 | from config import exp_root
17 | from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups
18 |
19 |
20 | def train(student, train_loader, test_loader, unlabelled_train_loader, args):
21 | params_groups = get_params_groups(student)
22 | optimizer = SGD(params_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
23 | fp16_scaler = None
24 | if args.fp16:
25 | fp16_scaler = torch.cuda.amp.GradScaler()
26 |
27 | exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(
28 | optimizer,
29 | T_max=args.epochs,
30 | eta_min=args.lr * 1e-3,
31 | )
32 |
33 |
34 | cluster_criterion = DistillLoss(
35 | args.warmup_teacher_temp_epochs,
36 | args.epochs,
37 | args.n_views,
38 | args.warmup_teacher_temp,
39 | args.teacher_temp,
40 | )
41 |
42 | # # inductive
43 | # best_test_acc_lab = 0
44 | # # transductive
45 | # best_train_acc_lab = 0
46 | # best_train_acc_ubl = 0
47 | # best_train_acc_all = 0
48 |
49 | for epoch in range(args.epochs):
50 | loss_record = AverageMeter()
51 |
52 | student.train()
53 | for batch_idx, batch in enumerate(train_loader):
54 | images, class_labels, uq_idxs, mask_lab = batch
55 | mask_lab = mask_lab[:, 0]
56 |
57 | class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool()
58 | images = torch.cat(images, dim=0).cuda(non_blocking=True)
59 |
60 | with torch.cuda.amp.autocast(fp16_scaler is not None):
61 | student_proj, student_out = student(images)
62 | teacher_out = student_out.detach()
63 |
64 | # clustering, sup
65 | sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0)
66 | sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0)
67 | cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels)
68 |
69 | # clustering, unsup
70 | cluster_loss = cluster_criterion(student_out, teacher_out, epoch)
71 | avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0)
72 | me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
73 | cluster_loss += args.memax_weight * me_max_loss
74 |
75 | # represent learning, unsup
76 | contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj)
77 | contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels)
78 |
79 | # representation learning, sup
80 | student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1)
81 | student_proj = torch.nn.functional.normalize(student_proj, dim=-1)
82 | sup_con_labels = class_labels[mask_lab]
83 | sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels)
84 |
85 | pstr = ''
86 | pstr += f'cls_loss: {cls_loss.item():.4f} '
87 | pstr += f'cluster_loss: {cluster_loss.item():.4f} '
88 | pstr += f'sup_con_loss: {sup_con_loss.item():.4f} '
89 | pstr += f'contrastive_loss: {contrastive_loss.item():.4f} '
90 |
91 | loss = 0
92 | loss += (1 - args.sup_weight) * cluster_loss + args.sup_weight * cls_loss
93 | loss += (1 - args.sup_weight) * contrastive_loss + args.sup_weight * sup_con_loss
94 |
95 | # Train acc
96 | loss_record.update(loss.item(), class_labels.size(0))
97 | optimizer.zero_grad()
98 | if fp16_scaler is None:
99 | loss.backward()
100 | optimizer.step()
101 | else:
102 | fp16_scaler.scale(loss).backward()
103 | fp16_scaler.step(optimizer)
104 | fp16_scaler.update()
105 |
106 | if batch_idx % args.print_freq == 0:
107 | args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}'
108 | .format(epoch, batch_idx, len(train_loader), loss.item(), pstr))
109 |
110 | args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg))
111 |
112 | args.logger.info('Testing on unlabelled examples in the training data...')
113 | all_acc, old_acc, new_acc = test(student, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args)
114 | # args.logger.info('Testing on disjoint test set...')
115 | # all_acc_test, old_acc_test, new_acc_test = test(student, test_loader, epoch=epoch, save_name='Test ACC', args=args)
116 |
117 |
118 | args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
119 | # args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))
120 |
121 | # Step schedule
122 | exp_lr_scheduler.step()
123 |
124 | save_dict = {
125 | 'model': student.state_dict(),
126 | 'optimizer': optimizer.state_dict(),
127 | 'epoch': epoch + 1,
128 | }
129 |
130 | torch.save(save_dict, args.model_path)
131 | args.logger.info("model saved to {}.".format(args.model_path))
132 |
133 | # if old_acc_test > best_test_acc_lab:
134 | #
135 | # args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
136 | # args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
137 | #
138 | # torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
139 | # args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))
140 | #
141 | # # inductive
142 | # best_test_acc_lab = old_acc_test
143 | # # transductive
144 | # best_train_acc_lab = old_acc
145 | # best_train_acc_ubl = new_acc
146 | # best_train_acc_all = all_acc
147 | #
148 | # args.logger.info(f'Exp Name: {args.exp_name}')
149 | # args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')
150 |
151 |
152 | def test(model, test_loader, epoch, save_name, args):
153 |
154 | model.eval()
155 |
156 | preds, targets = [], []
157 | mask = np.array([])
158 | for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)):
159 | images = images.cuda(non_blocking=True)
160 | with torch.no_grad():
161 | _, logits = model(images)
162 | preds.append(logits.argmax(1).cpu().numpy())
163 | targets.append(label.cpu().numpy())
164 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label]))
165 |
166 | preds = np.concatenate(preds)
167 | targets = np.concatenate(targets)
168 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask,
169 | T=epoch, eval_funcs=args.eval_funcs, save_name=save_name,
170 | args=args)
171 |
172 | return all_acc, old_acc, new_acc
173 |
174 |
175 | if __name__ == "__main__":
176 |
177 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
178 | parser.add_argument('--batch_size', default=128, type=int)
179 | parser.add_argument('--num_workers', default=8, type=int)
180 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2p'])
181 |
182 | parser.add_argument('--warmup_model_dir', type=str, default=None)
183 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19')
184 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
185 | parser.add_argument('--use_ssb_splits', action='store_true', default=True)
186 |
187 | parser.add_argument('--grad_from_block', type=int, default=11)
188 | parser.add_argument('--lr', type=float, default=0.1)
189 | parser.add_argument('--gamma', type=float, default=0.1)
190 | parser.add_argument('--momentum', type=float, default=0.9)
191 | parser.add_argument('--weight_decay', type=float, default=1e-4)
192 | parser.add_argument('--epochs', default=200, type=int)
193 | parser.add_argument('--exp_root', type=str, default=exp_root)
194 | parser.add_argument('--transform', type=str, default='imagenet')
195 | parser.add_argument('--sup_weight', type=float, default=0.35)
196 | parser.add_argument('--n_views', default=2, type=int)
197 |
198 | parser.add_argument('--memax_weight', type=float, default=2)
199 | parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.')
200 | parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.')
201 | parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.')
202 |
203 | parser.add_argument('--fp16', action='store_true', default=False)
204 | parser.add_argument('--print_freq', default=10, type=int)
205 | parser.add_argument('--exp_name', default=None, type=str)
206 |
207 | # ----------------------
208 | # INIT
209 | # ----------------------
210 | args = parser.parse_args()
211 | device = torch.device('cuda:0')
212 | args = get_class_splits(args)
213 |
214 | args.num_labeled_classes = len(args.train_classes)
215 | args.num_unlabeled_classes = len(args.unlabeled_classes)
216 |
217 | init_experiment(args, runner_name=['simgcd'])
218 | args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results')
219 |
220 | torch.backends.cudnn.benchmark = True
221 |
222 | # ----------------------
223 | # BASE MODEL
224 | # ----------------------
225 | args.interpolation = 3
226 | args.crop_pct = 0.875
227 |
228 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
229 |
230 | if args.warmup_model_dir is not None:
231 | args.logger.info(f'Loading weights from {args.warmup_model_dir}')
232 | backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu'))
233 |
234 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model
235 | args.image_size = 224
236 | args.feat_dim = 768
237 | args.num_mlp_layers = 3
238 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes
239 |
240 | # ----------------------
241 | # HOW MUCH OF BASE MODEL TO FINETUNE
242 | # ----------------------
243 | for m in backbone.parameters():
244 | m.requires_grad = False
245 |
246 | # Only finetune layers from block 'args.grad_from_block' onwards
247 | for name, m in backbone.named_parameters():
248 | if 'block' in name:
249 | block_num = int(name.split('.')[1])
250 | if block_num >= args.grad_from_block:
251 | m.requires_grad = True
252 |
253 |
254 | args.logger.info('model build')
255 |
256 | # --------------------
257 | # CONTRASTIVE TRANSFORM
258 | # --------------------
259 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
260 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
261 | # --------------------
262 | # DATASETS
263 | # --------------------
264 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name,
265 | train_transform,
266 | test_transform,
267 | args)
268 |
269 | # --------------------
270 | # SAMPLER
271 | # Sampler which balances labelled and unlabelled examples in each batch
272 | # --------------------
273 | label_len = len(train_dataset.labelled_dataset)
274 | unlabelled_len = len(train_dataset.unlabelled_dataset)
275 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
276 | sample_weights = torch.DoubleTensor(sample_weights)
277 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset))
278 |
279 | # --------------------
280 | # DATALOADERS
281 | # --------------------
282 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
283 | sampler=sampler, drop_last=True, pin_memory=True)
284 | test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
285 | batch_size=256, shuffle=False, pin_memory=False)
286 | # test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers,
287 | # batch_size=256, shuffle=False, pin_memory=False)
288 |
289 | # ----------------------
290 | # PROJECTION HEAD
291 | # ----------------------
292 | projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers)
293 | model = nn.Sequential(backbone, projector).to(device)
294 |
295 | # ----------------------
296 | # TRAIN
297 | # ----------------------
298 | # train(model, train_loader, test_loader_labelled, test_loader_unlabelled, args)
299 | train(model, train_loader, None, test_loader_unlabelled, args)
300 |
--------------------------------------------------------------------------------
/train_mp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import math
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.distributed as dist
9 | import torch.backends.cudnn as cudnn
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 |
13 | from data.augmentations import get_transform
14 | from data.get_datasets import get_datasets, get_class_splits
15 |
16 | from util.general_utils import AverageMeter, init_experiment, DistributedWeightedSampler
17 | from util.cluster_and_log_utils import log_accs_from_preds
18 | from config import exp_root
19 | from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups
20 |
21 |
22 | def get_parser():
23 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24 |
25 | parser.add_argument('--batch_size', default=128, type=int)
26 | parser.add_argument('--num_workers', default=8, type=int)
27 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2b'])
28 |
29 | parser.add_argument('--warmup_model_dir', type=str, default=None)
30 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19')
31 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
32 | parser.add_argument('--use_ssb_splits', action='store_true', default=True)
33 |
34 | parser.add_argument('--grad_from_block', type=int, default=11)
35 | parser.add_argument('--lr', type=float, default=0.1)
36 | parser.add_argument('--gamma', type=float, default=0.1)
37 | parser.add_argument('--momentum', type=float, default=0.9)
38 | parser.add_argument('--weight_decay', type=float, default=1e-4)
39 | parser.add_argument('--epochs', default=200, type=int)
40 | parser.add_argument('--exp_root', type=str, default=exp_root)
41 | parser.add_argument('--transform', type=str, default='imagenet')
42 | parser.add_argument('--sup_weight', type=float, default=0.35)
43 | parser.add_argument('--n_views', default=2, type=int)
44 |
45 | parser.add_argument('--memax_weight', type=float, default=2)
46 | parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.')
47 | parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.')
48 | parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.')
49 |
50 | parser.add_argument('--fp16', action='store_true', default=False)
51 | parser.add_argument('--print_freq', default=10, type=int)
52 | parser.add_argument('--exp_name', default=None, type=str)
53 |
54 | # ----------------------
55 | # INIT
56 | # ----------------------
57 | args = parser.parse_args()
58 | args = get_class_splits(args)
59 |
60 | args.num_labeled_classes = len(args.train_classes)
61 | args.num_unlabeled_classes = len(args.unlabeled_classes)
62 |
63 | if os.environ["LOCAL_RANK"] is not None:
64 | args.local_rank = int(os.environ["LOCAL_RANK"])
65 |
66 | return args
67 |
68 |
69 | def main(args):
70 | # ----------------------
71 | # BASE MODEL
72 | # ----------------------
73 | args.interpolation = 3
74 | args.crop_pct = 0.875
75 |
76 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
77 |
78 | if args.warmup_model_dir is not None:
79 | if dist.get_rank() == 0:
80 | args.logger.info(f'Loading weights from {args.warmup_model_dir}')
81 | backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu'))
82 |
83 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model
84 | args.image_size = 224
85 | args.feat_dim = 768
86 | args.num_mlp_layers = 3
87 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes
88 |
89 | # ----------------------
90 | # HOW MUCH OF BASE MODEL TO FINETUNE
91 | # ----------------------
92 | for m in backbone.parameters():
93 | m.requires_grad = False
94 |
95 | # Only finetune layers from block 'args.grad_from_block' onwards
96 | for name, m in backbone.named_parameters():
97 | if 'block' in name:
98 | block_num = int(name.split('.')[1])
99 | if block_num >= args.grad_from_block:
100 | m.requires_grad = True
101 |
102 | if dist.get_rank() == 0:
103 | args.logger.info('model build')
104 |
105 | # --------------------
106 | # CONTRASTIVE TRANSFORM
107 | # --------------------
108 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
109 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
110 | # --------------------
111 | # DATASETS
112 | # --------------------
113 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name,
114 | train_transform,
115 | test_transform,
116 | args)
117 |
118 | # --------------------
119 | # SAMPLER
120 | # Sampler which balances labelled and unlabelled examples in each batch
121 | # --------------------
122 | label_len = len(train_dataset.labelled_dataset)
123 | unlabelled_len = len(train_dataset.unlabelled_dataset)
124 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
125 | sample_weights = torch.DoubleTensor(sample_weights)
126 | train_sampler = DistributedWeightedSampler(train_dataset, sample_weights, num_samples=len(train_dataset))
127 | unlabelled_train_sampler = torch.utils.data.distributed.DistributedSampler(unlabelled_train_examples_test)
128 | # test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
129 | # --------------------
130 | # DATALOADERS
131 | # --------------------
132 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
133 | sampler=train_sampler, drop_last=True, pin_memory=True)
134 | unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, batch_size=256,
135 | shuffle=False, sampler=unlabelled_train_sampler, pin_memory=False)
136 | # test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=256,
137 | # shuffle=False, sampler=test_sampler, pin_memory=False)
138 |
139 | # ----------------------
140 | # PROJECTION HEAD
141 | # ----------------------
142 | projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers)
143 | model = nn.Sequential(backbone, projector).cuda()
144 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
145 |
146 | params_groups = get_params_groups(model)
147 | optimizer = torch.optim.SGD(
148 | params_groups,
149 | lr=args.lr * (args.batch_size * dist.get_world_size() / 128), # linear scaling rule
150 | momentum=args.momentum,
151 | weight_decay=args.weight_decay
152 | )
153 |
154 | fp16_scaler = None
155 | if args.fp16:
156 | fp16_scaler = torch.cuda.amp.GradScaler()
157 |
158 | exp_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
159 | optimizer,
160 | T_max=args.epochs,
161 | eta_min=args.lr * (args.batch_size * dist.get_world_size() / 128) * 1e-3,
162 | )
163 |
164 | cluster_criterion = DistillLoss(
165 | args.warmup_teacher_temp_epochs,
166 | args.epochs,
167 | args.n_views,
168 | args.warmup_teacher_temp,
169 | args.teacher_temp,
170 | )
171 |
172 | # # inductive
173 | # best_test_acc_lab = 0
174 | # # transductive
175 | # best_train_acc_lab = 0
176 | # best_train_acc_ubl = 0
177 | # best_train_acc_all = 0
178 |
179 | for epoch in range(args.epochs):
180 | train_sampler.set_epoch(epoch)
181 | train(model, train_loader, optimizer, fp16_scaler, exp_lr_scheduler, cluster_criterion, epoch, args)
182 |
183 | unlabelled_train_sampler.set_epoch(epoch)
184 | # test_sampler.set_epoch(epoch)
185 | if dist.get_rank() == 0:
186 | args.logger.info('Testing on unlabelled examples in the training data...')
187 | all_acc, old_acc, new_acc = test(model, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args)
188 | # if dist.get_rank() == 0:
189 | # args.logger.info('Testing on disjoint test set...')
190 | # all_acc_test, old_acc_test, new_acc_test = test(model, test_loader, epoch=epoch, save_name='Test ACC', args=args)
191 |
192 | if dist.get_rank() == 0:
193 | args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
194 | # args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))
195 |
196 | save_dict = {
197 | 'model': model.state_dict(),
198 | 'optimizer': optimizer.state_dict(),
199 | 'epoch': epoch + 1,
200 | }
201 |
202 | torch.save(save_dict, args.model_path)
203 | args.logger.info("model saved to {}.".format(args.model_path))
204 |
205 | # if old_acc_test > best_test_acc_lab and dist.get_rank() == 0:
206 | # args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
207 | # args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
208 | #
209 | # torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
210 | # args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))
211 | #
212 | # # inductive
213 | # best_test_acc_lab = old_acc_test
214 | # # transductive
215 | # best_train_acc_lab = old_acc
216 | # best_train_acc_ubl = new_acc
217 | # best_train_acc_all = all_acc
218 | #
219 | # if dist.get_rank() == 0:
220 | # args.logger.info(f'Exp Name: {args.exp_name}')
221 | # args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')
222 |
223 |
224 | def train(student, train_loader, optimizer, scaler, scheduler, cluster_criterion, epoch, args):
225 | loss_record = AverageMeter()
226 |
227 | student.train()
228 | for batch_idx, batch in enumerate(train_loader):
229 | images, class_labels, uq_idxs, mask_lab = batch
230 | mask_lab = mask_lab[:, 0]
231 |
232 | class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool()
233 | images = torch.cat(images, dim=0).cuda(non_blocking=True)
234 |
235 | with torch.cuda.amp.autocast(scaler is not None):
236 | student_proj, student_out = student(images)
237 | teacher_out = student_out.detach()
238 |
239 | # clustering, sup
240 | sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0)
241 | sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0)
242 | cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels)
243 |
244 | # clustering, unsup
245 | cluster_loss = cluster_criterion(student_out, teacher_out, epoch)
246 | avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0)
247 | me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
248 | cluster_loss += args.memax_weight * me_max_loss
249 |
250 | # represent learning, unsup
251 | contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj)
252 | contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels)
253 |
254 | # representation learning, sup
255 | student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1)
256 | student_proj = torch.nn.functional.normalize(student_proj, dim=-1)
257 | sup_con_labels = class_labels[mask_lab]
258 | sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels)
259 |
260 | pstr = ''
261 | pstr += f'cls_loss: {cls_loss.item():.4f} '
262 | pstr += f'cluster_loss: {cluster_loss.item():.4f} '
263 | pstr += f'sup_con_loss: {sup_con_loss.item():.4f} '
264 | pstr += f'contrastive_loss: {contrastive_loss.item():.4f} '
265 |
266 | loss = 0
267 | loss += (1 - args.sup_weight) * cluster_loss + args.sup_weight * cls_loss
268 | loss += (1 - args.sup_weight) * contrastive_loss + args.sup_weight * sup_con_loss
269 |
270 | # Train acc
271 | loss_record.update(loss.item(), class_labels.size(0))
272 | optimizer.zero_grad()
273 | if scaler is None:
274 | loss.backward()
275 | optimizer.step()
276 | else:
277 | scaler.scale(loss).backward()
278 | scaler.step(optimizer)
279 | scaler.update()
280 |
281 | if batch_idx % args.print_freq == 0 and dist.get_rank() == 0:
282 | args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}'
283 | .format(epoch, batch_idx, len(train_loader), loss.item(), pstr))
284 | # Step schedule
285 | scheduler.step()
286 |
287 | if dist.get_rank() == 0:
288 | args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg))
289 |
290 |
291 | def test(model, test_loader, epoch, save_name, args):
292 |
293 | model.eval()
294 |
295 | preds, targets = [], []
296 | mask = np.array([])
297 | for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)):
298 | images = images.cuda(non_blocking=True)
299 | with torch.no_grad():
300 | _, logits = model(images)
301 | preds.append(logits.argmax(1).cpu().numpy())
302 | targets.append(label.cpu().numpy())
303 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label]))
304 |
305 | preds = np.concatenate(preds)
306 | targets = np.concatenate(targets)
307 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask,
308 | T=epoch, eval_funcs=args.eval_funcs, save_name=save_name,
309 | args=args)
310 |
311 | return all_acc, old_acc, new_acc
312 |
313 |
314 | if __name__ == '__main__':
315 | args = get_parser()
316 |
317 | torch.cuda.set_device(args.local_rank)
318 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
319 | cudnn.benchmark = True
320 |
321 | if dist.get_rank() == 0:
322 | init_experiment(args, runner_name=['simgcd'])
323 | args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results')
324 |
325 | main(args)
--------------------------------------------------------------------------------