├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── configs ├── DomainNet │ ├── aggregate_training.yml │ ├── multihead_training.yml │ └── supermask_training.yml └── PACS │ ├── aggregate_training.yml │ ├── multihead_training.yml │ └── supermask_training.yml ├── data ├── DomainNet │ ├── create_dataset.sh │ ├── create_hdf5.py │ ├── download.sh │ └── tv_0.9_splits │ │ ├── clipart_test.txt │ │ ├── clipart_train.txt │ │ ├── clipart_val.txt │ │ ├── infograph_test.txt │ │ ├── infograph_train.txt │ │ ├── infograph_val.txt │ │ ├── painting_test.txt │ │ ├── painting_train.txt │ │ ├── painting_val.txt │ │ ├── quickdraw_test.txt │ │ ├── quickdraw_train.txt │ │ ├── quickdraw_val.txt │ │ ├── real_test.txt │ │ ├── real_train.txt │ │ ├── real_val.txt │ │ ├── sketch_test.txt │ │ ├── sketch_train.txt │ │ └── sketch_val.txt └── PACS │ └── download.txt ├── dataloaders ├── __init__.py └── domain_datasets.py ├── environment.yml ├── images └── DMG_approach_preview.png ├── models ├── __init__.py ├── basic_model.py ├── multihead_model.py ├── subnetwork_supermask_model.py └── supermasks.py ├── run_jobs.sh ├── train_model.py ├── trainers ├── __init__.py ├── aggregate_trainer.py ├── multihead_trainer.py └── subnetwork_supermask_trainer.py └── utils ├── __init__.py ├── inverse_lr_scheduler.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Prithvijit Chattopadhyay 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Domain-Specific-Masks-for-Generalization 2 | 3 | Pytorch implementation of the paper: 4 | 5 | **Learning to Balance Specificity and Invariance for In and Out of Domain Generalization** 6 | Prithvijit Chattopadhyay, Yogesh Balaji, Judy Hoffman 7 | ECCV 2020 8 | 9 | *We introduce Domain-specific Masks for Generalization, a model for improving both in-domain and out-of-domain generalization performance. For domain generalization, the goal is to learn from a set of source domains to produce a single model that will best generalize to an unseen target domain. As such, many prior approaches focus on learning representations which persist across all source domains with the assumption that these domain agnostic representations will generalize well. However, often individual domains contain characteristics which are unique and when leveraged can significantly aid in-domain recognition performance. To produce a model which best generalizes to both seen and unseen domains, we propose learning domain specific masks. The masks are encouraged to learn a balance of domain-invariant and domain-specific features, thus enabling a model which can benefit from the predictive power of specialized features while retaining the universal applicability of domain-invariant features. We demonstrate competitive performance compared to naive baselines and state-of-the-art methods on both PACS and DomainNet.* 10 | 11 | ![models](images/DMG_approach_preview.png) 12 | 13 | Table of Contents 14 | ================= 15 | 16 | * [Setup and Dependencies](#setup-and-dependencies) 17 | * [Usage](#usage) 18 | * [Download and setup data](#download-and-setup-data) 19 | * [Logging](#logging) 20 | * [Training](#training) 21 | * [Reference](#reference) 22 | 23 | ## Setup and Dependencies 24 | 25 | Our code is implemented in PyTorch (v1.2.0). To setup, do the following: 26 | 1. Install [Anaconda](https://docs.anaconda.com/anaconda/install/linux/) 27 | 2. Get the source: 28 | ``` 29 | git clone https://github.com/prithv1/DMG.git DMG 30 | ``` 31 | 3. Install requirements into the `dmg` virtual environment, using [Anaconda](https://anaconda.org/anaconda/python): 32 | ``` 33 | cd DMG 34 | conda env create -f environment.yml 35 | conda activate dmg 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### Download and setup data 41 | 42 | To download the associated datasets (PACS and DomainNet), execute the following: 43 | 44 | 1. [DomainNet](http://ai.bu.edu/M3SDA/) 45 | ``` 46 | cd data/DomainNet/ 47 | chmod +x download.sh && chmod +x create_dataset.sh 48 | ./download.sh 49 | ./create_dataset.sh 50 | ``` 51 | 2. [PACS](http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017) 52 | Download created hdf5 files from the [drive folder](https://drive.google.com/drive/folders/1i23DCs4TJ8LQsmBiMxsxo6qZsbhiX0gw?usp=sharing) and place all the `hdf5` files under `data/PACS/` 53 | 54 | ### Logging 55 | 56 | For logging, we use [Weights & Biases](https://docs.wandb.com/). Installing wandb is handled by conda environment creation instructions. Go to the instructions [page](https://docs.wandb.com/quickstart) and create a free account and login from your shell. 57 | 58 | ### Training 59 | 60 | Training is managed via experiment configs in the `configs/` folder. To run a job corresponding to a specific multi-source domain shift, do the following: 61 | 62 | 1. Aggregate Training (training a CNN jointly on all domains) 63 | ``` 64 | # Running a job on the DomainNet dataset 65 | python train_model.py \ 66 | --phase aggregate_training \ # training-phase 67 | --config-yml configs/DomainNet/aggregate_training.yml \ # config-file 68 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ # Source domains 69 | DATA.TARGET_DOMAINS sketch \ # Target domain 70 | HJOB.JOB_STRING dmnt_v1 \ # Unique job identifier string 71 | MODEL.BASE_MODEL alexnet # Base CNN architecture 72 | ``` 73 | 2. Multi-head Training (training a CNN with classifier heads per-source domain) 74 | ``` 75 | # Running a job on the DomainNet dataset 76 | python train_model.py \ 77 | --phase multihead_training \ # training-phase 78 | --config-yml configs/DomainNet/multihead_training.yml \ # config-file 79 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ # Source domains 80 | DATA.TARGET_DOMAINS sketch \ # Target domain 81 | HJOB.JOB_STRING dmnt_v1 \ # Unique job identifier string 82 | MODEL.BASE_MODEL alexnet \ # Base CNN architecture 83 | MODEL.SPLIT_LAYER classifier.6 # Split base-network at this layer 84 | ``` 85 | 3. Domain-Specific Mask-based Aggregate Training (DMG) 86 | ``` 87 | # Running a job on the DomainNet dataset 88 | python train_model.py \ 89 | --phase supermask_training \ # training-phase 90 | --config-yml configs/DomainNet/supermask_training.yml \ # config-file 91 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ # Source domains 92 | DATA.TARGET_DOMAINS sketch \ # Target domain 93 | HJOB.JOB_STRING dmnt_v1 \ # Unique job identifier string 94 | MODEL.BASE_MODEL alexnet \ # Base CNN architecture 95 | MODEL.MASK_LAYERS classifier.1,classifier.4,classifier.6 \ # Layers to mask 96 | MODEL.MASK_INIT_SETTING random_uniform \ # How to initialize masks 97 | OPTIM.OVERLAP_LAMBDA 0.1 \ # Strength of overlap penalty 98 | OPTIM.SPARSITY_LAMBDA 0.0 \ # Strength of sparsity penalty 99 | MODEL.POLICY_CONV_MODE False # True if masks are applied to a conv-layer 100 | ``` 101 | 102 | The script `run_jobs.sh` contains sample commands to run jobs across different architectures on both PACS and DomainNet. 103 | 104 | ## Reference 105 | 106 | If you use this code as part of any published research, please cite 107 | ``` 108 | @inproceedings{2020EccvDMG, 109 | author = {Chattopadhyay, Prithvijit and Balaji, Yogesh and Hoffman, Judy}, 110 | title = {Learning to Balance Specificity and Invariance for In and Out of Domain Generalization}, 111 | year = 2020, 112 | booktitle = {European Conference in Computer Vision (ECCV)} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | A module for package-wide configuration 3 | management. Inspired by Ross Girchick's yacs template 4 | Also, 5 | kd's source -- https://github.com/kdexd/probnmn-clevr/blob/master/probnmn/config.py 6 | """ 7 | from typing import List, Any 8 | from yacs.config import CfgNode as CN 9 | 10 | 11 | class Config(object): 12 | """ 13 | A collection of all the required configuration parameters. This class is a nested dict-like 14 | structure, with nested keys accessible as attributes. It contains sensible default values for 15 | all the parameters, which may be over-written by (first) through a YAML file and (second) through 16 | a list of attributes and values. 17 | 18 | - This class definition contains details relevant to all the training phases 19 | but is listed as the final adaptation training phase 20 | 21 | Parameters 22 | =========== 23 | config_yaml: str 24 | Path to a YAML file containing configuration parameters to override 25 | config_override: List[Any], optional (default=[]) 26 | A list of sequential attributes and values of parameters to override. This happens 27 | after overriding from YAML file. 28 | 29 | Attributes 30 | =========== 31 | --------------------------- 32 | (HIGH-LEVEL JOB RELATED ARGUMENTS) 33 | 34 | HJOB.RANDOM_SEED: 123 35 | Random seed for numpy and PyTorch for reproducibility 36 | 37 | HJOB.PHASE: "aggregate_training" 38 | Which phase to train on? One of 39 | - ``aggregate_training'' 40 | - ``multihead_training'' 41 | - ``supermask_training'' 42 | 43 | HJOB.JOB_STRING: "test_job" 44 | Job string prefix 45 | 46 | HJOB.WANDB_PROJECT: "DMG" 47 | Project name from wandb 48 | 49 | HJOB.WANDB_DIR: "wandb_runs/" 50 | Directory to store wandb data in 51 | 52 | -------------------------- 53 | (DATA RELATED ARGUMENTS) 54 | 55 | DATA.DATASET: "PACS" 56 | Dataset to perform experiments on 57 | 58 | DATA.DOMAIN_LIST: "cartoon,photo,sketch" 59 | List of domains to train jointly on 60 | 61 | DATA.TARGET_DOMAINS: "art_painting" 62 | List of target domains to evaluate 63 | 64 | DATA.DATA_SPLIT_DIR: "data/" 65 | Directory which stores all the data 66 | 67 | DATA.HEAD_MODE: "single" 68 | Specific mode of the aggregate dataloaders, specific to multi-head models. 69 | Can be either "single" or "multi". 70 | 71 | ------------------------- 72 | (CHECKPOINT RELATED ARGUMENTS) 73 | 74 | CKPT.STORAGE_DIR: "../DMG/" 75 | Directory to store checkpoints in 76 | 77 | ------------------------- 78 | (MODEL DEFINITION RELATED ARGUMENTS) 79 | 80 | MODEL.BASE_MODEL: "alexnet" 81 | The base architecture on which proposed model definition is to be built upon 82 | 83 | MODEL.PARAM_INIT: "custom" 84 | Whether to initialize new parameters in a standard-vs-custom matter 85 | 86 | MODEL.USE_PRETRAINED: True 87 | Whether to use pre-trained base model 88 | 89 | MODEL.SPLIT_LAYER: "classifier.0" 90 | The layer at which one should split old/new parameters for finetuning 91 | 92 | MODEL.TRAIN_FORWARD_MODE: "route" 93 | How to forward pass instances for the multi-headed network during training 94 | 95 | MODEL.EVAL_FORWARD_MODE: "route" 96 | How to forward pass instances for the multi-headed network during evaluation 97 | 98 | MODEL.NUM_CLASSES: 7 99 | Number of output classes for the classification task 100 | 101 | MODEL.MASK_LAYERS: "classifier.6" 102 | Comma-separated names of layers at which the conditional computation mask is to be applied 103 | 104 | MODEL.POLICY_SAMPLE_MODE: "sample" 105 | Sampling mode of the layer-wise mask policies -- ["sample", "greedy"] 106 | 107 | MODEL.POLICY_CONV_MODE: False 108 | Set to True, when a shared mask per unit in a channel is applied 109 | 110 | MODEL.MASK_INIT_SETTING: random 111 | How to initialize the masks -- ["random_uniform", "scalar"] 112 | 113 | MODEL.MASK_INIT_SCALAR: 1.0 114 | Scalar to initialize the masks with -- 1.0 (by default) 115 | 116 | ------------------------- 117 | (DATALOADER RELATED ARGUMENTS) 118 | 119 | DATALOADER.BATCH_SIZE: 64 120 | Batch size for the dataloader 121 | 122 | DATALOADER.DATA_SAMPLING_MODE: "uniform" 123 | Whether to sample data in a uniform / balanced manner 124 | 125 | ------------------------- 126 | (OPTIMIZATION RELATED ARGUMENTS) 127 | 128 | OPTIM.OPTIMIZER: Adam 129 | Optimizer to use -- [Adam, SGD] 130 | 131 | OPTIM.LEARNING_RATE: 5e-4 132 | Learning rate to use 133 | 134 | OPTIM.LEARNING_RATE_DECAY_RATE: 0.96 135 | Decay rate to use for learning rate decay 136 | 137 | OPTIM.LEARNING_RATE_DECAY_MODE: "iteration" 138 | Whether to decay learning rate per-iteration ("iteration") or per-epoch ("epoch") 139 | 140 | OPTIM.LEARNING_RATE_DECAY_STEP: 15000 141 | If we're decaying learning rate per-iteration, what is the decay-step size? 142 | 143 | OPTIM.LEARNING_RATE_SCHEDULER: exp 144 | What kind of learning rate scheduler to use 145 | 146 | OPTIM.WEIGHT_DECAY: 1e-5 147 | Weight decay to use 148 | 149 | OPTIM.MODEL_LEARNING_RATE: 5e-4 150 | Learning rate to use for the base model during meta-train updates 151 | 152 | OPTIM.POLICY_LEARNING_RATE: 5e-4 153 | Learning rate to use for the mask-policies 154 | 155 | OPTIM.POLICY_WEIGHT_DECAY: 1e-5 156 | Weight decay to use for the policy models 157 | 158 | OPTIM.SPARSITY_LAMBDA: 10 159 | Coefficient of the sparsity incentive (reward / regularization) 160 | 161 | OPTIM.OVERLAP_LAMBDA: 0.0 162 | Whether to penalize overlap amongst masks 163 | 164 | --------------------------- 165 | (Training epoch / iteration related arguments) 166 | EP_IT.MAX_EPOCHS: 100 167 | Maximum number of epochs to train the base CNN for 168 | 169 | EP_IT.LOG_INTERVAL: 100 170 | Number of iterations within an epoch after which terminal log is displayed 171 | 172 | EP_IT.CKPT_STORE_INTERVAL: 100 173 | Number of iterations / epochs after which recurring checkpoints are stored 174 | 175 | --------------------------- 176 | (CPU / GPU Related Arguments) 177 | PROCESS.USE_GPU: True 178 | Whether to use GPU or not 179 | 180 | PROCESS.NUM_WORKERS: 6 181 | Number of workers to use for training 182 | 183 | """ 184 | 185 | def __init__(self, config_yaml: str, config_override: List[Any] = []): 186 | 187 | self._C = CN() 188 | 189 | self._C.HJOB = CN() 190 | self._C.HJOB.RANDOM_SEED = 123 191 | self._C.HJOB.PHASE = "aggregate_training" 192 | self._C.HJOB.JOB_STRING = "test" 193 | self._C.HJOB.WANDB_PROJECT = "DMG" 194 | self._C.HJOB.WANDB_DIR = "wandb_runs" 195 | 196 | self._C.DATA = CN() 197 | self._C.DATA.DATASET = "PACS" 198 | self._C.DATA.DOMAIN_LIST = "cartoon,photo,sketch" 199 | self._C.DATA.TARGET_DOMAINS = "art_painting" 200 | self._C.DATA.DATA_SPLIT_DIR = "data/" 201 | self._C.DATA.HEAD_MODE = "single" 202 | 203 | self._C.CKPT = CN() 204 | self._C.CKPT.STORAGE_DIR = "../DMG/" 205 | 206 | self._C.MODEL = CN() 207 | self._C.MODEL.BASE_MODEL = "alexnet" 208 | self._C.MODEL.PARAM_INIT = "custom" 209 | self._C.MODEL.USE_PRETRAINED = True 210 | self._C.MODEL.SPLIT_LAYER = "classifier.0" 211 | self._C.MODEL.TRAIN_FORWARD_MODE = "route" 212 | self._C.MODEL.EVAL_FORWARD_MODE = "route" 213 | self._C.MODEL.NUM_CLASSES = 7 214 | self._C.MODEL.MASK_LAYERS = "classifier.6" 215 | self._C.MODEL.POLICY_SAMPLE_MODE = "sample" 216 | self._C.MODEL.POLICY_CONV_MODE = False 217 | self._C.MODEL.MASK_INIT_SETTING = "random_uniform" 218 | self._C.MODEL.MASK_INIT_SCALAR = 1.0 219 | 220 | self._C.DATALOADER = CN() 221 | self._C.DATALOADER.BATCH_SIZE = 64 222 | self._C.DATALOADER.DATA_SAMPLING_MODE = "uniform" 223 | 224 | self._C.OPTIM = CN() 225 | self._C.OPTIM.OPTIMIZER = "Adam" 226 | self._C.OPTIM.LEARNING_RATE = 5e-4 227 | self._C.OPTIM.LEARNING_RATE_DECAY_RATE = 0.96 228 | self._C.OPTIM.LEARNING_RATE_DECAY_MODE = "iter" 229 | self._C.OPTIM.LEARNING_RATE_DECAY_STEP = 15000 230 | self._C.OPTIM.LEARNING_RATE_SCHEDULER = "exp" 231 | self._C.OPTIM.WEIGHT_DECAY = 1e-5 232 | self._C.OPTIM.MODEL_LEARNING_RATE = 5e-4 233 | self._C.OPTIM.POLICY_LEARNING_RATE = 5e-4 234 | self._C.OPTIM.MODEL_WEIGHT_DECAY = 1e-5 235 | self._C.OPTIM.POLICY_WEIGHT_DECAY = 1e-5 236 | self._C.OPTIM.SPARSITY_LAMBDA = 10.0 237 | self._C.OPTIM.OVERLAP_LAMBDA = 0.0 238 | 239 | self._C.EP_IT = CN() 240 | self._C.EP_IT.MAX_EPOCHS = 100 241 | self._C.EP_IT.MAX_ITER = 20000 242 | self._C.EP_IT.LOG_INTERVAL = 100 243 | self._C.EP_IT.CKPT_STORE_INTERVAL = 100 244 | 245 | self._C.PROCESS = CN() 246 | self._C.PROCESS.USE_GPU = True 247 | self._C.PROCESS.NUM_WORKERS = 4 248 | 249 | # Override parameter values from YAML file first, then from override list 250 | self._C.merge_from_file(config_yaml) 251 | self._C.merge_from_list(config_override) 252 | 253 | # Make an instantiated object of this class immutable 254 | self._C.freeze() 255 | 256 | def dump(self, file_path: str): 257 | """Save config at the specified file path. 258 | Parameters 259 | ---------- 260 | file_path: str 261 | (YAML) path to save config at. 262 | """ 263 | self._C.dump(stream=open(file_path, "w")) 264 | 265 | def get_env(self): 266 | """ 267 | Get a string as environment name 268 | based on the config attribute values 269 | and the phase of the job 270 | """ 271 | DSET_PREFIX = "" 272 | ENV_NAME = "" 273 | # Prefix based on dataset 274 | if self._C.DATA.DATASET == "PACS": 275 | DSET_PREFIX = "pacs" 276 | elif self._C.DATA.DATASET == "DomainNet": 277 | DSET_PREFIX = "dmnt" 278 | else: 279 | print("Dataset not supported yet") 280 | 281 | if self._C.HJOB.PHASE == "aggregate_training": 282 | ENV_NAME = [self._C.HJOB.JOB_STRING, DSET_PREFIX, "AGG"] 283 | 284 | DOMAINS = self._C.DATA.DOMAIN_LIST.split(",") 285 | ENV_NAME += DOMAINS 286 | 287 | ENV_NAME += [ 288 | self._C.MODEL.BASE_MODEL, 289 | self._C.MODEL.USE_PRETRAINED, 290 | self._C.MODEL.SPLIT_LAYER, 291 | self._C.OPTIM.OPTIMIZER, 292 | "LR", 293 | self._C.OPTIM.LEARNING_RATE, 294 | self._C.OPTIM.LEARNING_RATE_DECAY_RATE, 295 | self._C.OPTIM.LEARNING_RATE_DECAY_MODE, 296 | self._C.OPTIM.LEARNING_RATE_DECAY_STEP, 297 | ] 298 | 299 | if self._C.OPTIM.LEARNING_RATE_SCHEDULER != "exp": 300 | ENV_NAME += ["LR_SCH", self._C.OPTIM.LEARNING_RATE_SCHEDULER] 301 | 302 | ENV_NAME += [ 303 | "WD", 304 | self._C.OPTIM.WEIGHT_DECAY, 305 | "BS", 306 | self._C.DATALOADER.BATCH_SIZE, 307 | self._C.DATALOADER.DATA_SAMPLING_MODE, 308 | "ME", 309 | self._C.EP_IT.MAX_EPOCHS, 310 | ] 311 | elif self._C.HJOB.PHASE == "multihead_training": 312 | ENV_NAME = [self._C.HJOB.JOB_STRING, DSET_PREFIX, "MH"] 313 | 314 | DOMAINS = self._C.DATA.DOMAIN_LIST.split(",") 315 | ENV_NAME += DOMAINS 316 | 317 | ENV_NAME += [ 318 | self._C.MODEL.BASE_MODEL, 319 | self._C.MODEL.USE_PRETRAINED, 320 | self._C.MODEL.SPLIT_LAYER, 321 | "TR_FWD", 322 | self._C.MODEL.TRAIN_FORWARD_MODE, 323 | "EV_FWD", 324 | self._C.MODEL.EVAL_FORWARD_MODE, 325 | self._C.OPTIM.OPTIMIZER, 326 | "LR", 327 | self._C.OPTIM.LEARNING_RATE, 328 | self._C.OPTIM.LEARNING_RATE_DECAY_RATE, 329 | self._C.OPTIM.LEARNING_RATE_DECAY_MODE, 330 | self._C.OPTIM.LEARNING_RATE_DECAY_STEP, 331 | ] 332 | 333 | if self._C.OPTIM.LEARNING_RATE_SCHEDULER != "exp": 334 | ENV_NAME += ["LR_SCH", self._C.OPTIM.LEARNING_RATE_SCHEDULER] 335 | 336 | ENV_NAME += [ 337 | "WD", 338 | self._C.OPTIM.WEIGHT_DECAY, 339 | "BS", 340 | self._C.DATALOADER.BATCH_SIZE, 341 | self._C.DATALOADER.DATA_SAMPLING_MODE, 342 | "ME", 343 | self._C.EP_IT.MAX_EPOCHS, 344 | ] 345 | elif self._C.HJOB.PHASE == "supermask_training": 346 | ENV_NAME = [self._C.HJOB.JOB_STRING, DSET_PREFIX, "SPMSK"] 347 | 348 | DOMAINS = self._C.DATA.DOMAIN_LIST.split(",") 349 | ENV_NAME += DOMAINS 350 | 351 | MASK_LAYERS = "_".join(self._C.MODEL.MASK_LAYERS.split(",")) 352 | 353 | ENV_NAME += [ 354 | self._C.MODEL.BASE_MODEL, 355 | self._C.MODEL.USE_PRETRAINED, 356 | MASK_LAYERS, 357 | ] 358 | 359 | ENV_NAME += [ 360 | self._C.OPTIM.OPTIMIZER, 361 | "LR", 362 | self._C.OPTIM.MODEL_LEARNING_RATE, 363 | self._C.OPTIM.POLICY_LEARNING_RATE, 364 | ] 365 | 366 | if self._C.MODEL.POLICY_SAMPLE_MODE != "sample": 367 | ENV_NAME += ["POL_SMP", self._C.MODEL.POLICY_SAMPLE_MODE] 368 | 369 | if self._C.MODEL.POLICY_CONV_MODE: 370 | ENV_NAME += ["POL_CNV_1"] 371 | 372 | if self._C.OPTIM.SPARSITY_LAMBDA > 0.0: 373 | ENV_NAME += ["L1_SP_", self._C.OPTIM.SPARSITY_LAMBDA] 374 | 375 | if self._C.OPTIM.OVERLAP_LAMBDA > 0.0: 376 | ENV_NAME += ["IOU_OV", self._C.OPTIM.OVERLAP_LAMBDA] 377 | 378 | ENV_NAME += ["MSK_INIT", self._C.MODEL.MASK_INIT_SETTING] 379 | if self._C.MODEL.MASK_INIT_SETTING == "scalar": 380 | ENV_NAME += [self._C.MODEL.MASK_INIT_SCALAR] 381 | 382 | ENV_NAME += [ 383 | self._C.OPTIM.LEARNING_RATE_DECAY_RATE, 384 | self._C.OPTIM.LEARNING_RATE_DECAY_MODE, 385 | self._C.OPTIM.LEARNING_RATE_DECAY_STEP, 386 | "WD", 387 | self._C.OPTIM.MODEL_WEIGHT_DECAY, 388 | self._C.OPTIM.POLICY_WEIGHT_DECAY, 389 | "BS", 390 | self._C.DATALOADER.BATCH_SIZE, 391 | self._C.DATALOADER.DATA_SAMPLING_MODE, 392 | "ME", 393 | self._C.EP_IT.MAX_EPOCHS, 394 | ] 395 | 396 | else: 397 | print("Job phase invalid / not supported yet") 398 | 399 | ENV_NAME = [str(x) for x in ENV_NAME] 400 | return "_".join(ENV_NAME) 401 | 402 | def __getattr__(self, attr: str): 403 | return self._C.__getattr__(attr) 404 | 405 | def __repr__(self): 406 | return self._C.__repr__() 407 | 408 | -------------------------------------------------------------------------------- /configs/DomainNet/aggregate_training.yml: -------------------------------------------------------------------------------- 1 | # High-level job related parameters 2 | HJOB: 3 | RANDOM_SEED: 1234 4 | PHASE: "aggregate_training" 5 | JOB_STRING: "test_agg" 6 | 7 | # Dataset related parameters 8 | DATA: 9 | DATASET: "DomainNet" 10 | DOMAIN_LIST: "clipart,infograph,painting,quickdraw,real" 11 | TARGET_DOMAINS: "sketch" 12 | DATA_SPLIT_DIR: "data/DomainNet/tv_0.9_splits/" 13 | HEAD_MODE: "single" 14 | 15 | # Model related parameters 16 | MODEL: 17 | BASE_MODEL: "alexnet" 18 | PARAM_INIT: "custom" 19 | USE_PRETRAINED: True 20 | NUM_CLASSES: 345 21 | 22 | # Checkpoint related parameters 23 | CKPT: 24 | STORAGE_DIR: "../DMG/" 25 | 26 | # Dataloader related parameters 27 | DATALOADER: 28 | BATCH_SIZE: 64 29 | DATA_SAMPLING_MODE: "uniform" 30 | 31 | # Optimizer related parameters 32 | OPTIM: 33 | OPTIMIZER: "Adam" 34 | LEARNING_RATE: 1e-4 35 | LEARNING_RATE_SCHEDULER: "invlr" 36 | LEARNING_RATE_DECAY_RATE: 0.1 37 | LEARNING_RATE_DECAY_MODE: "iter" 38 | LEARNING_RATE_DECAY_STEP: 1 39 | WEIGHT_DECAY: 0.0 40 | 41 | # Epoch / Iteration related parameters 42 | EP_IT: 43 | MAX_EPOCHS: 50 44 | LOG_INTERVAL: 50 45 | CKPT_STORE_INTERVAL: 2 46 | 47 | # Process related parameters 48 | PROCESS: 49 | USE_GPU: True 50 | NUM_WORKERS: 6 -------------------------------------------------------------------------------- /configs/DomainNet/multihead_training.yml: -------------------------------------------------------------------------------- 1 | # High-level job related parameters 2 | HJOB: 3 | RANDOM_SEED: 123 4 | PHASE: "multihead_training" 5 | JOB_STRING: "test_mh" 6 | 7 | # Dataset related parameters 8 | DATA: 9 | DATASET: "DomainNet" 10 | DOMAIN_LIST: "clipart,infograph,painting,quickdraw,real" 11 | TARGET_DOMAINS: "sketch" 12 | DATA_SPLIT_DIR: "data/DomainNet/tv_0.9_splits/" 13 | HEAD_MODE: "multi" 14 | 15 | # Model related parameters 16 | MODEL: 17 | BASE_MODEL: "alexnet" 18 | PARAM_INIT: "custom" 19 | USE_PRETRAINED: True 20 | SPLIT_LAYER: "classifier.6" 21 | TRAIN_FORWARD_MODE: "route" 22 | EVAL_FORWARD_MODE: "route" 23 | NUM_CLASSES: 345 24 | 25 | # Checkpoint related parameters 26 | CKPT: 27 | STORAGE_DIR: "../DMG/" 28 | 29 | # Dataloader related parameters 30 | DATALOADER: 31 | BATCH_SIZE: 64 32 | DATA_SAMPLING_MODE: "uniform" 33 | 34 | # Optimizer related parameters 35 | OPTIM: 36 | OPTIMIZER: "Adam" 37 | LEARNING_RATE: 1e-4 38 | LEARNING_RATE_SCHEDULER: "invlr" 39 | LEARNING_RATE_DECAY_RATE: 0.1 40 | LEARNING_RATE_DECAY_MODE: "iter" 41 | LEARNING_RATE_DECAY_STEP: 1 42 | WEIGHT_DECAY: 0.0 43 | 44 | # Epoch / Iteration related parameters 45 | EP_IT: 46 | MAX_EPOCHS: 50 47 | LOG_INTERVAL: 50 48 | CKPT_STORE_INTERVAL: 2 49 | 50 | # Process related parameters 51 | PROCESS: 52 | USE_GPU: True 53 | NUM_WORKERS: 6 -------------------------------------------------------------------------------- /configs/DomainNet/supermask_training.yml: -------------------------------------------------------------------------------- 1 | # High-level job related parameters 2 | HJOB: 3 | RANDOM_SEED: 123 4 | PHASE: "supermask_training" 5 | JOB_STRING: "test_supermask" 6 | 7 | # Dataset related parameters 8 | DATA: 9 | DATASET: "DomainNet" 10 | DOMAIN_LIST: "clipart,infograph,painting,quickdraw,real" 11 | TARGET_DOMAINS: "sketch" 12 | DATA_SPLIT_DIR: "data/DomainNet/tv_0.9_splits/" 13 | 14 | # Model related parameters 15 | MODEL: 16 | BASE_MODEL: "alexnet" 17 | PARAM_INIT: "custom" 18 | USE_PRETRAINED: True 19 | NUM_CLASSES: 345 20 | MASK_LAYERS: "classifier.1,classifier.4,classifier.6" 21 | POLICY_SAMPLE_MODE: "sample" 22 | POLICY_CONV_MODE: False 23 | MASK_INIT_SETTING: "random_uniform" 24 | MASK_INIT_SCALAR: 1.0 25 | 26 | # Checkpoint related parameters 27 | CKPT: 28 | STORAGE_DIR: "../DMG/" 29 | 30 | # Dataloader related parameters 31 | DATALOADER: 32 | BATCH_SIZE: 64 33 | DATA_SAMPLING_MODE: "uniform" 34 | 35 | # Optimizer related parameters 36 | OPTIM: 37 | OPTIMIZER: "Adam" 38 | MODEL_LEARNING_RATE: 1e-4 39 | POLICY_LEARNING_RATE: 1e-4 40 | LEARNING_RATE_SCHEDULER: "invlr" 41 | LEARNING_RATE_DECAY_RATE: 0.1 42 | LEARNING_RATE_DECAY_MODE: "iter" 43 | LEARNING_RATE_DECAY_STEP: 1 44 | MODEL_WEIGHT_DECAY: 0.0 45 | POLICY_WEIGHT_DECAY: 0.0 46 | SPARSITY_LAMBDA: 0.0 47 | OVERLAP_LAMBDA: 0.1 48 | 49 | # Epoch / Iteration related parameters 50 | EP_IT: 51 | MAX_EPOCHS: 50 52 | LOG_INTERVAL: 50 53 | CKPT_STORE_INTERVAL: 2 54 | 55 | # Process related parameters 56 | PROCESS: 57 | USE_GPU: True 58 | NUM_WORKERS: 6 -------------------------------------------------------------------------------- /configs/PACS/aggregate_training.yml: -------------------------------------------------------------------------------- 1 | # High-level job related parameters 2 | HJOB: 3 | RANDOM_SEED: 123 4 | PHASE: "aggregate_training" 5 | JOB_STRING: "test_agg" 6 | 7 | # Dataset related parameters 8 | DATA: 9 | DATASET: "PACS" 10 | DOMAIN_LIST: "art_painting,cartoon,photo" 11 | TARGET_DOMAINS: "sketch" 12 | DATA_SPLIT_DIR: "data/PACS/" 13 | HEAD_MODE: "single" 14 | 15 | CKPT: 16 | STORAGE_DIR: "../DMG/" 17 | 18 | # Model related parameters 19 | MODEL: 20 | BASE_MODEL: "alexnet" 21 | PARAM_INIT: "standard" 22 | USE_PRETRAINED: True 23 | SPLIT_LAYER: "classifier.6" 24 | NUM_CLASSES: 7 25 | 26 | # Dataloader related parameters 27 | DATALOADER: 28 | BATCH_SIZE: 64 29 | DATA_SAMPLING_MODE: "uniform" 30 | 31 | # Optimizer related parameters 32 | OPTIM: 33 | OPTIMIZER: "Adam" 34 | LEARNING_RATE: 1e-4 35 | LEARNING_RATE_SCHEDULER: "exp" 36 | LEARNING_RATE_DECAY_RATE: 0.99 37 | LEARNING_RATE_DECAY_MODE: "epoch" 38 | LEARNING_RATE_DECAY_STEP: 1 39 | WEIGHT_DECAY: 1e-5 40 | 41 | # Epoch / Iteration related parameters 42 | EP_IT: 43 | MAX_EPOCHS: 100 44 | LOG_INTERVAL: 50 45 | CKPT_STORE_INTERVAL: 2 46 | 47 | # Process related parameters 48 | PROCESS: 49 | USE_GPU: True 50 | NUM_WORKERS: 6 -------------------------------------------------------------------------------- /configs/PACS/multihead_training.yml: -------------------------------------------------------------------------------- 1 | # High-level job related parameters 2 | HJOB: 3 | RANDOM_SEED: 123 4 | PHASE: "multihead_training" 5 | JOB_STRING: "test_mh" 6 | 7 | # Dataset related parameters 8 | DATA: 9 | DATASET: "PACS" 10 | DOMAIN_LIST: "art_painting,cartoon,photo" 11 | TARGET_DOMAINS: "sketch" 12 | DATA_SPLIT_DIR: "data/PACS/" 13 | HEAD_MODE: "multi" 14 | 15 | CKPT: 16 | STORAGE_DIR: "../DMG/" 17 | 18 | # Model related parameters 19 | MODEL: 20 | BASE_MODEL: "alexnet" 21 | PARAM_INIT: "custom" 22 | USE_PRETRAINED: True 23 | SPLIT_LAYER: "classifier.6" 24 | TRAIN_FORWARD_MODE: "route" 25 | EVAL_FORWARD_MODE: "route" 26 | NUM_CLASSES: 7 27 | 28 | # Dataloader related parameters 29 | DATALOADER: 30 | BATCH_SIZE: 64 31 | DATA_SAMPLING_MODE: "uniform" 32 | 33 | # Optimizer related parameters 34 | OPTIM: 35 | OPTIMIZER: "Adam" 36 | LEARNING_RATE: 1e-4 37 | LEARNING_RATE_SCHEDULER: "exp" 38 | LEARNING_RATE_DECAY_RATE: 0.99 39 | LEARNING_RATE_DECAY_MODE: "epoch" 40 | LEARNING_RATE_DECAY_STEP: 1 41 | WEIGHT_DECAY: 1e-5 42 | 43 | # Epoch / Iteration related parameters 44 | EP_IT: 45 | MAX_EPOCHS: 100 46 | LOG_INTERVAL: 50 47 | CKPT_STORE_INTERVAL: 2 48 | 49 | # Process related parameters 50 | PROCESS: 51 | USE_GPU: True 52 | NUM_WORKERS: 6 -------------------------------------------------------------------------------- /configs/PACS/supermask_training.yml: -------------------------------------------------------------------------------- 1 | # High-level job related parameters 2 | HJOB: 3 | RANDOM_SEED: 123 4 | PHASE: "supermask_training" 5 | JOB_STRING: "test_supermask" 6 | 7 | # Dataset related parameters 8 | DATA: 9 | DATASET: "PACS" 10 | DOMAIN_LIST: "art_painting,cartoon,photo" 11 | TARGET_DOMAINS: "sketch" 12 | DATA_SPLIT_DIR: "data/PACS/" 13 | 14 | CKPT: 15 | STORAGE_DIR: "../DMG/" 16 | 17 | # Model related parameters 18 | MODEL: 19 | BASE_MODEL: "alexnet" 20 | PARAM_INIT: "standard" 21 | USE_PRETRAINED: True 22 | NUM_CLASSES: 7 23 | MASK_LAYERS: "classifier.1,classifier.4,classifier.6" 24 | POLICY_SAMPLE_MODE: "sample" 25 | POLICY_CONV_MODE: False 26 | MASK_INIT_SETTING: "random_uniform" 27 | MASK_INIT_SCALAR: 1.0 28 | 29 | # Dataloader related parameters 30 | DATALOADER: 31 | BATCH_SIZE: 64 32 | DATA_SAMPLING_MODE: "uniform" 33 | 34 | # Optimizer related parameters 35 | OPTIM: 36 | OPTIMIZER: "Adam" 37 | MODEL_LEARNING_RATE: 1e-4 38 | POLICY_LEARNING_RATE: 1e-4 39 | LEARNING_RATE_SCHEDULER: "exp" 40 | LEARNING_RATE_DECAY_RATE: 0.99 41 | LEARNING_RATE_DECAY_MODE: "epoch" 42 | LEARNING_RATE_DECAY_STEP: 1 43 | MODEL_WEIGHT_DECAY: 1e-5 44 | POLICY_WEIGHT_DECAY: 1e-5 45 | SPARSITY_LAMBDA: 0.0 46 | OVERLAP_LAMBDA: 0.1 47 | 48 | # Epoch / Iteration related parameters 49 | EP_IT: 50 | MAX_EPOCHS: 150 51 | LOG_INTERVAL: 50 52 | CKPT_STORE_INTERVAL: 2 53 | 54 | # Process related parameters 55 | PROCESS: 56 | USE_GPU: True 57 | NUM_WORKERS: 6 -------------------------------------------------------------------------------- /data/DomainNet/create_dataset.sh: -------------------------------------------------------------------------------- 1 | # Create h5 files based on data 2 | python create_hdf5.py -------------------------------------------------------------------------------- /data/DomainNet/create_hdf5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | import json 5 | import h5py 6 | import random 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | from pprint import pprint 12 | from scipy.stats import entropy 13 | 14 | # DomainNet -- all the 6 domains 15 | all_domains = ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] 16 | 17 | # Path to data containing all the files 18 | DATA_SPLIT_DIR = "../DomainNet/" 19 | 20 | 21 | def process_txt(txt_file): 22 | with open(txt_file, "r") as f: 23 | data = [x.strip("\n") for x in f.readlines()] 24 | data = [x.split(" ") for x in data] 25 | data = [[x[0], int(x[1])] for x in data] 26 | return data 27 | 28 | 29 | def create_dataset(split_txt_file, save_path): 30 | # Get split data 31 | print("Loading data from txt file..") 32 | split_data = process_txt(split_txt_file) 33 | 34 | # Get number of images 35 | num_instances = len(split_data) 36 | 37 | # Define data-shape 38 | im_shape = (num_instances, 224, 224, 3) 39 | lbl_shape = (num_instances,) 40 | 41 | # Open an h5 file and create earrays 42 | print("Opening h5 file and creating e-arrays..") 43 | hdf5_file = h5py.File(save_path, mode="w") 44 | hdf5_file.create_dataset("images", im_shape) 45 | hdf5_file.create_dataset("labels", lbl_shape) 46 | 47 | # Store labels in dataset 48 | print("Adding labels..") 49 | hdf5_file["labels"][...] = np.array([x[1] for x in split_data]) 50 | 51 | # Store images in dataset 52 | print("Adding images..") 53 | for i in tqdm(range(num_instances)): 54 | curr_img_path = split_data[i][0] 55 | img = cv2.imread(curr_img_path) 56 | img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 58 | 59 | # Save the image file 60 | hdf5_file["images"][i, ...] = img[None] 61 | 62 | hdf5_file.close() 63 | 64 | 65 | if __name__ == "__main__": 66 | # Save Path prefix for all files 67 | path_prefix = DATA_SPLIT_DIR + "tv_0.9_splits/" 68 | 69 | # Create hdf5 files for all splits of all domains 70 | for domain in all_domains: 71 | print("*" * 40) 72 | print("*" * 40) 73 | print("Creating files for domain ", domain) 74 | train_file = path_prefix + domain + "_train.txt" 75 | val_file = path_prefix + domain + "_val.txt" 76 | test_file = path_prefix + domain + "_test.txt" 77 | 78 | print("Processing split train") 79 | create_dataset(train_file, path_prefix + domain + "_train.h5") 80 | 81 | print("Processing split val") 82 | create_dataset(val_file, path_prefix + domain + "_val.h5") 83 | 84 | print("Processing split test") 85 | create_dataset(test_file, path_prefix + domain + "_test.h5") 86 | print("*" * 40) 87 | print("*" * 40) 88 | -------------------------------------------------------------------------------- /data/DomainNet/download.sh: -------------------------------------------------------------------------------- 1 | # Download dataset files 2 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip 3 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip 4 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip 5 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip 6 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip 7 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip 8 | 9 | # Download split-details 10 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/clipart_train.txt 11 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/clipart_test.txt 12 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/infograph_train.txt 13 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/infograph_test.txt 14 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/painting_train.txt 15 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/painting_test.txt 16 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/quickdraw_train.txt 17 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/quickdraw_test.txt 18 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/real_train.txt 19 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/real_test.txt 20 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/sketch_train.txt 21 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/sketch_test.txt 22 | 23 | # Unzip the compressed files 24 | for f in *.zip 25 | do 26 | unzip "$f" 27 | done 28 | -------------------------------------------------------------------------------- /data/PACS/download.txt: -------------------------------------------------------------------------------- 1 | (1) Download pre-read hdf5 files from - https://drive.google.com/drive/folders/1i23DCs4TJ8LQsmBiMxsxo6qZsbhiX0gw?usp=sharing 2 | (2) Place the files in ``../PACS/`` -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .domain_datasets import DomainDataset, Aggregate_DomainDataset 2 | 3 | __all__ = [DomainDataset, Aggregate_DomainDataset] 4 | 5 | -------------------------------------------------------------------------------- /dataloaders/domain_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import json 5 | import h5py 6 | import copy 7 | import torch 8 | import config 9 | 10 | import numpy as np 11 | 12 | from tqdm import tqdm 13 | from PIL import Image 14 | from pprint import pprint 15 | from scipy.io import loadmat 16 | from torch.utils.data import Dataset, ConcatDataset 17 | from torchvision import transforms, utils 18 | 19 | # PACS Domains 20 | PACS_DOM_LIST = ["art_painting", "cartoon", "photo", "sketch"] 21 | 22 | # DomainNet Domains 23 | DomainNet_DOM_LIST = ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] 24 | 25 | 26 | class DomainDataset(Dataset): 27 | def __init__( 28 | self, dataset_name, domain, data_split_dir, phase="train", image_transform=None 29 | ): 30 | super(DomainDataset, self).__init__() 31 | 32 | self.dataset_name = dataset_name 33 | self.domain = domain 34 | self.data_split_dir = data_split_dir 35 | self.phase = phase 36 | self.image_transform = image_transform 37 | 38 | # Load the dataset 39 | if self.dataset_name == "PACS": 40 | self.dataset = {} 41 | self.dataset_file = h5py.File(self.domain_filter(), "r") 42 | temp_imgs = np.array(self.dataset_file["images"]) 43 | temp_labels = np.array(self.dataset_file["labels"]) 44 | temp_imgs = temp_imgs[:, :, :, ::-1] 45 | temp_labels = temp_labels - 1 46 | self.dataset["images"] = temp_imgs 47 | self.dataset["labels"] = temp_labels 48 | self.dataset_file.close() 49 | self.dataset_len = self.dataset["images"].shape[0] 50 | self.n_classes = 7 51 | elif self.dataset_name == "DomainNet": 52 | self.dataset = None 53 | with h5py.File(self.domain_filter(), "r") as file: 54 | self.dataset_len = file["images"].shape[0] 55 | self.n_classes = 345 56 | else: 57 | print("Dataset not supported yet") 58 | 59 | def domain_filter(self): 60 | flist = os.listdir(self.data_split_dir) 61 | if self.dataset_name == "PACS": 62 | dom_flist = [ 63 | x for x in flist if "hdf5" in x and self.domain in x and self.phase in x 64 | ] 65 | elif self.dataset_name == "DomainNet": 66 | dom_flist = [ 67 | x for x in flist if "h5" in x and self.domain in x and self.phase in x 68 | ] 69 | else: 70 | print("Dataset not supported yet") 71 | return os.path.join(self.data_split_dir, dom_flist[0]) 72 | 73 | def __len__(self): 74 | return self.dataset_len 75 | 76 | def __getitem__(self, idx): 77 | if self.dataset_name == "PACS": 78 | img_arr = self.dataset["images"][idx] 79 | img_lbl = self.dataset["labels"][idx] 80 | elif self.dataset_name == "DomainNet": 81 | if self.dataset is None: 82 | self.dataset = h5py.File(self.domain_filter(), "r") 83 | img_arr = self.dataset["images"][idx] 84 | img_lbl = self.dataset["labels"][idx] 85 | 86 | img_dom = self.domain 87 | 88 | # Convert the image array to an image 89 | img = Image.fromarray(np.uint8(img_arr)) 90 | 91 | # Apply image transformation 92 | if self.image_transform: 93 | img = self.image_transform(img) 94 | 95 | return (img, img_lbl, img_dom) 96 | 97 | 98 | class Aggregate_DomainDataset: 99 | def __init__( 100 | self, 101 | dataset_name, 102 | domain_list, 103 | data_split_dir, 104 | phase="train", 105 | image_transform=None, 106 | batch_size=64, 107 | num_workers=4, 108 | use_gpu=True, 109 | shuffle=True, 110 | ): 111 | super(Aggregate_DomainDataset, self).__init__() 112 | self.dataset_name = dataset_name 113 | self.domain_list = domain_list 114 | self.data_split_dir = data_split_dir 115 | self.phase = phase 116 | self.image_transform = image_transform 117 | self.batch_size = batch_size 118 | self.num_workers = num_workers 119 | self.use_gpu = use_gpu 120 | self.shuffle = shuffle 121 | 122 | # Individual Data Splits 123 | self.indiv_datasets = {} 124 | for domain in self.domain_list: 125 | self.indiv_datasets[domain] = DomainDataset( 126 | self.dataset_name, 127 | domain, 128 | self.data_split_dir, 129 | self.phase, 130 | self.image_transform, 131 | ) 132 | 133 | # Aggregate Data-split 134 | self.aggregate_dataset = ConcatDataset(list(self.indiv_datasets.values())) 135 | 136 | # Store the list of labels and domains 137 | self.instance_labels = [] 138 | self.instance_dom = [] 139 | 140 | for domain in self.domain_list: 141 | self.instance_dom += [domain] * len(self.indiv_datasets[domain]) 142 | self.instance_dom = np.array(self.instance_dom) 143 | 144 | # Creating dataloaders 145 | self.cuda = self.use_gpu and torch.cuda.is_available() 146 | kwargs = ( 147 | {"num_workers": self.num_workers, "pin_memory": True} if self.cuda else {} 148 | ) 149 | 150 | self.indiv_dataloaders = {} 151 | for domain in self.domain_list: 152 | self.indiv_dataloaders[domain] = torch.utils.data.DataLoader( 153 | self.indiv_datasets[domain], 154 | batch_size=self.batch_size, 155 | shuffle=self.shuffle, 156 | **kwargs, 157 | ) 158 | 159 | self.aggregate_dataloader = torch.utils.data.DataLoader( 160 | self.aggregate_dataset, 161 | batch_size=self.batch_size, 162 | shuffle=self.shuffle, 163 | **kwargs, 164 | ) 165 | 166 | # Set current dataloader 167 | self.curr_split = self.aggregate_dataset 168 | self.curr_loader = self.aggregate_dataloader 169 | 170 | def set_domain_spec_mode(self, val=True, domain=None): 171 | if val: 172 | self.curr_loader = self.indiv_dataloaders[domain] 173 | self.curr_split = self.indiv_datasets[domain] 174 | else: 175 | self.curr_loader = self.aggregate_dataloader 176 | self.curr_split = self.aggregate_dataset 177 | 178 | @classmethod 179 | def from_config(cls, config, domain_list, phase, image_transform, shuffle): 180 | _C = config 181 | return cls( 182 | dataset=_C.DATA.DATASET, 183 | domain_list=domain_list, 184 | data_split_dir=_C.DATA.DATA_DIR, 185 | phase=phase, 186 | image_transform=image_transform, 187 | batch_size=_C.DATALOADER.BATCH_SIZE, 188 | num_workers=_C.PROCESS.NUM_WORKERS, 189 | use_gpu=_C.PROCESS.USE_GPU, 190 | shuffle=shuffle, 191 | ) 192 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dmg 2 | channels: 3 | - anaconda 4 | - menpo 5 | - mlgill 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _pytorch_select=0.2=gpu_0 11 | - apipkg=1.5=py36_0 12 | - asn1crypto=0.24.0=py36_0 13 | - atomicwrites=1.3.0=py36_1 14 | - attrs=19.1.0=py36_1 15 | - backcall=0.1.0=py36_0 16 | - blas=1.0=mkl 17 | - ca-certificates=2019.10.16=0 18 | - certifi=2019.9.11=py36_0 19 | - cffi=1.12.3=py36h2e261b9_0 20 | - chardet=3.0.4=py36_1003 21 | - cloudpickle=1.2.1=py_0 22 | - cryptography=2.7=py36h1ba5d50_0 23 | - cudatoolkit=9.2=0 24 | - cudnn=7.6.0=cuda9.2_0 25 | - cycler=0.10.0=py36_0 26 | - cytoolz=0.10.0=py36h7b6447c_0 27 | - dask-core=2.3.0=py_0 28 | - dbus=1.13.6=h746ee38_0 29 | - decorator=4.4.0=py36_1 30 | - defusedxml=0.6.0=py_0 31 | - entrypoints=0.3=py36_0 32 | - execnet=1.6.1=py_0 33 | - expat=2.2.6=he6710b0_0 34 | - fontconfig=2.13.0=h9420a91_0 35 | - freetype=2.9.1=h8a8886c_1 36 | - glib=2.56.2=hd408876_0 37 | - gmp=6.1.2=h6c8ec71_1 38 | - gst-plugins-base=1.14.0=hbbd80ab_1 39 | - gstreamer=1.14.0=hb453b48_1 40 | - h5py=2.9.0=py36h7918eee_0 41 | - hdf5=1.10.4=hb1b8bf9_0 42 | - icu=58.2=h9c2bf20_1 43 | - idna=2.8=py36_0 44 | - imageio=2.5.0=py36_0 45 | - importlib_metadata=0.19=py36_0 46 | - intel-openmp=2019.4=243 47 | - ipykernel=5.1.2=py36h39e3cac_0 48 | - ipython=7.8.0=py36h39e3cac_0 49 | - ipython_genutils=0.2.0=py36_0 50 | - ipywidgets=7.5.1=py_0 51 | - jedi=0.15.1=py36_0 52 | - jinja2=2.10.1=py36_0 53 | - jpeg=9b=h024ee3a_2 54 | - jsonschema=3.0.2=py36_0 55 | - jupyter=1.0.0=py36_7 56 | - jupyter_client=5.3.1=py_0 57 | - jupyter_console=6.0.0=py36_0 58 | - jupyter_core=4.5.0=py_0 59 | - kiwisolver=1.1.0=py36he6710b0_0 60 | - libedit=3.1.20181209=hc058e9b_0 61 | - libffi=3.2.1=hd88cf55_4 62 | - libgcc-ng=9.1.0=hdf63c60_0 63 | - libgfortran-ng=7.3.0=hdf63c60_0 64 | - libpng=1.6.37=hbc83047_0 65 | - libsodium=1.0.16=h1bed415_0 66 | - libstdcxx-ng=9.1.0=hdf63c60_0 67 | - libtiff=4.0.10=h2733197_2 68 | - libuuid=1.0.3=h1bed415_2 69 | - libxcb=1.13=h1bed415_1 70 | - libxml2=2.9.9=hea5a465_1 71 | - markupsafe=1.1.1=py36h7b6447c_0 72 | - matplotlib=3.1.1=py36h5429711_0 73 | - memory_profiler=0.39=0 74 | - mistune=0.8.4=py36h7b6447c_0 75 | - mkl=2019.4=243 76 | - mkl-service=2.3.0=py36he904b0f_0 77 | - mkl_fft=1.0.14=py36ha843d7b_0 78 | - mkl_random=1.0.2=py36hd81dba3_0 79 | - more-itertools=7.2.0=py36_0 80 | - nbconvert=5.5.0=py_0 81 | - nbformat=4.4.0=py36_0 82 | - ncurses=6.1=he6710b0_1 83 | - networkx=2.3=py_0 84 | - ninja=1.9.0=py36hfd86e86_0 85 | - notebook=6.0.0=py36_0 86 | - numpy=1.16.4=py36h7e9f1db_0 87 | - numpy-base=1.16.4=py36hde5b4d6_0 88 | - olefile=0.46=py36_0 89 | - opencv3=3.1.0=py36_0 90 | - openssl=1.1.1d=h7b6447c_3 91 | - packaging=19.1=py36_0 92 | - pandas=0.25.1=py36he6710b0_0 93 | - pandoc=2.2.3.2=0 94 | - pandocfilters=1.4.2=py36_1 95 | - parso=0.5.1=py_0 96 | - pcre=8.43=he6710b0_0 97 | - pexpect=4.7.0=py36_0 98 | - pickleshare=0.7.5=py36_0 99 | - pillow=6.1.0=py36h34e0f95_0 100 | - pip=19.2.2=py36_0 101 | - pluggy=0.12.0=py_0 102 | - prometheus_client=0.7.1=py_0 103 | - prompt_toolkit=2.0.9=py36_0 104 | - ptyprocess=0.6.0=py36_0 105 | - py=1.8.0=py36_0 106 | - pycparser=2.19=py36_0 107 | - pygments=2.4.2=py_0 108 | - pyopenssl=19.0.0=py36_0 109 | - pyparsing=2.4.2=py_0 110 | - pyqt=5.9.2=py36h05f1152_2 111 | - pyrsistent=0.14.11=py36h7b6447c_0 112 | - pysocks=1.7.0=py36_0 113 | - pytest=5.0.1=py36_0 114 | - pytest-xdist=1.13.1=0 115 | - python=3.6.9=h265db76_0 116 | - python-dateutil=2.8.0=py36_0 117 | - pytorch=1.2.0=cuda92py36hd3e106c_0 118 | - pytz=2019.2=py_0 119 | - pywavelets=1.0.3=py36hdd07704_1 120 | - pyzmq=18.1.0=py36he6710b0_0 121 | - qt=5.9.7=h5867ecd_1 122 | - qtconsole=4.5.4=py_0 123 | - readline=7.0=h7b6447c_5 124 | - requests=2.22.0=py36_0 125 | - scikit-image=0.15.0=py36he6710b0_0 126 | - scipy=1.3.1=py36h7c811a0_0 127 | - send2trash=1.5.0=py36_0 128 | - setuptools=41.0.1=py36_0 129 | - setuptools_scm=3.3.3=py_0 130 | - sip=4.19.8=py36hf484d3e_0 131 | - six=1.12.0=py36_0 132 | - sqlite=3.29.0=h7b6447c_0 133 | - terminado=0.8.2=py36_0 134 | - testpath=0.4.2=py36_0 135 | - tk=8.6.9=hed695b0_1002 136 | - toolz=0.10.0=py_0 137 | - torchfile=0.1.0=py_0 138 | - torchvision=0.4.0=cuda92py36h1667eeb_0 139 | - tornado=6.0.3=py36h7b6447c_0 140 | - tqdm=4.35.0=py_0 141 | - traitlets=4.3.2=py36_0 142 | - urllib3=1.24.2=py36_0 143 | - visdom=0.1.8.8=0 144 | - wcwidth=0.1.7=py36_0 145 | - webencodings=0.5.1=py36_1 146 | - websocket-client=0.56.0=py36_0 147 | - wheel=0.33.4=py36_0 148 | - widgetsnbextension=3.5.1=py36_0 149 | - xz=5.2.4=h14c3975_4 150 | - zeromq=4.3.1=he6710b0_3 151 | - zipp=0.5.2=py_0 152 | - zlib=1.2.11=h7b6447c_3 153 | - zstd=1.3.7=h0b5b093_0 154 | - pip: 155 | - appdirs==1.4.3 156 | - argh==0.26.2 157 | - black==19.3b0 158 | - click==7.0 159 | - configparser==4.0.2 160 | - docker-pycreds==0.4.0 161 | - gitdb2==2.0.6 162 | - gitpython==3.0.5 163 | - gql==0.1.0 164 | - graphql-core==2.0 165 | - nvidia-ml-py3==7.352.0 166 | - pathtools==0.1.2 167 | - promise==2.2.1 168 | - psutil==5.6.7 169 | - rx==3.0.1 170 | - seaborn==0.9.0 171 | - sentry-sdk==0.13.4 172 | - shortuuid==0.5.0 173 | - smmap2==2.0.5 174 | - subprocess32==3.5.4 175 | - toml==0.10.0 176 | - wand==0.5.7 177 | - wandb==0.8.16 178 | - watchdog==0.9.0 -------------------------------------------------------------------------------- /images/DMG_approach_preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prithv1/DMG/6b162b4958b52f4d99d51663e053f58f5a77e7cc/images/DMG_approach_preview.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_model import Basic_Model 2 | from .multihead_model import MultiHead_Model 3 | from .subnetwork_supermask_model import SubNetwork_SuperMask_Model 4 | from .supermasks import SuperMask 5 | 6 | __all__ = [ 7 | "Basic_Model", 8 | "MultiHead_Model", 9 | "SubNetwork_SuperMask_Model", 10 | "SuperMask", 11 | ] 12 | -------------------------------------------------------------------------------- /models/basic_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/pytorch/tutorials/blob/master/beginner_source/finetuning_torchvision_models_tutorial.py 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import torchvision 8 | import torch.utils.data 9 | 10 | import numpy as np 11 | 12 | from utils.misc import weights_init 13 | 14 | from torch import nn, optim 15 | from torch.nn import functional as F 16 | from torchvision import datasets, models, transforms 17 | 18 | ALEXNET_DROPOUT_LAYERS = ["classifier.0", "classifier.3"] 19 | VGG16_DROPOUT_LAYERS = ["classifier.2", "classifier.5"] 20 | 21 | 22 | class Basic_Model(nn.Module): 23 | def __init__( 24 | self, 25 | model_name, 26 | num_classes, 27 | split_layer, 28 | init_setting="custom", 29 | use_pretrained=True, 30 | ): 31 | super(Basic_Model, self).__init__() 32 | self.model_name = model_name 33 | self.num_classes = num_classes 34 | self.split_layer = split_layer 35 | self.use_pretrained = use_pretrained 36 | 37 | self.criterion = nn.CrossEntropyLoss() 38 | 39 | self.input_size = 0 40 | 41 | self.model_fn = getattr(models, self.model_name) 42 | self.model_ft = self.model_fn(pretrained=self.use_pretrained) 43 | 44 | # Iterate over specific model types 45 | if self.model_name == "alexnet": 46 | num_feats = self.model_ft.classifier[6].in_features 47 | self.input_size = 224 48 | self.model_ft.classifier[6] = nn.Linear(num_feats, self.num_classes) 49 | if init_setting == "custom": 50 | self.model_ft.classifier[6].apply(weights_init) 51 | print(self.model_ft.classifier[6].weight) 52 | print(self.model_ft.classifier[6].bias) 53 | elif "resnet" in self.model_name: 54 | num_feats = self.model_ft.fc.in_features 55 | self.input_size = 224 56 | self.model_ft.fc = nn.Linear(num_feats, self.num_classes) 57 | if init_setting == "custom": 58 | self.model_ft.fc.apply(weights_init) 59 | else: 60 | print("Model type not supported yet") 61 | 62 | def forward(self, img): 63 | scores = self.model_ft(img) 64 | return scores 65 | 66 | def loss_fn(self, scores, labels): 67 | loss_val = self.criterion(scores, labels) 68 | return loss_val 69 | 70 | def loss_gpu(self, flag=True): 71 | if flag: 72 | self.criterion.cuda() 73 | 74 | @classmethod 75 | def from_config(cls, config): 76 | _C = config 77 | return cls( 78 | model_name=_C.MODEL.BASE_MODEL, 79 | num_classes=_C.MODEL.NUM_CLASSES, 80 | split_layer=_C.MODEL.SPLIT_LAYER, 81 | init_setting=_C.MODEL.PARAM_INIT, 82 | use_pretrained=_C.MODEL.USE_PRETRAINED, 83 | ) 84 | -------------------------------------------------------------------------------- /models/multihead_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/pytorch/tutorials/blob/master/beginner_source/finetuning_torchvision_models_tutorial.py 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import torchvision 8 | import torch.utils.data 9 | 10 | import numpy as np 11 | 12 | from utils.misc import weights_init 13 | 14 | from pprint import pprint 15 | from torch import nn, optim 16 | from torch.nn import functional as F 17 | from torchvision import datasets, models, transforms 18 | 19 | # Possible split layers for different model architectures 20 | alexnet_task_net_layer = ["classifier." + str(x) for x in range(0, 7)] 21 | vgg16_task_net_layers = ["classifier." + str(x) for x in range(0, 7)] 22 | vgg16_bn_task_net_layers = ["classifier." + str(x) for x in range(0, 7)] 23 | resnet18_task_net_layers = ["fc"] 24 | resnet50_task_net_layers = ["fc"] 25 | 26 | ALEXNET_POSS_SPLIT_LAYERS = [ 27 | "classifier.0", # the entire network post pool5 28 | "classifier.1", # Following linear layer 29 | "classifier.4", # Following linear layer 30 | "classifier.6", # Last linear layer 31 | ] 32 | 33 | RESNET_18_POSS_SPLIT_LAYERS = ["fc"] # Only the last classifier layer 34 | 35 | RESNET_50_POSS_SPLIT_LAYERS = ["fc"] # Only the last classifier layer 36 | 37 | # Identity class 38 | class Identity(nn.Module): 39 | def __init__(self): 40 | super(Identity, self).__init__() 41 | 42 | def forward(self, x): 43 | return x 44 | 45 | 46 | class MultiHead_Model(nn.Module): 47 | def __init__( 48 | self, 49 | domain_list, 50 | model_name, 51 | num_classes, 52 | task_net_layer=None, 53 | init_setting="custom", 54 | use_pretrained=True, 55 | ): 56 | super(MultiHead_Model, self).__init__() 57 | self.model_name = model_name 58 | self.num_classes = num_classes 59 | self.use_pretrained = use_pretrained 60 | self.domain_list = domain_list 61 | self.task_net_layer = task_net_layer 62 | 63 | self.criterion = nn.CrossEntropyLoss() 64 | 65 | self.input_size = 0 66 | 67 | # Define model structure based on the specified arguments 68 | # This decides the feature-network and task-network splits 69 | self.model_fn = getattr(models, self.model_name) 70 | self.model_ft = self.model_fn(pretrained=self.use_pretrained) 71 | if self.model_name == "alexnet": 72 | num_feats = self.model_ft.classifier[6].in_features 73 | self.input_size = 224 74 | self.model_ft.classifier[6] = nn.Linear(num_feats, self.num_classes) 75 | if init_setting == "custom": 76 | self.model_ft.classifier[6].apply(weights_init) 77 | 78 | # Get a list of classifier layers 79 | classifier_layers = ["classifier." + str(x) for x in range(0, 7)] 80 | 81 | # Identify which classifier layer is being split at 82 | classifier_ind = int(self.task_net_layer.split(".")[-1]) 83 | 84 | # Get the whole module list 85 | module_list = list( 86 | self.model_ft.classifier[ 87 | classifier_ind : len(classifier_layers) 88 | ].children() 89 | ) 90 | 91 | # Create task networks for every domain 92 | self.domain_task_nets = nn.ModuleDict( 93 | {x: nn.Sequential(*module_list) for x in self.domain_list} 94 | ) 95 | 96 | # Make older versions identity 97 | for i in range(classifier_ind, 7): 98 | self.model_ft.classifier[i] = Identity() 99 | 100 | elif "resnet" in self.model_name: 101 | num_feats = self.model_ft.fc.in_features 102 | self.input_size = 224 103 | 104 | # Create task networks 105 | self.domain_task_nets = nn.ModuleDict( 106 | { 107 | x: nn.Sequential(nn.Linear(num_feats, self.num_classes)) 108 | for x in self.domain_list 109 | } 110 | ) 111 | 112 | # Weight and bias initialization 113 | if init_setting == "custom": 114 | for x in self.domain_list: 115 | self.domain_task_nets[x][0].apply(weights_init) 116 | 117 | self.model_ft.fc = nn.Identity() 118 | else: 119 | print("Model type not supported") 120 | 121 | # Note that our forward pass 122 | # has to be aware of the domain-ID 123 | # It has to route examples to specific 124 | # domain task networks accordingly 125 | def forward(self, img, dom): 126 | # Extract features first 127 | feats = self.model_ft(img) 128 | # Convert mini-batch domain list to list of indices 129 | scores = [] 130 | for i in range(len(dom)): 131 | scores.append(self.domain_task_nets[dom[i]](torch.unsqueeze(feats[i], 0))) 132 | scores = torch.cat(scores) 133 | return scores 134 | 135 | def loss_fn(self, scores, labels): 136 | # Basic cross-entropy criterion 137 | loss_val = self.criterion(scores, labels) 138 | return loss_val 139 | 140 | def loss_gpu(self, flag=True): 141 | if flag: 142 | self.criterion.cuda() 143 | 144 | # Now we need another evaluation time forward mode 145 | # We average the raw un-normalized scores 146 | # from all the heads. One instance -> through all heads 147 | def avg_forward(self, img): 148 | # Extract features 149 | feats = self.model_ft(img) 150 | # Extract scores from all the 151 | # domain-specific task networks 152 | scores = [] 153 | for dom in self.domain_list: 154 | scores.append(self.domain_task_nets[dom](feats)) 155 | 156 | return torch.mean(torch.stack(scores), dim=0) 157 | 158 | # Average probability and then forward 159 | def avg_prob_forward(self, img): 160 | # Extract features 161 | feats = self.model_ft(img) 162 | # Extract scores from all the 163 | # domain-specific task networks 164 | scores = [] 165 | for dom in self.domain_list: 166 | scores.append(nn.Softmax(dim=1)(self.domain_task_nets[dom](feats))) 167 | 168 | return torch.mean(torch.stack(scores), dim=0) 169 | 170 | @classmethod 171 | def from_config(cls, config): 172 | _C = config 173 | domains = _C.DATA.DOMAIN_LIST 174 | if "," in domains: 175 | domains = _C.DATA.DOMAIN_LIST.split(",") 176 | return cls( 177 | domain_list=domains, 178 | model_name=_C.MODEL.BASE_MODEL, 179 | num_classes=_C.MODEL.NUM_CLASSES, 180 | task_net_layer=_C.MODEL.SPLIT_LAYER, 181 | init_setting=_C.MODEL.PARAM_INIT, 182 | use_pretrained=_C.MODEL.USE_PRETRAINED, 183 | ) 184 | -------------------------------------------------------------------------------- /models/subnetwork_supermask_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torchvision 5 | import torch.utils.data 6 | 7 | import numpy as np 8 | 9 | from pprint import pprint 10 | from torch import nn, optim 11 | from torch.nn import functional as F 12 | from torchvision import datasets, models, transforms 13 | 14 | ALEXNET_LAYERS = ["classifier.1", "classifier.4", "classifier.6"] 15 | RESNET_LAYERS = ["layer3", "layer4", "fc"] 16 | 17 | ALEXNET_DROPOUT_LAYERS = ["classifier.0", "classifier.3"] 18 | 19 | # Specify the sizes of the input 20 | # activations for each of the valid masking layers 21 | RESNET18_LAYER_SIZES = {"layer3": 128 * 28 * 28, "layer4": 256 * 14 * 14, "fc": 512} 22 | 23 | RESNET50_LAYER_SIZES = {"layer3": 512 * 28 * 28, "layer4": 1024 * 14 * 14, "fc": 2048} 24 | 25 | RESNET18_CNV_LAYER_SIZES = {"layer3": 128, "layer4": 256, "fc": 512} 26 | 27 | RESNET50_CNV_LAYER_SIZES = {"layer3": 512, "layer4": 1024, "fc": 2048} 28 | 29 | 30 | class SubNetwork_SuperMask_Model(nn.Module): 31 | def __init__(self, mask_layers, joint_model): 32 | super(SubNetwork_SuperMask_Model, self).__init__() 33 | 34 | # Mask-Layers can be provided as 35 | # [classifier.4, classifier.6] 36 | self.mask_layers = mask_layers 37 | 38 | # The actual multi-head model 39 | self.joint_model = joint_model 40 | 41 | # Store base model in a variable 42 | self.base_model = self.joint_model.model_name 43 | 44 | # List legal mask areas 45 | self.legal_mask_areas = { 46 | "alexnet": ALEXNET_LAYERS, 47 | "resnet18": RESNET_LAYERS, 48 | "resnet50": RESNET_LAYERS, 49 | } 50 | 51 | # We'll go over the model structure and identify 52 | # legal mask / regions -- layers at whose input 53 | # activations it is legal to apply the mask 54 | self.legal_mask_areas, self.mask_areas = {}, {} 55 | 56 | # Define loss 57 | # The `input` is expected to contain raw, unnormalized scores for each class and associated label. 58 | self.criterion = nn.CrossEntropyLoss() 59 | 60 | def get_mask_struct(self, conv_mode=False): 61 | # Retrieve the input activation sizes 62 | act_size = [] 63 | if self.base_model in ["alexnet", "vgg16", "vgg16_bn"]: 64 | for mask_layer in self.mask_layers: 65 | if "classifier" in mask_layer: 66 | act_size.append( 67 | self.joint_model.model_ft.classifier[ 68 | int(mask_layer.split(".")[-1]) 69 | ].in_features 70 | ) 71 | else: 72 | print("Masking this layer is not supported yet") 73 | elif self.base_model == "resnet18": 74 | for mask_layer in self.mask_layers: 75 | if mask_layer == "layer3": 76 | if conv_mode: 77 | act_size.append(RESNET18_CNV_LAYER_SIZES[mask_layer]) 78 | else: 79 | act_size.append(RESNET18_LAYER_SIZES[mask_layer]) 80 | elif mask_layer == "layer4": 81 | if conv_mode: 82 | act_size.append(RESNET18_CNV_LAYER_SIZES[mask_layer]) 83 | else: 84 | act_size.append(RESNET18_LAYER_SIZES[mask_layer]) 85 | elif mask_layer == "fc": 86 | if conv_mode: 87 | act_size.append(RESNET18_CNV_LAYER_SIZES[mask_layer]) 88 | else: 89 | act_size.append(RESNET18_LAYER_SIZES[mask_layer]) 90 | elif self.base_model == "resnet50": 91 | for mask_layer in self.mask_layers: 92 | if mask_layer == "layer3": 93 | if conv_mode: 94 | act_size.append(RESNET50_CNV_LAYER_SIZES[mask_layer]) 95 | else: 96 | act_size.append(RESNET50_LAYER_SIZES[mask_layer]) 97 | elif mask_layer == "layer4": 98 | if conv_mode: 99 | act_size.append(RESNET50_CNV_LAYER_SIZES[mask_layer]) 100 | else: 101 | act_size.append(RESNET50_LAYER_SIZES[mask_layer]) 102 | elif mask_layer == "fc": 103 | if conv_mode: 104 | act_size.append(RESNET50_CNV_LAYER_SIZES[mask_layer]) 105 | else: 106 | act_size.append(RESNET50_LAYER_SIZES[mask_layer]) 107 | else: 108 | print("Model not supported yet") 109 | return act_size 110 | 111 | def set_dropout_eval(self, flag=True): 112 | # Set the dropout to eval mode for the 113 | # specified networks -- only alexnet, vgg16 114 | if self.base_model == "alexnet": 115 | dropout_layers = ALEXNET_DROPOUT_LAYERS 116 | for dropout_layer in dropout_layers: 117 | if flag: 118 | self.joint_model.model_ft.classifier[ 119 | int(dropout_layer.split(".")[-1]) 120 | ].eval() 121 | else: 122 | self.joint_model.model_ft.classifier[ 123 | int(dropout_layer.split(".")[-1]) 124 | ].train() 125 | elif self.base_model == "resnet18": 126 | pass 127 | elif self.base_model == "resnet50": 128 | pass 129 | else: 130 | print("Base model not supported yet") 131 | 132 | def classifier_dom_mask_forward( 133 | self, img, policy_modules, policy_domain, mode="sample" 134 | ): 135 | # This is only for alexnet and vgg-16 based architectures 136 | # We don't need these for resnet-18 and 50 137 | prob_ls, action_ls, logProb_ls = [], [], [] 138 | 139 | # Extract features 140 | feats = self.joint_model.model_ft.features(img) 141 | feats = self.joint_model.model_ft.avgpool(feats) 142 | feats = feats.view(feats.size(0), -1) 143 | # Go through the classifier 144 | for j in range(len(self.joint_model.model_ft.classifier)): 145 | if "classifier." + str(j) in self.mask_layers: 146 | rel_ind = self.mask_layers.index("classifier." + str(j)) 147 | feats, action, probs = policy_modules[rel_ind]( 148 | feats, policy_domain, mode 149 | ) 150 | prob_ls.append(probs) 151 | action_ls.append(action) 152 | feats = self.joint_model.model_ft.classifier[j](feats) 153 | else: 154 | feats = self.joint_model.model_ft.classifier[j](feats) 155 | 156 | scores = feats 157 | return scores, prob_ls, action_ls, logProb_ls 158 | 159 | def forward( 160 | self, img, policy_modules, policy_domain, mode="sample", conv_mode=False 161 | ): 162 | 163 | # Check if the number of layer indices match 164 | # match the number of policy modules 165 | assert len(self.mask_layers) == len( 166 | policy_modules 167 | ), "Unequal number of layers and modules" 168 | 169 | # Data structures to store the scores, 170 | # probabilities, actions and logProbs 171 | prob_ls, action_ls, logProb_ls = [], [], [] 172 | 173 | if self.base_model == "alexnet": 174 | scores, prob_ls, action_ls, logProb_ls = self.classifier_dom_mask_forward( 175 | img, policy_modules, policy_domain, mode 176 | ) 177 | elif self.base_model in ["resnet18", "resnet50"]: 178 | 179 | # Forward pass based on the specified masking layers 180 | feats = self.joint_model.model_ft.conv1(img) 181 | feats = self.joint_model.model_ft.bn1(feats) 182 | feats = self.joint_model.model_ft.relu(feats) 183 | feats = self.joint_model.model_ft.maxpool(feats) 184 | feats = self.joint_model.model_ft.layer1(feats) 185 | feats = self.joint_model.model_ft.layer2(feats) 186 | 187 | if "layer3" in self.mask_layers: 188 | feat_shape = feats.shape 189 | pol_ind = self.mask_layers.index("layer3") 190 | if not conv_mode: 191 | feats = feats.view(feat_shape[0], -1) 192 | feats, action, probs = policy_modules[pol_ind]( 193 | feats, policy_domain, mode, conv_mode 194 | ) 195 | prob_ls.append(probs) 196 | action_ls.append(action) 197 | feats = feats.view(feat_shape) 198 | feats = self.joint_model.model_ft.layer3(feats) 199 | else: 200 | feats = self.joint_model.model_ft.layer3(feats) 201 | 202 | if "layer4" in self.mask_layers: 203 | feat_shape = feats.shape 204 | pol_ind = self.mask_layers.index("layer4") 205 | if not conv_mode: 206 | feats = feats.view(feat_shape[0], -1) 207 | feats, action, probs = policy_modules[pol_ind]( 208 | feats, policy_domain, mode, conv_mode 209 | ) 210 | prob_ls.append(probs) 211 | action_ls.append(action) 212 | feats = feats.view(feat_shape) 213 | feats = self.joint_model.model_ft.layer4(feats) 214 | else: 215 | feats = self.joint_model.model_ft.layer4(feats) 216 | 217 | feats = self.joint_model.model_ft.avgpool(feats) 218 | 219 | feats = feats.view(feats.size(0), -1) 220 | if "fc" in self.mask_layers: 221 | pol_ind = self.mask_layers.index("fc") 222 | feats, action, probs = policy_modules[pol_ind]( 223 | feats, policy_domain, mode, conv_mode 224 | ) 225 | prob_ls.append(probs) 226 | action_ls.append(action) 227 | scores = self.joint_model.model_ft.fc(feats) 228 | del feats, probs, action 229 | else: 230 | print("Model not supported yet") 231 | return scores, prob_ls, action_ls 232 | 233 | def loss_fn(self, scores, labels): 234 | # Basic cross-entropy criterion 235 | # loss_val = self.criterion(scores, labels) 236 | loss_val = F.cross_entropy(scores, labels, reduction="none") 237 | return loss_val 238 | 239 | @classmethod 240 | def from_config(cls, config, joint_model, mask_layers): 241 | _C = config 242 | return cls(mask_layers=mask_layers, joint_model=joint_model) 243 | 244 | -------------------------------------------------------------------------------- /models/supermasks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torchvision 5 | import torch.utils.data 6 | 7 | import numpy as np 8 | 9 | from pprint import pprint 10 | from itertools import combinations 11 | 12 | from torch import nn, optim 13 | from torch.nn import functional as F 14 | from torch.distributions import Bernoulli, RelaxedBernoulli 15 | from torchvision import datasets, models, transforms 16 | 17 | SMOOTH = 1e-6 18 | 19 | 20 | class SuperMask(nn.Module): 21 | def __init__(self, domain_list, act_size, init_setting="random", init_scalar=1): 22 | super(SuperMask, self).__init__() 23 | self.domain_list = domain_list 24 | self.act_size = act_size 25 | self.init_setting = init_setting 26 | self.init_scalar = init_scalar 27 | 28 | # Define the super mask logits 29 | if self.init_setting == "random_uniform": 30 | self.super_mask_logits = nn.ParameterDict( 31 | { 32 | x: nn.Parameter(torch.rand(self.act_size, requires_grad=True)) 33 | for x in self.domain_list 34 | } 35 | ) 36 | elif self.init_setting == "scalar": 37 | param_tensor = torch.ones(self.act_size, requires_grad=True) 38 | param_tensor = param_tensor.new_tensor( 39 | [self.init_scalar] * self.act_size, requires_grad=True 40 | ) 41 | self.super_mask_logits = nn.ParameterDict( 42 | {x: nn.Parameter(param_tensor.clone()) for x in self.domain_list} 43 | ) 44 | 45 | def forward(self, activation, domain, mode="sample", conv_mode=False): 46 | # Mask repeated along channel dimensions if conv_mode == True 47 | probs = [nn.Sigmoid()(self.super_mask_logits[x]) for x in domain] 48 | probs = torch.stack(probs) 49 | if mode == "sample": 50 | mask_dist = Bernoulli(probs) 51 | hard_mask = mask_dist.sample() 52 | soft_mask = probs 53 | mask = (hard_mask - soft_mask).detach() + soft_mask 54 | if conv_mode and len(activation.shape) > 2: 55 | apply_mask = mask.view(mask.shape[0], mask.shape[1], 1, 1) 56 | apply_mask = apply_mask.repeat( 57 | 1, 1, activation.shape[2], activation.shape[3] 58 | ) 59 | activation = apply_mask * activation 60 | else: 61 | activation = mask * activation 62 | elif mode == "greedy": 63 | hard_mask = (probs > 0.5).float() 64 | soft_mask = probs 65 | mask = (hard_mask - soft_mask).detach() + soft_mask 66 | if conv_mode and len(activation.shape) > 2: 67 | apply_mask = mask.view(mask.shape[0], mask.shape[1], 1, 1) 68 | apply_mask = apply_mask.repeat( 69 | 1, 1, activation.shape[2], activation.shape[3] 70 | ) 71 | activation = apply_mask * activation 72 | else: 73 | activation = mask * activation 74 | elif mode == "softscale": 75 | hard_mask = (probs > 0.5).float() 76 | soft_mask = probs 77 | mask = hard_mask 78 | if conv_mode and len(activation.shape) > 2: 79 | apply_mask = soft_mask.view( 80 | soft_mask.shape[0], soft_mask.shape[1], 1, 1 81 | ) 82 | apply_mask = apply_mask.repeat( 83 | 1, 1, activation.shape[2], activation.shape[3] 84 | ) 85 | activation = apply_mask * activation 86 | else: 87 | activation = soft_mask * activation 88 | elif mode == "avg_mask_softscale": 89 | # Average all the source domain masks 90 | # instead of combining them 91 | all_probs = [ 92 | nn.Sigmoid()(self.super_mask_logits[x]) for x in self.domain_list 93 | ] 94 | all_probs = torch.mean(torch.stack(all_probs), 0) 95 | mean_mask = [all_probs for x in domain] 96 | mean_mask = torch.stack(mean_mask) 97 | soft_mask = mean_mask 98 | hard_mask = (mean_mask > 0.5).float() 99 | mask = hard_mask 100 | if conv_mode and len(activation.shape) > 2: 101 | apply_mask = soft_mask.view( 102 | soft_mask.shape[0], soft_mask.shape[1], 1, 1 103 | ) 104 | apply_mask = apply_mask.repeat( 105 | 1, 1, activation.shape[2], activation.shape[3] 106 | ) 107 | activation = apply_mask * activation 108 | else: 109 | activation = soft_mask * activation 110 | 111 | return (activation, mask, soft_mask) 112 | 113 | def sparsity(self, mask): 114 | return torch.mean(mask, dim=1) 115 | 116 | def sparsity_penalty(self): 117 | sparse_pen = 0 118 | for _, v in self.super_mask_logits.items(): 119 | sparse_pen += torch.sum(nn.Sigmoid()(v)) 120 | return sparse_pen 121 | 122 | def overlap_penalty(self): 123 | overlap_pen = 0 124 | domain_pairs = list(combinations(self.domain_list, 2)) 125 | for pair in domain_pairs: 126 | dom1, dom2 = pair 127 | mask1 = nn.Sigmoid()(self.super_mask_logits[dom1]) 128 | mask2 = nn.Sigmoid()(self.super_mask_logits[dom2]) 129 | intersection = torch.sum(mask1 * mask2) 130 | union = torch.sum(mask1 + mask2 - mask1 * mask2) 131 | iou = (intersection + SMOOTH) / (union + SMOOTH) 132 | overlap_pen += iou 133 | overlap_pen /= len(domain_pairs) 134 | return overlap_pen 135 | 136 | def mask_overlap(self, layer_name=""): 137 | if layer_name != "": 138 | prefix = layer_name + " : " 139 | else: 140 | prefix = "" 141 | domain_pairs = combinations(self.domain_list, 2) 142 | iou_overlap_dict = {} 143 | for pair in domain_pairs: 144 | mask_0 = nn.Sigmoid()(self.super_mask_logits[pair[0]]) 145 | mask_1 = nn.Sigmoid()(self.super_mask_logits[pair[1]]) 146 | mask_0 = mask_0 > 0.5 147 | mask_1 = mask_1 > 0.5 148 | intersection = (mask_0 & mask_1).float().sum() 149 | union = (mask_0 | mask_1).float().sum() 150 | iou = (intersection + SMOOTH) / (union + SMOOTH) 151 | iou_overlap_dict[ 152 | prefix + pair[0] + ", " + pair[1] + " IoU-Ov" 153 | ] = iou.data.item() 154 | iou_overlap_dict[prefix + "overall IoU-Ov"] = np.mean( 155 | [x for x in list(iou_overlap_dict.values())] 156 | ) 157 | return iou_overlap_dict 158 | 159 | @classmethod 160 | def from_config(cls, config, act_size): 161 | _C = config 162 | domains = _C.DATA.DOMAIN_LIST 163 | if "," in domains: 164 | domains = _C.DATA.DOMAIN_LIST.split(",") 165 | return cls( 166 | domains, act_size, _C.MODEL.MASK_INIT_SETTING, _C.MODEL.MASK_INIT_SCALAR 167 | ) 168 | -------------------------------------------------------------------------------- /run_jobs.sh: -------------------------------------------------------------------------------- 1 | # DomainNet Experiments 2 | # To run experiments on PACS, change configs/DomainNet -> configs/PACS 3 | 4 | # Aggregate Training (AlexNet) 5 | python train_model.py \ 6 | --phase aggregate_training \ 7 | --config-yml configs/DomainNet/aggregate_training.yml \ 8 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ 9 | DATA.TARGET_DOMAINS sketch HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL alexnet 10 | 11 | # Aggregate Training (ResNet-18) 12 | python train_model.py \ 13 | --phase aggregate_training \ 14 | --config-yml configs/DomainNet/aggregate_training.yml \ 15 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ 16 | DATA.TARGET_DOMAINS sketch HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL resnet18 17 | 18 | # Aggregate Training (ResNet-50) 19 | python train_model.py \ 20 | --phase aggregate_training \ 21 | --config-yml configs/DomainNet/aggregate_training.yml \ 22 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ 23 | DATA.TARGET_DOMAINS sketch HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL resnet50 24 | 25 | # Multi-Head Training (AlexNet) 26 | python train_model.py \ 27 | --phase multihead_training \ 28 | --config-yml configs/DomainNet/multihead_training.yml \ 29 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ 30 | DATA.TARGET_DOMAINS sketch HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL alexnet \ 31 | MODEL.SPLIT_LAYER classifier.6 32 | 33 | # Multi-Head Training (ResNet-18) 34 | python train_model.py \ 35 | --phase multihead_training \ 36 | --config-yml configs/DomainNet/multihead_training.yml \ 37 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ 38 | DATA.TARGET_DOMAINS sketch HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL resnet18 \ 39 | MODEL.SPLIT_LAYER fc 40 | 41 | # Multi-Head Training (ResNet-50) 42 | python train_model.py \ 43 | --phase multihead_training \ 44 | --config-yml configs/DomainNet/multihead_training.yml \ 45 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real \ 46 | DATA.TARGET_DOMAINS sketch HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL resnet50 \ 47 | MODEL.SPLIT_LAYER fc 48 | 49 | # Domain-Specific Masks for Generalization (AlexNet) 50 | python train_model.py \ 51 | --phase supermask_training \ 52 | --config-yml configs/DomainNet/supermask_training.yml \ 53 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real DATA.TARGET_DOMAINS sketch \ 54 | HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL alexnet \ 55 | MODEL.MASK_LAYERS classifier.1,classifier.4,classifier.6 \ 56 | MODEL.MASK_INIT_SETTING random_uniform MODEL.POLICY_CONV_MODE False 57 | 58 | # Domain-Specific Masks for Generalization (ResNet-18) 59 | python train_model.py \ 60 | --phase supermask_training \ 61 | --config-yml configs/DomainNet/supermask_training.yml \ 62 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real DATA.TARGET_DOMAINS sketch \ 63 | HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL resnet18 \ 64 | MODEL.MASK_LAYERS layer3,layer4,fc \ 65 | MODEL.MASK_INIT_SETTING random_uniform MODEL.POLICY_CONV_MODE True 66 | 67 | # Domain-Specific Masks for Generalization (ResNet-50) 68 | python train_model.py \ 69 | --phase supermask_training \ 70 | --config-yml configs/DomainNet/supermask_training.yml \ 71 | --config-override DATA.DOMAIN_LIST clipart,infograph,painting,quickdraw,real DATA.TARGET_DOMAINS sketch \ 72 | HJOB.JOB_STRING dmnt_v1 MODEL.BASE_MODEL resnet50 \ 73 | MODEL.MASK_LAYERS layer3,layer4,fc \ 74 | MODEL.MASK_INIT_SETTING random_uniform MODEL.POLICY_CONV_MODE True -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import copy 5 | import time 6 | import torch 7 | import random 8 | import argparse 9 | import torchvision 10 | import dataloaders 11 | import torch.utils.data 12 | 13 | import numpy as np 14 | 15 | from tqdm import tqdm 16 | from config import Config 17 | from torch import nn, optim 18 | from pprint import pprint, pformat 19 | from torch.autograd import Variable 20 | from torch.nn import functional as F 21 | from utils.inverse_lr_scheduler import InvLR 22 | from torchvision import datasets, transforms 23 | 24 | # Import all the model definitions 25 | from models import Basic_Model, MultiHead_Model, SubNetwork_SuperMask_Model, SuperMask 26 | 27 | # Import Dataloaders 28 | from dataloaders import DomainDataset, Aggregate_DomainDataset 29 | 30 | # Import all the trainers and evaluators 31 | from trainers import Aggregate_Trainer, MultiHead_Trainer, SubNetwork_SuperMask_Trainer 32 | 33 | parser = argparse.ArgumentParser("Run training for a particular phase") 34 | 35 | parser.add_argument( 36 | "--phase", 37 | required=True, 38 | choices=["aggregate_training", "multihead_training", "supermask_training"], 39 | help="Which phase to train, this must match the 'PHASE' parameter in the provided config.", 40 | ) 41 | 42 | parser.add_argument( 43 | "--config-yml", required=True, help="Path to a config file for a training job" 44 | ) 45 | 46 | parser.add_argument( 47 | "--config-override", 48 | default=[], 49 | nargs="*", 50 | help="A sequence of key-value pairs specifying certain config arguments (with dict-like nesting) using a dot operator. The actual config will be updated and recorded in the serialization directory.", 51 | ) 52 | 53 | if __name__ == "__main__": 54 | # Parse arguments 55 | _A = parser.parse_args() 56 | 57 | # Create a config with default values, then override from config file and _A. 58 | # This config object is immutable, nothing can be changed in this anymore 59 | _C = Config(_A.config_yml, _A.config_override) 60 | 61 | # Match the phase from arguments and config parameters 62 | if _A.phase != _C.HJOB.PHASE: 63 | raise ValueError( 64 | f"Provided `--phase` as {_A.phase}, does not match config PHASE ({_C.HJOB.PHASE})." 65 | ) 66 | 67 | # Print configs and args 68 | for arg in vars(_A): 69 | print("{:<20}: {}".format(arg, getattr(_A, arg))) 70 | 71 | # Display config to be used 72 | print(_C) 73 | 74 | # Get environment name 75 | ENV_NAME = _C.get_env() 76 | pprint(ENV_NAME) 77 | 78 | # Get checkpoint directory 79 | CKPT_DIR = _C.CKPT.STORAGE_DIR + _C.DATA.DATASET + "/checkpoints/" + ENV_NAME 80 | 81 | # Create directory and save config 82 | if not os.path.exists(CKPT_DIR): 83 | os.makedirs(CKPT_DIR) 84 | _C.dump(os.path.join(CKPT_DIR, "config.yml")) 85 | 86 | # Fix seeds for reproducibility 87 | # Reference - https://pytorch.org/docs/stable/notes/randomness.html 88 | # These five lines control all the major sources of randomness. 89 | np.random.seed(_C.HJOB.RANDOM_SEED) 90 | torch.manual_seed(_C.HJOB.RANDOM_SEED) 91 | torch.cuda.manual_seed_all(_C.HJOB.RANDOM_SEED) 92 | torch.backends.cudnn.benchmark = False 93 | torch.backends.cudnn.deterministic = True 94 | 95 | # Model Definition 96 | if _A.phase == "aggregate_training": 97 | model = Basic_Model.from_config(_C) 98 | input_size = model.input_size 99 | elif _A.phase == "multihead_training": 100 | model = MultiHead_Model.from_config(_C) 101 | input_size = model.input_size 102 | elif _A.phase == "supermask_training": 103 | MASK_LAYERS = [_C.MODEL.MASK_LAYERS] 104 | if "," in _C.MODEL.MASK_LAYERS: 105 | MASK_LAYERS = _C.MODEL.MASK_LAYERS.split(",") 106 | MASK_LAYERS = sorted(MASK_LAYERS) 107 | # Define joint model 108 | joint_model = Basic_Model.from_config(_C) 109 | input_size = joint_model.input_size 110 | # Define conditional computation model 111 | model = SubNetwork_SuperMask_Model.from_config(_C, joint_model, MASK_LAYERS) 112 | # Get init arguments for the policy modules 113 | act_size = model.get_mask_struct(_C.MODEL.POLICY_CONV_MODE) 114 | # Create policy modules 115 | policy_modules = [] 116 | for i in range(len(MASK_LAYERS)): 117 | policy_modules.append(SuperMask.from_config(_C, act_size[i])) 118 | else: 119 | print("Training phase not supported") 120 | 121 | # Move model to GPU 122 | if _C.PROCESS.USE_GPU: 123 | model.cuda() 124 | if _A.phase == "supermask_training": 125 | for i in range(len(MASK_LAYERS)): 126 | policy_modules[i].cuda() 127 | 128 | # Data Transformations 129 | data_transforms = { 130 | "val": transforms.Compose( 131 | [ 132 | transforms.Resize((input_size, input_size)), 133 | transforms.ToTensor(), 134 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 135 | ] 136 | ), 137 | "test": transforms.Compose( 138 | [ 139 | transforms.Resize((input_size, input_size)), 140 | transforms.ToTensor(), 141 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 142 | ] 143 | ), 144 | } 145 | 146 | if _C.DATA.DATASET == "PACS": 147 | data_transforms["train"] = transforms.Compose( 148 | [ 149 | transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), 150 | transforms.RandomHorizontalFlip(), 151 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 152 | transforms.ToTensor(), 153 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 154 | ] 155 | ) 156 | elif _C.DATA.DATASET == "DomainNet": 157 | data_transforms["train"] = transforms.Compose( 158 | [ 159 | transforms.Resize(256), 160 | transforms.RandomCrop(input_size), 161 | transforms.RandomHorizontalFlip(), 162 | transforms.ToTensor(), 163 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 164 | ] 165 | ) 166 | 167 | DOMAIN_LIST = _C.DATA.DOMAIN_LIST 168 | if "," in _C.DATA.DOMAIN_LIST: 169 | DOMAIN_LIST = _C.DATA.DOMAIN_LIST.split(",") 170 | else: 171 | DOMAIN_LIST = [_C.DATA.DOMAIN_LIST] 172 | 173 | TARGET_DOMAINS = _C.DATA.TARGET_DOMAINS 174 | if "," in _C.DATA.TARGET_DOMAINS: 175 | TARGET_DOMAINS = _C.DATA.TARGET_DOMAINS.split(",") 176 | else: 177 | TARGET_DOMAINS = [_C.DATA.TARGET_DOMAINS] 178 | 179 | # Dataloader Definitions 180 | train_split_obj = Aggregate_DomainDataset( 181 | _C.DATA.DATASET, 182 | DOMAIN_LIST, 183 | _C.DATA.DATA_SPLIT_DIR, 184 | "train", 185 | data_transforms["train"], 186 | _C.DATALOADER.BATCH_SIZE, 187 | _C.PROCESS.NUM_WORKERS, 188 | _C.PROCESS.USE_GPU, 189 | shuffle=True, 190 | ) 191 | 192 | val_split_obj = Aggregate_DomainDataset( 193 | _C.DATA.DATASET, 194 | DOMAIN_LIST, 195 | _C.DATA.DATA_SPLIT_DIR, 196 | "val", 197 | data_transforms["val"], 198 | _C.DATALOADER.BATCH_SIZE, 199 | _C.PROCESS.NUM_WORKERS, 200 | _C.PROCESS.USE_GPU, 201 | shuffle=False, 202 | ) 203 | 204 | test_split_obj = Aggregate_DomainDataset( 205 | _C.DATA.DATASET, 206 | TARGET_DOMAINS, 207 | _C.DATA.DATA_SPLIT_DIR, 208 | "test", 209 | data_transforms["test"], 210 | _C.DATALOADER.BATCH_SIZE, 211 | _C.PROCESS.NUM_WORKERS, 212 | _C.PROCESS.USE_GPU, 213 | shuffle=False, 214 | ) 215 | 216 | # Setup optimizers and start training 217 | if _A.phase == "aggregate_training" or _A.phase == "multihead_training": 218 | parameters = model.parameters() 219 | if _C.OPTIM.OPTIMIZER == "Adam": 220 | optimizer = optim.Adam( 221 | parameters, 222 | lr=_C.OPTIM.LEARNING_RATE, 223 | betas=(0.9, 0.999), 224 | eps=1e-08, 225 | weight_decay=_C.OPTIM.WEIGHT_DECAY, 226 | ) 227 | else: 228 | print("Optimizer not supported yet") 229 | 230 | if _C.OPTIM.LEARNING_RATE_SCHEDULER == "exp": 231 | scheduler = optim.lr_scheduler.ExponentialLR( 232 | optimizer, _C.OPTIM.LEARNING_RATE_DECAY_RATE 233 | ) 234 | elif _C.OPTIM.LEARNING_RATE_SCHEDULER == "invlr": 235 | scheduler = InvLR(optimizer) 236 | else: 237 | print("LR Scheduler not identified") 238 | 239 | elif _A.phase == "supermask_training": 240 | model_parameters = model.parameters() 241 | policy_module_parameters = [] 242 | for module in policy_modules: 243 | policy_module_parameters += list(module.parameters()) 244 | 245 | if _C.OPTIM.OPTIMIZER == "Adam": 246 | model_optimizer = optim.Adam( 247 | model_parameters, 248 | lr=_C.OPTIM.MODEL_LEARNING_RATE, 249 | betas=(0.9, 0.999), 250 | eps=1e-08, 251 | weight_decay=_C.OPTIM.MODEL_WEIGHT_DECAY, 252 | ) 253 | 254 | policy_optimizer = optim.Adam( 255 | policy_module_parameters, 256 | lr=_C.OPTIM.POLICY_LEARNING_RATE, 257 | betas=(0.9, 0.999), 258 | eps=1e-08, 259 | weight_decay=_C.OPTIM.POLICY_WEIGHT_DECAY, 260 | ) 261 | else: 262 | print("Optimizer not supported yet") 263 | 264 | if _C.OPTIM.LEARNING_RATE_SCHEDULER == "exp": 265 | model_scheduler = optim.lr_scheduler.ExponentialLR( 266 | model_optimizer, _C.OPTIM.LEARNING_RATE_DECAY_RATE 267 | ) 268 | policy_scheduler = optim.lr_scheduler.ExponentialLR( 269 | policy_optimizer, _C.OPTIM.LEARNING_RATE_DECAY_RATE 270 | ) 271 | elif _C.OPTIM.LEARNING_RATE_SCHEDULER == "invlr": 272 | model_scheduler = InvLR(model_optimizer) 273 | policy_scheduler = InvLR(policy_optimizer) 274 | else: 275 | print("LR Scheduler not identified") 276 | else: 277 | print("Phase not supported yet") 278 | 279 | # Define trainer objects and start training 280 | if _A.phase == "aggregate_training": 281 | Trainer = Aggregate_Trainer( 282 | model, 283 | _C.PROCESS.USE_GPU, 284 | train_split_obj, 285 | val_split_obj, 286 | test_split_obj, 287 | _C.EP_IT.MAX_EPOCHS, 288 | optimizer, 289 | ENV_NAME, 290 | _C.EP_IT.LOG_INTERVAL, 291 | scheduler, 292 | _C.OPTIM.LEARNING_RATE_DECAY_MODE, 293 | _C.OPTIM.LEARNING_RATE_DECAY_STEP, 294 | CKPT_DIR, 295 | _C, 296 | ) 297 | 298 | Trainer.train(DOMAIN_LIST, _C.EP_IT.CKPT_STORE_INTERVAL) 299 | 300 | elif _A.phase == "multihead_training": 301 | Trainer = MultiHead_Trainer( 302 | model, 303 | _C.PROCESS.USE_GPU, 304 | train_split_obj, 305 | val_split_obj, 306 | test_split_obj, 307 | _C.MODEL.TRAIN_FORWARD_MODE, 308 | _C.MODEL.EVAL_FORWARD_MODE, 309 | _C.EP_IT.MAX_EPOCHS, 310 | optimizer, 311 | ENV_NAME, 312 | _C.EP_IT.LOG_INTERVAL, 313 | scheduler, 314 | _C.OPTIM.LEARNING_RATE_DECAY_MODE, 315 | _C.OPTIM.LEARNING_RATE_DECAY_STEP, 316 | CKPT_DIR, 317 | _C, 318 | ) 319 | 320 | Trainer.train(DOMAIN_LIST, _C.EP_IT.CKPT_STORE_INTERVAL) 321 | elif _A.phase == "supermask_training": 322 | 323 | Trainer = SubNetwork_SuperMask_Trainer( 324 | model, 325 | MASK_LAYERS, 326 | policy_modules, 327 | DOMAIN_LIST, 328 | TARGET_DOMAINS, 329 | train_split_obj, 330 | val_split_obj, 331 | test_split_obj, 332 | _C.EP_IT.MAX_EPOCHS, 333 | model_optimizer, 334 | policy_optimizer, 335 | CKPT_DIR, 336 | _C, 337 | model_scheduler, 338 | policy_scheduler, 339 | _C.OPTIM.LEARNING_RATE_DECAY_MODE, 340 | _C.EP_IT.LOG_INTERVAL, 341 | ENV_NAME, 342 | _C.PROCESS.USE_GPU, 343 | ) 344 | 345 | Trainer.train(_C.EP_IT.CKPT_STORE_INTERVAL) 346 | else: 347 | print("Training phase not identified.") 348 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregate_trainer import Aggregate_Trainer 2 | from .multihead_trainer import MultiHead_Trainer 3 | from .subnetwork_supermask_trainer import SubNetwork_SuperMask_Trainer 4 | 5 | __all__ = ["Aggregate_Trainer", "MultiHead_Trainer", "SubNetwork_SuperMask_Trainer"] 6 | 7 | -------------------------------------------------------------------------------- /trainers/aggregate_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import copy 5 | import time 6 | import wandb 7 | import torch 8 | import random 9 | import argparse 10 | import torchvision 11 | import torch.utils.data 12 | 13 | import numpy as np 14 | 15 | from pprint import pprint 16 | from torch import nn, optim 17 | from torch.autograd import Variable 18 | from torch.nn import functional as F 19 | 20 | from torchvision import datasets, transforms 21 | 22 | 23 | class Aggregate_Trainer: 24 | def __init__( 25 | self, 26 | model, 27 | cuda_flag, 28 | train_split_obj, 29 | val_split_obj, 30 | test_split_obj, 31 | max_epochs, 32 | optimizer, 33 | viz_env_name, 34 | log_interval, 35 | scheduler, 36 | scheduler_mode, 37 | lr_decay_steps, 38 | ckpt_folder, 39 | params, 40 | ): 41 | self.model = model 42 | self.cuda_flag = cuda_flag 43 | self.train_split_obj = train_split_obj 44 | self.val_split_obj = val_split_obj 45 | self.test_split_obj = test_split_obj 46 | self.max_epochs = max_epochs 47 | self.optimizer = optimizer 48 | self.viz_env_name = viz_env_name 49 | self.log_interval = log_interval 50 | self.scheduler = scheduler 51 | self.scheduler_mode = scheduler_mode 52 | self.lr_decay_steps = lr_decay_steps 53 | self.ckpt_folder = ckpt_folder 54 | self.params = params 55 | self.epoch = 0 56 | self.iteration = 0 57 | 58 | self.target_domains = self.params.DATA.TARGET_DOMAINS.split(",") 59 | 60 | # Setup wandb 61 | wandb.init( 62 | project=self.params.HJOB.WANDB_PROJECT, 63 | name=self.viz_env_name, 64 | dir=self.params.HJOB.WANDB_DIR, 65 | ) 66 | 67 | # Watch model 68 | wandb.watch(self.model) 69 | 70 | # Add config 71 | wandb.config.update(params) 72 | 73 | # Criterion to GPU 74 | if self.params.PROCESS.USE_GPU: 75 | self.model.loss_gpu(True) 76 | 77 | def saveModel(self, saveFile, params): 78 | torch.save( 79 | { 80 | "model": self.model.state_dict(), 81 | "optimizer": self.optimizer.state_dict(), 82 | "epoch": self.epoch, 83 | "iteration": self.iteration, 84 | }, 85 | saveFile, 86 | ) 87 | 88 | def train_epoch(self): 89 | self.model.train() 90 | since = time.time() 91 | running_loss, running_corrects = 0, 0 92 | for batch_idx, (image, label, _) in enumerate(self.train_split_obj.curr_loader): 93 | label = label.long() 94 | if self.cuda_flag: 95 | image = image.cuda() 96 | label = label.cuda() 97 | iteration = ( 98 | batch_idx + (self.epoch - 1) * len(self.train_split_obj.curr_loader) + 1 99 | ) 100 | self.iteration = iteration 101 | self.optimizer.zero_grad() 102 | 103 | with torch.set_grad_enabled(True): 104 | outputs = self.model(image) 105 | loss = self.model.loss_fn(outputs, label) 106 | wandb.log({"Training Iter Loss": loss.data.item()}, step=self.iteration) 107 | _, preds = torch.max(outputs, 1) 108 | 109 | loss.backward() 110 | self.optimizer.step() 111 | 112 | if self.scheduler is not None: 113 | if self.scheduler_mode == "iter": 114 | self.scheduler.step() 115 | 116 | if self.log_interval is not None: 117 | if batch_idx % self.log_interval == 0: 118 | lr = self.optimizer.param_groups[0]["lr"] 119 | print( 120 | "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 121 | self.epoch, 122 | batch_idx * len(image), 123 | len(self.train_split_obj.curr_loader.dataset), 124 | 100.0 125 | * batch_idx 126 | / len(self.train_split_obj.curr_loader), 127 | loss.item(), 128 | ) 129 | ) 130 | 131 | running_loss += loss.item() * image.size(0) 132 | running_corrects += torch.sum(preds == label.data) 133 | time_elapsed = time.time() - since 134 | wandb.log({"Train Epoch Time": time_elapsed}, step=self.iteration) 135 | epoch_loss = running_loss / len(self.train_split_obj.curr_loader.dataset) 136 | epoch_acc = running_corrects.double() / len( 137 | self.train_split_obj.curr_loader.dataset 138 | ) 139 | return epoch_loss, epoch_acc 140 | 141 | def validate(self, mode="val"): 142 | self.model.eval() 143 | since = time.time() 144 | running_loss, running_corrects = 0, 0 145 | if mode == "val": 146 | rel_loader = self.val_split_obj.curr_loader 147 | elif mode == "test": 148 | rel_loader = self.test_split_obj.curr_loader 149 | else: 150 | print("Split mode not supported yet") 151 | 152 | for batch_idx, (image, label, _) in enumerate(rel_loader): 153 | label = label.long() 154 | if self.cuda_flag: 155 | image = image.cuda() 156 | label = label.cuda() 157 | 158 | with torch.no_grad(): 159 | outputs = self.model(image) 160 | loss = self.model.loss_fn(outputs, label) 161 | _, preds = torch.max(outputs, 1) 162 | running_loss += loss.item() * image.size(0) 163 | running_corrects += torch.sum(preds == label.data) 164 | 165 | time_elapsed = time.time() - since 166 | wandb.log({"Val Epoch Time": time_elapsed}, step=self.iteration) 167 | epoch_loss = running_loss / len(rel_loader.dataset) 168 | epoch_acc = running_corrects.double() / len(rel_loader.dataset) 169 | return epoch_loss, epoch_acc 170 | 171 | def train(self, domain_list=None, ckpt_store_interval=20): 172 | # Checkpoint storing interval 173 | ckpt_store_int = ckpt_store_interval 174 | # To keep track of running performance 175 | running_vl_loss_dict = {x: 0 for x in domain_list} 176 | running_vl_acc_dict = {x: 0 for x in domain_list} 177 | 178 | running_ts_loss_dict = {x: 0 for x in self.target_domains} 179 | running_ts_acc_dict = {x: 0 for x in self.target_domains} 180 | 181 | running_vl_loss_dict["overall"] = 0 182 | running_vl_acc_dict["overall"] = 0 183 | 184 | running_ts_loss_dict["overall"] = 0 185 | running_ts_acc_dict["overall"] = 0 186 | 187 | best_score = None 188 | last_epoch = 0 189 | 190 | # We don't need to specify which set of 191 | # parameters to keep gradients on for 192 | for epoch in range(1, self.max_epochs + 1): 193 | self.epoch = epoch 194 | self.train_epoch() 195 | 196 | # Plot iteration versus epochs 197 | wandb.log( 198 | {"Iteration": self.iteration, "Epoch": self.epoch}, step=self.iteration 199 | ) 200 | 201 | # Evaluate on the train and val splits 202 | # 1. Training Split 203 | for domain in domain_list: 204 | self.val_split_obj.set_domain_spec_mode(True, domain) 205 | temp_loss, temp_acc = self.validate("val") 206 | running_vl_loss_dict[domain] = temp_loss 207 | running_vl_acc_dict[domain] = temp_acc.data.item() 208 | self.val_split_obj.set_domain_spec_mode(False) 209 | 210 | # Calculate overall performance 211 | # Validation 212 | vl_loss = np.mean([running_vl_loss_dict[x] for x in domain_list]) 213 | vl_acc = np.mean([running_vl_acc_dict[x] for x in domain_list]) 214 | running_vl_loss_dict["overall"] = vl_loss 215 | running_vl_acc_dict["overall"] = vl_acc 216 | 217 | # WANDB Logs 218 | # Log losses of all domains in a single plot (both train + val) 219 | wandb_loss_log = {} 220 | for key, val in running_vl_loss_dict.items(): 221 | wandb_loss_log[key + "_vl_loss"] = val 222 | wandb_acc_log = {} 223 | for key, val in running_vl_acc_dict.items(): 224 | wandb_acc_log[key + "_vl_acc"] = val 225 | 226 | wandb.log(wandb_loss_log, step=self.iteration) 227 | wandb.log(wandb_acc_log, step=self.iteration) 228 | # Log accuracies of all domains in a single plot (both train + val) 229 | # On-screen performance 230 | print("-----------------------------------") 231 | print("Fine-grained validation performance") 232 | print("-----------------------------------") 233 | print("Loss") 234 | pprint(running_vl_loss_dict) 235 | print("Accuracy") 236 | pprint(running_vl_acc_dict) 237 | print("-----------------------------------") 238 | 239 | # Performance on target domains 240 | if self.target_domains is not None: 241 | for domain in self.target_domains: 242 | self.test_split_obj.set_domain_spec_mode(True, domain) 243 | temp_loss, temp_acc = self.validate("test") 244 | running_ts_loss_dict[domain] = temp_loss 245 | running_ts_acc_dict[domain] = temp_acc.data.item() 246 | self.test_split_obj.set_domain_spec_mode(False) 247 | 248 | # Calculate overall performance 249 | ts_loss = np.mean( 250 | [running_ts_loss_dict[x] for x in self.target_domains] 251 | ) 252 | ts_acc = np.mean([running_ts_acc_dict[x] for x in self.target_domains]) 253 | running_ts_loss_dict["overall"] = ts_loss 254 | running_ts_acc_dict["overall"] = ts_acc 255 | 256 | # WANDB Logs 257 | # Log losses of all domains in a single plot (both train + val) 258 | wandb_loss_log = {} 259 | for key, val in running_ts_loss_dict.items(): 260 | wandb_loss_log[key + "_ts_loss"] = val 261 | wandb_acc_log = {} 262 | for key, val in running_ts_acc_dict.items(): 263 | wandb_acc_log[key + "_ts_acc"] = val 264 | wandb.log(wandb_loss_log, step=self.iteration) 265 | wandb.log(wandb_acc_log, step=self.iteration) 266 | 267 | # Log accuracies of all domains in a single plot (both train + val) 268 | # On-screen performance 269 | print("-----------------------------------") 270 | print("Fine-grained test performance") 271 | print("-----------------------------------") 272 | print("Loss") 273 | pprint(running_ts_loss_dict) 274 | print("Accuracy") 275 | pprint(running_ts_acc_dict) 276 | print("-----------------------------------") 277 | 278 | # Store checkpoints 279 | if self.epoch % ckpt_store_int == 0 or self.epoch == 0: 280 | self.saveModel( 281 | self.ckpt_folder + "/model_ep_" + str(self.epoch) + ".pth", 282 | self.params, 283 | ) 284 | 285 | # Check for epoch level scheduler step 286 | # Update to match the DomainNet Multi-Source baselines 287 | if self.scheduler is not None: 288 | if self.scheduler_mode == "epoch": 289 | if self.params.OPTIM.LEARNING_RATE_SCHEDULER == "exp": 290 | self.scheduler.step() 291 | elif self.params.OPTIM.LEARNING_RATE_SCHEDULER == "invlr": 292 | self.scheduler.step() 293 | 294 | score = vl_acc 295 | if best_score is None or score > best_score: 296 | best_score = score 297 | self.saveModel(self.ckpt_folder + "/best_so_far.pth", self.params) 298 | with open(self.ckpt_folder + "/val_loss.json", "w") as f: 299 | json.dump(running_vl_loss_dict, f) 300 | with open(self.ckpt_folder + "/val_acc.json", "w") as f: 301 | json.dump(running_vl_acc_dict, f) 302 | 303 | # Store best values in tables 304 | best_score_wandb_table = wandb.Table( 305 | columns=["Domain", "Validation Loss", "Validation Accuracy"] 306 | ) 307 | for key, val in running_vl_loss_dict.items(): 308 | best_score_wandb_table.add_data( 309 | key, 310 | str(running_vl_loss_dict[key]), 311 | str(running_vl_acc_dict[key]), 312 | ) 313 | wandb.log({"Best_Score_Table": best_score_wandb_table}) 314 | 315 | best_test_score_wandb_table = wandb.Table( 316 | columns=["Domain", "Test Loss", "Test Accuracy"] 317 | ) 318 | for key, val in running_ts_loss_dict.items(): 319 | best_test_score_wandb_table.add_data( 320 | key, 321 | str(running_ts_loss_dict[key]), 322 | str(running_ts_acc_dict[key]), 323 | ) 324 | wandb.log({"Best_Test_Score_Table": best_test_score_wandb_table}) 325 | -------------------------------------------------------------------------------- /trainers/multihead_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import copy 5 | import time 6 | import torch 7 | import wandb 8 | import random 9 | import argparse 10 | import torchvision 11 | import torch.utils.data 12 | 13 | import numpy as np 14 | 15 | from pprint import pprint 16 | from torch import nn, optim 17 | from torch.autograd import Variable 18 | from torch.nn import functional as F 19 | 20 | from torchvision import datasets, transforms 21 | 22 | 23 | class MultiHead_Trainer: 24 | def __init__( 25 | self, 26 | model, 27 | cuda_flag, 28 | train_split_obj, 29 | val_split_obj, 30 | test_split_obj, 31 | train_forward_mode, 32 | eval_forward_mode, 33 | max_epochs, 34 | optimizer, 35 | viz_env_name, 36 | log_interval, 37 | scheduler, 38 | scheduler_mode, 39 | lr_decay_steps, 40 | ckpt_folder, 41 | params, 42 | ): 43 | self.model = model 44 | self.cuda_flag = cuda_flag 45 | self.train_split_obj = train_split_obj 46 | self.val_split_obj = val_split_obj 47 | self.test_split_obj = test_split_obj 48 | self.train_forward_mode = train_forward_mode 49 | self.eval_forward_mode = eval_forward_mode 50 | self.max_epochs = max_epochs 51 | self.optimizer = optimizer 52 | self.viz_env_name = viz_env_name 53 | self.log_interval = log_interval 54 | self.scheduler = scheduler 55 | self.scheduler_mode = scheduler_mode 56 | self.lr_decay_steps = lr_decay_steps 57 | self.ckpt_folder = ckpt_folder 58 | self.params = params 59 | self.epoch = 0 60 | self.iteration = 0 61 | 62 | self.target_domains = self.params.DATA.TARGET_DOMAINS.split(",") 63 | 64 | # Setup wandb 65 | wandb.init( 66 | project=self.params.HJOB.WANDB_PROJECT, 67 | name=self.viz_env_name, 68 | dir=self.params.HJOB.WANDB_DIR, 69 | ) 70 | 71 | # Watch model 72 | wandb.watch(self.model) 73 | 74 | # Add config 75 | wandb.config.update(params) 76 | 77 | def saveModel(self, saveFile, params): 78 | torch.save( 79 | { 80 | "model": self.model.state_dict(), 81 | "optimizer": self.optimizer.state_dict(), 82 | "epoch": self.epoch, 83 | "iteration": self.iteration, 84 | }, 85 | saveFile, 86 | ) 87 | 88 | def train_epoch(self): 89 | self.model.train() 90 | since = time.time() 91 | running_loss, running_corrects = 0, 0 92 | for batch_idx, (image, label, domain) in enumerate( 93 | self.train_split_obj.curr_loader 94 | ): 95 | label = label.long() 96 | if self.cuda_flag: 97 | image = image.cuda() 98 | label = label.cuda() 99 | iteration = ( 100 | batch_idx + (self.epoch - 1) * len(self.train_split_obj.curr_loader) + 1 101 | ) 102 | self.iteration = iteration 103 | self.optimizer.zero_grad() 104 | 105 | with torch.set_grad_enabled(True): 106 | if self.train_forward_mode == "route": 107 | outputs = self.model(image, domain) 108 | elif self.train_forward_mode == "avg_score_fwd": 109 | outputs = self.model.avg_forward(image) 110 | else: 111 | print("Cannot forward pass in this mode") 112 | loss = self.model.loss_fn(outputs, label) 113 | wandb.log({"Training Iter Loss": loss.data.item()}, step=self.iteration) 114 | _, preds = torch.max(outputs, 1) 115 | 116 | loss.backward() 117 | self.optimizer.step() 118 | 119 | if self.scheduler is not None: 120 | if self.scheduler_mode == "iter": 121 | self.scheduler.step() 122 | 123 | if self.log_interval is not None: 124 | if batch_idx % self.log_interval == 0: 125 | print( 126 | "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 127 | self.epoch, 128 | batch_idx * len(image), 129 | len(self.train_split_obj.curr_loader.dataset), 130 | 100.0 131 | * batch_idx 132 | / len(self.train_split_obj.curr_loader), 133 | loss.item(), 134 | ) 135 | ) 136 | 137 | running_loss += loss.item() * image.size(0) 138 | running_corrects += torch.sum(preds == label.data) 139 | time_elapsed = time.time() - since 140 | wandb.log({"Train Epoch Time": time_elapsed}, step=self.iteration) 141 | epoch_loss = running_loss / len(self.train_split_obj.curr_loader.dataset) 142 | epoch_acc = running_corrects.double() / len( 143 | self.train_split_obj.curr_loader.dataset 144 | ) 145 | return epoch_loss, epoch_acc 146 | 147 | def validate(self, mode="val"): 148 | self.model.eval() 149 | since = time.time() 150 | running_loss, running_corrects = 0, 0 151 | if mode == "val": 152 | rel_loader = self.val_split_obj.curr_loader 153 | elif mode == "test": 154 | rel_loader = self.test_split_obj.curr_loader 155 | else: 156 | print("Split mode not supported yet") 157 | 158 | for batch_idx, (image, label, domain) in enumerate(rel_loader): 159 | label = label.long() 160 | if self.cuda_flag: 161 | image = image.cuda() 162 | label = label.cuda() 163 | 164 | with torch.set_grad_enabled(False): 165 | if mode == "val": 166 | if self.eval_forward_mode == "route": 167 | outputs = self.model(image, domain) 168 | elif self.eval_forward_mode == "avg_score_fwd": 169 | outputs = self.model.avg_forward(image) 170 | elif self.eval_forward_mode == "avg_prob_fwd": 171 | outputs = self.model.avg_prob_forward(image) 172 | else: 173 | print("Eval forward mode not identified") 174 | elif mode == "test": 175 | outputs = self.model.avg_forward(image) 176 | else: 177 | print("Split mode not supported yet") 178 | 179 | loss = self.model.loss_fn(outputs, label) 180 | _, preds = torch.max(outputs, 1) 181 | running_loss += loss.item() * image.size(0) 182 | running_corrects += torch.sum(preds == label.data) 183 | 184 | time_elapsed = time.time() - since 185 | wandb.log({"Val Epoch Time": time_elapsed}, step=self.iteration) 186 | epoch_loss = running_loss / len(rel_loader.dataset) 187 | epoch_acc = running_corrects.double() / len(rel_loader.dataset) 188 | return epoch_loss, epoch_acc 189 | 190 | def train(self, domain_list=None, ckpt_store_interval=20): 191 | # Checkpoint storing interval 192 | ckpt_store_int = ckpt_store_interval 193 | # To keep track of running performance 194 | running_vl_loss_dict = {x: 0 for x in domain_list} 195 | running_vl_acc_dict = {x: 0 for x in domain_list} 196 | running_vl_loss_dict["overall"] = 0 197 | running_vl_acc_dict["overall"] = 0 198 | 199 | running_ts_loss_dict = {x: 0 for x in self.target_domains} 200 | running_ts_acc_dict = {x: 0 for x in self.target_domains} 201 | running_ts_loss_dict["overall"] = 0 202 | running_ts_acc_dict["overall"] = 0 203 | 204 | # Train and repeat the above functions 205 | # with appropriate logging 206 | best_score = None 207 | last_epoch = 0 208 | 209 | # We don't need to specify which set of 210 | # parameters to keep gradients on for 211 | for epoch in range(1, self.max_epochs + 1): 212 | # Set global epoch 213 | self.epoch = epoch 214 | self.train_epoch() 215 | 216 | # Plot iteration versus epochs 217 | wandb.log( 218 | {"Iteration": self.iteration, "Epoch": self.epoch}, step=self.iteration 219 | ) 220 | 221 | # Evaluate on the val split 222 | for domain in domain_list: 223 | self.val_split_obj.set_domain_spec_mode(True, domain) 224 | temp_loss, temp_acc = self.validate("val") 225 | running_vl_loss_dict[domain] = temp_loss 226 | running_vl_acc_dict[domain] = temp_acc.data.item() 227 | self.val_split_obj.set_domain_spec_mode(False) 228 | 229 | # Calculate overall performance 230 | # Validation 231 | vl_loss = np.mean([running_vl_loss_dict[x] for x in domain_list]) 232 | vl_acc = np.mean([running_vl_acc_dict[x] for x in domain_list]) 233 | # Store overall data 234 | running_vl_loss_dict["overall"] = vl_loss 235 | running_vl_acc_dict["overall"] = vl_acc 236 | 237 | # WANDB Logs 238 | wandb_loss_log = {} 239 | for key, val in running_vl_loss_dict.items(): 240 | wandb_loss_log[key + "_vl_loss"] = val 241 | wandb_acc_log = {} 242 | for key, val in running_vl_acc_dict.items(): 243 | wandb_acc_log[key + "_vl_acc"] = val 244 | wandb.log(wandb_loss_log, step=self.iteration) 245 | wandb.log(wandb_acc_log, step=self.iteration) 246 | # On-screen performance 247 | print("-----------------------------------") 248 | print("Fine-grained validation performance") 249 | print("-----------------------------------") 250 | print("Loss") 251 | pprint(running_vl_loss_dict) 252 | print("Accuracy") 253 | pprint(running_vl_acc_dict) 254 | print("-----------------------------------") 255 | 256 | # Performance on target domains 257 | if self.target_domains is not None: 258 | for domain in self.target_domains: 259 | self.test_split_obj.set_domain_spec_mode(True, domain) 260 | temp_loss, temp_acc = self.validate("test") 261 | running_ts_loss_dict[domain] = temp_loss 262 | running_ts_acc_dict[domain] = temp_acc.data.item() 263 | self.test_split_obj.set_domain_spec_mode(False) 264 | 265 | # Calculate overall performance 266 | ts_loss = np.mean( 267 | [running_ts_loss_dict[x] for x in self.target_domains] 268 | ) 269 | ts_acc = np.mean([running_ts_acc_dict[x] for x in self.target_domains]) 270 | running_ts_loss_dict["overall"] = ts_loss 271 | running_ts_acc_dict["overall"] = ts_acc 272 | 273 | # WANDB Logs 274 | # Log losses of all domains in a single plot 275 | wandb_loss_log = {} 276 | for key, val in running_ts_loss_dict.items(): 277 | wandb_loss_log[key + "_ts_loss"] = val 278 | wandb_acc_log = {} 279 | for key, val in running_ts_acc_dict.items(): 280 | wandb_acc_log[key + "_ts_acc"] = val 281 | wandb.log(wandb_loss_log, step=self.iteration) 282 | wandb.log(wandb_acc_log, step=self.iteration) 283 | 284 | # Log accuracies of all domains in a single plot 285 | # On-screen performance 286 | print("-----------------------------------") 287 | print("Fine-grained test performance") 288 | print("-----------------------------------") 289 | print("Loss") 290 | pprint(running_ts_loss_dict) 291 | print("Accuracy") 292 | pprint(running_ts_acc_dict) 293 | print("-----------------------------------") 294 | 295 | # Store checkpoints 296 | if self.epoch % ckpt_store_int == 0 or self.epoch == 0: 297 | self.saveModel( 298 | self.ckpt_folder + "/model_ep_" + str(self.epoch) + ".pth", 299 | self.params, 300 | ) 301 | 302 | # Check for epoch level scheduler step 303 | if self.scheduler is not None: 304 | if self.scheduler_mode == "epoch": 305 | if self.params.OPTIM.LEARNING_RATE_SCHEDULER == "exp": 306 | self.scheduler.step() 307 | elif self.params.OPTIM.LEARNING_RATE_SCHEDULER == "invlr": 308 | self.scheduler.step() 309 | 310 | score = vl_acc 311 | if best_score is None or score > best_score: 312 | best_score = score 313 | self.saveModel(self.ckpt_folder + "/best_so_far.pth", self.params) 314 | with open(self.ckpt_folder + "/val_loss.json", "w") as f: 315 | json.dump(running_vl_loss_dict, f) 316 | with open(self.ckpt_folder + "/val_acc.json", "w") as f: 317 | json.dump(running_vl_acc_dict, f) 318 | # Store best values in tables 319 | best_score_wandb_table = wandb.Table( 320 | columns=["Domain", "Validation Loss", "Validation Accuracy"] 321 | ) 322 | for key, val in running_vl_loss_dict.items(): 323 | best_score_wandb_table.add_data( 324 | key, 325 | str(running_vl_loss_dict[key]), 326 | str(running_vl_acc_dict[key]), 327 | ) 328 | wandb.log({"Best_Score_Table": best_score_wandb_table}) 329 | -------------------------------------------------------------------------------- /trainers/subnetwork_supermask_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import sys 4 | import json 5 | import copy 6 | import time 7 | import wandb 8 | import torch 9 | import random 10 | import argparse 11 | import torchvision 12 | import torch.utils.data 13 | 14 | import numpy as np 15 | 16 | from pprint import pprint 17 | from torch import nn, optim 18 | from torch.autograd import Variable 19 | from torch.nn import functional as F 20 | 21 | from torchvision import datasets, transforms 22 | 23 | 24 | class SubNetwork_SuperMask_Trainer: 25 | def __init__( 26 | self, 27 | model, 28 | mask_layers, 29 | mask_modules, 30 | domain_list, 31 | target_domains, 32 | train_split_obj, 33 | val_split_obj, 34 | test_split_obj, 35 | max_epochs, 36 | model_optimizer, 37 | mask_optimizer, 38 | ckpt_folder, 39 | params, 40 | model_scheduler, 41 | mask_scheduler, 42 | scheduler_mode, 43 | log_interval, 44 | viz_env_name=None, 45 | cuda_flag=True, 46 | ): 47 | self.model = model 48 | self.mask_layers = mask_layers 49 | self.mask_modules = mask_modules 50 | self.domain_list = domain_list 51 | self.target_domains = target_domains 52 | self.train_split_obj = train_split_obj 53 | self.val_split_obj = val_split_obj 54 | self.test_split_obj = test_split_obj 55 | self.max_epochs = max_epochs 56 | self.model_optimizer = model_optimizer 57 | self.mask_optimizer = mask_optimizer 58 | self.ckpt_folder = ckpt_folder 59 | self.params = params 60 | self.model_scheduler = model_scheduler 61 | self.mask_scheduler = mask_scheduler 62 | self.scheduler_mode = scheduler_mode 63 | self.log_interval = log_interval 64 | self.viz_env_name = viz_env_name 65 | self.cuda_flag = cuda_flag 66 | self.epoch = 0 67 | self.iteration = 0 68 | 69 | self.lr_decay_steps = self.params.OPTIM.LEARNING_RATE_DECAY_STEP 70 | 71 | # Setup wandb 72 | wandb.init( 73 | project=self.params.HJOB.WANDB_PROJECT, 74 | name=self.viz_env_name, 75 | dir=self.params.HJOB.WANDB_DIR, 76 | ) 77 | 78 | # Watch model 79 | wandb.watch(self.model) 80 | for x in self.mask_modules: 81 | wandb.watch(x) 82 | 83 | # Add config 84 | wandb.config.update(params) 85 | 86 | def saveModel( 87 | self, 88 | vl_loss_perf_dict, 89 | vl_acc_perf_dict, 90 | ts_loss_perf_dict, 91 | ts_acc_perf_dict, 92 | saveFile, 93 | ): 94 | # Save the model checkpoints 95 | # as well as the performance dictionaries 96 | ckpt_dict = { 97 | "joint_model": self.model.joint_model.state_dict(), 98 | "model_optimizer": self.model_optimizer.state_dict(), 99 | "mask_optimizer": self.mask_optimizer.state_dict(), 100 | "mask_layers": self.mask_layers, 101 | "loss_on_val": vl_loss_perf_dict, 102 | "acc_on_val": vl_acc_perf_dict, 103 | "loss_on_test": ts_loss_perf_dict, 104 | "acc_on_test": ts_acc_perf_dict, 105 | "epoch": self.epoch, 106 | "iteration": self.iteration, 107 | } 108 | 109 | # Add layer indices and the policy models 110 | for i in range(len(self.mask_layers)): 111 | ckpt_dict[str(self.mask_layers[i]) + "_super_mask"] = self.mask_modules[ 112 | i 113 | ].state_dict() 114 | 115 | # Save in the specified location 116 | torch.save(ckpt_dict, saveFile) 117 | 118 | def train_epoch(self): 119 | since = time.time() 120 | running_loss = 0 121 | running_class_loss = 0 122 | running_sparsity_loss = 0 123 | running_corrects = 0 124 | running_sparsity = 0 125 | 126 | running_mask_ls = {x: [] for x in self.mask_layers} 127 | running_prob_ls = {x: [] for x in self.mask_layers} 128 | 129 | self.model.train() 130 | self.model.set_dropout_eval(True) 131 | 132 | # Set policy modules to train mode 133 | for i in range(len(self.mask_modules)): 134 | self.mask_modules[i].train() 135 | 136 | for batch_idx, (image, label, domain) in enumerate( 137 | self.train_split_obj.curr_loader 138 | ): 139 | label = label.long() 140 | if self.cuda_flag: 141 | image = image.cuda() 142 | label = label.cuda() 143 | iteration = ( 144 | batch_idx + (self.epoch - 1) * len(self.train_split_obj.curr_loader) + 1 145 | ) 146 | self.iteration = iteration 147 | self.model_optimizer.zero_grad() 148 | self.mask_optimizer.zero_grad() 149 | mask_domain = domain 150 | 151 | with torch.set_grad_enabled(True): 152 | scores, prob_ls, action_ls = self.model( 153 | image, 154 | self.mask_modules, 155 | mask_domain, 156 | self.params.MODEL.POLICY_SAMPLE_MODE, 157 | self.params.MODEL.POLICY_CONV_MODE, 158 | ) 159 | 160 | # Store the probabilities and the actions 161 | # in the specified data-structure 162 | for x in self.mask_layers: 163 | mask_ind = self.mask_layers.index(x) 164 | curr_mask = action_ls[mask_ind].mean(dim=0).cpu().detach().numpy() 165 | curr_probs = prob_ls[mask_ind].mean(dim=0).cpu().detach().numpy() 166 | running_mask_ls[x].append(curr_mask) 167 | running_prob_ls[x].append(curr_probs) 168 | 169 | # Compute classification loss 170 | class_loss = self.model.loss_fn(scores, label) 171 | _, preds = torch.max(scores, 1) 172 | 173 | # Aggregate Sparsity Across layers 174 | # Also aggregate the sparsity loss 175 | sparsity = [] 176 | sparsity_loss = 0 177 | overlap_loss = 0 178 | for i in range(len(self.mask_modules)): 179 | sparsity.append( 180 | torch.mean(self.mask_modules[i].sparsity(action_ls[i])) 181 | ) 182 | sparsity_loss += self.mask_modules[i].sparsity_penalty() 183 | overlap_loss += self.mask_modules[i].overlap_penalty() 184 | 185 | # Calculte total loss 186 | loss = ( 187 | class_loss.mean() 188 | + self.params.OPTIM.SPARSITY_LAMBDA * sparsity_loss 189 | + self.params.OPTIM.OVERLAP_LAMBDA * overlap_loss 190 | ) 191 | 192 | loss.backward() 193 | self.model_optimizer.step() 194 | self.mask_optimizer.step() 195 | 196 | # Learning rate scheduler (per-iteration steps) 197 | if self.model_scheduler is not None: 198 | if self.scheduler_mode == "iter": 199 | self.model_scheduler.step() 200 | 201 | if self.mask_scheduler is not None: 202 | if self.scheduler_mode == "iter": 203 | self.mask_scheduler.step() 204 | 205 | # Show running loss 206 | if self.log_interval is not None: 207 | if batch_idx % self.log_interval == 0: 208 | print( 209 | "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 210 | self.epoch, 211 | batch_idx * len(image), 212 | len(self.train_split_obj.curr_loader.dataset), 213 | 100.0 214 | * batch_idx 215 | / len(self.train_split_obj.curr_loader), 216 | loss.item(), 217 | ) 218 | ) 219 | 220 | # Keep track of statistics 221 | running_loss += loss.item() * image.size(0) 222 | running_class_loss += class_loss.mean().item() * image.size(0) 223 | running_sparsity_loss += sparsity_loss.mean().item() * image.size(0) 224 | running_corrects += torch.sum(preds == label.data) 225 | running_sparsity += torch.mean(torch.stack(sparsity)) * image.size(0) 226 | 227 | # Get Mask Similarity Statistics 228 | for i in range(len(self.mask_layers)): 229 | iou_overlap_dict = self.mask_modules[i].mask_overlap( 230 | self.mask_layers[i] 231 | ) 232 | 233 | # Log per-iteration data in wandb 234 | wandb.log({"Iter Train Loss": loss.data.item()}, step=self.iteration) 235 | wandb.log( 236 | {"Iter Train CLoss": class_loss.mean().data.item()}, step=self.iteration 237 | ) 238 | wandb.log( 239 | {"Iter Train SLoss": sparsity_loss.mean().data.item()}, 240 | step=self.iteration, 241 | ) 242 | wandb.log( 243 | { 244 | "Iter Train CAcc": torch.sum(preds == label.data) 245 | .double() 246 | .data.item() 247 | / image.size(0) 248 | }, 249 | step=self.iteration, 250 | ) 251 | wandb.log( 252 | {"Iter Train Sparsity": torch.mean(torch.stack(sparsity)).data.item()}, 253 | step=self.iteration, 254 | ) 255 | epoch_loss = running_loss / len(self.train_split_obj.curr_loader.dataset) 256 | epoch_class_loss = running_class_loss / len( 257 | self.train_split_obj.curr_loader.dataset 258 | ) 259 | epoch_sparsity_loss = running_sparsity_loss / len( 260 | self.train_split_obj.curr_loader.dataset 261 | ) 262 | epoch_acc = running_corrects.double() / len( 263 | self.train_split_obj.curr_loader.dataset 264 | ) 265 | epoch_sparsity = running_sparsity.double() / len( 266 | self.train_split_obj.curr_loader.dataset 267 | ) 268 | 269 | # Log per-iteration data in wandb 270 | wandb.log({"Epoch Train Loss": epoch_loss}, step=self.iteration) 271 | wandb.log({"Epoch Train CLoss": epoch_class_loss}, step=self.iteration) 272 | wandb.log({"Epoch Train SLoss": epoch_sparsity_loss}, step=self.iteration) 273 | wandb.log({"Epoch Train CAcc": epoch_acc}, step=self.iteration) 274 | wandb.log({"Epoch Train Sparsity": epoch_sparsity}, step=self.iteration) 275 | time_elapsed = time.time() - since 276 | 277 | def evaluate(self, phase="val"): 278 | since = time.time() 279 | sparsity_stats = [[] for x in range(len(self.mask_modules))] 280 | running_loss, running_corrects = 0, 0 281 | self.model.eval() 282 | for i in range(len(self.mask_modules)): 283 | self.mask_modules[i].eval() 284 | 285 | if phase == "val": 286 | eval_loader = self.val_split_obj.curr_loader 287 | elif phase == "test": 288 | eval_loader = self.test_split_obj.curr_loader 289 | else: 290 | print("Phase not supported for evaluation") 291 | 292 | for batch_idx, (image, label, domain) in enumerate(eval_loader): 293 | label = label.long() 294 | if self.cuda_flag: 295 | image = image.cuda() 296 | label = label.cuda() 297 | 298 | if phase == "val": 299 | policy_domain = domain 300 | with torch.set_grad_enabled(False): 301 | scores, _, action_ls = self.model( 302 | image, 303 | self.mask_modules, 304 | policy_domain, 305 | "greedy", 306 | self.params.MODEL.POLICY_CONV_MODE, 307 | ) 308 | 309 | for i in range(len(self.mask_modules)): 310 | sparsity_stats[i].append( 311 | self.mask_modules[i].sparsity(action_ls[i]).mean().item() 312 | ) 313 | 314 | elif phase == "test": 315 | scores = [] 316 | for eval_domain in self.domain_list: 317 | policy_domain = [eval_domain] * len(domain) 318 | with torch.set_grad_enabled(False): 319 | score, _, action_ls = self.model( 320 | image, 321 | self.mask_modules, 322 | policy_domain, 323 | "softscale", 324 | self.params.MODEL.POLICY_CONV_MODE, 325 | ) 326 | scores.append(score) 327 | for i in range(len(self.mask_modules)): 328 | sparsity_stats[i].append( 329 | self.mask_modules[i] 330 | .sparsity(action_ls[i]) 331 | .mean() 332 | .item() 333 | ) 334 | scores = torch.stack(scores) 335 | scores = scores.mean(0) 336 | 337 | _, preds = torch.max(scores, 1) 338 | eval_loss = self.model.loss_fn(scores, label) 339 | eval_loss = eval_loss.mean() 340 | running_loss += eval_loss.item() * image.size(0) 341 | running_corrects += torch.sum(preds == label.data) 342 | 343 | time_elapsed = time.time() - since 344 | epoch_loss = running_loss / len(eval_loader.dataset) 345 | epoch_acc = running_corrects.double() / len(eval_loader.dataset) 346 | sparsity_stats = np.array(sparsity_stats) 347 | sparsity_stats = np.mean(sparsity_stats, axis=1).tolist() 348 | return epoch_loss, epoch_acc, sparsity_stats 349 | 350 | def train(self, ckpt_store_interval=20): 351 | ckpt_store_int = ckpt_store_interval 352 | running_vl_loss_dict = {x: 0 for x in self.domain_list} 353 | running_vl_acc_dict = {x: 0 for x in self.domain_list} 354 | 355 | running_vl_loss_dict["overall"] = 0 356 | running_vl_acc_dict["overall"] = 0 357 | 358 | running_ts_loss_dict = {x: 0 for x in self.target_domains} 359 | running_ts_acc_dict = {x: 0 for x in self.target_domains} 360 | 361 | running_ts_loss_dict["overall"] = 0 362 | running_ts_acc_dict["overall"] = 0 363 | 364 | vl_neurons_activated = [] 365 | for i in range(len(self.mask_layers)): 366 | neuron_stats = {x: 0 for x in self.domain_list} 367 | vl_neurons_activated.append(neuron_stats) 368 | 369 | ts_neurons_activated = [] 370 | for i in range(len(self.mask_layers)): 371 | neuron_stats = {x: 0 for x in self.target_domains} 372 | ts_neurons_activated.append(neuron_stats) 373 | 374 | best_score = None 375 | last_epoch = 0 376 | 377 | for epoch in range(1, self.max_epochs + 1): 378 | self.epoch = epoch 379 | self.train_epoch() 380 | 381 | wandb.log( 382 | {"Iteration": self.iteration, "Epoch": self.epoch}, step=self.iteration 383 | ) 384 | 385 | # Performance on the val dataset 386 | for domain in self.domain_list: 387 | self.val_split_obj.set_domain_spec_mode(True, domain) 388 | temp_loss, temp_acc, temp_sparsity = self.evaluate("val") 389 | running_vl_loss_dict[domain] = temp_loss 390 | running_vl_acc_dict[domain] = temp_acc.data.item() 391 | for j in range(len(self.mask_layers)): 392 | vl_neurons_activated[j][domain] = temp_sparsity[j] 393 | wandb.log( 394 | { 395 | domain 396 | + " : Blocks Activated at " 397 | + self.mask_layers[j]: temp_sparsity[j] 398 | }, 399 | step=self.iteration, 400 | ) 401 | 402 | wandb.log( 403 | {domain + " : Validation Loss": temp_loss}, step=self.iteration 404 | ) 405 | wandb.log( 406 | {domain + " : Validation Accuracy": temp_acc.data.item()}, 407 | step=self.iteration, 408 | ) 409 | self.val_split_obj.set_domain_spec_mode(False) 410 | 411 | vl_loss = np.mean([running_vl_loss_dict[x] for x in self.domain_list]) 412 | vl_acc = np.mean([running_vl_acc_dict[x] for x in self.domain_list]) 413 | running_vl_loss_dict["overall"] = vl_loss 414 | running_vl_acc_dict["overall"] = vl_acc 415 | wandb.log( 416 | {"Overall : Validation Loss": running_vl_loss_dict["overall"]}, 417 | step=self.iteration, 418 | ) 419 | wandb.log( 420 | {"Overall : Validation Accuracy": running_vl_acc_dict["overall"]}, 421 | step=self.iteration, 422 | ) 423 | 424 | print("-----------------------------------") 425 | print("Fine-grained validation performance") 426 | print("-----------------------------------") 427 | print("Loss") 428 | pprint(running_vl_loss_dict) 429 | print("Accuracy") 430 | pprint(running_vl_acc_dict) 431 | print("-----------------------------------") 432 | 433 | # Performance on the test dataset 434 | for domain in self.target_domains: 435 | self.test_split_obj.set_domain_spec_mode(True, domain) 436 | temp_loss, temp_acc, temp_sparsity = self.evaluate("test") 437 | running_ts_loss_dict[domain] = temp_loss 438 | running_ts_acc_dict[domain] = temp_acc.data.item() 439 | for j in range(len(self.mask_layers)): 440 | ts_neurons_activated[j][domain] = temp_sparsity[j] 441 | wandb.log( 442 | { 443 | domain 444 | + " : Blocks Activated at " 445 | + self.mask_layers[j]: temp_sparsity[j] 446 | }, 447 | step=self.iteration, 448 | ) 449 | 450 | wandb.log({domain + " : Test Loss": temp_loss}, step=self.iteration) 451 | wandb.log( 452 | {domain + " : Test Accuracy": temp_acc.data.item()}, 453 | step=self.iteration, 454 | ) 455 | self.test_split_obj.set_domain_spec_mode(False) 456 | 457 | ts_loss = np.mean([running_ts_loss_dict[x] for x in self.target_domains]) 458 | ts_acc = np.mean([running_ts_acc_dict[x] for x in self.target_domains]) 459 | running_ts_loss_dict["overall"] = ts_loss 460 | running_ts_acc_dict["overall"] = ts_acc 461 | wandb.log( 462 | {"Overall : Test Loss": running_ts_loss_dict["overall"]}, 463 | step=self.iteration, 464 | ) 465 | wandb.log( 466 | {"Overall : Test Accuracy": running_ts_acc_dict["overall"]}, 467 | step=self.iteration, 468 | ) 469 | 470 | print("-----------------------------------") 471 | print("Fine-grained test performance") 472 | print("-----------------------------------") 473 | print("Loss") 474 | pprint(running_ts_loss_dict) 475 | print("Accuracy") 476 | pprint(running_ts_acc_dict) 477 | print("-----------------------------------") 478 | 479 | # Store checkpoints 480 | if self.epoch % ckpt_store_int == 0 or self.epoch == 0: 481 | self.saveModel( 482 | running_vl_loss_dict, 483 | running_vl_acc_dict, 484 | running_ts_loss_dict, 485 | running_ts_acc_dict, 486 | self.ckpt_folder + "/model_ep_" + str(self.epoch) + ".pth", 487 | ) 488 | 489 | score = vl_acc 490 | if best_score is None or score > best_score: 491 | best_score = score 492 | self.saveModel( 493 | running_vl_loss_dict, 494 | running_vl_acc_dict, 495 | running_ts_loss_dict, 496 | running_ts_acc_dict, 497 | self.ckpt_folder + "/best_so_far.pth", 498 | ) 499 | with open(self.ckpt_folder + "/val_loss.json", "w") as f: 500 | json.dump(running_vl_loss_dict, f) 501 | with open(self.ckpt_folder + "/val_acc.json", "w") as f: 502 | json.dump(running_vl_acc_dict, f) 503 | with open(self.ckpt_folder + "/test_loss.json", "w") as f: 504 | json.dump(running_ts_loss_dict, f) 505 | with open(self.ckpt_folder + "/test_acc.json", "w") as f: 506 | json.dump(running_ts_acc_dict, f) 507 | 508 | # Store best values in tables 509 | best_vl_score_wandb_table = wandb.Table( 510 | columns=["Domain", "Validation Loss", "Validation Accuracy"] 511 | ) 512 | for key, val in running_vl_loss_dict.items(): 513 | best_vl_score_wandb_table.add_data( 514 | key, 515 | str(running_vl_loss_dict[key]), 516 | str(running_vl_acc_dict[key]), 517 | ) 518 | wandb.log({"Best_Validation_Score_Table": best_vl_score_wandb_table}) 519 | 520 | # Store test values in tables 521 | best_ts_score_wandb_table = wandb.Table( 522 | columns=["Domain", "Test Loss", "Test Accuracy"] 523 | ) 524 | for key, val in running_ts_loss_dict.items(): 525 | best_ts_score_wandb_table.add_data( 526 | key, 527 | str(running_ts_loss_dict[key]), 528 | str(running_ts_acc_dict[key]), 529 | ) 530 | wandb.log({"Best_Test_Score_Table": best_ts_score_wandb_table}) 531 | 532 | # Check for epoch level scheduler step 533 | if self.model_scheduler is not None: 534 | if self.scheduler_mode == "epoch": 535 | if self.params.OPTIM.LEARNING_RATE_SCHEDULER == "exp": 536 | self.model_scheduler.step() 537 | elif self.params.OPTIM.LEARNING_RATE_SCHEDULER == "invlr": 538 | self.model_scheduler.step() 539 | 540 | # if self.mask_scheduler is not None: 541 | # if self.scheduler_mode == "epoch": 542 | # if self.params.OPTIM.LEARNING_RATE_SCHEDULER == "exp": 543 | # self.mask_scheduler.step() 544 | # elif self.params.OPTIM.LEARNING_RATE_SCHEDULER == "invlr": 545 | # self.mask_scheduler.step() 546 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prithv1/DMG/6b162b4958b52f4d99d51663e053f58f5a77e7cc/utils/__init__.py -------------------------------------------------------------------------------- /utils/inverse_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class InvLR(_LRScheduler): 7 | """Decays the learning rate accroding to inv lr schedule 8 | """ 9 | 10 | def __init__(self, optimizer, gamma=0.0001, power=0.75, last_epoch=-1): 11 | self.gamma = gamma 12 | self.power = power 13 | super(InvLR, self).__init__(optimizer, last_epoch) 14 | 15 | def get_lr(self): 16 | factor = ( 17 | (1 + self.gamma * self.last_epoch) 18 | / (1 + self.gamma * (self.last_epoch - 1)) 19 | ) ** (-self.power) 20 | if self.last_epoch == 0: 21 | return [group["lr"] for group in self.optimizer.param_groups] 22 | return [group["lr"] * factor for group in self.optimizer.param_groups] 23 | 24 | def _get_closed_form_lr(self): 25 | return [ 26 | base_lr * ((1 + self.gamma * self.last_epoch) ** (-self.power)) 27 | for base_lr in self.base_lrs 28 | ] 29 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torchvision 5 | import torch.utils.data 6 | 7 | import numpy as np 8 | 9 | from pprint import pprint 10 | from itertools import combinations 11 | 12 | from torch import nn, optim 13 | from torch.nn import functional as F 14 | from torch.distributions import Bernoulli, RelaxedBernoulli 15 | from torchvision import datasets, models, transforms 16 | 17 | SMOOTH = 1e-6 18 | 19 | 20 | def weights_init(m): 21 | classname = m.__class__.__name__ 22 | if classname.find("Conv") != -1: 23 | m.weight.data.normal_(0.0, 0.02) 24 | elif classname.find("BatchNorm") != -1: 25 | m.weight.data.normal_(1.0, 0.02) 26 | m.bias.data.fill_(0) 27 | elif classname.find("Linear") != -1: 28 | size = m.weight.size() 29 | m.weight.data.normal_(0.0, 0.001) 30 | m.bias.data.fill_(0) 31 | --------------------------------------------------------------------------------