├── LICENSE ├── README.md ├── assets └── sdat.png ├── common ├── __init__.py ├── modules │ ├── __init__.py │ └── classifier.py ├── utils │ ├── __init__.py │ ├── analysis │ │ ├── __init__.py │ │ ├── a_distance.py │ │ └── tsne.py │ ├── data.py │ ├── logger.py │ ├── meter.py │ ├── metric │ │ └── __init__.py │ ├── sam.py │ └── scheduler.py └── vision │ ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── _util.cpython-36.pyc │ │ ├── _util.cpython-38.pyc │ │ ├── _util.cpython-39.pyc │ │ ├── aircrafts.cpython-36.pyc │ │ ├── aircrafts.cpython-38.pyc │ │ ├── aircrafts.cpython-39.pyc │ │ ├── coco70.cpython-36.pyc │ │ ├── coco70.cpython-38.pyc │ │ ├── coco70.cpython-39.pyc │ │ ├── cub200.cpython-36.pyc │ │ ├── cub200.cpython-38.pyc │ │ ├── cub200.cpython-39.pyc │ │ ├── digits.cpython-36.pyc │ │ ├── digits.cpython-38.pyc │ │ ├── digits.cpython-39.pyc │ │ ├── domainnet.cpython-36.pyc │ │ ├── domainnet.cpython-38.pyc │ │ ├── domainnet.cpython-39.pyc │ │ ├── dtd.cpython-36.pyc │ │ ├── dtd.cpython-38.pyc │ │ ├── dtd.cpython-39.pyc │ │ ├── eurosat.cpython-36.pyc │ │ ├── eurosat.cpython-38.pyc │ │ ├── eurosat.cpython-39.pyc │ │ ├── imagelist.cpython-36.pyc │ │ ├── imagelist.cpython-38.pyc │ │ ├── imagelist.cpython-39.pyc │ │ ├── imagenet_r.cpython-36.pyc │ │ ├── imagenet_r.cpython-38.pyc │ │ ├── imagenet_r.cpython-39.pyc │ │ ├── imagenet_sketch.cpython-36.pyc │ │ ├── imagenet_sketch.cpython-38.pyc │ │ ├── imagenet_sketch.cpython-39.pyc │ │ ├── office31.cpython-36.pyc │ │ ├── office31.cpython-38.pyc │ │ ├── office31.cpython-39.pyc │ │ ├── officecaltech.cpython-36.pyc │ │ ├── officecaltech.cpython-38.pyc │ │ ├── officecaltech.cpython-39.pyc │ │ ├── officehome.cpython-36.pyc │ │ ├── officehome.cpython-38.pyc │ │ ├── officehome.cpython-39.pyc │ │ ├── oxfordflowers.cpython-36.pyc │ │ ├── oxfordflowers.cpython-38.pyc │ │ ├── oxfordflowers.cpython-39.pyc │ │ ├── oxfordpet.cpython-36.pyc │ │ ├── oxfordpet.cpython-38.pyc │ │ ├── oxfordpet.cpython-39.pyc │ │ ├── pacs.cpython-36.pyc │ │ ├── pacs.cpython-38.pyc │ │ ├── pacs.cpython-39.pyc │ │ ├── patchcamelyon.cpython-36.pyc │ │ ├── patchcamelyon.cpython-38.pyc │ │ ├── patchcamelyon.cpython-39.pyc │ │ ├── resisc45.cpython-36.pyc │ │ ├── resisc45.cpython-38.pyc │ │ ├── resisc45.cpython-39.pyc │ │ ├── retinopathy.cpython-36.pyc │ │ ├── retinopathy.cpython-38.pyc │ │ ├── retinopathy.cpython-39.pyc │ │ ├── stanford_cars.cpython-36.pyc │ │ ├── stanford_cars.cpython-38.pyc │ │ ├── stanford_cars.cpython-39.pyc │ │ ├── stanford_dogs.cpython-36.pyc │ │ ├── stanford_dogs.cpython-38.pyc │ │ ├── stanford_dogs.cpython-39.pyc │ │ ├── visda2017.cpython-36.pyc │ │ ├── visda2017.cpython-38.pyc │ │ └── visda2017.cpython-39.pyc │ ├── _util.py │ ├── domainnet.py │ ├── imagelist.py │ ├── officehome.py │ └── visda2017.py │ ├── models │ ├── __init__.py │ └── resnet.py │ └── transforms │ └── __init__.py ├── dalib ├── adaptation │ ├── __init__.py │ ├── cdan.py │ └── mcc.py └── modules │ ├── __init__.py │ ├── domain_discriminator.py │ ├── entropy.py │ ├── gl.py │ ├── grl.py │ └── kernels.py ├── examples ├── cdan.py ├── cdan_mcc.py ├── cdan_mcc_sdat.py ├── cdan_sdat.py ├── eval.py ├── run_office_home.sh ├── run_visda.sh └── utils.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Video Analytics Lab -- IISc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #
Smooth Domain Adversarial Training
2 | 3 | **Harsh Rangwani\*, Sumukh K Aithal\*, Mayank Mishra, Arihant Jain, R. Venkatesh Babu** 4 | 5 | 6 | 7 | This is the official PyTorch implementation for our ICML'22 paper: **A Closer Look at Smoothness in Domain Adversarial Training**.[[`Paper`](https://arxiv.org/abs/2206.08213)] 8 | 9 | 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-closer-look-at-smoothness-in-domain-1/domain-adaptation-on-office-home)](https://paperswithcode.com/sota/domain-adaptation-on-office-home?p=a-closer-look-at-smoothness-in-domain-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-closer-look-at-smoothness-in-domain-1/domain-adaptation-on-visda2017)](https://paperswithcode.com/sota/domain-adaptation-on-visda2017?p=a-closer-look-at-smoothness-in-domain-1) 11 | 12 | ## Introduction 13 |
14 | Smooth Domain Adversarial Training 15 |
16 | 17 |

In recent times, methods converging to smooth optima have shown improved generalization for supervised learning tasks like classification. In this work, we analyze the effect of smoothness enhancing formulations on domain adversarial training, the objective of which is a combination of task loss (eg. classification, regression etc.) and adversarial terms. We find that converging to a smooth minima with respect to (w.r.t.) task loss stabilizes the adversarial training leading to better performance on target domain. In contrast to task loss, our analysis shows that converging to smooth minima w.r.t. adversarial loss leads to sub-optimal generalization on the target domain. Based on the analysis, we introduce the Smooth Domain Adversarial Training (SDAT) procedure, which effectively enhances the performance of existing domain adversarial methods for both classification and object detection tasks.

18 | 19 | **TLDR:** Just do a few line of code change to improve your adversarial domain adaptation algorithm by converting it to it's smooth variant. 20 | 21 | ### Why use SDAT? 22 | - Can be combined with any DAT algorithm. 23 | - Easy to integrate with a few lines of code. 24 | - Leads to significant improvement in the accuracy of target domain. 25 | 44 | 45 | #### DAT Based Method w/ SDAT 46 | We provide the details of changes required to convert any DAT algorithm (eg. CDAN, DANN, CDAN+MCC etc.) to it's Smooth DAT version. 47 | 48 | ```python 49 | optimizer = SAM(classifier.get_parameters(), torch.optim.SGD, rho=args.rho, adaptive=False, 50 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 51 | # optimizer refers to the Smooth optimizer which contains parameters of the feature extractor and classifier. 52 | optimizer.zero_grad() 53 | # ad_optimizer refers to standard SGD optimizer which contains parameters of domain classifier. 54 | ad_optimizer.zero_grad() 55 | 56 | # Calculate task loss 57 | class_prediction, feature = model(x) 58 | task_loss = task_loss_fn(class_prediction, label) 59 | task_loss.backward() 60 | 61 | # Calculate ϵ̂ (w) and add it to the weights 62 | optimizer.first_step() 63 | 64 | # Calculate task loss and domain loss 65 | class_prediction, feature = model(x) 66 | task_loss = task_loss_fn(class_prediction, label) 67 | domain_loss = domain_classifier(feature) 68 | loss = task_loss + domain_loss 69 | loss.backward() 70 | 71 | # Update parameters (Sharpness-Aware update) 72 | optimizer.step() 73 | # Update parameters of domain classifier 74 | ad_optimizer.step() 75 | ``` 76 | 77 | ## Getting started 78 | 79 | * ### Requirements 80 | 88 | * ### Installation 89 | ``` 90 | git clone https://github.com/val-iisc/SDAT.git 91 | cd SDAT 92 | pip install -r requirements.txt 93 | ``` 94 | We use Weights and Biases ([wandb](https://wandb.ai/site)) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The ```project``` and ```entity``` arguments in ```wandb.init``` must be changed accordingly. To disable wandb tracking, the ```log_results``` flag can be used. 95 | 96 | * ### Datasets 97 | The datasets used in the repository can be downloaded from the following links: 98 | 101 | The datasets are automatically downloaded to the ```data/``` folder if it is not available. 102 | ## Training 103 | We report our numbers primarily on two domain adaptation methods: CDAN w/ SDAT and CDAN+MCC w/ SDAT. The training scripts can be found under the `examples` subdirectory. 104 | 105 | ### Domain Adversarial Training (DAT) 106 | To train using standard CDAN and CDAN+MCC, use the `cdan.py` and `cdan_mcc.py` files, respectively. Sample command to execute the training of the aforementioned methods with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) can be found below. 107 | ``` 108 | python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 109 | ``` 110 | 111 | ### Smooth Domain Adversarial Training (SDAT) 112 | 113 | To train using our proposed CDAN w/ SDAT and CDAN+MCC w/ SDAT, use the `cdan_sdat.py` and `cdan_mcc_sdat.py` files, respectively. 114 | 115 | A sample script to run CDAN+MCC w/ SDAT with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) is given below. 116 | ``` 117 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 118 | ``` 119 | Additional commands to reproduce the results can be found from `run_office_home.sh` and `run_visda.sh` under `examples`. 120 | 121 | ### Results 122 | We following table reports the accuracy score across the various splits of Office-Home and VisDA-2017 datasets using CDAN+MCC w/ SDAT with VIT B-16 backbone. We also provide downloadable weights for the corresponding pretrained classifier. 123 |
124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 |
DatasetSourceTargetAccuracyCheckpoints
Office-HomeArtClipart70.8ckpt 141 |
ArtProduct80.7ckpt
ArtReal World90.5ckpt
ClipartArt85.2ckpt
ClipartProduct87.3ckpt
ClipartReal World89.7ckpt
ProductArt84.1ckpt
ProductClipart70.7ckpt
ProductReal World90.6ckpt
Real WorldArt88.3ckpt
Real WorldClipart75.5ckpt
Real WorldProduct92.1ckpt
VisDA-2017SyntheticReal89.8ckpt
218 |
219 | 220 | ### Evaluation 221 | To evaluate a classifier with pretrained weights, use the `eval.py` under `examples`. Set the `--weight_path` argument with the path of the weight to be evaluated. 222 | 223 | A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on Office-Home (with Art as source domain and Clipart as the target domain) is given below. 224 | ``` 225 | python eval.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 -b 24 --no-pool --weight_path path_to_weight.pth --log_name Ar2Cl_cdan_mcc_sdat_vit_eval --gpu 0 --phase test 226 | ``` 227 | A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on VisDA-2017 (with Synthetic as source domain and Real as the target domain) is given below. 228 | 229 | ``` 230 | python eval.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --per-class-eval --train-resizing cen.crop --weight_path path_to_weight.pth --log_name visda_cdan_mcc_sdat_vit_eval --gpu 0 --no-pool --phase test 231 | ``` 232 | 233 | 234 | ## Overview of the arguments 235 | Generally, all scripts in the project take the following flags 236 | - `-a`: Architecture of the backbone. (resnet50|vit_base_patch16_224) 237 | - `-d`: Dataset (OfficeHome|DomainNet) 238 | - `-s`: Source Domain 239 | - `-t`: Target Domain 240 | - `--epochs`: Number of Epochs to be trained for. 241 | - `--no-pool`: Use --no-pool for all experiments with ViT backbone. 242 | - `--log_name`: Name of the run on wandb. 243 | - `--gpu`: GPU id to use. 244 | - `--rho`: $\rho$ value in SDAT (Applicable only for SDAT runs). 245 | 246 | ## Acknowledgement 247 | Our implementation is based on the [Transfer Learning Library](https://github.com/thuml/Transfer-Learning-Library). We use the PyTorch implementation of SAM from https://github.com/davda54/sam. 248 | ## Citation 249 | If you find our paper or codebase useful, please consider citing us as: 250 | ```latex 251 | @InProceedings{rangwani2022closer, 252 | title={A Closer Look at Smoothness in Domain Adversarial Training}, 253 | author={Rangwani, Harsh and Aithal, Sumukh K and Mishra, Mayank and Jain, Arihant and Babu, R. Venkatesh}, 254 | booktitle={Proceedings of the 39th International Conference on Machine Learning}, 255 | year={2022} 256 | } 257 | ``` 258 | -------------------------------------------------------------------------------- /assets/sdat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/assets/sdat.png -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['modules', 'utils', 'vision'] 2 | -------------------------------------------------------------------------------- /common/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import * 2 | 3 | __all__ = ['classifier'] 4 | -------------------------------------------------------------------------------- /common/modules/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, List, Dict 2 | import torch.nn as nn 3 | import torch 4 | 5 | __all__ = ['Classifier'] 6 | 7 | 8 | class Classifier(nn.Module): 9 | """A generic Classifier class for domain adaptation. 10 | 11 | Args: 12 | backbone (torch.nn.Module): Any backbone to extract 2-d features from data 13 | num_classes (int): Number of classes 14 | bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default 15 | bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1 16 | head (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default 17 | finetune (bool): Whether finetune the classifier or train from scratch. Default: True 18 | freeze (bool) : Freeze the backbone and only train the classifier 19 | 20 | .. note:: 21 | Different classifiers are used in different domain adaptation algorithms to achieve better accuracy 22 | respectively, and we provide a suggested `Classifier` for different algorithms. 23 | Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with 24 | the domain adaptation algorithm in this algorithm library. 25 | 26 | .. note:: 27 | The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy 28 | by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`. 29 | 30 | Inputs: 31 | - x (tensor): input data fed to `backbone` 32 | 33 | Outputs: 34 | - predictions: classifier's predictions 35 | - features: features after `bottleneck` layer and before `head` layer 36 | 37 | Shape: 38 | - Inputs: (minibatch, *) where * means, any number of additional dimensions 39 | - predictions: (minibatch, `num_classes`) 40 | - features: (minibatch, `features_dim`) 41 | 42 | """ 43 | 44 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None, 45 | bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True, pool_layer=None): 46 | super(Classifier, self).__init__() 47 | self.backbone = backbone 48 | self.num_classes = num_classes 49 | if pool_layer is None: 50 | self.pool_layer = nn.Sequential( 51 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), 52 | nn.Flatten() 53 | ) 54 | else: 55 | self.pool_layer = pool_layer 56 | if bottleneck is None: 57 | self.bottleneck = nn.Identity() 58 | self._features_dim = backbone.out_features 59 | else: 60 | self.bottleneck = bottleneck 61 | print("[INFORMATION] The bottleneck dim is ", bottleneck_dim) 62 | assert bottleneck_dim > 0 63 | self._features_dim = bottleneck_dim 64 | 65 | 66 | if head is None: 67 | self.head = nn.Linear(self._features_dim, num_classes) 68 | else: 69 | self.head = head 70 | self.finetune = finetune 71 | 72 | @property 73 | def features_dim(self) -> int: 74 | """The dimension of features before the final `head` layer""" 75 | return self._features_dim 76 | 77 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 78 | """""" 79 | f = self.pool_layer(self.backbone(x)) 80 | f = self.bottleneck(f) 81 | predictions = self.head(f) 82 | if self.training: 83 | return predictions, f 84 | else: 85 | return predictions 86 | 87 | def get_parameters(self, base_lr=1.0) -> List[Dict]: 88 | """A parameter list which decides optimization hyper-parameters, 89 | such as the relative learning rate of each layer 90 | """ 91 | params = [ 92 | {"params": self.backbone.parameters(), "lr": 0.1*base_lr if self.finetune else 1.0 * base_lr}, 93 | {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, 94 | {"params": self.head.parameters(), "lr": 1.0 * base_lr}, 95 | ] 96 | 97 | return params 98 | 99 | 100 | class ImageClassifier(Classifier): 101 | pass 102 | -------------------------------------------------------------------------------- /common/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import CompleteLogger 2 | from .meter import * 3 | from .data import ForeverDataIterator 4 | 5 | __all__ = ['metric', 'analysis', 'meter', 'data', 'logger'] -------------------------------------------------------------------------------- /common/utils/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torch.nn as nn 4 | import tqdm 5 | 6 | 7 | def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module, 8 | device: torch.device, max_num_features=None) -> torch.Tensor: 9 | """ 10 | Fetch data from `data_loader`, and then use `feature_extractor` to collect features 11 | 12 | Args: 13 | data_loader (torch.utils.data.DataLoader): Data loader. 14 | feature_extractor (torch.nn.Module): A feature extractor. 15 | device (torch.device) 16 | max_num_features (int): The max number of features to return 17 | 18 | Returns: 19 | Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`). 20 | """ 21 | feature_extractor.eval() 22 | all_features = [] 23 | with torch.no_grad(): 24 | for i, (images, target) in enumerate(tqdm.tqdm(data_loader)): 25 | if max_num_features is not None and i >= max_num_features: 26 | break 27 | images = images.to(device) 28 | feature = feature_extractor(images).cpu() 29 | all_features.append(feature) 30 | return torch.cat(all_features, dim=0) 31 | -------------------------------------------------------------------------------- /common/utils/analysis/a_distance.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from torch.utils.data import TensorDataset 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.optim import SGD 11 | from ..meter import AverageMeter 12 | from ..metric import binary_accuracy 13 | 14 | 15 | class ANet(nn.Module): 16 | def __init__(self, in_feature): 17 | super(ANet, self).__init__() 18 | self.layer = nn.Linear(in_feature, 1) 19 | self.sigmoid = nn.Sigmoid() 20 | 21 | def forward(self, x): 22 | x = self.layer(x) 23 | x = self.sigmoid(x) 24 | return x 25 | 26 | 27 | def calculate(source_feature: torch.Tensor, target_feature: torch.Tensor, 28 | device, progress=True, training_epochs=10): 29 | """ 30 | Calculate the :math:`\mathcal{A}`-distance, which is a measure for distribution discrepancy. 31 | 32 | The definition is :math:`dist_\mathcal{A} = 2 (1-2\epsilon)`, where :math:`\epsilon` is the 33 | test error of a classifier trained to discriminate the source from the target. 34 | 35 | Args: 36 | source_feature (tensor): features from source domain in shape :math:`(minibatch, F)` 37 | target_feature (tensor): features from target domain in shape :math:`(minibatch, F)` 38 | device (torch.device) 39 | progress (bool): if True, displays a the progress of training A-Net 40 | training_epochs (int): the number of epochs when training the classifier 41 | 42 | Returns: 43 | :math:`\mathcal{A}`-distance 44 | """ 45 | source_label = torch.ones((source_feature.shape[0], 1)) 46 | target_label = torch.zeros((target_feature.shape[0], 1)) 47 | feature = torch.cat([source_feature, target_feature], dim=0) 48 | label = torch.cat([source_label, target_label], dim=0) 49 | 50 | dataset = TensorDataset(feature, label) 51 | length = len(dataset) 52 | train_size = int(0.8 * length) 53 | val_size = length - train_size 54 | train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size]) 55 | train_loader = DataLoader(train_set, batch_size=2, shuffle=True) 56 | val_loader = DataLoader(val_set, batch_size=8, shuffle=False) 57 | 58 | anet = ANet(feature.shape[1]).to(device) 59 | optimizer = SGD(anet.parameters(), lr=0.01) 60 | a_distance = 2.0 61 | for epoch in range(training_epochs): 62 | anet.train() 63 | for (x, label) in train_loader: 64 | x = x.to(device) 65 | label = label.to(device) 66 | anet.zero_grad() 67 | y = anet(x) 68 | loss = F.binary_cross_entropy(y, label) 69 | loss.backward() 70 | optimizer.step() 71 | 72 | anet.eval() 73 | meter = AverageMeter("accuracy", ":4.2f") 74 | with torch.no_grad(): 75 | for (x, label) in val_loader: 76 | x = x.to(device) 77 | label = label.to(device) 78 | y = anet(x) 79 | acc = binary_accuracy(y, label) 80 | meter.update(acc, x.shape[0]) 81 | error = 1 - meter.avg / 100 82 | a_distance = 2 * (1 - 2 * error) 83 | if progress: 84 | print("epoch {} accuracy: {} A-dist: {}".format(epoch, meter.avg, a_distance)) 85 | 86 | return a_distance 87 | 88 | -------------------------------------------------------------------------------- /common/utils/analysis/tsne.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch 6 | import matplotlib 7 | 8 | matplotlib.use('Agg') 9 | from sklearn.manifold import TSNE 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import matplotlib.colors as col 13 | 14 | 15 | def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor, 16 | filename: str, source_color='r', target_color='b'): 17 | """ 18 | Visualize features from different domains using t-SNE. 19 | 20 | Args: 21 | source_feature (tensor): features from source domain in shape :math:`(minibatch, F)` 22 | target_feature (tensor): features from target domain in shape :math:`(minibatch, F)` 23 | filename (str): the file name to save t-SNE 24 | source_color (str): the color of the source features. Default: 'r' 25 | target_color (str): the color of the target features. Default: 'b' 26 | 27 | """ 28 | source_feature = source_feature.numpy() 29 | target_feature = target_feature.numpy() 30 | features = np.concatenate([source_feature, target_feature], axis=0) 31 | 32 | # map features to 2-d using TSNE 33 | X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features) 34 | 35 | # domain labels, 1 represents source while 0 represents target 36 | domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature)))) 37 | 38 | # visualize using matplotlib 39 | fig, ax = plt.subplots(figsize=(10, 10)) 40 | ax.spines['top'].set_visible(False) 41 | ax.spines['right'].set_visible(False) 42 | ax.spines['bottom'].set_visible(False) 43 | ax.spines['left'].set_visible(False) 44 | plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20) 45 | plt.xticks([]) 46 | plt.yticks([]) 47 | plt.savefig(filename) 48 | -------------------------------------------------------------------------------- /common/utils/data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Sampler 7 | from torch.utils.data import DataLoader, Dataset 8 | from typing import TypeVar, Iterable, Dict, List 9 | 10 | T_co = TypeVar('T_co', covariant=True) 11 | T = TypeVar('T') 12 | 13 | 14 | def send_to_device(tensor, device): 15 | """ 16 | Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. 17 | 18 | Args: 19 | tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`): 20 | The data to send to a given device. 21 | device (:obj:`torch.device`): 22 | The device to send the data to 23 | 24 | Returns: 25 | The same data structure as :obj:`tensor` with all tensors sent to the proper device. 26 | """ 27 | if isinstance(tensor, (list, tuple)): 28 | return type(tensor)(send_to_device(t, device) for t in tensor) 29 | elif isinstance(tensor, dict): 30 | return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()}) 31 | elif not hasattr(tensor, "to"): 32 | return tensor 33 | return tensor.to(device) 34 | 35 | 36 | class ForeverDataIterator: 37 | r"""A data iterator that will never stop producing data""" 38 | 39 | def __init__(self, data_loader: DataLoader, device=None): 40 | self.data_loader = data_loader 41 | self.iter = iter(self.data_loader) 42 | self.device = device 43 | 44 | def __next__(self): 45 | try: 46 | data = next(self.iter) 47 | if self.device is not None: 48 | data = send_to_device(data, self.device) 49 | except StopIteration: 50 | self.iter = iter(self.data_loader) 51 | data = next(self.iter) 52 | if self.device is not None: 53 | data = send_to_device(data, self.device) 54 | return data 55 | 56 | def __len__(self): 57 | return len(self.data_loader) 58 | 59 | 60 | class RandomMultipleGallerySampler(Sampler): 61 | r"""Sampler from `In defense of the Triplet Loss for Person Re-Identification 62 | (ICCV 2017) `_. Assume there are :math:`N` identities in the dataset, this 63 | implementation simply samples :math:`K` images for every identity to form an iter of size :math:`N\times K`. During 64 | training, we will call ``__iter__`` method of pytorch dataloader once we reach a ``StopIteration``, this guarantees 65 | every image in the dataset will eventually be selected and we are not wasting any training data. 66 | 67 | Args: 68 | dataset(list): each element of this list is a tuple (image_path, person_id, camera_id) 69 | num_instances(int, optional): number of images to sample for every identity (:math:`K` here) 70 | """ 71 | 72 | def __init__(self, dataset, num_instances=4): 73 | super(RandomMultipleGallerySampler, self).__init__(dataset) 74 | self.dataset = dataset 75 | self.num_instances = num_instances 76 | 77 | self.idx_to_pid = {} 78 | self.cid_list_per_pid = {} 79 | self.idx_list_per_pid = {} 80 | 81 | for idx, (_, pid, cid) in enumerate(dataset): 82 | if pid not in self.cid_list_per_pid: 83 | self.cid_list_per_pid[pid] = [] 84 | self.idx_list_per_pid[pid] = [] 85 | 86 | self.idx_to_pid[idx] = pid 87 | self.cid_list_per_pid[pid].append(cid) 88 | self.idx_list_per_pid[pid].append(idx) 89 | 90 | self.pid_list = list(self.idx_list_per_pid.keys()) 91 | self.num_samples = len(self.pid_list) 92 | 93 | def __len__(self): 94 | return self.num_samples * self.num_instances 95 | 96 | def __iter__(self): 97 | def select_idxes(element_list, target_element): 98 | assert isinstance(element_list, list) 99 | return [i for i, element in enumerate(element_list) if element != target_element] 100 | 101 | pid_idxes = torch.randperm(len(self.pid_list)).tolist() 102 | final_idxes = [] 103 | 104 | for perm_id in pid_idxes: 105 | i = random.choice(self.idx_list_per_pid[self.pid_list[perm_id]]) 106 | _, _, cid = self.dataset[i] 107 | 108 | final_idxes.append(i) 109 | 110 | pid_i = self.idx_to_pid[i] 111 | cid_list = self.cid_list_per_pid[pid_i] 112 | idx_list = self.idx_list_per_pid[pid_i] 113 | selected_cid_list = select_idxes(cid_list, cid) 114 | 115 | if selected_cid_list: 116 | if len(selected_cid_list) >= self.num_instances: 117 | cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=False) 118 | else: 119 | cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=True) 120 | for cid_idx in cid_idxes: 121 | final_idxes.append(idx_list[cid_idx]) 122 | else: 123 | selected_idxes = select_idxes(idx_list, i) 124 | if not selected_idxes: 125 | continue 126 | if len(selected_idxes) >= self.num_instances: 127 | pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=False) 128 | else: 129 | pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=True) 130 | 131 | for pid_idx in pid_idxes: 132 | final_idxes.append(idx_list[pid_idx]) 133 | 134 | return iter(final_idxes) 135 | 136 | 137 | class CombineDataset(Dataset[T_co]): 138 | r"""Dataset as a combination of multiple datasets. 139 | The element of each dataset must be a list, and the i-th element of the combined dataset 140 | is a list splicing of the i-th element of each sub dataset. 141 | The length of the combined dataset is the minimum of the lengths of all sub datasets. 142 | 143 | Arguments: 144 | datasets (sequence): List of datasets to be concatenated 145 | """ 146 | 147 | def __init__(self, datasets: Iterable[Dataset]) -> None: 148 | super(CombineDataset, self).__init__() 149 | # Cannot verify that datasets is Sized 150 | assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore 151 | self.datasets = list(datasets) 152 | 153 | def __len__(self): 154 | return min([len(d) for d in self.datasets]) 155 | 156 | def __getitem__(self, idx): 157 | return list(itertools.chain(*[d[idx] for d in self.datasets])) 158 | 159 | 160 | def concatenate(tensors): 161 | """concatenate multiple batches into one batch. 162 | ``tensors`` can be :class:`torch.Tensor`, List or Dict, but they must be the same data format. 163 | """ 164 | if isinstance(tensors[0], torch.Tensor): 165 | return torch.cat(tensors, dim=0) 166 | elif isinstance(tensors[0], List): 167 | ret = [] 168 | for i in range(len(tensors[0])): 169 | ret.append(concatenate([t[i] for t in tensors])) 170 | return ret 171 | elif isinstance(tensors[0], Dict): 172 | ret = dict() 173 | for k in tensors[0].keys(): 174 | ret[k] = concatenate([t[k] for t in tensors]) 175 | return ret 176 | -------------------------------------------------------------------------------- /common/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | class TextLogger(object): 6 | """Writes stream output to external text file. 7 | 8 | Args: 9 | filename (str): the file to write stream output 10 | stream: the stream to read from. Default: sys.stdout 11 | """ 12 | def __init__(self, filename, stream=sys.stdout): 13 | self.terminal = stream 14 | self.log = open(filename, 'a') 15 | 16 | def write(self, message): 17 | self.terminal.write(message) 18 | self.log.write(message) 19 | self.flush() 20 | 21 | def flush(self): 22 | self.terminal.flush() 23 | self.log.flush() 24 | 25 | def close(self): 26 | self.terminal.close() 27 | self.log.close() 28 | 29 | 30 | class CompleteLogger: 31 | """ 32 | A useful logger that 33 | 34 | - writes outputs to files and displays them on the console at the same time. 35 | - manages the directory of checkpoints and debugging images. 36 | 37 | Args: 38 | root (str): the root directory of logger 39 | phase (str): the phase of training. 40 | 41 | """ 42 | 43 | def __init__(self, root, phase='train'): 44 | self.root = root 45 | self.phase = phase 46 | self.visualize_directory = os.path.join(self.root, "visualize") 47 | self.checkpoint_directory = os.path.join(self.root, "checkpoints") 48 | self.epoch = 0 49 | 50 | os.makedirs(self.root, exist_ok=True) 51 | os.makedirs(self.visualize_directory, exist_ok=True) 52 | os.makedirs(self.checkpoint_directory, exist_ok=True) 53 | 54 | # redirect std out 55 | now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time())) 56 | log_filename = os.path.join(self.root, "{}-{}.txt".format(phase, now)) 57 | if os.path.exists(log_filename): 58 | os.remove(log_filename) 59 | self.logger = TextLogger(log_filename) 60 | sys.stdout = self.logger 61 | sys.stderr = self.logger 62 | if phase != 'train': 63 | self.set_epoch(phase) 64 | 65 | def set_epoch(self, epoch): 66 | """Set the epoch number. Please use it during training.""" 67 | os.makedirs(os.path.join(self.visualize_directory, str(epoch)), exist_ok=True) 68 | self.epoch = epoch 69 | 70 | def _get_phase_or_epoch(self): 71 | if self.phase == 'train': 72 | return str(self.epoch) 73 | else: 74 | return self.phase 75 | 76 | def get_image_path(self, filename: str): 77 | """ 78 | Get the full image path for a specific filename 79 | """ 80 | return os.path.join(self.visualize_directory, self._get_phase_or_epoch(), filename) 81 | 82 | def get_checkpoint_path(self, name=None): 83 | """ 84 | Get the full checkpoint path. 85 | 86 | Args: 87 | name (optional): the filename (without file extension) to save checkpoint. 88 | If None, when the phase is ``train``, checkpoint will be saved to ``{epoch}.pth``. 89 | Otherwise, will be saved to ``{phase}.pth``. 90 | 91 | """ 92 | if name is None: 93 | name = self._get_phase_or_epoch() 94 | name = str(name) 95 | return os.path.join(self.checkpoint_directory, name + ".pth") 96 | 97 | def close(self): 98 | self.logger.close() 99 | -------------------------------------------------------------------------------- /common/utils/meter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | 4 | class AverageMeter(object): 5 | r"""Computes and stores the average and current value. 6 | 7 | Examples:: 8 | 9 | >>> # Initialize a meter to record loss 10 | >>> losses = AverageMeter() 11 | >>> # Update meter after every minibatch update 12 | >>> losses.update(loss_value, batch_size) 13 | """ 14 | def __init__(self, name: str, fmt: Optional[str] = ':f'): 15 | self.name = name 16 | self.fmt = fmt 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | if self.count > 0: 30 | self.avg = self.sum / self.count 31 | 32 | def __str__(self): 33 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 34 | return fmtstr.format(**self.__dict__) 35 | 36 | 37 | class AverageMeterDict(object): 38 | def __init__(self, names: List, fmt: Optional[str] = ':f'): 39 | self.dict = { 40 | name: AverageMeter(name, fmt) for name in names 41 | } 42 | 43 | def reset(self): 44 | for meter in self.dict.values(): 45 | meter.reset() 46 | 47 | def update(self, accuracies, n=1): 48 | for name, acc in accuracies.items(): 49 | self.dict[name].update(acc, n) 50 | 51 | def average(self): 52 | return { 53 | name: meter.avg for name, meter in self.dict.items() 54 | } 55 | 56 | def __getitem__(self, item): 57 | return self.dict[item] 58 | 59 | 60 | class Meter(object): 61 | """Computes and stores the current value.""" 62 | def __init__(self, name: str, fmt: Optional[str] = ':f'): 63 | self.name = name 64 | self.fmt = fmt 65 | self.reset() 66 | 67 | def reset(self): 68 | self.val = 0 69 | 70 | def update(self, val): 71 | self.val = val 72 | 73 | def __str__(self): 74 | fmtstr = '{name} {val' + self.fmt + '}' 75 | return fmtstr.format(**self.__dict__) 76 | 77 | 78 | class ProgressMeter(object): 79 | def __init__(self, num_batches, meters, prefix=""): 80 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 81 | self.meters = meters 82 | self.prefix = prefix 83 | 84 | def display(self, batch): 85 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 86 | entries += [str(meter) for meter in self.meters] 87 | print('\t'.join(entries)) 88 | 89 | def _get_batch_fmtstr(self, num_batches): 90 | num_digits = len(str(num_batches // 1)) 91 | fmt = '{:' + str(num_digits) + 'd}' 92 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 93 | 94 | 95 | -------------------------------------------------------------------------------- /common/utils/metric/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import prettytable 3 | 4 | __all__ = ['keypoint_detection'] 5 | 6 | def binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float: 7 | """Computes the accuracy for binary classification""" 8 | with torch.no_grad(): 9 | batch_size = target.size(0) 10 | pred = (output >= 0.5).float().t().view(-1) 11 | correct = pred.eq(target.view(-1)).float().sum() 12 | correct.mul_(100. / batch_size) 13 | return correct 14 | 15 | 16 | def accuracy(output, target, topk=(1,)): 17 | r""" 18 | Computes the accuracy over the k top predictions for the specified values of k 19 | 20 | Args: 21 | output (tensor): Classification outputs, :math:`(N, C)` where `C = number of classes` 22 | target (tensor): :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1` 23 | topk (sequence[int]): A list of top-N number. 24 | 25 | Returns: 26 | Top-N accuracies (N :math:`\in` topK). 27 | """ 28 | with torch.no_grad(): 29 | maxk = max(topk) 30 | batch_size = target.size(0) 31 | 32 | _, pred = output.topk(maxk, 1, True, True) 33 | pred = pred.t() 34 | correct = pred.eq(target[None]) 35 | 36 | res = [] 37 | for k in topk: 38 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 39 | res.append(correct_k * (100.0 / batch_size)) 40 | return res 41 | 42 | 43 | class ConfusionMatrix(object): 44 | def __init__(self, num_classes): 45 | self.num_classes = num_classes 46 | self.mat = None 47 | 48 | def update(self, target, output): 49 | """ 50 | Update confusion matrix. 51 | 52 | Args: 53 | target: ground truth 54 | output: predictions of models 55 | 56 | Shape: 57 | - target: :math:`(minibatch, C)` where C means the number of classes. 58 | - output: :math:`(minibatch, C)` where C means the number of classes. 59 | """ 60 | n = self.num_classes 61 | if self.mat is None: 62 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device) 63 | with torch.no_grad(): 64 | k = (target >= 0) & (target < n) 65 | inds = n * target[k].to(torch.int64) + output[k] 66 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 67 | 68 | def reset(self): 69 | self.mat.zero_() 70 | 71 | def compute(self): 72 | """compute global accuracy, per-class accuracy and per-class IoU""" 73 | h = self.mat.float() 74 | acc_global = torch.diag(h).sum() / h.sum() 75 | acc = torch.diag(h) / h.sum(1) 76 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 77 | return acc_global, acc, iu 78 | 79 | # def reduce_from_all_processes(self): 80 | # if not torch.distributed.is_available(): 81 | # return 82 | # if not torch.distributed.is_initialized(): 83 | # return 84 | # torch.distributed.barrier() 85 | # torch.distributed.all_reduce(self.mat) 86 | 87 | def __str__(self): 88 | acc_global, acc, iu = self.compute() 89 | return ( 90 | 'global correct: {:.1f}\n' 91 | 'average row correct: {}\n' 92 | 'IoU: {}\n' 93 | 'mean IoU: {:.1f}').format( 94 | acc_global.item() * 100, 95 | ['{:.1f}'.format(i) for i in (acc * 100).tolist()], 96 | ['{:.1f}'.format(i) for i in (iu * 100).tolist()], 97 | iu.mean().item() * 100) 98 | 99 | def format(self, classes: list): 100 | """Get the accuracy and IoU for each class in the table format""" 101 | acc_global, acc, iu = self.compute() 102 | 103 | table = prettytable.PrettyTable(["class", "acc", "iou"]) 104 | for i, class_name, per_acc, per_iu in zip(range(len(classes)), classes, (acc * 100).tolist(), (iu * 100).tolist()): 105 | table.add_row([class_name, per_acc, per_iu]) 106 | 107 | return 'global correct: {:.1f}\nmean correct:{:.1f}\nmean IoU: {:.1f}\n{}'.format( 108 | acc_global.item() * 100, acc.mean().item() * 100, iu.mean().item() * 100, table.get_string()) 109 | 110 | -------------------------------------------------------------------------------- /common/utils/sam.py: -------------------------------------------------------------------------------- 1 | # Credits: https://github.com/davda54/sam 2 | 3 | import torch 4 | 5 | class SAM(torch.optim.Optimizer): 6 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 7 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 8 | 9 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 10 | super(SAM, self).__init__(params, defaults) 11 | 12 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 13 | self.param_groups = self.base_optimizer.param_groups 14 | 15 | @torch.no_grad() 16 | def first_step(self, zero_grad=False): 17 | grad_norm = self._grad_norm() 18 | for group in self.param_groups: 19 | scale = group["rho"] / (grad_norm + 1e-12) 20 | 21 | for p in group["params"]: 22 | if p.grad is None: continue 23 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 24 | p.add_(e_w) # climb to the local maximum "w + e(w)" 25 | self.state[p]["e_w"] = e_w 26 | 27 | if zero_grad: self.zero_grad() 28 | 29 | @torch.no_grad() 30 | def second_step(self, zero_grad=False): 31 | for group in self.param_groups: 32 | for p in group["params"]: 33 | if p.grad is None: continue 34 | p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" 35 | 36 | self.base_optimizer.step() # do the actual "sharpness-aware" update 37 | 38 | if zero_grad: self.zero_grad() 39 | 40 | @torch.no_grad() 41 | def step(self, closure=None): 42 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 43 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 44 | 45 | self.first_step(zero_grad=True) 46 | closure() 47 | self.second_step() 48 | 49 | def _grad_norm(self): 50 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 51 | norm = torch.norm( 52 | torch.stack([ 53 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 54 | for group in self.param_groups for p in group["params"] 55 | if p.grad is not None 56 | ]), 57 | p=2 58 | ) 59 | return norm 60 | 61 | -------------------------------------------------------------------------------- /common/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bisect import bisect_right 3 | 4 | 5 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 6 | r"""Starts with a warm-up phase, then decays the learning rate of each parameter group by gamma once the 7 | number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | milestones (list): List of epoch indices. Must be increasing. 12 | gamma (float): Multiplicative factor of learning rate decay. 13 | Default: 0.1. 14 | warmup_factor (float): a float number :math:`k` between 0 and 1, the start learning rate of warmup phase 15 | will be set to :math:`k*initial\_lr` 16 | warmup_steps (int): number of warm-up steps. 17 | warmup_method (str): "constant" denotes a constant learning rate during warm-up phase and "linear" denotes a 18 | linear-increasing learning rate during warm-up phase. 19 | last_epoch (int): The index of last epoch. Default: -1. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | optimizer, 25 | milestones, 26 | gamma=0.1, 27 | warmup_factor=1.0 / 3, 28 | warmup_steps=500, 29 | warmup_method="linear", 30 | last_epoch=-1, 31 | ): 32 | if not list(milestones) == sorted(milestones): 33 | raise ValueError( 34 | "Milestones should be a list of" " increasing integers. Got {}", 35 | milestones, 36 | ) 37 | 38 | if warmup_method not in ("constant", "linear"): 39 | raise ValueError( 40 | "Only 'constant' or 'linear' warmup_method accepted" 41 | "got {}".format(warmup_method) 42 | ) 43 | self.milestones = milestones 44 | self.gamma = gamma 45 | self.warmup_factor = warmup_factor 46 | self.warmup_steps = warmup_steps 47 | self.warmup_method = warmup_method 48 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 49 | 50 | def get_lr(self): 51 | warmup_factor = 1 52 | if self.last_epoch < self.warmup_steps: 53 | if self.warmup_method == "constant": 54 | warmup_factor = self.warmup_factor 55 | elif self.warmup_method == "linear": 56 | alpha = float(self.last_epoch) / float(self.warmup_steps) 57 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 58 | return [ 59 | base_lr 60 | * warmup_factor 61 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 62 | for base_lr in self.base_lrs 63 | ] 64 | -------------------------------------------------------------------------------- /common/vision/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .officehome import OfficeHome 2 | from .visda2017 import VisDA2017 3 | from .domainnet import DomainNet 4 | 5 | __all__ = ['OfficeHome', "VisDA2017", "DomainNet"] 6 | 7 | 8 | -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/_util.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/_util.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/_util.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/aircrafts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/aircrafts.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/aircrafts.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/aircrafts.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/aircrafts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/aircrafts.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/coco70.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/coco70.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/coco70.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/coco70.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/coco70.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/coco70.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/cub200.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/cub200.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/cub200.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/cub200.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/cub200.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/cub200.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/digits.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/digits.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/digits.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/digits.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/digits.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/digits.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/domainnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/domainnet.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/domainnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/domainnet.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/domainnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/domainnet.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/dtd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/dtd.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/dtd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/dtd.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/dtd.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/dtd.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/eurosat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/eurosat.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/eurosat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/eurosat.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/eurosat.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/eurosat.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagelist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagelist.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagelist.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagelist.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagelist.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagelist.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagenet_r.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_r.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagenet_r.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_r.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagenet_r.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_r.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagenet_sketch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_sketch.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagenet_sketch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_sketch.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/imagenet_sketch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_sketch.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/office31.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/office31.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/office31.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/office31.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/office31.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/office31.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/officecaltech.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officecaltech.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/officecaltech.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officecaltech.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/officecaltech.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officecaltech.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/officehome.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officehome.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/officehome.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officehome.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/officehome.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officehome.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/oxfordflowers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordflowers.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/oxfordflowers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordflowers.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/oxfordflowers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordflowers.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/oxfordpet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordpet.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/oxfordpet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordpet.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/oxfordpet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordpet.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/pacs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/pacs.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/pacs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/pacs.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/pacs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/pacs.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/patchcamelyon.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/patchcamelyon.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/patchcamelyon.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/patchcamelyon.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/patchcamelyon.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/patchcamelyon.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/resisc45.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/resisc45.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/resisc45.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/resisc45.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/resisc45.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/resisc45.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/retinopathy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/retinopathy.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/retinopathy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/retinopathy.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/retinopathy.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/retinopathy.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/stanford_cars.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_cars.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/stanford_cars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_cars.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/stanford_cars.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_cars.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/stanford_dogs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_dogs.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/stanford_dogs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_dogs.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/stanford_dogs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_dogs.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/visda2017.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/visda2017.cpython-36.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/visda2017.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/visda2017.cpython-38.pyc -------------------------------------------------------------------------------- /common/vision/datasets/__pycache__/visda2017.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/visda2017.cpython-39.pyc -------------------------------------------------------------------------------- /common/vision/datasets/_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from torchvision.datasets.utils import download_and_extract_archive 4 | 5 | 6 | def download(root: str, file_name: str, archive_name: str, url_link: str): 7 | """ 8 | Download file from internet url link. 9 | 10 | Args: 11 | root (str) The directory to put downloaded files. 12 | file_name: (str) The name of the unzipped file. 13 | archive_name: (str) The name of archive(zipped file) downloaded. 14 | url_link: (str) The url link to download data. 15 | 16 | .. note:: 17 | If `file_name` already exists under path `root`, then it is not downloaded again. 18 | Else `archive_name` will be downloaded from `url_link` and extracted to `file_name`. 19 | """ 20 | if not os.path.exists(os.path.join(root, file_name)): 21 | print("Downloading {}".format(file_name)) 22 | # if os.path.exists(os.path.join(root, archive_name)): 23 | # os.remove(os.path.join(root, archive_name)) 24 | try: 25 | download_and_extract_archive(url_link, download_root=root, filename=archive_name, remove_finished=False) 26 | except Exception: 27 | print("Fail to download {} from url link {}".format(archive_name, url_link)) 28 | print('Please check you internet connection.' 29 | "Simply trying again may be fine.") 30 | exit(0) 31 | 32 | 33 | def check_exits(root: str, file_name: str): 34 | """Check whether `file_name` exists under directory `root`. """ 35 | if not os.path.exists(os.path.join(root, file_name)): 36 | print("Dataset directory {} not found under {}".format(file_name, root)) 37 | exit(-1) 38 | 39 | 40 | def read_list_from_file(file_name: str) -> List[str]: 41 | """Read data from file and convert each line into an element in the list""" 42 | result = [] 43 | with open(file_name, "r") as f: 44 | for line in f.readlines(): 45 | result.append(line.strip()) 46 | return result 47 | -------------------------------------------------------------------------------- /common/vision/datasets/domainnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from .imagelist import ImageList 4 | from ._util import download as download_data, check_exits 5 | 6 | 7 | class DomainNet(ImageList): 8 | """`DomainNet `_ (cleaned version, recommended) 9 | 10 | See `Moment Matching for Multi-Source Domain Adaptation `_ for details. 11 | 12 | Args: 13 | root (str): Root directory of dataset 14 | task (str): The task (domain) to create dataset. Choices include ``'c'``:clipart, \ 15 | ``'i'``: infograph, ``'p'``: painting, ``'q'``: quickdraw, ``'r'``: real, ``'s'``: sketch 16 | split (str, optional): The dataset split, supports ``train``, or ``test``. 17 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 18 | in root directory. If dataset is already downloaded, it is not downloaded again. 19 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 20 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 21 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 22 | 23 | .. note:: In `root`, there will exist following files after downloading. 24 | :: 25 | clipart/ 26 | infograph/ 27 | painting/ 28 | quickdraw/ 29 | real/ 30 | sketch/ 31 | image_list/ 32 | clipart.txt 33 | ... 34 | """ 35 | download_list = [ 36 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/bf0fe327e4b046eb89ba/?dl=1"), 37 | ("clipart", "clipart.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip"), 38 | ("infograph", "infograph.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip"), 39 | ("painting", "painting.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip"), 40 | ("quickdraw", "quickdraw.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip"), 41 | ("real", "real.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip"), 42 | ("sketch", "sketch.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip"), 43 | ] 44 | image_list = { 45 | "c": "clipart", 46 | "i": "infograph", 47 | "p": "painting", 48 | "q": "quickdraw", 49 | "r": "real", 50 | "s": "sketch", 51 | } 52 | CLASSES = ['aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil', 53 | 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball_bat', 54 | 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 55 | 'bicycle', 'binoculars', 'bird', 'birthday_cake', 'blackberry', 'blueberry', 'book', 'boomerang', 56 | 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 57 | 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 58 | 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling_fan', 59 | 'cello', 'cell_phone', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 60 | 'coffee_cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 61 | 'crown', 'cruise_ship', 'cup', 'diamond', 'dishwasher', 'diving_board', 'dog', 'dolphin', 'donut', 62 | 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 63 | 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire_hydrant', 64 | 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip_flops', 'floor_lamp', 'flower', 65 | 'flying_saucer', 'foot', 'fork', 'frog', 'frying_pan', 'garden', 'garden_hose', 'giraffe', 'goatee', 66 | 'golf_club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 67 | 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey_puck', 'hockey_stick', 'horse', 'hospital', 68 | 'hot_air_balloon', 'hot_dog', 'hot_tub', 'hourglass', 'house', 'house_plant', 'hurricane', 'ice_cream', 69 | 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf', 70 | 'leg', 'light_bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 71 | 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 72 | 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 73 | 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paintbrush', 74 | 'paint_can', 'palm_tree', 'panda', 'pants', 'paper_clip', 'parachute', 'parrot', 'passport', 'peanut', 75 | 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup_truck', 'picture_frame', 'pig', 'pillow', 76 | 'pineapple', 'pizza', 'pliers', 'police_car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 77 | 'power_outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote_control', 78 | 'rhinoceros', 'rifle', 'river', 'roller_coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 79 | 'saxophone', 'school_bus', 'scissors', 'scorpion', 'screwdriver', 'sea_turtle', 'see_saw', 'shark', 80 | 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping_bag', 81 | 'smiley_face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer_ball', 'sock', 'speedboat', 82 | 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 83 | 'stethoscope', 'stitches', 'stop_sign', 'stove', 'strawberry', 'streetlight', 'string_bean', 'submarine', 84 | 'suitcase', 'sun', 'swan', 'sweater', 'swing_set', 'sword', 'syringe', 'table', 'teapot', 'teddy-bear', 85 | 'telephone', 'television', 'tennis_racquet', 'tent', 'The_Eiffel_Tower', 'The_Great_Wall_of_China', 86 | 'The_Mona_Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 87 | 'tractor', 'traffic_light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 't-shirt', 88 | 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing_machine', 'watermelon', 'waterslide', 89 | 'whale', 'wheel', 'windmill', 'wine_bottle', 'wine_glass', 'wristwatch', 'yoga', 'zebra', 'zigzag'] 90 | 91 | def __init__(self, root: str, task: str, split: Optional[str] = 'train', download: Optional[float] = False, **kwargs): 92 | assert task in self.image_list 93 | assert split in ['train', 'test'] 94 | data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split)) 95 | print("loading {}".format(data_list_file)) 96 | 97 | if download: 98 | list(map(lambda args: download_data(root, *args), self.download_list)) 99 | else: 100 | list(map(lambda args: check_exits(root, args[0]), self.download_list)) 101 | 102 | super(DomainNet, self).__init__(root, DomainNet.CLASSES, data_list_file=data_list_file, **kwargs) 103 | 104 | @classmethod 105 | def domains(cls): 106 | return list(cls.image_list.keys()) 107 | -------------------------------------------------------------------------------- /common/vision/datasets/imagelist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Callable, Tuple, Any, List 3 | import torchvision.datasets as datasets 4 | from torchvision.datasets.folder import default_loader 5 | 6 | 7 | class ImageList(datasets.VisionDataset): 8 | """A generic Dataset class for image classification 9 | 10 | Args: 11 | root (str): Root directory of dataset 12 | classes (list[str]): The names of all the classes 13 | data_list_file (str): File to read the image list from. 14 | transform (callable, optional): A function/transform that takes in an PIL image \ 15 | and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 16 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 17 | 18 | .. note:: In `data_list_file`, each line has 2 values in the following format. 19 | :: 20 | source_dir/dog_xxx.png 0 21 | source_dir/cat_123.png 1 22 | target_dir/dog_xxy.png 0 23 | target_dir/cat_nsdf3.png 1 24 | 25 | The first value is the relative path of an image, and the second value is the label of the corresponding image. 26 | If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`. 27 | """ 28 | 29 | def __init__(self, root: str, classes: List[str], data_list_file: str, 30 | transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): 31 | super().__init__(root, transform=transform, target_transform=target_transform) 32 | self.samples = self.parse_data_file(data_list_file) 33 | self.classes = classes 34 | self.class_to_idx = {cls: idx 35 | for idx, cls in enumerate(self.classes)} 36 | self.loader = default_loader 37 | self.data_list_file = data_list_file 38 | 39 | def __getitem__(self, index: int) -> Tuple[Any, int]: 40 | """ 41 | Args: 42 | index (int): Index 43 | return (tuple): (image, target) where target is index of the target class. 44 | """ 45 | path, target = self.samples[index] 46 | img = self.loader(path) 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | if self.target_transform is not None and target is not None: 50 | target = self.target_transform(target) 51 | return img, target 52 | 53 | def __len__(self) -> int: 54 | return len(self.samples) 55 | 56 | def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]: 57 | """Parse file to data list 58 | 59 | Args: 60 | file_name (str): The path of data file 61 | return (list): List of (image path, class_index) tuples 62 | """ 63 | with open(file_name, "r") as f: 64 | data_list = [] 65 | for line in f.readlines(): 66 | split_line = line.split() 67 | target = split_line[-1] 68 | path = ' '.join(split_line[:-1]) 69 | if not os.path.isabs(path): 70 | path = os.path.join(self.root, path) 71 | target = int(target) 72 | data_list.append((path, target)) 73 | return data_list 74 | 75 | @property 76 | def num_classes(self) -> int: 77 | """Number of classes""" 78 | return len(self.classes) 79 | 80 | @classmethod 81 | def domains(cls): 82 | """All possible domain in this dataset""" 83 | raise NotImplemented 84 | -------------------------------------------------------------------------------- /common/vision/datasets/officehome.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from .imagelist import ImageList 4 | from ._util import download as download_data, check_exits 5 | 6 | 7 | class OfficeHome(ImageList): 8 | """`OfficeHome `_ Dataset. 9 | 10 | Args: 11 | root (str): Root directory of dataset 12 | task (str): The task (domain) to create dataset. Choices include ``'Ar'``: Art, \ 13 | ``'Cl'``: Clipart, ``'Pr'``: Product and ``'Rw'``: Real_World. 14 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 15 | in root directory. If dataset is already downloaded, it is not downloaded again. 16 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 17 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 18 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 19 | 20 | .. note:: In `root`, there will exist following files after downloading. 21 | :: 22 | Art/ 23 | Alarm_Clock/*.jpg 24 | ... 25 | Clipart/ 26 | Product/ 27 | Real_World/ 28 | image_list/ 29 | Art.txt 30 | Clipart.txt 31 | Product.txt 32 | Real_World.txt 33 | """ 34 | download_list = [ 35 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/1b0171a188944313b1f5/?dl=1"), 36 | ("Art", "Art.tgz", "https://cloud.tsinghua.edu.cn/f/6a006656b9a14567ade2/?dl=1"), 37 | ("Clipart", "Clipart.tgz", "https://cloud.tsinghua.edu.cn/f/ae88aa31d2d7411dad79/?dl=1"), 38 | ("Product", "Product.tgz", "https://cloud.tsinghua.edu.cn/f/f219b0ff35e142b3ab48/?dl=1"), 39 | ("Real_World", "Real_World.tgz", "https://cloud.tsinghua.edu.cn/f/6c19f3f15bb24ed3951a/?dl=1") 40 | ] 41 | image_list = { 42 | "Ar": "image_list/Art.txt", 43 | "Cl": "image_list/Clipart.txt", 44 | "Pr": "image_list/Product.txt", 45 | "Rw": "image_list/Real_World.txt", 46 | } 47 | CLASSES = ['Drill', 'Exit_Sign', 'Bottle', 'Glasses', 'Computer', 'File_Cabinet', 'Shelf', 'Toys', 'Sink', 48 | 'Laptop', 'Kettle', 'Folder', 'Keyboard', 'Flipflops', 'Pencil', 'Bed', 'Hammer', 'ToothBrush', 'Couch', 49 | 'Bike', 'Postit_Notes', 'Mug', 'Webcam', 'Desk_Lamp', 'Telephone', 'Helmet', 'Mouse', 'Pen', 'Monitor', 50 | 'Mop', 'Sneakers', 'Notebook', 'Backpack', 'Alarm_Clock', 'Push_Pin', 'Paper_Clip', 'Batteries', 'Radio', 51 | 'Fan', 'Ruler', 'Pan', 'Screwdriver', 'Trash_Can', 'Printer', 'Speaker', 'Eraser', 'Bucket', 'Chair', 52 | 'Calendar', 'Calculator', 'Flowers', 'Lamp_Shade', 'Spoon', 'Candles', 'Clipboards', 'Scissors', 'TV', 53 | 'Curtains', 'Fork', 'Soda', 'Table', 'Knives', 'Oven', 'Refrigerator', 'Marker'] 54 | 55 | def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs): 56 | assert task in self.image_list 57 | data_list_file = os.path.join(root, self.image_list[task]) 58 | 59 | if download: 60 | list(map(lambda args: download_data(root, *args), self.download_list)) 61 | else: 62 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) 63 | 64 | super(OfficeHome, self).__init__(root, OfficeHome.CLASSES, data_list_file=data_list_file, **kwargs) 65 | 66 | @classmethod 67 | def domains(cls): 68 | return list(cls.image_list.keys()) 69 | -------------------------------------------------------------------------------- /common/vision/datasets/visda2017.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from .imagelist import ImageList 4 | from ._util import download as download_data, check_exits 5 | 6 | 7 | class VisDA2017(ImageList): 8 | """`VisDA-2017 `_ Dataset 9 | 10 | Args: 11 | root (str): Root directory of dataset 12 | task (str): The task (domain) to create dataset. Choices include ``'Synthetic'``: synthetic images and \ 13 | ``'Real'``: real-world images. 14 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 15 | in root directory. If dataset is already downloaded, it is not downloaded again. 16 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 17 | transformed version. E.g, ``transforms.RandomCrop``. 18 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 19 | 20 | .. note:: In `root`, there will exist following files after downloading. 21 | :: 22 | train/ 23 | aeroplance/ 24 | *.png 25 | ... 26 | validation/ 27 | image_list/ 28 | train.txt 29 | validation.txt 30 | """ 31 | download_list = [ 32 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/c107de37b8094c5398dc/?dl=1"), 33 | ("train", "train.tar", "http://csr.bu.edu/ftp/visda17/clf/train.tar"), 34 | ("validation", "validation.tar", "http://csr.bu.edu/ftp/visda17/clf/validation.tar") 35 | ] 36 | image_list = { 37 | "Synthetic": "image_list/train.txt", 38 | "Real": "image_list/validation.txt" 39 | } 40 | CLASSES = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife', 41 | 'motorcycle', 'person', 'plant', 'skateboard', 'train', 'truck'] 42 | 43 | def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs): 44 | assert task in self.image_list 45 | data_list_file = os.path.join(root, self.image_list[task]) 46 | 47 | if download: 48 | list(map(lambda args: download_data(root, *args), self.download_list)) 49 | else: 50 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) 51 | 52 | super(VisDA2017, self).__init__(root, VisDA2017.CLASSES, data_list_file=data_list_file, **kwargs) 53 | 54 | @classmethod 55 | def domains(cls): 56 | return list(cls.image_list.keys()) 57 | -------------------------------------------------------------------------------- /common/vision/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | 3 | __all__ = ['resnet'] 4 | -------------------------------------------------------------------------------- /common/vision/models/resnet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from torchvision import models 4 | from torchvision.models.utils import load_state_dict_from_url 5 | #from torch.hub import load_state_dict_from_url 6 | from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls 7 | import copy 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 11 | 'wide_resnet50_2', 'wide_resnet101_2'] 12 | 13 | 14 | class ResNet(models.ResNet): 15 | """ResNets without fully connected layer""" 16 | 17 | def __init__(self, *args, **kwargs): 18 | super(ResNet, self).__init__(*args, **kwargs) 19 | self._out_features = self.fc.in_features 20 | 21 | def forward(self, x): 22 | """""" 23 | x = self.conv1(x) 24 | x = self.bn1(x) 25 | x = self.relu(x) 26 | x = self.maxpool(x) 27 | 28 | x = self.layer1(x) 29 | x = self.layer2(x) 30 | x = self.layer3(x) 31 | x = self.layer4(x) 32 | 33 | return x 34 | 35 | @property 36 | def out_features(self) -> int: 37 | """The dimension of output features""" 38 | return self._out_features 39 | 40 | def copy_head(self) -> nn.Module: 41 | """Copy the origin fully connected layer""" 42 | return copy.deepcopy(self.fc) 43 | 44 | 45 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 46 | model = ResNet(block, layers, **kwargs) 47 | if pretrained: 48 | model_dict = model.state_dict() 49 | pretrained_dict = load_state_dict_from_url(model_urls[arch], 50 | progress=progress) 51 | # remove keys from pretrained dict that doesn't appear in model dict 52 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 53 | model.load_state_dict(pretrained_dict, strict=False) 54 | return model 55 | 56 | 57 | def resnet18(pretrained=False, progress=True, **kwargs): 58 | r"""ResNet-18 model from 59 | `"Deep Residual Learning for Image Recognition" `_ 60 | 61 | Args: 62 | pretrained (bool): If True, returns a model pre-trained on ImageNet 63 | progress (bool): If True, displays a progress bar of the download to stderr 64 | """ 65 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 66 | **kwargs) 67 | 68 | 69 | def resnet34(pretrained=False, progress=True, **kwargs): 70 | r"""ResNet-34 model from 71 | `"Deep Residual Learning for Image Recognition" `_ 72 | 73 | Args: 74 | pretrained (bool): If True, returns a model pre-trained on ImageNet 75 | progress (bool): If True, displays a progress bar of the download to stderr 76 | """ 77 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 78 | **kwargs) 79 | 80 | 81 | def resnet50(pretrained=False, progress=True, **kwargs): 82 | r"""ResNet-50 model from 83 | `"Deep Residual Learning for Image Recognition" `_ 84 | 85 | Args: 86 | pretrained (bool): If True, returns a model pre-trained on ImageNet 87 | progress (bool): If True, displays a progress bar of the download to stderr 88 | """ 89 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 90 | **kwargs) 91 | 92 | 93 | def resnet101(pretrained=False, progress=True, **kwargs): 94 | r"""ResNet-101 model from 95 | `"Deep Residual Learning for Image Recognition" `_ 96 | 97 | Args: 98 | pretrained (bool): If True, returns a model pre-trained on ImageNet 99 | progress (bool): If True, displays a progress bar of the download to stderr 100 | """ 101 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 102 | **kwargs) 103 | 104 | 105 | def resnet152(pretrained=False, progress=True, **kwargs): 106 | r"""ResNet-152 model from 107 | `"Deep Residual Learning for Image Recognition" `_ 108 | 109 | Args: 110 | pretrained (bool): If True, returns a model pre-trained on ImageNet 111 | progress (bool): If True, displays a progress bar of the download to stderr 112 | """ 113 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 114 | **kwargs) 115 | 116 | 117 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 118 | r"""ResNeXt-50 32x4d model from 119 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 120 | 121 | Args: 122 | pretrained (bool): If True, returns a model pre-trained on ImageNet 123 | progress (bool): If True, displays a progress bar of the download to stderr 124 | """ 125 | kwargs['groups'] = 32 126 | kwargs['width_per_group'] = 4 127 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 128 | pretrained, progress, **kwargs) 129 | 130 | 131 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 132 | r"""ResNeXt-101 32x8d model from 133 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 134 | 135 | Args: 136 | pretrained (bool): If True, returns a model pre-trained on ImageNet 137 | progress (bool): If True, displays a progress bar of the download to stderr 138 | """ 139 | kwargs['groups'] = 32 140 | kwargs['width_per_group'] = 8 141 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 142 | pretrained, progress, **kwargs) 143 | 144 | 145 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 146 | r"""Wide ResNet-50-2 model from 147 | `"Wide Residual Networks" `_ 148 | 149 | The model is the same as ResNet except for the bottleneck number of channels 150 | which is twice larger in every block. The number of channels in outer 1x1 151 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 152 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 153 | 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | progress (bool): If True, displays a progress bar of the download to stderr 157 | """ 158 | kwargs['width_per_group'] = 64 * 2 159 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 160 | pretrained, progress, **kwargs) 161 | 162 | 163 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 164 | r"""Wide ResNet-101-2 model from 165 | `"Wide Residual Networks" `_ 166 | 167 | The model is the same as ResNet except for the bottleneck number of channels 168 | which is twice larger in every block. The number of channels in outer 1x1 169 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 170 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | progress (bool): If True, displays a progress bar of the download to stderr 175 | """ 176 | kwargs['width_per_group'] = 64 * 2 177 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 178 | pretrained, progress, **kwargs) 179 | -------------------------------------------------------------------------------- /common/vision/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torchvision.transforms import Normalize 7 | 8 | 9 | class ResizeImage(object): 10 | """Resize the input PIL Image to the given size. 11 | 12 | Args: 13 | size (sequence or int): Desired output size. If size is a sequence like 14 | (h, w), output size will be matched to this. If size is an int, 15 | output size will be (size, size) 16 | """ 17 | 18 | def __init__(self, size): 19 | if isinstance(size, int): 20 | self.size = (int(size), int(size)) 21 | else: 22 | self.size = size 23 | 24 | def __call__(self, img): 25 | th, tw = self.size 26 | return img.resize((th, tw)) 27 | 28 | def __repr__(self): 29 | return self.__class__.__name__ + '(size={0})'.format(self.size) 30 | 31 | 32 | class MultipleApply: 33 | """Apply a list of transformations to an image and get multiple transformed images. 34 | 35 | Args: 36 | transforms (list or tuple): list of transformations 37 | 38 | Example: 39 | 40 | >>> transform1 = T.Compose([ 41 | ... ResizeImage(256), 42 | ... T.RandomCrop(224) 43 | ... ]) 44 | >>> transform2 = T.Compose([ 45 | ... ResizeImage(256), 46 | ... T.RandomCrop(224), 47 | ... ]) 48 | >>> multiply_transform = MultipleApply([transform1, transform2]) 49 | """ 50 | 51 | def __init__(self, transforms): 52 | self.transforms = transforms 53 | 54 | def __call__(self, image): 55 | return [t(image) for t in self.transforms] 56 | 57 | def __repr__(self): 58 | format_string = self.__class__.__name__ + '(' 59 | for t in self.transforms: 60 | format_string += '\n' 61 | format_string += ' {0}'.format(t) 62 | format_string += '\n)' 63 | return format_string 64 | 65 | 66 | class Denormalize(Normalize): 67 | """DeNormalize a tensor image with mean and standard deviation. 68 | Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` 69 | channels, this transform will denormalize each channel of the input 70 | ``torch.*Tensor`` i.e., 71 | ``output[channel] = input[channel] * std[channel] + mean[channel]`` 72 | 73 | .. note:: 74 | This transform acts out of place, i.e., it does not mutate the input tensor. 75 | 76 | Args: 77 | mean (sequence): Sequence of means for each channel. 78 | std (sequence): Sequence of standard deviations for each channel. 79 | 80 | """ 81 | 82 | def __init__(self, mean, std): 83 | mean = np.array(mean) 84 | std = np.array(std) 85 | super().__init__((-mean / std).tolist(), (1 / std).tolist()) 86 | 87 | 88 | class NormalizeAndTranspose: 89 | """ 90 | First, normalize a tensor image with mean and standard deviation. 91 | Then, convert the shape (H x W x C) to shape (C x H x W). 92 | """ 93 | 94 | def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)): 95 | self.mean = np.array(mean, dtype=np.float32) 96 | 97 | def __call__(self, image): 98 | if isinstance(image, Image.Image): 99 | image = np.asarray(image, np.float32) 100 | # change to BGR 101 | image = image[:, :, ::-1] 102 | # normalize 103 | image -= self.mean 104 | image = image.transpose((2, 0, 1)).copy() 105 | elif isinstance(image, torch.Tensor): 106 | # change to BGR 107 | image = image[:, :, [2, 1, 0]] 108 | # normalize 109 | image -= torch.from_numpy(self.mean).to(image.device) 110 | image = image.permute((2, 0, 1)) 111 | else: 112 | raise NotImplementedError(type(image)) 113 | return image 114 | 115 | 116 | class DeNormalizeAndTranspose: 117 | """ 118 | First, convert a tensor image from the shape (C x H x W ) to shape (H x W x C). 119 | Then, denormalize it with mean and standard deviation. 120 | """ 121 | 122 | def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)): 123 | self.mean = np.array(mean, dtype=np.float32) 124 | 125 | def __call__(self, image): 126 | image = image.transpose((1, 2, 0)) 127 | # denormalize 128 | image += self.mean 129 | # change to RGB 130 | image = image[:, :, ::-1] 131 | return image 132 | 133 | 134 | class RandomErasing(object): 135 | """Random erasing augmentation from `Random Erasing Data Augmentation (CVPR 2017) 136 | `_. This augmentation randomly selects a rectangle region in an image 137 | and erases its pixels. 138 | 139 | Args: 140 | probability (float): The probability that the Random Erasing operation will be performed. 141 | sl (float): Minimum proportion of erased area against input image. 142 | sh (float): Maximum proportion of erased area against input image. 143 | r1 (float): Minimum aspect ratio of erased area. 144 | mean (sequence): Value to fill the erased area. 145 | """ 146 | 147 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 148 | self.probability = probability 149 | self.mean = mean 150 | self.sl = sl 151 | self.sh = sh 152 | self.r1 = r1 153 | 154 | def __call__(self, img): 155 | 156 | if random.uniform(0, 1) >= self.probability: 157 | return img 158 | 159 | for attempt in range(100): 160 | area = img.size()[1] * img.size()[2] 161 | 162 | target_area = random.uniform(self.sl, self.sh) * area 163 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 164 | 165 | h = int(round(math.sqrt(target_area * aspect_ratio))) 166 | w = int(round(math.sqrt(target_area / aspect_ratio))) 167 | 168 | if w < img.size()[2] and h < img.size()[1]: 169 | x1 = random.randint(0, img.size()[1] - h) 170 | y1 = random.randint(0, img.size()[2] - w) 171 | if img.size()[0] == 3: 172 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 173 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 174 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 175 | else: 176 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 177 | return img 178 | 179 | return img 180 | 181 | def __repr__(self): 182 | return self.__class__.__name__ + '(p={})'.format(self.probability) 183 | -------------------------------------------------------------------------------- /dalib/adaptation/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cdan 2 | from . import mcc 3 | 4 | __all__ = ["cdan", "mcc"] 5 | -------------------------------------------------------------------------------- /dalib/adaptation/cdan.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from common.modules.classifier import Classifier as ClassifierBase 8 | from common.utils.metric import binary_accuracy 9 | from ..modules.grl import WarmStartGradientReverseLayer 10 | from ..modules.entropy import entropy 11 | 12 | 13 | __all__ = ['ConditionalDomainAdversarialLoss', 'ImageClassifier'] 14 | 15 | 16 | class ConditionalDomainAdversarialLoss(nn.Module): 17 | r"""The Conditional Domain Adversarial Loss used in `Conditional Adversarial Domain Adaptation (NIPS 2018) `_ 18 | 19 | Conditional Domain adversarial loss measures the domain discrepancy through training a domain discriminator in a 20 | conditional manner. Given domain discriminator :math:`D`, feature representation :math:`f` and 21 | classifier predictions :math:`g`, the definition of CDAN loss is 22 | 23 | .. math:: 24 | loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(T(f_i^s, g_i^s))] \\ 25 | &+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(T(f_j^t, g_j^t))],\\ 26 | 27 | where :math:`T` is a :class:`MultiLinearMap` or :class:`RandomizedMultiLinearMap` which convert two tensors to a single tensor. 28 | 29 | Args: 30 | domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of 31 | features. Its input shape is (N, F) and output shape is (N, 1) 32 | entropy_conditioning (bool, optional): If True, use entropy-aware weight to reweight each training example. 33 | Default: False 34 | randomized (bool, optional): If True, use `randomized multi linear map`. Else, use `multi linear map`. 35 | Default: False 36 | num_classes (int, optional): Number of classes. Default: -1 37 | features_dim (int, optional): Dimension of input features. Default: -1 38 | randomized_dim (int, optional): Dimension of features after randomized. Default: 1024 39 | reduction (str, optional): Specifies the reduction to apply to the output: 40 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 41 | ``'mean'``: the sum of the output will be divided by the number of 42 | elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` 43 | 44 | .. note:: 45 | You need to provide `num_classes`, `features_dim` and `randomized_dim` **only when** `randomized` 46 | is set True. 47 | 48 | Inputs: 49 | - g_s (tensor): unnormalized classifier predictions on source domain, :math:`g^s` 50 | - f_s (tensor): feature representations on source domain, :math:`f^s` 51 | - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t` 52 | - f_t (tensor): feature representations on target domain, :math:`f^t` 53 | 54 | Shape: 55 | - g_s, g_t: :math:`(minibatch, C)` where C means the number of classes. 56 | - f_s, f_t: :math:`(minibatch, F)` where F means the dimension of input features. 57 | - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, )`. 58 | 59 | Examples:: 60 | 61 | >>> from dalib.modules.domain_discriminator import DomainDiscriminator 62 | >>> from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss 63 | >>> import torch 64 | >>> num_classes = 2 65 | >>> feature_dim = 1024 66 | >>> batch_size = 10 67 | >>> discriminator = DomainDiscriminator(in_feature=feature_dim * num_classes, hidden_size=1024) 68 | >>> loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean') 69 | >>> # features from source domain and target domain 70 | >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) 71 | >>> # logits output from source domain adn target domain 72 | >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) 73 | >>> output = loss(g_s, f_s, g_t, f_t) 74 | """ 75 | 76 | def __init__(self, domain_discriminator: nn.Module, entropy_conditioning: Optional[bool] = False, 77 | randomized: Optional[bool] = False, num_classes: Optional[int] = -1, 78 | features_dim: Optional[int] = -1, randomized_dim: Optional[int] = 1024, 79 | reduction: Optional[str] = 'mean'): 80 | super(ConditionalDomainAdversarialLoss, self).__init__() 81 | self.domain_discriminator = domain_discriminator 82 | self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True) 83 | self.entropy_conditioning = entropy_conditioning 84 | 85 | if randomized: 86 | assert num_classes > 0 and features_dim > 0 and randomized_dim > 0 87 | self.map = RandomizedMultiLinearMap(features_dim, num_classes, randomized_dim) 88 | else: 89 | self.map = MultiLinearMap() 90 | 91 | self.bce = lambda input, target, weight: F.binary_cross_entropy(input, target, weight, 92 | reduction=reduction) if self.entropy_conditioning \ 93 | else F.binary_cross_entropy(input, target, reduction=reduction) 94 | self.domain_discriminator_accuracy = None 95 | 96 | def forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor: 97 | f = torch.cat((f_s, f_t), dim=0) 98 | g = torch.cat((g_s, g_t), dim=0) 99 | g = F.softmax(g, dim=1).detach() 100 | h = self.grl(self.map(f, g)) 101 | d = self.domain_discriminator(h) 102 | d_label = torch.cat(( 103 | torch.ones((g_s.size(0), 1)).to(g_s.device), 104 | torch.zeros((g_t.size(0), 1)).to(g_t.device), 105 | )) 106 | weight = 1.0 + torch.exp(-entropy(g)) 107 | batch_size = f.size(0) 108 | weight = weight / torch.sum(weight) * batch_size 109 | self.domain_discriminator_accuracy = binary_accuracy(d, d_label) 110 | return self.bce(d, d_label, weight.view_as(d)) 111 | 112 | 113 | class RandomizedMultiLinearMap(nn.Module): 114 | """Random multi linear map 115 | 116 | Given two inputs :math:`f` and :math:`g`, the definition is 117 | 118 | .. math:: 119 | T_{\odot}(f,g) = \dfrac{1}{\sqrt{d}} (R_f f) \odot (R_g g), 120 | 121 | where :math:`\odot` is element-wise product, :math:`R_f` and :math:`R_g` are random matrices 122 | sampled only once and fixed in training. 123 | 124 | Args: 125 | features_dim (int): dimension of input :math:`f` 126 | num_classes (int): dimension of input :math:`g` 127 | output_dim (int, optional): dimension of output tensor. Default: 1024 128 | 129 | Shape: 130 | - f: (minibatch, features_dim) 131 | - g: (minibatch, num_classes) 132 | - Outputs: (minibatch, output_dim) 133 | """ 134 | 135 | def __init__(self, features_dim: int, num_classes: int, output_dim: Optional[int] = 1024): 136 | super(RandomizedMultiLinearMap, self).__init__() 137 | self.Rf = torch.randn(features_dim, output_dim) 138 | self.Rg = torch.randn(num_classes, output_dim) 139 | self.output_dim = output_dim 140 | 141 | def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: 142 | f = torch.mm(f, self.Rf.to(f.device)) 143 | g = torch.mm(g, self.Rg.to(g.device)) 144 | output = torch.mul(f, g) / np.sqrt(float(self.output_dim)) 145 | return output 146 | 147 | 148 | class MultiLinearMap(nn.Module): 149 | """Multi linear map 150 | 151 | Shape: 152 | - f: (minibatch, F) 153 | - g: (minibatch, C) 154 | - Outputs: (minibatch, F * C) 155 | """ 156 | 157 | def __init__(self): 158 | super(MultiLinearMap, self).__init__() 159 | 160 | def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: 161 | batch_size = f.size(0) 162 | output = torch.bmm(g.unsqueeze(2), f.unsqueeze(1)) 163 | return output.view(batch_size, -1) 164 | 165 | 166 | class ImageClassifier(ClassifierBase): 167 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): 168 | bottleneck = nn.Sequential( 169 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)), 170 | # nn.Flatten(), 171 | nn.Linear(backbone.out_features, bottleneck_dim), 172 | nn.BatchNorm1d(bottleneck_dim), 173 | nn.ReLU() 174 | ) 175 | super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) 176 | -------------------------------------------------------------------------------- /dalib/adaptation/mcc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from common.modules.classifier import Classifier as ClassifierBase 7 | from ..modules.entropy import entropy 8 | 9 | 10 | __all__ = ['MinimumClassConfusionLoss', 'ImageClassifier'] 11 | 12 | 13 | class MinimumClassConfusionLoss(nn.Module): 14 | r""" 15 | Minimum Class Confusion loss minimizes the class confusion in the target predictions. 16 | 17 | You can see more details in `Minimum Class Confusion for Versatile Domain Adaptation (ECCV 2020) `_ 18 | 19 | Args: 20 | temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if 21 | temperature is 1.0. 22 | 23 | .. note:: 24 | Make sure that temperature is larger than 0. 25 | 26 | Inputs: g_t 27 | - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t` 28 | 29 | Shape: 30 | - g_t: :math:`(minibatch, C)` where C means the number of classes. 31 | - Output: scalar. 32 | 33 | Examples:: 34 | >>> temperature = 2.0 35 | >>> loss = MinimumClassConfusionLoss(temperature) 36 | >>> # logits output from target domain 37 | >>> g_t = torch.randn(batch_size, num_classes) 38 | >>> output = loss(g_t) 39 | 40 | MCC can also serve as a regularizer for existing methods. 41 | Examples:: 42 | >>> from dalib.modules.domain_discriminator import DomainDiscriminator 43 | >>> num_classes = 2 44 | >>> feature_dim = 1024 45 | >>> batch_size = 10 46 | >>> temperature = 2.0 47 | >>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024) 48 | >>> cdan_loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean') 49 | >>> mcc_loss = MinimumClassConfusionLoss(temperature) 50 | >>> # features from source domain and target domain 51 | >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) 52 | >>> # logits output from source domain adn target domain 53 | >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) 54 | >>> total_loss = cdan_loss(g_s, f_s, g_t, f_t) + mcc_loss(g_t) 55 | """ 56 | 57 | def __init__(self, temperature: float): 58 | super(MinimumClassConfusionLoss, self).__init__() 59 | self.temperature = temperature 60 | 61 | def forward(self, logits: torch.Tensor) -> torch.Tensor: 62 | batch_size, num_classes = logits.shape 63 | predictions = F.softmax(logits / self.temperature, dim=1) # batch_size x num_classes 64 | entropy_weight = entropy(predictions).detach() 65 | entropy_weight = 1 + torch.exp(-entropy_weight) 66 | entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1) # batch_size x 1 67 | class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes 68 | class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1) 69 | mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes 70 | return mcc_loss 71 | 72 | 73 | class ImageClassifier(ClassifierBase): 74 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): 75 | bottleneck = nn.Sequential( 76 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)), 77 | # nn.Flatten(), 78 | nn.Linear(backbone.out_features, bottleneck_dim), 79 | nn.BatchNorm1d(bottleneck_dim), 80 | nn.ReLU() 81 | ) 82 | super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) 83 | -------------------------------------------------------------------------------- /dalib/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .grl import * 2 | from .domain_discriminator import * 3 | from .kernels import * 4 | from .entropy import * 5 | 6 | __all__ = ['grl', 'kernels', 'domain_discriminator', 'entropy'] 7 | -------------------------------------------------------------------------------- /dalib/modules/domain_discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from typing import List, Dict 6 | import torch.nn as nn 7 | 8 | __all__ = ['DomainDiscriminator'] 9 | 10 | 11 | class DomainDiscriminator(nn.Sequential): 12 | r"""Domain discriminator model from 13 | `Domain-Adversarial Training of Neural Networks (ICML 2015) `_ 14 | 15 | Distinguish whether the input features come from the source domain or the target domain. 16 | The source domain label is 1 and the target domain label is 0. 17 | 18 | Args: 19 | in_feature (int): dimension of the input feature 20 | hidden_size (int): dimension of the hidden features 21 | batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`. 22 | Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True. 23 | 24 | Shape: 25 | - Inputs: (minibatch, `in_feature`) 26 | - Outputs: :math:`(minibatch, 1)` 27 | """ 28 | 29 | def __init__(self, in_feature: int, hidden_size: int, batch_norm=True): 30 | if batch_norm: 31 | super(DomainDiscriminator, self).__init__( 32 | nn.Linear(in_feature, hidden_size), 33 | nn.BatchNorm1d(hidden_size), 34 | nn.ReLU(), 35 | nn.Linear(hidden_size, hidden_size), 36 | nn.BatchNorm1d(hidden_size), 37 | nn.ReLU(), 38 | nn.Linear(hidden_size, 1), 39 | nn.Sigmoid() 40 | ) 41 | else: 42 | super(DomainDiscriminator, self).__init__( 43 | nn.Linear(in_feature, hidden_size), 44 | nn.ReLU(inplace=True), 45 | nn.Dropout(0.5), 46 | nn.Linear(hidden_size, hidden_size), 47 | nn.ReLU(inplace=True), 48 | nn.Dropout(0.5), 49 | nn.Linear(hidden_size, 1), 50 | nn.Sigmoid() 51 | ) 52 | 53 | def get_parameters(self) -> List[Dict]: 54 | return [{"params": self.parameters(), "lr": 1.}] 55 | 56 | 57 | -------------------------------------------------------------------------------- /dalib/modules/entropy.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch 6 | 7 | 8 | def entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor: 9 | r"""Entropy of prediction. 10 | The definition is: 11 | 12 | .. math:: 13 | entropy(p) = - \sum_{c=1}^C p_c \log p_c 14 | 15 | where C is number of classes. 16 | 17 | Args: 18 | predictions (tensor): Classifier predictions. Expected to contain raw, normalized scores for each class 19 | reduction (str, optional): Specifies the reduction to apply to the output: 20 | ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, 21 | ``'mean'``: the sum of the output will be divided by the number of 22 | elements in the output. Default: ``'mean'`` 23 | 24 | Shape: 25 | - predictions: :math:`(minibatch, C)` where C means the number of classes. 26 | - Output: :math:`(minibatch, )` by default. If :attr:`reduction` is ``'mean'``, then scalar. 27 | """ 28 | epsilon = 1e-5 29 | H = -predictions * torch.log(predictions + epsilon) 30 | H = H.sum(dim=1) 31 | if reduction == 'mean': 32 | return H.mean() 33 | else: 34 | return H 35 | -------------------------------------------------------------------------------- /dalib/modules/gl.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from typing import Optional, Any, Tuple 6 | import numpy as np 7 | import torch.nn as nn 8 | from torch.autograd import Function 9 | import torch 10 | 11 | 12 | class GradientFunction(Function): 13 | 14 | @staticmethod 15 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: 16 | ctx.coeff = coeff 17 | output = input * 1.0 18 | return output 19 | 20 | @staticmethod 21 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 22 | return grad_output * ctx.coeff, None 23 | 24 | 25 | class WarmStartGradientLayer(nn.Module): 26 | """Warm Start Gradient Layer :math:`\mathcal{R}(x)` with warm start 27 | 28 | The forward and backward behaviours are: 29 | 30 | .. math:: 31 | \mathcal{R}(x) = x, 32 | 33 | \dfrac{ d\mathcal{R}} {dx} = \lambda I. 34 | 35 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule: 36 | 37 | .. math:: 38 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo 39 | 40 | where :math:`i` is the iteration step. 41 | 42 | Parameters: 43 | - **alpha** (float, optional): :math:`α`. Default: 1.0 44 | - **lo** (float, optional): Initial value of :math:`\lambda`. Default: 0.0 45 | - **hi** (float, optional): Final value of :math:`\lambda`. Default: 1.0 46 | - **max_iters** (int, optional): :math:`N`. Default: 1000 47 | - **auto_step** (bool, optional): If True, increase :math:`i` each time `forward` is called. 48 | Otherwise use function `step` to increase :math:`i`. Default: False 49 | """ 50 | 51 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1., 52 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False): 53 | super(WarmStartGradientLayer, self).__init__() 54 | self.alpha = alpha 55 | self.lo = lo 56 | self.hi = hi 57 | self.iter_num = 0 58 | self.max_iters = max_iters 59 | self.auto_step = auto_step 60 | 61 | def forward(self, input: torch.Tensor) -> torch.Tensor: 62 | """""" 63 | coeff = np.float( 64 | 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) 65 | - (self.hi - self.lo) + self.lo 66 | ) 67 | if self.auto_step: 68 | self.step() 69 | return GradientFunction.apply(input, coeff) 70 | 71 | def step(self): 72 | """Increase iteration number :math:`i` by 1""" 73 | self.iter_num += 1 74 | -------------------------------------------------------------------------------- /dalib/modules/grl.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from typing import Optional, Any, Tuple 6 | import numpy as np 7 | import torch.nn as nn 8 | from torch.autograd import Function 9 | import torch 10 | 11 | 12 | class GradientReverseFunction(Function): 13 | 14 | @staticmethod 15 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: 16 | ctx.coeff = coeff 17 | output = input * 1.0 18 | return output 19 | 20 | @staticmethod 21 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 22 | return grad_output.neg() * ctx.coeff, None 23 | 24 | 25 | class GradientReverseLayer(nn.Module): 26 | def __init__(self): 27 | super(GradientReverseLayer, self).__init__() 28 | 29 | def forward(self, *input): 30 | return GradientReverseFunction.apply(*input) 31 | 32 | 33 | class WarmStartGradientReverseLayer(nn.Module): 34 | """Gradient Reverse Layer :math:`\mathcal{R}(x)` with warm start 35 | 36 | The forward and backward behaviours are: 37 | 38 | .. math:: 39 | \mathcal{R}(x) = x, 40 | 41 | \dfrac{ d\mathcal{R}} {dx} = - \lambda I. 42 | 43 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule: 44 | 45 | .. math:: 46 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo 47 | 48 | where :math:`i` is the iteration step. 49 | 50 | Args: 51 | alpha (float, optional): :math:`α`. Default: 1.0 52 | lo (float, optional): Initial value of :math:`\lambda`. Default: 0.0 53 | hi (float, optional): Final value of :math:`\lambda`. Default: 1.0 54 | max_iters (int, optional): :math:`N`. Default: 1000 55 | auto_step (bool, optional): If True, increase :math:`i` each time `forward` is called. 56 | Otherwise use function `step` to increase :math:`i`. Default: False 57 | """ 58 | 59 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1., 60 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False): 61 | super(WarmStartGradientReverseLayer, self).__init__() 62 | self.alpha = alpha 63 | self.lo = lo 64 | self.hi = hi 65 | self.iter_num = 0 66 | self.max_iters = max_iters 67 | self.auto_step = auto_step 68 | 69 | def forward(self, input: torch.Tensor) -> torch.Tensor: 70 | """""" 71 | coeff = np.float( 72 | 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) 73 | - (self.hi - self.lo) + self.lo 74 | ) 75 | if self.auto_step: 76 | self.step() 77 | return GradientReverseFunction.apply(input, coeff) 78 | 79 | def step(self): 80 | """Increase iteration number :math:`i` by 1""" 81 | self.iter_num += 1 82 | -------------------------------------------------------------------------------- /dalib/modules/kernels.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | __all__ = ['GaussianKernel'] 7 | 8 | 9 | class GaussianKernel(nn.Module): 10 | r"""Gaussian Kernel Matrix 11 | 12 | Gaussian Kernel k is defined by 13 | 14 | .. math:: 15 | k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right) 16 | 17 | where :math:`x_1, x_2 \in R^d` are 1-d tensors. 18 | 19 | Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),` 20 | 21 | .. math:: 22 | K(X)_{i,j} = k(x_i, x_j) 23 | 24 | Also by default, during training this layer keeps running estimates of the 25 | mean of L2 distances, which are then used to set hyperparameter :math:`\sigma`. 26 | Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`. 27 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 28 | keep running estimates, and use a fixed :math:`\sigma` instead. 29 | 30 | Args: 31 | sigma (float, optional): bandwidth :math:`\sigma`. Default: None 32 | track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`. 33 | Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True`` 34 | alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True`` 35 | 36 | Inputs: 37 | - X (tensor): input group :math:`X` 38 | 39 | Shape: 40 | - Inputs: :math:`(minibatch, F)` where F means the dimension of input features. 41 | - Outputs: :math:`(minibatch, minibatch)` 42 | """ 43 | 44 | def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True, 45 | alpha: Optional[float] = 1.): 46 | super(GaussianKernel, self).__init__() 47 | assert track_running_stats or sigma is not None 48 | self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None 49 | self.track_running_stats = track_running_stats 50 | self.alpha = alpha 51 | 52 | def forward(self, X: torch.Tensor) -> torch.Tensor: 53 | l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2) 54 | 55 | if self.track_running_stats: 56 | self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach()) 57 | 58 | return torch.exp(-l2_distance_square / (2 * self.sigma_square)) 59 | -------------------------------------------------------------------------------- /examples/cdan.py: -------------------------------------------------------------------------------- 1 | # Credits: https://github.com/thuml/Transfer-Learning-Library 2 | import random 3 | import time 4 | import warnings 5 | import sys 6 | import argparse 7 | import shutil 8 | import os.path as osp 9 | import os 10 | import wandb 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.optim import SGD 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.utils.data import DataLoader 18 | import torch.nn.functional as F 19 | 20 | sys.path.append('../') 21 | from dalib.modules.domain_discriminator import DomainDiscriminator 22 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier 23 | from common.utils.data import ForeverDataIterator 24 | from common.utils.metric import accuracy 25 | from common.utils.meter import AverageMeter, ProgressMeter 26 | from common.utils.logger import CompleteLogger 27 | from common.utils.analysis import collect_feature, tsne, a_distance 28 | 29 | sys.path.append('.') 30 | import utils 31 | 32 | 33 | def main(args: argparse.Namespace): 34 | logger = CompleteLogger(args.log, args.phase) 35 | print(args) 36 | 37 | if args.log_results: 38 | wandb.init(project="DA", entity="SDAT", name=args.log_name) 39 | wandb.config.update(args) 40 | print(args) 41 | 42 | if args.seed is not None: 43 | random.seed(args.seed) 44 | torch.manual_seed(args.seed) 45 | cudnn.deterministic = True 46 | warnings.warn('You have chosen to seed training. ' 47 | 'This will turn on the CUDNN deterministic setting, ' 48 | 'which can slow down your training considerably! ' 49 | 'You may see unexpected behavior when restarting ' 50 | 'from checkpoints.') 51 | 52 | cudnn.benchmark = True 53 | device = args.device 54 | # Data loading code 55 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip, 56 | random_color_jitter=False, resize_size=args.resize_size, 57 | norm_mean=args.norm_mean, norm_std=args.norm_std) 58 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, 59 | norm_mean=args.norm_mean, norm_std=args.norm_std) 60 | print("train_transform: ", train_transform) 61 | print("val_transform: ", val_transform) 62 | 63 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ 64 | utils.get_dataset(args.data, args.root, args.source, 65 | args.target, train_transform, val_transform) 66 | train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, 67 | shuffle=True, num_workers=args.workers, drop_last=True) 68 | train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, 69 | shuffle=True, num_workers=args.workers, drop_last=True) 70 | val_loader = DataLoader( 71 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 72 | test_loader = DataLoader( 73 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 74 | 75 | train_source_iter = ForeverDataIterator(train_source_loader) 76 | train_target_iter = ForeverDataIterator(train_target_loader) 77 | 78 | # create model 79 | print("=> using model '{}'".format(args.arch)) 80 | backbone = utils.get_model(args.arch, pretrain=not args.scratch) 81 | print(backbone) 82 | pool_layer = nn.Identity() if args.no_pool else None 83 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, 84 | pool_layer=pool_layer, finetune=not args.scratch).to(device) 85 | classifier_feature_dim = classifier.features_dim 86 | 87 | if args.randomized: 88 | domain_discri = DomainDiscriminator( 89 | args.randomized_dim, hidden_size=1024).to(device) 90 | else: 91 | domain_discri = DomainDiscriminator( 92 | classifier_feature_dim * num_classes, hidden_size=1024).to(device) 93 | 94 | all_parameters = classifier.get_parameters() + domain_discri.get_parameters() 95 | # define optimizer and lr scheduler 96 | optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, 97 | weight_decay=args.weight_decay, nesterov=True) 98 | t_total = args.iters_per_epoch * args.epochs 99 | print("{INFORMATION} The total number of steps is ", t_total) 100 | 101 | lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * 102 | (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) 103 | 104 | # define loss function 105 | domain_adv = ConditionalDomainAdversarialLoss( 106 | domain_discri, entropy_conditioning=args.entropy, 107 | num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized, 108 | randomized_dim=args.randomized_dim 109 | ).to(device) 110 | 111 | # resume from the best checkpoint 112 | if args.phase != 'train': 113 | checkpoint = torch.load( 114 | logger.get_checkpoint_path('best'), map_location='cpu') 115 | classifier.load_state_dict(checkpoint) 116 | 117 | # analysis the model 118 | if args.phase == 'analysis': 119 | # extract features from both domains 120 | feature_extractor = nn.Sequential( 121 | classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) 122 | source_feature = collect_feature( 123 | train_source_loader, feature_extractor, device) 124 | target_feature = collect_feature( 125 | train_target_loader, feature_extractor, device) 126 | # plot t-SNE 127 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') 128 | tsne.visualize(source_feature, target_feature, tSNE_filename) 129 | print("Saving t-SNE to", tSNE_filename) 130 | # calculate A-distance, which is a measure for distribution discrepancy 131 | A_distance = a_distance.calculate( 132 | source_feature, target_feature, device) 133 | print("A-distance =", A_distance) 134 | return 135 | 136 | if args.phase == 'test': 137 | acc1 = utils.validate(test_loader, classifier, args, device) 138 | print(acc1) 139 | return 140 | 141 | # start training 142 | best_acc1 = 0. 143 | for epoch in range(args.epochs): 144 | print("lr_bbone:", lr_scheduler.get_last_lr()[0]) 145 | print("lr_btlnck:", lr_scheduler.get_last_lr()[1]) 146 | if args.log_results: 147 | wandb.log({"lr_bbone": lr_scheduler.get_last_lr()[0], 148 | "lr_btlnck": lr_scheduler.get_last_lr()[1]}) 149 | # train for one epoch 150 | train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, 151 | lr_scheduler, epoch, args) 152 | 153 | # evaluate on validation set 154 | acc1 = utils.validate(val_loader, classifier, args, device) 155 | if args.log_results: 156 | wandb.log({'epoch': epoch, 'val_acc': acc1}) 157 | 158 | # remember best acc@1 and save checkpoint 159 | torch.save(classifier.state_dict(), 160 | logger.get_checkpoint_path('latest')) 161 | if acc1 > best_acc1: 162 | shutil.copy(logger.get_checkpoint_path('latest'), 163 | logger.get_checkpoint_path('best')) 164 | best_acc1 = max(acc1, best_acc1) 165 | 166 | print("best_acc1 = {:3.1f}".format(best_acc1)) 167 | 168 | # evaluate on test set 169 | classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) 170 | acc1 = utils.validate(test_loader, classifier, args, device) 171 | print("test_acc1 = {:3.1f}".format(acc1)) 172 | if args.log_results: 173 | wandb.log({'epoch': epoch, 'test_acc': acc1}) 174 | 175 | logger.close() 176 | 177 | 178 | def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, 179 | domain_adv: ConditionalDomainAdversarialLoss, optimizer: SGD, 180 | lr_scheduler, epoch: int, args: argparse.Namespace): 181 | 182 | batch_time = AverageMeter('Time', ':3.1f') 183 | data_time = AverageMeter('Data', ':3.1f') 184 | losses = AverageMeter('Loss', ':3.2f') 185 | trans_losses = AverageMeter('Trans Loss', ':3.2f') 186 | cls_accs = AverageMeter('Cls Acc', ':3.1f') 187 | domain_accs = AverageMeter('Domain Acc', ':3.1f') 188 | progress = ProgressMeter( 189 | args.iters_per_epoch, 190 | [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs], 191 | prefix="Epoch: [{}]".format(epoch)) 192 | 193 | device = args.device 194 | # switch to train mode 195 | model.train() 196 | domain_adv.train() 197 | 198 | end = time.time() 199 | for i in range(args.iters_per_epoch): 200 | x_s, labels_s = next(train_source_iter) 201 | x_t, _ = next(train_target_iter) 202 | 203 | x_s = x_s.to(device) 204 | x_t = x_t.to(device) 205 | labels_s = labels_s.to(device) 206 | 207 | # measure data loading time 208 | data_time.update(time.time() - end) 209 | 210 | # compute output 211 | x = torch.cat((x_s, x_t), dim=0) 212 | y, f = model(x) 213 | y_s, y_t = y.chunk(2, dim=0) 214 | f_s, f_t = f.chunk(2, dim=0) 215 | 216 | cls_loss = F.cross_entropy(y_s, labels_s) 217 | transfer_loss = domain_adv(y_s, f_s, y_t, f_t) 218 | domain_acc = domain_adv.domain_discriminator_accuracy 219 | loss = cls_loss + transfer_loss * args.trade_off 220 | 221 | cls_acc = accuracy(y_s, labels_s)[0] 222 | if args.log_results: 223 | wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss, 224 | 'transfer_loss': transfer_loss, 'domain_acc': domain_acc}) 225 | 226 | losses.update(loss.item(), x_s.size(0)) 227 | cls_accs.update(cls_acc, x_s.size(0)) 228 | domain_accs.update(domain_acc, x_s.size(0)) 229 | trans_losses.update(transfer_loss.item(), x_s.size(0)) 230 | 231 | # compute gradient and do SGD step 232 | optimizer.zero_grad() 233 | loss.backward() 234 | optimizer.step() 235 | lr_scheduler.step() 236 | 237 | # measure elapsed time 238 | batch_time.update(time.time() - end) 239 | end = time.time() 240 | 241 | if i % args.print_freq == 0: 242 | progress.display(i) 243 | 244 | 245 | if __name__ == '__main__': 246 | parser = argparse.ArgumentParser( 247 | description='CDAN for Unsupervised Domain Adaptation') 248 | # dataset parameters 249 | parser.add_argument('root', metavar='DIR', 250 | help='root path of dataset') 251 | parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), 252 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) + 253 | ' (default: Office31)') 254 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') 255 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') 256 | parser.add_argument('--train-resizing', type=str, default='default') 257 | parser.add_argument('--val-resizing', type=str, default='default') 258 | parser.add_argument('--resize-size', type=int, default=224, 259 | help='the image size after resizing') 260 | parser.add_argument('--no-hflip', action='store_true', 261 | help='no random horizontal flipping during training') 262 | parser.add_argument('--norm-mean', type=float, nargs='+', 263 | default=(0.485, 0.456, 0.406), help='normalization mean') 264 | parser.add_argument('--norm-std', type=float, nargs='+', 265 | default=(0.229, 0.224, 0.225), help='normalization std') 266 | # model parameters 267 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 268 | choices=utils.get_model_names(), 269 | help='backbone architecture: ' + 270 | ' | '.join(utils.get_model_names()) + 271 | ' (default: resnet18)') 272 | parser.add_argument('--bottleneck-dim', default=256, type=int, 273 | help='Dimension of bottleneck') 274 | parser.add_argument('--no-pool', action='store_true', 275 | help='no pool layer after the feature extractor.') 276 | parser.add_argument('--scratch', action='store_true', 277 | help='whether train from scratch.') 278 | parser.add_argument('-r', '--randomized', action='store_true', 279 | help='using randomized multi-linear-map (default: False)') 280 | parser.add_argument('-rd', '--randomized-dim', default=1024, type=int, 281 | help='randomized dimension when using randomized multi-linear-map (default: 1024)') 282 | parser.add_argument('--entropy', default=False, 283 | action='store_true', help='use entropy conditioning') 284 | parser.add_argument('--trade-off', default=1., type=float, 285 | help='the trade-off hyper-parameter for transfer loss') 286 | # training parameters 287 | parser.add_argument('-b', '--batch-size', default=32, type=int, 288 | metavar='N', 289 | help='mini-batch size (default: 32)') 290 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 291 | metavar='LR', help='initial learning rate', dest='lr') 292 | parser.add_argument('--lr-gamma', default=0.001, 293 | type=float, help='parameter for lr scheduler') 294 | parser.add_argument('--lr-decay', default=0.75, 295 | type=float, help='parameter for lr scheduler') 296 | parser.add_argument('--momentum', default=0.9, 297 | type=float, metavar='M', help='momentum') 298 | parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, 299 | metavar='W', help='weight decay (default: 1e-3)', 300 | dest='weight_decay') 301 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 302 | help='number of data loading workers (default: 2)') 303 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 304 | help='number of total epochs to run') 305 | parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, 306 | help='Number of iterations per epoch') 307 | parser.add_argument('-p', '--print-freq', default=100, type=int, 308 | metavar='N', help='print frequency (default: 100)') 309 | parser.add_argument('--seed', default=None, type=int, 310 | help='seed for initializing training. ') 311 | parser.add_argument('--per-class-eval', action='store_true', 312 | help='whether output per-class accuracy during evaluation') 313 | parser.add_argument("--log", type=str, default='cdan', 314 | help="Where to save logs, checkpoints and debugging images.") 315 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], 316 | help="When phase is 'test', only test the model." 317 | "When phase is 'analysis', only analysis the model.") 318 | parser.add_argument('--log_results', action='store_true', 319 | help="To log results in wandb") 320 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID") 321 | parser.add_argument('--log_name', type=str, 322 | default="log", help="log name for wandb") 323 | args = parser.parse_args() 324 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 325 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 326 | args.device = device 327 | main(args) 328 | -------------------------------------------------------------------------------- /examples/cdan_mcc.py: -------------------------------------------------------------------------------- 1 | # Credits: https://github.com/thuml/Transfer-Learning-Library 2 | import random 3 | import time 4 | import warnings 5 | import sys 6 | import argparse 7 | import shutil 8 | import os.path as osp 9 | import os 10 | import wandb 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.optim import SGD 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.utils.data import DataLoader 18 | import torch.nn.functional as F 19 | 20 | sys.path.append('../') 21 | from dalib.modules.domain_discriminator import DomainDiscriminator 22 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier 23 | from dalib.adaptation.mcc import MinimumClassConfusionLoss 24 | from common.utils.data import ForeverDataIterator 25 | from common.utils.metric import accuracy 26 | from common.utils.meter import AverageMeter, ProgressMeter 27 | from common.utils.logger import CompleteLogger 28 | from common.utils.analysis import collect_feature, tsne, a_distance 29 | 30 | sys.path.append('.') 31 | import utils 32 | 33 | 34 | def main(args: argparse.Namespace): 35 | logger = CompleteLogger(args.log, args.phase) 36 | print(args) 37 | 38 | if args.log_results: 39 | wandb.init(project="DA", entity="SDAT", name=args.log_name) 40 | wandb.config.update(args) 41 | print(args) 42 | 43 | if args.seed is not None: 44 | random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | cudnn.deterministic = True 47 | warnings.warn('You have chosen to seed training. ' 48 | 'This will turn on the CUDNN deterministic setting, ' 49 | 'which can slow down your training considerably! ' 50 | 'You may see unexpected behavior when restarting ' 51 | 'from checkpoints.') 52 | 53 | cudnn.benchmark = True 54 | device = args.device 55 | 56 | # Data loading code 57 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip, 58 | random_color_jitter=False, resize_size=args.resize_size, 59 | norm_mean=args.norm_mean, norm_std=args.norm_std) 60 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, 61 | norm_mean=args.norm_mean, norm_std=args.norm_std) 62 | print("train_transform: ", train_transform) 63 | print("val_transform: ", val_transform) 64 | 65 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ 66 | utils.get_dataset(args.data, args.root, args.source, 67 | args.target, train_transform, val_transform) 68 | train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, 69 | shuffle=True, num_workers=args.workers, drop_last=True) 70 | train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, 71 | shuffle=True, num_workers=args.workers, drop_last=True) 72 | val_loader = DataLoader( 73 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 74 | test_loader = DataLoader( 75 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 76 | 77 | train_source_iter = ForeverDataIterator(train_source_loader) 78 | train_target_iter = ForeverDataIterator(train_target_loader) 79 | 80 | # create model 81 | print("=> using model '{}'".format(args.arch)) 82 | backbone = utils.get_model(args.arch, pretrain=not args.scratch) 83 | print(backbone) 84 | pool_layer = nn.Identity() if args.no_pool else None 85 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, 86 | pool_layer=pool_layer, finetune=not args.scratch).to(device) 87 | classifier_feature_dim = classifier.features_dim 88 | 89 | if args.randomized: 90 | domain_discri = DomainDiscriminator( 91 | args.randomized_dim, hidden_size=1024).to(device) 92 | else: 93 | domain_discri = DomainDiscriminator( 94 | classifier_feature_dim * num_classes, hidden_size=1024).to(device) 95 | 96 | all_parameters = classifier.get_parameters() + domain_discri.get_parameters() 97 | 98 | # define optimizer and lr scheduler 99 | optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, 100 | weight_decay=args.weight_decay, nesterov=True) 101 | t_total = args.iters_per_epoch * args.epochs 102 | print("{INFORMATION} The total number of steps is ", t_total) 103 | 104 | lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * 105 | (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) 106 | 107 | # define loss function 108 | domain_adv = ConditionalDomainAdversarialLoss( 109 | domain_discri, entropy_conditioning=args.entropy, 110 | num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized, 111 | randomized_dim=args.randomized_dim 112 | ).to(device) 113 | 114 | mcc_loss = MinimumClassConfusionLoss(temperature=args.temperature) 115 | 116 | # resume from the best checkpoint 117 | if args.phase != 'train': 118 | checkpoint = torch.load( 119 | logger.get_checkpoint_path('best'), map_location='cpu') 120 | classifier.load_state_dict(checkpoint) 121 | 122 | # analysis the model 123 | if args.phase == 'analysis': 124 | # extract features from both domains 125 | feature_extractor = nn.Sequential( 126 | classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) 127 | source_feature = collect_feature( 128 | train_source_loader, feature_extractor, device) 129 | target_feature = collect_feature( 130 | train_target_loader, feature_extractor, device) 131 | # plot t-SNE 132 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') 133 | tsne.visualize(source_feature, target_feature, tSNE_filename) 134 | print("Saving t-SNE to", tSNE_filename) 135 | # calculate A-distance, which is a measure for distribution discrepancy 136 | A_distance = a_distance.calculate( 137 | source_feature, target_feature, device) 138 | print("A-distance =", A_distance) 139 | return 140 | 141 | if args.phase == 'test': 142 | acc1 = utils.validate(test_loader, classifier, args, device) 143 | print(acc1) 144 | return 145 | 146 | # start training 147 | best_acc1 = 0. 148 | for epoch in range(args.epochs): 149 | print("lr_bbone:", lr_scheduler.get_last_lr()[0]) 150 | print("lr_btlnck:", lr_scheduler.get_last_lr()[1]) 151 | if args.log_results: 152 | wandb.log({"lr_bbone": lr_scheduler.get_lr()[0], 153 | "lr_btlnck": lr_scheduler.get_last_lr()[1]}) 154 | # train for one epoch 155 | train(train_source_iter, train_target_iter, classifier, domain_adv, mcc_loss, optimizer, 156 | lr_scheduler, epoch, args) 157 | 158 | # evaluate on validation set 159 | acc1 = utils.validate(val_loader, classifier, args, device) 160 | if args.log_results: 161 | wandb.log({'epoch': epoch, 'val_acc': acc1}) 162 | 163 | # remember best acc@1 and save checkpoint 164 | torch.save(classifier.state_dict(), 165 | logger.get_checkpoint_path('latest')) 166 | if acc1 > best_acc1: 167 | shutil.copy(logger.get_checkpoint_path('latest'), 168 | logger.get_checkpoint_path('best')) 169 | best_acc1 = max(acc1, best_acc1) 170 | 171 | print("best_acc1 = {:3.1f}".format(best_acc1)) 172 | 173 | # evaluate on test set 174 | classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) 175 | acc1 = utils.validate(test_loader, classifier, args, device) 176 | print("test_acc1 = {:3.1f}".format(acc1)) 177 | if args.log_results: 178 | wandb.log({'epoch': epoch, 'test_acc': acc1}) 179 | 180 | logger.close() 181 | 182 | 183 | def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, 184 | domain_adv: ConditionalDomainAdversarialLoss, mcc, optimizer: SGD, 185 | lr_scheduler, epoch: int, args: argparse.Namespace): 186 | batch_time = AverageMeter('Time', ':3.1f') 187 | data_time = AverageMeter('Data', ':3.1f') 188 | losses = AverageMeter('Loss', ':3.2f') 189 | trans_losses = AverageMeter('Trans Loss', ':3.2f') 190 | cls_accs = AverageMeter('Cls Acc', ':3.1f') 191 | domain_accs = AverageMeter('Domain Acc', ':3.1f') 192 | progress = ProgressMeter( 193 | args.iters_per_epoch, 194 | [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs], 195 | prefix="Epoch: [{}]".format(epoch)) 196 | 197 | # switch to train mode 198 | model.train() 199 | domain_adv.train() 200 | 201 | end = time.time() 202 | for i in range(args.iters_per_epoch): 203 | x_s, labels_s = next(train_source_iter) 204 | x_t, _ = next(train_target_iter) 205 | 206 | x_s = x_s.to(device) 207 | x_t = x_t.to(device) 208 | labels_s = labels_s.to(device) 209 | 210 | # measure data loading time 211 | data_time.update(time.time() - end) 212 | 213 | # compute output 214 | x = torch.cat((x_s, x_t), dim=0) 215 | y, f = model(x) 216 | y_s, y_t = y.chunk(2, dim=0) 217 | f_s, f_t = f.chunk(2, dim=0) 218 | 219 | cls_loss = F.cross_entropy(y_s, labels_s) 220 | transfer_loss = domain_adv(y_s, f_s, y_t, f_t) + mcc(y_t) 221 | mcc_loss_value = mcc(y_t) 222 | domain_acc = domain_adv.domain_discriminator_accuracy 223 | loss = cls_loss + transfer_loss * args.trade_off 224 | cls_acc = accuracy(y_s, labels_s)[0] 225 | if args.log_results: 226 | wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss, 227 | 'transfer_loss': transfer_loss, 'iteration': epoch*args.iters_per_epoch + i, 228 | 'domain_acc': domain_acc, 'mcc_loss': mcc_loss_value}) 229 | 230 | losses.update(loss.item(), x_s.size(0)) 231 | cls_accs.update(cls_acc, x_s.size(0)) 232 | domain_accs.update(domain_acc, x_s.size(0)) 233 | trans_losses.update(transfer_loss.item(), x_s.size(0)) 234 | 235 | # compute gradient and do SGD step 236 | optimizer.zero_grad() 237 | loss.backward() 238 | optimizer.step() 239 | lr_scheduler.step() 240 | 241 | # measure elapsed time 242 | batch_time.update(time.time() - end) 243 | end = time.time() 244 | 245 | if i % args.print_freq == 0: 246 | progress.display(i) 247 | 248 | 249 | if __name__ == '__main__': 250 | parser = argparse.ArgumentParser( 251 | description='CDAN+MCC for Unsupervised Domain Adaptation') 252 | # dataset parameters 253 | parser.add_argument('root', metavar='DIR', 254 | help='root path of dataset') 255 | parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), 256 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) + 257 | ' (default: Office31)') 258 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') 259 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') 260 | parser.add_argument('--train-resizing', type=str, default='default') 261 | parser.add_argument('--val-resizing', type=str, default='default') 262 | parser.add_argument('--resize-size', type=int, default=224, 263 | help='the image size after resizing') 264 | parser.add_argument('--no-hflip', action='store_true', 265 | help='no random horizontal flipping during training') 266 | parser.add_argument('--norm-mean', type=float, nargs='+', 267 | default=(0.485, 0.456, 0.406), help='normalization mean') 268 | parser.add_argument('--norm-std', type=float, nargs='+', 269 | default=(0.229, 0.224, 0.225), help='normalization std') 270 | # model parameters 271 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 272 | choices=utils.get_model_names(), 273 | help='backbone architecture: ' + 274 | ' | '.join(utils.get_model_names()) + 275 | ' (default: resnet18)') 276 | parser.add_argument('--bottleneck-dim', default=256, type=int, 277 | help='Dimension of bottleneck') 278 | parser.add_argument('--no-pool', action='store_true', 279 | help='no pool layer after the feature extractor.') 280 | parser.add_argument('--scratch', action='store_true', 281 | help='whether train from scratch.') 282 | parser.add_argument('-r', '--randomized', action='store_true', 283 | help='using randomized multi-linear-map (default: False)') 284 | parser.add_argument('-rd', '--randomized-dim', default=1024, type=int, 285 | help='randomized dimension when using randomized multi-linear-map (default: 1024)') 286 | parser.add_argument('--entropy', default=False, 287 | action='store_true', help='use entropy conditioning') 288 | parser.add_argument('--trade-off', default=1., type=float, 289 | help='the trade-off hyper-parameter for transfer loss') 290 | # training parameters 291 | parser.add_argument('-b', '--batch-size', default=32, type=int, 292 | metavar='N', 293 | help='mini-batch size (default: 32)') 294 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 295 | metavar='LR', help='initial learning rate', dest='lr') 296 | parser.add_argument('--lr-gamma', default=0.001, 297 | type=float, help='parameter for lr scheduler') 298 | parser.add_argument('--lr-decay', default=0.75, 299 | type=float, help='parameter for lr scheduler') 300 | parser.add_argument('--momentum', default=0.9, 301 | type=float, metavar='M', help='momentum') 302 | parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, 303 | metavar='W', help='weight decay (default: 1e-3)', 304 | dest='weight_decay') 305 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 306 | help='number of data loading workers (default: 2)') 307 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 308 | help='number of total epochs to run') 309 | parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, 310 | help='Number of iterations per epoch') 311 | parser.add_argument('-p', '--print-freq', default=100, type=int, 312 | metavar='N', help='print frequency (default: 100)') 313 | parser.add_argument('--seed', default=None, type=int, 314 | help='seed for initializing training. ') 315 | parser.add_argument('--per-class-eval', action='store_true', 316 | help='whether output per-class accuracy during evaluation') 317 | parser.add_argument("--log", type=str, default='cdan', 318 | help="Where to save logs, checkpoints and debugging images.") 319 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], 320 | help="When phase is 'test', only test the model." 321 | "When phase is 'analysis', only analysis the model.") 322 | parser.add_argument('--log_results', action='store_true', 323 | help="To log results in wandb") 324 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID") 325 | parser.add_argument('--log_name', type=str, 326 | default="log", help="log name for wandb") 327 | parser.add_argument('--temperature', default=2.0, 328 | type=float, help='parameter temperature scaling') 329 | args = parser.parse_args() 330 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 331 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 332 | args.device = device 333 | main(args) 334 | -------------------------------------------------------------------------------- /examples/cdan_mcc_sdat.py: -------------------------------------------------------------------------------- 1 | # Credits: https://github.com/thuml/Transfer-Learning-Library 2 | import random 3 | import time 4 | import warnings 5 | import sys 6 | import argparse 7 | import shutil 8 | import os.path as osp 9 | import os 10 | import wandb 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.optim import SGD 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.utils.data import DataLoader 18 | import torch.nn.functional as F 19 | 20 | sys.path.append('../') 21 | from dalib.modules.domain_discriminator import DomainDiscriminator 22 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier 23 | from dalib.adaptation.mcc import MinimumClassConfusionLoss 24 | from common.utils.data import ForeverDataIterator 25 | from common.utils.metric import accuracy 26 | from common.utils.meter import AverageMeter, ProgressMeter 27 | from common.utils.logger import CompleteLogger 28 | from common.utils.analysis import collect_feature, tsne, a_distance 29 | from common.utils.sam import SAM 30 | 31 | sys.path.append('.') 32 | import utils 33 | 34 | 35 | def main(args: argparse.Namespace): 36 | logger = CompleteLogger(args.log, args.phase) 37 | print(args) 38 | 39 | if args.log_results: 40 | wandb.init(project="DA", entity="SDAT", name=args.log_name) 41 | wandb.config.update(args) 42 | print(args) 43 | 44 | if args.seed is not None: 45 | random.seed(args.seed) 46 | torch.manual_seed(args.seed) 47 | cudnn.deterministic = True 48 | warnings.warn('You have chosen to seed training. ' 49 | 'This will turn on the CUDNN deterministic setting, ' 50 | 'which can slow down your training considerably! ' 51 | 'You may see unexpected behavior when restarting ' 52 | 'from checkpoints.') 53 | 54 | cudnn.benchmark = True 55 | device = args.device 56 | 57 | # Data loading code 58 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip, 59 | random_color_jitter=False, resize_size=args.resize_size, 60 | norm_mean=args.norm_mean, norm_std=args.norm_std) 61 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, 62 | norm_mean=args.norm_mean, norm_std=args.norm_std) 63 | print("train_transform: ", train_transform) 64 | print("val_transform: ", val_transform) 65 | 66 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ 67 | utils.get_dataset(args.data, args.root, args.source, 68 | args.target, train_transform, val_transform) 69 | train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, 70 | shuffle=True, num_workers=args.workers, drop_last=True) 71 | train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, 72 | shuffle=True, num_workers=args.workers, drop_last=True) 73 | val_loader = DataLoader( 74 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 75 | test_loader = DataLoader( 76 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 77 | 78 | train_source_iter = ForeverDataIterator(train_source_loader) 79 | train_target_iter = ForeverDataIterator(train_target_loader) 80 | 81 | # create model 82 | print("=> using model '{}'".format(args.arch)) 83 | backbone = utils.get_model(args.arch, pretrain=not args.scratch) 84 | print(backbone) 85 | pool_layer = nn.Identity() if args.no_pool else None 86 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, 87 | pool_layer=pool_layer, finetune=not args.scratch).to(device) 88 | classifier_feature_dim = classifier.features_dim 89 | 90 | if args.randomized: 91 | domain_discri = DomainDiscriminator( 92 | args.randomized_dim, hidden_size=1024).to(device) 93 | else: 94 | domain_discri = DomainDiscriminator( 95 | classifier_feature_dim * num_classes, hidden_size=1024).to(device) 96 | 97 | # define optimizer and lr scheduler 98 | base_optimizer = torch.optim.SGD 99 | ad_optimizer = SGD(domain_discri.get_parameters( 100 | ), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 101 | optimizer = SAM(classifier.get_parameters(), base_optimizer, rho=args.rho, adaptive=False, 102 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 103 | lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * 104 | (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) 105 | lr_scheduler_ad = LambdaLR( 106 | ad_optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) 107 | 108 | # define loss function 109 | domain_adv = ConditionalDomainAdversarialLoss( 110 | domain_discri, entropy_conditioning=args.entropy, 111 | num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized, 112 | randomized_dim=args.randomized_dim 113 | ).to(device) 114 | 115 | mcc_loss = MinimumClassConfusionLoss(temperature=args.temperature) 116 | 117 | # resume from the best checkpoint 118 | if args.phase != 'train': 119 | checkpoint = torch.load( 120 | logger.get_checkpoint_path('best'), map_location='cpu') 121 | classifier.load_state_dict(checkpoint) 122 | 123 | # analysis the model 124 | if args.phase == 'analysis': 125 | # extract features from both domains 126 | feature_extractor = nn.Sequential( 127 | classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) 128 | source_feature = collect_feature( 129 | train_source_loader, feature_extractor, device) 130 | target_feature = collect_feature( 131 | train_target_loader, feature_extractor, device) 132 | # plot t-SNE 133 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') 134 | tsne.visualize(source_feature, target_feature, tSNE_filename) 135 | print("Saving t-SNE to", tSNE_filename) 136 | # calculate A-distance, which is a measure for distribution discrepancy 137 | A_distance = a_distance.calculate( 138 | source_feature, target_feature, device) 139 | print("A-distance =", A_distance) 140 | return 141 | 142 | if args.phase == 'test': 143 | acc1 = utils.validate(test_loader, classifier, args, device) 144 | print(acc1) 145 | return 146 | 147 | # start training 148 | best_acc1 = 0. 149 | for epoch in range(args.epochs): 150 | print("lr_bbone:", lr_scheduler.get_last_lr()[0]) 151 | print("lr_btlnck:", lr_scheduler.get_last_lr()[1]) 152 | if args.log_results: 153 | wandb.log({"lr_bbone": lr_scheduler.get_last_lr()[0], 154 | "lr_btlnck": lr_scheduler.get_last_lr()[1]}) 155 | # train for one epoch 156 | 157 | train(train_source_iter, train_target_iter, classifier, domain_adv, mcc_loss, optimizer, ad_optimizer, 158 | lr_scheduler, lr_scheduler_ad, epoch, args) 159 | # evaluate on validation set 160 | acc1 = utils.validate(val_loader, classifier, args, device) 161 | if args.log_results: 162 | wandb.log({'epoch': epoch, 'val_acc': acc1}) 163 | 164 | # remember best acc@1 and save checkpoint 165 | torch.save(classifier.state_dict(), 166 | logger.get_checkpoint_path('latest')) 167 | if acc1 > best_acc1: 168 | shutil.copy(logger.get_checkpoint_path('latest'), 169 | logger.get_checkpoint_path('best')) 170 | best_acc1 = max(acc1, best_acc1) 171 | 172 | print("best_acc1 = {:3.1f}".format(best_acc1)) 173 | 174 | # evaluate on test set 175 | classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) 176 | acc1 = utils.validate(test_loader, classifier, args, device) 177 | print("test_acc1 = {:3.1f}".format(acc1)) 178 | if args.log_results: 179 | wandb.log({'epoch': epoch, 'test_acc': acc1}) 180 | 181 | logger.close() 182 | 183 | 184 | def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, 185 | domain_adv: ConditionalDomainAdversarialLoss, mcc, optimizer, ad_optimizer, 186 | lr_scheduler: LambdaLR, lr_scheduler_ad, epoch: int, args: argparse.Namespace): 187 | batch_time = AverageMeter('Time', ':3.1f') 188 | data_time = AverageMeter('Data', ':3.1f') 189 | losses = AverageMeter('Loss', ':3.2f') 190 | trans_losses = AverageMeter('Trans Loss', ':3.2f') 191 | cls_accs = AverageMeter('Cls Acc', ':3.1f') 192 | domain_accs = AverageMeter('Domain Acc', ':3.1f') 193 | progress = ProgressMeter( 194 | args.iters_per_epoch, 195 | [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs], 196 | prefix="Epoch: [{}]".format(epoch)) 197 | 198 | # switch to train mode 199 | model.train() 200 | domain_adv.train() 201 | 202 | end = time.time() 203 | for i in range(args.iters_per_epoch): 204 | x_s, labels_s = next(train_source_iter) 205 | x_t, _ = next(train_target_iter) 206 | 207 | x_s = x_s.to(device) 208 | x_t = x_t.to(device) 209 | labels_s = labels_s.to(device) 210 | 211 | # measure data loading time 212 | data_time.update(time.time() - end) 213 | optimizer.zero_grad() 214 | ad_optimizer.zero_grad() 215 | 216 | # compute output 217 | x = torch.cat((x_s, x_t), dim=0) 218 | y, f = model(x) 219 | y_s, y_t = y.chunk(2, dim=0) 220 | f_s, f_t = f.chunk(2, dim=0) 221 | cls_loss = F.cross_entropy(y_s, labels_s) 222 | mcc_loss_value = mcc(y_t) 223 | loss = cls_loss + mcc_loss_value 224 | 225 | loss.backward() 226 | 227 | # Calculate ϵ̂ (w) and add it to the weights 228 | optimizer.first_step(zero_grad=True) 229 | 230 | # Calculate task loss and domain loss 231 | y, f = model(x) 232 | y_s, y_t = y.chunk(2, dim=0) 233 | f_s, f_t = f.chunk(2, dim=0) 234 | 235 | cls_loss = F.cross_entropy(y_s, labels_s) 236 | transfer_loss = domain_adv(y_s, f_s, y_t, f_t) + mcc(y_t) 237 | domain_acc = domain_adv.domain_discriminator_accuracy 238 | loss = cls_loss + transfer_loss * args.trade_off 239 | 240 | cls_acc = accuracy(y_s, labels_s)[0] 241 | if args.log_results: 242 | wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss, 243 | 'transfer_loss': transfer_loss, 'domain_acc': domain_acc}) 244 | 245 | losses.update(loss.item(), x_s.size(0)) 246 | cls_accs.update(cls_acc, x_s.size(0)) 247 | domain_accs.update(domain_acc, x_s.size(0)) 248 | trans_losses.update(transfer_loss.item(), x_s.size(0)) 249 | 250 | loss.backward() 251 | # Update parameters of domain classifier 252 | ad_optimizer.step() 253 | # Update parameters (Sharpness-Aware update) 254 | optimizer.second_step(zero_grad=True) 255 | lr_scheduler.step() 256 | lr_scheduler_ad.step() 257 | 258 | # measure elapsed time 259 | batch_time.update(time.time() - end) 260 | end = time.time() 261 | 262 | if i % args.print_freq == 0: 263 | progress.display(i) 264 | 265 | 266 | if __name__ == '__main__': 267 | parser = argparse.ArgumentParser( 268 | description='CDAN+MCC with SDAT for Unsupervised Domain Adaptation') 269 | # dataset parameters 270 | parser.add_argument('root', metavar='DIR', 271 | help='root path of dataset') 272 | parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), 273 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) + 274 | ' (default: Office31)') 275 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') 276 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') 277 | parser.add_argument('--train-resizing', type=str, default='default') 278 | parser.add_argument('--val-resizing', type=str, default='default') 279 | parser.add_argument('--resize-size', type=int, default=224, 280 | help='the image size after resizing') 281 | parser.add_argument('--no-hflip', action='store_true', 282 | help='no random horizontal flipping during training') 283 | parser.add_argument('--norm-mean', type=float, nargs='+', 284 | default=(0.485, 0.456, 0.406), help='normalization mean') 285 | parser.add_argument('--norm-std', type=float, nargs='+', 286 | default=(0.229, 0.224, 0.225), help='normalization std') 287 | # model parameters 288 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 289 | choices=utils.get_model_names(), 290 | help='backbone architecture: ' + 291 | ' | '.join(utils.get_model_names()) + 292 | ' (default: resnet18)') 293 | parser.add_argument('--bottleneck-dim', default=256, type=int, 294 | help='Dimension of bottleneck') 295 | parser.add_argument('--no-pool', action='store_true', 296 | help='no pool layer after the feature extractor.') 297 | parser.add_argument('--scratch', action='store_true', 298 | help='whether train from scratch.') 299 | parser.add_argument('-r', '--randomized', action='store_true', 300 | help='using randomized multi-linear-map (default: False)') 301 | parser.add_argument('-rd', '--randomized-dim', default=1024, type=int, 302 | help='randomized dimension when using randomized multi-linear-map (default: 1024)') 303 | parser.add_argument('--entropy', default=False, 304 | action='store_true', help='use entropy conditioning') 305 | parser.add_argument('--trade-off', default=1., type=float, 306 | help='the trade-off hyper-parameter for transfer loss') 307 | # training parameters 308 | parser.add_argument('-b', '--batch-size', default=32, type=int, 309 | metavar='N', 310 | help='mini-batch size (default: 32)') 311 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 312 | metavar='LR', help='initial learning rate', dest='lr') 313 | parser.add_argument('--lr-gamma', default=0.001, 314 | type=float, help='parameter for lr scheduler') 315 | parser.add_argument('--lr-decay', default=0.75, 316 | type=float, help='parameter for lr scheduler') 317 | parser.add_argument('--momentum', default=0.9, 318 | type=float, metavar='M', help='momentum') 319 | parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, 320 | metavar='W', help='weight decay (default: 1e-3)', 321 | dest='weight_decay') 322 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 323 | help='number of data loading workers (default: 2)') 324 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 325 | help='number of total epochs to run') 326 | parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, 327 | help='Number of iterations per epoch') 328 | parser.add_argument('-p', '--print-freq', default=100, type=int, 329 | metavar='N', help='print frequency (default: 100)') 330 | parser.add_argument('--seed', default=None, type=int, 331 | help='seed for initializing training. ') 332 | parser.add_argument('--per-class-eval', action='store_true', 333 | help='whether output per-class accuracy during evaluation') 334 | parser.add_argument("--log", type=str, default='cdan', 335 | help="Where to save logs, checkpoints and debugging images.") 336 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], 337 | help="When phase is 'test', only test the model." 338 | "When phase is 'analysis', only analysis the model.") 339 | parser.add_argument('--log_results', action='store_true', 340 | help="To log results in wandb") 341 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID") 342 | parser.add_argument('--log_name', type=str, 343 | default="log", help="log name for wandb") 344 | parser.add_argument('--rho', type=float, default=0.05, help="GPU ID") 345 | parser.add_argument('--temperature', default=2.0, 346 | type=float, help='parameter temperature scaling') 347 | args = parser.parse_args() 348 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 349 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 350 | args.device = device 351 | main(args) 352 | -------------------------------------------------------------------------------- /examples/cdan_sdat.py: -------------------------------------------------------------------------------- 1 | # Credits: https://github.com/thuml/Transfer-Learning-Library 2 | import random 3 | import time 4 | import warnings 5 | import sys 6 | import argparse 7 | import shutil 8 | import os.path as osp 9 | import os 10 | import wandb 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.optim import SGD 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.utils.data import DataLoader 18 | import torch.nn.functional as F 19 | 20 | sys.path.append('../') 21 | from dalib.modules.domain_discriminator import DomainDiscriminator 22 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier 23 | from common.utils.data import ForeverDataIterator 24 | from common.utils.metric import accuracy 25 | from common.utils.meter import AverageMeter, ProgressMeter 26 | from common.utils.logger import CompleteLogger 27 | from common.utils.analysis import collect_feature, tsne, a_distance 28 | from common.utils.sam import SAM 29 | 30 | sys.path.append('.') 31 | import utils 32 | 33 | 34 | def main(args: argparse.Namespace): 35 | logger = CompleteLogger(args.log, args.phase) 36 | print(args) 37 | 38 | if args.log_results: 39 | wandb.init(project="DA", entity="SDAT", name=args.log_name) 40 | wandb.config.update(args) 41 | print(args) 42 | 43 | if args.seed is not None: 44 | random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | cudnn.deterministic = True 47 | warnings.warn('You have chosen to seed training. ' 48 | 'This will turn on the CUDNN deterministic setting, ' 49 | 'which can slow down your training considerably! ' 50 | 'You may see unexpected behavior when restarting ' 51 | 'from checkpoints.') 52 | 53 | cudnn.benchmark = True 54 | device = args.device 55 | 56 | # Data loading code 57 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip, 58 | random_color_jitter=False, resize_size=args.resize_size, 59 | norm_mean=args.norm_mean, norm_std=args.norm_std) 60 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, 61 | norm_mean=args.norm_mean, norm_std=args.norm_std) 62 | print("train_transform: ", train_transform) 63 | print("val_transform: ", val_transform) 64 | 65 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ 66 | utils.get_dataset(args.data, args.root, args.source, 67 | args.target, train_transform, val_transform) 68 | train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, 69 | shuffle=True, num_workers=args.workers, drop_last=True) 70 | train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, 71 | shuffle=True, num_workers=args.workers, drop_last=True) 72 | val_loader = DataLoader( 73 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 74 | test_loader = DataLoader( 75 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 76 | 77 | train_source_iter = ForeverDataIterator(train_source_loader) 78 | train_target_iter = ForeverDataIterator(train_target_loader) 79 | 80 | # create model 81 | print("=> using model '{}'".format(args.arch)) 82 | backbone = utils.get_model(args.arch, pretrain=not args.scratch) 83 | print(backbone) 84 | pool_layer = nn.Identity() if args.no_pool else None 85 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, 86 | pool_layer=pool_layer, finetune=not args.scratch).to(device) 87 | classifier_feature_dim = classifier.features_dim 88 | 89 | if args.randomized: 90 | domain_discri = DomainDiscriminator( 91 | args.randomized_dim, hidden_size=1024).to(device) 92 | else: 93 | domain_discri = DomainDiscriminator( 94 | classifier_feature_dim * num_classes, hidden_size=1024).to(device) 95 | 96 | # define optimizer and lr scheduler 97 | base_optimizer = torch.optim.SGD 98 | ad_optimizer = SGD(domain_discri.get_parameters( 99 | ), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 100 | optimizer = SAM(classifier.get_parameters(), base_optimizer, rho=args.rho, adaptive=False, 101 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 102 | lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * 103 | (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) 104 | lr_scheduler_ad = LambdaLR( 105 | ad_optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) 106 | 107 | # define loss function 108 | domain_adv = ConditionalDomainAdversarialLoss( 109 | domain_discri, entropy_conditioning=args.entropy, 110 | num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized, 111 | randomized_dim=args.randomized_dim 112 | ).to(device) 113 | 114 | # resume from the best checkpoint 115 | if args.phase != 'train': 116 | checkpoint = torch.load( 117 | logger.get_checkpoint_path('best'), map_location='cpu') 118 | classifier.load_state_dict(checkpoint) 119 | 120 | # analysis the model 121 | if args.phase == 'analysis': 122 | # extract features from both domains 123 | feature_extractor = nn.Sequential( 124 | classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) 125 | source_feature = collect_feature( 126 | train_source_loader, feature_extractor, device) 127 | target_feature = collect_feature( 128 | train_target_loader, feature_extractor, device) 129 | # plot t-SNE 130 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') 131 | tsne.visualize(source_feature, target_feature, tSNE_filename) 132 | print("Saving t-SNE to", tSNE_filename) 133 | # calculate A-distance, which is a measure for distribution discrepancy 134 | A_distance = a_distance.calculate( 135 | source_feature, target_feature, device) 136 | print("A-distance =", A_distance) 137 | return 138 | 139 | if args.phase == 'test': 140 | acc1 = utils.validate(test_loader, classifier, args, device) 141 | print(acc1) 142 | return 143 | 144 | # start training 145 | best_acc1 = 0. 146 | for epoch in range(args.epochs): 147 | print("lr_bbone:", lr_scheduler.get_last_lr()[0]) 148 | print("lr_btlnck:", lr_scheduler.get_last_lr()[1]) 149 | if args.log_results: 150 | wandb.log({"lr_bbone": lr_scheduler.get_last_lr()[0], 151 | "lr_btlnck": lr_scheduler.get_last_lr()[1]}) 152 | 153 | # train for one epoch 154 | train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, ad_optimizer, 155 | lr_scheduler, lr_scheduler_ad, epoch, args) 156 | # evaluate on validation set 157 | acc1 = utils.validate(val_loader, classifier, args, device) 158 | if args.log_results: 159 | wandb.log({'epoch': epoch, 'val_acc': acc1}) 160 | 161 | # remember best acc@1 and save checkpoint 162 | torch.save(classifier.state_dict(), 163 | logger.get_checkpoint_path('latest')) 164 | if acc1 > best_acc1: 165 | shutil.copy(logger.get_checkpoint_path('latest'), 166 | logger.get_checkpoint_path('best')) 167 | best_acc1 = max(acc1, best_acc1) 168 | 169 | print("best_acc1 = {:3.1f}".format(best_acc1)) 170 | 171 | # evaluate on test set 172 | classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) 173 | acc1 = utils.validate(test_loader, classifier, args, device) 174 | print("test_acc1 = {:3.1f}".format(acc1)) 175 | if args.log_results: 176 | wandb.log({'epoch': epoch, 'test_acc': acc1}) 177 | 178 | logger.close() 179 | 180 | 181 | def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, 182 | domain_adv: ConditionalDomainAdversarialLoss, optimizer, ad_optimizer, 183 | lr_scheduler: LambdaLR, lr_scheduler_ad, epoch: int, args: argparse.Namespace): 184 | batch_time = AverageMeter('Time', ':3.1f') 185 | data_time = AverageMeter('Data', ':3.1f') 186 | losses = AverageMeter('Loss', ':3.2f') 187 | trans_losses = AverageMeter('Trans Loss', ':3.2f') 188 | cls_accs = AverageMeter('Cls Acc', ':3.1f') 189 | domain_accs = AverageMeter('Domain Acc', ':3.1f') 190 | progress = ProgressMeter( 191 | args.iters_per_epoch, 192 | [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs], 193 | prefix="Epoch: [{}]".format(epoch)) 194 | 195 | device = args.device 196 | # switch to train mode 197 | model.train() 198 | domain_adv.train() 199 | 200 | end = time.time() 201 | for i in range(args.iters_per_epoch): 202 | x_s, labels_s = next(train_source_iter) 203 | x_t, _ = next(train_target_iter) 204 | 205 | x_s = x_s.to(device) 206 | x_t = x_t.to(device) 207 | labels_s = labels_s.to(device) 208 | 209 | # measure data loading time 210 | data_time.update(time.time() - end) 211 | optimizer.zero_grad() 212 | ad_optimizer.zero_grad() 213 | 214 | # compute task loss for first step 215 | x = torch.cat((x_s, x_t), dim=0) 216 | y, f = model(x) 217 | y_s, y_t = y.chunk(2, dim=0) 218 | f_s, f_t = f.chunk(2, dim=0) 219 | cls_loss = F.cross_entropy(y_s, labels_s) 220 | loss = cls_loss 221 | loss.backward() 222 | 223 | # Calculate ϵ̂ (w) and add it to the weights 224 | optimizer.first_step(zero_grad=True) 225 | 226 | # Calculate task loss and domain loss 227 | y, f = model(x) 228 | y_s, y_t = y.chunk(2, dim=0) 229 | f_s, f_t = f.chunk(2, dim=0) 230 | 231 | cls_loss = F.cross_entropy(y_s, labels_s) 232 | transfer_loss = domain_adv(y_s, f_s, y_t, f_t) 233 | domain_acc = domain_adv.domain_discriminator_accuracy 234 | loss = cls_loss + transfer_loss * args.trade_off 235 | 236 | cls_acc = accuracy(y_s, labels_s)[0] 237 | if args.log_results: 238 | wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss, 239 | 'transfer_loss': transfer_loss, 'domain_acc': domain_acc}) 240 | 241 | losses.update(loss.item(), x_s.size(0)) 242 | cls_accs.update(cls_acc, x_s.size(0)) 243 | domain_accs.update(domain_acc, x_s.size(0)) 244 | trans_losses.update(transfer_loss.item(), x_s.size(0)) 245 | 246 | loss.backward() 247 | # Update parameters of domain classifier 248 | ad_optimizer.step() 249 | # Update parameters (Sharpness-Aware update) 250 | optimizer.second_step(zero_grad=True) 251 | lr_scheduler.step() 252 | lr_scheduler_ad.step() 253 | 254 | # measure elapsed time 255 | batch_time.update(time.time() - end) 256 | end = time.time() 257 | 258 | if i % args.print_freq == 0: 259 | progress.display(i) 260 | 261 | 262 | if __name__ == '__main__': 263 | parser = argparse.ArgumentParser( 264 | description='CDAN with SDAT for Unsupervised Domain Adaptation') 265 | # dataset parameters 266 | parser.add_argument('root', metavar='DIR', 267 | help='root path of dataset') 268 | parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), 269 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) + 270 | ' (default: Office31)') 271 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') 272 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') 273 | parser.add_argument('--train-resizing', type=str, default='default') 274 | parser.add_argument('--val-resizing', type=str, default='default') 275 | parser.add_argument('--resize-size', type=int, default=224, 276 | help='the image size after resizing') 277 | parser.add_argument('--no-hflip', action='store_true', 278 | help='no random horizontal flipping during training') 279 | parser.add_argument('--norm-mean', type=float, nargs='+', 280 | default=(0.485, 0.456, 0.406), help='normalization mean') 281 | parser.add_argument('--norm-std', type=float, nargs='+', 282 | default=(0.229, 0.224, 0.225), help='normalization std') 283 | # model parameters 284 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 285 | choices=utils.get_model_names(), 286 | help='backbone architecture: ' + 287 | ' | '.join(utils.get_model_names()) + 288 | ' (default: resnet18)') 289 | parser.add_argument('--bottleneck-dim', default=256, type=int, 290 | help='Dimension of bottleneck') 291 | parser.add_argument('--no-pool', action='store_true', 292 | help='no pool layer after the feature extractor.') 293 | parser.add_argument('--scratch', action='store_true', 294 | help='whether train from scratch.') 295 | parser.add_argument('-r', '--randomized', action='store_true', 296 | help='using randomized multi-linear-map (default: False)') 297 | parser.add_argument('-rd', '--randomized-dim', default=1024, type=int, 298 | help='randomized dimension when using randomized multi-linear-map (default: 1024)') 299 | parser.add_argument('--entropy', default=False, 300 | action='store_true', help='use entropy conditioning') 301 | parser.add_argument('--trade-off', default=1., type=float, 302 | help='the trade-off hyper-parameter for transfer loss') 303 | # training parameters 304 | parser.add_argument('-b', '--batch-size', default=32, type=int, 305 | metavar='N', 306 | help='mini-batch size (default: 32)') 307 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 308 | metavar='LR', help='initial learning rate', dest='lr') 309 | parser.add_argument('--lr-gamma', default=0.001, 310 | type=float, help='parameter for lr scheduler') 311 | parser.add_argument('--lr-decay', default=0.75, 312 | type=float, help='parameter for lr scheduler') 313 | parser.add_argument('--momentum', default=0.9, 314 | type=float, metavar='M', help='momentum') 315 | parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, 316 | metavar='W', help='weight decay (default: 1e-3)', 317 | dest='weight_decay') 318 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 319 | help='number of data loading workers (default: 2)') 320 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 321 | help='number of total epochs to run') 322 | parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, 323 | help='Number of iterations per epoch') 324 | parser.add_argument('-p', '--print-freq', default=100, type=int, 325 | metavar='N', help='print frequency (default: 100)') 326 | parser.add_argument('--seed', default=None, type=int, 327 | help='seed for initializing training. ') 328 | parser.add_argument('--per-class-eval', action='store_true', 329 | help='whether output per-class accuracy during evaluation') 330 | parser.add_argument("--log", type=str, default='cdan', 331 | help="Where to save logs, checkpoints and debugging images.") 332 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], 333 | help="When phase is 'test', only test the model." 334 | "When phase is 'analysis', only analysis the model.") 335 | parser.add_argument('--log_results', action='store_true', 336 | help="To log results in wandb") 337 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID") 338 | parser.add_argument('--log_name', type=str, 339 | default="log", help="log name for wandb") 340 | parser.add_argument('--rho', type=float, default=0.05, help="GPU ID") 341 | args = parser.parse_args() 342 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 343 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 344 | args.device = device 345 | main(args) 346 | -------------------------------------------------------------------------------- /examples/eval.py: -------------------------------------------------------------------------------- 1 | # Credits: https://github.com/thuml/Transfer-Learning-Library 2 | import random 3 | import time 4 | import warnings 5 | import sys 6 | import argparse 7 | import shutil 8 | import os.path as osp 9 | import os 10 | import wandb 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.utils.data import DataLoader 16 | import torch.nn.functional as F 17 | 18 | sys.path.append('../') 19 | from dalib.adaptation.cdan import ImageClassifier 20 | from common.utils.data import ForeverDataIterator 21 | from common.utils.metric import accuracy 22 | from common.utils.meter import AverageMeter, ProgressMeter 23 | 24 | sys.path.append('.') 25 | import utils 26 | 27 | def main(args: argparse.Namespace): 28 | 29 | if args.log_results: 30 | wandb.init(project="DA", entity="SDAT", name=args.log_name) 31 | wandb.config.update(args) 32 | 33 | cudnn.benchmark = True 34 | device = args.device 35 | # Data loading code 36 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip, 37 | random_color_jitter=False, resize_size=args.resize_size, 38 | norm_mean=args.norm_mean, norm_std=args.norm_std) 39 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, 40 | norm_mean=args.norm_mean, norm_std=args.norm_std) 41 | 42 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ 43 | utils.get_dataset(args.data, args.root, args.source, 44 | args.target, train_transform, val_transform) 45 | val_loader = DataLoader( 46 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 47 | test_loader = DataLoader( 48 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 49 | 50 | 51 | # create model 52 | print("=> using model '{}'".format(args.arch)) 53 | backbone = utils.get_model(args.arch, pretrain=not args.scratch) 54 | print(backbone) 55 | pool_layer = nn.Identity() if args.no_pool else None 56 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, 57 | pool_layer=pool_layer, finetune=not args.scratch).to(device) 58 | classifier_feature_dim = classifier.features_dim 59 | 60 | # resume from the best checkpoint 61 | if args.phase != 'train': 62 | path = args.weight_path 63 | print(f"[INFORMATION] Using the weights stored at {args.weight_path}") 64 | classifier.load_state_dict(torch.load(path)) 65 | 66 | if args.phase == 'test': 67 | acc1 = utils.validate(test_loader, classifier, args, device) 68 | print(acc1) 69 | return 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser( 73 | description='CDAN for Unsupervised Domain Adaptation') 74 | # dataset parameters 75 | parser.add_argument('root', metavar='DIR', 76 | help='root path of dataset') 77 | parser.add_argument('-d', '--data', metavar='DATA', default='OfficeHome', choices=utils.get_dataset_names(), 78 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) + 79 | ' (default: OfficeHome)') 80 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') 81 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') 82 | parser.add_argument('--train-resizing', type=str, default='default') 83 | parser.add_argument('--val-resizing', type=str, default='default') 84 | parser.add_argument('--resize-size', type=int, default=224, 85 | help='the image size after resizing') 86 | parser.add_argument('--no-hflip', action='store_true', 87 | help='no random horizontal flipping during training') 88 | parser.add_argument('--norm-mean', type=float, nargs='+', 89 | default=(0.485, 0.456, 0.406), help='normalization mean') 90 | parser.add_argument('--norm-std', type=float, nargs='+', 91 | default=(0.229, 0.224, 0.225), help='normalization std') 92 | # model parameters 93 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 94 | choices=utils.get_model_names(), 95 | help='backbone architecture: ' + 96 | ' | '.join(utils.get_model_names()) + 97 | ' (default: resnet18)') 98 | parser.add_argument('--bottleneck-dim', default=256, type=int, 99 | help='Dimension of bottleneck') 100 | parser.add_argument('--no-pool', action='store_true', 101 | help='no pool layer after the feature extractor.') 102 | parser.add_argument('--scratch', action='store_true', 103 | help='whether train from scratch.') 104 | parser.add_argument('-b', '--batch-size', default=32, type=int, 105 | metavar='N', 106 | help='mini-batch size (default: 32)') 107 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 108 | help='number of data loading workers (default: 2)') 109 | parser.add_argument('-p', '--print-freq', default=100, type=int, 110 | metavar='N', help='print frequency (default: 100)') 111 | parser.add_argument('--seed', default=None, type=int, 112 | help='seed for initializing training. ') 113 | parser.add_argument('--per-class-eval', action='store_true', 114 | help='whether output per-class accuracy during evaluation') 115 | parser.add_argument("--weight_path", type=str, default='cdan', 116 | help="Path to the saved weights") 117 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], 118 | help="When phase is 'test', only test the model." 119 | "When phase is 'analysis', only analysis the model.") 120 | parser.add_argument('--log_results', action='store_true', 121 | help="To log results in wandb") 122 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID") 123 | parser.add_argument('--log_name', type=str, 124 | default="log", help="log name for wandb") 125 | args = parser.parse_args() 126 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 127 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 128 | args.device = device 129 | main(args) 130 | -------------------------------------------------------------------------------- /examples/run_office_home.sh: -------------------------------------------------------------------------------- 1 | #CDAN (Office-Home-ViT) 2 | python cdan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_vit --gpu 0 --log_results 3 | python cdan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Pr --log_name Ar2Pr_cdan_vit --gpu 0 --log_results 4 | python cdan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Rw --log_name Ar2Rw_cdan_vit --gpu 0 --log_results 5 | 6 | python cdan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Ar --log_name Cl2Ar_cdan_vit --gpu 0 --log_results 7 | python cdan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Pr --log_name Cl2Pr_cdan_vit --gpu 0 --log_results 8 | python cdan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Rw --log_name Cl2Rw_cdan_vit --gpu 0 --log_results 9 | 10 | python cdan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Ar --log_name Pr2Ar_cdan_vit --gpu 0 --log_results 11 | python cdan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Cl --log_name Pr2Cl_cdan_vit --gpu 0 --log_results 12 | python cdan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Rw --log_name Pr2Rw_cdan_vit --gpu 0 --log_results 13 | 14 | python cdan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Ar --log_name Rw2Ar_cdan_vit --gpu 0 --log_results 15 | python cdan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Cl --log_name Rw2Cl_cdan_vit --gpu 0 --log_results 16 | python cdan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Pr --log_name Rw2Pr_cdan_vit --gpu 0 --log_results 17 | 18 | #CDAN_SDAT (Office-Home-ViT) 19 | python cdan_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 20 | python cdan_sdat.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Ar2Pr --log_name Ar2Pr_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 21 | python cdan_sdat.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Ar2Rw --log_name Ar2Rw_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 22 | 23 | python cdan_sdat.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Cl2Ar --log_name Cl2Ar_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 24 | python cdan_sdat.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Cl2Pr --log_name Cl2Pr_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 25 | python cdan_sdat.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Cl2Rw --log_name Cl2Rw_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 26 | 27 | python cdan_sdat.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Pr2Ar --log_name Pr2Ar_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 28 | python cdan_sdat.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Pr2Cl --log_name Pr2Cl_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 29 | python cdan_sdat.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Pr2Rw --log_name Pr2Rw_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 30 | 31 | python cdan_sdat.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Rw2Ar --log_name Rw2Ar_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 32 | python cdan_sdat.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Rw2Cl --log_name Rw2Cl_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 33 | python cdan_sdat.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Rw2Pr --log_name Rw2Pr_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results 34 | 35 | #CDAN_MCC (Office-Home-ViT) 36 | python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 37 | python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Pr --log_name Ar2Pr_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 38 | python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Rw --log_name Ar2Rw_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 39 | 40 | python cdan_mcc.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Cl2Ar --log_name Cl2Ar_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 41 | python cdan_mcc.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Cl2Pr --log_name Cl2Pr_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 42 | python cdan_mcc.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Cl2Rw --log_name Cl2Rw_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 43 | 44 | python cdan_mcc.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Pr2Ar --log_name Pr2Ar_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 45 | python cdan_mcc.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Pr2Cl --log_name Pr2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 46 | python cdan_mcc.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Pr2Rw --log_name Pr2Rw_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 47 | 48 | python cdan_mcc.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Rw2Ar --log_name Rw2Ar_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 49 | python cdan_mcc.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Rw2Cl --log_name Rw2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 50 | python cdan_mcc.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Rw2Pr --log_name Rw2Pr_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results 51 | 52 | #CDAN_MCC_SDAT (Office-Home-ViT) 53 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 54 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Pr --log_name Ar2Pr_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 55 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Rw --log_name Ar2Rw_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 56 | 57 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Cl2Ar --log_name Cl2Ar_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 58 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Cl2Pr --log_name Cl2Pr_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 59 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Cl2Rw --log_name Cl2Rw_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 60 | 61 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Pr2Ar --log_name Pr2Ar_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 62 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Pr2Cl --log_name Pr2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 63 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Pr2Rw --log_name Pr2Rw_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 64 | 65 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Rw2Ar --log_name Rw2Ar_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 66 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Rw2Cl --log_name Rw2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 67 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Rw2Pr --log_name Rw2Pr_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results 68 | -------------------------------------------------------------------------------- /examples/run_visda.sh: -------------------------------------------------------------------------------- 1 | #CDAN (VisDA2017-ViT) 2 | python cdan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.01 --per-class-eval --train-resizing cen.crop --log logs/cdan_vit/VisDA2017 --log_name visda_cdan_vit --gpu 0 --no-pool --log_results 3 | 4 | #CDAN_SDAT (VisDA2017-ViT) 5 | python cdan_sdat.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.01 --per-class-eval --train-resizing cen.crop --log logs/cdan_sdat_vit/VisDA2017 --log_name visda_cdan_sdat_vit --gpu 0 --no-pool --rho 0.005 --log_results 6 | 7 | #CDAN_MCC (VisDA2017-ViT) 8 | python cdan_mcc.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.002 --per-class-eval --train-resizing cen.crop --log logs/cdan_mcc_vit/VisDA2017 --log_name visda_cdan_mcc_vit --gpu 0 --no-pool --log_results 9 | 10 | #CDAN_MCC_SDAT (VisDA2017-ViT) 11 | python cdan_mcc_sdat.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.002 --per-class-eval --train-resizing cen.crop --log logs/cdan_mcc_sdat_vit/VisDA2017 --log_name visda_cdan_mcc_sdat_vit --gpu 0 --no-pool --rho 0.02 --log_results 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | import time 4 | import timm 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.transforms as T 9 | from torch.utils.data import ConcatDataset 10 | import wandb 11 | import wilds 12 | 13 | sys.path.append('../') 14 | import common.vision.datasets as datasets 15 | import common.vision.models as models 16 | from common.vision.transforms import ResizeImage 17 | from common.utils.metric import accuracy, ConfusionMatrix 18 | from common.utils.meter import AverageMeter, ProgressMeter 19 | 20 | 21 | def get_model_names(): 22 | return sorted( 23 | name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name]) 26 | ) + timm.list_models() 27 | 28 | 29 | def get_model(model_name, pretrain=True): 30 | if model_name in models.__dict__: 31 | # load models from common.vision.models 32 | backbone = models.__dict__[model_name](pretrained=pretrain) 33 | else: 34 | # load models from pytorch-image-models 35 | backbone = timm.create_model(model_name, pretrained=pretrain) 36 | try: 37 | #backbone.out_features = backbone.get_classifier().in_features 38 | backbone.out_features = 768 39 | backbone.reset_classifier(0, '') 40 | except: 41 | backbone.out_features = backbone.head.in_features 42 | backbone.head = nn.Identity() 43 | return backbone 44 | 45 | 46 | def convert_from_wilds_dataset(wild_dataset): 47 | class Dataset: 48 | def __init__(self): 49 | self.dataset = wild_dataset 50 | 51 | def __getitem__(self, idx): 52 | x, y, metadata = self.dataset[idx] 53 | return x, y 54 | 55 | def __len__(self): 56 | return len(self.dataset) 57 | 58 | return Dataset() 59 | 60 | 61 | def get_dataset_names(): 62 | return sorted( 63 | name for name in datasets.__dict__ 64 | if not name.startswith("__") and callable(datasets.__dict__[name]) 65 | ) + wilds.supported_datasets + ['Digits'] 66 | 67 | 68 | def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None): 69 | if train_target_transform is None: 70 | train_target_transform = train_source_transform 71 | if dataset_name == "Digits": 72 | train_source_dataset = datasets.__dict__[source[0]](osp.join(root, source[0]), download=True, 73 | transform=train_source_transform) 74 | train_target_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), download=True, 75 | transform=train_target_transform) 76 | val_dataset = test_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), split='test', 77 | download=True, transform=val_transform) 78 | class_names = datasets.MNIST.get_classes() 79 | num_classes = len(class_names) 80 | elif dataset_name in datasets.__dict__: 81 | # load datasets from common.vision.datasets 82 | dataset = datasets.__dict__[dataset_name] 83 | 84 | def concat_dataset(tasks, **kwargs): 85 | return ConcatDataset([dataset(task=task, **kwargs) for task in tasks]) 86 | 87 | train_source_dataset = concat_dataset(root=root, tasks=source, download=True, transform=train_source_transform) 88 | train_target_dataset = concat_dataset(root=root, tasks=target, download=True, transform=train_target_transform) 89 | val_dataset = concat_dataset(root=root, tasks=target, download=True, transform=val_transform) 90 | if dataset_name == 'DomainNet': 91 | test_dataset = concat_dataset(root=root, tasks=target, split='test', download=True, transform=val_transform) 92 | else: 93 | test_dataset = val_dataset 94 | class_names = train_source_dataset.datasets[0].classes 95 | num_classes = len(class_names) 96 | else: 97 | # load datasets from wilds 98 | dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True) 99 | num_classes = dataset.n_classes 100 | class_names = None 101 | train_source_dataset = convert_from_wilds_dataset(dataset.get_subset('train', transform=train_source_transform)) 102 | train_target_dataset = convert_from_wilds_dataset(dataset.get_subset('test', transform=train_target_transform)) 103 | val_dataset = test_dataset = convert_from_wilds_dataset(dataset.get_subset('test', transform=val_transform)) 104 | return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names 105 | 106 | 107 | def validate(val_loader, model, args, device) -> float: 108 | batch_time = AverageMeter('Time', ':6.3f') 109 | losses = AverageMeter('Loss', ':.4e') 110 | top1 = AverageMeter('Acc@1', ':6.2f') 111 | progress = ProgressMeter( 112 | len(val_loader), 113 | [batch_time, losses, top1], 114 | prefix='Test: ') 115 | 116 | # switch to evaluate mode 117 | model.eval() 118 | if args.per_class_eval: 119 | confmat = ConfusionMatrix(len(args.class_names)) 120 | else: 121 | confmat = None 122 | 123 | with torch.no_grad(): 124 | end = time.time() 125 | for i, (images, target) in enumerate(val_loader): 126 | images = images.to(device) 127 | target = target.to(device) 128 | 129 | # compute output 130 | output = model(images) 131 | loss = F.cross_entropy(output, target) 132 | 133 | # measure accuracy and record loss 134 | acc1, = accuracy(output, target, topk=(1,)) 135 | if confmat: 136 | confmat.update(target, output.argmax(1)) 137 | losses.update(loss.item(), images.size(0)) 138 | top1.update(acc1.item(), images.size(0)) 139 | 140 | # measure elapsed time 141 | batch_time.update(time.time() - end) 142 | end = time.time() 143 | 144 | if i % args.print_freq == 0: 145 | progress.display(i) 146 | 147 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) 148 | if confmat: 149 | print(confmat.format(args.class_names)) 150 | 151 | return top1.avg 152 | 153 | 154 | def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False, 155 | resize_size=224, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)): 156 | """ 157 | resizing mode: 158 | - default: resize the image to 256 and take a random resized crop of size 224; 159 | - cen.crop: resize the image to 256 and take the center crop of size 224; 160 | - res: resize the image to 224; 161 | """ 162 | if resizing == 'default': 163 | transform = T.Compose([ 164 | ResizeImage(256), 165 | T.RandomResizedCrop(224) 166 | ]) 167 | elif resizing == 'cen.crop': 168 | transform = T.Compose([ 169 | ResizeImage(256), 170 | T.CenterCrop(224) 171 | ]) 172 | elif resizing == 'ran.crop': 173 | transform = T.Compose([ 174 | ResizeImage(256), 175 | T.RandomCrop(224) 176 | ]) 177 | elif resizing == 'res.': 178 | transform = ResizeImage(resize_size) 179 | else: 180 | raise NotImplementedError(resizing) 181 | transforms = [transform] 182 | if random_horizontal_flip: 183 | transforms.append(T.RandomHorizontalFlip()) 184 | if random_color_jitter: 185 | transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)) 186 | transforms.extend([ 187 | T.ToTensor(), 188 | T.Normalize(mean=norm_mean, std=norm_std) 189 | ]) 190 | return T.Compose(transforms) 191 | 192 | 193 | def get_val_transform(resizing='default', resize_size=224, 194 | norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)): 195 | """ 196 | resizing mode: 197 | - default: resize the image to 256 and take the center crop of size 224; 198 | – res.: resize the image to 224 199 | """ 200 | if resizing == 'default': 201 | transform = T.Compose([ 202 | ResizeImage(256), 203 | T.CenterCrop(224), 204 | ]) 205 | elif resizing == 'res.': 206 | transform = ResizeImage(resize_size) 207 | else: 208 | raise NotImplementedError(resizing) 209 | return T.Compose([ 210 | transform, 211 | T.ToTensor(), 212 | T.Normalize(mean=norm_mean, std=norm_std) 213 | ]) 214 | 215 | 216 | def pretrain(train_source_iter, model, optimizer, lr_scheduler, epoch, args, device): 217 | batch_time = AverageMeter('Time', ':3.1f') 218 | data_time = AverageMeter('Data', ':3.1f') 219 | losses = AverageMeter('Loss', ':3.2f') 220 | cls_accs = AverageMeter('Cls Acc', ':3.1f') 221 | 222 | progress = ProgressMeter( 223 | args.iters_per_epoch, 224 | [batch_time, data_time, losses, cls_accs], 225 | prefix="Epoch: [{}]".format(epoch)) 226 | 227 | # switch to train mode 228 | model.train() 229 | 230 | end = time.time() 231 | for i in range(args.iters_per_epoch): 232 | x_s, labels_s = next(train_source_iter) 233 | x_s = x_s.to(device) 234 | labels_s = labels_s.to(device) 235 | 236 | # measure data loading time 237 | data_time.update(time.time() - end) 238 | 239 | # compute output 240 | y_s, f_s = model(x_s) 241 | 242 | cls_loss = F.cross_entropy(y_s, labels_s) 243 | loss = cls_loss 244 | 245 | cls_acc = accuracy(y_s, labels_s)[0] 246 | if args.log_results: 247 | wandb.log({'iteration':epoch*args.iters_per_epoch + i, 'loss':loss}) 248 | 249 | losses.update(loss.item(), x_s.size(0)) 250 | cls_accs.update(cls_acc.item(), x_s.size(0)) 251 | 252 | # compute gradient and do SGD step 253 | optimizer.zero_grad() 254 | loss.backward() 255 | optimizer.step() 256 | lr_scheduler.step() 257 | 258 | # measure elapsed time 259 | batch_time.update(time.time() - end) 260 | end = time.time() 261 | 262 | if i % args.print_freq == 0: 263 | progress.display(i) 264 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.1 2 | torchvision==0.10.1 3 | wandb==0.12.2 4 | timm==0.5.5 5 | prettytable==2.2.0 6 | --------------------------------------------------------------------------------