├── requirements.txt ├── DeepDA ├── .DS_Store ├── requirements.txt ├── loss_funcs │ ├── .DS_Store │ ├── __init__.py │ ├── mmd.py │ ├── adv.py │ └── lmmd.py ├── run.sh ├── DSAN.yaml ├── utils.py ├── data_loader.py ├── transfer_losses.py ├── diff_transfer_losses.py ├── models.py ├── backbones.py └── dsan.py ├── AngularGapPre.pdf ├── acmmm2022_poster_mmfp2324.pdf ├── standard_curriculum_learning ├── prediction_depth │ ├── requirement.txt │ ├── run_pd.sh │ ├── README.md │ ├── plot_pd_hist.py │ ├── get_pd_vgg.py │ ├── knndnn.py │ └── get_pd_resnet.py ├── .DS_Store ├── orders │ ├── hsf.pt │ ├── angular_gap_order.npy │ ├── forgetting_events.pkl │ └── classification_margin.pt ├── utils │ ├── __init__.py │ ├── get_data.py │ └── utils.py ├── README.md ├── main_standard.py └── main_curriculum_learning.py ├── visualize ├── __init__.py ├── plot_reliability_diagrams.py └── plot_angular_space.py ├── LICIENCE ├── README.md ├── .gitignore ├── utils.py ├── angularloss.py ├── difficulty.py ├── models.py ├── calibration.py ├── resnet.py └── main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | torch 3 | torchvision 4 | ConfigArgParse==1.4.1 5 | -------------------------------------------------------------------------------- /DeepDA/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/DeepDA/.DS_Store -------------------------------------------------------------------------------- /DeepDA/requirements.txt: -------------------------------------------------------------------------------- 1 | ConfigArgParse==1.4.1 2 | torch==1.8.1 3 | torchvision==0.9.1 -------------------------------------------------------------------------------- /AngularGapPre.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/AngularGapPre.pdf -------------------------------------------------------------------------------- /DeepDA/loss_funcs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/DeepDA/loss_funcs/.DS_Store -------------------------------------------------------------------------------- /acmmm2022_poster_mmfp2324.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/acmmm2022_poster_mmfp2324.pdf -------------------------------------------------------------------------------- /DeepDA/loss_funcs/__init__.py: -------------------------------------------------------------------------------- 1 | from loss_funcs.mmd import * 2 | from loss_funcs.lmmd import * 3 | from loss_funcs.adv import * -------------------------------------------------------------------------------- /standard_curriculum_learning/prediction_depth/requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | json 3 | torch 4 | torchvision 5 | sklearn 6 | warnings -------------------------------------------------------------------------------- /standard_curriculum_learning/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/standard_curriculum_learning/.DS_Store -------------------------------------------------------------------------------- /standard_curriculum_learning/orders/hsf.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/standard_curriculum_learning/orders/hsf.pt -------------------------------------------------------------------------------- /standard_curriculum_learning/orders/angular_gap_order.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/standard_curriculum_learning/orders/angular_gap_order.npy -------------------------------------------------------------------------------- /standard_curriculum_learning/orders/forgetting_events.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/standard_curriculum_learning/orders/forgetting_events.pkl -------------------------------------------------------------------------------- /standard_curriculum_learning/orders/classification_margin.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengbohua/AngularGap/HEAD/standard_curriculum_learning/orders/classification_margin.pt -------------------------------------------------------------------------------- /standard_curriculum_learning/prediction_depth/run_pd.sh: -------------------------------------------------------------------------------- 1 | python3 get_pd_resnet_wsgn.py --result_dir ./cl_results_wsgn 2 | python3 get_pd_resnet.py --result_dir ./cl_results_resnet 3 | 4 | -------------------------------------------------------------------------------- /DeepDA/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | GPU_ID=0 3 | data_dir=/home/data/office31 4 | # Office31 5 | CUDA_VISIBLE_DEVICES=$GPU_ID python dsan.py --config ./DSAN.yaml --data_dir $data_dir --src_domain amazon --tgt_domain webcam -------------------------------------------------------------------------------- /visualize/__init__.py: -------------------------------------------------------------------------------- 1 | from visualize.plot_angular_space import plot2d, plot3d, ConvNet 2 | from visualize.plot_reliability_diagrams import plot_multiclass_reliability_diagram 3 | 4 | __all__ = ["plot3d", "plot2d", "plot_multiclass_reliability_diagram", "ConvNet"] 5 | -------------------------------------------------------------------------------- /DeepDA/DSAN.yaml: -------------------------------------------------------------------------------- 1 | # Backbone 2 | backbone: resnet50 3 | 4 | # Transfer loss related 5 | transfer_loss_weight: 0.5 6 | transfer_loss: lmmd 7 | 8 | # Optimizer related 9 | lr: 0.01 10 | weight_decay: 5e-4 11 | momentum: 0.9 12 | lr_scheduler: True 13 | lr_gamma: 0.0003 14 | lr_decay: 0.75 15 | 16 | # Training related 17 | n_iter_per_epoch: 500 18 | n_epoch: 20 19 | 20 | # Others 21 | seed: 1 22 | num_workers: 4 23 | -------------------------------------------------------------------------------- /standard_curriculum_learning/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .utils import get_model, get_optimizer, get_scheduler, LossTracker, AverageMeter, ProgressMeter, accuracy, balance_order_val,balance_order,get_pacing_function,run_cmd, shuffling_small_bucket 3 | from .get_data import get_dataset 4 | from .cifar_label import CIFAR100N 5 | from .cos_vis import plot_spheral_space, get_embeds 6 | __all__ = [ "get_dataset", "AverageMeter", "ProgressMeter", "accuracy", "get_optimizer", "get_scheduler", "get_model", "LossTracker","cifar_label","balance_order_val","balance_order","get_pacing_function","run_cmd", "plot_spheral_space", "get_embeds", 'shuffling_small_bucket'] 7 | -------------------------------------------------------------------------------- /standard_curriculum_learning/README.md: -------------------------------------------------------------------------------- 1 | # Standard curriculum learning evaluation 2 | We simplify current curriculum learning methods and use standard curriculum learning (Paced learning) to compare different precomputed image difficulty scores head-to-head. 3 | 4 | Image difficulty scores are saved in ./orders 5 | 6 | We provide an unofficial implementation of Prediction Depth. 7 | 8 | # Acknowledgement 9 | The implementation of standard curriculum learning borrow wisdom from the following works: 10 | 11 | [1] When Do Curricula work? https://arxiv.org/abs/2012.03107 [repo](https://github.com/google-research/understanding-curricula/blob/main/main_w_test.py) 12 | 13 | [2] Deep Learning Through the Lens of Example Difficulty. https://arxiv.org/abs/2106.09647 14 | -------------------------------------------------------------------------------- /LICIENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Anonymous 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Angular Gap 2 | This is the code necessary to run experiments described in the ACM MM'22 paper [Paper](https://arxiv.org/abs/2207.08525) 3 | ## Requirements 4 | All the required packages can be installed by running `pip install -r requirements.txt`. 5 | 6 | Or pull our domain adaptation docker image and run experiments with 7 | ``` 8 | docker pull marvinpeng2022/da-testbed 9 | ``` 10 | ## Difficulty estimation 11 | ```shell 12 | python main.py --dst cifar10 --arch resnet18 13 | ``` 14 | ## Visualization 15 | ```shell 16 | python main.py --dst cifar10 --arch visualization 17 | ``` 18 | ## Domain adaptation 19 | For domain adaptation, we have released our implementation of CRST and Curricular DSAN. 20 | ```shell 21 | cd DeepDA 22 | bash run.sh 23 | ``` 24 | ## Video 25 | [Presentation](https://files.atypon.com/acm/f7197189de64e2075eb0a2c2d1eee630) and [Slides](https://github.com/pengbohua/AngularGap/blob/main/AngularGapPre.pdf) 26 | 27 | If you make use of this code in your work, please cite the following paper: 28 | ``` 29 | @article{peng2022angular, 30 | title={Angular Gap: Reducing the Uncertainty of Image Difficulty through Model Calibration}, 31 | author={Peng, Bohua and Islam, Mobarakol and Tu, Mei}, 32 | journal={arXiv preprint arXiv:2207.08525}, 33 | year={2022} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /standard_curriculum_learning/prediction_depth/README.md: -------------------------------------------------------------------------------- 1 | # Unofficial Implementation of Prediction Depth 2 | This is a community script for [Deep Learning Through the Lens of Example Difficulty](https://arxiv.org/abs/2106.09647). 3 | 4 | ## requirement 5 | ```shell script 6 | pip3 install -r requirement.txt 7 | ``` 8 | ## Get Started 9 | ### Modify CIFAR10 to get index of data point (Important) 10 | Change __getitem__ of torchvision.datasets.CIFAR10 to output index of current data point 11 | ```python 12 | #130 return img, target 13 | return (img, target), index 14 | ``` 15 | Make a log directory for ResNet18 with Weight Standardization and Group Norm / original ResNet18 / VGG16 16 | ```shell script 17 | mkdir ./cl_results_resnet 18 | mkdir ./cl_results_vgg 19 | ``` 20 | Changing number of random seeds allows you to train more models to get average PD (line 284 in get_pd_resnet.py). 21 | Run training and plot the 2D histogram for train split and validation split afterwards. 22 | ```shell script 23 | python3 get_pd_resnet_wsgn.py --result_dir ./cl_results_wsgn --train_ratio 0.5 --knn_k 30 24 | python3 plot_pd_hist.py --result_dir ./cl_results_wsgn 25 | python3 get_pd_resnet.py --result_dir ./cl_results_resnet --train_ratio 0.5 --knn_k 30 26 | python3 plot_pd_hist.py --result_dir ./cl_results_resnet 27 | ``` 28 | 29 | ## Run PD in oneline 30 | Alternatively, run the following code to get all previous results in one line 31 | ```shell script 32 | sh run_pd.sh 33 | ``` -------------------------------------------------------------------------------- /DeepDA/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tensorboardX import SummaryWriter 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def str2bool(v): 25 | if isinstance(v, bool): 26 | return v 27 | if v.lower() in ("yes", "true", "t", "y", "1"): 28 | return True 29 | elif v.lower() in ("no", "false", "f", "n", "0"): 30 | return False 31 | else: 32 | raise ValueError("Boolean value expected.") 33 | 34 | 35 | class TensorboardLogger(object): 36 | def __init__(self, log_dir): 37 | self.writer = SummaryWriter(logdir=log_dir) 38 | self.step = 0 39 | 40 | def set_step(self, step=None): 41 | if step is not None: 42 | self.step = step 43 | else: 44 | self.step += 1 45 | 46 | def update(self, head="scalar", step=None, **kwargs): 47 | for k, v in kwargs.items(): 48 | if v is None: 49 | continue 50 | if isinstance(v, torch.Tensor): 51 | v = v.item() 52 | assert isinstance(v, (float, int)) 53 | self.writer.add_scalar( 54 | head + "/" + k, v, self.step if step is None else step 55 | ) 56 | 57 | def flush(self): 58 | self.writer.flush() 59 | -------------------------------------------------------------------------------- /visualize/plot_reliability_diagrams.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot_multiclass_reliability_diagram( 6 | y_true, p_pred, n_bins=15, title=None, fig=None, ax=None, legend=True 7 | ): 8 | """ 9 | y_true: needs to be (n_samples,) 10 | """ 11 | if fig is None and ax is None: 12 | fig = plt.figure() 13 | if ax is None: 14 | ax = fig.add_subplot(111) 15 | 16 | if title is not None: 17 | ax.set_title(title) 18 | 19 | y_true = y_true.flatten() 20 | p_pred = p_pred.flatten() 21 | 22 | bin_size = 1 / n_bins 23 | centers = np.linspace(bin_size / 2, 1.0 - bin_size / 2, n_bins) 24 | true_proportion = np.zeros(n_bins) 25 | 26 | pred_mean = np.zeros(n_bins) 27 | for i, center in enumerate(centers): 28 | if i == 0: 29 | # First bin include lower bound 30 | bin_indices = np.where( 31 | np.logical_and( 32 | p_pred >= center - bin_size / 2, p_pred <= center + bin_size / 2 33 | ) 34 | ) 35 | else: 36 | bin_indices = np.where( 37 | np.logical_and( 38 | p_pred > center - bin_size / 2, p_pred <= center + bin_size / 2 39 | ) 40 | ) 41 | true_proportion[i] = np.mean(y_true[bin_indices]) 42 | pred_mean[i] = np.mean(p_pred[bin_indices]) 43 | 44 | ax.bar( 45 | centers, 46 | true_proportion, 47 | width=bin_size, 48 | edgecolor="black", 49 | color="blue", 50 | label="True class prop.", 51 | ) 52 | ax.bar( 53 | centers, 54 | true_proportion - pred_mean, 55 | bottom=pred_mean, 56 | width=bin_size / 2, 57 | edgecolor="red", 58 | color="#ffc8c6", 59 | alpha=1, 60 | label="Gap pred. mean", 61 | ) 62 | if legend: 63 | ax.legend() 64 | 65 | ax.plot([0, 1], [0, 1], linestyle="--") 66 | ax.set_xlim([0, 1]) 67 | ax.set_ylim([0, 1]) 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /DeepDA/loss_funcs/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MMDLoss(nn.Module): 6 | def __init__( 7 | self, kernel_type="rbf", kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs 8 | ): 9 | super(MMDLoss, self).__init__() 10 | self.kernel_num = kernel_num 11 | self.kernel_mul = kernel_mul 12 | self.fix_sigma = None 13 | self.kernel_type = kernel_type 14 | 15 | def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma): 16 | n_samples = int(source.size()[0]) + int(target.size()[0]) 17 | total = torch.cat([source, target], dim=0) 18 | total0 = total.unsqueeze(0).expand( 19 | int(total.size(0)), int(total.size(0)), int(total.size(1)) 20 | ) 21 | total1 = total.unsqueeze(1).expand( 22 | int(total.size(0)), int(total.size(0)), int(total.size(1)) 23 | ) 24 | L2_distance = ((total0 - total1) ** 2).sum(2) 25 | if fix_sigma: 26 | bandwidth = fix_sigma 27 | else: 28 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 29 | bandwidth /= kernel_mul ** (kernel_num // 2) 30 | bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] 31 | kernel_val = [ 32 | torch.exp(-L2_distance / bandwidth_temp) 33 | for bandwidth_temp in bandwidth_list 34 | ] 35 | return sum(kernel_val) 36 | 37 | def linear_mmd2(self, f_of_X, f_of_Y): 38 | loss = 0.0 39 | delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0) 40 | loss = delta.dot(delta.T) 41 | return loss 42 | 43 | def forward(self, source, target): 44 | if self.kernel_type == "linear": 45 | return self.linear_mmd2(source, target) 46 | elif self.kernel_type == "rbf": 47 | batch_size = int(source.size()[0]) 48 | kernels = self.guassian_kernel( 49 | source, 50 | target, 51 | kernel_mul=self.kernel_mul, 52 | kernel_num=self.kernel_num, 53 | fix_sigma=self.fix_sigma, 54 | ) 55 | XX = torch.mean(kernels[:batch_size, :batch_size]) 56 | YY = torch.mean(kernels[batch_size:, batch_size:]) 57 | XY = torch.mean(kernels[:batch_size, batch_size:]) 58 | YX = torch.mean(kernels[batch_size:, :batch_size]) 59 | loss = torch.mean(XX + YY - XY - YX) 60 | return loss 61 | -------------------------------------------------------------------------------- /standard_curriculum_learning/utils/get_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.transforms as transforms 3 | import torchvision.datasets as datasets 4 | 5 | def get_dataset(dataset_name, data_dir, split, order=None, rand_fraction=None,clean=False, transform=None, imsize=None, bucket='pytorch-data', **kwargs): 6 | dataset = globals()[f'get_{dataset_name}'](dataset_name, data_dir, split, transform=imsize, imsize=imsize, bucket=bucket, **kwargs) 7 | 8 | item = dataset.__getitem__(0)[0] 9 | print (item.size(0)) 10 | dataset.nchannels = item.size(0) 11 | dataset.imsize = item.size(1) 12 | return dataset 13 | 14 | 15 | def get_aug(split, imsize=None, aug='large'): 16 | if aug == 'large': 17 | if split == 'train': 18 | return [transforms.RandomHorizontalFlip(0.5), 19 | transforms.Resize(224), 20 | ] 21 | else: 22 | # center crop down imagenet 23 | return [transforms.Resize(224)] 24 | else: 25 | imsize = imsize if imsize is not None else 32 26 | if split == 'train': 27 | train_transform = [] 28 | #return [transforms.RandomCrop(imsize, padding=round(imsize / 8))] 29 | train_transform.append(transforms.RandomCrop(32, padding=4)) 30 | train_transform.append(transforms.RandomHorizontalFlip()) 31 | return train_transform 32 | else: 33 | return [transforms.Resize(imsize), transforms.CenterCrop(imsize)] 34 | 35 | 36 | def get_transform(dataset_name, split, normalize=None, transform=None, imsize=None, aug='large'): 37 | if transform is None: 38 | if normalize is None: 39 | if aug == 'large': 40 | if 'cifar100' in dataset_name: 41 | normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 42 | else: 43 | # imagenet 44 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | else: 46 | if 'cifar10' in dataset_name: 47 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) 48 | if 'cifar100' in dataset_name: 49 | normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 50 | 51 | transform = transforms.Compose(get_aug(split, imsize=imsize, aug=aug) 52 | + [transforms.ToTensor(), normalize]) 53 | return transform 54 | 55 | # warning validation set and test set is the same thing in this implementation 56 | 57 | def get_cifar10(dataset_name, data_dir, split, transform=None, imsize=None, bucket='pytorch-data', **kwargs): 58 | transform = get_transform(dataset_name, split, transform=transform, imsize=imsize, aug='small') 59 | return datasets.CIFAR10(data_dir, train=(split=='train'), transform=transform, download=True, **kwargs) 60 | 61 | def get_cifar100(dataset_name, data_dir, split, transform=None, imsize=None, bucket='pytorch-data', **kwargs): 62 | transform = get_transform(dataset_name, split, transform=transform, imsize=imsize, aug='small') 63 | return datasets.CIFAR100(data_dir, train=(split=='train'), transform=transform, download=True,) 64 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import os 4 | import time 5 | import numpy as np 6 | import json 7 | 8 | 9 | class LossTracker(object): 10 | def __init__(self, num, prefix="", print_freq=1): 11 | self.print_freq = print_freq 12 | self.batch_time = AverageMeter("Time", ":6.3f") 13 | self.losses = AverageMeter("Loss", ":.4e") 14 | self.top1 = AverageMeter("Acc@1", ":6.2f") 15 | self.top5 = AverageMeter("Acc@5", ":6.2f") 16 | self.progress = ProgressMeter( 17 | num, [self.batch_time, self.losses, self.top1, self.top5], prefix=prefix 18 | ) 19 | self.end = time.time() 20 | 21 | def update(self, loss, output, target): 22 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 23 | self.losses.update(loss.item(), output.size(0)) 24 | self.top1.update(acc1[0], output.size(0)) 25 | self.top5.update(acc5[0], output.size(0)) 26 | 27 | def display(self, step): 28 | self.batch_time.update(time.time() - self.end) 29 | self.end = time.time() 30 | if step % self.print_freq == 0: 31 | self.progress.display(step) 32 | 33 | 34 | class AverageMeter(object): 35 | """Computes and stores the average and current value""" 36 | 37 | def __init__(self, name, fmt=":f"): 38 | self.name = name 39 | self.fmt = fmt 40 | self.reset() 41 | 42 | def reset(self): 43 | self.val = 0 44 | self.avg = 0 45 | self.sum = 0 46 | self.count = 0 47 | 48 | def update(self, val, n=1): 49 | self.val = val 50 | self.sum += val * n 51 | self.count += n 52 | self.avg = self.sum / self.count 53 | 54 | def __str__(self): 55 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 56 | return fmtstr.format(**self.__dict__) 57 | 58 | 59 | class ProgressMeter(object): 60 | def __init__(self, num_batches, meters, prefix=""): 61 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 62 | self.meters = meters 63 | self.prefix = prefix 64 | 65 | def display(self, batch): 66 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 67 | entries += [str(meter) for meter in self.meters] 68 | print("\t".join(entries), flush=True) 69 | 70 | def _get_batch_fmtstr(self, num_batches): 71 | num_digits = len(str(num_batches // 1)) 72 | fmt = "{:" + str(num_digits) + "d}" 73 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 74 | 75 | 76 | def accuracy(output, target, topk=(1,)): 77 | """Computes the accuracy over the k top predictions for the specified values of k""" 78 | with torch.no_grad(): 79 | maxk = max(topk) 80 | batch_size = target.size(0) 81 | 82 | _, pred = output.topk(maxk, 1, True, True) 83 | pred = pred.t() 84 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 85 | 86 | res = [] 87 | for k in topk: 88 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 89 | res.append(correct_k.mul_(100.0 / batch_size)) 90 | return res 91 | -------------------------------------------------------------------------------- /standard_curriculum_learning/prediction_depth/plot_pd_hist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import matplotlib.pyplot as plt 5 | import argparse 6 | import os 7 | 8 | parser = argparse.ArgumentParser(description='arguments to compute prediction depth for each data sample') 9 | parser.add_argument('--result_dir', default='./cl_results_wsgn', type=str, help='directory to save ckpt and results') 10 | parser.add_argument('--arch', default='resnet', type=str, help='arch for prediction depth') 11 | parser.add_argument('--knn_k', default=30, type=int, help='k nearest neighbors of knn classifier') 12 | parser.add_argument('--num_samples', default=10000, type=int, help='number samples of current dst') 13 | 14 | args = parser.parse_args() 15 | 16 | seeds = [1111, 2222, 3333, 4444, 5555, 6666] 17 | 18 | arch = args.arch 19 | pd_dir = args.result_dir 20 | 21 | print('computing prediction depth in train split') 22 | pd_train_split = np.zeros((len(seeds), args.num_samples)) 23 | for i, sd in enumerate(seeds): 24 | 25 | f = os.path.join(pd_dir, '{}train_seed{}_f_trainpd.pkl'.format(arch, sd)) 26 | with open(f, 'r') as p: 27 | pd_dict = json.load(p) 28 | for k, v in pd_dict.items(): 29 | pd_train_split[i, int(k)] = v[0] 30 | f = os.path.join(pd_dir, '{}train_seed{}_fflip_trainpd.pkl'.format(arch, sd)) 31 | with open(f, 'r') as p: 32 | pd_dict = json.load(p) 33 | for k, v in pd_dict.items(): 34 | pd_train_split[i, int(k)] = v[0] 35 | 36 | print(pd_train_split.shape) 37 | pd_train_split_avg = pd_train_split.mean(0) 38 | train_split_small_pds = np.where((pd_train_split_avg >1) & (pd_train_split_avg <= 2))[0] 39 | 40 | print('computing prediction depth in test split') 41 | pd_test_split = np.zeros((len(seeds), args.num_samples)) 42 | for i, sd in enumerate(seeds): 43 | f = os.path.join(pd_dir, '{}_seed{}_f_test_pd.pkl'.format(arch, sd)) 44 | with open(f, 'r') as p: 45 | pd_dict = json.load(p) 46 | for k, v in pd_dict.items(): 47 | pd_test_split[i, int(k)] = v[0] 48 | 49 | f = os.path.join(pd_dir, '{}_seed{}_fflip_test_pd.pkl'.format(arch, sd)) 50 | with open(f, 'r') as p: 51 | pd_dict = json.load(p) 52 | for k, v in pd_dict.items(): 53 | pd_test_split[i, int(k)] = v[0] 54 | 55 | def show_sample(index, dataset): 56 | img, _ = dataset[index] 57 | img = img.permute(1,2,0).numpy() 58 | plt.imshow(img) 59 | plt.savefig('./easy_samples/img{}.png'.format(index)) 60 | plt.show() 61 | 62 | pd_test_split_avg = pd_test_split.mean(0) 63 | 64 | 65 | H, x_edges, y_edges = np.histogram2d(pd_test_split_avg - 1, pd_train_split_avg - 1, bins=(np.linspace(0, 9, 50), np.linspace(0, 9, 50))) 66 | plt.figure() 67 | H[H < 1e-7] = np.nan 68 | H = H.T 69 | X, Y = np.meshgrid(x_edges, y_edges) 70 | plt.pcolormesh(X, Y, H) 71 | plt.xlabel('validation split prediction depth') 72 | plt.ylabel('train split prediction depth') 73 | plt.colorbar() 74 | plt.savefig(os.path.join(pd_dir, '/prediction_depth_12resnet{}.png').format(args.knn_k)) 75 | plt.show() -------------------------------------------------------------------------------- /DeepDA/loss_funcs/adv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class LambdaSheduler(nn.Module): 8 | def __init__(self, gamma=1.0, max_iter=1000, **kwargs): 9 | super(LambdaSheduler, self).__init__() 10 | self.gamma = gamma 11 | self.max_iter = max_iter 12 | self.curr_iter = 0 13 | 14 | def lamb(self): 15 | p = self.curr_iter / self.max_iter 16 | lamb = 2. / (1. + np.exp(-self.gamma * p)) - 1 17 | return lamb 18 | 19 | def step(self): 20 | self.curr_iter = min(self.curr_iter + 1, self.max_iter) 21 | 22 | class AdversarialLoss(nn.Module): 23 | ''' 24 | Acknowledgement: The adversarial loss implementation is inspired by http://transfer.thuml.ai/ 25 | ''' 26 | def __init__(self, gamma=1.0, max_iter=1000, use_lambda_scheduler=True, **kwargs): 27 | super(AdversarialLoss, self).__init__() 28 | self.domain_classifier = Discriminator() 29 | self.use_lambda_scheduler = use_lambda_scheduler 30 | if self.use_lambda_scheduler: 31 | self.lambda_scheduler = LambdaSheduler(gamma, max_iter) 32 | 33 | def forward(self, source, target): 34 | lamb = 1.0 35 | if self.use_lambda_scheduler: 36 | lamb = self.lambda_scheduler.lamb() 37 | self.lambda_scheduler.step() 38 | source_loss = self.get_adversarial_result(source, True, lamb) 39 | target_loss = self.get_adversarial_result(target, False, lamb) 40 | adv_loss = 0.5 * (source_loss + target_loss) 41 | return adv_loss 42 | 43 | def get_adversarial_result(self, x, source=True, lamb=1.0): 44 | x = ReverseLayerF.apply(x, lamb) 45 | domain_pred = self.domain_classifier(x) 46 | device = domain_pred.device 47 | if source: 48 | domain_label = torch.ones(len(x), 1).long() 49 | else: 50 | domain_label = torch.zeros(len(x), 1).long() 51 | loss_fn = nn.BCELoss() 52 | loss_adv = loss_fn(domain_pred, domain_label.float().to(device)) 53 | return loss_adv 54 | 55 | 56 | class ReverseLayerF(Function): 57 | @staticmethod 58 | def forward(ctx, x, alpha): 59 | ctx.alpha = alpha 60 | return x.view_as(x) 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | output = grad_output.neg() * ctx.alpha 65 | return output, None 66 | 67 | class Discriminator(nn.Module): 68 | def __init__(self, input_dim=256, hidden_dim=256): 69 | super(Discriminator, self).__init__() 70 | self.input_dim = input_dim 71 | self.hidden_dim = hidden_dim 72 | layers = [ 73 | nn.Linear(input_dim, hidden_dim), 74 | nn.BatchNorm1d(hidden_dim), 75 | nn.ReLU(), 76 | nn.Linear(hidden_dim, hidden_dim), 77 | nn.BatchNorm1d(hidden_dim), 78 | nn.ReLU(), 79 | nn.Linear(hidden_dim, 1), 80 | nn.Sigmoid() 81 | ] 82 | self.layers = torch.nn.Sequential(*layers) 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | -------------------------------------------------------------------------------- /angularloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AngularLoss(nn.Module): 7 | def __init__( 8 | self, in_features, out_features, loss_type="arcface", eps=1e-7, s=None, m=None 9 | ): 10 | """ 11 | AngularLoss 12 | Four 'loss_types' available: ['softmax loss', 'normalized softmax loss', 'arcface loss', 'cosface'] 13 | ArcFace: https://arxiv.org/abs/1801.07698 14 | CosFace: https://arxiv.org/abs/1801.05599 15 | """ 16 | super(AngularLoss, self).__init__() 17 | self.num_classes = out_features 18 | loss_type = loss_type.lower() 19 | assert loss_type in ["sl", "nsl", "arcface", "cosface"] 20 | if loss_type == "arcface": 21 | self.s = torch.ones(1) * 64.0 if not s else torch.ones(1) * s 22 | self.m = torch.ones(1, device="cuda") * 0.5 if not m else torch.ones(1) * m 23 | if loss_type == "cosface": 24 | self.s = torch.ones(1) * 30.0 if not s else torch.ones(1) * s 25 | self.m = torch.ones(1) * 0.35 if not m else torch.ones(1) * m 26 | if loss_type == "nsl": 27 | self.s = torch.ones(1) * 15.0 if not s else torch.ones(1) * s 28 | self.m = torch.ones(1) * 0.0 if not m else torch.ones(1) * m 29 | self.loss_type = loss_type 30 | self.in_features = in_features 31 | self.out_features = out_features 32 | self.fc = nn.Linear(in_features, out_features, bias=False) 33 | self.fc_softmax = nn.Linear(in_features, out_features, bias=True) 34 | self.eps = eps 35 | 36 | def forward(self, x, labels): 37 | """ 38 | input shape (N, in_features) 39 | """ 40 | 41 | assert len(x) == len(labels) 42 | assert torch.min(labels) >= 0 43 | assert torch.max(labels) < self.out_features 44 | if self.loss_type == "sl": 45 | logits = self.fc_softmax(x) 46 | W = F.normalize(self.fc_softmax.weight.data, p=2, dim=1) # C x F 47 | x = F.normalize(x, p=2, dim=1) # B x F 48 | cosine = torch.mm(x, W.t().contiguous()) # B x C 49 | return logits, cosine.detach() 50 | 51 | W = F.normalize(self.fc.weight.data, p=2, dim=1) # C x F 52 | 53 | x = F.normalize(x, p=2, dim=1) # B x F 54 | # cosine dists 55 | cosine = torch.mm(x, W.t().contiguous()) # B x C 56 | 57 | # move s, m to the same device as wf 58 | self.s = self.s.to(x.device) 59 | self.m = self.m.to(x.device) 60 | 61 | if self.loss_type == "nsl": 62 | logits = cosine * self.s 63 | elif self.loss_type == "cosface": 64 | m_hot = nn.functional.one_hot(labels, num_classes=self.num_classes) * self.m 65 | cosine_m = cosine - m_hot 66 | logits = cosine_m * self.s 67 | elif self.loss_type == "arcface": 68 | m_hot = nn.functional.one_hot(labels, num_classes=self.num_classes) * self.m 69 | cosine_m = cosine.clamp(-1.0 + self.eps, 1 - self.eps).acos() 70 | cosine_m += m_hot 71 | logits = cosine_m.cos() * self.s 72 | else: 73 | raise NotImplementedError 74 | return logits, cosine.detach() 75 | -------------------------------------------------------------------------------- /difficulty.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn.functional as F 4 | from scipy.stats import kendalltau, spearmanr 5 | import numpy as np 6 | 7 | 8 | def angular_gap(cos_dists, label_onehot, posterior=None): 9 | 10 | batch_size, num_cls = cos_dists.shape 11 | if posterior is None: 12 | posterior = torch.ones(1, num_cls, device=label_onehot.device) 13 | targets_cosine = torch.sum(cos_dists * label_onehot, 1) 14 | min_excl_cosine, min_angle_excl_idx = torch.max( 15 | cos_dists 16 | * (torch.ones_like(label_onehot, device=label_onehot.device) - label_onehot), 17 | dim=1, 18 | ) 19 | 20 | cos_margin = targets_cosine - min_excl_cosine 21 | targets_angle = torch.acos(torch.clamp(targets_cosine, -1.0 + 1e-7, 1.0 - 1e-7)) 22 | min_excl_angle = torch.acos(torch.clamp(min_excl_cosine, -1.0 + 1e-7, 1.0 - 1e-7)) 23 | s_max_excl = ( 24 | posterior.expand(batch_size, num_cls) 25 | .gather(1, min_angle_excl_idx[:, None]) 26 | .squeeze() 27 | ) 28 | s_y = torch.sum(posterior.expand(batch_size, num_cls) * label_onehot, 1) 29 | 30 | angular_margin = targets_angle * s_y - min_excl_angle * s_max_excl 31 | return targets_cosine, cos_margin, targets_angle, angular_margin 32 | 33 | 34 | def avh(cosine_dists, targets): 35 | """' 36 | @param cosine_dists: B x C 37 | @param targets: C 38 | @return: 39 | """ 40 | ang_dists = torch.acos(torch.clamp(cosine_dists, -1.0 + 1e-7, 1.0 - 1e-7)) 41 | avh = ( 42 | ang_dists.gather(1, targets[:, None]) / ang_dists.sum(1, keepdim=True).squeeze() 43 | ) 44 | return avh 45 | 46 | 47 | def get_confidence_output_margin(logits, label_onehot): 48 | confidence = F.softmax(logits, 1) 49 | targets_conf = torch.sum(confidence * label_onehot, 1) 50 | 51 | max_excl_conf, _ = torch.max( 52 | confidence 53 | * (torch.ones_like(label_onehot, device=label_onehot.device) - label_onehot), 54 | dim=1, 55 | ) 56 | conf_output_margin = targets_conf - max_excl_conf 57 | return targets_conf, conf_output_margin 58 | 59 | 60 | def embed_norm(cls_embeddings: Tensor) -> Tensor: 61 | """ 62 | 63 | @param cls_embeddings: N (C) x F 64 | @return: l2 norm as a scalar 65 | """ 66 | return torch.norm(cls_embeddings, dim=1) 67 | 68 | 69 | def get_kendalltau(score1, score2, standard_order=True): 70 | if standard_order: 71 | return kendalltau(score1, score2) 72 | elif isinstance(score1, dict) and isinstance(score2, dict): 73 | res = np.zeros(len(score1), 2) 74 | for (k0, v0), (k1, v1) in zip(score1.items(), score2.items()): 75 | res[k0][0] = v1 76 | res[k1][1] = v1 77 | return kendalltau(res[:, 0], res[:, 1]) 78 | 79 | 80 | def get_spearman(score1, score2, standard_order=True): 81 | """ 82 | eq: 1 - 6*sum((s1-s2)**2)/(n*(n**2-1)) 83 | @param cls_embeddings: N x F 84 | @return: l2 norm as a scalar 85 | """ 86 | if standard_order: 87 | return spearmanr(score1, score2) 88 | elif isinstance(score1, dict) and isinstance(score2, dict): 89 | res = np.zeros(len(score1), 2) 90 | for (k0, v0), (k1, v1) in zip(score1.items(), score2.items()): 91 | res[k0][0] = v1 92 | res[k1][1] = v1 93 | return spearmanr(res[:, 0], res[:, 1]) 94 | -------------------------------------------------------------------------------- /DeepDA/data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | import torch 3 | 4 | 5 | def load_data(data_folder, batch_size, train, num_workers=0, **kwargs): 6 | transform = { 7 | "train": transforms.Compose( 8 | [ 9 | transforms.Resize([256, 256]), 10 | transforms.RandomCrop(224), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ColorJitter( 13 | brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5 14 | ), 15 | transforms.ToTensor(), 16 | transforms.Normalize( 17 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 18 | ), 19 | ] 20 | ), 21 | "test": transforms.Compose( 22 | [ 23 | transforms.Resize([224, 224]), 24 | transforms.ToTensor(), 25 | transforms.Normalize( 26 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 27 | ), 28 | ] 29 | ), 30 | } 31 | data = datasets.ImageFolder( 32 | root=data_folder, transform=transform["train" if train else "test"] 33 | ) 34 | data_loader = get_data_loader( 35 | data, 36 | batch_size=batch_size, 37 | shuffle=True if train else False, 38 | num_workers=num_workers, 39 | **kwargs, 40 | drop_last=True if train else False 41 | ) 42 | n_class = len(data.classes) 43 | return data_loader, n_class 44 | 45 | 46 | def get_data_loader( 47 | dataset, 48 | batch_size, 49 | shuffle=True, 50 | drop_last=False, 51 | num_workers=0, 52 | infinite_data_loader=False, 53 | **kwargs 54 | ): 55 | if not infinite_data_loader: 56 | return torch.utils.data.DataLoader( 57 | dataset, 58 | batch_size=batch_size, 59 | shuffle=True, 60 | drop_last=drop_last, 61 | num_workers=num_workers, 62 | **kwargs 63 | ) 64 | else: 65 | return InfiniteDataLoader( 66 | dataset, 67 | batch_size=batch_size, 68 | shuffle=True, 69 | drop_last=drop_last, 70 | num_workers=num_workers, 71 | **kwargs 72 | ) 73 | 74 | 75 | class _InfiniteSampler(torch.utils.data.Sampler): 76 | """Wraps another Sampler to yield an infinite stream.""" 77 | 78 | def __init__(self, sampler): 79 | self.sampler = sampler 80 | 81 | def __iter__(self): 82 | while True: 83 | for batch in self.sampler: 84 | yield batch 85 | 86 | 87 | class InfiniteDataLoader: 88 | def __init__( 89 | self, 90 | dataset, 91 | batch_size, 92 | shuffle=True, 93 | drop_last=False, 94 | num_workers=0, 95 | weights=None, 96 | **kwargs 97 | ): 98 | if weights is not None: 99 | sampler = torch.utils.data.WeightedRandomSampler( 100 | weights, replacement=False, num_samples=batch_size 101 | ) 102 | else: 103 | sampler = torch.utils.data.RandomSampler(dataset, replacement=False) 104 | 105 | batch_sampler = torch.utils.data.BatchSampler( 106 | sampler, batch_size=batch_size, drop_last=drop_last 107 | ) 108 | 109 | self._infinite_iterator = iter( 110 | torch.utils.data.DataLoader( 111 | dataset, 112 | num_workers=num_workers, 113 | batch_sampler=_InfiniteSampler(batch_sampler), 114 | ) 115 | ) 116 | 117 | def __iter__(self): 118 | while True: 119 | yield next(self._infinite_iterator) 120 | 121 | def __len__(self): 122 | return 0 # Always return 0 123 | -------------------------------------------------------------------------------- /visualize/plot_angular_space.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | 7 | def plot2d(embeds, labels, num_classes, fig_path="./unit2d.pdf"): 8 | 9 | fig = plt.figure(figsize=(10, 10)) 10 | ax = fig.add_subplot(111) 11 | xlabels = [ 12 | "airplane", 13 | "car", 14 | "bird", 15 | "cat", 16 | "deer", 17 | "dog", 18 | "fog", 19 | "horse", 20 | "ship", 21 | "truck", 22 | ] 23 | 24 | embeds = F.normalize(embeds, dim=1) 25 | embeds = embeds.cpu().numpy() 26 | for i in range(num_classes): 27 | ax.scatter( 28 | embeds[labels == i, 0], embeds[labels == i, 1], label=xlabels[i], s=10 29 | ) 30 | 31 | ax.set_xlim([-1.1, 1.1]) 32 | ax.set_ylim([-1.1, 1.1]) 33 | plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 34 | ax.set_aspect(1) 35 | plt.subplots_adjust(right=0.75) 36 | plt.savefig(fig_path) 37 | plt.show() 38 | 39 | 40 | def plot3d(embeds, labels, num_classes, fig_path="./unit3d.pdf"): 41 | 42 | fig = plt.figure(figsize=(10, 10)) 43 | ax = fig.add_subplot(111, projection="3d") 44 | 45 | # Create a sphere 46 | r = 1 47 | pi = np.pi 48 | cos = np.cos 49 | sin = np.sin 50 | phi, theta = np.mgrid[0.0:pi:100j, 0.0 : 2.0 * pi : 100j] 51 | # theta = np.zeros_like(theta) 52 | x = r * sin(phi) * cos(theta) 53 | y = r * sin(phi) * sin(theta) 54 | z = r * cos(phi) 55 | ax.plot_surface(x, y, z, rstride=1, cstride=1, color="w", alpha=0.3, linewidth=0) 56 | embeds = F.normalize(embeds, dim=1) 57 | embeds = embeds.cpu().numpy() 58 | xlabels = [ 59 | "airplane", 60 | "car", 61 | "bird", 62 | "cat", 63 | "deer", 64 | "dog", 65 | "fog", 66 | "horse", 67 | "ship", 68 | "truck", 69 | ] 70 | for i in range(num_classes): 71 | ax.scatter( 72 | embeds[labels == i, 0], 73 | embeds[labels == i, 1], 74 | embeds[labels == i, 1], 75 | label=xlabels[i], 76 | s=10, 77 | ) 78 | 79 | ax.set_xlim([-1, 1]) 80 | ax.set_ylim([-1, 1]) 81 | ax.set_zlim([-1, 1]) 82 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left", fontsize="small") 83 | plt.tight_layout() 84 | plt.savefig(fig_path) 85 | plt.show() 86 | 87 | 88 | class ConvNet(nn.Module): 89 | def __init__(self, latent_dim): 90 | super(ConvNet, self).__init__() 91 | self.latent_dim = latent_dim 92 | self.layer1 = nn.Sequential( 93 | nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=0), 94 | nn.ReLU(), 95 | nn.BatchNorm2d(32), 96 | ) 97 | self.layer2 = nn.Sequential( 98 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0), 99 | nn.ReLU(), 100 | nn.BatchNorm2d(64), 101 | ) 102 | self.layer3 = nn.Sequential( 103 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 104 | nn.ReLU(), 105 | nn.BatchNorm2d(128), 106 | nn.MaxPool2d(kernel_size=2, stride=2), 107 | ) 108 | self.layer4 = nn.Sequential( 109 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0), 110 | nn.ReLU(), 111 | nn.BatchNorm2d(256), 112 | ) 113 | self.layer5 = nn.Sequential( 114 | nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 115 | nn.ReLU(), 116 | nn.BatchNorm2d(512), 117 | nn.MaxPool2d(kernel_size=8, stride=1), 118 | ) 119 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 120 | self.fc_projection = nn.Linear(512, latent_dim) 121 | 122 | def forward(self, x): 123 | batch_size = x.shape[0] 124 | x = self.layer1(x) 125 | x = self.layer2(x) 126 | x = self.layer3(x) 127 | x = self.layer4(x) 128 | x = self.layer5(x) 129 | x = self.avg_pool(x) 130 | x = x.view(batch_size, -1) 131 | x = self.fc_projection(x) 132 | return x 133 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from angularloss import AngularLoss 4 | from torchvision.models import vgg16_bn, alexnet 5 | from resnet import resnet18, resnet34, resnet50, resnet101 6 | from visualize import ConvNet 7 | 8 | device = "cuda" if torch.cuda.is_available() else "cpu" 9 | 10 | 11 | class Squeeze(nn.Module): 12 | def forward(self, x): 13 | if len(x.shape) == 4: 14 | return x.squeeze(2).squeeze(2) 15 | elif len(x.shape) == 2: 16 | return x 17 | else: 18 | raise ValueError("invalid input shape") 19 | 20 | 21 | class Baseline(nn.Module): 22 | def __init__(self, num_classes=10, latent_dim=512, arch="visualization"): 23 | super(Baseline, self).__init__() 24 | self.num_classes = num_classes 25 | if arch == "visualization": 26 | self.convlayers = ConvNet(latent_dim=3) 27 | self.fc_final = nn.Linear(3, num_classes) 28 | elif arch == "vgg16": 29 | self.feat = vgg16_bn(pretrained=True, num_classes=latent_dim) 30 | self.fc_final = nn.Linear(latent_dim, num_classes) 31 | elif arch == "resnet18": 32 | self.convlayers = resnet18(pretrained=True, num_classes=latent_dim) 33 | self.fc_final = nn.Linear(latent_dim, num_classes) 34 | elif arch == "resnet34": 35 | self.convlayers = resnet34(pretrained=True, num_classes=latent_dim) 36 | self.fc_final = nn.Linear(latent_dim, num_classes) 37 | elif arch == "resnet50": 38 | self.convlayers = resnet50(pretrained=True, num_classes=latent_dim) 39 | self.fc_final = nn.Linear(latent_dim, num_classes) 40 | elif arch == "resnet101": 41 | self.convlayers = resnet101(pretrained=True, num_classes=latent_dim) 42 | self.fc_final = nn.Linear(latent_dim, num_classes) 43 | 44 | def forward(self, x, embed=False): 45 | x = self.convlayers(x) 46 | if embed: 47 | return x 48 | x = self.fc_final(x) 49 | return x 50 | 51 | 52 | class AngularNet(nn.Module): 53 | def __init__( 54 | self, 55 | num_classes=10, 56 | loss_type="nsl", 57 | arch="visualization", 58 | latent_dim=512, 59 | s=None, 60 | m=None, 61 | ): 62 | super(AngularNet, self).__init__() 63 | self.num_classes = num_classes 64 | if arch == "visualization": 65 | self.feat = ConvNet(latent_dim=3) 66 | self.feat.out_features = 3 67 | self.angular_loss = AngularLoss( 68 | 3, num_classes, loss_type=loss_type, s=s, m=m 69 | ) 70 | elif arch == "alexnet": 71 | self.feat = alexnet(pretrained=True) 72 | self.angular_loss = AngularLoss( 73 | latent_dim, num_classes, loss_type=loss_type, s=s, m=m 74 | ) 75 | elif arch == "vgg16": 76 | self.feat = vgg16_bn(pretrained=True) 77 | self.angular_loss = AngularLoss( 78 | latent_dim, num_classes, loss_type=loss_type, s=s, m=m 79 | ) 80 | elif arch == "resnet18": 81 | self.feat = resnet18(pretrained=True) 82 | self.angular_loss = AngularLoss( 83 | latent_dim, num_classes, loss_type=loss_type, s=s, m=m 84 | ) 85 | elif arch == "resnet34": 86 | self.feat = resnet34(pretrained=True) 87 | self.angular_loss = AngularLoss( 88 | latent_dim, num_classes, loss_type=loss_type, s=s, m=m 89 | ) 90 | elif arch == "resnet50": 91 | self.feat = resnet50(pretrained=True) 92 | self.angular_loss = AngularLoss( 93 | latent_dim, num_classes, loss_type=loss_type, s=s, m=m 94 | ) 95 | elif arch == "resnet101": 96 | self.feat = resnet101(pretrained=True) 97 | self.angular_loss = AngularLoss( 98 | latent_dim, num_classes, loss_type=loss_type, s=s, m=m 99 | ) 100 | else: 101 | raise NotImplementedError 102 | if arch == "visualization": 103 | self.linear_project = nn.Identity() 104 | else: 105 | self.linear_project = nn.Linear(self.feat.out_features, latent_dim) 106 | 107 | def forward(self, x, labels=None, embed=False): 108 | x = self.feat(x) 109 | x = self.linear_project(x) 110 | if embed: 111 | return x 112 | else: 113 | logits, cos = self.angular_loss(x, labels) 114 | return logits, cos 115 | -------------------------------------------------------------------------------- /DeepDA/loss_funcs/lmmd.py: -------------------------------------------------------------------------------- 1 | from loss_funcs.mmd import MMDLoss 2 | from loss_funcs.adv import LambdaSheduler 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class LMMDLoss(MMDLoss, LambdaSheduler): 8 | def __init__( 9 | self, 10 | num_class, 11 | kernel_type="rbf", 12 | kernel_mul=2.0, 13 | kernel_num=5, 14 | fix_sigma=None, 15 | gamma=1.0, 16 | max_iter=1000, 17 | **kwargs 18 | ): 19 | """ 20 | Local MMD 21 | """ 22 | super(LMMDLoss, self).__init__( 23 | kernel_type, kernel_mul, kernel_num, fix_sigma, **kwargs 24 | ) 25 | super(MMDLoss, self).__init__(gamma, max_iter, **kwargs) 26 | self.num_class = num_class 27 | 28 | def forward(self, source, target, source_label, target_logits): 29 | if self.kernel_type == "linear": 30 | raise NotImplementedError("Linear kernel is not supported yet.") 31 | 32 | elif self.kernel_type == "rbf": 33 | batch_size = source.size()[0] 34 | weight_ss, weight_tt, weight_st = self.cal_weight( 35 | source_label, target_logits 36 | ) 37 | weight_ss = torch.from_numpy(weight_ss).cuda() # B, B 38 | weight_tt = torch.from_numpy(weight_tt).cuda() 39 | weight_st = torch.from_numpy(weight_st).cuda() 40 | 41 | kernels = self.guassian_kernel( 42 | source, 43 | target, 44 | kernel_mul=self.kernel_mul, 45 | kernel_num=self.kernel_num, 46 | fix_sigma=self.fix_sigma, 47 | ) 48 | loss = torch.Tensor([0]).cuda() 49 | if torch.sum(torch.isnan(sum(kernels))): 50 | return loss 51 | SS = kernels[:batch_size, :batch_size] 52 | TT = kernels[batch_size:, batch_size:] 53 | ST = kernels[:batch_size, batch_size:] 54 | 55 | loss += torch.sum(weight_ss * SS + weight_tt * TT - 2 * weight_st * ST) 56 | # Dynamic weighting 57 | lamb = self.lamb() 58 | self.step() 59 | loss = loss * lamb 60 | return loss, weight_ss.mean(), weight_tt.mean(), weight_st.mean() 61 | 62 | def cal_weight(self, source_label, target_logits): 63 | batch_size = source_label.size()[0] 64 | source_label = source_label.cpu().data.numpy() 65 | source_label_onehot = np.eye(self.num_class)[source_label] # one hot 66 | source_label_sum = np.sum(source_label_onehot, axis=0).reshape( 67 | 1, self.num_class 68 | ) 69 | source_label_sum[source_label_sum == 0] = 100 70 | source_label_onehot = ( 71 | source_label_onehot / source_label_sum 72 | ) # one hot src class ratio reciprocal 当前类别样本总数的倒数(类平衡) 73 | 74 | # Pseudo label 目标区域上没有标签只好用伪标签计算类别是否存在67行 75 | target_label = target_logits.cpu().data.max(1)[1].numpy() 76 | 77 | target_logits = target_logits.cpu().data.numpy() 78 | target_logits_sum = np.sum(target_logits, axis=0).reshape(1, self.num_class) 79 | target_logits_sum[target_logits_sum == 0] = 100 80 | target_logits = ( 81 | target_logits / target_logits_sum 82 | ) # 不同于src tar这里用的软标签 当前类别样本总数的倒数(类平衡) 83 | print("target_logits", target_logits) 84 | 85 | weight_ss = np.zeros((batch_size, batch_size)) 86 | weight_tt = np.zeros((batch_size, batch_size)) 87 | weight_st = np.zeros((batch_size, batch_size)) 88 | 89 | set_s = set(source_label) # 可能当前batch没有某一类 90 | set_t = set(target_label) 91 | count = 0 92 | for i in range(self.num_class): # C) 93 | if i in set_s and i in set_t: 94 | s_tvec = source_label_onehot[:, i].reshape(batch_size, -1) # (B, 1) 95 | t_tvec = target_logits[:, i].reshape(batch_size, -1) # (B, 1) 96 | 97 | ss = np.dot(s_tvec, s_tvec.T) # (B, B) 计算当前batch的加权矩阵 98 | weight_ss = weight_ss + ss # 这里为什么累加没看懂 99 | print(weight_ss) 100 | tt = np.dot(t_tvec, t_tvec.T) 101 | weight_tt = weight_tt + tt 102 | st = np.dot(s_tvec, t_tvec.T) 103 | weight_st = weight_st + st 104 | count += 1 105 | 106 | length = count 107 | if length != 0: 108 | weight_ss = weight_ss / length 109 | weight_tt = weight_tt / length 110 | weight_st = weight_st / length 111 | else: 112 | weight_ss = np.array([0]) 113 | weight_tt = np.array([0]) 114 | weight_st = np.array([0]) 115 | return ( 116 | weight_ss.astype("float32"), 117 | weight_tt.astype("float32"), 118 | weight_st.astype("float32"), 119 | ) 120 | 121 | 122 | if __name__ == "__main__": 123 | lmmd = LMMDLoss(4) 124 | torch.manual_seed(666) 125 | src_label, tar_logits = torch.LongTensor([0, 3, 0]), torch.softmax( 126 | torch.randn(3, 4), 1 127 | ) 128 | print("original logits", tar_logits) 129 | lmmd.cal_weight(src_label, tar_logits) 130 | -------------------------------------------------------------------------------- /DeepDA/transfer_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from loss_funcs import * 4 | from loss_funcs.mmd import MMDLoss 5 | from loss_funcs.adv import LambdaSheduler 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class TransferLoss(nn.Module): 11 | def __init__(self, loss_type, **kwargs): 12 | super(TransferLoss, self).__init__() 13 | self.loss_type = loss_type 14 | if loss_type == "mmd": 15 | self.loss_func = MMDLoss(**kwargs) 16 | elif loss_type == "lmmd": 17 | self.loss_func = LMMDLoss(**kwargs) 18 | elif loss_type == "coral": 19 | self.loss_func = CORAL 20 | elif loss_type == "adv": 21 | self.loss_func = AdversarialLoss(**kwargs) 22 | elif loss_type == "daan": 23 | self.loss_func = DAANLoss(**kwargs) 24 | elif loss_type == "bnm": 25 | self.loss_func = BNM 26 | else: 27 | print("WARNING: No valid transfer loss function is used.") 28 | self.loss_func = lambda x, y: 0 # return 0 29 | 30 | def forward(self, source, target, **kwargs): 31 | return self.loss_func(source, target, **kwargs) 32 | 33 | 34 | class DFTransferLoss(nn.Module): 35 | def __init__(self, loss_type, **kwargs): 36 | super(DFTransferLoss, self).__init__() 37 | self.loss_type = loss_type 38 | if loss_type == "lmmd": 39 | self.loss_func = DFLMMDLoss(**kwargs) 40 | else: 41 | print("WARNING: No valid transfer loss function is used.") 42 | self.loss_func = lambda x, y: 0 # return 0 43 | 44 | def forward(self, source, target, **kwargs): 45 | return self.loss_func(source, target, **kwargs) 46 | 47 | 48 | class DFLMMDLoss(MMDLoss, LambdaSheduler): 49 | def __init__( 50 | self, 51 | num_class, 52 | k, 53 | kernel_type="rbf", 54 | kernel_mul=2.0, 55 | kernel_num=5, 56 | fix_sigma=None, 57 | gamma=1.0, 58 | max_iter=1000, 59 | **kwargs 60 | ): 61 | """ 62 | Local MMD 63 | """ 64 | super(DFLMMDLoss, self).__init__( 65 | kernel_type, kernel_mul, kernel_num, fix_sigma, **kwargs 66 | ) 67 | super(MMDLoss, self).__init__(gamma, max_iter, **kwargs) 68 | self.num_class = num_class 69 | self.k = k 70 | 71 | def forward(self, source, target, source_label, target_logits, src_conf_margin): 72 | if self.kernel_type == "linear": 73 | raise NotImplementedError("Linear kernel is not supported yet.") 74 | 75 | elif self.kernel_type == "rbf": 76 | batch_size = source.size()[0] 77 | weight_ss, weight_tt, weight_st = self.cal_weight( 78 | source_label, target_logits 79 | ) 80 | weight_ss = torch.from_numpy(weight_ss).cuda() # B, B 81 | weight_tt = torch.from_numpy(weight_tt).cuda() 82 | weight_st = torch.from_numpy(weight_st).cuda() 83 | 84 | kernels = self.guassian_kernel( 85 | source, 86 | target, 87 | kernel_mul=self.kernel_mul, 88 | kernel_num=self.kernel_num, 89 | fix_sigma=self.fix_sigma, 90 | ) 91 | loss = torch.Tensor([0]).cuda() 92 | if torch.sum(torch.isnan(sum(kernels))): 93 | return loss 94 | SS = kernels[:batch_size, :batch_size] 95 | TT = kernels[batch_size:, batch_size:] 96 | ST = kernels[:batch_size, batch_size:] 97 | 98 | # loss += torch.sum(weight_ss * SS + weight_tt * TT - 2 * weight_st * ST) 99 | src_conf_margin = torch.sigmoid(src_conf_margin * self.k) 100 | loss += torch.sum( 101 | src_conf_margin * weight_ss * SS 102 | + weight_tt * TT 103 | - 2 * torch.sqrt(src_conf_margin) * weight_st * ST 104 | ) 105 | # Dynamic weighting 106 | lamb = self.lamb() 107 | self.step() 108 | loss = loss * lamb 109 | return loss, weight_ss.mean(), weight_tt.mean(), weight_st.mean() 110 | 111 | def cal_weight(self, source_label, target_logits): 112 | batch_size = source_label.size()[0] 113 | source_label = source_label.cpu().data.numpy() 114 | source_label_onehot = np.eye(self.num_class)[source_label] # one hot 115 | source_label_sum = np.sum(source_label_onehot, axis=0).reshape( 116 | 1, self.num_class 117 | ) 118 | source_label_sum[source_label_sum == 0] = 100 119 | source_label_onehot = source_label_onehot / source_label_sum # label ratio 120 | 121 | # Pseudo label 122 | target_label = target_logits.cpu().data.max(1)[1].numpy() 123 | 124 | target_logits = target_logits.cpu().data.numpy() 125 | target_logits_sum = np.sum(target_logits, axis=0).reshape(1, self.num_class) 126 | target_logits_sum[target_logits_sum == 0] = 100 127 | target_logits = target_logits / target_logits_sum 128 | 129 | weight_ss = np.zeros((batch_size, batch_size)) 130 | weight_tt = np.zeros((batch_size, batch_size)) 131 | weight_st = np.zeros((batch_size, batch_size)) 132 | 133 | set_s = set(source_label) 134 | set_t = set(target_label) 135 | count = 0 136 | for i in range(self.num_class): # (B, C) 137 | if i in set_s and i in set_t: 138 | s_tvec = source_label_onehot[:, i].reshape(batch_size, -1) # (B, 1) 139 | t_tvec = target_logits[:, i].reshape(batch_size, -1) # (B, 1) 140 | 141 | ss = np.dot(s_tvec, s_tvec.T) # (B, B) 142 | weight_ss = weight_ss + ss 143 | tt = np.dot(t_tvec, t_tvec.T) 144 | weight_tt = weight_tt + tt 145 | st = np.dot(s_tvec, t_tvec.T) 146 | weight_st = weight_st + st 147 | count += 1 148 | 149 | length = count 150 | if length != 0: 151 | weight_ss = weight_ss / length 152 | weight_tt = weight_tt / length 153 | weight_st = weight_st / length 154 | else: 155 | weight_ss = np.array([0]) 156 | weight_tt = np.array([0]) 157 | weight_st = np.array([0]) 158 | return ( 159 | weight_ss.astype("float32"), 160 | weight_tt.astype("float32"), 161 | weight_st.astype("float32"), 162 | ) 163 | -------------------------------------------------------------------------------- /DeepDA/diff_transfer_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from loss_funcs import * 4 | from loss_funcs.mmd import MMDLoss 5 | from loss_funcs.adv import LambdaSheduler 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class TransferLoss(nn.Module): 11 | def __init__(self, loss_type, **kwargs): 12 | super(TransferLoss, self).__init__() 13 | self.loss_type = loss_type 14 | if loss_type == "mmd": 15 | self.loss_func = MMDLoss(**kwargs) 16 | elif loss_type == "lmmd": 17 | self.loss_func = LMMDLoss(**kwargs) 18 | elif loss_type == "coral": 19 | self.loss_func = CORAL 20 | elif loss_type == "adv": 21 | self.loss_func = AdversarialLoss(**kwargs) 22 | elif loss_type == "daan": 23 | self.loss_func = DAANLoss(**kwargs) 24 | elif loss_type == "bnm": 25 | self.loss_func = BNM 26 | else: 27 | print("WARNING: No valid transfer loss function is used.") 28 | self.loss_func = lambda x, y: 0 # return 0 29 | 30 | def forward(self, source, target, **kwargs): 31 | return self.loss_func(source, target, **kwargs) 32 | 33 | 34 | class DFTransferLoss(nn.Module): 35 | def __init__(self, loss_type, **kwargs): 36 | super(DFTransferLoss, self).__init__() 37 | self.loss_type = loss_type 38 | if loss_type == "mmd": 39 | self.loss_func = MMDLoss(**kwargs) 40 | elif loss_type == "lmmd": 41 | self.loss_func = DFLMMDLoss(**kwargs) 42 | else: 43 | print("WARNING: No valid transfer loss function is used.") 44 | self.loss_func = lambda x, y: 0 # return 0 45 | 46 | def forward(self, source, target, **kwargs): 47 | return self.loss_func(source, target, **kwargs) 48 | 49 | 50 | class DFLMMDLoss(MMDLoss, LambdaSheduler): 51 | def __init__( 52 | self, 53 | num_class, 54 | kernel_type="rbf", 55 | kernel_mul=2.0, 56 | kernel_num=5, 57 | fix_sigma=None, 58 | gamma=1.0, 59 | max_iter=1000, 60 | **kwargs 61 | ): 62 | """ 63 | Local MMD 64 | """ 65 | super(DFLMMDLoss, self).__init__( 66 | kernel_type, kernel_mul, kernel_num, fix_sigma, **kwargs 67 | ) 68 | super(MMDLoss, self).__init__(gamma, max_iter, **kwargs) 69 | self.num_class = num_class 70 | 71 | def forward( 72 | self, 73 | source, 74 | target, 75 | source_label, 76 | target_logits, 77 | src_conf_margin, 78 | tar_conf_margin, 79 | ): 80 | if self.kernel_type == "linear": 81 | raise NotImplementedError("Linear kernel is not supported yet.") 82 | 83 | elif self.kernel_type == "rbf": 84 | batch_size = source.size()[0] 85 | weight_ss, weight_tt, weight_st = self.cal_weight( 86 | source_label, target_logits 87 | ) 88 | weight_ss = torch.from_numpy(weight_ss).cuda() # B, B 89 | weight_tt = torch.from_numpy(weight_tt).cuda() 90 | weight_st = torch.from_numpy(weight_st).cuda() 91 | 92 | kernels = self.guassian_kernel( 93 | source, 94 | target, 95 | kernel_mul=self.kernel_mul, 96 | kernel_num=self.kernel_num, 97 | fix_sigma=self.fix_sigma, 98 | ) 99 | loss = torch.Tensor([0]).cuda() 100 | if torch.sum(torch.isnan(sum(kernels))): 101 | return loss 102 | SS = kernels[:batch_size, :batch_size] 103 | TT = kernels[batch_size:, batch_size:] 104 | ST = kernels[:batch_size, batch_size:] 105 | 106 | # loss += torch.sum(weight_ss * SS + weight_tt * TT - 2 * weight_st * ST) 107 | src_conf_margin = torch.sigmoid(src_conf_margin) 108 | tar_conf_margin = torch.sigmoid(tar_conf_margin) 109 | loss += torch.sum( 110 | src_conf_margin * weight_ss * SS 111 | + tar_conf_margin * weight_tt * TT 112 | - 2 * torch.sqrt(src_conf_margin * tar_conf_margin) * weight_st * ST 113 | ) 114 | # Dynamic weighting 115 | lamb = self.lamb() 116 | self.step() 117 | loss = loss * lamb 118 | return loss, weight_ss.mean(), weight_tt.mean(), weight_st.mean() 119 | 120 | def cal_weight(self, source_label, target_logits): 121 | batch_size = source_label.size()[0] 122 | source_label = source_label.cpu().data.numpy() 123 | source_label_onehot = np.eye(self.num_class)[source_label] # one hot 124 | source_label_sum = np.sum(source_label_onehot, axis=0).reshape( 125 | 1, self.num_class 126 | ) 127 | source_label_sum[source_label_sum == 0] = 100 128 | source_label_onehot = source_label_onehot / source_label_sum # label ratio 129 | 130 | # Pseudo label 131 | target_label = target_logits.cpu().data.max(1)[1].numpy() 132 | 133 | target_logits = target_logits.cpu().data.numpy() 134 | target_logits_sum = np.sum(target_logits, axis=0).reshape(1, self.num_class) 135 | target_logits_sum[target_logits_sum == 0] = 100 136 | target_logits = target_logits / target_logits_sum 137 | 138 | weight_ss = np.zeros((batch_size, batch_size)) 139 | weight_tt = np.zeros((batch_size, batch_size)) 140 | weight_st = np.zeros((batch_size, batch_size)) 141 | 142 | set_s = set(source_label) 143 | set_t = set(target_label) 144 | count = 0 145 | for i in range(self.num_class): # (B, C) 146 | if i in set_s and i in set_t: 147 | s_tvec = source_label_onehot[:, i].reshape(batch_size, -1) # (B, 1) 148 | t_tvec = target_logits[:, i].reshape(batch_size, -1) # (B, 1) 149 | 150 | ss = np.dot(s_tvec, s_tvec.T) # (B, B) 151 | weight_ss = weight_ss + ss 152 | tt = np.dot(t_tvec, t_tvec.T) 153 | weight_tt = weight_tt + tt 154 | st = np.dot(s_tvec, t_tvec.T) 155 | weight_st = weight_st + st 156 | count += 1 157 | 158 | length = count 159 | if length != 0: 160 | weight_ss = weight_ss / length 161 | weight_tt = weight_tt / length 162 | weight_st = weight_st / length 163 | else: 164 | weight_ss = np.array([0]) 165 | weight_tt = np.array([0]) 166 | weight_st = np.array([0]) 167 | return ( 168 | weight_ss.astype("float32"), 169 | weight_tt.astype("float32"), 170 | weight_st.astype("float32"), 171 | ) 172 | -------------------------------------------------------------------------------- /calibration.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | device = "cuda" if torch.cuda.is_available() else "cpu" 6 | 7 | 8 | class TemperatureScaling(nn.Module): 9 | def __init__( 10 | self, 11 | ): 12 | super().__init__() 13 | self.temp = nn.Parameter(torch.ones(1, dtype=torch.float)) 14 | 15 | def forward(self, x): 16 | x = x * torch.clamp(self.temp, 0.8, 1.2) 17 | return x 18 | 19 | 20 | class DiagonalScaling(nn.Module): 21 | def __init__(self, class_nums): 22 | super().__init__() 23 | self.class_nums = class_nums 24 | self.diag = nn.Parameter(torch.eye(class_nums)) 25 | self.bias = nn.Parameter(torch.zeros(class_nums)) 26 | 27 | def forward(self, x): 28 | x = torch.mm(x, self.diag) + self.bias 29 | return x 30 | 31 | 32 | class MatrixScaling(nn.Module): 33 | def __init__(self, class_nums, off_diagonal_intercept_regularization=True): 34 | super().__init__() 35 | self.class_nums = class_nums 36 | self.odir = off_diagonal_intercept_regularization 37 | self.mat = torch.nn.Linear(class_nums, class_nums) 38 | 39 | def forward(self, x): 40 | return self.mat(x) 41 | 42 | 43 | def calibrationMapping( 44 | num_cls, 45 | model, 46 | val_loader, 47 | calibration_type="diagonal_scaling", 48 | calibration_lr=0.01, 49 | max_iter=10, 50 | ms_odir=True, 51 | ms_l=1e-5, 52 | ms_mu=1e-5, 53 | ): 54 | model.eval() 55 | logits_list = [] 56 | labels_list = [] 57 | num_classes = num_cls 58 | criterion = nn.CrossEntropyLoss() 59 | if calibration_type == "matrix_scaling": 60 | temp = MatrixScaling(num_classes, ms_odir) 61 | elif calibration_type == "temperature_scaling": 62 | temp = TemperatureScaling() 63 | elif calibration_type == "diagonal_scaling": 64 | temp = DiagonalScaling(num_classes) 65 | else: 66 | raise NotImplementedError 67 | temp = temp.to(device) 68 | optimizer = torch.optim.LBFGS( 69 | temp.parameters(), lr=calibration_lr, max_iter=max_iter 70 | ) 71 | 72 | with torch.no_grad(): 73 | for data, labels in val_loader: 74 | data, labels = data.cuda(non_blocking=True), labels.cuda(non_blocking=True) 75 | logits, _ = model(data, labels) 76 | logits_list.append(logits) 77 | labels_list.append(labels) 78 | logits = torch.cat(logits_list).to(device) 79 | labels = torch.cat(labels_list).to(device) 80 | 81 | def closure(): 82 | if calibration_type == "matrix scaling": 83 | if temp.odir: 84 | assert ( 85 | ms_l is not None and ms_mu is not None 86 | ), "assign l and mu to apply odir regularization" 87 | regularization = 0 88 | for i in range(temp.class_nums): 89 | # off diagonal regularization; bias magnitude regularization 90 | regularization += ms_l * torch.sum( 91 | torch.square(temp.mat.weight[0:i, i]) 92 | ) 93 | regularization += ms_l * torch.sum( 94 | torch.square(temp.mat.weight[i + 1 :, i]) 95 | ) 96 | regularization += ms_mu * torch.square(temp.mat.bias[i]) 97 | _logits = temp(logits) 98 | _loss = criterion(_logits, labels) 99 | else: 100 | _loss = criterion(temp(logits), labels) 101 | _loss.backward() 102 | return _loss 103 | 104 | optimizer.step(closure) 105 | return temp 106 | 107 | 108 | def ece_eval(preds, targets, n_bins=10, bg_cls=0): 109 | bin_boundaries = np.linspace(0, 1, n_bins + 1) 110 | bin_lowers = bin_boundaries[:-1] 111 | bin_uppers = bin_boundaries[1:] 112 | confidences, predictions = np.max(preds, 1), np.argmax(preds, 1) 113 | confidences, predictions = ( 114 | confidences[targets > bg_cls], 115 | predictions[targets > bg_cls], 116 | ) 117 | accuracies = predictions == targets[targets > bg_cls] 118 | Bm, acc, conf = np.zeros(n_bins), np.zeros(n_bins), np.zeros(n_bins) 119 | ece = 0.0 120 | bin_idx = 0 121 | for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): 122 | in_bin = np.logical_and(confidences > bin_lower, confidences <= bin_upper) 123 | bin_size = np.sum(in_bin) 124 | 125 | Bm[bin_idx] = bin_size 126 | if bin_size > 0: 127 | accuracy_in_bin = np.sum(accuracies[in_bin]) 128 | acc[bin_idx] = accuracy_in_bin / Bm[bin_idx] 129 | confidence_in_bin = np.sum(confidences[in_bin]) 130 | conf[bin_idx] = confidence_in_bin / Bm[bin_idx] 131 | bin_idx += 1 132 | 133 | ece_all = Bm * np.abs((acc - conf)) / Bm.sum() 134 | ece = ece_all.sum() 135 | return ece, acc, conf, Bm 136 | 137 | 138 | def tace_eval(preds, targets, n_bins=10, threshold=1e-4, bg_cls=0): 139 | init = 0 140 | if bg_cls == 0: 141 | init = 1 142 | preds = preds.astype(np.float32) 143 | targets = targets.astype(np.float16) 144 | n_img, n_classes = preds.shape[:2] 145 | Bm_all, acc_all, conf_all = np.zeros(n_bins), np.zeros(n_bins), np.zeros(n_bins) 146 | ece_all = [] 147 | for cur_class in range(init, n_classes): 148 | cur_class_conf = preds[:, cur_class] 149 | cur_class_conf = cur_class_conf.flatten() 150 | cur_class_conf_sorted = np.sort(cur_class_conf) 151 | targets_vec = targets.flatten() 152 | targets_sorted = targets_vec[cur_class_conf.argsort()] 153 | # target must be sorted along with cls conf 154 | targets_sorted = targets_sorted[cur_class_conf_sorted > threshold] 155 | cur_class_conf_sorted = cur_class_conf_sorted[cur_class_conf_sorted > threshold] 156 | bin_size = len(cur_class_conf_sorted) // n_bins 157 | ece_cls, Bm, acc, conf = ( 158 | np.zeros(n_bins), 159 | np.zeros(n_bins), 160 | np.zeros(n_bins), 161 | np.zeros(n_bins), 162 | ) 163 | bin_idx = 0 164 | for bin_i in range(n_bins): 165 | bin_start_ind = bin_i * bin_size 166 | if bin_i < n_bins - 1: 167 | bin_end_ind = bin_start_ind + bin_size 168 | else: 169 | bin_end_ind = len(targets_sorted) 170 | bin_size = ( 171 | bin_end_ind - bin_start_ind 172 | ) # extend last bin until the end of prediction array 173 | # print('bin start', cur_class_conf_sorted[bin_start_ind]) 174 | # print('bin end', cur_class_conf_sorted[bin_end_ind-1]) 175 | # Bm contains size to compute proportion 176 | Bm[bin_idx] = bin_size 177 | # compute bin acc with indices 178 | bin_acc = targets_sorted[bin_start_ind:bin_end_ind] == cur_class 179 | acc[bin_idx] = np.sum(bin_acc) / bin_size 180 | bin_conf = cur_class_conf_sorted[bin_start_ind:bin_end_ind] 181 | conf[bin_idx] = np.sum(bin_conf) / bin_size 182 | bin_idx += 1 183 | # weighted average 184 | ece_cls = Bm * np.abs((acc - conf)) / (Bm.sum()) 185 | ece_all.append(np.mean(ece_cls)) 186 | Bm_all += Bm 187 | acc_all += acc 188 | conf_all += conf 189 | ece, acc_all, conf_all = ( 190 | np.mean(ece_all), 191 | acc_all / (n_classes - init), 192 | conf_all / (n_classes - init), 193 | ) 194 | return ece, acc_all, conf_all, Bm_all 195 | -------------------------------------------------------------------------------- /DeepDA/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transfer_losses import TransferLoss, DFTransferLoss 4 | import backbones 5 | from apex import amp 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | 10 | class TransferNet(nn.Module): 11 | def __init__( 12 | self, 13 | num_class, 14 | base_net="resnet50", 15 | transfer_loss="lmmd", 16 | use_bottleneck=True, 17 | bottleneck_width=256, 18 | max_iter=1000, 19 | **kwargs 20 | ): 21 | super(TransferNet, self).__init__() 22 | self.num_class = num_class 23 | self.base_network = backbones.get_backbone(base_net) 24 | self.use_bottleneck = use_bottleneck 25 | self.transfer_loss = transfer_loss 26 | self.amp = None 27 | self.tr_weight = 10 28 | self.optimizer = None 29 | self.lr_scheduler = None 30 | if self.use_bottleneck: 31 | bottleneck_list = [ 32 | nn.Linear(self.base_network.output_num(), bottleneck_width), 33 | nn.ReLU(), 34 | ] 35 | self.bottleneck_layer = nn.Sequential(*bottleneck_list) 36 | feature_dim = bottleneck_width 37 | else: 38 | feature_dim = self.base_network.output_num() 39 | 40 | self.classifier_layer = nn.Linear(feature_dim, num_class) 41 | transfer_loss_args = { 42 | "loss_type": "lmmd", 43 | "max_iter": max_iter, 44 | "num_class": num_class, 45 | } 46 | # self.adapt_loss = TransferLoss(**transfer_loss_args) 47 | transfer_loss_args["k"] = nn.parameter.Parameter( 48 | torch.ones(1, device="cuda") * 10 49 | ) 50 | self.adapt_loss = DFTransferLoss(**transfer_loss_args) 51 | self.criterion = torch.nn.CrossEntropyLoss() 52 | 53 | self.mmd = MMDLoss() 54 | self.gk1 = GaussianKernel(1.0, 1.0) 55 | self.gk5 = GaussianKernel(1.0, 5.0) 56 | 57 | def forward(self, source, target, source_label): 58 | source = self.base_network(source) 59 | target = self.base_network(target) 60 | if self.use_bottleneck: 61 | source = self.bottleneck_layer(source) 62 | target = self.bottleneck_layer(target) 63 | # classification 64 | source_clf = self.classifier_layer(source) 65 | src_target_mask = torch.zeros(len(source_label), 31, device=source.device) 66 | src_target_mask = src_target_mask.scatter( 67 | dim=1, 68 | index=source_label[:, None], 69 | src=torch.ones((len(source_label), 1), device=source.device), 70 | ) # B, C 71 | _, src_conf_margin = get_confidence_output_margin(source_clf, src_target_mask) 72 | 73 | clf_loss = self.criterion(source_clf, source_label) 74 | 75 | # transfer 76 | kwargs = {} 77 | kwargs["source_label"] = source_label 78 | target_clf = self.classifier_layer(target) 79 | kwargs["target_logits"] = torch.nn.functional.softmax(target_clf, dim=1) 80 | 81 | with torch.no_grad(): 82 | curr_mmd = self.mmd(source, target) 83 | curr_gk1 = self.gk1(source, target).mean() 84 | curr_gk5 = self.gk5(source, target).mean() 85 | 86 | tar_pseudo_labels = target_clf.argmax(1) 87 | tar_target_mask = torch.zeros(len(tar_pseudo_labels), 31, device=target.device) 88 | tar_target_mask = tar_target_mask.scatter( 89 | dim=1, 90 | index=tar_pseudo_labels[:, None], 91 | src=torch.ones(len(tar_pseudo_labels), 1, device=target.device), 92 | ) 93 | _, tar_conf_margin = get_confidence_output_margin(target_clf, tar_target_mask) 94 | 95 | kwargs['src_conf_margin'] = src_conf_margin 96 | transfer_loss, _, _, _ = self.adapt_loss(source, target, **kwargs) 97 | loss = clf_loss + transfer_loss * self.tr_weight 98 | if self.amp: 99 | self.optimizer.zero_grad() 100 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 101 | scaled_loss.backward() 102 | self.optimizer.step() 103 | self.lr_scheduler.step() 104 | else: 105 | pass 106 | return clf_loss, transfer_loss, curr_mmd, curr_gk1, curr_gk5 107 | 108 | def get_parameters(self, initial_lr=1.0): 109 | params = [ 110 | {"params": self.base_network.parameters(), "lr": 0.1 * initial_lr}, 111 | {"params": self.classifier_layer.parameters(), "lr": 1.0 * initial_lr}, 112 | ] 113 | if self.use_bottleneck: 114 | params.append( 115 | {"params": self.bottleneck_layer.parameters(), "lr": 1.0 * initial_lr} 116 | ) 117 | # Loss-dependent 118 | if self.transfer_loss == "adv": 119 | params.append( 120 | { 121 | "params": self.adapt_loss.loss_func.domain_classifier.parameters(), 122 | "lr": 1.0 * initial_lr, 123 | } 124 | ) 125 | elif self.transfer_loss == "daan": 126 | params.append( 127 | { 128 | "params": self.adapt_loss.loss_func.domain_classifier.parameters(), 129 | "lr": 1.0 * initial_lr, 130 | } 131 | ) 132 | params.append( 133 | { 134 | "params": self.adapt_loss.loss_func.local_classifiers.parameters(), 135 | "lr": 1.0 * initial_lr, 136 | } 137 | ) 138 | 139 | return params 140 | 141 | def predict(self, x): 142 | features = self.base_network(x) 143 | x = self.bottleneck_layer(features) 144 | clf = self.classifier_layer(x) 145 | return clf 146 | 147 | def epoch_based_processing(self, *args, **kwargs): 148 | if self.transfer_loss == "daan": 149 | self.adapt_loss.loss_func.update_dynamic_factor(*args, **kwargs) 150 | else: 151 | pass 152 | 153 | 154 | def get_confidence_output_margin(logits, label_mask): 155 | confidence = F.softmax(logits, 1) 156 | targets_conf = torch.sum(confidence * label_mask, 1) 157 | 158 | max_excl_conf, _ = torch.max( 159 | confidence * (torch.ones_like(label_mask, device=logits.device) - label_mask), 160 | dim=1, 161 | ) 162 | conf_output_margin = targets_conf - max_excl_conf 163 | return targets_conf, conf_output_margin 164 | 165 | 166 | class MMDLoss(nn.Module): 167 | def forward(self, src_feats: Tensor, tar_feats: Tensor) -> Tensor: 168 | delta = src_feats.mean(0) - tar_feats.mean(0) 169 | loss = delta.dot(delta.t()) 170 | return loss 171 | 172 | 173 | class GaussianKernel(nn.Module): 174 | def __init__(self, amplitude: float, sigma: float): 175 | super(GaussianKernel, self).__init__() 176 | self.amplitude = amplitude 177 | self.sigma = sigma 178 | 179 | def get_covariance_matrix(self, src_feats: Tensor, tar_feats: Tensor) -> Tensor: 180 | """ 181 | :param src_feats: src_batch_size x F 182 | :param tar_feats: tar_batch_size x F 183 | :return: Tensor of src_batch_size x tar_batch_size 184 | """ 185 | 186 | distances_array = torch.stack( 187 | [ 188 | torch.stack([torch.linalg.norm(x_p - x_q) for x_q in tar_feats]) 189 | for x_p in src_feats 190 | ] 191 | ) 192 | covariance_matrix = self.amplitude * torch.exp( 193 | (-1 / (2 * self.sigma ** 2)) * (distances_array ** 2) 194 | ) 195 | 196 | return covariance_matrix 197 | 198 | def forward(self, src_feats: Tensor, tar_feats: Tensor) -> Tensor: 199 | return self.get_covariance_matrix(src_feats, tar_feats) 200 | -------------------------------------------------------------------------------- /DeepDA/backbones.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | import torch 4 | from typing import List, Optional, Dict, Tuple 5 | 6 | resnet_dict = { 7 | "resnet18": models.resnet18, 8 | "resnet34": models.resnet34, 9 | "resnet50": models.resnet50, 10 | "resnet101": models.resnet101, 11 | "resnet152": models.resnet152, 12 | } 13 | 14 | 15 | def get_backbone(name): 16 | if "resnet" in name.lower(): 17 | return ResNetBackbone(name) 18 | elif "alexnet" == name.lower(): 19 | return AlexNetBackbone() 20 | elif "dann" == name.lower(): 21 | return DaNNBackbone() 22 | 23 | 24 | class DaNNBackbone(nn.Module): 25 | def __init__(self, n_input=224 * 224 * 3, n_hidden=256): 26 | super(DaNNBackbone, self).__init__() 27 | self.layer_input = nn.Linear(n_input, n_hidden) 28 | self.dropout = nn.Dropout(p=0.5) 29 | self.relu = nn.ReLU() 30 | self._feature_dim = n_hidden 31 | 32 | def forward(self, x): 33 | x = x.view(x.size(0), -1) 34 | x = self.layer_input(x) 35 | x = self.dropout(x) 36 | x = self.relu(x) 37 | return x 38 | 39 | def output_num(self): 40 | return self._feature_dim 41 | 42 | 43 | # convnet without the last layer 44 | class AlexNetBackbone(nn.Module): 45 | def __init__(self): 46 | super(AlexNetBackbone, self).__init__() 47 | model_alexnet = models.alexnet(pretrained=True) 48 | self.features = model_alexnet.features 49 | self.classifier = nn.Sequential() 50 | for i in range(6): 51 | self.classifier.add_module( 52 | "classifier" + str(i), model_alexnet.classifier[i] 53 | ) 54 | self._feature_dim = model_alexnet.classifier[6].in_features 55 | 56 | def forward(self, x): 57 | x = self.features(x) 58 | x = x.view(x.size(0), 256 * 6 * 6) 59 | x = self.classifier(x) 60 | return x 61 | 62 | def output_num(self): 63 | return self._feature_dim 64 | 65 | 66 | class ResNetBackbone(nn.Module): 67 | def __init__(self, network_type): 68 | super(ResNetBackbone, self).__init__() 69 | state_dict = torch.load("/home/checkpoint/resnet50/resnet50-19c8e357.pth") 70 | # state_dict = torch.load('/home/checkpoint/visda/best_resnet_train.pth') 71 | resnet = resnet_dict[network_type](pretrained=False) 72 | resnet.load_state_dict(state_dict) 73 | self.conv1 = resnet.conv1 74 | self.bn1 = resnet.bn1 75 | self.relu = resnet.relu 76 | self.maxpool = resnet.maxpool 77 | self.layer1 = resnet.layer1 78 | self.layer2 = resnet.layer2 79 | self.layer3 = resnet.layer3 80 | self.layer4 = resnet.layer4 81 | self.avgpool = resnet.avgpool 82 | self._feature_dim = resnet.fc.in_features 83 | del resnet 84 | 85 | def forward(self, x): 86 | x = self.conv1(x) 87 | x = self.bn1(x) 88 | x = self.relu(x) 89 | x = self.maxpool(x) 90 | x = self.layer1(x) 91 | x = self.layer2(x) 92 | x = self.layer3(x) 93 | x = self.layer4(x) 94 | x = self.avgpool(x) 95 | x = x.view(x.size(0), -1) 96 | return x 97 | 98 | def output_num(self): 99 | return self._feature_dim 100 | 101 | 102 | class ClassifierBase(nn.Module): 103 | """A generic Classifier class for domain adaptation. 104 | 105 | Args: 106 | backbone (torch.nn.Module): Any backbone to extract 2-d features from data 107 | num_classes (int): Number of classes 108 | bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default 109 | bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1 110 | head (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default 111 | finetune (bool): Whether finetune the classifier or train from scratch. Default: True 112 | 113 | .. note:: 114 | Different classifiers are used in different domain adaptation algorithms to achieve better accuracy 115 | respectively, and we provide a suggested `Classifier` for different algorithms. 116 | Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with 117 | the domain adaptation algorithm in this algorithm library. 118 | 119 | .. note:: 120 | The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy 121 | by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`. 122 | 123 | Inputs: 124 | - x (tensor): input data fed to `backbone` 125 | 126 | Outputs: 127 | - predictions: classifier's predictions 128 | - features: features after `bottleneck` layer and before `head` layer 129 | 130 | Shape: 131 | - Inputs: (minibatch, *) where * means, any number of additional dimensions 132 | - predictions: (minibatch, `num_classes`) 133 | - features: (minibatch, `features_dim`) 134 | 135 | """ 136 | 137 | def __init__( 138 | self, 139 | backbone: nn.Module, 140 | num_classes: int, 141 | bottleneck: Optional[nn.Module] = None, 142 | bottleneck_dim: Optional[int] = -1, 143 | head: Optional[nn.Module] = None, 144 | finetune=True, 145 | pool_layer=None, 146 | ): 147 | super(ClassifierBase, self).__init__() 148 | self.backbone = backbone 149 | self.num_classes = num_classes 150 | if pool_layer is None: 151 | self.pool_layer = nn.Sequential( 152 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() 153 | ) 154 | else: 155 | self.pool_layer = pool_layer 156 | if bottleneck is None: 157 | self.bottleneck = nn.Identity() 158 | self._features_dim = backbone.out_features 159 | else: 160 | self.bottleneck = bottleneck 161 | assert bottleneck_dim > 0 162 | self._features_dim = bottleneck_dim 163 | 164 | if head is None: 165 | self.head = nn.Linear(self._features_dim, num_classes) 166 | else: 167 | self.head = head 168 | self.finetune = finetune 169 | 170 | @property 171 | def features_dim(self) -> int: 172 | """The dimension of features before the final `head` layer""" 173 | return self._features_dim 174 | 175 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 176 | """""" 177 | f = self.pool_layer(self.backbone(x)) 178 | f = self.bottleneck(f) 179 | predictions = self.head(f) 180 | if self.training: 181 | return predictions, f 182 | else: 183 | return predictions 184 | 185 | def get_parameters(self, base_lr=1.0) -> List[Dict]: 186 | """A parameter list which decides optimization hyper-parameters, 187 | such as the relative learning rate of each layer 188 | """ 189 | params = [ 190 | { 191 | "params": self.backbone.parameters(), 192 | "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr, 193 | }, 194 | {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, 195 | {"params": self.head.parameters(), "lr": 1.0 * base_lr}, 196 | ] 197 | 198 | return params 199 | 200 | 201 | class ImageClassifier(ClassifierBase): 202 | def __init__( 203 | self, 204 | backbone: nn.Module, 205 | num_classes: int, 206 | bottleneck_dim: Optional[int] = 256, 207 | **kwargs 208 | ): 209 | bottleneck = nn.Sequential( 210 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)), 211 | # nn.Flatten(), 212 | nn.Linear(backbone.out_features, bottleneck_dim), 213 | nn.BatchNorm1d(bottleneck_dim), 214 | nn.ReLU(), 215 | ) 216 | super(ImageClassifier, self).__init__( 217 | backbone, num_classes, bottleneck, bottleneck_dim, **kwargs 218 | ) 219 | -------------------------------------------------------------------------------- /standard_curriculum_learning/main_standard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import random 18 | import warnings 19 | import json 20 | import collections 21 | import numpy as np 22 | import pickle 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.parallel 26 | import torchvision.transforms as T 27 | import torch.optim 28 | import torch.utils.data 29 | from torch.utils.data import Subset 30 | from utils import get_dataset, get_model, get_optimizer, get_scheduler 31 | from utils import LossTracker, run_cmd 32 | from torchvision.datasets import CIFAR10 33 | from utils import get_pacing_function, balance_order_val 34 | 35 | parser = argparse.ArgumentParser(description='PyTorch Training') 36 | parser.add_argument('--data-dir', default='dataset', 37 | help='path to dataset') 38 | parser.add_argument('--order-dir', default='angular_gap_order.npy', 39 | help='path to train val idx') 40 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 41 | help='model architecture: (default: resnet18)') 42 | parser.add_argument('--dataset', default='cifar10', type=str, 43 | help='dataset') 44 | parser.add_argument('--printfreq', default=10, type=int, 45 | help='print frequency (default: 10)') 46 | parser.add_argument('--workers', default=4, type=int, 47 | help='number of data loading workers (default: 4)') 48 | parser.add_argument('--epochs', default=100, type=int, 49 | help='number of total epochs to run') 50 | parser.add_argument('-b', '--batchsize', default=128, type=int, 51 | help='mini-batch size (default: 256), this is the total') 52 | parser.add_argument('--optimizer', default="sgd", type=str, 53 | help='optimizer') 54 | parser.add_argument('--scheduler', default="cosine", type=str, 55 | help='lr scheduler') 56 | parser.add_argument('--lr', default=0.1, type=float, 57 | help='initial learning rate', dest='lr') 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 59 | help='momentum') 60 | parser.add_argument('--wd', default=5e-4, type=float, 61 | help='weight decay (default: 1e-4)') 62 | parser.add_argument('--seed', default=1111, type=int, 63 | help='seed for initializing training. ') 64 | parser.add_argument('--half', default=False, type=bool, 65 | help='training with half precision') 66 | parser.add_argument('--lr_decay', default=0.1, type=float, 67 | help='lr decay for milestone scheduler') 68 | # curriculum params 69 | parser.add_argument("--ordering", default="standard", type=str, help="which test case to use. supports: standard") 70 | parser.add_argument('--rand-fraction', default=0., type=float, 71 | help='label curruption (default:0)') 72 | args = parser.parse_args() 73 | 74 | def main(): 75 | set_seed(args.seed) 76 | train_transform = T.Compose([ 77 | T.RandomCrop(32, padding=4), 78 | T.RandomHorizontalFlip(), 79 | T.ToTensor(), 80 | T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=(0.247, 0.243, 0.261)) 81 | ]) 82 | test_transform = T.Compose([ 83 | T.Resize(36), 84 | T.CenterCrop(32), 85 | T.ToTensor(), 86 | T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=(0.247, 0.243, 0.261)) 87 | ]) 88 | tr_set = CIFAR10('./', train=False, download=True, transform=train_transform) 89 | 90 | # initiate a recorder for saving and loading stats and checkpoints 91 | if 'hsf' in args.order_dir: 92 | instance_loss = torch.load(os.path.join('./orders',args.order_dir), map_location=torch.device('cpu')) 93 | order = [k for k, v in sorted(instance_loss.items(), key=lambda it: it[1])] 94 | elif 'classification_margin' in args.order_dir: 95 | angular_gap = torch.load(os.path.join('./orders',args.order_dir), map_location=torch.device('cpu')).numpy() 96 | ordering = collections.defaultdict(list) 97 | list(map(lambda a, b: ordering[a].append(b), np.arange(len(angular_gap)), angular_gap)) 98 | order = [k for k, v in sorted(ordering.items(), key=lambda item: -1 * item[1][0])] 99 | elif 'forgetting_events.pkl' in args.order_dir: 100 | order_dir = os.path.join('./orders', args.order_dir) 101 | with open(order_dir, 'rb') as f: 102 | order_dict = pickle.load(f) 103 | indices = order_dict['indices'] 104 | forget_counts = order_dict['forgetting counts'] 105 | indices_order = {} 106 | for ind, count in zip(indices, forget_counts): 107 | indices_order[int(ind)] = count 108 | order = [k for k, v in sorted(indices_order.items(), key=lambda it: it[1])] # forgetting small to large easy to hard 109 | elif 'angular_gap_order' in args.order_dir: 110 | order = np.load('orders/angular_gap_order.npy') 111 | elif 'cscore' in args.order_dir: 112 | instance_loss = torch.load(args.order_dir, map_location=torch.device('cpu')) 113 | order = [k for k, v in sorted(instance_loss.items(), key=lambda it: torch.mean(torch.cat(it[1], 0)))] 114 | else: 115 | print( 116 | 'Please check if the files %s in your folder -- orders. See ./orders/README.md for instructions on how to create the folder' % ( 117 | args.order_dir)) 118 | raise NotImplementedError 119 | print('number classes', len(tr_set.classes)) 120 | order,order_val = balance_order_val(order, tr_set, num_classes=len(tr_set.classes), valp=0.0) 121 | order.extend(order_val) 122 | print(len(order)) 123 | 124 | #check the statistics 125 | bs = args.batchsize 126 | N = len(order) 127 | myiterations = (N//bs+1)*args.epochs 128 | 129 | #initial training 130 | model = get_model(args.arch, nchannels=3, imsize=32, nclasses=10, args=args) 131 | optimizer = get_optimizer(args.optimizer, model.parameters(), args.lr, args.momentum, args.wd) 132 | scheduler = get_scheduler(args.scheduler, optimizer, num_epochs=myiterations) 133 | 134 | history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [],"test_loss": [], "test_acc": [], "iter": [0,] } 135 | 136 | trainsets = Subset(tr_set, order) 137 | 138 | val_set = CIFAR10('./', train=True, download=True, transform=test_transform) 139 | 140 | test_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batchsize*2, 141 | shuffle=False, num_workers=args.workers, pin_memory=True) 142 | 143 | train_loader = torch.utils.data.DataLoader(trainsets, batch_size=args.batchsize, 144 | shuffle=False, num_workers=args.workers, pin_memory=True) 145 | 146 | criterion = nn.CrossEntropyLoss().cuda() 147 | 148 | iterations = 0 149 | history_per_iteration_record = {'train_acc': []} 150 | for epoch in range(args.epochs): 151 | tr_loss, tr_acc1, iterations, history_per_iter = standard_train(train_loader, model, criterion, optimizer,scheduler, epoch,iterations) 152 | print('epoch', epoch, 'lr', optimizer.param_groups[0]['lr'], 'train_loss', tr_loss, 'train_acc_top1', tr_acc1) 153 | history_per_iteration_record['train_acc'].extend(history_per_iter) 154 | test_loss, test_acc1 = standard_validate(test_loader, model, criterion) 155 | # print ("%s epoch %s iterations w/ LEARNING RATE %s"%(epoch, iterations,optimizer.param_groups[0]["lr"])) 156 | print('epoch', epoch, 'test_acc_top1', test_acc1) 157 | history["test_loss"].append(test_loss) 158 | history["test_acc"].append(test_acc1.item()) 159 | history["train_loss"].append(tr_loss) 160 | history["train_acc"].append(tr_acc1.item()) 161 | history["iter"].append(iterations) 162 | torch.save(history,"./results/{}_{}.pt".format(args.dataset, args.ordering,args.order_dir[:10])) 163 | torch.save(history_per_iteration_record, "./results/{}_{}_order{}.pt".format(args.dataset, args.ordering,args.order_dir[:10])) 164 | 165 | def standard_train(train_loader, model, criterion, optimizer,scheduler, epoch, iterations): 166 | # switch to train mode 167 | model.train() 168 | history_per_iterations = {'train_acc':[]} 169 | tracker = LossTracker(len(train_loader), f'Epoch: [{epoch}]', args.printfreq) 170 | for i, (images, target) in enumerate(train_loader): 171 | iterations += 1 172 | images, target = cuda_transfer(images, target) 173 | output = model(images) 174 | loss = criterion(output, target) 175 | optimizer.zero_grad() 176 | loss.backward() 177 | optimizer.step() 178 | tracker.update(loss, output, target) 179 | history_per_iterations['train_acc'].append(tracker.top1.avg) 180 | scheduler.step() 181 | return tracker.losses.avg, tracker.top1.avg, iterations, history_per_iterations 182 | 183 | def standard_validate(val_loader, model, criterion): 184 | # switch to evaluate mode 185 | model.eval() 186 | with torch.no_grad(): 187 | tracker = LossTracker(len(val_loader), f'val', args.printfreq) 188 | for i, (images, target) in enumerate(val_loader): 189 | images, target = cuda_transfer(images, target) 190 | output = model(images) 191 | loss = criterion(output, target) 192 | tracker.update(loss, output, target) 193 | return tracker.losses.avg, tracker.top1.avg 194 | 195 | def set_seed(seed=None): 196 | if seed is not None: 197 | random.seed(args.seed) 198 | torch.manual_seed(args.seed) 199 | torch.backends.cudnn.deterministic = True 200 | warnings.warn('You have chosen to seed training. ' 201 | 'This will turn on the CUDNN deterministic setting, ' 202 | 'which can slow down your training considerably! ' 203 | 'You may see unexpected behavior when restarting ' 204 | 'from checkpoints.') 205 | 206 | def cuda_transfer(images, target): 207 | images = images.cuda(non_blocking=True) 208 | target = target.cuda(non_blocking=True) 209 | return images, target 210 | 211 | if __name__ == '__main__': 212 | main() 213 | 214 | -------------------------------------------------------------------------------- /DeepDA/dsan.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import configargparse 4 | import data_loader 5 | import os 6 | import torch 7 | import models 8 | import utils 9 | from utils import str2bool 10 | import numpy as np 11 | import random 12 | import time 13 | from utils import TensorboardLogger 14 | from apex import amp 15 | import json 16 | 17 | currtime = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime(time.time())) 18 | os.makedirs("/home/DSANLog/%s" % currtime, exist_ok=True) 19 | log_path = "/home/DSANLog/%s" % currtime 20 | tensorboardLogger = TensorboardLogger(log_path) 21 | 22 | 23 | def get_parser(): 24 | """Get default arguments.""" 25 | parser = configargparse.ArgumentParser( 26 | description="Transfer learning config parser", 27 | config_file_parser_class=configargparse.YAMLConfigFileParser, 28 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, 29 | ) 30 | # general configuration 31 | parser.add("--config", is_config_file=True, help="config file path") 32 | parser.add("--seed", type=int, default=0) 33 | parser.add_argument("--num_workers", type=int, default=0) 34 | 35 | # network related 36 | parser.add_argument("--backbone", type=str, default="resnet50") 37 | parser.add_argument("--use_bottleneck", type=str2bool, default=True) 38 | 39 | # data loading related 40 | parser.add_argument("--data_dir", type=str, default="/home/data/office31") 41 | parser.add_argument("--src_domain", type=str, default="amazon") 42 | parser.add_argument("--tgt_domain", type=str, default="webcam") 43 | 44 | # training related 45 | parser.add_argument("--batch_size", type=int, default=64) 46 | parser.add_argument("--n_epoch", type=int, default=100) 47 | parser.add_argument("--early_stop", type=int, default=0, help="Early stopping") 48 | parser.add_argument( 49 | "--epoch_based_training", 50 | type=str2bool, 51 | default=False, 52 | help="Epoch-based training / Iteration-based training", 53 | ) 54 | parser.add_argument( 55 | "--n_iter_per_epoch", 56 | type=int, 57 | default=20, 58 | help="Used in Iteration-based training", 59 | ) 60 | 61 | # optimizer related 62 | parser.add_argument("--lr", type=float, default=1e-3) 63 | parser.add_argument("--momentum", type=float, default=0.9) 64 | parser.add_argument("--weight_decay", type=float, default=5e-4) 65 | 66 | # learning rate scheduler related 67 | parser.add_argument("--lr_gamma", type=float, default=0.0003) 68 | parser.add_argument("--lr_decay", type=float, default=0.75) 69 | parser.add_argument("--lr_scheduler", type=str2bool, default=True) 70 | 71 | # transfer related 72 | parser.add_argument("--transfer_loss_weight", type=float, default=10) 73 | parser.add_argument("--transfer_loss", type=str, default="lmmd") 74 | return parser 75 | 76 | 77 | def set_random_seed(seed=0): 78 | # seed setting 79 | random.seed(seed) 80 | np.random.seed(seed) 81 | torch.manual_seed(seed) 82 | torch.cuda.manual_seed(seed) 83 | torch.backends.cudnn.deterministic = False 84 | torch.backends.cudnn.benchmark = True 85 | 86 | 87 | def load_data(args): 88 | """ 89 | src_domain, tgt_domain data to load 90 | """ 91 | folder_src = os.path.join(args.data_dir, args.src_domain + "/images") 92 | folder_tgt = os.path.join(args.data_dir, args.tgt_domain + "/images") 93 | source_loader, n_class = data_loader.load_data( 94 | folder_src, 95 | args.batch_size, 96 | infinite_data_loader=not args.epoch_based_training, 97 | train=True, 98 | num_workers=args.num_workers, 99 | ) 100 | target_train_loader, _ = data_loader.load_data( 101 | folder_tgt, 102 | args.batch_size, 103 | infinite_data_loader=not args.epoch_based_training, 104 | train=True, 105 | num_workers=args.num_workers, 106 | ) 107 | target_test_loader, _ = data_loader.load_data( 108 | folder_tgt, 109 | args.batch_size, 110 | infinite_data_loader=False, 111 | train=False, 112 | num_workers=args.num_workers, 113 | ) 114 | return source_loader, target_train_loader, target_test_loader, n_class 115 | 116 | 117 | def get_model(args): 118 | model = models.TransferNet( 119 | args.n_class, 120 | transfer_loss=args.transfer_loss, 121 | base_net=args.backbone, 122 | max_iter=args.max_iter, 123 | use_bottleneck=args.use_bottleneck, 124 | ).to(args.device) 125 | return model 126 | 127 | 128 | def get_optimizer(model, args): 129 | initial_lr = args.lr if not args.lr_scheduler else 1.0 130 | params = model.get_parameters(initial_lr=initial_lr) 131 | optimizer = torch.optim.SGD( 132 | params, 133 | lr=args.lr, 134 | momentum=args.momentum, 135 | weight_decay=args.weight_decay, 136 | nesterov=False, 137 | ) 138 | return optimizer 139 | 140 | 141 | def get_scheduler(optimizer, args): 142 | scheduler = torch.optim.lr_scheduler.LambdaLR( 143 | optimizer, 144 | lambda x: args.lr * (1.0 + args.lr_gamma * float(x)) ** (-args.lr_decay), 145 | ) 146 | return scheduler 147 | 148 | 149 | def test(model, target_test_loader, args): 150 | model.eval() 151 | test_loss = utils.AverageMeter() 152 | correct = 0 153 | criterion = torch.nn.CrossEntropyLoss() 154 | len_target_dataset = len(target_test_loader.dataset) 155 | with torch.no_grad(): 156 | for data, target in target_test_loader: 157 | data, target = data.to(args.device), target.to(args.device) 158 | s_output = model.predict(data) 159 | loss = criterion(s_output, target) 160 | test_loss.update(loss.item()) 161 | pred = torch.max(s_output, 1)[1] 162 | correct += torch.sum(pred == target) 163 | acc = 100.0 * correct / len_target_dataset 164 | return acc, test_loss.avg 165 | 166 | 167 | def train( 168 | source_loader, 169 | target_train_loader, 170 | target_test_loader, 171 | model, 172 | optimizer, 173 | lr_scheduler, 174 | args, 175 | ): 176 | len_source_loader = len(source_loader) 177 | len_target_loader = len(target_train_loader) 178 | n_batch = min(len_source_loader, len_target_loader) 179 | if n_batch == 0: 180 | n_batch = args.n_iter_per_epoch 181 | 182 | iter_source, iter_target = iter(source_loader), iter(target_train_loader) 183 | 184 | best_acc = 0 185 | stop = 0 186 | history = collections.defaultdict(list) 187 | 188 | for e in range(1, args.n_epoch + 1): 189 | model.train() 190 | train_loss_clf = utils.AverageMeter() 191 | train_loss_transfer = utils.AverageMeter() 192 | train_loss_total = utils.AverageMeter() 193 | mmdLog = utils.AverageMeter() 194 | gk1Log = utils.AverageMeter() 195 | gk5Log = utils.AverageMeter() 196 | 197 | model.epoch_based_processing(n_batch) 198 | 199 | if max(len_target_loader, len_source_loader) != 0: 200 | iter_source, iter_target = iter(source_loader), iter(target_train_loader) 201 | 202 | lr_scalar = optimizer.param_groups[0]["lr"] 203 | for _ in range(n_batch): 204 | data_source, label_source = next(iter_source) # .next() 205 | data_target, _ = next(iter_target) # .next() 206 | data_source, label_source = data_source.to(args.device), label_source.to( 207 | args.device 208 | ) 209 | data_target = data_target.to(args.device) 210 | 211 | clf_loss, transfer_loss, mmd, gk1, gk5 = model( 212 | data_source, data_target, label_source 213 | ) 214 | loss = clf_loss + args.transfer_loss_weight * transfer_loss 215 | 216 | optimizer.zero_grad() 217 | loss.backward() 218 | optimizer.step() 219 | if lr_scheduler: 220 | lr_scheduler.step() 221 | 222 | train_loss_clf.update(clf_loss.item()) 223 | train_loss_transfer.update(transfer_loss.item()) 224 | train_loss_total.update(loss.item()) 225 | 226 | mmdLog.update(mmd.item()) 227 | gk1Log.update(gk1.item()) 228 | gk5Log.update(gk5.item()) 229 | 230 | history["mmd"].append(mmdLog.avg) 231 | history["gk1"].append(gk1Log.avg) 232 | history["gk5"].append(gk5Log.avg) 233 | 234 | tensorboardLogger.update(step=e, lr=lr_scalar) 235 | tensorboardLogger.update(step=e, tr_cls_loss=train_loss_clf.avg) 236 | tensorboardLogger.update(step=e, tr_transfer=train_loss_transfer.avg) 237 | tensorboardLogger.update(step=e, tr_loss_total=train_loss_total.avg) 238 | 239 | tensorboardLogger.update(step=e, w_ss=mmdLog.avg) 240 | tensorboardLogger.update(step=e, w_tt=gk1Log.avg) 241 | tensorboardLogger.update(step=e, w_st=gk5Log.avg) 242 | 243 | info = "Epoch: [{:2d}/{}], cls_loss: {:.4f}, transfer_loss: {:.4f}, total_Loss: {:.4f}".format( 244 | e, 245 | args.n_epoch, 246 | train_loss_clf.avg, 247 | train_loss_transfer.avg, 248 | train_loss_total.avg, 249 | ) 250 | # Test 251 | stop += 1 252 | test_acc, test_loss = test(model, target_test_loader, args) 253 | info += ", test_loss {:4f}, test_acc: {:.4f}".format(test_loss, test_acc) 254 | tensorboardLogger.update(step=e, test_acc=test_acc.avg) 255 | history["acc"].append(test_acc) 256 | with open("/home/DSANLog/%s/stat.json" % currtime, "w") as f: 257 | json.dump(history, f) 258 | 259 | if best_acc < test_acc: 260 | best_acc = test_acc 261 | stop = 0 262 | torch.save(model.state_dict(), "/home/DSANLog/%s/best.pth" % currtime) 263 | 264 | # early stopping 265 | if args.early_stop > 0 and stop >= args.early_stop: 266 | print("early stopping") 267 | print(info) 268 | break 269 | print(info) 270 | print("Transfer result: {:.4f}".format(best_acc)) 271 | 272 | 273 | def main(): 274 | parser = get_parser() 275 | args = parser.parse_args() 276 | setattr( 277 | args, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu") 278 | ) 279 | print(args) 280 | set_random_seed(args.seed) 281 | source_loader, target_train_loader, target_test_loader, n_class = load_data(args) 282 | setattr(args, "n_class", n_class) 283 | if args.epoch_based_training: 284 | setattr( 285 | args, 286 | "max_iter", 287 | args.n_epoch * min(len(source_loader), len(target_train_loader)), 288 | ) 289 | else: 290 | setattr(args, "max_iter", args.n_epoch * args.n_iter_per_epoch) 291 | model = get_model(args) 292 | optimizer = get_optimizer(model, args) 293 | 294 | if args.lr_scheduler: 295 | scheduler = get_scheduler(optimizer, args) 296 | else: 297 | scheduler = None 298 | train( 299 | source_loader, 300 | target_train_loader, 301 | target_test_loader, 302 | model, 303 | optimizer, 304 | scheduler, 305 | args, 306 | ) 307 | 308 | 309 | if __name__ == "__main__": 310 | main() 311 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls 4 | import copy 5 | import torch 6 | from torch.hub import load_state_dict_from_url 7 | 8 | __all__ = [ 9 | "ResNet", 10 | "resnet18", 11 | "resnet34", 12 | "resnet50", 13 | "resnet101", 14 | "resnet152", 15 | "resnext50_32x4d", 16 | "resnext101_32x8d", 17 | "wide_resnet50_2", 18 | "wide_resnet101_2", 19 | ] 20 | 21 | model_urls = { 22 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 23 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 24 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 25 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 26 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 27 | "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", 28 | "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", 29 | "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", 30 | "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", 31 | } 32 | 33 | 34 | class ResNet(models.ResNet): 35 | """ResNets without fully connected layer""" 36 | 37 | def __init__(self, *args, **kwargs): 38 | super(ResNet, self).__init__(*args, **kwargs) 39 | self._out_features = self.fc.in_features # get out features dimension 40 | 41 | def forward(self, x): 42 | """""" 43 | x = self.conv1(x) 44 | x = self.bn1(x) 45 | x = self.relu(x) 46 | x = self.maxpool(x) 47 | 48 | x = self.layer1(x) 49 | x = self.layer2(x) 50 | x = self.layer3(x) 51 | x = self.layer4(x) 52 | 53 | x = self.avgpool(x) 54 | x = torch.flatten(x, 1) 55 | # x = x.view(-1, self._out_features) 56 | return x 57 | 58 | @property 59 | def out_features(self) -> int: 60 | """The dimension of output features""" 61 | return self._out_features 62 | 63 | def copy_head(self) -> nn.Module: 64 | """Copy the origin fully connected layer""" 65 | return copy.deepcopy(self.fc) 66 | 67 | 68 | class DropoutBlock(nn.Module): 69 | """ 70 | same as a basic block but adding dropout to it 71 | """ 72 | 73 | def __init__( 74 | self, basic_block: BasicBlock, dropout_rate: float = 0.0, force_dropout=True 75 | ): 76 | super(DropoutBlock, self).__init__() 77 | self.conv1 = basic_block.conv1 78 | self.bn1 = basic_block.bn1 79 | self.relu = basic_block.relu 80 | self.conv2 = basic_block.conv2 81 | self.bn2 = basic_block.bn2 82 | self.downsample = basic_block.downsample 83 | self.stride = basic_block.stride 84 | self.force_dropout = force_dropout 85 | self.dropout_rate = dropout_rate 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | out = torch.nn.functional.dropout( 94 | out, p=self.dropout_rate, training=self.training or self.force_dropout 95 | ) 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | 99 | if self.downsample is not None: 100 | identity = self.downsample(x) 101 | 102 | out = torch.nn.functional.dropout( 103 | out, p=self.dropout_rate, training=self.training or self.force_dropout 104 | ) 105 | 106 | out += identity 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class DropoutResnet(nn.Module): 113 | """adds dropout to an existing resnet""" 114 | 115 | def __init__(self, source_resnet: ResNet, dropout_rate: float = 0.0): 116 | 117 | super(DropoutResnet, self).__init__() 118 | self._norm_layer = source_resnet._norm_layer 119 | 120 | self.inplanes = source_resnet.inplanes 121 | self.dilation = source_resnet.dilation 122 | self.groups = source_resnet.groups 123 | self.base_width = source_resnet.base_width 124 | self.conv1 = source_resnet.conv1 125 | self.bn1 = source_resnet.bn1 126 | self.relu = source_resnet.relu 127 | self.maxpool = source_resnet.relu 128 | self.layer1 = self._make_layer(source_resnet.layer1, dropout_rate) 129 | self.layer2 = self._make_layer(source_resnet.layer2, dropout_rate) 130 | self.layer3 = self._make_layer(source_resnet.layer3, dropout_rate) 131 | self.layer4 = self._make_layer(source_resnet.layer4, dropout_rate) 132 | self.avgpool = source_resnet.avgpool 133 | self.fc = source_resnet.fc 134 | 135 | @staticmethod 136 | def _set_force_dropout_on_layer(force_dropout: bool, layer: nn.Sequential): 137 | for block in layer.children(): 138 | block.force_dropout = force_dropout 139 | 140 | def set_force_dropout(self, force_dropout): 141 | self._set_force_dropout_on_layer(force_dropout, self.layer1) 142 | self._set_force_dropout_on_layer(force_dropout, self.layer2) 143 | self._set_force_dropout_on_layer(force_dropout, self.layer3) 144 | self._set_force_dropout_on_layer(force_dropout, self.layer4) 145 | 146 | def _make_layer(self, source_layer: nn.Sequential, dropout_rate): 147 | return nn.Sequential( 148 | *[DropoutBlock(block, dropout_rate) for block in source_layer.children()] 149 | ) 150 | 151 | def _forward(self, x): 152 | x = self.conv1(x) 153 | x = self.bn1(x) 154 | x = self.relu(x) 155 | x = self.maxpool(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | x = self.layer3(x) 160 | x = self.layer4(x) 161 | 162 | x = self.avgpool(x) 163 | x = torch.flatten(x, 1) 164 | x = self.fc(x) 165 | 166 | return x 167 | 168 | # Allow for accessing forward method in a inherited class 169 | forward = _forward 170 | 171 | 172 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 173 | model = ResNet(block, layers, **kwargs) 174 | if pretrained: 175 | model_dict = model.state_dict() 176 | pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 177 | # remove keys from pretrained dict that doesn't appear in model dict 178 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 179 | del pretrained_dict["fc.weight"] 180 | del pretrained_dict["fc.bias"] 181 | model.load_state_dict(pretrained_dict, strict=False) 182 | return model 183 | 184 | 185 | def resnet18(pretrained=False, progress=True, **kwargs): 186 | r"""ResNet-18 model from 187 | `"Deep Residual Learning for Image Recognition" `_ 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | progress (bool): If True, displays a progress bar of the download to stderr 192 | """ 193 | return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 194 | 195 | 196 | def resnet34(pretrained=False, progress=True, **kwargs): 197 | r"""ResNet-34 model from 198 | `"Deep Residual Learning for Image Recognition" `_ 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | progress (bool): If True, displays a progress bar of the download to stderr 203 | """ 204 | return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) 205 | 206 | 207 | def resnet50(pretrained=False, progress=True, **kwargs): 208 | r"""ResNet-50 model from 209 | `"Deep Residual Learning for Image Recognition" `_ 210 | 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | progress (bool): If True, displays a progress bar of the download to stderr 214 | """ 215 | return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 216 | 217 | 218 | def resnet101(pretrained=False, progress=True, **kwargs): 219 | r"""ResNet-101 model from 220 | `"Deep Residual Learning for Image Recognition" `_ 221 | 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | progress (bool): If True, displays a progress bar of the download to stderr 225 | """ 226 | return _resnet( 227 | "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs 228 | ) 229 | 230 | 231 | def resnet152(pretrained=False, progress=True, **kwargs): 232 | r"""ResNet-152 model from 233 | `"Deep Residual Learning for Image Recognition" `_ 234 | 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | progress (bool): If True, displays a progress bar of the download to stderr 238 | """ 239 | return _resnet( 240 | "resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs 241 | ) 242 | 243 | 244 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 245 | r"""ResNeXt-50 32x4d model from 246 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | kwargs["groups"] = 32 253 | kwargs["width_per_group"] = 4 254 | return _resnet( 255 | "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs 256 | ) 257 | 258 | 259 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 260 | r"""ResNeXt-101 32x8d model from 261 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 262 | 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | progress (bool): If True, displays a progress bar of the download to stderr 266 | """ 267 | kwargs["groups"] = 32 268 | kwargs["width_per_group"] = 8 269 | return _resnet( 270 | "resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs 271 | ) 272 | 273 | 274 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 275 | r"""Wide ResNet-50-2 model from 276 | `"Wide Residual Networks" `_ 277 | 278 | The model is the same as ResNet except for the bottleneck number of channels 279 | which is twice larger in every block. The number of channels in outer 1x1 280 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 281 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 282 | 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | """ 287 | kwargs["width_per_group"] = 64 * 2 288 | return _resnet( 289 | "wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs 290 | ) 291 | 292 | 293 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 294 | r"""Wide ResNet-101-2 model from 295 | `"Wide Residual Networks" `_ 296 | 297 | The model is the same as ResNet except for the bottleneck number of channels 298 | which is twice larger in every block. The number of channels in outer 1x1 299 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 300 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 301 | 302 | Args: 303 | pretrained (bool): If True, returns a model pre-trained on ImageNet 304 | progress (bool): If True, displays a progress bar of the download to stderr 305 | """ 306 | kwargs["width_per_group"] = 64 * 2 307 | return _resnet( 308 | "wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs 309 | ) 310 | -------------------------------------------------------------------------------- /standard_curriculum_learning/main_curriculum_learning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import random 18 | import wget 19 | import time 20 | import warnings 21 | import json 22 | import collections 23 | import numpy as np 24 | 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.parallel 28 | import torch.optim 29 | import torch.utils.data 30 | from torch.utils.data import Subset 31 | 32 | from utils import get_dataset, get_model, get_optimizer, get_scheduler 33 | from utils import LossTracker,run_cmd 34 | from torch.utils.data import DataLoader 35 | from utils import get_pacing_function,balance_order_val 36 | 37 | parser = argparse.ArgumentParser(description='PyTorch Training') 38 | parser.add_argument('--data-dir', default='dataset', 39 | help='path to dataset') 40 | parser.add_argument('--order-dir', default='cifar10-cscores-orig-order.npz', 41 | help='path to train val idx') 42 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 43 | help='model architecture: (default: resnet18)') 44 | parser.add_argument('--dataset', default='cifar10', type=str, 45 | help='dataset') 46 | parser.add_argument('--printfreq', default=10, type=int, 47 | help='print frequency (default: 10)') 48 | parser.add_argument('--workers', default=4, type=int, 49 | help='number of data loading workers (default: 4)') 50 | parser.add_argument('--epochs', default=100, type=int, 51 | help='number of total epochs to run') 52 | parser.add_argument('-b', '--batchsize', default=128, type=int, 53 | help='mini-batch size (default: 256), this is the total') 54 | parser.add_argument('--optimizer', default="sgd", type=str, 55 | help='optimizer') 56 | parser.add_argument('--scheduler', default="cosine", type=str, 57 | help='lr scheduler') 58 | parser.add_argument('--lr', default=0.1, type=float, 59 | help='initial learning rate', dest='lr') 60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 61 | help='momentum') 62 | parser.add_argument('--wd', default=5e-4, type=float, 63 | help='weight decay (default: 1e-4)') 64 | parser.add_argument('--seed', default=None, type=int, 65 | help='seed for initializing training. ') 66 | parser.add_argument('--half', default=False, action='store_true', 67 | help='training with half precision') 68 | # curriculum params 69 | parser.add_argument("--pacing-f", default="linear", type=str, help="which pacing function to take") 70 | parser.add_argument('--pacing-a', default=1., type=float, 71 | help='weight decay (default: 1e-4)') 72 | parser.add_argument('--pacing-b', default=1., type=float, 73 | help='weight decay (default: 1e-4)') 74 | parser.add_argument("--ordering", default="curr", type=str, help="which test case to use. supports: standard, curriculum, anti and random") 75 | parser.add_argument('--rand-fraction', default=0., type=float, 76 | help='label curruption (default:0)') 77 | args = parser.parse_args() 78 | def main(): 79 | set_seed(args.seed) 80 | # create training and validation datasets and intiate the dataloaders 81 | tr_set = get_dataset(args.dataset, args.data_dir, 'train',rand_fraction=args.rand_fraction) 82 | if args.dataset == "cifar100N": 83 | val_set = get_dataset("cifar100", args.data_dir, 'val') 84 | tr_set_clean = get_dataset("cifar100", args.data_dir, 'train') 85 | else: 86 | val_set = get_dataset(args.dataset, args.data_dir, 'val') 87 | train_loader = torch.utils.data.DataLoader(tr_set, batch_size=args.batchsize,\ 88 | shuffle=True, num_workers=args.workers, pin_memory=True) 89 | test_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batchsize*2, 90 | shuffle=False, num_workers=args.workers, pin_memory=True) 91 | 92 | criterion_ind = nn.CrossEntropyLoss(reduction="none").cuda() 93 | # initiate a recorder for saving and loading stats and checkpoints 94 | if 'cscores-orig-order.npz' in args.order_dir: 95 | temp_path = os.path.join("orders",args.dataset+'-cscores-orig-order.npz') 96 | if not os.path.isfile(temp_path): 97 | print ('Downloading the data cifar10-cscores-orig-order.npz and cifar100-cscores-orig-order.npz to folder orders') 98 | if 'cifar100' == args.dataset: 99 | url = 'https://pluskid.github.io/structural-regularity/cscores/cifar100-cscores-orig-order.npz' 100 | if 'cifar10' == args.dataset: 101 | url = 'https://pluskid.github.io/structural-regularity/cscores/cifar10-cscores-orig-order.npz' 102 | wget.download(url, './orders') 103 | temp_x = np.load(temp_path)['scores'] 104 | ordering = collections.defaultdict(list) 105 | list(map(lambda a, b: ordering[a].append(b), np.arange(len(temp_x)),temp_x)) 106 | order = [k for k, v in sorted(ordering.items(), key=lambda item: -1*item[1][0])] 107 | else: 108 | print ('Please check if the files %s in your folder -- orders. See ./orders/README.md for instructions on how to create the folder' %(args.order_dir)) 109 | order = [x for x in list(torch.load(os.path.join("orders",args.order_dir)).keys())] 110 | 111 | order,order_val = balance_order_val(order, tr_set, num_classes=len(tr_set.classes)) 112 | 113 | #decide CL, Anti-CL, or random-CL 114 | if args.ordering == "random": 115 | np.random.shuffle(order) 116 | elif args.ordering == "anti_curr": 117 | order = [x for x in reversed(order)] 118 | 119 | #check the statistics 120 | bs = args.batchsize 121 | N = len(order) 122 | myiterations = (N//bs+1)*args.epochs 123 | 124 | #initial training 125 | model = get_model(args.arch, tr_set.nchannels, tr_set.imsize, len(tr_set.classes), args.half) 126 | optimizer = get_optimizer(args.optimizer, model.parameters(), args.lr, args.momentum, args.wd) 127 | scheduler = get_scheduler(args.scheduler, optimizer, num_epochs=myiterations) 128 | 129 | start_epoch = 0 130 | total_iter = 0 131 | history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [],"test_loss": [], "test_acc": [], "iter": [0,] } 132 | start_time = time.time() 133 | 134 | if args.dataset == "cifar100N": 135 | val_set = Subset(tr_set_clean, order_val) 136 | else: 137 | val_set = Subset(tr_set, order_val) 138 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batchsize*2, 139 | shuffle=False, num_workers=args.workers, pin_memory=True) 140 | trainsets = Subset(tr_set, order) 141 | train_loader = torch.utils.data.DataLoader(trainsets, batch_size=args.batchsize, 142 | shuffle=True, num_workers=args.workers, pin_memory=True) 143 | criterion = nn.CrossEntropyLoss().cuda() 144 | 145 | if args.ordering == "standard": 146 | iterations = 0 147 | for epoch in range(args.epochs): 148 | tr_loss, tr_acc1, iterations = train(train_loader, model, criterion, optimizer,scheduler, epoch,iterations) 149 | val_loss, val_acc1 = validate(val_loader, model, criterion) 150 | test_loss, test_acc1 = validate(test_loader, model, criterion) 151 | print ("%s epoch %s iterations w/ LEARNING RATE %s"%(epoch, iterations,optimizer.param_groups[0]["lr"])) 152 | history["test_loss"].append(test_loss) 153 | history["test_acc"].append(test_acc) 154 | history["val_loss"].append(val_loss) 155 | history["val_acc"].append(val_acc1) 156 | history["train_loss"].append(tr_loss) 157 | history["train_acc"].append(tr_acc1) 158 | history["iter"].append(iterations) 159 | torch.save(history,"stat.pt") 160 | else: 161 | all_sum = N/(myiterations*(myiterations+1)/2) 162 | iter_per_epoch = N//bs 163 | pre_iterations = 0 164 | startIter = 0 165 | pacing_function = get_pacing_function(myiterations, N, args) 166 | 167 | startIter_next = pacing_function(0) # <======================================= 168 | print ('0 iter data between %s and %s w/ Pacing %s'%(startIter,startIter_next,args.pacing_f,)) 169 | trainsets = Subset(tr_set, list(order[startIter:max(startIter_next,256)])) 170 | train_loader = torch.utils.data.DataLoader(trainsets, batch_size=args.batchsize, 171 | shuffle=True, num_workers=args.workers, pin_memory=True) 172 | dataiter = iter(train_loader) 173 | step = 0 174 | 175 | while step < myiterations: 176 | tracker = LossTracker(len(train_loader), f'iteration : [{step}]', args.printfreq) 177 | for images, target in train_loader: 178 | step += 1 179 | images, target = cuda_transfer(images, target) 180 | output = model(images) 181 | loss = criterion(output, target) 182 | optimizer.zero_grad() 183 | loss.backward() 184 | optimizer.step() 185 | scheduler.step() 186 | tracker.update(loss, output, target) 187 | tracker.display(step-pre_iterations) 188 | 189 | #If we hit the end of the dynamic epoch build a new data loader 190 | pre_iterations = step 191 | if startIter_next <= N: 192 | startIter_next = pacing_function(step)# <======================================= 193 | print ("%s iter data between %s and %s w/ Pacing %s and LEARNING RATE %s "%(step,startIter,startIter_next,args.pacing_f, optimizer.param_groups[0]["lr"])) 194 | train_loader = torch.utils.data.DataLoader(Subset(tr_set, list(order[startIter:max(startIter_next,256)])),\ 195 | batch_size=args.batchsize,\ 196 | shuffle=True, num_workers=args.workers, pin_memory=True) 197 | # start your record 198 | if step > 50: 199 | tr_loss, tr_acc1 = tracker.losses.avg, tracker.top1.avg 200 | val_loss, val_acc1 = validate(val_loader, model, criterion) 201 | test_loss, test_acc1 = validate(test_loader, model, criterion) 202 | # record 203 | history["test_loss"].append(test_loss) 204 | history["test_acc"].append(test_acc1) 205 | history["val_loss"].append(val_loss) 206 | history["val_acc"].append(val_acc1) 207 | history["train_loss"].append(tr_loss) 208 | history["train_acc"].append(tr_acc1) 209 | history['iter'].append(step) 210 | torch.save(history,"stat.pt") 211 | # reinitialization<================= 212 | model.train() 213 | 214 | 215 | def train(train_loader, model, criterion, optimizer,scheduler, epoch, iterations): 216 | # switch to train mode 217 | model.train() 218 | tracker = LossTracker(len(train_loader), f'Epoch: [{epoch}]', args.printfreq) 219 | for i, (images, target) in enumerate(train_loader): 220 | iterations += 1 221 | images, target = cuda_transfer(images, target) 222 | output = model(images) 223 | loss = criterion(output, target) 224 | optimizer.zero_grad() 225 | loss.backward() 226 | optimizer.step() 227 | tracker.update(loss, output, target) 228 | tracker.display(i) 229 | scheduler.step() 230 | return tracker.losses.avg, tracker.top1.avg, iterations 231 | 232 | def validate(val_loader, model, criterion): 233 | # switch to evaluate mode 234 | model.eval() 235 | with torch.no_grad(): 236 | tracker = LossTracker(len(val_loader), f'val', args.printfreq) 237 | for i, (images, target) in enumerate(val_loader): 238 | images, target = cuda_transfer(images, target) 239 | output = model(images) 240 | loss = criterion(output, target) 241 | tracker.update(loss, output, target) 242 | tracker.display(i) 243 | return tracker.losses.avg, tracker.top1.avg 244 | 245 | def set_seed(seed=None): 246 | if seed is not None: 247 | random.seed(args.seed) 248 | torch.manual_seed(args.seed) 249 | torch.backends.cudnn.deterministic = True 250 | warnings.warn('You have chosen to seed training. ' 251 | 'This will turn on the CUDNN deterministic setting, ' 252 | 'which can slow down your training considerably! ' 253 | 'You may see unexpected behavior when restarting ' 254 | 'from checkpoints.') 255 | 256 | def cuda_transfer(images, target): 257 | images = images.cuda(non_blocking=True) 258 | target = target.cuda(non_blocking=True) 259 | if args.half: images = images.half() 260 | return images, target 261 | 262 | if __name__ == '__main__': 263 | main() 264 | 265 | -------------------------------------------------------------------------------- /standard_curriculum_learning/prediction_depth/get_pd_vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from knndnn import VGGPD, MLP7, ResNetPD, BasicBlockPD 3 | from torchvision.datasets import CIFAR10 4 | import torchvision.transforms as T 5 | from knndnn import knn_predict 6 | from torch.utils.data import DataLoader, Subset 7 | import torch.nn as nn 8 | import collections 9 | import numpy as np 10 | import json 11 | from torch.cuda.amp import autocast 12 | from torch.optim.lr_scheduler import CosineAnnealingLR 13 | from torchvision.models import vgg16 14 | from sklearn.model_selection import train_test_split 15 | import torch.nn.functional as F 16 | import random 17 | import warnings 18 | import argparse 19 | import os 20 | 21 | parser = argparse.ArgumentParser(description='arguments to compute prediction depth for each data sample') 22 | parser.add_argument('--train_ratio', default=0.5, type=float, help='ratio of train split / total data split') 23 | parser.add_argument('--result_dir', default='./cl_results_vgg', type=str, help='directory to save ckpt and results') 24 | parser.add_argument('--data', default='cifar10', type=str, help='dataset') 25 | parser.add_argument('--arch', default='vgg', type=str, help='vgg / mlp / resnet') 26 | parser.add_argument('--get_train_pd', default=True, type=bool, help='get prediction depth for training split') 27 | parser.add_argument('--get_val_pd', default=True, type=bool, help='get prediction depth for validation split') 28 | parser.add_argument('--resume', default=False, type=bool, help='resume from the ckpt') 29 | parser.add_argument('--fraction', default=0.4, type=float, help='ratio of noise') 30 | parser.add_argument('--half', default=False, type=str, help='use amp if GPU memory is 15 GB; set to False if GPU memory is 32 GB ') 31 | parser.add_argument('--num_epochs', default=80, type=int, help='number of epochs for training') 32 | parser.add_argument('--total_iteration', default=15000, type=str, help='if training process is more than total iteration then stop') 33 | parser.add_argument('--num_classes', default=10, type=int, help='number of classes') 34 | parser.add_argument('--num_samples', default=10000, type=int, help='number of samples') 35 | parser.add_argument('--knn_k', default=30, type=int, help='k nearest neighbors of knn classifier') 36 | 37 | args = parser.parse_args() 38 | 39 | # hyper parameters 40 | # change cifar10 as (img, label), index 41 | if args.arch == 'mlp': 42 | 'depth index starts from 0 and end with max_prediction_depth - 1' 43 | max_prediction_depth = 7 44 | elif args.arch == 'vgg': 45 | max_prediction_depth = 14 46 | elif args.arch == 'resnet': 47 | max_prediction_depth = 10 48 | 49 | lr_init = 0.04 50 | momentum = 0.9 51 | lr_decay = 0.2 52 | if args.arch == 'mlp': 53 | mile_stones = [1250, 4000, 12000] 54 | elif args.arch == 'vgg': 55 | mile_stones = [1000, 5000] 56 | elif args.arch == 'resnet': 57 | mile_stones = [7000] 58 | 59 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 60 | 61 | 62 | def mile_stone_step(optimizer, curr_iter): 63 | if curr_iter in mile_stones: 64 | for param_gp in optimizer.param_groups: 65 | param_gp['lr'] *= lr_decay 66 | 67 | 68 | def trainer(trainloader, testloader, model, optimizer, num_epochs, criterion, random_sd, flip): 69 | curr_iteration = 0 70 | cos_scheduler = CosineAnnealingLR(optimizer, num_epochs) 71 | history = {'train_loss': [], 'test_loss': [], 'train_acc': [], 'test_acc': []} 72 | for epo in range(num_epochs): 73 | train_acc = 0 74 | train_num_total = 0 75 | for (imgs, labels), idx in trainloader: 76 | curr_iteration += 1 77 | imgs, labels = imgs.cuda(non_blocking=True), labels.cuda(non_blocking=True) 78 | logits = model(imgs, train=True) 79 | loss = criterion(logits, labels) 80 | prds = logits.argmax(1) 81 | train_acc += sum(prds == labels) 82 | train_num_total += imgs.shape[0] 83 | 84 | optimizer.zero_grad() 85 | loss.backward() 86 | optimizer.step() 87 | # mile_stone_step(optimizer, curr_iteration) 88 | cos_scheduler.step() 89 | history['train_loss'].append(loss.item()) 90 | history['train_acc'].append(train_acc.item() / train_num_total) 91 | print('epoch:', epo, 'lr', optimizer.param_groups[0]['lr'], 'loss', loss.item(), 'train_acc', 92 | train_acc.item() / train_num_total) 93 | torch.save(model.state_dict(), os.path.join(args.result_dir, 'ms{}_{}sgd{}_{}.pt'.format(args.arch, args.data, random_sd, flip))) 94 | with torch.no_grad(): 95 | test_acc = 0 96 | test_num_total = 0 97 | for (imgs, labels), idx in testloader: 98 | imgs, labels = imgs.cuda(non_blocking=True), labels.cuda(non_blocking=True) 99 | logits = model(imgs, train=True) 100 | loss = criterion(logits, labels) 101 | prds = logits.argmax(1) 102 | test_acc += sum(prds == labels) 103 | test_num_total += imgs.shape[0] 104 | print('epoch:', epo, 'lr', optimizer.param_groups[0]['lr'], 'loss', loss.item(), 'test_acc', 105 | test_acc.item() / test_num_total) 106 | history['test_loss'].append(loss.item()) 107 | history['test_acc'].append(test_acc.item() / test_num_total) 108 | with open(os.path.join(args.result_dir, 'train_test_history_{}_sd{}_{}.pt'.format(args.arch, seed, flip)), 'w') as f: 109 | json.dump(history, f) 110 | 111 | if curr_iteration >= args.total_iteration: 112 | break 113 | return model 114 | 115 | 116 | def _get_feature_bank_from_kth_layer(model, dataloader, k): 117 | print(k, 'layer feature bank gotten') 118 | with torch.no_grad(): 119 | for (img, all_label), idx in dataloader: 120 | img = img.cuda(non_blocking=True) 121 | all_label = all_label.cuda(non_blocking=True) 122 | if args.half: 123 | with autocast(): 124 | _, fms = model(img, k, train=False) 125 | else: 126 | _, fms = model(img, k, train=False) 127 | return fms, all_label 128 | 129 | 130 | def get_knn_prds_k_layer(model, evaloader, floader, k, train_split=True): 131 | knn_labels_all = [] 132 | knn_conf_gt_all = [] # This statistics can be noisy 133 | indices_all = [] 134 | f_bank, all_labels = _get_feature_bank_from_kth_layer(model, floader, k) 135 | f_bank = f_bank.t().contiguous() 136 | with torch.no_grad(): 137 | for j, ((imgs, labels), idx) in enumerate(evaloader): 138 | imgs = imgs.cuda(non_blocking=True) 139 | labels_b = labels.cuda(non_blocking=True) 140 | nm_cls = 10 141 | if args.half: 142 | with autocast(): 143 | _, inp_f_curr = model(imgs, k, train=False) 144 | else: 145 | _, inp_f_curr = model(imgs, k, train=False) 146 | knn_scores = knn_predict(inp_f_curr, f_bank, all_labels, classes=nm_cls, knn_k=args.knn_k, knn_t=1, rm_top1=train_split) # B x C 147 | knn_probs = F.normalize(knn_scores, p=1, dim=1) 148 | knn_labels_prd = knn_probs.argmax(1) 149 | knn_conf_gt = knn_probs.gather(dim=1, index=labels_b[:, None]) # B x 1 150 | knn_labels_all.append(knn_labels_prd) 151 | knn_conf_gt_all.append(knn_conf_gt) 152 | indices_all.append(idx) 153 | knn_labels_all = torch.cat(knn_labels_all, dim=0) # N x 1 154 | knn_conf_gt_all = torch.cat(knn_conf_gt_all, dim=0).squeeze() 155 | indices_all = np.concatenate(indices_all, 0) 156 | return knn_labels_all, knn_conf_gt_all, indices_all 157 | 158 | 159 | def _get_prediction_depth(knn_labels_all): 160 | """ 161 | get prediction depth for a sample. reverse knn labels list and increase the counter until the label is different 162 | :param knn_labels_all: 163 | :return: 164 | """ 165 | pd = 0 166 | knn_labels_all = list(reversed(knn_labels_all)) 167 | while knn_labels_all[pd] == knn_labels_all[0] and pd <= max_prediction_depth - 2: 168 | pd += 1 169 | return max_prediction_depth - pd 170 | 171 | def set_seed(seed=1234): 172 | if seed is not None: 173 | random.seed(seed) 174 | np.random.seed(seed) 175 | torch.manual_seed(seed) 176 | # torch.backends.cudnn.deterministic = True 177 | warnings.warn('You have chosen to seed training. ' 178 | 'This will turn on the CUDNN deterministic setting, ' 179 | 'which can slow down your training considerably! ' 180 | 'You may see unexpected behavior when restarting ' 181 | 'from checkpoints.') 182 | 183 | def main(train_idx, val_idx, random_seed=1234, flip=''): 184 | # for simplicity, we do not use data augmentation when measuring difficulty 185 | # CIFAR10 w / 40% (Fixed) Randomized Labels 186 | # only the training dataset is shuffle. Datasets for prediction depth and testing remains the same as cifar10 original 187 | train_transform = T.Compose([ 188 | T.RandomCrop(32, padding=4), 189 | T.RandomHorizontalFlip(), 190 | T.ToTensor(), 191 | T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=(0.247, 0.243, 0.261)) 192 | ]) 193 | test_transform = T.Compose([T.ToTensor(), 194 | T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=(0.247, 0.243, 0.261)) 195 | ]) 196 | if args.data == 'cifar10': 197 | trainset = CIFAR10('./', transform=train_transform, train=False, download=True) 198 | testset = CIFAR10('./', transform=test_transform, train=True, download=True) 199 | else: 200 | raise NotImplementedError 201 | 202 | train_split = Subset(trainset, train_idx) 203 | supportset = train_split 204 | val_split = Subset(trainset, val_idx) 205 | trainloader = DataLoader(train_split, batch_size=128, shuffle=True, num_workers=2, pin_memory=True) 206 | testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2, pin_memory=True) 207 | 208 | supportloader = DataLoader(supportset, batch_size=len(supportset), shuffle=False, num_workers=1, pin_memory=True) 209 | if args.get_train_pd: 210 | # pd (train) data order follows train_indices 211 | evaluate_loader_train = DataLoader(train_split, batch_size=200, shuffle=False, num_workers=1, pin_memory=True) 212 | if args.get_val_pd: 213 | # pd (val) data order follows val_indices 214 | evaluate_loader_test = DataLoader(val_split, batch_size=200, shuffle=False, num_workers=1, pin_memory=True) 215 | 216 | if args.arch == 'mlp': 217 | model = MLP7(args.num_classes) 218 | elif args.arch == 'vgg': 219 | ecd = vgg16().features 220 | model = VGGPD(ecd, args.num_classes) 221 | elif args.arch == 'resnet': 222 | model = ResNetPD(BasicBlockPD, [2, 2, 2, 2], temp=1.0) 223 | else: 224 | raise NotImplementedError 225 | 226 | model = model.to(device) 227 | criterion = nn.CrossEntropyLoss() 228 | 229 | optimizer = torch.optim.SGD(model.parameters(), lr=lr_init, momentum=momentum) 230 | if not args.resume: 231 | model = trainer(trainloader, testloader, model, optimizer, args.num_epochs, criterion, random_seed, flip) 232 | else: 233 | print('loading model from ckpt') 234 | model.load_state_dict(torch.load(os.path.join(args.result_dir, 'ms{}_{}sgd{}_{}.pt'.format(args.arch, args.data, random_seed, flip)))) 235 | 236 | if args.get_train_pd: 237 | # TODO exclude current batch from support set 238 | index_knn_y = collections.defaultdict(list) 239 | index_pd = collections.defaultdict(list) 240 | knn_gt_conf_all = collections.defaultdict(list) 241 | for k in range(max_prediction_depth): 242 | knn_labels, knn_conf_gt_all, indices_all = get_knn_prds_k_layer(model, evaluate_loader_train, supportloader, 243 | k, train_split=args.get_train_pd) 244 | for idx, knn_l, knn_conf_gt in zip(indices_all, knn_labels, knn_conf_gt_all): 245 | index_knn_y[int(idx)].append(knn_l.item()) 246 | knn_gt_conf_all[int(idx)].append(knn_conf_gt.item()) 247 | for idx, knn_ls in index_knn_y.items(): 248 | index_pd[idx].append(_get_prediction_depth(knn_ls)) 249 | 250 | print(len(index_pd), len(index_knn_y), len(knn_gt_conf_all)) 251 | with open(os.path.join(args.result_dir, 'ms{}train_seed{}_f{}_trainpd.pkl'.format(args.arch, random_seed, flip)), 'w') as f: 252 | json.dump(index_pd, f) 253 | 254 | if args.get_val_pd: 255 | index_knn_y = collections.defaultdict(list) 256 | index_pd = collections.defaultdict(list) 257 | knn_gt_conf_all = collections.defaultdict(list) 258 | for k in range(max_prediction_depth): 259 | knn_labels, knn_conf_gt_all, indices_all = get_knn_prds_k_layer(model, evaluate_loader_test, supportloader, 260 | k, train_split=not(args.get_val_pd)) 261 | for idx, knn_l, knn_conf_gt in zip(indices_all, knn_labels, knn_conf_gt_all): 262 | index_knn_y[int(idx)].append(knn_l.item()) 263 | knn_gt_conf_all[int(idx)].append(knn_conf_gt.item()) 264 | for idx, knn_ls in index_knn_y.items(): 265 | index_pd[idx].append(_get_prediction_depth(knn_ls)) 266 | 267 | print(len(index_pd), len(index_knn_y), len(knn_gt_conf_all)) 268 | with open(os.path.join(args.result_dir, 'ms{}_seed{}_f{}_test_pd.pkl'.format(args.arch, random_seed, flip)), 'w') as f: 269 | json.dump(index_pd, f) 270 | 271 | 272 | if __name__ == '__main__': 273 | seeds = [9203, 9304, 9837, 9612, 3456, 5210] 274 | for seed in seeds: 275 | set_seed(seed) 276 | train_indices, val_indices = train_test_split(np.arange(args.num_samples), train_size=args.train_ratio, 277 | test_size=(1 - args.train_ratio)) # split the data 278 | main(train_indices, val_indices, random_seed=seed, flip='') 279 | main(val_indices, train_indices, random_seed=seed, flip='flip') 280 | -------------------------------------------------------------------------------- /standard_curriculum_learning/prediction_depth/knndnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class VGGPD(nn.Module): 6 | def __init__( 7 | self, 8 | encoder=None, 9 | num_classes=100 10 | ): 11 | super(VGGPD, self).__init__() 12 | self.encoder = encoder 13 | self.classifier = nn.Sequential(nn.Flatten(), 14 | nn.Linear(512, num_classes)) 15 | 16 | def forward(self, x, k=0, train=True): 17 | """ 18 | 19 | :param x: 20 | :param k: output fms from the kth conv2d or the last layer 21 | :return: 22 | """ 23 | n_layer = 0 24 | _fm = None 25 | for m in self.encoder.children(): 26 | x = m(x) 27 | if not train: 28 | if isinstance(m, nn.Conv2d): 29 | if n_layer == k: 30 | return None, x.view(x.shape[0], -1) # B x (C x F x F) 31 | n_layer += 1 32 | logits = self.classifier(x) 33 | if not train: 34 | if k == n_layer: 35 | _fm = torch.softmax(logits, 1) 36 | return None, _fm.view(_fm.shape[0], -1) # B x (C x F x F) 37 | else: 38 | return logits 39 | 40 | 41 | class MLP7(nn.Module): 42 | def __init__(self, num_classes=10): 43 | super(MLP7, self).__init__() 44 | test_in = torch.randn(1, 3, 32, 32).view(1, -1) 45 | self.fl = nn.Flatten() 46 | self.d1 = nn.Linear(test_in.shape[1], 2048) 47 | self.d2 = nn.Linear(2048, 2048) 48 | self.d3 = nn.Linear(2048, 2048) 49 | self.d4 = nn.Linear(2048, 2048) 50 | self.d5 = nn.Linear(2048, 2048) 51 | self.d6 = nn.Linear(2048, 2048) 52 | self.d7 = nn.Linear(2048, num_classes) 53 | 54 | def forward(self, x, k=0, train=True): 55 | representations = [] 56 | f1 = self.d1(self.fl(x)) 57 | representations.append(f1) # B x 1 x F 58 | f2 = self.d2(torch.relu_(f1)) 59 | representations.append(f2) 60 | f3 = self.d3(torch.relu_(f2)) 61 | representations.append(f3) 62 | f4 = self.d4(torch.relu_(f3)) 63 | representations.append(f4) 64 | f5 = self.d5(torch.relu_(f4)) 65 | representations.append(f5) 66 | f6 = self.d6(torch.relu_(f5)) 67 | representations.append(f6) 68 | logits = self.d7(torch.relu_(f6)) 69 | 70 | # the last representation is added after softmax 71 | f7 = torch.softmax(logits, dim=1) 72 | representations.append(f7) 73 | if train: 74 | return logits 75 | else: 76 | return None, representations[k] 77 | 78 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t, rm_top1=True, dist='l2'): 79 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 80 | B, F = feature.shape 81 | K, F = feature_bank.shape 82 | if dist =='cosine': 83 | feature = F.normalize(feature, dim=1, p=2.0) 84 | feature_bank = F.normalize(feature_bank, dim=1, p=2.0) # normalize feature dim 85 | feature.mul_(feature_bank.t().contiguous()) # similarity 86 | elif dist =='l2': 87 | feature = feature.unsqueeze(1).expand(B, K, F) 88 | feature_bank = feature_bank.unsqueeze(0).expand(B, K, F) 89 | feature.sub_(feature_bank).pow_(2).sum_(2) # similarity 90 | else: 91 | raise NotImplementedError 92 | 93 | # [B, K] 94 | if rm_top1: 95 | sim_weight_add_one, sim_indices_add_one = feature.topk(k=(knn_k + 1), dim=-1) 96 | sim_weight, sim_indices = sim_weight_add_one[:, 1:], sim_indices_add_one[:, 1:] # remove the nearest pt of current evaluating pt in the train split 97 | else: 98 | sim_weight, sim_indices = feature.topk(k=knn_k, dim=-1) 99 | # [B, K] labels for all pts in feature bank along dim1 100 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices) 101 | 102 | sim_weight = (sim_weight / knn_t).exp() 103 | 104 | # counts for each class 105 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device) 106 | # [B*K, C] 107 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) 108 | # weighted score ---> [B, C] 109 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) 110 | # pred_prob = F.normalize(pred_scores, p=1, dim=1) 111 | # pred_labels = pred_scores.argsort(dim=-1, descending=True) # rank the knn labels 112 | return pred_scores 113 | 114 | class BasicBlockPD(nn.Module): 115 | expansion = 1 116 | 117 | def __init__(self, in_planes, planes, stride=1): 118 | super(BasicBlockPD, self).__init__() 119 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 120 | self.bn1 = nn.BatchNorm2d(planes) 121 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 122 | self.bn2 = nn.BatchNorm2d(planes) 123 | 124 | self.shortcut = nn.Sequential() 125 | if stride != 1 or in_planes != self.expansion*planes: 126 | self.shortcut = nn.Sequential( 127 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(self.expansion*planes) 129 | ) 130 | 131 | def forward(self, x, train=True): 132 | out = F.relu_(self.bn1(self.conv1(x))) 133 | out = self.bn2(self.conv2(out)) 134 | out += self.shortcut(x) 135 | if not train: 136 | return None, out 137 | else: 138 | out = F.relu(out) 139 | return out 140 | 141 | 142 | class ResNetPD(nn.Module): 143 | def __init__(self, block, num_blocks, num_classes=10, temp=1.0): 144 | super(ResNetPD, self).__init__() 145 | self.in_planes = 64 146 | 147 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 148 | self.bn1 = nn.BatchNorm2d(64) 149 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 150 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 151 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 152 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 153 | self.fc = nn.Linear(512*block.expansion, num_classes) 154 | self.temp = temp 155 | 156 | def _make_layer(self, block, planes, num_blocks, stride): 157 | strides = [stride] + [1]*(num_blocks-1) 158 | layers = [] 159 | for stride in strides: 160 | layers.append(block(self.in_planes, planes, stride)) 161 | self.in_planes = planes * block.expansion 162 | return nn.Sequential(*layers) 163 | 164 | def forward(self, x, k=0, train=True): 165 | ''' 166 | 167 | :param x: 168 | :param k: 169 | :param train: switch model to test and extract the FMs of the kth layer 170 | :return: 171 | ''' 172 | i = 0 173 | out = self.bn1(self.conv1(x)) 174 | if k==i and not(train): 175 | return None, out.view(out.shape[0], -1) 176 | out = torch.relu_(out) 177 | i +=1 178 | for module in self.layer1: 179 | if k ==i and not(train): 180 | _, out = module(out, train=False) # take the output of ResBlock before relu 181 | return None, out.view(out.shape[0], -1) 182 | else: 183 | out = module(out) 184 | out = torch.relu_(out) 185 | i+=1 186 | 187 | for module in self.layer2: 188 | if k ==i and not(train): 189 | _, out = module(out, train=False) # take the output of ResBlock before relu 190 | return None, out.view(out.shape[0], -1) 191 | else: 192 | out = module(out) 193 | out = torch.relu_(out) 194 | i+=1 195 | for module in self.layer3: 196 | if k ==i and not(train): 197 | _, out = module(out, train=False) # take the output of ResBlock before relu 198 | return None, out.view(out.shape[0], -1) 199 | else: 200 | out = module(out) 201 | out = torch.relu_(out) 202 | i+=1 203 | for module in self.layer4: 204 | if k ==i and not(train): 205 | _, out = module(out, train=False) # take the output of ResBlock before relu 206 | return None, out.view(out.shape[0], -1) 207 | else: 208 | out = module(out) 209 | out = torch.relu_(out) 210 | i+=1 211 | out = F.avg_pool2d(out, 4) 212 | out = out.view(out.size(0), -1) 213 | out = self.fc(out) / self.temp 214 | if k == i and not (train): 215 | _f = F.softmax(out, 1) # take the output of softmax 216 | return None, _f 217 | else: 218 | return out 219 | 220 | class Conv2d(nn.Conv2d): 221 | 222 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 223 | padding=0, dilation=1, groups=1, bias=True): 224 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 225 | padding, dilation, groups, bias) 226 | 227 | def forward(self, x): 228 | weight = self.weight 229 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 230 | keepdim=True).mean(dim=3, keepdim=True) 231 | weight = weight - weight_mean 232 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 233 | weight = weight / std.expand_as(weight) 234 | return F.conv2d(x, weight, self.bias, self.stride, 235 | self.padding, self.dilation, self.groups) 236 | 237 | 238 | class BasicBlockWS(nn.Module): 239 | expansion = 1 240 | 241 | def __init__(self, in_planes, planes, stride=1): 242 | super(BasicBlockWS, self).__init__() 243 | self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 244 | self.gn1 = nn.GroupNorm(1, planes) 245 | self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 246 | self.gn2 = nn.GroupNorm(1, planes) 247 | 248 | self.shortcut = nn.Sequential() 249 | if stride != 1 or in_planes != self.expansion*planes: 250 | self.shortcut = nn.Sequential( 251 | Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 252 | nn.GroupNorm(1, self.expansion*planes) 253 | ) 254 | 255 | def forward(self, x, train=True): 256 | out = F.relu_(self.gn1(self.conv1(x))) 257 | out = self.gn2(self.conv2(out)) 258 | out += self.shortcut(x) 259 | if not train: 260 | return None, out 261 | else: 262 | out = F.relu(out) 263 | return out 264 | 265 | 266 | class ResNetWS(nn.Module): 267 | ''' 268 | We use Conv2d (weight standardization) to replace nn.Conv2d and Group norm to replace BN2d 269 | ''' 270 | def __init__(self, block, num_blocks, num_classes=10, temp=1.0): 271 | super(ResNetWS, self).__init__() 272 | self.in_planes = 64 273 | 274 | self.conv1 = Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 275 | self.gn1 = nn.GroupNorm(1, 64) 276 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 277 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 278 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 279 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 280 | self.fc = nn.Linear(512*block.expansion, num_classes) 281 | self.temp = temp 282 | 283 | def _make_layer(self, block, planes, num_blocks, stride): 284 | strides = [stride] + [1]*(num_blocks-1) 285 | layers = [] 286 | for stride in strides: 287 | layers.append(block(self.in_planes, planes, stride)) 288 | self.in_planes = planes * block.expansion 289 | return nn.Sequential(*layers) 290 | 291 | def forward(self, x, k=0, train=True): 292 | ''' 293 | 294 | :param x: 295 | :param k: 296 | :param train: switch model to test and extract the FMs of the kth layer 297 | :return: 298 | ''' 299 | i = 0 300 | out = self.gn1(self.conv1(x)) 301 | if k==i and not(train): 302 | return None, out.view(out.shape[0], -1) 303 | out = torch.relu_(out) 304 | i +=1 305 | for module in self.layer1: 306 | if k ==i and not(train): 307 | _, out = module(out, train=False) # take the output of ResBlock before relu 308 | return None, out.view(out.shape[0], -1) 309 | else: 310 | out = module(out) 311 | out = torch.relu_(out) 312 | i+=1 313 | 314 | for module in self.layer2: 315 | if k ==i and not(train): 316 | _, out = module(out, train=False) # take the output of ResBlock before relu 317 | return None, out.view(out.shape[0], -1) 318 | else: 319 | out = module(out) 320 | out = torch.relu_(out) 321 | i+=1 322 | for module in self.layer3: 323 | if k ==i and not(train): 324 | _, out = module(out, train=False) # take the output of ResBlock before relu 325 | return None, out.view(out.shape[0], -1) 326 | else: 327 | out = module(out) 328 | out = torch.relu_(out) 329 | i+=1 330 | for module in self.layer4: 331 | if k ==i and not(train): 332 | _, out = module(out, train=False) # take the output of ResBlock before relu 333 | return None, out.view(out.shape[0], -1) 334 | else: 335 | out = module(out) 336 | out = torch.relu_(out) 337 | i+=1 338 | out = F.avg_pool2d(out, 4) 339 | out = out.view(out.size(0), -1) 340 | out = self.fc(out) / self.temp 341 | if k == i and not (train): 342 | _f = F.softmax(out, 1) # take the output of softmax 343 | return None, _f 344 | else: 345 | return out 346 | -------------------------------------------------------------------------------- /standard_curriculum_learning/utils/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import collections 3 | import torch 4 | from torch import Tensor 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | import subprocess 8 | import os 9 | import time 10 | import shutil 11 | from datetime import datetime 12 | import torch.optim as optim 13 | from torch.optim import lr_scheduler 14 | import sys 15 | sys.path.append("..") 16 | import torchvision.models as models 17 | import numpy as np 18 | import math 19 | import random 20 | 21 | def run_cmd(cmd_str, prev_sp=None): 22 | """ 23 | This function runs the linux command cmr_str as a subprocess after waiting 24 | for prev_sp subprocess to finish 25 | """ 26 | if prev_sp is not None: 27 | prev_sp.wait() 28 | return subprocess.Popen(cmd_str, shell=True)#, stdout=open(os.devnull, 'w'), stderr=open(os.devnull, 'w')) 29 | 30 | def get_model(model_name, nchannels=3, imsize=32, nclasses=10, args=None): 31 | 32 | print("=> creating model '{}'".format(model_name)) 33 | if imsize < 128 and model_name in models.__dict__: 34 | model = models.__dict__[model_name](num_classes=nclasses) 35 | print(model) 36 | else: 37 | raise NotImplementedError 38 | model = model.cuda() 39 | cudnn.benchmark = True 40 | return model 41 | 42 | def get_optimizer(optimizer_name, parameters, lr, momentum=0, weight_decay=0): 43 | if optimizer_name == 'sgd': 44 | return optim.SGD(parameters, lr, momentum=momentum, weight_decay=weight_decay) 45 | elif optimizer_name == 'nesterov_sgd': 46 | return optim.SGD(parameters, lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) 47 | elif optimizer_name == 'rmsprop': 48 | return optim.RMSprop(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 49 | elif optimizer_name == 'adagrad': 50 | return optim.Adagrad(parameters, lr=lr, weight_decay=weight_decay) 51 | elif optimizer_name == 'adam': 52 | return optim.Adam(parameters, lr=lr, weight_decay=weight_decay) 53 | 54 | def get_scheduler(scheduler_name, optimizer, num_epochs, **kwargs): 55 | if scheduler_name == 'constant': 56 | return lr_scheduler.StepLR(optimizer, num_epochs, gamma=1, **kwargs) 57 | 58 | elif scheduler_name == 'step2': 59 | return lr_scheduler.StepLR(optimizer, round(num_epochs / 2), gamma=0.1, **kwargs) 60 | elif scheduler_name == 'step3': 61 | return lr_scheduler.StepLR(optimizer, round(num_epochs / 3), gamma=0.1, **kwargs) 62 | elif scheduler_name == 'exponential': 63 | return lr_scheduler.ExponentialLR(optimizer, (1e-3) ** (1 / num_epochs), **kwargs) 64 | elif scheduler_name == 'cosine': 65 | return lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, **kwargs) 66 | elif scheduler_name == 'step-more': 67 | return lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2, **kwargs) 68 | 69 | 70 | def run_cmd(cmd_str, prev_sp=None): 71 | """ 72 | This function runs the linux command cmr_str as a subprocess after waiting 73 | for prev_sp subprocess to finish 74 | """ 75 | if prev_sp is not None: 76 | prev_sp.wait() 77 | return subprocess.Popen(cmd_str, shell=True)#, stdout=open(os.devnull, 'w'), stderr=open(os.devnull, 'w')) 78 | 79 | 80 | class LossTracker(object): 81 | def __init__(self, num, prefix='', print_freq=1): 82 | self.print_freq=print_freq 83 | self.batch_time = AverageMeter('Time', ':6.3f') 84 | self.losses = AverageMeter('Loss', ':.4e') 85 | self.top1 = AverageMeter('Acc@1', ':6.2f') 86 | self.top5 = AverageMeter('Acc@5', ':6.2f') 87 | self.progress = ProgressMeter( num, [self.batch_time, self.losses, self.top1, self.top5], prefix=prefix) 88 | self.end = time.time() 89 | 90 | def update(self, loss, output, target): 91 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 92 | self.losses.update(loss.item(), output.size(0)) 93 | self.top1.update(acc1[0], output.size(0)) 94 | self.top5.update(acc5[0], output.size(0)) 95 | 96 | def display(self, step): 97 | self.batch_time.update(time.time() - self.end) 98 | self.end = time.time() 99 | if step % self.print_freq == 0: 100 | self.progress.display(step) 101 | 102 | class AverageMeter(object): 103 | """Computes and stores the average and current value""" 104 | def __init__(self, name, fmt=':f'): 105 | self.name = name 106 | self.fmt = fmt 107 | self.reset() 108 | 109 | def reset(self): 110 | self.val = 0 111 | self.avg = 0 112 | self.sum = 0 113 | self.count = 0 114 | 115 | def update(self, val, n=1): 116 | self.val = val 117 | self.sum += val * n 118 | self.count += n 119 | self.avg = self.sum / self.count 120 | 121 | 122 | def __str__(self): 123 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 124 | return fmtstr.format(**self.__dict__) 125 | 126 | 127 | class ProgressMeter(object): 128 | def __init__(self, num_batches, meters, prefix=""): 129 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 130 | self.meters = meters 131 | self.prefix = prefix 132 | 133 | def display(self, batch): 134 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 135 | entries += [str(meter) for meter in self.meters] 136 | print('\t'.join(entries), flush=True) 137 | 138 | def _get_batch_fmtstr(self, num_batches): 139 | num_digits = len(str(num_batches // 1)) 140 | fmt = '{:' + str(num_digits) + 'd}' 141 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 142 | 143 | 144 | def accuracy(output, target, topk=(1,)): 145 | """Computes the accuracy over the k top predictions for the specified values of k""" 146 | with torch.no_grad(): 147 | maxk = max(topk) 148 | batch_size = target.size(0) 149 | 150 | _, pred = output.topk(maxk, 1, True, True) 151 | pred = pred.t() 152 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 153 | 154 | res = [] 155 | for k in topk: 156 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 157 | res.append(correct_k.mul_(100.0 / batch_size)) 158 | return res 159 | 160 | def balance_order(order, dataset, num_classes=10): 161 | 162 | class_orders = collections.defaultdict(list) 163 | for i in range(len(order)): 164 | class_orders[dataset.targets[order[i]]].append(i) 165 | length = [] 166 | for cls in range(num_classes): 167 | length.append(len(class_orders[cls])) 168 | new_order = [] 169 | 170 | for group_idx in range(min(length)): 171 | group = sorted([class_orders[cls][group_idx] for cls in range(num_classes)]) 172 | new_order.extend([order[idx] for idx in group]) 173 | 174 | for group_idx in range(min(length), max(length)): 175 | cls_idx = [cls for cls in range(num_classes) if group_idx= args.total_iteration: 116 | break 117 | return model 118 | 119 | 120 | def _get_feature_bank_from_kth_layer(model, dataloader, k): 121 | """ 122 | Get the FMs of the kth layer of current model for all data points in dataloader 123 | :param model: ResNet18(10 layers) / VGG16 (14 layers) 124 | :param dataloader: support set 125 | :param k: k th layer's output from the current model 126 | :return: FMs, labels of the support set 127 | """ 128 | print(k, 'layer feature bank gotten') 129 | fms_list = [] 130 | all_label_list = [] 131 | with torch.no_grad(): 132 | for (img, all_label), idx in dataloader: 133 | img = img.to(device) 134 | all_label = all_label.to(device) 135 | if args.half: 136 | with autocast(): 137 | _, fms = model(img, k, train=False) 138 | fms_list.append(fms) 139 | else: 140 | _, fms = model(img, k, train=False) 141 | fms_list.append(fms) 142 | all_label_list.append(all_label) 143 | fms = torch.cat(fms_list, 0) 144 | all_label = torch.cat(all_label_list, 0) 145 | return fms, all_label 146 | 147 | 148 | def get_knn_prds_k_layer(model, evaloader, floader, k, train_split=True): 149 | """ 150 | 151 | :param model: 152 | :param evaloader: data split we want to evaluate (train split/ test split) 153 | :param floader: support set loader 154 | :param k: k th layer 155 | :param train_split: if it is the train split, remove the label of the current evaluating point 156 | :return: lists of labels from knn classifier, confidence scores and index of each instance 157 | """ 158 | knn_labels_all = [] 159 | knn_conf_gt_all = [] # This statistics can be noisy due to different temperature 160 | indices_all = [] 161 | f_bank, all_labels = _get_feature_bank_from_kth_layer(model, floader, k) 162 | f_bank = f_bank.t().contiguous() 163 | warnings.warn('temperature will affect predictions when using multiple splits as evaluations') 164 | 165 | with torch.no_grad(): 166 | for j, ((imgs, labels), idx) in enumerate(evaloader): 167 | imgs = imgs.cuda(non_blocking=True) 168 | labels_b = labels.cuda(non_blocking=True) 169 | nm_cls = 10 170 | _, inp_f_curr = model(imgs, k, train=False) 171 | knn_scores = torch.zeros(imgs.shape[0], nm_cls).cuda() 172 | knn_scores += knn_predict(inp_f_curr, f_bank, all_labels, classes=nm_cls, knn_k=(args.knn_k), knn_t=1, rm_top1=train_split) # B x C 173 | knn_probs = F.normalize(knn_scores, p=1, dim=1) 174 | knn_labels_prd = knn_probs.argmax(1) 175 | knn_conf_gt = knn_probs.gather(dim=1, index=labels_b[:, None]) # B x 1 176 | knn_labels_all.append(knn_labels_prd) 177 | knn_conf_gt_all.append(knn_conf_gt) 178 | indices_all.append(idx) 179 | knn_labels_all = torch.cat(knn_labels_all, dim=0) # N x 1 180 | knn_conf_gt_all = torch.cat(knn_conf_gt_all, dim=0).squeeze() 181 | indices_all = np.concatenate(indices_all, 0) 182 | del f_bank, all_labels, inp_f_curr 183 | return knn_labels_all, knn_conf_gt_all, indices_all 184 | 185 | 186 | def _get_prediction_depth(knn_labels_all): 187 | """ 188 | get prediction depth for a sample. reverse knn labels list and increase the counter until the label is different 189 | :param knn_labels_all: 190 | :return: 191 | """ 192 | num_consistent = 0 193 | knn_labels_all = list(reversed(knn_labels_all)) 194 | while knn_labels_all[num_consistent] == knn_labels_all[0] and num_consistent <= max_prediction_depth - 2: 195 | num_consistent += 1 196 | return max_prediction_depth - num_consistent 197 | 198 | def set_seed(seed=1234): 199 | if seed is not None: 200 | random.seed(seed) 201 | np.random.seed(seed) 202 | torch.manual_seed(seed) 203 | # torch.backends.cudnn.deterministic = True 204 | warnings.warn('You have chosen to seed training. ' 205 | 'This will turn on the CUDNN deterministic setting, ' 206 | 'which can slow down your training considerably! ' 207 | 'You may see unexpected behavior when restarting ' 208 | 'from checkpoints.') 209 | 210 | def reset_param(net): 211 | print('reset sequential parameters') 212 | for module in net.children(): 213 | if isinstance(module, nn.Sequential): 214 | reset_param(module) 215 | 216 | if hasattr(module, 'reset_parameters'): 217 | module.reset_parameters() 218 | else: 219 | pass 220 | 221 | def main(train_idx, val_idx, random_seed=1234, flip=''): 222 | # for simplicity, we do not use data augmentation when measuring difficulty 223 | # CIFAR10 w / 40% (Fixed) Randomized Labels 224 | # only the training dataset is shuffle. Datasets for prediction depth and testing remains the same as cifar10 original 225 | train_transform = T.Compose([T.ToTensor(), 226 | T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=(0.247, 0.243, 0.261)) 227 | ]) 228 | if args.data == 'cifar10': 229 | trainset = CIFAR10('./', transform=train_transform, train=True, download=False) 230 | testset = CIFAR10('./', transform=train_transform, train=False, download=False) 231 | else: 232 | trainset = CIFAR100('./', transform=train_transform, train=True, download=False) 233 | testset = CIFAR100('./', transform=train_transform, train=False, download=False) 234 | 235 | train_split = Subset(trainset, train_idx) 236 | supportset = train_split 237 | val_split = Subset(trainset, val_idx) 238 | trainloader = DataLoader(train_split, batch_size=128, shuffle=True, num_workers=2, pin_memory=True) 239 | testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2, pin_memory=True) 240 | 241 | supportloader = DataLoader(supportset, batch_size=500, shuffle=False, num_workers=1, pin_memory=True) 242 | if args.get_train_pd: 243 | # pd (train) data order follows train_indices 244 | evaluate_loader_train = DataLoader(train_split, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 245 | if args.get_val_pd: 246 | # pd (val) data order follows val_indices 247 | evaluate_loader_test = DataLoader(val_split, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 248 | 249 | if args.arch == 'mlp7': 250 | model = MLP7(args.num_classes) 251 | elif args.arch == 'vgg16': 252 | ecd = vgg16().features 253 | reset_param(ecd) 254 | model = VGGPD(ecd, args.num_classes) 255 | elif args.arch == 'resnet18': 256 | model = ResNetPD(BasicBlockPD, [2, 2, 2, 2], temp=1.0, num_classes=args.num_classes) 257 | else: 258 | raise NotImplementedError 259 | 260 | model = model.to(device) 261 | criterion = nn.CrossEntropyLoss() 262 | 263 | optimizer = torch.optim.SGD(model.parameters(), lr=lr_init, momentum=momentum) 264 | if not args.resume: 265 | model = trainer(trainloader, testloader, model, optimizer, args.num_epochs, criterion, random_seed, flip) 266 | else: 267 | print('loading model from ckpt') 268 | model.load_state_dict(torch.load(os.path.join(args.result_dir, '{}_{}sgd{}_{}.pt'.format(args.arch, args.data, random_seed, flip)))) 269 | 270 | model.eval() 271 | if args.get_train_pd: 272 | index_knn_y = collections.defaultdict(list) 273 | index_pd = collections.defaultdict(list) 274 | knn_gt_conf_all = collections.defaultdict(list) 275 | for k in range(max_prediction_depth): 276 | # knn predictions, confidence, sample indices at k layer 277 | knn_labels, knn_conf_gt_all, indices_all = get_knn_prds_k_layer(model, evaluate_loader_train, supportloader, 278 | k, train_split=args.get_train_pd) 279 | for idx, knn_l, knn_conf_gt in zip(indices_all, knn_labels, knn_conf_gt_all): 280 | index_knn_y[int(idx)].append(knn_l.item()) 281 | knn_gt_conf_all[int(idx)].append(knn_conf_gt.item()) 282 | for idx, knn_ls in index_knn_y.items(): 283 | index_pd[idx].append(_get_prediction_depth(knn_ls)) 284 | 285 | print(len(index_pd), len(index_knn_y), len(knn_gt_conf_all)) 286 | with open(os.path.join(args.result_dir, '{}train_seed{}_f{}_trainpd.pkl'.format(args.arch, random_seed, flip)), 'w') as f: 287 | json.dump(index_pd, f) 288 | 289 | if args.get_val_pd: 290 | index_knn_y = collections.defaultdict(list) 291 | index_pd = collections.defaultdict(list) 292 | knn_gt_conf_all = collections.defaultdict(list) 293 | for k in range(max_prediction_depth): 294 | knn_labels, knn_conf_gt_all, indices_all = get_knn_prds_k_layer(model, evaluate_loader_test, supportloader, 295 | k, train_split=not(args.get_val_pd)) 296 | for idx, knn_l, knn_conf_gt in zip(indices_all, knn_labels, knn_conf_gt_all): 297 | index_knn_y[int(idx)].append(knn_l.item()) 298 | knn_gt_conf_all[int(idx)].append(knn_conf_gt.item()) 299 | for idx, knn_ls in index_knn_y.items(): 300 | index_pd[idx].append(_get_prediction_depth(knn_ls)) 301 | 302 | print(len(index_pd), len(index_knn_y), len(knn_gt_conf_all)) 303 | with open(os.path.join(args.result_dir, '{}_seed{}_f{}_test_pd.pkl'.format(args.arch, random_seed, flip)), 'w') as f: 304 | json.dump(index_pd, f) 305 | 306 | 307 | if __name__ == '__main__': 308 | seeds = [1111, 2222, 3333, 4444, 5555, 6666] 309 | for seed in seeds: 310 | set_seed(seed) 311 | train_indices, val_indices = train_test_split(np.arange(10000), train_size=args.train_ratio, 312 | test_size=(1 - args.train_ratio)) # split the data 313 | main(train_indices, val_indices, random_seed=seed, flip='') 314 | main(val_indices, train_indices, random_seed=seed, flip='flip') 315 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import json 7 | import torch.nn as nn 8 | from torch.nn import DataParallel 9 | import collections 10 | import time 11 | from torchvision import datasets, transforms 12 | from tqdm import tqdm 13 | from difficulty import * 14 | from models import AngularNet, Baseline 15 | from utils import LossTracker, AverageMeter 16 | from calibration import calibrationMapping, ece_eval, tace_eval 17 | from timm.data import Mixup 18 | from torch.utils.data import Dataset, DataLoader 19 | import visualize 20 | 21 | 22 | def main(arg): 23 | set_seed(arg.seed) 24 | num_epochs = arg.epochs 25 | 26 | if arg.dst == "cifar10": 27 | train_ds = datasets.CIFAR10( 28 | root="./data", 29 | train=True, 30 | transform=transforms.Compose( 31 | [ 32 | transforms.RandomCrop(32, padding=4), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | transforms.Normalize( 36 | mean=[0.4914, 0.4822, 0.4465], std=(0.247, 0.243, 0.261) 37 | ), 38 | ] 39 | ), 40 | download=True, 41 | ) 42 | test_ds = datasets.CIFAR10( 43 | root="./data", 44 | train=False, 45 | transform=transforms.Compose( 46 | [ 47 | transforms.Resize(36), 48 | transforms.CenterCrop(32), 49 | transforms.ToTensor(), 50 | transforms.Normalize( 51 | mean=[0.4914, 0.4822, 0.4465], std=(0.247, 0.243, 0.261) 52 | ), 53 | ] 54 | ), 55 | download=True, 56 | ) 57 | 58 | validation_size = 5000 59 | train_indices = range(50000)[:-validation_size] 60 | val_indices = range(50000)[-validation_size:] 61 | vis_indices = range(1000) 62 | train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) 63 | valid_sampler = torch.utils.data.SubsetRandomSampler(val_indices) 64 | vis_sampler = torch.utils.data.SubsetRandomSampler(vis_indices) 65 | else: 66 | raise NotImplementedError 67 | 68 | train_loader = torch.utils.data.DataLoader( 69 | dataset=train_ds, batch_size=args.batch_size, sampler=train_sampler 70 | ) 71 | 72 | valid_loader = torch.utils.data.DataLoader( 73 | dataset=train_ds, batch_size=args.batch_size, sampler=valid_sampler 74 | ) 75 | 76 | test_loader = torch.utils.data.DataLoader( 77 | dataset=test_ds, 78 | batch_size=100, 79 | shuffle=False, 80 | ) 81 | os.makedirs("./figs", exist_ok=True) 82 | loss_type = arg.loss_type 83 | history = collections.defaultdict(list) 84 | print("Training {} model....".format(loss_type)) 85 | net = DataParallel( 86 | AngularNet( 87 | num_classes=len(train_loader.dataset.classes), 88 | loss_type=loss_type, 89 | arch=arg.arch, 90 | s=arg.s, 91 | m=arg.m, 92 | ) 93 | ).to(device) 94 | net, calibration_map = train_hyper( 95 | net, 96 | num_epochs, 97 | train_loader, 98 | valid_loader, 99 | test_loader, 100 | loss_type, 101 | history, 102 | arg, 103 | ) 104 | if args.arch == "visualization": 105 | angular_embeds, angular_labels = get_embeds(net, test_loader) 106 | visualize.plot3d( 107 | angular_embeds, 108 | angular_labels, 109 | num_classes=10, 110 | fig_path="./figs/{}.png".format(loss_type), 111 | ) 112 | print("Saved {} figure".format(loss_type)) 113 | del angular_embeds, angular_labels 114 | 115 | 116 | def train_hyper( 117 | model, 118 | total_epochs, 119 | train_loader, 120 | valid_loader, 121 | test_loader, 122 | loss_type, 123 | history, 124 | arg, 125 | ): 126 | # optimizer = torch.optim.SGD(model.parameters(), lr=arg.lr, momentum=0.9, weight_decay=5e-4) 127 | optimizer = torch.optim.Adam(model.parameters(), lr=arg.lr) 128 | if arg.aug == "mixup": 129 | mixup_args = { 130 | "mixup_alpha": 1.0, 131 | "cutmix_alpha": 0.0, 132 | "cutmix_minmax": None, 133 | "prob": 1.0, 134 | "switch_prob": 0.0, 135 | "mode": "batch", 136 | "label_smoothing": 0, 137 | "num_classes": 10, 138 | } 139 | aug_func = Mixup(**mixup_args) 140 | else: 141 | aug_func = nn.Identity() 142 | 143 | calibration_map = None 144 | 145 | step = 0 146 | for epoch in range(total_epochs): 147 | tracker = LossTracker(len(train_loader), "step : [{}]".format(step), 1000) 148 | for i, (b_data, b_labels) in enumerate(train_loader): 149 | b_data = b_data.cuda(non_blocking=True) 150 | b_labels = b_labels.cuda(non_blocking=True) 151 | optimizer.zero_grad() 152 | 153 | if arg.aug == "mixup": 154 | b_data, b_labels_soft = aug_func(b_data, b_labels) 155 | logits, _ = model(b_data, b_labels) 156 | loss = F.cross_entropy(logits, b_labels_soft) 157 | else: 158 | b_data = aug_func(b_data) 159 | logits, _ = model(b_data, b_labels) 160 | loss = F.cross_entropy(logits, b_labels) 161 | 162 | loss = loss.mean() # for DataParallel 163 | 164 | loss.backward() 165 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 166 | optimizer.step() 167 | 168 | tracker.update(loss, logits, b_labels) 169 | step += 1 170 | print( 171 | "{}: Epoch [{}/{}], Loss: {:.4f}".format( 172 | loss_type, epoch + 1, total_epochs, loss.item() 173 | ) 174 | ) 175 | # scheduler.step() 176 | history["train_loss"].append(tracker.losses.avg) 177 | history["train_acc"].append(tracker.top1.avg) 178 | history["train_top5"].append(tracker.top5.avg) 179 | history["lr"].append(optimizer.param_groups[0]["lr"]) 180 | 181 | # calibrate and test 182 | calibration_map = calibrationMapping( 183 | 10, 184 | model, 185 | valid_loader, 186 | calibration_type=arg.calibration, 187 | calibration_lr=arg.calib_lr, 188 | ) 189 | ( 190 | test_acc, 191 | target_confs, 192 | output_margins, 193 | anggaps, 194 | cosgaps, 195 | norm_angles, 196 | ) = test_calibrated_statistics( 197 | test_loader, 198 | model, 199 | calibration_map, 200 | calibration_type=arg.calibration, 201 | history=history, 202 | ) 203 | 204 | print("Test accuracy:%s" % test_acc) 205 | history["acc"].append(test_acc) 206 | 207 | diff_dict = { 208 | "target_confs": target_confs.tolist(), 209 | "output_margins": output_margins.tolist(), 210 | "anggaps": anggaps.tolist(), 211 | "cosgaps": cosgaps.tolist(), 212 | "avh": norm_angles.tolist(), 213 | } 214 | with open( 215 | os.path.join( 216 | arg.result_dir, 217 | "diff_score{}_{}scale{}sd{}_{}_{}epoch{}.json".format( 218 | arg.dst, 219 | arg.calibration, 220 | arg.arch, 221 | arg.s, 222 | arg.seed, 223 | arg.loss_type, 224 | epoch, 225 | ), 226 | ), 227 | "w", 228 | ) as f: 229 | json.dump(diff_dict, f) 230 | 231 | torch.save( 232 | history, 233 | os.path.join( 234 | arg.result_dir, 235 | "{}{}history_loss{}scale{}sd{}{}.pt".format( 236 | arg.dst, arg.calibration, arg.arch, arg.s, arg.seed, arg.loss_type 237 | ), 238 | ), 239 | ) 240 | torch.save( 241 | model.state_dict(), 242 | os.path.join( 243 | arg.result_dir, 244 | "{}{}model{}_scale{}seed{}{}.pt".format( 245 | arg.dst, arg.calibration, arg.arch, arg.s, arg.seed, arg.loss_type 246 | ), 247 | ), 248 | ) 249 | torch.save( 250 | optimizer.state_dict(), 251 | os.path.join( 252 | arg.result_dir, 253 | "{}{}optimizer{}_scale{}seed{}{}.pt".format( 254 | arg.dst, arg.calibration, arg.arch, arg.s, arg.seed, arg.loss_type 255 | ), 256 | ), 257 | ) 258 | torch.save( 259 | calibration_map.state_dict(), 260 | os.path.join( 261 | arg.result_dir, 262 | "{}{}calibMap{}_scale{}seed{}{}.pt".format( 263 | arg.dst, arg.calibration, arg.arch, arg.s, arg.seed, arg.loss_type 264 | ), 265 | ), 266 | ) 267 | 268 | return model, calibration_map 269 | 270 | 271 | def test_calibrated_statistics( 272 | testloader, 273 | model, 274 | calibration_map=None, 275 | calibration_type="diagonal_scaling", 276 | history=None, 277 | ): 278 | num_classes = len(testloader.dataset.classes) 279 | model.eval() 280 | 281 | criterion = nn.CrossEntropyLoss() 282 | if calibration_type == "diagonal_scaling": 283 | post = calibration_map.diag.diag() 284 | post = post.unsqueeze(0) 285 | elif calibration_type == "temperature_scaling": 286 | post = calibration_map.temp 287 | post = post.unsqueeze(0).expand(-1, num_classes) 288 | elif calibration_type == "matrix_scaling": 289 | post = calibration_map.mat.weight.diag() - calibration_map.mat.bias 290 | post = post.unsqueeze(0) 291 | else: 292 | raise NotImplementedError 293 | 294 | print("start testing") 295 | with torch.no_grad(): 296 | # uncalibrated stats 297 | uncalibrate_loss_tracker = LossTracker(len(testloader), "val", 1000) 298 | total_uncalibrate_cos_gap = AverageMeter("uncalib_cos_gap", ":.4e") 299 | total_uncalibrate_ang_gap = AverageMeter("uncalib_ang_gap", ":.4e") 300 | total_uncalibrate_norm_angle = AverageMeter("uncalib_norm_angle", ":.4e") 301 | total_uncalibrate_target_conf = AverageMeter("uncalib_target_conf", ":.4e") 302 | total_uncalibrate_output_margin = AverageMeter("uncalib_output_margin", ":.4e") 303 | uncalibrate_ece_avg = AverageMeter("uncalibrate ece", ":6.2f") 304 | uncalibrate_tace_avg = AverageMeter("uncalibrate tace", ":6.2f") 305 | 306 | # calibrated stats 307 | calibrate_loss_tracker = LossTracker(len(testloader), "val", 1000) 308 | total_calibrate_cos_gap = AverageMeter("cos_gap", ":.4e") 309 | total_calibrate_ang_gap = AverageMeter("ang_gap", ":.4e") 310 | total_calibrate_norm_angle = AverageMeter("norm_angle", ":.4e") 311 | total_calibrate_target_conf = AverageMeter("target_conf", ":.4e") 312 | total_calibrate_output_margin = AverageMeter("output_margin", ":.4e") 313 | calibrate_ece_avg = AverageMeter("calibrate ece", ":6.2f") 314 | calibrate_tace_avg = AverageMeter("calibrate tace", ":6.2f") 315 | 316 | target_confs = [] 317 | output_margins = [] 318 | anggaps = [] 319 | cosgaps = [] 320 | norm_angles = [] 321 | test_start = time.time() 322 | for (images, targets) in testloader: 323 | batch_size = len(images) 324 | target_onehot = torch.zeros(len(targets), num_classes) 325 | target_onehot = target_onehot.scatter( 326 | dim=1, index=targets[:, None], src=torch.ones(len(targets), 1) 327 | ).to( 328 | device 329 | ) # B, C 330 | images, targets = images.cuda(non_blocking=True), targets.cuda( 331 | non_blocking=True 332 | ) 333 | 334 | # uncalibrate_statics 335 | uncalibrated_logits, cos_dists = model(images, targets) 336 | uncalibrate_loss = criterion(uncalibrated_logits, targets) 337 | 338 | ( 339 | uncalib_target_confidence, 340 | uncalib_output_margin, 341 | ) = get_confidence_output_margin(uncalibrated_logits, target_onehot) 342 | _, uncalib_cos_gap, _, uncalib_ang_gap = angular_gap( 343 | cos_dists, target_onehot 344 | ) 345 | uncalib_norm_angle = avh(cos_dists, targets=targets) 346 | 347 | # recode uncalibrated stats 348 | uncalibrate_loss_tracker.update( 349 | uncalibrate_loss, uncalibrated_logits, targets 350 | ) 351 | total_uncalibrate_target_conf.update( 352 | uncalib_target_confidence.mean(), batch_size 353 | ) 354 | total_uncalibrate_output_margin.update( 355 | uncalib_output_margin.mean(), batch_size 356 | ) 357 | total_uncalibrate_cos_gap.update(uncalib_cos_gap.mean(), batch_size) 358 | total_uncalibrate_ang_gap.update(uncalib_ang_gap.mean(), batch_size) 359 | total_uncalibrate_norm_angle.update(uncalib_norm_angle.mean(), batch_size) 360 | # uncalib_prds_n = torch.softmax(uncalibrated_logits, dim=1).cpu().detach().numpy() 361 | # np_targets = targets.cpu().numpy() 362 | # un_ece, _, _, _ = ece_eval(uncalib_prds_n, np_targets) 363 | # un_tace, _, _, _ = tace_eval(uncalib_prds_n, np_targets) 364 | # uncalibrate_ece_avg.update(un_ece, np_targets.shape[0]) 365 | # uncalibrate_tace_avg.update(un_tace, np_targets.shape[0]) 366 | 367 | # record stats after calibration 368 | cos_dists_calibrated = post * cos_dists # B, C 369 | calibrated_logits = calibration_map(uncalibrated_logits) 370 | calibrate_loss = criterion(calibrated_logits, targets) 371 | 372 | calib_target_confidence, calib_output_margin = get_confidence_output_margin( 373 | calibrated_logits, target_onehot 374 | ) 375 | _, cos_gap, _, ang_gap = angular_gap( 376 | cos_dists_calibrated, target_onehot, calibration_map.diag.data.diag() 377 | ) 378 | norm_angle = avh(cos_dists_calibrated, targets=targets) 379 | # recode calibrated stats 380 | # calib_prds_n = torch.softmax(calibrated_logits, dim=1).cpu().detach().numpy() 381 | # ece, _, _, _ = ece_eval(calib_prds_n, np_targets) 382 | # tace, _, _, _ = tace_eval(calib_prds_n, np_targets) 383 | # calibrate_ece_avg.update(ece, np_targets.shape[0]) 384 | # calibrate_tace_avg.update(tace, np_targets.shape[0]) 385 | 386 | calibrate_loss_tracker.update(calibrate_loss, calibrated_logits, targets) 387 | total_calibrate_target_conf.update( 388 | calib_target_confidence.mean(), batch_size 389 | ) 390 | total_calibrate_output_margin.update(calib_output_margin.mean(), batch_size) 391 | total_calibrate_cos_gap.update(cos_gap.mean(), batch_size) 392 | total_calibrate_ang_gap.update(ang_gap.mean(), batch_size) 393 | total_calibrate_norm_angle.update(norm_angle.mean(), batch_size) 394 | 395 | target_confs.append(calib_target_confidence) 396 | output_margins.append(calib_output_margin) 397 | anggaps.append(ang_gap) 398 | cosgaps.append(cos_gap) 399 | norm_angles.append(norm_angle) 400 | 401 | test_end = time.time() 402 | print("testing, elapse time: {}".format(test_end - test_start)) 403 | # test convergence 404 | # class feature norm 405 | history["cls_emb_norm"].append(embed_norm(model.module.angular_loss.fc.weight)) 406 | # data feature norm 407 | emb = model(images, embed=True) 408 | history["data_emb_norm"].append(embed_norm(emb)) 409 | del emb 410 | 411 | history["uncalibrate_test_loss"].append(uncalibrate_loss_tracker.losses.avg) 412 | history["uncalibrate_test_acc1"].append(uncalibrate_loss_tracker.top1.avg) 413 | history["uncalibrate_test_acc5"].append(uncalibrate_loss_tracker.top5.avg) 414 | history["uncalibrate_cos_gap"].append(total_uncalibrate_cos_gap.avg) 415 | history["uncalibrate_ang_gap"].append(total_uncalibrate_ang_gap.avg) 416 | history["uncalibrate_norm_angle"].append(total_uncalibrate_norm_angle.avg) 417 | history["uncalibrate_target_conf"].append(total_uncalibrate_target_conf.avg) 418 | history["uncalibrate_output_gap_margin"].append( 419 | total_uncalibrate_output_margin.avg 420 | ) 421 | # calibrate 422 | history["calibrate_test_loss"].append(calibrate_loss_tracker.losses.avg) 423 | history["calibrate_test_acc1"].append(calibrate_loss_tracker.top1.avg) 424 | history["calibrate_test_acc5"].append(calibrate_loss_tracker.top5.avg) 425 | history["calibrate_cos_gap"].append(total_calibrate_cos_gap.avg) 426 | history["calibrate_ang_gap"].append(total_calibrate_ang_gap.avg) 427 | history["calibrate_norm_angle"].append(total_calibrate_norm_angle.avg) 428 | history["calibrate_target_conf"].append(total_calibrate_target_conf.avg) 429 | history["calibrate_output_gap_margin"].append(total_calibrate_output_margin.avg) 430 | 431 | target_confs = torch.cat(target_confs, 0).cpu().numpy() 432 | output_margins = torch.cat(output_margins, 0).cpu().numpy() 433 | anggaps = torch.cat(anggaps, 0).cpu().numpy() 434 | cosgaps = torch.cat(cosgaps, 0).cpu().numpy() 435 | norm_angles = torch.cat(norm_angles, 0).cpu().numpy() 436 | return ( 437 | uncalibrate_loss_tracker.top1.avg, 438 | target_confs, 439 | output_margins, 440 | anggaps, 441 | cosgaps, 442 | norm_angles, 443 | ) 444 | 445 | 446 | def get_calibrated_difficulty_correlation( 447 | model, 448 | testloader, 449 | calibration_map=None, 450 | calibration_type="diagonal_scaling", 451 | arg=None, 452 | ): 453 | human_score = np.load(os.path.join("./orders", arg.human_score)) 454 | 455 | correlations = {} 456 | # switch to evaluate mode 457 | model = model.cuda() 458 | model.eval() 459 | num_classes = 10 460 | criterion = nn.CrossEntropyLoss() 461 | if calibration_type == "diagonal_scaling": 462 | post = calibration_map.diag.diag() 463 | elif calibration_type == "temperature_scaling": 464 | post = calibration_map.temp 465 | elif calibration_type == "matrix_scaling": 466 | post = calibration_map.mat.weight.diag() - calibration_map.mat.bias 467 | else: 468 | raise NotImplementedError 469 | with torch.no_grad(): 470 | target_confs = [] 471 | output_margins = [] 472 | anggaps = [] 473 | cosgaps = [] 474 | norm_angles = [] 475 | 476 | for i, (images, target) in enumerate(tqdm(testloader)): 477 | batch_size = len(images) 478 | target_onehot = torch.zeros(len(targets), num_classes) 479 | target_onehot = target_onehot.scatter( 480 | dim=1, index=targets[:, None], src=torch.ones(len(targets), 1) 481 | ).to( 482 | device 483 | ) # B, C 484 | images, targets = images.cuda(non_blocking=True), targets.cuda( 485 | non_blocking=True 486 | ) 487 | 488 | # uncalibrate_statics 489 | uncalibrated_logits, cos_dists = model(images, targets) 490 | uncalibrate_loss = criterion(uncalibrated_logits, targets) 491 | 492 | ( 493 | uncalib_target_confidence, 494 | uncalib_output_margin, 495 | ) = get_confidence_output_margin(uncalibrated_logits, target_onehot) 496 | _, uncalib_cos_gap, _, uncalib_ang_gap = angular_gap( 497 | cos_dists, target_onehot 498 | ) 499 | uncalib_norm_angle = avh(cos_dists, targets=targets) 500 | 501 | cos_dists_calibrated = post * cos_dists # B, C 502 | calibrated_logits = calibration_map(uncalibrated_logits) 503 | calibrate_loss = criterion(calibrated_logits, targets) 504 | 505 | calib_target_confidence, calib_output_margin = get_confidence_output_margin( 506 | calibrated_logits, target_onehot 507 | ) 508 | _, cos_gap, _, ang_gap = angular_gap( 509 | cos_dists_calibrated, target_onehot, calibration_map.diag.data.diag() 510 | ) 511 | norm_angle = avh(cos_dists_calibrated, targets=targets) 512 | 513 | target_confs = target_confs.append(calib_target_confidence) 514 | output_margins = output_margins.append(calib_output_margin) 515 | anggaps = anggaps.append(ang_gap) 516 | cosgaps = cosgaps.append(cos_gap) 517 | norm_angles = norm_angles.append(norm_angle) 518 | 519 | total_target_conf = torch.cat(target_confs, 0).cpu().numpy() 520 | total_output_margin = torch.cat(output_margins, 0).cpu().numpy() 521 | total_ang_gap = torch.cat(anggaps, 0).cpu().numpy() 522 | total_cos_gap = torch.cat(cosgaps, 0).cpu().numpy() 523 | total_norm_angle = torch.cat(norm_angles, 0).cpu().numpy() 524 | 525 | correlations["target_confidence_spearman"] = get_spearman( 526 | total_target_conf, human_score 527 | ) 528 | correlations["target_confidence_tau"] = get_kendalltau( 529 | total_target_conf, human_score 530 | ) 531 | 532 | correlations["output_margin_spearman"] = get_spearman( 533 | total_output_margin, human_score 534 | ) 535 | correlations["output_margin_tau"] = get_kendalltau(total_output_margin, human_score) 536 | 537 | correlations["cos_gap_spearman"] = get_spearman(total_cos_gap, human_score) 538 | correlations["cos_gap_tau"] = get_kendalltau(total_cos_gap, human_score) 539 | 540 | correlations["ang_gap_spearman"] = get_spearman(total_ang_gap, human_score) 541 | correlations["ang_gap_tau"] = get_kendalltau(total_ang_gap, human_score) 542 | 543 | correlations["norm_angle_spearman"] = get_spearman(total_norm_angle, human_score) 544 | correlations["norm_angle_tau"] = get_kendalltau(total_norm_angle, human_score) 545 | return correlations 546 | 547 | 548 | def get_embeds(model, loader): 549 | model = model.to(device).eval() 550 | full_embeds = [] 551 | full_labels = [] 552 | with torch.no_grad(): 553 | for feats, labels in loader: 554 | feats = feats[:200].to(device) 555 | full_labels.append(labels[:200]) 556 | embeds = model(feats, embed=True) 557 | full_embeds.append(F.normalize(embeds, dim=1)) 558 | return torch.cat(full_embeds, 0), torch.cat(full_labels, 0) 559 | 560 | 561 | def parse_args(): 562 | parser = argparse.ArgumentParser( 563 | description="Run AngularGap and Baseline experiments on CIFAR10" 564 | ) 565 | parser.add_argument( 566 | "--batch-size", 567 | type=int, 568 | default=512, 569 | help="input batch size for training (default: 512)", 570 | ) 571 | parser.add_argument( 572 | "--arch", 573 | type=str, 574 | default="resnet18", 575 | help="visualization/resnet18/alexnet/vgg16", 576 | ) 577 | parser.add_argument( 578 | "--epochs", 579 | type=int, 580 | default=100, 581 | help="number of epochs to train each model for (default: 100)", 582 | ) 583 | parser.add_argument("--seed", type=int, default=9999, help="random seed") 584 | parser.add_argument( 585 | "--lr", type=float, default=0.01, help="learning rate (default: 0.01)" 586 | ) 587 | parser.add_argument( 588 | "--use-cuda", default=True, type=bool, help="enables CUDA training" 589 | ) 590 | parser.add_argument("--resume", default=False, type=bool, help="resume") 591 | parser.add_argument("--aug", default="none", type=str, help="none/mixup") 592 | parser.add_argument( 593 | "--resume_path", 594 | default=0, 595 | type=str, 596 | help="resume checkpoint from this directory", 597 | ) 598 | parser.add_argument("--s", default=30.0, type=float, help="scale") 599 | parser.add_argument("--dst", default="cifar10", type=str, help="train dataset") 600 | parser.add_argument( 601 | "--human_score", 602 | default="cifar10_human_probs.npy", 603 | type=str, 604 | help="path to CIFAR10-H human score", 605 | ) 606 | parser.add_argument( 607 | "--calibration", default="diagonal_scaling", type=str, help="calibration method" 608 | ) 609 | parser.add_argument("--result_dir", default="", type=str) 610 | parser.add_argument("--m", default=0.35, type=float) 611 | parser.add_argument("--cuda_no", default="0, 1", type=str) 612 | parser.add_argument("--clip", default=5.0, type=float) 613 | parser.add_argument( 614 | "--calib_lr", default=0.01, type=float, help="path to human score" 615 | ) 616 | parser.add_argument( 617 | "--loss_type", default="nsl", type=str, help="path to human score" 618 | ) 619 | parser.add_argument( 620 | "--num_gpus", default=1, type=int, help="num of gpus for training" 621 | ) 622 | parser.add_argument("--amp", default=False, type=bool, help="use cuda amp") 623 | args = parser.parse_args() 624 | return args 625 | 626 | 627 | def set_seed(seed=None): 628 | if seed is not None: 629 | random.seed(args.seed) 630 | torch.manual_seed(args.seed) 631 | 632 | 633 | if __name__ == "__main__": 634 | args = parse_args() 635 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_no 636 | set_seed(args.seed) 637 | if args.use_cuda and torch.cuda.is_available(): 638 | device = torch.device("cuda") 639 | torch.backends.cudnn.benchmark = True 640 | 641 | main(args) 642 | --------------------------------------------------------------------------------