├── .gitattributes ├── .gitignore ├── .gitkeep ├── CPL ├── COCOOP.md ├── COOP.md ├── DATASETS.md ├── Dassl.pytorch │ ├── .flake8 │ ├── .gitignore │ ├── .isort.cfg │ ├── .style.yapf │ ├── DATASETS.md │ ├── LICENSE │ ├── dassl │ │ ├── __init__.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ └── defaults.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── data_manager.py │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ ├── base_dataset.py │ │ │ │ ├── build.py │ │ │ │ ├── da │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cifarstl.py │ │ │ │ │ ├── digit5.py │ │ │ │ │ ├── domainnet.py │ │ │ │ │ ├── mini_domainnet.py │ │ │ │ │ ├── office31.py │ │ │ │ │ ├── office_home.py │ │ │ │ │ └── visda17.py │ │ │ │ ├── dg │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cifar_c.py │ │ │ │ │ ├── digit_single.py │ │ │ │ │ ├── digits_dg.py │ │ │ │ │ ├── office_home_dg.py │ │ │ │ │ ├── pacs.py │ │ │ │ │ └── vlcs.py │ │ │ │ └── ssl │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cifar.py │ │ │ │ │ ├── stl10.py │ │ │ │ │ └── svhn.py │ │ │ ├── samplers.py │ │ │ └── transforms │ │ │ │ ├── __init__.py │ │ │ │ ├── autoaugment.py │ │ │ │ ├── randaugment.py │ │ │ │ └── transforms.py │ │ ├── engine │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── da │ │ │ │ ├── __init__.py │ │ │ │ ├── adabn.py │ │ │ │ ├── adda.py │ │ │ │ ├── dael.py │ │ │ │ ├── dann.py │ │ │ │ ├── m3sda.py │ │ │ │ ├── mcd.py │ │ │ │ ├── mme.py │ │ │ │ ├── self_ensembling.py │ │ │ │ └── source_only.py │ │ │ ├── dg │ │ │ │ ├── __init__.py │ │ │ │ ├── crossgrad.py │ │ │ │ ├── daeldg.py │ │ │ │ ├── ddaig.py │ │ │ │ └── vanilla.py │ │ │ ├── ssl │ │ │ │ ├── __init__.py │ │ │ │ ├── entmin.py │ │ │ │ ├── fixmatch.py │ │ │ │ ├── mean_teacher.py │ │ │ │ ├── mixmatch.py │ │ │ │ └── sup_baseline.py │ │ │ └── trainer.py │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── evaluator.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── accuracy.py │ │ │ └── distance.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── backbone │ │ │ │ ├── __init__.py │ │ │ │ ├── alexnet.py │ │ │ │ ├── backbone.py │ │ │ │ ├── build.py │ │ │ │ ├── cnn_digit5_m3sda.py │ │ │ │ ├── cnn_digitsdg.py │ │ │ │ ├── cnn_digitsingle.py │ │ │ │ ├── efficientnet │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── model.py │ │ │ │ │ └── utils.py │ │ │ │ ├── mobilenetv2.py │ │ │ │ ├── preact_resnet18.py │ │ │ │ ├── resnet.py │ │ │ │ ├── shufflenetv2.py │ │ │ │ ├── vgg.py │ │ │ │ └── wide_resnet.py │ │ │ ├── head │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ └── mlp.py │ │ │ ├── network │ │ │ │ ├── __init__.py │ │ │ │ ├── build.py │ │ │ │ └── ddaig_fcn.py │ │ │ └── ops │ │ │ │ ├── __init__.py │ │ │ │ ├── cross_entropy.py │ │ │ │ ├── dsbn.py │ │ │ │ ├── efdmix.py │ │ │ │ ├── mixstyle.py │ │ │ │ ├── mixup.py │ │ │ │ ├── mmd.py │ │ │ │ ├── optimal_transport.py │ │ │ │ ├── reverse_grad.py │ │ │ │ ├── sequential2.py │ │ │ │ ├── transnorm.py │ │ │ │ └── utils.py │ │ ├── optim │ │ │ ├── __init__.py │ │ │ ├── lr_scheduler.py │ │ │ ├── optimizer.py │ │ │ └── radam.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── logger.py │ │ │ ├── meters.py │ │ │ ├── registry.py │ │ │ ├── tools.py │ │ │ └── torchtools.py │ ├── datasets │ │ ├── da │ │ │ ├── cifar_stl.py │ │ │ └── digit5.py │ │ ├── dg │ │ │ └── cifar_c.py │ │ └── ssl │ │ │ ├── cifar10_cifar100_svhn.py │ │ │ └── stl10.py │ ├── setup.py │ └── tools │ │ ├── __init__.py │ │ ├── parse_test_res.py │ │ ├── replace_text.py │ │ └── train.py ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── dataset_processing │ ├── README.md │ ├── VQA │ │ ├── README.md │ │ ├── T5_filtered_prompt.json │ │ ├── all.json │ │ ├── answer_vocab.json │ │ ├── lib │ │ │ ├── rule.py │ │ │ ├── vqaEvaluation │ │ │ │ ├── __init__.py │ │ │ │ └── vqaEval.py │ │ │ └── vqaTools │ │ │ │ ├── __init__.py │ │ │ │ ├── vqa.py │ │ │ │ └── vqa_dataloader.py │ │ ├── number_category.json │ │ ├── other_category.json │ │ ├── split.py │ │ └── yesno_category.json │ ├── classification │ │ └── README.md │ ├── flickr30k │ │ ├── README.md │ │ ├── dataset_flickr30k.json │ │ └── split.py │ ├── flickr8k │ │ ├── README.md │ │ └── split.py │ └── mscoco │ │ ├── README.md │ │ └── split.py ├── datasets │ ├── __init__.py │ ├── caltech101.py │ ├── dtd.py │ ├── eurosat.py │ ├── fgvc_aircraft.py │ ├── flickr30k.py │ ├── flickr8k.py │ ├── food101.py │ ├── imagenet.py │ ├── imagenet_a.py │ ├── imagenet_r.py │ ├── imagenet_sketch.py │ ├── imagenetv2.py │ ├── mscoco.py │ ├── oxford_flowers.py │ ├── oxford_pets.py │ ├── stanford_cars.py │ ├── sun397.py │ ├── ucf101.py │ └── vqav2.py ├── interpret_prompt.py ├── lib │ ├── rule.py │ ├── vqaEvaluation │ │ ├── __init__.py │ │ └── vqaEval.py │ └── vqaTools │ │ ├── __init__.py │ │ ├── vqa.py │ │ └── vqa_dataloader.py ├── linear_probing │ ├── README.md │ ├── feat_extractor.py │ └── linear_probe.py ├── losses │ └── losses.py ├── parse_test_res.py ├── prompt │ ├── infilled_template │ │ ├── number.json │ │ ├── other.json │ │ └── yesno.json │ ├── number_template.json │ ├── other_T5filtered_answers.json │ ├── other_template.json │ └── yesno_template.json ├── requirements.txt ├── scripts │ ├── .vscode-upload.json │ ├── cocoop │ │ ├── .bash_logout │ │ └── .bashrc │ └── zsclip │ │ ├── .bash_logout │ │ └── .bashrc ├── train.py ├── train_cf.py ├── trainers │ ├── __init__.py │ ├── cocoop.py │ ├── cocoopcf.py │ ├── coop.py │ ├── imagenet_templates.py │ ├── prompters.py │ └── zsclip.py └── utils │ ├── __init__.py │ ├── _utils │ ├── vqa_clip_inference.py │ └── vqa_collect_Eval.py │ ├── comm.py │ ├── configs.py │ ├── loader.py │ ├── model.py │ ├── solver.py │ ├── utils.py │ └── visualization.py ├── LICENSE ├── README.md └── assets ├── .gitkeep └── motivation.png /.gitattributes: -------------------------------------------------------------------------------- 1 | *.json filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.pth 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Custom 132 | output/ 133 | */output/ 134 | debug.sh 135 | configs/ 136 | *.sh 137 | *.png 138 | *.jpg 139 | *.pth 140 | *.npy 141 | *.checkpoint 142 | */temp/* 143 | -------------------------------------------------------------------------------- /.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/.gitkeep -------------------------------------------------------------------------------- /CPL/COCOOP.md: -------------------------------------------------------------------------------- 1 | ## How to Run 2 | 3 | The running scripts are provided in `scripts/cocoop/`. Make sure you change the path in `DATA` and run the commands under `CoOp/scripts/cocoop/`. 4 | 5 | ### Generalization From Base to New Classes 6 | 7 | This corresponds to the experiments in Section 4.1, i.e., Table 1. 8 | 9 | You will need both `scripts/cocoop/base2new_train.sh` and `scripts/cocoop/base2new_test.sh`. The former trains a model on base classes while the latter evaluates the trained model on new classes. Both scripts have two input arguments, i.e., `DATASET` and `SEED`. 10 | 11 | `DATASET` takes as input a dataset name, like `imagenet` or `caltech101`. The valid names are the files' names in `CoOp/configs/datasets/`. 12 | 13 | Below we provide an example on how to evaluate the model on ImageNet. 14 | 15 | ```bash 16 | # seed=1 17 | base base2new_train.sh imagenet 1 18 | base base2new_test.sh imagenet 1 19 | 20 | # seed=2 21 | base base2new_train.sh imagenet 2 22 | base base2new_test.sh imagenet 2 23 | 24 | # seed=3 25 | base base2new_train.sh imagenet 3 26 | base base2new_test.sh imagenet 3 27 | ``` 28 | 29 | When the evaluation is done, you can use `parse_test_res.py` to automatically calculate the average results. For instance, after you finish the evaluation (including `base2new_train.sh` and `base2new_test.sh`) on ImageNet using the aforementioned commands, you would get 30 | 31 | ``` 32 | output 33 | |–– base2new/ 34 | | |–– test_new/ 35 | | | |–– imagenet/ 36 | | | | |–– shots_16/ 37 | | | | | |–– CoCoOp/ 38 | | | | | | |–– vit_b16_c4_ep10_batch1_ctxv1/ 39 | | | | | | | |–– seed1/ 40 | | | | | | | |–– seed2/ 41 | | | | | | | |–– seed3/ 42 | | |–– train_base/ 43 | | | |–– imagenet/ 44 | | | | |–– shots_16/ 45 | | | | | |–– CoCoOp/ 46 | | | | | | |–– vit_b16_c4_ep10_batch1_ctxv1/ 47 | | | | | | | |–– seed1/ 48 | | | | | | | |–– seed2/ 49 | | | | | | | |–– seed3/ 50 | ``` 51 | 52 | Then, to get the average performance on the base classes, run 53 | 54 | ```bash 55 | python parse_test_res.py output/base2new/train_base/imagenet/shots_16/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1 56 | ``` 57 | 58 | To get the average performance on the new classes, run 59 | 60 | ```bash 61 | python parse_test_res.py output/base2new/test_new/imagenet/shots_16/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1 --test-log 62 | ``` 63 | 64 | ### Cross-Dataset Transfer 65 | 66 | This corresponds to the experiments in Section 4.2, i.e., Table 2. 67 | 68 | The relevant scripts are `scripts/cocoop/xd_train.sh` and `scripts/cocoop/xd_test.sh` where the `DATASET` variable is set to the default, namely `imagenet`. To train the model, run 69 | 70 | ```bash 71 | # seed=1 72 | base xd_train.sh 1 73 | 74 | # seed=2 75 | base xd_train.sh 2 76 | 77 | # seed=3 78 | base xd_train.sh 3 79 | ``` 80 | 81 | Then, you evaluate the model on other datasets, e.g., 82 | 83 | ```bash 84 | for SEED in 1 2 3 85 | do 86 | base xd_test.sh caltech101 ${SEED} 87 | base xd_test.sh oxford_pets ${SEED} 88 | base xd_test.sh stanford_cars ${SEED} 89 | done 90 | ``` 91 | 92 | ### Domain Generalization 93 | 94 | This corresponds to the experiments in Section 4.3, i.e., Table 3. 95 | 96 | The steps are similar to those discussed in "Cross-Dataset Transfer" except you evaluate the model on the variants of ImageNet, i.e., `imagenetv2`, `imagenet_sketch`, `imagenet_a` and `imagenet_r`. -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # At least two spaces before inline comment 4 | E261, 5 | # Line lengths are recommended to be no greater than 79 characters 6 | E501, 7 | # Missing whitespace around arithmetic operator 8 | E226, 9 | # Blank line contains whitespace 10 | W293, 11 | # Do not use bare 'except' 12 | E722, 13 | # Line break after binary operator 14 | W504, 15 | # Too many leading '#' for block comment 16 | E266, 17 | # line break before binary operator 18 | W503, 19 | # continuation line over-indented for hanging indent 20 | E126 21 | max-line-length = 79 22 | exclude = __init__.py, build -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # OS X 132 | .DS_Store 133 | .Spotlight-V100 134 | .Trashes 135 | ._* 136 | 137 | # This project 138 | output/ 139 | debug.sh 140 | debug.py 141 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=79 3 | multi_line_output=6 4 | length_sort=true 5 | known_standard_library=numpy,setuptools 6 | known_myself=dassl 7 | known_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs,scipy,gdown 8 | no_lines_before=STDLIB,THIRDPARTY 9 | sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER 10 | default_section=FIRSTPARTY -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | BASED_ON_STYLE = pep8 3 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 4 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 5 | DEDENT_CLOSING_BRACKETS = true 6 | SPACES_BEFORE_COMMENT = 2 7 | ARITHMETIC_PRECEDENCE_INDICATION = true -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kaiyang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dassl 3 | ------ 4 | PyTorch toolbox for domain adaptation and semi-supervised learning. 5 | 6 | URL: https://github.com/KaiyangZhou/Dassl.pytorch 7 | 8 | @article{zhou2020domain, 9 | title={Domain Adaptive Ensemble Learning}, 10 | author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao}, 11 | journal={arXiv preprint arXiv:2003.07325}, 12 | year={2020} 13 | } 14 | """ 15 | 16 | __version__ = "0.5.0" 17 | __author__ = "Kaiyang Zhou" 18 | __homepage__ = "https://kaiyangzhou.github.io/" 19 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg_default 2 | 3 | 4 | def get_cfg_default(): 5 | return cfg_default.clone() 6 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_manager import DataManager, DatasetWrapper 2 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import DATASET_REGISTRY, build_dataset # isort:skip 2 | from .base_dataset import Datum, DatasetBase # isort:skip 3 | 4 | from .da import * 5 | from .dg import * 6 | from .ssl import * 7 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | DATASET_REGISTRY = Registry("DATASET") 4 | 5 | 6 | def build_dataset(cfg): 7 | avai_datasets = DATASET_REGISTRY.registered_names() 8 | check_availability(cfg.DATASET.NAME, avai_datasets) 9 | if cfg.VERBOSE: 10 | print("Loading dataset: {}".format(cfg.DATASET.NAME)) 11 | return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg) 12 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/da/__init__.py: -------------------------------------------------------------------------------- 1 | from .digit5 import Digit5 2 | from .visda17 import VisDA17 3 | from .cifarstl import CIFARSTL 4 | from .office31 import Office31 5 | from .domainnet import DomainNet 6 | from .office_home import OfficeHome 7 | from .mini_domainnet import miniDomainNet 8 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/da/cifarstl.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class CIFARSTL(DatasetBase): 11 | """CIFAR-10 and STL-10. 12 | 13 | CIFAR-10: 14 | - 60,000 32x32 colour images. 15 | - 10 classes, with 6,000 images per class. 16 | - 50,000 training images and 10,000 test images. 17 | - URL: https://www.cs.toronto.edu/~kriz/cifar.html. 18 | 19 | STL-10: 20 | - 10 classes: airplane, bird, car, cat, deer, dog, horse, 21 | monkey, ship, truck. 22 | - Images are 96x96 pixels, color. 23 | - 500 training images (10 pre-defined folds), 800 test images 24 | per class. 25 | - URL: https://cs.stanford.edu/~acoates/stl10/. 26 | 27 | Reference: 28 | - Krizhevsky. Learning Multiple Layers of Features 29 | from Tiny Images. Tech report. 30 | - Coates et al. An Analysis of Single Layer Networks in 31 | Unsupervised Feature Learning. AISTATS 2011. 32 | """ 33 | 34 | dataset_dir = "cifar_stl" 35 | domains = ["cifar", "stl"] 36 | 37 | def __init__(self, cfg): 38 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 39 | self.dataset_dir = osp.join(root, self.dataset_dir) 40 | 41 | self.check_input_domains( 42 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 43 | ) 44 | 45 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 46 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 47 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 48 | 49 | super().__init__(train_x=train_x, train_u=train_u, test=test) 50 | 51 | def _read_data(self, input_domains, split="train"): 52 | items = [] 53 | 54 | for domain, dname in enumerate(input_domains): 55 | data_dir = osp.join(self.dataset_dir, dname, split) 56 | class_names = listdir_nohidden(data_dir) 57 | 58 | for class_name in class_names: 59 | class_dir = osp.join(data_dir, class_name) 60 | imnames = listdir_nohidden(class_dir) 61 | label = int(class_name.split("_")[0]) 62 | 63 | for imname in imnames: 64 | impath = osp.join(class_dir, imname) 65 | item = Datum(impath=impath, label=label, domain=domain) 66 | items.append(item) 67 | 68 | return items 69 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/da/domainnet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class DomainNet(DatasetBase): 9 | """DomainNet. 10 | 11 | Statistics: 12 | - 6 distinct domains: Clipart, Infograph, Painting, Quickdraw, 13 | Real, Sketch. 14 | - Around 0.6M images. 15 | - 345 categories. 16 | - URL: http://ai.bu.edu/M3SDA/. 17 | 18 | Special note: the t-shirt class (327) is missing in painting_train.txt. 19 | 20 | Reference: 21 | - Peng et al. Moment Matching for Multi-Source Domain 22 | Adaptation. ICCV 2019. 23 | """ 24 | 25 | dataset_dir = "domainnet" 26 | domains = [ 27 | "clipart", "infograph", "painting", "quickdraw", "real", "sketch" 28 | ] 29 | 30 | def __init__(self, cfg): 31 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 32 | self.dataset_dir = osp.join(root, self.dataset_dir) 33 | self.split_dir = osp.join(self.dataset_dir, "splits") 34 | 35 | self.check_input_domains( 36 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 37 | ) 38 | 39 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 40 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 41 | val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test") 42 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 43 | 44 | super().__init__(train_x=train_x, train_u=train_u, val=val, test=test) 45 | 46 | def _read_data(self, input_domains, split="train"): 47 | items = [] 48 | 49 | for domain, dname in enumerate(input_domains): 50 | filename = dname + "_" + split + ".txt" 51 | split_file = osp.join(self.split_dir, filename) 52 | 53 | with open(split_file, "r") as f: 54 | lines = f.readlines() 55 | for line in lines: 56 | line = line.strip() 57 | impath, label = line.split(" ") 58 | classname = impath.split("/")[1] 59 | impath = osp.join(self.dataset_dir, impath) 60 | label = int(label) 61 | item = Datum( 62 | impath=impath, 63 | label=label, 64 | domain=domain, 65 | classname=classname 66 | ) 67 | items.append(item) 68 | 69 | return items 70 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/da/mini_domainnet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class miniDomainNet(DatasetBase): 9 | """A subset of DomainNet. 10 | 11 | Reference: 12 | - Peng et al. Moment Matching for Multi-Source Domain 13 | Adaptation. ICCV 2019. 14 | - Zhou et al. Domain Adaptive Ensemble Learning. 15 | """ 16 | 17 | dataset_dir = "domainnet" 18 | domains = ["clipart", "painting", "real", "sketch"] 19 | 20 | def __init__(self, cfg): 21 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = osp.join(root, self.dataset_dir) 23 | self.split_dir = osp.join(self.dataset_dir, "splits_mini") 24 | 25 | self.check_input_domains( 26 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 27 | ) 28 | 29 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") 30 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") 31 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") 32 | 33 | super().__init__(train_x=train_x, train_u=train_u, test=test) 34 | 35 | def _read_data(self, input_domains, split="train"): 36 | items = [] 37 | 38 | for domain, dname in enumerate(input_domains): 39 | filename = dname + "_" + split + ".txt" 40 | split_file = osp.join(self.split_dir, filename) 41 | 42 | with open(split_file, "r") as f: 43 | lines = f.readlines() 44 | for line in lines: 45 | line = line.strip() 46 | impath, label = line.split(" ") 47 | classname = impath.split("/")[1] 48 | impath = osp.join(self.dataset_dir, impath) 49 | label = int(label) 50 | item = Datum( 51 | impath=impath, 52 | label=label, 53 | domain=domain, 54 | classname=classname 55 | ) 56 | items.append(item) 57 | 58 | return items 59 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/da/office31.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class Office31(DatasetBase): 11 | """Office-31. 12 | 13 | Statistics: 14 | - 4,110 images. 15 | - 31 classes related to office objects. 16 | - 3 domains: Amazon, Webcam, Dslr. 17 | - URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/. 18 | 19 | Reference: 20 | - Saenko et al. Adapting visual category models to 21 | new domains. ECCV 2010. 22 | """ 23 | 24 | dataset_dir = "office31" 25 | domains = ["amazon", "webcam", "dslr"] 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | self.check_input_domains( 32 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 33 | ) 34 | 35 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) 36 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) 37 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS) 38 | 39 | super().__init__(train_x=train_x, train_u=train_u, test=test) 40 | 41 | def _read_data(self, input_domains): 42 | items = [] 43 | 44 | for domain, dname in enumerate(input_domains): 45 | domain_dir = osp.join(self.dataset_dir, dname) 46 | class_names = listdir_nohidden(domain_dir) 47 | class_names.sort() 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_path = osp.join(domain_dir, class_name) 51 | imnames = listdir_nohidden(class_path) 52 | 53 | for imname in imnames: 54 | impath = osp.join(class_path, imname) 55 | item = Datum( 56 | impath=impath, 57 | label=label, 58 | domain=domain, 59 | classname=class_name 60 | ) 61 | items.append(item) 62 | 63 | return items 64 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/da/office_home.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from dassl.utils import listdir_nohidden 4 | 5 | from ..build import DATASET_REGISTRY 6 | from ..base_dataset import Datum, DatasetBase 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class OfficeHome(DatasetBase): 11 | """Office-Home. 12 | 13 | Statistics: 14 | - Around 15,500 images. 15 | - 65 classes related to office and home objects. 16 | - 4 domains: Art, Clipart, Product, Real World. 17 | - URL: http://hemanthdv.org/OfficeHome-Dataset/. 18 | 19 | Reference: 20 | - Venkateswara et al. Deep Hashing Network for Unsupervised 21 | Domain Adaptation. CVPR 2017. 22 | """ 23 | 24 | dataset_dir = "office_home" 25 | domains = ["art", "clipart", "product", "real_world"] 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | self.check_input_domains( 32 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 33 | ) 34 | 35 | train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS) 36 | train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS) 37 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS) 38 | 39 | super().__init__(train_x=train_x, train_u=train_u, test=test) 40 | 41 | def _read_data(self, input_domains): 42 | items = [] 43 | 44 | for domain, dname in enumerate(input_domains): 45 | domain_dir = osp.join(self.dataset_dir, dname) 46 | class_names = listdir_nohidden(domain_dir) 47 | class_names.sort() 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_path = osp.join(domain_dir, class_name) 51 | imnames = listdir_nohidden(class_path) 52 | 53 | for imname in imnames: 54 | impath = osp.join(class_path, imname) 55 | item = Datum( 56 | impath=impath, 57 | label=label, 58 | domain=domain, 59 | classname=class_name.lower(), 60 | ) 61 | items.append(item) 62 | 63 | return items 64 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/da/visda17.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class VisDA17(DatasetBase): 9 | """VisDA17. 10 | 11 | Focusing on simulation-to-reality domain shift. 12 | 13 | URL: http://ai.bu.edu/visda-2017/. 14 | 15 | Reference: 16 | - Peng et al. VisDA: The Visual Domain Adaptation 17 | Challenge. ArXiv 2017. 18 | """ 19 | 20 | dataset_dir = "visda17" 21 | domains = ["synthetic", "real"] 22 | 23 | def __init__(self, cfg): 24 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 25 | self.dataset_dir = osp.join(root, self.dataset_dir) 26 | 27 | self.check_input_domains( 28 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 29 | ) 30 | 31 | train_x = self._read_data("synthetic") 32 | train_u = self._read_data("real") 33 | test = self._read_data("real") 34 | 35 | super().__init__(train_x=train_x, train_u=train_u, test=test) 36 | 37 | def _read_data(self, dname): 38 | filedir = "train" if dname == "synthetic" else "validation" 39 | image_list = osp.join(self.dataset_dir, filedir, "image_list.txt") 40 | items = [] 41 | # There is only one source domain 42 | domain = 0 43 | 44 | with open(image_list, "r") as f: 45 | lines = f.readlines() 46 | 47 | for line in lines: 48 | line = line.strip() 49 | impath, label = line.split(" ") 50 | classname = impath.split("/")[0] 51 | impath = osp.join(self.dataset_dir, filedir, impath) 52 | label = int(label) 53 | item = Datum( 54 | impath=impath, 55 | label=label, 56 | domain=domain, 57 | classname=classname 58 | ) 59 | items.append(item) 60 | 61 | return items 62 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/dg/__init__.py: -------------------------------------------------------------------------------- 1 | from .pacs import PACS 2 | from .vlcs import VLCS 3 | from .cifar_c import CIFAR10C, CIFAR100C 4 | from .digits_dg import DigitsDG 5 | from .digit_single import DigitSingle 6 | from .office_home_dg import OfficeHomeDG 7 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/dg/digits_dg.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | 4 | from dassl.utils import listdir_nohidden 5 | 6 | from ..build import DATASET_REGISTRY 7 | from ..base_dataset import Datum, DatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class DigitsDG(DatasetBase): 12 | """Digits-DG. 13 | 14 | It contains 4 digit datasets: 15 | - MNIST: hand-written digits. 16 | - MNIST-M: variant of MNIST with blended background. 17 | - SVHN: street view house number. 18 | - SYN: synthetic digits. 19 | 20 | Reference: 21 | - Lecun et al. Gradient-based learning applied to document 22 | recognition. IEEE 1998. 23 | - Ganin et al. Domain-adversarial training of neural networks. 24 | JMLR 2016. 25 | - Netzer et al. Reading digits in natural images with unsupervised 26 | feature learning. NIPS-W 2011. 27 | - Zhou et al. Deep Domain-Adversarial Image Generation for Domain 28 | Generalisation. AAAI 2020. 29 | """ 30 | 31 | dataset_dir = "digits_dg" 32 | domains = ["mnist", "mnist_m", "svhn", "syn"] 33 | data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7" 34 | 35 | def __init__(self, cfg): 36 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 37 | self.dataset_dir = osp.join(root, self.dataset_dir) 38 | 39 | if not osp.exists(self.dataset_dir): 40 | dst = osp.join(root, "digits_dg.zip") 41 | self.download_data(self.data_url, dst, from_gdrive=True) 42 | 43 | self.check_input_domains( 44 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 45 | ) 46 | 47 | train = self.read_data( 48 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train" 49 | ) 50 | val = self.read_data( 51 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val" 52 | ) 53 | test = self.read_data( 54 | self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all" 55 | ) 56 | 57 | super().__init__(train_x=train, val=val, test=test) 58 | 59 | @staticmethod 60 | def read_data(dataset_dir, input_domains, split): 61 | 62 | def _load_data_from_directory(directory): 63 | folders = listdir_nohidden(directory) 64 | folders.sort() 65 | items_ = [] 66 | 67 | for label, folder in enumerate(folders): 68 | impaths = glob.glob(osp.join(directory, folder, "*.jpg")) 69 | 70 | for impath in impaths: 71 | items_.append((impath, label)) 72 | 73 | return items_ 74 | 75 | items = [] 76 | 77 | for domain, dname in enumerate(input_domains): 78 | if split == "all": 79 | train_dir = osp.join(dataset_dir, dname, "train") 80 | impath_label_list = _load_data_from_directory(train_dir) 81 | val_dir = osp.join(dataset_dir, dname, "val") 82 | impath_label_list += _load_data_from_directory(val_dir) 83 | else: 84 | split_dir = osp.join(dataset_dir, dname, split) 85 | impath_label_list = _load_data_from_directory(split_dir) 86 | 87 | for impath, label in impath_label_list: 88 | class_name = impath.split("/")[-2].lower() 89 | item = Datum( 90 | impath=impath, 91 | label=label, 92 | domain=domain, 93 | classname=class_name 94 | ) 95 | items.append(item) 96 | 97 | return items 98 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/dg/office_home_dg.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from .digits_dg import DigitsDG 5 | from ..base_dataset import DatasetBase 6 | 7 | 8 | @DATASET_REGISTRY.register() 9 | class OfficeHomeDG(DatasetBase): 10 | """Office-Home. 11 | 12 | Statistics: 13 | - Around 15,500 images. 14 | - 65 classes related to office and home objects. 15 | - 4 domains: Art, Clipart, Product, Real World. 16 | - URL: http://hemanthdv.org/OfficeHome-Dataset/. 17 | 18 | Reference: 19 | - Venkateswara et al. Deep Hashing Network for Unsupervised 20 | Domain Adaptation. CVPR 2017. 21 | """ 22 | 23 | dataset_dir = "office_home_dg" 24 | domains = ["art", "clipart", "product", "real_world"] 25 | data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa" 26 | 27 | def __init__(self, cfg): 28 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | 31 | if not osp.exists(self.dataset_dir): 32 | dst = osp.join(root, "office_home_dg.zip") 33 | self.download_data(self.data_url, dst, from_gdrive=True) 34 | 35 | self.check_input_domains( 36 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 37 | ) 38 | 39 | train = DigitsDG.read_data( 40 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train" 41 | ) 42 | val = DigitsDG.read_data( 43 | self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val" 44 | ) 45 | test = DigitsDG.read_data( 46 | self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all" 47 | ) 48 | 49 | super().__init__(train_x=train, val=val, test=test) 50 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/dg/pacs.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..build import DATASET_REGISTRY 4 | from ..base_dataset import Datum, DatasetBase 5 | 6 | 7 | @DATASET_REGISTRY.register() 8 | class PACS(DatasetBase): 9 | """PACS. 10 | 11 | Statistics: 12 | - 4 domains: Photo (1,670), Art (2,048), Cartoon 13 | (2,344), Sketch (3,929). 14 | - 7 categories: dog, elephant, giraffe, guitar, horse, 15 | house and person. 16 | 17 | Reference: 18 | - Li et al. Deeper, broader and artier domain generalization. 19 | ICCV 2017. 20 | """ 21 | 22 | dataset_dir = "pacs" 23 | domains = ["art_painting", "cartoon", "photo", "sketch"] 24 | data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE" 25 | # the following images contain errors and should be ignored 26 | _error_paths = ["sketch/dog/n02103406_4068-1.png"] 27 | 28 | def __init__(self, cfg): 29 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.image_dir = osp.join(self.dataset_dir, "images") 32 | self.split_dir = osp.join(self.dataset_dir, "splits") 33 | 34 | if not osp.exists(self.dataset_dir): 35 | dst = osp.join(root, "pacs.zip") 36 | self.download_data(self.data_url, dst, from_gdrive=True) 37 | 38 | self.check_input_domains( 39 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 40 | ) 41 | 42 | train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train") 43 | val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval") 44 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all") 45 | 46 | super().__init__(train_x=train, val=val, test=test) 47 | 48 | def _read_data(self, input_domains, split): 49 | items = [] 50 | 51 | for domain, dname in enumerate(input_domains): 52 | if split == "all": 53 | file_train = osp.join( 54 | self.split_dir, dname + "_train_kfold.txt" 55 | ) 56 | impath_label_list = self._read_split_pacs(file_train) 57 | file_val = osp.join( 58 | self.split_dir, dname + "_crossval_kfold.txt" 59 | ) 60 | impath_label_list += self._read_split_pacs(file_val) 61 | else: 62 | file = osp.join( 63 | self.split_dir, dname + "_" + split + "_kfold.txt" 64 | ) 65 | impath_label_list = self._read_split_pacs(file) 66 | 67 | for impath, label in impath_label_list: 68 | classname = impath.split("/")[-2] 69 | item = Datum( 70 | impath=impath, 71 | label=label, 72 | domain=domain, 73 | classname=classname 74 | ) 75 | items.append(item) 76 | 77 | return items 78 | 79 | def _read_split_pacs(self, split_file): 80 | items = [] 81 | 82 | with open(split_file, "r") as f: 83 | lines = f.readlines() 84 | 85 | for line in lines: 86 | line = line.strip() 87 | impath, label = line.split(" ") 88 | if impath in self._error_paths: 89 | continue 90 | impath = osp.join(self.image_dir, impath) 91 | label = int(label) - 1 92 | items.append((impath, label)) 93 | 94 | return items 95 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/dg/vlcs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | 4 | from dassl.utils import listdir_nohidden 5 | 6 | from ..build import DATASET_REGISTRY 7 | from ..base_dataset import Datum, DatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class VLCS(DatasetBase): 12 | """VLCS. 13 | 14 | Statistics: 15 | - 4 domains: CALTECH, LABELME, PASCAL, SUN 16 | - 5 categories: bird, car, chair, dog, and person. 17 | 18 | Reference: 19 | - Torralba and Efros. Unbiased look at dataset bias. CVPR 2011. 20 | """ 21 | 22 | dataset_dir = "VLCS" 23 | domains = ["caltech", "labelme", "pascal", "sun"] 24 | data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd" 25 | 26 | def __init__(self, cfg): 27 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 28 | self.dataset_dir = osp.join(root, self.dataset_dir) 29 | 30 | if not osp.exists(self.dataset_dir): 31 | dst = osp.join(root, "vlcs.zip") 32 | self.download_data(self.data_url, dst, from_gdrive=True) 33 | 34 | self.check_input_domains( 35 | cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS 36 | ) 37 | 38 | train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train") 39 | val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval") 40 | test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test") 41 | 42 | super().__init__(train_x=train, val=val, test=test) 43 | 44 | def _read_data(self, input_domains, split): 45 | items = [] 46 | 47 | for domain, dname in enumerate(input_domains): 48 | dname = dname.upper() 49 | path = osp.join(self.dataset_dir, dname, split) 50 | folders = listdir_nohidden(path) 51 | folders.sort() 52 | 53 | for label, folder in enumerate(folders): 54 | impaths = glob.glob(osp.join(path, folder, "*.jpg")) 55 | 56 | for impath in impaths: 57 | item = Datum(impath=impath, label=label, domain=domain) 58 | items.append(item) 59 | 60 | return items 61 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/ssl/__init__.py: -------------------------------------------------------------------------------- 1 | from .svhn import SVHN 2 | from .cifar import CIFAR10, CIFAR100 3 | from .stl10 import STL10 4 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/ssl/cifar.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import os.path as osp 4 | 5 | from dassl.utils import listdir_nohidden 6 | 7 | from ..build import DATASET_REGISTRY 8 | from ..base_dataset import Datum, DatasetBase 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class CIFAR10(DatasetBase): 13 | """CIFAR10 for SSL. 14 | 15 | Reference: 16 | - Krizhevsky. Learning Multiple Layers of Features 17 | from Tiny Images. Tech report. 18 | """ 19 | 20 | dataset_dir = "cifar10" 21 | 22 | def __init__(self, cfg): 23 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | train_dir = osp.join(self.dataset_dir, "train") 26 | test_dir = osp.join(self.dataset_dir, "test") 27 | 28 | assert cfg.DATASET.NUM_LABELED > 0 29 | 30 | train_x, train_u, val = self._read_data_train( 31 | train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT 32 | ) 33 | test = self._read_data_test(test_dir) 34 | 35 | if cfg.DATASET.ALL_AS_UNLABELED: 36 | train_u = train_u + train_x 37 | 38 | if len(val) == 0: 39 | val = None 40 | 41 | super().__init__(train_x=train_x, train_u=train_u, val=val, test=test) 42 | 43 | def _read_data_train(self, data_dir, num_labeled, val_percent): 44 | class_names = listdir_nohidden(data_dir) 45 | class_names.sort() 46 | num_labeled_per_class = num_labeled / len(class_names) 47 | items_x, items_u, items_v = [], [], [] 48 | 49 | for label, class_name in enumerate(class_names): 50 | class_dir = osp.join(data_dir, class_name) 51 | imnames = listdir_nohidden(class_dir) 52 | 53 | # Split into train and val following Oliver et al. 2018 54 | # Set cfg.DATASET.VAL_PERCENT to 0 to not use val data 55 | num_val = math.floor(len(imnames) * val_percent) 56 | imnames_train = imnames[num_val:] 57 | imnames_val = imnames[:num_val] 58 | 59 | # Note we do shuffle after split 60 | random.shuffle(imnames_train) 61 | 62 | for i, imname in enumerate(imnames_train): 63 | impath = osp.join(class_dir, imname) 64 | item = Datum(impath=impath, label=label) 65 | 66 | if (i + 1) <= num_labeled_per_class: 67 | items_x.append(item) 68 | 69 | else: 70 | items_u.append(item) 71 | 72 | for imname in imnames_val: 73 | impath = osp.join(class_dir, imname) 74 | item = Datum(impath=impath, label=label) 75 | items_v.append(item) 76 | 77 | return items_x, items_u, items_v 78 | 79 | def _read_data_test(self, data_dir): 80 | class_names = listdir_nohidden(data_dir) 81 | class_names.sort() 82 | items = [] 83 | 84 | for label, class_name in enumerate(class_names): 85 | class_dir = osp.join(data_dir, class_name) 86 | imnames = listdir_nohidden(class_dir) 87 | 88 | for imname in imnames: 89 | impath = osp.join(class_dir, imname) 90 | item = Datum(impath=impath, label=label) 91 | items.append(item) 92 | 93 | return items 94 | 95 | 96 | @DATASET_REGISTRY.register() 97 | class CIFAR100(CIFAR10): 98 | """CIFAR100 for SSL. 99 | 100 | Reference: 101 | - Krizhevsky. Learning Multiple Layers of Features 102 | from Tiny Images. Tech report. 103 | """ 104 | 105 | dataset_dir = "cifar100" 106 | 107 | def __init__(self, cfg): 108 | super().__init__(cfg) 109 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/ssl/stl10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | 4 | from dassl.utils import listdir_nohidden 5 | 6 | from ..build import DATASET_REGISTRY 7 | from ..base_dataset import Datum, DatasetBase 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class STL10(DatasetBase): 12 | """STL-10 dataset. 13 | 14 | Description: 15 | - 10 classes: airplane, bird, car, cat, deer, dog, horse, 16 | monkey, ship, truck. 17 | - Images are 96x96 pixels, color. 18 | - 500 training images per class, 800 test images per class. 19 | - 100,000 unlabeled images for unsupervised learning. 20 | 21 | Reference: 22 | - Coates et al. An Analysis of Single Layer Networks in 23 | Unsupervised Feature Learning. AISTATS 2011. 24 | """ 25 | 26 | dataset_dir = "stl10" 27 | 28 | def __init__(self, cfg): 29 | root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | train_dir = osp.join(self.dataset_dir, "train") 32 | test_dir = osp.join(self.dataset_dir, "test") 33 | unlabeled_dir = osp.join(self.dataset_dir, "unlabeled") 34 | fold_file = osp.join( 35 | self.dataset_dir, "stl10_binary", "fold_indices.txt" 36 | ) 37 | 38 | # Only use the first five splits 39 | assert 0 <= cfg.DATASET.STL10_FOLD <= 4 40 | 41 | train_x = self._read_data_train( 42 | train_dir, cfg.DATASET.STL10_FOLD, fold_file 43 | ) 44 | train_u = self._read_data_all(unlabeled_dir) 45 | test = self._read_data_all(test_dir) 46 | 47 | if cfg.DATASET.ALL_AS_UNLABELED: 48 | train_u = train_u + train_x 49 | 50 | super().__init__(train_x=train_x, train_u=train_u, test=test) 51 | 52 | def _read_data_train(self, data_dir, fold, fold_file): 53 | imnames = listdir_nohidden(data_dir) 54 | imnames.sort() 55 | items = [] 56 | 57 | list_idx = list(range(len(imnames))) 58 | if fold >= 0: 59 | with open(fold_file, "r") as f: 60 | str_idx = f.read().splitlines()[fold] 61 | list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ") 62 | 63 | for i in list_idx: 64 | imname = imnames[i] 65 | impath = osp.join(data_dir, imname) 66 | label = osp.splitext(imname)[0].split("_")[1] 67 | label = int(label) 68 | item = Datum(impath=impath, label=label) 69 | items.append(item) 70 | 71 | return items 72 | 73 | def _read_data_all(self, data_dir): 74 | imnames = listdir_nohidden(data_dir) 75 | items = [] 76 | 77 | for imname in imnames: 78 | impath = osp.join(data_dir, imname) 79 | label = osp.splitext(imname)[0].split("_")[1] 80 | if label == "none": 81 | label = -1 82 | else: 83 | label = int(label) 84 | item = Datum(impath=impath, label=label) 85 | items.append(item) 86 | 87 | return items 88 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/datasets/ssl/svhn.py: -------------------------------------------------------------------------------- 1 | from .cifar import CIFAR10 2 | from ..build import DATASET_REGISTRY 3 | 4 | 5 | @DATASET_REGISTRY.register() 6 | class SVHN(CIFAR10): 7 | """SVHN for SSL. 8 | 9 | Reference: 10 | - Netzer et al. Reading Digits in Natural Images with 11 | Unsupervised Feature Learning. NIPS-W 2011. 12 | """ 13 | 14 | dataset_dir = "svhn" 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import build_transform 2 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import TRAINER_REGISTRY, build_trainer # isort:skip 2 | from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip 3 | 4 | from .da import * 5 | from .dg import * 6 | from .ssl import * 7 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | TRAINER_REGISTRY = Registry("TRAINER") 4 | 5 | 6 | def build_trainer(cfg): 7 | avai_trainers = TRAINER_REGISTRY.registered_names() 8 | print(f"Check_availability {avai_trainers}") 9 | check_availability(cfg.TRAINER.NAME, avai_trainers) 10 | if cfg.VERBOSE: 11 | print("Loading trainer: {}".format(cfg.TRAINER.NAME)) 12 | return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg) 13 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/da/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcd import MCD 2 | from .mme import MME 3 | from .adda import ADDA 4 | from .dael import DAEL 5 | from .dann import DANN 6 | from .adabn import AdaBN 7 | from .m3sda import M3SDA 8 | from .source_only import SourceOnly 9 | from .self_ensembling import SelfEnsembling 10 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/da/adabn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dassl.utils import check_isfile 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class AdaBN(TrainerXU): 9 | """Adaptive Batch Normalization. 10 | 11 | https://arxiv.org/abs/1603.04779. 12 | """ 13 | 14 | def __init__(self, cfg): 15 | super().__init__(cfg) 16 | self.done_reset_bn_stats = False 17 | 18 | def check_cfg(self, cfg): 19 | assert check_isfile( 20 | cfg.MODEL.INIT_WEIGHTS 21 | ), "The weights of source model must be provided" 22 | 23 | def before_epoch(self): 24 | if not self.done_reset_bn_stats: 25 | for m in self.model.modules(): 26 | classname = m.__class__.__name__ 27 | if classname.find("BatchNorm") != -1: 28 | m.reset_running_stats() 29 | 30 | self.done_reset_bn_stats = True 31 | 32 | def forward_backward(self, batch_x, batch_u): 33 | input_u = batch_u["img"].to(self.device) 34 | 35 | with torch.no_grad(): 36 | self.model(input_u) 37 | 38 | return None 39 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/da/adda.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import check_isfile, count_num_param, open_specified_layers 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.modeling import build_head 9 | 10 | 11 | @TRAINER_REGISTRY.register() 12 | class ADDA(TrainerXU): 13 | """Adversarial Discriminative Domain Adaptation. 14 | 15 | https://arxiv.org/abs/1702.05464. 16 | """ 17 | 18 | def __init__(self, cfg): 19 | super().__init__(cfg) 20 | self.open_layers = ["backbone"] 21 | if isinstance(self.model.head, nn.Module): 22 | self.open_layers.append("head") 23 | 24 | self.source_model = copy.deepcopy(self.model) 25 | self.source_model.eval() 26 | for param in self.source_model.parameters(): 27 | param.requires_grad_(False) 28 | 29 | self.build_critic() 30 | 31 | self.bce = nn.BCEWithLogitsLoss() 32 | 33 | def check_cfg(self, cfg): 34 | assert check_isfile( 35 | cfg.MODEL.INIT_WEIGHTS 36 | ), "The weights of source model must be provided" 37 | 38 | def build_critic(self): 39 | cfg = self.cfg 40 | 41 | print("Building critic network") 42 | fdim = self.model.fdim 43 | critic_body = build_head( 44 | "mlp", 45 | verbose=cfg.VERBOSE, 46 | in_features=fdim, 47 | hidden_layers=[fdim, fdim // 2], 48 | activation="leaky_relu", 49 | ) 50 | self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1)) 51 | print("# params: {:,}".format(count_num_param(self.critic))) 52 | self.critic.to(self.device) 53 | self.optim_c = build_optimizer(self.critic, cfg.OPTIM) 54 | self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) 55 | self.register_model("critic", self.critic, self.optim_c, self.sched_c) 56 | 57 | def forward_backward(self, batch_x, batch_u): 58 | open_specified_layers(self.model, self.open_layers) 59 | input_x, _, input_u = self.parse_batch_train(batch_x, batch_u) 60 | domain_x = torch.ones(input_x.shape[0], 1).to(self.device) 61 | domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) 62 | 63 | _, feat_x = self.source_model(input_x, return_feature=True) 64 | _, feat_u = self.model(input_u, return_feature=True) 65 | 66 | logit_xd = self.critic(feat_x) 67 | logit_ud = self.critic(feat_u.detach()) 68 | 69 | loss_critic = self.bce(logit_xd, domain_x) 70 | loss_critic += self.bce(logit_ud, domain_u) 71 | self.model_backward_and_update(loss_critic, "critic") 72 | 73 | logit_ud = self.critic(feat_u) 74 | loss_model = self.bce(logit_ud, 1 - domain_u) 75 | self.model_backward_and_update(loss_model, "model") 76 | 77 | loss_summary = { 78 | "loss_critic": loss_critic.item(), 79 | "loss_model": loss_model.item(), 80 | } 81 | 82 | if (self.batch_idx + 1) == self.num_batches: 83 | self.update_lr() 84 | 85 | return loss_summary 86 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/da/dann.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import count_num_param 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.metrics import compute_accuracy 9 | from dassl.modeling import build_head 10 | from dassl.modeling.ops import ReverseGrad 11 | 12 | 13 | @TRAINER_REGISTRY.register() 14 | class DANN(TrainerXU): 15 | """Domain-Adversarial Neural Networks. 16 | 17 | https://arxiv.org/abs/1505.07818. 18 | """ 19 | 20 | def __init__(self, cfg): 21 | super().__init__(cfg) 22 | self.build_critic() 23 | self.ce = nn.CrossEntropyLoss() 24 | self.bce = nn.BCEWithLogitsLoss() 25 | 26 | def build_critic(self): 27 | cfg = self.cfg 28 | 29 | print("Building critic network") 30 | fdim = self.model.fdim 31 | critic_body = build_head( 32 | "mlp", 33 | verbose=cfg.VERBOSE, 34 | in_features=fdim, 35 | hidden_layers=[fdim, fdim], 36 | activation="leaky_relu", 37 | ) 38 | self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1)) 39 | print("# params: {:,}".format(count_num_param(self.critic))) 40 | self.critic.to(self.device) 41 | self.optim_c = build_optimizer(self.critic, cfg.OPTIM) 42 | self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM) 43 | self.register_model("critic", self.critic, self.optim_c, self.sched_c) 44 | self.revgrad = ReverseGrad() 45 | 46 | def forward_backward(self, batch_x, batch_u): 47 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 48 | domain_x = torch.ones(input_x.shape[0], 1).to(self.device) 49 | domain_u = torch.zeros(input_u.shape[0], 1).to(self.device) 50 | 51 | global_step = self.batch_idx + self.epoch * self.num_batches 52 | progress = global_step / (self.max_epoch * self.num_batches) 53 | lmda = 2 / (1 + np.exp(-10 * progress)) - 1 54 | 55 | logit_x, feat_x = self.model(input_x, return_feature=True) 56 | _, feat_u = self.model(input_u, return_feature=True) 57 | 58 | loss_x = self.ce(logit_x, label_x) 59 | 60 | feat_x = self.revgrad(feat_x, grad_scaling=lmda) 61 | feat_u = self.revgrad(feat_u, grad_scaling=lmda) 62 | output_xd = self.critic(feat_x) 63 | output_ud = self.critic(feat_u) 64 | loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u) 65 | 66 | loss = loss_x + loss_d 67 | self.model_backward_and_update(loss) 68 | 69 | loss_summary = { 70 | "loss_x": loss_x.item(), 71 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 72 | "loss_d": loss_d.item(), 73 | } 74 | 75 | if (self.batch_idx + 1) == self.num_batches: 76 | self.update_lr() 77 | 78 | return loss_summary 79 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/da/mme.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | from dassl.utils import count_num_param 7 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 8 | from dassl.metrics import compute_accuracy 9 | from dassl.modeling.ops import ReverseGrad 10 | from dassl.engine.trainer import SimpleNet 11 | 12 | 13 | class Prototypes(nn.Module): 14 | 15 | def __init__(self, fdim, num_classes, temp=0.05): 16 | super().__init__() 17 | self.prototypes = nn.Linear(fdim, num_classes, bias=False) 18 | self.temp = temp 19 | 20 | def forward(self, x): 21 | x = F.normalize(x, p=2, dim=1) 22 | out = self.prototypes(x) 23 | out = out / self.temp 24 | return out 25 | 26 | 27 | @TRAINER_REGISTRY.register() 28 | class MME(TrainerXU): 29 | """Minimax Entropy. 30 | 31 | https://arxiv.org/abs/1904.06487. 32 | """ 33 | 34 | def __init__(self, cfg): 35 | super().__init__(cfg) 36 | self.lmda = cfg.TRAINER.MME.LMDA 37 | 38 | def build_model(self): 39 | cfg = self.cfg 40 | 41 | print("Building F") 42 | self.F = SimpleNet(cfg, cfg.MODEL, 0) 43 | self.F.to(self.device) 44 | print("# params: {:,}".format(count_num_param(self.F))) 45 | self.optim_F = build_optimizer(self.F, cfg.OPTIM) 46 | self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) 47 | self.register_model("F", self.F, self.optim_F, self.sched_F) 48 | 49 | print("Building C") 50 | self.C = Prototypes(self.F.fdim, self.num_classes) 51 | self.C.to(self.device) 52 | print("# params: {:,}".format(count_num_param(self.C))) 53 | self.optim_C = build_optimizer(self.C, cfg.OPTIM) 54 | self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM) 55 | self.register_model("C", self.C, self.optim_C, self.sched_C) 56 | 57 | self.revgrad = ReverseGrad() 58 | 59 | def forward_backward(self, batch_x, batch_u): 60 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 61 | 62 | feat_x = self.F(input_x) 63 | logit_x = self.C(feat_x) 64 | loss_x = F.cross_entropy(logit_x, label_x) 65 | self.model_backward_and_update(loss_x) 66 | 67 | feat_u = self.F(input_u) 68 | feat_u = self.revgrad(feat_u) 69 | logit_u = self.C(feat_u) 70 | prob_u = F.softmax(logit_u, 1) 71 | loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean() 72 | self.model_backward_and_update(loss_u * self.lmda) 73 | 74 | loss_summary = { 75 | "loss_x": loss_x.item(), 76 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 77 | "loss_u": loss_u.item(), 78 | } 79 | 80 | if (self.batch_idx + 1) == self.num_batches: 81 | self.update_lr() 82 | 83 | return loss_summary 84 | 85 | def model_inference(self, input): 86 | return self.C(self.F(input)) 87 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/da/self_ensembling.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update 7 | 8 | 9 | @TRAINER_REGISTRY.register() 10 | class SelfEnsembling(TrainerXU): 11 | """Self-ensembling for visual domain adaptation. 12 | 13 | https://arxiv.org/abs/1706.05208. 14 | """ 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA 19 | self.conf_thre = cfg.TRAINER.SE.CONF_THRE 20 | self.rampup = cfg.TRAINER.SE.RAMPUP 21 | 22 | self.teacher = copy.deepcopy(self.model) 23 | self.teacher.train() 24 | for param in self.teacher.parameters(): 25 | param.requires_grad_(False) 26 | 27 | def check_cfg(self, cfg): 28 | assert cfg.DATALOADER.K_TRANSFORMS == 2 29 | 30 | def forward_backward(self, batch_x, batch_u): 31 | global_step = self.batch_idx + self.epoch * self.num_batches 32 | parsed = self.parse_batch_train(batch_x, batch_u) 33 | input_x, label_x, input_u1, input_u2 = parsed 34 | 35 | logit_x = self.model(input_x) 36 | loss_x = F.cross_entropy(logit_x, label_x) 37 | 38 | prob_u = F.softmax(self.model(input_u1), 1) 39 | t_prob_u = F.softmax(self.teacher(input_u2), 1) 40 | loss_u = ((prob_u - t_prob_u)**2).sum(1) 41 | 42 | if self.conf_thre: 43 | max_prob = t_prob_u.max(1)[0] 44 | mask = (max_prob > self.conf_thre).float() 45 | loss_u = (loss_u * mask).mean() 46 | else: 47 | weight_u = sigmoid_rampup(global_step, self.rampup) 48 | loss_u = loss_u.mean() * weight_u 49 | 50 | loss = loss_x + loss_u 51 | self.model_backward_and_update(loss) 52 | 53 | ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha) 54 | ema_model_update(self.model, self.teacher, ema_alpha) 55 | 56 | loss_summary = { 57 | "loss_x": loss_x.item(), 58 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 59 | "loss_u": loss_u.item(), 60 | } 61 | 62 | if (self.batch_idx + 1) == self.num_batches: 63 | self.update_lr() 64 | 65 | return loss_summary 66 | 67 | def parse_batch_train(self, batch_x, batch_u): 68 | input_x = batch_x["img"][0] 69 | label_x = batch_x["label"] 70 | input_u = batch_u["img"] 71 | input_u1, input_u2 = input_u 72 | 73 | input_x = input_x.to(self.device) 74 | label_x = label_x.to(self.device) 75 | input_u1 = input_u1.to(self.device) 76 | input_u2 = input_u2.to(self.device) 77 | 78 | return input_x, label_x, input_u1, input_u2 79 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/da/source_only.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class SourceOnly(TrainerXU): 9 | """Baseline model for domain adaptation, which is 10 | trained using source data only. 11 | """ 12 | 13 | def forward_backward(self, batch_x, batch_u): 14 | input, label = self.parse_batch_train(batch_x, batch_u) 15 | output = self.model(input) 16 | loss = F.cross_entropy(output, label) 17 | self.model_backward_and_update(loss) 18 | 19 | loss_summary = { 20 | "loss": loss.item(), 21 | "acc": compute_accuracy(output, label)[0].item(), 22 | } 23 | 24 | if (self.batch_idx + 1) == self.num_batches: 25 | self.update_lr() 26 | 27 | return loss_summary 28 | 29 | def parse_batch_train(self, batch_x, batch_u): 30 | input = batch_x["img"] 31 | label = batch_x["label"] 32 | input = input.to(self.device) 33 | label = label.to(self.device) 34 | return input, label 35 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/dg/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddaig import DDAIG 2 | from .daeldg import DAELDG 3 | from .vanilla import Vanilla 4 | from .crossgrad import CrossGrad 5 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/dg/crossgrad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.optim import build_optimizer, build_lr_scheduler 5 | from dassl.utils import count_num_param 6 | from dassl.engine import TRAINER_REGISTRY, TrainerX 7 | from dassl.engine.trainer import SimpleNet 8 | 9 | 10 | @TRAINER_REGISTRY.register() 11 | class CrossGrad(TrainerX): 12 | """Cross-gradient training. 13 | 14 | https://arxiv.org/abs/1804.10745. 15 | """ 16 | 17 | def __init__(self, cfg): 18 | super().__init__(cfg) 19 | self.eps_f = cfg.TRAINER.CG.EPS_F 20 | self.eps_d = cfg.TRAINER.CG.EPS_D 21 | self.alpha_f = cfg.TRAINER.CG.ALPHA_F 22 | self.alpha_d = cfg.TRAINER.CG.ALPHA_D 23 | 24 | def build_model(self): 25 | cfg = self.cfg 26 | 27 | print("Building F") 28 | self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes) 29 | self.F.to(self.device) 30 | print("# params: {:,}".format(count_num_param(self.F))) 31 | self.optim_F = build_optimizer(self.F, cfg.OPTIM) 32 | self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM) 33 | self.register_model("F", self.F, self.optim_F, self.sched_F) 34 | 35 | print("Building D") 36 | self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains) 37 | self.D.to(self.device) 38 | print("# params: {:,}".format(count_num_param(self.D))) 39 | self.optim_D = build_optimizer(self.D, cfg.OPTIM) 40 | self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM) 41 | self.register_model("D", self.D, self.optim_D, self.sched_D) 42 | 43 | def forward_backward(self, batch): 44 | input, label, domain = self.parse_batch_train(batch) 45 | 46 | input.requires_grad = True 47 | 48 | # Compute domain perturbation 49 | loss_d = F.cross_entropy(self.D(input), domain) 50 | loss_d.backward() 51 | grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1) 52 | input_d = input.data + self.eps_f * grad_d 53 | 54 | # Compute label perturbation 55 | input.grad.data.zero_() 56 | loss_f = F.cross_entropy(self.F(input), label) 57 | loss_f.backward() 58 | grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1) 59 | input_f = input.data + self.eps_d * grad_f 60 | 61 | input = input.detach() 62 | 63 | # Update label net 64 | loss_f1 = F.cross_entropy(self.F(input), label) 65 | loss_f2 = F.cross_entropy(self.F(input_d), label) 66 | loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2 67 | self.model_backward_and_update(loss_f, "F") 68 | 69 | # Update domain net 70 | loss_d1 = F.cross_entropy(self.D(input), domain) 71 | loss_d2 = F.cross_entropy(self.D(input_f), domain) 72 | loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2 73 | self.model_backward_and_update(loss_d, "D") 74 | 75 | loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()} 76 | 77 | if (self.batch_idx + 1) == self.num_batches: 78 | self.update_lr() 79 | 80 | return loss_summary 81 | 82 | def model_inference(self, input): 83 | return self.F(input) 84 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/dg/vanilla.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerX 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class Vanilla(TrainerX): 9 | """Vanilla baseline.""" 10 | 11 | def forward_backward(self, batch): 12 | input, label = self.parse_batch_train(batch) 13 | output = self.model(input) 14 | loss = F.cross_entropy(output, label) 15 | self.model_backward_and_update(loss) 16 | 17 | loss_summary = { 18 | "loss": loss.item(), 19 | "acc": compute_accuracy(output, label)[0].item(), 20 | } 21 | 22 | if (self.batch_idx + 1) == self.num_batches: 23 | self.update_lr() 24 | 25 | return loss_summary 26 | 27 | def parse_batch_train(self, batch): 28 | input = batch["img"] 29 | label = batch["label"] 30 | input = input.to(self.device) 31 | label = label.to(self.device) 32 | return input, label 33 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/ssl/__init__.py: -------------------------------------------------------------------------------- 1 | from .entmin import EntMin 2 | from .fixmatch import FixMatch 3 | from .mixmatch import MixMatch 4 | from .mean_teacher import MeanTeacher 5 | from .sup_baseline import SupBaseline 6 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/ssl/entmin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | 7 | 8 | @TRAINER_REGISTRY.register() 9 | class EntMin(TrainerXU): 10 | """Entropy Minimization. 11 | 12 | http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf. 13 | """ 14 | 15 | def __init__(self, cfg): 16 | super().__init__(cfg) 17 | self.lmda = cfg.TRAINER.ENTMIN.LMDA 18 | 19 | def forward_backward(self, batch_x, batch_u): 20 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 21 | 22 | output_x = self.model(input_x) 23 | loss_x = F.cross_entropy(output_x, label_x) 24 | 25 | output_u = F.softmax(self.model(input_u), 1) 26 | loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean() 27 | 28 | loss = loss_x + loss_u * self.lmda 29 | 30 | self.model_backward_and_update(loss) 31 | 32 | loss_summary = { 33 | "loss_x": loss_x.item(), 34 | "acc_x": compute_accuracy(output_x, label_x)[0].item(), 35 | "loss_u": loss_u.item(), 36 | } 37 | 38 | if (self.batch_idx + 1) == self.num_batches: 39 | self.update_lr() 40 | 41 | return loss_summary 42 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/ssl/mean_teacher.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.metrics import compute_accuracy 6 | from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update 7 | 8 | 9 | @TRAINER_REGISTRY.register() 10 | class MeanTeacher(TrainerXU): 11 | """Mean teacher. 12 | 13 | https://arxiv.org/abs/1703.01780. 14 | """ 15 | 16 | def __init__(self, cfg): 17 | super().__init__(cfg) 18 | self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U 19 | self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA 20 | self.rampup = cfg.TRAINER.MEANTEA.RAMPUP 21 | 22 | self.teacher = copy.deepcopy(self.model) 23 | self.teacher.train() 24 | for param in self.teacher.parameters(): 25 | param.requires_grad_(False) 26 | 27 | def forward_backward(self, batch_x, batch_u): 28 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 29 | 30 | logit_x = self.model(input_x) 31 | loss_x = F.cross_entropy(logit_x, label_x) 32 | 33 | target_u = F.softmax(self.teacher(input_u), 1) 34 | prob_u = F.softmax(self.model(input_u), 1) 35 | loss_u = ((prob_u - target_u)**2).sum(1).mean() 36 | 37 | weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup) 38 | loss = loss_x + loss_u*weight_u 39 | self.model_backward_and_update(loss) 40 | 41 | global_step = self.batch_idx + self.epoch * self.num_batches 42 | ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha) 43 | ema_model_update(self.model, self.teacher, ema_alpha) 44 | 45 | loss_summary = { 46 | "loss_x": loss_x.item(), 47 | "acc_x": compute_accuracy(logit_x, label_x)[0].item(), 48 | "loss_u": loss_u.item(), 49 | } 50 | 51 | if (self.batch_idx + 1) == self.num_batches: 52 | self.update_lr() 53 | 54 | return loss_summary 55 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/ssl/mixmatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 5 | from dassl.modeling.ops import mixup 6 | from dassl.modeling.ops.utils import ( 7 | sharpen_prob, create_onehot, linear_rampup, shuffle_index 8 | ) 9 | 10 | 11 | @TRAINER_REGISTRY.register() 12 | class MixMatch(TrainerXU): 13 | """MixMatch: A Holistic Approach to Semi-Supervised Learning. 14 | 15 | https://arxiv.org/abs/1905.02249. 16 | """ 17 | 18 | def __init__(self, cfg): 19 | super().__init__(cfg) 20 | self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U 21 | self.temp = cfg.TRAINER.MIXMATCH.TEMP 22 | self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA 23 | self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP 24 | 25 | def check_cfg(self, cfg): 26 | assert cfg.DATALOADER.K_TRANSFORMS > 1 27 | 28 | def forward_backward(self, batch_x, batch_u): 29 | input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u) 30 | num_x = input_x.shape[0] 31 | 32 | global_step = self.batch_idx + self.epoch * self.num_batches 33 | weight_u = self.weight_u * linear_rampup(global_step, self.rampup) 34 | 35 | # Generate pseudo-label for unlabeled data 36 | with torch.no_grad(): 37 | output_u = 0 38 | for input_ui in input_u: 39 | output_ui = F.softmax(self.model(input_ui), 1) 40 | output_u += output_ui 41 | output_u /= len(input_u) 42 | label_u = sharpen_prob(output_u, self.temp) 43 | label_u = [label_u] * len(input_u) 44 | label_u = torch.cat(label_u, 0) 45 | input_u = torch.cat(input_u, 0) 46 | 47 | # Combine and shuffle labeled and unlabeled data 48 | input_xu = torch.cat([input_x, input_u], 0) 49 | label_xu = torch.cat([label_x, label_u], 0) 50 | input_xu, label_xu = shuffle_index(input_xu, label_xu) 51 | 52 | # Mixup 53 | input_x, label_x = mixup( 54 | input_x, 55 | input_xu[:num_x], 56 | label_x, 57 | label_xu[:num_x], 58 | self.beta, 59 | preserve_order=True, 60 | ) 61 | 62 | input_u, label_u = mixup( 63 | input_u, 64 | input_xu[num_x:], 65 | label_u, 66 | label_xu[num_x:], 67 | self.beta, 68 | preserve_order=True, 69 | ) 70 | 71 | # Compute losses 72 | output_x = F.softmax(self.model(input_x), 1) 73 | loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean() 74 | 75 | output_u = F.softmax(self.model(input_u), 1) 76 | loss_u = ((label_u - output_u)**2).mean() 77 | 78 | loss = loss_x + loss_u*weight_u 79 | self.model_backward_and_update(loss) 80 | 81 | loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()} 82 | 83 | if (self.batch_idx + 1) == self.num_batches: 84 | self.update_lr() 85 | 86 | return loss_summary 87 | 88 | def parse_batch_train(self, batch_x, batch_u): 89 | input_x = batch_x["img"][0] 90 | label_x = batch_x["label"] 91 | label_x = create_onehot(label_x, self.num_classes) 92 | input_u = batch_u["img"] 93 | 94 | input_x = input_x.to(self.device) 95 | label_x = label_x.to(self.device) 96 | input_u = [input_ui.to(self.device) for input_ui in input_u] 97 | 98 | return input_x, label_x, input_u 99 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/engine/ssl/sup_baseline.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from dassl.engine import TRAINER_REGISTRY, TrainerXU 4 | from dassl.metrics import compute_accuracy 5 | 6 | 7 | @TRAINER_REGISTRY.register() 8 | class SupBaseline(TrainerXU): 9 | """Supervised Baseline.""" 10 | 11 | def forward_backward(self, batch_x, batch_u): 12 | input, label = self.parse_batch_train(batch_x, batch_u) 13 | output = self.model(input) 14 | loss = F.cross_entropy(output, label) 15 | self.model_backward_and_update(loss) 16 | 17 | loss_summary = { 18 | "loss": loss.item(), 19 | "acc": compute_accuracy(output, label)[0].item(), 20 | } 21 | 22 | if (self.batch_idx + 1) == self.num_batches: 23 | self.update_lr() 24 | 25 | return loss_summary 26 | 27 | def parse_batch_train(self, batch_x, batch_u): 28 | input = batch_x["img"] 29 | label = batch_x["label"] 30 | input = input.to(self.device) 31 | label = label.to(self.device) 32 | return input, label 33 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_evaluator, EVALUATOR_REGISTRY # isort:skip 2 | 3 | from .evaluator import EvaluatorBase, Classification 4 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/evaluation/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | EVALUATOR_REGISTRY = Registry("EVALUATOR") 4 | 5 | 6 | def build_evaluator(cfg, **kwargs): 7 | avai_evaluators = EVALUATOR_REGISTRY.registered_names() 8 | check_availability(cfg.TEST.EVALUATOR, avai_evaluators) 9 | if cfg.VERBOSE: 10 | print("Loading evaluator: {}".format(cfg.TEST.EVALUATOR)) 11 | return EVALUATOR_REGISTRY.get(cfg.TEST.EVALUATOR)(cfg, **kwargs) 12 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import compute_accuracy 2 | from .distance import ( 3 | cosine_distance, compute_distance_matrix, euclidean_squared_distance 4 | ) 5 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | def compute_accuracy(output, target, topk=(1, )): 2 | """Computes the accuracy over the k top predictions for 3 | the specified values of k. 4 | 5 | Args: 6 | output (torch.Tensor): prediction matrix with shape (batch_size, num_classes). 7 | target (torch.LongTensor): ground truth labels with shape (batch_size). 8 | topk (tuple, optional): accuracy at top-k will be computed. For example, 9 | topk=(1, 5) means accuracy at top-1 and top-5 will be computed. 10 | 11 | Returns: 12 | list: accuracy at top-k. 13 | """ 14 | maxk = max(topk) 15 | batch_size = target.size(0) 16 | 17 | if isinstance(output, (tuple, list)): 18 | output = output[0] 19 | 20 | _, pred = output.topk(maxk, 1, True, True) 21 | pred = pred.t() 22 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 23 | 24 | res = [] 25 | for k in topk: 26 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 27 | acc = correct_k.mul_(100.0 / batch_size) 28 | res.append(acc) 29 | 30 | return res 31 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/metrics/distance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source: https://github.com/KaiyangZhou/deep-person-reid 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | def compute_distance_matrix(input1, input2, metric="euclidean"): 9 | """A wrapper function for computing distance matrix. 10 | 11 | Each input matrix has the shape (n_data, feature_dim). 12 | 13 | Args: 14 | input1 (torch.Tensor): 2-D feature matrix. 15 | input2 (torch.Tensor): 2-D feature matrix. 16 | metric (str, optional): "euclidean" or "cosine". 17 | Default is "euclidean". 18 | 19 | Returns: 20 | torch.Tensor: distance matrix. 21 | """ 22 | # check input 23 | assert isinstance(input1, torch.Tensor) 24 | assert isinstance(input2, torch.Tensor) 25 | assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format( 26 | input1.dim() 27 | ) 28 | assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format( 29 | input2.dim() 30 | ) 31 | assert input1.size(1) == input2.size(1) 32 | 33 | if metric == "euclidean": 34 | distmat = euclidean_squared_distance(input1, input2) 35 | elif metric == "cosine": 36 | distmat = cosine_distance(input1, input2) 37 | else: 38 | raise ValueError( 39 | "Unknown distance metric: {}. " 40 | 'Please choose either "euclidean" or "cosine"'.format(metric) 41 | ) 42 | 43 | return distmat 44 | 45 | 46 | def euclidean_squared_distance(input1, input2): 47 | """Computes euclidean squared distance. 48 | 49 | Args: 50 | input1 (torch.Tensor): 2-D feature matrix. 51 | input2 (torch.Tensor): 2-D feature matrix. 52 | 53 | Returns: 54 | torch.Tensor: distance matrix. 55 | """ 56 | m, n = input1.size(0), input2.size(0) 57 | mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) 58 | mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 59 | distmat = mat1 + mat2 60 | distmat.addmm_(1, -2, input1, input2.t()) 61 | return distmat 62 | 63 | 64 | def cosine_distance(input1, input2): 65 | """Computes cosine distance. 66 | 67 | Args: 68 | input1 (torch.Tensor): 2-D feature matrix. 69 | input2 (torch.Tensor): 2-D feature matrix. 70 | 71 | Returns: 72 | torch.Tensor: distance matrix. 73 | """ 74 | input1_normed = F.normalize(input1, p=2, dim=1) 75 | input2_normed = F.normalize(input2, p=2, dim=1) 76 | distmat = 1 - torch.mm(input1_normed, input2_normed.t()) 77 | return distmat 78 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .head import HEAD_REGISTRY, build_head 2 | from .network import NETWORK_REGISTRY, build_network 3 | from .backbone import BACKBONE_REGISTRY, Backbone, build_backbone 4 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_backbone, BACKBONE_REGISTRY # isort:skip 2 | from .backbone import Backbone # isort:skip 3 | 4 | from .vgg import vgg16 5 | from .resnet import ( 6 | resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1, 7 | resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1, 8 | resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123, 9 | resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12, 10 | resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123, 11 | resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123 12 | ) 13 | from .alexnet import alexnet 14 | from .mobilenetv2 import mobilenetv2 15 | from .wide_resnet import wide_resnet_16_4, wide_resnet_28_2 16 | from .cnn_digitsdg import cnn_digitsdg 17 | from .efficientnet import ( 18 | efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3, 19 | efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 20 | ) 21 | from .shufflenetv2 import ( 22 | shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, 23 | shufflenet_v2_x2_0 24 | ) 25 | from .cnn_digitsingle import cnn_digitsingle 26 | from .preact_resnet18 import preact_resnet18 27 | from .cnn_digit5_m3sda import cnn_digit5_m3sda 28 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/backbone/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from .build import BACKBONE_REGISTRY 6 | from .backbone import Backbone 7 | 8 | model_urls = { 9 | "alexnet": "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth", 10 | } 11 | 12 | 13 | class AlexNet(Backbone): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | # Note that self.classifier outputs features rather than logits 34 | self.classifier = nn.Sequential( 35 | nn.Dropout(), 36 | nn.Linear(256 * 6 * 6, 4096), 37 | nn.ReLU(inplace=True), 38 | nn.Dropout(), 39 | nn.Linear(4096, 4096), 40 | nn.ReLU(inplace=True), 41 | ) 42 | 43 | self._out_features = 4096 44 | 45 | def forward(self, x): 46 | x = self.features(x) 47 | x = self.avgpool(x) 48 | x = torch.flatten(x, 1) 49 | return self.classifier(x) 50 | 51 | 52 | def init_pretrained_weights(model, model_url): 53 | pretrain_dict = model_zoo.load_url(model_url) 54 | model.load_state_dict(pretrain_dict, strict=False) 55 | 56 | 57 | @BACKBONE_REGISTRY.register() 58 | def alexnet(pretrained=True, **kwargs): 59 | model = AlexNet() 60 | 61 | if pretrained: 62 | init_pretrained_weights(model, model_urls["alexnet"]) 63 | 64 | return model 65 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Backbone(nn.Module): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self): 10 | pass 11 | 12 | @property 13 | def out_features(self): 14 | """Output feature dimension.""" 15 | if self.__dict__.get("_out_features") is None: 16 | return None 17 | return self._out_features 18 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/backbone/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | BACKBONE_REGISTRY = Registry("BACKBONE") 4 | 5 | 6 | def build_backbone(name, verbose=True, **kwargs): 7 | avai_backbones = BACKBONE_REGISTRY.registered_names() 8 | check_availability(name, avai_backbones) 9 | if verbose: 10 | print("Backbone: {}".format(name)) 11 | return BACKBONE_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/backbone/cnn_digit5_m3sda.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference 3 | 4 | https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA 5 | """ 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | 9 | from .build import BACKBONE_REGISTRY 10 | from .backbone import Backbone 11 | 12 | 13 | class FeatureExtractor(Backbone): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2) 18 | self.bn1 = nn.BatchNorm2d(64) 19 | self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2) 20 | self.bn2 = nn.BatchNorm2d(64) 21 | self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2) 22 | self.bn3 = nn.BatchNorm2d(128) 23 | self.fc1 = nn.Linear(8192, 3072) 24 | self.bn1_fc = nn.BatchNorm1d(3072) 25 | self.fc2 = nn.Linear(3072, 2048) 26 | self.bn2_fc = nn.BatchNorm1d(2048) 27 | 28 | self._out_features = 2048 29 | 30 | def _check_input(self, x): 31 | H, W = x.shape[2:] 32 | assert ( 33 | H == 32 and W == 32 34 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 35 | 36 | def forward(self, x): 37 | self._check_input(x) 38 | x = F.relu(self.bn1(self.conv1(x))) 39 | x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1) 40 | x = F.relu(self.bn2(self.conv2(x))) 41 | x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1) 42 | x = F.relu(self.bn3(self.conv3(x))) 43 | x = x.view(x.size(0), 8192) 44 | x = F.relu(self.bn1_fc(self.fc1(x))) 45 | x = F.dropout(x, training=self.training) 46 | x = F.relu(self.bn2_fc(self.fc2(x))) 47 | return x 48 | 49 | 50 | @BACKBONE_REGISTRY.register() 51 | def cnn_digit5_m3sda(**kwargs): 52 | """ 53 | This architecture was used for the Digit-5 dataset in: 54 | 55 | - Peng et al. Moment Matching for Multi-Source 56 | Domain Adaptation. ICCV 2019. 57 | """ 58 | return FeatureExtractor() 59 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/backbone/cnn_digitsdg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | 4 | from dassl.utils import init_network_weights 5 | 6 | from .build import BACKBONE_REGISTRY 7 | from .backbone import Backbone 8 | 9 | 10 | class Convolution(nn.Module): 11 | 12 | def __init__(self, c_in, c_out): 13 | super().__init__() 14 | self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1) 15 | self.relu = nn.ReLU(True) 16 | 17 | def forward(self, x): 18 | return self.relu(self.conv(x)) 19 | 20 | 21 | class ConvNet(Backbone): 22 | 23 | def __init__(self, c_hidden=64): 24 | super().__init__() 25 | self.conv1 = Convolution(3, c_hidden) 26 | self.conv2 = Convolution(c_hidden, c_hidden) 27 | self.conv3 = Convolution(c_hidden, c_hidden) 28 | self.conv4 = Convolution(c_hidden, c_hidden) 29 | 30 | self._out_features = 2**2 * c_hidden 31 | 32 | def _check_input(self, x): 33 | H, W = x.shape[2:] 34 | assert ( 35 | H == 32 and W == 32 36 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 37 | 38 | def forward(self, x): 39 | self._check_input(x) 40 | x = self.conv1(x) 41 | x = F.max_pool2d(x, 2) 42 | x = self.conv2(x) 43 | x = F.max_pool2d(x, 2) 44 | x = self.conv3(x) 45 | x = F.max_pool2d(x, 2) 46 | x = self.conv4(x) 47 | x = F.max_pool2d(x, 2) 48 | return x.view(x.size(0), -1) 49 | 50 | 51 | @BACKBONE_REGISTRY.register() 52 | def cnn_digitsdg(**kwargs): 53 | """ 54 | This architecture was used for DigitsDG dataset in: 55 | 56 | - Zhou et al. Deep Domain-Adversarial Image Generation 57 | for Domain Generalisation. AAAI 2020. 58 | """ 59 | model = ConvNet(c_hidden=64) 60 | init_network_weights(model, init_type="kaiming") 61 | return model 62 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/backbone/cnn_digitsingle.py: -------------------------------------------------------------------------------- 1 | """ 2 | This model is built based on 3 | https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py 4 | """ 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | from dassl.utils import init_network_weights 9 | 10 | from .build import BACKBONE_REGISTRY 11 | from .backbone import Backbone 12 | 13 | 14 | class CNN(Backbone): 15 | 16 | def __init__(self): 17 | super().__init__() 18 | self.conv1 = nn.Conv2d(3, 64, 5) 19 | self.conv2 = nn.Conv2d(64, 128, 5) 20 | self.fc3 = nn.Linear(5 * 5 * 128, 1024) 21 | self.fc4 = nn.Linear(1024, 1024) 22 | 23 | self._out_features = 1024 24 | 25 | def _check_input(self, x): 26 | H, W = x.shape[2:] 27 | assert ( 28 | H == 32 and W == 32 29 | ), "Input to network must be 32x32, " "but got {}x{}".format(H, W) 30 | 31 | def forward(self, x): 32 | self._check_input(x) 33 | x = self.conv1(x) 34 | x = F.relu(x) 35 | x = F.max_pool2d(x, 2) 36 | 37 | x = self.conv2(x) 38 | x = F.relu(x) 39 | x = F.max_pool2d(x, 2) 40 | 41 | x = x.view(x.size(0), -1) 42 | 43 | x = self.fc3(x) 44 | x = F.relu(x) 45 | 46 | x = self.fc4(x) 47 | x = F.relu(x) 48 | 49 | return x 50 | 51 | 52 | @BACKBONE_REGISTRY.register() 53 | def cnn_digitsingle(**kwargs): 54 | model = CNN() 55 | init_network_weights(model, init_type="kaiming") 56 | return model 57 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/backbone/efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source: https://github.com/lukemelas/EfficientNet-PyTorch. 3 | """ 4 | __version__ = "0.6.4" 5 | from .model import ( 6 | EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2, 7 | efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, 8 | efficientnet_b7 9 | ) 10 | from .utils import ( 11 | BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params 12 | ) 13 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_head, HEAD_REGISTRY # isort:skip 2 | 3 | from .mlp import mlp 4 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/head/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | HEAD_REGISTRY = Registry("HEAD") 4 | 5 | 6 | def build_head(name, verbose=True, **kwargs): 7 | avai_heads = HEAD_REGISTRY.registered_names() 8 | check_availability(name, avai_heads) 9 | if verbose: 10 | print("Head: {}".format(name)) 11 | return HEAD_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/head/mlp.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | from .build import HEAD_REGISTRY 5 | 6 | 7 | class MLP(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | in_features=2048, 12 | hidden_layers=[], 13 | activation="relu", 14 | bn=True, 15 | dropout=0.0, 16 | ): 17 | super().__init__() 18 | if isinstance(hidden_layers, int): 19 | hidden_layers = [hidden_layers] 20 | 21 | assert len(hidden_layers) > 0 22 | self.out_features = hidden_layers[-1] 23 | 24 | mlp = [] 25 | 26 | if activation == "relu": 27 | act_fn = functools.partial(nn.ReLU, inplace=True) 28 | elif activation == "leaky_relu": 29 | act_fn = functools.partial(nn.LeakyReLU, inplace=True) 30 | else: 31 | raise NotImplementedError 32 | 33 | for hidden_dim in hidden_layers: 34 | mlp += [nn.Linear(in_features, hidden_dim)] 35 | if bn: 36 | mlp += [nn.BatchNorm1d(hidden_dim)] 37 | mlp += [act_fn()] 38 | if dropout > 0: 39 | mlp += [nn.Dropout(dropout)] 40 | in_features = hidden_dim 41 | 42 | self.mlp = nn.Sequential(*mlp) 43 | 44 | def forward(self, x): 45 | return self.mlp(x) 46 | 47 | 48 | @HEAD_REGISTRY.register() 49 | def mlp(**kwargs): 50 | return MLP(**kwargs) 51 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_network, NETWORK_REGISTRY # isort:skip 2 | 3 | from .ddaig_fcn import ( 4 | fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn 5 | ) 6 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/network/build.py: -------------------------------------------------------------------------------- 1 | from dassl.utils import Registry, check_availability 2 | 3 | NETWORK_REGISTRY = Registry("NETWORK") 4 | 5 | 6 | def build_network(name, verbose=True, **kwargs): 7 | avai_models = NETWORK_REGISTRY.registered_names() 8 | check_availability(name, avai_models) 9 | if verbose: 10 | print("Network: {}".format(name)) 11 | return NETWORK_REGISTRY.get(name)(**kwargs) 12 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .mmd import MaximumMeanDiscrepancy 2 | from .dsbn import DSBN1d, DSBN2d 3 | from .mixup import mixup 4 | from .efdmix import ( 5 | EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix, 6 | crossdomain_efdmix, run_without_efdmix 7 | ) 8 | from .mixstyle import ( 9 | MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle, 10 | deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle 11 | ) 12 | from .transnorm import TransNorm1d, TransNorm2d 13 | from .sequential2 import Sequential2 14 | from .reverse_grad import ReverseGrad 15 | from .cross_entropy import cross_entropy 16 | from .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance 17 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def cross_entropy(input, target, label_smooth=0, reduction="mean"): 6 | """Cross entropy loss. 7 | 8 | Args: 9 | input (torch.Tensor): logit matrix with shape of (batch, num_classes). 10 | target (torch.LongTensor): int label matrix. 11 | label_smooth (float, optional): label smoothing hyper-parameter. 12 | Default is 0. 13 | reduction (str, optional): how the losses for a mini-batch 14 | will be aggregated. Default is 'mean'. 15 | """ 16 | num_classes = input.shape[1] 17 | log_prob = F.log_softmax(input, dim=1) 18 | zeros = torch.zeros(log_prob.size()) 19 | target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1) 20 | target = target.type_as(input) 21 | target = (1-label_smooth) * target + label_smooth/num_classes 22 | loss = (-target * log_prob).sum(1) 23 | if reduction == "mean": 24 | return loss.mean() 25 | elif reduction == "sum": 26 | return loss.sum() 27 | elif reduction == "none": 28 | return loss 29 | else: 30 | raise ValueError 31 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/dsbn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class _DSBN(nn.Module): 5 | """Domain Specific Batch Normalization. 6 | 7 | Args: 8 | num_features (int): number of features. 9 | n_domain (int): number of domains. 10 | bn_type (str): type of bn. Choices are ['1d', '2d']. 11 | """ 12 | 13 | def __init__(self, num_features, n_domain, bn_type): 14 | super().__init__() 15 | if bn_type == "1d": 16 | BN = nn.BatchNorm1d 17 | elif bn_type == "2d": 18 | BN = nn.BatchNorm2d 19 | else: 20 | raise ValueError 21 | 22 | self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain)) 23 | 24 | self.valid_domain_idxs = list(range(n_domain)) 25 | self.n_domain = n_domain 26 | self.domain_idx = 0 27 | 28 | def select_bn(self, domain_idx=0): 29 | assert domain_idx in self.valid_domain_idxs 30 | self.domain_idx = domain_idx 31 | 32 | def forward(self, x): 33 | return self.bn[self.domain_idx](x) 34 | 35 | 36 | class DSBN1d(_DSBN): 37 | 38 | def __init__(self, num_features, n_domain): 39 | super().__init__(num_features, n_domain, "1d") 40 | 41 | 42 | class DSBN2d(_DSBN): 43 | 44 | def __init__(self, num_features, n_domain): 45 | super().__init__(num_features, n_domain, "2d") 46 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/efdmix.py: -------------------------------------------------------------------------------- 1 | import random 2 | from contextlib import contextmanager 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def deactivate_efdmix(m): 8 | if type(m) == EFDMix: 9 | m.set_activation_status(False) 10 | 11 | 12 | def activate_efdmix(m): 13 | if type(m) == EFDMix: 14 | m.set_activation_status(True) 15 | 16 | 17 | def random_efdmix(m): 18 | if type(m) == EFDMix: 19 | m.update_mix_method("random") 20 | 21 | 22 | def crossdomain_efdmix(m): 23 | if type(m) == EFDMix: 24 | m.update_mix_method("crossdomain") 25 | 26 | 27 | @contextmanager 28 | def run_without_efdmix(model): 29 | # Assume MixStyle was initially activated 30 | try: 31 | model.apply(deactivate_efdmix) 32 | yield 33 | finally: 34 | model.apply(activate_efdmix) 35 | 36 | 37 | @contextmanager 38 | def run_with_efdmix(model, mix=None): 39 | # Assume MixStyle was initially deactivated 40 | if mix == "random": 41 | model.apply(random_efdmix) 42 | 43 | elif mix == "crossdomain": 44 | model.apply(crossdomain_efdmix) 45 | 46 | try: 47 | model.apply(activate_efdmix) 48 | yield 49 | finally: 50 | model.apply(deactivate_efdmix) 51 | 52 | 53 | class EFDMix(nn.Module): 54 | """EFDMix. 55 | 56 | Reference: 57 | Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022. 58 | """ 59 | 60 | def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): 61 | """ 62 | Args: 63 | p (float): probability of using MixStyle. 64 | alpha (float): parameter of the Beta distribution. 65 | eps (float): scaling parameter to avoid numerical issues. 66 | mix (str): how to mix. 67 | """ 68 | super().__init__() 69 | self.p = p 70 | self.beta = torch.distributions.Beta(alpha, alpha) 71 | self.eps = eps 72 | self.alpha = alpha 73 | self.mix = mix 74 | self._activated = True 75 | 76 | def __repr__(self): 77 | return ( 78 | f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" 79 | ) 80 | 81 | def set_activation_status(self, status=True): 82 | self._activated = status 83 | 84 | def update_mix_method(self, mix="random"): 85 | self.mix = mix 86 | 87 | def forward(self, x): 88 | if not self.training or not self._activated: 89 | return x 90 | 91 | if random.random() > self.p: 92 | return x 93 | 94 | B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3) 95 | x_view = x.view(B, C, -1) 96 | value_x, index_x = torch.sort(x_view) # sort inputs 97 | lmda = self.beta.sample((B, 1, 1)) 98 | lmda = lmda.to(x.device) 99 | 100 | if self.mix == "random": 101 | # random shuffle 102 | perm = torch.randperm(B) 103 | 104 | elif self.mix == "crossdomain": 105 | # split into two halves and swap the order 106 | perm = torch.arange(B - 1, -1, -1) # inverse index 107 | perm_b, perm_a = perm.chunk(2) 108 | perm_b = perm_b[torch.randperm(perm_b.shape[0])] 109 | perm_a = perm_a[torch.randperm(perm_a.shape[0])] 110 | perm = torch.cat([perm_b, perm_a], 0) 111 | 112 | else: 113 | raise NotImplementedError 114 | 115 | inverse_index = index_x.argsort(-1) 116 | x_view_copy = value_x[perm].gather(-1, inverse_index) 117 | new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda) 118 | return new_x.view(B, C, W, H) 119 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/mixstyle.py: -------------------------------------------------------------------------------- 1 | import random 2 | from contextlib import contextmanager 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def deactivate_mixstyle(m): 8 | if type(m) == MixStyle: 9 | m.set_activation_status(False) 10 | 11 | 12 | def activate_mixstyle(m): 13 | if type(m) == MixStyle: 14 | m.set_activation_status(True) 15 | 16 | 17 | def random_mixstyle(m): 18 | if type(m) == MixStyle: 19 | m.update_mix_method("random") 20 | 21 | 22 | def crossdomain_mixstyle(m): 23 | if type(m) == MixStyle: 24 | m.update_mix_method("crossdomain") 25 | 26 | 27 | @contextmanager 28 | def run_without_mixstyle(model): 29 | # Assume MixStyle was initially activated 30 | try: 31 | model.apply(deactivate_mixstyle) 32 | yield 33 | finally: 34 | model.apply(activate_mixstyle) 35 | 36 | 37 | @contextmanager 38 | def run_with_mixstyle(model, mix=None): 39 | # Assume MixStyle was initially deactivated 40 | if mix == "random": 41 | model.apply(random_mixstyle) 42 | 43 | elif mix == "crossdomain": 44 | model.apply(crossdomain_mixstyle) 45 | 46 | try: 47 | model.apply(activate_mixstyle) 48 | yield 49 | finally: 50 | model.apply(deactivate_mixstyle) 51 | 52 | 53 | class MixStyle(nn.Module): 54 | """MixStyle. 55 | 56 | Reference: 57 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 58 | """ 59 | 60 | def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"): 61 | """ 62 | Args: 63 | p (float): probability of using MixStyle. 64 | alpha (float): parameter of the Beta distribution. 65 | eps (float): scaling parameter to avoid numerical issues. 66 | mix (str): how to mix. 67 | """ 68 | super().__init__() 69 | self.p = p 70 | self.beta = torch.distributions.Beta(alpha, alpha) 71 | self.eps = eps 72 | self.alpha = alpha 73 | self.mix = mix 74 | self._activated = True 75 | 76 | def __repr__(self): 77 | return ( 78 | f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})" 79 | ) 80 | 81 | def set_activation_status(self, status=True): 82 | self._activated = status 83 | 84 | def update_mix_method(self, mix="random"): 85 | self.mix = mix 86 | 87 | def forward(self, x): 88 | if not self.training or not self._activated: 89 | return x 90 | 91 | if random.random() > self.p: 92 | return x 93 | 94 | B = x.size(0) 95 | 96 | mu = x.mean(dim=[2, 3], keepdim=True) 97 | var = x.var(dim=[2, 3], keepdim=True) 98 | sig = (var + self.eps).sqrt() 99 | mu, sig = mu.detach(), sig.detach() 100 | x_normed = (x-mu) / sig 101 | 102 | lmda = self.beta.sample((B, 1, 1, 1)) 103 | lmda = lmda.to(x.device) 104 | 105 | if self.mix == "random": 106 | # random shuffle 107 | perm = torch.randperm(B) 108 | 109 | elif self.mix == "crossdomain": 110 | # split into two halves and swap the order 111 | perm = torch.arange(B - 1, -1, -1) # inverse index 112 | perm_b, perm_a = perm.chunk(2) 113 | perm_b = perm_b[torch.randperm(perm_b.shape[0])] 114 | perm_a = perm_a[torch.randperm(perm_a.shape[0])] 115 | perm = torch.cat([perm_b, perm_a], 0) 116 | 117 | else: 118 | raise NotImplementedError 119 | 120 | mu2, sig2 = mu[perm], sig[perm] 121 | mu_mix = mu*lmda + mu2 * (1-lmda) 122 | sig_mix = sig*lmda + sig2 * (1-lmda) 123 | 124 | return x_normed*sig_mix + mu_mix 125 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mixup(x1, x2, y1, y2, beta, preserve_order=False): 5 | """Mixup. 6 | 7 | Args: 8 | x1 (torch.Tensor): data with shape of (b, c, h, w). 9 | x2 (torch.Tensor): data with shape of (b, c, h, w). 10 | y1 (torch.Tensor): label with shape of (b, n). 11 | y2 (torch.Tensor): label with shape of (b, n). 12 | beta (float): hyper-parameter for Beta sampling. 13 | preserve_order (bool): apply lmda=max(lmda, 1-lmda). 14 | Default is False. 15 | """ 16 | lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1]) 17 | if preserve_order: 18 | lmda = torch.max(lmda, 1 - lmda) 19 | lmda = lmda.to(x1.device) 20 | xmix = x1*lmda + x2 * (1-lmda) 21 | lmda = lmda[:, :, 0, 0] 22 | ymix = y1*lmda + y2 * (1-lmda) 23 | return xmix, ymix 24 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class MaximumMeanDiscrepancy(nn.Module): 7 | 8 | def __init__(self, kernel_type="rbf", normalize=False): 9 | super().__init__() 10 | self.kernel_type = kernel_type 11 | self.normalize = normalize 12 | 13 | def forward(self, x, y): 14 | # x, y: two batches of data with shape (batch, dim) 15 | # MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y') 16 | if self.normalize: 17 | x = F.normalize(x, dim=1) 18 | y = F.normalize(y, dim=1) 19 | if self.kernel_type == "linear": 20 | return self.linear_mmd(x, y) 21 | elif self.kernel_type == "poly": 22 | return self.poly_mmd(x, y) 23 | elif self.kernel_type == "rbf": 24 | return self.rbf_mmd(x, y) 25 | else: 26 | raise NotImplementedError 27 | 28 | def linear_mmd(self, x, y): 29 | # k(x, y) = x^T y 30 | k_xx = self.remove_self_distance(torch.mm(x, x.t())) 31 | k_yy = self.remove_self_distance(torch.mm(y, y.t())) 32 | k_xy = torch.mm(x, y.t()) 33 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 34 | 35 | def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2): 36 | # k(x, y) = (alpha * x^T y + c)^d 37 | k_xx = self.remove_self_distance(torch.mm(x, x.t())) 38 | k_xx = (alpha*k_xx + c).pow(d) 39 | k_yy = self.remove_self_distance(torch.mm(y, y.t())) 40 | k_yy = (alpha*k_yy + c).pow(d) 41 | k_xy = torch.mm(x, y.t()) 42 | k_xy = (alpha*k_xy + c).pow(d) 43 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 44 | 45 | def rbf_mmd(self, x, y): 46 | # k_xx 47 | d_xx = self.euclidean_squared_distance(x, x) 48 | d_xx = self.remove_self_distance(d_xx) 49 | k_xx = self.rbf_kernel_mixture(d_xx) 50 | # k_yy 51 | d_yy = self.euclidean_squared_distance(y, y) 52 | d_yy = self.remove_self_distance(d_yy) 53 | k_yy = self.rbf_kernel_mixture(d_yy) 54 | # k_xy 55 | d_xy = self.euclidean_squared_distance(x, y) 56 | k_xy = self.rbf_kernel_mixture(d_xy) 57 | return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean() 58 | 59 | @staticmethod 60 | def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]): 61 | K = 0 62 | for sigma in sigmas: 63 | gamma = 1.0 / (2.0 * sigma**2) 64 | K += torch.exp(-gamma * exponent) 65 | return K 66 | 67 | @staticmethod 68 | def remove_self_distance(distmat): 69 | tmp_list = [] 70 | for i, row in enumerate(distmat): 71 | row1 = torch.cat([row[:i], row[i + 1:]]) 72 | tmp_list.append(row1) 73 | return torch.stack(tmp_list) 74 | 75 | @staticmethod 76 | def euclidean_squared_distance(x, y): 77 | m, n = x.size(0), y.size(0) 78 | distmat = ( 79 | torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + 80 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 81 | ) 82 | # distmat.addmm_(1, -2, x, y.t()) 83 | distmat.addmm_(x, y.t(), beta=1, alpha=-2) 84 | return distmat 85 | 86 | 87 | if __name__ == "__main__": 88 | mmd = MaximumMeanDiscrepancy(kernel_type="rbf") 89 | input1, input2 = torch.rand(3, 100), torch.rand(3, 100) 90 | d = mmd(input1, input2) 91 | print(d.item()) 92 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/reverse_grad.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Function 3 | 4 | 5 | class _ReverseGrad(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input, grad_scaling): 9 | ctx.grad_scaling = grad_scaling 10 | return input.view_as(input) 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | grad_scaling = ctx.grad_scaling 15 | return -grad_scaling * grad_output, None 16 | 17 | 18 | reverse_grad = _ReverseGrad.apply 19 | 20 | 21 | class ReverseGrad(nn.Module): 22 | """Gradient reversal layer. 23 | 24 | It acts as an identity layer in the forward, 25 | but reverses the sign of the gradient in 26 | the backward. 27 | """ 28 | 29 | def forward(self, x, grad_scaling=1.0): 30 | assert (grad_scaling >= 31 | 0), "grad_scaling must be non-negative, " "but got {}".format( 32 | grad_scaling 33 | ) 34 | return reverse_grad(x, grad_scaling) 35 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/sequential2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Sequential2(nn.Sequential): 5 | """An alternative sequential container to nn.Sequential, 6 | which accepts an arbitrary number of input arguments. 7 | """ 8 | 9 | def forward(self, *inputs): 10 | for module in self._modules.values(): 11 | if isinstance(inputs, tuple): 12 | inputs = module(*inputs) 13 | else: 14 | inputs = module(inputs) 15 | return inputs 16 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/modeling/ops/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def sharpen_prob(p, temperature=2): 6 | """Sharpening probability with a temperature. 7 | 8 | Args: 9 | p (torch.Tensor): probability matrix (batch_size, n_classes) 10 | temperature (float): temperature. 11 | """ 12 | p = p.pow(temperature) 13 | return p / p.sum(1, keepdim=True) 14 | 15 | 16 | def reverse_index(data, label): 17 | """Reverse order.""" 18 | inv_idx = torch.arange(data.size(0) - 1, -1, -1).long() 19 | return data[inv_idx], label[inv_idx] 20 | 21 | 22 | def shuffle_index(data, label): 23 | """Shuffle order.""" 24 | rnd_idx = torch.randperm(data.shape[0]) 25 | return data[rnd_idx], label[rnd_idx] 26 | 27 | 28 | def create_onehot(label, num_classes): 29 | """Create one-hot tensor. 30 | 31 | We suggest using nn.functional.one_hot. 32 | 33 | Args: 34 | label (torch.Tensor): 1-D tensor. 35 | num_classes (int): number of classes. 36 | """ 37 | onehot = torch.zeros(label.shape[0], num_classes) 38 | return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1) 39 | 40 | 41 | def sigmoid_rampup(current, rampup_length): 42 | """Exponential rampup. 43 | 44 | Args: 45 | current (int): current step. 46 | rampup_length (int): maximum step. 47 | """ 48 | assert rampup_length > 0 49 | current = np.clip(current, 0.0, rampup_length) 50 | phase = 1.0 - current/rampup_length 51 | return float(np.exp(-5.0 * phase * phase)) 52 | 53 | 54 | def linear_rampup(current, rampup_length): 55 | """Linear rampup. 56 | 57 | Args: 58 | current (int): current step. 59 | rampup_length (int): maximum step. 60 | """ 61 | assert rampup_length > 0 62 | ratio = np.clip(current / rampup_length, 0.0, 1.0) 63 | return float(ratio) 64 | 65 | 66 | def ema_model_update(model, ema_model, alpha): 67 | """Exponential moving average of model parameters. 68 | 69 | Args: 70 | model (nn.Module): model being trained. 71 | ema_model (nn.Module): ema of the model. 72 | alpha (float): ema decay rate. 73 | """ 74 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 75 | ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) 76 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer import build_optimizer 2 | from .lr_scheduler import build_lr_scheduler 3 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .logger import * 3 | from .meters import * 4 | from .registry import * 5 | from .torchtools import * 6 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import os.path as osp 5 | 6 | from .tools import mkdir_if_missing 7 | 8 | __all__ = ["Logger", "setup_logger"] 9 | 10 | 11 | class Logger: 12 | """Write console output to external text file. 13 | 14 | Imported from ``_ 15 | 16 | Args: 17 | fpath (str): directory to save logging file. 18 | 19 | Examples:: 20 | >>> import sys 21 | >>> import os.path as osp 22 | >>> save_dir = 'output/experiment-1' 23 | >>> log_name = 'train.log' 24 | >>> sys.stdout = Logger(osp.join(save_dir, log_name)) 25 | """ 26 | 27 | def __init__(self, fpath=None): 28 | self.console = sys.stdout 29 | self.file = None 30 | if fpath is not None: 31 | mkdir_if_missing(osp.dirname(fpath)) 32 | self.file = open(fpath, "w") 33 | 34 | def __del__(self): 35 | self.close() 36 | 37 | def __enter__(self): 38 | pass 39 | 40 | def __exit__(self, *args): 41 | self.close() 42 | 43 | def write(self, msg): 44 | self.console.write(msg) 45 | if self.file is not None: 46 | self.file.write(msg) 47 | 48 | def flush(self): 49 | self.console.flush() 50 | if self.file is not None: 51 | self.file.flush() 52 | os.fsync(self.file.fileno()) 53 | 54 | def close(self): 55 | self.console.close() 56 | if self.file is not None: 57 | self.file.close() 58 | 59 | 60 | def setup_logger(output=None): 61 | if output is None: 62 | return 63 | 64 | if output.endswith(".txt") or output.endswith(".log"): 65 | fpath = output 66 | else: 67 | fpath = osp.join(output, "log.txt") 68 | 69 | if osp.exists(fpath): 70 | # make sure the existing log file is not over-written 71 | fpath += time.strftime("-%Y-%m-%d-%H-%M-%S") 72 | 73 | sys.stdout = Logger(fpath) 74 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/utils/meters.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | 4 | __all__ = ["AverageMeter", "MetricMeter"] 5 | 6 | 7 | class AverageMeter: 8 | """Compute and store the average and current value. 9 | 10 | Examples:: 11 | >>> # 1. Initialize a meter to record loss 12 | >>> losses = AverageMeter() 13 | >>> # 2. Update meter after every mini-batch update 14 | >>> losses.update(loss_value, batch_size) 15 | """ 16 | 17 | def __init__(self, ema=False): 18 | """ 19 | Args: 20 | ema (bool, optional): apply exponential moving average. 21 | """ 22 | self.ema = ema 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | if isinstance(val, torch.Tensor): 33 | val = val.item() 34 | 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | 39 | if self.ema: 40 | self.avg = self.avg * 0.9 + self.val * 0.1 41 | else: 42 | self.avg = self.sum / self.count 43 | 44 | 45 | class MetricMeter: 46 | """Store the average and current value for a set of metrics. 47 | 48 | Examples:: 49 | >>> # 1. Create an instance of MetricMeter 50 | >>> metric = MetricMeter() 51 | >>> # 2. Update using a dictionary as input 52 | >>> input_dict = {'loss_1': value_1, 'loss_2': value_2} 53 | >>> metric.update(input_dict) 54 | >>> # 3. Convert to string and print 55 | >>> print(str(metric)) 56 | """ 57 | 58 | def __init__(self, delimiter="\t"): 59 | self.meters = defaultdict(AverageMeter) 60 | self.delimiter = delimiter 61 | 62 | def update(self, input_dict): 63 | if input_dict is None: 64 | return 65 | 66 | if not isinstance(input_dict, dict): 67 | raise TypeError( 68 | "Input to MetricMeter.update() must be a dictionary" 69 | ) 70 | 71 | for k, v in input_dict.items(): 72 | if isinstance(v, torch.Tensor): 73 | v = v.item() 74 | self.meters[k].update(v) 75 | 76 | def __str__(self): 77 | output_str = [] 78 | for name, meter in self.meters.items(): 79 | output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})") 80 | return self.delimiter.join(output_str) 81 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/dassl/utils/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/facebookresearch/fvcore 3 | """ 4 | __all__ = ["Registry"] 5 | 6 | 7 | class Registry: 8 | """A registry providing name -> object mapping, to support 9 | custom modules. 10 | 11 | To create a registry (e.g. a backbone registry): 12 | 13 | .. code-block:: python 14 | 15 | BACKBONE_REGISTRY = Registry('BACKBONE') 16 | 17 | To register an object: 18 | 19 | .. code-block:: python 20 | 21 | @BACKBONE_REGISTRY.register() 22 | class MyBackbone(nn.Module): 23 | ... 24 | 25 | Or: 26 | 27 | .. code-block:: python 28 | 29 | BACKBONE_REGISTRY.register(MyBackbone) 30 | """ 31 | 32 | def __init__(self, name): 33 | self._name = name 34 | self._obj_map = dict() 35 | 36 | def _do_register(self, name, obj, force=False): 37 | if name in self._obj_map and not force: 38 | raise KeyError( 39 | 'An object named "{}" was already ' 40 | 'registered in "{}" registry'.format(name, self._name) 41 | ) 42 | 43 | self._obj_map[name] = obj 44 | 45 | def register(self, obj=None, force=False): 46 | if obj is None: 47 | # Used as a decorator 48 | def wrapper(fn_or_class): 49 | name = fn_or_class.__name__ 50 | self._do_register(name, fn_or_class, force=force) 51 | return fn_or_class 52 | 53 | return wrapper 54 | 55 | # Used as a function call 56 | name = obj.__name__ 57 | self._do_register(name, obj, force=force) 58 | 59 | def get(self, name): 60 | if name not in self._obj_map: 61 | raise KeyError( 62 | 'Object name "{}" does not exist ' 63 | 'in "{}" registry'.format(name, self._name) 64 | ) 65 | 66 | return self._obj_map[name] 67 | 68 | def registered_names(self): 69 | return list(self._obj_map.keys()) 70 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/datasets/da/cifar_stl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pprint as pp 3 | import os.path as osp 4 | from torchvision.datasets import STL10, CIFAR10 5 | 6 | from dassl.utils import mkdir_if_missing 7 | 8 | cifar_label2name = { 9 | 0: "airplane", 10 | 1: "car", # the original name was 'automobile' 11 | 2: "bird", 12 | 3: "cat", 13 | 4: "deer", 14 | 5: "dog", 15 | 6: "frog", # conflict class 16 | 7: "horse", 17 | 8: "ship", 18 | 9: "truck", 19 | } 20 | 21 | stl_label2name = { 22 | 0: "airplane", 23 | 1: "bird", 24 | 2: "car", 25 | 3: "cat", 26 | 4: "deer", 27 | 5: "dog", 28 | 6: "horse", 29 | 7: "monkey", # conflict class 30 | 8: "ship", 31 | 9: "truck", 32 | } 33 | 34 | new_name2label = { 35 | "airplane": 0, 36 | "bird": 1, 37 | "car": 2, 38 | "cat": 3, 39 | "deer": 4, 40 | "dog": 5, 41 | "horse": 6, 42 | "ship": 7, 43 | "truck": 8, 44 | } 45 | 46 | 47 | def extract_and_save_image(dataset, save_dir, discard, label2name): 48 | if osp.exists(save_dir): 49 | print('Folder "{}" already exists'.format(save_dir)) 50 | return 51 | 52 | print('Extracting images to "{}" ...'.format(save_dir)) 53 | mkdir_if_missing(save_dir) 54 | 55 | for i in range(len(dataset)): 56 | img, label = dataset[i] 57 | if label == discard: 58 | continue 59 | class_name = label2name[label] 60 | label_new = new_name2label[class_name] 61 | class_dir = osp.join( 62 | save_dir, 63 | str(label_new).zfill(3) + "_" + class_name 64 | ) 65 | mkdir_if_missing(class_dir) 66 | impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg") 67 | img.save(impath) 68 | 69 | 70 | def download_and_prepare(name, root, discarded_label, label2name): 71 | print("Dataset: {}".format(name)) 72 | print("Root: {}".format(root)) 73 | print("Old labels:") 74 | pp.pprint(label2name) 75 | print("Discarded label: {}".format(discarded_label)) 76 | print("New labels:") 77 | pp.pprint(new_name2label) 78 | 79 | if name == "cifar": 80 | train = CIFAR10(root, train=True, download=True) 81 | test = CIFAR10(root, train=False) 82 | else: 83 | train = STL10(root, split="train", download=True) 84 | test = STL10(root, split="test") 85 | 86 | train_dir = osp.join(root, name, "train") 87 | test_dir = osp.join(root, name, "test") 88 | 89 | extract_and_save_image(train, train_dir, discarded_label, label2name) 90 | extract_and_save_image(test, test_dir, discarded_label, label2name) 91 | 92 | 93 | if __name__ == "__main__": 94 | download_and_prepare("cifar", sys.argv[1], 6, cifar_label2name) 95 | download_and_prepare("stl", sys.argv[1], 7, stl_label2name) 96 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/datasets/dg/cifar_c.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script 3 | - creates a folder named "cifar10_c" under the same directory as 'CIFAR-10-C' 4 | - extracts images from .npy files and save them as .jpg. 5 | """ 6 | import os 7 | import sys 8 | import numpy as np 9 | import os.path as osp 10 | from PIL import Image 11 | 12 | from dassl.utils import mkdir_if_missing 13 | 14 | 15 | def extract_and_save(images, labels, level, dst): 16 | # level denotes the corruption intensity level (0-based) 17 | assert 0 <= level <= 4 18 | 19 | for i in range(10000): 20 | real_i = i + level*10000 21 | im = Image.fromarray(images[real_i]) 22 | label = int(labels[real_i]) 23 | category_dir = osp.join(dst, str(label).zfill(3)) 24 | mkdir_if_missing(category_dir) 25 | save_path = osp.join(category_dir, str(i + 1).zfill(5) + ".jpg") 26 | im.save(save_path) 27 | 28 | 29 | def main(npy_folder): 30 | npy_folder = osp.abspath(osp.expanduser(npy_folder)) 31 | dataset_cap = osp.basename(npy_folder) 32 | 33 | assert dataset_cap in ["CIFAR-10-C", "CIFAR-100-C"] 34 | 35 | if dataset_cap == "CIFAR-10-C": 36 | dataset = "cifar10_c" 37 | else: 38 | dataset = "cifar100_c" 39 | 40 | if not osp.exists(npy_folder): 41 | print('The given folder "{}" does not exist'.format(npy_folder)) 42 | 43 | root = osp.dirname(npy_folder) 44 | im_folder = osp.join(root, dataset) 45 | 46 | mkdir_if_missing(im_folder) 47 | 48 | dirnames = os.listdir(npy_folder) 49 | dirnames.remove("labels.npy") 50 | if "README.txt" in dirnames: 51 | dirnames.remove("README.txt") 52 | assert len(dirnames) == 19 53 | labels = np.load(osp.join(npy_folder, "labels.npy")) 54 | 55 | for dirname in dirnames: 56 | corruption = dirname.split(".")[0] 57 | corruption_folder = osp.join(im_folder, corruption) 58 | mkdir_if_missing(corruption_folder) 59 | 60 | npy_filename = osp.join(npy_folder, dirname) 61 | images = np.load(npy_filename) 62 | assert images.shape[0] == 50000 63 | 64 | for level in range(5): 65 | dst = osp.join(corruption_folder, str(level + 1)) 66 | mkdir_if_missing(dst) 67 | print('Saving images to "{}"'.format(dst)) 68 | extract_and_save(images, labels, level, dst) 69 | 70 | 71 | if __name__ == "__main__": 72 | # sys.argv[1] contains the path to CIFAR-10-C or CIFAR-100-C 73 | main(sys.argv[1]) 74 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/datasets/ssl/cifar10_cifar100_svhn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | from torchvision.datasets import SVHN, CIFAR10, CIFAR100 4 | 5 | from dassl.utils import mkdir_if_missing 6 | 7 | 8 | def extract_and_save_image(dataset, save_dir): 9 | if osp.exists(save_dir): 10 | print('Folder "{}" already exists'.format(save_dir)) 11 | return 12 | 13 | print('Extracting images to "{}" ...'.format(save_dir)) 14 | mkdir_if_missing(save_dir) 15 | 16 | for i in range(len(dataset)): 17 | img, label = dataset[i] 18 | class_dir = osp.join(save_dir, str(label).zfill(3)) 19 | mkdir_if_missing(class_dir) 20 | impath = osp.join(class_dir, str(i + 1).zfill(5) + ".jpg") 21 | img.save(impath) 22 | 23 | 24 | def download_and_prepare(name, root): 25 | print("Dataset: {}".format(name)) 26 | print("Root: {}".format(root)) 27 | 28 | if name == "cifar10": 29 | train = CIFAR10(root, train=True, download=True) 30 | test = CIFAR10(root, train=False) 31 | elif name == "cifar100": 32 | train = CIFAR100(root, train=True, download=True) 33 | test = CIFAR100(root, train=False) 34 | elif name == "svhn": 35 | train = SVHN(root, split="train", download=True) 36 | test = SVHN(root, split="test", download=True) 37 | else: 38 | raise ValueError 39 | 40 | train_dir = osp.join(root, name, "train") 41 | test_dir = osp.join(root, name, "test") 42 | 43 | extract_and_save_image(train, train_dir) 44 | extract_and_save_image(test, test_dir) 45 | 46 | 47 | if __name__ == "__main__": 48 | download_and_prepare("cifar10", sys.argv[1]) 49 | download_and_prepare("cifar100", sys.argv[1]) 50 | download_and_prepare("svhn", sys.argv[1]) 51 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/datasets/ssl/stl10.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | from torchvision.datasets import STL10 4 | 5 | from dassl.utils import mkdir_if_missing 6 | 7 | 8 | def extract_and_save_image(dataset, save_dir): 9 | if osp.exists(save_dir): 10 | print('Folder "{}" already exists'.format(save_dir)) 11 | return 12 | 13 | print('Extracting images to "{}" ...'.format(save_dir)) 14 | mkdir_if_missing(save_dir) 15 | 16 | for i in range(len(dataset)): 17 | img, label = dataset[i] 18 | if label == -1: 19 | label_name = "none" 20 | else: 21 | label_name = str(label) 22 | imname = str(i).zfill(6) + "_" + label_name + ".jpg" 23 | impath = osp.join(save_dir, imname) 24 | img.save(impath) 25 | 26 | 27 | def download_and_prepare(root): 28 | train = STL10(root, split="train", download=True) 29 | test = STL10(root, split="test") 30 | unlabeled = STL10(root, split="unlabeled") 31 | 32 | train_dir = osp.join(root, "train") 33 | test_dir = osp.join(root, "test") 34 | unlabeled_dir = osp.join(root, "unlabeled") 35 | 36 | extract_and_save_image(train, train_dir) 37 | extract_and_save_image(test, test_dir) 38 | extract_and_save_image(unlabeled, unlabeled_dir) 39 | 40 | 41 | if __name__ == "__main__": 42 | download_and_prepare(sys.argv[1]) 43 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/setup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | from setuptools import setup, find_packages 4 | 5 | 6 | def readme(): 7 | with open('README.md') as f: 8 | content = f.read() 9 | return content 10 | 11 | 12 | def find_version(): 13 | version_file = 'dassl/__init__.py' 14 | with open(version_file, 'r') as f: 15 | exec(compile(f.read(), version_file, 'exec')) 16 | return locals()['__version__'] 17 | 18 | 19 | def numpy_include(): 20 | try: 21 | numpy_include = np.get_include() 22 | except AttributeError: 23 | numpy_include = np.get_numpy_include() 24 | return numpy_include 25 | 26 | 27 | def get_requirements(filename='requirements.txt'): 28 | here = osp.dirname(osp.realpath(__file__)) 29 | with open(osp.join(here, filename), 'r') as f: 30 | requires = [line.replace('\n', '') for line in f.readlines()] 31 | return requires 32 | 33 | 34 | setup( 35 | name='dassl', 36 | version=find_version(), 37 | description='Dassl: Domain adaptation and semi-supervised learning', 38 | author='Kaiyang Zhou', 39 | license='MIT', 40 | long_description=readme(), 41 | url='https://github.com/KaiyangZhou/Dassl.pytorch', 42 | packages=find_packages(), 43 | install_requires=get_requirements(), 44 | keywords=[ 45 | 'Domain Adaptation', 'Domain Generalization', 46 | 'Semi-Supervised Learning', 'Pytorch' 47 | ] 48 | ) 49 | -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/CPL/Dassl.pytorch/tools/__init__.py -------------------------------------------------------------------------------- /CPL/Dassl.pytorch/tools/replace_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | Replace text in python files. 3 | """ 4 | import glob 5 | import os.path as osp 6 | import argparse 7 | import fileinput 8 | 9 | EXTENSION = ".py" 10 | 11 | 12 | def is_python_file(filename): 13 | ext = osp.splitext(filename)[1] 14 | return ext == EXTENSION 15 | 16 | 17 | def update_file(filename, text_to_search, replacement_text): 18 | print("Processing {}".format(filename)) 19 | with fileinput.FileInput(filename, inplace=True, backup="") as file: 20 | for line in file: 21 | print(line.replace(text_to_search, replacement_text), end="") 22 | 23 | 24 | def recursive_update(directory, text_to_search, replacement_text): 25 | filenames = glob.glob(osp.join(directory, "*")) 26 | 27 | for filename in filenames: 28 | if osp.isfile(filename): 29 | if not is_python_file(filename): 30 | continue 31 | update_file(filename, text_to_search, replacement_text) 32 | elif osp.isdir(filename): 33 | recursive_update(filename, text_to_search, replacement_text) 34 | else: 35 | raise NotImplementedError 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument( 41 | "file_or_dir", type=str, help="path to file or directory" 42 | ) 43 | parser.add_argument("text_to_search", type=str, help="name to be replaced") 44 | parser.add_argument("replacement_text", type=str, help="new name") 45 | parser.add_argument( 46 | "--ext", type=str, default=".py", help="file extension" 47 | ) 48 | args = parser.parse_args() 49 | 50 | file_or_dir = args.file_or_dir 51 | text_to_search = args.text_to_search 52 | replacement_text = args.replacement_text 53 | extension = args.ext 54 | 55 | global EXTENSION 56 | EXTENSION = extension 57 | 58 | if osp.isfile(file_or_dir): 59 | if not is_python_file(file_or_dir): 60 | return 61 | update_file(file_or_dir, text_to_search, replacement_text) 62 | elif osp.isdir(file_or_dir): 63 | recursive_update(file_or_dir, text_to_search, replacement_text) 64 | else: 65 | raise NotImplementedError 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /CPL/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /CPL/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/CPL/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /CPL/dataset_processing/README.md: -------------------------------------------------------------------------------- 1 | # Data proccessing 2 | - This repo orgnizes the repo and files to proprocess the downloaded dataset. 3 | - We share the processed data, prompts, and BERTScore matrix here https://drive.google.com/drive/folders/1kvFEQUdeZ5scjTWrLWdhE4BJIbzJpFbY?usp=sharing 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/README.md: -------------------------------------------------------------------------------- 1 | # Data praperation 2 | - Run `python split.py $n $a $path` to generate the train, val, and test split files for *n-shot* training using random seed *a* using the processed prompts 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/T5_filtered_prompt.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3d2cd3d79ab89e9960f2ff992c2fc863f968a78cdbdfdf41441a49c27f2ee9d5 3 | size 136051955 4 | -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/all.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:138717ee6847b90ccbebbada9174fb64fd8fd1e15c318fbdf0d97c73e89dcf48 3 | size 39323824 4 | -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/answer_vocab.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4359fce43607051deca1a059ba53ebfd1c95a493dd4dc0becc29ebe97260d052 3 | size 58077 4 | -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/lib/vqaEvaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/CPL/dataset_processing/VQA/lib/vqaEvaluation/__init__.py -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/lib/vqaTools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/CPL/dataset_processing/VQA/lib/vqaTools/__init__.py -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/number_category.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6617f5e11649ded4b6ec9e490ffa23d06b932ab1a6d409d12657951b136fc077 3 | size 1006671 4 | -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/other_category.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:37b22087c3ae59d9b0b5d50bca6341c3dd510de679321a2d4919b347faf912c9 3 | size 21936028 4 | -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/split.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import csv 3 | import random 4 | import sys 5 | from tqdm import tqdm 6 | import json 7 | import logging 8 | import pathlib 9 | import itertools 10 | from collections import defaultdict 11 | 12 | #RANDOM_SEED = 1 13 | FEW_SHOT = float(sys.argv[1]) # n-shot: 0.5% 1% 3% 14 | RANDOM_SEED = int(sys.argv[2]) # set random set 15 | # PATH = pathlib.Path(sys.argv[3]) 16 | PATH = pathlib.Path('./') 17 | 18 | random.seed(RANDOM_SEED) 19 | 20 | def json_prec_dump(data, prec=6): 21 | return json.dumps(json.loads(json.dumps(data), parse_float=lambda x: round(float(x), prec))) 22 | 23 | all_types = [] 24 | for json_file in PATH.iterdir(): 25 | if json_file.suffix != '.json': 26 | print(f'Ignoring file {json_file.name} by suffix.') 27 | continue 28 | if json_file.name == 'T5_filtered_prompt.json' or json_file.name == 'answer_vocab.json' or json_file.name == 'all.json': 29 | print(f'Ignoring file {json_file.name}.') 30 | continue 31 | # json_data = json.loads(json_file.read_text()) 32 | json_data = json.loads(json_file.read_text()) 33 | 34 | all_types.extend(json_data) 35 | 36 | b = json.dumps(all_types) 37 | f = open('all.json', 'w') 38 | f.write(b) 39 | f.close 40 | 41 | df = pd.read_json("all.json") 42 | # df = pd.read_json("/data1/xh/vqa/karpathy_test.json") 43 | # df = pd.read_json("/data1/xh/vqa/val.json") 44 | number = int(len(df)*FEW_SHOT) 45 | print(len(df), number, FEW_SHOT, len(df)*FEW_SHOT) 46 | df_train = df[:-1000].sample(n=number, random_state=RANDOM_SEED) # down sample to n-shot 47 | print(len(df_train), df_train.keys(), df_train[0:2]) 48 | df_test = df[-1000:] 49 | print('len is', len(df_test), df_test[0:2]) 50 | 51 | 52 | 53 | 54 | train_imnames = df_train['img'].tolist() 55 | train_prompts = [] 56 | for i in range(len(df_train)): 57 | train_prompts.append(df_train['question'].tolist()[i].rstrip().lstrip() +' ' + df_train['prompt'].tolist()[i].rstrip().lstrip()) 58 | 59 | test_imnames = [] 60 | test_prompts = [] 61 | 62 | print("Creating data split with ", FEW_SHOT, " training shot...") 63 | 64 | for i in tqdm(range(len(df_test))): 65 | test_imnames.append(df_test['img'].tolist()[i]) 66 | test_prompts.append(df_test['question'].tolist()[i].rstrip().lstrip() +' ' + df_test['prompt'].tolist()[i].rstrip().lstrip()) 67 | 68 | textfile = open("train.txt", "w") 69 | for i in range(len(train_imnames)): 70 | textfile.write(train_imnames[i] + "*" + train_prompts[i] + "\n") 71 | for i in range(len(test_imnames)): 72 | textfile.write(test_imnames[i] + "*" + test_prompts[i] + "\n") 73 | textfile.close() 74 | 75 | textfile = open("test.txt", "w") 76 | for i in range(len(train_imnames)): 77 | textfile.write(train_imnames[i] + "*" + train_prompts[i] + "\n") 78 | for i in range(len(test_imnames)): 79 | textfile.write(test_imnames[i] + "*" + test_prompts[i] + "\n") 80 | textfile.close() 81 | 82 | textfile = open("val.txt", "w") 83 | textfile.close() 84 | 85 | textfile = open("classnames.txt", "w") 86 | for i in range(len(train_prompts)): 87 | textfile.write(train_prompts[i] + "\n") 88 | for i in range(len(test_prompts)): 89 | textfile.write(test_prompts[i] + "\n") 90 | textfile.close() 91 | -------------------------------------------------------------------------------- /CPL/dataset_processing/VQA/yesno_category.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:85bf88e8ca7c4229376f2aa1b7fb832dd29929670dbc4f3ee95d77ee21b8e0aa 3 | size 16381125 4 | -------------------------------------------------------------------------------- /CPL/dataset_processing/classification/README.md: -------------------------------------------------------------------------------- 1 | # Image Classification 2 | 3 | Image classification scripts take three input arguments: `DATASET_NAME`, `SEED`, and `SPLIT`. 4 | 5 | Below are examples on training and testing the model one ImageNet using random seed 1: 6 | 7 | ```bash 8 | # train on base(first half) and test on new(second half) - corresponding table 1 9 | bash scripts/cocoop/base2new_train.sh imagenet 1 base 10 | bash scripts/cocoop/base2new_test.sh imagenet 1 new 11 | 12 | # train on new(second half) and test on base(first half) - corresponding table 5, Split One 13 | bash scripts/cocoop/base2new_train.sh imagenet 1 new 14 | bash scripts/cocoop/base2new_test.sh imagenet 1 base 15 | 16 | # train on base(first half) and test on new(second half) with shuffled labels - corresponding table 6, Split Two 17 | bash scripts/cocoop/base2new_train.sh imagenet 1 split_two 18 | bash scripts/cocoop/base2new_test.sh imagenet 1 new 19 | ``` -------------------------------------------------------------------------------- /CPL/dataset_processing/flickr30k/README.md: -------------------------------------------------------------------------------- 1 | # Data praperation 2 | - dataset_flickr30k.json: Original Karpathy split. This needs to be downloaded from [Here](https://www.kaggle.com/datasets/shtvkumar/karpathy-splits?select=flickr30k.json) 3 | - Run `python split.py $n $a` to generate the following 4 split files for *n-shot* training using random seed *a* 4 | 5 | # Flickr30k 6 | 7 | - train.txt contains: 8 | - *n* image-text pairs for few-shot training 9 | - *n* random images sampled from training set 10 | - The caption for each image is randomly selected from five golden captions, which will be used as the (fewshot) training example in [subsample_classes](../../datasets/oxford_pets.py) 11 | - 1000\*5 image-text pairs for unseen 12 | - 1000 unseen images and their corresponding five captions 13 | - This provides all (unseen) captions information, which will be used later as the label dictionary in [subsample_classes](../../datasets/oxford_pets.py) 14 | 15 | - test.txt contains: 16 | - *n* training pairs (testing on seen data is meaningless in our Image-Text Retrieval setting, yet we keep the training pairs here to maintain consistency with the data format of image classification) 17 | - Test set follows the original split 18 | 19 | - val.txt: empty file 20 | 21 | - classnames.txt: all captions (n + 1000\*5 in total) -------------------------------------------------------------------------------- /CPL/dataset_processing/flickr30k/dataset_flickr30k.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:db779ee2c5df40c25ff0eebeca464f6124b620cc4a9398650aea71fd1b53d4af 3 | size 38318553 4 | -------------------------------------------------------------------------------- /CPL/dataset_processing/flickr30k/split.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import csv 3 | import random 4 | import sys 5 | 6 | #RANDOM_SEED = 1 7 | FEW_SHOT = int(sys.argv[1]) # n-shot: 145 290 870 1450 8 | RANDOM_SEED = int(sys.argv[2]) # set random set 9 | 10 | random.seed(RANDOM_SEED) 11 | 12 | df = pd.read_json("dataset_flickr30k.json") 13 | df_train = df[df.images.str['split'] == 'train'] 14 | df_train = df_train.sample(n=FEW_SHOT, random_state=RANDOM_SEED) # down sample to n-shot 15 | df_test = df[df.images.str['split'] == 'test'] 16 | 17 | # n-shot train data 18 | train_imnames = df_train.images.str['filename'].tolist() 19 | train_captions = [] 20 | for i in range(len(df_train)): 21 | train_captions.append(df_train.images.str['sentences'].tolist()[i][random.randint(0, 4)]['raw']) 22 | 23 | # all test data 24 | test_imnames = [] 25 | test_captions = [] 26 | 27 | print("Creating data split with ", FEW_SHOT, " training shot...") 28 | 29 | for i in range(len(df_test)): 30 | for j in range(5): 31 | test_imnames.append(df_test.images.str['filename'].tolist()[i]) 32 | test_captions.append(df_test.images.str['sentences'].tolist()[i][j]['raw'][:265]) # truncate at char 265 33 | 34 | # n-shot train + 5000 test 35 | textfile = open("train.txt", "w") 36 | for i in range(len(train_imnames)): 37 | textfile.write(train_imnames[i] + "*" + train_captions[i] + "\n") 38 | for i in range(len(test_imnames)): 39 | textfile.write(test_imnames[i] + "*" + test_captions[i] + "\n") 40 | textfile.close() 41 | 42 | # n_shot train + 1000 test 43 | textfile = open("test.txt", "w") 44 | for i in range(len(train_imnames)): 45 | textfile.write(train_imnames[i] + "*" + train_captions[i] + "\n") 46 | for i in range(1000): 47 | textfile.write(test_imnames[i*5] + "*" + test_captions[i*5] + "\n") 48 | textfile.close() 49 | 50 | textfile = open("val.txt", "w") 51 | textfile.close() 52 | 53 | # file with all captions 54 | textfile = open("classnames.txt", "w") 55 | for i in range(len(train_captions)): 56 | textfile.write(train_captions[i] + "\n") 57 | for i in range(len(test_captions)): 58 | textfile.write(test_captions[i] + "\n") 59 | textfile.close() 60 | -------------------------------------------------------------------------------- /CPL/dataset_processing/flickr8k/README.md: -------------------------------------------------------------------------------- 1 | # split.py 2 | Run `pyrhon split.py $n` to generate split files for n-shot training. 3 | 4 | 5 | 6 | # Flickr8k 7 | 8 | - train.txt: 9 | - n image-text pairs for few-shot training (fewshot) 10 | - n random image sampled from training set in Karpathy split 11 | - Caption for each image is randomly selected from five golden caption 12 | - 1000\*5 pairs for testing (unseen) 13 | - Testing set from Karpathy split 14 | 15 | 16 | 17 | - test.txt: 18 | - Same as (fewshot) in train.txt 19 | - 1000 image-text pairs for testing 20 | 21 | - val.txt: empty file 22 | 23 | - classnames.txt: all captions (n + 1000\*5 in total) 24 | 25 | 26 | - dataset_flickr8k.json: Original Karpathy split 27 | 28 | -------------------------------------------------------------------------------- /CPL/dataset_processing/flickr8k/split.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import csv 3 | import random 4 | import sys 5 | 6 | RANDOM_SEED = 1 7 | FEW_SHOT = int(sys.argv[1]) # n-shot: 30 60 180 300 8 | 9 | random.seed(RANDOM_SEED) 10 | 11 | df = pd.read_json("dataset_flickr8k.json") 12 | df_train = df[df.images.str['split'] == 'train'] 13 | df_train = df_train.sample(n=FEW_SHOT, random_state=RANDOM_SEED) # down sample to n-shot 14 | df_test = df[df.images.str['split'] == 'test'] 15 | 16 | # n-shot train data 17 | train_imnames = df_train.images.str['filename'].tolist() 18 | train_captions = [] 19 | for i in range(len(df_train)): 20 | train_captions.append(df_train.images.str['sentences'].tolist()[i][random.randint(0, 4)]['raw']) 21 | 22 | # all test data 23 | test_imnames = [] 24 | test_captions = [] 25 | 26 | print("Creating data splilt with ", FEW_SHOT, " training shot...") 27 | 28 | for i in range(len(df_test)): 29 | for j in range(5): 30 | test_imnames.append(df_test.images.str['filename'].tolist()[i]) 31 | test_captions.append(df_test.images.str['sentences'].tolist()[i][j]['raw']) 32 | 33 | # n-shot train + 5000 test 34 | textfile = open("train.txt", "w") 35 | for i in range(len(train_imnames)): 36 | textfile.write(train_imnames[i] + "*" + train_captions[i] + "\n") 37 | for i in range(len(test_imnames)): 38 | textfile.write(test_imnames[i] + "*" + test_captions[i] + "\n") 39 | textfile.close() 40 | 41 | # n_shot train + 1000 test 42 | textfile = open("test.txt", "w") 43 | for i in range(len(train_imnames)): 44 | textfile.write(train_imnames[i] + "*" + train_captions[i] + "\n") 45 | for i in range(1000): 46 | textfile.write(test_imnames[i*5] + "*" + test_captions[i*5] + "\n") 47 | textfile.close() 48 | 49 | textfile = open("val.txt", "w") 50 | textfile.close() 51 | 52 | # file with all captions 53 | textfile = open("classnames.txt", "w") 54 | for i in range(len(train_captions)): 55 | textfile.write(train_captions[i] + "\n") 56 | for i in range(len(test_captions)): 57 | textfile.write(test_captions[i] + "\n") 58 | textfile.close() -------------------------------------------------------------------------------- /CPL/dataset_processing/mscoco/README.md: -------------------------------------------------------------------------------- 1 | # Data praperation 2 | - dataset_coco.json: Original Karpathy split. This needs to be downloaded from [Here](https://www.kaggle.com/datasets/shtvkumar/karpathy-splits?select=dataset_coco.json) 3 | 4 | - Run `python split.py $n $a` to generate the following 4 split files for *n-shot* training using random seed *a*: 5 | 6 | # MSCOCO 7 | - train.txt contains: 8 | - *n* image-text pairs for few-shot training 9 | - *n* random images sampled from training set 10 | - The caption for each image is randomly selected from five golden captions, which will be used as the (fewshot) training example in [subsample_classes](../../datasets/oxford_pets.py) 11 | - 5000\*5 image-text pairs for unseen 12 | - 5000 unseen images and their corresponding five captions 13 | - This provides all (unseen) captions information, which will be used later as the label dictionary in [subsample_classes](../../datasets/oxford_pets.py) 14 | - One sentences with more than 77 tokens have been shortened by hand 15 | 16 | - test.txt contains: 17 | - *n* training pairs (testing on seen data is meaningless in our Image-Text Retrieval setting, yet we keep the training pairs here to maintain consistency with the data format of image classification) 18 | - Test set follows the original split 19 | 20 | - val.txt: empty file 21 | 22 | - classnames.txt: all captions (*n* + 5000\*5 in total) -------------------------------------------------------------------------------- /CPL/dataset_processing/mscoco/split.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import csv 3 | import random 4 | import sys 5 | from tqdm import tqdm 6 | 7 | #RANDOM_SEED = 1 8 | FEW_SHOT = int(sys.argv[1]) # n-shot: 414 828 2483 4139 9 | RANDOM_SEED = int(sys.argv[2]) # set random set 10 | 11 | random.seed(RANDOM_SEED) 12 | 13 | df = pd.read_json("dataset_coco.json") 14 | df_train = df[df.images.str['split'] == 'val'] 15 | df_train = df_train.sample(n=FEW_SHOT, random_state=RANDOM_SEED) # down sample to n-shot 16 | df_test = df[df.images.str['split'] == 'test'] 17 | 18 | # n-shot train data 19 | train_imnames = df_train.images.str['filename'].tolist() 20 | train_captions = [] 21 | for i in range(len(df_train)): 22 | train_captions.append(df_train.images.str['sentences'].tolist()[i][random.randint(0, 4)]['raw'].rstrip().lstrip()) 23 | 24 | # all test data 25 | test_imnames = [] 26 | test_captions = [] 27 | 28 | print("Creating data split with ", FEW_SHOT, " training shot...") 29 | 30 | for i in tqdm(range(len(df_test))): 31 | for j in range(5): 32 | test_imnames.append(df_test.images.str['filename'].tolist()[i]) 33 | test_captions.append(df_test.images.str['sentences'].tolist()[i][j]['raw'].rstrip().lstrip()) 34 | 35 | # n-shot train + 5000 test 36 | textfile = open("train.txt", "w") 37 | for i in range(len(train_imnames)): 38 | textfile.write(train_imnames[i] + "*" + train_captions[i] + "\n") 39 | for i in range(len(test_imnames)): 40 | textfile.write(test_imnames[i] + "*" + test_captions[i] + "\n") 41 | textfile.close() 42 | 43 | # n_shot train + 1000 test 44 | textfile = open("test.txt", "w") 45 | for i in range(len(train_imnames)): 46 | textfile.write(train_imnames[i] + "*" + train_captions[i] + "\n") 47 | for i in range(5000): 48 | textfile.write(test_imnames[i*5] + "*" + test_captions[i*5] + "\n") 49 | textfile.close() 50 | 51 | textfile = open("val.txt", "w") 52 | textfile.close() 53 | 54 | # file with all captions 55 | textfile = open("classnames.txt", "w") 56 | for i in range(len(train_captions)): 57 | textfile.write(train_captions[i] + "\n") 58 | for i in range(len(test_captions)): 59 | textfile.write(test_captions[i] + "\n") 60 | textfile.close() 61 | -------------------------------------------------------------------------------- /CPL/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/CPL/datasets/__init__.py -------------------------------------------------------------------------------- /CPL/datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 11 | NEW_CNAMES = { 12 | "airplanes": "airplane", 13 | "Faces": "face", 14 | "Leopards": "leopard", 15 | "Motorbikes": "motorbike", 16 | } 17 | 18 | 19 | @DATASET_REGISTRY.register() 20 | class Caltech101(DatasetBase): 21 | 22 | dataset_dir = "caltech-101" 23 | 24 | def __init__(self, cfg): 25 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 26 | self.dataset_dir = os.path.join(root, self.dataset_dir) 27 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 28 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 29 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 30 | mkdir_if_missing(self.split_fewshot_dir) 31 | 32 | if os.path.exists(self.split_path): 33 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 34 | else: 35 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | -------------------------------------------------------------------------------- /CPL/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | NEW_CNAMES = { 11 | "AnnualCrop": "Annual Crop Land", 12 | "Forest": "Forest", 13 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 14 | "Highway": "Highway or Road", 15 | "Industrial": "Industrial Buildings", 16 | "Pasture": "Pasture Land", 17 | "PermanentCrop": "Permanent Crop Land", 18 | "Residential": "Residential Buildings", 19 | "River": "River", 20 | "SeaLake": "Sea or Lake", 21 | } 22 | 23 | 24 | @DATASET_REGISTRY.register() 25 | class EuroSAT(DatasetBase): 26 | 27 | dataset_dir = "eurosat" 28 | 29 | def __init__(self, cfg): 30 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 31 | self.dataset_dir = os.path.join(root, self.dataset_dir) 32 | self.image_dir = os.path.join(self.dataset_dir, "2750") 33 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 34 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 35 | mkdir_if_missing(self.split_fewshot_dir) 36 | 37 | if os.path.exists(self.split_path): 38 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 39 | else: 40 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 41 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 42 | 43 | num_shots = cfg.DATASET.NUM_SHOTS 44 | if num_shots >= 1: 45 | seed = cfg.SEED 46 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 47 | 48 | if os.path.exists(preprocessed): 49 | print(f"Loading preprocessed few-shot data from {preprocessed}") 50 | with open(preprocessed, "rb") as file: 51 | data = pickle.load(file) 52 | train, val = data["train"], data["val"] 53 | else: 54 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 55 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 56 | data = {"train": train, "val": val} 57 | print(f"Saving preprocessed few-shot data to {preprocessed}") 58 | with open(preprocessed, "wb") as file: 59 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 60 | 61 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 62 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 63 | 64 | super().__init__(train_x=train, val=val, test=test) 65 | 66 | def update_classname(self, dataset_old): 67 | dataset_new = [] 68 | for item_old in dataset_old: 69 | cname_old = item_old.classname 70 | cname_new = NEW_CLASSNAMES[cname_old] 71 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 72 | dataset_new.append(item_new) 73 | return dataset_new 74 | -------------------------------------------------------------------------------- /CPL/datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class FGVCAircraft(DatasetBase): 12 | 13 | dataset_dir = "fgvc_aircraft" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "images_variant_train.txt") 30 | val = self.read_data(cname2lab, "images_variant_val.txt") 31 | test = self.read_data(cname2lab, "images_variant_test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, cname2lab, split_file): 57 | filepath = os.path.join(self.dataset_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip().split(" ") 64 | imname = line[0] + ".jpg" 65 | classname = " ".join(line[1:]) 66 | impath = os.path.join(self.image_dir, imname) 67 | label = cname2lab[classname] 68 | item = Datum(impath=impath, label=label, classname=classname) 69 | items.append(item) 70 | 71 | return items 72 | -------------------------------------------------------------------------------- /CPL/datasets/flickr30k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class flickr30k(DatasetBase): 12 | 13 | dataset_dir = "flickr30k" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "classnames.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "train.txt") 30 | val = self.read_data(cname2lab, "val.txt") 31 | test = self.read_data(cname2lab, "test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) # 1 shot always 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | with open(os.path.join(self.dataset_dir, "test.txt"), "r") as f: 52 | lines = f.readlines() 53 | train_shot = (len(lines)) - 1000 54 | 55 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 56 | train, val, test = OxfordPets.subsample_classes(train, val, test, train_shot, subsample=subsample) # subsample = fewshot for training; = unseen for testing 57 | 58 | super().__init__(train_x=train, val=val, test=test) 59 | 60 | def read_data(self, cname2lab, split_file): 61 | filepath = os.path.join(self.dataset_dir, split_file) 62 | items = [] 63 | 64 | with open(filepath, "r") as f: 65 | lines = f.readlines() 66 | for line in lines: 67 | line = line.strip().split("*") 68 | imname = line[0] 69 | classname = " ".join(line[1:]) 70 | impath = os.path.join(self.image_dir, imname) 71 | 72 | label = cname2lab[classname] 73 | item = Datum(impath=impath, label=label, classname=classname) 74 | items.append(item) 75 | 76 | return items -------------------------------------------------------------------------------- /CPL/datasets/flickr8k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class flickr8k(DatasetBase): 12 | 13 | dataset_dir = "flickr8k" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "classnames.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "train.txt") 30 | val = self.read_data(cname2lab, "val.txt") 31 | test = self.read_data(cname2lab, "test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) # 1 shot always 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | with open(os.path.join(self.dataset_dir, "test.txt"), "r") as f: 52 | lines = f.readlines() 53 | train_shot = (len(lines)) - 1000 54 | 55 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 56 | train, val, test = OxfordPets.subsample_classes(train, val, test, train_shot, subsample=subsample) # subsample = fewshot for training; = unseen for testing 57 | 58 | super().__init__(train_x=train, val=val, test=test) 59 | 60 | def read_data(self, cname2lab, split_file): 61 | filepath = os.path.join(self.dataset_dir, split_file) 62 | items = [] 63 | 64 | with open(filepath, "r") as f: 65 | lines = f.readlines() 66 | for line in lines: 67 | line = line.strip().split("*") 68 | imname = line[0] 69 | classname = " ".join(line[1:]) 70 | impath = os.path.join(self.image_dir, imname) 71 | 72 | label = cname2lab[classname] 73 | item = Datum(impath=impath, label=label, classname=classname) 74 | items.append(item) 75 | 76 | return items -------------------------------------------------------------------------------- /CPL/datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Food101(DatasetBase): 13 | 14 | dataset_dir = "food-101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = DTD.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | -------------------------------------------------------------------------------- /CPL/datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetA(DatasetBase): 13 | """ImageNet-A(dversarial). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-adversarial" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /CPL/datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetR(DatasetBase): 13 | """ImageNet-R(endition). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-rendition" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /CPL/datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetSketch(DatasetBase): 11 | """ImageNet-Sketch. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenet-sketch" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "images") 22 | 23 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 24 | classnames = ImageNet.read_classnames(text_file) 25 | 26 | data = self.read_data(classnames) 27 | 28 | super().__init__(train_x=data, test=data) 29 | 30 | def read_data(self, classnames): 31 | image_dir = self.image_dir 32 | folders = listdir_nohidden(image_dir, sort=True) 33 | items = [] 34 | 35 | for label, folder in enumerate(folders): 36 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 37 | classname = classnames[folder] 38 | for imname in imnames: 39 | impath = os.path.join(image_dir, folder, imname) 40 | item = Datum(impath=impath, label=label, classname=classname) 41 | items.append(item) 42 | 43 | return items 44 | -------------------------------------------------------------------------------- /CPL/datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetV2(DatasetBase): 11 | """ImageNetV2. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenetv2" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | image_dir = "imagenetv2-matched-frequency-format-val" 22 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 23 | 24 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 25 | classnames = ImageNet.read_classnames(text_file) 26 | 27 | data = self.read_data(classnames) 28 | 29 | super().__init__(train_x=data, test=data) 30 | 31 | def read_data(self, classnames): 32 | image_dir = self.image_dir 33 | folders = list(classnames.keys()) 34 | items = [] 35 | 36 | for label in range(1000): 37 | class_dir = os.path.join(image_dir, str(label)) 38 | imnames = listdir_nohidden(class_dir) 39 | folder = folders[label] 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(class_dir, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /CPL/datasets/mscoco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class mscoco(DatasetBase): 12 | 13 | dataset_dir = "mscoco" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "val2014") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "classnames.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "train.txt") 30 | val = self.read_data(cname2lab, "val.txt") 31 | test = self.read_data(cname2lab, "test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) # 1 shot always 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | with open(os.path.join(self.dataset_dir, "test.txt"), "r") as f: 52 | lines = f.readlines() 53 | train_shot = (len(lines)) - 5000 54 | 55 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 56 | train, val, test = OxfordPets.subsample_classes(train, val, test, train_shot, subsample=subsample) # subsample = fewshot for training; = unseen for testing 57 | 58 | super().__init__(train_x=train, val=val, test=test) 59 | 60 | def read_data(self, cname2lab, split_file): 61 | filepath = os.path.join(self.dataset_dir, split_file) 62 | items = [] 63 | 64 | with open(filepath, "r") as f: 65 | lines = f.readlines() 66 | for line in lines: 67 | line = line.strip().split("*") 68 | imname = line[0] 69 | classname = " ".join(line[1:]) 70 | impath = os.path.join(self.image_dir, imname) 71 | 72 | label = cname2lab[classname] 73 | item = Datum(impath=impath, label=label, classname=classname) 74 | items.append(item) 75 | 76 | return items -------------------------------------------------------------------------------- /CPL/datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.io import loadmat 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class StanfordCars(DatasetBase): 13 | 14 | dataset_dir = "stanford_cars" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 25 | else: 26 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 27 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 28 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 29 | trainval = self.read_data("cars_train", trainval_file, meta_file) 30 | test = self.read_data("cars_test", test_file, meta_file) 31 | train, val = OxfordPets.split_trainval(trainval) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self, image_dir, anno_file, meta_file): 58 | anno_file = loadmat(anno_file)["annotations"][0] 59 | meta_file = loadmat(meta_file)["class_names"][0] 60 | items = [] 61 | 62 | for i in range(len(anno_file)): 63 | imname = anno_file[i]["fname"][0] 64 | impath = os.path.join(self.dataset_dir, image_dir, imname) 65 | label = anno_file[i]["class"][0, 0] 66 | label = int(label) - 1 # convert to 0-based index 67 | classname = meta_file[label][0] 68 | names = classname.split(" ") 69 | year = names.pop(-1) 70 | names.insert(0, year) 71 | classname = " ".join(names) 72 | item = Datum(impath=impath, label=label, classname=classname) 73 | items.append(item) 74 | 75 | return items 76 | -------------------------------------------------------------------------------- /CPL/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = "sun397" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 25 | else: 26 | classnames = [] 27 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | line = line.strip()[1:] # remove / 31 | classnames.append(line) 32 | cname2lab = {c: i for i, c in enumerate(classnames)} 33 | trainval = self.read_data(cname2lab, "Training_01.txt") 34 | test = self.read_data(cname2lab, "Testing_01.txt") 35 | print(f'cname2lab is, {cname2lab}') 36 | train, val = OxfordPets.split_trainval(trainval) 37 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 38 | 39 | num_shots = cfg.DATASET.NUM_SHOTS 40 | if num_shots >= 1: 41 | seed = cfg.SEED 42 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 43 | 44 | if os.path.exists(preprocessed): 45 | print(f"Loading preprocessed few-shot data from {preprocessed}") 46 | with open(preprocessed, "rb") as file: 47 | data = pickle.load(file) 48 | train, val = data["train"], data["val"] 49 | else: 50 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 51 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 52 | data = {"train": train, "val": val} 53 | print(f"Saving preprocessed few-shot data to {preprocessed}") 54 | with open(preprocessed, "wb") as file: 55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 56 | 57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 58 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 59 | 60 | 61 | super().__init__(train_x=train, val=val, test=test) 62 | 63 | def read_data(self, cname2lab, text_file): 64 | text_file = os.path.join(self.dataset_dir, text_file) 65 | items = [] 66 | 67 | with open(text_file, "r") as f: 68 | lines = f.readlines() 69 | for line in lines: 70 | imname = line.strip()[1:] # remove / 71 | classname = os.path.dirname(imname) 72 | label = cname2lab[classname] 73 | impath = os.path.join(self.image_dir, imname) 74 | 75 | names = classname.split("/")[1:] # remove 1st letter 76 | names = names[::-1] # put words like indoor/outdoor at first 77 | classname = " ".join(names) 78 | 79 | item = Datum(impath=impath, label=label, classname=classname) 80 | items.append(item) 81 | 82 | return items 83 | -------------------------------------------------------------------------------- /CPL/datasets/vqav2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class vqav2(DatasetBase): 12 | 13 | dataset_dir = "" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "classnames.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "train.txt") 30 | val = self.read_data(cname2lab, "val.txt") 31 | test = self.read_data(cname2lab, "test.txt") 32 | num_shots = cfg.DATASET.NUM_SHOTS 33 | 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) # 1 shot always 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | with open(os.path.join(self.dataset_dir, "test.txt"), "r") as f: 52 | lines = f.readlines() 53 | train_shot = (len(lines)) - 1000 54 | 55 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 56 | train, val, test = OxfordPets.subsample_classes(train, val, test, train_shot, subsample=subsample) # subsample = fewshot for training; = unseen for testing 57 | 58 | super().__init__(train_x=train, val=val, test=test) 59 | 60 | def read_data(self, cname2lab, split_file): 61 | filepath = os.path.join(self.dataset_dir, split_file) 62 | items = [] 63 | 64 | with open(filepath, "r") as f: 65 | lines = f.readlines() 66 | for line in lines: 67 | line = line.strip().split("*") 68 | imname = line[0] 69 | classname = " ".join(line[1:]) 70 | impath = os.path.join(self.image_dir, imname) 71 | 72 | label = cname2lab[classname] 73 | item = Datum(impath=impath, label=label, classname=classname) 74 | items.append(item) 75 | 76 | return items -------------------------------------------------------------------------------- /CPL/interpret_prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | from clip.simple_tokenizer import SimpleTokenizer 7 | from clip import clip 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def load_clip_to_cpu(backbone_name="RN50"): 14 | url = clip._MODELS[backbone_name] 15 | model_path = clip._download(url) 16 | 17 | try: 18 | # loading JIT archive 19 | model = torch.jit.load(model_path, map_location="cpu").eval() 20 | state_dict = None 21 | 22 | except RuntimeError: 23 | state_dict = torch.load(model_path, map_location="cpu") 24 | 25 | model = clip.build_model(state_dict or model.state_dict()) 26 | 27 | return model 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("fpath", type=str, help="Path to the learned prompt") 32 | parser.add_argument("topk", type=int, help="Select top-k similar words") 33 | args = parser.parse_args() 34 | 35 | fpath = args.fpath 36 | topk = args.topk 37 | 38 | assert os.path.exists(fpath) 39 | 40 | print(f"Return the top-{topk} matched words") 41 | 42 | tokenizer = SimpleTokenizer() 43 | clip_model = load_clip_to_cpu() 44 | token_embedding = clip_model.token_embedding.weight 45 | print(f"Size of token embedding: {token_embedding.shape}") 46 | 47 | prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"] 48 | ctx = prompt_learner["ctx"] 49 | ctx = ctx.float() 50 | print(f"Size of context: {ctx.shape}") 51 | 52 | if ctx.dim() == 2: 53 | # Generic context 54 | print('ctx dimension is', ctx.size(), ctx[:,0],token_embedding.size()) 55 | distance = torch.cdist(ctx, token_embedding) 56 | print(f"Size of distance matrix: {distance.shape}") 57 | sorted_idxs = torch.argsort(distance, dim=1) 58 | sorted_idxs = sorted_idxs[:, :topk] 59 | 60 | for m, idxs in enumerate(sorted_idxs): 61 | words = [tokenizer.decoder[idx.item()].split('<')[0] for idx in idxs] 62 | dist = [f"{distance[m, idx].item():.4f}" for idx in idxs] 63 | print(f"{m+1}: {words} {dist}") 64 | 65 | if m ==2: 66 | words = ['guest', 'challenges', 'number', 'danger', 'call', 'handle', 'crater', 'hood', 'drove', 'leeds'] 67 | if m ==3: 68 | words = ['gun', 'reasons', 'challenges', 'recl', 'booklet', 'dish', 'drying', 'screen', 'str', 'nd'] 69 | a1 = dist 70 | x = np.arange(10) 71 | 72 | plt.rcParams["font.family"] = "Times New Roman" 73 | 74 | fig, axes = plt.subplots(1, 1, figsize=(5, 3)) 75 | # 画柱状图 76 | axes.bar(x, a1, width=0.4, label='Distance', color="#D2ACA3") 77 | # 图例 78 | axes.legend(loc='best') 79 | # 设置坐标轴刻度、标签 80 | axes.set_xticks([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 81 | # axes.set_yticks([160, 165, 170, 175, 180, 185, 190]) 82 | # axes.set_ylim((0.8, 1.0)) 83 | # axes.set_xticklabels(['zhouyi', 'xuweijia', 'lurenchi', 'chenxiao', 'weiyu', 'guhaiyao']) 84 | fontdict = {'fontsize': 6} 85 | axes.set_xticklabels(words, fontdict=fontdict) 86 | # 设置title 87 | # axes.set_title('NLP group members heights') 88 | # 网格线 89 | axes.grid(linewidth=0.5, which="major", axis='y') 90 | # 隐藏上、右边框 91 | axes.spines['top'].set_visible(False) 92 | axes.spines['right'].set_visible(False) 93 | 94 | # for i in range(6): 95 | # axes.text(x[i], a1[i], a1[i], ha='center', va='bottom') 96 | 97 | plt.tight_layout() 98 | plt.show() 99 | fig.savefig(f"sun_dist_{m}.png", dpi=800) 100 | 101 | elif ctx.dim() == 3: 102 | # Class-specific context 103 | raise NotImplementedError 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /CPL/lib/vqaEvaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/CPL/lib/vqaEvaluation/__init__.py -------------------------------------------------------------------------------- /CPL/lib/vqaTools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/CPL/lib/vqaTools/__init__.py -------------------------------------------------------------------------------- /CPL/linear_probing/README.md: -------------------------------------------------------------------------------- 1 | # Linear Probe CLIP 2 | 3 | To run linear probe baselines, make sure that your current working directory is `lpclip/`. 4 | 5 | Step 1: Extract Features using the CLIP Image Encoder 6 | ```bash 7 | sh feat_extractor.sh 8 | ``` 9 | 10 | Step 2: Train few-shot linear probe 11 | ```bash 12 | sh linear_probe.sh 13 | ``` 14 | 15 | We follow the instructions stated in the Appendix A3 (pp.38) of [the original CLIP paper](https://arxiv.org/pdf/2103.00020.pdf), with a careful hyperparameter sweep. 16 | 17 | Note: please pull the latest Dassl (version >= `606a2c6`). 18 | -------------------------------------------------------------------------------- /CPL/prompt/infilled_template/number.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d997459230b146cc740f5596e35d7095d827e43ac8c97d253baa435060ea69be 3 | size 819255 4 | -------------------------------------------------------------------------------- /CPL/prompt/infilled_template/other.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2753ee87011e6d013de2b243e684452f065d2e610f75c033a1422e2481225b33 3 | size 18028828 4 | -------------------------------------------------------------------------------- /CPL/prompt/infilled_template/yesno.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f52935d21509dda2d097666a53d1b26e566576cf1c725d1aad5501875d30e7ba 3 | size 13159485 4 | -------------------------------------------------------------------------------- /CPL/prompt/number_template.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:533bf1ca629db2f0ef165dd44afea6819d85a9798dabf2c4d904b7ec9dc3203f 3 | size 680616 4 | -------------------------------------------------------------------------------- /CPL/prompt/other_T5filtered_answers.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3d2cd3d79ab89e9960f2ff992c2fc863f968a78cdbdfdf41441a49c27f2ee9d5 3 | size 136051955 4 | -------------------------------------------------------------------------------- /CPL/prompt/other_template.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fbc696aa80edb4d2d2ce216d4c37bb7f647f0ec72d4c4523bf4c58ec63a6206c 3 | size 19113600 4 | -------------------------------------------------------------------------------- /CPL/prompt/yesno_template.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:313b73a44d9f4fb4bf8e7cf2c3739694016b6eec515eda9d011cd3b905106cf1 3 | size 14980626 4 | -------------------------------------------------------------------------------- /CPL/scripts/.vscode-upload.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0379226f349d49084b56f3fd2c28ccc2691ca7d5fb96a150ece6a9cd7bd3c0e3 3 | size 383 4 | -------------------------------------------------------------------------------- /CPL/scripts/cocoop/.bash_logout: -------------------------------------------------------------------------------- 1 | # ~/.bash_logout: executed by bash(1) when login shell exits. 2 | 3 | # when leaving the console clear the screen to increase privacy 4 | 5 | if [ "$SHLVL" = 1 ]; then 6 | [ -x /usr/bin/clear_console ] && /usr/bin/clear_console -q 7 | fi 8 | -------------------------------------------------------------------------------- /CPL/scripts/zsclip/.bash_logout: -------------------------------------------------------------------------------- 1 | # ~/.bash_logout: executed by bash(1) when login shell exits. 2 | 3 | # when leaving the console clear the screen to increase privacy 4 | 5 | if [ "$SHLVL" = 1 ]; then 6 | [ -x /usr/bin/clear_console ] && /usr/bin/clear_console -q 7 | fi 8 | -------------------------------------------------------------------------------- /CPL/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/CPL/trainers/__init__.py -------------------------------------------------------------------------------- /CPL/trainers/imagenet_templates.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 2 | 3 | IMAGENET_TEMPLATES = [ 4 | "a bad photo of a {}.", 5 | "a photo of many {}.", 6 | "a sculpture of a {}.", 7 | "a photo of the hard to see {}.", 8 | "a low resolution photo of the {}.", 9 | "a rendering of a {}.", 10 | "graffiti of a {}.", 11 | "a bad photo of the {}.", 12 | "a cropped photo of the {}.", 13 | "a tattoo of a {}.", 14 | "the embroidered {}.", 15 | "a photo of a hard to see {}.", 16 | "a bright photo of a {}.", 17 | "a photo of a clean {}.", 18 | "a photo of a dirty {}.", 19 | "a dark photo of the {}.", 20 | "a drawing of a {}.", 21 | "a photo of my {}.", 22 | "the plastic {}.", 23 | "a photo of the cool {}.", 24 | "a close-up photo of a {}.", 25 | "a black and white photo of the {}.", 26 | "a painting of the {}.", 27 | "a painting of a {}.", 28 | "a pixelated photo of the {}.", 29 | "a sculpture of the {}.", 30 | "a bright photo of the {}.", 31 | "a cropped photo of a {}.", 32 | "a plastic {}.", 33 | "a photo of the dirty {}.", 34 | "a jpeg corrupted photo of a {}.", 35 | "a blurry photo of the {}.", 36 | "a photo of the {}.", 37 | "a good photo of the {}.", 38 | "a rendering of the {}.", 39 | "a {} in a video game.", 40 | "a photo of one {}.", 41 | "a doodle of a {}.", 42 | "a close-up photo of the {}.", 43 | "a photo of a {}.", 44 | "the origami {}.", 45 | "the {} in a video game.", 46 | "a sketch of a {}.", 47 | "a doodle of the {}.", 48 | "a origami {}.", 49 | "a low resolution photo of a {}.", 50 | "the toy {}.", 51 | "a rendition of the {}.", 52 | "a photo of the clean {}.", 53 | "a photo of a large {}.", 54 | "a rendition of a {}.", 55 | "a photo of a nice {}.", 56 | "a photo of a weird {}.", 57 | "a blurry photo of a {}.", 58 | "a cartoon {}.", 59 | "art of a {}.", 60 | "a sketch of the {}.", 61 | "a embroidered {}.", 62 | "a pixelated photo of a {}.", 63 | "itap of the {}.", 64 | "a jpeg corrupted photo of the {}.", 65 | "a good photo of a {}.", 66 | "a plushie {}.", 67 | "a photo of the nice {}.", 68 | "a photo of the small {}.", 69 | "a photo of the weird {}.", 70 | "the cartoon {}.", 71 | "art of the {}.", 72 | "a drawing of the {}.", 73 | "a photo of the large {}.", 74 | "a black and white photo of a {}.", 75 | "the plushie {}.", 76 | "a dark photo of a {}.", 77 | "itap of a {}.", 78 | "graffiti of the {}.", 79 | "a toy {}.", 80 | "itap of my {}.", 81 | "a photo of a cool {}.", 82 | "a photo of a small {}.", 83 | "a tattoo of the {}.", 84 | ] 85 | 86 | IMAGENET_TEMPLATES_SELECT = [ 87 | "itap of a {}.", 88 | "a bad photo of the {}.", 89 | "a origami {}.", 90 | "a photo of the large {}.", 91 | "a {} in a video game.", 92 | "art of the {}.", 93 | "a photo of the small {}.", 94 | ] 95 | -------------------------------------------------------------------------------- /CPL/trainers/prompters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class PadPrompter(nn.Module): 7 | def __init__(self, prompt_size, image_size): 8 | super(PadPrompter, self).__init__() 9 | pad_size = prompt_size 10 | image_size = image_size 11 | 12 | self.base_size = image_size - pad_size*2 13 | self.pad_up = nn.Parameter(torch.randn([1, 3, pad_size, image_size])) 14 | self.pad_down = nn.Parameter(torch.randn([1, 3, pad_size, image_size])) 15 | self.pad_left = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size])) 16 | self.pad_right = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size])) 17 | 18 | def forward(self, x): 19 | base = torch.zeros(1, 3, self.base_size, self.base_size).cuda() 20 | prompt = torch.cat([self.pad_left, base, self.pad_right], dim=3).cuda() 21 | prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=2).cuda() 22 | prompt = torch.cat(x.size(0) * [prompt]) 23 | 24 | return x + prompt 25 | 26 | 27 | class FixedPatchPrompter(nn.Module): 28 | def __init__(self, image_size, prompt_size): 29 | super(FixedPatchPrompter, self).__init__() 30 | self.isize = image_size 31 | self.psize = prompt_size 32 | self.patch = nn.Parameter(torch.randn([1, 3, self.psize, self.psize])) 33 | 34 | def forward(self, x): 35 | prompt = torch.zeros([1, 3, self.isize, self.isize]).cuda() 36 | prompt[:, :, :self.psize, :self.psize] = self.patch 37 | 38 | return x + prompt 39 | 40 | 41 | class RandomPatchPrompter(nn.Module): 42 | def __init__(self, args): 43 | super(RandomPatchPrompter, self).__init__() 44 | self.isize = args.image_size 45 | self.psize = args.prompt_size 46 | self.patch = nn.Parameter(torch.randn([1, 3, self.psize, self.psize])) 47 | 48 | def forward(self, x): 49 | x_ = np.random.choice(self.isize - self.psize) 50 | y_ = np.random.choice(self.isize - self.psize) 51 | 52 | prompt = torch.zeros([1, 3, self.isize, self.isize]).cuda() 53 | prompt[:, :, x_:x_ + self.psize, y_:y_ + self.psize] = self.patch 54 | 55 | return x + prompt 56 | 57 | 58 | def padding(prompt_size, image_size): 59 | return PadPrompter(prompt_size, image_size) 60 | 61 | 62 | def fixed_patch(args): 63 | return FixedPatchPrompter(args) 64 | 65 | 66 | def random_patch(args): 67 | return RandomPatchPrompter(args) -------------------------------------------------------------------------------- /CPL/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .configs import * 2 | from .utils import * 3 | from .solver import * 4 | from .comm import * 5 | from .visualization import * 6 | from ._utils import * -------------------------------------------------------------------------------- /CPL/utils/comm.py: -------------------------------------------------------------------------------- 1 | # import pickle 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class Comm(object): 8 | def __init__(self): 9 | self.local_rank = 0 10 | 11 | @property 12 | def world_size(self): 13 | if not dist.is_available(): 14 | return 1 15 | if not dist.is_initialized(): 16 | return 1 17 | return dist.get_world_size() 18 | 19 | @property 20 | def rank(self): 21 | if not dist.is_available(): 22 | return 0 23 | if not dist.is_initialized(): 24 | return 0 25 | return dist.get_rank() 26 | 27 | @property 28 | def local_rank(self): 29 | if not dist.is_available(): 30 | return 0 31 | if not dist.is_initialized(): 32 | return 0 33 | return self._local_rank 34 | 35 | @local_rank.setter 36 | def local_rank(self, value): 37 | if not dist.is_available(): 38 | self._local_rank = 0 39 | if not dist.is_initialized(): 40 | self._local_rank = 0 41 | self._local_rank = value 42 | 43 | @property 44 | def head(self): 45 | return 'Rank[{}/{}]'.format(self.rank, self.world_size) 46 | 47 | def is_main_process(self): 48 | return self.rank == 0 49 | 50 | def synchronize(self): 51 | if self.world_size == 1: 52 | return 53 | dist.barrier() 54 | 55 | 56 | comm = Comm() 57 | 58 | 59 | def all_gather(data): 60 | world_size = comm.world_size 61 | if world_size == 1: 62 | return [data] 63 | 64 | buffer = pickle.dumps(data) 65 | storage = torch.ByteStorage.from_buffer(buffer) 66 | tensor = torch.ByteTensor(storage).to("cuda") 67 | 68 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 69 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 70 | dist.all_gather(size_list, local_size) 71 | size_list = [int(size.item()) for size in size_list] 72 | max_size = max(size_list) 73 | 74 | tensor_list = [] 75 | for _ in size_list: 76 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 77 | if local_size != max_size: 78 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 79 | tensor = torch.cat((tensor, padding), dim=0) 80 | dist.all_gather(tensor_list, tensor) 81 | 82 | data_list = [] 83 | for size, tensor in zip(size_list, tensor_list): 84 | buffer = tensor.cpu().numpy().tobytes()[:size] 85 | data_list.append(pickle.loads(buffer)) 86 | 87 | return data_list 88 | 89 | 90 | def reduce_dict(input_dict, average=True): 91 | world_size = comm.world_size 92 | if world_size < 2: 93 | return input_dict 94 | with torch.no_grad(): 95 | names = [] 96 | values = [] 97 | for k in sorted(input_dict.keys()): 98 | names.append(k) 99 | values.append(input_dict[k]) 100 | values = torch.stack(values, dim=0) 101 | dist.reduce(values, dst=0) 102 | if dist.get_rank() == 0 and average: 103 | values /= world_size 104 | reduced_dict = {k: v for k, v in zip(names, values)} 105 | return reduced_dict 106 | 107 | 108 | def gather_tensors(tensor): 109 | tensors_gather = [ 110 | torch.ones_like(tensor) 111 | for _ in range(comm.world_size) 112 | ] 113 | 114 | dist.all_gather(tensors_gather, tensor, async_op=False) 115 | tensors_gather[comm.rank] = tensor 116 | output = torch.cat(tensors_gather, dim=0) 117 | return output 118 | -------------------------------------------------------------------------------- /CPL/utils/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | import zipfile 6 | from configs import * 7 | from clip import clip 8 | import glob 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset', type=str, default="mscoco") 13 | parser.add_argument('--seed', type=int, default=1) 14 | parser.add_argument('--opts', type=str, default="['DATASET.NUM_SHOTS', '1', 'DATASET.SUBSAMPLE_CLASSES', 'fewshot']", help='Options') 15 | parser.add_argument('--weights_folder', type=str, default= "../weights") 16 | parser.add_argument('--file_prefix', type=str, default = '') 17 | parser.add_argument('--label', type=str, default = None) 18 | args = parser.parse_args() 19 | 20 | 21 | 22 | assert not (os.path.exists(f'weights/{args.dataset}_{args.seed}_{args.opts}.pth')) 23 | 24 | rootpath = './' 25 | save_path = f'{args.weights_folder}' 26 | merge_path = os.path.join(rootpath, save_path) 27 | file_filter = args.file_prefix + f'*.pth' 28 | pth_path = os.path.join(merge_path, file_filter) 29 | 30 | files = glob.glob(pth_path, recursive = True) 31 | 32 | uv = torch.zeros(torch.load(files[0]).size()) 33 | 34 | for file in files: 35 | utemp = torch.clamp(torch.load(file), min=0.0) 36 | uv += utemp 37 | 38 | torch.save(uv.clamp(min=0, max=1), f'{args.weights_folder}/{args.dataset}_{args.seed}_{args.opts}.pth') -------------------------------------------------------------------------------- /CPL/utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from clip import clip 6 | 7 | 8 | 9 | 10 | class TextEncoder(nn.Module): 11 | def __init__(self, clip_model): 12 | super().__init__() 13 | self.transformer = clip_model.transformer 14 | self.positional_embedding = clip_model.positional_embedding 15 | self.ln_final = clip_model.ln_final 16 | self.text_projection = clip_model.text_projection 17 | self.dtype = clip_model.dtype 18 | 19 | def forward(self, tokenized_prompts): 20 | x = tokenized_prompts[torch.arange(tokenized_prompts.shape[0]), torch.arange(tokenized_prompts.shape[0])].type(self.dtype) @ self.text_projection 21 | return x 22 | 23 | 24 | 25 | 26 | def build_model(state_dict: dict): 27 | vit = "visual.proj" in state_dict 28 | 29 | if vit: 30 | vision_width = state_dict["visual.conv1.weight"].shape[0] 31 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 32 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 33 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 34 | image_resolution = vision_patch_size * grid_size 35 | else: 36 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 37 | vision_layers = tuple(counts) 38 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 39 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 40 | vision_patch_size = None 41 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 42 | image_resolution = output_width * 32 43 | 44 | embed_dim = state_dict["text_projection"].shape[1] 45 | context_length = state_dict["positional_embedding"].shape[0] 46 | vocab_size = state_dict["token_embedding.weight"].shape[0] 47 | transformer_width = state_dict["ln_final.weight"].shape[0] 48 | transformer_heads = transformer_width // 64 49 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 50 | 51 | model = CLIP( 52 | embed_dim, 53 | image_resolution, vision_layers, vision_width, vision_patch_size, 54 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 55 | ) 56 | 57 | for key in ["input_resolution", "context_length", "vocab_size"]: 58 | if key in state_dict: 59 | del state_dict[key] 60 | 61 | convert_weights(model) 62 | model.load_state_dict(state_dict) 63 | return model.eval() -------------------------------------------------------------------------------- /CPL/utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pathlib import Path 6 | 7 | import os 8 | import logging 9 | import time 10 | 11 | from .comm import comm 12 | 13 | 14 | def setup_logger(final_output_dir, rank, phase): 15 | time_str = time.strftime('%Y-%m-%d-%H-%M') 16 | log_file = f'{phase}_{time_str}_rank{rank}.txt' 17 | final_log_file = os.path.join(final_output_dir, log_file) 18 | head = "%(asctime)-15s:[P:%(process)d]:" + comm.head + ' %(message)s' 19 | logging.basicConfig( 20 | filename=str(final_log_file), format=head 21 | ) 22 | logger = logging.getLogger() 23 | logger.setLevel(logging.INFO) 24 | console = logging.StreamHandler() 25 | console.setFormatter( 26 | logging.Formatter(head) 27 | ) 28 | logging.getLogger('').addHandler(console) 29 | 30 | 31 | def create_logger(cfg, phase='train'): 32 | root_output_dir = Path(cfg.OUTPUT_DIR) 33 | dataset = cfg.DATASET.DATASET 34 | cfg_name = cfg.NAME 35 | 36 | final_output_dir = root_output_dir / dataset / cfg_name 37 | 38 | print('=> creating {} ...'.format(root_output_dir)) 39 | root_output_dir.mkdir(parents=True, exist_ok=True) 40 | print('=> creating {} ...'.format(final_output_dir)) 41 | final_output_dir.mkdir(parents=True, exist_ok=True) 42 | 43 | print('=> setup logger ...') 44 | setup_logger(final_output_dir, cfg.RANK, phase) 45 | 46 | return str(final_output_dir) 47 | 48 | -------------------------------------------------------------------------------- /CPL/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from PIL import Image 4 | import numpy as np 5 | # import cv2 6 | import matplotlib.pyplot as plt 7 | from torchvision.transforms import ToTensor, ToPILImage 8 | from utils import * 9 | from clip.utils import * 10 | 11 | 12 | def interpret(image, text, model, device, index=None): 13 | logits_per_image, logits_per_text = model(image, text) 14 | probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy() 15 | if index is None: 16 | index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1) 17 | one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32) 18 | one_hot[0, index] = 1 19 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 20 | one_hot = torch.sum(one_hot.cuda() * logits_per_image) 21 | model.zero_grad() 22 | one_hot.backward(retain_graph=True) 23 | 24 | image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values()) 25 | num_tokens = image_attn_blocks[0].attn_probs.shape[-1] 26 | R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device) 27 | for blk in image_attn_blocks: 28 | grad = blk.attn_grad 29 | cam = blk.attn_probs 30 | cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1]) 31 | grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1]) 32 | cam = grad * cam 33 | cam = cam.clamp(min=0).mean(dim=0) 34 | R += torch.matmul(cam, R) 35 | R[0, 0] = 0 36 | image_relevance = R[0, 1:] 37 | 38 | def show_cam_on_image(img, mask): 39 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 40 | heatmap = np.float32(heatmap) / 255 41 | cam = heatmap + np.float32(img) 42 | cam = cam / np.max(cam) 43 | return cam 44 | 45 | image_relevance = image_relevance.reshape(1, 1, 7, 7) 46 | image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear') 47 | image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy() 48 | image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) 49 | image = image[0].permute(1, 2, 0).data.cpu().numpy() 50 | image = (image - image.min()) / (image.max() - image.min()) 51 | vis = show_cam_on_image(image, image_relevance) 52 | vis = np.uint8(255 * vis) 53 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) 54 | 55 | plt.imshow(vis) 56 | plt.show() 57 | 58 | print("Label probs:", probs) 59 | 60 | 61 | def cfgen(x, nx, u, y, use_cuda=True): 62 | xp = (1-u[:,y]).t()*x + u[:,y].t()*nx 63 | return xp, u[:,y] 64 | 65 | 66 | def visualize_feature(model, img, img_features, img2, img2_features, S): 67 | for idx, idx2 in S: 68 | col = idx % img_features.shape[1] 69 | row = idx // img_features.shape[1] 70 | inv_size1 = get_size_before_generation(img_features.shape[1], model) 71 | window, window2 = compute_window(col, row, inv_size1) 72 | inv_size2 = get_size_before_generation(inv_size1, model) 73 | window0 = [window[0]*2, (window[1]*2)+2, window[2]*2, (window[3]*2)+2] 74 | img = visualize(img, img2, image_features, image2_features, window, window2, window0) 75 | ToPILImage()(img).save("visualization.png") 76 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xuehai He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/assets/.gitkeep -------------------------------------------------------------------------------- /assets/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eric-ai-lab/CPL/ce47190b110ec6807081dd70571454b8c23368c4/assets/motivation.png --------------------------------------------------------------------------------