├── data ├── __init__.py ├── augmentations.py ├── get_datasets.py ├── inatloc.py ├── imagenet.py └── openimages.py ├── models ├── __init__.py ├── resnet.py └── builder.py ├── methods ├── __init__.py ├── estimate_k.py ├── gcam.py ├── metric.py └── contrastive_co_training.py ├── project_utils ├── __init__.py ├── cluster_and_log_utils.py └── utils.py ├── method.png ├── bash_scripts ├── gcam_inatloc.sh ├── gcam_imagenet.sh ├── gcam_openimages.sh ├── estimate_k.sh ├── train_inatloc.sh ├── train_imagenet.sh └── train_openimages.sh ├── config.py └── README.md /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /project_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryylcc/OWSOL/HEAD/method.png -------------------------------------------------------------------------------- /bash_scripts/gcam_inatloc.sh: -------------------------------------------------------------------------------- 1 | hostname 2 | nvidia-smi 3 | 4 | export CUDA_VISIBLE_DEVICES=3 5 | 6 | python -m methods.gcam \ 7 | --batch_size 128 \ 8 | --dataset_name 'iNatLoc' \ 9 | --model_path './save/iNatLoc/lr0.0005_scl0.5_mcl1.0_mc5_e10/checkpoint_0009.pth.tar' \ 10 | --partitions 'all' 'known' 'nov_s' 'nov_d' -------------------------------------------------------------------------------- /bash_scripts/gcam_imagenet.sh: -------------------------------------------------------------------------------- 1 | hostname 2 | nvidia-smi 3 | 4 | export CUDA_VISIBLE_DEVICES=3 5 | 6 | python -m methods.gcam \ 7 | --batch_size 128 \ 8 | --dataset_name 'ImageNet' \ 9 | --model_path './save/ImageNet/lr0.0005_scl0.5_mcl1.0_mc5_e10/checkpoint_0009.pth.tar' \ 10 | --partitions 'all' 'known' 'nov_s' 'nov_d' -------------------------------------------------------------------------------- /bash_scripts/gcam_openimages.sh: -------------------------------------------------------------------------------- 1 | hostname 2 | nvidia-smi 3 | 4 | export CUDA_VISIBLE_DEVICES=3 5 | 6 | python -m methods.gcam \ 7 | --batch_size 128 \ 8 | --dataset_name 'OpenImages' \ 9 | --model_path './save/OpenImages/lr0.0005_scl0.5_mcl1.0_mc5_e10/checkpoint_0009.pth.tar' \ 10 | --partitions 'all' 'known' 'nov_d' 11 | -------------------------------------------------------------------------------- /bash_scripts/estimate_k.sh: -------------------------------------------------------------------------------- 1 | nvidia-smi 2 | hostname 3 | export CUDA_VISIBLE_DEVICES=6,7 4 | # Get unique log file 5 | SAVE_DIR=/raid/zhaochuan/master/NCL/code/OWSOL/log/estimate_k/ 6 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 7 | EXP_NUM=$((${EXP_NUM}+1)) 8 | echo $EXP_NUM 9 | 10 | python -m methods.estimate_k \ 11 | --max_classes 5000 \ 12 | --batch_size 64 \ 13 | --dataset_name 'inatloc' \ 14 | --search_mode 'brent'\ 15 | > ${SAVE_DIR}logfile_${EXP_NUM}.log -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ----------------- 2 | # DATASET ROOTS 3 | # ----------------- 4 | 5 | 6 | imagenet_root = '../dataset/ILSVRC' 7 | imagenet_meta = '../metadata/ImageNet' 8 | imagenet_sup_queue_path = '../metadata/ImageNet/sup_queue.pkl' 9 | 10 | inatloc_root = '../dataset/iNatLoc' 11 | inatloc_meta = '../metadata/iNatLoc' 12 | inatloc_sup_queue_path = '../metadata/iNatLoc/scl_queue.pkl' 13 | 14 | openimages_root = '../dataset/OpenImages' 15 | openimages_meta = '../metadata/OpenImages' 16 | openimages_sup_queue_path = '../metadata/OpenImages/scl_queue.pkl' 17 | 18 | 19 | # ----------------- 20 | # CKPT PATHS 21 | # ----------------- 22 | pretrained_path = '../ckpt/moco_v2_200ep_pretrain.pth.tar' 23 | -------------------------------------------------------------------------------- /bash_scripts/train_inatloc.sh: -------------------------------------------------------------------------------- 1 | hostname 2 | nvidia-smi 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | 5 | # Get unique log file, 6 | SAVE_DIR='/raid/zhaochuan/master/NCL/code/OWSOL/log/iNatLoc/' 7 | mkdir -p ${SAVE_DIR} 8 | LR='0.001' 9 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 10 | EXP_NUM=$((${EXP_NUM}+1)) 11 | echo $EXP_NUM 12 | 13 | python -m methods.contrastive_co_training \ 14 | --dataset_name 'iNatLoc' \ 15 | -b 128 \ 16 | --lr ${LR} \ 17 | --epochs 10 \ 18 | --dist_url 'tcp://localhost:10001' \ 19 | --multiprocessing_distributed --world_size 1 --rank 0 \ 20 | --num_cluster 5000 \ 21 | --mcl_k 4096 \ 22 | --num_multi_centroids 5 \ 23 | --scl_weight 0.5 \ 24 | --mcl_weight 1.0 \ 25 | > ${SAVE_DIR}logfile_${EXP_NUM}_${LR}.log -------------------------------------------------------------------------------- /bash_scripts/train_imagenet.sh: -------------------------------------------------------------------------------- 1 | hostname 2 | nvidia-smi 3 | export CUDA_VISIBLE_DEVICES=4,5,6,7 4 | 5 | # Get unique log file, 6 | SAVE_DIR='/raid/xjheng/zc/NCL/code/OWSOL/log/ImageNet/' 7 | mkdir -p ${SAVE_DIR} 8 | LR='0.003' 9 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 10 | EXP_NUM=$((${EXP_NUM}+1)) 11 | echo $EXP_NUM 12 | 13 | 14 | python -m methods.contrastive_co_training \ 15 | --dataset_name 'ImageNet' \ 16 | -b 256 \ 17 | --lr ${LR} \ 18 | --epochs 10 \ 19 | --dist_url 'tcp://localhost:10001' \ 20 | --multiprocessing_distributed --world_size 1 --rank 0 \ 21 | --num_cluster 50000 \ 22 | --mcl_k 16384 \ 23 | --num_multi_centroids 5 \ 24 | --scl_weight 0.5 \ 25 | --mcl_weight 1.0 \ 26 | > ${SAVE_DIR}logfile_${EXP_NUM}_${LR}.log 27 | 28 | 29 | -------------------------------------------------------------------------------- /bash_scripts/train_openimages.sh: -------------------------------------------------------------------------------- 1 | hostname 2 | nvidia-smi 3 | export CUDA_VISIBLE_DEVICES=4,5 4 | # Get unique log file, 5 | SAVE_DIR='/home/zhaochuan/workspace/NCL/code/OWSOL/log/OpenImages/' 6 | mkdir -p ${SAVE_DIR} 7 | LR='0.001' 8 | EXP_NUM=$(ls ${SAVE_DIR} | wc -l) 9 | EXP_NUM=$((${EXP_NUM}+1)) 10 | echo $EXP_NUM 11 | 12 | 13 | python -m methods.contrastive_co_training \ 14 | --dataset_name 'OpenImages' \ 15 | -b 64 \ 16 | --lr 0.001 \ 17 | --epochs 10 \ 18 | --dist_url 'tcp://localhost:10001' \ 19 | --multiprocessing_distributed --world_size 1 --rank 0 \ 20 | --num_cluster 2500 \ 21 | --mcl_k 2048 \ 22 | --num_multi_centroids 5 \ 23 | --scl_weight 0.5 \ 24 | --mcl_weight 1.0 \ 25 | > ${SAVE_DIR}logfile_${EXP_NUM}_${LR}.log 26 | 27 | 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OWSOL 2 | Code of our paper Open-World Weakly-Supervised Object Localization. 3 | 4 | ![](method.png) 5 | 6 | 7 | 8 | ## Dependencies 9 | 10 | - Python 3 11 | - PyTorch 1.7.1 12 | - OpenCV-Python 13 | - Numpy 14 | - Scipy 15 | - MatplotLib 16 | - faiss-gpu 17 | - munch 18 | 19 | ## Dataset 20 | 21 | ### ImageNet-1K 22 | 23 | "train" and "val" splits of original [ImageNet](http://www.image-net.org/) are treated as our `train` and `test`, [ImageNetV2](https://github.com/modestyachts/ImageNetV2) is treated as our `val`. 24 | 25 | Make sure your `dataset/ILSVRC` folder is structured as follows: 26 | ``` 27 | ├── ILSVRC/ 28 | | ├── train/ 29 | | | |── n01440764 30 | | | |── n01443537 31 | | | |── ... 32 | | ├── val 33 | | | |── n01440764 34 | | | |── n01443537 35 | | | |── ... 36 | | ├── val2 37 | | | |── 0 38 | | | |── 1 39 | | | └── ... 40 | ``` 41 | 42 | ## Metadata 43 | 44 | You can download the annotations of datasets from [Download metadata](https://pan.baidu.com/s/1QNsIb6UMn63J2XzHGSwONw?pwd=dqdf ), the password is `dqdf`. 45 | 46 | Make sure your `metadata/ImageNet` folder is structured as follows: 47 | 48 | ``` 49 | ├── ImageNet/ 50 | | ├── train/ 51 | | | |── class_labels.txt 52 | | | |── image_ids.txt 53 | | | |── image_ids_labeled.txt 54 | | ├── test 55 | | | |── class_labels.txt 56 | | | |── image_ids.txt 57 | | | |── image_sizes.txt 58 | | | |── localization.txt.txt 59 | | | |── partitions.txt 60 | | ├── val 61 | | | |── class_labels.txt 62 | | | |── image_ids.txt 63 | | | |── image_sizes.txt 64 | | | |── localization.txt.txt 65 | | | |── partitions.txt 66 | ``` 67 | 68 | for `test` and `val`, the partitions of **Known**, **Nov-S** and **Nov-D** is described in `partitions.txt`. In detail, **0**, **1** and **2** are correspond to **Known**, **Nov-S** and **Nov-D**, respectively. 69 | 70 | ## Training 71 | 72 | perform contrastive representation co-learning on ImageNet-1K dataset 73 | 74 | ``` 75 | bash bash_script/train_imagenet.sh 76 | ``` 77 | 78 | ## G-CAM 79 | 80 | perform g-cam on ImageNet-1K dataset 81 | 82 | ``` 83 | bash bash_script/gcam_imagenet.sh 84 | ``` 85 | 86 | -------------------------------------------------------------------------------- /data/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | 4 | def get_transform(transform_type='default', image_size=224, args=None): 5 | 6 | if transform_type == 'ImageNet': 7 | 8 | mean = (0.485, 0.456, 0.406) 9 | std = (0.229, 0.224, 0.225) 10 | interpolation = args.interpolation 11 | crop_pct = args.crop_pct 12 | 13 | train_transform = transforms.Compose([ 14 | transforms.Resize(int(image_size / crop_pct), interpolation), 15 | transforms.RandomCrop(image_size), 16 | transforms.RandomHorizontalFlip(p=0.5), 17 | transforms.ColorJitter(), 18 | transforms.ToTensor(), 19 | transforms.Normalize( 20 | mean=torch.tensor(mean), 21 | std=torch.tensor(std)) 22 | ]) 23 | 24 | test_transform = transforms.Compose([ 25 | transforms.Resize(int(image_size / crop_pct), interpolation), 26 | transforms.CenterCrop(image_size), 27 | transforms.ToTensor(), 28 | transforms.Normalize( 29 | mean=torch.tensor(mean), 30 | std=torch.tensor(std)) 31 | ]) 32 | 33 | elif transform_type == 'iNatLoc': 34 | 35 | mean = (0.485, 0.456, 0.406) 36 | std = (0.229, 0.224, 0.225) 37 | interpolation = args.interpolation 38 | crop_pct = args.crop_pct 39 | 40 | train_transform = transforms.Compose([ 41 | transforms.Resize(int(image_size / crop_pct), interpolation), 42 | transforms.RandomCrop(image_size), 43 | transforms.RandomHorizontalFlip(p=0.5), 44 | transforms.ColorJitter(), 45 | transforms.ToTensor(), 46 | transforms.Normalize( 47 | mean=torch.tensor(mean), 48 | std=torch.tensor(std)) 49 | ]) 50 | 51 | test_transform = transforms.Compose([ 52 | transforms.Resize(int(image_size / crop_pct), interpolation), 53 | transforms.CenterCrop(image_size), 54 | transforms.ToTensor(), 55 | transforms.Normalize( 56 | mean=torch.tensor(mean), 57 | std=torch.tensor(std)) 58 | ]) 59 | 60 | 61 | elif transform_type == 'OpenImages': 62 | 63 | mean = (0.485, 0.456, 0.406) 64 | std = (0.229, 0.224, 0.225) 65 | interpolation = args.interpolation 66 | crop_pct = args.crop_pct 67 | 68 | train_transform = transforms.Compose([ 69 | transforms.Resize(int(image_size / crop_pct), interpolation), 70 | transforms.RandomCrop(image_size), 71 | transforms.RandomHorizontalFlip(p=0.5), 72 | transforms.ColorJitter(), 73 | transforms.ToTensor(), 74 | transforms.Normalize( 75 | mean=torch.tensor(mean), 76 | std=torch.tensor(std)) 77 | ]) 78 | 79 | test_transform = transforms.Compose([ 80 | transforms.Resize(int(image_size / crop_pct), interpolation), 81 | transforms.CenterCrop(image_size), 82 | transforms.ToTensor(), 83 | transforms.Normalize( 84 | mean=torch.tensor(mean), 85 | std=torch.tensor(std)) 86 | ]) 87 | 88 | 89 | 90 | return (train_transform, test_transform) -------------------------------------------------------------------------------- /data/get_datasets.py: -------------------------------------------------------------------------------- 1 | 2 | from data.imagenet import get_imagenet_datasets, get_imagenet_datasets_cluster, get_imagenet_datasets_gcam, get_imagenet_datasets_estimate_k 3 | from data.inatloc import get_inatloc_datasets, get_inatloc_datasets_cluster, get_inatloc_datasets_gcam, get_inatloc_datasets_estimate_k 4 | from data.openimages import get_openimages_datasets, get_openimages_datasets_cluster, get_openimages_datasets_gcam, get_openimages_datasets_estimate_k 5 | 6 | 7 | 8 | get_dataset_funcs = { 9 | 'ImageNet': get_imagenet_datasets, 10 | 'iNatLoc': get_inatloc_datasets, 11 | 'OpenImages': get_openimages_datasets 12 | } 13 | 14 | get_dataset_cluster_funcs = { 15 | 'ImageNet': get_imagenet_datasets_cluster, 16 | 'iNatLoc': get_inatloc_datasets_cluster, 17 | 'OpenImages': get_openimages_datasets_cluster 18 | } 19 | 20 | get_dataset_gcam_funcs = { 21 | 'ImageNet': get_imagenet_datasets_gcam, 22 | 'iNatLoc': get_inatloc_datasets_gcam, 23 | 'OpenImages': get_openimages_datasets_gcam 24 | } 25 | 26 | get_dataset_estimate_k_funcs = { 27 | 'ImageNet': get_imagenet_datasets_estimate_k, 28 | 'iNatLoc': get_inatloc_datasets_estimate_k, 29 | 'OpenImages': get_openimages_datasets_estimate_k 30 | } 31 | 32 | def get_datasets(dataset_name, train_transform, test_transform): 33 | 34 | """ 35 | :return: train_dataset: labelled and unlabelled 36 | test_dataset, 37 | val_dataset, 38 | """ 39 | if dataset_name not in get_dataset_funcs.keys(): 40 | raise ValueError 41 | 42 | # Get datasets 43 | get_dataset_f = get_dataset_funcs[dataset_name] 44 | train_dataset, eval_dataset, val_dataset = get_dataset_f(train_transform=train_transform, test_transform=test_transform) 45 | 46 | return train_dataset, eval_dataset, val_dataset 47 | 48 | 49 | 50 | def get_datasets_cluster(dataset_name, test_transform): 51 | 52 | if dataset_name not in get_dataset_cluster_funcs.keys(): 53 | raise ValueError 54 | 55 | # Get datasets 56 | get_dataset_f = get_dataset_cluster_funcs[dataset_name] 57 | test_dataset = get_dataset_f(test_transform=test_transform) 58 | 59 | return test_dataset 60 | 61 | 62 | 63 | def get_datasets_gcam(dataset_name, test_transform, target_and_pred): 64 | 65 | if dataset_name not in get_dataset_gcam_funcs.keys(): 66 | raise ValueError 67 | 68 | # Get datasets 69 | get_dataset_f = get_dataset_gcam_funcs[dataset_name] 70 | test_dataset = get_dataset_f(test_transform=test_transform, target_and_pred=target_and_pred) 71 | 72 | return test_dataset 73 | 74 | 75 | 76 | def get_datasets_estimate_k(dataset_name, test_transform): 77 | 78 | if dataset_name not in get_dataset_estimate_k_funcs.keys(): 79 | raise ValueError 80 | 81 | # Get datasets 82 | get_dataset_f = get_dataset_estimate_k_funcs[dataset_name] 83 | val_dataset = get_dataset_f(test_transform=test_transform) 84 | 85 | return val_dataset 86 | 87 | 88 | 89 | def get_class_splits(args): 90 | 91 | # ------------- 92 | # GET CLASS SPLITS 93 | # ------------- 94 | if args.dataset_name == 'ImageNet': 95 | 96 | args.image_size = 224 97 | args.known_categories = range(500) 98 | args.novel_categories = range(500, 1000) 99 | args.nov_s_categories = range(500, 750) 100 | args.nov_d_categories = range(750, 1000) 101 | 102 | 103 | elif args.dataset_name == 'iNatLoc': 104 | args.image_size = 224 105 | args.known_categories = range(250) 106 | args.novel_categories = range(250, 500) 107 | args.nov_s_categories = range(250, 375) 108 | args.nov_d_categories = range(375, 500) 109 | 110 | 111 | elif args.dataset_name == 'OpenImages': 112 | 113 | args.image_size = 224 114 | args.known_categories = range(75) 115 | args.novel_categories = range(75, 150) 116 | 117 | else: 118 | 119 | raise NotImplementedError 120 | 121 | return args -------------------------------------------------------------------------------- /project_utils/cluster_and_log_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import linear_sum_assignment as linear_assignment 3 | 4 | def cluster_acc(y_true, y_pred, return_ind=False): 5 | """ 6 | Calculate clustering accuracy. Require scikit-learn installed 7 | 8 | # Arguments 9 | y: true labels, numpy.array with shape `(n_samples,)` 10 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 11 | 12 | # Return 13 | accuracy, in [0,1] 14 | """ 15 | y_true = y_true.astype(int) 16 | assert y_pred.size == y_true.size 17 | D = max(y_pred.max(), y_true.max()) + 1 18 | w = np.zeros((D, D), dtype=int) 19 | for i in range(y_pred.size): 20 | w[y_pred[i], y_true[i]] += 1 21 | 22 | ind = linear_assignment(w.max() - w) 23 | ind = np.vstack(ind).T 24 | 25 | if return_ind: 26 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w 27 | else: 28 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 29 | 30 | 31 | def split_cluster_acc_v1(y_true, y_pred, mask): 32 | 33 | """ 34 | Evaluate clustering metrics on two subsets of data, as defined by the mask 'mask' 35 | (Mask usually corresponding to `known' and `novel' categories in OWSOL setting) 36 | :param targets: All ground truth labels 37 | :param preds: All predictions 38 | :param mask: Mask defining two subsets 39 | :return: 40 | """ 41 | 42 | mask = mask.astype(bool) 43 | y_true = y_true.astype(int) 44 | y_pred = y_pred.astype(int) 45 | weight = mask.mean() 46 | 47 | known_acc, ind_known, w_known = cluster_acc(y_true[mask], y_pred[mask], return_ind=True) 48 | novel_acc, ind_novel, w_novel = cluster_acc(y_true[~mask], y_pred[~mask], return_ind=True) 49 | total_acc = weight * known_acc + (1 - weight) * novel_acc 50 | 51 | # j:gt i:pred 52 | ind_map_known = {i: j for i, j in ind_known} 53 | ind_map_novel = {i: j for i, j in ind_novel} 54 | 55 | return total_acc, known_acc, novel_acc, ind_map_known, ind_map_novel 56 | 57 | def split_cluster_acc_v2(y_true, y_pred, mask): 58 | 59 | """ 60 | Calculate clustering accuracy. Require scikit-learn installed 61 | First compute linear assignment on all data, then look at how good the accuracy is on subsets 62 | 63 | # Arguments 64 | mask: Which instances come from known categories (True) and which ones come from novel categories (False) 65 | y: true labels, numpy.array with shape `(n_samples,)` 66 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 67 | 68 | # Return 69 | accuracy, in [0,1] 70 | """ 71 | y_true = y_true.astype(int) 72 | 73 | known_categories_gt = set(y_true[mask]) 74 | novel_categories_gt = set(y_true[~mask]) 75 | 76 | assert y_pred.size == y_true.size 77 | D = max(y_pred.max(), y_true.max()) + 1 78 | w = np.zeros((D, D), dtype=int) 79 | for i in range(y_pred.size): 80 | w[y_pred[i], y_true[i]] += 1 81 | 82 | # max - ori --> min 83 | ind = linear_assignment(w.max() - w) 84 | ind = np.vstack(ind).T 85 | 86 | # j:gt i:pred 87 | ind_map = {j: i for i, j in ind} 88 | ind_map_pre2gt = {i: j for i, j in ind} 89 | total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 90 | 91 | known_acc = 0 92 | total_known_instances = 0 93 | for i in known_categories_gt: 94 | known_acc += w[ind_map[i], i] 95 | total_known_instances += sum(w[:, i]) 96 | known_acc /= total_known_instances 97 | 98 | novel_acc = 0 99 | total_novel_instances = 0 100 | for i in novel_categories_gt: 101 | novel_acc += w[ind_map[i], i] 102 | total_novel_instances += sum(w[:, i]) 103 | novel_acc /= total_novel_instances 104 | 105 | return total_acc, known_acc, novel_acc, ind_map_pre2gt 106 | 107 | 108 | 109 | def log_accs_from_preds(y_true, y_pred, mask): 110 | 111 | 112 | mask = mask.astype(bool) 113 | y_true = y_true.astype(int) 114 | y_pred = y_pred.astype(int) 115 | 116 | 117 | all_acc, known_acc, novel_acc, _, = split_cluster_acc_v2(y_true, y_pred, mask) 118 | to_return = (all_acc, known_acc, novel_acc) 119 | 120 | return to_return 121 | 122 | 123 | def log_accs_from_preds_infer(y_true, y_pred, mask): 124 | 125 | 126 | mask = mask.astype(bool) 127 | y_true = y_true.astype(int) 128 | y_pred = y_pred.astype(int) 129 | 130 | all_acc, known_acc, novel_acc, ind_map_pre2gt = split_cluster_acc_v2(y_true, y_pred, mask) 131 | to_return = (all_acc, known_acc, novel_acc, ind_map_pre2gt) 132 | 133 | 134 | return to_return -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | padding=dilation, bias=False, dilation=dilation) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = nn.BatchNorm2d(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class ResNet(nn.Module): 47 | 48 | def __init__(self, block, layers, num_classes, strides=(2, 2, 2, 1), dilations=(1, 1, 1, 1), large_feature_map=True): 49 | 50 | stride_l3 = 1 if large_feature_map else 2 51 | self.inplanes = 64 52 | super(ResNet, self).__init__() 53 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=strides[0], padding=3, 54 | bias=False) 55 | self.bn1 = nn.BatchNorm2d(64) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 58 | 59 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1, dilation=dilations[0]) 60 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1]) 61 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_l3, dilation=dilations[2]) 62 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3]) 63 | 64 | 65 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 66 | self.fc = nn.Linear(512 * block.expansion, num_classes) 67 | # self.fc = nn.Sequential(nn.Linear(512 * block.expansion, 512 * block.expansion), nn.ReLU(), nn.Linear(512 * block.expansion, num_classes)) 68 | 69 | 70 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 71 | downsample = None 72 | if stride != 1 or self.inplanes != planes * block.expansion: 73 | downsample = nn.Sequential( 74 | nn.Conv2d(self.inplanes, planes * block.expansion, 75 | kernel_size=1, stride=stride, bias=False), 76 | nn.BatchNorm2d(planes * block.expansion), 77 | ) 78 | 79 | layers = [block(self.inplanes, planes, stride, downsample, dilation=1)] 80 | self.inplanes = planes * block.expansion 81 | for i in range(1, blocks): 82 | layers.append(block(self.inplanes, planes, dilation=dilation)) 83 | 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | x = self.conv1(x) 88 | x = self.bn1(x) 89 | x = self.relu(x) 90 | x = self.maxpool(x) 91 | 92 | x = self.layer1(x) 93 | x = self.layer2(x) 94 | x = self.layer3(x) 95 | x = self.layer4(x) 96 | feature_map = x 97 | # feature_map = x.detach().clone() 98 | 99 | x = self.avgpool(x) 100 | x = x.view(x.size(0), -1) 101 | feature = x 102 | logits = self.fc(x) 103 | 104 | 105 | return {'feature': feature, 'feature_map': feature_map, 'logits': logits} 106 | 107 | 108 | 109 | def resnet50(num_classes, pretrained, **kwargs): 110 | 111 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes, **kwargs) 112 | if pretrained != None: 113 | state_dict = {} 114 | old_state_dict = torch.load(pretrained, map_location='cpu')['state_dict'] 115 | for key in old_state_dict.keys(): 116 | if key.startswith('module.encoder_q'): 117 | print(key) 118 | new_key = key.split('encoder_q.')[1] 119 | state_dict[new_key] = old_state_dict[key] 120 | model.load_state_dict(state_dict, strict=False) 121 | print("resnet50 pretrained initialized") 122 | 123 | return model 124 | 125 | -------------------------------------------------------------------------------- /project_utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-present NAVER Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | this software and associated documentation files (the "Software"), to deal in 6 | the Software without restriction, including without limitation the rights to 7 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | the Software, and to permit persons to whom the Software is furnished to do so, 9 | subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | """ 21 | 22 | import json 23 | import numpy as np 24 | import os 25 | import sys 26 | 27 | 28 | class Logger(object): 29 | """Log stdout messages.""" 30 | 31 | def __init__(self, outfile): 32 | self.terminal = sys.stdout 33 | self.log = open(outfile, "w") 34 | sys.stdout = self 35 | 36 | def write(self, message): 37 | self.terminal.write(message) 38 | self.log.write(message) 39 | 40 | def flush(self): 41 | self.terminal.flush() 42 | 43 | 44 | def t2n(t): 45 | return t.detach().cpu().numpy().astype(np.float) 46 | 47 | 48 | def check_scoremap_validity(scoremap): 49 | if not isinstance(scoremap, np.ndarray): 50 | raise TypeError("Scoremap must be a numpy array; it is {}." 51 | .format(type(scoremap))) 52 | if scoremap.dtype != np.float: 53 | raise TypeError("Scoremap must be of np.float type; it is of {} type." 54 | .format(scoremap.dtype)) 55 | if len(scoremap.shape) != 2: 56 | raise ValueError("Scoremap must be a 2D array; it is {}D." 57 | .format(len(scoremap.shape))) 58 | if np.isnan(scoremap).any(): 59 | raise ValueError("Scoremap must not contain nans.") 60 | if (scoremap > 1).any() or (scoremap < 0).any(): 61 | raise ValueError("Scoremap must be in range [0, 1]." 62 | "scoremap.min()={}, scoremap.max()={}." 63 | .format(scoremap.min(), scoremap.max())) 64 | 65 | 66 | def string_contains_any(string, substring_list): 67 | for substring in substring_list: 68 | if substring in string: 69 | return True 70 | return False 71 | 72 | 73 | class Reporter(object): 74 | def __init__(self, reporter_log_root, epoch): 75 | self.log_file = os.path.join(reporter_log_root, str(epoch)) 76 | self.epoch = epoch 77 | self.report_dict = { 78 | 'summary': True, 79 | 'step': self.epoch, 80 | } 81 | 82 | def add(self, key, val): 83 | self.report_dict.update({key: val}) 84 | 85 | def write(self): 86 | log_file = self.log_file 87 | while os.path.isfile(log_file): 88 | log_file += '_' 89 | with open(log_file, 'w') as f: 90 | f.write(json.dumps(self.report_dict)) 91 | 92 | 93 | def check_box_convention(boxes, convention): 94 | """ 95 | Args: 96 | boxes: numpy.ndarray(dtype=np.int or np.float, shape=(num_boxes, 4)) 97 | convention: string. One of ['x0y0x1y1', 'xywh']. 98 | Raises: 99 | RuntimeError if box does not meet the convention. 100 | """ 101 | if (boxes < 0).any(): 102 | raise RuntimeError("Box coordinates must be non-negative.") 103 | 104 | if len(boxes.shape) == 1: 105 | boxes = np.expand_dims(boxes, 0) 106 | elif len(boxes.shape) != 2: 107 | raise RuntimeError("Box array must have dimension (4) or " 108 | "(num_boxes, 4).") 109 | 110 | if boxes.shape[1] != 4: 111 | raise RuntimeError("Box array must have dimension (4) or " 112 | "(num_boxes, 4).") 113 | 114 | if convention == 'x0y0x1y1': 115 | widths = boxes[:, 2] - boxes[:, 0] 116 | heights = boxes[:, 3] - boxes[:, 1] 117 | elif convention == 'xywh': 118 | widths = boxes[:, 2] 119 | heights = boxes[:, 3] 120 | else: 121 | raise ValueError("Unknown convention {}.".format(convention)) 122 | 123 | if (widths < 0).any() or (heights < 0).any(): 124 | raise RuntimeError("Boxes do not follow the {} convention." 125 | .format(convention)) 126 | -------------------------------------------------------------------------------- /methods/estimate_k.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from scipy.optimize import minimize_scalar 5 | from functools import partial 6 | 7 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score 8 | from sklearn.metrics import adjusted_rand_score as ari_score 9 | from sklearn.cluster import KMeans 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision import transforms 14 | from torch.utils.data import DataLoader 15 | 16 | from models.resnet import resnet50 17 | from data.get_datasets import get_class_splits, get_datasets_estimate_k 18 | from project_utils.cluster_and_log_utils import cluster_acc 19 | 20 | 21 | 22 | # TODO: Debug 23 | import warnings 24 | warnings.filterwarnings("ignore", category=DeprecationWarning) 25 | 26 | def test_kmeans(K, all_feats=None, targets=None, mask_cls=None, args=None, verbose=False): 27 | 28 | """ 29 | In this case, the val loader needs to have known and novel categories 30 | """ 31 | 32 | if K is None: 33 | K = args.num_known_categories + args.num_novel_categories 34 | 35 | 36 | print('Fitting K-Means...') 37 | kmeans = KMeans(n_clusters=K, random_state=0).fit(all_feats) 38 | preds = kmeans.labels_ 39 | 40 | # ----------------------- 41 | # EVALUATE 42 | # ----------------------- 43 | mask = mask_cls 44 | 45 | 46 | known_acc, known_nmi, known_ari = cluster_acc(targets.astype(int)[mask], preds.astype(int)[mask]), \ 47 | nmi_score(targets[mask], preds[mask]), \ 48 | ari_score(targets[mask], preds[mask]) 49 | 50 | novel_acc, novel_nmi, novel_ari = cluster_acc(targets.astype(int)[~mask], 51 | preds.astype(int)[~mask]), \ 52 | nmi_score(targets[~mask], preds[~mask]), \ 53 | ari_score(targets[~mask], preds[~mask]) 54 | 55 | if verbose: 56 | print('K') 57 | print('Known Categories acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(known_acc, known_nmi, 58 | known_ari)) 59 | print('Novel Categories acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(novel_acc, novel_nmi, 60 | novel_ari)) 61 | 62 | 63 | return known_acc 64 | 65 | 66 | 67 | def test_kmeans_for_scipy(K, all_feats=None, targets=None, mask_cls=None, args=None, verbose=False): 68 | 69 | """ 70 | In this case, the val loader needs to have known and novel categories 71 | """ 72 | 73 | K = int(K) 74 | 75 | print(f'Fitting K-Means for K = {K}...') 76 | kmeans = KMeans(n_clusters=K, random_state=0).fit(all_feats) 77 | preds = kmeans.labels_ 78 | 79 | # ----------------------- 80 | # EVALUATE 81 | # ----------------------- 82 | mask = mask_cls 83 | 84 | 85 | known_acc, known_nmi, known_ari = cluster_acc(targets.astype(int)[mask], preds.astype(int)[mask]), \ 86 | nmi_score(targets[mask], preds[mask]), \ 87 | ari_score(targets[mask], preds[mask]) 88 | 89 | novel_acc, novel_nmi, novel_ari = cluster_acc(targets.astype(int)[~mask], 90 | preds.astype(int)[~mask]), \ 91 | nmi_score(targets[~mask], preds[~mask]), \ 92 | ari_score(targets[~mask], preds[~mask]) 93 | 94 | print(f'K = {K}') 95 | print('Known Categories acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(known_acc, known_nmi, 96 | known_ari)) 97 | print('Novel Categories acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(novel_acc, novel_nmi, 98 | novel_ari)) 99 | 100 | return -known_acc 101 | 102 | 103 | def binary_search(all_feats, targets, mask_cls, args): 104 | 105 | min_classes = args.num_known_categories 106 | 107 | # Iter 0 108 | big_k = args.max_classes 109 | small_k = min_classes 110 | diff = big_k - small_k 111 | middle_k = int(0.5 * diff + small_k) 112 | 113 | known_acc_big = test_kmeans(big_k, all_feats, targets, mask_cls, args) 114 | known_acc_small = test_kmeans(small_k, all_feats, targets, mask_cls, args) 115 | known_acc_middle = test_kmeans(middle_k, all_feats, targets, mask_cls, args) 116 | 117 | print(f'Iter 0: BigK {big_k}, Acc {known_acc_big:.4f} | MiddleK {middle_k}, Acc {known_acc_middle:.4f} | SmallK {small_k}, Acc {known_acc_small:.4f} ') 118 | all_accs = [known_acc_small, known_acc_middle, known_acc_big] 119 | best_acc_so_far = np.max(all_accs) 120 | best_acc_at_k = np.array([small_k, middle_k, big_k])[np.argmax(all_accs)] 121 | print(f'Best Acc so far {best_acc_so_far:.4f} at K {best_acc_at_k}') 122 | 123 | for i in range(1, int(np.log2(diff)) + 1): 124 | 125 | if known_acc_big > known_acc_small: 126 | 127 | best_acc = max(known_acc_middle, known_acc_big) 128 | 129 | small_k = middle_k 130 | known_acc_small = known_acc_middle 131 | diff = big_k - small_k 132 | middle_k = int(0.5 * diff + small_k) 133 | 134 | else: 135 | 136 | best_acc = max(known_acc_middle, known_acc_small) 137 | big_k = middle_k 138 | 139 | diff = big_k - small_k 140 | middle_k = int(0.5 * diff + small_k) 141 | known_acc_big = known_acc_middle 142 | 143 | known_acc_middle = test_kmeans(middle_k, all_feats, targets, mask_cls, args) 144 | 145 | print(f'Iter {i}: BigK {big_k}, Acc {known_acc_big:.4f} | MiddleK {middle_k}, Acc {known_acc_middle:.4f} | SmallK {small_k}, Acc {known_acc_small:.4f} ') 146 | all_accs = [known_acc_small, known_acc_middle, known_acc_big] 147 | best_acc_so_far = np.max(all_accs) 148 | best_acc_at_k = np.array([small_k, middle_k, big_k])[np.argmax(all_accs)] 149 | print(f'Best Acc so far {best_acc_so_far:.4f} at K {best_acc_at_k}') 150 | 151 | 152 | def scipy_optimise(all_feats, targets, mask_cls, args): 153 | 154 | small_k = args.num_known_categories 155 | big_k = args.max_classes 156 | 157 | test_k_means_partial = partial(test_kmeans_for_scipy, all_feats=all_feats, targets=targets, mask_cls=mask_cls, args=args, verbose=True) 158 | res = minimize_scalar(test_k_means_partial, bounds=(small_k, big_k), method='bounded', options={'disp': True}) 159 | print(f'Optimal K is {res.x}') 160 | 161 | 162 | if __name__ == "__main__": 163 | 164 | parser = argparse.ArgumentParser( 165 | description='estimate k', 166 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 167 | parser.add_argument('--batch_size', default=128, type=int) 168 | parser.add_argument('--workers', default=16, type=int) 169 | parser.add_argument('--max_classes', default=1000, type=int) 170 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50') 171 | parser.add_argument('--search_mode', type=str, default='brent', help='Mode for black box optimisation') 172 | parser.add_argument('--dataset_name', type=str, default='iNatLoc', help='options: ImageNet, iNatLoc, OpenImages') 173 | parser.add_argument('--model_path', type=str, default='/data/zhaochuan/NCL/code/GWSOL/save_imagenet/checkpoint_0009.pth.tar') 174 | 175 | # ---------------------- 176 | # INIT 177 | # ---------------------- 178 | args = parser.parse_args() 179 | args = get_class_splits(args) 180 | args.num_known_categories = len(args.known_categories) 181 | args.num_novel_categories = len(args.novel_categories) 182 | print(args) 183 | 184 | 185 | print("=> creating model '{}'".format(args.arch)) 186 | model = resnet50(num_classes=1000) 187 | 188 | # load pretrained 189 | state_dict = {} 190 | old_state_dict = torch.load(args.model_path, map_location='cpu')['state_dict'] 191 | for key in old_state_dict.keys(): 192 | if key.startswith('module.encoder_q'): 193 | print(key) 194 | new_key = key.split('encoder_q.')[1] 195 | state_dict[new_key] = old_state_dict[key] 196 | model.load_state_dict(state_dict, strict=False) 197 | model = nn.DataParallel(model) 198 | for name,parameter in model.named_parameters(): 199 | print(name) 200 | print(parameter) 201 | model.cuda() 202 | model.eval() 203 | 204 | # -------------------- 205 | # DATASETS 206 | # -------------------- 207 | print('Building datasets...') 208 | 209 | test_transform = transforms.Compose([ 210 | transforms.Resize((224, 224)), 211 | transforms.ToTensor(), 212 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 213 | ]) 214 | 215 | val_dataset = get_datasets_estimate_k(args.dataset_name, test_transform) 216 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.workers, shuffle=False) 217 | 218 | 219 | all_feats = [] 220 | targets = np.array([]) 221 | mask_cls = np.array([]) # From all the data, which instances belong to seen classes 222 | 223 | print('Collating features...') 224 | # First extract all features 225 | for batch_idx, (images, label, img_ids) in enumerate(val_loader): 226 | 227 | feats = model(images.cuda())['feature'] 228 | 229 | feats = torch.nn.functional.normalize(feats, dim=-1) 230 | 231 | all_feats.append(feats.detach().cpu().numpy()) 232 | 233 | targets = np.append(targets, label.cpu().numpy()) 234 | mask_cls = np.append(mask_cls, np.array([True if x.item() in args.known_categories 235 | else False for x in label])) 236 | 237 | # ----------------------- 238 | # K-MEANS 239 | # ----------------------- 240 | mask_cls = mask_cls.astype(bool) 241 | all_feats = np.concatenate(all_feats) 242 | 243 | print('Testing on the val set...') 244 | if args.search_mode == 'brent': 245 | print('Optimising with Brents algorithm') 246 | scipy_optimise(all_feats=all_feats, targets=targets, mask_cls=mask_cls, args=args) 247 | else: 248 | binary_search(all_feats=all_feats, targets=targets, mask_cls=mask_cls, args=args) -------------------------------------------------------------------------------- /models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | from random import sample 5 | 6 | 7 | class MoCo(nn.Module): 8 | """ 9 | Build a MoCo model with: a query encoder, a key encoder, and a queue 10 | https://arxiv.org/abs/1911.05722 11 | """ 12 | def __init__(self, base_encoder, pretrained_path, scl_queue, num_labeled_classes=500, num_multi_centers=5, dim=128, K=16384, m=0.999, T=0.07, step=12): 13 | """ 14 | dim: feature dimension (default: 128) 15 | K: queue size; number of negative keys (default: 65536) 16 | m: moco momentum of updating key encoder (default: 0.999) 17 | T: softmax temperature (default: 0.07) 18 | """ 19 | super(MoCo, self).__init__() 20 | 21 | self.K = K 22 | self.m = m 23 | self.T = T 24 | self.step = step 25 | self.num_labeled_classes = num_labeled_classes 26 | self.num_multi_centers = num_multi_centers 27 | self.scl_queue_loc = torch.arange(0, num_labeled_classes * step, step=step, dtype=torch.long) 28 | 29 | # create the encoders 30 | # num_classes is the output fc dimension 31 | self.encoder_q = base_encoder(num_classes=dim, pretrained=pretrained_path) 32 | self.encoder_k = base_encoder(num_classes=dim, pretrained=pretrained_path) 33 | 34 | dim_mlp = self.encoder_q.fc.weight.shape[1] 35 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 36 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 37 | 38 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 39 | param_k.data.copy_(param_q.data) # initialize 40 | param_k.requires_grad = False # not update by gradient 41 | 42 | 43 | # create the queue for supervised contrastive learning 44 | self.register_buffer("scl_queue", scl_queue) 45 | self.scl_queue = nn.functional.normalize(self.scl_queue, dim=0) 46 | self.register_buffer("scl_queue_ptr", torch.arange(0, num_labeled_classes * step, step=step, dtype=torch.long)) 47 | 48 | 49 | 50 | @torch.no_grad() 51 | def _momentum_update_key_encoder(self): 52 | """ 53 | Momentum update of the key encoder 54 | """ 55 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 56 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 57 | 58 | @torch.no_grad() 59 | def _scl_dequeue_and_enqueue(self, keys, labels): 60 | # gather keys before updating queue 61 | keys = concat_all_gather(keys) 62 | labels = concat_all_gather(labels) 63 | 64 | for i, key in enumerate(keys): 65 | 66 | ptr = int(self.scl_queue_ptr[labels[i]]) 67 | 68 | 69 | # replace the keys at ptr (dequeue and enqueue) --> self.scl_queue: (dim, K_supervised) 70 | self.scl_queue[:, ptr] = key.T 71 | ptr = (ptr + 1) % self.step + self.scl_queue_loc[labels[i]] # move pointer 72 | 73 | self.scl_queue_ptr[labels[i]] = ptr 74 | 75 | @torch.no_grad() 76 | def _batch_shuffle_ddp(self, x): 77 | """ 78 | Batch shuffle, for making use of BatchNorm. 79 | *** Only support DistributedDataParallel (DDP) model. *** 80 | """ 81 | # gather from all gpus 82 | batch_size_this = x.shape[0] 83 | x_gather = concat_all_gather(x) 84 | batch_size_all = x_gather.shape[0] 85 | 86 | num_gpus = batch_size_all // batch_size_this 87 | 88 | # random shuffle index 89 | idx_shuffle = torch.randperm(batch_size_all).cuda() 90 | 91 | # broadcast to all gpus 92 | torch.distributed.broadcast(idx_shuffle, src=0) 93 | 94 | # index for restoring 95 | idx_unshuffle = torch.argsort(idx_shuffle) 96 | 97 | # shuffled index for this gpu 98 | gpu_idx = torch.distributed.get_rank() 99 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 100 | 101 | return x_gather[idx_this], idx_unshuffle 102 | 103 | @torch.no_grad() 104 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 105 | """ 106 | Undo batch shuffle. 107 | *** Only support DistributedDataParallel (DDP) model. *** 108 | """ 109 | # gather from all gpus 110 | batch_size_this = x.shape[0] 111 | x_gather = concat_all_gather(x) 112 | batch_size_all = x_gather.shape[0] 113 | 114 | num_gpus = batch_size_all // batch_size_this 115 | 116 | # restored index for this gpu 117 | gpu_idx = torch.distributed.get_rank() 118 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 119 | 120 | return x_gather[idx_this] 121 | 122 | def forward(self, im_q=None, im_q_lb=None, im_k_lb=None, targets=None, is_eval=False, cluster_result=None, index=None): 123 | 124 | # extract features for mcl 125 | if is_eval: 126 | k = self.encoder_k(im_q)['logits'] 127 | k = nn.functional.normalize(k, dim=1) 128 | return k 129 | 130 | """ 131 | Input: 132 | im_q: a batch of images for mcl 133 | im_q_lb: a batch of query images for scl 134 | im_k_lb: a batch of key images for scl 135 | Output: 136 | scl_logits, scl_labels, mcl_logits, mcl_labels 137 | """ 138 | 139 | # compute query features for mcl 140 | q = self.encoder_q(im_q)['logits'] # queries: NxC 141 | q = nn.functional.normalize(q, dim=1) 142 | 143 | q_lb = self.encoder_q(im_q_lb)['logits'] 144 | q_lb = nn.functional.normalize(q_lb, dim=1) 145 | 146 | 147 | # compute key features 148 | with torch.no_grad(): # no gradient to keys 149 | self._momentum_update_key_encoder() # update the key encoder 150 | 151 | # shuffle for making use of BN 152 | im_k_lb, idx_unshuffle = self._batch_shuffle_ddp(im_k_lb) 153 | 154 | k_lb = self.encoder_k(im_k_lb)['logits'] # keys: NxC 155 | k_lb = nn.functional.normalize(k_lb, dim=1) 156 | 157 | # undo shuffle 158 | k_lb = self._batch_unshuffle_ddp(k_lb, idx_unshuffle) 159 | # print(k_lb.shape) 160 | 161 | # -------------------- 162 | # compute scl_logits 163 | # -------------------- 164 | # Einstein sum is more intuitive 165 | l_aug = torch.einsum('nc,nc->n', [q_lb, k_lb]).unsqueeze(-1) 166 | l_sup = torch.einsum('nc,ck->nk', [q_lb, self.scl_queue.clone().detach()]) 167 | 168 | # scl_logits: N x (1 + self.num_labeled_classes * step) 169 | scl_logits = torch.cat([l_aug, l_sup], dim=1) 170 | 171 | # apply temperature 172 | scl_logits /= self.T 173 | 174 | 175 | 176 | # generate scl_labels to mask negative 177 | ''' 178 | batch_size = 2 179 | one_hot: 180 | [[0, 0, 1], 181 | [1, 0, 0]] 182 | scl_labels(step = 2): 183 | [[0, 0, 0, 0, 1, 1], 184 | [1, 1, 0, 0, 0, 0]] 185 | ''' 186 | batch_size = q_lb.shape[0] 187 | labels = targets.unsqueeze(1).cuda() 188 | 189 | one_hot = torch.zeros((batch_size, self.num_labeled_classes), dtype=torch.long).cuda() 190 | one_hot = one_hot.scatter(1, labels, 1) 191 | scl_labels = one_hot.unsqueeze(-1).repeat(1, 1, self.step).view(batch_size, -1) 192 | 193 | aug_labels = torch.ones([batch_size,1], dtype=torch.long).cuda() 194 | scl_labels = torch.cat([aug_labels, scl_labels], dim=1) # N x (1 + self.num_labeled_classes * step) 195 | 196 | # dequeue and enqueue 197 | self._scl_dequeue_and_enqueue(k_lb,labels.squeeze(1)) 198 | 199 | 200 | 201 | # -------------------- 202 | # compute mcl_logits 203 | # -------------------- 204 | im2cluster, centroids, density = cluster_result['im2cluster'], cluster_result['centroids'], cluster_result['density'] 205 | 206 | centroids = centroids / density.unsqueeze(dim=1) 207 | # get positive centroids 208 | pos_centroid_ids = im2cluster[:, index].reshape(self.num_multi_centers * batch_size) # 5, batch_size --> 5 * batch_size 209 | pos_centroids = centroids[pos_centroid_ids].reshape(self.num_multi_centers, batch_size, -1).mean(dim=0) 210 | 211 | # sample negative centroids 212 | all_centroids_id = [i for i in range(im2cluster.max()+1)] 213 | neg_centroids_id = set(all_centroids_id) - set(pos_centroid_ids.tolist()) 214 | neg_centroids_id = sample(neg_centroids_id, self.K) #sample 16384 negative centroids 215 | neg_centroids = centroids[neg_centroids_id] 216 | 217 | # q --> (Batch, 128) pos_centroids --> (Batch, 128) neg_centroids --> (K, 128) 218 | pos_logits = torch.matmul(q.unsqueeze(dim=1), pos_centroids.unsqueeze(dim=-1)).reshape(-1, 1) # (Batch, 1) 219 | neg_logits = torch.matmul(q, neg_centroids.T) # (Batch, K) 220 | mcl_logits = torch.cat([pos_logits, neg_logits], dim=1) 221 | mcl_labels = torch.zeros(mcl_logits.shape[0], dtype=torch.long).cuda() 222 | 223 | 224 | return scl_logits, scl_labels, mcl_logits, mcl_labels 225 | 226 | 227 | def forward_feature(self, im_q): 228 | k = self.encoder_k(im_q)['feature'] 229 | k = nn.functional.normalize(k, dim=1) 230 | return k 231 | 232 | def forward_feature_map(self, im_q): 233 | q = self.encoder_q(im_q)['feature_map'] 234 | return q 235 | 236 | 237 | # utils 238 | @torch.no_grad() 239 | def concat_all_gather(tensor): 240 | """ 241 | Performs all_gather operation on the provided tensors. 242 | *** Warning ***: torch.distributed.all_gather has no gradient. 243 | """ 244 | tensors_gather = [torch.ones_like(tensor) 245 | for _ in range(torch.distributed.get_world_size())] 246 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 247 | 248 | output = torch.cat(tensors_gather, dim=0) 249 | return output 250 | 251 | 252 | 253 | -------------------------------------------------------------------------------- /data/inatloc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import munch 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | import torchvision.transforms as transforms 9 | from config import inatloc_root, inatloc_meta 10 | 11 | 12 | def mch(**kwargs): 13 | return munch.Munch(dict(**kwargs)) 14 | 15 | 16 | def configure_metadata(metadata_root): 17 | metadata = mch() 18 | metadata.image_ids = os.path.join(metadata_root, 'image_ids.txt') 19 | metadata.image_ids_lb = os.path.join(metadata_root, 'image_ids_labeled.txt') 20 | metadata.class_labels = os.path.join(metadata_root, 'class_labels.txt') 21 | return metadata 22 | 23 | def configure_metadata_infer(metadata_root, cluster_preds_name): 24 | metadata = mch() 25 | metadata.image_ids = os.path.join(metadata_root, 'image_ids.txt') 26 | metadata.class_labels = os.path.join(metadata_root, cluster_preds_name) 27 | metadata.partitions = os.path.join(metadata_root, 'partitions.txt') 28 | metadata.image_sizes = os.path.join(metadata_root, 'image_sizes.txt') 29 | metadata.localization = os.path.join(metadata_root, 'localization.txt') 30 | return metadata 31 | 32 | 33 | def get_image_ids(metadata): 34 | """ 35 | image_ids.txt has the structure 36 | 37 | path/image1.jpg 38 | path/image2.jpg 39 | path/image3.jpg 40 | ... 41 | """ 42 | image_ids = [] 43 | with open(metadata.image_ids) as f: 44 | for line in f.readlines(): 45 | image_ids.append(line.strip('\n')) 46 | return image_ids 47 | 48 | 49 | def get_image_ids_lb(metadata): 50 | """ 51 | image_ids_labeled.txt has the structure 52 | 53 | path/image1.jpg 54 | path/image2.jpg 55 | path/image3.jpg 56 | ... 57 | """ 58 | image_ids_lb = [] 59 | with open(metadata.image_ids_lb) as f: 60 | for line in f.readlines(): 61 | image_ids_lb.append(line.strip('\n')) 62 | return image_ids_lb 63 | 64 | 65 | def get_class_labels(metadata): 66 | """ 67 | class_labels.txt has the structure 68 | 69 | , 70 | path/image1.jpg,0 71 | path/image2.jpg,1 72 | path/image3.jpg,1 73 | ... 74 | """ 75 | class_labels = {} 76 | with open(metadata.class_labels) as f: 77 | for line in f.readlines(): 78 | image_id, class_label_string = line.strip('\n').split(',') 79 | class_labels[image_id] = int(class_label_string) 80 | return class_labels 81 | 82 | 83 | def get_partitions(metadata): 84 | """ 85 | partitions.txt has the structure 86 | 87 | , 0 --> known, 1 --> nov-s, 2 --> nov-d 88 | path/image1.jpg,0 89 | path/image2.jpg,1 90 | path/image3.jpg,2 91 | ... 92 | """ 93 | partitions = {} 94 | with open(metadata.partitions) as f: 95 | for line in f.readlines(): 96 | image_id, partition_string = line.strip('\n').split(',') 97 | partitions[image_id] = int(partition_string) 98 | return partitions 99 | 100 | 101 | def get_class_labels_and_pred(metadata): 102 | """ 103 | class_labels_preds.txt has the structure 104 | 105 | ,, 106 | path/image1.jpg,0,0 107 | path/image2.jpg,1,1 108 | path/image3.jpg,2,2 109 | ... 110 | """ 111 | class_labels_and_preds = {} 112 | with open(metadata.class_labels) as f: 113 | for line in f.readlines(): 114 | image_id, class_label_string, pred_string = line.strip('\n').split(',') 115 | class_labels_and_preds[image_id] = (int(class_label_string), int(pred_string)) 116 | return class_labels_and_preds 117 | 118 | 119 | def get_image_sizes(metadata): 120 | """ 121 | image_sizes.txt has the structure 122 | 123 | ,, 124 | path/image1.jpg,500,300 125 | path/image2.jpg,1000,600 126 | path/image3.jpg,500,300 127 | ... 128 | """ 129 | image_sizes = {} 130 | with open(metadata.image_sizes) as f: 131 | for line in f.readlines(): 132 | image_id, ws, hs = line.strip('\n').split(',') 133 | w, h = int(ws), int(hs) 134 | image_sizes[image_id] = (w, h) 135 | return image_sizes 136 | 137 | 138 | def get_bounding_boxes(metadata): 139 | """ 140 | localization.txt (for bounding box) has the structure 141 | 142 | ,,,, 143 | path/image1.jpg,156,163,318,230 144 | path/image1.jpg,23,12,101,259 145 | path/image2.jpg,143,142,394,248 146 | path/image3.jpg,28,94,485,303 147 | ... 148 | 149 | One image may contain multiple boxes (multiple boxes for the same path). 150 | """ 151 | boxes = {} 152 | with open(metadata.localization) as f: 153 | for line in f.readlines(): 154 | image_id, x0s, x1s, y0s, y1s = line.strip('\n').split(',') 155 | x0, x1, y0, y1 = int(x0s), int(x1s), int(y0s), int(y1s) 156 | if image_id in boxes: 157 | boxes[image_id].append((x0, x1, y0, y1)) 158 | else: 159 | boxes[image_id] = [(x0, x1, y0, y1)] 160 | return boxes 161 | 162 | def pil_loader(path: str): 163 | with open(path, 'rb') as f: 164 | img = Image.open(f) 165 | return img.convert('RGB') 166 | 167 | class TrainDataset(Dataset): 168 | def __init__(self, data_root, meta_root, transform): 169 | self.data_root = data_root 170 | self.transform = transform 171 | self.metadata = configure_metadata(meta_root) 172 | self.image_ids = get_image_ids(self.metadata) 173 | self.image_ids_lb = get_image_ids_lb(self.metadata) 174 | self.image_labels = get_class_labels(self.metadata) 175 | 176 | def __getitem__(self, idx): 177 | image_id = self.image_ids[idx] 178 | image_lb_id = self.image_ids_lb[idx] 179 | image_lb_label = self.image_labels[image_lb_id] 180 | image = pil_loader(os.path.join(self.data_root, image_id)) 181 | image_lb = pil_loader(os.path.join(self.data_root, image_lb_id)) 182 | image = self.transform(image) 183 | image_lb = self.transform(image_lb) 184 | 185 | 186 | return image, image_lb, image_lb_label, idx 187 | 188 | def __len__(self): 189 | return len(self.image_ids) 190 | 191 | 192 | class MclDataset(Dataset): 193 | def __init__(self, data_root, meta_root, transform): 194 | self.data_root = data_root 195 | self.transform = transform 196 | self.metadata = configure_metadata(meta_root) 197 | self.image_ids = get_image_ids(self.metadata) 198 | 199 | def __getitem__(self, idx): 200 | image_id = self.image_ids[idx] 201 | image = pil_loader(os.path.join(self.data_root, image_id)) 202 | image = self.transform(image) 203 | return image, idx 204 | 205 | def __len__(self): 206 | return len(self.image_ids) 207 | 208 | 209 | class EvalDataset(Dataset): 210 | def __init__(self, data_root, meta_root, transform): 211 | self.data_root = data_root 212 | self.transform = transform 213 | self.metadata = configure_metadata(meta_root) 214 | self.image_ids = get_image_ids(self.metadata) 215 | self.image_labels = get_class_labels(self.metadata) 216 | 217 | def __getitem__(self, idx): 218 | image_id = self.image_ids[idx] 219 | image_label = self.image_labels[image_id] 220 | image = pil_loader(os.path.join(self.data_root, image_id)) 221 | image = self.transform(image) 222 | return image, image_label 223 | 224 | def __len__(self): 225 | return len(self.image_ids) 226 | 227 | 228 | class ClusterDataset(Dataset): 229 | def __init__(self, data_root, meta_root, transform): 230 | self.data_root = data_root 231 | self.transform = transform 232 | self.metadata = configure_metadata(meta_root) 233 | self.image_ids = get_image_ids(self.metadata) 234 | self.image_labels = get_class_labels(self.metadata) 235 | 236 | def __getitem__(self, idx): 237 | image_id = self.image_ids[idx] 238 | image_label = self.image_labels[image_id] 239 | image = pil_loader(os.path.join(self.data_root, image_id)) 240 | image = self.transform(image) 241 | return image, image_label, image_id 242 | 243 | def __len__(self): 244 | return len(self.image_ids) 245 | 246 | 247 | class GcamDataSet(Dataset): 248 | def __init__(self, data_root, meta_root, preds_name, transform): 249 | self.data_root = data_root 250 | self.transform = transform 251 | self.metadata = configure_metadata_infer(meta_root, preds_name) 252 | self.image_ids = get_image_ids(self.metadata) 253 | self.image_labels_and_pred = get_class_labels_and_pred(self.metadata) 254 | 255 | def __getitem__(self, idx): 256 | image_id = self.image_ids[idx] 257 | image_label, pred = self.image_labels_and_pred[image_id] 258 | image = pil_loader(os.path.join(self.data_root, image_id)) 259 | image = self.transform(image) 260 | return image, pred, image_id 261 | 262 | def __len__(self): 263 | return len(self.image_ids) 264 | 265 | def get_inatloc_datasets(train_transform, test_transform): 266 | 267 | meta_train = os.path.join(inatloc_meta, 'train') 268 | meta_val = os.path.join(inatloc_meta, 'val') 269 | meta_test = os.path.join(inatloc_meta, 'test') 270 | 271 | train_dataset = TrainDataset(inatloc_root, meta_train, train_transform) 272 | cluster_dataset = MclDataset(inatloc_root, meta_train, test_transform) 273 | val_dataset = EvalDataset(inatloc_root, meta_val, test_transform) 274 | 275 | # return train_dataset, test_dataset, val_dataset 276 | return train_dataset, cluster_dataset, val_dataset 277 | 278 | 279 | def get_inatloc_datasets_cluster(test_transform): 280 | 281 | meta_test = os.path.join(inatloc_meta, 'test') 282 | test_dataset = ClusterDataset(inatloc_root, meta_test, test_transform) 283 | 284 | return test_dataset 285 | 286 | 287 | def get_inatloc_datasets_gcam(test_transform, target_and_pred): 288 | 289 | meta_test = os.path.join(inatloc_meta, 'test') 290 | test_dataset = GcamDataSet(inatloc_root, meta_test, target_and_pred, test_transform) 291 | 292 | return test_dataset 293 | 294 | def get_inatloc_datasets_estimate_k(test_transform): 295 | 296 | val_test = os.path.join(inatloc_meta, 'val') 297 | val_dataset = ClusterDataset(inatloc_root, val_test, test_transform) 298 | 299 | return val_dataset 300 | 301 | 302 | 303 | if __name__=='__main__': 304 | import numpy as np 305 | np.set_printoptions(threshold=np.inf) 306 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 307 | std=[0.229, 0.224, 0.225]) 308 | augmentation = [ 309 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 310 | transforms.ToTensor(), 311 | normalize 312 | ] 313 | dataset = TrainDataset('/data/zhaochuan/NCL/code/wsolevaluation-master/dataset/ILSVRC/', '/data/zhaochuan/NCL/code/GWSOL/metadata/inatloc/train/', transforms.Compose(augmentation)) 314 | train_loader = DataLoader( 315 | dataset, batch_size=16, shuffle=True, 316 | num_workers=8, pin_memory=True, drop_last=True) 317 | print(len(train_loader)) 318 | for i, (images, labels, mask_lab) in enumerate(train_loader): 319 | print(images.shape) 320 | print(labels.shape) 321 | print(mask_lab) 322 | 323 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import munch 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | import torchvision.transforms as transforms 9 | from config import imagenet_root, imagenet_meta 10 | 11 | 12 | def mch(**kwargs): 13 | return munch.Munch(dict(**kwargs)) 14 | 15 | 16 | def configure_metadata(metadata_root): 17 | metadata = mch() 18 | metadata.image_ids = os.path.join(metadata_root, 'image_ids.txt') 19 | metadata.image_ids_lb = os.path.join(metadata_root, 'image_ids_labeled.txt') 20 | metadata.class_labels = os.path.join(metadata_root, 'class_labels.txt') 21 | return metadata 22 | 23 | def configure_metadata_infer(metadata_root, cluster_preds_name): 24 | metadata = mch() 25 | metadata.image_ids = os.path.join(metadata_root, 'image_ids.txt') 26 | metadata.class_labels = os.path.join(metadata_root, cluster_preds_name) 27 | metadata.partitions = os.path.join(metadata_root, 'partitions.txt') 28 | metadata.image_sizes = os.path.join(metadata_root, 'image_sizes.txt') 29 | metadata.localization = os.path.join(metadata_root, 'localization.txt') 30 | return metadata 31 | 32 | 33 | def get_image_ids(metadata): 34 | """ 35 | image_ids.txt has the structure 36 | 37 | path/image1.jpg 38 | path/image2.jpg 39 | path/image3.jpg 40 | ... 41 | """ 42 | image_ids = [] 43 | with open(metadata.image_ids) as f: 44 | for line in f.readlines(): 45 | image_ids.append(line.strip('\n')) 46 | return image_ids 47 | 48 | 49 | def get_image_ids_lb(metadata): 50 | """ 51 | image_ids_labeled.txt has the structure 52 | 53 | path/image1.jpg 54 | path/image2.jpg 55 | path/image3.jpg 56 | ... 57 | """ 58 | image_ids_lb = [] 59 | with open(metadata.image_ids_lb) as f: 60 | for line in f.readlines(): 61 | image_ids_lb.append(line.strip('\n')) 62 | return image_ids_lb 63 | 64 | 65 | def get_class_labels(metadata): 66 | """ 67 | class_labels.txt has the structure 68 | 69 | , 70 | path/image1.jpg,0 71 | path/image2.jpg,1 72 | path/image3.jpg,1 73 | ... 74 | """ 75 | class_labels = {} 76 | with open(metadata.class_labels) as f: 77 | for line in f.readlines(): 78 | image_id, class_label_string = line.strip('\n').split(',') 79 | class_labels[image_id] = int(class_label_string) 80 | return class_labels 81 | 82 | 83 | def get_partitions(metadata): 84 | """ 85 | partitions.txt has the structure 86 | 87 | , 0 --> known, 1 --> nov-s, 2 --> nov-d 88 | path/image1.jpg,0 89 | path/image2.jpg,1 90 | path/image3.jpg,2 91 | ... 92 | """ 93 | partitions = {} 94 | with open(metadata.partitions) as f: 95 | for line in f.readlines(): 96 | image_id, partition_string = line.strip('\n').split(',') 97 | partitions[image_id] = int(partition_string) 98 | return partitions 99 | 100 | 101 | def get_class_labels_and_pred(metadata): 102 | """ 103 | class_labels_preds.txt has the structure 104 | 105 | ,, 106 | path/image1.jpg,0,0 107 | path/image2.jpg,1,1 108 | path/image3.jpg,2,2 109 | ... 110 | """ 111 | class_labels_and_preds = {} 112 | with open(metadata.class_labels) as f: 113 | for line in f.readlines(): 114 | image_id, class_label_string, pred_string = line.strip('\n').split(',') 115 | class_labels_and_preds[image_id] = (int(class_label_string), int(pred_string)) 116 | return class_labels_and_preds 117 | 118 | 119 | def get_image_sizes(metadata): 120 | """ 121 | image_sizes.txt has the structure 122 | 123 | ,, 124 | path/image1.jpg,500,300 125 | path/image2.jpg,1000,600 126 | path/image3.jpg,500,300 127 | ... 128 | """ 129 | image_sizes = {} 130 | with open(metadata.image_sizes) as f: 131 | for line in f.readlines(): 132 | image_id, ws, hs = line.strip('\n').split(',') 133 | w, h = int(ws), int(hs) 134 | image_sizes[image_id] = (w, h) 135 | return image_sizes 136 | 137 | 138 | def get_bounding_boxes(metadata): 139 | """ 140 | localization.txt (for bounding box) has the structure 141 | 142 | ,,,, 143 | path/image1.jpg,156,163,318,230 144 | path/image1.jpg,23,12,101,259 145 | path/image2.jpg,143,142,394,248 146 | path/image3.jpg,28,94,485,303 147 | ... 148 | 149 | One image may contain multiple boxes (multiple boxes for the same path). 150 | """ 151 | boxes = {} 152 | with open(metadata.localization) as f: 153 | for line in f.readlines(): 154 | image_id, x0s, x1s, y0s, y1s = line.strip('\n').split(',') 155 | x0, x1, y0, y1 = int(x0s), int(x1s), int(y0s), int(y1s) 156 | if image_id in boxes: 157 | boxes[image_id].append((x0, x1, y0, y1)) 158 | else: 159 | boxes[image_id] = [(x0, x1, y0, y1)] 160 | return boxes 161 | 162 | def pil_loader(path: str): 163 | with open(path, 'rb') as f: 164 | img = Image.open(f) 165 | return img.convert('RGB') 166 | 167 | class TrainDataset(Dataset): 168 | def __init__(self, data_root, meta_root, transform): 169 | self.data_root = data_root 170 | self.transform = transform 171 | self.metadata = configure_metadata(meta_root) 172 | self.image_ids = get_image_ids(self.metadata) 173 | self.image_ids_lb = get_image_ids_lb(self.metadata) 174 | self.image_labels = get_class_labels(self.metadata) 175 | 176 | def __getitem__(self, idx): 177 | image_id = self.image_ids[idx] 178 | image_lb_id = self.image_ids_lb[idx] 179 | image_lb_label = self.image_labels[image_lb_id] 180 | image = pil_loader(os.path.join(self.data_root, image_id)) 181 | image_lb = pil_loader(os.path.join(self.data_root, image_lb_id)) 182 | image = self.transform(image) 183 | image_lb = self.transform(image_lb) 184 | 185 | 186 | return image, image_lb, image_lb_label, idx 187 | 188 | def __len__(self): 189 | return len(self.image_ids) 190 | 191 | 192 | class MclDataset(Dataset): 193 | def __init__(self, data_root, meta_root, transform): 194 | self.data_root = data_root 195 | self.transform = transform 196 | self.metadata = configure_metadata(meta_root) 197 | self.image_ids = get_image_ids(self.metadata) 198 | 199 | def __getitem__(self, idx): 200 | image_id = self.image_ids[idx] 201 | image = pil_loader(os.path.join(self.data_root, image_id)) 202 | image = self.transform(image) 203 | return image, idx 204 | 205 | def __len__(self): 206 | return len(self.image_ids) 207 | 208 | 209 | class EvalDataset(Dataset): 210 | def __init__(self, data_root, meta_root, transform): 211 | self.data_root = data_root 212 | self.transform = transform 213 | self.metadata = configure_metadata(meta_root) 214 | self.image_ids = get_image_ids(self.metadata) 215 | self.image_labels = get_class_labels(self.metadata) 216 | 217 | def __getitem__(self, idx): 218 | image_id = self.image_ids[idx] 219 | image_label = self.image_labels[image_id] 220 | image = pil_loader(os.path.join(self.data_root, image_id)) 221 | image = self.transform(image) 222 | return image, image_label 223 | 224 | def __len__(self): 225 | return len(self.image_ids) 226 | 227 | 228 | class ClusterDataset(Dataset): 229 | def __init__(self, data_root, meta_root, transform): 230 | self.data_root = data_root 231 | self.transform = transform 232 | self.metadata = configure_metadata(meta_root) 233 | self.image_ids = get_image_ids(self.metadata) 234 | self.image_labels = get_class_labels(self.metadata) 235 | 236 | def __getitem__(self, idx): 237 | image_id = self.image_ids[idx] 238 | image_label = self.image_labels[image_id] 239 | image = pil_loader(os.path.join(self.data_root, image_id)) 240 | image = self.transform(image) 241 | return image, image_label, image_id 242 | 243 | def __len__(self): 244 | return len(self.image_ids) 245 | 246 | 247 | class GcamDataSet(Dataset): 248 | def __init__(self, data_root, meta_root, preds_name, transform): 249 | self.data_root = data_root 250 | self.transform = transform 251 | self.metadata = configure_metadata_infer(meta_root, preds_name) 252 | self.image_ids = get_image_ids(self.metadata) 253 | self.image_labels_and_pred = get_class_labels_and_pred(self.metadata) 254 | 255 | def __getitem__(self, idx): 256 | image_id = self.image_ids[idx] 257 | image_label, pred = self.image_labels_and_pred[image_id] 258 | image = pil_loader(os.path.join(self.data_root, image_id)) 259 | image = self.transform(image) 260 | return image, pred, image_id 261 | 262 | def __len__(self): 263 | return len(self.image_ids) 264 | 265 | def get_imagenet_datasets(train_transform, test_transform): 266 | 267 | meta_train = os.path.join(imagenet_meta, 'train') 268 | meta_val = os.path.join(imagenet_meta, 'val') 269 | meta_test = os.path.join(imagenet_meta, 'test') 270 | 271 | train_dataset = TrainDataset(imagenet_root, meta_train, train_transform) 272 | cluster_dataset = MclDataset(imagenet_root, meta_train, test_transform) 273 | val_dataset = EvalDataset(imagenet_root, meta_val, test_transform) 274 | 275 | # return train_dataset, test_dataset, val_dataset 276 | return train_dataset, cluster_dataset, val_dataset 277 | 278 | 279 | def get_imagenet_datasets_cluster(test_transform): 280 | 281 | meta_test = os.path.join(imagenet_meta, 'test') 282 | test_dataset = ClusterDataset(imagenet_root, meta_test, test_transform) 283 | 284 | return test_dataset 285 | 286 | 287 | def get_imagenet_datasets_gcam(test_transform, target_and_pred): 288 | 289 | meta_test = os.path.join(imagenet_meta, 'test') 290 | test_dataset = GcamDataSet(imagenet_root, meta_test, target_and_pred, test_transform) 291 | 292 | return test_dataset 293 | 294 | def get_imagenet_datasets_estimate_k(test_transform): 295 | 296 | meta_val = os.path.join(imagenet_meta, 'val') 297 | val_dataset = ClusterDataset(imagenet_root, meta_val, test_transform) 298 | 299 | return val_dataset 300 | 301 | 302 | 303 | if __name__=='__main__': 304 | import numpy as np 305 | np.set_printoptions(threshold=np.inf) 306 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 307 | std=[0.229, 0.224, 0.225]) 308 | augmentation = [ 309 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 310 | transforms.ToTensor(), 311 | normalize 312 | ] 313 | dataset = TrainDataset('/data/zhaochuan/NCL/code/wsolevaluation-master/dataset/ILSVRC/', '/data/zhaochuan/NCL/code/GWSOL/metadata/ImageNet/train/', transforms.Compose(augmentation)) 314 | train_loader = DataLoader( 315 | dataset, batch_size=16, shuffle=True, 316 | num_workers=8, pin_memory=True, drop_last=True) 317 | print(len(train_loader)) 318 | for i, (images, labels, mask_lab) in enumerate(train_loader): 319 | print(images.shape) 320 | print(labels.shape) 321 | print(mask_lab) 322 | 323 | -------------------------------------------------------------------------------- /data/openimages.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import munch 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | import torchvision.transforms as transforms 9 | from config import openimages_root, openimages_meta 10 | 11 | 12 | def mch(**kwargs): 13 | return munch.Munch(dict(**kwargs)) 14 | 15 | 16 | def configure_metadata(metadata_root): 17 | metadata = mch() 18 | metadata.image_ids = os.path.join(metadata_root, 'image_ids.txt') 19 | metadata.image_ids_lb = os.path.join(metadata_root, 'image_ids_labeled.txt') 20 | metadata.class_labels = os.path.join(metadata_root, 'class_labels.txt') 21 | return metadata 22 | 23 | def configure_metadata_infer(metadata_root, cluster_preds_name): 24 | metadata = mch() 25 | metadata.image_ids = os.path.join(metadata_root, 'image_ids.txt') 26 | metadata.class_labels = os.path.join(metadata_root, cluster_preds_name) 27 | metadata.partitions = os.path.join(metadata_root, 'partitions.txt') 28 | metadata.image_sizes = os.path.join(metadata_root, 'image_sizes.txt') 29 | metadata.localization = os.path.join(metadata_root, 'localization.txt') 30 | return metadata 31 | 32 | 33 | def get_image_ids(metadata): 34 | """ 35 | image_ids.txt has the structure 36 | 37 | path/image1.jpg 38 | path/image2.jpg 39 | path/image3.jpg 40 | ... 41 | """ 42 | image_ids = [] 43 | with open(metadata.image_ids) as f: 44 | for line in f.readlines(): 45 | image_ids.append(line.strip('\n')) 46 | return image_ids 47 | 48 | 49 | def get_image_ids_lb(metadata): 50 | """ 51 | image_ids_labeled.txt has the structure 52 | 53 | path/image1.jpg 54 | path/image2.jpg 55 | path/image3.jpg 56 | ... 57 | """ 58 | image_ids_lb = [] 59 | with open(metadata.image_ids_lb) as f: 60 | for line in f.readlines(): 61 | image_ids_lb.append(line.strip('\n')) 62 | return image_ids_lb 63 | 64 | 65 | def get_class_labels(metadata): 66 | """ 67 | class_labels.txt has the structure 68 | 69 | , 70 | path/image1.jpg,0 71 | path/image2.jpg,1 72 | path/image3.jpg,1 73 | ... 74 | """ 75 | class_labels = {} 76 | with open(metadata.class_labels) as f: 77 | for line in f.readlines(): 78 | image_id, class_label_string = line.strip('\n').split(',') 79 | class_labels[image_id] = int(class_label_string) 80 | return class_labels 81 | 82 | 83 | def get_partitions(metadata): 84 | """ 85 | partitions.txt has the structure 86 | 87 | , 0 --> known, 1 --> nov-s, 2 --> nov-d 88 | path/image1.jpg,0 89 | path/image2.jpg,1 90 | path/image3.jpg,2 91 | ... 92 | """ 93 | partitions = {} 94 | with open(metadata.partitions) as f: 95 | for line in f.readlines(): 96 | image_id, partition_string = line.strip('\n').split(',') 97 | partitions[image_id] = int(partition_string) 98 | return partitions 99 | 100 | 101 | def get_class_labels_and_pred(metadata): 102 | """ 103 | class_labels_preds.txt has the structure 104 | 105 | ,, 106 | path/image1.jpg,0,0 107 | path/image2.jpg,1,1 108 | path/image3.jpg,2,2 109 | ... 110 | """ 111 | class_labels_and_preds = {} 112 | with open(metadata.class_labels) as f: 113 | for line in f.readlines(): 114 | image_id, class_label_string, pred_string = line.strip('\n').split(',') 115 | class_labels_and_preds[image_id] = (int(class_label_string), int(pred_string)) 116 | return class_labels_and_preds 117 | 118 | 119 | def get_image_sizes(metadata): 120 | """ 121 | image_sizes.txt has the structure 122 | 123 | ,, 124 | path/image1.jpg,500,300 125 | path/image2.jpg,1000,600 126 | path/image3.jpg,500,300 127 | ... 128 | """ 129 | image_sizes = {} 130 | with open(metadata.image_sizes) as f: 131 | for line in f.readlines(): 132 | image_id, ws, hs = line.strip('\n').split(',') 133 | w, h = int(ws), int(hs) 134 | image_sizes[image_id] = (w, h) 135 | return image_sizes 136 | 137 | 138 | def get_bounding_boxes(metadata): 139 | """ 140 | localization.txt (for bounding box) has the structure 141 | 142 | ,,,, 143 | path/image1.jpg,156,163,318,230 144 | path/image1.jpg,23,12,101,259 145 | path/image2.jpg,143,142,394,248 146 | path/image3.jpg,28,94,485,303 147 | ... 148 | 149 | One image may contain multiple boxes (multiple boxes for the same path). 150 | """ 151 | boxes = {} 152 | with open(metadata.localization) as f: 153 | for line in f.readlines(): 154 | image_id, x0s, x1s, y0s, y1s = line.strip('\n').split(',') 155 | x0, x1, y0, y1 = int(x0s), int(x1s), int(y0s), int(y1s) 156 | if image_id in boxes: 157 | boxes[image_id].append((x0, x1, y0, y1)) 158 | else: 159 | boxes[image_id] = [(x0, x1, y0, y1)] 160 | return boxes 161 | 162 | def pil_loader(path: str): 163 | with open(path, 'rb') as f: 164 | img = Image.open(f) 165 | return img.convert('RGB') 166 | 167 | class TrainDataset(Dataset): 168 | def __init__(self, data_root, meta_root, transform): 169 | self.data_root = data_root 170 | self.transform = transform 171 | self.metadata = configure_metadata(meta_root) 172 | self.image_ids = get_image_ids(self.metadata) 173 | self.image_ids_lb = get_image_ids_lb(self.metadata) 174 | self.image_labels = get_class_labels(self.metadata) 175 | 176 | def __getitem__(self, idx): 177 | image_id = self.image_ids[idx] 178 | image_lb_id = self.image_ids_lb[idx] 179 | image_lb_label = self.image_labels[image_lb_id] 180 | image = pil_loader(os.path.join(self.data_root, image_id)) 181 | image_lb = pil_loader(os.path.join(self.data_root, image_lb_id)) 182 | image = self.transform(image) 183 | image_lb = self.transform(image_lb) 184 | 185 | 186 | return image, image_lb, image_lb_label, idx 187 | 188 | def __len__(self): 189 | return len(self.image_ids) 190 | 191 | 192 | class MclDataset(Dataset): 193 | def __init__(self, data_root, meta_root, transform): 194 | self.data_root = data_root 195 | self.transform = transform 196 | self.metadata = configure_metadata(meta_root) 197 | self.image_ids = get_image_ids(self.metadata) 198 | 199 | def __getitem__(self, idx): 200 | image_id = self.image_ids[idx] 201 | image = pil_loader(os.path.join(self.data_root, image_id)) 202 | image = self.transform(image) 203 | return image, idx 204 | 205 | def __len__(self): 206 | return len(self.image_ids) 207 | 208 | 209 | class EvalDataset(Dataset): 210 | def __init__(self, data_root, meta_root, transform): 211 | self.data_root = data_root 212 | self.transform = transform 213 | self.metadata = configure_metadata(meta_root) 214 | self.image_ids = get_image_ids(self.metadata) 215 | self.image_labels = get_class_labels(self.metadata) 216 | 217 | def __getitem__(self, idx): 218 | image_id = self.image_ids[idx] 219 | image_label = self.image_labels[image_id] 220 | image = pil_loader(os.path.join(self.data_root, image_id)) 221 | image = self.transform(image) 222 | return image, image_label 223 | 224 | def __len__(self): 225 | return len(self.image_ids) 226 | 227 | 228 | class ClusterDataset(Dataset): 229 | def __init__(self, data_root, meta_root, transform): 230 | self.data_root = data_root 231 | self.transform = transform 232 | self.metadata = configure_metadata(meta_root) 233 | self.image_ids = get_image_ids(self.metadata) 234 | self.image_labels = get_class_labels(self.metadata) 235 | 236 | def __getitem__(self, idx): 237 | image_id = self.image_ids[idx] 238 | image_label = self.image_labels[image_id] 239 | image = pil_loader(os.path.join(self.data_root, image_id)) 240 | image = self.transform(image) 241 | return image, image_label, image_id 242 | 243 | def __len__(self): 244 | return len(self.image_ids) 245 | 246 | 247 | class GcamDataSet(Dataset): 248 | def __init__(self, data_root, meta_root, preds_name, transform): 249 | self.data_root = data_root 250 | self.transform = transform 251 | self.metadata = configure_metadata_infer(meta_root, preds_name) 252 | self.image_ids = get_image_ids(self.metadata) 253 | self.image_labels_and_pred = get_class_labels_and_pred(self.metadata) 254 | 255 | def __getitem__(self, idx): 256 | image_id = self.image_ids[idx] 257 | image_label, pred = self.image_labels_and_pred[image_id] 258 | image = pil_loader(os.path.join(self.data_root, image_id)) 259 | image = self.transform(image) 260 | return image, pred, image_id 261 | 262 | def __len__(self): 263 | return len(self.image_ids) 264 | 265 | def get_openimages_datasets(train_transform, test_transform): 266 | 267 | meta_train = os.path.join(openimages_meta, 'train') 268 | meta_val = os.path.join(openimages_meta, 'val') 269 | meta_test = os.path.join(openimages_meta, 'test') 270 | 271 | train_dataset = TrainDataset(openimages_root, meta_train, train_transform) 272 | cluster_dataset = MclDataset(openimages_root, meta_train, test_transform) 273 | val_dataset = EvalDataset(openimages_root, meta_val, test_transform) 274 | 275 | # return train_dataset, test_dataset, val_dataset 276 | return train_dataset, cluster_dataset, val_dataset 277 | 278 | 279 | def get_openimages_datasets_cluster(test_transform): 280 | 281 | meta_test = os.path.join(openimages_meta, 'test') 282 | test_dataset = ClusterDataset(openimages_root, meta_test, test_transform) 283 | 284 | return test_dataset 285 | 286 | 287 | def get_openimages_datasets_gcam(test_transform, target_and_pred): 288 | 289 | meta_test = os.path.join(openimages_meta, 'test') 290 | test_dataset = GcamDataSet(openimages_root, meta_test, target_and_pred, test_transform) 291 | 292 | return test_dataset 293 | 294 | def get_openimages_datasets_estimate_k(test_transform): 295 | 296 | val_test = os.path.join(openimages_meta, 'val') 297 | val_dataset = ClusterDataset(openimages_root, val_test, test_transform) 298 | 299 | return val_dataset 300 | 301 | 302 | 303 | if __name__=='__main__': 304 | import numpy as np 305 | np.set_printoptions(threshold=np.inf) 306 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 307 | std=[0.229, 0.224, 0.225]) 308 | augmentation = [ 309 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 310 | transforms.ToTensor(), 311 | normalize 312 | ] 313 | dataset = TrainDataset('/data/zhaochuan/NCL/code/wsolevaluation-master/dataset/ILSVRC/', '/data/zhaochuan/NCL/code/GWSOL/metadata/openimages/train/', transforms.Compose(augmentation)) 314 | train_loader = DataLoader( 315 | dataset, batch_size=16, shuffle=True, 316 | num_workers=8, pin_memory=True, drop_last=True) 317 | print(len(train_loader)) 318 | for i, (images, labels, mask_lab) in enumerate(train_loader): 319 | print(images.shape) 320 | print(labels.shape) 321 | print(mask_lab) 322 | 323 | -------------------------------------------------------------------------------- /methods/gcam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import random 5 | import shutil 6 | import munch 7 | import numpy as np 8 | from sklearn.cluster import KMeans 9 | from os.path import join as ospj 10 | from os.path import dirname as ospd 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torchvision import transforms 15 | from torch.utils.data import DataLoader 16 | 17 | from models.resnet import resnet50 18 | from methods.metric import BoxEvaluator 19 | from data.imagenet import configure_metadata_infer 20 | from data.get_datasets import get_class_splits, get_datasets_cluster, get_datasets_gcam 21 | 22 | from project_utils.utils import t2n 23 | from project_utils.utils import Logger 24 | from project_utils.cluster_and_log_utils import log_accs_from_preds_infer 25 | 26 | 27 | 28 | def mch(**kwargs): 29 | return munch.Munch(dict(**kwargs)) 30 | 31 | def set_random_seed(seed): 32 | if seed is None: 33 | return 34 | np.random.seed(seed) 35 | random.seed(seed) 36 | torch.manual_seed(seed) 37 | 38 | def str2bool(v): 39 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 40 | return True 41 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 42 | return False 43 | else: 44 | raise argparse.ArgumentTypeError('Boolean value expected.') 45 | 46 | def configure_log_folder(args): 47 | log_folder = ospj('infer_log', args.dataset_name, args.experiment_name) 48 | 49 | if os.path.isdir(log_folder): 50 | if args.override_cache: 51 | shutil.rmtree(log_folder, ignore_errors=True) 52 | else: 53 | raise RuntimeError("Experiment with the same name exists: {}" 54 | .format(log_folder)) 55 | os.makedirs(log_folder) 56 | return log_folder 57 | 58 | 59 | def configure_log(args): 60 | log_file_name = ospj(args.log_folder, 'log.log') 61 | Logger(log_file_name) 62 | 63 | 64 | parser = argparse.ArgumentParser() 65 | 66 | # Util 67 | parser.add_argument('--seed', type=int, default=0) 68 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50') 69 | parser.add_argument('--experiment_name', type=str, default='G-CAM_') 70 | parser.add_argument('--workers', default=8, type=int, help='number of data loading workers (default: 8)') 71 | parser.add_argument('--batch_size', default=128, type=int) 72 | parser.add_argument('--override_cache', type=str2bool, nargs='?', const=True, default=False) 73 | 74 | 75 | # Data 76 | parser.add_argument('--dataset_name', type=str, default='ImageNet') 77 | parser.add_argument('--metadata_root', type=str, default='./metadata/') 78 | parser.add_argument('--model_path', type=str, default=None) 79 | parser.add_argument('--interpolation', default=3, type=int) 80 | parser.add_argument('--crop_pct', type=float, default=0.875) 81 | parser.add_argument('--partitions', nargs='+', default=['all', 'known', 'nov_s', 'nov_d']) 82 | 83 | 84 | # G-CAM Setting 85 | parser.add_argument('--cam_curve_interval', type=float, default=.001, help='CAM curve interval') 86 | parser.add_argument('--multi_contour_eval', type=str2bool, nargs='?', const=True, default=True) 87 | parser.add_argument('--multi_iou_eval', type=str2bool, nargs='?', const=True, default=True) 88 | parser.add_argument('--iou_threshold_list', nargs='+', type=int, default=[30, 50, 70]) 89 | 90 | # Save name 91 | parser.add_argument('--weight_dict_name', type=str, default='weight_') 92 | parser.add_argument('--weight_dict_path', type=str, default=None) 93 | parser.add_argument('--labels_and_preds_name', type=str, default='labels_and_preds_') 94 | parser.add_argument('--pred_save_path', type=str, default='test') 95 | 96 | 97 | 98 | args = parser.parse_args() 99 | args = get_class_splits(args) 100 | args.num_known_categories = len(args.known_categories) 101 | args.num_novel_categories = len(args.novel_categories) 102 | 103 | model_name = args.model_path.split('/')[-2] 104 | args.experiment_name = args.experiment_name + model_name 105 | args.log_folder = configure_log_folder(args) 106 | configure_log(args) 107 | 108 | args.metadata_root = ospj(args.metadata_root, args.dataset_name) 109 | args.pred_save_path = ospj(args.metadata_root, args.pred_save_path) 110 | args.weight_dict_name = args.weight_dict_name + model_name 111 | args.labels_and_preds_name = args.labels_and_preds_name + model_name 112 | args.weight_dict_path = ospj(args.log_folder, args.weight_dict_name) 113 | 114 | 115 | 116 | def normalize_scoremap(gcam): 117 | """ 118 | Args: 119 | gcam: numpy.ndarray(size=(H, W), dtype=np.float) 120 | Returns: 121 | numpy.ndarray(size=(H, W), dtype=np.float) between 0 and 1. 122 | If input array is constant, a zero-array is returned. 123 | """ 124 | if np.isnan(gcam).any(): 125 | return np.zeros_like(gcam) 126 | if gcam.min() == gcam.max(): 127 | return np.zeros_like(gcam) 128 | gcam -= gcam.min() 129 | gcam /= gcam.max() 130 | return gcam 131 | 132 | def get_batch_weight(targets, weight_dict): 133 | weight_list = [] 134 | for item in targets: 135 | weight_list.append(torch.from_numpy(weight_dict[item])) 136 | gcam_weights = torch.stack(weight_list) 137 | return gcam_weights 138 | 139 | 140 | 141 | class GcamComputer(object): 142 | def __init__(self, model, loader, metadata_root, labels_and_preds, 143 | iou_threshold_list, dataset_name, partitions, multi_contour_eval, 144 | gcam_curve_interval=.001, log_folder=None, weight_dict=None): 145 | self.model = model 146 | self.model.eval() 147 | self.loader = loader 148 | self.log_folder = log_folder 149 | self.weight_dict = weight_dict 150 | 151 | 152 | metadata = configure_metadata_infer(metadata_root, labels_and_preds) 153 | gcam_threshold_list = list(np.arange(0, 1, gcam_curve_interval)) 154 | 155 | self.evaluator = BoxEvaluator(metadata=metadata, 156 | dataset_name=dataset_name, 157 | gcam_threshold_list=gcam_threshold_list, 158 | iou_threshold_list=iou_threshold_list, 159 | partitions=partitions, 160 | multi_contour_eval=multi_contour_eval) 161 | 162 | def compute_and_evaluate_gcams(self): 163 | print("Computing and evaluating G-CAMs.") 164 | for images, targets, image_ids in self.loader: 165 | image_size = images.shape[2:] 166 | images = images.cuda() 167 | feature_map = self.model(images)['feature_map'] 168 | targets = targets.detach().clone().cpu().numpy() 169 | gcam_weights = get_batch_weight(targets, self.weight_dict).cuda() 170 | gcams = (gcam_weights.view(*feature_map.shape[:2], 1, 1) * 171 | feature_map).mean(1, keepdim=False) 172 | gcams = t2n(gcams) 173 | 174 | for gcam, image_id in zip(gcams, image_ids): 175 | gcam_resized = cv2.resize(gcam, image_size, 176 | interpolation=cv2.INTER_CUBIC) 177 | gcam_normalized = normalize_scoremap(gcam_resized) 178 | gcam_path = ospj(self.log_folder, 'scoremaps', image_id) 179 | if not os.path.exists(ospd(gcam_path)): 180 | os.makedirs(ospd(gcam_path)) 181 | np.save(ospj(gcam_path), gcam_normalized) 182 | self.evaluator.accumulate(gcam_normalized, image_id) 183 | return self.evaluator.compute() 184 | 185 | 186 | 187 | if __name__=='__main__': 188 | set_random_seed(args.seed) 189 | 190 | 191 | print("=> creating model '{}'".format(args.arch)) 192 | model = resnet50(num_classes=128, pretrained=None) 193 | 194 | # load pretrained 195 | state_dict = {} 196 | old_state_dict = torch.load(args.model_path, map_location='cpu')['state_dict'] 197 | for key in old_state_dict.keys(): 198 | if key.startswith('module.encoder_q'): 199 | print(key) 200 | new_key = key.split('encoder_q.')[1] 201 | state_dict[new_key] = old_state_dict[key] 202 | model.load_state_dict(state_dict, strict=False) 203 | model = nn.DataParallel(model) 204 | for name,parameter in model.named_parameters(): 205 | print(name) 206 | print(parameter) 207 | model.cuda() 208 | model.eval() 209 | 210 | test_transform = transforms.Compose([ 211 | transforms.Resize((224, 224)), 212 | transforms.ToTensor(), 213 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 214 | ]) 215 | 216 | ############################################### Clustering ############################################## 217 | # -------------------- 218 | # DATASETS 219 | # -------------------- 220 | 221 | 222 | cluster_dataset = get_datasets_cluster(dataset_name=args.dataset_name, test_transform=test_transform) 223 | cluster_loader = DataLoader(cluster_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 224 | 225 | # -------------------- 226 | # EXTRACT FEATURE TO CLUSTER 227 | # -------------------- 228 | 229 | ft_list = [] 230 | id_list = [] 231 | targets = np.array([]) 232 | mask = np.array([]) 233 | weight_dict = {} 234 | print('Collating features...') 235 | # First extract all features 236 | for batch_idx, (images, label, img_ids) in enumerate(cluster_loader): 237 | 238 | feats = model(images.cuda())['feature'] 239 | feats = torch.nn.functional.normalize(feats, dim=-1) 240 | print(feats.shape) 241 | ft_list.append(feats.detach().cpu().numpy()) 242 | targets = np.append(targets, label.cpu().numpy()) 243 | # distinguish labeled classes 244 | mask = np.append(mask, np.array([True if x.item() in args.known_categories 245 | else False for x in label])) 246 | id_list += img_ids 247 | 248 | all_feats = np.concatenate(ft_list) 249 | print("ft_all.shape: ", all_feats.shape) 250 | print('Fitting K-Means...') 251 | kmeans = KMeans(n_clusters=args.num_known_categories + args.num_novel_categories, random_state=0).fit(all_feats) 252 | preds = kmeans.labels_ 253 | centroids = kmeans.cluster_centers_ 254 | 255 | print('labels.shape:' , preds.shape) 256 | print('weight.shape: ', centroids.shape) 257 | 258 | all_acc, known_acc, novel_acc, ind_map_pre2gt = log_accs_from_preds_infer(y_true=targets, y_pred=preds, mask=mask) 259 | 260 | print('Accuracies: All {:.4f} | Known {:.4f} | Novel {:.4f}'.format(all_acc, known_acc, novel_acc)) 261 | 262 | 263 | 264 | # match pred with target 265 | pred2gt = np.array([]) 266 | pred2gt = np.append(pred2gt, np.array([int(ind_map_pre2gt[pred]) for pred in preds])) 267 | print('pred2gt.shape:', pred2gt.shape) 268 | 269 | # use pred2gt to save the weight_dict 270 | for i,item in enumerate(centroids): 271 | weight_dict[int(ind_map_pre2gt[i])] = item 272 | print("len(weight_dict): ", len(weight_dict.keys())) 273 | np.save(args.weight_dict_path, weight_dict) 274 | 275 | with open(ospj(args.pred_save_path, args.labels_and_preds_name), 'w') as f: 276 | for i,item in enumerate(id_list): 277 | f.write(item + ',' + str(int(targets[i])) + ',' + str(int(pred2gt[i]))) 278 | f.write('\n') 279 | 280 | 281 | ############################################## INFER G-CAM ############################################## 282 | 283 | 284 | # -------------------- 285 | # DATASETS 286 | # -------------------- 287 | gcam_dataset = get_datasets_gcam(dataset_name=args.dataset_name, test_transform=test_transform, target_and_pred=args.labels_and_preds_name) 288 | gcam_loader = DataLoader(gcam_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 289 | 290 | 291 | 292 | gcam_computer = GcamComputer( 293 | model=model, 294 | loader=gcam_loader, 295 | metadata_root=os.path.join(args.metadata_root, 'test'), 296 | labels_and_preds=args.labels_and_preds_name, 297 | iou_threshold_list=args.iou_threshold_list, 298 | dataset_name=args.dataset_name, 299 | partitions=args.partitions, 300 | gcam_curve_interval=args.cam_curve_interval, 301 | multi_contour_eval=args.multi_contour_eval, 302 | log_folder=args.log_folder, 303 | weight_dict = weight_dict 304 | ) 305 | 306 | loc_acc, clus_loc_acc, clus_acc = gcam_computer.compute_and_evaluate_gcams() 307 | 308 | print('######################## Clus Acc ########################') 309 | for k,v in clus_acc.items(): 310 | print(k, ":", v) 311 | print('###################### Clus Loc Acc ######################') 312 | for k,v in clus_loc_acc.items(): 313 | print(k, ":", np.average(v)) 314 | print('######################## Loc Acc #########################') 315 | for k,v in loc_acc.items(): 316 | print(k, ":", np.average(v)) 317 | 318 | -------------------------------------------------------------------------------- /methods/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020-present NAVER Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | this software and associated documentation files (the "Software"), to deal in 6 | the Software without restriction, including without limitation the rights to 7 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 8 | the Software, and to permit persons to whom the Software is furnished to do so, 9 | subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 16 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 17 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 18 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 19 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 20 | """ 21 | 22 | import cv2 23 | import numpy as np 24 | 25 | from data.imagenet import get_image_ids 26 | from data.imagenet import get_bounding_boxes 27 | from data.imagenet import get_image_sizes 28 | from data.imagenet import get_partitions 29 | from data.imagenet import get_class_labels_and_pred 30 | from project_utils.utils import check_scoremap_validity 31 | from project_utils.utils import check_box_convention 32 | 33 | 34 | _RESIZE_LENGTH = 224 35 | _CONTOUR_INDEX = 1 if cv2.__version__.split('.')[0] == '3' else 0 36 | _PARTITION_MAP = {0:'known', 1:'nov_s', 2:'nov_d'} 37 | 38 | 39 | 40 | def calculate_multiple_iou(box_a, box_b): 41 | """ 42 | Args: 43 | box_a: numpy.ndarray(dtype=np.int, shape=(num_a, 4)) 44 | x0y0x1y1 convention. 45 | box_b: numpy.ndarray(dtype=np.int, shape=(num_b, 4)) 46 | x0y0x1y1 convention. 47 | Returns: 48 | ious: numpy.ndarray(dtype=np.int, shape(num_a, num_b)) 49 | """ 50 | num_a = box_a.shape[0] 51 | num_b = box_b.shape[0] 52 | 53 | check_box_convention(box_a, 'x0y0x1y1') 54 | check_box_convention(box_b, 'x0y0x1y1') 55 | 56 | # num_a x 4 -> num_a x num_b x 4 57 | box_a = np.tile(box_a, num_b) 58 | box_a = np.expand_dims(box_a, axis=1).reshape((num_a, num_b, -1)) 59 | 60 | # num_b x 4 -> num_b x num_a x 4 61 | box_b = np.tile(box_b, num_a) 62 | box_b = np.expand_dims(box_b, axis=1).reshape((num_b, num_a, -1)) 63 | 64 | # num_b x num_a x 4 -> num_a x num_b x 4 65 | box_b = np.transpose(box_b, (1, 0, 2)) 66 | 67 | # num_a x num_b 68 | min_x = np.maximum(box_a[:, :, 0], box_b[:, :, 0]) 69 | min_y = np.maximum(box_a[:, :, 1], box_b[:, :, 1]) 70 | max_x = np.minimum(box_a[:, :, 2], box_b[:, :, 2]) 71 | max_y = np.minimum(box_a[:, :, 3], box_b[:, :, 3]) 72 | 73 | # num_a x num_b 74 | area_intersect = (np.maximum(0, max_x - min_x + 1) 75 | * np.maximum(0, max_y - min_y + 1)) 76 | area_a = ((box_a[:, :, 2] - box_a[:, :, 0] + 1) * 77 | (box_a[:, :, 3] - box_a[:, :, 1] + 1)) 78 | area_b = ((box_b[:, :, 2] - box_b[:, :, 0] + 1) * 79 | (box_b[:, :, 3] - box_b[:, :, 1] + 1)) 80 | 81 | denominator = area_a + area_b - area_intersect 82 | degenerate_indices = np.where(denominator <= 0) 83 | denominator[degenerate_indices] = 1 84 | 85 | ious = area_intersect / denominator 86 | ious[degenerate_indices] = 0 87 | return ious 88 | 89 | 90 | def resize_bbox(box, image_size, resize_size): 91 | """ 92 | Args: 93 | box: iterable (ints) of length 4 (x0, y0, x1, y1) 94 | image_size: iterable (ints) of length 2 (width, height) 95 | resize_size: iterable (ints) of length 2 (width, height) 96 | 97 | Returns: 98 | new_box: iterable (ints) of length 4 (x0, y0, x1, y1) 99 | """ 100 | check_box_convention(np.array(box), 'x0y0x1y1') 101 | box_x0, box_y0, box_x1, box_y1 = map(float, box) 102 | image_w, image_h = map(float, image_size) 103 | new_image_w, new_image_h = map(float, resize_size) 104 | 105 | newbox_x0 = box_x0 * new_image_w / image_w 106 | newbox_y0 = box_y0 * new_image_h / image_h 107 | newbox_x1 = box_x1 * new_image_w / image_w 108 | newbox_y1 = box_y1 * new_image_h / image_h 109 | return int(newbox_x0), int(newbox_y0), int(newbox_x1), int(newbox_y1) 110 | 111 | 112 | def compute_bboxes_from_scoremaps(scoremap, scoremap_threshold_list, 113 | multi_contour_eval=False): 114 | """ 115 | Args: 116 | scoremap: numpy.ndarray(dtype=np.float32, size=(H, W)) between 0 and 1 117 | scoremap_threshold_list: iterable 118 | multi_contour_eval: flag for multi-contour evaluation 119 | 120 | Returns: 121 | estimated_boxes_at_each_thr: list of estimated boxes (list of np.array) 122 | at each cam threshold 123 | number_of_box_list: list of the number of boxes at each cam threshold 124 | """ 125 | check_scoremap_validity(scoremap) 126 | height, width = scoremap.shape 127 | scoremap_image = np.expand_dims((scoremap * 255).astype(np.uint8), 2) 128 | 129 | def scoremap2bbox(threshold): 130 | _, thr_gray_heatmap = cv2.threshold( 131 | src=scoremap_image, 132 | thresh=int(threshold * np.max(scoremap_image)), 133 | maxval=255, 134 | type=cv2.THRESH_BINARY) 135 | contours = cv2.findContours( 136 | image=thr_gray_heatmap, 137 | mode=cv2.RETR_TREE, 138 | method=cv2.CHAIN_APPROX_SIMPLE)[_CONTOUR_INDEX] 139 | 140 | if len(contours) == 0: 141 | return np.asarray([[0, 0, 0, 0]]), 1 142 | 143 | if not multi_contour_eval: 144 | contours = [max(contours, key=cv2.contourArea)] 145 | 146 | estimated_boxes = [] 147 | for contour in contours: 148 | x, y, w, h = cv2.boundingRect(contour) 149 | x0, y0, x1, y1 = x, y, x + w, y + h 150 | x1 = min(x1, width - 1) 151 | y1 = min(y1, height - 1) 152 | estimated_boxes.append([x0, y0, x1, y1]) 153 | 154 | return np.asarray(estimated_boxes), len(contours) 155 | 156 | estimated_boxes_at_each_thr = [] 157 | number_of_box_list = [] 158 | for threshold in scoremap_threshold_list: 159 | boxes, number_of_box = scoremap2bbox(threshold) 160 | estimated_boxes_at_each_thr.append(boxes) 161 | number_of_box_list.append(number_of_box) 162 | 163 | return estimated_boxes_at_each_thr, number_of_box_list 164 | 165 | 166 | 167 | 168 | class LocalizationEvaluator(object): 169 | """ Abstract class for localization evaluation over score maps. 170 | 171 | The class is designed to operate in a for loop (e.g. batch-wise cam 172 | score map computation). At initialization, __init__ registers paths to 173 | annotations and data containers for evaluation. At each iteration, 174 | each score map is passed to the accumulate() method along with its image_id. 175 | After the for loop is finalized, compute() is called to compute the final 176 | localization performance. 177 | """ 178 | 179 | def __init__(self, metadata, dataset_name, gcam_threshold_list, 180 | iou_threshold_list, partitions, multi_contour_eval): 181 | self.metadata = metadata 182 | self.gcam_threshold_list = gcam_threshold_list 183 | self.iou_threshold_list = iou_threshold_list 184 | self.dataset_partitions = partitions 185 | self.dataset_name = dataset_name 186 | self.multi_contour_eval = multi_contour_eval 187 | 188 | def accumulate(self, scoremap, image_id): 189 | raise NotImplementedError 190 | 191 | def compute(self): 192 | raise NotImplementedError 193 | 194 | 195 | class BoxEvaluator(LocalizationEvaluator): 196 | def __init__(self, **kwargs): 197 | super(BoxEvaluator, self).__init__(**kwargs) 198 | 199 | self.image_ids = get_image_ids(metadata=self.metadata) 200 | self.resize_length = _RESIZE_LENGTH 201 | # self.cnt = {'all':0, 'known':0, 'nov_s':0, 'nov_d':0} 202 | # self.cnt_clus = {'all':0, 'known':0, 'nov_s':0, 'nov_d':0} 203 | 204 | self.cnt = {partition:0 for partition in self.dataset_partitions} 205 | self.cnt_clus = {partition:0 for partition in self.dataset_partitions} 206 | 207 | self.num_correct = \ 208 | {partition: {iou_threshold: np.zeros(len(self.gcam_threshold_list)) 209 | for iou_threshold in self.iou_threshold_list} 210 | for partition in self.dataset_partitions} 211 | 212 | self.num_correct_top1 = \ 213 | {partition: {iou_threshold: np.zeros(len(self.gcam_threshold_list)) 214 | for iou_threshold in self.iou_threshold_list} 215 | for partition in self.dataset_partitions} 216 | 217 | self.original_bboxes = get_bounding_boxes(self.metadata) 218 | self.image_sizes = get_image_sizes(self.metadata) 219 | self.gt_bboxes = self._load_resized_boxes(self.original_bboxes) 220 | self.target_and_pred = get_class_labels_and_pred(self.metadata) 221 | self.partition = get_partitions(self.metadata) 222 | 223 | def _load_resized_boxes(self, original_bboxes): 224 | resized_bbox = {image_id: [ 225 | resize_bbox(bbox, self.image_sizes[image_id], 226 | (self.resize_length, self.resize_length)) 227 | for bbox in original_bboxes[image_id]] 228 | for image_id in self.image_ids} 229 | return resized_bbox 230 | 231 | def accumulate(self, scoremap, image_id): 232 | """ 233 | From a score map, a box is inferred (compute_bboxes_from_scoremaps). 234 | The box is compared against GT boxes. Count a scoremap as a correct 235 | prediction if the IOU against at least one box is greater than a certain 236 | threshold (_IOU_THRESHOLD). 237 | 238 | Args: 239 | scoremap: numpy.ndarray(size=(H, W), dtype=np.float) 240 | image_id: string. 241 | """ 242 | target, pred = self.target_and_pred[image_id] 243 | partition = self.partition[image_id] 244 | partition = _PARTITION_MAP[partition] 245 | 246 | 247 | boxes_at_thresholds, number_of_box_list = compute_bboxes_from_scoremaps( 248 | scoremap=scoremap, 249 | scoremap_threshold_list=self.gcam_threshold_list, 250 | multi_contour_eval=self.multi_contour_eval) 251 | 252 | # (N_threshold) --> (num_all_boxes, 4) <==> (num_a, 4) 253 | boxes_at_thresholds = np.concatenate(boxes_at_thresholds, axis=0) 254 | 255 | # num_a x num_b 256 | multiple_iou = calculate_multiple_iou( 257 | np.array(boxes_at_thresholds), 258 | np.array(self.gt_bboxes[image_id])) 259 | 260 | # find the max iou match with different gt boxes, 261 | # and then find the max iou boxes in a threshold. 262 | # nr_box: number of each threshold 263 | # sliced_multiple_iou 264 | 265 | idx = 0 266 | sliced_multiple_iou = [] # len: 1000 267 | for nr_box in number_of_box_list: 268 | sliced_multiple_iou.append( 269 | max(multiple_iou.max(1)[idx:idx + nr_box])) 270 | idx += nr_box 271 | # record clus_acc 272 | if target == pred: 273 | self.cnt_clus['all'] += 1 274 | self.cnt_clus[partition] += 1 275 | 276 | 277 | for _THRESHOLD in self.iou_threshold_list: 278 | correct_threshold_indices = \ 279 | np.where(np.asarray(sliced_multiple_iou) >= (_THRESHOLD/100))[0] 280 | # record loc_acc 281 | self.num_correct['all'][_THRESHOLD][correct_threshold_indices] += 1 282 | self.num_correct[partition][_THRESHOLD][correct_threshold_indices] += 1 283 | # record clus_loc_acc 284 | if target == pred: 285 | self.num_correct_top1['all'][_THRESHOLD][correct_threshold_indices] += 1 286 | self.num_correct_top1[partition][_THRESHOLD][correct_threshold_indices] += 1 287 | 288 | 289 | self.cnt['all'] += 1 290 | self.cnt[partition] += 1 291 | 292 | def compute(self): 293 | """ 294 | Returns: 295 | max_localization_accuracy: float. The ratio of images where the 296 | box prediction is correct. The best scoremap threshold is taken 297 | for the final performance. 298 | """ 299 | loc_acc = {partition:[] for partition in self.dataset_partitions} 300 | clus_loc_acc = {partition:[] for partition in self.dataset_partitions} 301 | clus_acc = {} 302 | for partition in self.dataset_partitions: 303 | for _THRESHOLD in self.iou_threshold_list: 304 | localization_accuracies_all = self.num_correct['all'][_THRESHOLD] * 100. / float(self.cnt['all']) 305 | loc_acc_max_index = np.where(localization_accuracies_all==localization_accuracies_all.max()) 306 | cluster_localization_accuracies_all = self.num_correct['all'][_THRESHOLD] * 100. / float(self.cnt['all']) 307 | clus_loc_acc_max_index = np.where(cluster_localization_accuracies_all==cluster_localization_accuracies_all.max()) 308 | 309 | # using the best threshold 310 | localization_accuracies = self.num_correct[partition][_THRESHOLD] * 100. / float(self.cnt[partition]) 311 | cluster_localization_accuracies = self.num_correct_top1[partition][_THRESHOLD] * 100. / float(self.cnt[partition]) 312 | loc_acc[partition].append(localization_accuracies[loc_acc_max_index].max()) 313 | clus_loc_acc[partition].append(cluster_localization_accuracies[clus_loc_acc_max_index].max()) 314 | 315 | 316 | clus_acc[partition] = self.cnt_clus[partition] / self.cnt[partition] 317 | 318 | return loc_acc, clus_loc_acc, clus_acc -------------------------------------------------------------------------------- /methods/contrastive_co_training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import math 4 | import random 5 | import shutil 6 | import time 7 | import argparse 8 | import builtins 9 | import datetime 10 | import pickle 11 | import warnings 12 | import faiss 13 | import numpy as np 14 | from sklearn.cluster import KMeans 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.parallel 19 | import torch.nn.functional as F 20 | import torch.backends.cudnn as cudnn 21 | import torch.distributed as dist 22 | import torch.optim 23 | import torch.multiprocessing as mp 24 | import torch.utils.data 25 | import torch.utils.data.distributed 26 | 27 | 28 | from models.builder import MoCo 29 | from models.resnet import resnet50 30 | 31 | from data.augmentations import get_transform 32 | from data.get_datasets import get_datasets, get_class_splits 33 | 34 | from project_utils.cluster_and_log_utils import log_accs_from_preds 35 | 36 | 37 | from config import pretrained_path, imagenet_sup_queue_path, inatloc_sup_queue_path, openimages_sup_queue_path 38 | 39 | 40 | def str2bool(v): 41 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 42 | return True 43 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 44 | return False 45 | else: 46 | raise argparse.ArgumentTypeError('Boolean value expected.') 47 | 48 | 49 | parser = argparse.ArgumentParser(description='PyTorch OWSOL Training') 50 | parser.add_argument('--dataset_name', type=str, default='ImageNet', 51 | help='options: ImageNet, iNatLoc, OpenImages') 52 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50') 53 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 54 | help='number of data loading workers (default: 32)') 55 | parser.add_argument('--epochs', default=10, type=int, metavar='N', 56 | help='number of total epochs to run') 57 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 58 | help='manual epoch number (useful on restarts)') 59 | parser.add_argument('-b', '--batch_size', default=256, type=int, 60 | metavar='N', 61 | help='mini-batch size (default: 256), this is the total ' 62 | 'batch size of all GPUs on the current node when ' 63 | 'using Data Parallel or Distributed Data Parallel') 64 | parser.add_argument('--eval_batch_size', default=64, type=int) 65 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 66 | metavar='LR', help='initial learning rate', dest='lr') 67 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, 68 | help='learning rate schedule (when to drop lr by 10x)') 69 | parser.add_argument('--cos', action='store_true', 70 | help='use cosine lr schedule') 71 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 72 | help='momentum of SGD solver') 73 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 74 | metavar='W', help='weight decay (default: 1e-4)', 75 | dest='weight_decay') 76 | parser.add_argument('-p', '--print-freq', default=100, type=int, 77 | metavar='N', help='print frequency (default: 100)') 78 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 79 | help='path to latest checkpoint (default: none)') 80 | parser.add_argument('--world_size', default=-1, type=int, 81 | help='number of nodes for distributed training') 82 | parser.add_argument('--rank', default=-1, type=int, 83 | help='node rank for distributed training') 84 | parser.add_argument('--dist_url', default="env://", type=str, 85 | help='url used to set up distributed training') 86 | parser.add_argument('--dist_backend', default='nccl', type=str, 87 | help='distributed backend') 88 | parser.add_argument('--seed', default=None, type=int, 89 | help='seed for initializing training. ') 90 | parser.add_argument('--gpu', default=None, type=int, 91 | help='GPU id to use.') 92 | parser.add_argument('--multiprocessing_distributed', action='store_true', 93 | help='Use multi-processing distributed training to launch ' 94 | 'N processes per node, which has N GPUs. This is the ' 95 | 'fastest way to use PyTorch for either single node or ' 96 | 'multi node data parallel training') 97 | # augmentation 98 | parser.add_argument('--n_views', default=2, type=int) 99 | parser.add_argument('--interpolation', default=3, type=int) 100 | parser.add_argument('--crop_pct', type=float, default=0.875) 101 | 102 | # model specific configs 103 | parser.add_argument('--mlp_dim', default=128, type=int, 104 | help='feature dimension (default: 128)') 105 | parser.add_argument('--moco_m', default=0.999, type=float, 106 | help='moco momentum of updating key encoder (default: 0.999)') 107 | parser.add_argument('--scl_t', default=0.07, type=float, 108 | help='softmax temperature for scl(default: 0.07)') 109 | parser.add_argument('--mcl_t', default=0.2, type=float, 110 | help='base temperature for mcl(default: 0.07)') 111 | parser.add_argument('--num_cluster', default='50000', type=int, 112 | help='number of clusters') 113 | parser.add_argument('--num_multi_centroids', default=5, type=int, 114 | help='number of multi centroids') 115 | parser.add_argument('--mcl_k', default=4096, type=int, 116 | help='number of negative centroids') 117 | parser.add_argument('--scl_weight', type=float, default=0.5) 118 | parser.add_argument('--mcl_weight', type=float, default=0.5) 119 | 120 | 121 | 122 | class SupConLoss(nn.Module): 123 | def __init__(self): 124 | super(SupConLoss, self).__init__() 125 | 126 | 127 | def forward(self, logits, mask): 128 | 129 | # logits --> Nx(1+scl_k) 130 | # for numerical stability 131 | logits_max, _ = torch.max(logits, dim=1, keepdim=True) 132 | logits = logits - logits_max.detach() 133 | 134 | 135 | # compute log_prob 136 | exp_logits = torch.exp(logits) 137 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # log(A/B)=log(A)-log(B) --> torch.log(exp_logits) - torch.log(exp_logits.sum(1, keepdim=True)) 138 | 139 | # compute mean of log-likelihood over positive 140 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 141 | 142 | # loss 143 | loss = -1 * mean_log_prob_pos 144 | loss = loss.mean() 145 | 146 | return loss 147 | 148 | 149 | class ContrastiveLearningViewGenerator(object): 150 | """Take two random crops of one image as the query and key.""" 151 | 152 | def __init__(self, base_transform, n_views=2): 153 | self.base_transform = base_transform 154 | self.n_views = n_views 155 | 156 | def __call__(self, x): 157 | return [self.base_transform(x) for i in range(self.n_views)] 158 | 159 | 160 | def main(): 161 | args = parser.parse_args() 162 | args = get_class_splits(args) 163 | args.num_known_categories = len(args.known_categories) 164 | args.num_novel_categories = len(args.novel_categories) 165 | 166 | 167 | if args.seed is not None: 168 | random.seed(args.seed) 169 | np.random.seed(args.seed) 170 | torch.manual_seed(args.seed) 171 | torch.cuda.manual_seed(args.seed) 172 | cudnn.deterministic = True 173 | warnings.warn('You have chosen to seed training. ' 174 | 'This will turn on the CUDNN deterministic setting, ' 175 | 'which can slow down your training considerably! ' 176 | 'You may see unexpected behavior when restarting ' 177 | 'from checkpoints.') 178 | 179 | if args.gpu is not None: 180 | warnings.warn('You have chosen a specific GPU. This will completely ' 181 | 'disable data parallelism.') 182 | 183 | if args.dist_url == "env://" and args.world_size == -1: 184 | args.world_size = int(os.environ["WORLD_SIZE"]) 185 | 186 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 187 | print('num_cluster:', args.num_cluster) 188 | print('num_multi_centroids:', args.num_multi_centroids ) 189 | print('learning rate:', args.lr) 190 | print('scl_weight: ', args.scl_weight) 191 | print('mcl_weight: ', args.mcl_weight) 192 | 193 | ngpus_per_node = torch.cuda.device_count() 194 | if args.multiprocessing_distributed: 195 | # Since we have ngpus_per_node processes per node, the total world_size 196 | # needs to be adjusted accordingly 197 | args.world_size = ngpus_per_node * args.world_size 198 | # Use torch.multiprocessing.spawn to launch distributed processes: the 199 | # main_worker process function 200 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 201 | else: 202 | # Simply call main_worker function 203 | main_worker(args.gpu, ngpus_per_node, args) 204 | 205 | 206 | def main_worker(gpu, ngpus_per_node, args): 207 | args.gpu = gpu 208 | 209 | # suppress printing if not master 210 | if args.multiprocessing_distributed and args.gpu != 0: 211 | def print_pass(*args): 212 | pass 213 | builtins.print = print_pass 214 | 215 | if args.gpu is not None: 216 | print("Use GPU: {} for training".format(args.gpu)) 217 | 218 | if args.distributed: 219 | if args.dist_url == "env://" and args.rank == -1: 220 | args.rank = int(os.environ["RANK"]) 221 | if args.multiprocessing_distributed: 222 | # For multiprocessing distributed training, rank needs to be the 223 | # global rank among all the processes 224 | args.rank = args.rank * ngpus_per_node + gpu 225 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,timeout=datetime.timedelta(0, 7200), 226 | world_size=args.world_size, rank=args.rank) 227 | 228 | # create model 229 | if args.dataset_name == 'ImageNet': 230 | f = open(imagenet_sup_queue_path, 'rb') 231 | sup_queue = pickle.load(f) 232 | elif args.dataset_name == 'iNatLoc': 233 | f = open(inatloc_sup_queue_path, 'rb') 234 | sup_queue = pickle.load(f) 235 | elif args.dataset_name == 'OpenImages': 236 | f = open(openimages_sup_queue_path, 'rb') 237 | sup_queue = pickle.load(f) 238 | 239 | 240 | 241 | 242 | print("=> creating model '{}'".format(args.arch)) 243 | model = MoCo( 244 | resnet50, pretrained_path, sup_queue, args.num_known_categories, args.num_multi_centroids, 245 | args.mlp_dim, args.mcl_k, args.moco_m, args.scl_t) 246 | print(model) 247 | 248 | for name, parms in model.named_parameters(): 249 | print('-->name:', name) 250 | print('-->grad_requirs:',parms.requires_grad) 251 | 252 | for name,parameter in model.named_parameters(): 253 | print(name) 254 | print(parameter) 255 | print('#'*80) 256 | 257 | 258 | if args.distributed: 259 | # For multiprocessing distributed, DistributedDataParallel constructor 260 | # should always set the single device scope, otherwise, 261 | # DistributedDataParallel will use all available devices. 262 | if args.gpu is not None: 263 | torch.cuda.set_device(args.gpu) 264 | model.cuda(args.gpu) 265 | # When using a single GPU per process and per 266 | # DistributedDataParallel, we need to divide the batch size 267 | # ourselves based on the total number of GPUs we have 268 | args.batch_size = int(args.batch_size / ngpus_per_node) 269 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 270 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 271 | else: 272 | model.cuda() 273 | # DistributedDataParallel will divide and allocate batch_size to all 274 | # available GPUs if device_ids are not set 275 | model = torch.nn.parallel.DistributedDataParallel(model) 276 | elif args.gpu is not None: 277 | torch.cuda.set_device(args.gpu) 278 | model = model.cuda(args.gpu) 279 | # comment out the following line for debugging 280 | raise NotImplementedError("Only DistributedDataParallel is supported.") 281 | else: 282 | # AllGather implementation (batch shuffle, queue update, etc.) in 283 | # this code only supports DistributedDataParallel. 284 | raise NotImplementedError("Only DistributedDataParallel is supported.") 285 | 286 | 287 | 288 | # define loss function (criterion) and optimizer 289 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 290 | criterion_scl = SupConLoss().cuda(args.gpu) 291 | 292 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 293 | momentum=args.momentum, 294 | weight_decay=args.weight_decay) 295 | # optimizer = torch.optim.AdamW(model.parameters(), args.lr, 296 | # weight_decay=args.weight_decay) 297 | 298 | # optionally resume from a checkpoint 299 | if args.resume: 300 | if os.path.isfile(args.resume): 301 | print("=> loading checkpoint '{}'".format(args.resume)) 302 | if args.gpu is None: 303 | checkpoint = torch.load(args.resume) 304 | else: 305 | # Map model to be loaded to specified single gpu. 306 | loc = 'cuda:{}'.format(args.gpu) 307 | checkpoint = torch.load(args.resume, map_location=loc) 308 | args.start_epoch = checkpoint['epoch'] 309 | model.load_state_dict(checkpoint['state_dict']) 310 | optimizer.load_state_dict(checkpoint['optimizer']) 311 | print("=> loaded checkpoint '{}' (epoch {})" 312 | .format(args.resume, checkpoint['epoch'])) 313 | else: 314 | print("=> no checkpoint found at '{}'".format(args.resume)) 315 | 316 | cudnn.benchmark = True 317 | 318 | 319 | # -------------------- 320 | # CONTRASTIVE TRANSFORM 321 | # -------------------- 322 | train_transform, test_transform = get_transform(args.dataset_name, image_size=args.image_size, args=args) 323 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 324 | 325 | # -------------------- 326 | # DATASETS 327 | # -------------------- 328 | train_dataset, eval_dataset, val_dataset = get_datasets(args.dataset_name, train_transform, test_transform) 329 | 330 | 331 | 332 | 333 | if args.distributed: 334 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 335 | eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset,shuffle=False) 336 | # val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset,shuffle=False) 337 | else: 338 | train_sampler = None 339 | eval_sampler = None 340 | # val_sampler = None 341 | 342 | train_loader = torch.utils.data.DataLoader( 343 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 344 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 345 | 346 | eval_loader = torch.utils.data.DataLoader( 347 | eval_dataset, batch_size=args.eval_batch_size, shuffle=False, 348 | sampler=eval_sampler, num_workers=args.workers, pin_memory=True) 349 | 350 | val_loader = torch.utils.data.DataLoader( 351 | val_dataset, num_workers=args.workers, batch_size=args.eval_batch_size, shuffle=False) 352 | 353 | best_val_acc = 0 354 | for epoch in range(args.start_epoch, args.epochs): 355 | 356 | # compute momentum features for center-cropped images 357 | features = compute_features(eval_loader, model, args) 358 | 359 | # placeholder for clustering result 360 | cluster_result = {} 361 | cluster_result['im2cluster'] = torch.zeros(args.num_multi_centroids, len(eval_dataset), dtype=torch.long).cuda() 362 | cluster_result['centroids'] = torch.zeros(int(args.num_cluster), args.mlp_dim).cuda() 363 | cluster_result['density'] = torch.zeros(int(args.num_cluster)).cuda() 364 | 365 | if args.gpu == 0: 366 | features[torch.norm(features,dim=1)>1.5] /= 2 #account for the few samples that are computed twice 367 | features = features.numpy() 368 | cluster_result = run_kmeans(features, args) #run kmeans clustering on master node 369 | # save the clustering result 370 | # torch.save(cluster_result,os.path.join(args.exp_dir, 'clusters_%d'%epoch)) 371 | 372 | dist.barrier() 373 | # broadcast clustering result 374 | print('----- broadcast -----') 375 | dist.broadcast(cluster_result['im2cluster'], 0, async_op=False) 376 | dist.broadcast(cluster_result['centroids'], 0, async_op=False) 377 | dist.broadcast(cluster_result['density'], 0, async_op=False) 378 | print("im2cluster:", cluster_result['im2cluster'].shape) 379 | print("centroids:", cluster_result['centroids'].shape) 380 | print("density:", cluster_result['density'].shape) 381 | 382 | 383 | if args.distributed: 384 | train_sampler.set_epoch(epoch) 385 | adjust_learning_rate(optimizer, epoch, args) 386 | 387 | # train for one epoch 388 | train(train_loader, model, criterion, criterion_scl, optimizer, epoch, cluster_result, args) 389 | 390 | with torch.no_grad(): 391 | print('Testing on val set...') 392 | all_acc, known_acc, novel_acc = test_kmeans(model, val_loader, args=args) 393 | print('Val Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, known_acc, novel_acc)) 394 | 395 | if known_acc > best_val_acc: 396 | print('Best val Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, known_acc, novel_acc)) 397 | best_val_acc = known_acc 398 | 399 | if (epoch + 1) % 10 == 0: 400 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 401 | and args.rank % ngpus_per_node == 0): 402 | save_checkpoint({ 403 | 'epoch': epoch + 1, 404 | 'arch': args.arch, 405 | 'state_dict': model.state_dict(), 406 | 'optimizer' : optimizer.state_dict(), 407 | }, is_best=False, save_path='save/{}/lr{}_scl{}_mcl{}_mc{}_e{}'.format(args.dataset_name, args.lr, args.scl_weight, args.mcl_weight, args.num_multi_centroids, 408 | args.epochs), filename='checkpoint_{:04d}.pth.tar'.format(epoch)) 409 | 410 | 411 | 412 | 413 | def train(train_loader, model, criterion, criterion_scl, optimizer, epoch, cluster_result, args): 414 | batch_time = AverageMeter('Time', ':6.3f') 415 | data_time = AverageMeter('Data', ':6.3f') 416 | losses = AverageMeter('Loss', ':.4e') 417 | loss_mcl = AverageMeter('loss_mcl', ':.4e') 418 | loss_scl = AverageMeter('loss_scl', ':.4e') 419 | acc_mcl = AverageMeter('Acc@Proto', ':6.2f') 420 | 421 | progress = ProgressMeter( 422 | len(train_loader), 423 | [losses, loss_mcl, loss_scl, acc_mcl], 424 | prefix="Epoch: [{}]".format(epoch)) 425 | 426 | # switch to train mode 427 | model.train() 428 | 429 | end = time.time() 430 | for i, (images, images_lb, labels, index) in enumerate(train_loader): 431 | # measure data loading time 432 | data_time.update(time.time() - end) 433 | 434 | if args.gpu is not None: 435 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 436 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 437 | images_lb[0] = images_lb[0].cuda(args.gpu, non_blocking=True) 438 | images_lb[1] = images_lb[1].cuda(args.gpu, non_blocking=True) 439 | labels = labels.cuda(args.gpu, non_blocking=True) 440 | 441 | 442 | # compute output 443 | scl_logits, scl_labels, mcl_logits, mcl_labels = model(im_q=images[0], im_q_lb=images_lb[0], im_k_lb=images_lb[1], 444 | targets=labels, is_eval=False, cluster_result=cluster_result, index=index) 445 | 446 | 447 | loss1 = criterion(mcl_logits, mcl_labels) 448 | acc = accuracy(mcl_logits, mcl_labels)[0] 449 | acc_mcl.update(acc[0], images[0].size(0)) 450 | 451 | loss2 = criterion_scl(scl_logits, scl_labels) 452 | loss = args.mcl_weight * loss1 + args.scl_weight * loss2 453 | 454 | 455 | losses.update(loss.item(), images[0].size(0)) 456 | loss_mcl.update(loss1.item(), images[0].size(0)) 457 | loss_scl.update(loss2.item(), images[0].size(0)) 458 | 459 | 460 | # compute gradient and do SGD step 461 | optimizer.zero_grad() 462 | loss.backward() 463 | optimizer.step() 464 | 465 | # measure elapsed time 466 | batch_time.update(time.time() - end) 467 | end = time.time() 468 | 469 | if i % args.print_freq == 0: 470 | progress.display(i) 471 | 472 | 473 | 474 | def save_checkpoint(state, is_best, save_path, filename='checkpoint.pth.tar'): 475 | if not os.path.exists(save_path): 476 | os.makedirs(save_path) 477 | torch.save(state, os.path.join(save_path, filename)) 478 | if is_best: 479 | shutil.copyfile(filename, 'model_best.pth.tar') 480 | 481 | def test_kmeans(model, test_loader, args): 482 | 483 | model.eval() 484 | 485 | all_feats = [] 486 | targets = np.array([]) 487 | mask = np.array([]) 488 | 489 | print('Collating features...') 490 | # First extract all features 491 | for batch_idx, (images, label) in enumerate(test_loader): 492 | 493 | images = images.cuda(args.gpu, non_blocking=True) 494 | 495 | # Pass features through base model and then additional learnable transform (linear layer) 496 | feats = model.module.forward_feature(images) 497 | 498 | feats = torch.nn.functional.normalize(feats, dim=-1) 499 | 500 | all_feats.append(feats.cpu().numpy()) 501 | targets = np.append(targets, label.cpu().numpy()) 502 | # distinguish labeled classes and 503 | mask = np.append(mask, np.array([True if x.item() in args.known_categories 504 | else False for x in label])) 505 | 506 | # ----------------------- 507 | # K-MEANS 508 | # ----------------------- 509 | print('Fitting K-Means...') 510 | all_feats = np.concatenate(all_feats) 511 | kmeans = KMeans(n_clusters=args.num_known_categories + args.num_novel_categories, random_state=0).fit(all_feats) 512 | preds = kmeans.labels_ 513 | print('Done!') 514 | 515 | # ----------------------- 516 | # EVALUATE 517 | # ----------------------- 518 | all_acc, known_acc, novel_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask) 519 | 520 | return all_acc, known_acc, novel_acc 521 | 522 | 523 | 524 | def compute_features(eval_loader, model, args): 525 | print('Computing features to cluster for mcl...') 526 | model.eval() 527 | features = torch.zeros(len(eval_loader.dataset), args.mlp_dim).cuda() 528 | for i, (images, index) in enumerate(eval_loader): 529 | with torch.no_grad(): 530 | images = images.cuda(non_blocking=True) 531 | feat = model(images,is_eval=True) 532 | features[index] = feat 533 | dist.barrier() 534 | dist.all_reduce(features, op=dist.ReduceOp.SUM) 535 | return features.cpu() 536 | 537 | 538 | def run_kmeans(x, args): 539 | """ 540 | Args: 541 | x: data to be clustered 542 | """ 543 | 544 | print('performing kmeans clustering') 545 | results = {} 546 | # intialize faiss clustering parameters 547 | d = x.shape[1] 548 | k = int(args.num_cluster) 549 | clus = faiss.Clustering(d, k) 550 | clus.verbose = True 551 | clus.niter = 20 552 | clus.nredo = 5 553 | clus.seed = 0 554 | clus.max_points_per_centroid = 1000 555 | clus.min_points_per_centroid = 10 556 | 557 | res = faiss.StandardGpuResources() 558 | res.setTempMemory(2048 * 1024 * 1024) 559 | cfg = faiss.GpuIndexFlatConfig() 560 | cfg.useFloat16 = False 561 | cfg.device = args.gpu 562 | index = faiss.GpuIndexFlatL2(res, d, cfg) 563 | 564 | clus.train(x, index) 565 | 566 | D, I = index.search(x, args.num_multi_centroids) # for each sample, find cluster distance and top-k assignments 567 | im2cluster = [] 568 | for i in range(args.num_multi_centroids): 569 | im2cluster_sub = [int(n[i]) for n in I] # [x.shape[0]] 570 | im2cluster.append(im2cluster_sub) 571 | if i == 0: 572 | im2cluster0 = im2cluster_sub 573 | 574 | 575 | # get cluster centroids 576 | centroids = faiss.vector_to_array(clus.centroids).reshape(k,d) 577 | 578 | # sample-to-centroid distances for each cluster 579 | Dcluster = [[] for c in range(k)] # [[], [], [], [], [] ...k] 580 | for im, i in enumerate(im2cluster0): 581 | Dcluster[i].append(D[im][0]) 582 | 583 | # concentration estimation (phi) 584 | density = np.zeros(k) 585 | for i,dist in enumerate(Dcluster): 586 | if len(dist)>1: 587 | d = (np.asarray(dist)**0.5).mean() 588 | density[i] = d 589 | 590 | #if cluster only has one point, use the max to estimate its concentration 591 | dmax = density.max() 592 | for i,dist in enumerate(Dcluster): 593 | if len(dist)<=1: 594 | density[i] = dmax 595 | 596 | density = density.clip(np.percentile(density,10),np.percentile(density,90)) #clamp extreme values for stability 597 | density = args.mcl_t * density/density.mean() #scale the mean to temperature 598 | 599 | 600 | 601 | # convert to cuda Tensors for broadcast 602 | centroids = torch.Tensor(centroids).cuda() 603 | centroids = nn.functional.normalize(centroids, p=2, dim=1) 604 | 605 | im2cluster = torch.LongTensor(im2cluster).cuda() 606 | density = torch.Tensor(density).cuda() 607 | 608 | results['centroids'] = centroids 609 | results['density'] = density 610 | results['im2cluster'] = im2cluster 611 | 612 | print("Kmean Done!") 613 | return results 614 | 615 | 616 | class AverageMeter(object): 617 | """Computes and stores the average and current value""" 618 | def __init__(self, name, fmt=':f'): 619 | self.name = name 620 | self.fmt = fmt 621 | self.reset() 622 | 623 | def reset(self): 624 | self.val = 0 625 | self.avg = 0 626 | self.sum = 0 627 | self.count = 0 628 | 629 | def update(self, val, n=1): 630 | self.val = val 631 | self.sum += val * n 632 | self.count += n 633 | self.avg = self.sum / self.count 634 | 635 | def __str__(self): 636 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 637 | return fmtstr.format(**self.__dict__) 638 | 639 | 640 | class ProgressMeter(object): 641 | def __init__(self, num_batches, meters, prefix=""): 642 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 643 | self.meters = meters 644 | self.prefix = prefix 645 | 646 | def display(self, batch): 647 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 648 | entries += [str(meter) for meter in self.meters] 649 | print('\t'.join(entries)) 650 | 651 | def _get_batch_fmtstr(self, num_batches): 652 | num_digits = len(str(num_batches // 1)) 653 | fmt = '{:' + str(num_digits) + 'd}' 654 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 655 | 656 | 657 | def adjust_learning_rate(optimizer, epoch, args): 658 | """Decay the learning rate based on schedule""" 659 | lr = args.lr 660 | if args.cos: # cosine lr schedule 661 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 662 | else: # stepwise lr schedule 663 | for milestone in args.schedule: 664 | lr *= 0.1 if epoch >= milestone else 1. 665 | for param_group in optimizer.param_groups: 666 | param_group['lr'] = lr 667 | 668 | 669 | def accuracy(output, target, topk=(1,)): 670 | """Computes the accuracy over the k top predictions for the specified values of k""" 671 | with torch.no_grad(): 672 | maxk = max(topk) 673 | batch_size = target.size(0) 674 | 675 | _, pred = output.topk(maxk, 1, True, True) 676 | pred = pred.t() 677 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 678 | 679 | res = [] 680 | for k in topk: 681 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 682 | res.append(correct_k.mul_(100.0 / batch_size)) 683 | return res 684 | 685 | 686 | if __name__ == '__main__': 687 | main() 688 | --------------------------------------------------------------------------------