├── LICENSE
├── README.md
├── assets
└── sdat.png
├── common
├── __init__.py
├── modules
│ ├── __init__.py
│ └── classifier.py
├── utils
│ ├── __init__.py
│ ├── analysis
│ │ ├── __init__.py
│ │ ├── a_distance.py
│ │ └── tsne.py
│ ├── data.py
│ ├── logger.py
│ ├── meter.py
│ ├── metric
│ │ └── __init__.py
│ ├── sam.py
│ └── scheduler.py
└── vision
│ ├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── _util.cpython-36.pyc
│ │ ├── _util.cpython-38.pyc
│ │ ├── _util.cpython-39.pyc
│ │ ├── aircrafts.cpython-36.pyc
│ │ ├── aircrafts.cpython-38.pyc
│ │ ├── aircrafts.cpython-39.pyc
│ │ ├── coco70.cpython-36.pyc
│ │ ├── coco70.cpython-38.pyc
│ │ ├── coco70.cpython-39.pyc
│ │ ├── cub200.cpython-36.pyc
│ │ ├── cub200.cpython-38.pyc
│ │ ├── cub200.cpython-39.pyc
│ │ ├── digits.cpython-36.pyc
│ │ ├── digits.cpython-38.pyc
│ │ ├── digits.cpython-39.pyc
│ │ ├── domainnet.cpython-36.pyc
│ │ ├── domainnet.cpython-38.pyc
│ │ ├── domainnet.cpython-39.pyc
│ │ ├── dtd.cpython-36.pyc
│ │ ├── dtd.cpython-38.pyc
│ │ ├── dtd.cpython-39.pyc
│ │ ├── eurosat.cpython-36.pyc
│ │ ├── eurosat.cpython-38.pyc
│ │ ├── eurosat.cpython-39.pyc
│ │ ├── imagelist.cpython-36.pyc
│ │ ├── imagelist.cpython-38.pyc
│ │ ├── imagelist.cpython-39.pyc
│ │ ├── imagenet_r.cpython-36.pyc
│ │ ├── imagenet_r.cpython-38.pyc
│ │ ├── imagenet_r.cpython-39.pyc
│ │ ├── imagenet_sketch.cpython-36.pyc
│ │ ├── imagenet_sketch.cpython-38.pyc
│ │ ├── imagenet_sketch.cpython-39.pyc
│ │ ├── office31.cpython-36.pyc
│ │ ├── office31.cpython-38.pyc
│ │ ├── office31.cpython-39.pyc
│ │ ├── officecaltech.cpython-36.pyc
│ │ ├── officecaltech.cpython-38.pyc
│ │ ├── officecaltech.cpython-39.pyc
│ │ ├── officehome.cpython-36.pyc
│ │ ├── officehome.cpython-38.pyc
│ │ ├── officehome.cpython-39.pyc
│ │ ├── oxfordflowers.cpython-36.pyc
│ │ ├── oxfordflowers.cpython-38.pyc
│ │ ├── oxfordflowers.cpython-39.pyc
│ │ ├── oxfordpet.cpython-36.pyc
│ │ ├── oxfordpet.cpython-38.pyc
│ │ ├── oxfordpet.cpython-39.pyc
│ │ ├── pacs.cpython-36.pyc
│ │ ├── pacs.cpython-38.pyc
│ │ ├── pacs.cpython-39.pyc
│ │ ├── patchcamelyon.cpython-36.pyc
│ │ ├── patchcamelyon.cpython-38.pyc
│ │ ├── patchcamelyon.cpython-39.pyc
│ │ ├── resisc45.cpython-36.pyc
│ │ ├── resisc45.cpython-38.pyc
│ │ ├── resisc45.cpython-39.pyc
│ │ ├── retinopathy.cpython-36.pyc
│ │ ├── retinopathy.cpython-38.pyc
│ │ ├── retinopathy.cpython-39.pyc
│ │ ├── stanford_cars.cpython-36.pyc
│ │ ├── stanford_cars.cpython-38.pyc
│ │ ├── stanford_cars.cpython-39.pyc
│ │ ├── stanford_dogs.cpython-36.pyc
│ │ ├── stanford_dogs.cpython-38.pyc
│ │ ├── stanford_dogs.cpython-39.pyc
│ │ ├── visda2017.cpython-36.pyc
│ │ ├── visda2017.cpython-38.pyc
│ │ └── visda2017.cpython-39.pyc
│ ├── _util.py
│ ├── domainnet.py
│ ├── imagelist.py
│ ├── officehome.py
│ └── visda2017.py
│ ├── models
│ ├── __init__.py
│ └── resnet.py
│ └── transforms
│ └── __init__.py
├── dalib
├── adaptation
│ ├── __init__.py
│ ├── cdan.py
│ └── mcc.py
└── modules
│ ├── __init__.py
│ ├── domain_discriminator.py
│ ├── entropy.py
│ ├── gl.py
│ ├── grl.py
│ └── kernels.py
├── examples
├── cdan.py
├── cdan_mcc.py
├── cdan_mcc_sdat.py
├── cdan_sdat.py
├── eval.py
├── run_office_home.sh
├── run_visda.sh
└── utils.py
└── requirements.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Video Analytics Lab -- IISc
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
Smooth Domain Adversarial Training
2 |
3 | **Harsh Rangwani\*, Sumukh K Aithal\*, Mayank Mishra, Arihant Jain, R. Venkatesh Babu**
4 |
5 |
6 |
7 | This is the official PyTorch implementation for our ICML'22 paper: **A Closer Look at Smoothness in Domain Adversarial Training**.[[`Paper`](https://arxiv.org/abs/2206.08213)]
8 |
9 |
10 | [](https://paperswithcode.com/sota/domain-adaptation-on-office-home?p=a-closer-look-at-smoothness-in-domain-1) [](https://paperswithcode.com/sota/domain-adaptation-on-visda2017?p=a-closer-look-at-smoothness-in-domain-1)
11 |
12 | ## Introduction
13 |
14 |

15 |
16 |
17 | In recent times, methods converging to smooth optima have shown improved generalization for supervised learning tasks like classification. In this work, we analyze the effect of smoothness enhancing formulations on domain adversarial training, the objective of which is a combination of task loss (eg. classification, regression etc.) and adversarial terms. We find that converging to a smooth minima with respect to (w.r.t.) task loss stabilizes the adversarial training leading to better performance on target domain. In contrast to task loss, our analysis shows that converging to smooth minima w.r.t. adversarial loss leads to sub-optimal generalization on the target domain. Based on the analysis, we introduce the Smooth Domain Adversarial Training (SDAT) procedure, which effectively enhances the performance of existing domain adversarial methods for both classification and object detection tasks.
18 |
19 | **TLDR:** Just do a few line of code change to improve your adversarial domain adaptation algorithm by converting it to it's smooth variant.
20 |
21 | ### Why use SDAT?
22 | - Can be combined with any DAT algorithm.
23 | - Easy to integrate with a few lines of code.
24 | - Leads to significant improvement in the accuracy of target domain.
25 |
44 |
45 | #### DAT Based Method w/ SDAT
46 | We provide the details of changes required to convert any DAT algorithm (eg. CDAN, DANN, CDAN+MCC etc.) to it's Smooth DAT version.
47 |
48 | ```python
49 | optimizer = SAM(classifier.get_parameters(), torch.optim.SGD, rho=args.rho, adaptive=False,
50 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
51 | # optimizer refers to the Smooth optimizer which contains parameters of the feature extractor and classifier.
52 | optimizer.zero_grad()
53 | # ad_optimizer refers to standard SGD optimizer which contains parameters of domain classifier.
54 | ad_optimizer.zero_grad()
55 |
56 | # Calculate task loss
57 | class_prediction, feature = model(x)
58 | task_loss = task_loss_fn(class_prediction, label)
59 | task_loss.backward()
60 |
61 | # Calculate ϵ̂ (w) and add it to the weights
62 | optimizer.first_step()
63 |
64 | # Calculate task loss and domain loss
65 | class_prediction, feature = model(x)
66 | task_loss = task_loss_fn(class_prediction, label)
67 | domain_loss = domain_classifier(feature)
68 | loss = task_loss + domain_loss
69 | loss.backward()
70 |
71 | # Update parameters (Sharpness-Aware update)
72 | optimizer.step()
73 | # Update parameters of domain classifier
74 | ad_optimizer.step()
75 | ```
76 |
77 | ## Getting started
78 |
79 | * ### Requirements
80 |
81 | - pytorch 1.9.1
82 | - torchvision 0.10.1
83 | - wandb 0.12.2
84 | - timm 0.5.5
85 | - prettytable 2.2.0
86 | - scikit-learn
87 |
88 | * ### Installation
89 | ```
90 | git clone https://github.com/val-iisc/SDAT.git
91 | cd SDAT
92 | pip install -r requirements.txt
93 | ```
94 | We use Weights and Biases ([wandb](https://wandb.ai/site)) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The ```project``` and ```entity``` arguments in ```wandb.init``` must be changed accordingly. To disable wandb tracking, the ```log_results``` flag can be used.
95 |
96 | * ### Datasets
97 | The datasets used in the repository can be downloaded from the following links:
98 |
99 | - [Office-Home](https://www.hemanthdv.org/officeHomeDataset.html)
- [VisDA-2017](https://github.com/VisionLearningGroup/taskcv-2017-public) (under classification track)
- [DomainNet](http://ai.bu.edu/M3SDA/)
100 |
101 | The datasets are automatically downloaded to the ```data/``` folder if it is not available.
102 | ## Training
103 | We report our numbers primarily on two domain adaptation methods: CDAN w/ SDAT and CDAN+MCC w/ SDAT. The training scripts can be found under the `examples` subdirectory.
104 |
105 | ### Domain Adversarial Training (DAT)
106 | To train using standard CDAN and CDAN+MCC, use the `cdan.py` and `cdan_mcc.py` files, respectively. Sample command to execute the training of the aforementioned methods with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) can be found below.
107 | ```
108 | python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
109 | ```
110 |
111 | ### Smooth Domain Adversarial Training (SDAT)
112 |
113 | To train using our proposed CDAN w/ SDAT and CDAN+MCC w/ SDAT, use the `cdan_sdat.py` and `cdan_mcc_sdat.py` files, respectively.
114 |
115 | A sample script to run CDAN+MCC w/ SDAT with a ViT B-16 backbone, on Office-Home dataset (with Art as source domain and Clipart as the target domain) is given below.
116 | ```
117 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
118 | ```
119 | Additional commands to reproduce the results can be found from `run_office_home.sh` and `run_visda.sh` under `examples`.
120 |
121 | ### Results
122 | We following table reports the accuracy score across the various splits of Office-Home and VisDA-2017 datasets using CDAN+MCC w/ SDAT with VIT B-16 backbone. We also provide downloadable weights for the corresponding pretrained classifier.
123 |
124 |
125 |
126 |
127 | Dataset |
128 | Source |
129 | Target |
130 | Accuracy |
131 | Checkpoints |
132 |
133 |
134 |
135 |
136 | Office-Home |
137 | Art |
138 | Clipart |
139 | 70.8 |
140 | ckpt
141 | |
142 |
143 |
144 | Art |
145 | Product |
146 | 80.7 |
147 | ckpt |
148 |
149 |
150 | Art |
151 | Real World |
152 | 90.5 |
153 | ckpt |
154 |
155 |
156 | Clipart |
157 | Art |
158 | 85.2 |
159 | ckpt |
160 |
161 |
162 | Clipart |
163 | Product |
164 | 87.3 |
165 | ckpt |
166 |
167 |
168 | Clipart |
169 | Real World |
170 | 89.7 |
171 | ckpt |
172 |
173 |
174 | Product |
175 | Art |
176 | 84.1 |
177 | ckpt |
178 |
179 |
180 | Product |
181 | Clipart |
182 | 70.7 |
183 | ckpt |
184 |
185 |
186 | Product |
187 | Real World |
188 | 90.6 |
189 | ckpt |
190 |
191 |
192 | Real World |
193 | Art |
194 | 88.3 |
195 | ckpt |
196 |
197 |
198 | Real World |
199 | Clipart |
200 | 75.5 |
201 | ckpt |
202 |
203 |
204 | Real World |
205 | Product |
206 | 92.1 |
207 | ckpt |
208 |
209 |
210 | VisDA-2017 |
211 | Synthetic |
212 | Real |
213 | 89.8 |
214 | ckpt |
215 |
216 |
217 |
218 |
219 |
220 | ### Evaluation
221 | To evaluate a classifier with pretrained weights, use the `eval.py` under `examples`. Set the `--weight_path` argument with the path of the weight to be evaluated.
222 |
223 | A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on Office-Home (with Art as source domain and Clipart as the target domain) is given below.
224 | ```
225 | python eval.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 -b 24 --no-pool --weight_path path_to_weight.pth --log_name Ar2Cl_cdan_mcc_sdat_vit_eval --gpu 0 --phase test
226 | ```
227 | A sample run to evaluate the pretrained ViT B-16 with CDAN+MCC w/ SDAT on VisDA-2017 (with Synthetic as source domain and Real as the target domain) is given below.
228 |
229 | ```
230 | python eval.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --per-class-eval --train-resizing cen.crop --weight_path path_to_weight.pth --log_name visda_cdan_mcc_sdat_vit_eval --gpu 0 --no-pool --phase test
231 | ```
232 |
233 |
234 | ## Overview of the arguments
235 | Generally, all scripts in the project take the following flags
236 | - `-a`: Architecture of the backbone. (resnet50|vit_base_patch16_224)
237 | - `-d`: Dataset (OfficeHome|DomainNet)
238 | - `-s`: Source Domain
239 | - `-t`: Target Domain
240 | - `--epochs`: Number of Epochs to be trained for.
241 | - `--no-pool`: Use --no-pool for all experiments with ViT backbone.
242 | - `--log_name`: Name of the run on wandb.
243 | - `--gpu`: GPU id to use.
244 | - `--rho`: $\rho$ value in SDAT (Applicable only for SDAT runs).
245 |
246 | ## Acknowledgement
247 | Our implementation is based on the [Transfer Learning Library](https://github.com/thuml/Transfer-Learning-Library). We use the PyTorch implementation of SAM from https://github.com/davda54/sam.
248 | ## Citation
249 | If you find our paper or codebase useful, please consider citing us as:
250 | ```latex
251 | @InProceedings{rangwani2022closer,
252 | title={A Closer Look at Smoothness in Domain Adversarial Training},
253 | author={Rangwani, Harsh and Aithal, Sumukh K and Mishra, Mayank and Jain, Arihant and Babu, R. Venkatesh},
254 | booktitle={Proceedings of the 39th International Conference on Machine Learning},
255 | year={2022}
256 | }
257 | ```
258 |
--------------------------------------------------------------------------------
/assets/sdat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/assets/sdat.png
--------------------------------------------------------------------------------
/common/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['modules', 'utils', 'vision']
2 |
--------------------------------------------------------------------------------
/common/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifier import *
2 |
3 | __all__ = ['classifier']
4 |
--------------------------------------------------------------------------------
/common/modules/classifier.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Optional, List, Dict
2 | import torch.nn as nn
3 | import torch
4 |
5 | __all__ = ['Classifier']
6 |
7 |
8 | class Classifier(nn.Module):
9 | """A generic Classifier class for domain adaptation.
10 |
11 | Args:
12 | backbone (torch.nn.Module): Any backbone to extract 2-d features from data
13 | num_classes (int): Number of classes
14 | bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default
15 | bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1
16 | head (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default
17 | finetune (bool): Whether finetune the classifier or train from scratch. Default: True
18 | freeze (bool) : Freeze the backbone and only train the classifier
19 |
20 | .. note::
21 | Different classifiers are used in different domain adaptation algorithms to achieve better accuracy
22 | respectively, and we provide a suggested `Classifier` for different algorithms.
23 | Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with
24 | the domain adaptation algorithm in this algorithm library.
25 |
26 | .. note::
27 | The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy
28 | by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`.
29 |
30 | Inputs:
31 | - x (tensor): input data fed to `backbone`
32 |
33 | Outputs:
34 | - predictions: classifier's predictions
35 | - features: features after `bottleneck` layer and before `head` layer
36 |
37 | Shape:
38 | - Inputs: (minibatch, *) where * means, any number of additional dimensions
39 | - predictions: (minibatch, `num_classes`)
40 | - features: (minibatch, `features_dim`)
41 |
42 | """
43 |
44 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None,
45 | bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True, pool_layer=None):
46 | super(Classifier, self).__init__()
47 | self.backbone = backbone
48 | self.num_classes = num_classes
49 | if pool_layer is None:
50 | self.pool_layer = nn.Sequential(
51 | nn.AdaptiveAvgPool2d(output_size=(1, 1)),
52 | nn.Flatten()
53 | )
54 | else:
55 | self.pool_layer = pool_layer
56 | if bottleneck is None:
57 | self.bottleneck = nn.Identity()
58 | self._features_dim = backbone.out_features
59 | else:
60 | self.bottleneck = bottleneck
61 | print("[INFORMATION] The bottleneck dim is ", bottleneck_dim)
62 | assert bottleneck_dim > 0
63 | self._features_dim = bottleneck_dim
64 |
65 |
66 | if head is None:
67 | self.head = nn.Linear(self._features_dim, num_classes)
68 | else:
69 | self.head = head
70 | self.finetune = finetune
71 |
72 | @property
73 | def features_dim(self) -> int:
74 | """The dimension of features before the final `head` layer"""
75 | return self._features_dim
76 |
77 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
78 | """"""
79 | f = self.pool_layer(self.backbone(x))
80 | f = self.bottleneck(f)
81 | predictions = self.head(f)
82 | if self.training:
83 | return predictions, f
84 | else:
85 | return predictions
86 |
87 | def get_parameters(self, base_lr=1.0) -> List[Dict]:
88 | """A parameter list which decides optimization hyper-parameters,
89 | such as the relative learning rate of each layer
90 | """
91 | params = [
92 | {"params": self.backbone.parameters(), "lr": 0.1*base_lr if self.finetune else 1.0 * base_lr},
93 | {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
94 | {"params": self.head.parameters(), "lr": 1.0 * base_lr},
95 | ]
96 |
97 | return params
98 |
99 |
100 | class ImageClassifier(Classifier):
101 | pass
102 |
--------------------------------------------------------------------------------
/common/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import CompleteLogger
2 | from .meter import *
3 | from .data import ForeverDataIterator
4 |
5 | __all__ = ['metric', 'analysis', 'meter', 'data', 'logger']
--------------------------------------------------------------------------------
/common/utils/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | import torch.nn as nn
4 | import tqdm
5 |
6 |
7 | def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module,
8 | device: torch.device, max_num_features=None) -> torch.Tensor:
9 | """
10 | Fetch data from `data_loader`, and then use `feature_extractor` to collect features
11 |
12 | Args:
13 | data_loader (torch.utils.data.DataLoader): Data loader.
14 | feature_extractor (torch.nn.Module): A feature extractor.
15 | device (torch.device)
16 | max_num_features (int): The max number of features to return
17 |
18 | Returns:
19 | Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`).
20 | """
21 | feature_extractor.eval()
22 | all_features = []
23 | with torch.no_grad():
24 | for i, (images, target) in enumerate(tqdm.tqdm(data_loader)):
25 | if max_num_features is not None and i >= max_num_features:
26 | break
27 | images = images.to(device)
28 | feature = feature_extractor(images).cpu()
29 | all_features.append(feature)
30 | return torch.cat(all_features, dim=0)
31 |
--------------------------------------------------------------------------------
/common/utils/analysis/a_distance.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Junguang Jiang
3 | @contact: JiangJunguang1123@outlook.com
4 | """
5 | from torch.utils.data import TensorDataset
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import DataLoader
10 | from torch.optim import SGD
11 | from ..meter import AverageMeter
12 | from ..metric import binary_accuracy
13 |
14 |
15 | class ANet(nn.Module):
16 | def __init__(self, in_feature):
17 | super(ANet, self).__init__()
18 | self.layer = nn.Linear(in_feature, 1)
19 | self.sigmoid = nn.Sigmoid()
20 |
21 | def forward(self, x):
22 | x = self.layer(x)
23 | x = self.sigmoid(x)
24 | return x
25 |
26 |
27 | def calculate(source_feature: torch.Tensor, target_feature: torch.Tensor,
28 | device, progress=True, training_epochs=10):
29 | """
30 | Calculate the :math:`\mathcal{A}`-distance, which is a measure for distribution discrepancy.
31 |
32 | The definition is :math:`dist_\mathcal{A} = 2 (1-2\epsilon)`, where :math:`\epsilon` is the
33 | test error of a classifier trained to discriminate the source from the target.
34 |
35 | Args:
36 | source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`
37 | target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`
38 | device (torch.device)
39 | progress (bool): if True, displays a the progress of training A-Net
40 | training_epochs (int): the number of epochs when training the classifier
41 |
42 | Returns:
43 | :math:`\mathcal{A}`-distance
44 | """
45 | source_label = torch.ones((source_feature.shape[0], 1))
46 | target_label = torch.zeros((target_feature.shape[0], 1))
47 | feature = torch.cat([source_feature, target_feature], dim=0)
48 | label = torch.cat([source_label, target_label], dim=0)
49 |
50 | dataset = TensorDataset(feature, label)
51 | length = len(dataset)
52 | train_size = int(0.8 * length)
53 | val_size = length - train_size
54 | train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])
55 | train_loader = DataLoader(train_set, batch_size=2, shuffle=True)
56 | val_loader = DataLoader(val_set, batch_size=8, shuffle=False)
57 |
58 | anet = ANet(feature.shape[1]).to(device)
59 | optimizer = SGD(anet.parameters(), lr=0.01)
60 | a_distance = 2.0
61 | for epoch in range(training_epochs):
62 | anet.train()
63 | for (x, label) in train_loader:
64 | x = x.to(device)
65 | label = label.to(device)
66 | anet.zero_grad()
67 | y = anet(x)
68 | loss = F.binary_cross_entropy(y, label)
69 | loss.backward()
70 | optimizer.step()
71 |
72 | anet.eval()
73 | meter = AverageMeter("accuracy", ":4.2f")
74 | with torch.no_grad():
75 | for (x, label) in val_loader:
76 | x = x.to(device)
77 | label = label.to(device)
78 | y = anet(x)
79 | acc = binary_accuracy(y, label)
80 | meter.update(acc, x.shape[0])
81 | error = 1 - meter.avg / 100
82 | a_distance = 2 * (1 - 2 * error)
83 | if progress:
84 | print("epoch {} accuracy: {} A-dist: {}".format(epoch, meter.avg, a_distance))
85 |
86 | return a_distance
87 |
88 |
--------------------------------------------------------------------------------
/common/utils/analysis/tsne.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Junguang Jiang
3 | @contact: JiangJunguang1123@outlook.com
4 | """
5 | import torch
6 | import matplotlib
7 |
8 | matplotlib.use('Agg')
9 | from sklearn.manifold import TSNE
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | import matplotlib.colors as col
13 |
14 |
15 | def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor,
16 | filename: str, source_color='r', target_color='b'):
17 | """
18 | Visualize features from different domains using t-SNE.
19 |
20 | Args:
21 | source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`
22 | target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`
23 | filename (str): the file name to save t-SNE
24 | source_color (str): the color of the source features. Default: 'r'
25 | target_color (str): the color of the target features. Default: 'b'
26 |
27 | """
28 | source_feature = source_feature.numpy()
29 | target_feature = target_feature.numpy()
30 | features = np.concatenate([source_feature, target_feature], axis=0)
31 |
32 | # map features to 2-d using TSNE
33 | X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features)
34 |
35 | # domain labels, 1 represents source while 0 represents target
36 | domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature))))
37 |
38 | # visualize using matplotlib
39 | fig, ax = plt.subplots(figsize=(10, 10))
40 | ax.spines['top'].set_visible(False)
41 | ax.spines['right'].set_visible(False)
42 | ax.spines['bottom'].set_visible(False)
43 | ax.spines['left'].set_visible(False)
44 | plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20)
45 | plt.xticks([])
46 | plt.yticks([])
47 | plt.savefig(filename)
48 |
--------------------------------------------------------------------------------
/common/utils/data.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import random
3 | import numpy as np
4 |
5 | import torch
6 | from torch.utils.data import Sampler
7 | from torch.utils.data import DataLoader, Dataset
8 | from typing import TypeVar, Iterable, Dict, List
9 |
10 | T_co = TypeVar('T_co', covariant=True)
11 | T = TypeVar('T')
12 |
13 |
14 | def send_to_device(tensor, device):
15 | """
16 | Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
17 |
18 | Args:
19 | tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`):
20 | The data to send to a given device.
21 | device (:obj:`torch.device`):
22 | The device to send the data to
23 |
24 | Returns:
25 | The same data structure as :obj:`tensor` with all tensors sent to the proper device.
26 | """
27 | if isinstance(tensor, (list, tuple)):
28 | return type(tensor)(send_to_device(t, device) for t in tensor)
29 | elif isinstance(tensor, dict):
30 | return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()})
31 | elif not hasattr(tensor, "to"):
32 | return tensor
33 | return tensor.to(device)
34 |
35 |
36 | class ForeverDataIterator:
37 | r"""A data iterator that will never stop producing data"""
38 |
39 | def __init__(self, data_loader: DataLoader, device=None):
40 | self.data_loader = data_loader
41 | self.iter = iter(self.data_loader)
42 | self.device = device
43 |
44 | def __next__(self):
45 | try:
46 | data = next(self.iter)
47 | if self.device is not None:
48 | data = send_to_device(data, self.device)
49 | except StopIteration:
50 | self.iter = iter(self.data_loader)
51 | data = next(self.iter)
52 | if self.device is not None:
53 | data = send_to_device(data, self.device)
54 | return data
55 |
56 | def __len__(self):
57 | return len(self.data_loader)
58 |
59 |
60 | class RandomMultipleGallerySampler(Sampler):
61 | r"""Sampler from `In defense of the Triplet Loss for Person Re-Identification
62 | (ICCV 2017) `_. Assume there are :math:`N` identities in the dataset, this
63 | implementation simply samples :math:`K` images for every identity to form an iter of size :math:`N\times K`. During
64 | training, we will call ``__iter__`` method of pytorch dataloader once we reach a ``StopIteration``, this guarantees
65 | every image in the dataset will eventually be selected and we are not wasting any training data.
66 |
67 | Args:
68 | dataset(list): each element of this list is a tuple (image_path, person_id, camera_id)
69 | num_instances(int, optional): number of images to sample for every identity (:math:`K` here)
70 | """
71 |
72 | def __init__(self, dataset, num_instances=4):
73 | super(RandomMultipleGallerySampler, self).__init__(dataset)
74 | self.dataset = dataset
75 | self.num_instances = num_instances
76 |
77 | self.idx_to_pid = {}
78 | self.cid_list_per_pid = {}
79 | self.idx_list_per_pid = {}
80 |
81 | for idx, (_, pid, cid) in enumerate(dataset):
82 | if pid not in self.cid_list_per_pid:
83 | self.cid_list_per_pid[pid] = []
84 | self.idx_list_per_pid[pid] = []
85 |
86 | self.idx_to_pid[idx] = pid
87 | self.cid_list_per_pid[pid].append(cid)
88 | self.idx_list_per_pid[pid].append(idx)
89 |
90 | self.pid_list = list(self.idx_list_per_pid.keys())
91 | self.num_samples = len(self.pid_list)
92 |
93 | def __len__(self):
94 | return self.num_samples * self.num_instances
95 |
96 | def __iter__(self):
97 | def select_idxes(element_list, target_element):
98 | assert isinstance(element_list, list)
99 | return [i for i, element in enumerate(element_list) if element != target_element]
100 |
101 | pid_idxes = torch.randperm(len(self.pid_list)).tolist()
102 | final_idxes = []
103 |
104 | for perm_id in pid_idxes:
105 | i = random.choice(self.idx_list_per_pid[self.pid_list[perm_id]])
106 | _, _, cid = self.dataset[i]
107 |
108 | final_idxes.append(i)
109 |
110 | pid_i = self.idx_to_pid[i]
111 | cid_list = self.cid_list_per_pid[pid_i]
112 | idx_list = self.idx_list_per_pid[pid_i]
113 | selected_cid_list = select_idxes(cid_list, cid)
114 |
115 | if selected_cid_list:
116 | if len(selected_cid_list) >= self.num_instances:
117 | cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=False)
118 | else:
119 | cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=True)
120 | for cid_idx in cid_idxes:
121 | final_idxes.append(idx_list[cid_idx])
122 | else:
123 | selected_idxes = select_idxes(idx_list, i)
124 | if not selected_idxes:
125 | continue
126 | if len(selected_idxes) >= self.num_instances:
127 | pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=False)
128 | else:
129 | pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=True)
130 |
131 | for pid_idx in pid_idxes:
132 | final_idxes.append(idx_list[pid_idx])
133 |
134 | return iter(final_idxes)
135 |
136 |
137 | class CombineDataset(Dataset[T_co]):
138 | r"""Dataset as a combination of multiple datasets.
139 | The element of each dataset must be a list, and the i-th element of the combined dataset
140 | is a list splicing of the i-th element of each sub dataset.
141 | The length of the combined dataset is the minimum of the lengths of all sub datasets.
142 |
143 | Arguments:
144 | datasets (sequence): List of datasets to be concatenated
145 | """
146 |
147 | def __init__(self, datasets: Iterable[Dataset]) -> None:
148 | super(CombineDataset, self).__init__()
149 | # Cannot verify that datasets is Sized
150 | assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
151 | self.datasets = list(datasets)
152 |
153 | def __len__(self):
154 | return min([len(d) for d in self.datasets])
155 |
156 | def __getitem__(self, idx):
157 | return list(itertools.chain(*[d[idx] for d in self.datasets]))
158 |
159 |
160 | def concatenate(tensors):
161 | """concatenate multiple batches into one batch.
162 | ``tensors`` can be :class:`torch.Tensor`, List or Dict, but they must be the same data format.
163 | """
164 | if isinstance(tensors[0], torch.Tensor):
165 | return torch.cat(tensors, dim=0)
166 | elif isinstance(tensors[0], List):
167 | ret = []
168 | for i in range(len(tensors[0])):
169 | ret.append(concatenate([t[i] for t in tensors]))
170 | return ret
171 | elif isinstance(tensors[0], Dict):
172 | ret = dict()
173 | for k in tensors[0].keys():
174 | ret[k] = concatenate([t[k] for t in tensors])
175 | return ret
176 |
--------------------------------------------------------------------------------
/common/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 |
5 | class TextLogger(object):
6 | """Writes stream output to external text file.
7 |
8 | Args:
9 | filename (str): the file to write stream output
10 | stream: the stream to read from. Default: sys.stdout
11 | """
12 | def __init__(self, filename, stream=sys.stdout):
13 | self.terminal = stream
14 | self.log = open(filename, 'a')
15 |
16 | def write(self, message):
17 | self.terminal.write(message)
18 | self.log.write(message)
19 | self.flush()
20 |
21 | def flush(self):
22 | self.terminal.flush()
23 | self.log.flush()
24 |
25 | def close(self):
26 | self.terminal.close()
27 | self.log.close()
28 |
29 |
30 | class CompleteLogger:
31 | """
32 | A useful logger that
33 |
34 | - writes outputs to files and displays them on the console at the same time.
35 | - manages the directory of checkpoints and debugging images.
36 |
37 | Args:
38 | root (str): the root directory of logger
39 | phase (str): the phase of training.
40 |
41 | """
42 |
43 | def __init__(self, root, phase='train'):
44 | self.root = root
45 | self.phase = phase
46 | self.visualize_directory = os.path.join(self.root, "visualize")
47 | self.checkpoint_directory = os.path.join(self.root, "checkpoints")
48 | self.epoch = 0
49 |
50 | os.makedirs(self.root, exist_ok=True)
51 | os.makedirs(self.visualize_directory, exist_ok=True)
52 | os.makedirs(self.checkpoint_directory, exist_ok=True)
53 |
54 | # redirect std out
55 | now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
56 | log_filename = os.path.join(self.root, "{}-{}.txt".format(phase, now))
57 | if os.path.exists(log_filename):
58 | os.remove(log_filename)
59 | self.logger = TextLogger(log_filename)
60 | sys.stdout = self.logger
61 | sys.stderr = self.logger
62 | if phase != 'train':
63 | self.set_epoch(phase)
64 |
65 | def set_epoch(self, epoch):
66 | """Set the epoch number. Please use it during training."""
67 | os.makedirs(os.path.join(self.visualize_directory, str(epoch)), exist_ok=True)
68 | self.epoch = epoch
69 |
70 | def _get_phase_or_epoch(self):
71 | if self.phase == 'train':
72 | return str(self.epoch)
73 | else:
74 | return self.phase
75 |
76 | def get_image_path(self, filename: str):
77 | """
78 | Get the full image path for a specific filename
79 | """
80 | return os.path.join(self.visualize_directory, self._get_phase_or_epoch(), filename)
81 |
82 | def get_checkpoint_path(self, name=None):
83 | """
84 | Get the full checkpoint path.
85 |
86 | Args:
87 | name (optional): the filename (without file extension) to save checkpoint.
88 | If None, when the phase is ``train``, checkpoint will be saved to ``{epoch}.pth``.
89 | Otherwise, will be saved to ``{phase}.pth``.
90 |
91 | """
92 | if name is None:
93 | name = self._get_phase_or_epoch()
94 | name = str(name)
95 | return os.path.join(self.checkpoint_directory, name + ".pth")
96 |
97 | def close(self):
98 | self.logger.close()
99 |
--------------------------------------------------------------------------------
/common/utils/meter.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 |
4 | class AverageMeter(object):
5 | r"""Computes and stores the average and current value.
6 |
7 | Examples::
8 |
9 | >>> # Initialize a meter to record loss
10 | >>> losses = AverageMeter()
11 | >>> # Update meter after every minibatch update
12 | >>> losses.update(loss_value, batch_size)
13 | """
14 | def __init__(self, name: str, fmt: Optional[str] = ':f'):
15 | self.name = name
16 | self.fmt = fmt
17 | self.reset()
18 |
19 | def reset(self):
20 | self.val = 0
21 | self.avg = 0
22 | self.sum = 0
23 | self.count = 0
24 |
25 | def update(self, val, n=1):
26 | self.val = val
27 | self.sum += val * n
28 | self.count += n
29 | if self.count > 0:
30 | self.avg = self.sum / self.count
31 |
32 | def __str__(self):
33 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
34 | return fmtstr.format(**self.__dict__)
35 |
36 |
37 | class AverageMeterDict(object):
38 | def __init__(self, names: List, fmt: Optional[str] = ':f'):
39 | self.dict = {
40 | name: AverageMeter(name, fmt) for name in names
41 | }
42 |
43 | def reset(self):
44 | for meter in self.dict.values():
45 | meter.reset()
46 |
47 | def update(self, accuracies, n=1):
48 | for name, acc in accuracies.items():
49 | self.dict[name].update(acc, n)
50 |
51 | def average(self):
52 | return {
53 | name: meter.avg for name, meter in self.dict.items()
54 | }
55 |
56 | def __getitem__(self, item):
57 | return self.dict[item]
58 |
59 |
60 | class Meter(object):
61 | """Computes and stores the current value."""
62 | def __init__(self, name: str, fmt: Optional[str] = ':f'):
63 | self.name = name
64 | self.fmt = fmt
65 | self.reset()
66 |
67 | def reset(self):
68 | self.val = 0
69 |
70 | def update(self, val):
71 | self.val = val
72 |
73 | def __str__(self):
74 | fmtstr = '{name} {val' + self.fmt + '}'
75 | return fmtstr.format(**self.__dict__)
76 |
77 |
78 | class ProgressMeter(object):
79 | def __init__(self, num_batches, meters, prefix=""):
80 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
81 | self.meters = meters
82 | self.prefix = prefix
83 |
84 | def display(self, batch):
85 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
86 | entries += [str(meter) for meter in self.meters]
87 | print('\t'.join(entries))
88 |
89 | def _get_batch_fmtstr(self, num_batches):
90 | num_digits = len(str(num_batches // 1))
91 | fmt = '{:' + str(num_digits) + 'd}'
92 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
93 |
94 |
95 |
--------------------------------------------------------------------------------
/common/utils/metric/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import prettytable
3 |
4 | __all__ = ['keypoint_detection']
5 |
6 | def binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
7 | """Computes the accuracy for binary classification"""
8 | with torch.no_grad():
9 | batch_size = target.size(0)
10 | pred = (output >= 0.5).float().t().view(-1)
11 | correct = pred.eq(target.view(-1)).float().sum()
12 | correct.mul_(100. / batch_size)
13 | return correct
14 |
15 |
16 | def accuracy(output, target, topk=(1,)):
17 | r"""
18 | Computes the accuracy over the k top predictions for the specified values of k
19 |
20 | Args:
21 | output (tensor): Classification outputs, :math:`(N, C)` where `C = number of classes`
22 | target (tensor): :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`
23 | topk (sequence[int]): A list of top-N number.
24 |
25 | Returns:
26 | Top-N accuracies (N :math:`\in` topK).
27 | """
28 | with torch.no_grad():
29 | maxk = max(topk)
30 | batch_size = target.size(0)
31 |
32 | _, pred = output.topk(maxk, 1, True, True)
33 | pred = pred.t()
34 | correct = pred.eq(target[None])
35 |
36 | res = []
37 | for k in topk:
38 | correct_k = correct[:k].flatten().sum(dtype=torch.float32)
39 | res.append(correct_k * (100.0 / batch_size))
40 | return res
41 |
42 |
43 | class ConfusionMatrix(object):
44 | def __init__(self, num_classes):
45 | self.num_classes = num_classes
46 | self.mat = None
47 |
48 | def update(self, target, output):
49 | """
50 | Update confusion matrix.
51 |
52 | Args:
53 | target: ground truth
54 | output: predictions of models
55 |
56 | Shape:
57 | - target: :math:`(minibatch, C)` where C means the number of classes.
58 | - output: :math:`(minibatch, C)` where C means the number of classes.
59 | """
60 | n = self.num_classes
61 | if self.mat is None:
62 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
63 | with torch.no_grad():
64 | k = (target >= 0) & (target < n)
65 | inds = n * target[k].to(torch.int64) + output[k]
66 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
67 |
68 | def reset(self):
69 | self.mat.zero_()
70 |
71 | def compute(self):
72 | """compute global accuracy, per-class accuracy and per-class IoU"""
73 | h = self.mat.float()
74 | acc_global = torch.diag(h).sum() / h.sum()
75 | acc = torch.diag(h) / h.sum(1)
76 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
77 | return acc_global, acc, iu
78 |
79 | # def reduce_from_all_processes(self):
80 | # if not torch.distributed.is_available():
81 | # return
82 | # if not torch.distributed.is_initialized():
83 | # return
84 | # torch.distributed.barrier()
85 | # torch.distributed.all_reduce(self.mat)
86 |
87 | def __str__(self):
88 | acc_global, acc, iu = self.compute()
89 | return (
90 | 'global correct: {:.1f}\n'
91 | 'average row correct: {}\n'
92 | 'IoU: {}\n'
93 | 'mean IoU: {:.1f}').format(
94 | acc_global.item() * 100,
95 | ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
96 | ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
97 | iu.mean().item() * 100)
98 |
99 | def format(self, classes: list):
100 | """Get the accuracy and IoU for each class in the table format"""
101 | acc_global, acc, iu = self.compute()
102 |
103 | table = prettytable.PrettyTable(["class", "acc", "iou"])
104 | for i, class_name, per_acc, per_iu in zip(range(len(classes)), classes, (acc * 100).tolist(), (iu * 100).tolist()):
105 | table.add_row([class_name, per_acc, per_iu])
106 |
107 | return 'global correct: {:.1f}\nmean correct:{:.1f}\nmean IoU: {:.1f}\n{}'.format(
108 | acc_global.item() * 100, acc.mean().item() * 100, iu.mean().item() * 100, table.get_string())
109 |
110 |
--------------------------------------------------------------------------------
/common/utils/sam.py:
--------------------------------------------------------------------------------
1 | # Credits: https://github.com/davda54/sam
2 |
3 | import torch
4 |
5 | class SAM(torch.optim.Optimizer):
6 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
7 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
8 |
9 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
10 | super(SAM, self).__init__(params, defaults)
11 |
12 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
13 | self.param_groups = self.base_optimizer.param_groups
14 |
15 | @torch.no_grad()
16 | def first_step(self, zero_grad=False):
17 | grad_norm = self._grad_norm()
18 | for group in self.param_groups:
19 | scale = group["rho"] / (grad_norm + 1e-12)
20 |
21 | for p in group["params"]:
22 | if p.grad is None: continue
23 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
24 | p.add_(e_w) # climb to the local maximum "w + e(w)"
25 | self.state[p]["e_w"] = e_w
26 |
27 | if zero_grad: self.zero_grad()
28 |
29 | @torch.no_grad()
30 | def second_step(self, zero_grad=False):
31 | for group in self.param_groups:
32 | for p in group["params"]:
33 | if p.grad is None: continue
34 | p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
35 |
36 | self.base_optimizer.step() # do the actual "sharpness-aware" update
37 |
38 | if zero_grad: self.zero_grad()
39 |
40 | @torch.no_grad()
41 | def step(self, closure=None):
42 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
43 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
44 |
45 | self.first_step(zero_grad=True)
46 | closure()
47 | self.second_step()
48 |
49 | def _grad_norm(self):
50 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
51 | norm = torch.norm(
52 | torch.stack([
53 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
54 | for group in self.param_groups for p in group["params"]
55 | if p.grad is not None
56 | ]),
57 | p=2
58 | )
59 | return norm
60 |
61 |
--------------------------------------------------------------------------------
/common/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bisect import bisect_right
3 |
4 |
5 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
6 | r"""Starts with a warm-up phase, then decays the learning rate of each parameter group by gamma once the
7 | number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | milestones (list): List of epoch indices. Must be increasing.
12 | gamma (float): Multiplicative factor of learning rate decay.
13 | Default: 0.1.
14 | warmup_factor (float): a float number :math:`k` between 0 and 1, the start learning rate of warmup phase
15 | will be set to :math:`k*initial\_lr`
16 | warmup_steps (int): number of warm-up steps.
17 | warmup_method (str): "constant" denotes a constant learning rate during warm-up phase and "linear" denotes a
18 | linear-increasing learning rate during warm-up phase.
19 | last_epoch (int): The index of last epoch. Default: -1.
20 | """
21 |
22 | def __init__(
23 | self,
24 | optimizer,
25 | milestones,
26 | gamma=0.1,
27 | warmup_factor=1.0 / 3,
28 | warmup_steps=500,
29 | warmup_method="linear",
30 | last_epoch=-1,
31 | ):
32 | if not list(milestones) == sorted(milestones):
33 | raise ValueError(
34 | "Milestones should be a list of" " increasing integers. Got {}",
35 | milestones,
36 | )
37 |
38 | if warmup_method not in ("constant", "linear"):
39 | raise ValueError(
40 | "Only 'constant' or 'linear' warmup_method accepted"
41 | "got {}".format(warmup_method)
42 | )
43 | self.milestones = milestones
44 | self.gamma = gamma
45 | self.warmup_factor = warmup_factor
46 | self.warmup_steps = warmup_steps
47 | self.warmup_method = warmup_method
48 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
49 |
50 | def get_lr(self):
51 | warmup_factor = 1
52 | if self.last_epoch < self.warmup_steps:
53 | if self.warmup_method == "constant":
54 | warmup_factor = self.warmup_factor
55 | elif self.warmup_method == "linear":
56 | alpha = float(self.last_epoch) / float(self.warmup_steps)
57 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
58 | return [
59 | base_lr
60 | * warmup_factor
61 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
62 | for base_lr in self.base_lrs
63 | ]
64 |
--------------------------------------------------------------------------------
/common/vision/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .officehome import OfficeHome
2 | from .visda2017 import VisDA2017
3 | from .domainnet import DomainNet
4 |
5 | __all__ = ['OfficeHome', "VisDA2017", "DomainNet"]
6 |
7 |
8 |
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/_util.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/_util.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/_util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/_util.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/aircrafts.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/aircrafts.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/aircrafts.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/aircrafts.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/aircrafts.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/aircrafts.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/coco70.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/coco70.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/coco70.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/coco70.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/coco70.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/coco70.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/cub200.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/cub200.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/cub200.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/cub200.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/cub200.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/cub200.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/digits.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/digits.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/digits.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/digits.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/digits.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/digits.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/domainnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/domainnet.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/domainnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/domainnet.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/domainnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/domainnet.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/dtd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/dtd.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/dtd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/dtd.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/dtd.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/dtd.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/eurosat.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/eurosat.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/eurosat.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/eurosat.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/eurosat.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/eurosat.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagelist.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagelist.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagelist.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagelist.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagelist.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagelist.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagenet_r.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_r.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagenet_r.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_r.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagenet_r.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_r.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagenet_sketch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_sketch.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagenet_sketch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_sketch.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/imagenet_sketch.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/imagenet_sketch.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/office31.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/office31.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/office31.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/office31.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/office31.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/office31.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/officecaltech.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officecaltech.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/officecaltech.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officecaltech.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/officecaltech.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officecaltech.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/officehome.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officehome.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/officehome.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officehome.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/officehome.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/officehome.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/oxfordflowers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordflowers.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/oxfordflowers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordflowers.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/oxfordflowers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordflowers.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/oxfordpet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordpet.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/oxfordpet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordpet.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/oxfordpet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/oxfordpet.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/pacs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/pacs.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/pacs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/pacs.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/pacs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/pacs.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/patchcamelyon.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/patchcamelyon.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/patchcamelyon.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/patchcamelyon.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/patchcamelyon.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/patchcamelyon.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/resisc45.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/resisc45.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/resisc45.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/resisc45.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/resisc45.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/resisc45.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/retinopathy.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/retinopathy.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/retinopathy.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/retinopathy.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/retinopathy.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/retinopathy.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/stanford_cars.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_cars.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/stanford_cars.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_cars.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/stanford_cars.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_cars.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/stanford_dogs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_dogs.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/stanford_dogs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_dogs.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/stanford_dogs.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/stanford_dogs.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/visda2017.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/visda2017.cpython-36.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/visda2017.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/visda2017.cpython-38.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/__pycache__/visda2017.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/val-iisc/SDAT/33d7429d37522972edda608cb802a686e4b9e794/common/vision/datasets/__pycache__/visda2017.cpython-39.pyc
--------------------------------------------------------------------------------
/common/vision/datasets/_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 | from torchvision.datasets.utils import download_and_extract_archive
4 |
5 |
6 | def download(root: str, file_name: str, archive_name: str, url_link: str):
7 | """
8 | Download file from internet url link.
9 |
10 | Args:
11 | root (str) The directory to put downloaded files.
12 | file_name: (str) The name of the unzipped file.
13 | archive_name: (str) The name of archive(zipped file) downloaded.
14 | url_link: (str) The url link to download data.
15 |
16 | .. note::
17 | If `file_name` already exists under path `root`, then it is not downloaded again.
18 | Else `archive_name` will be downloaded from `url_link` and extracted to `file_name`.
19 | """
20 | if not os.path.exists(os.path.join(root, file_name)):
21 | print("Downloading {}".format(file_name))
22 | # if os.path.exists(os.path.join(root, archive_name)):
23 | # os.remove(os.path.join(root, archive_name))
24 | try:
25 | download_and_extract_archive(url_link, download_root=root, filename=archive_name, remove_finished=False)
26 | except Exception:
27 | print("Fail to download {} from url link {}".format(archive_name, url_link))
28 | print('Please check you internet connection.'
29 | "Simply trying again may be fine.")
30 | exit(0)
31 |
32 |
33 | def check_exits(root: str, file_name: str):
34 | """Check whether `file_name` exists under directory `root`. """
35 | if not os.path.exists(os.path.join(root, file_name)):
36 | print("Dataset directory {} not found under {}".format(file_name, root))
37 | exit(-1)
38 |
39 |
40 | def read_list_from_file(file_name: str) -> List[str]:
41 | """Read data from file and convert each line into an element in the list"""
42 | result = []
43 | with open(file_name, "r") as f:
44 | for line in f.readlines():
45 | result.append(line.strip())
46 | return result
47 |
--------------------------------------------------------------------------------
/common/vision/datasets/domainnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 | from .imagelist import ImageList
4 | from ._util import download as download_data, check_exits
5 |
6 |
7 | class DomainNet(ImageList):
8 | """`DomainNet `_ (cleaned version, recommended)
9 |
10 | See `Moment Matching for Multi-Source Domain Adaptation `_ for details.
11 |
12 | Args:
13 | root (str): Root directory of dataset
14 | task (str): The task (domain) to create dataset. Choices include ``'c'``:clipart, \
15 | ``'i'``: infograph, ``'p'``: painting, ``'q'``: quickdraw, ``'r'``: real, ``'s'``: sketch
16 | split (str, optional): The dataset split, supports ``train``, or ``test``.
17 | download (bool, optional): If true, downloads the dataset from the internet and puts it \
18 | in root directory. If dataset is already downloaded, it is not downloaded again.
19 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \
20 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
21 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
22 |
23 | .. note:: In `root`, there will exist following files after downloading.
24 | ::
25 | clipart/
26 | infograph/
27 | painting/
28 | quickdraw/
29 | real/
30 | sketch/
31 | image_list/
32 | clipart.txt
33 | ...
34 | """
35 | download_list = [
36 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/bf0fe327e4b046eb89ba/?dl=1"),
37 | ("clipart", "clipart.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip"),
38 | ("infograph", "infograph.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip"),
39 | ("painting", "painting.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip"),
40 | ("quickdraw", "quickdraw.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip"),
41 | ("real", "real.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip"),
42 | ("sketch", "sketch.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip"),
43 | ]
44 | image_list = {
45 | "c": "clipart",
46 | "i": "infograph",
47 | "p": "painting",
48 | "q": "quickdraw",
49 | "r": "real",
50 | "s": "sketch",
51 | }
52 | CLASSES = ['aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil',
53 | 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball_bat',
54 | 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench',
55 | 'bicycle', 'binoculars', 'bird', 'birthday_cake', 'blackberry', 'blueberry', 'book', 'boomerang',
56 | 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket',
57 | 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera',
58 | 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling_fan',
59 | 'cello', 'cell_phone', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud',
60 | 'coffee_cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile',
61 | 'crown', 'cruise_ship', 'cup', 'diamond', 'dishwasher', 'diving_board', 'dog', 'dolphin', 'donut',
62 | 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant',
63 | 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire_hydrant',
64 | 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip_flops', 'floor_lamp', 'flower',
65 | 'flying_saucer', 'foot', 'fork', 'frog', 'frying_pan', 'garden', 'garden_hose', 'giraffe', 'goatee',
66 | 'golf_club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones',
67 | 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey_puck', 'hockey_stick', 'horse', 'hospital',
68 | 'hot_air_balloon', 'hot_dog', 'hot_tub', 'hourglass', 'house', 'house_plant', 'hurricane', 'ice_cream',
69 | 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf',
70 | 'leg', 'light_bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster',
71 | 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave',
72 | 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom',
73 | 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paintbrush',
74 | 'paint_can', 'palm_tree', 'panda', 'pants', 'paper_clip', 'parachute', 'parrot', 'passport', 'peanut',
75 | 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup_truck', 'picture_frame', 'pig', 'pillow',
76 | 'pineapple', 'pizza', 'pliers', 'police_car', 'pond', 'pool', 'popsicle', 'postcard', 'potato',
77 | 'power_outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote_control',
78 | 'rhinoceros', 'rifle', 'river', 'roller_coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw',
79 | 'saxophone', 'school_bus', 'scissors', 'scorpion', 'screwdriver', 'sea_turtle', 'see_saw', 'shark',
80 | 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping_bag',
81 | 'smiley_face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer_ball', 'sock', 'speedboat',
82 | 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo',
83 | 'stethoscope', 'stitches', 'stop_sign', 'stove', 'strawberry', 'streetlight', 'string_bean', 'submarine',
84 | 'suitcase', 'sun', 'swan', 'sweater', 'swing_set', 'sword', 'syringe', 'table', 'teapot', 'teddy-bear',
85 | 'telephone', 'television', 'tennis_racquet', 'tent', 'The_Eiffel_Tower', 'The_Great_Wall_of_China',
86 | 'The_Mona_Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado',
87 | 'tractor', 'traffic_light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 't-shirt',
88 | 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing_machine', 'watermelon', 'waterslide',
89 | 'whale', 'wheel', 'windmill', 'wine_bottle', 'wine_glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']
90 |
91 | def __init__(self, root: str, task: str, split: Optional[str] = 'train', download: Optional[float] = False, **kwargs):
92 | assert task in self.image_list
93 | assert split in ['train', 'test']
94 | data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split))
95 | print("loading {}".format(data_list_file))
96 |
97 | if download:
98 | list(map(lambda args: download_data(root, *args), self.download_list))
99 | else:
100 | list(map(lambda args: check_exits(root, args[0]), self.download_list))
101 |
102 | super(DomainNet, self).__init__(root, DomainNet.CLASSES, data_list_file=data_list_file, **kwargs)
103 |
104 | @classmethod
105 | def domains(cls):
106 | return list(cls.image_list.keys())
107 |
--------------------------------------------------------------------------------
/common/vision/datasets/imagelist.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, Callable, Tuple, Any, List
3 | import torchvision.datasets as datasets
4 | from torchvision.datasets.folder import default_loader
5 |
6 |
7 | class ImageList(datasets.VisionDataset):
8 | """A generic Dataset class for image classification
9 |
10 | Args:
11 | root (str): Root directory of dataset
12 | classes (list[str]): The names of all the classes
13 | data_list_file (str): File to read the image list from.
14 | transform (callable, optional): A function/transform that takes in an PIL image \
15 | and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
16 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
17 |
18 | .. note:: In `data_list_file`, each line has 2 values in the following format.
19 | ::
20 | source_dir/dog_xxx.png 0
21 | source_dir/cat_123.png 1
22 | target_dir/dog_xxy.png 0
23 | target_dir/cat_nsdf3.png 1
24 |
25 | The first value is the relative path of an image, and the second value is the label of the corresponding image.
26 | If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`.
27 | """
28 |
29 | def __init__(self, root: str, classes: List[str], data_list_file: str,
30 | transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
31 | super().__init__(root, transform=transform, target_transform=target_transform)
32 | self.samples = self.parse_data_file(data_list_file)
33 | self.classes = classes
34 | self.class_to_idx = {cls: idx
35 | for idx, cls in enumerate(self.classes)}
36 | self.loader = default_loader
37 | self.data_list_file = data_list_file
38 |
39 | def __getitem__(self, index: int) -> Tuple[Any, int]:
40 | """
41 | Args:
42 | index (int): Index
43 | return (tuple): (image, target) where target is index of the target class.
44 | """
45 | path, target = self.samples[index]
46 | img = self.loader(path)
47 | if self.transform is not None:
48 | img = self.transform(img)
49 | if self.target_transform is not None and target is not None:
50 | target = self.target_transform(target)
51 | return img, target
52 |
53 | def __len__(self) -> int:
54 | return len(self.samples)
55 |
56 | def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]:
57 | """Parse file to data list
58 |
59 | Args:
60 | file_name (str): The path of data file
61 | return (list): List of (image path, class_index) tuples
62 | """
63 | with open(file_name, "r") as f:
64 | data_list = []
65 | for line in f.readlines():
66 | split_line = line.split()
67 | target = split_line[-1]
68 | path = ' '.join(split_line[:-1])
69 | if not os.path.isabs(path):
70 | path = os.path.join(self.root, path)
71 | target = int(target)
72 | data_list.append((path, target))
73 | return data_list
74 |
75 | @property
76 | def num_classes(self) -> int:
77 | """Number of classes"""
78 | return len(self.classes)
79 |
80 | @classmethod
81 | def domains(cls):
82 | """All possible domain in this dataset"""
83 | raise NotImplemented
84 |
--------------------------------------------------------------------------------
/common/vision/datasets/officehome.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 | from .imagelist import ImageList
4 | from ._util import download as download_data, check_exits
5 |
6 |
7 | class OfficeHome(ImageList):
8 | """`OfficeHome `_ Dataset.
9 |
10 | Args:
11 | root (str): Root directory of dataset
12 | task (str): The task (domain) to create dataset. Choices include ``'Ar'``: Art, \
13 | ``'Cl'``: Clipart, ``'Pr'``: Product and ``'Rw'``: Real_World.
14 | download (bool, optional): If true, downloads the dataset from the internet and puts it \
15 | in root directory. If dataset is already downloaded, it is not downloaded again.
16 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \
17 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
18 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
19 |
20 | .. note:: In `root`, there will exist following files after downloading.
21 | ::
22 | Art/
23 | Alarm_Clock/*.jpg
24 | ...
25 | Clipart/
26 | Product/
27 | Real_World/
28 | image_list/
29 | Art.txt
30 | Clipart.txt
31 | Product.txt
32 | Real_World.txt
33 | """
34 | download_list = [
35 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/1b0171a188944313b1f5/?dl=1"),
36 | ("Art", "Art.tgz", "https://cloud.tsinghua.edu.cn/f/6a006656b9a14567ade2/?dl=1"),
37 | ("Clipart", "Clipart.tgz", "https://cloud.tsinghua.edu.cn/f/ae88aa31d2d7411dad79/?dl=1"),
38 | ("Product", "Product.tgz", "https://cloud.tsinghua.edu.cn/f/f219b0ff35e142b3ab48/?dl=1"),
39 | ("Real_World", "Real_World.tgz", "https://cloud.tsinghua.edu.cn/f/6c19f3f15bb24ed3951a/?dl=1")
40 | ]
41 | image_list = {
42 | "Ar": "image_list/Art.txt",
43 | "Cl": "image_list/Clipart.txt",
44 | "Pr": "image_list/Product.txt",
45 | "Rw": "image_list/Real_World.txt",
46 | }
47 | CLASSES = ['Drill', 'Exit_Sign', 'Bottle', 'Glasses', 'Computer', 'File_Cabinet', 'Shelf', 'Toys', 'Sink',
48 | 'Laptop', 'Kettle', 'Folder', 'Keyboard', 'Flipflops', 'Pencil', 'Bed', 'Hammer', 'ToothBrush', 'Couch',
49 | 'Bike', 'Postit_Notes', 'Mug', 'Webcam', 'Desk_Lamp', 'Telephone', 'Helmet', 'Mouse', 'Pen', 'Monitor',
50 | 'Mop', 'Sneakers', 'Notebook', 'Backpack', 'Alarm_Clock', 'Push_Pin', 'Paper_Clip', 'Batteries', 'Radio',
51 | 'Fan', 'Ruler', 'Pan', 'Screwdriver', 'Trash_Can', 'Printer', 'Speaker', 'Eraser', 'Bucket', 'Chair',
52 | 'Calendar', 'Calculator', 'Flowers', 'Lamp_Shade', 'Spoon', 'Candles', 'Clipboards', 'Scissors', 'TV',
53 | 'Curtains', 'Fork', 'Soda', 'Table', 'Knives', 'Oven', 'Refrigerator', 'Marker']
54 |
55 | def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):
56 | assert task in self.image_list
57 | data_list_file = os.path.join(root, self.image_list[task])
58 |
59 | if download:
60 | list(map(lambda args: download_data(root, *args), self.download_list))
61 | else:
62 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
63 |
64 | super(OfficeHome, self).__init__(root, OfficeHome.CLASSES, data_list_file=data_list_file, **kwargs)
65 |
66 | @classmethod
67 | def domains(cls):
68 | return list(cls.image_list.keys())
69 |
--------------------------------------------------------------------------------
/common/vision/datasets/visda2017.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 | from .imagelist import ImageList
4 | from ._util import download as download_data, check_exits
5 |
6 |
7 | class VisDA2017(ImageList):
8 | """`VisDA-2017 `_ Dataset
9 |
10 | Args:
11 | root (str): Root directory of dataset
12 | task (str): The task (domain) to create dataset. Choices include ``'Synthetic'``: synthetic images and \
13 | ``'Real'``: real-world images.
14 | download (bool, optional): If true, downloads the dataset from the internet and puts it \
15 | in root directory. If dataset is already downloaded, it is not downloaded again.
16 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \
17 | transformed version. E.g, ``transforms.RandomCrop``.
18 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
19 |
20 | .. note:: In `root`, there will exist following files after downloading.
21 | ::
22 | train/
23 | aeroplance/
24 | *.png
25 | ...
26 | validation/
27 | image_list/
28 | train.txt
29 | validation.txt
30 | """
31 | download_list = [
32 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/c107de37b8094c5398dc/?dl=1"),
33 | ("train", "train.tar", "http://csr.bu.edu/ftp/visda17/clf/train.tar"),
34 | ("validation", "validation.tar", "http://csr.bu.edu/ftp/visda17/clf/validation.tar")
35 | ]
36 | image_list = {
37 | "Synthetic": "image_list/train.txt",
38 | "Real": "image_list/validation.txt"
39 | }
40 | CLASSES = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife',
41 | 'motorcycle', 'person', 'plant', 'skateboard', 'train', 'truck']
42 |
43 | def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):
44 | assert task in self.image_list
45 | data_list_file = os.path.join(root, self.image_list[task])
46 |
47 | if download:
48 | list(map(lambda args: download_data(root, *args), self.download_list))
49 | else:
50 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
51 |
52 | super(VisDA2017, self).__init__(root, VisDA2017.CLASSES, data_list_file=data_list_file, **kwargs)
53 |
54 | @classmethod
55 | def domains(cls):
56 | return list(cls.image_list.keys())
57 |
--------------------------------------------------------------------------------
/common/vision/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 |
3 | __all__ = ['resnet']
4 |
--------------------------------------------------------------------------------
/common/vision/models/resnet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | from torchvision import models
4 | from torchvision.models.utils import load_state_dict_from_url
5 | #from torch.hub import load_state_dict_from_url
6 | from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls
7 | import copy
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
11 | 'wide_resnet50_2', 'wide_resnet101_2']
12 |
13 |
14 | class ResNet(models.ResNet):
15 | """ResNets without fully connected layer"""
16 |
17 | def __init__(self, *args, **kwargs):
18 | super(ResNet, self).__init__(*args, **kwargs)
19 | self._out_features = self.fc.in_features
20 |
21 | def forward(self, x):
22 | """"""
23 | x = self.conv1(x)
24 | x = self.bn1(x)
25 | x = self.relu(x)
26 | x = self.maxpool(x)
27 |
28 | x = self.layer1(x)
29 | x = self.layer2(x)
30 | x = self.layer3(x)
31 | x = self.layer4(x)
32 |
33 | return x
34 |
35 | @property
36 | def out_features(self) -> int:
37 | """The dimension of output features"""
38 | return self._out_features
39 |
40 | def copy_head(self) -> nn.Module:
41 | """Copy the origin fully connected layer"""
42 | return copy.deepcopy(self.fc)
43 |
44 |
45 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
46 | model = ResNet(block, layers, **kwargs)
47 | if pretrained:
48 | model_dict = model.state_dict()
49 | pretrained_dict = load_state_dict_from_url(model_urls[arch],
50 | progress=progress)
51 | # remove keys from pretrained dict that doesn't appear in model dict
52 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
53 | model.load_state_dict(pretrained_dict, strict=False)
54 | return model
55 |
56 |
57 | def resnet18(pretrained=False, progress=True, **kwargs):
58 | r"""ResNet-18 model from
59 | `"Deep Residual Learning for Image Recognition" `_
60 |
61 | Args:
62 | pretrained (bool): If True, returns a model pre-trained on ImageNet
63 | progress (bool): If True, displays a progress bar of the download to stderr
64 | """
65 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
66 | **kwargs)
67 |
68 |
69 | def resnet34(pretrained=False, progress=True, **kwargs):
70 | r"""ResNet-34 model from
71 | `"Deep Residual Learning for Image Recognition" `_
72 |
73 | Args:
74 | pretrained (bool): If True, returns a model pre-trained on ImageNet
75 | progress (bool): If True, displays a progress bar of the download to stderr
76 | """
77 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
78 | **kwargs)
79 |
80 |
81 | def resnet50(pretrained=False, progress=True, **kwargs):
82 | r"""ResNet-50 model from
83 | `"Deep Residual Learning for Image Recognition" `_
84 |
85 | Args:
86 | pretrained (bool): If True, returns a model pre-trained on ImageNet
87 | progress (bool): If True, displays a progress bar of the download to stderr
88 | """
89 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
90 | **kwargs)
91 |
92 |
93 | def resnet101(pretrained=False, progress=True, **kwargs):
94 | r"""ResNet-101 model from
95 | `"Deep Residual Learning for Image Recognition" `_
96 |
97 | Args:
98 | pretrained (bool): If True, returns a model pre-trained on ImageNet
99 | progress (bool): If True, displays a progress bar of the download to stderr
100 | """
101 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
102 | **kwargs)
103 |
104 |
105 | def resnet152(pretrained=False, progress=True, **kwargs):
106 | r"""ResNet-152 model from
107 | `"Deep Residual Learning for Image Recognition" `_
108 |
109 | Args:
110 | pretrained (bool): If True, returns a model pre-trained on ImageNet
111 | progress (bool): If True, displays a progress bar of the download to stderr
112 | """
113 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
114 | **kwargs)
115 |
116 |
117 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
118 | r"""ResNeXt-50 32x4d model from
119 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
120 |
121 | Args:
122 | pretrained (bool): If True, returns a model pre-trained on ImageNet
123 | progress (bool): If True, displays a progress bar of the download to stderr
124 | """
125 | kwargs['groups'] = 32
126 | kwargs['width_per_group'] = 4
127 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
128 | pretrained, progress, **kwargs)
129 |
130 |
131 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
132 | r"""ResNeXt-101 32x8d model from
133 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
134 |
135 | Args:
136 | pretrained (bool): If True, returns a model pre-trained on ImageNet
137 | progress (bool): If True, displays a progress bar of the download to stderr
138 | """
139 | kwargs['groups'] = 32
140 | kwargs['width_per_group'] = 8
141 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
142 | pretrained, progress, **kwargs)
143 |
144 |
145 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
146 | r"""Wide ResNet-50-2 model from
147 | `"Wide Residual Networks" `_
148 |
149 | The model is the same as ResNet except for the bottleneck number of channels
150 | which is twice larger in every block. The number of channels in outer 1x1
151 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
152 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
153 |
154 | Args:
155 | pretrained (bool): If True, returns a model pre-trained on ImageNet
156 | progress (bool): If True, displays a progress bar of the download to stderr
157 | """
158 | kwargs['width_per_group'] = 64 * 2
159 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
160 | pretrained, progress, **kwargs)
161 |
162 |
163 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
164 | r"""Wide ResNet-101-2 model from
165 | `"Wide Residual Networks" `_
166 |
167 | The model is the same as ResNet except for the bottleneck number of channels
168 | which is twice larger in every block. The number of channels in outer 1x1
169 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
170 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
171 |
172 | Args:
173 | pretrained (bool): If True, returns a model pre-trained on ImageNet
174 | progress (bool): If True, displays a progress bar of the download to stderr
175 | """
176 | kwargs['width_per_group'] = 64 * 2
177 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
178 | pretrained, progress, **kwargs)
179 |
--------------------------------------------------------------------------------
/common/vision/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | from PIL import Image
4 | import numpy as np
5 | import torch
6 | from torchvision.transforms import Normalize
7 |
8 |
9 | class ResizeImage(object):
10 | """Resize the input PIL Image to the given size.
11 |
12 | Args:
13 | size (sequence or int): Desired output size. If size is a sequence like
14 | (h, w), output size will be matched to this. If size is an int,
15 | output size will be (size, size)
16 | """
17 |
18 | def __init__(self, size):
19 | if isinstance(size, int):
20 | self.size = (int(size), int(size))
21 | else:
22 | self.size = size
23 |
24 | def __call__(self, img):
25 | th, tw = self.size
26 | return img.resize((th, tw))
27 |
28 | def __repr__(self):
29 | return self.__class__.__name__ + '(size={0})'.format(self.size)
30 |
31 |
32 | class MultipleApply:
33 | """Apply a list of transformations to an image and get multiple transformed images.
34 |
35 | Args:
36 | transforms (list or tuple): list of transformations
37 |
38 | Example:
39 |
40 | >>> transform1 = T.Compose([
41 | ... ResizeImage(256),
42 | ... T.RandomCrop(224)
43 | ... ])
44 | >>> transform2 = T.Compose([
45 | ... ResizeImage(256),
46 | ... T.RandomCrop(224),
47 | ... ])
48 | >>> multiply_transform = MultipleApply([transform1, transform2])
49 | """
50 |
51 | def __init__(self, transforms):
52 | self.transforms = transforms
53 |
54 | def __call__(self, image):
55 | return [t(image) for t in self.transforms]
56 |
57 | def __repr__(self):
58 | format_string = self.__class__.__name__ + '('
59 | for t in self.transforms:
60 | format_string += '\n'
61 | format_string += ' {0}'.format(t)
62 | format_string += '\n)'
63 | return format_string
64 |
65 |
66 | class Denormalize(Normalize):
67 | """DeNormalize a tensor image with mean and standard deviation.
68 | Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
69 | channels, this transform will denormalize each channel of the input
70 | ``torch.*Tensor`` i.e.,
71 | ``output[channel] = input[channel] * std[channel] + mean[channel]``
72 |
73 | .. note::
74 | This transform acts out of place, i.e., it does not mutate the input tensor.
75 |
76 | Args:
77 | mean (sequence): Sequence of means for each channel.
78 | std (sequence): Sequence of standard deviations for each channel.
79 |
80 | """
81 |
82 | def __init__(self, mean, std):
83 | mean = np.array(mean)
84 | std = np.array(std)
85 | super().__init__((-mean / std).tolist(), (1 / std).tolist())
86 |
87 |
88 | class NormalizeAndTranspose:
89 | """
90 | First, normalize a tensor image with mean and standard deviation.
91 | Then, convert the shape (H x W x C) to shape (C x H x W).
92 | """
93 |
94 | def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)):
95 | self.mean = np.array(mean, dtype=np.float32)
96 |
97 | def __call__(self, image):
98 | if isinstance(image, Image.Image):
99 | image = np.asarray(image, np.float32)
100 | # change to BGR
101 | image = image[:, :, ::-1]
102 | # normalize
103 | image -= self.mean
104 | image = image.transpose((2, 0, 1)).copy()
105 | elif isinstance(image, torch.Tensor):
106 | # change to BGR
107 | image = image[:, :, [2, 1, 0]]
108 | # normalize
109 | image -= torch.from_numpy(self.mean).to(image.device)
110 | image = image.permute((2, 0, 1))
111 | else:
112 | raise NotImplementedError(type(image))
113 | return image
114 |
115 |
116 | class DeNormalizeAndTranspose:
117 | """
118 | First, convert a tensor image from the shape (C x H x W ) to shape (H x W x C).
119 | Then, denormalize it with mean and standard deviation.
120 | """
121 |
122 | def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)):
123 | self.mean = np.array(mean, dtype=np.float32)
124 |
125 | def __call__(self, image):
126 | image = image.transpose((1, 2, 0))
127 | # denormalize
128 | image += self.mean
129 | # change to RGB
130 | image = image[:, :, ::-1]
131 | return image
132 |
133 |
134 | class RandomErasing(object):
135 | """Random erasing augmentation from `Random Erasing Data Augmentation (CVPR 2017)
136 | `_. This augmentation randomly selects a rectangle region in an image
137 | and erases its pixels.
138 |
139 | Args:
140 | probability (float): The probability that the Random Erasing operation will be performed.
141 | sl (float): Minimum proportion of erased area against input image.
142 | sh (float): Maximum proportion of erased area against input image.
143 | r1 (float): Minimum aspect ratio of erased area.
144 | mean (sequence): Value to fill the erased area.
145 | """
146 |
147 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
148 | self.probability = probability
149 | self.mean = mean
150 | self.sl = sl
151 | self.sh = sh
152 | self.r1 = r1
153 |
154 | def __call__(self, img):
155 |
156 | if random.uniform(0, 1) >= self.probability:
157 | return img
158 |
159 | for attempt in range(100):
160 | area = img.size()[1] * img.size()[2]
161 |
162 | target_area = random.uniform(self.sl, self.sh) * area
163 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
164 |
165 | h = int(round(math.sqrt(target_area * aspect_ratio)))
166 | w = int(round(math.sqrt(target_area / aspect_ratio)))
167 |
168 | if w < img.size()[2] and h < img.size()[1]:
169 | x1 = random.randint(0, img.size()[1] - h)
170 | y1 = random.randint(0, img.size()[2] - w)
171 | if img.size()[0] == 3:
172 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
173 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
174 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
175 | else:
176 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
177 | return img
178 |
179 | return img
180 |
181 | def __repr__(self):
182 | return self.__class__.__name__ + '(p={})'.format(self.probability)
183 |
--------------------------------------------------------------------------------
/dalib/adaptation/__init__.py:
--------------------------------------------------------------------------------
1 | from . import cdan
2 | from . import mcc
3 |
4 | __all__ = ["cdan", "mcc"]
5 |
--------------------------------------------------------------------------------
/dalib/adaptation/cdan.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from common.modules.classifier import Classifier as ClassifierBase
8 | from common.utils.metric import binary_accuracy
9 | from ..modules.grl import WarmStartGradientReverseLayer
10 | from ..modules.entropy import entropy
11 |
12 |
13 | __all__ = ['ConditionalDomainAdversarialLoss', 'ImageClassifier']
14 |
15 |
16 | class ConditionalDomainAdversarialLoss(nn.Module):
17 | r"""The Conditional Domain Adversarial Loss used in `Conditional Adversarial Domain Adaptation (NIPS 2018) `_
18 |
19 | Conditional Domain adversarial loss measures the domain discrepancy through training a domain discriminator in a
20 | conditional manner. Given domain discriminator :math:`D`, feature representation :math:`f` and
21 | classifier predictions :math:`g`, the definition of CDAN loss is
22 |
23 | .. math::
24 | loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(T(f_i^s, g_i^s))] \\
25 | &+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(T(f_j^t, g_j^t))],\\
26 |
27 | where :math:`T` is a :class:`MultiLinearMap` or :class:`RandomizedMultiLinearMap` which convert two tensors to a single tensor.
28 |
29 | Args:
30 | domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of
31 | features. Its input shape is (N, F) and output shape is (N, 1)
32 | entropy_conditioning (bool, optional): If True, use entropy-aware weight to reweight each training example.
33 | Default: False
34 | randomized (bool, optional): If True, use `randomized multi linear map`. Else, use `multi linear map`.
35 | Default: False
36 | num_classes (int, optional): Number of classes. Default: -1
37 | features_dim (int, optional): Dimension of input features. Default: -1
38 | randomized_dim (int, optional): Dimension of features after randomized. Default: 1024
39 | reduction (str, optional): Specifies the reduction to apply to the output:
40 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
41 | ``'mean'``: the sum of the output will be divided by the number of
42 | elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
43 |
44 | .. note::
45 | You need to provide `num_classes`, `features_dim` and `randomized_dim` **only when** `randomized`
46 | is set True.
47 |
48 | Inputs:
49 | - g_s (tensor): unnormalized classifier predictions on source domain, :math:`g^s`
50 | - f_s (tensor): feature representations on source domain, :math:`f^s`
51 | - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`
52 | - f_t (tensor): feature representations on target domain, :math:`f^t`
53 |
54 | Shape:
55 | - g_s, g_t: :math:`(minibatch, C)` where C means the number of classes.
56 | - f_s, f_t: :math:`(minibatch, F)` where F means the dimension of input features.
57 | - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, )`.
58 |
59 | Examples::
60 |
61 | >>> from dalib.modules.domain_discriminator import DomainDiscriminator
62 | >>> from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss
63 | >>> import torch
64 | >>> num_classes = 2
65 | >>> feature_dim = 1024
66 | >>> batch_size = 10
67 | >>> discriminator = DomainDiscriminator(in_feature=feature_dim * num_classes, hidden_size=1024)
68 | >>> loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')
69 | >>> # features from source domain and target domain
70 | >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
71 | >>> # logits output from source domain adn target domain
72 | >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
73 | >>> output = loss(g_s, f_s, g_t, f_t)
74 | """
75 |
76 | def __init__(self, domain_discriminator: nn.Module, entropy_conditioning: Optional[bool] = False,
77 | randomized: Optional[bool] = False, num_classes: Optional[int] = -1,
78 | features_dim: Optional[int] = -1, randomized_dim: Optional[int] = 1024,
79 | reduction: Optional[str] = 'mean'):
80 | super(ConditionalDomainAdversarialLoss, self).__init__()
81 | self.domain_discriminator = domain_discriminator
82 | self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True)
83 | self.entropy_conditioning = entropy_conditioning
84 |
85 | if randomized:
86 | assert num_classes > 0 and features_dim > 0 and randomized_dim > 0
87 | self.map = RandomizedMultiLinearMap(features_dim, num_classes, randomized_dim)
88 | else:
89 | self.map = MultiLinearMap()
90 |
91 | self.bce = lambda input, target, weight: F.binary_cross_entropy(input, target, weight,
92 | reduction=reduction) if self.entropy_conditioning \
93 | else F.binary_cross_entropy(input, target, reduction=reduction)
94 | self.domain_discriminator_accuracy = None
95 |
96 | def forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:
97 | f = torch.cat((f_s, f_t), dim=0)
98 | g = torch.cat((g_s, g_t), dim=0)
99 | g = F.softmax(g, dim=1).detach()
100 | h = self.grl(self.map(f, g))
101 | d = self.domain_discriminator(h)
102 | d_label = torch.cat((
103 | torch.ones((g_s.size(0), 1)).to(g_s.device),
104 | torch.zeros((g_t.size(0), 1)).to(g_t.device),
105 | ))
106 | weight = 1.0 + torch.exp(-entropy(g))
107 | batch_size = f.size(0)
108 | weight = weight / torch.sum(weight) * batch_size
109 | self.domain_discriminator_accuracy = binary_accuracy(d, d_label)
110 | return self.bce(d, d_label, weight.view_as(d))
111 |
112 |
113 | class RandomizedMultiLinearMap(nn.Module):
114 | """Random multi linear map
115 |
116 | Given two inputs :math:`f` and :math:`g`, the definition is
117 |
118 | .. math::
119 | T_{\odot}(f,g) = \dfrac{1}{\sqrt{d}} (R_f f) \odot (R_g g),
120 |
121 | where :math:`\odot` is element-wise product, :math:`R_f` and :math:`R_g` are random matrices
122 | sampled only once and fixed in training.
123 |
124 | Args:
125 | features_dim (int): dimension of input :math:`f`
126 | num_classes (int): dimension of input :math:`g`
127 | output_dim (int, optional): dimension of output tensor. Default: 1024
128 |
129 | Shape:
130 | - f: (minibatch, features_dim)
131 | - g: (minibatch, num_classes)
132 | - Outputs: (minibatch, output_dim)
133 | """
134 |
135 | def __init__(self, features_dim: int, num_classes: int, output_dim: Optional[int] = 1024):
136 | super(RandomizedMultiLinearMap, self).__init__()
137 | self.Rf = torch.randn(features_dim, output_dim)
138 | self.Rg = torch.randn(num_classes, output_dim)
139 | self.output_dim = output_dim
140 |
141 | def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
142 | f = torch.mm(f, self.Rf.to(f.device))
143 | g = torch.mm(g, self.Rg.to(g.device))
144 | output = torch.mul(f, g) / np.sqrt(float(self.output_dim))
145 | return output
146 |
147 |
148 | class MultiLinearMap(nn.Module):
149 | """Multi linear map
150 |
151 | Shape:
152 | - f: (minibatch, F)
153 | - g: (minibatch, C)
154 | - Outputs: (minibatch, F * C)
155 | """
156 |
157 | def __init__(self):
158 | super(MultiLinearMap, self).__init__()
159 |
160 | def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
161 | batch_size = f.size(0)
162 | output = torch.bmm(g.unsqueeze(2), f.unsqueeze(1))
163 | return output.view(batch_size, -1)
164 |
165 |
166 | class ImageClassifier(ClassifierBase):
167 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
168 | bottleneck = nn.Sequential(
169 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)),
170 | # nn.Flatten(),
171 | nn.Linear(backbone.out_features, bottleneck_dim),
172 | nn.BatchNorm1d(bottleneck_dim),
173 | nn.ReLU()
174 | )
175 | super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
176 |
--------------------------------------------------------------------------------
/dalib/adaptation/mcc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from common.modules.classifier import Classifier as ClassifierBase
7 | from ..modules.entropy import entropy
8 |
9 |
10 | __all__ = ['MinimumClassConfusionLoss', 'ImageClassifier']
11 |
12 |
13 | class MinimumClassConfusionLoss(nn.Module):
14 | r"""
15 | Minimum Class Confusion loss minimizes the class confusion in the target predictions.
16 |
17 | You can see more details in `Minimum Class Confusion for Versatile Domain Adaptation (ECCV 2020) `_
18 |
19 | Args:
20 | temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if
21 | temperature is 1.0.
22 |
23 | .. note::
24 | Make sure that temperature is larger than 0.
25 |
26 | Inputs: g_t
27 | - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`
28 |
29 | Shape:
30 | - g_t: :math:`(minibatch, C)` where C means the number of classes.
31 | - Output: scalar.
32 |
33 | Examples::
34 | >>> temperature = 2.0
35 | >>> loss = MinimumClassConfusionLoss(temperature)
36 | >>> # logits output from target domain
37 | >>> g_t = torch.randn(batch_size, num_classes)
38 | >>> output = loss(g_t)
39 |
40 | MCC can also serve as a regularizer for existing methods.
41 | Examples::
42 | >>> from dalib.modules.domain_discriminator import DomainDiscriminator
43 | >>> num_classes = 2
44 | >>> feature_dim = 1024
45 | >>> batch_size = 10
46 | >>> temperature = 2.0
47 | >>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024)
48 | >>> cdan_loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')
49 | >>> mcc_loss = MinimumClassConfusionLoss(temperature)
50 | >>> # features from source domain and target domain
51 | >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
52 | >>> # logits output from source domain adn target domain
53 | >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
54 | >>> total_loss = cdan_loss(g_s, f_s, g_t, f_t) + mcc_loss(g_t)
55 | """
56 |
57 | def __init__(self, temperature: float):
58 | super(MinimumClassConfusionLoss, self).__init__()
59 | self.temperature = temperature
60 |
61 | def forward(self, logits: torch.Tensor) -> torch.Tensor:
62 | batch_size, num_classes = logits.shape
63 | predictions = F.softmax(logits / self.temperature, dim=1) # batch_size x num_classes
64 | entropy_weight = entropy(predictions).detach()
65 | entropy_weight = 1 + torch.exp(-entropy_weight)
66 | entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1) # batch_size x 1
67 | class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes
68 | class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1)
69 | mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes
70 | return mcc_loss
71 |
72 |
73 | class ImageClassifier(ClassifierBase):
74 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
75 | bottleneck = nn.Sequential(
76 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)),
77 | # nn.Flatten(),
78 | nn.Linear(backbone.out_features, bottleneck_dim),
79 | nn.BatchNorm1d(bottleneck_dim),
80 | nn.ReLU()
81 | )
82 | super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
83 |
--------------------------------------------------------------------------------
/dalib/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .grl import *
2 | from .domain_discriminator import *
3 | from .kernels import *
4 | from .entropy import *
5 |
6 | __all__ = ['grl', 'kernels', 'domain_discriminator', 'entropy']
7 |
--------------------------------------------------------------------------------
/dalib/modules/domain_discriminator.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Junguang Jiang
3 | @contact: JiangJunguang1123@outlook.com
4 | """
5 | from typing import List, Dict
6 | import torch.nn as nn
7 |
8 | __all__ = ['DomainDiscriminator']
9 |
10 |
11 | class DomainDiscriminator(nn.Sequential):
12 | r"""Domain discriminator model from
13 | `Domain-Adversarial Training of Neural Networks (ICML 2015) `_
14 |
15 | Distinguish whether the input features come from the source domain or the target domain.
16 | The source domain label is 1 and the target domain label is 0.
17 |
18 | Args:
19 | in_feature (int): dimension of the input feature
20 | hidden_size (int): dimension of the hidden features
21 | batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`.
22 | Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True.
23 |
24 | Shape:
25 | - Inputs: (minibatch, `in_feature`)
26 | - Outputs: :math:`(minibatch, 1)`
27 | """
28 |
29 | def __init__(self, in_feature: int, hidden_size: int, batch_norm=True):
30 | if batch_norm:
31 | super(DomainDiscriminator, self).__init__(
32 | nn.Linear(in_feature, hidden_size),
33 | nn.BatchNorm1d(hidden_size),
34 | nn.ReLU(),
35 | nn.Linear(hidden_size, hidden_size),
36 | nn.BatchNorm1d(hidden_size),
37 | nn.ReLU(),
38 | nn.Linear(hidden_size, 1),
39 | nn.Sigmoid()
40 | )
41 | else:
42 | super(DomainDiscriminator, self).__init__(
43 | nn.Linear(in_feature, hidden_size),
44 | nn.ReLU(inplace=True),
45 | nn.Dropout(0.5),
46 | nn.Linear(hidden_size, hidden_size),
47 | nn.ReLU(inplace=True),
48 | nn.Dropout(0.5),
49 | nn.Linear(hidden_size, 1),
50 | nn.Sigmoid()
51 | )
52 |
53 | def get_parameters(self) -> List[Dict]:
54 | return [{"params": self.parameters(), "lr": 1.}]
55 |
56 |
57 |
--------------------------------------------------------------------------------
/dalib/modules/entropy.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Junguang Jiang
3 | @contact: JiangJunguang1123@outlook.com
4 | """
5 | import torch
6 |
7 |
8 | def entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor:
9 | r"""Entropy of prediction.
10 | The definition is:
11 |
12 | .. math::
13 | entropy(p) = - \sum_{c=1}^C p_c \log p_c
14 |
15 | where C is number of classes.
16 |
17 | Args:
18 | predictions (tensor): Classifier predictions. Expected to contain raw, normalized scores for each class
19 | reduction (str, optional): Specifies the reduction to apply to the output:
20 | ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,
21 | ``'mean'``: the sum of the output will be divided by the number of
22 | elements in the output. Default: ``'mean'``
23 |
24 | Shape:
25 | - predictions: :math:`(minibatch, C)` where C means the number of classes.
26 | - Output: :math:`(minibatch, )` by default. If :attr:`reduction` is ``'mean'``, then scalar.
27 | """
28 | epsilon = 1e-5
29 | H = -predictions * torch.log(predictions + epsilon)
30 | H = H.sum(dim=1)
31 | if reduction == 'mean':
32 | return H.mean()
33 | else:
34 | return H
35 |
--------------------------------------------------------------------------------
/dalib/modules/gl.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Junguang Jiang
3 | @contact: JiangJunguang1123@outlook.com
4 | """
5 | from typing import Optional, Any, Tuple
6 | import numpy as np
7 | import torch.nn as nn
8 | from torch.autograd import Function
9 | import torch
10 |
11 |
12 | class GradientFunction(Function):
13 |
14 | @staticmethod
15 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
16 | ctx.coeff = coeff
17 | output = input * 1.0
18 | return output
19 |
20 | @staticmethod
21 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
22 | return grad_output * ctx.coeff, None
23 |
24 |
25 | class WarmStartGradientLayer(nn.Module):
26 | """Warm Start Gradient Layer :math:`\mathcal{R}(x)` with warm start
27 |
28 | The forward and backward behaviours are:
29 |
30 | .. math::
31 | \mathcal{R}(x) = x,
32 |
33 | \dfrac{ d\mathcal{R}} {dx} = \lambda I.
34 |
35 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule:
36 |
37 | .. math::
38 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo
39 |
40 | where :math:`i` is the iteration step.
41 |
42 | Parameters:
43 | - **alpha** (float, optional): :math:`α`. Default: 1.0
44 | - **lo** (float, optional): Initial value of :math:`\lambda`. Default: 0.0
45 | - **hi** (float, optional): Final value of :math:`\lambda`. Default: 1.0
46 | - **max_iters** (int, optional): :math:`N`. Default: 1000
47 | - **auto_step** (bool, optional): If True, increase :math:`i` each time `forward` is called.
48 | Otherwise use function `step` to increase :math:`i`. Default: False
49 | """
50 |
51 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1.,
52 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False):
53 | super(WarmStartGradientLayer, self).__init__()
54 | self.alpha = alpha
55 | self.lo = lo
56 | self.hi = hi
57 | self.iter_num = 0
58 | self.max_iters = max_iters
59 | self.auto_step = auto_step
60 |
61 | def forward(self, input: torch.Tensor) -> torch.Tensor:
62 | """"""
63 | coeff = np.float(
64 | 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters))
65 | - (self.hi - self.lo) + self.lo
66 | )
67 | if self.auto_step:
68 | self.step()
69 | return GradientFunction.apply(input, coeff)
70 |
71 | def step(self):
72 | """Increase iteration number :math:`i` by 1"""
73 | self.iter_num += 1
74 |
--------------------------------------------------------------------------------
/dalib/modules/grl.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: Junguang Jiang
3 | @contact: JiangJunguang1123@outlook.com
4 | """
5 | from typing import Optional, Any, Tuple
6 | import numpy as np
7 | import torch.nn as nn
8 | from torch.autograd import Function
9 | import torch
10 |
11 |
12 | class GradientReverseFunction(Function):
13 |
14 | @staticmethod
15 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
16 | ctx.coeff = coeff
17 | output = input * 1.0
18 | return output
19 |
20 | @staticmethod
21 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
22 | return grad_output.neg() * ctx.coeff, None
23 |
24 |
25 | class GradientReverseLayer(nn.Module):
26 | def __init__(self):
27 | super(GradientReverseLayer, self).__init__()
28 |
29 | def forward(self, *input):
30 | return GradientReverseFunction.apply(*input)
31 |
32 |
33 | class WarmStartGradientReverseLayer(nn.Module):
34 | """Gradient Reverse Layer :math:`\mathcal{R}(x)` with warm start
35 |
36 | The forward and backward behaviours are:
37 |
38 | .. math::
39 | \mathcal{R}(x) = x,
40 |
41 | \dfrac{ d\mathcal{R}} {dx} = - \lambda I.
42 |
43 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule:
44 |
45 | .. math::
46 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo
47 |
48 | where :math:`i` is the iteration step.
49 |
50 | Args:
51 | alpha (float, optional): :math:`α`. Default: 1.0
52 | lo (float, optional): Initial value of :math:`\lambda`. Default: 0.0
53 | hi (float, optional): Final value of :math:`\lambda`. Default: 1.0
54 | max_iters (int, optional): :math:`N`. Default: 1000
55 | auto_step (bool, optional): If True, increase :math:`i` each time `forward` is called.
56 | Otherwise use function `step` to increase :math:`i`. Default: False
57 | """
58 |
59 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1.,
60 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False):
61 | super(WarmStartGradientReverseLayer, self).__init__()
62 | self.alpha = alpha
63 | self.lo = lo
64 | self.hi = hi
65 | self.iter_num = 0
66 | self.max_iters = max_iters
67 | self.auto_step = auto_step
68 |
69 | def forward(self, input: torch.Tensor) -> torch.Tensor:
70 | """"""
71 | coeff = np.float(
72 | 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters))
73 | - (self.hi - self.lo) + self.lo
74 | )
75 | if self.auto_step:
76 | self.step()
77 | return GradientReverseFunction.apply(input, coeff)
78 |
79 | def step(self):
80 | """Increase iteration number :math:`i` by 1"""
81 | self.iter_num += 1
82 |
--------------------------------------------------------------------------------
/dalib/modules/kernels.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | __all__ = ['GaussianKernel']
7 |
8 |
9 | class GaussianKernel(nn.Module):
10 | r"""Gaussian Kernel Matrix
11 |
12 | Gaussian Kernel k is defined by
13 |
14 | .. math::
15 | k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right)
16 |
17 | where :math:`x_1, x_2 \in R^d` are 1-d tensors.
18 |
19 | Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),`
20 |
21 | .. math::
22 | K(X)_{i,j} = k(x_i, x_j)
23 |
24 | Also by default, during training this layer keeps running estimates of the
25 | mean of L2 distances, which are then used to set hyperparameter :math:`\sigma`.
26 | Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`.
27 | If :attr:`track_running_stats` is set to ``False``, this layer then does not
28 | keep running estimates, and use a fixed :math:`\sigma` instead.
29 |
30 | Args:
31 | sigma (float, optional): bandwidth :math:`\sigma`. Default: None
32 | track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`.
33 | Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True``
34 | alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True``
35 |
36 | Inputs:
37 | - X (tensor): input group :math:`X`
38 |
39 | Shape:
40 | - Inputs: :math:`(minibatch, F)` where F means the dimension of input features.
41 | - Outputs: :math:`(minibatch, minibatch)`
42 | """
43 |
44 | def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True,
45 | alpha: Optional[float] = 1.):
46 | super(GaussianKernel, self).__init__()
47 | assert track_running_stats or sigma is not None
48 | self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None
49 | self.track_running_stats = track_running_stats
50 | self.alpha = alpha
51 |
52 | def forward(self, X: torch.Tensor) -> torch.Tensor:
53 | l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2)
54 |
55 | if self.track_running_stats:
56 | self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach())
57 |
58 | return torch.exp(-l2_distance_square / (2 * self.sigma_square))
59 |
--------------------------------------------------------------------------------
/examples/cdan.py:
--------------------------------------------------------------------------------
1 | # Credits: https://github.com/thuml/Transfer-Learning-Library
2 | import random
3 | import time
4 | import warnings
5 | import sys
6 | import argparse
7 | import shutil
8 | import os.path as osp
9 | import os
10 | import wandb
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.backends.cudnn as cudnn
15 | from torch.optim import SGD
16 | from torch.optim.lr_scheduler import LambdaLR
17 | from torch.utils.data import DataLoader
18 | import torch.nn.functional as F
19 |
20 | sys.path.append('../')
21 | from dalib.modules.domain_discriminator import DomainDiscriminator
22 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier
23 | from common.utils.data import ForeverDataIterator
24 | from common.utils.metric import accuracy
25 | from common.utils.meter import AverageMeter, ProgressMeter
26 | from common.utils.logger import CompleteLogger
27 | from common.utils.analysis import collect_feature, tsne, a_distance
28 |
29 | sys.path.append('.')
30 | import utils
31 |
32 |
33 | def main(args: argparse.Namespace):
34 | logger = CompleteLogger(args.log, args.phase)
35 | print(args)
36 |
37 | if args.log_results:
38 | wandb.init(project="DA", entity="SDAT", name=args.log_name)
39 | wandb.config.update(args)
40 | print(args)
41 |
42 | if args.seed is not None:
43 | random.seed(args.seed)
44 | torch.manual_seed(args.seed)
45 | cudnn.deterministic = True
46 | warnings.warn('You have chosen to seed training. '
47 | 'This will turn on the CUDNN deterministic setting, '
48 | 'which can slow down your training considerably! '
49 | 'You may see unexpected behavior when restarting '
50 | 'from checkpoints.')
51 |
52 | cudnn.benchmark = True
53 | device = args.device
54 | # Data loading code
55 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
56 | random_color_jitter=False, resize_size=args.resize_size,
57 | norm_mean=args.norm_mean, norm_std=args.norm_std)
58 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
59 | norm_mean=args.norm_mean, norm_std=args.norm_std)
60 | print("train_transform: ", train_transform)
61 | print("val_transform: ", val_transform)
62 |
63 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
64 | utils.get_dataset(args.data, args.root, args.source,
65 | args.target, train_transform, val_transform)
66 | train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
67 | shuffle=True, num_workers=args.workers, drop_last=True)
68 | train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
69 | shuffle=True, num_workers=args.workers, drop_last=True)
70 | val_loader = DataLoader(
71 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
72 | test_loader = DataLoader(
73 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
74 |
75 | train_source_iter = ForeverDataIterator(train_source_loader)
76 | train_target_iter = ForeverDataIterator(train_target_loader)
77 |
78 | # create model
79 | print("=> using model '{}'".format(args.arch))
80 | backbone = utils.get_model(args.arch, pretrain=not args.scratch)
81 | print(backbone)
82 | pool_layer = nn.Identity() if args.no_pool else None
83 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
84 | pool_layer=pool_layer, finetune=not args.scratch).to(device)
85 | classifier_feature_dim = classifier.features_dim
86 |
87 | if args.randomized:
88 | domain_discri = DomainDiscriminator(
89 | args.randomized_dim, hidden_size=1024).to(device)
90 | else:
91 | domain_discri = DomainDiscriminator(
92 | classifier_feature_dim * num_classes, hidden_size=1024).to(device)
93 |
94 | all_parameters = classifier.get_parameters() + domain_discri.get_parameters()
95 | # define optimizer and lr scheduler
96 | optimizer = SGD(all_parameters, args.lr, momentum=args.momentum,
97 | weight_decay=args.weight_decay, nesterov=True)
98 | t_total = args.iters_per_epoch * args.epochs
99 | print("{INFORMATION} The total number of steps is ", t_total)
100 |
101 | lr_scheduler = LambdaLR(optimizer, lambda x: args.lr *
102 | (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
103 |
104 | # define loss function
105 | domain_adv = ConditionalDomainAdversarialLoss(
106 | domain_discri, entropy_conditioning=args.entropy,
107 | num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,
108 | randomized_dim=args.randomized_dim
109 | ).to(device)
110 |
111 | # resume from the best checkpoint
112 | if args.phase != 'train':
113 | checkpoint = torch.load(
114 | logger.get_checkpoint_path('best'), map_location='cpu')
115 | classifier.load_state_dict(checkpoint)
116 |
117 | # analysis the model
118 | if args.phase == 'analysis':
119 | # extract features from both domains
120 | feature_extractor = nn.Sequential(
121 | classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
122 | source_feature = collect_feature(
123 | train_source_loader, feature_extractor, device)
124 | target_feature = collect_feature(
125 | train_target_loader, feature_extractor, device)
126 | # plot t-SNE
127 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
128 | tsne.visualize(source_feature, target_feature, tSNE_filename)
129 | print("Saving t-SNE to", tSNE_filename)
130 | # calculate A-distance, which is a measure for distribution discrepancy
131 | A_distance = a_distance.calculate(
132 | source_feature, target_feature, device)
133 | print("A-distance =", A_distance)
134 | return
135 |
136 | if args.phase == 'test':
137 | acc1 = utils.validate(test_loader, classifier, args, device)
138 | print(acc1)
139 | return
140 |
141 | # start training
142 | best_acc1 = 0.
143 | for epoch in range(args.epochs):
144 | print("lr_bbone:", lr_scheduler.get_last_lr()[0])
145 | print("lr_btlnck:", lr_scheduler.get_last_lr()[1])
146 | if args.log_results:
147 | wandb.log({"lr_bbone": lr_scheduler.get_last_lr()[0],
148 | "lr_btlnck": lr_scheduler.get_last_lr()[1]})
149 | # train for one epoch
150 | train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,
151 | lr_scheduler, epoch, args)
152 |
153 | # evaluate on validation set
154 | acc1 = utils.validate(val_loader, classifier, args, device)
155 | if args.log_results:
156 | wandb.log({'epoch': epoch, 'val_acc': acc1})
157 |
158 | # remember best acc@1 and save checkpoint
159 | torch.save(classifier.state_dict(),
160 | logger.get_checkpoint_path('latest'))
161 | if acc1 > best_acc1:
162 | shutil.copy(logger.get_checkpoint_path('latest'),
163 | logger.get_checkpoint_path('best'))
164 | best_acc1 = max(acc1, best_acc1)
165 |
166 | print("best_acc1 = {:3.1f}".format(best_acc1))
167 |
168 | # evaluate on test set
169 | classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
170 | acc1 = utils.validate(test_loader, classifier, args, device)
171 | print("test_acc1 = {:3.1f}".format(acc1))
172 | if args.log_results:
173 | wandb.log({'epoch': epoch, 'test_acc': acc1})
174 |
175 | logger.close()
176 |
177 |
178 | def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
179 | domain_adv: ConditionalDomainAdversarialLoss, optimizer: SGD,
180 | lr_scheduler, epoch: int, args: argparse.Namespace):
181 |
182 | batch_time = AverageMeter('Time', ':3.1f')
183 | data_time = AverageMeter('Data', ':3.1f')
184 | losses = AverageMeter('Loss', ':3.2f')
185 | trans_losses = AverageMeter('Trans Loss', ':3.2f')
186 | cls_accs = AverageMeter('Cls Acc', ':3.1f')
187 | domain_accs = AverageMeter('Domain Acc', ':3.1f')
188 | progress = ProgressMeter(
189 | args.iters_per_epoch,
190 | [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],
191 | prefix="Epoch: [{}]".format(epoch))
192 |
193 | device = args.device
194 | # switch to train mode
195 | model.train()
196 | domain_adv.train()
197 |
198 | end = time.time()
199 | for i in range(args.iters_per_epoch):
200 | x_s, labels_s = next(train_source_iter)
201 | x_t, _ = next(train_target_iter)
202 |
203 | x_s = x_s.to(device)
204 | x_t = x_t.to(device)
205 | labels_s = labels_s.to(device)
206 |
207 | # measure data loading time
208 | data_time.update(time.time() - end)
209 |
210 | # compute output
211 | x = torch.cat((x_s, x_t), dim=0)
212 | y, f = model(x)
213 | y_s, y_t = y.chunk(2, dim=0)
214 | f_s, f_t = f.chunk(2, dim=0)
215 |
216 | cls_loss = F.cross_entropy(y_s, labels_s)
217 | transfer_loss = domain_adv(y_s, f_s, y_t, f_t)
218 | domain_acc = domain_adv.domain_discriminator_accuracy
219 | loss = cls_loss + transfer_loss * args.trade_off
220 |
221 | cls_acc = accuracy(y_s, labels_s)[0]
222 | if args.log_results:
223 | wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss,
224 | 'transfer_loss': transfer_loss, 'domain_acc': domain_acc})
225 |
226 | losses.update(loss.item(), x_s.size(0))
227 | cls_accs.update(cls_acc, x_s.size(0))
228 | domain_accs.update(domain_acc, x_s.size(0))
229 | trans_losses.update(transfer_loss.item(), x_s.size(0))
230 |
231 | # compute gradient and do SGD step
232 | optimizer.zero_grad()
233 | loss.backward()
234 | optimizer.step()
235 | lr_scheduler.step()
236 |
237 | # measure elapsed time
238 | batch_time.update(time.time() - end)
239 | end = time.time()
240 |
241 | if i % args.print_freq == 0:
242 | progress.display(i)
243 |
244 |
245 | if __name__ == '__main__':
246 | parser = argparse.ArgumentParser(
247 | description='CDAN for Unsupervised Domain Adaptation')
248 | # dataset parameters
249 | parser.add_argument('root', metavar='DIR',
250 | help='root path of dataset')
251 | parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
252 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
253 | ' (default: Office31)')
254 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
255 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
256 | parser.add_argument('--train-resizing', type=str, default='default')
257 | parser.add_argument('--val-resizing', type=str, default='default')
258 | parser.add_argument('--resize-size', type=int, default=224,
259 | help='the image size after resizing')
260 | parser.add_argument('--no-hflip', action='store_true',
261 | help='no random horizontal flipping during training')
262 | parser.add_argument('--norm-mean', type=float, nargs='+',
263 | default=(0.485, 0.456, 0.406), help='normalization mean')
264 | parser.add_argument('--norm-std', type=float, nargs='+',
265 | default=(0.229, 0.224, 0.225), help='normalization std')
266 | # model parameters
267 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
268 | choices=utils.get_model_names(),
269 | help='backbone architecture: ' +
270 | ' | '.join(utils.get_model_names()) +
271 | ' (default: resnet18)')
272 | parser.add_argument('--bottleneck-dim', default=256, type=int,
273 | help='Dimension of bottleneck')
274 | parser.add_argument('--no-pool', action='store_true',
275 | help='no pool layer after the feature extractor.')
276 | parser.add_argument('--scratch', action='store_true',
277 | help='whether train from scratch.')
278 | parser.add_argument('-r', '--randomized', action='store_true',
279 | help='using randomized multi-linear-map (default: False)')
280 | parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,
281 | help='randomized dimension when using randomized multi-linear-map (default: 1024)')
282 | parser.add_argument('--entropy', default=False,
283 | action='store_true', help='use entropy conditioning')
284 | parser.add_argument('--trade-off', default=1., type=float,
285 | help='the trade-off hyper-parameter for transfer loss')
286 | # training parameters
287 | parser.add_argument('-b', '--batch-size', default=32, type=int,
288 | metavar='N',
289 | help='mini-batch size (default: 32)')
290 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
291 | metavar='LR', help='initial learning rate', dest='lr')
292 | parser.add_argument('--lr-gamma', default=0.001,
293 | type=float, help='parameter for lr scheduler')
294 | parser.add_argument('--lr-decay', default=0.75,
295 | type=float, help='parameter for lr scheduler')
296 | parser.add_argument('--momentum', default=0.9,
297 | type=float, metavar='M', help='momentum')
298 | parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
299 | metavar='W', help='weight decay (default: 1e-3)',
300 | dest='weight_decay')
301 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
302 | help='number of data loading workers (default: 2)')
303 | parser.add_argument('--epochs', default=20, type=int, metavar='N',
304 | help='number of total epochs to run')
305 | parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
306 | help='Number of iterations per epoch')
307 | parser.add_argument('-p', '--print-freq', default=100, type=int,
308 | metavar='N', help='print frequency (default: 100)')
309 | parser.add_argument('--seed', default=None, type=int,
310 | help='seed for initializing training. ')
311 | parser.add_argument('--per-class-eval', action='store_true',
312 | help='whether output per-class accuracy during evaluation')
313 | parser.add_argument("--log", type=str, default='cdan',
314 | help="Where to save logs, checkpoints and debugging images.")
315 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
316 | help="When phase is 'test', only test the model."
317 | "When phase is 'analysis', only analysis the model.")
318 | parser.add_argument('--log_results', action='store_true',
319 | help="To log results in wandb")
320 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID")
321 | parser.add_argument('--log_name', type=str,
322 | default="log", help="log name for wandb")
323 | args = parser.parse_args()
324 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
325 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326 | args.device = device
327 | main(args)
328 |
--------------------------------------------------------------------------------
/examples/cdan_mcc.py:
--------------------------------------------------------------------------------
1 | # Credits: https://github.com/thuml/Transfer-Learning-Library
2 | import random
3 | import time
4 | import warnings
5 | import sys
6 | import argparse
7 | import shutil
8 | import os.path as osp
9 | import os
10 | import wandb
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.backends.cudnn as cudnn
15 | from torch.optim import SGD
16 | from torch.optim.lr_scheduler import LambdaLR
17 | from torch.utils.data import DataLoader
18 | import torch.nn.functional as F
19 |
20 | sys.path.append('../')
21 | from dalib.modules.domain_discriminator import DomainDiscriminator
22 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier
23 | from dalib.adaptation.mcc import MinimumClassConfusionLoss
24 | from common.utils.data import ForeverDataIterator
25 | from common.utils.metric import accuracy
26 | from common.utils.meter import AverageMeter, ProgressMeter
27 | from common.utils.logger import CompleteLogger
28 | from common.utils.analysis import collect_feature, tsne, a_distance
29 |
30 | sys.path.append('.')
31 | import utils
32 |
33 |
34 | def main(args: argparse.Namespace):
35 | logger = CompleteLogger(args.log, args.phase)
36 | print(args)
37 |
38 | if args.log_results:
39 | wandb.init(project="DA", entity="SDAT", name=args.log_name)
40 | wandb.config.update(args)
41 | print(args)
42 |
43 | if args.seed is not None:
44 | random.seed(args.seed)
45 | torch.manual_seed(args.seed)
46 | cudnn.deterministic = True
47 | warnings.warn('You have chosen to seed training. '
48 | 'This will turn on the CUDNN deterministic setting, '
49 | 'which can slow down your training considerably! '
50 | 'You may see unexpected behavior when restarting '
51 | 'from checkpoints.')
52 |
53 | cudnn.benchmark = True
54 | device = args.device
55 |
56 | # Data loading code
57 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
58 | random_color_jitter=False, resize_size=args.resize_size,
59 | norm_mean=args.norm_mean, norm_std=args.norm_std)
60 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
61 | norm_mean=args.norm_mean, norm_std=args.norm_std)
62 | print("train_transform: ", train_transform)
63 | print("val_transform: ", val_transform)
64 |
65 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
66 | utils.get_dataset(args.data, args.root, args.source,
67 | args.target, train_transform, val_transform)
68 | train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
69 | shuffle=True, num_workers=args.workers, drop_last=True)
70 | train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
71 | shuffle=True, num_workers=args.workers, drop_last=True)
72 | val_loader = DataLoader(
73 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
74 | test_loader = DataLoader(
75 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
76 |
77 | train_source_iter = ForeverDataIterator(train_source_loader)
78 | train_target_iter = ForeverDataIterator(train_target_loader)
79 |
80 | # create model
81 | print("=> using model '{}'".format(args.arch))
82 | backbone = utils.get_model(args.arch, pretrain=not args.scratch)
83 | print(backbone)
84 | pool_layer = nn.Identity() if args.no_pool else None
85 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
86 | pool_layer=pool_layer, finetune=not args.scratch).to(device)
87 | classifier_feature_dim = classifier.features_dim
88 |
89 | if args.randomized:
90 | domain_discri = DomainDiscriminator(
91 | args.randomized_dim, hidden_size=1024).to(device)
92 | else:
93 | domain_discri = DomainDiscriminator(
94 | classifier_feature_dim * num_classes, hidden_size=1024).to(device)
95 |
96 | all_parameters = classifier.get_parameters() + domain_discri.get_parameters()
97 |
98 | # define optimizer and lr scheduler
99 | optimizer = SGD(all_parameters, args.lr, momentum=args.momentum,
100 | weight_decay=args.weight_decay, nesterov=True)
101 | t_total = args.iters_per_epoch * args.epochs
102 | print("{INFORMATION} The total number of steps is ", t_total)
103 |
104 | lr_scheduler = LambdaLR(optimizer, lambda x: args.lr *
105 | (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
106 |
107 | # define loss function
108 | domain_adv = ConditionalDomainAdversarialLoss(
109 | domain_discri, entropy_conditioning=args.entropy,
110 | num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,
111 | randomized_dim=args.randomized_dim
112 | ).to(device)
113 |
114 | mcc_loss = MinimumClassConfusionLoss(temperature=args.temperature)
115 |
116 | # resume from the best checkpoint
117 | if args.phase != 'train':
118 | checkpoint = torch.load(
119 | logger.get_checkpoint_path('best'), map_location='cpu')
120 | classifier.load_state_dict(checkpoint)
121 |
122 | # analysis the model
123 | if args.phase == 'analysis':
124 | # extract features from both domains
125 | feature_extractor = nn.Sequential(
126 | classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
127 | source_feature = collect_feature(
128 | train_source_loader, feature_extractor, device)
129 | target_feature = collect_feature(
130 | train_target_loader, feature_extractor, device)
131 | # plot t-SNE
132 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
133 | tsne.visualize(source_feature, target_feature, tSNE_filename)
134 | print("Saving t-SNE to", tSNE_filename)
135 | # calculate A-distance, which is a measure for distribution discrepancy
136 | A_distance = a_distance.calculate(
137 | source_feature, target_feature, device)
138 | print("A-distance =", A_distance)
139 | return
140 |
141 | if args.phase == 'test':
142 | acc1 = utils.validate(test_loader, classifier, args, device)
143 | print(acc1)
144 | return
145 |
146 | # start training
147 | best_acc1 = 0.
148 | for epoch in range(args.epochs):
149 | print("lr_bbone:", lr_scheduler.get_last_lr()[0])
150 | print("lr_btlnck:", lr_scheduler.get_last_lr()[1])
151 | if args.log_results:
152 | wandb.log({"lr_bbone": lr_scheduler.get_lr()[0],
153 | "lr_btlnck": lr_scheduler.get_last_lr()[1]})
154 | # train for one epoch
155 | train(train_source_iter, train_target_iter, classifier, domain_adv, mcc_loss, optimizer,
156 | lr_scheduler, epoch, args)
157 |
158 | # evaluate on validation set
159 | acc1 = utils.validate(val_loader, classifier, args, device)
160 | if args.log_results:
161 | wandb.log({'epoch': epoch, 'val_acc': acc1})
162 |
163 | # remember best acc@1 and save checkpoint
164 | torch.save(classifier.state_dict(),
165 | logger.get_checkpoint_path('latest'))
166 | if acc1 > best_acc1:
167 | shutil.copy(logger.get_checkpoint_path('latest'),
168 | logger.get_checkpoint_path('best'))
169 | best_acc1 = max(acc1, best_acc1)
170 |
171 | print("best_acc1 = {:3.1f}".format(best_acc1))
172 |
173 | # evaluate on test set
174 | classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
175 | acc1 = utils.validate(test_loader, classifier, args, device)
176 | print("test_acc1 = {:3.1f}".format(acc1))
177 | if args.log_results:
178 | wandb.log({'epoch': epoch, 'test_acc': acc1})
179 |
180 | logger.close()
181 |
182 |
183 | def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
184 | domain_adv: ConditionalDomainAdversarialLoss, mcc, optimizer: SGD,
185 | lr_scheduler, epoch: int, args: argparse.Namespace):
186 | batch_time = AverageMeter('Time', ':3.1f')
187 | data_time = AverageMeter('Data', ':3.1f')
188 | losses = AverageMeter('Loss', ':3.2f')
189 | trans_losses = AverageMeter('Trans Loss', ':3.2f')
190 | cls_accs = AverageMeter('Cls Acc', ':3.1f')
191 | domain_accs = AverageMeter('Domain Acc', ':3.1f')
192 | progress = ProgressMeter(
193 | args.iters_per_epoch,
194 | [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],
195 | prefix="Epoch: [{}]".format(epoch))
196 |
197 | # switch to train mode
198 | model.train()
199 | domain_adv.train()
200 |
201 | end = time.time()
202 | for i in range(args.iters_per_epoch):
203 | x_s, labels_s = next(train_source_iter)
204 | x_t, _ = next(train_target_iter)
205 |
206 | x_s = x_s.to(device)
207 | x_t = x_t.to(device)
208 | labels_s = labels_s.to(device)
209 |
210 | # measure data loading time
211 | data_time.update(time.time() - end)
212 |
213 | # compute output
214 | x = torch.cat((x_s, x_t), dim=0)
215 | y, f = model(x)
216 | y_s, y_t = y.chunk(2, dim=0)
217 | f_s, f_t = f.chunk(2, dim=0)
218 |
219 | cls_loss = F.cross_entropy(y_s, labels_s)
220 | transfer_loss = domain_adv(y_s, f_s, y_t, f_t) + mcc(y_t)
221 | mcc_loss_value = mcc(y_t)
222 | domain_acc = domain_adv.domain_discriminator_accuracy
223 | loss = cls_loss + transfer_loss * args.trade_off
224 | cls_acc = accuracy(y_s, labels_s)[0]
225 | if args.log_results:
226 | wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss,
227 | 'transfer_loss': transfer_loss, 'iteration': epoch*args.iters_per_epoch + i,
228 | 'domain_acc': domain_acc, 'mcc_loss': mcc_loss_value})
229 |
230 | losses.update(loss.item(), x_s.size(0))
231 | cls_accs.update(cls_acc, x_s.size(0))
232 | domain_accs.update(domain_acc, x_s.size(0))
233 | trans_losses.update(transfer_loss.item(), x_s.size(0))
234 |
235 | # compute gradient and do SGD step
236 | optimizer.zero_grad()
237 | loss.backward()
238 | optimizer.step()
239 | lr_scheduler.step()
240 |
241 | # measure elapsed time
242 | batch_time.update(time.time() - end)
243 | end = time.time()
244 |
245 | if i % args.print_freq == 0:
246 | progress.display(i)
247 |
248 |
249 | if __name__ == '__main__':
250 | parser = argparse.ArgumentParser(
251 | description='CDAN+MCC for Unsupervised Domain Adaptation')
252 | # dataset parameters
253 | parser.add_argument('root', metavar='DIR',
254 | help='root path of dataset')
255 | parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
256 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
257 | ' (default: Office31)')
258 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
259 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
260 | parser.add_argument('--train-resizing', type=str, default='default')
261 | parser.add_argument('--val-resizing', type=str, default='default')
262 | parser.add_argument('--resize-size', type=int, default=224,
263 | help='the image size after resizing')
264 | parser.add_argument('--no-hflip', action='store_true',
265 | help='no random horizontal flipping during training')
266 | parser.add_argument('--norm-mean', type=float, nargs='+',
267 | default=(0.485, 0.456, 0.406), help='normalization mean')
268 | parser.add_argument('--norm-std', type=float, nargs='+',
269 | default=(0.229, 0.224, 0.225), help='normalization std')
270 | # model parameters
271 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
272 | choices=utils.get_model_names(),
273 | help='backbone architecture: ' +
274 | ' | '.join(utils.get_model_names()) +
275 | ' (default: resnet18)')
276 | parser.add_argument('--bottleneck-dim', default=256, type=int,
277 | help='Dimension of bottleneck')
278 | parser.add_argument('--no-pool', action='store_true',
279 | help='no pool layer after the feature extractor.')
280 | parser.add_argument('--scratch', action='store_true',
281 | help='whether train from scratch.')
282 | parser.add_argument('-r', '--randomized', action='store_true',
283 | help='using randomized multi-linear-map (default: False)')
284 | parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,
285 | help='randomized dimension when using randomized multi-linear-map (default: 1024)')
286 | parser.add_argument('--entropy', default=False,
287 | action='store_true', help='use entropy conditioning')
288 | parser.add_argument('--trade-off', default=1., type=float,
289 | help='the trade-off hyper-parameter for transfer loss')
290 | # training parameters
291 | parser.add_argument('-b', '--batch-size', default=32, type=int,
292 | metavar='N',
293 | help='mini-batch size (default: 32)')
294 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
295 | metavar='LR', help='initial learning rate', dest='lr')
296 | parser.add_argument('--lr-gamma', default=0.001,
297 | type=float, help='parameter for lr scheduler')
298 | parser.add_argument('--lr-decay', default=0.75,
299 | type=float, help='parameter for lr scheduler')
300 | parser.add_argument('--momentum', default=0.9,
301 | type=float, metavar='M', help='momentum')
302 | parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
303 | metavar='W', help='weight decay (default: 1e-3)',
304 | dest='weight_decay')
305 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
306 | help='number of data loading workers (default: 2)')
307 | parser.add_argument('--epochs', default=20, type=int, metavar='N',
308 | help='number of total epochs to run')
309 | parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
310 | help='Number of iterations per epoch')
311 | parser.add_argument('-p', '--print-freq', default=100, type=int,
312 | metavar='N', help='print frequency (default: 100)')
313 | parser.add_argument('--seed', default=None, type=int,
314 | help='seed for initializing training. ')
315 | parser.add_argument('--per-class-eval', action='store_true',
316 | help='whether output per-class accuracy during evaluation')
317 | parser.add_argument("--log", type=str, default='cdan',
318 | help="Where to save logs, checkpoints and debugging images.")
319 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
320 | help="When phase is 'test', only test the model."
321 | "When phase is 'analysis', only analysis the model.")
322 | parser.add_argument('--log_results', action='store_true',
323 | help="To log results in wandb")
324 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID")
325 | parser.add_argument('--log_name', type=str,
326 | default="log", help="log name for wandb")
327 | parser.add_argument('--temperature', default=2.0,
328 | type=float, help='parameter temperature scaling')
329 | args = parser.parse_args()
330 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
331 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
332 | args.device = device
333 | main(args)
334 |
--------------------------------------------------------------------------------
/examples/cdan_mcc_sdat.py:
--------------------------------------------------------------------------------
1 | # Credits: https://github.com/thuml/Transfer-Learning-Library
2 | import random
3 | import time
4 | import warnings
5 | import sys
6 | import argparse
7 | import shutil
8 | import os.path as osp
9 | import os
10 | import wandb
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.backends.cudnn as cudnn
15 | from torch.optim import SGD
16 | from torch.optim.lr_scheduler import LambdaLR
17 | from torch.utils.data import DataLoader
18 | import torch.nn.functional as F
19 |
20 | sys.path.append('../')
21 | from dalib.modules.domain_discriminator import DomainDiscriminator
22 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier
23 | from dalib.adaptation.mcc import MinimumClassConfusionLoss
24 | from common.utils.data import ForeverDataIterator
25 | from common.utils.metric import accuracy
26 | from common.utils.meter import AverageMeter, ProgressMeter
27 | from common.utils.logger import CompleteLogger
28 | from common.utils.analysis import collect_feature, tsne, a_distance
29 | from common.utils.sam import SAM
30 |
31 | sys.path.append('.')
32 | import utils
33 |
34 |
35 | def main(args: argparse.Namespace):
36 | logger = CompleteLogger(args.log, args.phase)
37 | print(args)
38 |
39 | if args.log_results:
40 | wandb.init(project="DA", entity="SDAT", name=args.log_name)
41 | wandb.config.update(args)
42 | print(args)
43 |
44 | if args.seed is not None:
45 | random.seed(args.seed)
46 | torch.manual_seed(args.seed)
47 | cudnn.deterministic = True
48 | warnings.warn('You have chosen to seed training. '
49 | 'This will turn on the CUDNN deterministic setting, '
50 | 'which can slow down your training considerably! '
51 | 'You may see unexpected behavior when restarting '
52 | 'from checkpoints.')
53 |
54 | cudnn.benchmark = True
55 | device = args.device
56 |
57 | # Data loading code
58 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
59 | random_color_jitter=False, resize_size=args.resize_size,
60 | norm_mean=args.norm_mean, norm_std=args.norm_std)
61 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
62 | norm_mean=args.norm_mean, norm_std=args.norm_std)
63 | print("train_transform: ", train_transform)
64 | print("val_transform: ", val_transform)
65 |
66 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
67 | utils.get_dataset(args.data, args.root, args.source,
68 | args.target, train_transform, val_transform)
69 | train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
70 | shuffle=True, num_workers=args.workers, drop_last=True)
71 | train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
72 | shuffle=True, num_workers=args.workers, drop_last=True)
73 | val_loader = DataLoader(
74 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
75 | test_loader = DataLoader(
76 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
77 |
78 | train_source_iter = ForeverDataIterator(train_source_loader)
79 | train_target_iter = ForeverDataIterator(train_target_loader)
80 |
81 | # create model
82 | print("=> using model '{}'".format(args.arch))
83 | backbone = utils.get_model(args.arch, pretrain=not args.scratch)
84 | print(backbone)
85 | pool_layer = nn.Identity() if args.no_pool else None
86 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
87 | pool_layer=pool_layer, finetune=not args.scratch).to(device)
88 | classifier_feature_dim = classifier.features_dim
89 |
90 | if args.randomized:
91 | domain_discri = DomainDiscriminator(
92 | args.randomized_dim, hidden_size=1024).to(device)
93 | else:
94 | domain_discri = DomainDiscriminator(
95 | classifier_feature_dim * num_classes, hidden_size=1024).to(device)
96 |
97 | # define optimizer and lr scheduler
98 | base_optimizer = torch.optim.SGD
99 | ad_optimizer = SGD(domain_discri.get_parameters(
100 | ), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
101 | optimizer = SAM(classifier.get_parameters(), base_optimizer, rho=args.rho, adaptive=False,
102 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
103 | lr_scheduler = LambdaLR(optimizer, lambda x: args.lr *
104 | (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
105 | lr_scheduler_ad = LambdaLR(
106 | ad_optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
107 |
108 | # define loss function
109 | domain_adv = ConditionalDomainAdversarialLoss(
110 | domain_discri, entropy_conditioning=args.entropy,
111 | num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,
112 | randomized_dim=args.randomized_dim
113 | ).to(device)
114 |
115 | mcc_loss = MinimumClassConfusionLoss(temperature=args.temperature)
116 |
117 | # resume from the best checkpoint
118 | if args.phase != 'train':
119 | checkpoint = torch.load(
120 | logger.get_checkpoint_path('best'), map_location='cpu')
121 | classifier.load_state_dict(checkpoint)
122 |
123 | # analysis the model
124 | if args.phase == 'analysis':
125 | # extract features from both domains
126 | feature_extractor = nn.Sequential(
127 | classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
128 | source_feature = collect_feature(
129 | train_source_loader, feature_extractor, device)
130 | target_feature = collect_feature(
131 | train_target_loader, feature_extractor, device)
132 | # plot t-SNE
133 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
134 | tsne.visualize(source_feature, target_feature, tSNE_filename)
135 | print("Saving t-SNE to", tSNE_filename)
136 | # calculate A-distance, which is a measure for distribution discrepancy
137 | A_distance = a_distance.calculate(
138 | source_feature, target_feature, device)
139 | print("A-distance =", A_distance)
140 | return
141 |
142 | if args.phase == 'test':
143 | acc1 = utils.validate(test_loader, classifier, args, device)
144 | print(acc1)
145 | return
146 |
147 | # start training
148 | best_acc1 = 0.
149 | for epoch in range(args.epochs):
150 | print("lr_bbone:", lr_scheduler.get_last_lr()[0])
151 | print("lr_btlnck:", lr_scheduler.get_last_lr()[1])
152 | if args.log_results:
153 | wandb.log({"lr_bbone": lr_scheduler.get_last_lr()[0],
154 | "lr_btlnck": lr_scheduler.get_last_lr()[1]})
155 | # train for one epoch
156 |
157 | train(train_source_iter, train_target_iter, classifier, domain_adv, mcc_loss, optimizer, ad_optimizer,
158 | lr_scheduler, lr_scheduler_ad, epoch, args)
159 | # evaluate on validation set
160 | acc1 = utils.validate(val_loader, classifier, args, device)
161 | if args.log_results:
162 | wandb.log({'epoch': epoch, 'val_acc': acc1})
163 |
164 | # remember best acc@1 and save checkpoint
165 | torch.save(classifier.state_dict(),
166 | logger.get_checkpoint_path('latest'))
167 | if acc1 > best_acc1:
168 | shutil.copy(logger.get_checkpoint_path('latest'),
169 | logger.get_checkpoint_path('best'))
170 | best_acc1 = max(acc1, best_acc1)
171 |
172 | print("best_acc1 = {:3.1f}".format(best_acc1))
173 |
174 | # evaluate on test set
175 | classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
176 | acc1 = utils.validate(test_loader, classifier, args, device)
177 | print("test_acc1 = {:3.1f}".format(acc1))
178 | if args.log_results:
179 | wandb.log({'epoch': epoch, 'test_acc': acc1})
180 |
181 | logger.close()
182 |
183 |
184 | def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
185 | domain_adv: ConditionalDomainAdversarialLoss, mcc, optimizer, ad_optimizer,
186 | lr_scheduler: LambdaLR, lr_scheduler_ad, epoch: int, args: argparse.Namespace):
187 | batch_time = AverageMeter('Time', ':3.1f')
188 | data_time = AverageMeter('Data', ':3.1f')
189 | losses = AverageMeter('Loss', ':3.2f')
190 | trans_losses = AverageMeter('Trans Loss', ':3.2f')
191 | cls_accs = AverageMeter('Cls Acc', ':3.1f')
192 | domain_accs = AverageMeter('Domain Acc', ':3.1f')
193 | progress = ProgressMeter(
194 | args.iters_per_epoch,
195 | [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],
196 | prefix="Epoch: [{}]".format(epoch))
197 |
198 | # switch to train mode
199 | model.train()
200 | domain_adv.train()
201 |
202 | end = time.time()
203 | for i in range(args.iters_per_epoch):
204 | x_s, labels_s = next(train_source_iter)
205 | x_t, _ = next(train_target_iter)
206 |
207 | x_s = x_s.to(device)
208 | x_t = x_t.to(device)
209 | labels_s = labels_s.to(device)
210 |
211 | # measure data loading time
212 | data_time.update(time.time() - end)
213 | optimizer.zero_grad()
214 | ad_optimizer.zero_grad()
215 |
216 | # compute output
217 | x = torch.cat((x_s, x_t), dim=0)
218 | y, f = model(x)
219 | y_s, y_t = y.chunk(2, dim=0)
220 | f_s, f_t = f.chunk(2, dim=0)
221 | cls_loss = F.cross_entropy(y_s, labels_s)
222 | mcc_loss_value = mcc(y_t)
223 | loss = cls_loss + mcc_loss_value
224 |
225 | loss.backward()
226 |
227 | # Calculate ϵ̂ (w) and add it to the weights
228 | optimizer.first_step(zero_grad=True)
229 |
230 | # Calculate task loss and domain loss
231 | y, f = model(x)
232 | y_s, y_t = y.chunk(2, dim=0)
233 | f_s, f_t = f.chunk(2, dim=0)
234 |
235 | cls_loss = F.cross_entropy(y_s, labels_s)
236 | transfer_loss = domain_adv(y_s, f_s, y_t, f_t) + mcc(y_t)
237 | domain_acc = domain_adv.domain_discriminator_accuracy
238 | loss = cls_loss + transfer_loss * args.trade_off
239 |
240 | cls_acc = accuracy(y_s, labels_s)[0]
241 | if args.log_results:
242 | wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss,
243 | 'transfer_loss': transfer_loss, 'domain_acc': domain_acc})
244 |
245 | losses.update(loss.item(), x_s.size(0))
246 | cls_accs.update(cls_acc, x_s.size(0))
247 | domain_accs.update(domain_acc, x_s.size(0))
248 | trans_losses.update(transfer_loss.item(), x_s.size(0))
249 |
250 | loss.backward()
251 | # Update parameters of domain classifier
252 | ad_optimizer.step()
253 | # Update parameters (Sharpness-Aware update)
254 | optimizer.second_step(zero_grad=True)
255 | lr_scheduler.step()
256 | lr_scheduler_ad.step()
257 |
258 | # measure elapsed time
259 | batch_time.update(time.time() - end)
260 | end = time.time()
261 |
262 | if i % args.print_freq == 0:
263 | progress.display(i)
264 |
265 |
266 | if __name__ == '__main__':
267 | parser = argparse.ArgumentParser(
268 | description='CDAN+MCC with SDAT for Unsupervised Domain Adaptation')
269 | # dataset parameters
270 | parser.add_argument('root', metavar='DIR',
271 | help='root path of dataset')
272 | parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
273 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
274 | ' (default: Office31)')
275 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
276 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
277 | parser.add_argument('--train-resizing', type=str, default='default')
278 | parser.add_argument('--val-resizing', type=str, default='default')
279 | parser.add_argument('--resize-size', type=int, default=224,
280 | help='the image size after resizing')
281 | parser.add_argument('--no-hflip', action='store_true',
282 | help='no random horizontal flipping during training')
283 | parser.add_argument('--norm-mean', type=float, nargs='+',
284 | default=(0.485, 0.456, 0.406), help='normalization mean')
285 | parser.add_argument('--norm-std', type=float, nargs='+',
286 | default=(0.229, 0.224, 0.225), help='normalization std')
287 | # model parameters
288 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
289 | choices=utils.get_model_names(),
290 | help='backbone architecture: ' +
291 | ' | '.join(utils.get_model_names()) +
292 | ' (default: resnet18)')
293 | parser.add_argument('--bottleneck-dim', default=256, type=int,
294 | help='Dimension of bottleneck')
295 | parser.add_argument('--no-pool', action='store_true',
296 | help='no pool layer after the feature extractor.')
297 | parser.add_argument('--scratch', action='store_true',
298 | help='whether train from scratch.')
299 | parser.add_argument('-r', '--randomized', action='store_true',
300 | help='using randomized multi-linear-map (default: False)')
301 | parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,
302 | help='randomized dimension when using randomized multi-linear-map (default: 1024)')
303 | parser.add_argument('--entropy', default=False,
304 | action='store_true', help='use entropy conditioning')
305 | parser.add_argument('--trade-off', default=1., type=float,
306 | help='the trade-off hyper-parameter for transfer loss')
307 | # training parameters
308 | parser.add_argument('-b', '--batch-size', default=32, type=int,
309 | metavar='N',
310 | help='mini-batch size (default: 32)')
311 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
312 | metavar='LR', help='initial learning rate', dest='lr')
313 | parser.add_argument('--lr-gamma', default=0.001,
314 | type=float, help='parameter for lr scheduler')
315 | parser.add_argument('--lr-decay', default=0.75,
316 | type=float, help='parameter for lr scheduler')
317 | parser.add_argument('--momentum', default=0.9,
318 | type=float, metavar='M', help='momentum')
319 | parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
320 | metavar='W', help='weight decay (default: 1e-3)',
321 | dest='weight_decay')
322 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
323 | help='number of data loading workers (default: 2)')
324 | parser.add_argument('--epochs', default=20, type=int, metavar='N',
325 | help='number of total epochs to run')
326 | parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
327 | help='Number of iterations per epoch')
328 | parser.add_argument('-p', '--print-freq', default=100, type=int,
329 | metavar='N', help='print frequency (default: 100)')
330 | parser.add_argument('--seed', default=None, type=int,
331 | help='seed for initializing training. ')
332 | parser.add_argument('--per-class-eval', action='store_true',
333 | help='whether output per-class accuracy during evaluation')
334 | parser.add_argument("--log", type=str, default='cdan',
335 | help="Where to save logs, checkpoints and debugging images.")
336 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
337 | help="When phase is 'test', only test the model."
338 | "When phase is 'analysis', only analysis the model.")
339 | parser.add_argument('--log_results', action='store_true',
340 | help="To log results in wandb")
341 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID")
342 | parser.add_argument('--log_name', type=str,
343 | default="log", help="log name for wandb")
344 | parser.add_argument('--rho', type=float, default=0.05, help="GPU ID")
345 | parser.add_argument('--temperature', default=2.0,
346 | type=float, help='parameter temperature scaling')
347 | args = parser.parse_args()
348 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
349 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
350 | args.device = device
351 | main(args)
352 |
--------------------------------------------------------------------------------
/examples/cdan_sdat.py:
--------------------------------------------------------------------------------
1 | # Credits: https://github.com/thuml/Transfer-Learning-Library
2 | import random
3 | import time
4 | import warnings
5 | import sys
6 | import argparse
7 | import shutil
8 | import os.path as osp
9 | import os
10 | import wandb
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.backends.cudnn as cudnn
15 | from torch.optim import SGD
16 | from torch.optim.lr_scheduler import LambdaLR
17 | from torch.utils.data import DataLoader
18 | import torch.nn.functional as F
19 |
20 | sys.path.append('../')
21 | from dalib.modules.domain_discriminator import DomainDiscriminator
22 | from dalib.adaptation.cdan import ConditionalDomainAdversarialLoss, ImageClassifier
23 | from common.utils.data import ForeverDataIterator
24 | from common.utils.metric import accuracy
25 | from common.utils.meter import AverageMeter, ProgressMeter
26 | from common.utils.logger import CompleteLogger
27 | from common.utils.analysis import collect_feature, tsne, a_distance
28 | from common.utils.sam import SAM
29 |
30 | sys.path.append('.')
31 | import utils
32 |
33 |
34 | def main(args: argparse.Namespace):
35 | logger = CompleteLogger(args.log, args.phase)
36 | print(args)
37 |
38 | if args.log_results:
39 | wandb.init(project="DA", entity="SDAT", name=args.log_name)
40 | wandb.config.update(args)
41 | print(args)
42 |
43 | if args.seed is not None:
44 | random.seed(args.seed)
45 | torch.manual_seed(args.seed)
46 | cudnn.deterministic = True
47 | warnings.warn('You have chosen to seed training. '
48 | 'This will turn on the CUDNN deterministic setting, '
49 | 'which can slow down your training considerably! '
50 | 'You may see unexpected behavior when restarting '
51 | 'from checkpoints.')
52 |
53 | cudnn.benchmark = True
54 | device = args.device
55 |
56 | # Data loading code
57 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
58 | random_color_jitter=False, resize_size=args.resize_size,
59 | norm_mean=args.norm_mean, norm_std=args.norm_std)
60 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
61 | norm_mean=args.norm_mean, norm_std=args.norm_std)
62 | print("train_transform: ", train_transform)
63 | print("val_transform: ", val_transform)
64 |
65 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
66 | utils.get_dataset(args.data, args.root, args.source,
67 | args.target, train_transform, val_transform)
68 | train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
69 | shuffle=True, num_workers=args.workers, drop_last=True)
70 | train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
71 | shuffle=True, num_workers=args.workers, drop_last=True)
72 | val_loader = DataLoader(
73 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
74 | test_loader = DataLoader(
75 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
76 |
77 | train_source_iter = ForeverDataIterator(train_source_loader)
78 | train_target_iter = ForeverDataIterator(train_target_loader)
79 |
80 | # create model
81 | print("=> using model '{}'".format(args.arch))
82 | backbone = utils.get_model(args.arch, pretrain=not args.scratch)
83 | print(backbone)
84 | pool_layer = nn.Identity() if args.no_pool else None
85 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
86 | pool_layer=pool_layer, finetune=not args.scratch).to(device)
87 | classifier_feature_dim = classifier.features_dim
88 |
89 | if args.randomized:
90 | domain_discri = DomainDiscriminator(
91 | args.randomized_dim, hidden_size=1024).to(device)
92 | else:
93 | domain_discri = DomainDiscriminator(
94 | classifier_feature_dim * num_classes, hidden_size=1024).to(device)
95 |
96 | # define optimizer and lr scheduler
97 | base_optimizer = torch.optim.SGD
98 | ad_optimizer = SGD(domain_discri.get_parameters(
99 | ), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
100 | optimizer = SAM(classifier.get_parameters(), base_optimizer, rho=args.rho, adaptive=False,
101 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
102 | lr_scheduler = LambdaLR(optimizer, lambda x: args.lr *
103 | (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
104 | lr_scheduler_ad = LambdaLR(
105 | ad_optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
106 |
107 | # define loss function
108 | domain_adv = ConditionalDomainAdversarialLoss(
109 | domain_discri, entropy_conditioning=args.entropy,
110 | num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,
111 | randomized_dim=args.randomized_dim
112 | ).to(device)
113 |
114 | # resume from the best checkpoint
115 | if args.phase != 'train':
116 | checkpoint = torch.load(
117 | logger.get_checkpoint_path('best'), map_location='cpu')
118 | classifier.load_state_dict(checkpoint)
119 |
120 | # analysis the model
121 | if args.phase == 'analysis':
122 | # extract features from both domains
123 | feature_extractor = nn.Sequential(
124 | classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
125 | source_feature = collect_feature(
126 | train_source_loader, feature_extractor, device)
127 | target_feature = collect_feature(
128 | train_target_loader, feature_extractor, device)
129 | # plot t-SNE
130 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
131 | tsne.visualize(source_feature, target_feature, tSNE_filename)
132 | print("Saving t-SNE to", tSNE_filename)
133 | # calculate A-distance, which is a measure for distribution discrepancy
134 | A_distance = a_distance.calculate(
135 | source_feature, target_feature, device)
136 | print("A-distance =", A_distance)
137 | return
138 |
139 | if args.phase == 'test':
140 | acc1 = utils.validate(test_loader, classifier, args, device)
141 | print(acc1)
142 | return
143 |
144 | # start training
145 | best_acc1 = 0.
146 | for epoch in range(args.epochs):
147 | print("lr_bbone:", lr_scheduler.get_last_lr()[0])
148 | print("lr_btlnck:", lr_scheduler.get_last_lr()[1])
149 | if args.log_results:
150 | wandb.log({"lr_bbone": lr_scheduler.get_last_lr()[0],
151 | "lr_btlnck": lr_scheduler.get_last_lr()[1]})
152 |
153 | # train for one epoch
154 | train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, ad_optimizer,
155 | lr_scheduler, lr_scheduler_ad, epoch, args)
156 | # evaluate on validation set
157 | acc1 = utils.validate(val_loader, classifier, args, device)
158 | if args.log_results:
159 | wandb.log({'epoch': epoch, 'val_acc': acc1})
160 |
161 | # remember best acc@1 and save checkpoint
162 | torch.save(classifier.state_dict(),
163 | logger.get_checkpoint_path('latest'))
164 | if acc1 > best_acc1:
165 | shutil.copy(logger.get_checkpoint_path('latest'),
166 | logger.get_checkpoint_path('best'))
167 | best_acc1 = max(acc1, best_acc1)
168 |
169 | print("best_acc1 = {:3.1f}".format(best_acc1))
170 |
171 | # evaluate on test set
172 | classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
173 | acc1 = utils.validate(test_loader, classifier, args, device)
174 | print("test_acc1 = {:3.1f}".format(acc1))
175 | if args.log_results:
176 | wandb.log({'epoch': epoch, 'test_acc': acc1})
177 |
178 | logger.close()
179 |
180 |
181 | def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
182 | domain_adv: ConditionalDomainAdversarialLoss, optimizer, ad_optimizer,
183 | lr_scheduler: LambdaLR, lr_scheduler_ad, epoch: int, args: argparse.Namespace):
184 | batch_time = AverageMeter('Time', ':3.1f')
185 | data_time = AverageMeter('Data', ':3.1f')
186 | losses = AverageMeter('Loss', ':3.2f')
187 | trans_losses = AverageMeter('Trans Loss', ':3.2f')
188 | cls_accs = AverageMeter('Cls Acc', ':3.1f')
189 | domain_accs = AverageMeter('Domain Acc', ':3.1f')
190 | progress = ProgressMeter(
191 | args.iters_per_epoch,
192 | [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],
193 | prefix="Epoch: [{}]".format(epoch))
194 |
195 | device = args.device
196 | # switch to train mode
197 | model.train()
198 | domain_adv.train()
199 |
200 | end = time.time()
201 | for i in range(args.iters_per_epoch):
202 | x_s, labels_s = next(train_source_iter)
203 | x_t, _ = next(train_target_iter)
204 |
205 | x_s = x_s.to(device)
206 | x_t = x_t.to(device)
207 | labels_s = labels_s.to(device)
208 |
209 | # measure data loading time
210 | data_time.update(time.time() - end)
211 | optimizer.zero_grad()
212 | ad_optimizer.zero_grad()
213 |
214 | # compute task loss for first step
215 | x = torch.cat((x_s, x_t), dim=0)
216 | y, f = model(x)
217 | y_s, y_t = y.chunk(2, dim=0)
218 | f_s, f_t = f.chunk(2, dim=0)
219 | cls_loss = F.cross_entropy(y_s, labels_s)
220 | loss = cls_loss
221 | loss.backward()
222 |
223 | # Calculate ϵ̂ (w) and add it to the weights
224 | optimizer.first_step(zero_grad=True)
225 |
226 | # Calculate task loss and domain loss
227 | y, f = model(x)
228 | y_s, y_t = y.chunk(2, dim=0)
229 | f_s, f_t = f.chunk(2, dim=0)
230 |
231 | cls_loss = F.cross_entropy(y_s, labels_s)
232 | transfer_loss = domain_adv(y_s, f_s, y_t, f_t)
233 | domain_acc = domain_adv.domain_discriminator_accuracy
234 | loss = cls_loss + transfer_loss * args.trade_off
235 |
236 | cls_acc = accuracy(y_s, labels_s)[0]
237 | if args.log_results:
238 | wandb.log({'iteration': epoch*args.iters_per_epoch + i, 'loss': loss, 'cls_loss': cls_loss,
239 | 'transfer_loss': transfer_loss, 'domain_acc': domain_acc})
240 |
241 | losses.update(loss.item(), x_s.size(0))
242 | cls_accs.update(cls_acc, x_s.size(0))
243 | domain_accs.update(domain_acc, x_s.size(0))
244 | trans_losses.update(transfer_loss.item(), x_s.size(0))
245 |
246 | loss.backward()
247 | # Update parameters of domain classifier
248 | ad_optimizer.step()
249 | # Update parameters (Sharpness-Aware update)
250 | optimizer.second_step(zero_grad=True)
251 | lr_scheduler.step()
252 | lr_scheduler_ad.step()
253 |
254 | # measure elapsed time
255 | batch_time.update(time.time() - end)
256 | end = time.time()
257 |
258 | if i % args.print_freq == 0:
259 | progress.display(i)
260 |
261 |
262 | if __name__ == '__main__':
263 | parser = argparse.ArgumentParser(
264 | description='CDAN with SDAT for Unsupervised Domain Adaptation')
265 | # dataset parameters
266 | parser.add_argument('root', metavar='DIR',
267 | help='root path of dataset')
268 | parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
269 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
270 | ' (default: Office31)')
271 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
272 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
273 | parser.add_argument('--train-resizing', type=str, default='default')
274 | parser.add_argument('--val-resizing', type=str, default='default')
275 | parser.add_argument('--resize-size', type=int, default=224,
276 | help='the image size after resizing')
277 | parser.add_argument('--no-hflip', action='store_true',
278 | help='no random horizontal flipping during training')
279 | parser.add_argument('--norm-mean', type=float, nargs='+',
280 | default=(0.485, 0.456, 0.406), help='normalization mean')
281 | parser.add_argument('--norm-std', type=float, nargs='+',
282 | default=(0.229, 0.224, 0.225), help='normalization std')
283 | # model parameters
284 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
285 | choices=utils.get_model_names(),
286 | help='backbone architecture: ' +
287 | ' | '.join(utils.get_model_names()) +
288 | ' (default: resnet18)')
289 | parser.add_argument('--bottleneck-dim', default=256, type=int,
290 | help='Dimension of bottleneck')
291 | parser.add_argument('--no-pool', action='store_true',
292 | help='no pool layer after the feature extractor.')
293 | parser.add_argument('--scratch', action='store_true',
294 | help='whether train from scratch.')
295 | parser.add_argument('-r', '--randomized', action='store_true',
296 | help='using randomized multi-linear-map (default: False)')
297 | parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,
298 | help='randomized dimension when using randomized multi-linear-map (default: 1024)')
299 | parser.add_argument('--entropy', default=False,
300 | action='store_true', help='use entropy conditioning')
301 | parser.add_argument('--trade-off', default=1., type=float,
302 | help='the trade-off hyper-parameter for transfer loss')
303 | # training parameters
304 | parser.add_argument('-b', '--batch-size', default=32, type=int,
305 | metavar='N',
306 | help='mini-batch size (default: 32)')
307 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
308 | metavar='LR', help='initial learning rate', dest='lr')
309 | parser.add_argument('--lr-gamma', default=0.001,
310 | type=float, help='parameter for lr scheduler')
311 | parser.add_argument('--lr-decay', default=0.75,
312 | type=float, help='parameter for lr scheduler')
313 | parser.add_argument('--momentum', default=0.9,
314 | type=float, metavar='M', help='momentum')
315 | parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
316 | metavar='W', help='weight decay (default: 1e-3)',
317 | dest='weight_decay')
318 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
319 | help='number of data loading workers (default: 2)')
320 | parser.add_argument('--epochs', default=20, type=int, metavar='N',
321 | help='number of total epochs to run')
322 | parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
323 | help='Number of iterations per epoch')
324 | parser.add_argument('-p', '--print-freq', default=100, type=int,
325 | metavar='N', help='print frequency (default: 100)')
326 | parser.add_argument('--seed', default=None, type=int,
327 | help='seed for initializing training. ')
328 | parser.add_argument('--per-class-eval', action='store_true',
329 | help='whether output per-class accuracy during evaluation')
330 | parser.add_argument("--log", type=str, default='cdan',
331 | help="Where to save logs, checkpoints and debugging images.")
332 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
333 | help="When phase is 'test', only test the model."
334 | "When phase is 'analysis', only analysis the model.")
335 | parser.add_argument('--log_results', action='store_true',
336 | help="To log results in wandb")
337 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID")
338 | parser.add_argument('--log_name', type=str,
339 | default="log", help="log name for wandb")
340 | parser.add_argument('--rho', type=float, default=0.05, help="GPU ID")
341 | args = parser.parse_args()
342 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
343 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
344 | args.device = device
345 | main(args)
346 |
--------------------------------------------------------------------------------
/examples/eval.py:
--------------------------------------------------------------------------------
1 | # Credits: https://github.com/thuml/Transfer-Learning-Library
2 | import random
3 | import time
4 | import warnings
5 | import sys
6 | import argparse
7 | import shutil
8 | import os.path as osp
9 | import os
10 | import wandb
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.backends.cudnn as cudnn
15 | from torch.utils.data import DataLoader
16 | import torch.nn.functional as F
17 |
18 | sys.path.append('../')
19 | from dalib.adaptation.cdan import ImageClassifier
20 | from common.utils.data import ForeverDataIterator
21 | from common.utils.metric import accuracy
22 | from common.utils.meter import AverageMeter, ProgressMeter
23 |
24 | sys.path.append('.')
25 | import utils
26 |
27 | def main(args: argparse.Namespace):
28 |
29 | if args.log_results:
30 | wandb.init(project="DA", entity="SDAT", name=args.log_name)
31 | wandb.config.update(args)
32 |
33 | cudnn.benchmark = True
34 | device = args.device
35 | # Data loading code
36 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
37 | random_color_jitter=False, resize_size=args.resize_size,
38 | norm_mean=args.norm_mean, norm_std=args.norm_std)
39 | val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
40 | norm_mean=args.norm_mean, norm_std=args.norm_std)
41 |
42 | train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
43 | utils.get_dataset(args.data, args.root, args.source,
44 | args.target, train_transform, val_transform)
45 | val_loader = DataLoader(
46 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
47 | test_loader = DataLoader(
48 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
49 |
50 |
51 | # create model
52 | print("=> using model '{}'".format(args.arch))
53 | backbone = utils.get_model(args.arch, pretrain=not args.scratch)
54 | print(backbone)
55 | pool_layer = nn.Identity() if args.no_pool else None
56 | classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
57 | pool_layer=pool_layer, finetune=not args.scratch).to(device)
58 | classifier_feature_dim = classifier.features_dim
59 |
60 | # resume from the best checkpoint
61 | if args.phase != 'train':
62 | path = args.weight_path
63 | print(f"[INFORMATION] Using the weights stored at {args.weight_path}")
64 | classifier.load_state_dict(torch.load(path))
65 |
66 | if args.phase == 'test':
67 | acc1 = utils.validate(test_loader, classifier, args, device)
68 | print(acc1)
69 | return
70 |
71 | if __name__ == '__main__':
72 | parser = argparse.ArgumentParser(
73 | description='CDAN for Unsupervised Domain Adaptation')
74 | # dataset parameters
75 | parser.add_argument('root', metavar='DIR',
76 | help='root path of dataset')
77 | parser.add_argument('-d', '--data', metavar='DATA', default='OfficeHome', choices=utils.get_dataset_names(),
78 | help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
79 | ' (default: OfficeHome)')
80 | parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
81 | parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
82 | parser.add_argument('--train-resizing', type=str, default='default')
83 | parser.add_argument('--val-resizing', type=str, default='default')
84 | parser.add_argument('--resize-size', type=int, default=224,
85 | help='the image size after resizing')
86 | parser.add_argument('--no-hflip', action='store_true',
87 | help='no random horizontal flipping during training')
88 | parser.add_argument('--norm-mean', type=float, nargs='+',
89 | default=(0.485, 0.456, 0.406), help='normalization mean')
90 | parser.add_argument('--norm-std', type=float, nargs='+',
91 | default=(0.229, 0.224, 0.225), help='normalization std')
92 | # model parameters
93 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
94 | choices=utils.get_model_names(),
95 | help='backbone architecture: ' +
96 | ' | '.join(utils.get_model_names()) +
97 | ' (default: resnet18)')
98 | parser.add_argument('--bottleneck-dim', default=256, type=int,
99 | help='Dimension of bottleneck')
100 | parser.add_argument('--no-pool', action='store_true',
101 | help='no pool layer after the feature extractor.')
102 | parser.add_argument('--scratch', action='store_true',
103 | help='whether train from scratch.')
104 | parser.add_argument('-b', '--batch-size', default=32, type=int,
105 | metavar='N',
106 | help='mini-batch size (default: 32)')
107 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
108 | help='number of data loading workers (default: 2)')
109 | parser.add_argument('-p', '--print-freq', default=100, type=int,
110 | metavar='N', help='print frequency (default: 100)')
111 | parser.add_argument('--seed', default=None, type=int,
112 | help='seed for initializing training. ')
113 | parser.add_argument('--per-class-eval', action='store_true',
114 | help='whether output per-class accuracy during evaluation')
115 | parser.add_argument("--weight_path", type=str, default='cdan',
116 | help="Path to the saved weights")
117 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
118 | help="When phase is 'test', only test the model."
119 | "When phase is 'analysis', only analysis the model.")
120 | parser.add_argument('--log_results', action='store_true',
121 | help="To log results in wandb")
122 | parser.add_argument('--gpu', type=str, default="0", help="GPU ID")
123 | parser.add_argument('--log_name', type=str,
124 | default="log", help="log name for wandb")
125 | args = parser.parse_args()
126 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
127 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128 | args.device = device
129 | main(args)
130 |
--------------------------------------------------------------------------------
/examples/run_office_home.sh:
--------------------------------------------------------------------------------
1 | #CDAN (Office-Home-ViT)
2 | python cdan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_vit --gpu 0 --log_results
3 | python cdan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Pr --log_name Ar2Pr_cdan_vit --gpu 0 --log_results
4 | python cdan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Rw --log_name Ar2Rw_cdan_vit --gpu 0 --log_results
5 |
6 | python cdan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Ar --log_name Cl2Ar_cdan_vit --gpu 0 --log_results
7 | python cdan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Pr --log_name Cl2Pr_cdan_vit --gpu 0 --log_results
8 | python cdan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Rw --log_name Cl2Rw_cdan_vit --gpu 0 --log_results
9 |
10 | python cdan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Ar --log_name Pr2Ar_cdan_vit --gpu 0 --log_results
11 | python cdan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Cl --log_name Pr2Cl_cdan_vit --gpu 0 --log_results
12 | python cdan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Rw --log_name Pr2Rw_cdan_vit --gpu 0 --log_results
13 |
14 | python cdan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Ar --log_name Rw2Ar_cdan_vit --gpu 0 --log_results
15 | python cdan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Cl --log_name Rw2Cl_cdan_vit --gpu 0 --log_results
16 | python cdan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Pr --log_name Rw2Pr_cdan_vit --gpu 0 --log_results
17 |
18 | #CDAN_SDAT (Office-Home-ViT)
19 | python cdan_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
20 | python cdan_sdat.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Ar2Pr --log_name Ar2Pr_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
21 | python cdan_sdat.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Ar2Rw --log_name Ar2Rw_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
22 |
23 | python cdan_sdat.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Cl2Ar --log_name Cl2Ar_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
24 | python cdan_sdat.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Cl2Pr --log_name Cl2Pr_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
25 | python cdan_sdat.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Cl2Rw --log_name Cl2Rw_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
26 |
27 | python cdan_sdat.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Pr2Ar --log_name Pr2Ar_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
28 | python cdan_sdat.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Pr2Cl --log_name Pr2Cl_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
29 | python cdan_sdat.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Pr2Rw --log_name Pr2Rw_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
30 |
31 | python cdan_sdat.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Rw2Ar --log_name Rw2Ar_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
32 | python cdan_sdat.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Rw2Cl --log_name Rw2Cl_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
33 | python cdan_sdat.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_sdat_vit/OfficeHome_Rw2Pr --log_name Rw2Pr_cdan_sdat_vit --gpu 0 --rho 0.02 --log_results
34 |
35 | #CDAN_MCC (Office-Home-ViT)
36 | python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
37 | python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Pr --log_name Ar2Pr_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
38 | python cdan_mcc.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Ar2Rw --log_name Ar2Rw_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
39 |
40 | python cdan_mcc.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Cl2Ar --log_name Cl2Ar_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
41 | python cdan_mcc.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Cl2Pr --log_name Cl2Pr_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
42 | python cdan_mcc.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Cl2Rw --log_name Cl2Rw_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
43 |
44 | python cdan_mcc.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Pr2Ar --log_name Pr2Ar_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
45 | python cdan_mcc.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Pr2Cl --log_name Pr2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
46 | python cdan_mcc.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Pr2Rw --log_name Pr2Rw_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
47 |
48 | python cdan_mcc.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Rw2Ar --log_name Rw2Ar_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
49 | python cdan_mcc.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Rw2Cl --log_name Rw2Cl_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
50 | python cdan_mcc.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_vit/OfficeHome_Rw2Pr --log_name Rw2Pr_cdan_mcc_vit --gpu 0 --lr 0.002 --log_results
51 |
52 | #CDAN_MCC_SDAT (Office-Home-ViT)
53 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Cl --log_name Ar2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
54 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Pr --log_name Ar2Pr_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
55 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Ar2Rw --log_name Ar2Rw_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
56 |
57 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Cl2Ar --log_name Cl2Ar_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
58 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Cl2Pr --log_name Cl2Pr_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
59 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Cl2Rw --log_name Cl2Rw_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
60 |
61 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Pr2Ar --log_name Pr2Ar_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
62 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Pr2Cl --log_name Pr2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
63 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Pr2Rw --log_name Pr2Rw_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
64 |
65 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Rw2Ar --log_name Rw2Ar_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
66 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Rw2Cl --log_name Rw2Cl_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
67 | python cdan_mcc_sdat.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_mcc_sdat_vit/OfficeHome_Rw2Pr --log_name Rw2Pr_cdan_mcc_sdat_vit --gpu 0 --rho 0.02 --lr 0.002 --log_results
68 |
--------------------------------------------------------------------------------
/examples/run_visda.sh:
--------------------------------------------------------------------------------
1 | #CDAN (VisDA2017-ViT)
2 | python cdan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.01 --per-class-eval --train-resizing cen.crop --log logs/cdan_vit/VisDA2017 --log_name visda_cdan_vit --gpu 0 --no-pool --log_results
3 |
4 | #CDAN_SDAT (VisDA2017-ViT)
5 | python cdan_sdat.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.01 --per-class-eval --train-resizing cen.crop --log logs/cdan_sdat_vit/VisDA2017 --log_name visda_cdan_sdat_vit --gpu 0 --no-pool --rho 0.005 --log_results
6 |
7 | #CDAN_MCC (VisDA2017-ViT)
8 | python cdan_mcc.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.002 --per-class-eval --train-resizing cen.crop --log logs/cdan_mcc_vit/VisDA2017 --log_name visda_cdan_mcc_vit --gpu 0 --no-pool --log_results
9 |
10 | #CDAN_MCC_SDAT (VisDA2017-ViT)
11 | python cdan_mcc_sdat.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a vit_base_patch16_224 --epochs 15 --seed 0 --lr 0.002 --per-class-eval --train-resizing cen.crop --log logs/cdan_mcc_sdat_vit/VisDA2017 --log_name visda_cdan_mcc_sdat_vit --gpu 0 --no-pool --rho 0.02 --log_results
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/examples/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os.path as osp
3 | import time
4 | import timm
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision.transforms as T
9 | from torch.utils.data import ConcatDataset
10 | import wandb
11 | import wilds
12 |
13 | sys.path.append('../')
14 | import common.vision.datasets as datasets
15 | import common.vision.models as models
16 | from common.vision.transforms import ResizeImage
17 | from common.utils.metric import accuracy, ConfusionMatrix
18 | from common.utils.meter import AverageMeter, ProgressMeter
19 |
20 |
21 | def get_model_names():
22 | return sorted(
23 | name for name in models.__dict__
24 | if name.islower() and not name.startswith("__")
25 | and callable(models.__dict__[name])
26 | ) + timm.list_models()
27 |
28 |
29 | def get_model(model_name, pretrain=True):
30 | if model_name in models.__dict__:
31 | # load models from common.vision.models
32 | backbone = models.__dict__[model_name](pretrained=pretrain)
33 | else:
34 | # load models from pytorch-image-models
35 | backbone = timm.create_model(model_name, pretrained=pretrain)
36 | try:
37 | #backbone.out_features = backbone.get_classifier().in_features
38 | backbone.out_features = 768
39 | backbone.reset_classifier(0, '')
40 | except:
41 | backbone.out_features = backbone.head.in_features
42 | backbone.head = nn.Identity()
43 | return backbone
44 |
45 |
46 | def convert_from_wilds_dataset(wild_dataset):
47 | class Dataset:
48 | def __init__(self):
49 | self.dataset = wild_dataset
50 |
51 | def __getitem__(self, idx):
52 | x, y, metadata = self.dataset[idx]
53 | return x, y
54 |
55 | def __len__(self):
56 | return len(self.dataset)
57 |
58 | return Dataset()
59 |
60 |
61 | def get_dataset_names():
62 | return sorted(
63 | name for name in datasets.__dict__
64 | if not name.startswith("__") and callable(datasets.__dict__[name])
65 | ) + wilds.supported_datasets + ['Digits']
66 |
67 |
68 | def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None):
69 | if train_target_transform is None:
70 | train_target_transform = train_source_transform
71 | if dataset_name == "Digits":
72 | train_source_dataset = datasets.__dict__[source[0]](osp.join(root, source[0]), download=True,
73 | transform=train_source_transform)
74 | train_target_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), download=True,
75 | transform=train_target_transform)
76 | val_dataset = test_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), split='test',
77 | download=True, transform=val_transform)
78 | class_names = datasets.MNIST.get_classes()
79 | num_classes = len(class_names)
80 | elif dataset_name in datasets.__dict__:
81 | # load datasets from common.vision.datasets
82 | dataset = datasets.__dict__[dataset_name]
83 |
84 | def concat_dataset(tasks, **kwargs):
85 | return ConcatDataset([dataset(task=task, **kwargs) for task in tasks])
86 |
87 | train_source_dataset = concat_dataset(root=root, tasks=source, download=True, transform=train_source_transform)
88 | train_target_dataset = concat_dataset(root=root, tasks=target, download=True, transform=train_target_transform)
89 | val_dataset = concat_dataset(root=root, tasks=target, download=True, transform=val_transform)
90 | if dataset_name == 'DomainNet':
91 | test_dataset = concat_dataset(root=root, tasks=target, split='test', download=True, transform=val_transform)
92 | else:
93 | test_dataset = val_dataset
94 | class_names = train_source_dataset.datasets[0].classes
95 | num_classes = len(class_names)
96 | else:
97 | # load datasets from wilds
98 | dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True)
99 | num_classes = dataset.n_classes
100 | class_names = None
101 | train_source_dataset = convert_from_wilds_dataset(dataset.get_subset('train', transform=train_source_transform))
102 | train_target_dataset = convert_from_wilds_dataset(dataset.get_subset('test', transform=train_target_transform))
103 | val_dataset = test_dataset = convert_from_wilds_dataset(dataset.get_subset('test', transform=val_transform))
104 | return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names
105 |
106 |
107 | def validate(val_loader, model, args, device) -> float:
108 | batch_time = AverageMeter('Time', ':6.3f')
109 | losses = AverageMeter('Loss', ':.4e')
110 | top1 = AverageMeter('Acc@1', ':6.2f')
111 | progress = ProgressMeter(
112 | len(val_loader),
113 | [batch_time, losses, top1],
114 | prefix='Test: ')
115 |
116 | # switch to evaluate mode
117 | model.eval()
118 | if args.per_class_eval:
119 | confmat = ConfusionMatrix(len(args.class_names))
120 | else:
121 | confmat = None
122 |
123 | with torch.no_grad():
124 | end = time.time()
125 | for i, (images, target) in enumerate(val_loader):
126 | images = images.to(device)
127 | target = target.to(device)
128 |
129 | # compute output
130 | output = model(images)
131 | loss = F.cross_entropy(output, target)
132 |
133 | # measure accuracy and record loss
134 | acc1, = accuracy(output, target, topk=(1,))
135 | if confmat:
136 | confmat.update(target, output.argmax(1))
137 | losses.update(loss.item(), images.size(0))
138 | top1.update(acc1.item(), images.size(0))
139 |
140 | # measure elapsed time
141 | batch_time.update(time.time() - end)
142 | end = time.time()
143 |
144 | if i % args.print_freq == 0:
145 | progress.display(i)
146 |
147 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
148 | if confmat:
149 | print(confmat.format(args.class_names))
150 |
151 | return top1.avg
152 |
153 |
154 | def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False,
155 | resize_size=224, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
156 | """
157 | resizing mode:
158 | - default: resize the image to 256 and take a random resized crop of size 224;
159 | - cen.crop: resize the image to 256 and take the center crop of size 224;
160 | - res: resize the image to 224;
161 | """
162 | if resizing == 'default':
163 | transform = T.Compose([
164 | ResizeImage(256),
165 | T.RandomResizedCrop(224)
166 | ])
167 | elif resizing == 'cen.crop':
168 | transform = T.Compose([
169 | ResizeImage(256),
170 | T.CenterCrop(224)
171 | ])
172 | elif resizing == 'ran.crop':
173 | transform = T.Compose([
174 | ResizeImage(256),
175 | T.RandomCrop(224)
176 | ])
177 | elif resizing == 'res.':
178 | transform = ResizeImage(resize_size)
179 | else:
180 | raise NotImplementedError(resizing)
181 | transforms = [transform]
182 | if random_horizontal_flip:
183 | transforms.append(T.RandomHorizontalFlip())
184 | if random_color_jitter:
185 | transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))
186 | transforms.extend([
187 | T.ToTensor(),
188 | T.Normalize(mean=norm_mean, std=norm_std)
189 | ])
190 | return T.Compose(transforms)
191 |
192 |
193 | def get_val_transform(resizing='default', resize_size=224,
194 | norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
195 | """
196 | resizing mode:
197 | - default: resize the image to 256 and take the center crop of size 224;
198 | – res.: resize the image to 224
199 | """
200 | if resizing == 'default':
201 | transform = T.Compose([
202 | ResizeImage(256),
203 | T.CenterCrop(224),
204 | ])
205 | elif resizing == 'res.':
206 | transform = ResizeImage(resize_size)
207 | else:
208 | raise NotImplementedError(resizing)
209 | return T.Compose([
210 | transform,
211 | T.ToTensor(),
212 | T.Normalize(mean=norm_mean, std=norm_std)
213 | ])
214 |
215 |
216 | def pretrain(train_source_iter, model, optimizer, lr_scheduler, epoch, args, device):
217 | batch_time = AverageMeter('Time', ':3.1f')
218 | data_time = AverageMeter('Data', ':3.1f')
219 | losses = AverageMeter('Loss', ':3.2f')
220 | cls_accs = AverageMeter('Cls Acc', ':3.1f')
221 |
222 | progress = ProgressMeter(
223 | args.iters_per_epoch,
224 | [batch_time, data_time, losses, cls_accs],
225 | prefix="Epoch: [{}]".format(epoch))
226 |
227 | # switch to train mode
228 | model.train()
229 |
230 | end = time.time()
231 | for i in range(args.iters_per_epoch):
232 | x_s, labels_s = next(train_source_iter)
233 | x_s = x_s.to(device)
234 | labels_s = labels_s.to(device)
235 |
236 | # measure data loading time
237 | data_time.update(time.time() - end)
238 |
239 | # compute output
240 | y_s, f_s = model(x_s)
241 |
242 | cls_loss = F.cross_entropy(y_s, labels_s)
243 | loss = cls_loss
244 |
245 | cls_acc = accuracy(y_s, labels_s)[0]
246 | if args.log_results:
247 | wandb.log({'iteration':epoch*args.iters_per_epoch + i, 'loss':loss})
248 |
249 | losses.update(loss.item(), x_s.size(0))
250 | cls_accs.update(cls_acc.item(), x_s.size(0))
251 |
252 | # compute gradient and do SGD step
253 | optimizer.zero_grad()
254 | loss.backward()
255 | optimizer.step()
256 | lr_scheduler.step()
257 |
258 | # measure elapsed time
259 | batch_time.update(time.time() - end)
260 | end = time.time()
261 |
262 | if i % args.print_freq == 0:
263 | progress.display(i)
264 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.9.1
2 | torchvision==0.10.1
3 | wandb==0.12.2
4 | timm==0.5.5
5 | prettytable==2.2.0
6 |
--------------------------------------------------------------------------------