├── .gitattributes ├── .gitignore ├── CITATION.bib ├── README.md ├── assets └── fade.png ├── conda.yml ├── fade ├── __init__.py ├── config │ ├── config.yaml │ ├── dataset │ │ ├── Adult2.yaml │ │ ├── DigitFive.yaml │ │ ├── Office31.yaml │ │ ├── OfficeHome65.yaml │ │ └── comb │ │ │ ├── Adult2FM.yaml │ │ │ ├── Digit.yaml │ │ │ ├── Office31_X2X_1s_3t.yaml │ │ │ └── OfficeHome65_X2X_1s_3t.yaml │ ├── model │ │ ├── AdultDnnSplitAdv.yaml │ │ ├── Office31_CnnSplitAdv.yaml │ │ └── OfficeHome65CnnSplitAdv.yaml │ ├── server │ │ ├── FedAdv.yaml │ │ ├── FedAvg.yaml │ │ └── FedLocal.yaml │ └── user │ │ ├── generic.yaml │ │ ├── group_adv.yaml │ │ └── group_adv_office_uda.yaml ├── data │ ├── __init__.py │ ├── federalize.py │ ├── meta │ │ ├── __init__.py │ │ ├── adult.py │ │ ├── office_caltech.py │ │ └── usps.py │ ├── multi_domain.py │ └── utils.py ├── file.py ├── mainx.py ├── model │ ├── __init__.py │ ├── adult.py │ ├── adv.py │ ├── mnist.py │ ├── office.py │ ├── shot.py │ ├── shot_digit.py │ ├── split.py │ └── utils.py ├── server │ ├── FedAdv.py │ ├── FedAvg.py │ ├── __init__.py │ └── base.py ├── user │ ├── __init__.py │ ├── base.py │ ├── cdan_loss.py │ ├── generic.py │ ├── group_adv.py │ └── shot_digit_loss.py └── utils.py └── sweeps ├── Office31_UDA ├── A2X_cdan_nuser_sweep.yaml ├── A2X_dann_nuser_sweep.yaml ├── A2X_shot_nuser_sweep.yaml ├── A_fedavg.sh ├── D2X_cdan_nuser_sweep.yaml ├── D2X_dann_nuser_sweep.yaml ├── D2X_shot_nuser_sweep.yaml ├── D_fedavg.sh ├── W2X_cdan_nuser_sweep.yaml ├── W2X_dann_nuser_sweep.yaml ├── W2X_shot_nuser_sweep.yaml ├── W_fedavg.sh └── sweep_all.sh └── OfficeHome65_1to3_uda_iid ├── R2X_cdan_nuser_sweep.yaml ├── R2X_dann_nuser_sweep.yaml ├── R2X_shot_nuser_sweep.yaml └── sweep_all.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | 2 | *.pt filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Created by https://www.gitignore.io/api/macos,latex 107 | # Edit at https://www.gitignore.io/?templates=macos,latex 108 | 109 | ### LaTeX ### 110 | ## Core latex/pdflatex auxiliary files: 111 | *.aux 112 | *.lof 113 | *.log 114 | *.lot 115 | *.fls 116 | *.out 117 | *.toc 118 | *.fmt 119 | *.fot 120 | *.cb 121 | *.cb2 122 | .*.lb 123 | 124 | ## Intermediate documents: 125 | *.dvi 126 | *.xdv 127 | *-converted-to.* 128 | # these rules might exclude image files for figures etc. 129 | # *.ps 130 | # *.eps 131 | # *.pdf 132 | 133 | ## Generated if empty string is given at "Please type another file name for output:" 134 | .pdf 135 | 136 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 137 | *.bbl 138 | *.bcf 139 | *.blg 140 | *-blx.aux 141 | *-blx.bib 142 | *.run.xml 143 | 144 | ## Build tool auxiliary files: 145 | *.fdb_latexmk 146 | *.synctex 147 | *.synctex(busy) 148 | *.synctex.gz 149 | *.synctex.gz(busy) 150 | *.pdfsync 151 | 152 | ## Build tool directories for auxiliary files 153 | # latexrun 154 | latex.out/ 155 | 156 | ## Auxiliary and intermediate files from other packages: 157 | # algorithms 158 | *.alg 159 | *.loa 160 | 161 | # achemso 162 | acs-*.bib 163 | 164 | # amsthm 165 | *.thm 166 | 167 | # beamer 168 | *.nav 169 | *.pre 170 | *.snm 171 | *.vrb 172 | 173 | # changes 174 | *.soc 175 | 176 | # comment 177 | *.cut 178 | 179 | # cprotect 180 | *.cpt 181 | 182 | # elsarticle (documentclass of Elsevier journals) 183 | *.spl 184 | 185 | # endnotes 186 | *.ent 187 | 188 | # fixme 189 | *.lox 190 | 191 | # feynmf/feynmp 192 | *.mf 193 | *.mp 194 | *.t[1-9] 195 | *.t[1-9][0-9] 196 | *.tfm 197 | 198 | #(r)(e)ledmac/(r)(e)ledpar 199 | *.end 200 | *.?end 201 | *.[1-9] 202 | *.[1-9][0-9] 203 | *.[1-9][0-9][0-9] 204 | *.[1-9]R 205 | *.[1-9][0-9]R 206 | *.[1-9][0-9][0-9]R 207 | *.eledsec[1-9] 208 | *.eledsec[1-9]R 209 | *.eledsec[1-9][0-9] 210 | *.eledsec[1-9][0-9]R 211 | *.eledsec[1-9][0-9][0-9] 212 | *.eledsec[1-9][0-9][0-9]R 213 | 214 | # glossaries 215 | *.acn 216 | *.acr 217 | *.glg 218 | *.glo 219 | *.gls 220 | *.glsdefs 221 | 222 | # gnuplottex 223 | *-gnuplottex-* 224 | 225 | # gregoriotex 226 | *.gaux 227 | *.gtex 228 | 229 | # htlatex 230 | *.4ct 231 | *.4tc 232 | *.idv 233 | *.lg 234 | *.trc 235 | *.xref 236 | 237 | # hyperref 238 | *.brf 239 | 240 | # knitr 241 | *-concordance.tex 242 | # TODO Comment the next line if you want to keep your tikz graphics files 243 | *.tikz 244 | *-tikzDictionary 245 | 246 | # listings 247 | *.lol 248 | 249 | # luatexja-ruby 250 | *.ltjruby 251 | 252 | # makeidx 253 | *.idx 254 | *.ilg 255 | *.ind 256 | *.ist 257 | 258 | # minitoc 259 | *.maf 260 | *.mlf 261 | *.mlt 262 | *.mtc[0-9]* 263 | *.slf[0-9]* 264 | *.slt[0-9]* 265 | *.stc[0-9]* 266 | 267 | # minted 268 | _minted* 269 | *.pyg 270 | 271 | # morewrites 272 | *.mw 273 | 274 | # nomencl 275 | *.nlg 276 | *.nlo 277 | *.nls 278 | 279 | # pax 280 | *.pax 281 | 282 | # pdfpcnotes 283 | *.pdfpc 284 | 285 | # sagetex 286 | *.sagetex.sage 287 | *.sagetex.py 288 | *.sagetex.scmd 289 | 290 | # scrwfile 291 | *.wrt 292 | 293 | # sympy 294 | *.sout 295 | *.sympy 296 | sympy-plots-for-*.tex/ 297 | 298 | # pdfcomment 299 | *.upa 300 | *.upb 301 | 302 | # pythontex 303 | *.pytxcode 304 | pythontex-files-*/ 305 | 306 | # tcolorbox 307 | *.listing 308 | 309 | # thmtools 310 | *.loe 311 | 312 | # TikZ & PGF 313 | *.dpth 314 | *.md5 315 | *.auxlock 316 | 317 | # todonotes 318 | *.tdo 319 | 320 | # vhistory 321 | *.hst 322 | *.ver 323 | 324 | # easy-todo 325 | *.lod 326 | 327 | # xcolor 328 | *.xcp 329 | 330 | # xmpincl 331 | *.xmpi 332 | 333 | # xindy 334 | *.xdy 335 | 336 | # xypic precompiled matrices 337 | *.xyc 338 | 339 | # endfloat 340 | *.ttt 341 | *.fff 342 | 343 | # Latexian 344 | TSWLatexianTemp* 345 | 346 | ## Editors: 347 | # WinEdt 348 | *.bak 349 | *.sav 350 | 351 | # Texpad 352 | .texpadtmp 353 | 354 | # LyX 355 | *.lyx~ 356 | 357 | # Kile 358 | *.backup 359 | 360 | # KBibTeX 361 | *~[0-9]* 362 | 363 | # auto folder when using emacs and auctex 364 | ./auto/* 365 | *.el 366 | 367 | # expex forward references with \gathertags 368 | *-tags.tex 369 | 370 | # standalone packages 371 | *.sta 372 | 373 | ### LaTeX Patch ### 374 | # glossaries 375 | *.glstex 376 | 377 | ### macOS ### 378 | # General 379 | .DS_Store 380 | .AppleDouble 381 | .LSOverride 382 | 383 | # Icon must end with two \r 384 | Icon 385 | 386 | # Thumbnails 387 | ._* 388 | 389 | # Files that might appear in the root of a volume 390 | .DocumentRevisions-V100 391 | .fseventsd 392 | .Spotlight-V100 393 | .TemporaryItems 394 | .Trashes 395 | .VolumeIcon.icns 396 | .com.apple.timemachine.donotpresent 397 | 398 | # Directories potentially created on remote AFP share 399 | .AppleDB 400 | .AppleDesktop 401 | Network Trash Folder 402 | Temporary Items 403 | .apdisk 404 | 405 | # End of https://www.gitignore.io/api/macos,latex 406 | fade/config/tele/nintendo.yaml 407 | *.zip 408 | *.tar.gz 409 | wandb/ 410 | out/ 411 | outputs/ 412 | data/ 413 | -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{hong2021federated, 2 | title={Federated Adversarial Debiasing for Fair and Transferable Representations}, 3 | author={Hong, Junyuan and Zhu, Zhuangdi and Yu, Shuyang and Wang, Zhangyang and Dodge, Hiroko and Zhou, Jiayu}, 4 | booktitle={Proceedings of the 27th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining}, 5 | year={2021} 6 | } 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Federated Adversarial Debiasing (FADE) 2 | ====================================== 3 | 4 | Code for paper: "Federated Adversarial Debiasing for Fair and Transferable Representations" Junyuan Hong, Zhuangdi Zhu, Shuyang Yu, Zhangyang Wang, Hiroko Dodge, and Jiayu Zhou. *KDD'21* 5 | [[paper]](https://dl.acm.org/doi/10.1145/3447548.3467281) [[slides]](https://jyhong.gitlab.io/publication/fade2021kdd/slides.pdf) 6 | 7 | **TL;DR**: FADE is the first work showing that clients can optimize an group-to-group adversarial debiasing objective [1] **without its adversarial data on local device**. The technique is applicable for unsupervised domain adaptation (UDA) and group-fair learning. In UDA, our method outperforms the SOTA UDA w/o source data (SHOT) in federated learning. 8 | 9 | ![adversarial objective w/o adversarial data](https://user-images.githubusercontent.com/6964516/160862893-fba4e6a3-298e-4cb1-b7f0-d39bdde64b68.png) 10 | 11 | [1] Ganin, Y., & Lempitsky, V. (2015). Unsupervised domain adaptation by backpropagation. *ICML*. 12 | 13 | **Abstract** 14 | 15 | Federated learning is a distributed learning framework that is communication efficient and provides protection over participating users' raw training data. One outstanding challenge of federate learning comes from the users' heterogeneity, and learning from such data may yield biased and unfair models for minority groups. While adversarial learning is commonly used in centralized learning for mitigating bias, there are significant barriers when extending it to the federated framework. In this work, we study these barriers and address them by proposing a novel approach Federated Adversarial DEbiasing (FADE). FADE does not require users' sensitive group information for debiasing and offers users the freedom to opt-out from the adversarial component when privacy or computational costs become a concern. We show that ideally, FADE can attain the same global optimality as the one by the centralized algorithm. We then analyze when its convergence may fail in practice and propose a simple yet effective method to address the problem. Finally, we demonstrate the effectiveness of the proposed framework through extensive empirical studies, including the problem settings of unsupervised domain adaptation and fair learning. 16 | 17 | ![FADE](assets/fade.png) 18 | 19 | ## Usage 20 | ### Setup Environment 21 | Clone the repository and setup the environment. 22 | ```shell 23 | git clone git@github.com:illidanlab/FADE.git 24 | cd FADE 25 | # create conda env 26 | conda env create -f conda.yml 27 | conda activate fade 28 | # run 29 | python -m fade.mainx 30 | ``` 31 | 32 | To run repeated experiments, we use `wandb` to log. Run 33 | ```shell 34 | wandb sweep 35 | ``` 36 | Note, you need a wandb account which will be required at first run. 37 | 38 | ### Download Datasets 39 | 40 | * **Office**: Download zip file from [here](https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view) (preprocessed by [SHOT](https://github.com/tim-learn/SHOT)) and unpack into `./data/office31`. Verify the file structure to make sure the missing image path exist. 41 | * **OfficeHome**: Download zip file from [here](https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view) (preprocessed by [SHOT](https://github.com/tim-learn/SHOT)) and unpack into `./data/OfficeHome65`. Verify the file structure to make sure the missing image path exist. 42 | 43 | ### Download Pre-trained Source-domain Models 44 | 45 | For each UDA tasks, we pre-train models on the source domain first. You can pre-train these models by yourself: 46 | ```shell 47 | source sweeps/Office31_UDA/A_fedavg.sh 48 | ``` 49 | Instead, you may download the pre-trained source-domain models from [here](https://www.dropbox.com/sh/0imy8vft8o3mph8/AABhNuzbW02OmwboMu84e672a?dl=0). Place under `out/models/`. 50 | 51 | ### Pre-trained adapted models 52 | 53 | To add soon. 54 | 55 | ## Run UDA experiments 56 | 57 | * Office dataset 58 | ```shell 59 | # pretrain the model on domain A, D, W. 60 | source sweeps/Office31_UDA/A_fedavg.sh 61 | # create wandb sweeps for A2X, D2X, W2X where X is one of the rest two domains. 62 | # the command will prompt the agent commands. 63 | source sweeps/Office31_UDA/sweep_all.sh 64 | # Run wandb agent commands from the prompt or the sweep page. 65 | wandb agent 66 | ``` 67 | Demo wandb project page: [fade-demo-Office31_X2X_UDA](https://wandb.ai/jyhong/fade-demo-Office31_X2X_UDA?workspace=user-jyhong). Check [sweeps](https://wandb.ai/jyhong/fade-demo-Office31_X2X_UDA/sweeps?workspace=user-jyhong) here. 68 | * OfficeHome dataset 69 | ```shell 70 | # pretrain the model on domain R 71 | source sweeps/OfficeHome65_1to3_uda_iid/R_fedavg.sh 72 | # create wandb sweeps for R2X where X is one of the rest domains. 73 | # the command will prompt the agent commands. 74 | source sweeps/OfficeHome65_1to3_uda_iid/sweep_all.sh 75 | # Run wandb agent commands from the prompt or the sweep page. 76 | wandb agent 77 | ``` 78 | 79 | ![image](https://user-images.githubusercontent.com/6964516/160864437-f124fcbd-da17-422a-b9b2-189f745d9f3b.png) 80 | 81 | ## Extend with other debias methods 82 | 83 | To extend FADE framework with other debias methods, you need to update the user and server codes. To start, please read the `GroupAdvUser` class in [fade/user/group_adv.py](fade/user/group_adv.py) and `FedAdv` in [fade/server/FedAdv.py](fade/server/FedAdv.py). 84 | 85 | Typically, you will need to update the `compute_loss` function in `GroupAdvUser` class to customize your loss computation. 86 | 87 | ------------ 88 | 89 | If you find the repository useful, please cite our paper. 90 | ```bibtex 91 | @inproceedings{hong2021federated, 92 | title={Federated Adversarial Debiasing for Fair and Transferable Representations}, 93 | author={Hong, Junyuan and Zhu, Zhuangdi and Yu, Shuyang and Wang, Zhangyang and Dodge, Hiroko and Zhou, Jiayu}, 94 | booktitle={Proceedings of the 27th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining}, 95 | year={2021} 96 | } 97 | ``` 98 | 99 | **Acknowledgement** 100 | 101 | This material is based in part upon work supported by the National Science Foundation under Grant IIS-1749940, EPCN-2053272, Office of Naval Research N00014-20-1-2382, and National Institute on Aging (NIA) R01AG051628, R01AG056102, P30AG066518, P30AG024978, RF1AG072449. 102 | -------------------------------------------------------------------------------- /assets/fade.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illidanlab/FADE/7997485ab6470fd31c2f9353bf2415a1bec87363/assets/fade.png -------------------------------------------------------------------------------- /conda.yml: -------------------------------------------------------------------------------- 1 | name: fade 2 | channels: 3 | - pytorch 4 | - hcc 5 | - cvxgrp 6 | - anaconda 7 | - conda-forge 8 | - oxfordcontrol 9 | - defaults 10 | dependencies: 11 | - pytorch>=1.3 12 | # - cudatoolkit=10.1 13 | - numpy>=1.15 14 | - torchvision>=0.4 15 | - scipy>=1.2 16 | - python==3.7.5 17 | - seaborn 18 | - matplotlib 19 | - click 20 | - coloredlogs 21 | - scikit-learn 22 | - tqdm 23 | - requests 24 | - tensorboard 25 | - hydra-core==1.0.4 26 | - pip: 27 | - wandb # ==0.9.7 # 0.10 has known issue with joblib. See https://github.com/wandb/client/issues/1525 28 | - hydra_colorlog 29 | - hydra-joblib-launcher 30 | -------------------------------------------------------------------------------- /fade/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illidanlab/FADE/7997485ab6470fd31c2f9353bf2415a1bec87363/fade/__init__.py -------------------------------------------------------------------------------- /fade/config/config.yaml: -------------------------------------------------------------------------------- 1 | # This is the default experiment config. 2 | defaults: 3 | - model: cnn-split 4 | - server: FedAvg 5 | - user: generic 6 | - dataset: Mnist # Note the runtime entry name is 'dataset' instead of 'data'. 7 | - hydra/job_logging: colorlog 8 | - hydra/hydra_logging: colorlog 9 | - hydra/launcher: joblib 10 | 11 | # Example: dataset=comb/MnistM_c5u40 server.num_users=1 times=3 user.batch_size=5 num_glob_iters=400 user.local_epochs=50 user.optimizer.learning_rate=0.005 server.beta=1.0 n_jobs=3 server.name=pFedSplit model=cnn-split user.optimizer.name=sgd model.mid_dim=20 12 | 13 | # not hashed config 14 | num_glob_iters: 800 15 | partial_eval: false # Do partial evaluation instead of all-user evaluation at each global round 16 | n_rep: 5 # Number of repetition times 17 | i_rep: -1 # should be smaller than n_rep or negative (ignored) 18 | n_jobs: 1 # num of parallel jobs 19 | device: 'cuda' # run device (cpu | cuda) 20 | action: 'train' # 'train' | 'eval' | 'avg' | 'check_files' 21 | logging: 'WARN' 22 | 23 | # only keys listed below will be hashed as server unique name. 24 | hash_keys: ["model", "server", "user", "dataset", "seed", "name"] 25 | 26 | load_model: # NOTE: for action=eval, even if load_modle.do=false, the saved default models will still be loaded. 27 | do: false 28 | load: [server, user] # Example: [server, user], [server], [] 29 | # TODO this has to be moved to model dict. such that pretrained and train-from-scratch can be distinguished. 30 | hash_name: null # could be hash name of a server or null to load from default path. Could be absolute path to super-dir of `server.pt`. 31 | disable_save_user_model: True 32 | 33 | # will be hashed 34 | name: "dmtl" # experiment name 35 | seed: 42 # for repetitions, use different seed 36 | 37 | logger: 38 | loggers: ["wandb"] 39 | log_user: true # log for each user. 40 | wandb: 41 | name: rep_${i_rep}_${user.optimizer.learning_rate} 42 | project: ${name}-${dataset.name} 43 | group: ${server.name}-${user.name}-${model.name} 44 | # offline: false # Run offline (data can be streamed later to wandb servers). 45 | 46 | # Set config only for evaluation (which will not change the hash of the experiment). 47 | # The nested config will replace the original setting during aciton=eval. 48 | #eval_config: 49 | # dataset: 50 | 51 | plot: 52 | func: plot_hidden_states 53 | kwargs: 54 | color_by: class 55 | info_group: test -------------------------------------------------------------------------------- /fade/config/dataset/Adult2.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: Adult2W 3 | seed: 42 4 | n_user: 1 5 | n_class: 2 6 | class_stride: 1 7 | partition_mode: uni 8 | min_n_sample_per_share: 2 # num sample per class 9 | max_n_sample_per_share: -1 10 | max_n_sample_per_class: 1000 # equivalent to -1 in practice 11 | user_data_format: tensor 12 | n_sample_per_class: -1 # test size per class -------------------------------------------------------------------------------- /fade/config/dataset/DigitFive.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: Mnist # Mnist | MnistM | SVHN | USPS | SynDigit 3 | seed: 42 4 | n_user: 1 5 | n_class: 10 6 | class_stride: 2 7 | partition_mode: uni 8 | min_n_sample_per_share: 2 # num sample per class 9 | max_n_sample_per_share: -1 10 | max_n_sample_per_class: -1 # 1000 11 | n_channel: 3 # 1 or 3 12 | user_data_format: tensor 13 | random_crop_rot: True # only used by USPS TODO do we need to set false? 14 | resize: 28 15 | -------------------------------------------------------------------------------- /fade/config/dataset/Office31.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: Office31A 3 | seed: 42 4 | n_user: 1 5 | n_class: 31 6 | class_stride: 1 7 | partition_mode: uni 8 | min_n_sample_per_share: 2 # num sample per class 9 | max_n_sample_per_share: -1 10 | max_n_sample_per_class: 1000 # equivalent to -1 in practice 11 | user_data_format: dataset 12 | feature_type: images -------------------------------------------------------------------------------- /fade/config/dataset/OfficeHome65.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: OfficeHome65A 3 | seed: 42 4 | n_user: 1 5 | n_class: 65 6 | class_stride: 1 7 | partition_mode: uni 8 | min_n_sample_per_share: 2 # num sample per class 9 | max_n_sample_per_share: -1 10 | max_n_sample_per_class: 1000 # equivalent to -1 in practice 11 | user_data_format: dataset 12 | feature_type: images -------------------------------------------------------------------------------- /fade/config/dataset/comb/Adult2FM.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Fairness experiment 4 | dataset: 5 | name: "comb/Adult2FM" 6 | meta_fed_ds: federalize 7 | meta_datasets: 8 | - name: Adult2F 9 | seed: 42 10 | n_user: 10 11 | n_class: 2 12 | class_stride: 1 13 | min_n_sample_per_share: 2 # num sample per class 14 | max_n_sample_per_share: -1 15 | max_n_sample_per_class: 200 16 | # n_sample_per_class: 200 # use all test samples 17 | user_data_format: tensor 18 | partition_distribution: uni 19 | - name: Adult2M 20 | seed: 42 21 | n_user: 10 22 | n_class: 2 23 | class_stride: 1 24 | min_n_sample_per_share: 2 # num sample per class 25 | max_n_sample_per_share: -1 26 | max_n_sample_per_class: 200 27 | # n_sample_per_class: 200 # use all test samples 28 | user_data_format: dataset 29 | partition_distribution: uni 30 | -------------------------------------------------------------------------------- /fade/config/dataset/comb/Digit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Domain adaptation Mnist -> USPS 4 | # This config only support `federalize` to create the dataset. 5 | # Do not change the order. 6 | dataset: 7 | name: "comb/Digit" 8 | meta_fed_ds: federalize # federalize | extend (default) 9 | # meta_datasets: ["Mnist", "MnistM"] 10 | # comb_mode: "sep" # sep: Each user only have a single domain data. mix: Each user has all domain. 11 | # total_n_user: 20 12 | meta_datasets: 13 | - name: SVHN # target domain 14 | seed: 42 15 | n_user: 1 16 | n_class: 10 17 | class_stride: 2 18 | partition_mode: uni 19 | min_n_sample_per_share: 2 # num sample per class per user shard 20 | max_n_sample_per_share: -1 21 | max_n_sample_per_class: -1 22 | n_channel: 3 # 1 or 3 23 | user_data_format: tensor 24 | random_crop_rot: True 25 | resize: 28 26 | - name: Mnist # source domain 27 | seed: 42 28 | n_user: 1 29 | n_class: 10 30 | class_stride: 2 31 | partition_mode: uni 32 | min_n_sample_per_share: 2 # num sample per class 33 | max_n_sample_per_share: -1 34 | max_n_sample_per_class: -1 35 | n_channel: 3 # 1 or 3 36 | user_data_format: tensor 37 | -------------------------------------------------------------------------------- /fade/config/dataset/comb/Office31_X2X_1s_3t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # 1to1 Domain adaptation on Office31 dataset 4 | # 1 source user vs 3 target user 5 | # Do not change the order. 6 | dataset: 7 | name: "comb/Office31_X2X_1s_3t" 8 | meta_fed_ds: federalize 9 | meta_datasets: 10 | - name: Office31W # target domain: Webcam 11 | seed: 42 12 | n_user: 3 13 | n_class: 15 14 | class_stride: 10 15 | min_n_sample_per_share: 2 # num sample per class per user shard 16 | max_n_sample_per_share: -1 17 | max_n_sample_per_class: -1 18 | user_data_format: dataset 19 | feature_type: images 20 | partition_mode: uni 21 | - name: Office31A # source domain: Amazon 22 | seed: 42 23 | n_user: 1 24 | n_class: 31 25 | class_stride: 1 26 | min_n_sample_per_share: 2 # num sample per class 27 | max_n_sample_per_share: -1 28 | max_n_sample_per_class: -1 29 | user_data_format: dataset 30 | feature_type: images 31 | partition_mode: uni 32 | -------------------------------------------------------------------------------- /fade/config/dataset/comb/OfficeHome65_X2X_1s_3t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # 1to1 Domain adaptation on Office31 dataset 4 | # 1 source user vs 3 target user 5 | # Do not change the order. 6 | dataset: 7 | name: "comb/OfficeHome65_X2X_1s_3t" 8 | meta_fed_ds: federalize # federalize | extend (default) 9 | meta_datasets: 10 | - name: OfficeHome65C # target domain: Clipart 11 | seed: 42 12 | n_user: 3 13 | n_class: 25 14 | class_stride: 20 15 | min_n_sample_per_share: 2 # num sample per class 16 | max_n_sample_per_share: -1 17 | max_n_sample_per_class: -1 18 | user_data_format: dataset 19 | feature_type: images 20 | partition_mode: uni 21 | - name: OfficeHome65A # source domain: Arts 22 | seed: 42 23 | n_user: 1 24 | n_class: 65 25 | class_stride: 1 26 | min_n_sample_per_share: 2 # num sample per class 27 | max_n_sample_per_share: -1 28 | max_n_sample_per_class: -1 29 | user_data_format: dataset 30 | feature_type: images 31 | partition_mode: uni 32 | -------------------------------------------------------------------------------- /fade/config/model/AdultDnnSplitAdv.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: 'AdultDNNSplit' 3 | n_task: 1 # > 0 for task/group adv training. <=0 otherwise 4 | rev_lambda_scale: 1. # 0 to disable backward from task discriminator. >0 to reverse grad. <0 to normal grad. NOTE this is just a constant scale to the lambda. The real lambda will be further scheduled on run. 5 | n_class: 2 # This depends on dataset. 6 | mid_dim: 64 7 | freeze_backbone: False 8 | freeze_decoder: False 9 | disable_bn_stat: False # set True in testing or fine-tuning 10 | CDAN_task: False -------------------------------------------------------------------------------- /fade/config/model/Office31_CnnSplitAdv.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: 'OfficeCnnSplit' 3 | backbone: resnet50 4 | n_task: 1 # > 0 for task/group adv training. <=0 otherwise 5 | rev_lambda_scale: 1. # 0 to disable backward from task discriminator. >0 to reverse grad. <0 to normal grad. NOTE this is just a constant scale to the lambda. The real lambda will be further scheduled on run. 6 | n_class: 31 # This depends on dataset. 7 | mid_dim: 256 8 | freeze_backbone: False 9 | freeze_decoder: False 10 | disable_bn_stat: False # set True in testing or fine-tuning 11 | CDAN_task: False 12 | bottleneck_type: dropout 13 | -------------------------------------------------------------------------------- /fade/config/model/OfficeHome65CnnSplitAdv.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: 'OfficeCnnSplit' 3 | backbone: resnet50 4 | n_task: 1 # > 0 for task/group adv training. <=0 otherwise 5 | rev_lambda_scale: 1. # 0 to disable backward from task discriminator. >0 to reverse grad. <0 to normal grad. NOTE this is just a constant scale to the lambda. The real lambda will be further scheduled on run. 6 | n_class: 65 # This depends on dataset. 7 | mid_dim: 256 8 | freeze_backbone: True 9 | freeze_decoder: False 10 | disable_bn_stat: False # set True in testing or fine-tuning 11 | CDAN_task: False 12 | bottleneck_type: dropout -------------------------------------------------------------------------------- /fade/config/server/FedAdv.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # Unsupervised Domain Adaptation by Adv Training. 3 | # Assumptions: 4 | # All users are processing same task, i.e., p(y|x) are the same. Thus, we share all params like FedAvg. 5 | # Domain differs by p(x) 6 | # One group of users are unsupervised. 7 | server: 8 | name: FedAdv 9 | alg: FedAdv # server alg 10 | num_users: 1 # Number of Users per global round 11 | beta: 1. # Average moving parameter for pFedMe, or Second learning rate of Per-FedAvg 12 | share_mode: all # share all because MnistM does not use labels to train. 13 | sync_optimizer: false # true to sync optimizers between users after each global run. NOTE: sync only support num_users==1. 14 | user_selection: "sequential" 15 | fair_update: False 16 | # group_label_mode: ??? # {Mnist:supervised,MnistM:unsupervised} 17 | # Example for setting 'group_label_mode': 18 | # +server.group_label_mode.Mnist=supervised 19 | # +server.group_label_mode.MnistM=unsupervised 20 | privacy: 21 | enable: false 22 | # Add only on use 23 | # user_clip_norm: -1 # User-DP clip norm (layer wise) 24 | # user_dp_sigma: -1 # User-DP noise sigma 25 | -------------------------------------------------------------------------------- /fade/config/server/FedAvg.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: FedAvg 3 | alg: FedAvg # server alg 4 | num_users: 20 # Number of Users per global round 5 | beta: 1. # Average moving parameter for pFedMe, or Second learning rate of Per-FedAvg 6 | share_mode: all 7 | privacy: 8 | enable: false 9 | # Add only on use 10 | # user_clip_norm: -1 # User-DP clip norm (layer wise) 11 | # user_dp_sigma: -1 # User-DP noise sigma 12 | -------------------------------------------------------------------------------- /fade/config/server/FedLocal.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: FedLocal 3 | alg: FedAvg # server alg 4 | num_users: 20 # Number of Users per global round 5 | beta: 1. # Average moving parameter for pFedMe, or Second learning rate of Per-FedAvg 6 | share_mode: private 7 | privacy: 8 | enable: false 9 | # Add only on use 10 | # user_clip_norm: -1 # User-DP clip norm (layer wise) 11 | # user_dp_sigma: -1 # User-DP noise sigma 12 | -------------------------------------------------------------------------------- /fade/config/user/generic.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: generic 3 | loss: xent 4 | optimizer: 5 | name: sgd 6 | learning_rate: 0.005 # Local learning rate 7 | personal_learning_rate: 0.09 # Persionalized learning rate to caculate theta aproximately using K steps 8 | privacy: 9 | enable: false 10 | # user_clip_norm: -1 # User-DP clip norm (layer wise) 11 | # user_dp_sigma: -1 # User-DP noise sigma 12 | batch_size: 20 13 | local_epochs: 20 14 | lamda: 0 # Regularization term 15 | K: 5 # Computation steps on local fine-tune -------------------------------------------------------------------------------- /fade/config/user/group_adv.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: group_adv 3 | loss: xent 4 | group_loss: bce # bce | sq_bce | xent | wd (Wasserstein distance) | dib | cdan 5 | adv_lambda: 1. # coef for adv loss. 6 | optimizer: 7 | name: sgd_sch 8 | learning_rate: 0.01 # Local learning rate 9 | personal_learning_rate: 0.09 # Persionalized learning rate to caculate theta aproximately using K steps 10 | privacy: 11 | enable: false 12 | # user_clip_norm: -1 # User-DP clip norm (layer wise) 13 | # user_dp_sigma: -1 # User-DP noise sigma 14 | batch_size: 32 15 | local_epochs: 20 16 | lamda: 0 # Regularization term 17 | K: 5 # Computation steps on local fine-tune -------------------------------------------------------------------------------- /fade/config/user/group_adv_office_uda.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: group_adv 3 | loss: sxent 4 | group_loss: bce # bce | sq_bce | xent | wd (Wasserstein distance) | dib | cdan 5 | adv_lambda: 1. # coef for adv loss. 6 | optimizer: 7 | name: sgd_sch 8 | learning_rate: 0.01 # Local learning rate 9 | personal_learning_rate: 0.09 # Persionalized learning rate to caculate theta aproximately using K steps 10 | privacy: 11 | enable: false 12 | # user_clip_norm: -1 # User-DP clip norm (layer wise) 13 | # user_dp_sigma: -1 # User-DP noise sigma 14 | batch_size: 32 15 | local_epochs: -1 16 | total_local_epochs: 10 17 | relabel_coef: 0. 18 | # useless 19 | lamda: 0 # Regularization term 20 | K: 5 # Computation steps on local fine-tune -------------------------------------------------------------------------------- /fade/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | from omegaconf import DictConfig, OmegaConf 7 | from fade.utils import hash_config 8 | from fade.file import FileManager 9 | 10 | from typing import TYPE_CHECKING 11 | if TYPE_CHECKING: 12 | from typing import Dict 13 | 14 | 15 | class FedDataset(object): 16 | """Multi-domain federated dataset.""" 17 | def __init__(self, cfg): 18 | self.cfg = cfg 19 | self.rng = None # Random State for generating 20 | 21 | def get_hash_name(self): 22 | return "/".join([self.cfg.name, hash_config(self.cfg, exclude_keys="viz")]) 23 | 24 | def generate(self): 25 | print("==" * 20) 26 | print(f"/// Generating {self.cfg.name} dataset ///") 27 | if hasattr(self.cfg, "seed"): 28 | self.rng = np.random.RandomState(self.cfg.seed) 29 | fed_dict = self._generate() 30 | self.save(fed_dict) 31 | print("==" * 20) 32 | 33 | def _generate(self) -> Dict: 34 | """Generate federated dataset. 35 | 36 | Returns: 37 | fed_dict: A dict consists of federated data. 38 | """ 39 | raise NotImplementedError() 40 | 41 | def viz(self): 42 | """Visualize to explore data.""" 43 | raise NotImplementedError() 44 | 45 | def exist(self, subset): 46 | """Return true if the `subset`.pt file exists.""" 47 | root_path = FileManager.data(os.path.join(self.get_hash_name(), subset), is_dir=True, overwrite=False, create_new=False) 48 | file_path = os.path.join(root_path, f"{subset}.pt") 49 | if not os.path.exists(file_path): 50 | print(f"Not found: {file_path}") 51 | return False 52 | else: 53 | return True 54 | 55 | def save(self, fed_dict): 56 | for subset in fed_dict: 57 | root_path = FileManager.data(os.path.join(self.get_hash_name(), subset), is_dir=True, overwrite=True) 58 | file_path = os.path.join(root_path, f"{subset}.pt") 59 | 60 | # with open(file_path, 'wb') as outfile: 61 | print(f"Dumping {subset} data => {file_path}") 62 | torch.save(fed_dict[subset], file_path) 63 | 64 | def load(self, generate_if_not_exist=False, subsets=["train", "test"]): 65 | """Load fed dict.""" 66 | # check existence. 67 | for subset in ("train", "test"): 68 | ex = self.exist(subset) 69 | if not ex: 70 | print(f"NOT found {subset} for {self.cfg.name} dataset.") 71 | if generate_if_not_exist: 72 | print(f"\n====== Regenerate =====") 73 | self.generate() 74 | print(f"====== Generation Finished =====") 75 | else: 76 | raise FileNotFoundError(f"{subset} for {self.cfg.name}") 77 | print() 78 | # Loading 79 | fed_dict = {} 80 | for subset in subsets: 81 | root_path = FileManager.data(os.path.join(self.get_hash_name(), subset), 82 | is_dir=True, overwrite=True) 83 | file_path = os.path.join(root_path, f"{subset}.pt") 84 | 85 | # with open(file_path, 'rb') as f: 86 | print(f"Load {subset} data <= {file_path}") 87 | fed_dict[subset] = torch.load(file_path) 88 | return fed_dict 89 | 90 | def getLogger(self, fname, subset, root_path=None, logger_name=None): 91 | if logger_name is None: 92 | logger_name = self.__class__.__name__ 93 | if root_path is None: 94 | log_fname = FileManager.data(os.path.join(self.get_hash_name(), subset, fname), is_dir=False, 95 | overwrite=True) 96 | else: 97 | log_fname = os.path.join(root_path, fname) 98 | print(f"Detail log to file: {log_fname}") 99 | logger = logging.getLogger(logger_name) 100 | logger.setLevel(logging.DEBUG) 101 | fh = logging.FileHandler(filename=log_fname, mode='w') 102 | fh.setFormatter(logging.Formatter('[%(asctime)s - %(levelname)s -' 103 | ' %(filename)s:%(funcName)s] %(message)s')) 104 | logger.addHandler(fh) 105 | logger.propagate = False # Set False to disable stdout print. 106 | return logger, log_fname 107 | 108 | 109 | def read_fed_dataset(cfg: DictConfig): 110 | # if cfg.name == "comb/MnistM": 111 | if cfg.name.startswith("comb/"): 112 | from .multi_domain import MDFedDataset as FedDataset 113 | elif cfg.name in ("Mnist", "MnistM", "SVHN", "USPS") or cfg.name.startswith("ReviewBow") \ 114 | or cfg.name.startswith("ReviewTok") \ 115 | or cfg.name.startswith("Office31") or cfg.name.startswith("OfficeHome65")\ 116 | or cfg.name.startswith("DomainNet"): 117 | from .federalize import FedExtDataset as FedDataset 118 | else: 119 | raise ValueError(f"Unknown data: {cfg.name} with config: \n{OmegaConf.to_yaml(cfg)}") 120 | fed_dict = FedDataset(cfg).load(generate_if_not_exist=True) 121 | 122 | groups = fed_dict["train"]["hierarchies"] if "hierarchies" in fed_dict["train"] else [] 123 | fed_dict["train"]["hierarchies"] = groups 124 | fed_dict["test"]["hierarchies"] = groups 125 | assert fed_dict["train"]['users'] == fed_dict["test"]['users'] 126 | 127 | return fed_dict 128 | -------------------------------------------------------------------------------- /fade/data/federalize.py: -------------------------------------------------------------------------------- 1 | """Transform central dataset into the format of federated. 2 | 3 | Available datasets: 4 | DigitFive: Mnist | MnistM | SVHN | USPS | SynDigit 5 | Example: python -m fade.data.extend -cn DigitFive name=SVHN n_class=5 n_user=20 6 | """ 7 | import logging 8 | import os 9 | import hydra 10 | import numpy as np 11 | import torch 12 | import torchvision.transforms as transforms 13 | from torch.utils.data import DataLoader, Subset, TensorDataset 14 | 15 | from fade.file import FileManager 16 | from . import FedDataset 17 | from .meta import load_meta_dataset 18 | from .utils import basic_stat_fed_dict 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class FedExtDataset(FedDataset): 24 | """Extend meta dataset (formatted by pytorch Dataset) to federated. 25 | 26 | To add new dataset: 27 | 1. Update `fade.data.meta.load_meta_dataset` with the data name and configures. 28 | 2. Add the data name entry to `read_fed_dataset` in fade.data.__init__. 29 | """ 30 | def _generate(self): 31 | print("Number of users: {}".format(self.cfg.n_user)) 32 | print("Number of classes: {}".format(self.cfg.n_class)) 33 | 34 | print(f"=== Reading source dataset ===") 35 | _root_name = self.cfg.name 36 | if _root_name.lower().startswith("emnist"): 37 | _root_name = _root_name[:-1] 38 | dataset_root_path = FileManager.data(os.path.join(_root_name, "data"), is_dir=True) 39 | 40 | test_set, train_set, _est_max_n_sample_per_shard, _min_n_sample_per_shard = \ 41 | load_meta_dataset(self.cfg, self.cfg.name, dataset_root_path) 42 | # FIXME This is ad-hoc. Set the default value in cfg. 43 | if self.cfg.max_n_sample_per_share is None: 44 | self.cfg.max_n_sample_per_share = _est_max_n_sample_per_shard 45 | if self.cfg.min_n_sample_per_share is None: 46 | self.cfg.min_n_sample_per_share = _min_n_sample_per_shard 47 | 48 | print("\n=== Processing training set ===") 49 | _, SRC_N_CLASS, train_idx_by_class = preprocess_dataset(train_set) 50 | 51 | # n_test_sample, train_idx_by_class, test_idx_by_class, SRC_N_CLASS = \ 52 | # preprocess_dataset(testset, trainset) 53 | assert SRC_N_CLASS >= self.cfg.n_class, \ 54 | f"Found N_CLASS_PER_USER={self.cfg.n_class} larger than SRC_N_CLASS={SRC_N_CLASS}" 55 | 56 | class_by_user = split_classes_by_user(SRC_N_CLASS, self.cfg.n_user, 57 | self.cfg.n_class, self.cfg.class_stride) 58 | 59 | # Split class into user shares 60 | classes, n_share_by_class = np.unique(class_by_user, return_counts=True) 61 | if len(classes) != SRC_N_CLASS: 62 | logger.warning(f"After user class splitting, only {len(classes)} are used which " 63 | f"is not equal to total {SRC_N_CLASS}.") 64 | 65 | partitioner = Partitioner( 66 | self.rng, partition_mode=self.cfg.partition_mode, 67 | max_n_sample_per_share=self.cfg.max_n_sample_per_share, 68 | min_n_sample_per_share=self.cfg.min_n_sample_per_share, 69 | max_n_sample=self.cfg.max_n_sample_per_class) 70 | train_idx_by_user = self.split_data(n_share_by_class, train_idx_by_class, 71 | class_by_user, partitioner) 72 | 73 | print("\n=== Processing test set ===") 74 | n_test_sample, test_SRC_N_CLASS, test_idx_by_class = preprocess_dataset(test_set) 75 | # SRC_N_CLASS may not equal test_SRC_N_CLASS 76 | assert SRC_N_CLASS == test_SRC_N_CLASS 77 | assert test_SRC_N_CLASS >= self.cfg.n_class, \ 78 | f"Found N_CLASS_PER_USER={self.cfg.n_class} larger than SRC_N_CLASS={test_SRC_N_CLASS}" 79 | 80 | # FIXME for test set, n_test_sample=-1 ==> Use all samples, Otherwise, randomly select. 81 | assert not hasattr(self.cfg, "n_test_sample"), "Not support cfg: n_test_sample" 82 | partitioner = Partitioner(self.rng, partition_mode="uni", 83 | max_n_sample=-1, max_n_sample_per_share=-1, 84 | min_n_sample_per_share=2) 85 | test_idx_by_user = self.split_data(n_share_by_class, test_idx_by_class, 86 | class_by_user, partitioner) 87 | 88 | # Create data structure 89 | print("\n=== Construct data dict ===") 90 | fed_dict = { 91 | 'train': self.construct_fed_dict(train_set, train_idx_by_user, is_train=True), 92 | 'test': self.construct_fed_dict(test_set, test_idx_by_user, is_train=False), 93 | } 94 | 95 | for subset in ["train", "test"]: 96 | print(f"{subset.upper()} #sample by user:", fed_dict['train']['num_samples']) 97 | simple_stat(fed_dict['train']['num_samples']) 98 | print("Total_samples:", sum(fed_dict['train']['num_samples'] + fed_dict['test']['num_samples']), "TRAIN", 99 | sum(fed_dict['train']['num_samples']), "TEST", sum(fed_dict['test']['num_samples'])) 100 | 101 | return fed_dict 102 | 103 | def viz(self, do=True, subset="train", user_idx=0, title=''): 104 | print(f"Analysis of dataset: {self.cfg.name}") 105 | fed_dict = self.load(generate_if_not_exist=True) 106 | dataset = fed_dict[subset] 107 | print(f"== Basic stat of {subset} set ==") 108 | basic_stat_fed_dict(dataset, verbose=True) 109 | 110 | import matplotlib.pyplot as plt 111 | from .utils import plot_sample_size_dist 112 | ax = plot_sample_size_dist(fed_dict, subset) 113 | ax.set(title=f"{self.cfg.name} {subset}") 114 | plt.tight_layout() 115 | plt.show() 116 | 117 | if hasattr(self.cfg, 'feature_type') and self.cfg.feature_type == "images": 118 | # Plot the images of one user. 119 | user = dataset['users'][user_idx] 120 | from .utils import grid_imshow 121 | if self.cfg.user_data_format == "index": 122 | ds = Subset(dataset['dataset'], dataset['user_data'][user]['idx']) 123 | class_name_fh = lambda i: dataset['dataset'].classes[targets[i].numpy()] 124 | elif self.cfg.user_data_format == "dataset": 125 | ds = dataset['user_data'][user]['dataset'] 126 | classes = ds.dataset.classes if isinstance(ds, Subset) else ds.classes 127 | class_name_fh = lambda i: classes[targets[i].numpy()] 128 | elif self.cfg.user_data_format == "tensor": 129 | ds = TensorDataset(dataset['user_data'][user]['x'], dataset['user_data'][user]['y']) 130 | class_name_fh = lambda i: targets[i] 131 | else: 132 | raise ValueError(f"self.cfg.user_data_format: {self.cfg.user_data_format}") 133 | # print(dataset['dataset'].targets[dataset['user_data'][user]['idx']]) 134 | loader = DataLoader(ds, batch_size=16, shuffle=True) 135 | 136 | imgs, targets = next(iter(loader)) 137 | grid_imshow(imgs, normalize=True, title=class_name_fh) 138 | plt.show() 139 | 140 | def split_data(self, n_share_by_class, train_idx_by_class, class_by_user, 141 | partitioner): 142 | """The train_idx_by_class will be split into user according to class_by_user and 143 | n_share_by_class. Partitioned by partitioner.""" 144 | n_user = len(class_by_user) 145 | data_idx_by_user = [[] for _ in range(n_user)] 146 | print(f" # of shards for each class: {n_share_by_class}") 147 | for cl in range(len(n_share_by_class)): 148 | n_share = n_share_by_class[cl] 149 | data_idxs = train_idx_by_class[cl] 150 | n_smp = len(data_idxs) 151 | print(f" Split {n_smp} samples of class {cl} into {n_share} shares.") 152 | 153 | # Split data of class. 154 | partitions = partitioner(n_smp, n_share) 155 | simple_stat(partitions) 156 | end_idxs = [0] + np.cumsum(partitions).tolist() 157 | 158 | # Assign shares to users. 159 | self.rng.shuffle(data_idxs) # in-place 160 | i_share = 0 161 | for user in range(n_user): 162 | if cl in class_by_user[user]: 163 | start_i = end_idxs[i_share] 164 | end_i = end_idxs[i_share+1] 165 | data_idx_by_user[user].extend(data_idxs[start_i:end_i]) 166 | i_share += 1 167 | assert i_share == n_share, f"Share is not fully used. Generate {n_share} shares, " \ 168 | f"but only {i_share} shares are used." 169 | if end_i < len(data_idxs): 170 | logger.warning(f" Use {end_i} out of {len(data_idxs)} samples. Total {len(data_idxs) - end_i} samples are droped.") 171 | return data_idx_by_user 172 | 173 | def construct_fed_dict(self, dataset, data_idxs_by_user, is_train): 174 | data_dict = {'users': [], 'user_data': {}, 'num_samples': []} 175 | if self.cfg.user_data_format == "tensor": 176 | load_all_data_to_tensor(dataset) 177 | elif self.cfg.user_data_format == "index": 178 | data_dict['dataset'] = dataset 179 | elif self.cfg.user_data_format == "dataset": 180 | pass 181 | else: 182 | raise ValueError(f"self.cfg.user_data_format: {self.cfg.user_data_format}") 183 | 184 | if hasattr(self.cfg, 'niid_distort_train') and self.cfg.niid_distort_train: 185 | assert self.cfg.user_data_format == "dataset", "Only support niid_distort_train " \ 186 | "when user_data_format=dataset" 187 | for i in range(self.cfg.n_user): 188 | uname = 'f_{0:05d}'.format(i) 189 | 190 | data_dict['users'].append(uname) 191 | if self.cfg.user_data_format == "index": 192 | data_dict['user_data'][uname] = { 193 | 'idx': data_idxs_by_user[i] 194 | } 195 | elif self.cfg.user_data_format == "dataset": 196 | if is_train and hasattr(self.cfg, 'niid_distort_train') and self.cfg.niid_distort_train: 197 | # Rearrange distort transform in non-iid manner 198 | from copy import deepcopy 199 | from .meta.distort import DistortTransform, MultiDistortTransform,\ 200 | PRESET_DISTORTION_SETS 201 | from .meta.office_caltech import DistortPathMaker 202 | ds = deepcopy(dataset) 203 | ts = [] 204 | if isinstance(self.cfg.distort_train, str) \ 205 | and self.cfg.distort_train in PRESET_DISTORTION_SETS: 206 | distort_train = PRESET_DISTORTION_SETS[self.cfg.distort_train] 207 | else: 208 | distort_train = self.cfg.distort_train 209 | assert self.cfg.n_user == len(distort_train), \ 210 | f"Not enough distort methods for {self.cfg.n_user} users. " \ 211 | f"All distort: {distort_train}." 212 | for t in ds.transform.transforms: 213 | # replace 214 | if isinstance(t, MultiDistortTransform): 215 | t = DistortTransform(distort_train[i], t.severity) 216 | elif isinstance(t, DistortPathMaker): 217 | t = DistortPathMaker(distort_train[i], t.severity) 218 | ts.append(t) 219 | ds.transform = transforms.Compose(ts) 220 | else: 221 | ds = dataset 222 | data_dict['user_data'][uname] = { 223 | 'dataset': Subset(ds, data_idxs_by_user[i]) 224 | } 225 | elif self.cfg.user_data_format == "tensor": 226 | idxs = data_idxs_by_user[i] 227 | data_dict['user_data'][uname] = { 228 | 'x': torch.tensor(dataset.data[idxs], dtype=torch.float32), 229 | 'y': torch.tensor(dataset.targets[idxs], dtype=torch.int64) 230 | } 231 | else: 232 | raise ValueError(f"self.cfg.user_data_format: {self.cfg.user_data_format}") 233 | data_dict['num_samples'].append(len(data_idxs_by_user[i])) 234 | return data_dict 235 | 236 | 237 | def split_classes_by_user(total_n_class, n_user, n_class_per_user, class_stride, mode="seq"): 238 | """ 239 | 240 | Args: 241 | total_n_class (): 242 | n_user (): 243 | n_class_per_user (): 244 | class_stride (): 245 | mode (): 246 | 247 | Returns: 248 | user_classes is a list where user_classes[i] is a list of classes for user i. 249 | """ 250 | print(f"Split {total_n_class} classes into {n_user} users. ") 251 | if mode == "seq": 252 | print(f" MODE: {mode}") 253 | print(f" {n_class_per_user} classes per user and {class_stride} stride.") 254 | user_classes = [[] for _ in range(n_user)] 255 | 256 | for user in range(n_user): 257 | for j in range(n_class_per_user): 258 | l = (user * class_stride + j) % total_n_class 259 | user_classes[user].append(l) 260 | # TODO flatten user_classes 261 | elif mode == "random": 262 | # Randomly assign some classes 263 | raise NotImplementedError() 264 | else: 265 | raise RuntimeError(f"Unknown mode: {mode}") 266 | print(f"Classes by user") 267 | for user in range(n_user): 268 | print(f" user {user}: {user_classes[user]}") 269 | return user_classes 270 | 271 | 272 | def preprocess_dataset(dataset): 273 | """Get data indexes for each class and check the sample shape.""" 274 | n_sample = len(dataset) 275 | n_class = len(dataset.classes) 276 | data_idx_by_class = rearrange_dataset_by_class(dataset, n_class) 277 | 278 | smp = dataset[0][0] 279 | if not isinstance(smp, torch.Tensor): 280 | if isinstance(smp, np.ndarray): 281 | smp = torch.from_numpy(smp) 282 | else: 283 | smp = transforms.ToTensor()(smp) 284 | print(f" Total #samples: {n_sample}. sample shape: {smp.shape}") 285 | print(" #samples per class:\n", [len(v) for v in data_idx_by_class]) 286 | return n_sample, n_class, data_idx_by_class 287 | 288 | 289 | def rearrange_dataset_by_class(dataset, n_class): 290 | """Get data indexes for each class""" 291 | data_by_class = [[] for _ in range(n_class)] 292 | for i, y in enumerate(dataset.targets): 293 | data_by_class[y].append(i) 294 | return data_by_class 295 | 296 | 297 | class Partitioner: 298 | """Partition a sequence into shares.""" 299 | def __init__(self, rng, partition_mode="dir", 300 | max_n_sample_per_share=10, 301 | min_n_sample_per_share=2, 302 | max_n_sample=-1 303 | ): 304 | self.rng = rng 305 | self.partition_mode = partition_mode 306 | self.max_n_sample_per_share = max_n_sample_per_share 307 | self.min_n_sample_per_share = min_n_sample_per_share 308 | self.max_n_sample = max_n_sample 309 | 310 | def __call__(self, n_sample, n_share): 311 | """Partition a sequence of `n_sample` into `n_share` shares. 312 | Returns: 313 | partition: A list of num of samples for each share. 314 | """ 315 | print(f"{n_sample} samples => {n_share} shards by {self.partition_mode} distribution.") 316 | if self.max_n_sample > 0: 317 | n_sample = min((n_sample, self.max_n_sample)) 318 | if self.max_n_sample_per_share > 0: 319 | n_sample = min((n_sample, n_share * self.max_n_sample_per_share)) 320 | 321 | n_sample -= self.min_n_sample_per_share * n_share 322 | if self.partition_mode == "dir": 323 | partition = (self.rng.dirichlet(n_share * [1]) * n_sample).astype(int) 324 | elif self.partition_mode == "uni": 325 | partition = int(n_sample // n_share) * np.ones(n_share, dtype='int') 326 | else: 327 | raise ValueError(f"Invalid partition_mode: {self.partition_mode}") 328 | 329 | partition[-1] += n_sample - np.sum(partition) # add residual 330 | assert sum(partition) == n_sample, f"{sum(partition)} != {n_sample}" 331 | partition = partition + self.min_n_sample_per_share 332 | n_sample += self.min_n_sample_per_share * n_share 333 | # partition = np.minimum(partition, max_n_sample_per_share) 334 | partition = partition.tolist() 335 | 336 | assert sum(partition) == n_sample, f"{sum(partition)} != {n_sample}" 337 | assert len(partition) == n_share, f"{len(partition)} != {n_share}" 338 | return partition 339 | 340 | 341 | def simple_stat(arr): 342 | res = {} 343 | for metric in ("mean", "std", "max", "min", "median"): 344 | res[metric] = eval(f'np.{metric}(arr)') 345 | print(f" {metric}: {res[metric]:.4g}", end=", ") 346 | print("") 347 | return res 348 | 349 | 350 | def load_all_data_to_tensor(dataset): 351 | data_loader = DataLoader(dataset, batch_size=len(dataset), drop_last=False, 352 | shuffle=False) 353 | for x, y in data_loader: 354 | dataset.data, dataset.targets = x, y 355 | return dataset 356 | 357 | 358 | @hydra.main(config_name="Office31Av1", config_path="../config/dataset") 359 | def main(cfg): 360 | ds = FedExtDataset(cfg) 361 | if hasattr(cfg, "viz") and cfg.viz.do: 362 | # ds.generate() 363 | ds.viz(**cfg.viz) 364 | else: 365 | ds.generate() 366 | 367 | 368 | if __name__ == '__main__': 369 | main() 370 | -------------------------------------------------------------------------------- /fade/data/meta/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | from torchvision import transforms 4 | from fade.file import FileManager 5 | from omegaconf import ListConfig 6 | 7 | from typing import TYPE_CHECKING 8 | if TYPE_CHECKING: 9 | from omegaconf import OmegaConf 10 | from typing import Union, Tuple 11 | from torch.utils.data import Dataset 12 | 13 | 14 | def load_meta_dataset(cfg: OmegaConf, data_name: str, dataset_root_path: str) \ 15 | -> Tuple[Dataset, Dataset, int, int]: 16 | """Load meta dataset 17 | 18 | Args: 19 | cfg: omega config object 20 | data_name: The data name 21 | dataset_root_path: Root to where the data (data_name) is stored. 22 | 23 | Returns: 24 | testset, trainset, est_max_n_sample_per_shard, min_n_sample_per_shard 25 | testset, trainset are torch Dataset objects. 26 | """ 27 | est_max_n_sample_per_shard = -1 28 | min_n_sample_per_shard = 2 29 | if data_name.lower() == "mnist": 30 | from torchvision.datasets import MNIST 31 | 32 | nc = cfg.n_channel 33 | if nc == 3: 34 | from ..utils import GrayscaleToRgb 35 | # trans = [transforms.Lambda(lambda x: x.convert("RGB"))] 36 | trans = [GrayscaleToRgb()] 37 | else: 38 | assert nc == 1 39 | trans = [] 40 | # if cfg.binarize: # used in DANN experiments. But need to change normalize 41 | # trans += [ToBinary()] 42 | if hasattr(cfg, 'resize'): # to match SVHN 32 43 | assert cfg.resize > 0 44 | trans += [transforms.Resize(cfg.resize)] 45 | trans += [ 46 | # GrayscaleToRgb(), 47 | transforms.ToTensor(), 48 | # TODO when construst MnistM, we also need this. 49 | # ToBinary(scale=torch.tensor(1.)), # Used in DANN (official code, https://github.com/pumpikano/tf-dann/blob/master/MNIST-DANN.ipynb 50 | # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 51 | transforms.Normalize([.5], [.5]) 52 | ] 53 | trans = transforms.Compose(trans) 54 | testset = MNIST(root=dataset_root_path, train=False, transform=trans, download=True) 55 | trainset = MNIST(root=dataset_root_path, train=True, transform=trans, download=True) 56 | 57 | # %% Configure %% 58 | est_max_n_sample_per_shard = 10 59 | min_n_sample_per_shard = 2 60 | elif data_name.lower().startswith("usps"): # NOTE: place this before mnist. 61 | from .usps import USPS 62 | 63 | assert cfg.n_channel in [1, 3], \ 64 | f"Invalid n_channel: {cfg.n_channel}. Expected 1 or 3." 65 | trans = [] 66 | if cfg.n_channel == 3: 67 | from ..utils import GrayscaleToRgb 68 | trans += [GrayscaleToRgb()] 69 | if cfg.random_crop_rot: 70 | trans += [ 71 | transforms.RandomCrop(28, padding=4), 72 | transforms.RandomRotation(10), 73 | ] 74 | trans += [ 75 | # GrayscaleToRgb(), 76 | transforms.ToTensor(), 77 | transforms.Normalize([0.5], [0.5]) 78 | ] 79 | trans = transforms.Compose(trans) 80 | trainset = USPS(root=dataset_root_path, train=True, transform=trans, download=True) 81 | 82 | trans = [] 83 | if cfg.n_channel == 3: 84 | from ..utils import GrayscaleToRgb 85 | trans += [GrayscaleToRgb()] 86 | trans += [ 87 | # GrayscaleToRgb(), 88 | transforms.ToTensor(), 89 | transforms.Normalize([0.5], [0.5]) 90 | ] 91 | trans = transforms.Compose(trans) 92 | testset = USPS(root=dataset_root_path, train=False, transform=trans, download=True) 93 | trainset.classes = list(range(10)) 94 | testset.classes = list(range(10)) 95 | 96 | # %% Configure %% 97 | est_max_n_sample_per_shard = 10 98 | min_n_sample_per_shard = 2 99 | elif data_name.lower().startswith("office"): 100 | from .office_caltech import get_office_caltech_dataset 101 | 102 | if data_name.lower().startswith("OfficeCal10".lower()): 103 | source = "OfficeCaltech10" 104 | elif data_name.lower().startswith("Office31".lower()): 105 | source = "office31" 106 | elif data_name.lower().startswith("OfficeHome65".lower()): 107 | source = "OfficeHome65" 108 | else: 109 | raise ValueError(f"data_name: {data_name}") 110 | 111 | # NOTE: DO NOT use a different domain as test domain. Otherwise, you 112 | # will see a drop in the loss after each test (evaluation). 113 | if source == "OfficeHome65": 114 | if hasattr(cfg, 'domain') and cfg.domain != "default": 115 | domain = cfg.domain 116 | test_domain = domain 117 | else: 118 | if data_name[-1].lower() == "a": 119 | domain = "Art" 120 | test_domain = "Clipart" 121 | elif data_name[-1].lower() == "c": 122 | domain = "Clipart" 123 | test_domain = "Art" 124 | elif data_name[-1].lower() == "p": 125 | domain = "Product" 126 | test_domain = "Art" 127 | elif data_name[-1].lower() == "r": 128 | domain = "RealWorld" 129 | test_domain = "Art" 130 | else: 131 | raise ValueError(f"data_name: {data_name}") 132 | else: # Office31 133 | if hasattr(cfg, 'domain') and cfg.domain != "default": 134 | domain = cfg.domain 135 | test_domain = domain 136 | else: 137 | if data_name[-1].lower() == "a": 138 | domain = "amazon" 139 | test_domain = "webcam" 140 | elif data_name[-1].lower() == "d": 141 | domain = "dslr" 142 | test_domain = "webcam" 143 | elif data_name[-1].lower() == "w": 144 | domain = "webcam" 145 | test_domain = "dslr" 146 | elif source == "OfficeCaltech10" and data_name[-1].lower() == "c": 147 | domain = "caltech10" 148 | test_domain = "amazon" 149 | else: 150 | raise ValueError(f"data_name: {data_name}") 151 | 152 | if hasattr(cfg, "ood_test_domain"): # out-of-distribution test 153 | if isinstance(cfg.ood_test_domain, bool): 154 | if not cfg.ood_test_domain: 155 | test_domain = domain 156 | else: 157 | assert isinstance(cfg.ood_test_domain, str) 158 | if cfg.ood_test_domain == "self": 159 | test_domain = domain 160 | elif cfg.ood_test_domain == "default": 161 | pass 162 | else: 163 | test_domain = cfg.ood_test_domain 164 | else: 165 | test_domain = domain 166 | if cfg.feature_type == "images": 167 | # standard AlexNet and ResNet101 preprocessing. 168 | # Refer to 169 | # - AlexNet: https://pytorch.org/hub/pytorch_vision_alexnet/ 170 | # - ResNet: https://pytorch.org/hub/pytorch_vision_resnet/ 171 | train_data_kwargs = {} 172 | train_trans_ = [ 173 | transforms.Resize((256, 256)), 174 | transforms.RandomCrop(224), 175 | transforms.RandomHorizontalFlip() 176 | ] 177 | if hasattr(cfg, "distort_train") and cfg.distort_train != "none": 178 | raise NotImplementedError(f"Not support distortion anymore.") 179 | train_trans_ += [ 180 | transforms.ToTensor(), 181 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 182 | ] 183 | train_trans = transforms.Compose(train_trans_) 184 | 185 | test_data_kwargs = {} 186 | test_trans_ = [ 187 | transforms.Resize((256, 256)), 188 | transforms.CenterCrop(224), 189 | ] 190 | if hasattr(cfg, "distort") and cfg.distort != "none" and cfg.severity > 0: 191 | raise NotImplementedError(f"Not support distortion anymore.") 192 | test_trans_ += [ 193 | transforms.ToTensor(), 194 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 195 | ] 196 | test_trans = transforms.Compose(test_trans_) 197 | 198 | trainset = get_office_caltech_dataset(source=source, domain=domain, 199 | transform=train_trans, **train_data_kwargs) 200 | testset = get_office_caltech_dataset(source=source, domain=test_domain, 201 | transform=test_trans, **test_data_kwargs) 202 | else: 203 | trainset = get_office_caltech_dataset(source=source, domain=domain, 204 | feature_type=cfg.feature_type) 205 | testset = get_office_caltech_dataset(source=source, domain=test_domain, 206 | feature_type=cfg.feature_type) 207 | elif data_name.lower().startswith("adult"): 208 | from .adult import Adult 209 | root = FileManager.data('adult', is_dir=True) 210 | if data_name.lower().endswith("w"): 211 | group = "white" 212 | group_by = "white_black" 213 | elif data_name.lower().endswith("b"): 214 | group = "black" 215 | group_by = "white_black" 216 | elif data_name.lower().endswith("f"): 217 | group = "female" 218 | group_by = "gender" 219 | elif data_name.lower().endswith("m"): 220 | group = "male" 221 | group_by = "gender" 222 | else: 223 | raise ValueError(f"data_name : {data_name}") 224 | trainset = Adult(root, train=True, group_by=group_by, group=group) 225 | testset = Adult(root, train=False, group_by=group_by, group=group) 226 | else: 227 | raise ValueError(f"Unknown data name: {data_name}") 228 | return testset, trainset, est_max_n_sample_per_shard, min_n_sample_per_shard 229 | -------------------------------------------------------------------------------- /fade/data/meta/adult.py: -------------------------------------------------------------------------------- 1 | """Adult dataset. 2 | 3 | References 4 | https://github.com/jctaillandier/adult_neuralnet/blob/master/adult_nn.ipynb 5 | https://github.com/htwang14/fairness/blob/main/dataloaders/adult.py 6 | """ 7 | import gzip 8 | import os 9 | import pickle 10 | import urllib 11 | 12 | import numpy as np 13 | import pickle as pk 14 | import torch 15 | import torch.utils.data as data 16 | from torch.utils.data.sampler import WeightedRandomSampler 17 | from torchvision import datasets, transforms 18 | 19 | import pandas as pd 20 | from sklearn.preprocessing import LabelEncoder 21 | from sklearn.preprocessing import OneHotEncoder 22 | from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler 23 | from torchvision import datasets, transforms 24 | from torch.utils.data import DataLoader, Subset 25 | 26 | import time, os, random 27 | from tqdm import tqdm 28 | 29 | 30 | class Adult(data.Dataset): 31 | """Adult Dataset. 32 | 33 | Args: 34 | root (string): Root directory of dataset where dataset file exist. 35 | train (bool, optional): If True, resample from dataset randomly. 36 | download (bool, optional): If true, downloads the dataset 37 | from the internet and puts it in root directory. 38 | If dataset is already downloaded, it is not downloaded again. 39 | transform (callable, optional): A function/transform that takes in 40 | an PIL image and returns a transformed version. 41 | E.g, ``transforms.RandomCrop`` 42 | group (str): Rely on group_by 43 | group_by (str): white_black 44 | """ 45 | 46 | url = "https://www.kaggle.com/wenruliu/adult-income-dataset" 47 | classes = ['<=50K', '>50K'] 48 | 49 | def __init__(self, root, train=True, transform=None, group='white', create_new=True, 50 | group_by='white_black'): 51 | # init params 52 | self.root = root 53 | self.train = train 54 | subset = "train" if train else "test" 55 | self.filename = f"{subset}_{group}.npy" 56 | self.adult_csv_filename = "adult.csv" 57 | # Num of Train = , Num ot Test 58 | self.transform = transform 59 | # self.train_size = 30000 60 | 61 | if create_new and not self._check_exists(): 62 | csv_path = os.path.join(root, self.adult_csv_filename) 63 | if not os.path.exists(csv_path): 64 | raise RuntimeError( 65 | f"File not found at {csv_path}. Download csv file from {self.url} " 66 | f"and place it at {csv_path}") 67 | prep_adult(original_csv_path=csv_path, save_path=self.root) 68 | 69 | if not self._check_exists(): 70 | filename = os.path.join(self.root, self.filename) 71 | raise RuntimeError(f"File not found at {filename}. Use create_new=True") 72 | 73 | self.data, self.targets = torch.load(os.path.join(self.root, self.filename)) 74 | 75 | def __getitem__(self, index): 76 | """Get images and target for data loader. 77 | Args: 78 | index (int): Index 79 | Returns: 80 | tuple: (image, target) where target is index of the target class. 81 | """ 82 | x, label = self.data[index], self.targets[index] 83 | x = torch.from_numpy(x) 84 | if self.transform is not None: 85 | x = self.transform(x) 86 | return x, label.astype("int64") 87 | 88 | def __len__(self): 89 | """Return size of dataset.""" 90 | return len(self.data) 91 | 92 | def _check_exists(self): 93 | """Check if dataset is download and in right place.""" 94 | return os.path.exists(os.path.join(self.root, self.filename)) 95 | 96 | 97 | def whiten(X, mean, std): 98 | X = X - mean 99 | X = np.divide(X, std + 1e-6) 100 | return X 101 | 102 | 103 | def prep_adult(original_csv_path="datasets/adult.csv", save_path='datasets/adult'): 104 | full_data = pd.read_csv( 105 | original_csv_path, 106 | names=[ 107 | "Age", "Workclass", "fnlwgt", "Education", "Education-Num", "Martial Status", 108 | "Occupation", "Relationship", "Race", "Sex", "Capital Gain", "Capital Loss", 109 | "Hours per week", "Country", "Target"], 110 | sep=r'\s*,\s*', 111 | engine='python', skiprows=1, 112 | na_values="?", 113 | dtype={0: int, 1: str, 2: int, 3: str, 4: int, 5: str, 6: str, 7: str, 8: str, 9: str, 114 | 10: int, 11: int, 12: int, 13: str, 14: str}) 115 | 116 | print('Dataset size: ', full_data.shape[0]) 117 | 118 | str_list = [] 119 | for data in [full_data]: 120 | for colname, colvalue in data.iteritems(): 121 | if type(colvalue[1]) == str: 122 | str_list.append(colname) 123 | num_list = data.columns.difference(str_list) 124 | 125 | # Replace '?' with NaN, then delete those rows: 126 | full_size = full_data.shape[0] 127 | print('Dataset size Before pruning: ', full_size) 128 | for data in [full_data]: 129 | for i in full_data: 130 | data[i].replace('nan', np.nan, inplace=True) 131 | data.dropna(inplace=True) 132 | real_size = full_data.shape[0] 133 | 134 | print('Dataset size after pruning: ', real_size) 135 | print('We eliminated ', (full_size - real_size), ' datapoints') 136 | 137 | # Take labels out and encode them: 138 | full_labels = full_data['Target'].copy() 139 | label_encoder = LabelEncoder() 140 | full_labels = label_encoder.fit_transform(full_labels) 141 | print(f"Classes: {np.unique(full_labels)}, {label_encoder.classes_}") 142 | 143 | # Get male_idx and female_idx: 144 | print(full_data.head()) 145 | male_idx = np.array(full_data['Sex'] == 'Male') # boolean, len=45222 146 | female_idx = np.array(full_data['Sex'] == 'Female') # boolean, len=45222 147 | print('male_idx:', male_idx[0:5], male_idx.shape) 148 | print('female_idx:', female_idx[0:5], female_idx.shape) 149 | np.save(os.path.join(save_path, 'male_idx.npy'), male_idx) 150 | np.save(os.path.join(save_path, 'female_idx.npy'), female_idx) 151 | 152 | # get race idx: 153 | indian_idx = np.array(full_data['Race'] == 'Amer-Indian-Eskimo') 154 | asian_idx = np.array(full_data['Race'] == 'Asian-Pac-Islander') 155 | black_idx = np.array(full_data['Race'] == 'Black') 156 | other_idx = np.array(full_data['Race'] == 'Other') 157 | white_idx = np.array(full_data['Race'] == 'White') 158 | 159 | # Deal with categorical variables: 160 | full_data = full_data.drop(['Target'], axis=1) 161 | cat_data = full_data.select_dtypes(include=['object']).copy() 162 | other_data = full_data.select_dtypes(include=['int']).copy() 163 | print('cat_data:', cat_data.shape) # cat_data: (45222, 8) 164 | print('other_data:', other_data.shape) # other_data: (45222, 6) 165 | 166 | # Then One Hot encode other Categorical Variables: 167 | newcat_data = pd.get_dummies(cat_data, columns=[ 168 | "Workclass", "Education", "Country", "Relationship", "Martial Status", "Occupation", 169 | "Relationship", 170 | "Race", "Sex" 171 | ]) 172 | print('newcat_data:', newcat_data.shape) # newcat_data: (45222, 104) 173 | 174 | # Append all columns back together: 175 | full_data = pd.concat([other_data, newcat_data], axis=1) 176 | print('full_data:', full_data.shape) # full_data: (45222, 110) 177 | 178 | # Dataframe to npy: 179 | full_data = np.asarray(full_data).astype(np.float32) 180 | 181 | # Split and whitening: 182 | train_size = 30000 # Given 45222 datapoints, # test_size is the remainder 183 | 184 | train_x = full_data[:train_size, 185 | :] # M: train_x[i,-2:] == np.array([0,1]); F: train_x[i,-2:] == np.array([1,0]) 186 | test_x = full_data[train_size:, :] 187 | # print('train_x:', train_x.shape) 188 | # print(train_x[0:5,-2:]) 189 | 190 | mean = np.mean(train_x, axis=0) 191 | std = np.std(train_x, axis=0) 192 | # print(mean, std) 193 | train_x = whiten(train_x, mean, std) 194 | print('train_x:', 195 | train_x.shape) # M: train_x[i,-2:] == np.array([-0.69225496 , 0.692255]); F: train_x[i,-2:] == np.array([1.4445544, -1.4445543]) 196 | test_x = whiten(test_x, mean, std) 197 | print('test_x:', test_x.shape) 198 | 199 | full_data = np.concatenate([train_x, test_x], axis=0) 200 | print('full_data:', full_data.shape) 201 | print() 202 | 203 | train_labels = full_labels[:train_size] 204 | test_labels = full_labels[train_size:] 205 | 206 | # Save male and female data seperately as .npy files: 207 | train_male_idx = male_idx[:train_size] 208 | train_female_idx = female_idx[:train_size] 209 | test_male_idx = male_idx[train_size:] 210 | test_female_idx = female_idx[train_size:] 211 | 212 | train_male_data = train_x[train_male_idx] # train_male_data: (20281, 110) 213 | train_male_targets = train_labels[train_male_idx] # train_male_targets: (20281,) 214 | train_female_data = train_x[train_female_idx] # train_female_data: (9719, 110) 215 | train_female_targets = train_labels[train_female_idx] # train_female_targets: (9719,) 216 | print('train_male_data:', train_male_data.shape) 217 | print('train_male_targets:', train_male_targets.shape) 218 | print('train_female_data:', train_female_data.shape) 219 | print('train_female_targets:', train_female_targets.shape) 220 | print() 221 | 222 | test_male_data = test_x[test_male_idx] # test_male_data: (10246, 110) 223 | test_male_targets = test_labels[test_male_idx] # test_male_targets: (10246,) 224 | test_female_data = test_x[test_female_idx] # test_female_data: (4976, 110) 225 | test_female_targets = test_labels[test_female_idx] # test_female_targets: (4976,) 226 | print('test_male_data:', test_male_data.shape) 227 | print('test_male_targets:', test_male_targets.shape) 228 | print('test_female_data:', test_female_data.shape) 229 | print('test_female_targets:', test_female_targets.shape) 230 | print() 231 | 232 | # np.save(os.path.join(save_path, 'train_male_data.npy'), train_male_data) 233 | # np.save(os.path.join(save_path, 'train_male_targets.npy'), train_male_targets) 234 | # np.save(os.path.join(save_path, 'train_female_data.npy'), train_female_data) 235 | # np.save(os.path.join(save_path, 'train_female_targets.npy'), train_female_targets) 236 | # 237 | # np.save(os.path.join(save_path, 'test_male_data.npy'), test_male_data) 238 | # np.save(os.path.join(save_path, 'test_male_targets.npy'), test_male_targets) 239 | # np.save(os.path.join(save_path, 'test_female_data.npy'), test_female_data) 240 | # np.save(os.path.join(save_path, 'test_female_targets.npy'), test_female_targets) 241 | 242 | torch.save((train_male_data, train_male_targets), os.path.join(save_path, 'train_male.npy')) 243 | torch.save((train_female_data, train_female_targets), os.path.join(save_path, 'train_female.npy')) 244 | 245 | torch.save((test_male_data, test_male_targets), os.path.join(save_path, 'test_male.npy')) 246 | torch.save((test_female_data, test_female_targets), os.path.join(save_path, 'test_female.npy')) 247 | 248 | # Save race data seperately as .npy files: 249 | train_white_idx = white_idx[:train_size] 250 | train_black_idx = black_idx[:train_size] 251 | test_white_idx = white_idx[train_size:] 252 | test_black_idx = black_idx[train_size:] 253 | 254 | train_white_data = train_x[train_white_idx] # train_white_data: (25800, 110) 255 | train_white_targets = train_labels[train_white_idx] # train_white_targets: (25800,) 256 | train_black_data = train_x[train_black_idx] # train_black_data: (2797, 110) 257 | train_black_targets = train_labels[train_black_idx] # train_black_targets: (2797,) 258 | print('train_white_data:', train_white_data.shape) 259 | print('train_white_targets:', train_white_targets.shape) 260 | print('train_black_data:', train_black_data.shape) 261 | print('train_black_targets:', train_black_targets.shape) 262 | print() 263 | 264 | test_white_data = test_x[test_white_idx] # test_white_data: (13103, 110) 265 | test_white_targets = test_labels[test_white_idx] # test_white_targets: (13103,) 266 | test_black_data = test_x[test_black_idx] # test_black_data: (1431, 110) 267 | test_black_targets = test_labels[test_black_idx] # test_black_targets: (1431,) 268 | print('test_white_data:', test_white_data.shape) 269 | print('test_white_targets:', test_white_targets.shape) 270 | print('test_black_data:', test_black_data.shape) 271 | print('test_black_targets:', test_black_targets.shape) 272 | print() 273 | 274 | # np.save(os.path.join(save_path, 'train_white_data.npy'), train_white_data) 275 | # np.save(os.path.join(save_path, 'train_white_targets.npy'), train_white_targets) 276 | # np.save(os.path.join(save_path, 'train_black_data.npy'), train_black_data) 277 | # np.save(os.path.join(save_path, 'train_black_targets.npy'), train_black_targets) 278 | # 279 | # np.save(os.path.join(save_path, 'test_white_data.npy'), test_white_data) 280 | # np.save(os.path.join(save_path, 'test_white_targets.npy'), test_white_targets) 281 | # np.save(os.path.join(save_path, 'test_black_data.npy'), test_black_data) 282 | # np.save(os.path.join(save_path, 'test_black_targets.npy'), test_black_targets) 283 | 284 | torch.save((train_white_data, train_white_targets), os.path.join(save_path, 'train_white.npy')) 285 | torch.save((train_black_data, train_black_targets), os.path.join(save_path, 'train_black.npy')) 286 | 287 | torch.save((test_white_data, test_white_targets), os.path.join(save_path, 'test_white.npy')) 288 | torch.save((test_black_data, test_black_targets), os.path.join(save_path, 'test_black.npy')) 289 | 290 | # save data as npy 291 | z = 0 292 | start = time.time() 293 | for x in range(full_data.shape[0]): 294 | for y in range(2): 295 | if full_labels[x] == y: 296 | temp = (full_data[x, :]) 297 | 298 | directory = os.path.join(save_path, str(label_encoder.classes_[y])) 299 | if not os.path.exists(directory): 300 | os.makedirs(directory) 301 | 302 | np.save((directory + '/' + str(z) + '.npy'), temp) 303 | 304 | z += 1 305 | 306 | end = time.time() 307 | 308 | print('Time to process: ', end - start) 309 | print(z, ' datapoints saved to path') 310 | return label_encoder.classes_ 311 | -------------------------------------------------------------------------------- /fade/data/meta/office_caltech.py: -------------------------------------------------------------------------------- 1 | """Office and Caltech10""" 2 | from __future__ import print_function 3 | 4 | import os 5 | from torchvision.datasets.folder import default_loader, ImageFolder 6 | 7 | from fade.file import FileManager 8 | 9 | ALL_SOURCES = ["office31", "officehome65"] 10 | ALL_DOMAINS = { 11 | "office31": ["amazon", "dslr", "webcam"], 12 | "officehome65": ["Art", "Clipart", "Product", "RealWorld"], 13 | } 14 | ALL_URLS = { 15 | # links from https://github.com/tim-learn/SHOT 16 | "office31": "https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view", 17 | "officehome65": "https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view?resourcekey=0-2SNWq0CDAuWOBRRBL7ZZsw" 18 | } 19 | 20 | 21 | class LoadImageFolder(ImageFolder): 22 | """Different from ImageFolder, you need to use transform to load images. If transform is None, 23 | then the default loader is used to load images. Otherwise, transform has to process the path str 24 | to load images. 25 | """ 26 | 27 | def __getitem__(self, index): 28 | path, target = self.samples[index] 29 | if self.transform is None: 30 | sample = self.loader(path) 31 | else: 32 | sample = self.transform(path) 33 | if self.target_transform is not None: 34 | target = self.target_transform(target) 35 | 36 | return sample, target 37 | 38 | 39 | class DefaultImageLoader(object): 40 | """Transformer to load image from path.""" 41 | def __init__(self): 42 | pass 43 | 44 | def __call__(self, path): 45 | return default_loader(path) 46 | 47 | 48 | def get_office_caltech_dataset(source, domain, transform=None, target_transform=None, 49 | feature_type="images", load_img_by_transform=False): 50 | """load_img_by_transform: The dataset will not auto load image for transform. 51 | Use transform to process path str and transform path to images, instead.""" 52 | root = FileManager.data(os.path.join(source), is_dir=True) 53 | if source.lower() == "office31": 54 | assert feature_type == "images", "Office31 only support image features." 55 | assert domain in ALL_DOMAINS[source.lower()], f"Unknown domain: {domain}" 56 | image_path = os.path.join(root, domain, "images") 57 | if not os.path.exists(image_path): 58 | print(f"### cwd: {os.getcwd()}") 59 | raise FileNotFoundError(f"No found image directory at: {image_path}. " 60 | f"Download zip file from {ALL_URLS[source.lower()]} and unpack" 61 | f" into {root}. Verify the file structure to make" 62 | f" sure the missing image path exist.") 63 | if load_img_by_transform: 64 | ds = LoadImageFolder(root=image_path, transform=transform, 65 | target_transform=target_transform) 66 | else: 67 | ds = ImageFolder(root=image_path, transform=transform, 68 | target_transform=target_transform) 69 | elif source.lower() == "officehome65": 70 | assert feature_type == "images", "OfficeHome65 only support image features." 71 | assert domain in ALL_DOMAINS[source.lower()], f"Unknown domain: {domain}" 72 | image_path = os.path.join(root, domain) 73 | if not os.path.exists(image_path): 74 | raise FileNotFoundError(f"No found image directory at: {image_path}. " 75 | f"Download zip file from {ALL_URLS[source.lower()]} and unpack" 76 | f" into {root}. Verify the file structure to make" 77 | f" sure the missing image path exist.") 78 | if load_img_by_transform: 79 | ds = LoadImageFolder(root=image_path, transform=transform, 80 | target_transform=target_transform) 81 | else: 82 | ds = ImageFolder(root=image_path, transform=transform, 83 | target_transform=target_transform) 84 | else: 85 | raise ValueError(f"Invalid source: {source}") 86 | return ds 87 | 88 | 89 | def main(): 90 | """Verify the consistence of classes.""" 91 | for source in ALL_SOURCES: 92 | class_to_idx = [] 93 | for domain in ALL_DOMAINS[source]: 94 | print() 95 | print(f"====== source: {source}, domain: {domain} ======") 96 | ds = get_office_caltech_dataset(source, domain) 97 | print(f" classes: {ds.classes}") 98 | print(f" class_to_idx: {ds.class_to_idx}") 99 | if len(class_to_idx) == 0: 100 | pass 101 | else: 102 | for (name0, idx0), (name1, idx1) in zip(class_to_idx[-1].items(), ds.class_to_idx.items()): 103 | assert name0 == name1 104 | assert idx0 == idx1 105 | print(f"[OK] {len(ds.classes)} classes in domain {domain} of {source} are verified.") 106 | class_to_idx.append(ds.class_to_idx) 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /fade/data/meta/usps.py: -------------------------------------------------------------------------------- 1 | """Dataset setting and data loader for USPS. 2 | Modified from 3 | + https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py 4 | + SHOT 5 | """ 6 | 7 | import gzip 8 | import os 9 | import pickle 10 | import urllib 11 | from PIL import Image 12 | 13 | import numpy as np 14 | import torch 15 | import torch.utils.data as data 16 | from torch.utils.data.sampler import WeightedRandomSampler 17 | from torchvision import datasets, transforms 18 | 19 | 20 | class USPS(data.Dataset): 21 | """USPS Dataset. 22 | Args: 23 | root (string): Root directory of dataset where dataset file exist. 24 | train (bool, optional): If True, resample from dataset randomly. 25 | download (bool, optional): If true, downloads the dataset 26 | from the internet and puts it in root directory. 27 | If dataset is already downloaded, it is not downloaded again. 28 | transform (callable, optional): A function/transform that takes in 29 | an PIL image and returns a transformed version. 30 | E.g, ``transforms.RandomCrop`` 31 | """ 32 | 33 | url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl" 34 | 35 | def __init__(self, root, train=True, transform=None, download=False): 36 | """Init USPS dataset.""" 37 | # init params 38 | self.root = os.path.expanduser(root) 39 | self.filename = "usps_28x28.pkl" 40 | self.train = train 41 | # Num of Train = 7438, Num ot Test 1860 42 | self.transform = transform 43 | self.dataset_size = None 44 | 45 | # download dataset. 46 | if download: 47 | self.download() 48 | if not self._check_exists(): 49 | raise RuntimeError("Dataset not found." + 50 | " You can use download=True to download it") 51 | 52 | self.data, self.targets = self.load_samples() 53 | if self.train: 54 | total_num_samples = self.targets.shape[0] 55 | indices = np.arange(total_num_samples) 56 | self.data = self.data[indices[0:self.dataset_size], ::] 57 | self.targets = self.targets[indices[0:self.dataset_size]] 58 | self.data *= 255.0 59 | self.data = np.squeeze(self.data).astype(np.uint8) 60 | 61 | def __getitem__(self, index): 62 | """Get images and target for data loader. 63 | Args: 64 | index (int): Index 65 | Returns: 66 | tuple: (image, target) where target is index of the target class. 67 | """ 68 | img, label = self.data[index], self.targets[index] 69 | img = Image.fromarray(img, mode='L') 70 | img = img.copy() 71 | if self.transform is not None: 72 | img = self.transform(img) 73 | return img, label.astype("int64") 74 | 75 | def __len__(self): 76 | """Return size of dataset.""" 77 | return len(self.data) 78 | 79 | def _check_exists(self): 80 | """Check if dataset is download and in right place.""" 81 | return os.path.exists(os.path.join(self.root, self.filename)) 82 | 83 | def download(self): 84 | """Download dataset.""" 85 | filename = os.path.join(self.root, self.filename) 86 | dirname = os.path.dirname(filename) 87 | if not os.path.isdir(dirname): 88 | os.makedirs(dirname) 89 | if os.path.isfile(filename): 90 | return 91 | print("Download %s to %s" % (self.url, os.path.abspath(filename))) 92 | urllib.request.urlretrieve(self.url, filename) 93 | print("[DONE]") 94 | return 95 | 96 | def load_samples(self): 97 | """Load sample images from dataset.""" 98 | filename = os.path.join(self.root, self.filename) 99 | f = gzip.open(filename, "rb") 100 | data_set = pickle.load(f, encoding="bytes") 101 | f.close() 102 | if self.train: 103 | images = data_set[0][0] 104 | labels = data_set[0][1] 105 | self.dataset_size = labels.shape[0] 106 | else: 107 | images = data_set[1][0] 108 | labels = data_set[1][1] 109 | self.dataset_size = labels.shape[0] 110 | return images, labels 111 | -------------------------------------------------------------------------------- /fade/data/multi_domain.py: -------------------------------------------------------------------------------- 1 | """Fed set including multiple domains. 2 | 3 | Examples: 4 | # Mnist + MnistM 5 | python -m fade.data.multi_domain -cn comb/MnistM 6 | # change n_user of Mnist 7 | python -m fade.data.multi_domain -cn comb/MnistM --cfg job dataset.meta_datasets.0.n_user=10 8 | 9 | # Office31_A2W 10 | python -m fade.data.multi_domain -cn comb/Office31_A2W 11 | 12 | =================== 13 | # MnistM + Mnist 14 | # Step 1: generate meta sets 15 | python -m fade.data.extend -d Mnist -c 5 -u 20 16 | python -m fade.data.extend -d MnistM -c 5 -u 20 17 | # python -m fade.data.extend -d Svhn -c 5 -u 20 18 | # Step 1: combine them 19 | python -m fade.data.multi_domain -d comb/MnistM_c5u40 20 | """ 21 | import os 22 | import torch 23 | import numpy as np 24 | from typing import List 25 | # import argparse 26 | import hydra 27 | from omegaconf import DictConfig, ListConfig 28 | from typing import Iterable 29 | 30 | from fade.file import FileManager 31 | from fade.utils import _log_time_usage 32 | from .utils import load_fed_dataset 33 | from . import FedDataset 34 | 35 | 36 | class MDFedDataset(FedDataset): 37 | def _generate(self): 38 | assert isinstance(self.cfg.meta_datasets, Iterable), f"type is {type(self.cfg.meta_datasets)}" 39 | assert isinstance(self.cfg.meta_datasets[0], DictConfig) 40 | 41 | if hasattr(self.cfg, "meta_fed_ds") and self.cfg.meta_fed_ds != "extend": 42 | if self.cfg.meta_fed_ds == "federalize": 43 | from .federalize import FedExtDataset as ExtendedDataset 44 | else: 45 | raise RuntimeError(f"meta_fed_ds: {self.cfg.meta_fed_ds}") 46 | else: 47 | raise ValueError(f"Not support 'extend' module anymore. Set 'meta_fed_ds' in " 48 | f"dataset as 'federalize'.") 49 | 50 | # Auto set n user for the unset one. 51 | if hasattr(self.cfg, 'total_n_user'): 52 | n_users = [md.n_user for md in self.cfg.meta_datasets] 53 | void_idx = np.nonzero(np.array(n_users) == -1)[0] 54 | num_void = len(void_idx) 55 | assert num_void <= 1, f"Only allow one meta dataset to set n_user " \ 56 | f"as -1, but get {num_void}. All settings " \ 57 | f"are {n_users}." 58 | if num_void > 0: 59 | void_idx = int(void_idx[0]) 60 | cur_n_user = sum(n_users) + 1 # complement -1 61 | assert cur_n_user < self.cfg.total_n_user,\ 62 | f"Already have {cur_n_user} users more than total " \ 63 | f"{self.cfg.total_n_user} users." 64 | self.cfg.meta_datasets[void_idx].n_user = self.cfg.total_n_user - cur_n_user 65 | print(f"Set the {void_idx}-th meta dataset with {self.cfg.meta_datasets[void_idx].n_user} users.") 66 | 67 | # check meta datasets existence or generate new. 68 | for ds_cfg in self.cfg.meta_datasets: 69 | ext_ds = ExtendedDataset(ds_cfg) 70 | for subset in ("train", "test"): 71 | if not ext_ds.exist(subset): 72 | print(f"Not found {subset} for {ds_cfg.name}") 73 | ext_ds.generate() 74 | break 75 | else: 76 | print(f"CHECKED: {ds_cfg.name} has been generated with specific config.") 77 | 78 | return combine(self.cfg.meta_datasets, ExtendedDataset) 79 | 80 | 81 | def combine(meta_datasets, ExtendedDataset): 82 | """Read and combine datasets from the list. 83 | 84 | Args: 85 | meta_datasets (List[str] | ListConfig): List of meta fed datasets. 86 | 87 | Returns: 88 | A dict compress train/test fed sets. 89 | """ 90 | comb_ds = { 91 | "train": {"users": [], "user_data": {}, "hierarchies": [], "num_samples": []}, 92 | "test": {"users": [], "user_data": {}, "hierarchies": [], "num_samples": []}, 93 | } 94 | 95 | def rename_client_id(data_name, old_id): 96 | return f"{data_name}_{old_id}" 97 | 98 | if isinstance(meta_datasets[0], DictConfig): 99 | subsets = ("train", "test") 100 | for cfg in meta_datasets: 101 | fed_dict = ExtendedDataset(cfg).load(subsets) 102 | 103 | for subset in subsets: 104 | data_dict = fed_dict[subset] 105 | assert ("hierarchies" not in data_dict) or len(data_dict["hierarchies"]) == 0, \ 106 | "Not support: meta dataset include groups." 107 | new_users = [rename_client_id(cfg.name, c) for c in data_dict["users"]] 108 | comb_ds[subset]["users"].extend(new_users) 109 | comb_ds[subset]["num_samples"].extend(data_dict["num_samples"]) 110 | for c in data_dict["users"]: 111 | comb_ds[subset]["user_data"][rename_client_id(cfg.name, c)] = \ 112 | data_dict["user_data"][c] 113 | # NOTE: not sure how to construct this. just a try. 114 | comb_ds[subset]["hierarchies"].extend([cfg.name] * len(new_users)) 115 | else: 116 | for data_name in meta_datasets: 117 | for subset in ("train", "test"): 118 | print(f"Reading {data_name}, {subset}...") 119 | fed_data_dict = load_fed_dataset(data_name, subset=subset) 120 | 121 | assert ("hierarchies" not in fed_data_dict) or len(fed_data_dict["hierarchies"]) == 0, \ 122 | "Not support: meta dataset include groups." 123 | new_users = [rename_client_id(data_name, c) for c in fed_data_dict["users"]] 124 | comb_ds[subset]["users"].extend(new_users) 125 | comb_ds[subset]["num_samples"].extend(fed_data_dict["num_samples"]) 126 | for c in fed_data_dict["users"]: 127 | comb_ds[subset]["user_data"][rename_client_id(data_name, c)] = fed_data_dict["user_data"][c] 128 | # NOTE: not sure how to construct this. just a try. 129 | comb_ds[subset]["hierarchies"].extend([data_name] * len(new_users)) 130 | 131 | return comb_ds 132 | 133 | 134 | @hydra.main(config_name="comb/MnistM", config_path="../config/dataset/") 135 | def main(cfg): 136 | MDFedDataset(cfg.dataset).generate() 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /fade/file.py: -------------------------------------------------------------------------------- 1 | """Utility for accessing files.""" 2 | # from __future__ import annotations 3 | import os 4 | from pathlib import Path 5 | import shutil 6 | from enum import Enum, auto 7 | from datetime import datetime 8 | import matplotlib.pyplot as plt 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | import fade 13 | _absolute_global_root_path = Path(os.path.realpath(fade.__file__)).parents[1] 14 | global_root_path = _absolute_global_root_path # absolute path. 15 | 16 | 17 | class FileManager(object): 18 | """All the path is based on the `src` folder.""" 19 | def __init__(self): 20 | pass 21 | 22 | task_path = "" # TODO set the task path in experiments. 23 | 24 | class CleanMode(Enum): 25 | all = auto() 26 | log = auto() 27 | out = auto() 28 | none = auto() 29 | # old = auto() # TODO only keep the lastest 30 | 31 | @staticmethod 32 | def clean_dir(dir_to_clean: str): 33 | if os.path.exists(dir_to_clean): 34 | logger.info(f"Recursively removing dir: {dir_to_clean}") 35 | 36 | def rm_error_handler(function, path, excinfo): 37 | e = excinfo[1] 38 | if isinstance(e, OSError) and e.errno == 39: # 39: e.strerror == "Directory not empty": 39 | # print(f"Delete non-empty dir: {path}") 40 | # shutil.rmtree(path, ignore_errors=True) 41 | logger.error(f"Fail to delete folder '{path}' due to some ramianing files. " 42 | f"Try to close the tensorboard or any occupying process.") 43 | raise e 44 | try: 45 | shutil.rmtree(dir_to_clean, onerror=rm_error_handler) 46 | except OSError as e: 47 | print(f"Error: {e.filename} - {e.strerror}.") 48 | 49 | @classmethod 50 | def generate_path_with_root(cls, root, subdir, is_dir=False, create_new=True, 51 | overwrite=True, verbose=False, return_date=False): 52 | """Generate path given root and sub-dir. 53 | 54 | :param root: The root of the path, e.g., 'out', 'data'. This is used for specify the function. 55 | :param subdir: (or filename) Usually, this presents the unique name of the storage, e.g., data name. 56 | :param is_dir: True if the `subdir` is a dir, else it is treated as a file. 57 | :param create_new: Create new folder and sub-folders if not exists. 58 | :param overwrite: Overwrite folders if exists. 59 | :param verbose: Print info when overwriting or creating. 60 | :param return_date: Return the date when the file/dir was modified. 61 | :return: Generated path str. str of modification time of the file (if required). 62 | """ 63 | path = os.path.join(global_root_path, root, cls.task_path, subdir) 64 | if create_new: 65 | if is_dir: 66 | fld = path 67 | else: 68 | fld = os.path.dirname(path) 69 | if os.path.exists(path): # os.path.dirname(path)): 70 | if overwrite: 71 | pass 72 | # if verbose: 73 | # logger.warning("'{}' already exists. Overwrite...".format(fld)) 74 | else: 75 | return path 76 | elif verbose: 77 | logger.debug("Creating dir: {}".format(fld)) 78 | os.makedirs(fld, exist_ok=overwrite) 79 | if return_date: 80 | if os.path.exists(path): 81 | import pathlib 82 | pp = pathlib.Path(path) 83 | mtime = datetime.fromtimestamp(pp.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S") 84 | else: 85 | mtime = "n/a" 86 | return path, mtime 87 | return path 88 | 89 | @classmethod 90 | def out(cls, filename, **kwargs): 91 | """Generate path given folder. 92 | 93 | :param filename: This presents the unique name of the storage, e.g., data name. 94 | :param create_new: Create new folder and sub-folders if not exists. 95 | :param overwrite: Overwrite folders if exists. 96 | :param verbose: Print info when overwriting or creating. 97 | :return: Generated path str. 98 | """ 99 | return cls.generate_path_with_root("out", filename, **kwargs) 100 | 101 | @classmethod 102 | def log(cls, filename, **kwargs): 103 | """Generate path given folder. 104 | 105 | :param filename: This presents the unique name of the storage, e.g., data name. 106 | :param create_new: Create new folder and sub-folders if not exists. 107 | :param overwrite: Overwrite folders if exists. 108 | :param verbose: Print info when overwriting or creating. 109 | :return: Generated path str. 110 | """ 111 | return cls.generate_path_with_root("log", filename, **kwargs) 112 | 113 | @classmethod 114 | def data(cls, filename, **kwargs): 115 | """Generate path given folder. 116 | 117 | :param filename: data filename. 118 | :param create_new: Create new folder and sub-folders if not exists. 119 | :param overwrite: Overwrite folders if exists. 120 | :param verbose: Print info when overwriting or creating. 121 | :return: Generated path str. 122 | """ 123 | return cls.generate_path_with_root("data", filename, **kwargs) 124 | 125 | hpcc_jobid = "" 126 | __logid = None # type: str 127 | __logdir = None # type: str 128 | 129 | @classmethod 130 | def get_logid(cls): 131 | """Return a unique id for the current log.""" 132 | if cls.__logid is None: 133 | cls.__logid = datetime.now().strftime("%Y%m%d-%H%M%S") 134 | if len(cls.hpcc_jobid) > 0: 135 | cls.__logid = cls.hpcc_jobid + "@" + cls.__logid 136 | return cls.__logid 137 | 138 | @classmethod 139 | def get_logdir(cls, new=True): 140 | """Get the log path. Set `new` False if an existing path is expected.""" 141 | if cls.__logdir is None: 142 | cls.__logdir = cls.log(cls.get_logid(), is_dir=True, create_new=new, 143 | overwrite=new, verbose=True) 144 | return cls.__logdir 145 | 146 | 147 | def save_current_fig(name="untitled"): 148 | file_name = FileManager.out(f"./fig/{name}.pdf", 149 | create_new=True, overwrite=True) 150 | # file_name = FileManager.out(f"fig/{name}.pdf", create_new=True, overwrite=True) 151 | plt.savefig(file_name, bbox_inches="tight") 152 | print(f"save figure => {file_name}") 153 | -------------------------------------------------------------------------------- /fade/mainx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from fade.server.base import ServerAgent 4 | import torch 5 | from time import time 6 | from datetime import timedelta 7 | import numpy as np 8 | import hydra 9 | from omegaconf import DictConfig, OmegaConf 10 | import wandb 11 | import logging 12 | 13 | os.putenv("LC_ALL", "C.UTF-8") 14 | os.putenv("LANG", "C.UTF-8") 15 | os.putenv("LANGUAGE", "C.UTF-8") 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def train_loop_body(i, seed, server_agent: ServerAgent, device): 20 | """Major function to create server and run training.""" 21 | cfg = server_agent.full_config 22 | if cfg.i_rep >= 0 and cfg.i_rep != i: 23 | return 24 | # change seed every time. 25 | torch.manual_seed(seed) 26 | np.random.seed(seed) 27 | print(f"--------------- Running with seed #{i}: {seed} ------------") 28 | server = server_agent.create_server(i, device=device) 29 | if cfg.load_model.do: 30 | server.load_model(hash_name=cfg.load_model.hash_name, 31 | to_load=cfg.load_model.load) 32 | if 'user' not in cfg.load_model.load: 33 | server.send_parameters(partial=False) 34 | 35 | # Run training with clients 36 | server.train() 37 | 38 | # The final eval results will not override the one runed inside train(). 39 | # NOTE for Central server, we use use the server model for eval 40 | res_dict = server.evaluate(reduce_users=False, add_res_to_record=False, return_dict=True, 41 | personal=False, full_info=False, model=None) 42 | res_dict = dict(("g_" + k, v) for k, v in res_dict.items()) 43 | server.log(res_dict, commit=True) 44 | server.dump_to_file(personal=False, key="user_eval", obj=res_dict) 45 | 46 | 47 | def eval_loop_body(i, seed, cfgs, server_agent: ServerAgent): 48 | """Evaluation.""" 49 | if cfgs.i_rep >= 0 and cfgs.i_rep != i: 50 | return 51 | # change seed every time. 52 | torch.manual_seed(seed) 53 | print(f"---------------Running time: {i}, seed {seed} ------------") 54 | server = server_agent.create_server(i, device=cfgs.device) 55 | assert cfgs.load_model.do, "Model is required to be loaded when evaluating." 56 | print(f"Warning: Load models by specifying hash name as: {cfgs.load_model.hash_name}") 57 | server.load_model(hash_name=cfgs.load_model.hash_name, to_load=cfgs.load_model.load) 58 | # Send model to users. 59 | # NOTE send params will not send states e.g., running mean of BatchNorm. 60 | if 'user' not in cfgs.load_model.load: 61 | logger.warning(f"Not load users' models. Sending loaded server model params to users. " 62 | f"NOTE: this will not send states of BN's, which " 63 | f"may cause low acc in some cases.") 64 | server.send_parameters(partial=False) 65 | 66 | # The eval results will not override the one runed inside train(). 67 | eval_dict = \ 68 | server.evaluate(reduce_users=False, full_info=True, add_res_to_record=False, 69 | model=None, return_dict=True, 70 | personal=False) # NOTE: only evaluate server model. 71 | # NOTE Remove info that can not be log, e.g., z which may cause inconsistent #sample between users. 72 | server.dump_to_file(personal=False, key="full_user_eval", obj=eval_dict) 73 | for info_group in ("train", "test"): 74 | # DO NOT log the vector info. 75 | for k in ('pred_group', 'pred_y', 'true_y', 'z'): 76 | for _id in eval_dict["extra_info"][info_group]: 77 | eval_dict["extra_info"][info_group][_id].pop(k) 78 | server.log(eval_dict, commit=True, is_summary=True) 79 | 80 | 81 | def run(cfgs): 82 | """Handle the non-critical config options.""" 83 | assert cfgs.i_rep < cfgs.n_rep, f"Found cfgs.i_rep ({cfgs.i_rep}) >= cfgs.n_rep ({cfgs.n_rep})" 84 | if cfgs.action in ["train", "eval"] and "wandb" in cfgs.logger.loggers \ 85 | and cfgs.n_jobs == 1 and cfgs.i_rep >= 0: 86 | wandb.init(**cfgs.logger.wandb, reinit=True, 87 | config=OmegaConf.to_container(cfgs, resolve=True, enum_to_str=True)) 88 | else: 89 | cfgs.logger.loggers = [v for v in cfgs.logger.loggers if v != "wandb"] 90 | 91 | # server agent will spawn duplicated servers at run. 92 | server_agent = ServerAgent(args=cfgs, times=cfgs.n_rep) 93 | 94 | rng = np.random.RandomState(cfgs.seed) 95 | random_seeds = rng.randint(np.iinfo(np.int32).max, size=10) # used for repetitions 96 | if cfgs.action in ["train", "average", "eval", "check_files"]: 97 | server_agent.preload_data(print_stat=True) 98 | 99 | # choose action 100 | if cfgs.action == "train": 101 | for i in range(cfgs.n_rep): 102 | train_loop_body(i, random_seeds[i], server_agent, cfgs.device) 103 | elif cfgs.action == "eval": 104 | for i in range(cfgs.n_rep): 105 | eval_loop_body(i, random_seeds[i], cfgs, server_agent) 106 | elif cfgs.action == "check_files": # check generated files, e.g., saved models. 107 | server = server_agent.create_server(cfgs.n_rep) 108 | server.print_config() 109 | print("Hash name:", server.get_hash_name(include_rep=False)) 110 | server.check_files(verbose=True, personal=False, times=range(cfgs.n_rep)) 111 | 112 | if cfgs.action in ["train", "average"]: 113 | # Average data 114 | # NOTE: The returned metric value may not include all users if `partial_eval` is true. 115 | if not (cfgs.action == "train" and cfgs.i_rep >= 0): 116 | res_stat = server_agent.average_results_and_save(cfgs.n_rep, cfgs.num_glob_iters) 117 | if cfgs.action in ["train", "eval"]: 118 | res_stat = {} 119 | else: 120 | res_stat = None 121 | 122 | return server_agent.get_hash_name(), res_stat 123 | 124 | 125 | @hydra.main(config_name="config.yaml", config_path="config") 126 | def app_main(args: DictConfig): 127 | print("=" * 60) 128 | print("Summary of training process:") 129 | # print("Dataset : {}".format(args.dataset.name)) 130 | print("Algorithm : {}".format(args.server.name)) 131 | print("Local Model : {}".format(args.model.name)) 132 | print("Optimizer : {}".format(args.user.optimizer.name)) 133 | print("Loss : {}".format(args.user.loss)) 134 | print("Batch size : {}".format(args.user.batch_size)) 135 | print("Learning rate : {}".format(args.user.optimizer.learning_rate)) 136 | print("Moving Average : {}".format(args.server.beta)) 137 | print("Subset of users : {}".format(args.server.num_users)) 138 | print("Num of global rounds : {}".format(args.num_glob_iters)) 139 | print("Num of local rounds : {}".format(args.user.local_epochs)) 140 | print("Partial evaluate? : {}".format(args.partial_eval)) 141 | print("Device : {}".format(args.device)) 142 | print("Logging : {}".format(args.logging)) 143 | print("=" * 60) 144 | 145 | from fade.utils import set_coloredlogs_env 146 | import coloredlogs 147 | # logging.basicConfig(level=args.logging.upper()) 148 | set_coloredlogs_env() 149 | coloredlogs.install(level=args.logging.upper()) 150 | 151 | start_time = time() 152 | exp_name, res_stat = run(cfgs=args) 153 | end_time = time() 154 | elapsed = str(timedelta(seconds=end_time - start_time)) 155 | print(f"--- Elapsed: {elapsed} secs ---") 156 | 157 | 158 | def entry(): 159 | # this function is required to allow automatic detection of the module name when running 160 | # from a binary script. 161 | # it should be called from the executable script and not the hydra.main() function directly. 162 | app_main() 163 | 164 | 165 | if __name__ == '__main__': 166 | app_main() 167 | -------------------------------------------------------------------------------- /fade/model/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | from torch.nn import Module 3 | 4 | 5 | def import_model(model_name, sub_module) -> Type[Module]: 6 | # These has to be loaded for access the `if_personal_local_adaptation()`. 7 | import importlib 8 | module = importlib.import_module(f"fade.model.{sub_module}") 9 | return getattr(module, model_name) 10 | 11 | 12 | def get_model(name="", dataset="", **kwargs): 13 | model = None 14 | if "Mnist" in dataset or dataset in ["comb/Digit", "comb/M2U", "comb/U2M", "comb/S2M", "USPS", "SVHN"]: 15 | if name == "cnn-split": 16 | name = "MnistCnnSplit" 17 | elif name == "cnn-seprep": 18 | name = "MnistCnnSepRep" 19 | model = import_model(name, "mnist")(**kwargs) 20 | elif dataset.startswith("Office") or dataset.startswith("comb/Office") \ 21 | or dataset.startswith("Visda") or dataset.startswith("comb/Visda")\ 22 | or dataset.startswith("DomainNet") or dataset.startswith("comb/DomainNet"): 23 | if name == "cnn-split": 24 | name = "OfficeCnnSplit" 25 | elif name == "dnn-split": 26 | name = "OfficeDnnSplit" 27 | model = import_model(name, "office")(**kwargs) 28 | elif dataset.startswith("Celeba"): 29 | if name == "cnn-split": 30 | name = "CelebaCnnSplit" 31 | model = import_model(name, "celeba")(**kwargs) 32 | elif dataset.startswith("Adult") or dataset.startswith("comb/Adult"): 33 | if name == "dnn-split": 34 | name = "AdultDNNSplit" 35 | model = import_model(name, "adult")(**kwargs) 36 | if model is None: 37 | raise NotImplementedError(f"{name}, {dataset}, thus model is {model}") 38 | return model 39 | -------------------------------------------------------------------------------- /fade/model/adult.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .adv import SplitAdvNet 4 | from .shot import init_weights 5 | from .utils import freeze_model 6 | 7 | 8 | class AdultDNNSplit(SplitAdvNet): 9 | """ 10 | Three branch network with GRL. 11 | """ 12 | def __deepcopy__(self, memodict={}): 13 | new_model = self.__class__(mid_dim=self.mid_dim, 14 | n_class=self.n_class, 15 | n_task=self.n_task, in_channel=self.in_channel, 16 | rev_lambda_scale=self.rev_lambda_scale, 17 | freeze_backbone=self.freeze_backbone, 18 | freeze_decoder=self.freeze_decoder, 19 | disable_bn_stat=self.disable_bn_stat, 20 | CDAN_task=self.CDAN_task, 21 | ).to('cuda') 22 | if hasattr(self, 'feature_extractor'): 23 | new_model.feature_extractor.load_state_dict(self.feature_extractor.state_dict()) 24 | # # NOTE this may ignore some class args. 25 | new_model.encoder.load_state_dict(self.encoder.state_dict()) 26 | new_model.decoder.load_state_dict(self.decoder.state_dict()) 27 | new_model.task_decoder.load_state_dict(self.task_decoder.state_dict()) 28 | return new_model 29 | 30 | def __init__(self, mid_dim=512, n_class=10, n_task=0, 31 | in_channel=3, rev_lambda_scale=1., freeze_backbone=True, freeze_decoder=False, 32 | disable_bn_stat=False, CDAN_task=False): 33 | ''' 34 | Args: 35 | alpha: L = L_utility - alpha * L_adversarial 36 | ''' 37 | super().__init__(mid_dim=mid_dim, n_class=n_class, n_task=n_task, 38 | rev_lambda_scale=rev_lambda_scale, freeze_backbone=freeze_backbone, 39 | freeze_decoder=freeze_decoder, disable_bn_stat=disable_bn_stat) 40 | self.in_channel = in_channel 41 | self.CDAN_task = CDAN_task 42 | 43 | if self.CDAN_task: 44 | task_fea_dim = mid_dim * self.n_class 45 | else: 46 | task_fea_dim = mid_dim 47 | 48 | # f: shared feature extractor 49 | self.feature_extractor = torch.nn.Sequential( 50 | torch.nn.Linear(110, 100), 51 | torch.nn.ReLU(inplace=True), 52 | torch.nn.Dropout(p=0.25), 53 | ) 54 | freeze_model(self.feature_extractor, self.freeze_backbone) 55 | self.encoder = torch.nn.Linear(100, self.mid_dim) 56 | 57 | # g: utility task classifier: 58 | self.decoder = torch.nn.Sequential( 59 | torch.nn.Linear(self.mid_dim, 32), 60 | torch.nn.ReLU(inplace=True), 61 | torch.nn.Dropout(p=0.25), 62 | torch.nn.Linear(32, 2) 63 | ) 64 | freeze_model(self.decoder, self.freeze_decoder) 65 | 66 | if self.n_task > 0: 67 | # h: adversarial/privacy task classifier: 68 | self.task_decoder = torch.nn.Sequential( 69 | torch.nn.Linear(task_fea_dim, 32), 70 | torch.nn.ReLU(inplace=True), 71 | torch.nn.Dropout(p=0.25), 72 | torch.nn.Linear(32, self.n_task) 73 | ) 74 | self.task_decoder.apply(init_weights) 75 | self.shared += [p for p in self.task_decoder.parameters(recurse=True)] 76 | self.shared += [p for p in self.encoder.parameters(recurse=True)] 77 | self.private += [p for p in self.decoder.parameters(recurse=True)] 78 | -------------------------------------------------------------------------------- /fade/model/adv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from typing import Union, Dict 5 | from .split import SplitEncoder 6 | from .utils import GradientReversalFunction 7 | from .shot import init_weights 8 | 9 | 10 | class SplitAdvNet(SplitEncoder): 11 | """Split Adversarial Net with classifier and discriminator heads above the encoder. 12 | 13 | Model structure refer to Peng et al. ICLR 2020, Fed Adv Domain Adaptation. 14 | 15 | Args: 16 | backbone: 17 | - `alexnet`: Use AlexNet as feature_extractor. 18 | - `resnet`: Use ResNet101 as feature_extractor. 19 | We load the model pretrained on ImageNet from torch's rep. The module 20 | code can be found in https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py 21 | mid_dim: Split dimension size. NOTE this is useless for backbone=lenet5a. 22 | n_task: If n_task > 0, construct a `task_decoder` and predict task. 23 | rev_lambda: The param for Reversal Gradient layer in task decoder. 24 | in_channel: Num of input channels. 25 | disable_bn_stat: If set True, the batch norm layers will not update running mean and std 26 | both in train and eval mode. 27 | ``` 28 | """ 29 | def train(self, mode: bool = True): 30 | super(SplitAdvNet, self).train(mode) 31 | 32 | if self.disable_bn_stat: 33 | def stop_bn_stat(m): 34 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 35 | m.eval() 36 | self.apply(stop_bn_stat) 37 | return self 38 | 39 | def __deepcopy__(self, memodict={}): 40 | new_model = self.__class__(mid_dim=self.mid_dim, n_class=self.n_class, 41 | n_task=self.n_task, 42 | rev_lambda_scale=self.rev_lambda_scale, 43 | freeze_backbone=self.freeze_backbone, 44 | freeze_decoder=self.freeze_decoder, 45 | disable_bn_stat=self.disable_bn_stat, 46 | ).to('cuda') 47 | if hasattr(self, 'feature_extractor'): 48 | new_model.feature_extractor.load_state_dict(self.feature_extractor.state_dict()) 49 | new_model.encoder.load_state_dict(self.encoder.state_dict()) 50 | new_model.decoder.load_state_dict(self.decoder.state_dict()) 51 | new_model.task_decoder.load_state_dict(self.task_decoder.state_dict()) 52 | return new_model 53 | 54 | def load_state_dict(self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], 55 | strict: bool = True): 56 | # Remove task_encoder keys 57 | state_dict = dict((k, v) for k, v in state_dict.items() if not k.startswith('task_decoder')) 58 | ret = super().load_state_dict(state_dict, strict=False) 59 | # missing_keys, unexpected_keys 60 | if strict: 61 | assert len(ret.unexpected_keys) == 0, f"Got unexpected_keys: {ret.unexpected_keys}" 62 | unexpected_missing_keys = [] 63 | for k in ret.missing_keys: 64 | if not k.startswith('task_decoder'): 65 | unexpected_missing_keys.append(k) 66 | if len(unexpected_missing_keys) > 0: 67 | raise RuntimeError(f"Got unexpected missing keys: {unexpected_missing_keys}." 68 | f" Only allow task encoder keys to be missing.") 69 | return ret 70 | 71 | def __init__(self, mid_dim=512, n_class=10, n_task=0, 72 | rev_lambda_scale=1., freeze_backbone=True, freeze_decoder=False, 73 | disable_bn_stat=False): 74 | super().__init__() 75 | self.mid_dim = mid_dim 76 | # self.n_class = n_class 77 | self.n_class = n_class 78 | self.n_task = n_task 79 | self.rev_lambda_scale = rev_lambda_scale 80 | self.freeze_backbone = freeze_backbone 81 | self.freeze_decoder = freeze_decoder 82 | self.disable_bn_stat = disable_bn_stat 83 | 84 | def can_predict_task(self): 85 | return hasattr(self, 'task_decoder') 86 | 87 | def encode(self, x: torch.Tensor, a=0.5): 88 | if hasattr(self, 'feature_extractor'): 89 | x = self.feature_extractor(x) 90 | return self.encoder(x) 91 | 92 | def decode(self, z, a=0.5) -> torch.Tensor: 93 | return self.decoder(z) 94 | 95 | def predict_task(self, z, rev_lambda) -> torch.Tensor: 96 | z = GradientReversalFunction.apply(z, rev_lambda * self.rev_lambda_scale) 97 | # print("### z", z.shape) 98 | # print("### task_decoder", self.task_decoder) 99 | return self.task_decoder(z) 100 | 101 | def get_shared_submodule(self): 102 | return [self.encoder, self.task_decoder] 103 | 104 | def get_private_submodule(self): 105 | return self.decoder 106 | 107 | def get_param_group_with_lr(self, lr, param_group=[], **kwargs): 108 | for k, v in self.feature_extractor.named_parameters(): 109 | param_group += [{'params': v, 'lr': lr * 0.1, **kwargs}] 110 | for k, v in self.encoder.named_parameters(): 111 | param_group += [{'params': v, 'lr': lr, **kwargs}] 112 | for k, v in self.decoder.named_parameters(): 113 | param_group += [{'params': v, 'lr': lr, **kwargs}] 114 | for k, v in self.task_decoder.named_parameters(): 115 | param_group += [{'params': v, 'lr': lr, **kwargs}] 116 | return param_group 117 | 118 | def reset_task_decoder(self): 119 | """Reset task decoder by re-init.""" 120 | self.task_decoder.apply(init_weights) 121 | -------------------------------------------------------------------------------- /fade/model/mnist.py: -------------------------------------------------------------------------------- 1 | """Models for mnist-class datasets.""" 2 | import torch 3 | from torch import nn 4 | 5 | from .adv import SplitAdvNet 6 | from .utils import GradientReversalFunction, freeze_model 7 | 8 | 9 | class MnistCnnSplit(SplitAdvNet): 10 | """CNN split network for Mnist. 11 | Old name: MnistCNN_pEnc 12 | 13 | Args: 14 | backbone: One of 'lenet5a' (default), 'lenet5b', 15 | - `lenet5a` is "modified" version of LeNet5 using ReLU. 16 | - `lenet5b` is similar to "lenet5a" where we split at the last layer. 17 | mid_dim: Split dimension size. NOTE this is useless for backbone=lenet5a. 18 | n_task: If n_task > 0, construct a `task_decoder` and predict task. 19 | rev_lambda: The param for Reversal Gradient layer in task decoder. 20 | in_channel: Num of input channels. 21 | """ 22 | 23 | def __deepcopy__(self, memodict={}): 24 | new_model = self.__class__(backbone=self.backbone, mid_dim=self.mid_dim, n_task=self.n_task, 25 | in_channel=self.in_channel, rev_lambda_scale=self.rev_lambda_scale, 26 | n_class=self.n_class, bottleneck_type=self.bottleneck_type, 27 | freeze_decoder=self.freeze_decoder, 28 | CDAN_task=self.CDAN_task, 29 | disable_bn_stat=self.disable_bn_stat, 30 | ).to('cuda') 31 | if hasattr(self, 'feature_extractor'): 32 | new_model.feature_extractor.load_state_dict(self.feature_extractor.state_dict()) 33 | new_model.encoder.load_state_dict(self.encoder.state_dict()) 34 | new_model.decoder.load_state_dict(self.decoder.state_dict()) 35 | new_model.task_decoder.load_state_dict(self.task_decoder.state_dict()) 36 | return new_model 37 | 38 | def __init__(self, backbone="lenet5a", mid_dim=100, n_task=0, 39 | in_channel=3, rev_lambda_scale=1., n_class=10, 40 | bottleneck_type='bn', freeze_decoder=False, 41 | CDAN_task=False, disable_bn_stat=False): 42 | super().__init__(mid_dim=mid_dim, n_class=n_class, n_task=n_task, 43 | rev_lambda_scale=rev_lambda_scale, freeze_backbone=False, 44 | freeze_decoder=freeze_decoder, 45 | disable_bn_stat=disable_bn_stat) 46 | self.backbone = backbone 47 | self.in_channel = in_channel 48 | self.bottleneck_type = bottleneck_type 49 | self.CDAN_task = CDAN_task 50 | 51 | if self.CDAN_task: 52 | task_fea_dim = mid_dim * self.n_class 53 | else: 54 | task_fea_dim = mid_dim 55 | if backbone.lower() == "lenet5c": 56 | from .shot_digit import LeNetBase, feat_bootleneck, feat_classifier 57 | base = LeNetBase(self.in_channel) # NOTE for s2m, use DTN 58 | # self.feature_extractor = base # may cause error when load 59 | 60 | netB = feat_bootleneck(type=self.bottleneck_type, feature_dim=base.in_features, 61 | bottleneck_dim=mid_dim) 62 | self.encoder = nn.Sequential( 63 | base, netB 64 | ) 65 | self.decoder = feat_classifier(type='wn', class_num=self.n_class, 66 | bottleneck_dim=mid_dim) 67 | if freeze_decoder: 68 | freeze_model(self.decoder) 69 | if n_task > 0: 70 | self.task_decoder = nn.Sequential( 71 | # GradientReversal(lambda_=rev_lambda), 72 | nn.Linear(task_fea_dim, 50), 73 | nn.ReLU(), 74 | nn.Linear(50, 20), 75 | nn.ReLU(), 76 | nn.Linear(20, n_task) 77 | ) 78 | self.shared += [p for p in self.task_decoder.parameters(recurse=True)] 79 | elif backbone.lower() == "dtn": # used for SVHN 80 | from .shot_digit import DTNBase, feat_bootleneck, feat_classifier 81 | base = DTNBase(self.in_channel) 82 | netB = feat_bootleneck(type=self.bottleneck_type, feature_dim=base.in_features, 83 | bottleneck_dim=mid_dim) 84 | # self.feature_extractor = base 85 | self.encoder = nn.Sequential( 86 | base, netB 87 | ) 88 | self.decoder = feat_classifier(type='wn', class_num=self.n_class, 89 | bottleneck_dim=mid_dim) 90 | if freeze_decoder: 91 | freeze_model(self.decoder) 92 | if n_task > 0: 93 | self.task_decoder = nn.Sequential( 94 | # GradientReversal(lambda_=rev_lambda), 95 | nn.Linear(task_fea_dim, 100), 96 | nn.ReLU(), 97 | nn.Linear(100, 100), 98 | nn.ReLU(), 99 | nn.Linear(100, n_task) 100 | ) 101 | self.shared += [p for p in self.task_decoder.parameters(recurse=True)] 102 | else: 103 | raise ValueError(f"backbone {backbone}") 104 | self.shared += [p for p in self.encoder.parameters(recurse=True)] 105 | self.private += [p for p in self.decoder.parameters(recurse=True)] 106 | 107 | # self.task_clf = nn.Parameter(torch.randn((mid_dim, 1))) 108 | 109 | def can_predict_task(self): 110 | return hasattr(self, 'task_decoder') 111 | 112 | def encode(self, x: torch.Tensor, a=0.5): 113 | if len(x.shape) < 4: 114 | x = torch.reshape(x, (x.shape[0], self.in_channel, 28, 28)) 115 | # print(f"### x shape: {x.shape}") 116 | return self.encoder(x) 117 | 118 | def decode(self, z, a=0.5) -> torch.Tensor: 119 | # print(f"### z shape: {z.shape}") 120 | if self.backbone.lower() == "se": 121 | z = self.se(z) 122 | z = z.view(z.shape[0], -1) 123 | z = self.decoder0(z) 124 | return self.decoder(z) 125 | elif self.backbone.lower() in ["lenet5a", "dann15"]: 126 | z = z.view(z.shape[0], -1) 127 | return self.decoder(z) 128 | else: 129 | return self.decoder(z) 130 | # F.log_softmax(x, dim=1) 131 | 132 | def predict_task(self, z, rev_lambda): 133 | # return torch.sum((z - torch.reshape(self.task_clf, (1, -1))) ** 2, dim=1) / z.shape[1] 134 | if self.backbone.lower() in ["lenet5c", "dtn"]: 135 | z = GradientReversalFunction.apply(z, rev_lambda * self.rev_lambda_scale) 136 | return self.task_decoder(z) 137 | else: 138 | raise NotImplementedError() 139 | # return self.task_decoder(F.relu(z)) 140 | 141 | def get_shared_submodule(self): 142 | return self.encoder 143 | 144 | def get_private_submodule(self): 145 | return self.decoder 146 | 147 | def get_param_group_with_lr(self, lr, param_group=[], **kwargs): 148 | return [{'params': self.parameters(), 'lr': lr, **kwargs}] 149 | -------------------------------------------------------------------------------- /fade/model/office.py: -------------------------------------------------------------------------------- 1 | """For dataset. office+Caltech10 2 | 3 | Model structure refer to Peng et al. ICLR 2020, Fed Adv Domain Adaptation. 4 | """ 5 | import torch 6 | from torch import nn 7 | from .utils import GradientReversalFunction, freeze_model 8 | from .shot import init_weights 9 | from .adv import SplitAdvNet 10 | 11 | 12 | class OfficeCnnSplit(SplitAdvNet): 13 | """CNN split network for Office + Caltech10 datasets for domain adaptation. 14 | 15 | Model structure refer to Peng et al. ICLR 2020, Fed Adv Domain Adaptation. 16 | 17 | Args: 18 | backbone: 19 | - `alexnet`: Use AlexNet as feature_extractor. 20 | - `resnet`: Use ResNet101 as feature_extractor. 21 | We load the model pretrained on ImageNet from torch's rep. The module 22 | code can be found in https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py 23 | mid_dim: Split dimension size. NOTE this is useless for backbone=lenet5a. 24 | n_task: If n_task > 0, construct a `task_decoder` and predict task. 25 | rev_lambda: The param for Reversal Gradient layer in task decoder. 26 | in_channel: Num of input channels. 27 | disable_bn_stat: If set True, the batch norm layers will not update running mean and std 28 | both in train and eval mode. 29 | """ 30 | def __deepcopy__(self, memodict={}): 31 | new_model = self.__class__(backbone=self.backbone, mid_dim=self.mid_dim, 32 | n_class=self.n_class, 33 | n_task=self.n_task, in_channel=self.in_channel, 34 | rev_lambda_scale=self.rev_lambda_scale, 35 | freeze_backbone=self.freeze_backbone, 36 | freeze_decoder=self.freeze_decoder, 37 | disable_bn_stat=self.disable_bn_stat, 38 | bottleneck_type=self.bottleneck_type, 39 | CDAN_task=self.CDAN_task, 40 | pretrained=self.pretrained 41 | ).to('cuda') 42 | if hasattr(self, 'feature_extractor'): 43 | new_model.feature_extractor.load_state_dict(self.feature_extractor.state_dict()) 44 | new_model.encoder.load_state_dict(self.encoder.state_dict()) 45 | new_model.decoder.load_state_dict(self.decoder.state_dict()) 46 | new_model.task_decoder.load_state_dict(self.task_decoder.state_dict()) 47 | return new_model 48 | 49 | def __init__(self, backbone="alexnet", mid_dim=512, n_class=10, n_task=0, 50 | in_channel=3, rev_lambda_scale=1., freeze_backbone=True, freeze_decoder=False, 51 | disable_bn_stat=False, bottleneck_type='bn', 52 | CDAN_task=False, pretrained=True): 53 | super().__init__(mid_dim=mid_dim, n_class=n_class, n_task=n_task, 54 | rev_lambda_scale=rev_lambda_scale, freeze_backbone=freeze_backbone, 55 | freeze_decoder=freeze_decoder, disable_bn_stat=disable_bn_stat) 56 | self.backbone = backbone 57 | self.in_channel = in_channel 58 | self.bottleneck_type = bottleneck_type 59 | self.CDAN_task = CDAN_task 60 | self.pretrained = pretrained 61 | 62 | if self.CDAN_task: 63 | task_fea_dim = mid_dim * self.n_class 64 | else: 65 | task_fea_dim = mid_dim 66 | 67 | if backbone.lower().startswith('resnet'): 68 | # from torchvision.models import resnet50 69 | from .shot import ResBase, feat_bootleneck, feat_classifier 70 | if backbone.startswith('resnet'): 71 | base = ResBase(res_name=backbone, pretrained=self.pretrained) 72 | else: 73 | raise ValueError(f"Invalid backbone: {backbone}") 74 | bottleneck = feat_bootleneck(type=self.bottleneck_type, feature_dim=base.in_features, 75 | bottleneck_dim=mid_dim) 76 | 77 | freeze_model(base, self.freeze_backbone) 78 | self.feature_extractor = base 79 | self.encoder = bottleneck 80 | 81 | self.decoder = feat_classifier(type='wn', class_num=n_class, 82 | bottleneck_dim=mid_dim) 83 | freeze_model(self.decoder, self.freeze_decoder) 84 | 85 | if n_task > 0: 86 | if n_task > 1: 87 | self.task_decoder = nn.Linear(task_fea_dim, n_task) 88 | else: 89 | self.task_decoder = nn.Sequential( 90 | nn.Linear(task_fea_dim, 256), 91 | nn.ReLU(), 92 | nn.Dropout(), 93 | nn.Linear(256, n_task), 94 | ) 95 | self.task_decoder.apply(init_weights) 96 | self.shared += [p for p in self.task_decoder.parameters(recurse=True)] 97 | else: 98 | raise NotImplementedError(f"backbone {backbone}") 99 | self.shared += [p for p in self.encoder.parameters(recurse=True)] 100 | self.private += [p for p in self.decoder.parameters(recurse=True)] 101 | 102 | def encode(self, x: torch.Tensor, a=0.5): 103 | if hasattr(self, 'feature_extractor'): 104 | x = self.feature_extractor(x) 105 | return self.encoder(x) 106 | 107 | def decode(self, z, a=0.5) -> torch.Tensor: 108 | return self.decoder(z) 109 | 110 | def predict_task(self, z, rev_lambda) -> torch.Tensor: 111 | z = GradientReversalFunction.apply(z, rev_lambda * self.rev_lambda_scale) 112 | return self.task_decoder(z) 113 | 114 | def get_current_module_norm(self, mode="grad"): 115 | """Get submodules' (grad) norms. 116 | 117 | Args: 118 | mode: 'grad' or 'weight' 119 | 120 | Returns: 121 | all_norms, a dict 122 | """ 123 | all_modules = {"enc": self.encoder, "task_dec": self.task_decoder, "dec": self.decoder} 124 | all_norms = {"enc": 0., "task_dec": 0., "dec": 0.} 125 | for name, module in all_modules.items(): 126 | for np, p in module.named_parameters(): 127 | if p.grad is not None: 128 | if mode == "grad": 129 | norm = p.grad.data.norm().item() 130 | all_norms[name] += p.grad.data.norm().item() 131 | elif mode == "weight": 132 | norm = p.data.norm().item() 133 | else: 134 | raise ValueError(f"mode: {mode}") 135 | all_norms[name] += norm 136 | # print(f" ### > {np} norm... {norm}") 137 | # print(f" ### model {mode} norm... ", name, all_norms[name]) 138 | all_norms = dict((k+f"_{mode}_norm", v) for k, v in all_norms.items()) 139 | return all_norms 140 | 141 | -------------------------------------------------------------------------------- /fade/model/shot.py: -------------------------------------------------------------------------------- 1 | """Models used by SHOT (ICML 2020)""" 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.utils.weight_norm as weightNorm 5 | from torchvision import models 6 | 7 | 8 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 9 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low) 10 | 11 | def init_weights(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 14 | nn.init.kaiming_uniform_(m.weight) 15 | nn.init.zeros_(m.bias) 16 | elif classname.find('BatchNorm') != -1: 17 | nn.init.normal_(m.weight, 1.0, 0.02) 18 | nn.init.zeros_(m.bias) 19 | elif classname.find('Linear') != -1: 20 | nn.init.xavier_normal_(m.weight) 21 | if classname not in ["SimLinear"]: 22 | nn.init.zeros_(m.bias) 23 | 24 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50, 25 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d} 26 | 27 | class ResBase(nn.Module): 28 | def __init__(self, res_name, pretrained=True): 29 | super(ResBase, self).__init__() 30 | model_resnet = res_dict[res_name](pretrained=pretrained) 31 | self.conv1 = model_resnet.conv1 32 | self.bn1 = model_resnet.bn1 33 | self.relu = model_resnet.relu 34 | self.maxpool = model_resnet.maxpool 35 | self.layer1 = model_resnet.layer1 36 | self.layer2 = model_resnet.layer2 37 | self.layer3 = model_resnet.layer3 38 | self.layer4 = model_resnet.layer4 39 | self.avgpool = model_resnet.avgpool 40 | self.in_features = model_resnet.fc.in_features 41 | 42 | def forward(self, x): 43 | x = self.conv1(x) 44 | x = self.bn1(x) 45 | x = self.relu(x) 46 | x = self.maxpool(x) 47 | 48 | x = self.layer1(x) 49 | x = self.layer2(x) 50 | x = self.layer3(x) 51 | x = self.layer4(x) 52 | 53 | x = self.avgpool(x) 54 | x = x.view(x.size(0), -1) 55 | return x 56 | 57 | class feat_bootleneck(nn.Module): 58 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 59 | super(feat_bootleneck, self).__init__() 60 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.dropout = nn.Dropout(p=0.5) 63 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 64 | self.bottleneck.apply(init_weights) 65 | self.type = type 66 | 67 | def forward(self, x): 68 | x = self.bottleneck(x) 69 | if self.type == "bn": 70 | x = self.bn(x) 71 | elif self.type == 'dropout': 72 | x = self.relu(self.dropout(x)) 73 | # x = self.dropout(x) # NOTE: This will cause the group_loss to be inf. No idea why. 74 | else: 75 | raise ValueError(f"Wrong type: {self.type}") 76 | return x 77 | 78 | class feat_classifier(nn.Module): 79 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 80 | super(feat_classifier, self).__init__() 81 | self.type = type 82 | if type == 'wn': 83 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 84 | self.fc.apply(init_weights) 85 | else: 86 | self.fc = nn.Linear(bottleneck_dim, class_num) 87 | self.fc.apply(init_weights) 88 | 89 | def forward(self, x): 90 | x = self.fc(x) 91 | return x 92 | 93 | class feat_classifier_two(nn.Module): 94 | def __init__(self, class_num, input_dim, bottleneck_dim=256): 95 | super(feat_classifier_two, self).__init__() 96 | self.type = type 97 | self.fc0 = nn.Linear(input_dim, bottleneck_dim) 98 | self.fc0.apply(init_weights) 99 | self.fc1 = nn.Linear(bottleneck_dim, class_num) 100 | self.fc1.apply(init_weights) 101 | 102 | def forward(self, x): 103 | x = self.fc0(x) 104 | x = self.fc1(x) 105 | return x 106 | -------------------------------------------------------------------------------- /fade/model/shot_digit.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.utils.weight_norm as weightNorm 3 | 4 | 5 | def init_weights(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 8 | nn.init.kaiming_uniform_(m.weight) 9 | nn.init.zeros_(m.bias) 10 | elif classname.find('BatchNorm') != -1: 11 | nn.init.normal_(m.weight, 1.0, 0.02) 12 | nn.init.zeros_(m.bias) 13 | elif classname.find('Linear') != -1: 14 | nn.init.xavier_normal_(m.weight) 15 | nn.init.zeros_(m.bias) 16 | 17 | class feat_bootleneck(nn.Module): 18 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 19 | super(feat_bootleneck, self).__init__() 20 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 21 | self.dropout = nn.Dropout(p=0.5) 22 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 23 | self.bottleneck.apply(init_weights) 24 | self.type = type 25 | 26 | def forward(self, x): 27 | x = self.bottleneck(x) 28 | if self.type == "bn": 29 | x = self.bn(x) 30 | x = self.dropout(x) 31 | elif self.type == 'dropout': 32 | x = self.dropout(x) 33 | return x 34 | 35 | class feat_classifier(nn.Module): 36 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 37 | super(feat_classifier, self).__init__() 38 | if type == "linear": 39 | self.fc = nn.Linear(bottleneck_dim, class_num) 40 | else: 41 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 42 | self.fc.apply(init_weights) 43 | 44 | def forward(self, x): 45 | x = self.fc(x) 46 | return x 47 | 48 | class DTNBase(nn.Module): 49 | def __init__(self, in_channel): 50 | super(DTNBase, self).__init__() 51 | self.conv_params = nn.Sequential( 52 | nn.Conv2d(in_channel, 64, kernel_size=5, stride=2, padding=2), 53 | nn.BatchNorm2d(64), 54 | nn.Dropout2d(0.1), 55 | nn.ReLU(), 56 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), 57 | nn.BatchNorm2d(128), 58 | nn.Dropout2d(0.3), 59 | nn.ReLU(), 60 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), 61 | nn.BatchNorm2d(256), 62 | nn.Dropout2d(0.5), 63 | nn.ReLU() 64 | ) 65 | self.in_features = 256*4*4 66 | 67 | def forward(self, x): 68 | x = self.conv_params(x) 69 | x = x.view(x.size(0), -1) 70 | return x 71 | 72 | class LeNetBase(nn.Module): 73 | def __init__(self, in_channel): 74 | super(LeNetBase, self).__init__() 75 | self.in_channel = in_channel 76 | self.conv_params = nn.Sequential( 77 | nn.Conv2d(in_channel, 20, kernel_size=5), 78 | nn.MaxPool2d(2), 79 | nn.ReLU(), 80 | nn.Conv2d(20, 50, kernel_size=5), 81 | nn.Dropout2d(p=0.5), 82 | nn.MaxPool2d(2), 83 | nn.ReLU(), 84 | ) 85 | self.in_features = 50*4*4 86 | 87 | def forward(self, x): 88 | x = self.conv_params(x) 89 | x = x.view(x.size(0), -1) 90 | return x -------------------------------------------------------------------------------- /fade/model/split.py: -------------------------------------------------------------------------------- 1 | """Split networks""" 2 | from torch import nn 3 | from typing import List 4 | 5 | 6 | class SplitNet(nn.Module): 7 | """Only partial network is sharable. The shared subnet is stored in the attribute `shared`. 8 | 9 | Note: by default, the loss will be cross-entropy. 10 | """ 11 | def __init__(self): 12 | super().__init__() 13 | # put shared sub-modules in the list. 14 | self.shared = [] # type: List[nn.Parameter] 15 | self.private = [] # type: List[nn.Parameter] 16 | 17 | def get_shared_parameters(self, detach=True): 18 | """Return a list of shared parameters.""" 19 | return self.shared 20 | 21 | def get_private_parameters(self): 22 | """Return a list of shared parameters.""" 23 | return self.private 24 | 25 | def get_shared_submodule(self): 26 | """This is used in DP engine. A whole module is required by dp engine. 27 | If no module can be returned, (for example, shared part is not in a module), then None is 28 | returned. 29 | """ 30 | return None 31 | 32 | 33 | class SplitEncoder(SplitNet): 34 | def encode(self, x, a=0.5): 35 | return self.encoder(x) 36 | 37 | def decode(self, z, a=0.5): 38 | return self.decoder(z) 39 | 40 | def forward(self, x, a=0.5): 41 | z = self.encode(x, a=a) 42 | return self.decode(z, a=a) 43 | -------------------------------------------------------------------------------- /fade/model/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | source: https://github.com/jvanvugt/pytorch-domain-adaptation/blob/master/utils.py 3 | """ 4 | 5 | # from PIL import Image 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn, Tensor 10 | import torch.nn.functional as F 11 | from torch.autograd import Function 12 | from torchvision import models 13 | 14 | 15 | def set_requires_grad(model, requires_grad=True): 16 | for param in model.parameters(): 17 | param.requires_grad = requires_grad 18 | 19 | 20 | def loop_iterable(iterable): 21 | while True: 22 | yield from iterable 23 | 24 | 25 | class GradientReversalFunction(Function): 26 | """ 27 | Gradient Reversal Layer from: 28 | Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) 29 | Forward pass is the identity function. In the backward pass, 30 | the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) 31 | 32 | Refer to: https://github.com/jvanvugt/pytorch-domain-adaptation/blob/master/utils.py 33 | """ 34 | 35 | @staticmethod 36 | def forward(ctx, x, lambda_): 37 | ctx.lambda_ = lambda_ 38 | return x.clone() 39 | 40 | @staticmethod 41 | def backward(ctx, grads): 42 | lambda_ = ctx.lambda_ 43 | lambda_ = grads.new_tensor(lambda_) 44 | dx = -lambda_ * grads 45 | return dx, None 46 | 47 | 48 | class GradientReversal(torch.nn.Module): 49 | def __init__(self, lambda_=1.): 50 | super(GradientReversal, self).__init__() 51 | self.lambda_ = lambda_ 52 | 53 | def forward(self, x): 54 | return GradientReversalFunction.apply(x, self.lambda_) 55 | 56 | 57 | def freeze_model(model, freeze=True): 58 | for param in model.parameters(): 59 | param.requires_grad = not freeze 60 | 61 | -------------------------------------------------------------------------------- /fade/server/FedAdv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fade.server.base import Server 3 | from fade.utils import _log_time_usage 4 | 5 | 6 | class FedAdv(Server): 7 | """Federated Adversarial Server assuming: 8 | * Each user from different adversarial groups, e.g., real vs fake images, male vs female. 9 | * The group indicates the adversarial group. 10 | """ 11 | if_personal_local_adaptation = False 12 | 13 | def train(self): 14 | loss = [] 15 | only_online_users = True 16 | glob_iter = -1 17 | for glob_iter in range(self.num_glob_iters): 18 | print("-------------Round number: ", glob_iter, " -------------") 19 | 20 | if hasattr(self.full_config.server, "rev_lambda_warmup_iter"): 21 | rev_lambda_warmup_iter = self.full_config.server.rev_lambda_warmup_iter 22 | else: 23 | rev_lambda_warmup_iter = 0 24 | if 0 < rev_lambda_warmup_iter < 1: 25 | rev_lambda_warmup_iter *= self.num_glob_iters 26 | progress = max(float(glob_iter - rev_lambda_warmup_iter) / self.num_glob_iters, 0) 27 | # progress = float(glob_iter) / self.num_glob_iters 28 | rev_lambda = 2. / (1. + np.exp(-10. * progress)) - 1 # from 0 to 1. Half-Sigmoid 29 | print(f"## rev_lambda: {rev_lambda}") 30 | 31 | # loss_ = 0 32 | if len(self.online_user_idxs) >= 1: 33 | self.send_parameters(glob_iter=glob_iter) 34 | print(f"Online: {len(self.online_user_idxs)}/{len(self.users)}") 35 | else: 36 | print(f"Local training.") 37 | only_online_users = False 38 | 39 | self.selected_users = self.select_users(glob_iter, self.num_users, 40 | only_online_users=only_online_users) 41 | eval_users = self.selected_users if self.partial_eval else self.users 42 | 43 | with _log_time_usage(): 44 | _do_save = False 45 | if hasattr(self.full_config, 'eval_freq'): 46 | if glob_iter % self.full_config.eval_freq == 0: 47 | _do_evaluation = True 48 | _do_save = True 49 | else: 50 | _do_evaluation = False 51 | else: 52 | _do_evaluation = True 53 | 54 | if _do_evaluation: 55 | # Evaluate model each iteration 56 | if hasattr(self.full_config, 'snapshot') and self.full_config.snapshot: 57 | raise RuntimeError(f"Not support snapshot") 58 | else: 59 | eval_dict = self.evaluate(eval_users, reduce_users=self.partial_eval, 60 | full_info=False, return_dict=True) 61 | eval_dict = dict(("g_" + k, v) for k, v in eval_dict.items()) 62 | self.log(eval_dict, commit=False) 63 | if _do_save: 64 | self.save_model() 65 | 66 | with _log_time_usage("train and aggregate"): 67 | if hasattr(self.user_cfg, 'no_local_model') and self.user_cfg.no_local_model: 68 | raise RuntimeError(f"Not support no_local_model.") 69 | else: 70 | self.train_users(rev_lambda=rev_lambda) 71 | 72 | if hasattr(self.full_config.server, 'sync_optimizer') and self.full_config.server.sync_optimizer: 73 | assert len(self.selected_users) == 1, \ 74 | "For copying user's opt states, only one selected user is allowed." 75 | sel_user = self.selected_users[0] 76 | for user in self.users: 77 | if user.id != sel_user.id: # ignore same user. 78 | user.optimizer.load_state_dict(sel_user.optimizer.state_dict()) 79 | 80 | self.log({"global epoch": glob_iter, "rev_lambda": rev_lambda}, commit=True) 81 | if len(self.online_user_idxs) >= 1: 82 | self.send_parameters(glob_iter=glob_iter+1) 83 | self.save_results() 84 | self.save_model() 85 | 86 | def train_users(self, **user_train_kwargs): 87 | """Train users and aggregate parameters. 88 | If fair_update is required, then aggregation will be weighted by softmax-ed losses. 89 | """ 90 | user_losses = [] 91 | for user in self.selected_users: 92 | losses = user.train(**user_train_kwargs) # * user.train_samples 93 | user_losses.append(losses[0]) # Only keep the first one 94 | if hasattr(self.full_config.server, 'fair_update') and self.full_config.server.fair_update: 95 | group_losses = [] 96 | for loss_ in user_losses: 97 | group_losses.append(loss_['group_loss'][0].item()) 98 | total_group_loss = np.sum(group_losses) 99 | weights = [gl / total_group_loss for gl in group_losses] 100 | self.personalized_aggregate_parameters(weights=weights) 101 | else: 102 | self.personalized_aggregate_parameters() 103 | -------------------------------------------------------------------------------- /fade/server/FedAvg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fade.server.base import Server 3 | from fade.utils import _log_time_usage 4 | 5 | 6 | def softmax(logits, temp=1.): 7 | y = np.minimum(np.exp(logits * temp), 1e4) 8 | st = y / np.sum(y) 9 | return st 10 | 11 | 12 | class FedAvg(Server): 13 | if_personal_local_adaptation = False 14 | 15 | def train(self): 16 | loss = [] 17 | only_online_users = True 18 | glob_iter = -1 19 | probs = np.ones(len(self.users)) / len(self.users) 20 | for glob_iter in range(self.num_glob_iters): 21 | print("-------------Round number: ", glob_iter, " -------------") 22 | # loss_ = 0 23 | if len(self.online_user_idxs) >= 1: 24 | self.send_parameters(glob_iter=glob_iter) 25 | print(f"Online: {len(self.online_user_idxs)}/{len(self.users)}") 26 | else: 27 | print(f"Local training.") 28 | only_online_users = False 29 | 30 | self.selected_users = self.select_users(glob_iter, self.num_users, 31 | only_online_users=only_online_users, 32 | probs=probs) 33 | print("Select users:", [user.id for user in self.selected_users]) 34 | eval_users = self.selected_users if self.partial_eval else self.users 35 | 36 | with _log_time_usage(): 37 | _do_save = False 38 | if hasattr(self.full_config, 'eval_freq'): 39 | if glob_iter % self.full_config.eval_freq == 0: 40 | _do_evaluation = True 41 | _do_save = True 42 | else: 43 | _do_evaluation = False 44 | else: 45 | _do_evaluation = True 46 | 47 | if _do_evaluation: 48 | # Evaluate model each iteration 49 | # FIXME Ad-hoc reduce_users=self.partial_eval 50 | eval_dict = self.evaluate(eval_users, reduce_users=False, return_dict=True) 51 | eval_dict = dict(("g_"+k, v) for k, v in eval_dict.items()) 52 | self.log(eval_dict, commit=False) 53 | print("### g_train_loss", eval_dict["g_train_loss"]) 54 | print("### g_train_acc", eval_dict["g_train_acc"]) 55 | if _do_save: 56 | self.save_model() 57 | 58 | with _log_time_usage("train and aggregate"): 59 | if hasattr(self.user_cfg, 'no_local_model') and self.user_cfg.no_local_model: 60 | self.train_users_online_aggregate() 61 | else: 62 | self.train_users() 63 | 64 | self.log({"global epoch": glob_iter}, commit=True) 65 | if len(self.online_user_idxs) >= 1: 66 | self.send_parameters(glob_iter=glob_iter+1) 67 | self.save_results() 68 | self.save_model() 69 | -------------------------------------------------------------------------------- /fade/server/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fade/user/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/illidanlab/FADE/7997485ab6470fd31c2f9353bf2415a1bec87363/fade/user/__init__.py -------------------------------------------------------------------------------- /fade/user/cdan_loss.py: -------------------------------------------------------------------------------- 1 | """CDAN Loss: 2 | 3 | References: 4 | https://github.com/thuml/CDAN/blob/f7889063b76fca0b9a7147c88103d356531924bd/pytorch/loss.py#L21 5 | """ 6 | import torch 7 | 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | 17 | def grl_hook(coeff): 18 | def fun1(grad): 19 | return -coeff*grad.clone() 20 | return fun1 21 | 22 | 23 | def CDAN_predict_task(feature, softmax_output, model, random_layer=None, alpha=None): 24 | softmax_output = softmax_output.detach() 25 | if random_layer is None: 26 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 27 | ad_out = model.predict_task(op_out.view(-1, softmax_output.size(1) * feature.size(1)), 28 | rev_lambda=alpha) 29 | else: 30 | random_out = random_layer.forward([feature, softmax_output]) 31 | ad_out = model.predict_task(random_out.view(-1, random_out.size(1)), rev_lambda=alpha) 32 | return ad_out 33 | 34 | 35 | def CDAN(group_loss, softmax_output, group_labels, compute_ent_weights=False, alpha=None): 36 | # group_loss = F.binary_cross_entropy_with_logits(pred_group, group_labels) 37 | if compute_ent_weights: 38 | entropy = Entropy(softmax_output) 39 | entropy.register_hook(grl_hook(alpha)) 40 | entropy = 1.0+torch.exp(-entropy) 41 | for g in (0, 1): 42 | mask = group_labels == g 43 | entropy[mask] = entropy[mask] / torch.sum(entropy[mask]).detach().item() 44 | return torch.sum(entropy.view(-1, 1) * group_loss) / torch.sum(entropy).detach().item() 45 | else: 46 | return group_loss 47 | -------------------------------------------------------------------------------- /fade/user/generic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import wandb 4 | 5 | from .base import User 6 | from fade.model.split import SplitEncoder 7 | 8 | 9 | class GenericUser(User): 10 | """Implementation for FedAvg clients""" 11 | def __init__(self, *args, negative_coef=1., **kwargs): 12 | super().__init__(*args, **kwargs) 13 | # self.personalized_model_params = deepcopy(list(self.model.parameters())) 14 | # self.personal_model = deepcopy(self.model) 15 | self.negative_coef = negative_coef 16 | 17 | self.is_privacy_budget_out = False 18 | 19 | def can_join_for_train(self): 20 | return not self.is_privacy_budget_out 21 | 22 | def has_sharable_model(self): 23 | return not self.is_privacy_budget_out and super().has_sharable_model() 24 | 25 | def compute_loss(self, X, y, model=None): 26 | if model is None: 27 | model = self.model 28 | losses = {} 29 | if isinstance(model, SplitEncoder): 30 | Z = model.encode(X) 31 | output = model.decode(Z) 32 | else: 33 | output = model(X) 34 | if isinstance(self.loss, nn.MSELoss): 35 | output = output.view_as(y) 36 | if isinstance(self.loss, (nn.BCELoss, nn.BCEWithLogitsLoss)): 37 | y = y.float() 38 | task_loss = self.loss(output, y) 39 | if model.n_class <= 2: 40 | losses["task_loss"] = (task_loss, self.negative_coef if 0 in y else 1.) 41 | else: 42 | losses["task_loss"] = (task_loss, 1.) 43 | return losses 44 | 45 | def train(self): 46 | LOSS = [] 47 | if not self.no_local_model: 48 | self.load_model_parameters(self.local_model_params) 49 | self.model.train() 50 | for epoch in range(1, self.local_epochs + 1): 51 | self.model.train() 52 | X, y = self.get_next_train_batch() 53 | self.optimizer.zero_grad() 54 | if len(y) <= 1: # will cause error for BN layer. 55 | continue 56 | 57 | losses = self.compute_loss(X, y) 58 | LOSS.append(losses) 59 | loss = 0 60 | for k, (value, coef) in losses.items(): 61 | wandb.log({f"{self.id} {self.group} " + k: value}, commit=False) 62 | loss = loss + value * coef 63 | 64 | self.optimizer.zero_grad() 65 | loss.backward() 66 | self.optimizer.step() 67 | if hasattr(self, 'sch'): 68 | self.sch.step() 69 | wandb.log({f"{self.id} {self.group} " + "lr": self.sch.get_lr()[0]}, 70 | commit=False) 71 | if not self.no_local_model: 72 | self.clone_model_paramenter(self.model.parameters(), self.local_model_params) 73 | return LOSS 74 | 75 | -------------------------------------------------------------------------------- /fade/user/group_adv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import wandb 5 | 6 | from .base import User 7 | from fade.model.split import SplitEncoder 8 | 9 | from .shot_digit_loss import Entropy, cluster_estimate_label 10 | from ..data.utils import update_dataset_targets 11 | 12 | 13 | class GroupAdvUser(User): 14 | """Implementation for FedAvg clients""" 15 | def __init__(self, *args, adv_lambda=1., group_loss='bce', relabel_coef=0., 16 | cluster_threshold=10., negative_coef=1.0, group_loss_q=1, 17 | group_loss_dro_reg=0., loss_reshape='none', 18 | loss_reshape_q=1, clamp_grad=None, 19 | **kwargs): 20 | super().__init__(*args, **kwargs) 21 | self.adv_lambda = adv_lambda 22 | 23 | self.is_privacy_budget_out = False 24 | self.group_loss_q = group_loss_q # used with sq_bce 25 | 26 | self.group_loss = group_loss 27 | self.relabel_coef = relabel_coef 28 | self.current_steps = 0 29 | self.cluster_threshold = cluster_threshold # use 10 for PDA, 0 for DA. 30 | self.negative_coef = negative_coef 31 | 32 | # params for DRO or fair resource allocation 33 | self.group_loss_dro_reg = group_loss_dro_reg 34 | self.loss_reshape = loss_reshape 35 | self.loss_reshape_q = loss_reshape_q 36 | self.clamp_grad = clamp_grad 37 | 38 | def can_join_for_train(self): 39 | return not self.is_privacy_budget_out 40 | 41 | def has_sharable_model(self): 42 | return not self.is_privacy_budget_out and super().has_sharable_model() 43 | 44 | def compute_loss(self, X, y, rev_lambda=1., model=None): 45 | if model is None: 46 | model = self.model 47 | losses = {} 48 | if isinstance(model, SplitEncoder): 49 | Z = model.encode(X) 50 | 51 | wandb.log({f"{self.id} {self.group} Z mean": torch.mean(Z, dim=0).data.cpu().numpy()}, commit=False) 52 | wandb.log({f"{self.id} {self.group} Z std": torch.std(Z, dim=0).data.cpu().numpy()}, commit=False) 53 | 54 | assert hasattr(model, "predict_task") 55 | if self.group_loss in ('cdan', 'sq_cdan'): 56 | from .cdan_loss import CDAN_predict_task 57 | output = model.decode(Z) 58 | # NOTE do not detach softmax_out s.t. we can BP. 59 | softmax_out = F.softmax(output, dim=1) 60 | pred_group = CDAN_predict_task(Z, softmax_out, model, 61 | alpha=rev_lambda) 62 | else: 63 | pred_group = model.predict_task(Z, rev_lambda=rev_lambda) 64 | 65 | if model.n_task == 1: 66 | group_label = torch.ones(pred_group.shape[0], dtype=torch.long).fill_( 67 | self.group).to(self.device) 68 | group_acc = torch.mean(((pred_group > 0.).int() == group_label).float()).item() 69 | wandb.log({f"{self.id} {self.group} group_acc": group_acc}, commit=False) 70 | # binary classification 71 | if self.group_loss == 'bce': 72 | assert 0 <= self.group < 2 73 | assert pred_group.shape[1] == 1, f"pred_group.shape={pred_group.shape}" 74 | group_loss = F.binary_cross_entropy_with_logits(pred_group.view(-1,), group_label.float()) 75 | elif self.group_loss == 'sq_bce': 76 | assert 0 <= self.group < 2 77 | assert pred_group.shape[1] == 1, f"pred_group.shape={pred_group.shape}" 78 | group_loss = F.binary_cross_entropy_with_logits(pred_group.view(-1,), group_label.float()) 79 | # FIXME ad-hoc, the 1/2 is not used previously. 80 | group_loss = group_loss ** (self.group_loss_q + 1.) / (1 + self.group_loss_q) 81 | elif self.group_loss == 'xent': 82 | assert pred_group.shape[1] > 1, f"pred_group.shape={pred_group.shape}" 83 | group_loss = F.cross_entropy(pred_group, group_label) 84 | elif self.group_loss in ('cdan', 'sq_cdan'): 85 | from .cdan_loss import CDAN 86 | group_loss = F.binary_cross_entropy_with_logits(pred_group.view(-1,), group_label.float()) 87 | group_loss = CDAN(group_loss, softmax_out, group_label.float(), 88 | compute_ent_weights=True, alpha=rev_lambda) 89 | if self.group_loss == 'sq_cdan': 90 | group_loss = group_loss ** 2 / 2. 91 | elif self.group_loss == 'none': 92 | pass 93 | else: 94 | raise ValueError(f"Invalid group_loss: {self.group_loss} for " 95 | f"{model.n_task} tasks.") 96 | else: 97 | group_label = torch.ones(pred_group.shape[0], dtype=torch.long).fill_( 98 | self.group).to(self.device) 99 | group_acc = torch.mean((torch.argmax(pred_group) == group_label).float()).item() 100 | wandb.log({f"{self.id} {self.group} group_acc": group_acc}, commit=False) 101 | if self.group_loss == 'bce': 102 | group_loss = F.cross_entropy(pred_group, group_label) 103 | # ic(self.id, group_loss, pred_group, group_label) 104 | elif self.group_loss == 'sq_bce': 105 | group_loss = F.cross_entropy(pred_group, group_label) 106 | # ic(self.id, group_loss, pred_group, group_label) 107 | # FIXME ad-hoc, the 1/2 is not used previously. 108 | group_loss = group_loss ** (self.group_loss_q + 1.) / (1 + self.group_loss_q) 109 | elif self.group_loss in ('cdan', 'sq_cdan'): 110 | from .cdan_loss import CDAN 111 | group_loss = F.cross_entropy(pred_group, group_label) 112 | group_loss = CDAN(group_loss, softmax_out, group_label.float(), 113 | compute_ent_weights=True, alpha=rev_lambda) 114 | if self.group_loss == 'sq_cdan': 115 | group_loss = group_loss ** 2 / 2. 116 | elif self.group_loss == 'none': 117 | pass 118 | else: 119 | raise ValueError(f"Invalid group_loss: {self.group_loss} for " 120 | f"{model.n_task} tasks.") 121 | if self.group_loss != 'none' and self.adv_lambda > 0: 122 | # # FIXME not used 123 | if self.group_loss_dro_reg > 0.: 124 | losses["group_loss"] = (torch.abs(group_loss - self.group_loss_dro_reg), self.adv_lambda) 125 | else: 126 | # loss = loss + self.adv_lambda * group_loss 127 | losses["group_loss"] = (group_loss, self.adv_lambda) 128 | 129 | if self.label_mode == "supervised": 130 | output = model.decode(Z) 131 | if isinstance(self.loss, nn.MSELoss): 132 | output = output.view_as(y) 133 | if isinstance(self.loss, (nn.BCELoss, nn.BCEWithLogitsLoss)): 134 | y = y.float() 135 | # ic(self.loss(output, y)) 136 | # ic(output, y) 137 | 138 | if self.loss_reshape.lower() == 'dro': # distributionally robust opt 139 | _loss_reduction = self.loss.reduction 140 | self.loss.reduction = 'none' 141 | task_loss = self.loss(output, y) 142 | self.loss.reduction = _loss_reduction 143 | task_loss = torch.mean(torch.maximum(task_loss - self.loss_reshape_q, 144 | torch.zeros_like(task_loss)) ** 2) 145 | elif self.loss_reshape.lower() == 'fra': # fair resource allocation 146 | assert self.loss_reshape_q >= 0 147 | task_loss = self.loss(output, y) 148 | task_loss = task_loss ** (self.loss_reshape_q + 1) / (self.loss_reshape_q + 1) 149 | else: 150 | task_loss = self.loss(output, y) 151 | 152 | if model.n_class <= 2: 153 | losses["task_loss"] = (task_loss, self.negative_coef if 0 in y else 1.) 154 | else: 155 | losses["task_loss"] = (task_loss, 1.) 156 | elif self.label_mode == "unsupervised": 157 | pass 158 | elif self.label_mode == "self_supervised": # using Info-Max loss 159 | output = model.decode(Z) 160 | out_softmax = F.softmax(output, dim=1) 161 | # assert isinstance(self.loss, nn.CrossEntropyLoss) 162 | # losses["im_loss"] = (torch.mean(Entropy(out_softmax)), 1.) 163 | msoftmax = out_softmax.mean(dim=0) 164 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) 165 | losses["im_loss"] = (torch.mean(Entropy(out_softmax)) - gentropy_loss, 0.) # 1.) FIXME ad-hoc set as 1. 166 | 167 | if self.relabel_coef > 0: 168 | losses["relabel_loss"] = (F.cross_entropy(output, y), self.relabel_coef) 169 | # print(f"#### y: {y}") 170 | else: 171 | raise ValueError(f"label_mode: {self.label_mode}") 172 | else: 173 | raise NotImplementedError(f"Model type is {type(model)}") 174 | return losses 175 | 176 | def train(self, mode="train", rev_lambda=1.): 177 | LOSS = [] 178 | if not self.no_local_model: 179 | self.load_model_parameters(self.local_model_params) 180 | 181 | if self.label_mode == "self_supervised" and self.relabel_coef > 0. and self.current_steps % self.relabel_interval == 0: 182 | # assert hasattr(self, 'nonshuffle_testloader') 183 | assert hasattr(self.model, 'n_class') 184 | self.model.eval() 185 | labels = cluster_estimate_label(self.static_trainloader, self.model, 186 | class_num=self.model.n_class, 187 | threshold=self.cluster_threshold) 188 | print(f"### relabel train set for user {self.id}") 189 | update_dataset_targets(self.train_data, labels) 190 | self.iter_trainloader = iter(self.trainloader) 191 | 192 | 193 | self.model.train() 194 | flag_large_group_loss = False 195 | for epoch in range(1, self.local_epochs + 1): 196 | self.current_steps += 1 197 | self.model.train() 198 | X, y = self.get_next_train_batch() 199 | if len(y) <= 1: 200 | # raise ValueError(f"len y <=1: {len(y)}") 201 | # 1 sample will result in error for BN layer. 202 | print(f"{self.id} Only one sample is in the batch.") 203 | continue 204 | self.optimizer.zero_grad() 205 | 206 | if flag_large_group_loss and hasattr(self.model, 'reset_task_decoder'): 207 | print(f"!! Reset task decoder.") 208 | self.model.reset_task_decoder() 209 | 210 | losses = self.compute_loss(X, y, rev_lambda=0. if flag_large_group_loss else rev_lambda) 211 | pre_flag_large_group_loss = flag_large_group_loss 212 | flag_large_group_loss = ("group_loss" in losses) and (losses["group_loss"][0] > 10) 213 | LOSS.append(losses) 214 | loss = 0 215 | # print(f"## {self.id} {self.group}:", end=" ") 216 | for k, (value, coef) in losses.items(): 217 | # FIXME When local_epochs > 1, this will result in multiple records in one global wandb step. 218 | wandb.log({f"{self.id} {self.group} " + k: value}, commit=False) 219 | print(f"### {self.id} {self.group} " + k, value.item()) 220 | if mode == "pretrain": 221 | if k == "group_loss": 222 | print(f"### PRETRAIN: Ignore group_loss") 223 | continue 224 | loss = loss + value * coef 225 | # print(f" {k}: {value} * {coef}", end="; ") 226 | # print() 227 | if not isinstance(loss, torch.Tensor): 228 | print(f"### No loss. Skip backward") 229 | continue 230 | 231 | self.optimizer.zero_grad() 232 | loss.backward() 233 | 234 | if self.clamp_grad is not None: 235 | assert self.clamp_grad > 0 236 | nn.utils.clip_grad_value_(self.model.task_decoder.parameters(), 237 | clip_value=self.clamp_grad) 238 | 239 | # Log the grad/weight norms of submodules. 240 | if hasattr(self.model, "get_current_module_norm"): 241 | for mode in ("grad", "weight"): 242 | wandb.log(dict((f"{self.id} {self.group} {k}", v) 243 | for k, v in 244 | self.model.get_current_module_norm(mode=mode).items()), 245 | commit=False) 246 | 247 | self.optimizer.step() 248 | if hasattr(self, "sch"): 249 | self.sch.step() 250 | wandb.log({f"{self.id} {self.group} lr": self.sch.get_last_lr()[0]}, commit=False) 251 | 252 | try: 253 | self.optimizer.zero_grad(set_to_none=True) 254 | except TypeError: 255 | # try another call 256 | self.optimizer.zero_grad() 257 | 258 | if not self.no_local_model: 259 | self.clone_model_paramenter(self.model.parameters(), self.local_model_params) 260 | return LOSS 261 | 262 | -------------------------------------------------------------------------------- /fade/user/shot_digit_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy.spatial.distance import cdist 5 | 6 | 7 | def Entropy(input_): 8 | bs = input_.size(0) 9 | entropy = -input_ * torch.log(input_ + 1e-5) 10 | entropy = torch.sum(entropy, dim=1) 11 | return entropy 12 | 13 | class CrossEntropyLabelSmooth(nn.Module): 14 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, size_average=True): 15 | super(CrossEntropyLabelSmooth, self).__init__() 16 | self.num_classes = num_classes 17 | self.epsilon = epsilon 18 | self.use_gpu = use_gpu 19 | self.size_average = size_average 20 | self.logsoftmax = nn.LogSoftmax(dim=1) 21 | 22 | def forward(self, inputs, targets): 23 | log_probs = self.logsoftmax(inputs) 24 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 25 | if self.use_gpu: targets = targets.cuda() 26 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 27 | if self.size_average: 28 | loss = (- targets * log_probs).mean(0).sum() 29 | else: 30 | loss = (- targets * log_probs).sum(1) 31 | return loss 32 | 33 | 34 | def cluster_estimate_label(data_loader, model, class_num, epsilon=1e-5, distance='cosine', threshold=10): 35 | start_test = True 36 | with torch.no_grad(): 37 | iter_test = iter(data_loader) 38 | # FIXME this has to disable shuffle. 39 | for _ in range(len(data_loader)): 40 | data = iter_test.next() 41 | inputs = data[0] 42 | labels = data[1] 43 | inputs = inputs.cuda() 44 | feas = model.encode(inputs) 45 | outputs = model.decode(feas) 46 | if start_test: 47 | all_fea = feas.float().cpu() 48 | all_output = outputs.float().cpu() 49 | all_label = labels.float() 50 | start_test = False 51 | else: 52 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 53 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 54 | all_label = torch.cat((all_label, labels.float()), 0) 55 | 56 | all_output = nn.Softmax(dim=1)(all_output) 57 | ent = torch.sum(-all_output * torch.log(all_output + epsilon), dim=1) 58 | unknown_weight = 1 - ent / np.log(class_num) 59 | _, predict = torch.max(all_output, 1) 60 | 61 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 62 | if distance == 'cosine': 63 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 64 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 65 | 66 | all_fea = all_fea.float().cpu().numpy() 67 | K = all_output.size(1) 68 | aff = all_output.float().cpu().numpy() 69 | initc = aff.transpose().dot(all_fea) 70 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 71 | cls_count = np.eye(K)[predict].sum(axis=0) 72 | labelset = np.where(cls_count>threshold) 73 | labelset = labelset[0] 74 | # print(labelset) 75 | 76 | dd = cdist(all_fea, initc[labelset], distance) 77 | pred_label = dd.argmin(axis=1) 78 | pred_label = labelset[pred_label] 79 | 80 | for round in range(1): 81 | aff = np.eye(K)[pred_label] 82 | initc = aff.transpose().dot(all_fea) 83 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 84 | dd = cdist(all_fea, initc[labelset], distance) 85 | pred_label = dd.argmin(axis=1) 86 | pred_label = labelset[pred_label] 87 | 88 | return pred_label.astype('int') 89 | -------------------------------------------------------------------------------- /fade/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from contextlib import contextmanager 4 | import time 5 | import logging 6 | from datetime import timedelta 7 | import pandas as pd 8 | from hashlib import sha1 9 | from omegaconf import OmegaConf, DictConfig 10 | 11 | timer_logger = logging.getLogger("TIME") 12 | timer_logger.setLevel("DEBUG") 13 | 14 | 15 | def set_coloredlogs_env(): 16 | os.environ['COLOREDLOGS_LOG_FORMAT'] = "%(name)s[%(process)d] %(levelname)s %(message)s" 17 | os.environ['COLOREDLOGS_LEVEL_STYLES'] = \ 18 | 'spam=22;debug=28;verbose=34;notice=220;warning=184;' \ 19 | 'info=101;success=118,bold;error=161,bold;critical=background=red' 20 | 21 | 22 | def flatten_dict(d: dict, sep=".", handle_list=True): 23 | if handle_list: 24 | return _flatten(d, sep) 25 | else: 26 | df = pd.json_normalize(d, sep=sep) 27 | return df.to_dict(orient="records")[0] 28 | 29 | 30 | def _flatten(input_dict, separator='_', prefix=''): 31 | """Flatten a dict including nested list. 32 | Ref: https://stackoverflow.com/a/55834113/3503604 33 | """ 34 | output_dict = {} 35 | for key, value in input_dict.items(): 36 | if isinstance(value, dict) and value: 37 | deeper = _flatten(value, separator, prefix+key+separator) 38 | output_dict.update({key2: val2 for key2, val2 in deeper.items()}) 39 | elif isinstance(value, list) and value: 40 | for index, sublist in enumerate(value): 41 | if isinstance(sublist, dict) and sublist: 42 | deeper = _flatten(sublist, separator, prefix+key+separator+str(index)+separator) 43 | output_dict.update({key2: val2 for key2, val2 in deeper.items()}) 44 | else: 45 | output_dict[prefix+key+separator+str(index)] = value 46 | else: 47 | output_dict[prefix+key] = value 48 | return output_dict 49 | 50 | 51 | def hash_config(cfg, select_keys=[], exclude_keys=[]): 52 | if len(exclude_keys) > 0: 53 | for k in exclude_keys: 54 | if k in select_keys: 55 | raise ValueError(f"Try to exclude key {k} which is the selected keys:" 56 | f"{select_keys}") 57 | cfg = OmegaConf.masked_copy(cfg, [k for k in cfg if k not in exclude_keys]) 58 | if len(select_keys) > 0: 59 | cfg = OmegaConf.masked_copy(cfg, select_keys) 60 | cfg_str = OmegaConf.to_yaml(cfg, resolve=True, sort_keys=True) 61 | hash_hex = sha1(cfg_str.encode('utf-8')).hexdigest() 62 | return hash_hex 63 | 64 | 65 | @contextmanager 66 | def _log_time_usage(prefix="", debug_only=True): 67 | '''log the time usage in a code block 68 | prefix: the prefix text to show 69 | 70 | Refer: https://stackoverflow.com/a/37429875/3503604 71 | ''' 72 | start = time.time() 73 | try: 74 | info = f"=== {prefix} time block ===" 75 | if debug_only: 76 | timer_logger.debug(info) 77 | else: 78 | print(info) 79 | yield 80 | finally: 81 | end = time.time() 82 | elapsed = str(timedelta(seconds=end - start)) 83 | info = f"=== {prefix} elapsed: {elapsed} ===" 84 | if debug_only: 85 | timer_logger.debug(info) 86 | else: 87 | print(info) 88 | 89 | 90 | def average_smooth(data, window_len=20, window='hanning'): 91 | results = [] 92 | if window_len < 3: 93 | return data 94 | for i in range(len(data)): 95 | x = data[i] 96 | s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]] 97 | # print(len(s)) 98 | if window == 'flat': # moving average 99 | w = np.ones(window_len, 'd') 100 | else: 101 | w = eval('np.' + window + '(window_len)') 102 | 103 | y = np.convolve(w / w.sum(), s, mode='valid') 104 | results.append(y[window_len - 1:]) 105 | return np.array(results) 106 | 107 | 108 | def get_file_date(fname): 109 | import pathlib 110 | import datetime 111 | file = pathlib.Path(fname) 112 | mtime = datetime.datetime.fromtimestamp(file.stat().st_mtime) 113 | return mtime 114 | -------------------------------------------------------------------------------- /sweeps/Office31_UDA/A2X_cdan_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | name: A2X_cdan_nuser 7 | project: fade-demo-Office31_X2X_UDA 8 | command: 9 | - ${interpreter} 10 | - -m 11 | - ${program} 12 | - logger.wandb.project=fade-demo-Office31_X2X_UDA 13 | - dataset=comb/Office31_X2X_1s_3t 14 | - server=FedAdv 15 | - model=Office31_CnnSplitAdv 16 | - user=group_adv_office_uda 17 | - num_glob_iters=601 18 | # 19 | - +eval_freq=30 20 | - logger.wandb.group='cdan-${dataset.meta_datasets.0.name}-c${dataset.meta_datasets.0.n_class}' 21 | - logger.wandb.name='r${i_rep}' 22 | - server.beta=.5 23 | - load_model.do=true 24 | - load_model.load=[server] 25 | - server.user_selection=random_uniform 26 | # 27 | - dataset.meta_datasets.1.name=Office31A 28 | - +server.group_label_mode.Office31A=supervised 29 | - +server.group_label_mode.Office31W=unsupervised 30 | - +server.group_label_mode.Office31D=unsupervised 31 | - load_model.hash_name=Office31A/FedAvg/OfficeCnnSplit/47095220f910831acac239dd12b241120a1a093c/g_0 32 | # 33 | - model.CDAN_task=True 34 | - model.freeze_backbone=False 35 | - model.freeze_decoder=False 36 | - model.rev_lambda_scale=1. 37 | - model.disable_bn_stat=True 38 | - model.bottleneck_type=dropout 39 | - user.group_loss=cdan 40 | - user.relabel_coef=0. 41 | - ${args_no_hyphens} 42 | method: grid 43 | metric: 44 | goal: maximize 45 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 46 | parameters: 47 | i_rep: 48 | values: 49 | - 0 50 | - 1 51 | - 2 52 | dataset.meta_datasets.0.name: 53 | values: 54 | - Office31W 55 | - Office31D 56 | server.num_users: 57 | values: 58 | - 2 59 | dataset.meta_datasets.0.n_class: 60 | values: 61 | - 15 62 | - 31 63 | # user.optimizer.learning_rate: 64 | # values: 65 | # - 0.01 66 | # - 0.001 67 | # user.adv_lambda: 68 | # values: 69 | ## - 1. 70 | # - 0.1 71 | # - 0.01 72 | program: fade.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/A2X_dann_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | name: A2X_dann_nuser 7 | project: fade-demo-Office31_X2X_UDA 8 | command: 9 | - ${interpreter} 10 | - -m 11 | - ${program} 12 | - logger.wandb.project=fade-demo-Office31_X2X_UDA 13 | - dataset=comb/Office31_X2X_1s_3t 14 | - server=FedAdv 15 | - model=Office31_CnnSplitAdv 16 | - user=group_adv_office_uda 17 | - num_glob_iters=300 18 | # 19 | - +eval_freq=30 20 | - logger.wandb.group='dann-${dataset.meta_datasets.0.name}-c${dataset.meta_datasets.0.n_class}' 21 | - logger.wandb.name='r${i_rep}' 22 | - server.beta=.5 23 | - load_model.do=true 24 | - load_model.load=[server] 25 | - server.user_selection=random_uniform 26 | # 27 | - dataset.meta_datasets.1.name=Office31A 28 | - +server.group_label_mode.Office31A=supervised 29 | - +server.group_label_mode.Office31W=unsupervised 30 | - +server.group_label_mode.Office31D=unsupervised 31 | - load_model.hash_name=Office31A/FedAvg/OfficeCnnSplit/47095220f910831acac239dd12b241120a1a093c/g_0 32 | # 33 | - model.freeze_backbone=False 34 | - model.freeze_decoder=False 35 | - model.rev_lambda_scale=1. 36 | - model.disable_bn_stat=True 37 | - model.bottleneck_type=dropout 38 | - user.group_loss=bce 39 | - user.relabel_coef=0. 40 | - ${args_no_hyphens} 41 | method: grid 42 | metric: 43 | goal: maximize 44 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 45 | parameters: 46 | i_rep: 47 | values: 48 | - 0 49 | - 1 50 | - 2 51 | dataset.meta_datasets.0.name: 52 | values: 53 | - Office31W 54 | - Office31D 55 | dataset.meta_datasets.0.n_class: 56 | values: 57 | - 15 58 | - 31 59 | # user.group_loss: 60 | # values: 61 | # - bce 62 | # - sq_bce 63 | server.num_users: 64 | values: 65 | # - 1 66 | - 2 67 | # - 3 68 | # - 4 69 | model.freeze_backbone: 70 | values: 71 | # - True 72 | - False 73 | # user.optimizer.learning_rate: 74 | # values: 75 | # - 0.01 76 | # - 0.001 77 | # user.adv_lambda: 78 | # values: 79 | ## - 1. 80 | # - 0.1 81 | # - 0.01 82 | program: fade.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/A2X_shot_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | name: A2X_shot_nuser 7 | project: fade-demo-Office31_X2X_UDA 8 | command: 9 | - ${interpreter} 10 | - -m 11 | - ${program} 12 | - logger.wandb.project=fade-demo-Office31_X2X_UDA 13 | - dataset=comb/Office31_X2X_1s_3t 14 | - server=FedAdv 15 | - model=Office31_CnnSplitAdv 16 | - user=group_adv_office_uda 17 | - num_glob_iters=300 18 | # 19 | - +eval_freq=30 20 | - logger.wandb.group='shot-${dataset.meta_datasets.0.name}-c${dataset.meta_datasets.0.n_class}' 21 | - logger.wandb.name='r${i_rep}' 22 | - server.beta=.5 23 | - load_model.do=true 24 | - load_model.load=[server] 25 | - server.user_selection=random_uniform 26 | # 27 | - dataset.meta_datasets.1.name=Office31A 28 | - +server.group_label_mode.Office31A=supervised 29 | - +server.group_label_mode.Office31W=self_supervised 30 | - +server.group_label_mode.Office31D=self_supervised 31 | - load_model.hash_name=Office31A/FedAvg/OfficeCnnSplit/4bae7dd1a1c5bf247a4ca4ce1bf1f2394eb1f34b/g_0 32 | # 33 | - user.adv_lambda=0. 34 | - user.group_loss=none 35 | - model.freeze_backbone=False 36 | - model.freeze_decoder=True 37 | - model.rev_lambda_scale=0. 38 | - model.disable_bn_stat=True 39 | - model.bottleneck_type=bn 40 | - user.relabel_coef=0.1 41 | # 42 | - ${args_no_hyphens} 43 | method: grid 44 | metric: 45 | goal: maximize 46 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 47 | parameters: 48 | i_rep: 49 | values: 50 | - 0 51 | - 1 52 | - 2 53 | dataset.meta_datasets.0.name: 54 | values: 55 | - Office31W 56 | - Office31D 57 | server.num_users: 58 | values: 59 | - 2 60 | dataset.meta_datasets.0.n_class: 61 | values: 62 | - 15 63 | - 31 64 | program: fade.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/A_fedavg.sh: -------------------------------------------------------------------------------- 1 | # Goal: 2 | # Train CNN on OfficeHome65A dataset by Single-Task-Learning (STL). 3 | # All settings follow SHOT using central training. 4 | 5 | # TODO use MnistCnnSplit (w/o adv/task predictor) model. but be careful when load. 6 | kwargs=" 7 | dataset=Office31 8 | dataset.name=Office31A 9 | logger.wandb.project=dmtl-Office31A_STL 10 | name=stl 11 | model=OfficeCnnSplitAdv 12 | user=generic 13 | user.batch_size=64 14 | user.local_epochs=-1 15 | num_glob_iters=50 16 | n_rep=3 17 | hydra.launcher.n_jobs=1 18 | tele.enable=False 19 | server.num_users=-1 20 | +user.total_local_epochs=38 21 | +eval_freq=10 22 | " 23 | data_kwargs="" 24 | data_kwargs_multi=" 25 | dataset.n_user=1 26 | " 27 | #data_kwargs_multi_select=" 28 | #dataset.n_user=5 29 | #dataset.n_class=15 30 | #+dataset.class_stride=10 31 | #" 32 | data_kwargs_multi_select=" 33 | dataset.n_user=1 34 | " 35 | kwargs_multi=" 36 | user.optimizer.learning_rate=0.01,0.001 37 | user.optimizer.name=rmsprop 38 | i_rep=0,1,2,3,4 39 | server=FedAvg 40 | " 41 | # adam works better than rmsprop 42 | kwargs_multi_select=" 43 | user.optimizer.learning_rate=0.01 44 | user.optimizer.name=sgd_sch 45 | user.loss=sxent 46 | i_rep=0 47 | server=FedAvg 48 | server.beta=1. 49 | model.mid_dim=256 50 | model.backbone=resnet50 51 | model.freeze_backbone=False 52 | +model.bottleneck_type=dropout 53 | model.disable_bn_stat=True 54 | " 55 | # +model.bottleneck_type=drop 56 | 57 | # NOTE: load_model.hash_name has to be updated after training. 58 | load_kwargs=" 59 | load_model.do=true 60 | load_model.load=[server] 61 | load_model.hash_name=Office31A/FedAvg/OfficeCnnSplit/3e6393ec7c8e457107995908726fb9e9c9b5ef56/g_0 62 | " 63 | # dropout bottleneck: 47095220f910831acac239dd12b241120a1a093c 64 | # dropout: 3e6393ec7c8e457107995908726fb9e9c9b5ef56 (NEW: no idea what is the difference) 65 | # bn battleneck: 4bae7dd1a1c5bf247a4ca4ce1bf1f2394eb1f34b 66 | #echo ===== 67 | #echo Generating dataset 68 | #echo ===== 69 | #echo 70 | ## python -m dmtl.data.sinusoid_mtl -cn MnistSTL --cfg job 71 | #python -m dmtl.data.extend -cn OfficeHome65 n_user=1 name=OfficeHome65A 72 | 73 | #echo ===== 74 | #echo Train 75 | #echo ===== 76 | #echo 77 | # print the config 78 | #python -m dmtl.mainx $kwargs logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $kwargs_multi_select $data_kwargs_multi_select --cfg job 79 | ## TODO add `$load_kwargs` to fine-tune the model. 80 | #python -m dmtl.mainx $kwargs $data_kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $kwargs_multi_select 81 | 82 | ## # Repeat experiments 83 | ##python -m dmtl.mainx $kwargs $data_kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $kwargs_multi_select i_rep=1,2 -m 84 | 85 | #echo ===== 86 | #echo Check generated files 87 | #echo ===== 88 | #echo 89 | #python -m dmtl.mainx $kwargs $kwargs_multi_select $data_kwargs $data_kwargs_multi_select logger.wandb.group='${server.name}-tr${dataset.tasks.1.total_tr_size}-u${dataset.tasks.1.n_user}-lr${user.optimizer.learning_rate}' action=check_files 90 | 91 | #echo 92 | #echo ========================================== 93 | #echo "Now copy copy the path to the root of server.pt as the hash name." 94 | #echo ========================================== 95 | # 96 | ## TODO after updating the hash name, uncomment below and run evaluation to check the performance. 97 | 98 | #echo ===== 99 | #echo Eval on Office31A 100 | #echo ===== 101 | # 102 | #python -m dmtl.mainx $kwargs $data_kwargs_multi_select $load_kwargs $data_kwargs $kwargs_multi_select action=eval logger.loggers='[]' 103 | 104 | 105 | all_targets=( 106 | Office31D 107 | Office31W 108 | ) 109 | 110 | for target in ${all_targets[@]} 111 | do 112 | echo ===== 113 | echo Eval on ${target} 114 | echo ===== 115 | 116 | # echo With loaded model 117 | # python -m dmtl.mainx $kwargs $kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $data_kwargs_multi_select $load_kwargs action=eval dataset.name=$target logger.loggers='[]' 118 | 119 | 120 | echo With trained models 121 | python -m dmtl.mainx $kwargs $kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs action=eval dataset.name=$target logger.loggers='[]' $data_kwargs_multi_select i_rep=-1 load_model.do=true load_model.load=[server] load_model.do=true load_model.load=[server] load_model.hash_name=Office31A/FedAvg/OfficeCnnSplit/3e6393ec7c8e457107995908726fb9e9c9b5ef56 122 | # $data_kwargs_multi_select 123 | done -------------------------------------------------------------------------------- /sweeps/Office31_UDA/D2X_cdan_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # Update: 7 | # disable bn stat 8 | # 9 | # Preparation: 10 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 11 | # not affect the quality of source domain. 12 | # 2. Run `m2u_fuda_puser.sh to create the project. 13 | # 14 | # Runs: 15 | # - [1/29] https://wandb.ai/jyhong/dmtl-Office31_X2X_FUDA_fuda/sweeps/51ojeo40/overview 16 | name: D2X_cdan_nuser 17 | project: fade-demo-Office31_X2X_UDA 18 | command: 19 | - ${interpreter} 20 | - -m 21 | - ${program} 22 | - logger.wandb.project=dmtl-Office31_X2X_UDA 23 | - dataset=comb/Office31_X2X_1s_3t 24 | - server=FedAdv 25 | - model=Office31_CnnSplitAdv 26 | - user=group_adv_office_uda 27 | - num_glob_iters=601 28 | # 29 | - +eval_freq=30 30 | - logger.wandb.group='cdan-c${dataset.meta_datasets.0.n_class}-a${user.adv_lambda}-lr{user.optimizer.learning_rate}' 31 | - logger.wandb.name='r${i_rep}' 32 | - server.beta=.5 33 | - load_model.do=true 34 | - load_model.load=[server] 35 | - server.user_selection=random_uniform 36 | # 37 | - dataset.meta_datasets.1.name=Office31D 38 | - +server.group_label_mode.Office31D=supervised 39 | - +server.group_label_mode.Office31A=unsupervised 40 | - +server.group_label_mode.Office31W=unsupervised 41 | - load_model.hash_name=Office31D/FedAvg/OfficeCnnSplit/0190b6bc823de135d22e99960783dbcd1ec8cc1d/g_0 42 | # 43 | - model.CDAN_task=True 44 | - user.adv_lambda=1. 45 | - model.freeze_backbone=False 46 | - model.freeze_decoder=False 47 | - model.rev_lambda_scale=1. 48 | - model.disable_bn_stat=True 49 | - model.bottleneck_type=dropout 50 | - user.group_loss=cdan 51 | - user.relabel_coef=0. 52 | - ${args_no_hyphens} 53 | method: grid 54 | metric: 55 | goal: maximize 56 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 57 | parameters: 58 | i_rep: 59 | values: 60 | - 0 61 | - 1 62 | - 2 63 | dataset.meta_datasets.0.name: 64 | values: 65 | - Office31A 66 | - Office31W 67 | server.num_users: 68 | values: 69 | # - 1 70 | - 2 71 | # - 3 72 | # - 4 73 | dataset.meta_datasets.0.n_class: 74 | values: 75 | - 15 76 | - 31 77 | user.optimizer.learning_rate: 78 | values: 79 | - 0.01 80 | - 0.001 81 | user.adv_lambda: 82 | values: 83 | # - 1. 84 | - 0.1 85 | - 0.01 86 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/D2X_dann_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # 7 | # Preparation: 8 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 9 | # not affect the quality of source domain. 10 | # 2. Run `m2u_fuda_puser.sh to create the project. 11 | # 12 | # Runs: 13 | # - [1/28] https://wandb.ai/jyhong/dmtl-Office31_A2W_FUDA_fuda_vs_shot/sweeps/g7ybmzo0/overview 14 | # - https://wandb.ai/jyhong/dmtl-Office31_A2W_FUDA_fuda_vs_shot/sweeps/lab3hesm/overview 15 | name: D2X_dann_nuser 16 | project: fade-demo-Office31_X2X_UDA 17 | command: 18 | - ${interpreter} 19 | - -m 20 | - ${program} 21 | - logger.wandb.project=dmtl-Office31_X2X_UDA 22 | - dataset=comb/Office31_X2X_1s_3t 23 | - server=FedAdv 24 | - model=Office31_CnnSplitAdv 25 | - user=group_adv_office_uda 26 | - num_glob_iters=601 # 300 27 | # 28 | - +eval_freq=30 29 | - logger.wandb.group='cdan-c${dataset.meta_datasets.0.n_class}-a${user.adv_lambda}-lr{user.optimizer.learning_rate}' 30 | - logger.wandb.name='r${i_rep}' 31 | - server.beta=.5 32 | - load_model.do=true 33 | - load_model.load=[server] 34 | - server.user_selection=random_uniform 35 | # 36 | - dataset.meta_datasets.1.name=Office31D 37 | - +server.group_label_mode.Office31D=supervised 38 | - +server.group_label_mode.Office31A=unsupervised 39 | - +server.group_label_mode.Office31W=unsupervised 40 | - load_model.hash_name=Office31D/FedAvg/OfficeCnnSplit/0190b6bc823de135d22e99960783dbcd1ec8cc1d/g_0 41 | # 42 | #- user.adv_lambda=1. 43 | - model.freeze_backbone=False 44 | - model.freeze_decoder=False 45 | - model.rev_lambda_scale=1. 46 | - model.disable_bn_stat=True 47 | - model.bottleneck_type=dropout 48 | - user.group_loss=bce 49 | - user.relabel_coef=0. 50 | - ${args_no_hyphens} 51 | method: grid 52 | metric: 53 | goal: maximize 54 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 55 | parameters: 56 | i_rep: 57 | values: 58 | - 0 59 | - 1 60 | - 2 61 | dataset.meta_datasets.0.name: 62 | values: 63 | # - Office31A 64 | - Office31W 65 | user.group_loss: 66 | values: 67 | - bce 68 | # - sq_bce 69 | server.num_users: 70 | values: 71 | # - 1 72 | - 2 73 | # - 3 74 | # - 4 75 | dataset.meta_datasets.0.n_class: 76 | values: 77 | - 15 78 | - 31 79 | user.optimizer.learning_rate: 80 | values: 81 | - 0.01 82 | - 0.001 83 | user.adv_lambda: 84 | values: 85 | # - 1. 86 | - 0.1 87 | - 0.01 88 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/D2X_shot_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # 7 | # Preparation: 8 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 9 | # not affect the quality of source domain. 10 | # 2. Run `m2u_fuda_puser.sh to create the project. 11 | # 12 | # Runs: 13 | # - [1/28] https://wandb.ai/jyhong/dmtl-Office31_A2W_FUDA_fuda_vs_shot/sweeps/x37dpbwb/overview 14 | # - https://wandb.ai/jyhong/dmtl-Office31_A2W_FUDA_fuda_vs_shot/sweeps/tjvv5816/overview 15 | name: D2X_shot_nuser 16 | project: fade-demo-Office31_X2X_UDA 17 | command: 18 | - ${interpreter} 19 | - -m 20 | - ${program} 21 | - logger.wandb.project=dmtl-Office31_X2X_UDA 22 | - dataset=comb/Office31_X2X_1s_3t 23 | - server=FedAdv 24 | - model=Office31_CnnSplitAdv 25 | - user=group_adv_office_uda 26 | - num_glob_iters=300 27 | # 28 | - +eval_freq=30 29 | - logger.wandb.group='adv-shot-nuser-${dataset.meta_datasets.0.name}' 30 | - logger.wandb.name='r${i_rep}-shot-nuser-${dataset.meta_datasets.0.name}' 31 | - server.beta=.5 32 | - load_model.do=true 33 | - load_model.load=[server] 34 | - server.user_selection=random_uniform 35 | # 36 | - dataset.meta_datasets.1.name=Office31D 37 | - +server.group_label_mode.Office31D=supervised 38 | - +server.group_label_mode.Office31A=self_supervised 39 | - +server.group_label_mode.Office31W=self_supervised 40 | - load_model.hash_name=Office31D/FedAvg/OfficeCnnSplit/c131bb220bfae6defe855b253afbcb6788da0dc1/g_0 41 | # 42 | - user.adv_lambda=0. 43 | - user.group_loss=none 44 | - model.freeze_backbone=False 45 | - model.freeze_decoder=True 46 | - model.rev_lambda_scale=0. 47 | - model.disable_bn_stat=True 48 | - model.bottleneck_type=bn 49 | - user.relabel_coef=0.1 50 | # 51 | - ${args_no_hyphens} 52 | method: grid 53 | metric: 54 | goal: maximize 55 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 56 | parameters: 57 | i_rep: 58 | values: 59 | - 0 60 | - 1 61 | - 2 62 | dataset.meta_datasets.0.name: 63 | values: 64 | - Office31A 65 | - Office31W 66 | server.num_users: 67 | values: 68 | - 1 69 | - 2 70 | - 3 71 | - 4 72 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/D_fedavg.sh: -------------------------------------------------------------------------------- 1 | # Goal: 2 | # Train CNN on OfficeHome65A dataset by Single-Task-Learning (STL). 3 | # All settings follow SHOT using central training. 4 | 5 | # TODO use MnistCnnSplit (w/o adv/task predictor) model. but be careful when load. 6 | kwargs=" 7 | dataset=Office31 8 | dataset.name=Office31D 9 | logger.wandb.project=dmtl-Office31A_STL 10 | name=stl 11 | model=OfficeCnnSplitAdv 12 | user=generic 13 | user.batch_size=64 14 | user.local_epochs=-1 15 | num_glob_iters=50 16 | n_rep=5 17 | hydra.launcher.n_jobs=3 18 | tele.enable=False 19 | server.num_users=-1 20 | +user.total_local_epochs=38 21 | +eval_freq=10 22 | " 23 | data_kwargs="" 24 | data_kwargs_multi=" 25 | dataset.n_user=1 26 | " 27 | #data_kwargs_multi_select=" 28 | #dataset.n_user=5 29 | #dataset.n_class=15 30 | #+dataset.class_stride=10 31 | #" 32 | data_kwargs_multi_select=" 33 | dataset.n_user=1 34 | " 35 | kwargs_multi=" 36 | user.optimizer.learning_rate=0.01,0.001 37 | user.optimizer.name=rmsprop 38 | i_rep=0,1,2,3,4 39 | server=FedAvg 40 | " 41 | # adam works better than rmsprop 42 | kwargs_multi_select=" 43 | user.optimizer.learning_rate=0.01 44 | user.optimizer.name=sgd_sch 45 | user.loss=sxent 46 | i_rep=0 47 | server=FedAvg 48 | server.beta=1. 49 | model.mid_dim=256 50 | model.backbone=resnet50 51 | model.freeze_backbone=False 52 | +model.bottleneck_type=bn 53 | " 54 | # +model.bottleneck_type=drop 55 | 56 | # NOTE: load_model.hash_name has to be updated after training. 57 | load_kwargs=" 58 | load_model.do=true 59 | load_model.load=[server] 60 | load_model.hash_name=Office31D/FedAvg/OfficeCnnSplit/c131bb220bfae6defe855b253afbcb6788da0dc1/g_0 61 | " 62 | # dropout bottleneck: 0190b6bc823de135d22e99960783dbcd1ec8cc1d 63 | # bn battleneck: c131bb220bfae6defe855b253afbcb6788da0dc1 64 | #echo ===== 65 | #echo Generating dataset 66 | #echo ===== 67 | #echo 68 | ## python -m dmtl.data.sinusoid_mtl -cn MnistSTL --cfg job 69 | #python -m dmtl.data.extend -cn OfficeHome65 n_user=1 name=OfficeHome65A 70 | 71 | #echo ===== 72 | #echo Train 73 | #echo ===== 74 | #echo 75 | ## print the config 76 | #python -m dmtl.mainx $kwargs logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $kwargs_multi_select $data_kwargs_multi_select --cfg job 77 | ##echo 78 | ### run sweep 79 | ### python -m dmtl.mainx $kwargs $kwargs_multi logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $data_kwargs_multi -m 80 | ### Run only with the selected hparam 81 | ### TODO add `$load_kwargs` to fine-tune the model. 82 | #python -m dmtl.mainx $kwargs $data_kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $kwargs_multi_select 83 | # 84 | echo ===== 85 | echo Check generated files 86 | echo ===== 87 | echo 88 | python -m dmtl.mainx $kwargs $kwargs_multi_select $data_kwargs $data_kwargs_multi_select logger.wandb.group='${server.name}-tr${dataset.tasks.1.total_tr_size}-u${dataset.tasks.1.n_user}-lr${user.optimizer.learning_rate}' action=check_files 89 | # 90 | #echo 91 | #echo ========================================== 92 | #echo "Now copy copy the path to the root of server.pt as the hash name." 93 | #echo ========================================== 94 | 95 | # TODO after updating the hash name, uncomment below and run evaluation to check the performance. 96 | 97 | #echo ===== 98 | #echo Eval on Office31A 99 | #echo ===== 100 | # 101 | #python -m dmtl.mainx $kwargs $data_kwargs_multi_select $load_kwargs $data_kwargs $kwargs_multi_select action=eval logger.loggers='[]' 102 | 103 | 104 | echo ===== 105 | echo Eval on Office31A 106 | echo ===== 107 | # generate data 108 | #echo python -m dmtl.data.extend -cn MnistM $data_kwargs $data_kwargs_multi_select dataset.name=MnistMSTL 109 | 110 | # TODO the train set may use random crop for augmentation. 111 | 112 | python -m dmtl.mainx $kwargs $kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $load_kwargs action=eval dataset.name=Office31A logger.loggers='[]' 113 | # $data_kwargs_multi_select 114 | 115 | 116 | echo ===== 117 | echo Eval on Office31W 118 | echo ===== 119 | # generate data 120 | #echo python -m dmtl.data.extend -cn MnistM $data_kwargs $data_kwargs_multi_select dataset.name=MnistMSTL 121 | 122 | # TODO the train set may use random crop for augmentation. 123 | 124 | python -m dmtl.mainx $kwargs $kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $load_kwargs action=eval dataset.name=Office31W logger.loggers='[]' 125 | # $data_kwargs_multi_select -------------------------------------------------------------------------------- /sweeps/Office31_UDA/W2X_cdan_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # Update: 7 | # disable bn stat 8 | # 9 | # Preparation: 10 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 11 | # not affect the quality of source domain. 12 | # 2. Run `m2u_fuda_puser.sh to create the project. 13 | # 14 | # Runs: 15 | # - [1/29] https://wandb.ai/jyhong/dmtl-Office31_X2X_FUDA_fuda/sweeps/51ojeo40/overview 16 | name: W2X_cdan_nuser 17 | project: fade-demo-Office31_X2X_UDA 18 | command: 19 | - ${interpreter} 20 | - -m 21 | - ${program} 22 | - logger.wandb.project=dmtl-Office31_X2X_UDA 23 | - dataset=comb/Office31_X2X_1s_3t 24 | - server=FedAdv 25 | - model=Office31_CnnSplitAdv 26 | - user=group_adv_office_uda 27 | - num_glob_iters=300 28 | # 29 | - +eval_freq=30 30 | - logger.wandb.group='cdan-c${dataset.meta_datasets.0.n_class}-a${user.adv_lambda}-lr{user.optimizer.learning_rate}' 31 | - logger.wandb.name='r${i_rep}' 32 | - server.beta=.5 33 | - load_model.do=true 34 | - load_model.load=[server] 35 | - server.user_selection=random_uniform 36 | # 37 | - dataset.meta_datasets.1.name=Office31W 38 | - +server.group_label_mode.Office31W=supervised 39 | - +server.group_label_mode.Office31A=unsupervised 40 | - +server.group_label_mode.Office31D=unsupervised 41 | - load_model.hash_name=Office31W/FedAvg/OfficeCnnSplit/0ff680f5368dcbe84087c270afefc79940dbd08b/g_0 42 | # 43 | - model.CDAN_task=True 44 | - user.adv_lambda=1. 45 | - model.freeze_backbone=False 46 | - model.freeze_decoder=False 47 | - model.rev_lambda_scale=1. 48 | - model.disable_bn_stat=True 49 | - model.bottleneck_type=dropout 50 | - user.group_loss=cdan 51 | - user.relabel_coef=0. 52 | - ${args_no_hyphens} 53 | method: grid 54 | metric: 55 | goal: maximize 56 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 57 | parameters: 58 | i_rep: 59 | values: 60 | - 0 61 | - 1 62 | - 2 63 | dataset.meta_datasets.0.name: 64 | values: 65 | - Office31A 66 | - Office31D 67 | server.num_users: 68 | values: 69 | - 2 70 | dataset.meta_datasets.0.n_class: 71 | values: 72 | - 15 73 | - 31 74 | user.optimizer.learning_rate: 75 | values: 76 | - 0.01 77 | - 0.001 78 | user.adv_lambda: 79 | values: 80 | # - 1. 81 | - 0.1 82 | - 0.01 83 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/W2X_dann_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # 7 | # Preparation: 8 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 9 | # not affect the quality of source domain. 10 | # 2. Run `m2u_fuda_puser.sh to create the project. 11 | # 12 | # Runs: 13 | # - [1/28] https://wandb.ai/jyhong/dmtl-Office31_A2W_FUDA_fuda_vs_shot/sweeps/9bgsg4ur/overview 14 | # - https://wandb.ai/jyhong/dmtl-Office31_A2W_FUDA_fuda_vs_shot/sweeps/t0m402eg/overview 15 | name: W2X_dann_nuser 16 | project: fade-demo-Office31_X2X_UDA 17 | command: 18 | - ${interpreter} 19 | - -m 20 | - ${program} 21 | - logger.wandb.project=dmtl-Office31_X2X_UDA 22 | - dataset=comb/Office31_X2X_1s_3t 23 | - server=FedAdv 24 | - model=Office31_CnnSplitAdv 25 | - user=group_adv_office_uda 26 | - num_glob_iters=300 27 | # task 28 | - +eval_freq=30 29 | - logger.wandb.group='cdan-c${dataset.meta_datasets.0.n_class}-a${user.adv_lambda}-lr{user.optimizer.learning_rate}' 30 | - logger.wandb.name='r${i_rep}' 31 | - server.beta=.5 32 | - load_model.do=true 33 | - load_model.load=[server] 34 | - server.user_selection=random_uniform 35 | # data 36 | - dataset.meta_datasets.1.name=Office31W 37 | - +server.group_label_mode.Office31W=supervised 38 | - +server.group_label_mode.Office31A=unsupervised 39 | - +server.group_label_mode.Office31D=unsupervised 40 | - load_model.hash_name=Office31W/FedAvg/OfficeCnnSplit/0ff680f5368dcbe84087c270afefc79940dbd08b/g_0 41 | # model 42 | - user.adv_lambda=1. 43 | - model.freeze_backbone=False 44 | - model.freeze_decoder=False 45 | - model.rev_lambda_scale=1. 46 | - model.disable_bn_stat=True 47 | - model.bottleneck_type=dropout 48 | - user.group_loss=bce 49 | - user.relabel_coef=0. 50 | - ${args_no_hyphens} 51 | method: grid 52 | metric: 53 | goal: maximize 54 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 55 | parameters: 56 | i_rep: 57 | values: 58 | - 0 59 | - 1 60 | - 2 61 | dataset.meta_datasets.0.name: 62 | values: 63 | - Office31A 64 | # - Office31D 65 | user.group_loss: 66 | values: 67 | - bce 68 | # - sq_bce 69 | server.num_users: 70 | values: 71 | # - 1 72 | - 2 73 | # - 3 74 | # - 4 75 | dataset.meta_datasets.0.n_class: 76 | values: 77 | - 15 78 | - 31 79 | user.optimizer.learning_rate: 80 | values: 81 | - 0.01 82 | - 0.001 83 | user.adv_lambda: 84 | values: 85 | # - 1. 86 | - 0.1 87 | - 0.01 88 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/W2X_shot_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # 7 | # Preparation: 8 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 9 | # not affect the quality of source domain. 10 | # 2. Run `m2u_fuda_puser.sh to create the project. 11 | # 12 | # Runs: 13 | # - [1/28] https://wandb.ai/jyhong/dmtl-Office31_A2W_FUDA_fuda_vs_shot/sweeps/soz1s67k/overview 14 | # - https://wandb.ai/jyhong/dmtl-Office31_A2W_FUDA_fuda_vs_shot/sweeps/wd9mp9ap?workspace=user-jyhong 15 | name: W2X_shot_nuser 16 | project: fade-demo-Office31_X2X_UDA 17 | command: 18 | - ${interpreter} 19 | - -m 20 | - ${program} 21 | - logger.wandb.project=dmtl-Office31_X2X_UDA 22 | - dataset=comb/Office31_X2X_1s_3t 23 | - server=FedAdv 24 | - model=Office31_CnnSplitAdv 25 | - user=group_adv_office_uda 26 | - num_glob_iters=300 27 | # task 28 | - +eval_freq=30 29 | - logger.wandb.group='adv-shot-nuser-${dataset.meta_datasets.0.name}' 30 | - logger.wandb.name='r${i_rep}-shot-nuser-${dataset.meta_datasets.0.name}' 31 | - server.beta=.5 32 | - load_model.do=true 33 | - load_model.load=[server] 34 | - server.user_selection=random_uniform 35 | # data 36 | - dataset.meta_datasets.1.name=Office31W 37 | - +server.group_label_mode.Office31W=supervised 38 | - +server.group_label_mode.Office31A=self_supervised 39 | - +server.group_label_mode.Office31D=self_supervised 40 | - load_model.hash_name=Office31W/FedAvg/OfficeCnnSplit/1e6351a86a07eafba767541fb79aab07bb6eb01f/g_0 41 | # method 42 | - user.adv_lambda=0. 43 | - user.group_loss=none 44 | - model.freeze_backbone=False 45 | - model.freeze_decoder=True 46 | - model.rev_lambda_scale=0. 47 | - model.disable_bn_stat=True 48 | - model.bottleneck_type=bn 49 | - user.relabel_coef=0.1 50 | # 51 | - ${args_no_hyphens} 52 | method: grid 53 | metric: 54 | goal: maximize 55 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 56 | parameters: 57 | i_rep: 58 | values: 59 | - 0 60 | - 1 61 | - 2 62 | dataset.meta_datasets.0.name: 63 | values: 64 | - Office31A 65 | - Office31D 66 | server.num_users: 67 | values: 68 | - 1 69 | - 2 70 | - 3 71 | - 4 72 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/Office31_UDA/W_fedavg.sh: -------------------------------------------------------------------------------- 1 | # Goal: 2 | # Train CNN on OfficeHome65A dataset by Single-Task-Learning (STL). 3 | # All settings follow SHOT using central training. 4 | 5 | # TODO use MnistCnnSplit (w/o adv/task predictor) model. but be careful when load. 6 | kwargs=" 7 | dataset=Office31 8 | dataset.name=Office31W 9 | logger.wandb.project=dmtl-Office31A_STL 10 | name=stl 11 | model=OfficeCnnSplitAdv 12 | user=generic 13 | user.batch_size=64 14 | user.local_epochs=-1 15 | num_glob_iters=50 16 | n_rep=5 17 | hydra.launcher.n_jobs=3 18 | tele.enable=False 19 | server.num_users=-1 20 | +user.total_local_epochs=38 21 | +eval_freq=10 22 | " 23 | data_kwargs="" 24 | data_kwargs_multi=" 25 | dataset.n_user=1 26 | " 27 | #data_kwargs_multi_select=" 28 | #dataset.n_user=5 29 | #dataset.n_class=15 30 | #+dataset.class_stride=10 31 | #" 32 | data_kwargs_multi_select=" 33 | dataset.n_user=1 34 | " 35 | kwargs_multi=" 36 | user.optimizer.learning_rate=0.01,0.001 37 | user.optimizer.name=rmsprop 38 | i_rep=0,1,2,3,4 39 | server=FedAvg 40 | " 41 | # adam works better than rmsprop 42 | kwargs_multi_select=" 43 | user.optimizer.learning_rate=0.01 44 | user.optimizer.name=sgd_sch 45 | user.loss=sxent 46 | i_rep=0 47 | server=FedAvg 48 | server.beta=1. 49 | model.mid_dim=256 50 | model.backbone=resnet50 51 | model.freeze_backbone=False 52 | +model.bottleneck_type=bn 53 | " 54 | # +model.bottleneck_type=drop 55 | 56 | # NOTE: load_model.hash_name has to be updated after training. 57 | load_kwargs=" 58 | load_model.do=true 59 | load_model.load=[server] 60 | load_model.hash_name=Office31W/FedAvg/OfficeCnnSplit/1e6351a86a07eafba767541fb79aab07bb6eb01f/g_0 61 | " 62 | # dropout bottleneck: 0ff680f5368dcbe84087c270afefc79940dbd08b 63 | # bn battleneck: 1e6351a86a07eafba767541fb79aab07bb6eb01f 64 | #echo ===== 65 | #echo Generating dataset 66 | #echo ===== 67 | #echo 68 | ## python -m dmtl.data.sinusoid_mtl -cn MnistSTL --cfg job 69 | #python -m dmtl.data.extend -cn OfficeHome65 n_user=1 name=OfficeHome65A 70 | 71 | #echo ===== 72 | #echo Train 73 | #echo ===== 74 | #echo 75 | ## print the config 76 | #python -m dmtl.mainx $kwargs logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $kwargs_multi_select $data_kwargs_multi_select --cfg job 77 | ##echo 78 | ### run sweep 79 | ### python -m dmtl.mainx $kwargs $kwargs_multi logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $data_kwargs_multi -m 80 | ### Run only with the selected hparam 81 | ### TODO add `$load_kwargs` to fine-tune the model. 82 | #python -m dmtl.mainx $kwargs $data_kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $kwargs_multi_select 83 | # 84 | #echo ===== 85 | #echo Check generated files 86 | #echo ===== 87 | #echo 88 | #python -m dmtl.mainx $kwargs $kwargs_multi_select $data_kwargs $data_kwargs_multi_select logger.wandb.group='${server.name}-tr${dataset.tasks.1.total_tr_size}-u${dataset.tasks.1.n_user}-lr${user.optimizer.learning_rate}' action=check_files 89 | # 90 | #echo 91 | #echo ========================================== 92 | #echo "Now copy copy the path to the root of server.pt as the hash name." 93 | #echo ========================================== 94 | # 95 | ## TODO after updating the hash name, uncomment below and run evaluation to check the performance. 96 | 97 | #echo ===== 98 | #echo Eval on Office31A 99 | #echo ===== 100 | # 101 | #python -m dmtl.mainx $kwargs $data_kwargs_multi_select $load_kwargs $data_kwargs $kwargs_multi_select action=eval logger.loggers='[]' 102 | 103 | 104 | echo ===== 105 | echo Eval on Office31A 106 | echo ===== 107 | # generate data 108 | #echo python -m dmtl.data.extend -cn MnistM $data_kwargs $data_kwargs_multi_select dataset.name=MnistMSTL 109 | 110 | # TODO the train set may use random crop for augmentation. 111 | 112 | python -m dmtl.mainx $kwargs $kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $load_kwargs action=eval dataset.name=Office31A logger.loggers='[]' 113 | # $data_kwargs_multi_select 114 | 115 | 116 | echo ===== 117 | echo Eval on Office31D 118 | echo ===== 119 | # generate data 120 | #echo python -m dmtl.data.extend -cn MnistM $data_kwargs $data_kwargs_multi_select dataset.name=MnistMSTL 121 | 122 | # TODO the train set may use random crop for augmentation. 123 | 124 | python -m dmtl.mainx $kwargs $kwargs_multi_select logger.wandb.group='${server.name}-u${dataset.n_user}-lr${user.optimizer.learning_rate}' $data_kwargs $load_kwargs action=eval dataset.name=Office31D logger.loggers='[]' 125 | # $data_kwargs_multi_select -------------------------------------------------------------------------------- /sweeps/Office31_UDA/sweep_all.sh: -------------------------------------------------------------------------------- 1 | # NOTE only run this under `src/` 2 | 3 | # New sweeps: https://wandb.ai/jyhong/dmtl-Office31_X2X_UDA/sweeps 4 | 5 | # Old sweeps: https://wandb.ai/jyhong/dmtl-Office31_X2X_FUDA_fuda/sweeps 6 | # Sweep for Office31 1source vs 3 target (niid with target n_class=15): https://wandb.ai/jyhong/dmtl-Office31_X2X_FUDA_fuda/sweeps 7 | # Sweep for Office31 1source vs 3 target (niid with target n_class=31): 8 | # shotBNFX: 9 | 10 | echo current path: $(pwd) 11 | pwd=sweeps/Office31_UDA/*_sweep.yaml 12 | 13 | for f in $pwd 14 | do 15 | echo $f 16 | done 17 | 18 | echo "========================================" 19 | echo "TODO: Copy the print paths to all_files." 20 | echo "========================================" 21 | 22 | # TODO update the paths 23 | all_files=( 24 | # A2X 25 | sweeps/Office31_UDA/A2X_dann_nuser_sweep.yaml 26 | sweeps/Office31_UDA/A2X_cdan_nuser_sweep.yaml 27 | sweeps/Office31_UDA/A2X_shot_nuser_sweep.yaml 28 | ## D2X 29 | #sweeps/Office31_UDA/D2X_dann_nuser_sweep.yaml 30 | #sweeps/Office31_UDA/D2X_cdan_nuser_sweep.yaml 31 | #sweeps/Office31_UDA/D2X_shot_nuser_sweep.yaml 32 | ## W2X 33 | #sweeps/Office31_UDA/W2X_dann_nuser_sweep.yaml 34 | #sweeps/Office31_UDA/W2X_cdan_nuser_sweep.yaml 35 | #sweeps/Office31_UDA/W2X_shot_nuser_sweep.yaml 36 | ) 37 | 38 | #wandb sweep $(pwd) 39 | for f in ${all_files[@]} 40 | do 41 | echo "==== sweep $f ===" 42 | wandb sweep $f 43 | done -------------------------------------------------------------------------------- /sweeps/OfficeHome65_1to3_uda_iid/R2X_cdan_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # Update: 7 | # disable bn stat 8 | # 9 | # Preparation: 10 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 11 | # not affect the quality of source domain. 12 | # 2. Run `m2u_fuda_puser.sh to create the project. 13 | # 14 | # Runs: 15 | # - [1/29] 16 | name: R2X_cdan_nuser_iid 17 | project: fade-demo-OfficeHome65_X2X_1to3 18 | command: 19 | - ${interpreter} 20 | - -m 21 | - ${program} 22 | - logger.wandb.project=dmtl-OfficeHome65_X2X_1to3 23 | - dataset=comb/OfficeHome65_X2X_1s_3t 24 | - server=FedAdv 25 | - model=OfficeHome65CnnSplitAdv 26 | - user=group_adv_office_uda 27 | - num_glob_iters=500 28 | # 29 | - +eval_freq=30 30 | - logger.wandb.group='adv-shot-nuser-${dataset.meta_datasets.0.name}' 31 | - logger.wandb.name='r${i_rep}-cdan-shot-nuser-${dataset.meta_datasets.0.name}' 32 | - server.beta=.5 33 | - load_model.do=true 34 | - load_model.load=[server] 35 | - server.user_selection=random_uniform 36 | # 37 | - dataset.meta_datasets.1.name=OfficeHome65R 38 | - +server.group_label_mode.OfficeHome65R=supervised 39 | - +server.group_label_mode.OfficeHome65A=unsupervised 40 | - +server.group_label_mode.OfficeHome65C=unsupervised 41 | - +server.group_label_mode.OfficeHome65P=unsupervised 42 | - load_model.hash_name=OfficeHome65R/FedAvg/OfficeCnnSplit/f9ec4fd2818018611936f4e1f363d6b722c6d80a/g_0 43 | # 44 | - model.CDAN_task=True 45 | - user.adv_lambda=1. 46 | - model.freeze_backbone=False 47 | - model.freeze_decoder=False 48 | - model.rev_lambda_scale=1. 49 | - model.disable_bn_stat=True 50 | - model.bottleneck_type=dropout 51 | - user.group_loss=cdan 52 | - user.relabel_coef=0. 53 | #- +user.cluster_threshold=1. 54 | - ${args_no_hyphens} 55 | method: grid 56 | metric: 57 | goal: maximize 58 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 59 | parameters: 60 | i_rep: 61 | values: 62 | - 0 63 | - 1 64 | - 2 65 | dataset.meta_datasets.0.name: 66 | values: 67 | - OfficeHome65A 68 | - OfficeHome65C 69 | - OfficeHome65P 70 | server.num_users: 71 | values: 72 | # - 1 73 | - 2 74 | # - 4 75 | model.freeze_backbone: 76 | values: 77 | # - True 78 | - False 79 | dataset.meta_datasets.0.n_class: 80 | values: 81 | # - 25 # runed 82 | - 45 83 | - 65 # runed 84 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/OfficeHome65_1to3_uda_iid/R2X_dann_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # 7 | # Preparation: 8 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 9 | # not affect the quality of source domain. 10 | # 2. Run `m2u_fuda_puser.sh to create the project. 11 | # 12 | # Runs: 13 | # - [1/28] 14 | name: R2X_adv_nuser_iid 15 | project: fade-demo-OfficeHome65_X2X_1to3 16 | command: 17 | - ${interpreter} 18 | - -m 19 | - ${program} 20 | - logger.wandb.project=dmtl-OfficeHome65_X2X_1to3 21 | - dataset=comb/OfficeHome65_X2X_1s_3t 22 | - server=FedAdv 23 | - model=OfficeHome65CnnSplitAdv 24 | - user=group_adv_office_uda 25 | - num_glob_iters=500 26 | # 27 | - +eval_freq=30 28 | - logger.wandb.group='adv-shot-nuser-${dataset.meta_datasets.0.name}' 29 | - logger.wandb.name='r${i_rep}-adv-shot-nuser-${dataset.meta_datasets.0.name}' 30 | - server.beta=.5 31 | - load_model.do=true 32 | - load_model.load=[server] 33 | - server.user_selection=random_uniform 34 | # 35 | - dataset.meta_datasets.1.name=OfficeHome65R 36 | - +server.group_label_mode.OfficeHome65R=supervised 37 | - +server.group_label_mode.OfficeHome65A=unsupervised 38 | - +server.group_label_mode.OfficeHome65C=unsupervised 39 | - +server.group_label_mode.OfficeHome65P=unsupervised 40 | - load_model.hash_name=OfficeHome65R/FedAvg/OfficeCnnSplit/f9ec4fd2818018611936f4e1f363d6b722c6d80a/g_0 41 | # 42 | - user.adv_lambda=1. 43 | - model.freeze_backbone=False 44 | - model.freeze_decoder=False 45 | - model.rev_lambda_scale=1. 46 | - model.disable_bn_stat=True 47 | - model.bottleneck_type=dropout 48 | - user.group_loss=bce 49 | - user.relabel_coef=0. 50 | - +user.cluster_threshold=1. 51 | - ${args_no_hyphens} 52 | method: grid 53 | metric: 54 | goal: maximize 55 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 56 | parameters: 57 | i_rep: 58 | values: 59 | - 0 60 | - 1 61 | - 2 62 | dataset.meta_datasets.0.name: 63 | values: 64 | - OfficeHome65A 65 | - OfficeHome65C 66 | - OfficeHome65P 67 | # user.group_loss: 68 | # values: 69 | # - bce 70 | # - sq_bce 71 | server.num_users: 72 | values: 73 | # - 1 74 | - 2 75 | # - 4 76 | model.freeze_backbone: 77 | values: 78 | # - True 79 | - False 80 | dataset.meta_datasets.0.n_class: 81 | values: 82 | # - 25 # runed 83 | - 45 84 | - 65 # runed 85 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/OfficeHome65_1to3_uda_iid/R2X_shot_nuser_sweep.yaml: -------------------------------------------------------------------------------- 1 | # Sweep file for wandb tuning. M2U: Mnist -> USPS 2 | # Goal: Source domain include 5 non-iid users. 3 | # point). 4 | # Variable: prob of source-domain user in each global epoch 5 | # Metric: Train acc on the target domain. 6 | # 7 | # Preparation: 8 | # 1. Run `bash dmtl/experiments/OfficeHome65a_fedavg_5user_niid.sh` to pretrain the source domain model s.t. the varying p_src will 9 | # not affect the quality of source domain. 10 | # 2. Run `m2u_fuda_puser.sh to create the project. 11 | # 12 | # Runs: 13 | # - [1/28] 14 | name: R2X_shot_nuser_iid 15 | project: fade-demo-OfficeHome65_X2X_1to3 16 | command: 17 | - ${interpreter} 18 | - -m 19 | - ${program} 20 | - logger.wandb.project=dmtl-OfficeHome65_X2X_1to3 21 | - dataset=comb/OfficeHome65_X2X_1s_3t 22 | - server=FedAdv 23 | - model=OfficeHome65CnnSplitAdv 24 | - user=group_adv_office_uda 25 | - num_glob_iters=500 26 | # 27 | - +eval_freq=30 28 | - logger.wandb.group='adv-shot-nuser-${dataset.meta_datasets.0.name}' 29 | - logger.wandb.name='r${i_rep}-shot-nuser-${dataset.meta_datasets.0.name}' 30 | - server.beta=.5 31 | - load_model.do=true 32 | - load_model.load=[server] 33 | - server.user_selection=random_uniform 34 | # 35 | - dataset.meta_datasets.1.name=OfficeHome65R 36 | - +server.group_label_mode.OfficeHome65R=supervised 37 | - +server.group_label_mode.OfficeHome65A=self_supervised 38 | - +server.group_label_mode.OfficeHome65C=self_supervised 39 | - +server.group_label_mode.OfficeHome65P=self_supervised 40 | - load_model.hash_name=OfficeHome65R/FedAvg/OfficeCnnSplit/c920a6b39b93732d1bab471e906c98545205eb6d/g_0 41 | # 42 | - user.adv_lambda=0. 43 | - user.group_loss=none 44 | - model.freeze_backbone=False 45 | - model.freeze_decoder=True 46 | - model.rev_lambda_scale=0. 47 | - model.disable_bn_stat=False 48 | - model.bottleneck_type=bn 49 | - user.relabel_coef=0.1 50 | - +user.cluster_threshold=1. 51 | # 52 | - ${args_no_hyphens} 53 | method: grid 54 | metric: 55 | goal: maximize 56 | name: g_test_acc # need to change at GUI according to dataset.meta_datasets.0.name 57 | parameters: 58 | i_rep: 59 | values: 60 | - 0 61 | - 1 62 | - 2 63 | dataset.meta_datasets.0.name: 64 | values: 65 | - OfficeHome65A 66 | - OfficeHome65C 67 | - OfficeHome65P 68 | server.num_users: 69 | values: 70 | # - 1 71 | - 2 72 | # - 4 73 | dataset.meta_datasets.0.n_class: 74 | values: 75 | # - 25 # runed 76 | - 45 77 | - 65 # runed 78 | program: dmtl.mainx -------------------------------------------------------------------------------- /sweeps/OfficeHome65_1to3_uda_iid/sweep_all.sh: -------------------------------------------------------------------------------- 1 | # NOTE only run this under `src/` 2 | # Pretrain using Office65_1to3_uda/*fedavg.sh 3 | 4 | # Sweep for Office31 1source vs 3 target with iid users: https://wandb.ai/jyhong/dmtl-OfficeHome65_X2X_1to3/sweeps 5 | 6 | echo current path: $(pwd) 7 | pwd=dmtl/experiments/OfficeHome65_1to3_uda_iid/*_sweep.yaml 8 | 9 | for f in $pwd 10 | do 11 | echo $f 12 | done 13 | 14 | echo "========================================" 15 | echo "TODO: Copy the print paths to all_files." 16 | echo "========================================" 17 | 18 | ## TODO update the paths 19 | all_files=( 20 | dmtl/experiments/OfficeHome65_1to3_uda_iid/R2X_dann_nuser_sweep.yaml 21 | dmtl/experiments/OfficeHome65_1to3_uda_iid/R2X_cdan_nuser_sweep.yaml 22 | dmtl/experiments/OfficeHome65_1to3_uda_iid/R2X_shot_nuser_sweep.yaml 23 | ) 24 | 25 | #wandb sweep $(pwd) 26 | for f in ${all_files[@]} 27 | do 28 | echo "==== sweep $f ===" 29 | wandb sweep $f 30 | done --------------------------------------------------------------------------------