├── DATASET.md ├── Dassl.pytorch ├── .flake8 ├── .gitignore ├── .isort.cfg ├── .style.yapf ├── DATASETS.md ├── LICENSE ├── README.md ├── configs │ ├── README.md │ ├── datasets │ │ ├── da │ │ │ ├── cifar_stl.yaml │ │ │ ├── digit5.yaml │ │ │ ├── domainnet.yaml │ │ │ ├── mini_domainnet.yaml │ │ │ ├── office31.yaml │ │ │ ├── office_home.yaml │ │ │ └── visda17.yaml │ │ ├── dg │ │ │ ├── camelyon17.yaml │ │ │ ├── cifar100_c.yaml │ │ │ ├── cifar10_c.yaml │ │ │ ├── digit_single.yaml │ │ │ ├── digits_dg.yaml │ │ │ ├── fmow.yaml │ │ │ ├── iwildcam.yaml │ │ │ ├── office_home_dg.yaml │ │ │ ├── pacs.yaml │ │ │ └── vlcs.yaml │ │ └── ssl │ │ │ ├── cifar10.yaml │ │ │ ├── cifar100.yaml │ │ │ ├── stl10.yaml │ │ │ └── svhn.yaml │ └── trainers │ │ ├── da │ │ ├── cdac │ │ │ ├── digit5.yaml │ │ │ ├── domainnet.yaml │ │ │ └── mini_domainnet.yaml │ │ ├── dael │ │ │ ├── digit5.yaml │ │ │ ├── domainnet.yaml │ │ │ └── mini_domainnet.yaml │ │ ├── m3sda │ │ │ ├── digit5.yaml │ │ │ ├── domainnet.yaml │ │ │ └── mini_domainnet.yaml │ │ └── source_only │ │ │ ├── digit5.yaml │ │ │ ├── mini_domainnet.yaml │ │ │ ├── office31.yaml │ │ │ └── visda17.yaml │ │ ├── dg │ │ ├── daeldg │ │ │ ├── digits_dg.yaml │ │ │ ├── office_home_dg.yaml │ │ │ └── pacs.yaml │ │ ├── ddaig │ │ │ ├── digits_dg.yaml │ │ │ ├── office_home_dg.yaml │ │ │ └── pacs.yaml │ │ └── vanilla │ │ │ ├── digits_dg.yaml │ │ │ ├── mini_domainnet.yaml │ │ │ ├── office_home_dg.yaml │ │ │ └── pacs.yaml │ │ └── ssl │ │ └── fixmatch │ │ └── cifar10.yaml ├── dassl │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ └── defaults.py │ ├── data │ │ ├── __init__.py │ │ ├── data_manager.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── base_dataset.py │ │ │ ├── build.py │ │ │ ├── da │ │ │ │ ├── __init__.py │ │ │ │ ├── cifarstl.py │ │ │ │ ├── digit5.py │ │ │ │ ├── domainnet.py │ │ │ │ ├── mini_domainnet.py │ │ │ │ ├── office31.py │ │ │ │ ├── office_home.py │ │ │ │ └── visda17.py │ │ │ ├── dg │ │ │ │ ├── __init__.py │ │ │ │ ├── cifar_c.py │ │ │ │ ├── digit_single.py │ │ │ │ ├── digits_dg.py │ │ │ │ ├── office_home_dg.py │ │ │ │ ├── pacs.py │ │ │ │ ├── vlcs.py │ │ │ │ └── wilds │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── camelyon17.py │ │ │ │ │ ├── fmow.py │ │ │ │ │ ├── iwildcam.py │ │ │ │ │ └── wilds_base.py │ │ │ └── ssl │ │ │ │ ├── __init__.py │ │ │ │ ├── cifar.py │ │ │ │ ├── stl10.py │ │ │ │ └── svhn.py │ │ ├── samplers.py │ │ └── transforms │ │ │ ├── __init__.py │ │ │ ├── autoaugment.py │ │ │ ├── randaugment.py │ │ │ └── transforms.py │ ├── engine │ │ ├── __init__.py │ │ ├── build.py │ │ ├── da │ │ │ ├── __init__.py │ │ │ ├── adabn.py │ │ │ ├── adda.py │ │ │ ├── cdac.py │ │ │ ├── dael.py │ │ │ ├── dann.py │ │ │ ├── m3sda.py │ │ │ ├── mcd.py │ │ │ ├── mme.py │ │ │ ├── se.py │ │ │ └── source_only.py │ │ ├── dg │ │ │ ├── __init__.py │ │ │ ├── crossgrad.py │ │ │ ├── daeldg.py │ │ │ ├── ddaig.py │ │ │ ├── domain_mix.py │ │ │ └── vanilla.py │ │ ├── ssl │ │ │ ├── __init__.py │ │ │ ├── entmin.py │ │ │ ├── fixmatch.py │ │ │ ├── mean_teacher.py │ │ │ ├── mixmatch.py │ │ │ └── sup_baseline.py │ │ └── trainer.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── build.py │ │ └── evaluator.py │ ├── metrics │ │ ├── __init__.py │ │ ├── accuracy.py │ │ └── distance.py │ ├── modeling │ │ ├── __init__.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── alexnet.py │ │ │ ├── backbone.py │ │ │ ├── build.py │ │ │ ├── cnn_digit5_m3sda.py │ │ │ ├── cnn_digitsdg.py │ │ │ ├── cnn_digitsingle.py │ │ │ ├── efficientnet │ │ │ │ ├── __init__.py │ │ │ │ ├── model.py │ │ │ │ └── utils.py │ │ │ ├── preact_resnet18.py │ │ │ ├── resnet.py │ │ │ ├── resnet_dynamic.py │ │ │ ├── vgg.py │ │ │ └── wide_resnet.py │ │ ├── head │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── mlp.py │ │ ├── network │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── ddaig_fcn.py │ │ └── ops │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── conv.py │ │ │ ├── cross_entropy.py │ │ │ ├── dsbn.py │ │ │ ├── efdmix.py │ │ │ ├── mixstyle.py │ │ │ ├── mixup.py │ │ │ ├── mmd.py │ │ │ ├── optimal_transport.py │ │ │ ├── reverse_grad.py │ │ │ ├── sequential2.py │ │ │ ├── transnorm.py │ │ │ └── utils.py │ ├── optim │ │ ├── __init__.py │ │ ├── lr_scheduler.py │ │ ├── optimizer.py │ │ └── radam.py │ └── utils │ │ ├── __init__.py │ │ ├── logger.py │ │ ├── meters.py │ │ ├── registry.py │ │ ├── tools.py │ │ └── torchtools.py ├── datasets │ ├── da │ │ ├── cifar_stl.py │ │ ├── digit5.py │ │ └── visda17.sh │ ├── dg │ │ └── cifar_c.py │ └── ssl │ │ ├── cifar10_cifar100_svhn.py │ │ └── stl10.py ├── linter.sh ├── requirements.txt ├── setup.py └── tools │ ├── parse_test_res.py │ ├── replace_text.py │ └── train.py ├── LICENSE ├── README.md ├── asset └── dapt.png ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── datasets │ ├── caltech101.yaml │ ├── dtd.yaml │ ├── eurosat.yaml │ ├── fgvc_aircraft.yaml │ ├── food101.yaml │ ├── imagenet.yaml │ ├── imagenet_a.yaml │ ├── imagenet_r.yaml │ ├── imagenet_sketch.yaml │ ├── imagenetv2.yaml │ ├── oxford_flowers.yaml │ ├── oxford_pets.yaml │ ├── stanford_cars.yaml │ ├── sun397.yaml │ └── ucf101.yaml └── trainers │ └── DAPT │ ├── vit_b16.yaml │ ├── vit_b16_ep100.yaml │ └── vit_b16_ep50.yaml ├── datasets ├── __init__.py ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc_aircraft.py ├── food101.py ├── imagenet.py ├── imagenet_a.py ├── imagenet_r.py ├── imagenet_sketch.py ├── imagenetv2.py ├── oxford_flowers.py ├── oxford_pets.py ├── stanford_cars.py ├── sun397.py └── ucf101.py ├── env.yaml ├── scripts ├── eval.sh ├── gen_prototype.sh └── main.sh ├── train.py └── trainers ├── __init__.py └── dapt.py /Dassl.pytorch/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # At least two spaces before inline comment 4 | E261, 5 | # Line lengths are recommended to be no greater than 79 characters 6 | E501, 7 | # Missing whitespace around arithmetic operator 8 | E226, 9 | # Blank line contains whitespace 10 | W293, 11 | # Do not use bare 'except' 12 | E722, 13 | # Line break after binary operator 14 | W504, 15 | # Too many leading '#' for block comment 16 | E266, 17 | # Line break before binary operator 18 | W503, 19 | # Continuation line over-indented for hanging indent 20 | E126, 21 | # Module level import not at top of file 22 | E402 23 | max-line-length = 79 24 | exclude = __init__.py, build -------------------------------------------------------------------------------- /Dassl.pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # OS X 132 | .DS_Store 133 | .Spotlight-V100 134 | .Trashes 135 | ._* 136 | 137 | # This project 138 | output/ 139 | debug/ 140 | -------------------------------------------------------------------------------- /Dassl.pytorch/.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=79 3 | multi_line_output=6 4 | length_sort=true 5 | known_standard_library=numpy,setuptools 6 | known_myself=dassl 7 | known_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs,scipy,gdown 8 | no_lines_before=STDLIB,THIRDPARTY 9 | sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER 10 | default_section=FIRSTPARTY -------------------------------------------------------------------------------- /Dassl.pytorch/.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | BASED_ON_STYLE = pep8 3 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 4 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 5 | DEDENT_CLOSING_BRACKETS = true 6 | SPACES_BEFORE_COMMENT = 2 7 | ARITHMETIC_PRECEDENCE_INDICATION = true -------------------------------------------------------------------------------- /Dassl.pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kaiyang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Dassl.pytorch/configs/README.md: -------------------------------------------------------------------------------- 1 | The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, data augmentation, network architecture) used by most papers. The `trainers/` folder contains method-specific config files which define optimization algorithms (e.g., optimizer, epoch) and hyperparameter settings. 2 | -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/da/cifar_stl.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | PIXEL_MEAN: [0.5, 0.5, 0.5] 4 | PIXEL_STD: [0.5, 0.5, 0.5] 5 | 6 | DATASET: 7 | NAME: "CIFARSTL" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/da/digit5.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | PIXEL_MEAN: [0.5, 0.5, 0.5] 4 | PIXEL_STD: [0.5, 0.5, 0.5] 5 | TRANSFORMS: ["normalize"] 6 | 7 | DATASET: 8 | NAME: "Digit5" 9 | 10 | MODEL: 11 | BACKBONE: 12 | NAME: "cnn_digit5_m3sda" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/da/domainnet.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "DomainNet" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet101" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/da/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (96, 96) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "miniDomainNet" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet18" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/da/office31.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "Office31" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet50" 11 | HEAD: 12 | NAME: "mlp" 13 | HIDDEN_LAYERS: [256] 14 | DROPOUT: 0. -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/da/office_home.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | 4 | DATASET: 5 | NAME: "OfficeHome" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/da/visda17.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "center_crop", "normalize"] 4 | 5 | DATASET: 6 | NAME: "VisDA17" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet101" 11 | 12 | TEST: 13 | PER_CLASS_RESULT: True -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/camelyon17.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 4 | 5 | DATASET: 6 | NAME: "Camelyon17" 7 | -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/cifar100_c.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "CIFAR100C" 9 | CIFAR_C_TYPE: "fog" 10 | CIFAR_C_LEVEL: 5 11 | 12 | MODEL: 13 | BACKBONE: 14 | NAME: "wide_resnet_16_4" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/cifar10_c.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "CIFAR10C" 9 | CIFAR_C_TYPE: "fog" 10 | CIFAR_C_LEVEL: 5 11 | 12 | MODEL: 13 | BACKBONE: 14 | NAME: "wide_resnet_16_4" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/digit_single.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "DigitSingle" 9 | 10 | MODEL: 11 | BACKBONE: 12 | NAME: "cnn_digitsingle" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/digits_dg.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "DigitsDG" 9 | 10 | MODEL: 11 | BACKBONE: 12 | NAME: "cnn_digitsdg" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/fmow.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 4 | 5 | DATASET: 6 | NAME: "FMoW" 7 | -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/iwildcam.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 4 | 5 | DATASET: 6 | NAME: "IWildCam" 7 | -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/office_home_dg.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "OfficeHomeDG" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet18" 11 | PRETRAINED: True -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/pacs.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "PACS" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet18" 11 | PRETRAINED: True -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/dg/vlcs.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (224, 224) 3 | TRANSFORMS: ["random_flip", "random_translation", "normalize"] 4 | 5 | DATASET: 6 | NAME: "VLCS" 7 | 8 | MODEL: 9 | BACKBONE: 10 | NAME: "resnet18" 11 | PRETRAINED: True -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/ssl/cifar10.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | 7 | DATASET: 8 | NAME: "CIFAR10" 9 | NUM_LABELED: 4000 10 | VAL_PERCENT: 0. 11 | 12 | MODEL: 13 | BACKBONE: 14 | NAME: "wide_resnet_28_2" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/ssl/cifar100.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | CROP_PADDING: 4 7 | 8 | DATASET: 9 | NAME: "CIFAR100" 10 | NUM_LABELED: 10000 11 | VAL_PERCENT: 0. 12 | 13 | MODEL: 14 | BACKBONE: 15 | NAME: "wide_resnet_28_2" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/ssl/stl10.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (96, 96) 3 | TRANSFORMS: ["random_flip", "random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | CROP_PADDING: 4 7 | 8 | DATASET: 9 | NAME: "STL10" 10 | STL10_FOLD: 0 11 | 12 | MODEL: 13 | BACKBONE: 14 | NAME: "wide_resnet_28_2" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/datasets/ssl/svhn.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | SIZE: (32, 32) 3 | TRANSFORMS: ["random_crop", "normalize"] 4 | PIXEL_MEAN: [0.5, 0.5, 0.5] 5 | PIXEL_STD: [0.5, 0.5, 0.5] 6 | CROP_PADDING: 4 7 | 8 | DATASET: 9 | NAME: "SVHN" 10 | NUM_LABELED: 1000 11 | VAL_PERCENT: 0. 12 | 13 | MODEL: 14 | BACKBONE: 15 | NAME: "wide_resnet_28_2" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/cdac/digit5.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomSampler" 4 | BATCH_SIZE: 64 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 192 8 | TEST: 9 | BATCH_SIZE: 256 10 | K_TRANSFORMS: 2 11 | 12 | OPTIM: 13 | NAME: "sgd" 14 | LR: 0.001 15 | MAX_EPOCH: 90 16 | RAMPUP_ITRS: 10000 17 | 18 | TRAINER: 19 | CDAC: 20 | STRONG_TRANSFORMS: ["randaugment", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/cdac/domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 6 8 | TEST: 9 | BATCH_SIZE: 30 10 | K_TRANSFORMS: 2 11 | 12 | OPTIM: 13 | NAME: "sgd" 14 | LR: 0.001 15 | MAX_EPOCH: 90 16 | RAMPUP_ITRS: 10000 17 | 18 | TRAINER: 19 | CDAC: 20 | STRONG_TRANSFORMS: ["randaugment", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/cdac/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 64 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 192 8 | TEST: 9 | BATCH_SIZE: 200 10 | K_TRANSFORMS: 2 11 | 12 | OPTIM: 13 | NAME: "sgd" 14 | LR: 0.001 15 | MAX_EPOCH: 60 16 | RAMPUP_ITRS: 10000 17 | LR_SCHEDULER: "cosine" 18 | 19 | TRAINER: 20 | CDAC: 21 | STRONG_TRANSFORMS: ["randaugment", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/dael/digit5.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 256 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 64 8 | TEST: 9 | BATCH_SIZE: 256 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.05 14 | STEPSIZE: [30] 15 | MAX_EPOCH: 30 16 | LR_SCHEDULER: "cosine" 17 | 18 | TRAINER: 19 | DAEL: 20 | STRONG_TRANSFORMS: ["randaugment2", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/dael/domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 6 8 | TEST: 9 | BATCH_SIZE: 30 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.002 14 | MAX_EPOCH: 40 15 | LR_SCHEDULER: "cosine" 16 | 17 | TRAINER: 18 | DAEL: 19 | STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/dael/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 192 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 64 8 | TEST: 9 | BATCH_SIZE: 200 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.005 14 | MAX_EPOCH: 60 15 | LR_SCHEDULER: "cosine" 16 | 17 | TRAINER: 18 | DAEL: 19 | STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/m3sda/digit5.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 256 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 64 8 | TEST: 9 | BATCH_SIZE: 256 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.05 14 | STEPSIZE: [30] 15 | MAX_EPOCH: 30 16 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/m3sda/domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 6 8 | TEST: 9 | BATCH_SIZE: 30 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.002 14 | MAX_EPOCH: 40 15 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/m3sda/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 192 5 | TRAIN_U: 6 | SAME_AS_X: False 7 | BATCH_SIZE: 64 8 | TEST: 9 | BATCH_SIZE: 200 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.005 14 | MAX_EPOCH: 60 15 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/source_only/digit5.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 256 4 | TEST: 5 | BATCH_SIZE: 256 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.05 10 | STEPSIZE: [30] 11 | MAX_EPOCH: 30 12 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/source_only/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 128 4 | TEST: 5 | BATCH_SIZE: 128 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.005 10 | MAX_EPOCH: 60 11 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/source_only/office31.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 32 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.002 10 | STEPSIZE: [20] 11 | MAX_EPOCH: 20 -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/da/source_only/visda17.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 32 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.0001 10 | STEPSIZE: [2] 11 | MAX_EPOCH: 2 12 | 13 | TRAIN: 14 | PRINT_FREQ: 50 15 | COUNT_ITER: "train_u" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/daeldg/digits_dg.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 120 5 | TEST: 6 | BATCH_SIZE: 100 7 | 8 | OPTIM: 9 | NAME: "sgd" 10 | LR: 0.05 11 | STEPSIZE: [20] 12 | MAX_EPOCH: 50 13 | 14 | TRAINER: 15 | DAELDG: 16 | STRONG_TRANSFORMS: ["randaugment2", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/daeldg/office_home_dg.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TEST: 6 | BATCH_SIZE: 100 7 | 8 | OPTIM: 9 | NAME: "sgd" 10 | LR: 0.002 11 | MAX_EPOCH: 40 12 | LR_SCHEDULER: "cosine" 13 | 14 | TRAINER: 15 | DAELDG: 16 | STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/daeldg/pacs.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | SAMPLER: "RandomDomainSampler" 4 | BATCH_SIZE: 30 5 | TEST: 6 | BATCH_SIZE: 100 7 | 8 | OPTIM: 9 | NAME: "sgd" 10 | LR: 0.002 11 | MAX_EPOCH: 40 12 | LR_SCHEDULER: "cosine" 13 | 14 | TRAINER: 15 | DAELDG: 16 | STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"] -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/ddaig/digits_dg.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | PIXEL_MEAN: [0., 0., 0.] 3 | PIXEL_STD: [1., 1., 1.] 4 | 5 | DATALOADER: 6 | TRAIN_X: 7 | BATCH_SIZE: 128 8 | TEST: 9 | BATCH_SIZE: 128 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.05 14 | STEPSIZE: [20] 15 | MAX_EPOCH: 50 16 | 17 | TRAINER: 18 | DDAIG: 19 | G_ARCH: "fcn_3x32_gctx" 20 | LMDA: 0.3 -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/ddaig/office_home_dg.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | PIXEL_MEAN: [0., 0., 0.] 3 | PIXEL_STD: [1., 1., 1.] 4 | 5 | DATALOADER: 6 | TRAIN_X: 7 | BATCH_SIZE: 16 8 | TEST: 9 | BATCH_SIZE: 16 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.0005 14 | STEPSIZE: [20] 15 | MAX_EPOCH: 25 16 | 17 | TRAINER: 18 | DDAIG: 19 | G_ARCH: "fcn_3x64_gctx" 20 | WARMUP: 3 21 | LMDA: 0.3 -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/ddaig/pacs.yaml: -------------------------------------------------------------------------------- 1 | INPUT: 2 | PIXEL_MEAN: [0., 0., 0.] 3 | PIXEL_STD: [1., 1., 1.] 4 | 5 | DATALOADER: 6 | TRAIN_X: 7 | BATCH_SIZE: 16 8 | TEST: 9 | BATCH_SIZE: 16 10 | 11 | OPTIM: 12 | NAME: "sgd" 13 | LR: 0.0005 14 | STEPSIZE: [20] 15 | MAX_EPOCH: 25 16 | 17 | TRAINER: 18 | DDAIG: 19 | G_ARCH: "fcn_3x64_gctx" 20 | WARMUP: 3 21 | LMDA: 0.3 -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/vanilla/digits_dg.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 128 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | OPTIM: 9 | NAME: "sgd" 10 | LR: 0.05 11 | STEPSIZE: [20] 12 | MAX_EPOCH: 50 13 | 14 | TRAIN: 15 | PRINT_FREQ: 20 -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/vanilla/mini_domainnet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 128 4 | TEST: 5 | BATCH_SIZE: 128 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.005 10 | MAX_EPOCH: 60 11 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/vanilla/office_home_dg.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 64 4 | TEST: 5 | BATCH_SIZE: 100 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.001 10 | MAX_EPOCH: 50 11 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/dg/vanilla/pacs.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 64 4 | TEST: 5 | BATCH_SIZE: 100 6 | 7 | OPTIM: 8 | NAME: "sgd" 9 | LR: 0.001 10 | MAX_EPOCH: 50 11 | LR_SCHEDULER: "cosine" -------------------------------------------------------------------------------- /Dassl.pytorch/configs/trainers/ssl/fixmatch/cifar10.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 64 4 | TRAIN_U: 5 | SAME_AS_X: False 6 | BATCH_SIZE: 448 7 | TEST: 8 | BATCH_SIZE: 500 9 | 10 | OPTIM: 11 | NAME: "sgd" 12 | LR: 0.05 13 | STEPSIZE: [4000] 14 | MAX_EPOCH: 4000 15 | LR_SCHEDULER: "cosine" 16 | 17 | TRAIN: 18 | COUNT_ITER: "train_u" 19 | PRINT_FREQ: 10 20 | 21 | TRAINER: 22 | FIXMATCH: 23 | STRONG_TRANSFORMS: ["random_flip", "randaugment_fixmatch", "normalize", "cutout"] -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dassl 3 | ------ 4 | PyTorch toolbox for domain adaptation and semi-supervised learning. 5 | 6 | URL: https://github.com/KaiyangZhou/Dassl.pytorch 7 | 8 | @article{zhou2020domain, 9 | title={Domain Adaptive Ensemble Learning}, 10 | author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao}, 11 | journal={arXiv preprint arXiv:2003.07325}, 12 | year={2020} 13 | } 14 | """ 15 | 16 | __version__ = "0.6.3" 17 | __author__ = "Kaiyang Zhou" 18 | __homepage__ = "https://kaiyangzhou.github.io/" 19 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg_default 2 | 3 | 4 | def get_cfg_default(): 5 | return cfg_default.clone() 6 | 7 | 8 | def clean_cfg(cfg, trainer): 9 | """Remove unused trainers (configs). 10 | 11 | Aim: Only show relevant information when calling print(cfg). 12 | 13 | Args: 14 | cfg (_C): cfg instance. 15 | trainer (str): trainer name. 16 | """ 17 | keys = list(cfg.TRAINER.keys()) 18 | for key in keys: 19 | if key == "NAME" or key == trainer.upper(): 20 | continue 21 | cfg.TRAINER.pop(key, None) 22 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_manager import DataManager, DatasetWrapper 2 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import DATASET_REGISTRY, build_dataset # isort:skip 2 | from .base_dataset import Datum, DatasetBase # isort:skip 3 | 4 | from .da import * 5 | from .dg import * 6 | from .ssl import * 7 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | DATASET_REGISTRY = Registry("DATASET") 4 | 5 | 6 | def build_dataset(cfg): 7 | avai_datasets = DATASET_REGISTRY.registered_names() 8 | check_availability(cfg.DATASET.NAME, avai_datasets) 9 | if cfg.VERBOSE: 10 | print("Loading dataset: {}".format(cfg.DATASET.NAME)) 11 | return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg) 12 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/da/__init__.py: -------------------------------------------------------------------------------- 1 | from .digit5 import Digit5 2 | from .visda17 import VisDA17 3 | from .cifarstl import CIFARSTL 4 | from .office31 import Office31 5 | from .domainnet import DomainNet 6 | from .office_home import OfficeHome 7 | from .mini_domainnet import miniDomainNet 8 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/da/cifarstl.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class CIFARSTL(DatasetBase): 11 | """CIFAR-10 and STL-10. 12 | 13 | CIFAR-10: 14 | - 60,000 32x32 colour images. 15 | - 10 classes, with 6,000 images per class. 16 | - 50,000 training images and 10,000 test images. 17 | - URL: https://www.cs.toronto.edu/~kriz/cifar.html. 18 | 19 | STL-10: 20 | - 10 classes: airplane, bird, car, cat, deer, dog, horse, 21 | monkey, ship, truck. 22 | - Images are 96x96 pixels, color. 23 | - 500 training images (10 pre-defined folds), 800 test images 24 | per class. 25 | - URL: https://cs.stanford.edu/~acoates/stl10/. 26 | 27 | Reference: 28 | - Krizhevsky. Learning Multiple Layers of Features 29 | from Tiny Images. Tech report. 30 | - Coates et al. An Analysis of Single Layer Networks in 31 | Unsupervised Feature Learning. AISTATS 2011. 32 | """ 33 | 34 | dataset_dir = "cifar_stl" 35 | domains = ["cifar", "stl"] 36 | 37 | def __init__(self, cfg): 38 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 39 | self.dataset_dir = osp.join(root, self.dataset_dir) 40 | 41 | self.check_input_domains( 42 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 43 | ) 44 | 45 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 46 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 47 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 48 | 49 | super().__init__(train_x=train_x, train_u=train_u, test=test) 50 | 51 | def _read_data(self, input_domains, split="train"): 52 | items = [] 53 | 54 | for domain, dname in enumerate(input_domains): 55 | data_dir = osp.join(self.dataset_dir, dname, split) 56 | class_names = listdir_nohidden(data_dir) 57 | 58 | for class_name in class_names: 59 | class_dir = osp.join(data_dir, class_name) 60 | imnames = listdir_nohidden(class_dir) 61 | label = int(class_name.split("_")[0]) 62 | 63 | for imname in imnames: 64 | impath = osp.join(class_dir, imname) 65 | item = Datum(impath=impath, label=label, domain=domain) 66 | items.append(item) 67 | 68 | return items 69 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/da/domainnet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class DomainNet(DatasetBase): 9 | """DomainNet. 10 | 11 | Statistics: 12 | - 6 distinct domains: Clipart, Infograph, Painting, Quickdraw, 13 | Real, Sketch. 14 | - Around 0.6M images. 15 | - 345 categories. 16 | - URL: http://ai.bu.edu/M3SDA/. 17 | 18 | Special note: the t-shirt class (327) is missing in painting_train.txt. 19 | 20 | Reference: 21 | - Peng et al. Moment Matching for Multi-Source Domain 22 | Adaptation. ICCV 2019. 23 | """ 24 | 25 | dataset_dir = "domainnet" 26 | domains = [ 27 | "clipart", "infograph", "painting", "quickdraw", "real", "sketch" 28 | ] 29 | 30 | def __init__(self, cfg): 31 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 32 | self.dataset_dir = osp.join(root, self.dataset_dir) 33 | self.split_dir = osp.join(self.dataset_dir, "splits") 34 | 35 | self.check_input_domains( 36 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 37 | ) 38 | 39 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 40 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 41 | val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test") 42 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 43 | 44 | super().__init__(train_x=train_x, train_u=train_u, val=val, test=test) 45 | 46 | def _read_data(self, input_domains, split="train"): 47 | items = [] 48 | 49 | for domain, dname in enumerate(input_domains): 50 | filename = dname + "_" + split + ".txt" 51 | split_file = osp.join(self.split_dir, filename) 52 | 53 | with open(split_file, "r") as f: 54 | lines = f.readlines() 55 | for line in lines: 56 | line = line.strip() 57 | impath, label = line.split(" ") 58 | classname = impath.split("/")[1] 59 | impath = osp.join(self.dataset_dir, impath) 60 | label = int(label) 61 | item = Datum( 62 | impath=impath, 63 | label=label, 64 | domain=domain, 65 | classname=classname 66 | ) 67 | items.append(item) 68 | 69 | return items 70 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/da/mini_domainnet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class miniDomainNet(DatasetBase): 9 | """A subset of DomainNet. 10 | 11 | Reference: 12 | - Peng et al. Moment Matching for Multi-Source Domain 13 | Adaptation. ICCV 2019. 14 | - Zhou et al. Domain Adaptive Ensemble Learning. 15 | """ 16 | 17 | dataset_dir = "domainnet" 18 | domains = ["clipart", "painting", "real", "sketch"] 19 | 20 | def __init__(self, cfg): 21 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = osp.join(root, self.dataset_dir) 23 | self.split_dir = osp.join(self.dataset_dir, "splits_mini") 24 | 25 | self.check_input_domains( 26 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 27 | ) 28 | 29 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 30 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 31 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 32 | 33 | super().__init__(train_x=train_x, train_u=train_u, test=test) 34 | 35 | def _read_data(self, input_domains, split="train"): 36 | items = [] 37 | 38 | for domain, dname in enumerate(input_domains): 39 | filename = dname + "_" + split + ".txt" 40 | split_file = osp.join(self.split_dir, filename) 41 | 42 | with open(split_file, "r") as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | line = line.strip() 46 | impath, label = line.split(" ") 47 | classname = impath.split("/")[1] 48 | impath = osp.join(self.dataset_dir, impath) 49 | label = int(label) 50 | item = Datum( 51 | impath=impath, 52 | label=label, 53 | domain=domain, 54 | classname=classname 55 | ) 56 | items.append(item) 57 | 58 | return items 59 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/da/office31.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class Office31(DatasetBase): 11 | """Office-31. 12 | 13 | Statistics: 14 | - 4,110 images. 15 | - 31 classes related to office objects. 16 | - 3 domains: Amazon, Webcam, Dslr. 17 | - URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/. 18 | 19 | Reference: 20 | - Saenko et al. Adapting visual category models to 21 | new domains. ECCV 2010. 22 | """ 23 | 24 | dataset_dir = "office31" 25 | domains = ["amazon", "webcam", "dslr"] 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | self.check_input_domains( 32 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 33 | ) 34 | 35 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) 36 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) 37 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS) 38 | 39 | super().__init__(train_x=train_x, train_u=train_u, test=test) 40 | 41 | def _read_data(self, input_domains): 42 | items = [] 43 | 44 | for domain, dname in enumerate(input_domains): 45 | domain_dir = osp.join(self.dataset_dir, dname) 46 | class_names = listdir_nohidden(domain_dir) 47 | class_names.sort() 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_path = osp.join(domain_dir, class_name) 51 | imnames = listdir_nohidden(class_path) 52 | 53 | for imname in imnames: 54 | impath = osp.join(class_path, imname) 55 | item = Datum( 56 | impath=impath, 57 | label=label, 58 | domain=domain, 59 | classname=class_name 60 | ) 61 | items.append(item) 62 | 63 | return items 64 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/da/office_home.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class OfficeHome(DatasetBase): 11 | """Office-Home. 12 | 13 | Statistics: 14 | - Around 15,500 images. 15 | - 65 classes related to office and home objects. 16 | - 4 domains: Art, Clipart, Product, Real World. 17 | - URL: http://hemanthdv.org/OfficeHome-Dataset/. 18 | 19 | Reference: 20 | - Venkateswara et al. Deep Hashing Network for Unsupervised 21 | Domain Adaptation. CVPR 2017. 22 | """ 23 | 24 | dataset_dir = "office_home" 25 | domains = ["art", "clipart", "product", "real_world"] 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | self.check_input_domains( 32 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 33 | ) 34 | 35 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) 36 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) 37 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS) 38 | 39 | super().__init__(train_x=train_x, train_u=train_u, test=test) 40 | 41 | def _read_data(self, input_domains): 42 | items = [] 43 | 44 | for domain, dname in enumerate(input_domains): 45 | domain_dir = osp.join(self.dataset_dir, dname) 46 | class_names = listdir_nohidden(domain_dir) 47 | class_names.sort() 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_path = osp.join(domain_dir, class_name) 51 | imnames = listdir_nohidden(class_path) 52 | 53 | for imname in imnames: 54 | impath = osp.join(class_path, imname) 55 | item = Datum( 56 | impath=impath, 57 | label=label, 58 | domain=domain, 59 | classname=class_name.lower(), 60 | ) 61 | items.append(item) 62 | 63 | return items 64 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/da/visda17.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class VisDA17(DatasetBase): 9 | """VisDA17. 10 | 11 | Focusing on simulation-to-reality domain shift. 12 | 13 | URL: http://ai.bu.edu/visda-2017/. 14 | 15 | Reference: 16 | - Peng et al. VisDA: The Visual Domain Adaptation 17 | Challenge. ArXiv 2017. 18 | """ 19 | 20 | dataset_dir = "visda17" 21 | domains = ["synthetic", "real"] 22 | 23 | def __init__(self, cfg): 24 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | 27 | self.check_input_domains( 28 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 29 | ) 30 | 31 | train_x = self._read_data("synthetic") 32 | train_u = self._read_data("real") 33 | test = self._read_data("real") 34 | 35 | super().__init__(train_x=train_x, train_u=train_u, test=test) 36 | 37 | def _read_data(self, dname): 38 | filedir = "train" if dname == "synthetic" else "validation" 39 | image_list = osp.join(self.dataset_dir, filedir, "image_list.txt") 40 | items = [] 41 | # There is only one source domain 42 | domain = 0 43 | 44 | with open(image_list, "r") as f: 45 | lines = f.readlines() 46 | 47 | for line in lines: 48 | line = line.strip() 49 | impath, label = line.split(" ") 50 | classname = impath.split("/")[0] 51 | impath = osp.join(self.dataset_dir, filedir, impath) 52 | label = int(label) 53 | item = Datum( 54 | impath=impath, 55 | label=label, 56 | domain=domain, 57 | classname=classname 58 | ) 59 | items.append(item) 60 | 61 | return items 62 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/__init__.py: -------------------------------------------------------------------------------- 1 | from .pacs import PACS 2 | from .vlcs import VLCS 3 | from .wilds import * 4 | from .cifar_c import CIFAR10C, CIFAR100C 5 | from .digits_dg import DigitsDG 6 | from .digit_single import DigitSingle 7 | from .office_home_dg import OfficeHomeDG 8 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/cifar_c.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | AVAI_C_TYPES = [ 9 | "brightness", 10 | "contrast", 11 | "defocus_blur", 12 | "elastic_transform", 13 | "fog", 14 | "frost", 15 | "gaussian_blur", 16 | "gaussian_noise", 17 | "glass_blur", 18 | "impulse_noise", 19 | "jpeg_compression", 20 | "motion_blur", 21 | "pixelate", 22 | "saturate", 23 | "shot_noise", 24 | "snow", 25 | "spatter", 26 | "speckle_noise", 27 | "zoom_blur", 28 | ] 29 | 30 | 31 | @DATASET_REGISTRY.register() 32 | class CIFAR10C(DatasetBase): 33 | """CIFAR-10 -> CIFAR-10-C. 34 | 35 | Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o 36 | 37 | Statistics: 38 | - 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10 39 | - 10 categories 40 | 41 | Reference: 42 | - Hendrycks et al. Benchmarking neural network robustness 43 | to common corruptions and perturbations. ICLR 2019. 44 | """ 45 | 46 | dataset_dir = "" 47 | domains = ["cifar10", "cifar10_c"] 48 | 49 | def __init__(self, cfg): 50 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 51 | self.dataset_dir = root 52 | 53 | self.check_input_domains( 54 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 55 | ) 56 | source_domain = cfg.DATASET.SOURCE_DOMAINS[0] 57 | target_domain = cfg.DATASET.TARGET_DOMAINS[0] 58 | assert source_domain == self.domains[0] 59 | assert target_domain == self.domains[1] 60 | 61 | c_type = cfg.DATASET.CIFAR_C_TYPE 62 | c_level = cfg.DATASET.CIFAR_C_LEVEL 63 | 64 | if not c_type: 65 | raise ValueError( 66 | "Please specify DATASET.CIFAR_C_TYPE in the config file" 67 | ) 68 | 69 | assert ( 70 | c_type in AVAI_C_TYPES 71 | ), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"' 72 | assert 1 <= c_level <= 5 73 | 74 | train_dir = osp.join(self.dataset_dir, source_domain, "train") 75 | test_dir = osp.join( 76 | self.dataset_dir, target_domain, c_type, str(c_level) 77 | ) 78 | 79 | if not osp.exists(test_dir): 80 | raise ValueError 81 | 82 | train = self._read_data(train_dir) 83 | test = self._read_data(test_dir) 84 | 85 | super().__init__(train_x=train, test=test) 86 | 87 | def _read_data(self, data_dir): 88 | class_names = listdir_nohidden(data_dir) 89 | class_names.sort() 90 | items = [] 91 | 92 | for label, class_name in enumerate(class_names): 93 | class_dir = osp.join(data_dir, class_name) 94 | imnames = listdir_nohidden(class_dir) 95 | 96 | for imname in imnames: 97 | impath = osp.join(class_dir, imname) 98 | item = Datum(impath=impath, label=label, domain=0) 99 | items.append(item) 100 | 101 | return items 102 | 103 | 104 | @DATASET_REGISTRY.register() 105 | class CIFAR100C(CIFAR10C): 106 | """CIFAR-100 -> CIFAR-100-C. 107 | 108 | Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o 109 | 110 | Statistics: 111 | - 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100 112 | - 10 categories 113 | 114 | Reference: 115 | - Hendrycks et al. Benchmarking neural network robustness 116 | to common corruptions and perturbations. ICLR 2019. 117 | """ 118 | 119 | dataset_dir = "" 120 | domains = ["cifar100", "cifar100_c"] 121 | 122 | def __init__(self, cfg): 123 | super().__init__(cfg) 124 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/digits_dg.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | 4 | from dassl.utils import listdir_nohidden 5 | 6 | from ..build import DATASET_REGISTRY 7 | from ..base_dataset import Datum, DatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class DigitsDG(DatasetBase): 12 | """Digits-DG. 13 | 14 | It contains 4 digit datasets: 15 | - MNIST: hand-written digits. 16 | - MNIST-M: variant of MNIST with blended background. 17 | - SVHN: street view house number. 18 | - SYN: synthetic digits. 19 | 20 | Reference: 21 | - Lecun et al. Gradient-based learning applied to document 22 | recognition. IEEE 1998. 23 | - Ganin et al. Domain-adversarial training of neural networks. 24 | JMLR 2016. 25 | - Netzer et al. Reading digits in natural images with unsupervised 26 | feature learning. NIPS-W 2011. 27 | - Zhou et al. Deep Domain-Adversarial Image Generation for Domain 28 | Generalisation. AAAI 2020. 29 | """ 30 | 31 | dataset_dir = "digits_dg" 32 | domains = ["mnist", "mnist_m", "svhn", "syn"] 33 | data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7" 34 | 35 | def __init__(self, cfg): 36 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 37 | self.dataset_dir = osp.join(root, self.dataset_dir) 38 | 39 | if not osp.exists(self.dataset_dir): 40 | dst = osp.join(root, "digits_dg.zip") 41 | self.download_data(self.data_url, dst, from_gdrive=True) 42 | 43 | self.check_input_domains( 44 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 45 | ) 46 | 47 | train = self.read_data( 48 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train" 49 | ) 50 | val = self.read_data( 51 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val" 52 | ) 53 | test = self.read_data( 54 | self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all" 55 | ) 56 | 57 | super().__init__(train_x=train, val=val, test=test) 58 | 59 | @staticmethod 60 | def read_data(dataset_dir, input_domains, split): 61 | 62 | def _load_data_from_directory(directory): 63 | folders = listdir_nohidden(directory) 64 | folders.sort() 65 | items_ = [] 66 | 67 | for label, folder in enumerate(folders): 68 | impaths = glob.glob(osp.join(directory, folder, "*.jpg")) 69 | 70 | for impath in impaths: 71 | items_.append((impath, label)) 72 | 73 | return items_ 74 | 75 | items = [] 76 | 77 | for domain, dname in enumerate(input_domains): 78 | if split == "all": 79 | train_dir = osp.join(dataset_dir, dname, "train") 80 | impath_label_list = _load_data_from_directory(train_dir) 81 | val_dir = osp.join(dataset_dir, dname, "val") 82 | impath_label_list += _load_data_from_directory(val_dir) 83 | else: 84 | split_dir = osp.join(dataset_dir, dname, split) 85 | impath_label_list = _load_data_from_directory(split_dir) 86 | 87 | for impath, label in impath_label_list: 88 | class_name = impath.split("/")[-2].lower() 89 | item = Datum( 90 | impath=impath, 91 | label=label, 92 | domain=domain, 93 | classname=class_name 94 | ) 95 | items.append(item) 96 | 97 | return items 98 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/office_home_dg.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from .digits_dg import DigitsDG 5 | from ..base_dataset import DatasetBase 6 | 7 | 8 | @DATASET_REGISTRY.register() 9 | class OfficeHomeDG(DatasetBase): 10 | """Office-Home. 11 | 12 | Statistics: 13 | - Around 15,500 images. 14 | - 65 classes related to office and home objects. 15 | - 4 domains: Art, Clipart, Product, Real World. 16 | - URL: http://hemanthdv.org/OfficeHome-Dataset/. 17 | 18 | Reference: 19 | - Venkateswara et al. Deep Hashing Network for Unsupervised 20 | Domain Adaptation. CVPR 2017. 21 | """ 22 | 23 | dataset_dir = "office_home_dg" 24 | domains = ["art", "clipart", "product", "real_world"] 25 | data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa" 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | if not osp.exists(self.dataset_dir): 32 | dst = osp.join(root, "office_home_dg.zip") 33 | self.download_data(self.data_url, dst, from_gdrive=True) 34 | 35 | self.check_input_domains( 36 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 37 | ) 38 | 39 | train = DigitsDG.read_data( 40 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train" 41 | ) 42 | val = DigitsDG.read_data( 43 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val" 44 | ) 45 | test = DigitsDG.read_data( 46 | self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all" 47 | ) 48 | 49 | super().__init__(train_x=train, val=val, test=test) 50 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/pacs.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class PACS(DatasetBase): 9 | """PACS. 10 | 11 | Statistics: 12 | - 4 domains: Photo (1,670), Art (2,048), Cartoon 13 | (2,344), Sketch (3,929). 14 | - 7 categories: dog, elephant, giraffe, guitar, horse, 15 | house and person. 16 | 17 | Reference: 18 | - Li et al. Deeper, broader and artier domain generalization. 19 | ICCV 2017. 20 | """ 21 | 22 | dataset_dir = "pacs" 23 | domains = ["art_painting", "cartoon", "photo", "sketch"] 24 | data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE" 25 | # the following images contain errors and should be ignored 26 | _error_paths = ["sketch/dog/n02103406_4068-1.png"] 27 | 28 | def __init__(self, cfg): 29 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.image_dir = osp.join(self.dataset_dir, "images") 32 | self.split_dir = osp.join(self.dataset_dir, "splits") 33 | 34 | if not osp.exists(self.dataset_dir): 35 | dst = osp.join(root, "pacs.zip") 36 | self.download_data(self.data_url, dst, from_gdrive=True) 37 | 38 | self.check_input_domains( 39 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 40 | ) 41 | 42 | train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train") 43 | val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval") 44 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all") 45 | 46 | super().__init__(train_x=train, val=val, test=test) 47 | 48 | def _read_data(self, input_domains, split): 49 | items = [] 50 | 51 | for domain, dname in enumerate(input_domains): 52 | if split == "all": 53 | file_train = osp.join( 54 | self.split_dir, dname + "_train_kfold.txt" 55 | ) 56 | impath_label_list = self._read_split_pacs(file_train) 57 | file_val = osp.join( 58 | self.split_dir, dname + "_crossval_kfold.txt" 59 | ) 60 | impath_label_list += self._read_split_pacs(file_val) 61 | else: 62 | file = osp.join( 63 | self.split_dir, dname + "_" + split + "_kfold.txt" 64 | ) 65 | impath_label_list = self._read_split_pacs(file) 66 | 67 | for impath, label in impath_label_list: 68 | classname = impath.split("/")[-2] 69 | item = Datum( 70 | impath=impath, 71 | label=label, 72 | domain=domain, 73 | classname=classname 74 | ) 75 | items.append(item) 76 | 77 | return items 78 | 79 | def _read_split_pacs(self, split_file): 80 | items = [] 81 | 82 | with open(split_file, "r") as f: 83 | lines = f.readlines() 84 | 85 | for line in lines: 86 | line = line.strip() 87 | impath, label = line.split(" ") 88 | if impath in self._error_paths: 89 | continue 90 | impath = osp.join(self.image_dir, impath) 91 | label = int(label) - 1 92 | items.append((impath, label)) 93 | 94 | return items 95 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/vlcs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | 4 | from dassl.utils import listdir_nohidden 5 | 6 | from ..build import DATASET_REGISTRY 7 | from ..base_dataset import Datum, DatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class VLCS(DatasetBase): 12 | """VLCS. 13 | 14 | Statistics: 15 | - 4 domains: CALTECH, LABELME, PASCAL, SUN 16 | - 5 categories: bird, car, chair, dog, and person. 17 | 18 | Reference: 19 | - Torralba and Efros. Unbiased look at dataset bias. CVPR 2011. 20 | """ 21 | 22 | dataset_dir = "VLCS" 23 | domains = ["caltech", "labelme", "pascal", "sun"] 24 | data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd" 25 | 26 | def __init__(self, cfg): 27 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 28 | self.dataset_dir = osp.join(root, self.dataset_dir) 29 | 30 | if not osp.exists(self.dataset_dir): 31 | dst = osp.join(root, "vlcs.zip") 32 | self.download_data(self.data_url, dst, from_gdrive=True) 33 | 34 | self.check_input_domains( 35 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 36 | ) 37 | 38 | train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train") 39 | val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval") 40 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test") 41 | 42 | super().__init__(train_x=train, val=val, test=test) 43 | 44 | def _read_data(self, input_domains, split): 45 | items = [] 46 | 47 | for domain, dname in enumerate(input_domains): 48 | dname = dname.upper() 49 | path = osp.join(self.dataset_dir, dname, split) 50 | folders = listdir_nohidden(path) 51 | folders.sort() 52 | 53 | for label, folder in enumerate(folders): 54 | impaths = glob.glob(osp.join(path, folder, "*.jpg")) 55 | 56 | for impath in impaths: 57 | item = Datum(impath=impath, label=label, domain=domain) 58 | items.append(item) 59 | 60 | return items 61 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/wilds/__init__.py: -------------------------------------------------------------------------------- 1 | from .fmow import FMoW 2 | from .iwildcam import IWildCam 3 | from .camelyon17 import Camelyon17 4 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/wilds/camelyon17.py: -------------------------------------------------------------------------------- 1 | from dassl.data.datasets import DATASET_REGISTRY 2 | 3 | from .wilds_base import WILDSBase 4 | 5 | 6 | @DATASET_REGISTRY.register() 7 | class Camelyon17(WILDSBase): 8 | """Tumor tissue recognition. 9 | 10 | 2 classes (whether a given region of tissue contains tumor tissue). 11 | 12 | Reference: 13 | - Bandi et al. "From detection of individual metastases to classification of lymph 14 | node status at the patient level: the CAMELYON17 challenge." TMI 2021. 15 | - Koh et al. "Wilds: A benchmark of in-the-wild distribution shifts." ICML 2021. 16 | """ 17 | 18 | dataset_dir = "camelyon17_v1.0" 19 | 20 | def __init__(self, cfg): 21 | super().__init__(cfg) 22 | 23 | def load_classnames(self): 24 | return {0: "healthy tissue", 1: "tumor tissue"} 25 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/wilds/fmow.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY 4 | 5 | from .wilds_base import WILDSBase 6 | 7 | CATEGORIES = [ 8 | "airport", "airport_hangar", "airport_terminal", "amusement_park", 9 | "aquaculture", "archaeological_site", "barn", "border_checkpoint", 10 | "burial_site", "car_dealership", "construction_site", "crop_field", "dam", 11 | "debris_or_rubble", "educational_institution", "electric_substation", 12 | "factory_or_powerplant", "fire_station", "flooded_road", "fountain", 13 | "gas_station", "golf_course", "ground_transportation_station", "helipad", 14 | "hospital", "impoverished_settlement", "interchange", "lake_or_pond", 15 | "lighthouse", "military_facility", "multi-unit_residential", 16 | "nuclear_powerplant", "office_building", "oil_or_gas_facility", "park", 17 | "parking_lot_or_garage", "place_of_worship", "police_station", "port", 18 | "prison", "race_track", "railway_bridge", "recreational_facility", 19 | "road_bridge", "runway", "shipyard", "shopping_mall", 20 | "single-unit_residential", "smokestack", "solar_farm", "space_facility", 21 | "stadium", "storage_tank", "surface_mine", "swimming_pool", "toll_booth", 22 | "tower", "tunnel_opening", "waste_disposal", "water_treatment_facility", 23 | "wind_farm", "zoo" 24 | ] 25 | 26 | 27 | @DATASET_REGISTRY.register() 28 | class FMoW(WILDSBase): 29 | """Satellite imagery classification. 30 | 31 | 62 classes (building or land use categories). 32 | 33 | Reference: 34 | - Christie et al. "Functional Map of the World." CVPR 2018. 35 | - Koh et al. "Wilds: A benchmark of in-the-wild distribution shifts." ICML 2021. 36 | """ 37 | 38 | dataset_dir = "fmow_v1.1" 39 | 40 | def __init__(self, cfg): 41 | super().__init__(cfg) 42 | 43 | def get_image_path(self, dataset, idx): 44 | idx = dataset.full_idxs[idx] 45 | image_name = f"rgb_img_{idx}.png" 46 | image_path = osp.join(self.dataset_dir, "images", image_name) 47 | return image_path 48 | 49 | def get_domain(self, dataset, idx): 50 | # number of regions: 5 or 6 51 | # number of years: 16 52 | region_id = int(dataset.metadata_array[idx][0]) 53 | year_id = int(dataset.metadata_array[idx][1]) 54 | return region_id*16 + year_id 55 | 56 | def load_classnames(self): 57 | return {i: cat for i, cat in enumerate(CATEGORIES)} 58 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/dg/wilds/iwildcam.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pandas as pd 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY 5 | 6 | from .wilds_base import WILDSBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class IWildCam(WILDSBase): 11 | """Animal species recognition. 12 | 13 | 182 classes (species). 14 | 15 | Reference: 16 | - Beery et al. "The iwildcam 2021 competition dataset." arXiv 2021. 17 | - Koh et al. "Wilds: A benchmark of in-the-wild distribution shifts." ICML 2021. 18 | """ 19 | 20 | dataset_dir = "iwildcam_v2.0" 21 | 22 | def __init__(self, cfg): 23 | super().__init__(cfg) 24 | 25 | def get_image_path(self, dataset, idx): 26 | image_name = dataset._input_array[idx] 27 | image_path = osp.join(self.dataset_dir, "train", image_name) 28 | return image_path 29 | 30 | def load_classnames(self): 31 | df = pd.read_csv(osp.join(self.dataset_dir, "categories.csv")) 32 | return dict(df["name"]) 33 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/ssl/__init__.py: -------------------------------------------------------------------------------- 1 | from .svhn import SVHN 2 | from .cifar import CIFAR10, CIFAR100 3 | from .stl10 import STL10 4 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/ssl/cifar.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import os.path as osp 4 | 5 | from dassl.utils import listdir_nohidden 6 | 7 | from ..build import DATASET_REGISTRY 8 | from ..base_dataset import Datum, DatasetBase 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class CIFAR10(DatasetBase): 13 | """CIFAR10 for SSL. 14 | 15 | Reference: 16 | - Krizhevsky. Learning Multiple Layers of Features 17 | from Tiny Images. Tech report. 18 | """ 19 | 20 | dataset_dir = "cifar10" 21 | 22 | def __init__(self, cfg): 23 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | train_dir = osp.join(self.dataset_dir, "train") 26 | test_dir = osp.join(self.dataset_dir, "test") 27 | 28 | assert cfg.DATASET.NUM_LABELED > 0 29 | 30 | train_x, train_u, val = self._read_data_train( 31 | train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT 32 | ) 33 | test = self._read_data_test(test_dir) 34 | 35 | if cfg.DATASET.ALL_AS_UNLABELED: 36 | train_u = train_u + train_x 37 | 38 | if len(val) == 0: 39 | val = None 40 | 41 | super().__init__(train_x=train_x, train_u=train_u, val=val, test=test) 42 | 43 | def _read_data_train(self, data_dir, num_labeled, val_percent): 44 | class_names = listdir_nohidden(data_dir) 45 | class_names.sort() 46 | num_labeled_per_class = num_labeled / len(class_names) 47 | items_x, items_u, items_v = [], [], [] 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_dir = osp.join(data_dir, class_name) 51 | imnames = listdir_nohidden(class_dir) 52 | 53 | # Split into train and val following Oliver et al. 2018 54 | # Set cfg.DATASET.VAL_PERCENT to 0 to not use val data 55 | num_val = math.floor(len(imnames) * val_percent) 56 | imnames_train = imnames[num_val:] 57 | imnames_val = imnames[:num_val] 58 | 59 | # Note we do shuffle after split 60 | random.shuffle(imnames_train) 61 | 62 | for i, imname in enumerate(imnames_train): 63 | impath = osp.join(class_dir, imname) 64 | item = Datum(impath=impath, label=label) 65 | 66 | if (i + 1) <= num_labeled_per_class: 67 | items_x.append(item) 68 | 69 | else: 70 | items_u.append(item) 71 | 72 | for imname in imnames_val: 73 | impath = osp.join(class_dir, imname) 74 | item = Datum(impath=impath, label=label) 75 | items_v.append(item) 76 | 77 | return items_x, items_u, items_v 78 | 79 | def _read_data_test(self, data_dir): 80 | class_names = listdir_nohidden(data_dir) 81 | class_names.sort() 82 | items = [] 83 | 84 | for label, class_name in enumerate(class_names): 85 | class_dir = osp.join(data_dir, class_name) 86 | imnames = listdir_nohidden(class_dir) 87 | 88 | for imname in imnames: 89 | impath = osp.join(class_dir, imname) 90 | item = Datum(impath=impath, label=label) 91 | items.append(item) 92 | 93 | return items 94 | 95 | 96 | @DATASET_REGISTRY.register() 97 | class CIFAR100(CIFAR10): 98 | """CIFAR100 for SSL. 99 | 100 | Reference: 101 | - Krizhevsky. Learning Multiple Layers of Features 102 | from Tiny Images. Tech report. 103 | """ 104 | 105 | dataset_dir = "cifar100" 106 | 107 | def __init__(self, cfg): 108 | super().__init__(cfg) 109 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/ssl/stl10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | 4 | from dassl.utils import listdir_nohidden 5 | 6 | from ..build import DATASET_REGISTRY 7 | from ..base_dataset import Datum, DatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class STL10(DatasetBase): 12 | """STL-10 dataset. 13 | 14 | Description: 15 | - 10 classes: airplane, bird, car, cat, deer, dog, horse, 16 | monkey, ship, truck. 17 | - Images are 96x96 pixels, color. 18 | - 500 training images per class, 800 test images per class. 19 | - 100,000 unlabeled images for unsupervised learning. 20 | 21 | Reference: 22 | - Coates et al. An Analysis of Single Layer Networks in 23 | Unsupervised Feature Learning. AISTATS 2011. 24 | """ 25 | 26 | dataset_dir = "stl10" 27 | 28 | def __init__(self, cfg): 29 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | train_dir = osp.join(self.dataset_dir, "train") 32 | test_dir = osp.join(self.dataset_dir, "test") 33 | unlabeled_dir = osp.join(self.dataset_dir, "unlabeled") 34 | fold_file = osp.join( 35 | self.dataset_dir, "stl10_binary", "fold_indices.txt" 36 | ) 37 | 38 | # Only use the first five splits 39 | assert 0 <= cfg.DATASET.STL10_FOLD <= 4 40 | 41 | train_x = self._read_data_train( 42 | train_dir, cfg.DATASET.STL10_FOLD, fold_file 43 | ) 44 | train_u = self._read_data_all(unlabeled_dir) 45 | test = self._read_data_all(test_dir) 46 | 47 | if cfg.DATASET.ALL_AS_UNLABELED: 48 | train_u = train_u + train_x 49 | 50 | super().__init__(train_x=train_x, train_u=train_u, test=test) 51 | 52 | def _read_data_train(self, data_dir, fold, fold_file): 53 | imnames = listdir_nohidden(data_dir) 54 | imnames.sort() 55 | items = [] 56 | 57 | list_idx = list(range(len(imnames))) 58 | if fold >= 0: 59 | with open(fold_file, "r") as f: 60 | str_idx = f.read().splitlines()[fold] 61 | list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ") 62 | 63 | for i in list_idx: 64 | imname = imnames[i] 65 | impath = osp.join(data_dir, imname) 66 | label = osp.splitext(imname)[0].split("_")[1] 67 | label = int(label) 68 | item = Datum(impath=impath, label=label) 69 | items.append(item) 70 | 71 | return items 72 | 73 | def _read_data_all(self, data_dir): 74 | imnames = listdir_nohidden(data_dir) 75 | items = [] 76 | 77 | for imname in imnames: 78 | impath = osp.join(data_dir, imname) 79 | label = osp.splitext(imname)[0].split("_")[1] 80 | if label == "none": 81 | label = -1 82 | else: 83 | label = int(label) 84 | item = Datum(impath=impath, label=label) 85 | items.append(item) 86 | 87 | return items 88 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/datasets/ssl/svhn.py: -------------------------------------------------------------------------------- 1 | from .cifar import CIFAR10 2 | from ..build import DATASET_REGISTRY 3 | 4 | 5 | @DATASET_REGISTRY.register() 6 | class SVHN(CIFAR10): 7 | """SVHN for SSL. 8 | 9 | Reference: 10 | - Netzer et al. Reading Digits in Natural Images with 11 | Unsupervised Feature Learning. NIPS-W 2011. 12 | """ 13 | 14 | dataset_dir = "svhn" 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import INTERPOLATION_MODES, build_transform 2 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import TRAINER_REGISTRY, build_trainer # isort:skip 2 | from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip 3 | 4 | from .da import * 5 | from .dg import * 6 | from .ssl import * 7 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | TRAINER_REGISTRY = Registry("TRAINER") 4 | 5 | 6 | def build_trainer(cfg): 7 | avai_trainers = TRAINER_REGISTRY.registered_names() 8 | check_availability(cfg.TRAINER.NAME, avai_trainers) 9 | if cfg.VERBOSE: 10 | print("Loading trainer: {}".format(cfg.TRAINER.NAME)) 11 | return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg) 12 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/da/__init__.py: -------------------------------------------------------------------------------- 1 | from .se import SE 2 | from .mcd import MCD 3 | from .mme import MME 4 | from .adda import ADDA 5 | from .cdac import CDAC 6 | from .dael import DAEL 7 | from .dann import DANN 8 | from .adabn import AdaBN 9 | from .m3sda import M3SDA 10 | from .source_only import SourceOnly 11 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/da/adabn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dassl.utils import check_isfile 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class AdaBN(TrainerXU): 9 | """Adaptive Batch Normalization. 10 | 11 | https://arxiv.org/abs/1603.04779. 12 | """ 13 | 14 | def __init__(self, cfg): 15 | super().__init__(cfg) 16 | self.done_reset_bn_stats = False 17 | 18 | def check_cfg(self, cfg): 19 | assert check_isfile( 20 | cfg.MODEL.INIT_WEIGHTS 21 | ), "The weights of source model must be provided" 22 | 23 | def before_epoch(self): 24 | if not self.done_reset_bn_stats: 25 | for m in self.model.modules(): 26 | classname = m.__class__.__name__ 27 | if classname.find("BatchNorm") != -1: 28 | m.reset_running_stats() 29 | 30 | self.done_reset_bn_stats = True 31 | 32 | def forward_backward(self, batch_x, batch_u): 33 | input_u = batch_u["img"].to(self.device) 34 | 35 | with torch.no_grad(): 36 | self.model(input_u) 37 | 38 | return None 39 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/da/adda.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import check_isfile, count_num_param, open_specified_layers 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.modeling import build_head 9 | 10 | 11 | @TRAINER_REGISTRY.register() 12 | class ADDA(TrainerXU): 13 | """Adversarial Discriminative Domain Adaptation. 14 | 15 | https://arxiv.org/abs/1702.05464. 16 | """ 17 | 18 | def __init__(self, cfg): 19 | super().__init__(cfg) 20 | self.open_layers = ["backbone"] 21 | if isinstance(self.model.head, nn.Module): 22 | self.open_layers.append("head") 23 | 24 | self.source_model = copy.deepcopy(self.model) 25 | self.source_model.eval() 26 | for param in self.source_model.parameters(): 27 | param.requires_grad_(False) 28 | 29 | self.build_critic() 30 | 31 | self.bce = nn.BCEWithLogitsLoss() 32 | 33 | def check_cfg(self, cfg): 34 | assert check_isfile( 35 | cfg.MODEL.INIT_WEIGHTS 36 | ), "The weights of source model must be provided" 37 | 38 | def build_critic(self): 39 | cfg = self.cfg 40 | 41 | print("Building critic network") 42 | fdim = self.model.fdim 43 | critic_body = build_head( 44 | "mlp", 45 | verbose=cfg.VERBOSE, 46 | in_features=fdim, 47 | hidden_layers=[fdim, fdim // 2], 48 | activation="leaky_relu", 49 | ) 50 | self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1)) 51 | print("# params: {:,}".format(count_num_param(self.critic))) 52 | self.critic.to(self.device) 53 | self.optim_c = build_optimizer(self.critic, cfg.OPTIM) 54 | self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) 55 | self.register_model("critic", self.critic, self.optim_c, self.sched_c) 56 | 57 | def forward_backward(self, batch_x, batch_u): 58 | open_specified_layers(self.model, self.open_layers) 59 | input_x, _, input_u = self.parse_batch_train(batch_x, batch_u) 60 | domain_x = torch.ones(input_x.shape[0], 1).to(self.device) 61 | domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) 62 | 63 | _, feat_x = self.source_model(input_x, return_feature=True) 64 | _, feat_u = self.model(input_u, return_feature=True) 65 | 66 | logit_xd = self.critic(feat_x) 67 | logit_ud = self.critic(feat_u.detach()) 68 | 69 | loss_critic = self.bce(logit_xd, domain_x) 70 | loss_critic += self.bce(logit_ud, domain_u) 71 | self.model_backward_and_update(loss_critic, "critic") 72 | 73 | logit_ud = self.critic(feat_u) 74 | loss_model = self.bce(logit_ud, 1 - domain_u) 75 | self.model_backward_and_update(loss_model, "model") 76 | 77 | loss_summary = { 78 | "loss_critic": loss_critic.item(), 79 | "loss_model": loss_model.item(), 80 | } 81 | 82 | if (self.batch_idx + 1) == self.num_batches: 83 | self.update_lr() 84 | 85 | return loss_summary 86 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/da/dann.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import count_num_param 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.metrics import compute_accuracy 9 | from dassl.modeling import build_head 10 | from dassl.modeling.ops import ReverseGrad 11 | 12 | 13 | @TRAINER_REGISTRY.register() 14 | class DANN(TrainerXU): 15 | """Domain-Adversarial Neural Networks. 16 | 17 | https://arxiv.org/abs/1505.07818. 18 | """ 19 | 20 | def __init__(self, cfg): 21 | super().__init__(cfg) 22 | self.build_critic() 23 | self.ce = nn.CrossEntropyLoss() 24 | self.bce = nn.BCEWithLogitsLoss() 25 | 26 | def build_critic(self): 27 | cfg = self.cfg 28 | 29 | print("Building critic network") 30 | fdim = self.model.fdim 31 | critic_body = build_head( 32 | "mlp", 33 | verbose=cfg.VERBOSE, 34 | in_features=fdim, 35 | hidden_layers=[fdim, fdim], 36 | activation="leaky_relu", 37 | ) 38 | self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1)) 39 | print("# params: {:,}".format(count_num_param(self.critic))) 40 | self.critic.to(self.device) 41 | self.optim_c = build_optimizer(self.critic, cfg.OPTIM) 42 | self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) 43 | self.register_model("critic", self.critic, self.optim_c, self.sched_c) 44 | self.revgrad = ReverseGrad() 45 | 46 | def forward_backward(self, batch_x, batch_u): 47 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 48 | domain_x = torch.ones(input_x.shape[0], 1).to(self.device) 49 | domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) 50 | 51 | global_step = self.batch_idx + self.epoch * self.num_batches 52 | progress = global_step / (self.max_epoch * self.num_batches) 53 | lmda = 2 / (1 + np.exp(-10 * progress)) - 1 54 | 55 | logit_x, feat_x = self.model(input_x, return_feature=True) 56 | _, feat_u = self.model(input_u, return_feature=True) 57 | 58 | loss_x = self.ce(logit_x, label_x) 59 | 60 | feat_x = self.revgrad(feat_x, grad_scaling=lmda) 61 | feat_u = self.revgrad(feat_u, grad_scaling=lmda) 62 | output_xd = self.critic(feat_x) 63 | output_ud = self.critic(feat_u) 64 | loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u) 65 | 66 | loss = loss_x + loss_d 67 | self.model_backward_and_update(loss) 68 | 69 | loss_summary = { 70 | "loss_x": loss_x.item(), 71 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 72 | "loss_d": loss_d.item(), 73 | } 74 | 75 | if (self.batch_idx + 1) == self.num_batches: 76 | self.update_lr() 77 | 78 | return loss_summary 79 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/da/mme.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import count_num_param 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.metrics import compute_accuracy 9 | from dassl.modeling.ops import ReverseGrad 10 | from dassl.engine.trainer import SimpleNet 11 | 12 | 13 | class Prototypes(nn.Module): 14 | 15 | def __init__(self, fdim, num_classes, temp=0.05): 16 | super().__init__() 17 | self.prototypes = nn.Linear(fdim, num_classes, bias=False) 18 | self.temp = temp 19 | 20 | def forward(self, x): 21 | x = F.normalize(x, p=2, dim=1) 22 | out = self.prototypes(x) 23 | out = out / self.temp 24 | return out 25 | 26 | 27 | @TRAINER_REGISTRY.register() 28 | class MME(TrainerXU): 29 | """Minimax Entropy. 30 | 31 | https://arxiv.org/abs/1904.06487. 32 | """ 33 | 34 | def __init__(self, cfg): 35 | super().__init__(cfg) 36 | self.lmda = cfg.TRAINER.MME.LMDA 37 | 38 | def build_model(self): 39 | cfg = self.cfg 40 | 41 | print("Building F") 42 | self.F = SimpleNet(cfg, cfg.MODEL, 0) 43 | self.F.to(self.device) 44 | print("# params: {:,}".format(count_num_param(self.F))) 45 | self.optim_F = build_optimizer(self.F, cfg.OPTIM) 46 | self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) 47 | self.register_model("F", self.F, self.optim_F, self.sched_F) 48 | 49 | print("Building C") 50 | self.C = Prototypes(self.F.fdim, self.num_classes) 51 | self.C.to(self.device) 52 | print("# params: {:,}".format(count_num_param(self.C))) 53 | self.optim_C = build_optimizer(self.C, cfg.OPTIM) 54 | self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) 55 | self.register_model("C", self.C, self.optim_C, self.sched_C) 56 | 57 | self.revgrad = ReverseGrad() 58 | 59 | def forward_backward(self, batch_x, batch_u): 60 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 61 | 62 | feat_x = self.F(input_x) 63 | logit_x = self.C(feat_x) 64 | loss_x = F.cross_entropy(logit_x, label_x) 65 | self.model_backward_and_update(loss_x) 66 | 67 | feat_u = self.F(input_u) 68 | feat_u = self.revgrad(feat_u) 69 | logit_u = self.C(feat_u) 70 | prob_u = F.softmax(logit_u, 1) 71 | loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean() 72 | self.model_backward_and_update(loss_u * self.lmda) 73 | 74 | loss_summary = { 75 | "loss_x": loss_x.item(), 76 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 77 | "loss_u": loss_u.item(), 78 | } 79 | 80 | if (self.batch_idx + 1) == self.num_batches: 81 | self.update_lr() 82 | 83 | return loss_summary 84 | 85 | def model_inference(self, input): 86 | return self.C(self.F(input)) 87 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/da/se.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update 7 | 8 | 9 | @TRAINER_REGISTRY.register() 10 | class SE(TrainerXU): 11 | """Self-ensembling for visual domain adaptation. 12 | 13 | https://arxiv.org/abs/1706.05208. 14 | """ 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA 19 | self.conf_thre = cfg.TRAINER.SE.CONF_THRE 20 | self.rampup = cfg.TRAINER.SE.RAMPUP 21 | 22 | self.teacher = copy.deepcopy(self.model) 23 | self.teacher.train() 24 | for param in self.teacher.parameters(): 25 | param.requires_grad_(False) 26 | 27 | def check_cfg(self, cfg): 28 | assert cfg.DATALOADER.K_TRANSFORMS == 2 29 | 30 | def forward_backward(self, batch_x, batch_u): 31 | global_step = self.batch_idx + self.epoch * self.num_batches 32 | parsed = self.parse_batch_train(batch_x, batch_u) 33 | input_x, label_x, input_u1, input_u2 = parsed 34 | 35 | logit_x = self.model(input_x) 36 | loss_x = F.cross_entropy(logit_x, label_x) 37 | 38 | prob_u = F.softmax(self.model(input_u1), 1) 39 | t_prob_u = F.softmax(self.teacher(input_u2), 1) 40 | loss_u = ((prob_u - t_prob_u)**2).sum(1) 41 | 42 | if self.conf_thre: 43 | max_prob = t_prob_u.max(1)[0] 44 | mask = (max_prob > self.conf_thre).float() 45 | loss_u = (loss_u * mask).mean() 46 | else: 47 | weight_u = sigmoid_rampup(global_step, self.rampup) 48 | loss_u = loss_u.mean() * weight_u 49 | 50 | loss = loss_x + loss_u 51 | self.model_backward_and_update(loss) 52 | 53 | ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha) 54 | ema_model_update(self.model, self.teacher, ema_alpha) 55 | 56 | loss_summary = { 57 | "loss_x": loss_x.item(), 58 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 59 | "loss_u": loss_u.item(), 60 | } 61 | 62 | if (self.batch_idx + 1) == self.num_batches: 63 | self.update_lr() 64 | 65 | return loss_summary 66 | 67 | def parse_batch_train(self, batch_x, batch_u): 68 | input_x = batch_x["img"][0] 69 | label_x = batch_x["label"] 70 | input_u = batch_u["img"] 71 | input_u1, input_u2 = input_u 72 | 73 | input_x = input_x.to(self.device) 74 | label_x = label_x.to(self.device) 75 | input_u1 = input_u1.to(self.device) 76 | input_u2 = input_u2.to(self.device) 77 | 78 | return input_x, label_x, input_u1, input_u2 79 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/da/source_only.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class SourceOnly(TrainerXU): 9 | """Baseline model for domain adaptation, which is 10 | trained using source data only. 11 | """ 12 | 13 | def forward_backward(self, batch_x, batch_u): 14 | input, label = self.parse_batch_train(batch_x, batch_u) 15 | output = self.model(input) 16 | loss = F.cross_entropy(output, label) 17 | self.model_backward_and_update(loss) 18 | 19 | loss_summary = { 20 | "loss": loss.item(), 21 | "acc": compute_accuracy(output, label)[0].item(), 22 | } 23 | 24 | if (self.batch_idx + 1) == self.num_batches: 25 | self.update_lr() 26 | 27 | return loss_summary 28 | 29 | def parse_batch_train(self, batch_x, batch_u): 30 | input = batch_x["img"] 31 | label = batch_x["label"] 32 | input = input.to(self.device) 33 | label = label.to(self.device) 34 | return input, label 35 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/dg/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddaig import DDAIG 2 | from .daeldg import DAELDG 3 | from .vanilla import Vanilla 4 | from .crossgrad import CrossGrad 5 | from .domain_mix import DomainMix 6 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/dg/crossgrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.optim import build_optimizer, build_lr_scheduler 5 | from dassl.utils import count_num_param 6 | from dassl.engine import TRAINER_REGISTRY, TrainerX 7 | from dassl.engine.trainer import SimpleNet 8 | 9 | 10 | @TRAINER_REGISTRY.register() 11 | class CrossGrad(TrainerX): 12 | """Cross-gradient training. 13 | 14 | https://arxiv.org/abs/1804.10745. 15 | """ 16 | 17 | def __init__(self, cfg): 18 | super().__init__(cfg) 19 | self.eps_f = cfg.TRAINER.CROSSGRAD.EPS_F 20 | self.eps_d = cfg.TRAINER.CROSSGRAD.EPS_D 21 | self.alpha_f = cfg.TRAINER.CROSSGRAD.ALPHA_F 22 | self.alpha_d = cfg.TRAINER.CROSSGRAD.ALPHA_D 23 | 24 | def build_model(self): 25 | cfg = self.cfg 26 | 27 | print("Building F") 28 | self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) 29 | self.F.to(self.device) 30 | print("# params: {:,}".format(count_num_param(self.F))) 31 | self.optim_F = build_optimizer(self.F, cfg.OPTIM) 32 | self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) 33 | self.register_model("F", self.F, self.optim_F, self.sched_F) 34 | 35 | print("Building D") 36 | self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains) 37 | self.D.to(self.device) 38 | print("# params: {:,}".format(count_num_param(self.D))) 39 | self.optim_D = build_optimizer(self.D, cfg.OPTIM) 40 | self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) 41 | self.register_model("D", self.D, self.optim_D, self.sched_D) 42 | 43 | def forward_backward(self, batch): 44 | input, label, domain = self.parse_batch_train(batch) 45 | 46 | input.requires_grad = True 47 | 48 | # Compute domain perturbation 49 | loss_d = F.cross_entropy(self.D(input), domain) 50 | loss_d.backward() 51 | grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1) 52 | input_d = input.data + self.eps_f * grad_d 53 | 54 | # Compute label perturbation 55 | input.grad.data.zero_() 56 | loss_f = F.cross_entropy(self.F(input), label) 57 | loss_f.backward() 58 | grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1) 59 | input_f = input.data + self.eps_d * grad_f 60 | 61 | input = input.detach() 62 | 63 | # Update label net 64 | loss_f1 = F.cross_entropy(self.F(input), label) 65 | loss_f2 = F.cross_entropy(self.F(input_d), label) 66 | loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2 67 | self.model_backward_and_update(loss_f, "F") 68 | 69 | # Update domain net 70 | loss_d1 = F.cross_entropy(self.D(input), domain) 71 | loss_d2 = F.cross_entropy(self.D(input_f), domain) 72 | loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2 73 | self.model_backward_and_update(loss_d, "D") 74 | 75 | loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()} 76 | 77 | if (self.batch_idx + 1) == self.num_batches: 78 | self.update_lr() 79 | 80 | return loss_summary 81 | 82 | def model_inference(self, input): 83 | return self.F(input) 84 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/dg/domain_mix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerX 5 | from dassl.metrics import compute_accuracy 6 | 7 | __all__ = ["DomainMix"] 8 | 9 | 10 | @TRAINER_REGISTRY.register() 11 | class DomainMix(TrainerX): 12 | """DomainMix. 13 | 14 | Dynamic Domain Generalization. 15 | 16 | https://github.com/MetaVisionLab/DDG 17 | """ 18 | 19 | def __init__(self, cfg): 20 | super(DomainMix, self).__init__(cfg) 21 | self.mix_type = cfg.TRAINER.DOMAINMIX.TYPE 22 | self.alpha = cfg.TRAINER.DOMAINMIX.ALPHA 23 | self.beta = cfg.TRAINER.DOMAINMIX.BETA 24 | self.dist_beta = torch.distributions.Beta(self.alpha, self.beta) 25 | 26 | def forward_backward(self, batch): 27 | images, label_a, label_b, lam = self.parse_batch_train(batch) 28 | output = self.model(images) 29 | loss = lam * F.cross_entropy( 30 | output, label_a 31 | ) + (1-lam) * F.cross_entropy(output, label_b) 32 | self.model_backward_and_update(loss) 33 | 34 | loss_summary = { 35 | "loss": loss.item(), 36 | "acc": compute_accuracy(output, label_a)[0].item() 37 | } 38 | 39 | if (self.batch_idx + 1) == self.num_batches: 40 | self.update_lr() 41 | 42 | return loss_summary 43 | 44 | def parse_batch_train(self, batch): 45 | images = batch["img"] 46 | target = batch["label"] 47 | domain = batch["domain"] 48 | images = images.to(self.device) 49 | target = target.to(self.device) 50 | domain = domain.to(self.device) 51 | images, target_a, target_b, lam = self.domain_mix( 52 | images, target, domain 53 | ) 54 | return images, target_a, target_b, lam 55 | 56 | def domain_mix(self, x, target, domain): 57 | lam = ( 58 | self.dist_beta.rsample((1, )) 59 | if self.alpha > 0 else torch.tensor(1) 60 | ).to(x.device) 61 | 62 | # random shuffle 63 | perm = torch.randperm(x.size(0), dtype=torch.int64, device=x.device) 64 | if self.mix_type == "crossdomain": 65 | domain_list = torch.unique(domain) 66 | if len(domain_list) > 1: 67 | for idx in domain_list: 68 | cnt_a = torch.sum(domain == idx) 69 | idx_b = (domain != idx).nonzero().squeeze(-1) 70 | cnt_b = idx_b.shape[0] 71 | perm_b = torch.ones(cnt_b).multinomial( 72 | num_samples=cnt_a, replacement=bool(cnt_a > cnt_b) 73 | ) 74 | perm[domain == idx] = idx_b[perm_b] 75 | elif self.mix_type != "random": 76 | raise NotImplementedError( 77 | f"Chooses {'random', 'crossdomain'}, but got {self.mix_type}." 78 | ) 79 | mixed_x = lam*x + (1-lam) * x[perm, :] 80 | target_a, target_b = target, target[perm] 81 | return mixed_x, target_a, target_b, lam 82 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/dg/vanilla.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerX 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class Vanilla(TrainerX): 9 | """Vanilla model. 10 | 11 | A.k.a. Empirical Risk Minimization, or ERM. 12 | """ 13 | 14 | def forward_backward(self, batch): 15 | input, target = self.parse_batch_train(batch) 16 | output = self.model(input) 17 | loss = F.cross_entropy(output, target) 18 | self.model_backward_and_update(loss) 19 | 20 | loss_summary = { 21 | "loss": loss.item(), 22 | "acc": compute_accuracy(output, target)[0].item(), 23 | } 24 | 25 | if (self.batch_idx + 1) == self.num_batches: 26 | self.update_lr() 27 | 28 | return loss_summary 29 | 30 | def parse_batch_train(self, batch): 31 | input = batch["img"] 32 | target = batch["label"] 33 | input = input.to(self.device) 34 | target = target.to(self.device) 35 | return input, target 36 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/ssl/__init__.py: -------------------------------------------------------------------------------- 1 | from .entmin import EntMin 2 | from .fixmatch import FixMatch 3 | from .mixmatch import MixMatch 4 | from .mean_teacher import MeanTeacher 5 | from .sup_baseline import SupBaseline 6 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/ssl/entmin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | 7 | 8 | @TRAINER_REGISTRY.register() 9 | class EntMin(TrainerXU): 10 | """Entropy Minimization. 11 | 12 | http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf. 13 | """ 14 | 15 | def __init__(self, cfg): 16 | super().__init__(cfg) 17 | self.lmda = cfg.TRAINER.ENTMIN.LMDA 18 | 19 | def forward_backward(self, batch_x, batch_u): 20 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 21 | 22 | output_x = self.model(input_x) 23 | loss_x = F.cross_entropy(output_x, label_x) 24 | 25 | output_u = F.softmax(self.model(input_u), 1) 26 | loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean() 27 | 28 | loss = loss_x + loss_u * self.lmda 29 | 30 | self.model_backward_and_update(loss) 31 | 32 | loss_summary = { 33 | "loss_x": loss_x.item(), 34 | "acc_x": compute_accuracy(output_x, label_x)[0].item(), 35 | "loss_u": loss_u.item(), 36 | } 37 | 38 | if (self.batch_idx + 1) == self.num_batches: 39 | self.update_lr() 40 | 41 | return loss_summary 42 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/ssl/mean_teacher.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update 7 | 8 | 9 | @TRAINER_REGISTRY.register() 10 | class MeanTeacher(TrainerXU): 11 | """Mean teacher. 12 | 13 | https://arxiv.org/abs/1703.01780. 14 | """ 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | self.weight_u = cfg.TRAINER.MEANTEACHER.WEIGHT_U 19 | self.ema_alpha = cfg.TRAINER.MEANTEACHER.EMA_ALPHA 20 | self.rampup = cfg.TRAINER.MEANTEACHER.RAMPUP 21 | 22 | self.teacher = copy.deepcopy(self.model) 23 | self.teacher.train() 24 | for param in self.teacher.parameters(): 25 | param.requires_grad_(False) 26 | 27 | def forward_backward(self, batch_x, batch_u): 28 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 29 | 30 | logit_x = self.model(input_x) 31 | loss_x = F.cross_entropy(logit_x, label_x) 32 | 33 | target_u = F.softmax(self.teacher(input_u), 1) 34 | prob_u = F.softmax(self.model(input_u), 1) 35 | loss_u = ((prob_u - target_u)**2).sum(1).mean() 36 | 37 | weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup) 38 | loss = loss_x + loss_u*weight_u 39 | self.model_backward_and_update(loss) 40 | 41 | global_step = self.batch_idx + self.epoch * self.num_batches 42 | ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha) 43 | ema_model_update(self.model, self.teacher, ema_alpha) 44 | 45 | loss_summary = { 46 | "loss_x": loss_x.item(), 47 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 48 | "loss_u": loss_u.item(), 49 | } 50 | 51 | if (self.batch_idx + 1) == self.num_batches: 52 | self.update_lr() 53 | 54 | return loss_summary 55 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/ssl/mixmatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.modeling.ops import mixup 6 | from dassl.modeling.ops.utils import ( 7 | sharpen_prob, create_onehot, linear_rampup, shuffle_index 8 | ) 9 | 10 | 11 | @TRAINER_REGISTRY.register() 12 | class MixMatch(TrainerXU): 13 | """MixMatch: A Holistic Approach to Semi-Supervised Learning. 14 | 15 | https://arxiv.org/abs/1905.02249. 16 | """ 17 | 18 | def __init__(self, cfg): 19 | super().__init__(cfg) 20 | self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U 21 | self.temp = cfg.TRAINER.MIXMATCH.TEMP 22 | self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA 23 | self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP 24 | 25 | def check_cfg(self, cfg): 26 | assert cfg.DATALOADER.K_TRANSFORMS > 1 27 | 28 | def forward_backward(self, batch_x, batch_u): 29 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 30 | num_x = input_x.shape[0] 31 | 32 | global_step = self.batch_idx + self.epoch * self.num_batches 33 | weight_u = self.weight_u * linear_rampup(global_step, self.rampup) 34 | 35 | # Generate pseudo-label for unlabeled data 36 | with torch.no_grad(): 37 | output_u = 0 38 | for input_ui in input_u: 39 | output_ui = F.softmax(self.model(input_ui), 1) 40 | output_u += output_ui 41 | output_u /= len(input_u) 42 | label_u = sharpen_prob(output_u, self.temp) 43 | label_u = [label_u] * len(input_u) 44 | label_u = torch.cat(label_u, 0) 45 | input_u = torch.cat(input_u, 0) 46 | 47 | # Combine and shuffle labeled and unlabeled data 48 | input_xu = torch.cat([input_x, input_u], 0) 49 | label_xu = torch.cat([label_x, label_u], 0) 50 | input_xu, label_xu = shuffle_index(input_xu, label_xu) 51 | 52 | # Mixup 53 | input_x, label_x = mixup( 54 | input_x, 55 | input_xu[:num_x], 56 | label_x, 57 | label_xu[:num_x], 58 | self.beta, 59 | preserve_order=True, 60 | ) 61 | 62 | input_u, label_u = mixup( 63 | input_u, 64 | input_xu[num_x:], 65 | label_u, 66 | label_xu[num_x:], 67 | self.beta, 68 | preserve_order=True, 69 | ) 70 | 71 | # Compute losses 72 | output_x = F.softmax(self.model(input_x), 1) 73 | loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean() 74 | 75 | output_u = F.softmax(self.model(input_u), 1) 76 | loss_u = ((label_u - output_u)**2).mean() 77 | 78 | loss = loss_x + loss_u*weight_u 79 | self.model_backward_and_update(loss) 80 | 81 | loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()} 82 | 83 | if (self.batch_idx + 1) == self.num_batches: 84 | self.update_lr() 85 | 86 | return loss_summary 87 | 88 | def parse_batch_train(self, batch_x, batch_u): 89 | input_x = batch_x["img"][0] 90 | label_x = batch_x["label"] 91 | label_x = create_onehot(label_x, self.num_classes) 92 | input_u = batch_u["img"] 93 | 94 | input_x = input_x.to(self.device) 95 | label_x = label_x.to(self.device) 96 | input_u = [input_ui.to(self.device) for input_ui in input_u] 97 | 98 | return input_x, label_x, input_u 99 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/engine/ssl/sup_baseline.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class SupBaseline(TrainerXU): 9 | """Supervised Baseline.""" 10 | 11 | def forward_backward(self, batch_x, batch_u): 12 | input, label = self.parse_batch_train(batch_x, batch_u) 13 | output = self.model(input) 14 | loss = F.cross_entropy(output, label) 15 | self.model_backward_and_update(loss) 16 | 17 | loss_summary = { 18 | "loss": loss.item(), 19 | "acc": compute_accuracy(output, label)[0].item(), 20 | } 21 | 22 | if (self.batch_idx + 1) == self.num_batches: 23 | self.update_lr() 24 | 25 | return loss_summary 26 | 27 | def parse_batch_train(self, batch_x, batch_u): 28 | input = batch_x["img"] 29 | label = batch_x["label"] 30 | input = input.to(self.device) 31 | label = label.to(self.device) 32 | return input, label 33 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_evaluator, EVALUATOR_REGISTRY # isort:skip 2 | 3 | from .evaluator import EvaluatorBase, Classification 4 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/evaluation/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | EVALUATOR_REGISTRY = Registry("EVALUATOR") 4 | 5 | 6 | def build_evaluator(cfg, **kwargs): 7 | avai_evaluators = EVALUATOR_REGISTRY.registered_names() 8 | check_availability(cfg.TEST.EVALUATOR, avai_evaluators) 9 | if cfg.VERBOSE: 10 | print("Loading evaluator: {}".format(cfg.TEST.EVALUATOR)) 11 | return EVALUATOR_REGISTRY.get(cfg.TEST.EVALUATOR)(cfg, **kwargs) 12 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import compute_accuracy 2 | from .distance import ( 3 | cosine_distance, compute_distance_matrix, euclidean_squared_distance 4 | ) 5 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | def compute_accuracy(output, target, topk=(1, )): 2 | """Computes the accuracy over the k top predictions for 3 | the specified values of k. 4 | 5 | Args: 6 | output (torch.Tensor): prediction matrix with shape (batch_size, num_classes). 7 | target (torch.LongTensor): ground truth labels with shape (batch_size). 8 | topk (tuple, optional): accuracy at top-k will be computed. For example, 9 | topk=(1, 5) means accuracy at top-1 and top-5 will be computed. 10 | 11 | Returns: 12 | list: accuracy at top-k. 13 | """ 14 | maxk = max(topk) 15 | batch_size = target.size(0) 16 | 17 | if isinstance(output, (tuple, list)): 18 | output = output[0] 19 | 20 | _, pred = output.topk(maxk, 1, True, True) 21 | pred = pred.t() 22 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 23 | 24 | res = [] 25 | for k in topk: 26 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 27 | acc = correct_k.mul_(100.0 / batch_size) 28 | res.append(acc) 29 | 30 | return res 31 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/metrics/distance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source: https://github.com/KaiyangZhou/deep-person-reid 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | def compute_distance_matrix(input1, input2, metric="euclidean"): 9 | """A wrapper function for computing distance matrix. 10 | 11 | Each input matrix has the shape (n_data, feature_dim). 12 | 13 | Args: 14 | input1 (torch.Tensor): 2-D feature matrix. 15 | input2 (torch.Tensor): 2-D feature matrix. 16 | metric (str, optional): "euclidean" or "cosine". 17 | Default is "euclidean". 18 | 19 | Returns: 20 | torch.Tensor: distance matrix. 21 | """ 22 | # check input 23 | assert isinstance(input1, torch.Tensor) 24 | assert isinstance(input2, torch.Tensor) 25 | assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format( 26 | input1.dim() 27 | ) 28 | assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format( 29 | input2.dim() 30 | ) 31 | assert input1.size(1) == input2.size(1) 32 | 33 | if metric == "euclidean": 34 | distmat = euclidean_squared_distance(input1, input2) 35 | elif metric == "cosine": 36 | distmat = cosine_distance(input1, input2) 37 | else: 38 | raise ValueError( 39 | "Unknown distance metric: {}. " 40 | 'Please choose either "euclidean" or "cosine"'.format(metric) 41 | ) 42 | 43 | return distmat 44 | 45 | 46 | def euclidean_squared_distance(input1, input2): 47 | """Computes euclidean squared distance. 48 | 49 | Args: 50 | input1 (torch.Tensor): 2-D feature matrix. 51 | input2 (torch.Tensor): 2-D feature matrix. 52 | 53 | Returns: 54 | torch.Tensor: distance matrix. 55 | """ 56 | m, n = input1.size(0), input2.size(0) 57 | mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) 58 | mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 59 | distmat = mat1 + mat2 60 | distmat.addmm_(1, -2, input1, input2.t()) 61 | return distmat 62 | 63 | 64 | def cosine_distance(input1, input2): 65 | """Computes cosine distance. 66 | 67 | Args: 68 | input1 (torch.Tensor): 2-D feature matrix. 69 | input2 (torch.Tensor): 2-D feature matrix. 70 | 71 | Returns: 72 | torch.Tensor: distance matrix. 73 | """ 74 | input1_normed = F.normalize(input1, p=2, dim=1) 75 | input2_normed = F.normalize(input2, p=2, dim=1) 76 | distmat = 1 - torch.mm(input1_normed, input2_normed.t()) 77 | return distmat 78 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .head import HEAD_REGISTRY, build_head 2 | from .network import NETWORK_REGISTRY, build_network 3 | from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone 4 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_backbone, BACKBONE_REGISTRY # isort:skip 2 | from .backbone import Backbone # isort:skip 3 | 4 | from .vgg import vgg16 5 | from .resnet import ( 6 | resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1, 7 | resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1, 8 | resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123, 9 | resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12, 10 | resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123, 11 | resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123 12 | ) 13 | from .alexnet import alexnet 14 | from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2 15 | from .cnn_digitsdg import cnn_digitsdg 16 | from .efficientnet import ( 17 | efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, 18 | efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 19 | ) 20 | from .resnet_dynamic import * 21 | from .cnn_digitsingle import cnn_digitsingle 22 | from .preact_resnet18 import preact_resnet18 23 | from .cnn_digit5_m3sda import cnn_digit5_m3sda 24 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/backbone/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from .build import BACKBONE_REGISTRY 6 | from .backbone import Backbone 7 | 8 | model_urls = { 9 | "alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth", 10 | } 11 | 12 | 13 | class AlexNet(Backbone): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | # Note that self.classifier outputs features rather than logits 34 | self.classifier = nn.Sequential( 35 | nn.Dropout(), 36 | nn.Linear(256 * 6 * 6, 4096), 37 | nn.ReLU(inplace=True), 38 | nn.Dropout(), 39 | nn.Linear(4096, 4096), 40 | nn.ReLU(inplace=True), 41 | ) 42 | 43 | self._out_features = 4096 44 | 45 | def forward(self, x): 46 | x = self.features(x) 47 | x = self.avgpool(x) 48 | x = torch.flatten(x, 1) 49 | return self.classifier(x) 50 | 51 | 52 | def init_pretrained_weights(model, model_url): 53 | pretrain_dict = model_zoo.load_url(model_url) 54 | model.load_state_dict(pretrain_dict, strict=False) 55 | 56 | 57 | @BACKBONE_REGISTRY.register() 58 | def alexnet(pretrained=True, **kwargs): 59 | model = AlexNet() 60 | 61 | if pretrained: 62 | init_pretrained_weights(model, model_urls["alexnet"]) 63 | 64 | return model 65 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Backbone(nn.Module): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self): 10 | pass 11 | 12 | @property 13 | def out_features(self): 14 | """Output feature dimension.""" 15 | if self.__dict__.get("_out_features") is None: 16 | return None 17 | return self._out_features 18 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/backbone/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | BACKBONE_REGISTRY = Registry("BACKBONE") 4 | 5 | 6 | def build_backbone(name, verbose=True, **kwargs): 7 | avai_backbones = BACKBONE_REGISTRY.registered_names() 8 | check_availability(name, avai_backbones) 9 | if verbose: 10 | print("Backbone: {}".format(name)) 11 | return BACKBONE_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/backbone/cnn_digit5_m3sda.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference 3 | 4 | https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA 5 | """ 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | 9 | from .build import BACKBONE_REGISTRY 10 | from .backbone import Backbone 11 | 12 | 13 | class FeatureExtractor(Backbone): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2) 18 | self.bn1 = nn.BatchNorm2d(64) 19 | self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2) 20 | self.bn2 = nn.BatchNorm2d(64) 21 | self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2) 22 | self.bn3 = nn.BatchNorm2d(128) 23 | self.fc1 = nn.Linear(8192, 3072) 24 | self.bn1_fc = nn.BatchNorm1d(3072) 25 | self.fc2 = nn.Linear(3072, 2048) 26 | self.bn2_fc = nn.BatchNorm1d(2048) 27 | 28 | self._out_features = 2048 29 | 30 | def _check_input(self, x): 31 | H, W = x.shape[2:] 32 | assert ( 33 | H == 32 and W == 32 34 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 35 | 36 | def forward(self, x): 37 | self._check_input(x) 38 | x = F.relu(self.bn1(self.conv1(x))) 39 | x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1) 40 | x = F.relu(self.bn2(self.conv2(x))) 41 | x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1) 42 | x = F.relu(self.bn3(self.conv3(x))) 43 | x = x.view(x.size(0), 8192) 44 | x = F.relu(self.bn1_fc(self.fc1(x))) 45 | x = F.dropout(x, training=self.training) 46 | x = F.relu(self.bn2_fc(self.fc2(x))) 47 | return x 48 | 49 | 50 | @BACKBONE_REGISTRY.register() 51 | def cnn_digit5_m3sda(**kwargs): 52 | """ 53 | This architecture was used for the Digit-5 dataset in: 54 | 55 | - Peng et al. Moment Matching for Multi-Source 56 | Domain Adaptation. ICCV 2019. 57 | """ 58 | return FeatureExtractor() 59 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/backbone/cnn_digitsdg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | 4 | from dassl.utils import init_network_weights 5 | 6 | from .build import BACKBONE_REGISTRY 7 | from .backbone import Backbone 8 | 9 | 10 | class Convolution(nn.Module): 11 | 12 | def __init__(self, c_in, c_out): 13 | super().__init__() 14 | self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1) 15 | self.relu = nn.ReLU(True) 16 | 17 | def forward(self, x): 18 | return self.relu(self.conv(x)) 19 | 20 | 21 | class ConvNet(Backbone): 22 | 23 | def __init__(self, c_hidden=64): 24 | super().__init__() 25 | self.conv1 = Convolution(3, c_hidden) 26 | self.conv2 = Convolution(c_hidden, c_hidden) 27 | self.conv3 = Convolution(c_hidden, c_hidden) 28 | self.conv4 = Convolution(c_hidden, c_hidden) 29 | 30 | self._out_features = 2**2 * c_hidden 31 | 32 | def _check_input(self, x): 33 | H, W = x.shape[2:] 34 | assert ( 35 | H == 32 and W == 32 36 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 37 | 38 | def forward(self, x): 39 | self._check_input(x) 40 | x = self.conv1(x) 41 | x = F.max_pool2d(x, 2) 42 | x = self.conv2(x) 43 | x = F.max_pool2d(x, 2) 44 | x = self.conv3(x) 45 | x = F.max_pool2d(x, 2) 46 | x = self.conv4(x) 47 | x = F.max_pool2d(x, 2) 48 | return x.view(x.size(0), -1) 49 | 50 | 51 | @BACKBONE_REGISTRY.register() 52 | def cnn_digitsdg(**kwargs): 53 | """ 54 | This architecture was used for DigitsDG dataset in: 55 | 56 | - Zhou et al. Deep Domain-Adversarial Image Generation 57 | for Domain Generalisation. AAAI 2020. 58 | """ 59 | model = ConvNet(c_hidden=64) 60 | init_network_weights(model, init_type="kaiming") 61 | return model 62 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/backbone/cnn_digitsingle.py: -------------------------------------------------------------------------------- 1 | """ 2 | This model is built based on 3 | https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py 4 | """ 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | from dassl.utils import init_network_weights 9 | 10 | from .build import BACKBONE_REGISTRY 11 | from .backbone import Backbone 12 | 13 | 14 | class CNN(Backbone): 15 | 16 | def __init__(self): 17 | super().__init__() 18 | self.conv1 = nn.Conv2d(3, 64, 5) 19 | self.conv2 = nn.Conv2d(64, 128, 5) 20 | self.fc3 = nn.Linear(5 * 5 * 128, 1024) 21 | self.fc4 = nn.Linear(1024, 1024) 22 | 23 | self._out_features = 1024 24 | 25 | def _check_input(self, x): 26 | H, W = x.shape[2:] 27 | assert ( 28 | H == 32 and W == 32 29 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 30 | 31 | def forward(self, x): 32 | self._check_input(x) 33 | x = self.conv1(x) 34 | x = F.relu(x) 35 | x = F.max_pool2d(x, 2) 36 | 37 | x = self.conv2(x) 38 | x = F.relu(x) 39 | x = F.max_pool2d(x, 2) 40 | 41 | x = x.view(x.size(0), -1) 42 | 43 | x = self.fc3(x) 44 | x = F.relu(x) 45 | 46 | x = self.fc4(x) 47 | x = F.relu(x) 48 | 49 | return x 50 | 51 | 52 | @BACKBONE_REGISTRY.register() 53 | def cnn_digitsingle(**kwargs): 54 | model = CNN() 55 | init_network_weights(model, init_type="kaiming") 56 | return model 57 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/backbone/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source: https://github.com/lukemelas/EfficientNet-PyTorch. 3 | """ 4 | __version__ = "0.6.4" 5 | from .model import ( 6 | EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2, 7 | efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, 8 | efficientnet_b7 9 | ) 10 | from .utils import ( 11 | BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params 12 | ) 13 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_head, HEAD_REGISTRY # isort:skip 2 | 3 | from .mlp import mlp 4 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/head/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | HEAD_REGISTRY = Registry("HEAD") 4 | 5 | 6 | def build_head(name, verbose=True, **kwargs): 7 | avai_heads = HEAD_REGISTRY.registered_names() 8 | check_availability(name, avai_heads) 9 | if verbose: 10 | print("Head: {}".format(name)) 11 | return HEAD_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/head/mlp.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | from .build import HEAD_REGISTRY 5 | 6 | 7 | class MLP(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | in_features=2048, 12 | hidden_layers=[], 13 | activation="relu", 14 | bn=True, 15 | dropout=0.0, 16 | ): 17 | super().__init__() 18 | if isinstance(hidden_layers, int): 19 | hidden_layers = [hidden_layers] 20 | 21 | assert len(hidden_layers) > 0 22 | self.out_features = hidden_layers[-1] 23 | 24 | mlp = [] 25 | 26 | if activation == "relu": 27 | act_fn = functools.partial(nn.ReLU, inplace=True) 28 | elif activation == "leaky_relu": 29 | act_fn = functools.partial(nn.LeakyReLU, inplace=True) 30 | else: 31 | raise NotImplementedError 32 | 33 | for hidden_dim in hidden_layers: 34 | mlp += [nn.Linear(in_features, hidden_dim)] 35 | if bn: 36 | mlp += [nn.BatchNorm1d(hidden_dim)] 37 | mlp += [act_fn()] 38 | if dropout > 0: 39 | mlp += [nn.Dropout(dropout)] 40 | in_features = hidden_dim 41 | 42 | self.mlp = nn.Sequential(*mlp) 43 | 44 | def forward(self, x): 45 | return self.mlp(x) 46 | 47 | 48 | @HEAD_REGISTRY.register() 49 | def mlp(**kwargs): 50 | return MLP(**kwargs) 51 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_network, NETWORK_REGISTRY # isort:skip 2 | 3 | from .ddaig_fcn import ( 4 | fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn 5 | ) 6 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/network/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | NETWORK_REGISTRY = Registry("NETWORK") 4 | 5 | 6 | def build_network(name, verbose=True, **kwargs): 7 | avai_models = NETWORK_REGISTRY.registered_names() 8 | check_availability(name, avai_models) 9 | if verbose: 10 | print("Network: {}".format(name)) 11 | return NETWORK_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmd import MaximumMeanDiscrepancy 2 | from .conv import * 3 | from .dsbn import DSBN1d, DSBN2d 4 | from .mixup import mixup 5 | from .efdmix import ( 6 | EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix, 7 | crossdomain_efdmix, run_without_efdmix 8 | ) 9 | from .mixstyle import ( 10 | MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle, 11 | deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle 12 | ) 13 | from .attention import * 14 | from .transnorm import TransNorm1d, TransNorm2d 15 | from .sequential2 import Sequential2 16 | from .reverse_grad import ReverseGrad 17 | from .cross_entropy import cross_entropy 18 | from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance 19 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | 4 | __all__ = ["Attention"] 5 | 6 | 7 | class Attention(nn.Module): 8 | """Attention from `"Dynamic Domain Generalization" `_. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | in_channels: int, 14 | out_features: int, 15 | squeeze=None, 16 | bias: bool = True 17 | ): 18 | super(Attention, self).__init__() 19 | self.squeeze = squeeze if squeeze else in_channels // 16 20 | assert self.squeeze > 0 21 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 22 | self.fc1 = nn.Linear(in_channels, self.squeeze, bias=bias) 23 | self.fc2 = nn.Linear(self.squeeze, out_features, bias=bias) 24 | self.sf = nn.Softmax(dim=-1) 25 | 26 | def forward(self, x): 27 | x = self.avg_pool(x).view(x.shape[:-2]) 28 | x = self.fc1(x) 29 | x = F.relu(x, inplace=True) 30 | x = self.fc2(x) 31 | return self.sf(x) 32 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .attention import Attention 4 | 5 | __all__ = ["Conv2dDynamic"] 6 | 7 | 8 | class Conv2dDynamic(nn.Module): 9 | """Conv2dDynamic from `"Dynamic Domain Generalization" `_. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | in_channels: int, 15 | out_channels: int, 16 | kernel_size: int, 17 | stride: int, 18 | padding: int, 19 | bias: bool = True, 20 | squeeze: int = None, 21 | attention_in_channels: int = None 22 | ) -> None: 23 | super(Conv2dDynamic, self).__init__() 24 | 25 | if kernel_size // 2 != padding: 26 | # Only when this condition is met, we can ensure that different 27 | # kernel_size can obtain feature maps of consistent size. 28 | # Let I, K, S, P, O: O = (I + 2P - K) // S + 1, if P = K // 2, then O = (I - K % 2) // S + 1 29 | # This means that the output of two different Ks with the same parity can be made the same by adjusting P. 30 | raise ValueError("`padding` must be equal to `kernel_size // 2`.") 31 | if kernel_size % 2 == 0: 32 | raise ValueError( 33 | "Kernel_size must be odd now because the templates we used are odd (kernel_size=1)." 34 | ) 35 | 36 | self.conv = nn.Conv2d( 37 | in_channels, 38 | out_channels, 39 | kernel_size=kernel_size, 40 | stride=stride, 41 | padding=padding, 42 | bias=bias 43 | ) 44 | self.kernel_templates = nn.ModuleDict() 45 | self.kernel_templates["conv_nn"] = nn.Conv2d( 46 | in_channels, 47 | out_channels, 48 | kernel_size=kernel_size, 49 | stride=stride, 50 | padding=padding, 51 | groups=min(in_channels, out_channels), 52 | bias=bias 53 | ) 54 | self.kernel_templates["conv_11"] = nn.Conv2d( 55 | in_channels, 56 | out_channels, 57 | kernel_size=1, 58 | stride=stride, 59 | padding=0, 60 | bias=bias 61 | ) 62 | self.kernel_templates["conv_n1"] = nn.Conv2d( 63 | in_channels, 64 | out_channels, 65 | kernel_size=(kernel_size, 1), 66 | stride=stride, 67 | padding=(padding, 0), 68 | bias=bias 69 | ) 70 | self.kernel_templates["conv_1n"] = nn.Conv2d( 71 | in_channels, 72 | out_channels, 73 | kernel_size=(1, kernel_size), 74 | stride=stride, 75 | padding=(0, padding), 76 | bias=bias 77 | ) 78 | self.attention = Attention( 79 | attention_in_channels if attention_in_channels else in_channels, 80 | 4, 81 | squeeze, 82 | bias=bias 83 | ) 84 | 85 | def forward(self, x, attention_x=None): 86 | attention_x = x if attention_x is None else attention_x 87 | y = self.attention(attention_x) 88 | 89 | out = self.conv(x) 90 | 91 | for i, template in enumerate(self.kernel_templates): 92 | out += self.kernel_templates[template](x) * y[:, 93 | i].view(-1, 1, 1, 1) 94 | 95 | return out 96 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def cross_entropy(input, target, label_smooth=0, reduction="mean"): 6 | """Cross entropy loss. 7 | 8 | Args: 9 | input (torch.Tensor): logit matrix with shape of (batch, num_classes). 10 | target (torch.LongTensor): int label matrix. 11 | label_smooth (float, optional): label smoothing hyper-parameter. 12 | Default is 0. 13 | reduction (str, optional): how the losses for a mini-batch 14 | will be aggregated. Default is 'mean'. 15 | """ 16 | num_classes = input.shape[1] 17 | log_prob = F.log_softmax(input, dim=1) 18 | zeros = torch.zeros(log_prob.size()) 19 | target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1) 20 | target = target.type_as(input) 21 | target = (1-label_smooth) * target + label_smooth/num_classes 22 | loss = (-target * log_prob).sum(1) 23 | if reduction == "mean": 24 | return loss.mean() 25 | elif reduction == "sum": 26 | return loss.sum() 27 | elif reduction == "none": 28 | return loss 29 | else: 30 | raise ValueError 31 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class _DSBN(nn.Module): 5 | """Domain Specific Batch Normalization. 6 | 7 | Args: 8 | num_features (int): number of features. 9 | n_domain (int): number of domains. 10 | bn_type (str): type of bn. Choices are ['1d', '2d']. 11 | """ 12 | 13 | def __init__(self, num_features, n_domain, bn_type): 14 | super().__init__() 15 | if bn_type == "1d": 16 | BN = nn.BatchNorm1d 17 | elif bn_type == "2d": 18 | BN = nn.BatchNorm2d 19 | else: 20 | raise ValueError 21 | 22 | self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain)) 23 | 24 | self.valid_domain_idxs = list(range(n_domain)) 25 | self.n_domain = n_domain 26 | self.domain_idx = 0 27 | 28 | def select_bn(self, domain_idx=0): 29 | assert domain_idx in self.valid_domain_idxs 30 | self.domain_idx = domain_idx 31 | 32 | def forward(self, x): 33 | return self.bn[self.domain_idx](x) 34 | 35 | 36 | class DSBN1d(_DSBN): 37 | 38 | def __init__(self, num_features, n_domain): 39 | super().__init__(num_features, n_domain, "1d") 40 | 41 | 42 | class DSBN2d(_DSBN): 43 | 44 | def __init__(self, num_features, n_domain): 45 | super().__init__(num_features, n_domain, "2d") 46 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/efdmix.py: -------------------------------------------------------------------------------- 1 | import random 2 | from contextlib import contextmanager 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def deactivate_efdmix(m): 8 | if type(m) == EFDMix: 9 | m.set_activation_status(False) 10 | 11 | 12 | def activate_efdmix(m): 13 | if type(m) == EFDMix: 14 | m.set_activation_status(True) 15 | 16 | 17 | def random_efdmix(m): 18 | if type(m) == EFDMix: 19 | m.update_mix_method("random") 20 | 21 | 22 | def crossdomain_efdmix(m): 23 | if type(m) == EFDMix: 24 | m.update_mix_method("crossdomain") 25 | 26 | 27 | @contextmanager 28 | def run_without_efdmix(model): 29 | # Assume MixStyle was initially activated 30 | try: 31 | model.apply(deactivate_efdmix) 32 | yield 33 | finally: 34 | model.apply(activate_efdmix) 35 | 36 | 37 | @contextmanager 38 | def run_with_efdmix(model, mix=None): 39 | # Assume MixStyle was initially deactivated 40 | if mix == "random": 41 | model.apply(random_efdmix) 42 | 43 | elif mix == "crossdomain": 44 | model.apply(crossdomain_efdmix) 45 | 46 | try: 47 | model.apply(activate_efdmix) 48 | yield 49 | finally: 50 | model.apply(deactivate_efdmix) 51 | 52 | 53 | class EFDMix(nn.Module): 54 | """EFDMix. 55 | 56 | Reference: 57 | Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022. 58 | """ 59 | 60 | def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): 61 | """ 62 | Args: 63 | p (float): probability of using MixStyle. 64 | alpha (float): parameter of the Beta distribution. 65 | eps (float): scaling parameter to avoid numerical issues. 66 | mix (str): how to mix. 67 | """ 68 | super().__init__() 69 | self.p = p 70 | self.beta = torch.distributions.Beta(alpha, alpha) 71 | self.eps = eps 72 | self.alpha = alpha 73 | self.mix = mix 74 | self._activated = True 75 | 76 | def __repr__(self): 77 | return ( 78 | f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" 79 | ) 80 | 81 | def set_activation_status(self, status=True): 82 | self._activated = status 83 | 84 | def update_mix_method(self, mix="random"): 85 | self.mix = mix 86 | 87 | def forward(self, x): 88 | if not self.training or not self._activated: 89 | return x 90 | 91 | if random.random() > self.p: 92 | return x 93 | 94 | B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3) 95 | x_view = x.view(B, C, -1) 96 | value_x, index_x = torch.sort(x_view) # sort inputs 97 | lmda = self.beta.sample((B, 1, 1)) 98 | lmda = lmda.to(x.device) 99 | 100 | if self.mix == "random": 101 | # random shuffle 102 | perm = torch.randperm(B) 103 | 104 | elif self.mix == "crossdomain": 105 | # split into two halves and swap the order 106 | perm = torch.arange(B - 1, -1, -1) # inverse index 107 | perm_b, perm_a = perm.chunk(2) 108 | perm_b = perm_b[torch.randperm(perm_b.shape[0])] 109 | perm_a = perm_a[torch.randperm(perm_a.shape[0])] 110 | perm = torch.cat([perm_b, perm_a], 0) 111 | 112 | else: 113 | raise NotImplementedError 114 | 115 | inverse_index = index_x.argsort(-1) 116 | x_view_copy = value_x[perm].gather(-1, inverse_index) 117 | new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda) 118 | return new_x.view(B, C, W, H) 119 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/mixstyle.py: -------------------------------------------------------------------------------- 1 | import random 2 | from contextlib import contextmanager 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def deactivate_mixstyle(m): 8 | if type(m) == MixStyle: 9 | m.set_activation_status(False) 10 | 11 | 12 | def activate_mixstyle(m): 13 | if type(m) == MixStyle: 14 | m.set_activation_status(True) 15 | 16 | 17 | def random_mixstyle(m): 18 | if type(m) == MixStyle: 19 | m.update_mix_method("random") 20 | 21 | 22 | def crossdomain_mixstyle(m): 23 | if type(m) == MixStyle: 24 | m.update_mix_method("crossdomain") 25 | 26 | 27 | @contextmanager 28 | def run_without_mixstyle(model): 29 | # Assume MixStyle was initially activated 30 | try: 31 | model.apply(deactivate_mixstyle) 32 | yield 33 | finally: 34 | model.apply(activate_mixstyle) 35 | 36 | 37 | @contextmanager 38 | def run_with_mixstyle(model, mix=None): 39 | # Assume MixStyle was initially deactivated 40 | if mix == "random": 41 | model.apply(random_mixstyle) 42 | 43 | elif mix == "crossdomain": 44 | model.apply(crossdomain_mixstyle) 45 | 46 | try: 47 | model.apply(activate_mixstyle) 48 | yield 49 | finally: 50 | model.apply(deactivate_mixstyle) 51 | 52 | 53 | class MixStyle(nn.Module): 54 | """MixStyle. 55 | 56 | Reference: 57 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 58 | """ 59 | 60 | def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): 61 | """ 62 | Args: 63 | p (float): probability of using MixStyle. 64 | alpha (float): parameter of the Beta distribution. 65 | eps (float): scaling parameter to avoid numerical issues. 66 | mix (str): how to mix. 67 | """ 68 | super().__init__() 69 | self.p = p 70 | self.beta = torch.distributions.Beta(alpha, alpha) 71 | self.eps = eps 72 | self.alpha = alpha 73 | self.mix = mix 74 | self._activated = True 75 | 76 | def __repr__(self): 77 | return ( 78 | f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" 79 | ) 80 | 81 | def set_activation_status(self, status=True): 82 | self._activated = status 83 | 84 | def update_mix_method(self, mix="random"): 85 | self.mix = mix 86 | 87 | def forward(self, x): 88 | if not self.training or not self._activated: 89 | return x 90 | 91 | if random.random() > self.p: 92 | return x 93 | 94 | B = x.size(0) 95 | 96 | mu = x.mean(dim=[2, 3], keepdim=True) 97 | var = x.var(dim=[2, 3], keepdim=True) 98 | sig = (var + self.eps).sqrt() 99 | mu, sig = mu.detach(), sig.detach() 100 | x_normed = (x-mu) / sig 101 | 102 | lmda = self.beta.sample((B, 1, 1, 1)) 103 | lmda = lmda.to(x.device) 104 | 105 | if self.mix == "random": 106 | # random shuffle 107 | perm = torch.randperm(B) 108 | 109 | elif self.mix == "crossdomain": 110 | # split into two halves and swap the order 111 | perm = torch.arange(B - 1, -1, -1) # inverse index 112 | perm_b, perm_a = perm.chunk(2) 113 | perm_b = perm_b[torch.randperm(perm_b.shape[0])] 114 | perm_a = perm_a[torch.randperm(perm_a.shape[0])] 115 | perm = torch.cat([perm_b, perm_a], 0) 116 | 117 | else: 118 | raise NotImplementedError 119 | 120 | mu2, sig2 = mu[perm], sig[perm] 121 | mu_mix = mu*lmda + mu2 * (1-lmda) 122 | sig_mix = sig*lmda + sig2 * (1-lmda) 123 | 124 | return x_normed*sig_mix + mu_mix 125 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mixup(x1, x2, y1, y2, beta, preserve_order=False): 5 | """Mixup. 6 | 7 | Args: 8 | x1 (torch.Tensor): data with shape of (b, c, h, w). 9 | x2 (torch.Tensor): data with shape of (b, c, h, w). 10 | y1 (torch.Tensor): label with shape of (b, n). 11 | y2 (torch.Tensor): label with shape of (b, n). 12 | beta (float): hyper-parameter for Beta sampling. 13 | preserve_order (bool): apply lmda=max(lmda, 1-lmda). 14 | Default is False. 15 | """ 16 | lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1]) 17 | if preserve_order: 18 | lmda = torch.max(lmda, 1 - lmda) 19 | lmda = lmda.to(x1.device) 20 | xmix = x1*lmda + x2 * (1-lmda) 21 | lmda = lmda[:, :, 0, 0] 22 | ymix = y1*lmda + y2 * (1-lmda) 23 | return xmix, ymix 24 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class MaximumMeanDiscrepancy(nn.Module): 7 | 8 | def __init__(self, kernel_type="rbf", normalize=False): 9 | super().__init__() 10 | self.kernel_type = kernel_type 11 | self.normalize = normalize 12 | 13 | def forward(self, x, y): 14 | # x, y: two batches of data with shape (batch, dim) 15 | # MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y') 16 | if self.normalize: 17 | x = F.normalize(x, dim=1) 18 | y = F.normalize(y, dim=1) 19 | if self.kernel_type == "linear": 20 | return self.linear_mmd(x, y) 21 | elif self.kernel_type == "poly": 22 | return self.poly_mmd(x, y) 23 | elif self.kernel_type == "rbf": 24 | return self.rbf_mmd(x, y) 25 | else: 26 | raise NotImplementedError 27 | 28 | def linear_mmd(self, x, y): 29 | # k(x, y) = x^T y 30 | k_xx = self.remove_self_distance(torch.mm(x, x.t())) 31 | k_yy = self.remove_self_distance(torch.mm(y, y.t())) 32 | k_xy = torch.mm(x, y.t()) 33 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 34 | 35 | def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2): 36 | # k(x, y) = (alpha * x^T y + c)^d 37 | k_xx = self.remove_self_distance(torch.mm(x, x.t())) 38 | k_xx = (alpha*k_xx + c).pow(d) 39 | k_yy = self.remove_self_distance(torch.mm(y, y.t())) 40 | k_yy = (alpha*k_yy + c).pow(d) 41 | k_xy = torch.mm(x, y.t()) 42 | k_xy = (alpha*k_xy + c).pow(d) 43 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 44 | 45 | def rbf_mmd(self, x, y): 46 | # k_xx 47 | d_xx = self.euclidean_squared_distance(x, x) 48 | d_xx = self.remove_self_distance(d_xx) 49 | k_xx = self.rbf_kernel_mixture(d_xx) 50 | # k_yy 51 | d_yy = self.euclidean_squared_distance(y, y) 52 | d_yy = self.remove_self_distance(d_yy) 53 | k_yy = self.rbf_kernel_mixture(d_yy) 54 | # k_xy 55 | d_xy = self.euclidean_squared_distance(x, y) 56 | k_xy = self.rbf_kernel_mixture(d_xy) 57 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 58 | 59 | @staticmethod 60 | def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]): 61 | K = 0 62 | for sigma in sigmas: 63 | gamma = 1.0 / (2.0 * sigma**2) 64 | K += torch.exp(-gamma * exponent) 65 | return K 66 | 67 | @staticmethod 68 | def remove_self_distance(distmat): 69 | tmp_list = [] 70 | for i, row in enumerate(distmat): 71 | row1 = torch.cat([row[:i], row[i + 1:]]) 72 | tmp_list.append(row1) 73 | return torch.stack(tmp_list) 74 | 75 | @staticmethod 76 | def euclidean_squared_distance(x, y): 77 | m, n = x.size(0), y.size(0) 78 | distmat = ( 79 | torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + 80 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 81 | ) 82 | # distmat.addmm_(1, -2, x, y.t()) 83 | distmat.addmm_(x, y.t(), beta=1, alpha=-2) 84 | return distmat 85 | 86 | 87 | if __name__ == "__main__": 88 | mmd = MaximumMeanDiscrepancy(kernel_type="rbf") 89 | input1, input2 = torch.rand(3, 100), torch.rand(3, 100) 90 | d = mmd(input1, input2) 91 | print(d.item()) 92 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/reverse_grad.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Function 3 | 4 | 5 | class _ReverseGrad(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input, grad_scaling): 9 | ctx.grad_scaling = grad_scaling 10 | return input.view_as(input) 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | grad_scaling = ctx.grad_scaling 15 | return -grad_scaling * grad_output, None 16 | 17 | 18 | reverse_grad = _ReverseGrad.apply 19 | 20 | 21 | class ReverseGrad(nn.Module): 22 | """Gradient reversal layer. 23 | 24 | It acts as an identity layer in the forward, 25 | but reverses the sign of the gradient in 26 | the backward. 27 | """ 28 | 29 | def forward(self, x, grad_scaling=1.0): 30 | assert (grad_scaling >= 31 | 0), "grad_scaling must be non-negative, " "but got {}".format( 32 | grad_scaling 33 | ) 34 | return reverse_grad(x, grad_scaling) 35 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/sequential2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Sequential2(nn.Sequential): 5 | """An alternative sequential container to nn.Sequential, 6 | which accepts an arbitrary number of input arguments. 7 | """ 8 | 9 | def forward(self, *inputs): 10 | for module in self._modules.values(): 11 | if isinstance(inputs, tuple): 12 | inputs = module(*inputs) 13 | else: 14 | inputs = module(inputs) 15 | return inputs 16 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/modeling/ops/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def sharpen_prob(p, temperature=2): 6 | """Sharpening probability with a temperature. 7 | 8 | Args: 9 | p (torch.Tensor): probability matrix (batch_size, n_classes) 10 | temperature (float): temperature. 11 | """ 12 | p = p.pow(temperature) 13 | return p / p.sum(1, keepdim=True) 14 | 15 | 16 | def reverse_index(data, label): 17 | """Reverse order.""" 18 | inv_idx = torch.arange(data.size(0) - 1, -1, -1).long() 19 | return data[inv_idx], label[inv_idx] 20 | 21 | 22 | def shuffle_index(data, label): 23 | """Shuffle order.""" 24 | rnd_idx = torch.randperm(data.shape[0]) 25 | return data[rnd_idx], label[rnd_idx] 26 | 27 | 28 | def create_onehot(label, num_classes): 29 | """Create one-hot tensor. 30 | 31 | We suggest using nn.functional.one_hot. 32 | 33 | Args: 34 | label (torch.Tensor): 1-D tensor. 35 | num_classes (int): number of classes. 36 | """ 37 | onehot = torch.zeros(label.shape[0], num_classes) 38 | onehot = onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1) 39 | onehot = onehot.to(label.device) 40 | return onehot 41 | 42 | 43 | def sigmoid_rampup(current, rampup_length): 44 | """Exponential rampup. 45 | 46 | Args: 47 | current (int): current step. 48 | rampup_length (int): maximum step. 49 | """ 50 | assert rampup_length > 0 51 | current = np.clip(current, 0.0, rampup_length) 52 | phase = 1.0 - current/rampup_length 53 | return float(np.exp(-5.0 * phase * phase)) 54 | 55 | 56 | def linear_rampup(current, rampup_length): 57 | """Linear rampup. 58 | 59 | Args: 60 | current (int): current step. 61 | rampup_length (int): maximum step. 62 | """ 63 | assert rampup_length > 0 64 | ratio = np.clip(current / rampup_length, 0.0, 1.0) 65 | return float(ratio) 66 | 67 | 68 | def ema_model_update(model, ema_model, alpha): 69 | """Exponential moving average of model parameters. 70 | 71 | Args: 72 | model (nn.Module): model being trained. 73 | ema_model (nn.Module): ema of the model. 74 | alpha (float): ema decay rate. 75 | """ 76 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 77 | ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) 78 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer import build_optimizer 2 | from .lr_scheduler import build_lr_scheduler 3 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .logger import * 3 | from .meters import * 4 | from .registry import * 5 | from .torchtools import * 6 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import os.path as osp 5 | 6 | from .tools import mkdir_if_missing 7 | 8 | __all__ = ["Logger", "setup_logger"] 9 | 10 | 11 | class Logger: 12 | """Write console output to external text file. 13 | 14 | Imported from ``_ 15 | 16 | Args: 17 | fpath (str): directory to save logging file. 18 | 19 | Examples:: 20 | >>> import sys 21 | >>> import os.path as osp 22 | >>> save_dir = 'output/experiment-1' 23 | >>> log_name = 'train.log' 24 | >>> sys.stdout = Logger(osp.join(save_dir, log_name)) 25 | """ 26 | 27 | def __init__(self, fpath=None): 28 | self.console = sys.stdout 29 | self.file = None 30 | if fpath is not None: 31 | mkdir_if_missing(osp.dirname(fpath)) 32 | self.file = open(fpath, "w") 33 | 34 | def __del__(self): 35 | self.close() 36 | 37 | def __enter__(self): 38 | pass 39 | 40 | def __exit__(self, *args): 41 | self.close() 42 | 43 | def write(self, msg): 44 | self.console.write(msg) 45 | if self.file is not None: 46 | self.file.write(msg) 47 | 48 | def flush(self): 49 | self.console.flush() 50 | if self.file is not None: 51 | self.file.flush() 52 | os.fsync(self.file.fileno()) 53 | 54 | def close(self): 55 | self.console.close() 56 | if self.file is not None: 57 | self.file.close() 58 | 59 | 60 | def setup_logger(output=None): 61 | if output is None: 62 | return 63 | 64 | if output.endswith(".txt") or output.endswith(".log"): 65 | fpath = output 66 | else: 67 | fpath = osp.join(output, "log.txt") 68 | 69 | if osp.exists(fpath): 70 | # make sure the existing log file is not over-written 71 | fpath += time.strftime("-%Y-%m-%d-%H-%M-%S") 72 | 73 | sys.stdout = Logger(fpath) 74 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/utils/meters.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | 4 | __all__ = ["AverageMeter", "MetricMeter"] 5 | 6 | 7 | class AverageMeter: 8 | """Compute and store the average and current value. 9 | 10 | Examples:: 11 | >>> # 1. Initialize a meter to record loss 12 | >>> losses = AverageMeter() 13 | >>> # 2. Update meter after every mini-batch update 14 | >>> losses.update(loss_value, batch_size) 15 | """ 16 | 17 | def __init__(self, ema=False): 18 | """ 19 | Args: 20 | ema (bool, optional): apply exponential moving average. 21 | """ 22 | self.ema = ema 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | if isinstance(val, torch.Tensor): 33 | val = val.item() 34 | 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | 39 | if self.ema: 40 | self.avg = self.avg * 0.9 + self.val * 0.1 41 | else: 42 | self.avg = self.sum / self.count 43 | 44 | 45 | class MetricMeter: 46 | """Store the average and current value for a set of metrics. 47 | 48 | Examples:: 49 | >>> # 1. Create an instance of MetricMeter 50 | >>> metric = MetricMeter() 51 | >>> # 2. Update using a dictionary as input 52 | >>> input_dict = {'loss_1': value_1, 'loss_2': value_2} 53 | >>> metric.update(input_dict) 54 | >>> # 3. Convert to string and print 55 | >>> print(str(metric)) 56 | """ 57 | 58 | def __init__(self, delimiter=" "): 59 | self.meters = defaultdict(AverageMeter) 60 | self.delimiter = delimiter 61 | 62 | def update(self, input_dict): 63 | if input_dict is None: 64 | return 65 | 66 | if not isinstance(input_dict, dict): 67 | raise TypeError( 68 | "Input to MetricMeter.update() must be a dictionary" 69 | ) 70 | 71 | for k, v in input_dict.items(): 72 | if isinstance(v, torch.Tensor): 73 | v = v.item() 74 | self.meters[k].update(v) 75 | 76 | def __str__(self): 77 | output_str = [] 78 | for name, meter in self.meters.items(): 79 | output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})") 80 | return self.delimiter.join(output_str) 81 | -------------------------------------------------------------------------------- /Dassl.pytorch/dassl/utils/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/facebookresearch/fvcore 3 | """ 4 | __all__ = ["Registry"] 5 | 6 | 7 | class Registry: 8 | """A registry providing name -> object mapping, to support 9 | custom modules. 10 | 11 | To create a registry (e.g. a backbone registry): 12 | 13 | .. code-block:: python 14 | 15 | BACKBONE_REGISTRY = Registry('BACKBONE') 16 | 17 | To register an object: 18 | 19 | .. code-block:: python 20 | 21 | @BACKBONE_REGISTRY.register() 22 | class MyBackbone(nn.Module): 23 | ... 24 | 25 | Or: 26 | 27 | .. code-block:: python 28 | 29 | BACKBONE_REGISTRY.register(MyBackbone) 30 | """ 31 | 32 | def __init__(self, name): 33 | self._name = name 34 | self._obj_map = dict() 35 | 36 | def _do_register(self, name, obj, force=False): 37 | if name in self._obj_map and not force: 38 | raise KeyError( 39 | 'An object named "{}" was already ' 40 | 'registered in "{}" registry'.format(name, self._name) 41 | ) 42 | 43 | self._obj_map[name] = obj 44 | 45 | def register(self, obj=None, force=False): 46 | if obj is None: 47 | # Used as a decorator 48 | def wrapper(fn_or_class): 49 | name = fn_or_class.__name__ 50 | self._do_register(name, fn_or_class, force=force) 51 | return fn_or_class 52 | 53 | return wrapper 54 | 55 | # Used as a function call 56 | name = obj.__name__ 57 | self._do_register(name, obj, force=force) 58 | 59 | def get(self, name): 60 | if name not in self._obj_map: 61 | raise KeyError( 62 | 'Object name "{}" does not exist ' 63 | 'in "{}" registry'.format(name, self._name) 64 | ) 65 | 66 | return self._obj_map[name] 67 | 68 | def registered_names(self): 69 | return list(self._obj_map.keys()) 70 | -------------------------------------------------------------------------------- /Dassl.pytorch/datasets/da/cifar_stl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pprint as pp 3 | import os.path as osp 4 | from torchvision.datasets import STL10, CIFAR10 5 | 6 | from dassl.utils import mkdir_if_missing 7 | 8 | cifar_label2name = { 9 | 0: "airplane", 10 | 1: "car", # the original name was 'automobile' 11 | 2: "bird", 12 | 3: "cat", 13 | 4: "deer", 14 | 5: "dog", 15 | 6: "frog", # conflict class 16 | 7: "horse", 17 | 8: "ship", 18 | 9: "truck", 19 | } 20 | 21 | stl_label2name = { 22 | 0: "airplane", 23 | 1: "bird", 24 | 2: "car", 25 | 3: "cat", 26 | 4: "deer", 27 | 5: "dog", 28 | 6: "horse", 29 | 7: "monkey", # conflict class 30 | 8: "ship", 31 | 9: "truck", 32 | } 33 | 34 | new_name2label = { 35 | "airplane": 0, 36 | "bird": 1, 37 | "car": 2, 38 | "cat": 3, 39 | "deer": 4, 40 | "dog": 5, 41 | "horse": 6, 42 | "ship": 7, 43 | "truck": 8, 44 | } 45 | 46 | 47 | def extract_and_save_image(dataset, save_dir, discard, label2name): 48 | if osp.exists(save_dir): 49 | print('Folder "{}" already exists'.format(save_dir)) 50 | return 51 | 52 | print('Extracting images to "{}" ...'.format(save_dir)) 53 | mkdir_if_missing(save_dir) 54 | 55 | for i in range(len(dataset)): 56 | img, label = dataset[i] 57 | if label == discard: 58 | continue 59 | class_name = label2name[label] 60 | label_new = new_name2label[class_name] 61 | class_dir = osp.join( 62 | save_dir, 63 | str(label_new).zfill(3) + "_" + class_name 64 | ) 65 | mkdir_if_missing(class_dir) 66 | impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg") 67 | img.save(impath) 68 | 69 | 70 | def download_and_prepare(name, root, discarded_label, label2name): 71 | print("Dataset: {}".format(name)) 72 | print("Root: {}".format(root)) 73 | print("Old labels:") 74 | pp.pprint(label2name) 75 | print("Discarded label: {}".format(discarded_label)) 76 | print("New labels:") 77 | pp.pprint(new_name2label) 78 | 79 | if name == "cifar": 80 | train = CIFAR10(root, train=True, download=True) 81 | test = CIFAR10(root, train=False) 82 | else: 83 | train = STL10(root, split="train", download=True) 84 | test = STL10(root, split="test") 85 | 86 | train_dir = osp.join(root, name, "train") 87 | test_dir = osp.join(root, name, "test") 88 | 89 | extract_and_save_image(train, train_dir, discarded_label, label2name) 90 | extract_and_save_image(test, test_dir, discarded_label, label2name) 91 | 92 | 93 | if __name__ == "__main__": 94 | download_and_prepare("cifar", sys.argv[1], 6, cifar_label2name) 95 | download_and_prepare("stl", sys.argv[1], 7, stl_label2name) 96 | -------------------------------------------------------------------------------- /Dassl.pytorch/datasets/da/visda17.sh: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # ROOT is the root directory where you put your domain datasets. 3 | # 4 | # Suppose you wanna put the dataset under $DATA, which stores all the 5 | # domain datasets, run the following command in your terminal to 6 | # download VisDa17: 7 | # 8 | # $ sh visda17.sh $DATA 9 | #------------------------------------------------------------------------ 10 | 11 | ROOT=$1 12 | mkdir $ROOT/visda17 13 | cd $ROOT/visda17 14 | 15 | wget http://csr.bu.edu/ftp/visda17/clf/train.tar 16 | tar xvf train.tar 17 | 18 | wget http://csr.bu.edu/ftp/visda17/clf/validation.tar 19 | tar xvf validation.tar 20 | 21 | wget http://csr.bu.edu/ftp/visda17/clf/test.tar 22 | tar xvf test.tar 23 | 24 | wget https://raw.githubusercontent.com/VisionLearningGroup/taskcv-2017-public/master/classification/data/image_list.txt -O test/image_list.txt -------------------------------------------------------------------------------- /Dassl.pytorch/datasets/dg/cifar_c.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script 3 | - creates a folder named "cifar10_c" under the same directory as 'CIFAR-10-C' 4 | - extracts images from .npy files and save them as .jpg. 5 | """ 6 | import os 7 | import sys 8 | import numpy as np 9 | import os.path as osp 10 | from PIL import Image 11 | 12 | from dassl.utils import mkdir_if_missing 13 | 14 | 15 | def extract_and_save(images, labels, level, dst): 16 | # level denotes the corruption intensity level (0-based) 17 | assert 0 <= level <= 4 18 | 19 | for i in range(10000): 20 | real_i = i + level*10000 21 | im = Image.fromarray(images[real_i]) 22 | label = int(labels[real_i]) 23 | category_dir = osp.join(dst, str(label).zfill(3)) 24 | mkdir_if_missing(category_dir) 25 | save_path = osp.join(category_dir, str(i + 1).zfill(5) + ".jpg") 26 | im.save(save_path) 27 | 28 | 29 | def main(npy_folder): 30 | npy_folder = osp.abspath(osp.expanduser(npy_folder)) 31 | dataset_cap = osp.basename(npy_folder) 32 | 33 | assert dataset_cap in ["CIFAR-10-C", "CIFAR-100-C"] 34 | 35 | if dataset_cap == "CIFAR-10-C": 36 | dataset = "cifar10_c" 37 | else: 38 | dataset = "cifar100_c" 39 | 40 | if not osp.exists(npy_folder): 41 | print('The given folder "{}" does not exist'.format(npy_folder)) 42 | 43 | root = osp.dirname(npy_folder) 44 | im_folder = osp.join(root, dataset) 45 | 46 | mkdir_if_missing(im_folder) 47 | 48 | dirnames = os.listdir(npy_folder) 49 | dirnames.remove("labels.npy") 50 | if "README.txt" in dirnames: 51 | dirnames.remove("README.txt") 52 | assert len(dirnames) == 19 53 | labels = np.load(osp.join(npy_folder, "labels.npy")) 54 | 55 | for dirname in dirnames: 56 | corruption = dirname.split(".")[0] 57 | corruption_folder = osp.join(im_folder, corruption) 58 | mkdir_if_missing(corruption_folder) 59 | 60 | npy_filename = osp.join(npy_folder, dirname) 61 | images = np.load(npy_filename) 62 | assert images.shape[0] == 50000 63 | 64 | for level in range(5): 65 | dst = osp.join(corruption_folder, str(level + 1)) 66 | mkdir_if_missing(dst) 67 | print('Saving images to "{}"'.format(dst)) 68 | extract_and_save(images, labels, level, dst) 69 | 70 | 71 | if __name__ == "__main__": 72 | # sys.argv[1] contains the path to CIFAR-10-C or CIFAR-100-C 73 | main(sys.argv[1]) 74 | -------------------------------------------------------------------------------- /Dassl.pytorch/datasets/ssl/cifar10_cifar100_svhn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | from torchvision.datasets import SVHN, CIFAR10, CIFAR100 4 | 5 | from dassl.utils import mkdir_if_missing 6 | 7 | 8 | def extract_and_save_image(dataset, save_dir): 9 | if osp.exists(save_dir): 10 | print('Folder "{}" already exists'.format(save_dir)) 11 | return 12 | 13 | print('Extracting images to "{}" ...'.format(save_dir)) 14 | mkdir_if_missing(save_dir) 15 | 16 | for i in range(len(dataset)): 17 | img, label = dataset[i] 18 | class_dir = osp.join(save_dir, str(label).zfill(3)) 19 | mkdir_if_missing(class_dir) 20 | impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg") 21 | img.save(impath) 22 | 23 | 24 | def download_and_prepare(name, root): 25 | print("Dataset: {}".format(name)) 26 | print("Root: {}".format(root)) 27 | 28 | if name == "cifar10": 29 | train = CIFAR10(root, train=True, download=True) 30 | test = CIFAR10(root, train=False) 31 | elif name == "cifar100": 32 | train = CIFAR100(root, train=True, download=True) 33 | test = CIFAR100(root, train=False) 34 | elif name == "svhn": 35 | train = SVHN(root, split="train", download=True) 36 | test = SVHN(root, split="test", download=True) 37 | else: 38 | raise ValueError 39 | 40 | train_dir = osp.join(root, name, "train") 41 | test_dir = osp.join(root, name, "test") 42 | 43 | extract_and_save_image(train, train_dir) 44 | extract_and_save_image(test, test_dir) 45 | 46 | 47 | if __name__ == "__main__": 48 | download_and_prepare("cifar10", sys.argv[1]) 49 | download_and_prepare("cifar100", sys.argv[1]) 50 | download_and_prepare("svhn", sys.argv[1]) 51 | -------------------------------------------------------------------------------- /Dassl.pytorch/datasets/ssl/stl10.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | from torchvision.datasets import STL10 4 | 5 | from dassl.utils import mkdir_if_missing 6 | 7 | 8 | def extract_and_save_image(dataset, save_dir): 9 | if osp.exists(save_dir): 10 | print('Folder "{}" already exists'.format(save_dir)) 11 | return 12 | 13 | print('Extracting images to "{}" ...'.format(save_dir)) 14 | mkdir_if_missing(save_dir) 15 | 16 | for i in range(len(dataset)): 17 | img, label = dataset[i] 18 | if label == -1: 19 | label_name = "none" 20 | else: 21 | label_name = str(label) 22 | imname = str(i).zfill(6) + "_" + label_name + ".jpg" 23 | impath = osp.join(save_dir, imname) 24 | img.save(impath) 25 | 26 | 27 | def download_and_prepare(root): 28 | train = STL10(root, split="train", download=True) 29 | test = STL10(root, split="test") 30 | unlabeled = STL10(root, split="unlabeled") 31 | 32 | train_dir = osp.join(root, "train") 33 | test_dir = osp.join(root, "test") 34 | unlabeled_dir = osp.join(root, "unlabeled") 35 | 36 | extract_and_save_image(train, train_dir) 37 | extract_and_save_image(test, test_dir) 38 | extract_and_save_image(unlabeled, unlabeled_dir) 39 | 40 | 41 | if __name__ == "__main__": 42 | download_and_prepare(sys.argv[1]) 43 | -------------------------------------------------------------------------------- /Dassl.pytorch/linter.sh: -------------------------------------------------------------------------------- 1 | echo "Running isort" 2 | isort -y -sp . 3 | echo "Done" 4 | 5 | echo "Running yapf" 6 | yapf -i -r -vv -e build . 7 | echo "Done" 8 | 9 | echo "Running flake8" 10 | flake8 . 11 | echo "Done" -------------------------------------------------------------------------------- /Dassl.pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | flake8==3.7.9 2 | yapf==0.29.0 3 | isort==4.3.21 4 | yacs 5 | gdown 6 | tb-nightly 7 | future 8 | scipy 9 | scikit-learn 10 | tqdm 11 | ftfy 12 | regex 13 | wilds==1.2.2 14 | tabulate 15 | -------------------------------------------------------------------------------- /Dassl.pytorch/setup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | from setuptools import setup, find_packages 4 | 5 | 6 | def readme(): 7 | with open('README.md') as f: 8 | content = f.read() 9 | return content 10 | 11 | 12 | def find_version(): 13 | version_file = 'dassl/__init__.py' 14 | with open(version_file, 'r') as f: 15 | exec(compile(f.read(), version_file, 'exec')) 16 | return locals()['__version__'] 17 | 18 | 19 | def numpy_include(): 20 | try: 21 | numpy_include = np.get_include() 22 | except AttributeError: 23 | numpy_include = np.get_numpy_include() 24 | return numpy_include 25 | 26 | 27 | def get_requirements(filename='requirements.txt'): 28 | here = osp.dirname(osp.realpath(__file__)) 29 | with open(osp.join(here, filename), 'r') as f: 30 | requires = [line.replace('\n', '') for line in f.readlines()] 31 | return requires 32 | 33 | 34 | setup( 35 | name='dassl', 36 | version=find_version(), 37 | description='Dassl: Domain adaptation and semi-supervised learning', 38 | author='Kaiyang Zhou', 39 | license='MIT', 40 | long_description=readme(), 41 | url='https://github.com/KaiyangZhou/Dassl.pytorch', 42 | packages=find_packages(), 43 | install_requires=get_requirements(), 44 | keywords=[ 45 | 'Domain Adaptation', 'Domain Generalization', 46 | 'Semi-Supervised Learning', 'Pytorch' 47 | ] 48 | ) 49 | -------------------------------------------------------------------------------- /Dassl.pytorch/tools/replace_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replace text in python files. 3 | """ 4 | import glob 5 | import os.path as osp 6 | import argparse 7 | import fileinput 8 | 9 | EXTENSION = ".py" 10 | 11 | 12 | def is_python_file(filename): 13 | ext = osp.splitext(filename)[1] 14 | return ext == EXTENSION 15 | 16 | 17 | def update_file(filename, text_to_search, replacement_text): 18 | print("Processing {}".format(filename)) 19 | with fileinput.FileInput(filename, inplace=True, backup="") as file: 20 | for line in file: 21 | print(line.replace(text_to_search, replacement_text), end="") 22 | 23 | 24 | def recursive_update(directory, text_to_search, replacement_text): 25 | filenames = glob.glob(osp.join(directory, "*")) 26 | 27 | for filename in filenames: 28 | if osp.isfile(filename): 29 | if not is_python_file(filename): 30 | continue 31 | update_file(filename, text_to_search, replacement_text) 32 | elif osp.isdir(filename): 33 | recursive_update(filename, text_to_search, replacement_text) 34 | else: 35 | raise NotImplementedError 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument( 41 | "file_or_dir", type=str, help="path to file or directory" 42 | ) 43 | parser.add_argument("text_to_search", type=str, help="name to be replaced") 44 | parser.add_argument("replacement_text", type=str, help="new name") 45 | parser.add_argument( 46 | "--ext", type=str, default=".py", help="file extension" 47 | ) 48 | args = parser.parse_args() 49 | 50 | file_or_dir = args.file_or_dir 51 | text_to_search = args.text_to_search 52 | replacement_text = args.replacement_text 53 | extension = args.ext 54 | 55 | global EXTENSION 56 | EXTENSION = extension 57 | 58 | if osp.isfile(file_or_dir): 59 | if not is_python_file(file_or_dir): 60 | return 61 | update_file(file_or_dir, text_to_search, replacement_text) 62 | elif osp.isdir(file_or_dir): 63 | recursive_update(file_or_dir, text_to_search, replacement_text) 64 | else: 65 | raise NotImplementedError 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MLV Lab (Machine Learning and Vision Lab at Korea University) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distribution-Aware Prompt Tuning for Vision-Language Models 2 | Official pytorch implementation of "[Distribution-Aware Prompt Tuning for Vision-Language Models](https://openaccess.thecvf.com/content/ICCV2023/papers/Cho_Distribution-Aware_Prompt_Tuning_for_Vision-Language_Models_ICCV_2023_paper.pdf)" (ICCV 2023). 3 |
4 | 5 |
6 | 7 | ## Setup 8 | ### Clone repository 9 | ``` 10 | git clone https://github.com/mlvlab/DAPT.git 11 | cd DAPT 12 | ``` 13 | 14 | ### Prepare dataset 15 | Follow [DATASET.md](DATASET.md) to install the datasets. 16 | 17 | ### Setup conda environment 18 | Before creating the environment, you should modify appropriate conda path in `env.yaml` 19 | 20 | ``` 21 | conda env create —-file env.yaml 22 | conda activate dapt 23 | pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html 24 | ``` 25 | 26 | ### Setup [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch) package 27 | ``` 28 | cd Dassl.pytorch 29 | python setup.py develop 30 | cd .. 31 | ``` 32 | 33 | ## Run 34 | ### Dataset path setting 35 | Modify the data path `$DATA` in `main.sh`, `gen_prototype.sh`, and `eval.sh` to match the path to the dataset you downloaded. 36 | 37 | ### Generate prototype 38 | When the dataset is ready, you can generate the prototype as follows. 39 | 40 | ``` 41 | bash scripts/gen_prototype.sh [gpu_id] 42 | ``` 43 | 44 | ### Few-shot image classification 45 | Below is an example of Caltech101 for each shot. 46 | 47 | Note that for ImageNet, we use `configs/trainers/DAPT/vit_b16_ep50.yaml` for all settings following [CoOp](https://arxiv.org/abs/2109.01134). 48 | 49 | ``` 50 | # 1shot 51 | bash scripts/main.sh caltech101 1 [gpu_id] 52 | # 2shots 53 | bash scripts/main.sh caltech101 2 [gpu_id] 54 | # 4shots 55 | bash scripts/main.sh caltech101 4 [gpu_id] 56 | # 8shots 57 | bash scripts/main.sh caltech101 8 [gpu_id] 58 | # 16shots 59 | bash scripts/main.sh caltech101 16 [gpu_id] 60 | ``` 61 | 62 | ### Domain generalization 63 | Before domain generalization, you should completed few-shot image classification on ImageNet. 64 | 65 | After the few-shot image classification experiment on ImageNet is finished, you can load the model learned on ImageNet using `--eval-only` command to conduct domain generalization on `imagenetv2`, `imagenet-sketch`, `imagenet-a`, and `imagenet-r`. 66 | ``` 67 | bash scripts/eval.sh [gpu_id] 68 | ``` 69 | 70 | ## Acknowledgement 71 | This repository is built upon [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch), [CoOp](https://github.com/KaiyangZhou/CoOp), and [VPT](https://github.com/KMnP/vpt). We thank the authors for their code. 72 | 73 | ## Citation 74 | If you use this code in your research, please kindly cite the following paper: 75 | ``` 76 | @InProceedings{Cho_2023_ICCV, 77 | author = {Cho, Eulrang and Kim, Jooyeon and Kim, Hyunwoo J}, 78 | title = {Distribution-Aware Prompt Tuning for Vision-Language Models}, 79 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 80 | month = {October}, 81 | year = {2023}, 82 | pages = {22004-22013} 83 | } 84 | ``` 85 | 86 | ## License 87 | Licensed under [MIT License](https://github.com/mlvlab/DAPT/blob/master/LICENSE) 88 | > Copyright (c) 2023 MLV Lab (Machine Learning and Vision Lab at Korea University) 89 | -------------------------------------------------------------------------------- /asset/dapt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/DAPT/cacb311eee17acb0781e9c0aec628b57c112cd5e/asset/dapt.png -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/DAPT/cacb311eee17acb0781e9c0aec628b57c112cd5e/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | OPTIM: 4 | LR: 0.2 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 10.0 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | OPTIM: 4 | LR: 20.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 100.0 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | OPTIM: 4 | LR: 20.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 100.0 8 | TXT_BETA: 10.0 -------------------------------------------------------------------------------- /configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | OPTIM: 4 | LR: 2.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 100.0 8 | TXT_BETA: 0.01 9 | -------------------------------------------------------------------------------- /configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | OPTIM: 4 | LR: 20.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 10.0 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | OPTIM: 4 | LR: 2.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 0.01 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | OPTIM: 4 | LR: 2.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 0.01 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | OPTIM: 4 | LR: 2.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 0.01 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | OPTIM: 4 | LR: 2.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 0.01 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | OPTIM: 4 | LR: 2.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 0.01 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" 3 | OPTIM: 4 | LR: 0.002 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 10.0 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" 3 | OPTIM: 4 | LR: 0.02 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 10.0 8 | TXT_BETA: 0.1 -------------------------------------------------------------------------------- /configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | OPTIM: 4 | LR: 0.002 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 100.0 8 | TXT_BETA: 0.1 -------------------------------------------------------------------------------- /configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | OPTIM: 4 | LR: 20.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 100.0 8 | TXT_BETA: 0.01 -------------------------------------------------------------------------------- /configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | OPTIM: 4 | LR: 20.0 5 | TRAINER: 6 | DAPT: 7 | VIS_BETA: 0.1 8 | TXT_BETA: 0.1 -------------------------------------------------------------------------------- /configs/trainers/DAPT/vit_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | MAX_EPOCH: 200 18 | LR_SCHEDULER: "cosine" 19 | WARMUP_EPOCH: 1 20 | WARMUP_TYPE: "constant" 21 | WARMUP_CONS_LR: 1e-5 22 | 23 | TRAIN: 24 | PRINT_FREQ: 5 25 | 26 | MODEL: 27 | BACKBONE: 28 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/DAPT/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | MAX_EPOCH: 100 18 | LR_SCHEDULER: "cosine" 19 | WARMUP_EPOCH: 1 20 | WARMUP_TYPE: "constant" 21 | WARMUP_CONS_LR: 1e-5 22 | 23 | TRAIN: 24 | PRINT_FREQ: 5 25 | 26 | MODEL: 27 | BACKBONE: 28 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/DAPT/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | MAX_EPOCH: 50 18 | LR_SCHEDULER: "cosine" 19 | WARMUP_EPOCH: 1 20 | WARMUP_TYPE: "constant" 21 | WARMUP_CONS_LR: 1e-5 22 | 23 | TRAIN: 24 | PRINT_FREQ: 5 25 | 26 | MODEL: 27 | BACKBONE: 28 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/DAPT/cacb311eee17acb0781e9c0aec628b57c112cd5e/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 11 | NEW_CNAMES = { 12 | "airplanes": "airplane", 13 | "Faces": "face", 14 | "Leopards": "leopard", 15 | "Motorbikes": "motorbike", 16 | } 17 | 18 | 19 | @DATASET_REGISTRY.register() 20 | class Caltech101(DatasetBase): 21 | 22 | dataset_dir = "caltech-101" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 28 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 29 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 30 | mkdir_if_missing(self.split_fewshot_dir) 31 | 32 | if os.path.exists(self.split_path): 33 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 34 | else: 35 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | NEW_CNAMES = { 11 | "AnnualCrop": "Annual Crop Land", 12 | "Forest": "Forest", 13 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 14 | "Highway": "Highway or Road", 15 | "Industrial": "Industrial Buildings", 16 | "Pasture": "Pasture Land", 17 | "PermanentCrop": "Permanent Crop Land", 18 | "Residential": "Residential Buildings", 19 | "River": "River", 20 | "SeaLake": "Sea or Lake", 21 | } 22 | 23 | 24 | @DATASET_REGISTRY.register() 25 | class EuroSAT(DatasetBase): 26 | 27 | dataset_dir = "eurosat" 28 | 29 | def __init__(self, cfg): 30 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 31 | self.dataset_dir = os.path.join(root, self.dataset_dir) 32 | self.image_dir = os.path.join(self.dataset_dir, "2750") 33 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 34 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 35 | mkdir_if_missing(self.split_fewshot_dir) 36 | 37 | if os.path.exists(self.split_path): 38 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 39 | else: 40 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 41 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 42 | 43 | num_shots = cfg.DATASET.NUM_SHOTS 44 | if num_shots >= 1: 45 | seed = cfg.SEED 46 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 47 | 48 | if os.path.exists(preprocessed): 49 | print(f"Loading preprocessed few-shot data from {preprocessed}") 50 | with open(preprocessed, "rb") as file: 51 | data = pickle.load(file) 52 | train, val = data["train"], data["val"] 53 | else: 54 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 55 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 56 | data = {"train": train, "val": val} 57 | print(f"Saving preprocessed few-shot data to {preprocessed}") 58 | with open(preprocessed, "wb") as file: 59 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 60 | 61 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 62 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 63 | 64 | super().__init__(train_x=train, val=val, test=test) 65 | 66 | def update_classname(self, dataset_old): 67 | dataset_new = [] 68 | for item_old in dataset_old: 69 | cname_old = item_old.classname 70 | cname_new = NEW_CLASSNAMES[cname_old] 71 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 72 | dataset_new.append(item_new) 73 | return dataset_new 74 | -------------------------------------------------------------------------------- /datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class FGVCAircraft(DatasetBase): 12 | 13 | dataset_dir = "fgvc_aircraft" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "images_variant_train.txt") 30 | val = self.read_data(cname2lab, "images_variant_val.txt") 31 | test = self.read_data(cname2lab, "images_variant_test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, cname2lab, split_file): 57 | filepath = os.path.join(self.dataset_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip().split(" ") 64 | imname = line[0] + ".jpg" 65 | classname = " ".join(line[1:]) 66 | impath = os.path.join(self.image_dir, imname) 67 | label = cname2lab[classname] 68 | item = Datum(impath=impath, label=label, classname=classname) 69 | items.append(item) 70 | 71 | return items 72 | -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Food101(DatasetBase): 13 | 14 | dataset_dir = "food-101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = DTD.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNet(DatasetBase): 13 | 14 | dataset_dir = "imagenet" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.preprocessed): 25 | with open(self.preprocessed, "rb") as f: 26 | preprocessed = pickle.load(f) 27 | train = preprocessed["train"] 28 | test = preprocessed["test"] 29 | else: 30 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 31 | classnames = self.read_classnames(text_file) 32 | train = self.read_data(classnames, "train") 33 | # Follow standard practice to perform evaluation on the val set 34 | # Also used as the val set (so evaluate the last-step model) 35 | test = self.read_data(classnames, "val") 36 | 37 | preprocessed = {"train": train, "test": test} 38 | with open(self.preprocessed, "wb") as f: 39 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train = data["train"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | data = {"train": train} 54 | print(f"Saving preprocessed few-shot data to {preprocessed}") 55 | with open(preprocessed, "wb") as file: 56 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 57 | 58 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 59 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 60 | 61 | super().__init__(train_x=train, val=test, test=test) 62 | 63 | @staticmethod 64 | def read_classnames(text_file): 65 | """Return a dictionary containing 66 | key-value pairs of : . 67 | """ 68 | classnames = OrderedDict() 69 | with open(text_file, "r") as f: 70 | lines = f.readlines() 71 | for line in lines: 72 | line = line.strip().split(" ") 73 | folder = line[0] 74 | classname = " ".join(line[1:]) 75 | classnames[folder] = classname 76 | return classnames 77 | 78 | def read_data(self, classnames, split_dir): 79 | split_dir = os.path.join(self.image_dir, split_dir) 80 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 81 | items = [] 82 | 83 | for label, folder in enumerate(folders): 84 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 85 | classname = classnames[folder] 86 | for imname in imnames: 87 | impath = os.path.join(split_dir, folder, imname) 88 | item = Datum(impath=impath, label=label, classname=classname) 89 | items.append(item) 90 | 91 | return items 92 | -------------------------------------------------------------------------------- /datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetA(DatasetBase): 13 | """ImageNet-A(dversarial). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-adversarial" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetR(DatasetBase): 13 | """ImageNet-R(endition). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-rendition" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetSketch(DatasetBase): 11 | """ImageNet-Sketch. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenet-sketch" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "images") 22 | 23 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 24 | classnames = ImageNet.read_classnames(text_file) 25 | 26 | data = self.read_data(classnames) 27 | 28 | super().__init__(train_x=data, test=data) 29 | 30 | def read_data(self, classnames): 31 | image_dir = self.image_dir 32 | folders = listdir_nohidden(image_dir, sort=True) 33 | items = [] 34 | 35 | for label, folder in enumerate(folders): 36 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 37 | classname = classnames[folder] 38 | for imname in imnames: 39 | impath = os.path.join(image_dir, folder, imname) 40 | item = Datum(impath=impath, label=label, classname=classname) 41 | items.append(item) 42 | 43 | return items 44 | -------------------------------------------------------------------------------- /datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetV2(DatasetBase): 11 | """ImageNetV2. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenetv2" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | image_dir = "imagenetv2-matched-frequency-format-val" 22 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 23 | 24 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 25 | classnames = ImageNet.read_classnames(text_file) 26 | 27 | data = self.read_data(classnames) 28 | 29 | super().__init__(train_x=data, test=data) 30 | 31 | def read_data(self, classnames): 32 | image_dir = self.image_dir 33 | folders = list(classnames.keys()) 34 | items = [] 35 | 36 | for label in range(1000): 37 | class_dir = os.path.join(image_dir, str(label)) 38 | imnames = listdir_nohidden(class_dir) 39 | folder = folders[label] 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(class_dir, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from scipy.io import loadmat 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, mkdir_if_missing 9 | 10 | from .oxford_pets import OxfordPets 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class OxfordFlowers(DatasetBase): 15 | 16 | dataset_dir = "oxford_flowers" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 23 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = self.read_data() 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self): 58 | tracker = defaultdict(list) 59 | label_file = loadmat(self.label_file)["labels"][0] 60 | for i, label in enumerate(label_file): 61 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 62 | impath = os.path.join(self.image_dir, imname) 63 | label = int(label) 64 | tracker[label].append(impath) 65 | 66 | print("Splitting data into 50% train, 20% val, and 30% test") 67 | 68 | def _collate(ims, y, c): 69 | items = [] 70 | for im in ims: 71 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label 72 | items.append(item) 73 | return items 74 | 75 | lab2cname = read_json(self.lab2cname_file) 76 | train, val, test = [], [], [] 77 | for label, impaths in tracker.items(): 78 | random.shuffle(impaths) 79 | n_total = len(impaths) 80 | n_train = round(n_total * 0.5) 81 | n_val = round(n_total * 0.2) 82 | n_test = n_total - n_train - n_val 83 | assert n_train > 0 and n_val > 0 and n_test > 0 84 | cname = lab2cname[str(label)] 85 | train.extend(_collate(impaths[:n_train], label, cname)) 86 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname)) 87 | test.extend(_collate(impaths[n_train + n_val :], label, cname)) 88 | 89 | return train, val, test 90 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.io import loadmat 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class StanfordCars(DatasetBase): 13 | 14 | dataset_dir = "stanford_cars" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 25 | else: 26 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 27 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 28 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 29 | trainval = self.read_data("cars_train", trainval_file, meta_file) 30 | test = self.read_data("cars_test", test_file, meta_file) 31 | train, val = OxfordPets.split_trainval(trainval) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self, image_dir, anno_file, meta_file): 58 | anno_file = loadmat(anno_file)["annotations"][0] 59 | meta_file = loadmat(meta_file)["class_names"][0] 60 | items = [] 61 | 62 | for i in range(len(anno_file)): 63 | imname = anno_file[i]["fname"][0] 64 | impath = os.path.join(self.dataset_dir, image_dir, imname) 65 | label = anno_file[i]["class"][0, 0] 66 | label = int(label) - 1 # convert to 0-based index 67 | classname = meta_file[label][0] 68 | names = classname.split(" ") 69 | year = names.pop(-1) 70 | names.insert(0, year) 71 | classname = " ".join(names) 72 | item = Datum(impath=impath, label=label, classname=classname) 73 | items.append(item) 74 | 75 | return items 76 | -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = "sun397" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 25 | else: 26 | classnames = [] 27 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | line = line.strip()[1:] # remove / 31 | classnames.append(line) 32 | cname2lab = {c: i for i, c in enumerate(classnames)} 33 | trainval = self.read_data(cname2lab, "Training_01.txt") 34 | test = self.read_data(cname2lab, "Testing_01.txt") 35 | train, val = OxfordPets.split_trainval(trainval) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | 61 | def read_data(self, cname2lab, text_file): 62 | text_file = os.path.join(self.dataset_dir, text_file) 63 | items = [] 64 | 65 | with open(text_file, "r") as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | imname = line.strip()[1:] # remove / 69 | classname = os.path.dirname(imname) 70 | label = cname2lab[classname] 71 | impath = os.path.join(self.image_dir, imname) 72 | 73 | names = classname.split("/")[1:] # remove 1st letter 74 | names = names[::-1] # put words like indoor/outdoor at first 75 | classname = " ".join(names) 76 | 77 | item = Datum(impath=impath, label=label, classname=classname) 78 | items.append(item) 79 | 80 | return items 81 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class UCF101(DatasetBase): 13 | 14 | dataset_dir = "ucf101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | cname2lab = {} 28 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 29 | with open(filepath, "r") as f: 30 | lines = f.readlines() 31 | for line in lines: 32 | label, classname = line.strip().split(" ") 33 | label = int(label) - 1 # conver to 0-based index 34 | cname2lab[classname] = label 35 | 36 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 37 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 38 | train, val = OxfordPets.split_trainval(trainval) 39 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train, val = data["train"], data["val"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 54 | data = {"train": train, "val": val} 55 | print(f"Saving preprocessed few-shot data to {preprocessed}") 56 | with open(preprocessed, "wb") as file: 57 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 58 | 59 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 60 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 61 | 62 | super().__init__(train_x=train, val=val, test=test) 63 | 64 | def read_data(self, cname2lab, text_file): 65 | text_file = os.path.join(self.dataset_dir, text_file) 66 | items = [] 67 | 68 | with open(text_file, "r") as f: 69 | lines = f.readlines() 70 | for line in lines: 71 | line = line.strip().split(" ")[0] # trainlist: filename, label 72 | action, filename = line.split("/") 73 | label = cname2lab[action] 74 | 75 | elements = re.findall("[A-Z][^A-Z]*", action) 76 | renamed_action = "_".join(elements) 77 | 78 | filename = filename.replace(".avi", ".jpg") 79 | impath = os.path.join(self.image_dir, renamed_action, filename) 80 | 81 | item = Datum(impath=impath, label=label, classname=renamed_action) 82 | items.append(item) 83 | 84 | return items 85 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/path/to/datasets 5 | TRAINER=DAPT 6 | CFG=vit_b16_ep50 7 | SHOTS=16 8 | 9 | DEVICE=$1 10 | 11 | for DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r 12 | do 13 | for SEED in 1 2 3 14 | do 15 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 16 | if [ -d "$DIR" ]; then 17 | echo "Oops! The results exist at ${DIR} (so skip this job)" 18 | else 19 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 20 | python train.py \ 21 | --root ${DATA} \ 22 | --seed ${SEED} \ 23 | --trainer ${TRAINER} \ 24 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 25 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 26 | --output-dir ${DIR} \ 27 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 28 | --load-epoch 50 \ 29 | --eval-only 30 | fi 31 | done 32 | done -------------------------------------------------------------------------------- /scripts/gen_prototype.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/path/to/datasets 5 | TRAINER=DAPT 6 | 7 | DEVICE=$1 8 | 9 | for DATASET in caltech101 dtd eurosat fgvc_aircraft food101 imagenet oxford_flowers oxford_pets stanford_cars sun397 ucf101 10 | do 11 | for SHOTS in 1 2 4 8 16 12 | do 13 | if [ ${DATASET} == "imagenet" ]; then 14 | CFG=vit_b16_ep50 15 | elif [ ${SHOTS} -eq 1 ]; then 16 | CFG=vit_b16_ep50 17 | elif [ ${SHOTS} -eq 2 ] || [ ${SHOTS} -eq 4 ]; then 18 | CFG=vit_b16_ep100 19 | elif [ ${SHOTS} -eq 8 ] || [ ${SHOTS} -eq 16 ]; then 20 | CFG=vit_b16 21 | fi 22 | 23 | for SEED in 1 2 3 24 | do 25 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 26 | python train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 32 | DATASET.NUM_SHOTS ${SHOTS} \ 33 | TRAINER.DAPT.PROTOTYPE_GEN True 34 | done 35 | done 36 | done -------------------------------------------------------------------------------- /scripts/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # custom config 4 | DATA=/path/to/datasets 5 | TRAINER=DAPT 6 | 7 | DATASET=$1 8 | SHOTS=$2 9 | DEVICE=$3 10 | 11 | for SEED in 1 2 3 12 | do 13 | if [ ${DATASET} == "imagenet" ]; then 14 | CFG=vit_b16_ep50 15 | elif [ ${SHOTS} -eq 1 ]; then 16 | CFG=vit_b16_ep50 17 | elif [ ${SHOTS} -eq 2 ] || [ ${SHOTS} -eq 4 ]; then 18 | CFG=vit_b16_ep100 19 | elif [ ${SHOTS} -eq 8 ] || [ ${SHOTS} -eq 16 ]; then 20 | CFG=vit_b16 21 | fi 22 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 23 | if [ -d "$DIR" ]; then 24 | echo "Oops! The results exist at ${DIR} (so skip this job)" 25 | else 26 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 27 | python train.py \ 28 | --root ${DATA} \ 29 | --seed ${SEED} \ 30 | --trainer ${TRAINER} \ 31 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 32 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 33 | --output-dir ${DIR} \ 34 | DATASET.NUM_SHOTS ${SHOTS} 35 | fi 36 | done -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/DAPT/cacb311eee17acb0781e9c0aec628b57c112cd5e/trainers/__init__.py --------------------------------------------------------------------------------