├── .DS_Store ├── README.md ├── code ├── .gitignore ├── Dassl.pytorch │ ├── .flake8 │ ├── .gitignore │ ├── .isort.cfg │ ├── .style.yapf │ ├── DATASETS.md │ ├── LICENSE │ ├── README.md │ ├── configs │ │ ├── README.md │ │ ├── datasets │ │ │ ├── da │ │ │ │ ├── cifar_stl.yaml │ │ │ │ ├── digit5.yaml │ │ │ │ ├── domainnet.yaml │ │ │ │ ├── mini_domainnet.yaml │ │ │ │ ├── office31.yaml │ │ │ │ ├── office_home.yaml │ │ │ │ └── visda17.yaml │ │ │ ├── dg │ │ │ │ ├── camelyon17.yaml │ │ │ │ ├── cifar100_c.yaml │ │ │ │ ├── cifar10_c.yaml │ │ │ │ ├── digit_single.yaml │ │ │ │ ├── digits_dg.yaml │ │ │ │ ├── fmow.yaml │ │ │ │ ├── iwildcam.yaml │ │ │ │ ├── office_home_dg.yaml │ │ │ │ ├── pacs.yaml │ │ │ │ └── vlcs.yaml │ │ │ └── ssl │ │ │ │ ├── cifar10.yaml │ │ │ │ ├── cifar100.yaml │ │ │ │ ├── stl10.yaml │ │ │ │ └── svhn.yaml │ │ └── trainers │ │ │ ├── da │ │ │ ├── cdac │ │ │ │ ├── digit5.yaml │ │ │ │ ├── domainnet.yaml │ │ │ │ └── mini_domainnet.yaml │ │ │ ├── dael │ │ │ │ ├── digit5.yaml │ │ │ │ ├── domainnet.yaml │ │ │ │ └── mini_domainnet.yaml │ │ │ ├── m3sda │ │ │ │ ├── digit5.yaml │ │ │ │ ├── domainnet.yaml │ │ │ │ └── mini_domainnet.yaml │ │ │ └── source_only │ │ │ │ ├── digit5.yaml │ │ │ │ ├── mini_domainnet.yaml │ │ │ │ ├── office31.yaml │ │ │ │ └── visda17.yaml │ │ │ ├── dg │ │ │ ├── daeldg │ │ │ │ ├── digits_dg.yaml │ │ │ │ ├── office_home_dg.yaml │ │ │ │ └── pacs.yaml │ │ │ ├── ddaig │ │ │ │ ├── digits_dg.yaml │ │ │ │ ├── office_home_dg.yaml │ │ │ │ └── pacs.yaml │ │ │ └── vanilla │ │ │ │ ├── digits_dg.yaml │ │ │ │ ├── mini_domainnet.yaml │ │ │ │ ├── office_home_dg.yaml │ │ │ │ └── pacs.yaml │ │ │ └── ssl │ │ │ └── fixmatch │ │ │ └── cifar10.yaml │ ├── dassl │ │ ├── __init__.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ └── defaults.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── data_manager.py │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ ├── base_dataset.py │ │ │ │ ├── build.py │ │ │ │ ├── da │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cifarstl.py │ │ │ │ │ ├── digit5.py │ │ │ │ │ ├── domainnet.py │ │ │ │ │ ├── mini_domainnet.py │ │ │ │ │ ├── office31.py │ │ │ │ │ ├── office_home.py │ │ │ │ │ └── visda17.py │ │ │ │ ├── dg │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cifar_c.py │ │ │ │ │ ├── digit_single.py │ │ │ │ │ ├── digits_dg.py │ │ │ │ │ ├── office_home_dg.py │ │ │ │ │ ├── pacs.py │ │ │ │ │ ├── vlcs.py │ │ │ │ │ └── wilds │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── camelyon17.py │ │ │ │ │ │ ├── fmow.py │ │ │ │ │ │ ├── iwildcam.py │ │ │ │ │ │ └── wilds_base.py │ │ │ │ └── ssl │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cifar.py │ │ │ │ │ ├── stl10.py │ │ │ │ │ └── svhn.py │ │ │ ├── samplers.py │ │ │ └── transforms │ │ │ │ ├── __init__.py │ │ │ │ ├── autoaugment.py │ │ │ │ ├── randaugment.py │ │ │ │ └── transforms.py │ │ ├── engine │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── da │ │ │ │ ├── __init__.py │ │ │ │ ├── adabn.py │ │ │ │ ├── adda.py │ │ │ │ ├── cdac.py │ │ │ │ ├── dael.py │ │ │ │ ├── dann.py │ │ │ │ ├── m3sda.py │ │ │ │ ├── mcd.py │ │ │ │ ├── mme.py │ │ │ │ ├── se.py │ │ │ │ └── source_only.py │ │ │ ├── dg │ │ │ │ ├── __init__.py │ │ │ │ ├── crossgrad.py │ │ │ │ ├── daeldg.py │ │ │ │ ├── ddaig.py │ │ │ │ ├── domain_mix.py │ │ │ │ └── vanilla.py │ │ │ ├── ssl │ │ │ │ ├── __init__.py │ │ │ │ ├── entmin.py │ │ │ │ ├── fixmatch.py │ │ │ │ ├── mean_teacher.py │ │ │ │ ├── mixmatch.py │ │ │ │ └── sup_baseline.py │ │ │ └── trainer.py │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── evaluator.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── accuracy.py │ │ │ └── distance.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── backbone │ │ │ │ ├── __init__.py │ │ │ │ ├── alexnet.py │ │ │ │ ├── backbone.py │ │ │ │ ├── build.py │ │ │ │ ├── cnn_digit5_m3sda.py │ │ │ │ ├── cnn_digitsdg.py │ │ │ │ ├── cnn_digitsingle.py │ │ │ │ ├── efficientnet │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── model.py │ │ │ │ │ └── utils.py │ │ │ │ ├── preact_resnet18.py │ │ │ │ ├── resnet.py │ │ │ │ ├── resnet_dynamic.py │ │ │ │ ├── vgg.py │ │ │ │ └── wide_resnet.py │ │ │ ├── head │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ └── mlp.py │ │ │ ├── network │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ └── ddaig_fcn.py │ │ │ └── ops │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── conv.py │ │ │ │ ├── cross_entropy.py │ │ │ │ ├── dsbn.py │ │ │ │ ├── efdmix.py │ │ │ │ ├── mixstyle.py │ │ │ │ ├── mixup.py │ │ │ │ ├── mmd.py │ │ │ │ ├── optimal_transport.py │ │ │ │ ├── reverse_grad.py │ │ │ │ ├── sequential2.py │ │ │ │ ├── transnorm.py │ │ │ │ └── utils.py │ │ ├── optim │ │ │ ├── __init__.py │ │ │ ├── lr_scheduler.py │ │ │ ├── optimizer.py │ │ │ └── radam.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── logger.py │ │ │ ├── meters.py │ │ │ ├── registry.py │ │ │ ├── tools.py │ │ │ └── torchtools.py │ ├── datasets │ │ ├── da │ │ │ ├── cifar_stl.py │ │ │ ├── digit5.py │ │ │ └── visda17.sh │ │ ├── dg │ │ │ └── cifar_c.py │ │ └── ssl │ │ │ ├── cifar10_cifar100_svhn.py │ │ │ └── stl10.py │ ├── linter.sh │ ├── requirements.txt │ ├── setup.py │ └── tools │ │ ├── parse_test_res.py │ │ ├── replace_text.py │ │ └── train.py ├── LEGAL.md ├── LICENSE ├── README.md ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── configs │ ├── datasets │ │ ├── caltech101.yaml │ │ ├── dtd.yaml │ │ ├── eurosat.yaml │ │ ├── fgvc_aircraft.yaml │ │ ├── food101.yaml │ │ ├── imagenet.yaml │ │ ├── imagenet_a.yaml │ │ ├── imagenet_r.yaml │ │ ├── imagenet_sketch.yaml │ │ ├── imagenetv2.yaml │ │ ├── oxford_flowers.yaml │ │ ├── oxford_pets.yaml │ │ ├── stanford_cars.yaml │ │ ├── sun397.yaml │ │ └── ucf101.yaml │ └── trainers │ │ ├── CoCoOp │ │ ├── rn50_b16_c4_ep10_batch1.yaml │ │ ├── vit_b16_c16_ep10_batch1.yaml │ │ ├── vit_b16_c4_ep10_batch1.yaml │ │ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml │ │ └── vit_b16_c8_ep10_batch1.yaml │ │ ├── CoOp │ │ ├── rn101.yaml │ │ ├── rn101_ep50.yaml │ │ ├── rn50.yaml │ │ ├── rn50_batch16.yaml │ │ ├── rn50_ctxv1.yaml │ │ ├── rn50_ep100.yaml │ │ ├── rn50_ep50.yaml │ │ ├── rn50_ep50_ctxv1.yaml │ │ ├── rn50_val.yaml │ │ ├── vit_b16.yaml │ │ ├── vit_b16_ctxv1.yaml │ │ ├── vit_b16_ep100.yaml │ │ ├── vit_b16_ep100_ctxv1.yaml │ │ ├── vit_b16_ep50.yaml │ │ ├── vit_b16_ep50_ctxv1.yaml │ │ ├── vit_b16_val.yaml │ │ ├── vit_b32.yaml │ │ └── vit_b32_ep50.yaml │ │ └── MASK │ │ ├── imagenet_vit_b16.yaml │ │ └── vit_b16.yaml ├── datasets │ ├── __init__.py │ ├── caltech101.py │ ├── dtd.py │ ├── eurosat.py │ ├── fgvc_aircraft.py │ ├── food101.py │ ├── imagenet.py │ ├── imagenet_a.py │ ├── imagenet_r.py │ ├── imagenet_sketch.py │ ├── imagenetv2.py │ ├── oxford_flowers.py │ ├── oxford_pets.py │ ├── stanford_cars.py │ ├── sun397.py │ └── ucf101.py ├── loss │ └── reg_loss.py ├── modules │ ├── __init__.py │ ├── masklayers.py │ ├── resnet.py │ └── visiontransformer.py ├── requirements.txt ├── scripts │ ├── cocoop │ │ ├── eval.sh │ │ └── main.sh │ ├── coop │ │ ├── eval_imagenet.sh │ │ ├── main.sh │ │ ├── main_imagenet.sh │ │ ├── main_oxford_pets.sh │ │ └── zeroshot.sh │ └── masktuning │ │ ├── .main.sh.swp │ │ ├── base2new_eval.sh │ │ ├── base2new_train.sh │ │ ├── eval.sh │ │ └── train.sh ├── train.py └── trainers │ ├── __init__.py │ ├── cocoop.py │ ├── coop.py │ ├── imagenet_templates.py │ ├── masktuning.py │ └── zsclip.py ├── figures ├── clip_mask.jpg └── init.txt ├── index.html └── static ├── css ├── bulma-carousel.min.css ├── bulma-slider.min.css ├── bulma.css.map.txt ├── bulma.min.css ├── fontawesome.all.min.css └── index.css ├── images ├── Pipeline.jpg ├── exp.png └── teaser1-1.jpg └── js ├── bulma-carousel.js ├── bulma-carousel.min.js ├── bulma-slider.js ├── bulma-slider.min.js ├── fontawesome.all.min.js ├── index.js └── video_comparison.js /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/.DS_Store -------------------------------------------------------------------------------- /code/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Custom 132 | output/ 133 | debug.sh 134 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # At least two spaces before inline comment 4 | E261, 5 | # Line lengths are recommended to be no greater than 79 characters 6 | E501, 7 | # Missing whitespace around arithmetic operator 8 | E226, 9 | # Blank line contains whitespace 10 | W293, 11 | # Do not use bare 'except' 12 | E722, 13 | # Line break after binary operator 14 | W504, 15 | # Too many leading '#' for block comment 16 | E266, 17 | # Line break before binary operator 18 | W503, 19 | # Continuation line over-indented for hanging indent 20 | E126, 21 | # Module level import not at top of file 22 | E402 23 | max-line-length = 79 24 | exclude = __init__.py, build -------------------------------------------------------------------------------- /code/Dassl.pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # OS X 132 | .DS_Store 133 | .Spotlight-V100 134 | .Trashes 135 | ._* 136 | 137 | # This project 138 | output/ 139 | debug/ 140 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=79 3 | multi_line_output=6 4 | length_sort=true 5 | known_standard_library=numpy,setuptools 6 | known_myself=dassl 7 | known_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs,scipy,gdown 8 | no_lines_before=STDLIB,THIRDPARTY 9 | sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER 10 | default_section=FIRSTPARTY -------------------------------------------------------------------------------- /code/Dassl.pytorch/.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | BASED_ON_STYLE = pep8 3 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 4 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 5 | DEDENT_CLOSING_BRACKETS = true 6 | SPACES_BEFORE_COMMENT = 2 7 | ARITHMETIC_PRECEDENCE_INDICATION = true -------------------------------------------------------------------------------- /code/Dassl.pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kaiyang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/README.md: -------------------------------------------------------------------------------- 1 | The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, data augmentation, network architecture) used by most papers. The `trainers/` folder contains method-specific config files which define optimization algorithms (e.g., optimizer, epoch) and hyperparameter settings. 2 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/da/cifar_stl.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | PIXEL_MEAN: [0.5, 0.5, 0.5] 4 | PIXEL_STD: [0.5, 0.5, 0.5] 5 | 6 | DATASET: 7 | NAME: "CIFARSTL" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/da/digit5.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | PIXEL_MEAN: [0.5, 0.5, 0.5] 4 | PIXEL_STD: [0.5, 0.5, 0.5] 5 | TRANSFORMS: ["normalize"] 6 | 7 | DATASET: 8 | NAME: "Digit5" 9 | 10 | MODEL: 11 | BACKBONE: 12 | NAME: "cnn_digit5_m3sda" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/da/domainnet.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "DomainNet" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet101" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/da/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (96, 96) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "miniDomainNet" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet18" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/da/office31.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "Office31" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet50" 11 | HEAD: 12 | NAME: "mlp" 13 | HIDDEN_LAYERS: [256] 14 | DROPOUT: 0. -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/da/office_home.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | 4 | DATASET: 5 | NAME: "OfficeHome" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/da/visda17.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "center_crop", "normalize"] 4 | 5 | DATASET: 6 | NAME: "VisDA17" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet101" 11 | 12 | TEST: 13 | PER_CLASS_RESULT: True -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/camelyon17.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 4 | 5 | DATASET: 6 | NAME: "Camelyon17" 7 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/cifar100_c.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "CIFAR100C" 9 | CIFAR_C_TYPE: "fog" 10 | CIFAR_C_LEVEL: 5 11 | 12 | MODEL: 13 | BACKBONE: 14 | NAME: "wide_resnet_16_4" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/cifar10_c.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "CIFAR10C" 9 | CIFAR_C_TYPE: "fog" 10 | CIFAR_C_LEVEL: 5 11 | 12 | MODEL: 13 | BACKBONE: 14 | NAME: "wide_resnet_16_4" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/digit_single.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "DigitSingle" 9 | 10 | MODEL: 11 | BACKBONE: 12 | NAME: "cnn_digitsingle" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/digits_dg.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "DigitsDG" 9 | 10 | MODEL: 11 | BACKBONE: 12 | NAME: "cnn_digitsdg" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/fmow.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 4 | 5 | DATASET: 6 | NAME: "FMoW" 7 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/iwildcam.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 4 | 5 | DATASET: 6 | NAME: "IWildCam" 7 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/office_home_dg.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "OfficeHomeDG" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet18" 11 | PRETRAINED: True -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/pacs.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "PACS" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet18" 11 | PRETRAINED: True -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/dg/vlcs.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "VLCS" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet18" 11 | PRETRAINED: True -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/ssl/cifar10.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "CIFAR10" 9 | NUM_LABELED: 4000 10 | VAL_PERCENT: 0. 11 | 12 | MODEL: 13 | BACKBONE: 14 | NAME: "wide_resnet_28_2" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/ssl/cifar100.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | CROP_PADDING: 4 7 | 8 | DATASET: 9 | NAME: "CIFAR100" 10 | NUM_LABELED: 10000 11 | VAL_PERCENT: 0. 12 | 13 | MODEL: 14 | BACKBONE: 15 | NAME: "wide_resnet_28_2" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/ssl/stl10.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (96, 96) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | CROP_PADDING: 4 7 | 8 | DATASET: 9 | NAME: "STL10" 10 | STL10_FOLD: 0 11 | 12 | MODEL: 13 | BACKBONE: 14 | NAME: "wide_resnet_28_2" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/datasets/ssl/svhn.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | CROP_PADDING: 4 7 | 8 | DATASET: 9 | NAME: "SVHN" 10 | NUM_LABELED: 1000 11 | VAL_PERCENT: 0. 12 | 13 | MODEL: 14 | BACKBONE: 15 | NAME: "wide_resnet_28_2" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/cdac/digit5.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomSampler" 4 | BATCH_SIZE: 64 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 192 8 | TEST: 9 | BATCH_SIZE: 256 10 | K_TRANSFORMS: 2 11 | 12 | OPTIM: 13 | NAME: "sgd" 14 | LR: 0.001 15 | MAX_EPOCH: 90 16 | RAMPUP_ITRS: 10000 17 | 18 | TRAINER: 19 | CDAC: 20 | STRONG_TRANSFORMS: ["randaugment", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/cdac/domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 6 8 | TEST: 9 | BATCH_SIZE: 30 10 | K_TRANSFORMS: 2 11 | 12 | OPTIM: 13 | NAME: "sgd" 14 | LR: 0.001 15 | MAX_EPOCH: 90 16 | RAMPUP_ITRS: 10000 17 | 18 | TRAINER: 19 | CDAC: 20 | STRONG_TRANSFORMS: ["randaugment", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/cdac/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 64 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 192 8 | TEST: 9 | BATCH_SIZE: 200 10 | K_TRANSFORMS: 2 11 | 12 | OPTIM: 13 | NAME: "sgd" 14 | LR: 0.001 15 | MAX_EPOCH: 60 16 | RAMPUP_ITRS: 10000 17 | LR_SCHEDULER: "cosine" 18 | 19 | TRAINER: 20 | CDAC: 21 | STRONG_TRANSFORMS: ["randaugment", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/dael/digit5.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 256 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 64 8 | TEST: 9 | BATCH_SIZE: 256 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.05 14 | STEPSIZE: [30] 15 | MAX_EPOCH: 30 16 | LR_SCHEDULER: "cosine" 17 | 18 | TRAINER: 19 | DAEL: 20 | STRONG_TRANSFORMS: ["randaugment2", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/dael/domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 6 8 | TEST: 9 | BATCH_SIZE: 30 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.002 14 | MAX_EPOCH: 40 15 | LR_SCHEDULER: "cosine" 16 | 17 | TRAINER: 18 | DAEL: 19 | STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/dael/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 192 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 64 8 | TEST: 9 | BATCH_SIZE: 200 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.005 14 | MAX_EPOCH: 60 15 | LR_SCHEDULER: "cosine" 16 | 17 | TRAINER: 18 | DAEL: 19 | STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/m3sda/digit5.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 256 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 64 8 | TEST: 9 | BATCH_SIZE: 256 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.05 14 | STEPSIZE: [30] 15 | MAX_EPOCH: 30 16 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/m3sda/domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 6 8 | TEST: 9 | BATCH_SIZE: 30 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.002 14 | MAX_EPOCH: 40 15 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/m3sda/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 192 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 64 8 | TEST: 9 | BATCH_SIZE: 200 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.005 14 | MAX_EPOCH: 60 15 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/source_only/digit5.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 256 4 | TEST: 5 | BATCH_SIZE: 256 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.05 10 | STEPSIZE: [30] 11 | MAX_EPOCH: 30 12 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/source_only/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 128 4 | TEST: 5 | BATCH_SIZE: 128 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.005 10 | MAX_EPOCH: 60 11 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/source_only/office31.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 32 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.002 10 | STEPSIZE: [20] 11 | MAX_EPOCH: 20 -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/da/source_only/visda17.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 32 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.0001 10 | STEPSIZE: [2] 11 | MAX_EPOCH: 2 12 | 13 | TRAIN: 14 | PRINT_FREQ: 50 15 | COUNT_ITER: "train_u" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/daeldg/digits_dg.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 120 5 | TEST: 6 | BATCH_SIZE: 100 7 | 8 | OPTIM: 9 | NAME: "sgd" 10 | LR: 0.05 11 | STEPSIZE: [20] 12 | MAX_EPOCH: 50 13 | 14 | TRAINER: 15 | DAELDG: 16 | STRONG_TRANSFORMS: ["randaugment2", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/daeldg/office_home_dg.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TEST: 6 | BATCH_SIZE: 100 7 | 8 | OPTIM: 9 | NAME: "sgd" 10 | LR: 0.002 11 | MAX_EPOCH: 40 12 | LR_SCHEDULER: "cosine" 13 | 14 | TRAINER: 15 | DAELDG: 16 | STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/daeldg/pacs.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TEST: 6 | BATCH_SIZE: 100 7 | 8 | OPTIM: 9 | NAME: "sgd" 10 | LR: 0.002 11 | MAX_EPOCH: 40 12 | LR_SCHEDULER: "cosine" 13 | 14 | TRAINER: 15 | DAELDG: 16 | STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/ddaig/digits_dg.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | PIXEL_MEAN: [0., 0., 0.] 3 | PIXEL_STD: [1., 1., 1.] 4 | 5 | DATALOADER: 6 | TRAIN_X: 7 | BATCH_SIZE: 128 8 | TEST: 9 | BATCH_SIZE: 128 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.05 14 | STEPSIZE: [20] 15 | MAX_EPOCH: 50 16 | 17 | TRAINER: 18 | DDAIG: 19 | G_ARCH: "fcn_3x32_gctx" 20 | LMDA: 0.3 -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/ddaig/office_home_dg.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | PIXEL_MEAN: [0., 0., 0.] 3 | PIXEL_STD: [1., 1., 1.] 4 | 5 | DATALOADER: 6 | TRAIN_X: 7 | BATCH_SIZE: 16 8 | TEST: 9 | BATCH_SIZE: 16 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.0005 14 | STEPSIZE: [20] 15 | MAX_EPOCH: 25 16 | 17 | TRAINER: 18 | DDAIG: 19 | G_ARCH: "fcn_3x64_gctx" 20 | WARMUP: 3 21 | LMDA: 0.3 -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/ddaig/pacs.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | PIXEL_MEAN: [0., 0., 0.] 3 | PIXEL_STD: [1., 1., 1.] 4 | 5 | DATALOADER: 6 | TRAIN_X: 7 | BATCH_SIZE: 16 8 | TEST: 9 | BATCH_SIZE: 16 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.0005 14 | STEPSIZE: [20] 15 | MAX_EPOCH: 25 16 | 17 | TRAINER: 18 | DDAIG: 19 | G_ARCH: "fcn_3x64_gctx" 20 | WARMUP: 3 21 | LMDA: 0.3 -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/vanilla/digits_dg.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 128 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | OPTIM: 9 | NAME: "sgd" 10 | LR: 0.05 11 | STEPSIZE: [20] 12 | MAX_EPOCH: 50 13 | 14 | TRAIN: 15 | PRINT_FREQ: 20 -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/vanilla/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 128 4 | TEST: 5 | BATCH_SIZE: 128 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.005 10 | MAX_EPOCH: 60 11 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/vanilla/office_home_dg.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 64 4 | TEST: 5 | BATCH_SIZE: 100 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.001 10 | MAX_EPOCH: 50 11 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/dg/vanilla/pacs.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 64 4 | TEST: 5 | BATCH_SIZE: 100 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.001 10 | MAX_EPOCH: 50 11 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /code/Dassl.pytorch/configs/trainers/ssl/fixmatch/cifar10.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 64 4 | TRAIN_U: 5 | SAME_AS_X: False 6 | BATCH_SIZE: 448 7 | TEST: 8 | BATCH_SIZE: 500 9 | 10 | OPTIM: 11 | NAME: "sgd" 12 | LR: 0.05 13 | STEPSIZE: [4000] 14 | MAX_EPOCH: 4000 15 | LR_SCHEDULER: "cosine" 16 | 17 | TRAIN: 18 | COUNT_ITER: "train_u" 19 | PRINT_FREQ: 10 20 | 21 | TRAINER: 22 | FIXMATCH: 23 | STRONG_TRANSFORMS: ["random_flip", "randaugment_fixmatch", "normalize", "cutout"] -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dassl 3 | ------ 4 | PyTorch toolbox for domain adaptation and semi-supervised learning. 5 | 6 | URL: https://github.com/KaiyangZhou/Dassl.pytorch 7 | 8 | @article{zhou2020domain, 9 | title={Domain Adaptive Ensemble Learning}, 10 | author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao}, 11 | journal={arXiv preprint arXiv:2003.07325}, 12 | year={2020} 13 | } 14 | """ 15 | 16 | __version__ = "0.6.3" 17 | __author__ = "Kaiyang Zhou" 18 | __homepage__ = "https://kaiyangzhou.github.io/" 19 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg_default 2 | 3 | 4 | def get_cfg_default(): 5 | return cfg_default.clone() 6 | 7 | 8 | def clean_cfg(cfg, trainer): 9 | """Remove unused trainers (configs). 10 | 11 | Aim: Only show relevant information when calling print(cfg). 12 | 13 | Args: 14 | cfg (_C): cfg instance. 15 | trainer (str): trainer name. 16 | """ 17 | keys = list(cfg.TRAINER.keys()) 18 | for key in keys: 19 | if key == "NAME" or key == trainer.upper(): 20 | continue 21 | cfg.TRAINER.pop(key, None) 22 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_manager import DataManager, DatasetWrapper 2 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import DATASET_REGISTRY, build_dataset # isort:skip 2 | from .base_dataset import Datum, DatasetBase # isort:skip 3 | 4 | from .da import * 5 | from .dg import * 6 | from .ssl import * 7 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | DATASET_REGISTRY = Registry("DATASET") 4 | 5 | 6 | def build_dataset(cfg): 7 | avai_datasets = DATASET_REGISTRY.registered_names() 8 | check_availability(cfg.DATASET.NAME, avai_datasets) 9 | if cfg.VERBOSE: 10 | print("Loading dataset: {}".format(cfg.DATASET.NAME)) 11 | return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg) 12 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/da/__init__.py: -------------------------------------------------------------------------------- 1 | from .digit5 import Digit5 2 | from .visda17 import VisDA17 3 | from .cifarstl import CIFARSTL 4 | from .office31 import Office31 5 | from .domainnet import DomainNet 6 | from .office_home import OfficeHome 7 | from .mini_domainnet import miniDomainNet 8 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/da/cifarstl.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class CIFARSTL(DatasetBase): 11 | """CIFAR-10 and STL-10. 12 | 13 | CIFAR-10: 14 | - 60,000 32x32 colour images. 15 | - 10 classes, with 6,000 images per class. 16 | - 50,000 training images and 10,000 test images. 17 | - URL: https://www.cs.toronto.edu/~kriz/cifar.html. 18 | 19 | STL-10: 20 | - 10 classes: airplane, bird, car, cat, deer, dog, horse, 21 | monkey, ship, truck. 22 | - Images are 96x96 pixels, color. 23 | - 500 training images (10 pre-defined folds), 800 test images 24 | per class. 25 | - URL: https://cs.stanford.edu/~acoates/stl10/. 26 | 27 | Reference: 28 | - Krizhevsky. Learning Multiple Layers of Features 29 | from Tiny Images. Tech report. 30 | - Coates et al. An Analysis of Single Layer Networks in 31 | Unsupervised Feature Learning. AISTATS 2011. 32 | """ 33 | 34 | dataset_dir = "cifar_stl" 35 | domains = ["cifar", "stl"] 36 | 37 | def __init__(self, cfg): 38 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 39 | self.dataset_dir = osp.join(root, self.dataset_dir) 40 | 41 | self.check_input_domains( 42 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 43 | ) 44 | 45 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 46 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 47 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 48 | 49 | super().__init__(train_x=train_x, train_u=train_u, test=test) 50 | 51 | def _read_data(self, input_domains, split="train"): 52 | items = [] 53 | 54 | for domain, dname in enumerate(input_domains): 55 | data_dir = osp.join(self.dataset_dir, dname, split) 56 | class_names = listdir_nohidden(data_dir) 57 | 58 | for class_name in class_names: 59 | class_dir = osp.join(data_dir, class_name) 60 | imnames = listdir_nohidden(class_dir) 61 | label = int(class_name.split("_")[0]) 62 | 63 | for imname in imnames: 64 | impath = osp.join(class_dir, imname) 65 | item = Datum(impath=impath, label=label, domain=domain) 66 | items.append(item) 67 | 68 | return items 69 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/da/domainnet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class DomainNet(DatasetBase): 9 | """DomainNet. 10 | 11 | Statistics: 12 | - 6 distinct domains: Clipart, Infograph, Painting, Quickdraw, 13 | Real, Sketch. 14 | - Around 0.6M images. 15 | - 345 categories. 16 | - URL: http://ai.bu.edu/M3SDA/. 17 | 18 | Special note: the t-shirt class (327) is missing in painting_train.txt. 19 | 20 | Reference: 21 | - Peng et al. Moment Matching for Multi-Source Domain 22 | Adaptation. ICCV 2019. 23 | """ 24 | 25 | dataset_dir = "domainnet" 26 | domains = [ 27 | "clipart", "infograph", "painting", "quickdraw", "real", "sketch" 28 | ] 29 | 30 | def __init__(self, cfg): 31 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 32 | self.dataset_dir = osp.join(root, self.dataset_dir) 33 | self.split_dir = osp.join(self.dataset_dir, "splits") 34 | 35 | self.check_input_domains( 36 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 37 | ) 38 | 39 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 40 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 41 | val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test") 42 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 43 | 44 | super().__init__(train_x=train_x, train_u=train_u, val=val, test=test) 45 | 46 | def _read_data(self, input_domains, split="train"): 47 | items = [] 48 | 49 | for domain, dname in enumerate(input_domains): 50 | filename = dname + "_" + split + ".txt" 51 | split_file = osp.join(self.split_dir, filename) 52 | 53 | with open(split_file, "r") as f: 54 | lines = f.readlines() 55 | for line in lines: 56 | line = line.strip() 57 | impath, label = line.split(" ") 58 | classname = impath.split("/")[1] 59 | impath = osp.join(self.dataset_dir, impath) 60 | label = int(label) 61 | item = Datum( 62 | impath=impath, 63 | label=label, 64 | domain=domain, 65 | classname=classname 66 | ) 67 | items.append(item) 68 | 69 | return items 70 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/da/mini_domainnet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class miniDomainNet(DatasetBase): 9 | """A subset of DomainNet. 10 | 11 | Reference: 12 | - Peng et al. Moment Matching for Multi-Source Domain 13 | Adaptation. ICCV 2019. 14 | - Zhou et al. Domain Adaptive Ensemble Learning. 15 | """ 16 | 17 | dataset_dir = "domainnet" 18 | domains = ["clipart", "painting", "real", "sketch"] 19 | 20 | def __init__(self, cfg): 21 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = osp.join(root, self.dataset_dir) 23 | self.split_dir = osp.join(self.dataset_dir, "splits_mini") 24 | 25 | self.check_input_domains( 26 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 27 | ) 28 | 29 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 30 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 31 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 32 | 33 | super().__init__(train_x=train_x, train_u=train_u, test=test) 34 | 35 | def _read_data(self, input_domains, split="train"): 36 | items = [] 37 | 38 | for domain, dname in enumerate(input_domains): 39 | filename = dname + "_" + split + ".txt" 40 | split_file = osp.join(self.split_dir, filename) 41 | 42 | with open(split_file, "r") as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | line = line.strip() 46 | impath, label = line.split(" ") 47 | classname = impath.split("/")[1] 48 | impath = osp.join(self.dataset_dir, impath) 49 | label = int(label) 50 | item = Datum( 51 | impath=impath, 52 | label=label, 53 | domain=domain, 54 | classname=classname 55 | ) 56 | items.append(item) 57 | 58 | return items 59 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/da/office31.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class Office31(DatasetBase): 11 | """Office-31. 12 | 13 | Statistics: 14 | - 4,110 images. 15 | - 31 classes related to office objects. 16 | - 3 domains: Amazon, Webcam, Dslr. 17 | - URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/. 18 | 19 | Reference: 20 | - Saenko et al. Adapting visual category models to 21 | new domains. ECCV 2010. 22 | """ 23 | 24 | dataset_dir = "office31" 25 | domains = ["amazon", "webcam", "dslr"] 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | self.check_input_domains( 32 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 33 | ) 34 | 35 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) 36 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) 37 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS) 38 | 39 | super().__init__(train_x=train_x, train_u=train_u, test=test) 40 | 41 | def _read_data(self, input_domains): 42 | items = [] 43 | 44 | for domain, dname in enumerate(input_domains): 45 | domain_dir = osp.join(self.dataset_dir, dname) 46 | class_names = listdir_nohidden(domain_dir) 47 | class_names.sort() 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_path = osp.join(domain_dir, class_name) 51 | imnames = listdir_nohidden(class_path) 52 | 53 | for imname in imnames: 54 | impath = osp.join(class_path, imname) 55 | item = Datum( 56 | impath=impath, 57 | label=label, 58 | domain=domain, 59 | classname=class_name 60 | ) 61 | items.append(item) 62 | 63 | return items 64 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/da/office_home.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class OfficeHome(DatasetBase): 11 | """Office-Home. 12 | 13 | Statistics: 14 | - Around 15,500 images. 15 | - 65 classes related to office and home objects. 16 | - 4 domains: Art, Clipart, Product, Real World. 17 | - URL: http://hemanthdv.org/OfficeHome-Dataset/. 18 | 19 | Reference: 20 | - Venkateswara et al. Deep Hashing Network for Unsupervised 21 | Domain Adaptation. CVPR 2017. 22 | """ 23 | 24 | dataset_dir = "office_home" 25 | domains = ["art", "clipart", "product", "real_world"] 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | self.check_input_domains( 32 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 33 | ) 34 | 35 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) 36 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) 37 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS) 38 | 39 | super().__init__(train_x=train_x, train_u=train_u, test=test) 40 | 41 | def _read_data(self, input_domains): 42 | items = [] 43 | 44 | for domain, dname in enumerate(input_domains): 45 | domain_dir = osp.join(self.dataset_dir, dname) 46 | class_names = listdir_nohidden(domain_dir) 47 | class_names.sort() 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_path = osp.join(domain_dir, class_name) 51 | imnames = listdir_nohidden(class_path) 52 | 53 | for imname in imnames: 54 | impath = osp.join(class_path, imname) 55 | item = Datum( 56 | impath=impath, 57 | label=label, 58 | domain=domain, 59 | classname=class_name.lower(), 60 | ) 61 | items.append(item) 62 | 63 | return items 64 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/da/visda17.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class VisDA17(DatasetBase): 9 | """VisDA17. 10 | 11 | Focusing on simulation-to-reality domain shift. 12 | 13 | URL: http://ai.bu.edu/visda-2017/. 14 | 15 | Reference: 16 | - Peng et al. VisDA: The Visual Domain Adaptation 17 | Challenge. ArXiv 2017. 18 | """ 19 | 20 | dataset_dir = "visda17" 21 | domains = ["synthetic", "real"] 22 | 23 | def __init__(self, cfg): 24 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | 27 | self.check_input_domains( 28 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 29 | ) 30 | 31 | train_x = self._read_data("synthetic") 32 | train_u = self._read_data("real") 33 | test = self._read_data("real") 34 | 35 | super().__init__(train_x=train_x, train_u=train_u, test=test) 36 | 37 | def _read_data(self, dname): 38 | filedir = "train" if dname == "synthetic" else "validation" 39 | image_list = osp.join(self.dataset_dir, filedir, "image_list.txt") 40 | items = [] 41 | # There is only one source domain 42 | domain = 0 43 | 44 | with open(image_list, "r") as f: 45 | lines = f.readlines() 46 | 47 | for line in lines: 48 | line = line.strip() 49 | impath, label = line.split(" ") 50 | classname = impath.split("/")[0] 51 | impath = osp.join(self.dataset_dir, filedir, impath) 52 | label = int(label) 53 | item = Datum( 54 | impath=impath, 55 | label=label, 56 | domain=domain, 57 | classname=classname 58 | ) 59 | items.append(item) 60 | 61 | return items 62 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/dg/__init__.py: -------------------------------------------------------------------------------- 1 | from .pacs import PACS 2 | from .vlcs import VLCS 3 | from .wilds import * 4 | from .cifar_c import CIFAR10C, CIFAR100C 5 | from .digits_dg import DigitsDG 6 | from .digit_single import DigitSingle 7 | from .office_home_dg import OfficeHomeDG 8 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/dg/office_home_dg.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from .digits_dg import DigitsDG 5 | from ..base_dataset import DatasetBase 6 | 7 | 8 | @DATASET_REGISTRY.register() 9 | class OfficeHomeDG(DatasetBase): 10 | """Office-Home. 11 | 12 | Statistics: 13 | - Around 15,500 images. 14 | - 65 classes related to office and home objects. 15 | - 4 domains: Art, Clipart, Product, Real World. 16 | - URL: http://hemanthdv.org/OfficeHome-Dataset/. 17 | 18 | Reference: 19 | - Venkateswara et al. Deep Hashing Network for Unsupervised 20 | Domain Adaptation. CVPR 2017. 21 | """ 22 | 23 | dataset_dir = "office_home_dg" 24 | domains = ["art", "clipart", "product", "real_world"] 25 | data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa" 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | if not osp.exists(self.dataset_dir): 32 | dst = osp.join(root, "office_home_dg.zip") 33 | self.download_data(self.data_url, dst, from_gdrive=True) 34 | 35 | self.check_input_domains( 36 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 37 | ) 38 | 39 | train = DigitsDG.read_data( 40 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train" 41 | ) 42 | val = DigitsDG.read_data( 43 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val" 44 | ) 45 | test = DigitsDG.read_data( 46 | self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all" 47 | ) 48 | 49 | super().__init__(train_x=train, val=val, test=test) 50 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/dg/vlcs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | 4 | from dassl.utils import listdir_nohidden 5 | 6 | from ..build import DATASET_REGISTRY 7 | from ..base_dataset import Datum, DatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class VLCS(DatasetBase): 12 | """VLCS. 13 | 14 | Statistics: 15 | - 4 domains: CALTECH, LABELME, PASCAL, SUN 16 | - 5 categories: bird, car, chair, dog, and person. 17 | 18 | Reference: 19 | - Torralba and Efros. Unbiased look at dataset bias. CVPR 2011. 20 | """ 21 | 22 | dataset_dir = "VLCS" 23 | domains = ["caltech", "labelme", "pascal", "sun"] 24 | data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd" 25 | 26 | def __init__(self, cfg): 27 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 28 | self.dataset_dir = osp.join(root, self.dataset_dir) 29 | 30 | if not osp.exists(self.dataset_dir): 31 | dst = osp.join(root, "vlcs.zip") 32 | self.download_data(self.data_url, dst, from_gdrive=True) 33 | 34 | self.check_input_domains( 35 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 36 | ) 37 | 38 | train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train") 39 | val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval") 40 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test") 41 | 42 | super().__init__(train_x=train, val=val, test=test) 43 | 44 | def _read_data(self, input_domains, split): 45 | items = [] 46 | 47 | for domain, dname in enumerate(input_domains): 48 | dname = dname.upper() 49 | path = osp.join(self.dataset_dir, dname, split) 50 | folders = listdir_nohidden(path) 51 | folders.sort() 52 | 53 | for label, folder in enumerate(folders): 54 | impaths = glob.glob(osp.join(path, folder, "*.jpg")) 55 | 56 | for impath in impaths: 57 | item = Datum(impath=impath, label=label, domain=domain) 58 | items.append(item) 59 | 60 | return items 61 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/dg/wilds/__init__.py: -------------------------------------------------------------------------------- 1 | from .fmow import FMoW 2 | from .iwildcam import IWildCam 3 | from .camelyon17 import Camelyon17 4 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/dg/wilds/camelyon17.py: -------------------------------------------------------------------------------- 1 | from dassl.data.datasets import DATASET_REGISTRY 2 | 3 | from .wilds_base import WILDSBase 4 | 5 | 6 | @DATASET_REGISTRY.register() 7 | class Camelyon17(WILDSBase): 8 | """Tumor tissue recognition. 9 | 10 | 2 classes (whether a given region of tissue contains tumor tissue). 11 | 12 | Reference: 13 | - Bandi et al. "From detection of individual metastases to classification of lymph 14 | node status at the patient level: the CAMELYON17 challenge." TMI 2021. 15 | - Koh et al. "Wilds: A benchmark of in-the-wild distribution shifts." ICML 2021. 16 | """ 17 | 18 | dataset_dir = "camelyon17_v1.0" 19 | 20 | def __init__(self, cfg): 21 | super().__init__(cfg) 22 | 23 | def load_classnames(self): 24 | return {0: "healthy tissue", 1: "tumor tissue"} 25 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/dg/wilds/fmow.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY 4 | 5 | from .wilds_base import WILDSBase 6 | 7 | CATEGORIES = [ 8 | "airport", "airport_hangar", "airport_terminal", "amusement_park", 9 | "aquaculture", "archaeological_site", "barn", "border_checkpoint", 10 | "burial_site", "car_dealership", "construction_site", "crop_field", "dam", 11 | "debris_or_rubble", "educational_institution", "electric_substation", 12 | "factory_or_powerplant", "fire_station", "flooded_road", "fountain", 13 | "gas_station", "golf_course", "ground_transportation_station", "helipad", 14 | "hospital", "impoverished_settlement", "interchange", "lake_or_pond", 15 | "lighthouse", "military_facility", "multi-unit_residential", 16 | "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park", 17 | "parking_lot_or_garage", "place_of_worship", "police_station", "port", 18 | "prison", "race_track", "railway_bridge", "recreational_facility", 19 | "road_bridge", "runway", "shipyard", "shopping_mall", 20 | "single-unit_residential", "smokestack", "solar_farm", "space_facility", 21 | "stadium", "storage_tank", "surface_mine", "swimming_pool", "toll_booth", 22 | "tower", "tunnel_opening", "waste_disposal", "water_treatment_facility", 23 | "wind_farm", "zoo" 24 | ] 25 | 26 | 27 | @DATASET_REGISTRY.register() 28 | class FMoW(WILDSBase): 29 | """Satellite imagery classification. 30 | 31 | 62 classes (building or land use categories). 32 | 33 | Reference: 34 | - Christie et al. "Functional Map of the World." CVPR 2018. 35 | - Koh et al. "Wilds: A benchmark of in-the-wild distribution shifts." ICML 2021. 36 | """ 37 | 38 | dataset_dir = "fmow_v1.1" 39 | 40 | def __init__(self, cfg): 41 | super().__init__(cfg) 42 | 43 | def get_image_path(self, dataset, idx): 44 | idx = dataset.full_idxs[idx] 45 | image_name = f"rgb_img_{idx}.png" 46 | image_path = osp.join(self.dataset_dir, "images", image_name) 47 | return image_path 48 | 49 | def get_domain(self, dataset, idx): 50 | # number of regions: 5 or 6 51 | # number of years: 16 52 | region_id = int(dataset.metadata_array[idx][0]) 53 | year_id = int(dataset.metadata_array[idx][1]) 54 | return region_id*16 + year_id 55 | 56 | def load_classnames(self): 57 | return {i: cat for i, cat in enumerate(CATEGORIES)} 58 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/dg/wilds/iwildcam.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pandas as pd 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY 5 | 6 | from .wilds_base import WILDSBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class IWildCam(WILDSBase): 11 | """Animal species recognition. 12 | 13 | 182 classes (species). 14 | 15 | Reference: 16 | - Beery et al. "The iwildcam 2021 competition dataset." arXiv 2021. 17 | - Koh et al. "Wilds: A benchmark of in-the-wild distribution shifts." ICML 2021. 18 | """ 19 | 20 | dataset_dir = "iwildcam_v2.0" 21 | 22 | def __init__(self, cfg): 23 | super().__init__(cfg) 24 | 25 | def get_image_path(self, dataset, idx): 26 | image_name = dataset._input_array[idx] 27 | image_path = osp.join(self.dataset_dir, "train", image_name) 28 | return image_path 29 | 30 | def load_classnames(self): 31 | df = pd.read_csv(osp.join(self.dataset_dir, "categories.csv")) 32 | return dict(df["name"]) 33 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/ssl/__init__.py: -------------------------------------------------------------------------------- 1 | from .svhn import SVHN 2 | from .cifar import CIFAR10, CIFAR100 3 | from .stl10 import STL10 4 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/ssl/stl10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | 4 | from dassl.utils import listdir_nohidden 5 | 6 | from ..build import DATASET_REGISTRY 7 | from ..base_dataset import Datum, DatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class STL10(DatasetBase): 12 | """STL-10 dataset. 13 | 14 | Description: 15 | - 10 classes: airplane, bird, car, cat, deer, dog, horse, 16 | monkey, ship, truck. 17 | - Images are 96x96 pixels, color. 18 | - 500 training images per class, 800 test images per class. 19 | - 100,000 unlabeled images for unsupervised learning. 20 | 21 | Reference: 22 | - Coates et al. An Analysis of Single Layer Networks in 23 | Unsupervised Feature Learning. AISTATS 2011. 24 | """ 25 | 26 | dataset_dir = "stl10" 27 | 28 | def __init__(self, cfg): 29 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | train_dir = osp.join(self.dataset_dir, "train") 32 | test_dir = osp.join(self.dataset_dir, "test") 33 | unlabeled_dir = osp.join(self.dataset_dir, "unlabeled") 34 | fold_file = osp.join( 35 | self.dataset_dir, "stl10_binary", "fold_indices.txt" 36 | ) 37 | 38 | # Only use the first five splits 39 | assert 0 <= cfg.DATASET.STL10_FOLD <= 4 40 | 41 | train_x = self._read_data_train( 42 | train_dir, cfg.DATASET.STL10_FOLD, fold_file 43 | ) 44 | train_u = self._read_data_all(unlabeled_dir) 45 | test = self._read_data_all(test_dir) 46 | 47 | if cfg.DATASET.ALL_AS_UNLABELED: 48 | train_u = train_u + train_x 49 | 50 | super().__init__(train_x=train_x, train_u=train_u, test=test) 51 | 52 | def _read_data_train(self, data_dir, fold, fold_file): 53 | imnames = listdir_nohidden(data_dir) 54 | imnames.sort() 55 | items = [] 56 | 57 | list_idx = list(range(len(imnames))) 58 | if fold >= 0: 59 | with open(fold_file, "r") as f: 60 | str_idx = f.read().splitlines()[fold] 61 | list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ") 62 | 63 | for i in list_idx: 64 | imname = imnames[i] 65 | impath = osp.join(data_dir, imname) 66 | label = osp.splitext(imname)[0].split("_")[1] 67 | label = int(label) 68 | item = Datum(impath=impath, label=label) 69 | items.append(item) 70 | 71 | return items 72 | 73 | def _read_data_all(self, data_dir): 74 | imnames = listdir_nohidden(data_dir) 75 | items = [] 76 | 77 | for imname in imnames: 78 | impath = osp.join(data_dir, imname) 79 | label = osp.splitext(imname)[0].split("_")[1] 80 | if label == "none": 81 | label = -1 82 | else: 83 | label = int(label) 84 | item = Datum(impath=impath, label=label) 85 | items.append(item) 86 | 87 | return items 88 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/datasets/ssl/svhn.py: -------------------------------------------------------------------------------- 1 | from .cifar import CIFAR10 2 | from ..build import DATASET_REGISTRY 3 | 4 | 5 | @DATASET_REGISTRY.register() 6 | class SVHN(CIFAR10): 7 | """SVHN for SSL. 8 | 9 | Reference: 10 | - Netzer et al. Reading Digits in Natural Images with 11 | Unsupervised Feature Learning. NIPS-W 2011. 12 | """ 13 | 14 | dataset_dir = "svhn" 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import INTERPOLATION_MODES, build_transform 2 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import TRAINER_REGISTRY, build_trainer # isort:skip 2 | from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip 3 | 4 | from .da import * 5 | from .dg import * 6 | from .ssl import * 7 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | TRAINER_REGISTRY = Registry("TRAINER") 4 | 5 | 6 | def build_trainer(cfg): 7 | avai_trainers = TRAINER_REGISTRY.registered_names() 8 | check_availability(cfg.TRAINER.NAME, avai_trainers) 9 | if cfg.VERBOSE: 10 | print("Loading trainer: {}".format(cfg.TRAINER.NAME)) 11 | return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg) 12 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/da/__init__.py: -------------------------------------------------------------------------------- 1 | from .se import SE 2 | from .mcd import MCD 3 | from .mme import MME 4 | from .adda import ADDA 5 | from .cdac import CDAC 6 | from .dael import DAEL 7 | from .dann import DANN 8 | from .adabn import AdaBN 9 | from .m3sda import M3SDA 10 | from .source_only import SourceOnly 11 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/da/adabn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dassl.utils import check_isfile 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class AdaBN(TrainerXU): 9 | """Adaptive Batch Normalization. 10 | 11 | https://arxiv.org/abs/1603.04779. 12 | """ 13 | 14 | def __init__(self, cfg): 15 | super().__init__(cfg) 16 | self.done_reset_bn_stats = False 17 | 18 | def check_cfg(self, cfg): 19 | assert check_isfile( 20 | cfg.MODEL.INIT_WEIGHTS 21 | ), "The weights of source model must be provided" 22 | 23 | def before_epoch(self): 24 | if not self.done_reset_bn_stats: 25 | for m in self.model.modules(): 26 | classname = m.__class__.__name__ 27 | if classname.find("BatchNorm") != -1: 28 | m.reset_running_stats() 29 | 30 | self.done_reset_bn_stats = True 31 | 32 | def forward_backward(self, batch_x, batch_u): 33 | input_u = batch_u["img"].to(self.device) 34 | 35 | with torch.no_grad(): 36 | self.model(input_u) 37 | 38 | return None 39 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/da/adda.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import check_isfile, count_num_param, open_specified_layers 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.modeling import build_head 9 | 10 | 11 | @TRAINER_REGISTRY.register() 12 | class ADDA(TrainerXU): 13 | """Adversarial Discriminative Domain Adaptation. 14 | 15 | https://arxiv.org/abs/1702.05464. 16 | """ 17 | 18 | def __init__(self, cfg): 19 | super().__init__(cfg) 20 | self.open_layers = ["backbone"] 21 | if isinstance(self.model.head, nn.Module): 22 | self.open_layers.append("head") 23 | 24 | self.source_model = copy.deepcopy(self.model) 25 | self.source_model.eval() 26 | for param in self.source_model.parameters(): 27 | param.requires_grad_(False) 28 | 29 | self.build_critic() 30 | 31 | self.bce = nn.BCEWithLogitsLoss() 32 | 33 | def check_cfg(self, cfg): 34 | assert check_isfile( 35 | cfg.MODEL.INIT_WEIGHTS 36 | ), "The weights of source model must be provided" 37 | 38 | def build_critic(self): 39 | cfg = self.cfg 40 | 41 | print("Building critic network") 42 | fdim = self.model.fdim 43 | critic_body = build_head( 44 | "mlp", 45 | verbose=cfg.VERBOSE, 46 | in_features=fdim, 47 | hidden_layers=[fdim, fdim // 2], 48 | activation="leaky_relu", 49 | ) 50 | self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1)) 51 | print("# params: {:,}".format(count_num_param(self.critic))) 52 | self.critic.to(self.device) 53 | self.optim_c = build_optimizer(self.critic, cfg.OPTIM) 54 | self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) 55 | self.register_model("critic", self.critic, self.optim_c, self.sched_c) 56 | 57 | def forward_backward(self, batch_x, batch_u): 58 | open_specified_layers(self.model, self.open_layers) 59 | input_x, _, input_u = self.parse_batch_train(batch_x, batch_u) 60 | domain_x = torch.ones(input_x.shape[0], 1).to(self.device) 61 | domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) 62 | 63 | _, feat_x = self.source_model(input_x, return_feature=True) 64 | _, feat_u = self.model(input_u, return_feature=True) 65 | 66 | logit_xd = self.critic(feat_x) 67 | logit_ud = self.critic(feat_u.detach()) 68 | 69 | loss_critic = self.bce(logit_xd, domain_x) 70 | loss_critic += self.bce(logit_ud, domain_u) 71 | self.model_backward_and_update(loss_critic, "critic") 72 | 73 | logit_ud = self.critic(feat_u) 74 | loss_model = self.bce(logit_ud, 1 - domain_u) 75 | self.model_backward_and_update(loss_model, "model") 76 | 77 | loss_summary = { 78 | "loss_critic": loss_critic.item(), 79 | "loss_model": loss_model.item(), 80 | } 81 | 82 | if (self.batch_idx + 1) == self.num_batches: 83 | self.update_lr() 84 | 85 | return loss_summary 86 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/da/dann.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import count_num_param 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.metrics import compute_accuracy 9 | from dassl.modeling import build_head 10 | from dassl.modeling.ops import ReverseGrad 11 | 12 | 13 | @TRAINER_REGISTRY.register() 14 | class DANN(TrainerXU): 15 | """Domain-Adversarial Neural Networks. 16 | 17 | https://arxiv.org/abs/1505.07818. 18 | """ 19 | 20 | def __init__(self, cfg): 21 | super().__init__(cfg) 22 | self.build_critic() 23 | self.ce = nn.CrossEntropyLoss() 24 | self.bce = nn.BCEWithLogitsLoss() 25 | 26 | def build_critic(self): 27 | cfg = self.cfg 28 | 29 | print("Building critic network") 30 | fdim = self.model.fdim 31 | critic_body = build_head( 32 | "mlp", 33 | verbose=cfg.VERBOSE, 34 | in_features=fdim, 35 | hidden_layers=[fdim, fdim], 36 | activation="leaky_relu", 37 | ) 38 | self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1)) 39 | print("# params: {:,}".format(count_num_param(self.critic))) 40 | self.critic.to(self.device) 41 | self.optim_c = build_optimizer(self.critic, cfg.OPTIM) 42 | self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) 43 | self.register_model("critic", self.critic, self.optim_c, self.sched_c) 44 | self.revgrad = ReverseGrad() 45 | 46 | def forward_backward(self, batch_x, batch_u): 47 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 48 | domain_x = torch.ones(input_x.shape[0], 1).to(self.device) 49 | domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) 50 | 51 | global_step = self.batch_idx + self.epoch * self.num_batches 52 | progress = global_step / (self.max_epoch * self.num_batches) 53 | lmda = 2 / (1 + np.exp(-10 * progress)) - 1 54 | 55 | logit_x, feat_x = self.model(input_x, return_feature=True) 56 | _, feat_u = self.model(input_u, return_feature=True) 57 | 58 | loss_x = self.ce(logit_x, label_x) 59 | 60 | feat_x = self.revgrad(feat_x, grad_scaling=lmda) 61 | feat_u = self.revgrad(feat_u, grad_scaling=lmda) 62 | output_xd = self.critic(feat_x) 63 | output_ud = self.critic(feat_u) 64 | loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u) 65 | 66 | loss = loss_x + loss_d 67 | self.model_backward_and_update(loss) 68 | 69 | loss_summary = { 70 | "loss_x": loss_x.item(), 71 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 72 | "loss_d": loss_d.item(), 73 | } 74 | 75 | if (self.batch_idx + 1) == self.num_batches: 76 | self.update_lr() 77 | 78 | return loss_summary 79 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/da/mme.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import count_num_param 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.metrics import compute_accuracy 9 | from dassl.modeling.ops import ReverseGrad 10 | from dassl.engine.trainer import SimpleNet 11 | 12 | 13 | class Prototypes(nn.Module): 14 | 15 | def __init__(self, fdim, num_classes, temp=0.05): 16 | super().__init__() 17 | self.prototypes = nn.Linear(fdim, num_classes, bias=False) 18 | self.temp = temp 19 | 20 | def forward(self, x): 21 | x = F.normalize(x, p=2, dim=1) 22 | out = self.prototypes(x) 23 | out = out / self.temp 24 | return out 25 | 26 | 27 | @TRAINER_REGISTRY.register() 28 | class MME(TrainerXU): 29 | """Minimax Entropy. 30 | 31 | https://arxiv.org/abs/1904.06487. 32 | """ 33 | 34 | def __init__(self, cfg): 35 | super().__init__(cfg) 36 | self.lmda = cfg.TRAINER.MME.LMDA 37 | 38 | def build_model(self): 39 | cfg = self.cfg 40 | 41 | print("Building F") 42 | self.F = SimpleNet(cfg, cfg.MODEL, 0) 43 | self.F.to(self.device) 44 | print("# params: {:,}".format(count_num_param(self.F))) 45 | self.optim_F = build_optimizer(self.F, cfg.OPTIM) 46 | self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) 47 | self.register_model("F", self.F, self.optim_F, self.sched_F) 48 | 49 | print("Building C") 50 | self.C = Prototypes(self.F.fdim, self.num_classes) 51 | self.C.to(self.device) 52 | print("# params: {:,}".format(count_num_param(self.C))) 53 | self.optim_C = build_optimizer(self.C, cfg.OPTIM) 54 | self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) 55 | self.register_model("C", self.C, self.optim_C, self.sched_C) 56 | 57 | self.revgrad = ReverseGrad() 58 | 59 | def forward_backward(self, batch_x, batch_u): 60 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 61 | 62 | feat_x = self.F(input_x) 63 | logit_x = self.C(feat_x) 64 | loss_x = F.cross_entropy(logit_x, label_x) 65 | self.model_backward_and_update(loss_x) 66 | 67 | feat_u = self.F(input_u) 68 | feat_u = self.revgrad(feat_u) 69 | logit_u = self.C(feat_u) 70 | prob_u = F.softmax(logit_u, 1) 71 | loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean() 72 | self.model_backward_and_update(loss_u * self.lmda) 73 | 74 | loss_summary = { 75 | "loss_x": loss_x.item(), 76 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 77 | "loss_u": loss_u.item(), 78 | } 79 | 80 | if (self.batch_idx + 1) == self.num_batches: 81 | self.update_lr() 82 | 83 | return loss_summary 84 | 85 | def model_inference(self, input): 86 | return self.C(self.F(input)) 87 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/da/se.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update 7 | 8 | 9 | @TRAINER_REGISTRY.register() 10 | class SE(TrainerXU): 11 | """Self-ensembling for visual domain adaptation. 12 | 13 | https://arxiv.org/abs/1706.05208. 14 | """ 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA 19 | self.conf_thre = cfg.TRAINER.SE.CONF_THRE 20 | self.rampup = cfg.TRAINER.SE.RAMPUP 21 | 22 | self.teacher = copy.deepcopy(self.model) 23 | self.teacher.train() 24 | for param in self.teacher.parameters(): 25 | param.requires_grad_(False) 26 | 27 | def check_cfg(self, cfg): 28 | assert cfg.DATALOADER.K_TRANSFORMS == 2 29 | 30 | def forward_backward(self, batch_x, batch_u): 31 | global_step = self.batch_idx + self.epoch * self.num_batches 32 | parsed = self.parse_batch_train(batch_x, batch_u) 33 | input_x, label_x, input_u1, input_u2 = parsed 34 | 35 | logit_x = self.model(input_x) 36 | loss_x = F.cross_entropy(logit_x, label_x) 37 | 38 | prob_u = F.softmax(self.model(input_u1), 1) 39 | t_prob_u = F.softmax(self.teacher(input_u2), 1) 40 | loss_u = ((prob_u - t_prob_u)**2).sum(1) 41 | 42 | if self.conf_thre: 43 | max_prob = t_prob_u.max(1)[0] 44 | mask = (max_prob > self.conf_thre).float() 45 | loss_u = (loss_u * mask).mean() 46 | else: 47 | weight_u = sigmoid_rampup(global_step, self.rampup) 48 | loss_u = loss_u.mean() * weight_u 49 | 50 | loss = loss_x + loss_u 51 | self.model_backward_and_update(loss) 52 | 53 | ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha) 54 | ema_model_update(self.model, self.teacher, ema_alpha) 55 | 56 | loss_summary = { 57 | "loss_x": loss_x.item(), 58 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 59 | "loss_u": loss_u.item(), 60 | } 61 | 62 | if (self.batch_idx + 1) == self.num_batches: 63 | self.update_lr() 64 | 65 | return loss_summary 66 | 67 | def parse_batch_train(self, batch_x, batch_u): 68 | input_x = batch_x["img"][0] 69 | label_x = batch_x["label"] 70 | input_u = batch_u["img"] 71 | input_u1, input_u2 = input_u 72 | 73 | input_x = input_x.to(self.device) 74 | label_x = label_x.to(self.device) 75 | input_u1 = input_u1.to(self.device) 76 | input_u2 = input_u2.to(self.device) 77 | 78 | return input_x, label_x, input_u1, input_u2 79 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/da/source_only.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class SourceOnly(TrainerXU): 9 | """Baseline model for domain adaptation, which is 10 | trained using source data only. 11 | """ 12 | 13 | def forward_backward(self, batch_x, batch_u): 14 | input, label = self.parse_batch_train(batch_x, batch_u) 15 | output = self.model(input) 16 | loss = F.cross_entropy(output, label) 17 | self.model_backward_and_update(loss) 18 | 19 | loss_summary = { 20 | "loss": loss.item(), 21 | "acc": compute_accuracy(output, label)[0].item(), 22 | } 23 | 24 | if (self.batch_idx + 1) == self.num_batches: 25 | self.update_lr() 26 | 27 | return loss_summary 28 | 29 | def parse_batch_train(self, batch_x, batch_u): 30 | input = batch_x["img"] 31 | label = batch_x["label"] 32 | input = input.to(self.device) 33 | label = label.to(self.device) 34 | return input, label 35 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/dg/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddaig import DDAIG 2 | from .daeldg import DAELDG 3 | from .vanilla import Vanilla 4 | from .crossgrad import CrossGrad 5 | from .domain_mix import DomainMix 6 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/dg/crossgrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.optim import build_optimizer, build_lr_scheduler 5 | from dassl.utils import count_num_param 6 | from dassl.engine import TRAINER_REGISTRY, TrainerX 7 | from dassl.engine.trainer import SimpleNet 8 | 9 | 10 | @TRAINER_REGISTRY.register() 11 | class CrossGrad(TrainerX): 12 | """Cross-gradient training. 13 | 14 | https://arxiv.org/abs/1804.10745. 15 | """ 16 | 17 | def __init__(self, cfg): 18 | super().__init__(cfg) 19 | self.eps_f = cfg.TRAINER.CROSSGRAD.EPS_F 20 | self.eps_d = cfg.TRAINER.CROSSGRAD.EPS_D 21 | self.alpha_f = cfg.TRAINER.CROSSGRAD.ALPHA_F 22 | self.alpha_d = cfg.TRAINER.CROSSGRAD.ALPHA_D 23 | 24 | def build_model(self): 25 | cfg = self.cfg 26 | 27 | print("Building F") 28 | self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) 29 | self.F.to(self.device) 30 | print("# params: {:,}".format(count_num_param(self.F))) 31 | self.optim_F = build_optimizer(self.F, cfg.OPTIM) 32 | self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) 33 | self.register_model("F", self.F, self.optim_F, self.sched_F) 34 | 35 | print("Building D") 36 | self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains) 37 | self.D.to(self.device) 38 | print("# params: {:,}".format(count_num_param(self.D))) 39 | self.optim_D = build_optimizer(self.D, cfg.OPTIM) 40 | self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) 41 | self.register_model("D", self.D, self.optim_D, self.sched_D) 42 | 43 | def forward_backward(self, batch): 44 | input, label, domain = self.parse_batch_train(batch) 45 | 46 | input.requires_grad = True 47 | 48 | # Compute domain perturbation 49 | loss_d = F.cross_entropy(self.D(input), domain) 50 | loss_d.backward() 51 | grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1) 52 | input_d = input.data + self.eps_f * grad_d 53 | 54 | # Compute label perturbation 55 | input.grad.data.zero_() 56 | loss_f = F.cross_entropy(self.F(input), label) 57 | loss_f.backward() 58 | grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1) 59 | input_f = input.data + self.eps_d * grad_f 60 | 61 | input = input.detach() 62 | 63 | # Update label net 64 | loss_f1 = F.cross_entropy(self.F(input), label) 65 | loss_f2 = F.cross_entropy(self.F(input_d), label) 66 | loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2 67 | self.model_backward_and_update(loss_f, "F") 68 | 69 | # Update domain net 70 | loss_d1 = F.cross_entropy(self.D(input), domain) 71 | loss_d2 = F.cross_entropy(self.D(input_f), domain) 72 | loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2 73 | self.model_backward_and_update(loss_d, "D") 74 | 75 | loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()} 76 | 77 | if (self.batch_idx + 1) == self.num_batches: 78 | self.update_lr() 79 | 80 | return loss_summary 81 | 82 | def model_inference(self, input): 83 | return self.F(input) 84 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/dg/domain_mix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerX 5 | from dassl.metrics import compute_accuracy 6 | 7 | __all__ = ["DomainMix"] 8 | 9 | 10 | @TRAINER_REGISTRY.register() 11 | class DomainMix(TrainerX): 12 | """DomainMix. 13 | 14 | Dynamic Domain Generalization. 15 | 16 | https://github.com/MetaVisionLab/DDG 17 | """ 18 | 19 | def __init__(self, cfg): 20 | super(DomainMix, self).__init__(cfg) 21 | self.mix_type = cfg.TRAINER.DOMAINMIX.TYPE 22 | self.alpha = cfg.TRAINER.DOMAINMIX.ALPHA 23 | self.beta = cfg.TRAINER.DOMAINMIX.BETA 24 | self.dist_beta = torch.distributions.Beta(self.alpha, self.beta) 25 | 26 | def forward_backward(self, batch): 27 | images, label_a, label_b, lam = self.parse_batch_train(batch) 28 | output = self.model(images) 29 | loss = lam * F.cross_entropy( 30 | output, label_a 31 | ) + (1-lam) * F.cross_entropy(output, label_b) 32 | self.model_backward_and_update(loss) 33 | 34 | loss_summary = { 35 | "loss": loss.item(), 36 | "acc": compute_accuracy(output, label_a)[0].item() 37 | } 38 | 39 | if (self.batch_idx + 1) == self.num_batches: 40 | self.update_lr() 41 | 42 | return loss_summary 43 | 44 | def parse_batch_train(self, batch): 45 | images = batch["img"] 46 | target = batch["label"] 47 | domain = batch["domain"] 48 | images = images.to(self.device) 49 | target = target.to(self.device) 50 | domain = domain.to(self.device) 51 | images, target_a, target_b, lam = self.domain_mix( 52 | images, target, domain 53 | ) 54 | return images, target_a, target_b, lam 55 | 56 | def domain_mix(self, x, target, domain): 57 | lam = ( 58 | self.dist_beta.rsample((1, )) 59 | if self.alpha > 0 else torch.tensor(1) 60 | ).to(x.device) 61 | 62 | # random shuffle 63 | perm = torch.randperm(x.size(0), dtype=torch.int64, device=x.device) 64 | if self.mix_type == "crossdomain": 65 | domain_list = torch.unique(domain) 66 | if len(domain_list) > 1: 67 | for idx in domain_list: 68 | cnt_a = torch.sum(domain == idx) 69 | idx_b = (domain != idx).nonzero().squeeze(-1) 70 | cnt_b = idx_b.shape[0] 71 | perm_b = torch.ones(cnt_b).multinomial( 72 | num_samples=cnt_a, replacement=bool(cnt_a > cnt_b) 73 | ) 74 | perm[domain == idx] = idx_b[perm_b] 75 | elif self.mix_type != "random": 76 | raise NotImplementedError( 77 | f"Chooses {'random', 'crossdomain'}, but got {self.mix_type}." 78 | ) 79 | mixed_x = lam*x + (1-lam) * x[perm, :] 80 | target_a, target_b = target, target[perm] 81 | return mixed_x, target_a, target_b, lam 82 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/dg/vanilla.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerX 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class Vanilla(TrainerX): 9 | """Vanilla model. 10 | 11 | A.k.a. Empirical Risk Minimization, or ERM. 12 | """ 13 | 14 | def forward_backward(self, batch): 15 | input, target = self.parse_batch_train(batch) 16 | output = self.model(input) 17 | loss = F.cross_entropy(output, target) 18 | self.model_backward_and_update(loss) 19 | 20 | loss_summary = { 21 | "loss": loss.item(), 22 | "acc": compute_accuracy(output, target)[0].item(), 23 | } 24 | 25 | if (self.batch_idx + 1) == self.num_batches: 26 | self.update_lr() 27 | 28 | return loss_summary 29 | 30 | def parse_batch_train(self, batch): 31 | input = batch["img"] 32 | target = batch["label"] 33 | input = input.to(self.device) 34 | target = target.to(self.device) 35 | return input, target 36 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/ssl/__init__.py: -------------------------------------------------------------------------------- 1 | from .entmin import EntMin 2 | from .fixmatch import FixMatch 3 | from .mixmatch import MixMatch 4 | from .mean_teacher import MeanTeacher 5 | from .sup_baseline import SupBaseline 6 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/ssl/entmin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | 7 | 8 | @TRAINER_REGISTRY.register() 9 | class EntMin(TrainerXU): 10 | """Entropy Minimization. 11 | 12 | http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf. 13 | """ 14 | 15 | def __init__(self, cfg): 16 | super().__init__(cfg) 17 | self.lmda = cfg.TRAINER.ENTMIN.LMDA 18 | 19 | def forward_backward(self, batch_x, batch_u): 20 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 21 | 22 | output_x = self.model(input_x) 23 | loss_x = F.cross_entropy(output_x, label_x) 24 | 25 | output_u = F.softmax(self.model(input_u), 1) 26 | loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean() 27 | 28 | loss = loss_x + loss_u * self.lmda 29 | 30 | self.model_backward_and_update(loss) 31 | 32 | loss_summary = { 33 | "loss_x": loss_x.item(), 34 | "acc_x": compute_accuracy(output_x, label_x)[0].item(), 35 | "loss_u": loss_u.item(), 36 | } 37 | 38 | if (self.batch_idx + 1) == self.num_batches: 39 | self.update_lr() 40 | 41 | return loss_summary 42 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/ssl/mean_teacher.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update 7 | 8 | 9 | @TRAINER_REGISTRY.register() 10 | class MeanTeacher(TrainerXU): 11 | """Mean teacher. 12 | 13 | https://arxiv.org/abs/1703.01780. 14 | """ 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | self.weight_u = cfg.TRAINER.MEANTEACHER.WEIGHT_U 19 | self.ema_alpha = cfg.TRAINER.MEANTEACHER.EMA_ALPHA 20 | self.rampup = cfg.TRAINER.MEANTEACHER.RAMPUP 21 | 22 | self.teacher = copy.deepcopy(self.model) 23 | self.teacher.train() 24 | for param in self.teacher.parameters(): 25 | param.requires_grad_(False) 26 | 27 | def forward_backward(self, batch_x, batch_u): 28 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 29 | 30 | logit_x = self.model(input_x) 31 | loss_x = F.cross_entropy(logit_x, label_x) 32 | 33 | target_u = F.softmax(self.teacher(input_u), 1) 34 | prob_u = F.softmax(self.model(input_u), 1) 35 | loss_u = ((prob_u - target_u)**2).sum(1).mean() 36 | 37 | weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup) 38 | loss = loss_x + loss_u*weight_u 39 | self.model_backward_and_update(loss) 40 | 41 | global_step = self.batch_idx + self.epoch * self.num_batches 42 | ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha) 43 | ema_model_update(self.model, self.teacher, ema_alpha) 44 | 45 | loss_summary = { 46 | "loss_x": loss_x.item(), 47 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 48 | "loss_u": loss_u.item(), 49 | } 50 | 51 | if (self.batch_idx + 1) == self.num_batches: 52 | self.update_lr() 53 | 54 | return loss_summary 55 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/ssl/mixmatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.modeling.ops import mixup 6 | from dassl.modeling.ops.utils import ( 7 | sharpen_prob, create_onehot, linear_rampup, shuffle_index 8 | ) 9 | 10 | 11 | @TRAINER_REGISTRY.register() 12 | class MixMatch(TrainerXU): 13 | """MixMatch: A Holistic Approach to Semi-Supervised Learning. 14 | 15 | https://arxiv.org/abs/1905.02249. 16 | """ 17 | 18 | def __init__(self, cfg): 19 | super().__init__(cfg) 20 | self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U 21 | self.temp = cfg.TRAINER.MIXMATCH.TEMP 22 | self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA 23 | self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP 24 | 25 | def check_cfg(self, cfg): 26 | assert cfg.DATALOADER.K_TRANSFORMS > 1 27 | 28 | def forward_backward(self, batch_x, batch_u): 29 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 30 | num_x = input_x.shape[0] 31 | 32 | global_step = self.batch_idx + self.epoch * self.num_batches 33 | weight_u = self.weight_u * linear_rampup(global_step, self.rampup) 34 | 35 | # Generate pseudo-label for unlabeled data 36 | with torch.no_grad(): 37 | output_u = 0 38 | for input_ui in input_u: 39 | output_ui = F.softmax(self.model(input_ui), 1) 40 | output_u += output_ui 41 | output_u /= len(input_u) 42 | label_u = sharpen_prob(output_u, self.temp) 43 | label_u = [label_u] * len(input_u) 44 | label_u = torch.cat(label_u, 0) 45 | input_u = torch.cat(input_u, 0) 46 | 47 | # Combine and shuffle labeled and unlabeled data 48 | input_xu = torch.cat([input_x, input_u], 0) 49 | label_xu = torch.cat([label_x, label_u], 0) 50 | input_xu, label_xu = shuffle_index(input_xu, label_xu) 51 | 52 | # Mixup 53 | input_x, label_x = mixup( 54 | input_x, 55 | input_xu[:num_x], 56 | label_x, 57 | label_xu[:num_x], 58 | self.beta, 59 | preserve_order=True, 60 | ) 61 | 62 | input_u, label_u = mixup( 63 | input_u, 64 | input_xu[num_x:], 65 | label_u, 66 | label_xu[num_x:], 67 | self.beta, 68 | preserve_order=True, 69 | ) 70 | 71 | # Compute losses 72 | output_x = F.softmax(self.model(input_x), 1) 73 | loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean() 74 | 75 | output_u = F.softmax(self.model(input_u), 1) 76 | loss_u = ((label_u - output_u)**2).mean() 77 | 78 | loss = loss_x + loss_u*weight_u 79 | self.model_backward_and_update(loss) 80 | 81 | loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()} 82 | 83 | if (self.batch_idx + 1) == self.num_batches: 84 | self.update_lr() 85 | 86 | return loss_summary 87 | 88 | def parse_batch_train(self, batch_x, batch_u): 89 | input_x = batch_x["img"][0] 90 | label_x = batch_x["label"] 91 | label_x = create_onehot(label_x, self.num_classes) 92 | input_u = batch_u["img"] 93 | 94 | input_x = input_x.to(self.device) 95 | label_x = label_x.to(self.device) 96 | input_u = [input_ui.to(self.device) for input_ui in input_u] 97 | 98 | return input_x, label_x, input_u 99 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/engine/ssl/sup_baseline.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class SupBaseline(TrainerXU): 9 | """Supervised Baseline.""" 10 | 11 | def forward_backward(self, batch_x, batch_u): 12 | input, label = self.parse_batch_train(batch_x, batch_u) 13 | output = self.model(input) 14 | loss = F.cross_entropy(output, label) 15 | self.model_backward_and_update(loss) 16 | 17 | loss_summary = { 18 | "loss": loss.item(), 19 | "acc": compute_accuracy(output, label)[0].item(), 20 | } 21 | 22 | if (self.batch_idx + 1) == self.num_batches: 23 | self.update_lr() 24 | 25 | return loss_summary 26 | 27 | def parse_batch_train(self, batch_x, batch_u): 28 | input = batch_x["img"] 29 | label = batch_x["label"] 30 | input = input.to(self.device) 31 | label = label.to(self.device) 32 | return input, label 33 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_evaluator, EVALUATOR_REGISTRY # isort:skip 2 | 3 | from .evaluator import EvaluatorBase, Classification 4 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/evaluation/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | EVALUATOR_REGISTRY = Registry("EVALUATOR") 4 | 5 | 6 | def build_evaluator(cfg, **kwargs): 7 | avai_evaluators = EVALUATOR_REGISTRY.registered_names() 8 | check_availability(cfg.TEST.EVALUATOR, avai_evaluators) 9 | if cfg.VERBOSE: 10 | print("Loading evaluator: {}".format(cfg.TEST.EVALUATOR)) 11 | return EVALUATOR_REGISTRY.get(cfg.TEST.EVALUATOR)(cfg, **kwargs) 12 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import compute_accuracy 2 | from .distance import ( 3 | cosine_distance, compute_distance_matrix, euclidean_squared_distance 4 | ) 5 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | def compute_accuracy(output, target, topk=(1, )): 2 | """Computes the accuracy over the k top predictions for 3 | the specified values of k. 4 | 5 | Args: 6 | output (torch.Tensor): prediction matrix with shape (batch_size, num_classes). 7 | target (torch.LongTensor): ground truth labels with shape (batch_size). 8 | topk (tuple, optional): accuracy at top-k will be computed. For example, 9 | topk=(1, 5) means accuracy at top-1 and top-5 will be computed. 10 | 11 | Returns: 12 | list: accuracy at top-k. 13 | """ 14 | maxk = max(topk) 15 | batch_size = target.size(0) 16 | 17 | if isinstance(output, (tuple, list)): 18 | output = output[0] 19 | 20 | _, pred = output.topk(maxk, 1, True, True) 21 | pred = pred.t() 22 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 23 | 24 | res = [] 25 | for k in topk: 26 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 27 | acc = correct_k.mul_(100.0 / batch_size) 28 | res.append(acc) 29 | 30 | return res 31 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/metrics/distance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source: https://github.com/KaiyangZhou/deep-person-reid 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | def compute_distance_matrix(input1, input2, metric="euclidean"): 9 | """A wrapper function for computing distance matrix. 10 | 11 | Each input matrix has the shape (n_data, feature_dim). 12 | 13 | Args: 14 | input1 (torch.Tensor): 2-D feature matrix. 15 | input2 (torch.Tensor): 2-D feature matrix. 16 | metric (str, optional): "euclidean" or "cosine". 17 | Default is "euclidean". 18 | 19 | Returns: 20 | torch.Tensor: distance matrix. 21 | """ 22 | # check input 23 | assert isinstance(input1, torch.Tensor) 24 | assert isinstance(input2, torch.Tensor) 25 | assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format( 26 | input1.dim() 27 | ) 28 | assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format( 29 | input2.dim() 30 | ) 31 | assert input1.size(1) == input2.size(1) 32 | 33 | if metric == "euclidean": 34 | distmat = euclidean_squared_distance(input1, input2) 35 | elif metric == "cosine": 36 | distmat = cosine_distance(input1, input2) 37 | else: 38 | raise ValueError( 39 | "Unknown distance metric: {}. " 40 | 'Please choose either "euclidean" or "cosine"'.format(metric) 41 | ) 42 | 43 | return distmat 44 | 45 | 46 | def euclidean_squared_distance(input1, input2): 47 | """Computes euclidean squared distance. 48 | 49 | Args: 50 | input1 (torch.Tensor): 2-D feature matrix. 51 | input2 (torch.Tensor): 2-D feature matrix. 52 | 53 | Returns: 54 | torch.Tensor: distance matrix. 55 | """ 56 | m, n = input1.size(0), input2.size(0) 57 | mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) 58 | mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 59 | distmat = mat1 + mat2 60 | distmat.addmm_(1, -2, input1, input2.t()) 61 | return distmat 62 | 63 | 64 | def cosine_distance(input1, input2): 65 | """Computes cosine distance. 66 | 67 | Args: 68 | input1 (torch.Tensor): 2-D feature matrix. 69 | input2 (torch.Tensor): 2-D feature matrix. 70 | 71 | Returns: 72 | torch.Tensor: distance matrix. 73 | """ 74 | input1_normed = F.normalize(input1, p=2, dim=1) 75 | input2_normed = F.normalize(input2, p=2, dim=1) 76 | distmat = 1 - torch.mm(input1_normed, input2_normed.t()) 77 | return distmat 78 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .head import HEAD_REGISTRY, build_head 2 | from .network import NETWORK_REGISTRY, build_network 3 | from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone 4 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_backbone, BACKBONE_REGISTRY # isort:skip 2 | from .backbone import Backbone # isort:skip 3 | 4 | from .vgg import vgg16 5 | from .resnet import ( 6 | resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1, 7 | resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1, 8 | resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123, 9 | resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12, 10 | resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123, 11 | resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123 12 | ) 13 | from .alexnet import alexnet 14 | from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2 15 | from .cnn_digitsdg import cnn_digitsdg 16 | from .efficientnet import ( 17 | efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, 18 | efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 19 | ) 20 | from .resnet_dynamic import * 21 | from .cnn_digitsingle import cnn_digitsingle 22 | from .preact_resnet18 import preact_resnet18 23 | from .cnn_digit5_m3sda import cnn_digit5_m3sda 24 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/backbone/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from .build import BACKBONE_REGISTRY 6 | from .backbone import Backbone 7 | 8 | model_urls = { 9 | "alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth", 10 | } 11 | 12 | 13 | class AlexNet(Backbone): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | # Note that self.classifier outputs features rather than logits 34 | self.classifier = nn.Sequential( 35 | nn.Dropout(), 36 | nn.Linear(256 * 6 * 6, 4096), 37 | nn.ReLU(inplace=True), 38 | nn.Dropout(), 39 | nn.Linear(4096, 4096), 40 | nn.ReLU(inplace=True), 41 | ) 42 | 43 | self._out_features = 4096 44 | 45 | def forward(self, x): 46 | x = self.features(x) 47 | x = self.avgpool(x) 48 | x = torch.flatten(x, 1) 49 | return self.classifier(x) 50 | 51 | 52 | def init_pretrained_weights(model, model_url): 53 | pretrain_dict = model_zoo.load_url(model_url) 54 | model.load_state_dict(pretrain_dict, strict=False) 55 | 56 | 57 | @BACKBONE_REGISTRY.register() 58 | def alexnet(pretrained=True, **kwargs): 59 | model = AlexNet() 60 | 61 | if pretrained: 62 | init_pretrained_weights(model, model_urls["alexnet"]) 63 | 64 | return model 65 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Backbone(nn.Module): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self): 10 | pass 11 | 12 | @property 13 | def out_features(self): 14 | """Output feature dimension.""" 15 | if self.__dict__.get("_out_features") is None: 16 | return None 17 | return self._out_features 18 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/backbone/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | BACKBONE_REGISTRY = Registry("BACKBONE") 4 | 5 | 6 | def build_backbone(name, verbose=True, **kwargs): 7 | avai_backbones = BACKBONE_REGISTRY.registered_names() 8 | check_availability(name, avai_backbones) 9 | if verbose: 10 | print("Backbone: {}".format(name)) 11 | return BACKBONE_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/backbone/cnn_digit5_m3sda.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference 3 | 4 | https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA 5 | """ 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | 9 | from .build import BACKBONE_REGISTRY 10 | from .backbone import Backbone 11 | 12 | 13 | class FeatureExtractor(Backbone): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2) 18 | self.bn1 = nn.BatchNorm2d(64) 19 | self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2) 20 | self.bn2 = nn.BatchNorm2d(64) 21 | self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2) 22 | self.bn3 = nn.BatchNorm2d(128) 23 | self.fc1 = nn.Linear(8192, 3072) 24 | self.bn1_fc = nn.BatchNorm1d(3072) 25 | self.fc2 = nn.Linear(3072, 2048) 26 | self.bn2_fc = nn.BatchNorm1d(2048) 27 | 28 | self._out_features = 2048 29 | 30 | def _check_input(self, x): 31 | H, W = x.shape[2:] 32 | assert ( 33 | H == 32 and W == 32 34 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 35 | 36 | def forward(self, x): 37 | self._check_input(x) 38 | x = F.relu(self.bn1(self.conv1(x))) 39 | x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1) 40 | x = F.relu(self.bn2(self.conv2(x))) 41 | x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1) 42 | x = F.relu(self.bn3(self.conv3(x))) 43 | x = x.view(x.size(0), 8192) 44 | x = F.relu(self.bn1_fc(self.fc1(x))) 45 | x = F.dropout(x, training=self.training) 46 | x = F.relu(self.bn2_fc(self.fc2(x))) 47 | return x 48 | 49 | 50 | @BACKBONE_REGISTRY.register() 51 | def cnn_digit5_m3sda(**kwargs): 52 | """ 53 | This architecture was used for the Digit-5 dataset in: 54 | 55 | - Peng et al. Moment Matching for Multi-Source 56 | Domain Adaptation. ICCV 2019. 57 | """ 58 | return FeatureExtractor() 59 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/backbone/cnn_digitsdg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | 4 | from dassl.utils import init_network_weights 5 | 6 | from .build import BACKBONE_REGISTRY 7 | from .backbone import Backbone 8 | 9 | 10 | class Convolution(nn.Module): 11 | 12 | def __init__(self, c_in, c_out): 13 | super().__init__() 14 | self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1) 15 | self.relu = nn.ReLU(True) 16 | 17 | def forward(self, x): 18 | return self.relu(self.conv(x)) 19 | 20 | 21 | class ConvNet(Backbone): 22 | 23 | def __init__(self, c_hidden=64): 24 | super().__init__() 25 | self.conv1 = Convolution(3, c_hidden) 26 | self.conv2 = Convolution(c_hidden, c_hidden) 27 | self.conv3 = Convolution(c_hidden, c_hidden) 28 | self.conv4 = Convolution(c_hidden, c_hidden) 29 | 30 | self._out_features = 2**2 * c_hidden 31 | 32 | def _check_input(self, x): 33 | H, W = x.shape[2:] 34 | assert ( 35 | H == 32 and W == 32 36 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 37 | 38 | def forward(self, x): 39 | self._check_input(x) 40 | x = self.conv1(x) 41 | x = F.max_pool2d(x, 2) 42 | x = self.conv2(x) 43 | x = F.max_pool2d(x, 2) 44 | x = self.conv3(x) 45 | x = F.max_pool2d(x, 2) 46 | x = self.conv4(x) 47 | x = F.max_pool2d(x, 2) 48 | return x.view(x.size(0), -1) 49 | 50 | 51 | @BACKBONE_REGISTRY.register() 52 | def cnn_digitsdg(**kwargs): 53 | """ 54 | This architecture was used for DigitsDG dataset in: 55 | 56 | - Zhou et al. Deep Domain-Adversarial Image Generation 57 | for Domain Generalisation. AAAI 2020. 58 | """ 59 | model = ConvNet(c_hidden=64) 60 | init_network_weights(model, init_type="kaiming") 61 | return model 62 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/backbone/cnn_digitsingle.py: -------------------------------------------------------------------------------- 1 | """ 2 | This model is built based on 3 | https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py 4 | """ 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | from dassl.utils import init_network_weights 9 | 10 | from .build import BACKBONE_REGISTRY 11 | from .backbone import Backbone 12 | 13 | 14 | class CNN(Backbone): 15 | 16 | def __init__(self): 17 | super().__init__() 18 | self.conv1 = nn.Conv2d(3, 64, 5) 19 | self.conv2 = nn.Conv2d(64, 128, 5) 20 | self.fc3 = nn.Linear(5 * 5 * 128, 1024) 21 | self.fc4 = nn.Linear(1024, 1024) 22 | 23 | self._out_features = 1024 24 | 25 | def _check_input(self, x): 26 | H, W = x.shape[2:] 27 | assert ( 28 | H == 32 and W == 32 29 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 30 | 31 | def forward(self, x): 32 | self._check_input(x) 33 | x = self.conv1(x) 34 | x = F.relu(x) 35 | x = F.max_pool2d(x, 2) 36 | 37 | x = self.conv2(x) 38 | x = F.relu(x) 39 | x = F.max_pool2d(x, 2) 40 | 41 | x = x.view(x.size(0), -1) 42 | 43 | x = self.fc3(x) 44 | x = F.relu(x) 45 | 46 | x = self.fc4(x) 47 | x = F.relu(x) 48 | 49 | return x 50 | 51 | 52 | @BACKBONE_REGISTRY.register() 53 | def cnn_digitsingle(**kwargs): 54 | model = CNN() 55 | init_network_weights(model, init_type="kaiming") 56 | return model 57 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/backbone/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source: https://github.com/lukemelas/EfficientNet-PyTorch. 3 | """ 4 | __version__ = "0.6.4" 5 | from .model import ( 6 | EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2, 7 | efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, 8 | efficientnet_b7 9 | ) 10 | from .utils import ( 11 | BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params 12 | ) 13 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_head, HEAD_REGISTRY # isort:skip 2 | 3 | from .mlp import mlp 4 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/head/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | HEAD_REGISTRY = Registry("HEAD") 4 | 5 | 6 | def build_head(name, verbose=True, **kwargs): 7 | avai_heads = HEAD_REGISTRY.registered_names() 8 | check_availability(name, avai_heads) 9 | if verbose: 10 | print("Head: {}".format(name)) 11 | return HEAD_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/head/mlp.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | from .build import HEAD_REGISTRY 5 | 6 | 7 | class MLP(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | in_features=2048, 12 | hidden_layers=[], 13 | activation="relu", 14 | bn=True, 15 | dropout=0.0, 16 | ): 17 | super().__init__() 18 | if isinstance(hidden_layers, int): 19 | hidden_layers = [hidden_layers] 20 | 21 | assert len(hidden_layers) > 0 22 | self.out_features = hidden_layers[-1] 23 | 24 | mlp = [] 25 | 26 | if activation == "relu": 27 | act_fn = functools.partial(nn.ReLU, inplace=True) 28 | elif activation == "leaky_relu": 29 | act_fn = functools.partial(nn.LeakyReLU, inplace=True) 30 | else: 31 | raise NotImplementedError 32 | 33 | for hidden_dim in hidden_layers: 34 | mlp += [nn.Linear(in_features, hidden_dim)] 35 | if bn: 36 | mlp += [nn.BatchNorm1d(hidden_dim)] 37 | mlp += [act_fn()] 38 | if dropout > 0: 39 | mlp += [nn.Dropout(dropout)] 40 | in_features = hidden_dim 41 | 42 | self.mlp = nn.Sequential(*mlp) 43 | 44 | def forward(self, x): 45 | return self.mlp(x) 46 | 47 | 48 | @HEAD_REGISTRY.register() 49 | def mlp(**kwargs): 50 | return MLP(**kwargs) 51 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_network, NETWORK_REGISTRY # isort:skip 2 | 3 | from .ddaig_fcn import ( 4 | fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn 5 | ) 6 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/network/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | NETWORK_REGISTRY = Registry("NETWORK") 4 | 5 | 6 | def build_network(name, verbose=True, **kwargs): 7 | avai_models = NETWORK_REGISTRY.registered_names() 8 | check_availability(name, avai_models) 9 | if verbose: 10 | print("Network: {}".format(name)) 11 | return NETWORK_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmd import MaximumMeanDiscrepancy 2 | from .conv import * 3 | from .dsbn import DSBN1d, DSBN2d 4 | from .mixup import mixup 5 | from .efdmix import ( 6 | EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix, 7 | crossdomain_efdmix, run_without_efdmix 8 | ) 9 | from .mixstyle import ( 10 | MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle, 11 | deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle 12 | ) 13 | from .attention import * 14 | from .transnorm import TransNorm1d, TransNorm2d 15 | from .sequential2 import Sequential2 16 | from .reverse_grad import ReverseGrad 17 | from .cross_entropy import cross_entropy 18 | from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance 19 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | 4 | __all__ = ["Attention"] 5 | 6 | 7 | class Attention(nn.Module): 8 | """Attention from `"Dynamic Domain Generalization" `_. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | in_channels: int, 14 | out_features: int, 15 | squeeze=None, 16 | bias: bool = True 17 | ): 18 | super(Attention, self).__init__() 19 | self.squeeze = squeeze if squeeze else in_channels // 16 20 | assert self.squeeze > 0 21 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 22 | self.fc1 = nn.Linear(in_channels, self.squeeze, bias=bias) 23 | self.fc2 = nn.Linear(self.squeeze, out_features, bias=bias) 24 | self.sf = nn.Softmax(dim=-1) 25 | 26 | def forward(self, x): 27 | x = self.avg_pool(x).view(x.shape[:-2]) 28 | x = self.fc1(x) 29 | x = F.relu(x, inplace=True) 30 | x = self.fc2(x) 31 | return self.sf(x) 32 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .attention import Attention 4 | 5 | __all__ = ["Conv2dDynamic"] 6 | 7 | 8 | class Conv2dDynamic(nn.Module): 9 | """Conv2dDynamic from `"Dynamic Domain Generalization" `_. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | in_channels: int, 15 | out_channels: int, 16 | kernel_size: int, 17 | stride: int, 18 | padding: int, 19 | bias: bool = True, 20 | squeeze: int = None, 21 | attention_in_channels: int = None 22 | ) -> None: 23 | super(Conv2dDynamic, self).__init__() 24 | 25 | if kernel_size // 2 != padding: 26 | # Only when this condition is met, we can ensure that different 27 | # kernel_size can obtain feature maps of consistent size. 28 | # Let I, K, S, P, O: O = (I + 2P - K) // S + 1, if P = K // 2, then O = (I - K % 2) // S + 1 29 | # This means that the output of two different Ks with the same parity can be made the same by adjusting P. 30 | raise ValueError("`padding` must be equal to `kernel_size // 2`.") 31 | if kernel_size % 2 == 0: 32 | raise ValueError( 33 | "Kernel_size must be odd now because the templates we used are odd (kernel_size=1)." 34 | ) 35 | 36 | self.conv = nn.Conv2d( 37 | in_channels, 38 | out_channels, 39 | kernel_size=kernel_size, 40 | stride=stride, 41 | padding=padding, 42 | bias=bias 43 | ) 44 | self.kernel_templates = nn.ModuleDict() 45 | self.kernel_templates["conv_nn"] = nn.Conv2d( 46 | in_channels, 47 | out_channels, 48 | kernel_size=kernel_size, 49 | stride=stride, 50 | padding=padding, 51 | groups=min(in_channels, out_channels), 52 | bias=bias 53 | ) 54 | self.kernel_templates["conv_11"] = nn.Conv2d( 55 | in_channels, 56 | out_channels, 57 | kernel_size=1, 58 | stride=stride, 59 | padding=0, 60 | bias=bias 61 | ) 62 | self.kernel_templates["conv_n1"] = nn.Conv2d( 63 | in_channels, 64 | out_channels, 65 | kernel_size=(kernel_size, 1), 66 | stride=stride, 67 | padding=(padding, 0), 68 | bias=bias 69 | ) 70 | self.kernel_templates["conv_1n"] = nn.Conv2d( 71 | in_channels, 72 | out_channels, 73 | kernel_size=(1, kernel_size), 74 | stride=stride, 75 | padding=(0, padding), 76 | bias=bias 77 | ) 78 | self.attention = Attention( 79 | attention_in_channels if attention_in_channels else in_channels, 80 | 4, 81 | squeeze, 82 | bias=bias 83 | ) 84 | 85 | def forward(self, x, attention_x=None): 86 | attention_x = x if attention_x is None else attention_x 87 | y = self.attention(attention_x) 88 | 89 | out = self.conv(x) 90 | 91 | for i, template in enumerate(self.kernel_templates): 92 | out += self.kernel_templates[template](x) * y[:, 93 | i].view(-1, 1, 1, 1) 94 | 95 | return out 96 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def cross_entropy(input, target, label_smooth=0, reduction="mean"): 6 | """Cross entropy loss. 7 | 8 | Args: 9 | input (torch.Tensor): logit matrix with shape of (batch, num_classes). 10 | target (torch.LongTensor): int label matrix. 11 | label_smooth (float, optional): label smoothing hyper-parameter. 12 | Default is 0. 13 | reduction (str, optional): how the losses for a mini-batch 14 | will be aggregated. Default is 'mean'. 15 | """ 16 | num_classes = input.shape[1] 17 | log_prob = F.log_softmax(input, dim=1) 18 | zeros = torch.zeros(log_prob.size()) 19 | target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1) 20 | target = target.type_as(input) 21 | target = (1-label_smooth) * target + label_smooth/num_classes 22 | loss = (-target * log_prob).sum(1) 23 | if reduction == "mean": 24 | return loss.mean() 25 | elif reduction == "sum": 26 | return loss.sum() 27 | elif reduction == "none": 28 | return loss 29 | else: 30 | raise ValueError 31 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class _DSBN(nn.Module): 5 | """Domain Specific Batch Normalization. 6 | 7 | Args: 8 | num_features (int): number of features. 9 | n_domain (int): number of domains. 10 | bn_type (str): type of bn. Choices are ['1d', '2d']. 11 | """ 12 | 13 | def __init__(self, num_features, n_domain, bn_type): 14 | super().__init__() 15 | if bn_type == "1d": 16 | BN = nn.BatchNorm1d 17 | elif bn_type == "2d": 18 | BN = nn.BatchNorm2d 19 | else: 20 | raise ValueError 21 | 22 | self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain)) 23 | 24 | self.valid_domain_idxs = list(range(n_domain)) 25 | self.n_domain = n_domain 26 | self.domain_idx = 0 27 | 28 | def select_bn(self, domain_idx=0): 29 | assert domain_idx in self.valid_domain_idxs 30 | self.domain_idx = domain_idx 31 | 32 | def forward(self, x): 33 | return self.bn[self.domain_idx](x) 34 | 35 | 36 | class DSBN1d(_DSBN): 37 | 38 | def __init__(self, num_features, n_domain): 39 | super().__init__(num_features, n_domain, "1d") 40 | 41 | 42 | class DSBN2d(_DSBN): 43 | 44 | def __init__(self, num_features, n_domain): 45 | super().__init__(num_features, n_domain, "2d") 46 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mixup(x1, x2, y1, y2, beta, preserve_order=False): 5 | """Mixup. 6 | 7 | Args: 8 | x1 (torch.Tensor): data with shape of (b, c, h, w). 9 | x2 (torch.Tensor): data with shape of (b, c, h, w). 10 | y1 (torch.Tensor): label with shape of (b, n). 11 | y2 (torch.Tensor): label with shape of (b, n). 12 | beta (float): hyper-parameter for Beta sampling. 13 | preserve_order (bool): apply lmda=max(lmda, 1-lmda). 14 | Default is False. 15 | """ 16 | lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1]) 17 | if preserve_order: 18 | lmda = torch.max(lmda, 1 - lmda) 19 | lmda = lmda.to(x1.device) 20 | xmix = x1*lmda + x2 * (1-lmda) 21 | lmda = lmda[:, :, 0, 0] 22 | ymix = y1*lmda + y2 * (1-lmda) 23 | return xmix, ymix 24 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class MaximumMeanDiscrepancy(nn.Module): 7 | 8 | def __init__(self, kernel_type="rbf", normalize=False): 9 | super().__init__() 10 | self.kernel_type = kernel_type 11 | self.normalize = normalize 12 | 13 | def forward(self, x, y): 14 | # x, y: two batches of data with shape (batch, dim) 15 | # MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y') 16 | if self.normalize: 17 | x = F.normalize(x, dim=1) 18 | y = F.normalize(y, dim=1) 19 | if self.kernel_type == "linear": 20 | return self.linear_mmd(x, y) 21 | elif self.kernel_type == "poly": 22 | return self.poly_mmd(x, y) 23 | elif self.kernel_type == "rbf": 24 | return self.rbf_mmd(x, y) 25 | else: 26 | raise NotImplementedError 27 | 28 | def linear_mmd(self, x, y): 29 | # k(x, y) = x^T y 30 | k_xx = self.remove_self_distance(torch.mm(x, x.t())) 31 | k_yy = self.remove_self_distance(torch.mm(y, y.t())) 32 | k_xy = torch.mm(x, y.t()) 33 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 34 | 35 | def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2): 36 | # k(x, y) = (alpha * x^T y + c)^d 37 | k_xx = self.remove_self_distance(torch.mm(x, x.t())) 38 | k_xx = (alpha*k_xx + c).pow(d) 39 | k_yy = self.remove_self_distance(torch.mm(y, y.t())) 40 | k_yy = (alpha*k_yy + c).pow(d) 41 | k_xy = torch.mm(x, y.t()) 42 | k_xy = (alpha*k_xy + c).pow(d) 43 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 44 | 45 | def rbf_mmd(self, x, y): 46 | # k_xx 47 | d_xx = self.euclidean_squared_distance(x, x) 48 | d_xx = self.remove_self_distance(d_xx) 49 | k_xx = self.rbf_kernel_mixture(d_xx) 50 | # k_yy 51 | d_yy = self.euclidean_squared_distance(y, y) 52 | d_yy = self.remove_self_distance(d_yy) 53 | k_yy = self.rbf_kernel_mixture(d_yy) 54 | # k_xy 55 | d_xy = self.euclidean_squared_distance(x, y) 56 | k_xy = self.rbf_kernel_mixture(d_xy) 57 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 58 | 59 | @staticmethod 60 | def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]): 61 | K = 0 62 | for sigma in sigmas: 63 | gamma = 1.0 / (2.0 * sigma**2) 64 | K += torch.exp(-gamma * exponent) 65 | return K 66 | 67 | @staticmethod 68 | def remove_self_distance(distmat): 69 | tmp_list = [] 70 | for i, row in enumerate(distmat): 71 | row1 = torch.cat([row[:i], row[i + 1:]]) 72 | tmp_list.append(row1) 73 | return torch.stack(tmp_list) 74 | 75 | @staticmethod 76 | def euclidean_squared_distance(x, y): 77 | m, n = x.size(0), y.size(0) 78 | distmat = ( 79 | torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + 80 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 81 | ) 82 | # distmat.addmm_(1, -2, x, y.t()) 83 | distmat.addmm_(x, y.t(), beta=1, alpha=-2) 84 | return distmat 85 | 86 | 87 | if __name__ == "__main__": 88 | mmd = MaximumMeanDiscrepancy(kernel_type="rbf") 89 | input1, input2 = torch.rand(3, 100), torch.rand(3, 100) 90 | d = mmd(input1, input2) 91 | print(d.item()) 92 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/reverse_grad.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Function 3 | 4 | 5 | class _ReverseGrad(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input, grad_scaling): 9 | ctx.grad_scaling = grad_scaling 10 | return input.view_as(input) 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | grad_scaling = ctx.grad_scaling 15 | return -grad_scaling * grad_output, None 16 | 17 | 18 | reverse_grad = _ReverseGrad.apply 19 | 20 | 21 | class ReverseGrad(nn.Module): 22 | """Gradient reversal layer. 23 | 24 | It acts as an identity layer in the forward, 25 | but reverses the sign of the gradient in 26 | the backward. 27 | """ 28 | 29 | def forward(self, x, grad_scaling=1.0): 30 | assert (grad_scaling >= 31 | 0), "grad_scaling must be non-negative, " "but got {}".format( 32 | grad_scaling 33 | ) 34 | return reverse_grad(x, grad_scaling) 35 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/sequential2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Sequential2(nn.Sequential): 5 | """An alternative sequential container to nn.Sequential, 6 | which accepts an arbitrary number of input arguments. 7 | """ 8 | 9 | def forward(self, *inputs): 10 | for module in self._modules.values(): 11 | if isinstance(inputs, tuple): 12 | inputs = module(*inputs) 13 | else: 14 | inputs = module(inputs) 15 | return inputs 16 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/modeling/ops/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def sharpen_prob(p, temperature=2): 6 | """Sharpening probability with a temperature. 7 | 8 | Args: 9 | p (torch.Tensor): probability matrix (batch_size, n_classes) 10 | temperature (float): temperature. 11 | """ 12 | p = p.pow(temperature) 13 | return p / p.sum(1, keepdim=True) 14 | 15 | 16 | def reverse_index(data, label): 17 | """Reverse order.""" 18 | inv_idx = torch.arange(data.size(0) - 1, -1, -1).long() 19 | return data[inv_idx], label[inv_idx] 20 | 21 | 22 | def shuffle_index(data, label): 23 | """Shuffle order.""" 24 | rnd_idx = torch.randperm(data.shape[0]) 25 | return data[rnd_idx], label[rnd_idx] 26 | 27 | 28 | def create_onehot(label, num_classes): 29 | """Create one-hot tensor. 30 | 31 | We suggest using nn.functional.one_hot. 32 | 33 | Args: 34 | label (torch.Tensor): 1-D tensor. 35 | num_classes (int): number of classes. 36 | """ 37 | onehot = torch.zeros(label.shape[0], num_classes) 38 | onehot = onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1) 39 | onehot = onehot.to(label.device) 40 | return onehot 41 | 42 | 43 | def sigmoid_rampup(current, rampup_length): 44 | """Exponential rampup. 45 | 46 | Args: 47 | current (int): current step. 48 | rampup_length (int): maximum step. 49 | """ 50 | assert rampup_length > 0 51 | current = np.clip(current, 0.0, rampup_length) 52 | phase = 1.0 - current/rampup_length 53 | return float(np.exp(-5.0 * phase * phase)) 54 | 55 | 56 | def linear_rampup(current, rampup_length): 57 | """Linear rampup. 58 | 59 | Args: 60 | current (int): current step. 61 | rampup_length (int): maximum step. 62 | """ 63 | assert rampup_length > 0 64 | ratio = np.clip(current / rampup_length, 0.0, 1.0) 65 | return float(ratio) 66 | 67 | 68 | def ema_model_update(model, ema_model, alpha): 69 | """Exponential moving average of model parameters. 70 | 71 | Args: 72 | model (nn.Module): model being trained. 73 | ema_model (nn.Module): ema of the model. 74 | alpha (float): ema decay rate. 75 | """ 76 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 77 | ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) 78 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer import build_optimizer 2 | from .lr_scheduler import build_lr_scheduler 3 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .logger import * 3 | from .meters import * 4 | from .registry import * 5 | from .torchtools import * 6 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import os.path as osp 5 | 6 | from .tools import mkdir_if_missing 7 | 8 | __all__ = ["Logger", "setup_logger"] 9 | 10 | 11 | class Logger: 12 | """Write console output to external text file. 13 | 14 | Imported from ``_ 15 | 16 | Args: 17 | fpath (str): directory to save logging file. 18 | 19 | Examples:: 20 | >>> import sys 21 | >>> import os.path as osp 22 | >>> save_dir = 'output/experiment-1' 23 | >>> log_name = 'train.log' 24 | >>> sys.stdout = Logger(osp.join(save_dir, log_name)) 25 | """ 26 | 27 | def __init__(self, fpath=None): 28 | self.console = sys.stdout 29 | self.file = None 30 | if fpath is not None: 31 | mkdir_if_missing(osp.dirname(fpath)) 32 | self.file = open(fpath, "w") 33 | 34 | def __del__(self): 35 | self.close() 36 | 37 | def __enter__(self): 38 | pass 39 | 40 | def __exit__(self, *args): 41 | self.close() 42 | 43 | def write(self, msg): 44 | self.console.write(msg) 45 | if self.file is not None: 46 | self.file.write(msg) 47 | 48 | def flush(self): 49 | self.console.flush() 50 | if self.file is not None: 51 | self.file.flush() 52 | os.fsync(self.file.fileno()) 53 | 54 | def close(self): 55 | self.console.close() 56 | if self.file is not None: 57 | self.file.close() 58 | 59 | 60 | def setup_logger(output=None): 61 | if output is None: 62 | return 63 | 64 | if output.endswith(".txt") or output.endswith(".log"): 65 | fpath = output 66 | else: 67 | fpath = osp.join(output, "log.txt") 68 | 69 | if osp.exists(fpath): 70 | # make sure the existing log file is not over-written 71 | fpath += time.strftime("-%Y-%m-%d-%H-%M-%S") 72 | 73 | sys.stdout = Logger(fpath) 74 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/utils/meters.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | 4 | __all__ = ["AverageMeter", "MetricMeter"] 5 | 6 | 7 | class AverageMeter: 8 | """Compute and store the average and current value. 9 | 10 | Examples:: 11 | >>> # 1. Initialize a meter to record loss 12 | >>> losses = AverageMeter() 13 | >>> # 2. Update meter after every mini-batch update 14 | >>> losses.update(loss_value, batch_size) 15 | """ 16 | 17 | def __init__(self, ema=False): 18 | """ 19 | Args: 20 | ema (bool, optional): apply exponential moving average. 21 | """ 22 | self.ema = ema 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | if isinstance(val, torch.Tensor): 33 | val = val.item() 34 | 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | 39 | if self.ema: 40 | self.avg = self.avg * 0.9 + self.val * 0.1 41 | else: 42 | self.avg = self.sum / self.count 43 | 44 | 45 | class MetricMeter: 46 | """Store the average and current value for a set of metrics. 47 | 48 | Examples:: 49 | >>> # 1. Create an instance of MetricMeter 50 | >>> metric = MetricMeter() 51 | >>> # 2. Update using a dictionary as input 52 | >>> input_dict = {'loss_1': value_1, 'loss_2': value_2} 53 | >>> metric.update(input_dict) 54 | >>> # 3. Convert to string and print 55 | >>> print(str(metric)) 56 | """ 57 | 58 | def __init__(self, delimiter=" "): 59 | self.meters = defaultdict(AverageMeter) 60 | self.delimiter = delimiter 61 | 62 | def update(self, input_dict): 63 | if input_dict is None: 64 | return 65 | 66 | if not isinstance(input_dict, dict): 67 | raise TypeError( 68 | "Input to MetricMeter.update() must be a dictionary" 69 | ) 70 | 71 | for k, v in input_dict.items(): 72 | if isinstance(v, torch.Tensor): 73 | v = v.item() 74 | self.meters[k].update(v) 75 | 76 | def __str__(self): 77 | output_str = [] 78 | for name, meter in self.meters.items(): 79 | output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})") 80 | return self.delimiter.join(output_str) 81 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/dassl/utils/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/facebookresearch/fvcore 3 | """ 4 | __all__ = ["Registry"] 5 | 6 | 7 | class Registry: 8 | """A registry providing name -> object mapping, to support 9 | custom modules. 10 | 11 | To create a registry (e.g. a backbone registry): 12 | 13 | .. code-block:: python 14 | 15 | BACKBONE_REGISTRY = Registry('BACKBONE') 16 | 17 | To register an object: 18 | 19 | .. code-block:: python 20 | 21 | @BACKBONE_REGISTRY.register() 22 | class MyBackbone(nn.Module): 23 | ... 24 | 25 | Or: 26 | 27 | .. code-block:: python 28 | 29 | BACKBONE_REGISTRY.register(MyBackbone) 30 | """ 31 | 32 | def __init__(self, name): 33 | self._name = name 34 | self._obj_map = dict() 35 | 36 | def _do_register(self, name, obj, force=False): 37 | if name in self._obj_map and not force: 38 | raise KeyError( 39 | 'An object named "{}" was already ' 40 | 'registered in "{}" registry'.format(name, self._name) 41 | ) 42 | 43 | self._obj_map[name] = obj 44 | 45 | def register(self, obj=None, force=False): 46 | if obj is None: 47 | # Used as a decorator 48 | def wrapper(fn_or_class): 49 | name = fn_or_class.__name__ 50 | self._do_register(name, fn_or_class, force=force) 51 | return fn_or_class 52 | 53 | return wrapper 54 | 55 | # Used as a function call 56 | name = obj.__name__ 57 | self._do_register(name, obj, force=force) 58 | 59 | def get(self, name): 60 | if name not in self._obj_map: 61 | raise KeyError( 62 | 'Object name "{}" does not exist ' 63 | 'in "{}" registry'.format(name, self._name) 64 | ) 65 | 66 | return self._obj_map[name] 67 | 68 | def registered_names(self): 69 | return list(self._obj_map.keys()) 70 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/datasets/da/cifar_stl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pprint as pp 3 | import os.path as osp 4 | from torchvision.datasets import STL10, CIFAR10 5 | 6 | from dassl.utils import mkdir_if_missing 7 | 8 | cifar_label2name = { 9 | 0: "airplane", 10 | 1: "car", # the original name was 'automobile' 11 | 2: "bird", 12 | 3: "cat", 13 | 4: "deer", 14 | 5: "dog", 15 | 6: "frog", # conflict class 16 | 7: "horse", 17 | 8: "ship", 18 | 9: "truck", 19 | } 20 | 21 | stl_label2name = { 22 | 0: "airplane", 23 | 1: "bird", 24 | 2: "car", 25 | 3: "cat", 26 | 4: "deer", 27 | 5: "dog", 28 | 6: "horse", 29 | 7: "monkey", # conflict class 30 | 8: "ship", 31 | 9: "truck", 32 | } 33 | 34 | new_name2label = { 35 | "airplane": 0, 36 | "bird": 1, 37 | "car": 2, 38 | "cat": 3, 39 | "deer": 4, 40 | "dog": 5, 41 | "horse": 6, 42 | "ship": 7, 43 | "truck": 8, 44 | } 45 | 46 | 47 | def extract_and_save_image(dataset, save_dir, discard, label2name): 48 | if osp.exists(save_dir): 49 | print('Folder "{}" already exists'.format(save_dir)) 50 | return 51 | 52 | print('Extracting images to "{}" ...'.format(save_dir)) 53 | mkdir_if_missing(save_dir) 54 | 55 | for i in range(len(dataset)): 56 | img, label = dataset[i] 57 | if label == discard: 58 | continue 59 | class_name = label2name[label] 60 | label_new = new_name2label[class_name] 61 | class_dir = osp.join( 62 | save_dir, 63 | str(label_new).zfill(3) + "_" + class_name 64 | ) 65 | mkdir_if_missing(class_dir) 66 | impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg") 67 | img.save(impath) 68 | 69 | 70 | def download_and_prepare(name, root, discarded_label, label2name): 71 | print("Dataset: {}".format(name)) 72 | print("Root: {}".format(root)) 73 | print("Old labels:") 74 | pp.pprint(label2name) 75 | print("Discarded label: {}".format(discarded_label)) 76 | print("New labels:") 77 | pp.pprint(new_name2label) 78 | 79 | if name == "cifar": 80 | train = CIFAR10(root, train=True, download=True) 81 | test = CIFAR10(root, train=False) 82 | else: 83 | train = STL10(root, split="train", download=True) 84 | test = STL10(root, split="test") 85 | 86 | train_dir = osp.join(root, name, "train") 87 | test_dir = osp.join(root, name, "test") 88 | 89 | extract_and_save_image(train, train_dir, discarded_label, label2name) 90 | extract_and_save_image(test, test_dir, discarded_label, label2name) 91 | 92 | 93 | if __name__ == "__main__": 94 | download_and_prepare("cifar", sys.argv[1], 6, cifar_label2name) 95 | download_and_prepare("stl", sys.argv[1], 7, stl_label2name) 96 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/datasets/da/visda17.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # ROOT is the root directory where you put your domain datasets. 3 | # 4 | # Suppose you wanna put the dataset under $DATA, which stores all the 5 | # domain datasets, run the following command in your terminal to 6 | # download VisDa17: 7 | # 8 | # $ sh visda17.sh $DATA 9 | #------------------------------------------------------------------------ 10 | 11 | ROOT=$1 12 | mkdir $ROOT/visda17 13 | cd $ROOT/visda17 14 | 15 | wget http://csr.bu.edu/ftp/visda17/clf/train.tar 16 | tar xvf train.tar 17 | 18 | wget http://csr.bu.edu/ftp/visda17/clf/validation.tar 19 | tar xvf validation.tar 20 | 21 | wget http://csr.bu.edu/ftp/visda17/clf/test.tar 22 | tar xvf test.tar 23 | 24 | wget https://raw.githubusercontent.com/VisionLearningGroup/taskcv-2017-public/master/classification/data/image_list.txt -O test/image_list.txt -------------------------------------------------------------------------------- /code/Dassl.pytorch/datasets/dg/cifar_c.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script 3 | - creates a folder named "cifar10_c" under the same directory as 'CIFAR-10-C' 4 | - extracts images from .npy files and save them as .jpg. 5 | """ 6 | import os 7 | import sys 8 | import numpy as np 9 | import os.path as osp 10 | from PIL import Image 11 | 12 | from dassl.utils import mkdir_if_missing 13 | 14 | 15 | def extract_and_save(images, labels, level, dst): 16 | # level denotes the corruption intensity level (0-based) 17 | assert 0 <= level <= 4 18 | 19 | for i in range(10000): 20 | real_i = i + level*10000 21 | im = Image.fromarray(images[real_i]) 22 | label = int(labels[real_i]) 23 | category_dir = osp.join(dst, str(label).zfill(3)) 24 | mkdir_if_missing(category_dir) 25 | save_path = osp.join(category_dir, str(i + 1).zfill(5) + ".jpg") 26 | im.save(save_path) 27 | 28 | 29 | def main(npy_folder): 30 | npy_folder = osp.abspath(osp.expanduser(npy_folder)) 31 | dataset_cap = osp.basename(npy_folder) 32 | 33 | assert dataset_cap in ["CIFAR-10-C", "CIFAR-100-C"] 34 | 35 | if dataset_cap == "CIFAR-10-C": 36 | dataset = "cifar10_c" 37 | else: 38 | dataset = "cifar100_c" 39 | 40 | if not osp.exists(npy_folder): 41 | print('The given folder "{}" does not exist'.format(npy_folder)) 42 | 43 | root = osp.dirname(npy_folder) 44 | im_folder = osp.join(root, dataset) 45 | 46 | mkdir_if_missing(im_folder) 47 | 48 | dirnames = os.listdir(npy_folder) 49 | dirnames.remove("labels.npy") 50 | if "README.txt" in dirnames: 51 | dirnames.remove("README.txt") 52 | assert len(dirnames) == 19 53 | labels = np.load(osp.join(npy_folder, "labels.npy")) 54 | 55 | for dirname in dirnames: 56 | corruption = dirname.split(".")[0] 57 | corruption_folder = osp.join(im_folder, corruption) 58 | mkdir_if_missing(corruption_folder) 59 | 60 | npy_filename = osp.join(npy_folder, dirname) 61 | images = np.load(npy_filename) 62 | assert images.shape[0] == 50000 63 | 64 | for level in range(5): 65 | dst = osp.join(corruption_folder, str(level + 1)) 66 | mkdir_if_missing(dst) 67 | print('Saving images to "{}"'.format(dst)) 68 | extract_and_save(images, labels, level, dst) 69 | 70 | 71 | if __name__ == "__main__": 72 | # sys.argv[1] contains the path to CIFAR-10-C or CIFAR-100-C 73 | main(sys.argv[1]) 74 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/datasets/ssl/cifar10_cifar100_svhn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | from torchvision.datasets import SVHN, CIFAR10, CIFAR100 4 | 5 | from dassl.utils import mkdir_if_missing 6 | 7 | 8 | def extract_and_save_image(dataset, save_dir): 9 | if osp.exists(save_dir): 10 | print('Folder "{}" already exists'.format(save_dir)) 11 | return 12 | 13 | print('Extracting images to "{}" ...'.format(save_dir)) 14 | mkdir_if_missing(save_dir) 15 | 16 | for i in range(len(dataset)): 17 | img, label = dataset[i] 18 | class_dir = osp.join(save_dir, str(label).zfill(3)) 19 | mkdir_if_missing(class_dir) 20 | impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg") 21 | img.save(impath) 22 | 23 | 24 | def download_and_prepare(name, root): 25 | print("Dataset: {}".format(name)) 26 | print("Root: {}".format(root)) 27 | 28 | if name == "cifar10": 29 | train = CIFAR10(root, train=True, download=True) 30 | test = CIFAR10(root, train=False) 31 | elif name == "cifar100": 32 | train = CIFAR100(root, train=True, download=True) 33 | test = CIFAR100(root, train=False) 34 | elif name == "svhn": 35 | train = SVHN(root, split="train", download=True) 36 | test = SVHN(root, split="test", download=True) 37 | else: 38 | raise ValueError 39 | 40 | train_dir = osp.join(root, name, "train") 41 | test_dir = osp.join(root, name, "test") 42 | 43 | extract_and_save_image(train, train_dir) 44 | extract_and_save_image(test, test_dir) 45 | 46 | 47 | if __name__ == "__main__": 48 | download_and_prepare("cifar10", sys.argv[1]) 49 | download_and_prepare("cifar100", sys.argv[1]) 50 | download_and_prepare("svhn", sys.argv[1]) 51 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/datasets/ssl/stl10.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | from torchvision.datasets import STL10 4 | 5 | from dassl.utils import mkdir_if_missing 6 | 7 | 8 | def extract_and_save_image(dataset, save_dir): 9 | if osp.exists(save_dir): 10 | print('Folder "{}" already exists'.format(save_dir)) 11 | return 12 | 13 | print('Extracting images to "{}" ...'.format(save_dir)) 14 | mkdir_if_missing(save_dir) 15 | 16 | for i in range(len(dataset)): 17 | img, label = dataset[i] 18 | if label == -1: 19 | label_name = "none" 20 | else: 21 | label_name = str(label) 22 | imname = str(i).zfill(6) + "_" + label_name + ".jpg" 23 | impath = osp.join(save_dir, imname) 24 | img.save(impath) 25 | 26 | 27 | def download_and_prepare(root): 28 | train = STL10(root, split="train", download=True) 29 | test = STL10(root, split="test") 30 | unlabeled = STL10(root, split="unlabeled") 31 | 32 | train_dir = osp.join(root, "train") 33 | test_dir = osp.join(root, "test") 34 | unlabeled_dir = osp.join(root, "unlabeled") 35 | 36 | extract_and_save_image(train, train_dir) 37 | extract_and_save_image(test, test_dir) 38 | extract_and_save_image(unlabeled, unlabeled_dir) 39 | 40 | 41 | if __name__ == "__main__": 42 | download_and_prepare(sys.argv[1]) 43 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/linter.sh: -------------------------------------------------------------------------------- 1 | echo "Running isort" 2 | isort -y -sp . 3 | echo "Done" 4 | 5 | echo "Running yapf" 6 | yapf -i -r -vv -e build . 7 | echo "Done" 8 | 9 | echo "Running flake8" 10 | flake8 . 11 | echo "Done" -------------------------------------------------------------------------------- /code/Dassl.pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | flake8==3.7.9 2 | yapf==0.29.0 3 | isort==4.3.21 4 | yacs 5 | gdown 6 | tb-nightly 7 | future 8 | scipy 9 | scikit-learn 10 | tqdm 11 | ftfy 12 | regex 13 | wilds==1.2.2 14 | tabulate 15 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/setup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | from setuptools import setup, find_packages 4 | 5 | 6 | def readme(): 7 | with open('README.md') as f: 8 | content = f.read() 9 | return content 10 | 11 | 12 | def find_version(): 13 | version_file = 'dassl/__init__.py' 14 | with open(version_file, 'r') as f: 15 | exec(compile(f.read(), version_file, 'exec')) 16 | return locals()['__version__'] 17 | 18 | 19 | def numpy_include(): 20 | try: 21 | numpy_include = np.get_include() 22 | except AttributeError: 23 | numpy_include = np.get_numpy_include() 24 | return numpy_include 25 | 26 | 27 | def get_requirements(filename='requirements.txt'): 28 | here = osp.dirname(osp.realpath(__file__)) 29 | with open(osp.join(here, filename), 'r') as f: 30 | requires = [line.replace('\n', '') for line in f.readlines()] 31 | return requires 32 | 33 | 34 | setup( 35 | name='dassl', 36 | version=find_version(), 37 | description='Dassl: Domain adaptation and semi-supervised learning', 38 | author='Kaiyang Zhou', 39 | license='MIT', 40 | long_description=readme(), 41 | url='https://github.com/KaiyangZhou/Dassl.pytorch', 42 | packages=find_packages(), 43 | install_requires=get_requirements(), 44 | keywords=[ 45 | 'Domain Adaptation', 'Domain Generalization', 46 | 'Semi-Supervised Learning', 'Pytorch' 47 | ] 48 | ) 49 | -------------------------------------------------------------------------------- /code/Dassl.pytorch/tools/replace_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replace text in python files. 3 | """ 4 | import glob 5 | import os.path as osp 6 | import argparse 7 | import fileinput 8 | 9 | EXTENSION = ".py" 10 | 11 | 12 | def is_python_file(filename): 13 | ext = osp.splitext(filename)[1] 14 | return ext == EXTENSION 15 | 16 | 17 | def update_file(filename, text_to_search, replacement_text): 18 | print("Processing {}".format(filename)) 19 | with fileinput.FileInput(filename, inplace=True, backup="") as file: 20 | for line in file: 21 | print(line.replace(text_to_search, replacement_text), end="") 22 | 23 | 24 | def recursive_update(directory, text_to_search, replacement_text): 25 | filenames = glob.glob(osp.join(directory, "*")) 26 | 27 | for filename in filenames: 28 | if osp.isfile(filename): 29 | if not is_python_file(filename): 30 | continue 31 | update_file(filename, text_to_search, replacement_text) 32 | elif osp.isdir(filename): 33 | recursive_update(filename, text_to_search, replacement_text) 34 | else: 35 | raise NotImplementedError 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument( 41 | "file_or_dir", type=str, help="path to file or directory" 42 | ) 43 | parser.add_argument("text_to_search", type=str, help="name to be replaced") 44 | parser.add_argument("replacement_text", type=str, help="new name") 45 | parser.add_argument( 46 | "--ext", type=str, default=".py", help="file extension" 47 | ) 48 | args = parser.parse_args() 49 | 50 | file_or_dir = args.file_or_dir 51 | text_to_search = args.text_to_search 52 | replacement_text = args.replacement_text 53 | extension = args.ext 54 | 55 | global EXTENSION 56 | EXTENSION = extension 57 | 58 | if osp.isfile(file_or_dir): 59 | if not is_python_file(file_or_dir): 60 | return 61 | update_file(file_or_dir, text_to_search, replacement_text) 62 | elif osp.isdir(file_or_dir): 63 | recursive_update(file_or_dir, text_to_search, replacement_text) 64 | else: 65 | raise NotImplementedError 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /code/LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kaiyang Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | # Regularized Mask Tuning: Uncovering Hidden Knowledge in Pre-trained Vision-Language Models 2 | Official implementation of ['Regularized Mask Tuning: Uncovering Hidden Knowledge in Pre-trained Vision-Language Models](https://arxiv.org/abs/2307.15049)'. 3 | 4 | ## How to Install 5 | 6 | This code is built on top of [CoOP](https://github.com/KaiyangZhou/CoOp). So you need to install the environment following CoOP first. After that, run pip install -r requirements.txt to install a few more packages. 7 | 8 | Follow [DATASET.md](https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md) to install 11 datasets referring to CoOp. 9 | 10 | 11 | ## How to Run 12 | 13 | The running scripts are provided in `scripts/masktuning/`, which allow you to reproduce the results on the ICCV'23 paper. 14 | 15 | 16 | ### Few-shot Classification 17 | This corresponds to the experiments in Section 4.2, i.e., Fig 4. 18 | 19 | You will need `scripts/masktuning/train.sh` for training. The script has two input arguments, i.e., `DATASET` and `GDR_LAMBDA`. `DATASET` takes as input a dataset name, like `imagenet` or `caltech101`. `GDR_LAMBDA` is the parameter in Eq 7, which is set to be is set to 0.3 for datasets except for ImageNet, SUN397, and Food101 in 16-shot experiments. And l is set to 1.0 in other experiments. 20 | 21 | For evaluation, you will need `scripts/masktuning/eval.sh`. The script has one input argument, i.e., `DATASET`. 22 | 23 | Below we provide an example on how to train and evaluate the model on ImageNet and Caltech101. 24 | 25 | ```bash 26 | # train 27 | bash scripts/masktuning/train.sh imagenet 3e-1 28 | bash scripts/masktuning/train.sh caltech101 10e-1 29 | 30 | # eval 31 | bash scripts/masktuning/eval.sh imagenet 32 | bash scripts/masktuning/eval.sh caltech101 33 | ``` 34 | 35 | ### Generalization From Base to New Classes 36 | 37 | This corresponds to the experiments in Section 4.2, i.e., Table 1. 38 | 39 | You will need `scripts/masktuning/base2new_train.sh` and `scripts/masktuning/base2new_eval.sh` for training and evaluation. Both scripts have one input argument, i.e., `DATASET`. 40 | 41 | Below we provide an example on how to train and evaluate the model on ImageNet. 42 | 43 | ```bash 44 | # train 45 | bash scripts/masktuning/base2new_train.sh imagenet 46 | 47 | # eval 48 | bash scripts/masktuning/base2new_eval.sh imagenet 49 | ``` -------------------------------------------------------------------------------- /code/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /code/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/code/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /code/configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" -------------------------------------------------------------------------------- /code/configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" -------------------------------------------------------------------------------- /code/configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | -------------------------------------------------------------------------------- /code/configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | -------------------------------------------------------------------------------- /code/configs/trainers/CoCoOp/rn50_b16_c4_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 2 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "" 35 | PREC: "fp16" 36 | -------------------------------------------------------------------------------- /code/configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 16 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /code/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /code/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /code/configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 8 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn101.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn101_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | RRCROP_SCALE: (0.5, 1.0) 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.002 19 | MAX_EPOCH: 200 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 5 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "RN50" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn50_batch16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 16 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.001 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn50_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 200 4 | TEST: 5 | BATCH_SIZE: 200 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | MODEL: 16 | BACKBONE: 17 | NAME: "RN50" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b16_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b16_ep100_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b16_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b16_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 200 4 | TEST: 5 | BATCH_SIZE: 200 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | TRAIN: 16 | PRINT_FREQ: 5 17 | 18 | MODEL: 19 | BACKBONE: 20 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b32.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /code/configs/trainers/CoOp/vit_b32_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /code/configs/trainers/MASK/imagenet_vit_b16.yaml: -------------------------------------------------------------------------------- 1 | MASK: 2 | INIT: '1s' 3 | SCALE: 1e-2 4 | THRESHOLD: 5e-3 5 | MASK_LOSS: True 6 | LOSS_WEIGHT: 1e-4 7 | MASK_MLP: False 8 | GDR: True 9 | GDR_LAMBDA: 1. 10 | 11 | TEST: 12 | NO_TEST: True 13 | DATALOADER: 14 | TRAIN_X: 15 | BATCH_SIZE: 32 16 | TEST: 17 | BATCH_SIZE: 100 18 | NUM_WORKERS: 8 19 | 20 | INPUT: 21 | SIZE: (224, 224) 22 | INTERPOLATION: "bicubic" 23 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 24 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 25 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 26 | RRCROP_SCALE: (0.5, 1.0) 27 | 28 | OPTIM: 29 | NAME: "adam" 30 | WEIGHT_DECAY: 0. 31 | LR: 3e-5 32 | MAX_EPOCH: 10 33 | LR_SCHEDULER: "cosine" 34 | WARMUP_EPOCH: -1 35 | 36 | TRAIN: 37 | PRINT_FREQ: 5 38 | 39 | MODEL: 40 | BACKBONE: 41 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /code/configs/trainers/MASK/vit_b16.yaml: -------------------------------------------------------------------------------- 1 | MASK: 2 | INIT: '1s' 3 | SCALE: 1e-2 4 | THRESHOLD: 5e-3 5 | MASK_LOSS: True 6 | LOSS_WEIGHT: 1e-4 7 | MASK_MLP: False 8 | GDR: True 9 | GDR_LAMBDA: 1. 10 | 11 | TEST: 12 | NO_TEST: True 13 | DATALOADER: 14 | TRAIN_X: 15 | BATCH_SIZE: 32 16 | TEST: 17 | BATCH_SIZE: 100 18 | NUM_WORKERS: 8 19 | 20 | INPUT: 21 | SIZE: (224, 224) 22 | INTERPOLATION: "bicubic" 23 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 24 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 25 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 26 | RRCROP_SCALE: (0.5, 1.0) 27 | 28 | OPTIM: 29 | NAME: "adam" 30 | WEIGHT_DECAY: 0. 31 | LR: 8e-5 32 | MAX_EPOCH: 30 33 | LR_SCHEDULER: "cosine" 34 | WARMUP_EPOCH: -1 35 | 36 | TRAIN: 37 | PRINT_FREQ: 5 38 | 39 | MODEL: 40 | BACKBONE: 41 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /code/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/code/datasets/__init__.py -------------------------------------------------------------------------------- /code/datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 11 | NEW_CNAMES = { 12 | "airplanes": "airplane", 13 | "Faces": "face", 14 | "Leopards": "leopard", 15 | "Motorbikes": "motorbike", 16 | } 17 | 18 | 19 | @DATASET_REGISTRY.register() 20 | class Caltech101(DatasetBase): 21 | 22 | dataset_dir = "caltech-101" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 28 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 29 | self.split_fewshot_dir = "../data_config/caltech-101_split_fewshot" #os.path.join(self.dataset_dir, "split_fewshot") 30 | mkdir_if_missing(self.split_fewshot_dir) 31 | 32 | if os.path.exists(self.split_path): 33 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 34 | else: 35 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | -------------------------------------------------------------------------------- /code/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | NEW_CNAMES = { 11 | "AnnualCrop": "Annual Crop Land", 12 | "Forest": "Forest", 13 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 14 | "Highway": "Highway or Road", 15 | "Industrial": "Industrial Buildings", 16 | "Pasture": "Pasture Land", 17 | "PermanentCrop": "Permanent Crop Land", 18 | "Residential": "Residential Buildings", 19 | "River": "River", 20 | "SeaLake": "Sea or Lake", 21 | } 22 | 23 | 24 | @DATASET_REGISTRY.register() 25 | class EuroSAT(DatasetBase): 26 | 27 | dataset_dir = "eurosat" 28 | 29 | def __init__(self, cfg): 30 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 31 | self.dataset_dir = os.path.join(root, self.dataset_dir) 32 | self.image_dir = os.path.join(self.dataset_dir, "2750") 33 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 34 | self.split_fewshot_dir = "../data_config/eurosat_split_fewshot" #os.path.join(self.dataset_dir, "split_fewshot") 35 | mkdir_if_missing(self.split_fewshot_dir) 36 | 37 | if os.path.exists(self.split_path): 38 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 39 | else: 40 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 41 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 42 | 43 | num_shots = cfg.DATASET.NUM_SHOTS 44 | if num_shots >= 1: 45 | seed = cfg.SEED 46 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 47 | 48 | if os.path.exists(preprocessed): 49 | print(f"Loading preprocessed few-shot data from {preprocessed}") 50 | with open(preprocessed, "rb") as file: 51 | data = pickle.load(file) 52 | train, val = data["train"], data["val"] 53 | else: 54 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 55 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 56 | data = {"train": train, "val": val} 57 | print(f"Saving preprocessed few-shot data to {preprocessed}") 58 | with open(preprocessed, "wb") as file: 59 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 60 | 61 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 62 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 63 | 64 | super().__init__(train_x=train, val=val, test=test) 65 | 66 | def update_classname(self, dataset_old): 67 | dataset_new = [] 68 | for item_old in dataset_old: 69 | cname_old = item_old.classname 70 | cname_new = NEW_CLASSNAMES[cname_old] 71 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 72 | dataset_new.append(item_new) 73 | return dataset_new 74 | -------------------------------------------------------------------------------- /code/datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class FGVCAircraft(DatasetBase): 12 | 13 | dataset_dir = "fgvc-aircraft-2013b/" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = "../data_config/fgvc_aircraft_split_fewshot" #os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "images_variant_train.txt") 30 | val = self.read_data(cname2lab, "images_variant_val.txt") 31 | test = self.read_data(cname2lab, "images_variant_test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, cname2lab, split_file): 57 | filepath = os.path.join(self.dataset_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip().split(" ") 64 | imname = line[0] + ".jpg" 65 | classname = " ".join(line[1:]) 66 | impath = os.path.join(self.image_dir, imname) 67 | label = cname2lab[classname] 68 | item = Datum(impath=impath, label=label, classname=classname) 69 | items.append(item) 70 | 71 | return items 72 | -------------------------------------------------------------------------------- /code/datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Food101(DatasetBase): 13 | 14 | dataset_dir = "food-101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 21 | self.split_fewshot_dir = "../data_config/food-101_split_fewshot" # os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = DTD.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | -------------------------------------------------------------------------------- /code/datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetA(DatasetBase): 13 | """ImageNet-A(dversarial). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-adversarial" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 24 | 25 | text_file = os.path.join(root, "./classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /code/datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetR(DatasetBase): 13 | """ImageNet-R(endition). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-rendition" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 24 | 25 | text_file = os.path.join(root, "/classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /code/datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetSketch(DatasetBase): 11 | """ImageNet-Sketch. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenet-sketch" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "images") 22 | 23 | text_file = os.path.join(root, "./classnames.txt") 24 | classnames = ImageNet.read_classnames(text_file) 25 | 26 | data = self.read_data(classnames) 27 | 28 | super().__init__(train_x=data, test=data) 29 | 30 | def read_data(self, classnames): 31 | image_dir = self.image_dir 32 | folders = listdir_nohidden(image_dir, sort=True) 33 | items = [] 34 | 35 | for label, folder in enumerate(folders): 36 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 37 | classname = classnames[folder] 38 | for imname in imnames: 39 | impath = os.path.join(image_dir, folder, imname) 40 | item = Datum(impath=impath, label=label, classname=classname) 41 | items.append(item) 42 | 43 | return items 44 | -------------------------------------------------------------------------------- /code/datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetV2(DatasetBase): 11 | """ImageNetV2. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenetv2" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | image_dir = "imagenetv2-matched-frequency-format-val" 22 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 23 | 24 | text_file = os.path.join(root, "./classnames.txt") 25 | classnames = ImageNet.read_classnames(text_file) 26 | 27 | data = self.read_data(classnames) 28 | 29 | super().__init__(train_x=data, test=data) 30 | 31 | def read_data(self, classnames): 32 | image_dir = self.image_dir 33 | folders = list(classnames.keys()) 34 | items = [] 35 | 36 | for label in range(1000): 37 | class_dir = os.path.join(image_dir, str(label)) 38 | imnames = listdir_nohidden(class_dir) 39 | folder = folders[label] 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(class_dir, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /code/loss/reg_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def RegLoss(param, k): 6 | assert k in [1,2] 7 | param = param.view(-1) 8 | reg_loss = torch.norm(param, k) 9 | return reg_loss 10 | 11 | 12 | -------------------------------------------------------------------------------- /code/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/code/modules/__init__.py -------------------------------------------------------------------------------- /code/requirements.txt: -------------------------------------------------------------------------------- 1 | flake8==3.7.9 2 | yapf==0.29.0 3 | isort==4.3.21 4 | yacs 5 | gdown 6 | tb-nightly 7 | future 8 | scipy 9 | scikit-learn 10 | tqdm 11 | ftfy 12 | regex 13 | wilds==1.2.2 14 | tabulate 15 | opencv-python 16 | ttach 17 | matplotlib 18 | thop -------------------------------------------------------------------------------- /code/scripts/cocoop/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/usr/zkc/data 5 | TRAINER=CoCoOp 6 | 7 | DATASET=imagenet 8 | CFG=rn50_b16_c4_ep10_batch1 # config file 9 | 10 | SHOTS=16 # number of shots (1, 2, 4, 8, 16) 11 | SUB=base 12 | LOADEP=10 13 | 14 | for SEED in 1 #2 3 15 | do 16 | DIR=output/evaluation/${TRAINER}/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}/epoch${LOADEP}_${SUB} 17 | if [ -d "$DIR" ]; then 18 | echo "Oops! The results exist at ${DIR} (so skip this job)" 19 | else 20 | CUDA_VISIBLE_DEVICES=0,1 python train.py \ 21 | --root ${DATA} \ 22 | --seed ${SEED} \ 23 | --trainer ${TRAINER} \ 24 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 25 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 26 | --output-dir ${DIR} \ 27 | --model-dir output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} \ 28 | --eval-only \ 29 | --load-epoch ${LOADEP} \ 30 | DATASET.NUM_SHOTS ${SHOTS} \ 31 | DATASET.SUBSAMPLE_CLASSES ${SUB} 32 | fi 33 | done 34 | -------------------------------------------------------------------------------- /code/scripts/cocoop/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../../ 5 | TRAINER=CoCoOp 6 | 7 | DATASET=imagenet 8 | CFG=vit_b16_c4_ep10_batch1 # config file 9 | 10 | SHOTS=16 # number of shots (1, 2, 4, 8, 16) 11 | 12 | for SEED in 1 #2 3 13 | do 14 | DIR=../output/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 15 | if [ -d "$DIR" ]; then 16 | echo "Oops! The results exist at ${DIR} (so skip this job)" 17 | else 18 | python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir ${DIR} \ 25 | DATALOADER.TRAIN_X.BATCH_SIZE 1 \ 26 | DATASET.NUM_SHOTS ${SHOTS} \ 27 | DATASET.SUBSAMPLE_CLASSES base 28 | fi 29 | done 30 | -------------------------------------------------------------------------------- /code/scripts/coop/eval_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../../ 5 | TRAINER=CoOp 6 | 7 | DATASET=imagenet 8 | CFG=vit_b16 # config file 9 | CTP=end # class token position (end or middle) 10 | NCTX=16 # number of context tokens 11 | # SHOTS=16 # number of shots (1, 2, 4, 8, 16) 12 | CSC=False # class-specific context (False or True) 13 | 14 | 15 | for SHOTS in 16 16 | do 17 | for SEED in 1 18 | do 19 | DIR=../output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 20 | 21 | python train.py \ 22 | --root ${DATA} \ 23 | --seed ${SEED} \ 24 | --trainer ${TRAINER} \ 25 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 26 | --config-file configs/trainers/CoOp/${CFG}_ep50.yaml \ 27 | --output-dir ${DIR}/eval \ 28 | --eval-only \ 29 | DATALOADER.TEST.BATCH_SIZE 1 \ 30 | TRAINER.COOP.N_CTX ${NCTX} \ 31 | TRAINER.COOP.CSC ${CSC} \ 32 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 33 | DATASET.NUM_SHOTS ${SHOTS} 34 | done 35 | done 36 | -------------------------------------------------------------------------------- /code/scripts/coop/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../../datasets 5 | TRAINER=CoOp 6 | 7 | # DATASET=stanford_cars 8 | CFG=vit_b16 # config file 9 | CTP=end # class token position (end or middle) 10 | NCTX=16 # number of context tokens 11 | # SHOTS=16 # number of shots (1, 2, 4, 8, 16) 12 | CSC=False # class-specific context (False or True) 13 | 14 | for DATASET in stanford_cars dtd food101 sun397 caltech101 ucf101 15 | do 16 | for SHOTS in 16 8 17 | do 18 | for SEED in 1 2 3 19 | do 20 | DIR=../output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 21 | if [ -d "$DIR" ]; then 22 | echo "Oops! The results exist at ${DIR} (so skip this job)" 23 | else 24 | python train.py \ 25 | --root ${DATA} \ 26 | --seed ${SEED} \ 27 | --trainer ${TRAINER} \ 28 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 29 | --config-file configs/trainers/CoOp/${CFG}.yaml \ 30 | --output-dir ${DIR} \ 31 | TRAINER.COOP.N_CTX ${NCTX} \ 32 | TRAINER.COOP.CSC ${CSC} \ 33 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 34 | DATASET.NUM_SHOTS ${SHOTS} 35 | fi 36 | done 37 | done 38 | 39 | for SHOTS in 4 2 40 | do 41 | for SEED in 1 2 3 42 | do 43 | DIR=../output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 44 | if [ -d "$DIR" ]; then 45 | echo "Oops! The results exist at ${DIR} (so skip this job)" 46 | else 47 | python train.py \ 48 | --root ${DATA} \ 49 | --seed ${SEED} \ 50 | --trainer ${TRAINER} \ 51 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 52 | --config-file configs/trainers/CoOp/${CFG}_ep100.yaml \ 53 | --output-dir ${DIR} \ 54 | TRAINER.COOP.N_CTX ${NCTX} \ 55 | TRAINER.COOP.CSC ${CSC} \ 56 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 57 | DATASET.NUM_SHOTS ${SHOTS} 58 | fi 59 | done 60 | done 61 | 62 | for SHOTS in 1 63 | do 64 | for SEED in 1 2 3 65 | do 66 | DIR=../output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 67 | if [ -d "$DIR" ]; then 68 | echo "Oops! The results exist at ${DIR} (so skip this job)" 69 | else 70 | python train.py \ 71 | --root ${DATA} \ 72 | --seed ${SEED} \ 73 | --trainer ${TRAINER} \ 74 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 75 | --config-file configs/trainers/CoOp/${CFG}_ep50.yaml \ 76 | --output-dir ${DIR} \ 77 | TRAINER.COOP.N_CTX ${NCTX} \ 78 | TRAINER.COOP.CSC ${CSC} \ 79 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 80 | DATASET.NUM_SHOTS ${SHOTS} 81 | fi 82 | done 83 | done 84 | 85 | done -------------------------------------------------------------------------------- /code/scripts/coop/main_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../../ 5 | TRAINER=CoOp 6 | 7 | DATASET=imagenet 8 | CFG=vit_b16 # config file 9 | CTP=end # class token position (end or middle) 10 | NCTX=16 # number of context tokens 11 | # SHOTS=16 # number of shots (1, 2, 4, 8, 16) 12 | CSC=False # class-specific context (False or True) 13 | 14 | 15 | for SHOTS in 16 #8 4 2 1 16 | do 17 | for SEED in 1 #2 3 18 | do 19 | DIR=../output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}_test/seed${SEED} 20 | if [ -d "$DIR" ]; then 21 | echo "Oops! The results exist at ${DIR} (so skip this job)" 22 | else 23 | python train.py \ 24 | --root ${DATA} \ 25 | --seed ${SEED} \ 26 | --trainer ${TRAINER} \ 27 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 28 | --config-file configs/trainers/CoOp/${CFG}_ep50.yaml \ 29 | --output-dir ${DIR} \ 30 | DATALOADER.TRAIN_X.BATCH_SIZE 1 \ 31 | TRAINER.COOP.N_CTX ${NCTX} \ 32 | TRAINER.COOP.CSC ${CSC} \ 33 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 34 | DATASET.NUM_SHOTS ${SHOTS} 35 | fi 36 | done 37 | done 38 | -------------------------------------------------------------------------------- /code/scripts/coop/main_oxford_pets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/workspace1/datasets/ #/usr/zkc/data 5 | TRAINER=CoOp 6 | 7 | DATASET=oxford_pets 8 | CFG=rn50_aircraft # config file 9 | CTP=end # class token position (end or middle) 10 | NCTX=16 # number of context tokens 11 | SHOTS=16 # number of shots (1, 2, 4, 8, 16) 12 | CSC=False # class-specific context (False or True) 13 | LR=8e-5 14 | 15 | for SEED in 1 #2 3 16 | do 17 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 18 | if [ -d "$DIR" ]; then 19 | echo "Oops! The results exist at ${DIR} (so skip this job)" 20 | else 21 | CUDA_VISIBLE_DEVICES=1 python train.py \ 22 | --root ${DATA} \ 23 | --seed ${SEED} \ 24 | --trainer ${TRAINER} \ 25 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 26 | --config-file configs/trainers/MASK/${CFG}.yaml \ 27 | --output-dir ${DIR} \ 28 | TRAINER.COOP.N_CTX ${NCTX} \ 29 | TRAINER.COOP.CSC ${CSC} \ 30 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | OPTIM.LR ${LR} 33 | fi 34 | done 35 | -------------------------------------------------------------------------------- /code/scripts/coop/zeroshot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../../datasets/ 5 | TRAINER=ZeroshotCLIP 6 | DATASET=imagenet_a 7 | CFG=vit_b16 # rn50, rn101, vit_b32 or vit_b16 8 | 9 | python train.py \ 10 | --root ${DATA} \ 11 | --trainer ${TRAINER} \ 12 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 13 | --config-file configs/trainers/MASK/imagenet/${CFG}.yaml \ 14 | --output-dir ../output/${TRAINER}/${CFG}/${DATASET} \ 15 | --eval-only 16 | -------------------------------------------------------------------------------- /code/scripts/masktuning/.main.sh.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/code/scripts/masktuning/.main.sh.swp -------------------------------------------------------------------------------- /code/scripts/masktuning/base2new_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../datasets # data root 5 | TRAINER=MaskTuning 6 | 7 | # ================= 8 | DATASET=$1 # oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet 9 | # ================= 10 | CFG=vit_b16 11 | if [[ ${DATASET} == imagenet ]] 12 | then 13 | CFG_FILE=imagenet_${CFG}.yaml 14 | else 15 | CFG_FILE=${CFG}.yaml 16 | fi 17 | 18 | SHOTS=16 # number of shots (1, 2, 4, 8, 16) 19 | for SEED in 1 2 3 20 | do 21 | DIR=../output_ramt_b2n/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/MASK/${CFG_FILE} \ 28 | --output-dir ${DIR}/eval \ 29 | --model-dir ${DIR} \ 30 | --eval-only \ 31 | DATASET.SUBSAMPLE_CLASSES "new" 32 | 33 | done -------------------------------------------------------------------------------- /code/scripts/masktuning/base2new_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../datasets # data root 5 | TRAINER=MaskTuning 6 | 7 | # ================= 8 | DATASET=$1 # oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet 9 | # ================= 10 | CFG=vit_b16 11 | if [[ ${DATASET} == imagenet ]] 12 | then 13 | CFG_FILE=imagenet_${CFG}.yaml 14 | else 15 | CFG_FILE=${CFG}.yaml 16 | fi 17 | SHOTS=16 # number of shots (1, 2, 4, 8, 16) 18 | for SEED in 1 2 3 19 | do 20 | DIR=../output_ramt_b2n/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 21 | if [ -d "$DIR" ]; then 22 | echo "Oops! The results exist at ${DIR} (so skip this job)" 23 | else 24 | python train.py \ 25 | --root ${DATA} \ 26 | --seed ${SEED} \ 27 | --trainer ${TRAINER} \ 28 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 29 | --config-file configs/trainers/MASK/${CFG_FILE} \ 30 | --output-dir ${DIR} \ 31 | DATASET.SUBSAMPLE_CLASSES "base" \ 32 | TEST.FINAL_MODEL "last_step" \ 33 | DATASET.NUM_SHOTS ${SHOTS} 34 | fi 35 | done -------------------------------------------------------------------------------- /code/scripts/masktuning/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../datasets # data root 5 | TRAINER=MaskTuning 6 | 7 | # ================= 8 | DATASET=$1 # oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet 9 | # ================= 10 | CFG=vit_b16 11 | if [[ ${DATASET} == imagenet ]] 12 | then 13 | CFG_FILE=imagenet_${CFG}.yaml 14 | else 15 | CFG_FILE=${CFG}.yaml 16 | fi 17 | 18 | SHOTS=16 # number of shots (1, 2, 4, 8, 16) 19 | for SEED in 1 2 3 20 | do 21 | DIR=../output_ramt/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/MASK/${CFG_FILE} \ 28 | --output-dir ${DIR}/eval \ 29 | --model-dir ${DIR} \ 30 | --eval-only 31 | done -------------------------------------------------------------------------------- /code/scripts/masktuning/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=../datasets # data root 5 | TRAINER=MaskTuning 6 | 7 | # ================= 8 | DATASET=$1 # oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet 9 | GDR_LAMBDA=$2 #$ # use cosine lr scheduler 10 | # ================= 11 | CFG=vit_b16 12 | if [[ ${DATASET} == imagenet ]] 13 | then 14 | CFG_FILE=imagenet_${CFG}.yaml 15 | else 16 | CFG_FILE=${CFG}.yaml 17 | fi 18 | 19 | SHOTS=16 # number of shots (1, 2, 4, 8, 16) 20 | for SEED in 1 2 3 21 | do 22 | DIR=../output_ramt/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 23 | if [ -d "$DIR" ]; then 24 | echo "Oops! The results exist at ${DIR} (so skip this job)" 25 | else 26 | python train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/MASK/${CFG_FILE} \ 32 | --output-dir ${DIR} \ 33 | TEST.FINAL_MODEL "last_step" \ 34 | DATASET.NUM_SHOTS ${SHOTS} \ 35 | MASK.GDR_LAMBDA ${GDR_LAMBDA} 36 | fi 37 | done -------------------------------------------------------------------------------- /code/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/code/trainers/__init__.py -------------------------------------------------------------------------------- /code/trainers/imagenet_templates.py: -------------------------------------------------------------------------------- 1 | IMAGENET_TEMPLATES = [ 2 | "a bad photo of a {}.", 3 | "a photo of many {}.", 4 | "a sculpture of a {}.", 5 | "a photo of the hard to see {}.", 6 | "a low resolution photo of the {}.", 7 | "a rendering of a {}.", 8 | "graffiti of a {}.", 9 | "a bad photo of the {}.", 10 | "a cropped photo of the {}.", 11 | "a tattoo of a {}.", 12 | "the embroidered {}.", 13 | "a photo of a hard to see {}.", 14 | "a bright photo of a {}.", 15 | "a photo of a clean {}.", 16 | "a photo of a dirty {}.", 17 | "a dark photo of the {}.", 18 | "a drawing of a {}.", 19 | "a photo of my {}.", 20 | "the plastic {}.", 21 | "a photo of the cool {}.", 22 | "a close-up photo of a {}.", 23 | "a black and white photo of the {}.", 24 | "a painting of the {}.", 25 | "a painting of a {}.", 26 | "a pixelated photo of the {}.", 27 | "a sculpture of the {}.", 28 | "a bright photo of the {}.", 29 | "a cropped photo of a {}.", 30 | "a plastic {}.", 31 | "a photo of the dirty {}.", 32 | "a jpeg corrupted photo of a {}.", 33 | "a blurry photo of the {}.", 34 | "a photo of the {}.", 35 | "a good photo of the {}.", 36 | "a rendering of the {}.", 37 | "a {} in a video game.", 38 | "a photo of one {}.", 39 | "a doodle of a {}.", 40 | "a close-up photo of the {}.", 41 | "a photo of a {}.", 42 | "the origami {}.", 43 | "the {} in a video game.", 44 | "a sketch of a {}.", 45 | "a doodle of the {}.", 46 | "a origami {}.", 47 | "a low resolution photo of a {}.", 48 | "the toy {}.", 49 | "a rendition of the {}.", 50 | "a photo of the clean {}.", 51 | "a photo of a large {}.", 52 | "a rendition of a {}.", 53 | "a photo of a nice {}.", 54 | "a photo of a weird {}.", 55 | "a blurry photo of a {}.", 56 | "a cartoon {}.", 57 | "art of a {}.", 58 | "a sketch of the {}.", 59 | "a embroidered {}.", 60 | "a pixelated photo of a {}.", 61 | "itap of the {}.", 62 | "a jpeg corrupted photo of the {}.", 63 | "a good photo of a {}.", 64 | "a plushie {}.", 65 | "a photo of the nice {}.", 66 | "a photo of the small {}.", 67 | "a photo of the weird {}.", 68 | "the cartoon {}.", 69 | "art of the {}.", 70 | "a drawing of the {}.", 71 | "a photo of the large {}.", 72 | "a black and white photo of a {}.", 73 | "the plushie {}.", 74 | "a dark photo of a {}.", 75 | "itap of a {}.", 76 | "graffiti of the {}.", 77 | "a toy {}.", 78 | "itap of my {}.", 79 | "a photo of a cool {}.", 80 | "a photo of a small {}.", 81 | "a tattoo of the {}.", 82 | ] 83 | 84 | IMAGENET_TEMPLATES_SELECT = [ 85 | "itap of a {}.", 86 | "a bad photo of the {}.", 87 | "a origami {}.", 88 | "a photo of the large {}.", 89 | "a {} in a video game.", 90 | "art of the {}.", 91 | "a photo of the small {}.", 92 | ] -------------------------------------------------------------------------------- /figures/clip_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/figures/clip_mask.jpg -------------------------------------------------------------------------------- /figures/init.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /static/images/Pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/static/images/Pipeline.jpg -------------------------------------------------------------------------------- /static/images/exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/static/images/exp.png -------------------------------------------------------------------------------- /static/images/teaser1-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuw2019/R-AMT/07f8d2c36670e1fb9498aab169b9f3790ccda66b/static/images/teaser1-1.jpg -------------------------------------------------------------------------------- /static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | --------------------------------------------------------------------------------