├── .gitignore ├── README.md ├── config ├── A2C.yaml ├── A2P.yaml ├── A2S.yaml ├── C2A.yaml ├── C2P.yaml ├── C2S.yaml ├── P2A.yaml ├── P2C.yaml ├── P2S.yaml ├── S2A.yaml ├── S2C.yaml ├── S2P.yaml ├── art2clipart.yaml ├── art2product.yaml ├── art2realWorld.yaml ├── clipart2art.yaml ├── clipart2product.yaml ├── clipart2realWorld.yaml ├── mnist2mnist_m.yaml ├── mnist2svhn.yaml ├── mnist2syn.yaml ├── mnist_m2mnist.yaml ├── mnist_m2svhn.yaml ├── mnist_m2syn.yaml ├── product2art.yaml ├── product2clipart.yaml ├── product2realWorld.yaml ├── realWorld2art.yaml ├── realWorld2clipart.yaml ├── realWorld2product.yaml ├── svhn2mnist.yaml ├── svhn2mnist_m.yaml ├── svhn2syn.yaml ├── syn2mnist.yaml ├── syn2mnist_m.yaml └── syn2svhn.yaml ├── coteaching_trainer.py ├── main.py ├── main_web_ignored.py ├── main_web_refined.py ├── main_web_retained.py ├── mcd_trainer.py ├── models ├── __init__.py ├── build.py ├── cnn_digitsdg.py ├── mixstyle.py └── resnet.py ├── requirements.txt ├── script ├── digits.sh ├── officeHome.sh ├── officeHome_web_ignored.sh ├── officeHome_web_refined.sh ├── officeHome_web_retained.sh ├── pacs.sh ├── pacs_web_ignored.sh ├── pacs_web_refined.sh └── pacs_web_retained.sh ├── tools └── parse_test_res_single.py └── utils ├── __init__.py ├── dataset.py ├── evaluator.py ├── gaussian_blur.py ├── logger.py ├── meters.py ├── mixup.py ├── sampler.py ├── tools.py ├── torchtools.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-Supervised Domain Generalization with Evolving Intermediate Domain 2 | 3 | [paper](https://www.sciencedirect.com/science/article/pii/S0031320324000311) 4 | 5 | ## Requirements 6 | 7 | - numpy==1.19.2 8 | - Pillow==8.1.0 9 | - PyYAML==5.3.1 10 | - scipy==1.9.3 11 | - scikit_learn==1.1.1 12 | - six==1.15.0 13 | - torch==1.7.1 14 | - torchvision==0.8.2 15 | - Ubuntu==18.04 16 | - Python==3.8.5 17 | 18 | ## Installation 19 | 20 | Local install 21 | 22 | ``` 23 | git clone https://github.com/MetaVisionLab/SSDG.git 24 | cd SSDG 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | Or using docker image ```pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel```. 29 | 30 | ## Data Preparation 31 | 32 | You can download the dataset [PACS](https://csip.fzu.edu.cn/files/datasets/SSDG/pacs.zip), [Digits-DG](https://csip.fzu.edu.cn/files/datasets/SSDG/digits_dg.zip), [Office-Home-DG](https://csip.fzu.edu.cn/files/datasets/SSDG/office_home_dg.zip), [web_pacs](https://csip.fzu.edu.cn/files/datasets/SSDG/web_pacs.zip) and [web_office](https://csip.fzu.edu.cn/files/datasets/SSDG/web_office.zip) to the folder SSDG/data and unzip it, which include three folders representing three datasets in our paper. 33 | 34 | File structure: 35 | 36 | ``` 37 | SSDG/data 38 | |--digits_dg/ 39 | |–– images/ 40 | |–– splits/ 41 | |--office_home_dg/ 42 | |–– images/ 43 | |–– splits/ 44 | |--pacs/ 45 | |–– images/ 46 | |–– splits/ 47 | |--web_office/ 48 | |--web_pacs/ 49 | ``` 50 | 51 | ## Training 52 | 53 | ``` 54 | bash script/pacs.sh #train pacs 55 | bash script/digits.sh #train digits_dg 56 | bash script/officeHome.sh #train office_home_dg 57 | ... 58 | ``` 59 | 60 | # Validation 61 | 62 | We evaluate our method on all the SSDG tasks for each dataset and report the average accuracy.For each task(for example A2C),we report its average accuracy on last 5 epoch. 63 | 64 | ``` 65 | cd tools/ 66 | python parse_test_res_single.py log/UDAG_A2C.log --test-log 67 | ``` 68 | 69 | # Citation 70 | 71 | Please cite our paper: 72 | 73 | ``` 74 | @article{lin2024semi, 75 | title={Semi-supervised domain generalization with evolving intermediate domain}, 76 | author={Lin, Luojun and Xie, Han and Sun, Zhishu and Chen, Weijie and Liu, Wenxi and Yu, Yuanlong and Zhang, Lei}, 77 | journal={Pattern Recognition}, 78 | volume={149}, 79 | pages={110280}, 80 | year={2024}, 81 | publisher={Elsevier} 82 | } 83 | ``` 84 | # Contact us 85 | 86 | For any questions, please feel free to contact [Han Xie](mailto:han_xie@foxmail.com) or Dr. [Luojun Lin](mailto:linluojun2009@126.com). 87 | 88 | # Copyright 89 | 90 | This code is free to the academic community for research purpose only. For commercial purpose usage, please contact Dr. [Luojun Lin](mailto:linluojun2009@126.com). 91 | -------------------------------------------------------------------------------- /config/A2C.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/A2C' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'art_painting' 25 | source_domain_u_1: 'sketch' 26 | source_domain_u_2: 'photo' 27 | target_domain: 'cartoon' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_A2C' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/A2P.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/A2P' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'art_painting' 25 | source_domain_u_1: 'sketch' 26 | source_domain_u_2: 'cartoon' 27 | target_domain: 'photo' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_A2P' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/A2S.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/A2S' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'art_painting' 25 | source_domain_u_1: 'photo' 26 | source_domain_u_2: 'cartoon' 27 | target_domain: 'sketch' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_A2S' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/C2A.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/C2A' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'cartoon' 25 | source_domain_u_1: 'sketch' 26 | source_domain_u_2: 'photo' 27 | target_domain: 'art_painting' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_C2A' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/C2P.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/C2P' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'cartoon' 25 | source_domain_u_1: 'sketch' 26 | source_domain_u_2: 'art_painting' 27 | target_domain: 'photo' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_C2P' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/C2S.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/C2S' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'cartoon' 25 | source_domain_u_1: 'art_painting' 26 | source_domain_u_2: 'photo' 27 | target_domain: 'sketch' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_C2S' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/P2A.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/P2A' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'photo' 25 | source_domain_u_1: 'sketch' 26 | source_domain_u_2: 'cartoon' 27 | target_domain: 'art_painting' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_P2A' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /config/P2C.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/P2C' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'photo' 25 | source_domain_u_1: 'sketch' 26 | source_domain_u_2: 'art_painting' 27 | target_domain: 'cartoon' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_P2C' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /config/P2S.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/P2S' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'photo' 25 | source_domain_u_1: 'cartoon' 26 | source_domain_u_2: 'art_painting' 27 | target_domain: 'sketch' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_P2S' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /config/S2A.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/S2A' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'sketch' 25 | source_domain_u_1: 'cartoon' 26 | source_domain_u_2: 'photo' 27 | target_domain: 'art_painting' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_S2A' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/S2C.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/S2C' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'sketch' 25 | source_domain_u_1: 'photo' 26 | source_domain_u_2: 'art_painting' 27 | target_domain: 'cartoon' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_S2C' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/S2P.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/S2P' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'pacs' 21 | class_number: 7 22 | root: 'data/pacs' 23 | webroot: 'data/web_pacs' 24 | source_domain_x: 'sketch' 25 | source_domain_u_1: 'cartoon' 26 | source_domain_u_2: 'art_painting' 27 | target_domain: 'photo' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'UDAG_S2P' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/art2clipart.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/art2clipart' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'art' 25 | source_domain_u_1: 'product' 26 | source_domain_u_2: 'real_world' 27 | target_domain: 'clipart' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'art2clipart' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/art2product.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/art2product' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'art' 25 | source_domain_u_1: 'clipart' 26 | source_domain_u_2: 'real_world' 27 | target_domain: 'product' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'art2product' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/art2realWorld.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/art2realWorld' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'art' 25 | source_domain_u_1: 'product' 26 | source_domain_u_2: 'clipart' 27 | target_domain: 'real_world' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'art2realWorld' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/clipart2art.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/clipart2art' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'clipart' 25 | source_domain_u_1: 'product' 26 | source_domain_u_2: 'real_world' 27 | target_domain: 'art' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'clipart2art' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/clipart2product.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/clipart2product' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'clipart' 25 | source_domain_u_1: 'art' 26 | source_domain_u_2: 'real_world' 27 | target_domain: 'product' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'clipart2product' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/clipart2realWorld.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/clipart2realWorld' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'clipart' 25 | source_domain_u_1: 'product' 26 | source_domain_u_2: 'art' 27 | target_domain: 'real_world' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'clipart2realWorld' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/mnist2mnist_m.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/mnist2mnist_m' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'mnist' 24 | source_domain_u_1: 'svhn' 25 | source_domain_u_2: 'syn' 26 | target_domain: 'mnist_m' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'mnist2mnist_m' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/mnist2svhn.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/mnist2svhn' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'mnist' 24 | source_domain_u_1: 'mnist_m' 25 | source_domain_u_2: 'syn' 26 | target_domain: 'svhn' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'mnist2svhn' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/mnist2syn.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/mnist2syn' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'mnist' 24 | source_domain_u_1: 'mnist_m' 25 | source_domain_u_2: 'svhn' 26 | target_domain: 'syn' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'mnist2syn' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/mnist_m2mnist.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/mnist_m2mnist' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'mnist_m' 24 | source_domain_u_1: 'svhn' 25 | source_domain_u_2: 'syn' 26 | target_domain: 'mnist' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'mnist_m2mnist' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/mnist_m2svhn.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/mnist_m2svhn' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'mnist_m' 24 | source_domain_u_1: 'syn' 25 | source_domain_u_2: 'mnist' 26 | target_domain: 'svhn' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'mnist_m2svhn' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/mnist_m2syn.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/mnist_m2syn' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'mnist_m' 24 | source_domain_u_1: 'mnist' 25 | source_domain_u_2: 'svhn' 26 | target_domain: 'syn' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'mnist_m2syn' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/product2art.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/product2art' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'product' 25 | source_domain_u_1: 'clipart' 26 | source_domain_u_2: 'real_world' 27 | target_domain: 'art' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'product2art' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/product2clipart.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/product2clipart' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'product' 25 | source_domain_u_1: 'art' 26 | source_domain_u_2: 'real_world' 27 | target_domain: 'clipart' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'product2clipart' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/product2realWorld.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/product2realWorld' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'product' 25 | source_domain_u_1: 'clipart' 26 | source_domain_u_2: 'art' 27 | target_domain: 'real_world' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'product2realWorld' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/realWorld2art.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/realWorld2art' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'real_world' 25 | source_domain_u_1: 'product' 26 | source_domain_u_2: 'clipart' 27 | target_domain: 'art' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'realWorld2art' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/realWorld2clipart.yaml: -------------------------------------------------------------------------------- 1 | 2 | data_transforms: 3 | input_size: 224 4 | 5 | trainer: 6 | n_step_f: 4 # MCD N_STEP_F 7 | batch_size: 128 8 | # max_epochs: 200 9 | num_workers: 4 10 | save_model_addr: 'checkpoints/realWorld2clipart' 11 | forget_rate: 0.5 12 | 13 | optimizer: 14 | name: 'SGD' 15 | params: 16 | lr: 0.001 17 | momentum: 0.9 18 | weight_decay: 0.0005 19 | 20 | data: 21 | type: 'officeHome' 22 | class_number: 65 23 | root: 'data/office_home_dg' 24 | webroot: 'data/web_office' 25 | source_domain_x: 'real_world' 26 | source_domain_u_1: 'product' 27 | source_domain_u_2: 'art' 28 | target_domain: 'clipart' 29 | 30 | log: 31 | save_addr: 'log' 32 | save_name: 'realWorld2clipart' 33 | 34 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 35 | -------------------------------------------------------------------------------- /config/realWorld2product.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 224 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/realWorld2product' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'officeHome' 21 | class_number: 65 22 | root: 'data/office_home_dg' 23 | webroot: 'data/web_office' 24 | source_domain_x: 'real_world' 25 | source_domain_u_1: 'clipart' 26 | source_domain_u_2: 'art' 27 | target_domain: 'product' 28 | 29 | log: 30 | save_addr: 'log' 31 | save_name: 'realWorld2product' 32 | 33 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 34 | -------------------------------------------------------------------------------- /config/svhn2mnist.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/svhn2mnist' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'svhn' 24 | source_domain_u_1: 'mnist_m' 25 | source_domain_u_2: 'syn' 26 | target_domain: 'mnist' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'svhn2mnist' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/svhn2mnist_m.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/svhn2mnist_m' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'svhn' 24 | source_domain_u_1: 'syn' 25 | source_domain_u_2: 'mnist' 26 | target_domain: 'mnist_m' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'svhn2mnist_m' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/svhn2syn.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/svhn2syn' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'svhn' 24 | source_domain_u_1: 'mnist' 25 | source_domain_u_2: 'mnist_m' 26 | target_domain: 'syn' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'svhn2syn' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/syn2mnist.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/syn2mnist' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'syn' 24 | source_domain_u_1: 'mnist_m' 25 | source_domain_u_2: 'svhn' 26 | target_domain: 'mnist' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'syn2mnist' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/syn2mnist_m.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/syn2mnist_m' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'syn' 24 | source_domain_u_1: 'mnist' 25 | source_domain_u_2: 'svhn' 26 | target_domain: 'mnist_m' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'syn2mnist_m' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /config/syn2svhn.yaml: -------------------------------------------------------------------------------- 1 | data_transforms: 2 | input_size: 32 3 | 4 | trainer: 5 | n_step_f: 4 # MCD N_STEP_F 6 | batch_size: 128 7 | # max_epochs: 200 8 | num_workers: 4 9 | save_model_addr: 'checkpoints/syn2svhn' 10 | forget_rate: 0.5 11 | 12 | optimizer: 13 | name: 'SGD' 14 | params: 15 | lr: 0.001 16 | momentum: 0.9 17 | weight_decay: 0.0005 18 | 19 | data: 20 | type: 'digits' 21 | class_number: 10 22 | root: 'data/digits_dg' 23 | source_domain_x: 'syn' 24 | source_domain_u_1: 'mnist' 25 | source_domain_u_2: 'mnist_m' 26 | target_domain: 'svhn' 27 | 28 | log: 29 | save_addr: 'log' 30 | save_name: 'syn2svhn' 31 | 32 | ratio: 0.4 # a ratio to split samples into clean set and noisy set 33 | -------------------------------------------------------------------------------- /coteaching_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from utils import * 7 | from utils.dataset import * 8 | from utils.evaluator import Classification 9 | import numpy as np 10 | 11 | 12 | def discrepancy(y1, y2): 13 | return (y1 - y2).abs().mean() 14 | 15 | 16 | class CoTeachingTrainer: 17 | def __init__(self, models, optimizer, device, index, source_domain, **params): 18 | self.index = index 19 | self._models = models 20 | self.F1 = models["F1"] 21 | self.F2 = models["F2"] 22 | self.C1 = models["C1"] 23 | self.C2 = models["C2"] 24 | self._optims = optimizer 25 | self.source_domain = source_domain 26 | 27 | self.device = device 28 | self.batch_size = params['batch_size'] 29 | self.num_workers = params['num_workers'] 30 | self.batch_size = params['batch_size'] 31 | self.forget_rate = params['forget_rate'] 32 | self.folder = osp.join(params['save_model_addr'], f"coTeaching{index}") 33 | 34 | self.max_epoch = 0 35 | self.start_epoch = self.epoch = 0 36 | self.evaluator = Classification() 37 | mkdir_if_missing(self.folder) 38 | 39 | def get_model_names(self, names=None): 40 | names_real = list(self._models.keys()) 41 | if names is not None: 42 | names = tolist_if_not(names) 43 | for name in names: 44 | assert name in names_real 45 | return names 46 | else: 47 | return names_real 48 | 49 | def get_current_lr(self, names=None): 50 | names = self.get_model_names(names) 51 | name = names[0] 52 | return self._optims[name].param_groups[0]['lr'] 53 | 54 | def set_model_mode(self, mode='train', names=None): 55 | names = self.get_model_names(names) 56 | 57 | for name in names: 58 | if mode == 'train': 59 | self._models[name].train() 60 | else: 61 | self._models[name].eval() 62 | 63 | def model_zero_grad(self, names=None): 64 | names = self.get_model_names(names) 65 | for name in names: 66 | if self._optims[name] is not None: 67 | self._optims[name].zero_grad() 68 | 69 | def detect_anomaly(self, loss): 70 | if not torch.isfinite(loss).all(): 71 | raise FloatingPointError('Loss is infinite or NaN!') 72 | 73 | def model_backward(self, loss, retain_graph): 74 | self.detect_anomaly(loss) 75 | loss.backward(retain_graph=retain_graph) 76 | 77 | def model_update(self, names=None): 78 | names = self.get_model_names(names) 79 | for name in names: 80 | if self._optims[name] is not None: 81 | self._optims[name].step() 82 | 83 | def model_backward_and_update(self, loss, names=None, retain_graph=False): 84 | self.model_zero_grad(names) 85 | self.model_backward(loss, retain_graph) 86 | self.model_update(names) 87 | 88 | def save_model(self, epoch, directory, is_best=False, model_name=''): 89 | names = self.get_model_names() 90 | 91 | for name in names: 92 | model_dict = self._models[name].state_dict() 93 | 94 | optim_dict = None 95 | if self._optims[name] is not None: 96 | optim_dict = self._optims[name].state_dict() 97 | 98 | save_checkpoint( 99 | { 100 | 'state_dict': model_dict, 101 | 'epoch': epoch + 1, 102 | 'optimizer': optim_dict 103 | }, 104 | osp.join(directory, name), 105 | is_best=is_best, 106 | model_name=model_name 107 | ) 108 | 109 | def after_epoch(self, test_loader): 110 | if (self.epoch + 1) % 1 == 0: 111 | self.test(test_loader) 112 | last_epoch = (self.epoch + 1) == self.max_epoch 113 | if last_epoch: 114 | self.save_model(self.epoch, self.folder) 115 | 116 | def before_train(self): 117 | # Remember the starting time (for computing the elapsed time) 118 | self.time_start = time.time() 119 | 120 | def after_train(self): 121 | print(f'Finished CoTeaching_{self.index} training') 122 | 123 | # Show elapsed time 124 | elapsed = round(time.time() - self.time_start) 125 | elapsed = str(datetime.timedelta(seconds=elapsed)) 126 | print('Elapsed: {}'.format(elapsed)) 127 | 128 | def train_dg(self, train_loader, test_loader, max_epoch): 129 | self.max_epoch = max_epoch 130 | self.rate_schedule = np.ones(self.max_epoch) * self.forget_rate 131 | self.rate_schedule[:10] = np.linspace(0, self.forget_rate ** 1, 10) 132 | self.before_train() 133 | for self.epoch in range(self.start_epoch, self.max_epoch): 134 | self.run_epoch(train_loader) 135 | self.after_epoch(test_loader) 136 | self.after_train() 137 | 138 | def run_epoch(self, train_loader): 139 | self.set_model_mode('train') 140 | losses = MetricMeter() 141 | batch_time = AverageMeter() 142 | data_time = AverageMeter() 143 | 144 | # Decide to iterate over labeled or unlabeled dataset 145 | len_train_loader = len(train_loader) 146 | num_batches = len_train_loader 147 | 148 | train_loader_iter = iter(train_loader) 149 | 150 | end = time.time() 151 | 152 | for self.batch_idx in range(len_train_loader): 153 | try: 154 | batch_train = next(train_loader_iter) 155 | except StopIteration: 156 | train_loader_iter = iter(train_loader) 157 | batch_train = next(train_loader_iter) 158 | 159 | data_time.update(time.time() - end) 160 | loss_summary = self.forward_backward(batch_train) 161 | batch_time.update(time.time() - end) 162 | losses.update(loss_summary) 163 | 164 | if (self.batch_idx + 1) % 10 == 0: 165 | nb_this_epoch = num_batches - (self.batch_idx + 1) 166 | nb_future_epochs = ( 167 | self.max_epoch - (self.epoch + 1) 168 | ) * num_batches 169 | eta_seconds = batch_time.avg * (nb_this_epoch + nb_future_epochs) 170 | eta = str(datetime.timedelta(seconds=int(eta_seconds))) 171 | print( 172 | 'epoch [{0}/{1}][{2}/{3}]\t' 173 | 'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 174 | 'data {data_time.val:.3f} ({data_time.avg:.3f})\t' 175 | 'eta {eta}\t' 176 | '{losses}\t' 177 | 'lr {lr}'.format( 178 | self.epoch + 1, 179 | self.max_epoch, 180 | self.batch_idx + 1, 181 | num_batches, 182 | batch_time=batch_time, 183 | data_time=data_time, 184 | eta=eta, 185 | losses=losses, 186 | lr=self.get_current_lr() 187 | ) 188 | ) 189 | 190 | end = time.time() 191 | 192 | def ELogLoss(self, probx): #没有加上负号 193 | ''' 194 | x : [B, X] 195 | ''' 196 | logx = torch.log(probx + 1e-5) 197 | # [B, X] => [B, 1, X] 198 | probx = probx.unsqueeze(1) 199 | # [B, X] => [B, X, 1] 200 | logx = logx.unsqueeze(2) 201 | # [B, 1, 1] 202 | res = torch.bmm(probx, logx) 203 | # [B, 1] 204 | return res.squeeze(2) 205 | 206 | def forward_backward(self, batch): 207 | parsed = self.parse_batch_train(batch) 208 | input, label, paths = parsed 209 | 210 | output1 = self.F1(input) 211 | output1 = self.C1(output1) 212 | output2 = self.F2(input) 213 | output2 = self.C2(output2) 214 | 215 | domain_true = [] 216 | domain_pse = [] 217 | for i,path in enumerate(paths): 218 | if path.find(self.source_domain) != -1: 219 | domain_true.append(i) 220 | else: 221 | domain_pse.append(i) 222 | 223 | 224 | domain_true = torch.tensor(domain_true).type(torch.long) 225 | domain_pse = torch.tensor(domain_pse).type(torch.long) 226 | output1_true = output1[domain_true] 227 | output2_true = output2[domain_true] 228 | label_true = label[domain_true] 229 | 230 | output1_pse = output1[domain_pse] 231 | output2_pse = output2[domain_pse] 232 | label_pse = label[domain_pse] 233 | 234 | loss_1_true = F.cross_entropy(output1_true, label_true) 235 | loss_2_true = F.cross_entropy(output2_true, label_true) 236 | 237 | loss_1 = F.cross_entropy(output1_pse, label_pse, reduction='none') 238 | ind_1_sorted = np.argsort(loss_1.data.cpu()) 239 | loss_1_sorted = loss_1[ind_1_sorted] 240 | 241 | loss_2 = F.cross_entropy(output2_pse, label_pse, reduction='none') 242 | ind_2_sorted = np.argsort(loss_2.data.cpu()) 243 | loss_2_sorted = loss_2[ind_2_sorted] 244 | 245 | remember_rate = 1 - self.rate_schedule[self.epoch] 246 | num_remember = int(remember_rate * len(loss_1_sorted)) 247 | 248 | ind_1_update = ind_1_sorted[:num_remember] 249 | ind_2_update = ind_2_sorted[:num_remember] 250 | 251 | loss_1_update = F.cross_entropy(output1_pse[ind_2_update], label_pse[ind_2_update]) # 交换 252 | loss_2_update = F.cross_entropy(output2_pse[ind_1_update], label_pse[ind_1_update]) 253 | 254 | # information maximization loss 255 | pred1_pse = F.softmax(output1_pse, 1) 256 | pred2_pse = F.softmax(output2_pse, 1) 257 | pred1_pse_hat = pred1_pse.mean(dim=0).unsqueeze(0) 258 | loss_1_im = -self.ELogLoss(pred1_pse).mean(0) + self.ELogLoss(pred1_pse_hat).sum() 259 | pred2_pse_hat = pred2_pse.mean(dim=0).unsqueeze(0) 260 | loss_2_im = -self.ELogLoss(pred2_pse).mean(0) + self.ELogLoss(pred2_pse_hat).sum() 261 | 262 | if domain_true.shape != torch.Size([0]): 263 | loss_1_final = loss_1_true + loss_1_update + loss_1_im 264 | loss_2_final = loss_2_true + loss_2_update + loss_2_im 265 | else: 266 | loss_1_final = loss_1_update + loss_1_im 267 | loss_2_final = loss_2_update + loss_2_im 268 | self.model_backward_and_update(loss_1_final, ['F1','C1']) 269 | self.model_backward_and_update(loss_2_final, ['F2','C2']) 270 | 271 | loss_summary = { 272 | 'loss_model1_true': loss_1_true.item(), 273 | 'loss_model2_true': loss_2_true.item(), 274 | 'loss_model1_pse': loss_1_update.item(), 275 | 'loss_mode2_pse': loss_2_update.item(), 276 | 'loss_1_final': loss_1_final.item(), 277 | 'loss_2_final': loss_2_final.item() 278 | } 279 | 280 | return loss_summary 281 | 282 | 283 | return 284 | 285 | def parse_batch_train(self, batch): 286 | input = batch['img'] 287 | label = batch['label'] 288 | impath = batch['impath'] 289 | 290 | input = input.to(self.device) 291 | label = label.to(self.device) 292 | 293 | 294 | return input, label, impath 295 | 296 | 297 | def parse_batch_test(self, batch): 298 | input = batch['img'] 299 | label = batch['label'] 300 | impath = batch["impath"] 301 | 302 | input = input.to(self.device) 303 | label = label.to(self.device) 304 | 305 | 306 | return input, label, impath 307 | 308 | def model_inference(self, input): 309 | feat = self.F1(input) 310 | return self.C1(feat) 311 | 312 | def load_model(self, directory, epoch=None): 313 | names = self.get_model_names() 314 | 315 | # By default, the best model is loaded 316 | model_file = 'model-best.pth.tar' 317 | 318 | if epoch is not None: 319 | model_file = 'model.pth.tar-' + str(epoch) 320 | 321 | for name in names: 322 | model_path = osp.join(directory, name, model_file) 323 | 324 | if not osp.exists(model_path): 325 | raise FileNotFoundError( 326 | 'Model not found at "{}"'.format(model_path) 327 | ) 328 | 329 | checkpoint = load_checkpoint(model_path) 330 | state_dict = checkpoint['state_dict'] 331 | epoch = checkpoint['epoch'] 332 | 333 | print( 334 | 'Loading weights to {} ' 335 | 'from "{}" (epoch = {})'.format(name, model_path, epoch) 336 | ) 337 | self._models[name].load_state_dict(state_dict) 338 | 339 | def test(self, test_loader): 340 | """A testing pipeline.""" 341 | self.set_model_mode('eval') 342 | self.evaluator.reset() 343 | 344 | for batch_idx, batch in enumerate(test_loader): 345 | input, label, _ = self.parse_batch_test(batch) 346 | output = self.model_inference(input) 347 | self.evaluator.process(output, label) 348 | 349 | results = self.evaluator.evaluate() 350 | 351 | return results['accuracy'] 352 | 353 | def update_lr(self, index=1): 354 | for key in self._optims: 355 | opt = self._optims[key] 356 | lr = opt.defaults["lr"] / index**2 357 | opt.param_groups[0]['lr'] = lr 358 | 359 | def get_pl(self, target_loader, ratio): 360 | pl = [] 361 | self.set_model_mode('eval') 362 | for batch_idx, batch in enumerate(target_loader): 363 | input, label, impath = self.parse_batch_test(batch) 364 | output1 = self.F1(input) 365 | output1 = self.C1(output1) 366 | output2 = self.F2(input) 367 | output2 = self.C2(output2) 368 | pred1 = F.softmax(output1, 1) # 128*10 369 | pred2 = F.softmax(output2, 1) # 128*10 370 | _, p_label_1 = torch.max(pred1, 1) # 1 3 3 5 371 | _, p_label_2 = torch.max(pred2, 1) # 2 3 4 5 372 | agree_id = torch.where(p_label_1 == p_label_2)[0].cpu().numpy() # 找到相同的下标 373 | label = label[agree_id] # 切片 374 | output1 = output1[agree_id] 375 | output2 = output2[agree_id] 376 | loss = F.cross_entropy(output1, label, reduction='none') 377 | ind_sorted = np.argsort(loss.data.cpu()) 378 | loss_sorted = loss[ind_sorted] 379 | num_remember = int(ratio * len(loss_sorted)) 380 | ind_update_clean = ind_sorted[:num_remember] 381 | ind_update_clean = agree_id[ind_update_clean] 382 | ind_update_hard = ind_sorted[num_remember:] 383 | ind_update_hard = agree_id[ind_update_hard] 384 | pred = pred1.max(1)[1] 385 | for index, path in enumerate(impath): 386 | if index in ind_update_clean: 387 | pl.append((path, int(pred[index]))) 388 | return pl 389 | 390 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import yaml 4 | import numpy as np 5 | import os.path as osp 6 | from torchvision import datasets 7 | from utils.transforms import * 8 | from utils.dataset import * 9 | from utils import * 10 | from torch.utils.data.dataloader import DataLoader 11 | from models import models 12 | from mcd_trainer import MCDTrainer 13 | from coteaching_trainer import CoTeachingTrainer 14 | 15 | from models.build import build_mcd_model,build_dg_model,build_digits_dg_model,build_digits_mcd_model 16 | 17 | def set_random_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | 24 | def build_mcd_optimizer(model, config): 25 | if config['optimizer']['name'] == 'SGD': 26 | opt_F = torch.optim.SGD(list(model["F"].parameters()), 27 | **config['optimizer']['params']) 28 | opt_C1 = torch.optim.SGD(list(model["C1"].parameters()), 29 | **config['optimizer']['params']) 30 | opt_C2 = torch.optim.SGD(list(model["C2"].parameters()), 31 | **config['optimizer']['params']) 32 | else: 33 | opt_F = torch.optim.Adam(list(model["F"].parameters()), 34 | **config['optimizer']['params']) 35 | opt_C1 = torch.optim.Adam(list(model["C1"].parameters()), 36 | **config['optimizer']['params']) 37 | opt_C2 = torch.optim.Adam(list(model["C2"].parameters()), 38 | **config['optimizer']['params']) 39 | 40 | return {"F": opt_F, "C1": opt_C1, "C2": opt_C2} 41 | 42 | def build_dg_optimizer(model, config): 43 | if config['optimizer']['name'] == 'SGD': 44 | opt_F1 = torch.optim.SGD(list(model["F1"].parameters()), 45 | **config['optimizer']['params']) 46 | opt_F2 = torch.optim.SGD(list(model["F2"].parameters()), 47 | **config['optimizer']['params']) 48 | opt_C1 = torch.optim.SGD(list(model["C1"].parameters()), 49 | **config['optimizer']['params']) 50 | opt_C2 = torch.optim.SGD(list(model["C2"].parameters()), 51 | **config['optimizer']['params']) 52 | else: 53 | opt_F1 = torch.optim.Adam(list(model["F1"].parameters()), 54 | **config['optimizer']['params']) 55 | opt_F2 = torch.optim.Adam(list(model["F2"].parameters()), 56 | **config['optimizer']['params']) 57 | opt_C1 = torch.optim.Adam(list(model["C1"].parameters()), 58 | **config['optimizer']['params']) 59 | opt_C2 = torch.optim.Adam(list(model["C2"].parameters()), 60 | **config['optimizer']['params']) 61 | 62 | return {"F1": opt_F1,"F2": opt_F2, "C1": opt_C1, "C2": opt_C2 } 63 | 64 | 65 | def build_dataloader(data_list, batch_size, config, transform=None, istrain=True): 66 | dataset = base_dataset(data_list, transform=transform) 67 | return DataLoader(dataset, batch_size=batch_size, 68 | num_workers=config['trainer']['num_workers'], 69 | drop_last=istrain, shuffle=istrain) 70 | 71 | 72 | def main(): 73 | import argparse 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--task', default='visda', help='task name') 76 | args = parser.parse_args() 77 | 78 | set_random_seed(1) 79 | 80 | config = yaml.load(open("./config/" + args.task + ".yaml", "r"), Loader=yaml.FullLoader) 81 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 82 | 83 | # logger info 84 | logdir = osp.join(osp.dirname(__file__), config['log']['save_addr'], config['log']['save_name'] + '.log') 85 | setup_logger(logdir) 86 | print(config) 87 | 88 | # network 89 | mcd_model_1 = build_mcd_model(config).to(device) 90 | mcd_model_2 = build_mcd_model(config).to(device) 91 | dg_model = build_dg_model(config).to(device) 92 | 93 | # optim 94 | mcd_opt_1 = build_mcd_optimizer(mcd_model_1, config) 95 | mcd_opt_2 = build_mcd_optimizer(mcd_model_2, config) 96 | dg_opt = build_dg_optimizer(dg_model, config) 97 | 98 | 99 | # data_list, transform 100 | input_size = config['data_transforms']['input_size'] 101 | batch_size = config['trainer']['batch_size'] 102 | mcd_transform_test = simple_transform_test(input_size=input_size, type = config['data']['type']) 103 | mcd_transform_train = simple_transform_train(input_size=input_size, type = config['data']['type']) 104 | impath_label_x = get_image_dirs(root=config['data']['root'], 105 | dname=config['data']['source_domain_x'], 106 | split="train") #源域1 107 | impath_label_u_1 = get_image_dirs(root=config['data']['root'], 108 | dname=config['data']['source_domain_u_1'], 109 | split="train") #源域2 110 | impath_label_u_2 = get_image_dirs(root=config['data']['root'], 111 | dname=config['data']['source_domain_u_2'], 112 | split="train") #源域3 113 | impath_label_t = get_image_dirs(root=config['data']['root'], 114 | dname=config['data']['target_domain'], 115 | split="all") 116 | fake_mcd_u_1 = [] 117 | fake_mcd_u_2 = [] 118 | fake_dg_u_1 = [] 119 | fake_dg_u_2 = [] 120 | 121 | # trainer 122 | trainer_mcd_1 = MCDTrainer(mcd_model_1, mcd_opt_1, device, 1, **config['trainer']) 123 | trainer_mcd_2 = MCDTrainer(mcd_model_2, mcd_opt_2, device, 2, **config['trainer']) 124 | 125 | for index in range(3): 126 | print(f"Round {index}: Training MCD.".center(100, "#")) 127 | # dataloader 128 | train_data_1 = impath_label_x + fake_dg_u_1 129 | train_data_2 = impath_label_x + fake_dg_u_2 130 | dataloader_x_1 = build_dataloader(train_data_1, batch_size, config, mcd_transform_train) 131 | print('dataset dataloader_x_1: {}'.format(len(dataloader_x_1))) 132 | dataloader_x_2 = build_dataloader(train_data_2, batch_size, config, mcd_transform_train) 133 | print('dataset dataloader_x_2: {}'.format(len(dataloader_x_2))) 134 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_train) 135 | print('dataset dataloader_u_1: {}'.format(len(dataloader_u_1))) 136 | dataloader_u_2 = build_dataloader(impath_label_u_2, batch_size, config, mcd_transform_train) 137 | print('dataset dataloader_u_2: {}'.format(len(dataloader_u_2))) 138 | 139 | # train 140 | trainer_mcd_1.update_lr(index + 1) 141 | trainer_mcd_2.update_lr(index + 1) 142 | trainer_mcd_1.train_mcd(dataloader_x_1, dataloader_u_1, 30) 143 | trainer_mcd_2.train_mcd(dataloader_x_2, dataloader_u_2, 30) 144 | del dataloader_x_1, dataloader_x_2, dataloader_u_1, dataloader_u_2 145 | 146 | # test dataloader. 147 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_test, False) 148 | dataloader_u_2 = build_dataloader(impath_label_u_2, batch_size, config, mcd_transform_test, False) 149 | # test 150 | print("test dataloader_u_1.".center(60, "#")) 151 | trainer_mcd_1.test(dataloader_u_1) 152 | print("test dataloader_u_2.".center(60, "#")) 153 | trainer_mcd_2.test(dataloader_u_2) 154 | # get pseudo label. 155 | print("get pseudo label.".center(60, "#")) 156 | fake_mcd_u_1 = trainer_mcd_1.get_pl(dataloader_u_1) 157 | fake_mcd_u_2 = trainer_mcd_2.get_pl(dataloader_u_2) 158 | del dataloader_u_1, dataloader_u_2 159 | 160 | # Train Co-teaching 161 | print("Training CO-Teaching.".center(100, "#")) 162 | train_data = impath_label_x + fake_mcd_u_1 + fake_mcd_u_2 163 | dg_transform_test = simple_transform_test(input_size=input_size, type = config['data']['type']) 164 | dg_transform_train = simple_transform_train(input_size=input_size, type = config['data']['type']) 165 | dataloader_train = build_dataloader(train_data, batch_size, config, dg_transform_train) 166 | print('dataset dg train dataloader: {}'.format(len(dataloader_train))) 167 | dataloader_test = build_dataloader(impath_label_t, batch_size, config, dg_transform_test) 168 | print('dataset dg test dataloader: {}'.format(len(dataloader_test))) 169 | 170 | source_domain = config['data']['source_domain_x'] 171 | trainer_dg = CoTeachingTrainer(dg_model, dg_opt, device, 1, source_domain, **config['trainer']) 172 | #train 173 | trainer_dg.update_lr(index + 1) 174 | trainer_dg.train_dg(dataloader_train,dataloader_test,15) 175 | 176 | #test dataloader 177 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_test, False) 178 | dataloader_u_2 = build_dataloader(impath_label_u_2, batch_size, config, mcd_transform_test, False) 179 | 180 | #get pseudo label 181 | ratio = config['ratio'] 182 | print("get dg pseudo label.".center(60, "#")) 183 | fake_dg_u_1 = trainer_dg.get_pl(dataloader_u_1,ratio = ratio) 184 | fake_dg_u_2 = trainer_dg.get_pl(dataloader_u_2,ratio = ratio) 185 | del dataloader_u_1, dataloader_u_2 186 | 187 | 188 | 189 | if __name__ == '__main__': 190 | main() 191 | -------------------------------------------------------------------------------- /main_web_ignored.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import yaml 4 | import numpy as np 5 | import os.path as osp 6 | from torchvision import datasets 7 | from utils.transforms import * 8 | from utils.dataset import * 9 | from utils import * 10 | from torch.utils.data.dataloader import DataLoader 11 | from models import models 12 | from mcd_trainer import MCDTrainer 13 | from coteaching_trainer import CoTeachingTrainer 14 | 15 | from models.build import build_mcd_model,build_dg_model,build_digits_dg_model,build_digits_mcd_model 16 | 17 | def set_random_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | def build_mcd_optimizer(model, config): 24 | if config['optimizer']['name'] == 'SGD': 25 | opt_F = torch.optim.SGD(list(model["F"].parameters()), 26 | **config['optimizer']['params']) 27 | opt_C1 = torch.optim.SGD(list(model["C1"].parameters()), 28 | **config['optimizer']['params']) 29 | opt_C2 = torch.optim.SGD(list(model["C2"].parameters()), 30 | **config['optimizer']['params']) 31 | else: 32 | opt_F = torch.optim.Adam(list(model["F"].parameters()), 33 | **config['optimizer']['params']) 34 | opt_C1 = torch.optim.Adam(list(model["C1"].parameters()), 35 | **config['optimizer']['params']) 36 | opt_C2 = torch.optim.Adam(list(model["C2"].parameters()), 37 | **config['optimizer']['params']) 38 | 39 | return {"F": opt_F, "C1": opt_C1, "C2": opt_C2} 40 | 41 | def build_dg_optimizer(model, config): 42 | if config['optimizer']['name'] == 'SGD': 43 | opt_F1 = torch.optim.SGD(list(model["F1"].parameters()), 44 | **config['optimizer']['params']) 45 | opt_F2 = torch.optim.SGD(list(model["F2"].parameters()), 46 | **config['optimizer']['params']) 47 | opt_C1 = torch.optim.SGD(list(model["C1"].parameters()), 48 | **config['optimizer']['params']) 49 | opt_C2 = torch.optim.SGD(list(model["C2"].parameters()), 50 | **config['optimizer']['params']) 51 | else: 52 | opt_F1 = torch.optim.Adam(list(model["F1"].parameters()), 53 | **config['optimizer']['params']) 54 | opt_F2 = torch.optim.Adam(list(model["F2"].parameters()), 55 | **config['optimizer']['params']) 56 | opt_C1 = torch.optim.Adam(list(model["C1"].parameters()), 57 | **config['optimizer']['params']) 58 | opt_C2 = torch.optim.Adam(list(model["C2"].parameters()), 59 | **config['optimizer']['params']) 60 | 61 | return {"F1": opt_F1,"F2": opt_F2, "C1": opt_C1, "C2": opt_C2 } 62 | 63 | 64 | def build_dataloader(data_list, batch_size, config, transform=None, istrain=True): 65 | dataset = base_dataset(data_list, transform=transform) 66 | return DataLoader(dataset, batch_size=batch_size, 67 | num_workers=config['trainer']['num_workers'], 68 | drop_last=istrain, shuffle=istrain) 69 | 70 | 71 | def main(): 72 | import argparse 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--task', default='art2clipart', help='task name') 75 | args = parser.parse_args() 76 | 77 | set_random_seed(1) 78 | 79 | config = yaml.load(open("./config/" + args.task + ".yaml", "r"), Loader=yaml.FullLoader) 80 | 81 | config['trainer']['save_model_addr'] = f"{config['trainer']['save_model_addr']}_ignored" 82 | config['log']['save_name'] = f"{config['log']['save_name']}_ignored" 83 | 84 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 85 | 86 | # logger info 87 | logdir = osp.join(osp.dirname(__file__), config['log']['save_addr'], config['log']['save_name'] + '.log') 88 | setup_logger(logdir) 89 | print(config) 90 | 91 | # network 建立网络模型,即骨干网络 92 | mcd_model_1 = build_mcd_model(config).to(device) 93 | #mcd_model_2 = build_mcd_model(config).to(device) #两个APL模块的网络 94 | dg_model = build_dg_model(config).to(device) #DCG过程的网络 95 | 96 | # optim 97 | mcd_opt_1 = build_mcd_optimizer(mcd_model_1, config) 98 | #mcd_opt_2 = build_mcd_optimizer(mcd_model_2, config) 99 | dg_opt = build_dg_optimizer(dg_model, config) 100 | 101 | 102 | # data_list, transform 103 | input_size = config['data_transforms']['input_size'] 104 | batch_size = config['trainer']['batch_size'] 105 | 106 | #transform设置 107 | mcd_transform_test = simple_transform_test(input_size=input_size, type = config['data']['type']) 108 | mcd_transform_train = simple_transform_train(input_size=input_size, type = config['data']['type']) 109 | 110 | #读取数据 111 | impath_label_x = get_image_dirs(root=config['data']['root'], 112 | dname=config['data']['source_domain_x'], 113 | split="train") #源域1 114 | impath_label_u_1 = get_web_image_dirs(root=config['data']['webroot'] 115 | ) #web数据 116 | 117 | impath_label_t = get_image_dirs(root=config['data']['root'], 118 | dname=config['data']['target_domain'], 119 | split="train") 120 | 121 | fake_dg_u_1 = [] 122 | 123 | 124 | # trainer 125 | trainer_mcd_1 = MCDTrainer(mcd_model_1, mcd_opt_1, device, 1, **config['trainer']) 126 | 127 | 128 | for index in range(3): 129 | print(f"Round {index}: Training MCD.".center(100, "#")) 130 | # dataloader 131 | train_data_1 = impath_label_x + fake_dg_u_1 132 | dataloader_x_1 = build_dataloader(train_data_1, batch_size, config, mcd_transform_train) 133 | print('dataset dataloader_x_1: {}'.format(len(dataloader_x_1))) 134 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_train) 135 | print('dataset dataloader_u_1: {}'.format(len(dataloader_u_1))) 136 | 137 | # train 138 | trainer_mcd_1.update_lr(index + 1) 139 | trainer_mcd_1.train_mcd(dataloader_x_1, dataloader_u_1, 30) 140 | del dataloader_x_1, dataloader_u_1 141 | 142 | # test dataloader. 143 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_test, False) 144 | # test 145 | print("test dataloader_u_1.".center(60, "#")) 146 | trainer_mcd_1.test(dataloader_u_1) 147 | # get pseudo label. 148 | print("get pseudo label.".center(60, "#")) #通过MCD训练得到模型后,经过test得到伪标签 149 | fake_mcd_u_1 = trainer_mcd_1.get_pl(dataloader_u_1) 150 | del dataloader_u_1 151 | 152 | # Train Co-teach0ing 153 | print("Training CO-Teaching.".center(100, "#")) 154 | train_data = impath_label_x + fake_mcd_u_1 #数据为路径加标签,即文本中的一行 155 | dg_transform_test = simple_transform_test(input_size=input_size, type = config['data']['type']) 156 | dg_transform_train = simple_transform_train(input_size=input_size, type = config['data']['type']) 157 | dataloader_train = build_dataloader(train_data, batch_size, config, dg_transform_train) 158 | print('dataset dg train dataloader: {}'.format(len(dataloader_train))) 159 | dataloader_test = build_dataloader(impath_label_t, batch_size, config, dg_transform_test) 160 | print('dataset dg test dataloader: {}'.format(len(dataloader_test))) 161 | 162 | source_domain = config['data']['source_domain_x'] 163 | trainer_dg = CoTeachingTrainer(dg_model, dg_opt, device, 1, source_domain, **config['trainer']) 164 | #train 165 | trainer_dg.update_lr(index + 1) 166 | trainer_dg.train_dg(dataloader_train,dataloader_test,15) 167 | 168 | #test dataloader 169 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_test, False) 170 | 171 | #get pseudo label 172 | ratio = config['ratio'] 173 | print("get dg pseudo label.".center(60, "#")) 174 | fake_dg_u_1 = trainer_dg.get_pl(dataloader_u_1,ratio = ratio) 175 | del dataloader_u_1 176 | 177 | 178 | 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /main_web_refined.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import yaml 4 | import numpy as np 5 | import os.path as osp 6 | from torchvision import datasets 7 | from utils.transforms import * 8 | from utils.dataset import * 9 | from utils import * 10 | from torch.utils.data.dataloader import DataLoader 11 | from models import models 12 | from mcd_trainer import MCDTrainer 13 | from coteaching_trainer import CoTeachingTrainer 14 | 15 | from models.build import build_mcd_model,build_dg_model,build_digits_dg_model,build_digits_mcd_model 16 | 17 | def set_random_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | def build_mcd_optimizer(model, config): 24 | if config['optimizer']['name'] == 'SGD': 25 | opt_F = torch.optim.SGD(list(model["F"].parameters()), 26 | **config['optimizer']['params']) 27 | opt_C1 = torch.optim.SGD(list(model["C1"].parameters()), 28 | **config['optimizer']['params']) 29 | opt_C2 = torch.optim.SGD(list(model["C2"].parameters()), 30 | **config['optimizer']['params']) 31 | else: 32 | opt_F = torch.optim.Adam(list(model["F"].parameters()), 33 | **config['optimizer']['params']) 34 | opt_C1 = torch.optim.Adam(list(model["C1"].parameters()), 35 | **config['optimizer']['params']) 36 | opt_C2 = torch.optim.Adam(list(model["C2"].parameters()), 37 | **config['optimizer']['params']) 38 | 39 | return {"F": opt_F, "C1": opt_C1, "C2": opt_C2} 40 | 41 | def build_dg_optimizer(model, config): 42 | if config['optimizer']['name'] == 'SGD': 43 | opt_F1 = torch.optim.SGD(list(model["F1"].parameters()), 44 | **config['optimizer']['params']) 45 | opt_F2 = torch.optim.SGD(list(model["F2"].parameters()), 46 | **config['optimizer']['params']) 47 | opt_C1 = torch.optim.SGD(list(model["C1"].parameters()), 48 | **config['optimizer']['params']) 49 | opt_C2 = torch.optim.SGD(list(model["C2"].parameters()), 50 | **config['optimizer']['params']) 51 | else: 52 | opt_F1 = torch.optim.Adam(list(model["F1"].parameters()), 53 | **config['optimizer']['params']) 54 | opt_F2 = torch.optim.Adam(list(model["F2"].parameters()), 55 | **config['optimizer']['params']) 56 | opt_C1 = torch.optim.Adam(list(model["C1"].parameters()), 57 | **config['optimizer']['params']) 58 | opt_C2 = torch.optim.Adam(list(model["C2"].parameters()), 59 | **config['optimizer']['params']) 60 | 61 | return {"F1": opt_F1,"F2": opt_F2, "C1": opt_C1, "C2": opt_C2 } 62 | 63 | 64 | def build_dataloader(data_list, batch_size, config, transform=None, istrain=True): 65 | dataset = base_dataset(data_list, transform=transform) 66 | return DataLoader(dataset, batch_size=batch_size, 67 | num_workers=config['trainer']['num_workers'], 68 | drop_last=istrain, shuffle=istrain) 69 | 70 | 71 | def main(): 72 | import argparse 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--task', default='art2clipart', help='task name') 75 | args = parser.parse_args() 76 | 77 | set_random_seed(1) 78 | 79 | config = yaml.load(open("./config/" + args.task + ".yaml", "r"), Loader=yaml.FullLoader) 80 | 81 | config['trainer']['save_model_addr'] = f"{config['trainer']['save_model_addr']}_refined" 82 | config['log']['save_name'] = f"{config['log']['save_name']}_refined" 83 | 84 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 85 | 86 | # logger info 87 | logdir = osp.join(osp.dirname(__file__), config['log']['save_addr'], config['log']['save_name'] + '.log') 88 | setup_logger(logdir) 89 | print(config) 90 | 91 | # network 建立网络模型,即骨干网络 92 | mcd_model_1 = build_mcd_model(config).to(device) 93 | #mcd_model_2 = build_mcd_model(config).to(device) #两个APL模块的网络 94 | dg_model = build_dg_model(config).to(device) #DCG过程的网络 95 | 96 | # optim 97 | mcd_opt_1 = build_mcd_optimizer(mcd_model_1, config) 98 | #mcd_opt_2 = build_mcd_optimizer(mcd_model_2, config) 99 | dg_opt = build_dg_optimizer(dg_model, config) 100 | 101 | 102 | # data_list, transform 103 | input_size = config['data_transforms']['input_size'] 104 | batch_size = config['trainer']['batch_size'] 105 | 106 | #transform设置 107 | mcd_transform_test = simple_transform_test(input_size=input_size, type = config['data']['type']) 108 | mcd_transform_train = simple_transform_train(input_size=input_size, type = config['data']['type']) 109 | 110 | #读取数据 111 | impath_label_x = get_image_dirs(root=config['data']['root'], 112 | dname=config['data']['source_domain_x'], 113 | split="train") #源域1 114 | impath_label_u_1 = get_web_image_dirs(root=config['data']['webroot']) #源域2即web数据 115 | 116 | impath_label_t = get_image_dirs(root=config['data']['root'], 117 | dname=config['data']['target_domain'], 118 | split="train") 119 | 120 | fake_dg_u_1 = [] 121 | 122 | 123 | # trainer 124 | trainer_mcd_1 = MCDTrainer(mcd_model_1, mcd_opt_1, device, 1, **config['trainer']) 125 | 126 | 127 | for index in range(3): 128 | 129 | if index == 0: 130 | print(f"Round {index}: Training MCD.".center(100, "#")) 131 | fake_mcd_u_1 = impath_label_u_1 132 | 133 | else: 134 | 135 | print(f"Round {index}: Training MCD.".center(100, "#")) 136 | # dataloader 137 | train_data_1 = impath_label_x + fake_dg_u_1 138 | dataloader_x_1 = build_dataloader(train_data_1, batch_size, config, mcd_transform_train) 139 | print('dataset dataloader_x_1: {}'.format(len(dataloader_x_1))) 140 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_train) 141 | print('dataset dataloader_u_1: {}'.format(len(dataloader_u_1))) 142 | 143 | # train 144 | trainer_mcd_1.update_lr(index + 1) 145 | trainer_mcd_1.train_mcd(dataloader_x_1, dataloader_u_1, 30) 146 | del dataloader_x_1, dataloader_u_1 147 | 148 | # test dataloader. 149 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_test, False) 150 | # test 151 | print("test dataloader_u_1.".center(60, "#")) 152 | trainer_mcd_1.test(dataloader_u_1) 153 | # get pseudo label. 154 | print("get pseudo label.".center(60, "#")) #通过MCD训练得到模型后,经过test得到伪标签 155 | fake_mcd_u_1 = trainer_mcd_1.get_pl(dataloader_u_1) 156 | del dataloader_u_1 157 | 158 | # Train Co-teach0ing 159 | print("Training CO-Teaching.".center(100, "#")) 160 | train_data = impath_label_x + fake_mcd_u_1 #数据为路径加标签,即文本中的一行 161 | dg_transform_test = simple_transform_test(input_size=input_size, type = config['data']['type']) 162 | dg_transform_train = simple_transform_train(input_size=input_size, type = config['data']['type']) 163 | dataloader_train = build_dataloader(train_data, batch_size, config, dg_transform_train) 164 | print('dataset dg train dataloader: {}'.format(len(dataloader_train))) 165 | dataloader_test = build_dataloader(impath_label_t, batch_size, config, dg_transform_test) 166 | print('dataset dg test dataloader: {}'.format(len(dataloader_test))) 167 | 168 | source_domain = config['data']['source_domain_x'] 169 | trainer_dg = CoTeachingTrainer(dg_model, dg_opt, device, 1, source_domain, **config['trainer']) 170 | #train 171 | trainer_dg.update_lr(index + 1) 172 | trainer_dg.train_dg(dataloader_train,dataloader_test,15) 173 | 174 | #test dataloader 175 | dataloader_u_1 = build_dataloader(impath_label_u_1, batch_size, config, mcd_transform_test, False) 176 | 177 | #get pseudo label 178 | ratio = config['ratio'] 179 | print("get dg pseudo label.".center(60, "#")) 180 | fake_dg_u_1 = trainer_dg.get_pl(dataloader_u_1,ratio = ratio) 181 | del dataloader_u_1 182 | 183 | 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /main_web_retained.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import yaml 4 | import numpy as np 5 | import os.path as osp 6 | from torchvision import datasets 7 | from utils.transforms import * 8 | from utils.dataset import * 9 | from utils import * 10 | from torch.utils.data.dataloader import DataLoader 11 | from models import models 12 | from mcd_trainer import MCDTrainer 13 | from coteaching_trainer import CoTeachingTrainer 14 | 15 | from models.build import build_mcd_model,build_dg_model,build_digits_dg_model,build_digits_mcd_model 16 | 17 | def set_random_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | def build_mcd_optimizer(model, config): 24 | if config['optimizer']['name'] == 'SGD': 25 | opt_F = torch.optim.SGD(list(model["F"].parameters()), 26 | **config['optimizer']['params']) 27 | opt_C1 = torch.optim.SGD(list(model["C1"].parameters()), 28 | **config['optimizer']['params']) 29 | opt_C2 = torch.optim.SGD(list(model["C2"].parameters()), 30 | **config['optimizer']['params']) 31 | else: 32 | opt_F = torch.optim.Adam(list(model["F"].parameters()), 33 | **config['optimizer']['params']) 34 | opt_C1 = torch.optim.Adam(list(model["C1"].parameters()), 35 | **config['optimizer']['params']) 36 | opt_C2 = torch.optim.Adam(list(model["C2"].parameters()), 37 | **config['optimizer']['params']) 38 | 39 | return {"F": opt_F, "C1": opt_C1, "C2": opt_C2} 40 | 41 | def build_dg_optimizer(model, config): 42 | if config['optimizer']['name'] == 'SGD': 43 | opt_F1 = torch.optim.SGD(list(model["F1"].parameters()), 44 | **config['optimizer']['params']) 45 | opt_F2 = torch.optim.SGD(list(model["F2"].parameters()), 46 | **config['optimizer']['params']) 47 | opt_C1 = torch.optim.SGD(list(model["C1"].parameters()), 48 | **config['optimizer']['params']) 49 | opt_C2 = torch.optim.SGD(list(model["C2"].parameters()), 50 | **config['optimizer']['params']) 51 | else: 52 | opt_F1 = torch.optim.Adam(list(model["F1"].parameters()), 53 | **config['optimizer']['params']) 54 | opt_F2 = torch.optim.Adam(list(model["F2"].parameters()), 55 | **config['optimizer']['params']) 56 | opt_C1 = torch.optim.Adam(list(model["C1"].parameters()), 57 | **config['optimizer']['params']) 58 | opt_C2 = torch.optim.Adam(list(model["C2"].parameters()), 59 | **config['optimizer']['params']) 60 | 61 | return {"F1": opt_F1,"F2": opt_F2, "C1": opt_C1, "C2": opt_C2 } 62 | 63 | 64 | def build_dataloader(data_list, batch_size, config, transform=None, istrain=True): 65 | dataset = base_dataset(data_list, transform=transform) 66 | return DataLoader(dataset, batch_size=batch_size, 67 | num_workers=config['trainer']['num_workers'], 68 | drop_last=istrain, shuffle=istrain) 69 | 70 | 71 | def main(): 72 | import argparse 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--task', default='art2clipart', help='task name') 75 | args = parser.parse_args() 76 | 77 | set_random_seed(1) 78 | 79 | config = yaml.load(open("./config/" + args.task + ".yaml", "r"), Loader=yaml.FullLoader) 80 | 81 | config['trainer']['save_model_addr'] = f"{config['trainer']['save_model_addr']}_retained" 82 | config['log']['save_name'] = f"{config['log']['save_name']}_retained" 83 | 84 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 85 | 86 | # logger info 87 | logdir = osp.join(osp.dirname(__file__), config['log']['save_addr'], config['log']['save_name'] + '.log') 88 | setup_logger(logdir) 89 | print(config) 90 | 91 | # network 建立网络模型,即骨干网络 92 | mcd_model_1 = build_mcd_model(config).to(device) 93 | dg_model = build_dg_model(config).to(device) #DCG过程的网络 94 | 95 | # optim 96 | mcd_opt_1 = build_mcd_optimizer(mcd_model_1, config) 97 | dg_opt = build_dg_optimizer(dg_model, config) 98 | 99 | 100 | # data_list, transform 101 | input_size = config['data_transforms']['input_size'] 102 | batch_size = config['trainer']['batch_size'] 103 | 104 | #transform设置 105 | mcd_transform_test = simple_transform_test(input_size=input_size, type = config['data']['type']) 106 | mcd_transform_train = simple_transform_train(input_size=input_size, type = config['data']['type']) 107 | 108 | #读取数据 109 | impath_label_x = get_image_dirs(root=config['data']['root'], 110 | dname=config['data']['source_domain_x'], 111 | split="train") #源域1 112 | impath_label_u_1 = get_web_image_dirs(root=config['data']['webroot']) #源域2即web数据 113 | 114 | impath_label_t = get_image_dirs(root=config['data']['root'], 115 | dname=config['data']['target_domain'], 116 | split="train") 117 | 118 | fake_dg_u_1 = [] 119 | 120 | # Train Co-teach0ing 121 | print("Training CO-Teaching.".center(100, "#")) 122 | train_data = impath_label_x + impath_label_u_1 #数据为路径加标签,即文本中的一行 123 | dg_transform_test = simple_transform_test(input_size=input_size, type = config['data']['type']) 124 | dg_transform_train = simple_transform_train(input_size=input_size, type = config['data']['type']) 125 | dataloader_train = build_dataloader(train_data, batch_size, config, dg_transform_train) 126 | print('dataset dg train dataloader: {}'.format(len(dataloader_train))) 127 | dataloader_test = build_dataloader(impath_label_t, batch_size, config, dg_transform_test) 128 | print('dataset dg test dataloader: {}'.format(len(dataloader_test))) 129 | 130 | source_domain = config['data']['source_domain_x'] 131 | trainer_dg = CoTeachingTrainer(dg_model, dg_opt, device, 1, source_domain, **config['trainer']) 132 | #train 133 | trainer_dg.update_lr(1) 134 | trainer_dg.train_dg(dataloader_train,dataloader_test,15) 135 | 136 | 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /mcd_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | 4 | from torch.nn import functional as F 5 | from utils import * 6 | from utils.dataset import * 7 | from utils.evaluator import Classification 8 | 9 | from utils import mixup 10 | 11 | 12 | def discrepancy(y1, y2): 13 | return (y1 - y2).abs().mean() 14 | 15 | 16 | class MCDTrainer: 17 | def __init__(self, models, optimizer, device, index, **params): 18 | self.index = index 19 | self.n_step_f = params["n_step_f"] 20 | self._models = models 21 | self.F = models["F"] 22 | self.C1 = models["C1"] 23 | self.C2 = models["C2"] 24 | self._optims = optimizer 25 | 26 | self.device = device 27 | self.batch_size = params['batch_size'] 28 | self.num_workers = params['num_workers'] 29 | self.batch_size = params['batch_size'] 30 | self.folder = osp.join(params['save_model_addr'], f"MCD{index}") 31 | 32 | self.max_epoch = 0 33 | self.start_epoch = self.epoch = 0 34 | self.evaluator = Classification() 35 | mkdir_if_missing(self.folder) 36 | 37 | def get_model_names(self, names=None): 38 | names_real = list(self._models.keys()) 39 | if names is not None: 40 | names = tolist_if_not(names) 41 | for name in names: 42 | assert name in names_real 43 | return names 44 | else: 45 | return names_real 46 | 47 | def get_current_lr(self, names=None): 48 | names = self.get_model_names(names) 49 | name = names[0] 50 | return self._optims[name].param_groups[0]['lr'] 51 | 52 | def set_model_mode(self, mode='train', names=None): 53 | names = self.get_model_names(names) 54 | 55 | for name in names: 56 | if mode == 'train': 57 | self._models[name].train() 58 | else: 59 | self._models[name].eval() 60 | 61 | def model_zero_grad(self, names=None): 62 | names = self.get_model_names(names) 63 | for name in names: 64 | if self._optims[name] is not None: 65 | self._optims[name].zero_grad() 66 | 67 | def detect_anomaly(self, loss): 68 | if not torch.isfinite(loss).all(): 69 | raise FloatingPointError('Loss is infinite or NaN!') 70 | 71 | def model_backward(self, loss, retain_graph): 72 | self.detect_anomaly(loss) 73 | loss.backward(retain_graph=retain_graph) 74 | 75 | def model_update(self, names=None): 76 | names = self.get_model_names(names) 77 | for name in names: 78 | if self._optims[name] is not None: 79 | self._optims[name].step() 80 | 81 | def model_backward_and_update(self, loss, names=None, retain_graph=False): 82 | self.model_zero_grad(names) 83 | self.model_backward(loss, retain_graph) 84 | self.model_update(names) 85 | 86 | def save_model(self, epoch, directory, is_best=False, model_name=''): 87 | names = self.get_model_names() 88 | 89 | for name in names: 90 | model_dict = self._models[name].state_dict() 91 | 92 | optim_dict = None 93 | if self._optims[name] is not None: 94 | optim_dict = self._optims[name].state_dict() 95 | 96 | save_checkpoint( 97 | { 98 | 'state_dict': model_dict, 99 | 'epoch': epoch + 1, 100 | 'optimizer': optim_dict 101 | }, 102 | osp.join(directory, name), 103 | is_best=is_best, 104 | model_name=model_name 105 | ) 106 | 107 | def after_epoch(self, test_loader): 108 | if (self.epoch + 1) % 10 == 0: 109 | self.test(test_loader) 110 | last_epoch = (self.epoch + 1) == self.max_epoch 111 | if last_epoch: 112 | self.save_model(self.epoch, self.folder) 113 | 114 | def before_train(self): 115 | # Remember the starting time (for computing the elapsed time) 116 | self.time_start = time.time() 117 | 118 | def after_train(self): 119 | print(f'Finished MCD_{self.index} training') 120 | 121 | # Show elapsed time 122 | elapsed = round(time.time() - self.time_start) 123 | elapsed = str(datetime.timedelta(seconds=elapsed)) 124 | print('Elapsed: {}'.format(elapsed)) 125 | 126 | def train_mcd(self, train_loader, test_loader, max_epoch): 127 | self.max_epoch = max_epoch 128 | self.before_train() 129 | for self.epoch in range(self.start_epoch, self.max_epoch): 130 | self.run_epoch(train_loader, test_loader) 131 | self.after_epoch(test_loader) 132 | self.after_train() 133 | 134 | def run_epoch(self, train_loader, test_loader): 135 | self.set_model_mode('train') 136 | losses = MetricMeter() 137 | batch_time = AverageMeter() 138 | data_time = AverageMeter() 139 | 140 | # Decide to iterate over labeled or unlabeled dataset 141 | len_train_loader_x = len(train_loader) 142 | len_train_loader_u = len(test_loader) 143 | num_batches = min(len_train_loader_x, len_train_loader_u) 144 | 145 | train_loader_x_iter = iter(train_loader) 146 | train_loader_u_iter = iter(test_loader) 147 | 148 | end = time.time() 149 | 150 | for self.batch_idx in range(num_batches): 151 | try: 152 | batch_x = next(train_loader_x_iter) 153 | except StopIteration: 154 | train_loader_x_iter = iter(train_loader) 155 | batch_x = next(train_loader_x_iter) 156 | 157 | try: 158 | batch_u = next(train_loader_u_iter) 159 | except StopIteration: 160 | train_loader_u_iter = iter(test_loader) 161 | batch_u = next(train_loader_u_iter) 162 | 163 | data_time.update(time.time() - end) 164 | loss_summary = self.forward_backward(batch_x, batch_u) 165 | batch_time.update(time.time() - end) 166 | losses.update(loss_summary) 167 | 168 | if (self.batch_idx + 1) % 10 == 0: 169 | nb_this_epoch = num_batches - (self.batch_idx + 1) 170 | nb_future_epochs = ( 171 | self.max_epoch - (self.epoch + 1) 172 | ) * num_batches 173 | eta_seconds = batch_time.avg * (nb_this_epoch + nb_future_epochs) 174 | eta = str(datetime.timedelta(seconds=int(eta_seconds))) 175 | print( 176 | 'epoch [{0}/{1}][{2}/{3}]\t' 177 | 'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 178 | 'data {data_time.val:.3f} ({data_time.avg:.3f})\t' 179 | 'eta {eta}\t' 180 | '{losses}\t' 181 | 'lr {lr}'.format( 182 | self.epoch + 1, 183 | self.max_epoch, 184 | self.batch_idx + 1, 185 | num_batches, 186 | batch_time=batch_time, 187 | data_time=data_time, 188 | eta=eta, 189 | losses=losses, 190 | lr=self.get_current_lr() 191 | ) 192 | ) 193 | 194 | end = time.time() 195 | 196 | def forward_backward(self, batch_x, batch_u): 197 | parsed = self.parse_batch_train(batch_x, batch_u) 198 | input_x, label_x_a, label_x_b, lam, input_u = parsed #mixup 199 | 200 | # Step A 201 | feat_x = self.F(input_x) 202 | logit_x1 = self.C1(feat_x) 203 | logit_x2 = self.C2(feat_x) 204 | # mixup 205 | loss_x1 = lam * F.cross_entropy(logit_x1, label_x_a) + (1 - lam) * F.cross_entropy(logit_x1, label_x_b) 206 | loss_x2 = lam * F.cross_entropy(logit_x2, label_x_a) + (1 - lam) * F.cross_entropy(logit_x2, label_x_b) 207 | loss_step_A = loss_x1 + loss_x2 208 | self.model_backward_and_update(loss_step_A) 209 | 210 | # Step B 211 | with torch.no_grad(): 212 | feat_x = self.F(input_x) 213 | logit_x1 = self.C1(feat_x) 214 | logit_x2 = self.C2(feat_x) 215 | # mixup 216 | loss_x1 = lam * F.cross_entropy(logit_x1, label_x_a) + (1 - lam) * F.cross_entropy(logit_x1, label_x_b) 217 | loss_x2 = lam * F.cross_entropy(logit_x2, label_x_a) + (1 - lam) * F.cross_entropy(logit_x2, label_x_b) 218 | loss_x = loss_x1 + loss_x2 219 | 220 | with torch.no_grad(): 221 | feat_u = self.F(input_u) 222 | pred_u1 = F.softmax(self.C1(feat_u), 1) 223 | pred_u2 = F.softmax(self.C2(feat_u), 1) 224 | loss_dis = discrepancy(pred_u1, pred_u2) 225 | 226 | loss_step_B = loss_x - loss_dis 227 | self.model_backward_and_update(loss_step_B, ['C1', 'C2']) 228 | 229 | # Step C 230 | for _ in range(self.n_step_f): 231 | feat_u = self.F(input_u) 232 | pred_u1 = F.softmax(self.C1(feat_u), 1) 233 | pred_u2 = F.softmax(self.C2(feat_u), 1) 234 | loss_step_C = discrepancy(pred_u1, pred_u2) 235 | self.model_backward_and_update(loss_step_C, 'F') 236 | 237 | loss_summary = { 238 | 'loss_step_A': loss_step_A.item(), 239 | 'loss_step_B': loss_step_B.item(), 240 | 'loss_step_C': loss_step_C.item() 241 | } 242 | 243 | return loss_summary 244 | 245 | def parse_batch_train(self, batch_x, batch_u): 246 | input_x = batch_x['img'] 247 | label_x = batch_x['label'] 248 | input_u = batch_u["img"] 249 | 250 | input_x = input_x.to(self.device) 251 | label_x = label_x.to(self.device) 252 | input_u = input_u.to(self.device) 253 | 254 | input_x, label_x_a, label_x_b, lam = mixup_data(input_x, label_x, alpha=1.0) 255 | 256 | return input_x, label_x_a, label_x_b, lam, input_u 257 | 258 | def parse_batch_test(self, batch): 259 | input = batch['img'] 260 | label = batch['label'] 261 | impath = batch["impath"] 262 | 263 | input = input.to(self.device) 264 | label = label.to(self.device) 265 | 266 | return input, label, impath 267 | 268 | def model_inference(self, input): 269 | feat = self.F(input) 270 | return self.C1(feat) 271 | 272 | def load_model(self, directory, epoch=None): 273 | names = self.get_model_names() 274 | 275 | # By default, the best model is loaded 276 | model_file = 'model-best.pth.tar' 277 | 278 | if epoch is not None: 279 | model_file = 'model.pth.tar-' + str(epoch) 280 | 281 | for name in names: 282 | model_path = osp.join(directory, name, model_file) 283 | 284 | if not osp.exists(model_path): 285 | raise FileNotFoundError( 286 | 'Model not found at "{}"'.format(model_path) 287 | ) 288 | 289 | checkpoint = load_checkpoint(model_path) 290 | state_dict = checkpoint['state_dict'] 291 | epoch = checkpoint['epoch'] 292 | 293 | print( 294 | 'Loading weights to {} ' 295 | 'from "{}" (epoch = {})'.format(name, model_path, epoch) 296 | ) 297 | self._models[name].load_state_dict(state_dict) 298 | 299 | def test(self, test_loader): 300 | """A testing pipeline.""" 301 | self.set_model_mode('eval') 302 | self.evaluator.reset() 303 | 304 | for batch_idx, batch in enumerate(test_loader): 305 | input, label, _ = self.parse_batch_test(batch) 306 | output = self.model_inference(input) 307 | self.evaluator.process(output, label) 308 | 309 | results = self.evaluator.evaluate() 310 | 311 | return results['accuracy'] 312 | 313 | def update_lr(self, index=1): 314 | for key in self._optims: 315 | opt = self._optims[key] 316 | lr = opt.defaults["lr"] / index ** 2 317 | opt.param_groups[0]['lr'] = lr 318 | 319 | def get_pl(self, target_loader): 320 | pl = [] 321 | self.set_model_mode('eval') 322 | for batch_idx, batch in enumerate(target_loader): 323 | input, _, impath = self.parse_batch_test(batch) 324 | output = self.model_inference(input) 325 | pred = output.max(1)[1] 326 | for index, path in enumerate(impath): 327 | pl.append((path, int(pred[index]))) 328 | return pl 329 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .cnn_digitsdg import * 3 | 4 | models = { 5 | 'R18': resnet18, 6 | 'CNN':CNN_Digits 7 | } -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | from models import models 2 | import torch.nn as nn 3 | 4 | 5 | def build_mcd_model(config): 6 | num_classes = config['data']['class_number'] 7 | F = models["R18"]() 8 | fdim = F.fdim 9 | C1 = nn.Linear(fdim, num_classes) 10 | C2 = nn.Linear(fdim, num_classes) 11 | model = nn.ModuleDict({"F": F, "C1": C1, "C2": C2}) 12 | return model 13 | 14 | 15 | def build_dg_model(config): 16 | num_classes = config['data']['class_number'] 17 | F1 = models["R18"](ms_layers=['layer1', 'layer2', 'layer3']) 18 | F2 = models["R18"](ms_layers=['layer1', 'layer2', 'layer3']) 19 | fdim = F1.fdim 20 | C1 = nn.Linear(fdim, num_classes) 21 | C2 = nn.Linear(fdim, num_classes) 22 | model = nn.ModuleDict({"F1": F1, "F2": F2, "C1": C1, "C2": C2}) 23 | return model 24 | 25 | def build_digits_mcd_model(config): #mcd用cnn 26 | num_classes = config['data']['class_number'] 27 | F = models["CNN"]() 28 | fdim = F.fdim 29 | # init_network_weights(F, init_type='kaiming') 30 | C1 = nn.Linear(fdim, num_classes) 31 | C2 = nn.Linear(fdim, num_classes) 32 | model = nn.ModuleDict({"F": F, "C1": C1, "C2": C2}) 33 | return model 34 | 35 | def build_digits_dg_model(config): #digits也用cnn 36 | num_classes = config['data']['class_number'] 37 | F1 = models["CNN"](ms_layers=['layer1', 'layer2', 'layer3']) 38 | F2 = models["CNN"](ms_layers=['layer1', 'layer2', 'layer3']) 39 | fdim = F1.fdim 40 | # init_network_weights(F1, init_type='kaiming') 41 | # init_network_weights(F2, init_type='kaiming') 42 | C1 = nn.Linear(fdim, num_classes) 43 | C2 = nn.Linear(fdim, num_classes) 44 | model = nn.ModuleDict({"F1": F1, "F2": F2, "C1": C1, "C2": C2}) 45 | return model 46 | 47 | 48 | 49 | def init_network_weights(model, init_type='normal', gain=0.02): 50 | 51 | def _init_func(m): 52 | classname = m.__class__.__name__ 53 | 54 | if hasattr(m, 'weight') and ( 55 | classname.find('Conv') != -1 or classname.find('Linear') != -1 56 | ): 57 | if init_type == 'normal': 58 | nn.init.normal_(m.weight.data, 0.0, gain) 59 | elif init_type == 'xavier': 60 | nn.init.xavier_normal_(m.weight.data, gain=gain) 61 | elif init_type == 'kaiming': 62 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 63 | elif init_type == 'orthogonal': 64 | nn.init.orthogonal_(m.weight.data, gain=gain) 65 | else: 66 | raise NotImplementedError( 67 | 'initialization method {} is not implemented'. 68 | format(init_type) 69 | ) 70 | if hasattr(m, 'bias') and m.bias is not None: 71 | nn.init.constant_(m.bias.data, 0.0) 72 | 73 | elif classname.find('BatchNorm') != -1: 74 | nn.init.constant_(m.weight.data, 1.0) 75 | nn.init.constant_(m.bias.data, 0.0) 76 | 77 | elif classname.find('InstanceNorm') != -1: 78 | if m.weight is not None and m.bias is not None: 79 | nn.init.constant_(m.weight.data, 1.0) 80 | nn.init.constant_(m.bias.data, 0.0) 81 | 82 | model.apply(_init_func) -------------------------------------------------------------------------------- /models/cnn_digitsdg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torchvision 5 | from .mixstyle import MixStyle 6 | from torch.nn import functional as F 7 | 8 | 9 | class Convolution(nn.Module): 10 | 11 | def __init__(self, c_in, c_out): 12 | super().__init__() 13 | self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1) 14 | self.relu = nn.ReLU(True) 15 | 16 | def forward(self, x): 17 | return self.relu(self.conv(x)) 18 | 19 | class CNN_Digits(nn.Module): 20 | 21 | def __init__(self, c_hidden=64, ms_layers=[], ms_p=0.5, ms_a=0.1): 22 | super().__init__() 23 | self.conv1 = Convolution(3, c_hidden) 24 | self.conv2 = Convolution(c_hidden, c_hidden) 25 | self.conv3 = Convolution(c_hidden, c_hidden) 26 | self.conv4 = Convolution(c_hidden, c_hidden) 27 | self.mixstyle = None 28 | if ms_layers: 29 | self.mixstyle = MixStyle(p=ms_p, alpha=ms_a) 30 | for layer_name in ms_layers: 31 | assert layer_name in ['layer1', 'layer2', 'layer3'] 32 | print(f'Insert MixStyle after {ms_layers}') 33 | self.ms_layers = ms_layers 34 | 35 | self.fdim = 2**2 * c_hidden 36 | 37 | 38 | def _check_input(self, x): 39 | H, W = x.shape[2:] 40 | assert H == 32 and W == 32, \ 41 | 'Input to network must be 32x32, ' \ 42 | 'but got {}x{}'.format(H, W) 43 | 44 | def forward(self, x): 45 | self._check_input(x) 46 | x = self.conv1(x) 47 | x = F.max_pool2d(x, 2) 48 | if 'layer1' in self.ms_layers: 49 | x = self.mixstyle(x) 50 | x = self.conv2(x) 51 | x = F.max_pool2d(x, 2) 52 | if 'layer2' in self.ms_layers: 53 | x = self.mixstyle(x) 54 | x = self.conv3(x) 55 | x = F.max_pool2d(x, 2) 56 | if 'layer3' in self.ms_layers: 57 | x = self.mixstyle(x) 58 | x = self.conv4(x) 59 | x = F.max_pool2d(x, 2) 60 | return x.view(x.size(0), -1) 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /models/mixstyle.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def deactivate_mixstyle(m): 7 | if type(m) == MixStyle: 8 | m.set_activation_status(False) 9 | 10 | 11 | def activate_mixstyle(m): 12 | if type(m) == MixStyle: 13 | m.set_activation_status(True) 14 | 15 | 16 | class MixStyle(nn.Module): 17 | """MixStyle. 18 | 19 | Reference: 20 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 21 | """ 22 | 23 | def __init__(self, p=0.5, alpha=0.1, eps=1e-6): 24 | """ 25 | Args: 26 | p (float): probability of using MixStyle. 27 | alpha (float): parameter of the Beta distribution. 28 | eps (float): scaling parameter to avoid numerical issues. 29 | """ 30 | super().__init__() 31 | self.p = p 32 | self.beta = torch.distributions.Beta(alpha, alpha) 33 | self.eps = eps 34 | self.alpha = alpha 35 | 36 | self._activated = True 37 | 38 | def __repr__(self): 39 | return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})' 40 | 41 | def set_activation_status(self, status=True): 42 | self._activated = status 43 | 44 | def forward(self, x): 45 | if not self.training or not self._activated: 46 | return x 47 | 48 | if random.random() > self.p: 49 | return x 50 | 51 | B = x.size(0) 52 | 53 | mu = x.mean(dim=[2, 3], keepdim=True) 54 | var = x.var(dim=[2, 3], keepdim=True) 55 | sig = (var + self.eps).sqrt() 56 | mu, sig = mu.detach(), sig.detach() 57 | x_normed = (x - mu) / sig 58 | 59 | lmda = self.beta.sample((B, 1, 1, 1)) 60 | lmda = lmda.to(x.device) 61 | 62 | perm = torch.randperm(B) 63 | mu2, sig2 = mu[perm], sig[perm] 64 | mu_mix = mu * lmda + mu2 * (1 - lmda) 65 | sig_mix = sig * lmda + sig2 * (1 - lmda) 66 | 67 | return x_normed * sig_mix + mu_mix 68 | 69 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torchvision 5 | from .mixstyle import MixStyle 6 | 7 | model_urls = { 8 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 9 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 10 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 13 | } 14 | 15 | 16 | class resnet18(nn.Module): 17 | def __init__(self, pretrained=True,ms_layers=[],ms_p=0.5,ms_a=0.1): 18 | super(resnet18, self).__init__() 19 | self.model = torchvision.models.resnet18(pretrained=pretrained) 20 | self.fdim = self.model.fc.in_features 21 | self.mixstyle = None 22 | if ms_layers: 23 | self.mixstyle = MixStyle(p = ms_p, alpha= ms_a) 24 | for layer_name in ms_layers: 25 | assert layer_name in ['layer1', 'layer2', 'layer3'] 26 | print(f'Insert MixStyle after {ms_layers}') 27 | self.ms_layers = ms_layers 28 | 29 | def forward(self, x): 30 | x = self.model.conv1(x) 31 | x = self.model.bn1(x) 32 | x = self.model.relu(x) 33 | x = self.model.maxpool(x) 34 | x = self.model.layer1(x) 35 | if 'layer1' in self.ms_layers: 36 | x = self.mixstyle(x) 37 | x = self.model.layer2(x) 38 | if 'layer2' in self.ms_layers: 39 | x = self.mixstyle(x) 40 | x = self.model.layer3(x) 41 | if 'layer3' in self.ms_layers: 42 | x = self.mixstyle(x) 43 | x = self.model.layer4(x) 44 | x = self.model.avgpool(x) 45 | x = torch.flatten(x, 1) 46 | 47 | return x 48 | 49 | def init_pretrained_weights(self, model_url): 50 | pretrain_dict = model_zoo.load_url(model_url) 51 | pretrain_dict_1 = {} 52 | for key in pretrain_dict: 53 | if "model." in key: 54 | pretrain_dict_1[key.replace("model.", "")] = pretrain_dict[key] 55 | self.model.load_state_dict(pretrain_dict_1, strict=False) 56 | 57 | def load_imagenet_dict(self, pretrained_dict): 58 | print('--------Loading weight--------') 59 | model_dict = self.model.state_dict() 60 | if list(model_dict.keys())[0].find('model') != -1: 61 | pretrained_dict = {'model.' + k: v for k, v in pretrained_dict.items()} 62 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 63 | model_dict.update(pretrained_dict) 64 | self.model.load_state_dict(model_dict) 65 | 66 | 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | Pillow==8.1.0 3 | PyYAML==5.3.1 4 | scipy==1.9.3 5 | scikit_learn==1.1.1 6 | six==1.15.0 7 | torch==1.7.1 8 | torchvision==0.8.2 9 | -------------------------------------------------------------------------------- /script/digits.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main.py --task mnist2mnist_m 5 | 6 | python main.py --task mnist2syn 7 | 8 | python main.py --task mnist2svhn 9 | 10 | python main.py --task mnist_m2mnist 11 | 12 | python main.py --task mnist_m2syn 13 | 14 | python main.py --task mnist_m2svhn 15 | 16 | python main.py --task syn2mnist 17 | 18 | python main.py --task syn2mnist_m 19 | 20 | python main.py --task syn2svhn 21 | 22 | python main.py --task svhn2mnist 23 | 24 | python main.py --task svhn2mnist_m 25 | 26 | python main.py --task svhn2syn 27 | -------------------------------------------------------------------------------- /script/officeHome.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main.py --task art2realWorld 5 | 6 | python main.py --task art2clipart 7 | 8 | python main.py --task art2product 9 | 10 | python main.py --task realWorld2art 11 | 12 | python main.py --task realWorld2clipart 13 | 14 | python main.py --task realWorld2product 15 | 16 | python main.py --task clipart2art 17 | 18 | python main.py --task clipart2realWorld 19 | 20 | python main.py --task clipart2product 21 | 22 | python main.py --task product2art 23 | 24 | python main.py --task product2realWorld 25 | 26 | python main.py --task product2clipart 27 | -------------------------------------------------------------------------------- /script/officeHome_web_ignored.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main_web_ignored.py --task art2realWorld 5 | 6 | python main_web_ignored.py --task art2clipart 7 | 8 | python main_web_ignored.py --task art2product 9 | 10 | python main_web_ignored.py --task realWorld2art 11 | 12 | python main_web_ignored.py --task realWorld2clipart 13 | 14 | python main_web_ignored.py --task realWorld2product 15 | 16 | python main_web_ignored.py --task clipart2art 17 | 18 | python main_web_ignored.py --task clipart2realWorld 19 | 20 | python main_web_ignored.py --task clipart2product 21 | 22 | python main_web_ignored.py --task product2art 23 | 24 | python main_web_ignored.py --task product2realWorld 25 | 26 | python main_web_ignored.py --task product2clipart 27 | -------------------------------------------------------------------------------- /script/officeHome_web_refined.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main_web_refined.py --task art2realWorld 5 | 6 | python main_web_refined.py --task art2clipart 7 | 8 | python main_web_refined.py --task art2product 9 | 10 | python main_web_refined.py --task realWorld2art 11 | 12 | python main_web_refined.py --task realWorld2clipart 13 | 14 | python main_web_refined.py --task realWorld2product 15 | 16 | python main_web_refined.py --task clipart2art 17 | 18 | python main_web_refined.py --task clipart2realWorld 19 | 20 | python main_web_refined.py --task clipart2product 21 | 22 | python main_web_refined.py --task product2art 23 | 24 | python main_web_refined.py --task product2realWorld 25 | 26 | python main_web_refined.py --task product2clipart 27 | -------------------------------------------------------------------------------- /script/officeHome_web_retained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main_web_retained.py --task art2realWorld 5 | 6 | python main_web_retained.py --task art2clipart 7 | 8 | python main_web_retained.py --task art2product 9 | 10 | python main_web_retained.py --task realWorld2art 11 | 12 | python main_web_retained.py --task realWorld2clipart 13 | 14 | python main_web_retained.py --task realWorld2product 15 | 16 | python main_web_retained.py --task clipart2art 17 | 18 | python main_web_retained.py --task clipart2realWorld 19 | 20 | python main_web_retained.py --task clipart2product 21 | 22 | python main_web_retained.py --task product2art 23 | 24 | python main_web_retained.py --task product2realWorld 25 | 26 | python main_web_retained.py --task product2clipart 27 | -------------------------------------------------------------------------------- /script/pacs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main.py --task P2A 5 | 6 | python main.py --task P2C 7 | 8 | python main.py --task P2S 9 | 10 | python main.py --task S2A 11 | 12 | python main.py --task S2C 13 | 14 | python main.py --task S2P 15 | 16 | python main.py --task A2C 17 | 18 | python main.py --task A2S 19 | 20 | python main.py --task A2P 21 | 22 | python main.py --task C2P 23 | 24 | python main.py --task C2A 25 | 26 | python main.py --task C2S 27 | -------------------------------------------------------------------------------- /script/pacs_web_ignored.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main_web_ignored.py --task P2A 5 | 6 | python main_web_ignored.py --task P2C 7 | 8 | python main_web_ignored.py --task P2S 9 | 10 | python main_web_ignored.py --task S2A 11 | 12 | python main_web_ignored.py --task S2C 13 | 14 | python main_web_ignored.py --task S2P 15 | 16 | python main_web_ignored.py --task A2C 17 | 18 | python main_web_ignored.py --task A2S 19 | 20 | python main_web_ignored.py --task A2P 21 | 22 | python main_web_ignored.py --task C2P 23 | 24 | python main_web_ignored.py --task C2A 25 | 26 | python main_web_ignored.py --task C2S 27 | -------------------------------------------------------------------------------- /script/pacs_web_refined.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main_web_refined.py --task P2A 5 | 6 | python main_web_refined.py --task P2C 7 | 8 | python main_web_refined.py --task P2S 9 | 10 | python main_web_refined.py --task S2A 11 | 12 | python main_web_refined.py --task S2C 13 | 14 | python main_web_refined.py --task S2P 15 | 16 | python main_web_refined.py --task A2C 17 | 18 | python main_web_refined.py --task A2S 19 | 20 | python main_web_refined.py --task A2P 21 | 22 | python main_web_refined.py --task C2P 23 | 24 | python main_web_refined.py --task C2A 25 | 26 | python main_web_refined.py --task C2S 27 | -------------------------------------------------------------------------------- /script/pacs_web_retained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | python main_web_retained.py --task P2A 5 | 6 | python main_web_retained.py --task P2C 7 | 8 | python main_web_retained.py --task P2S 9 | 10 | python main_web_retained.py --task S2A 11 | 12 | python main_web_retained.py --task S2C 13 | 14 | python main_web_retained.py --task S2P 15 | 16 | python main_web_retained.py --task A2C 17 | 18 | python main_web_retained.py --task A2S 19 | 20 | python main_web_retained.py --task A2P 21 | 22 | python main_web_retained.py --task C2P 23 | 24 | python main_web_retained.py --task C2A 25 | 26 | python main_web_retained.py --task C2S 27 | -------------------------------------------------------------------------------- /tools/parse_test_res_single.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import os.path as osp 4 | import argparse 5 | import os 6 | 7 | 8 | def listdir_nohidden(path, sort=False): 9 | items = [f for f in os.listdir(path) if not f.startswith('.')] 10 | if sort: 11 | items.sort() 12 | return items 13 | 14 | def compute_ci95(res): 15 | return 1.96 * np.std(res) / np.sqrt(len(res)) 16 | 17 | 18 | def parse_dir(directory, end_signal, regex_acc, regex_err, args): 19 | print('Parsing {}'.format(directory)) 20 | 21 | valid_fpaths = [] 22 | valid_accs = [] 23 | valid_errs = [] 24 | 25 | 26 | fpath = directory 27 | good_to_go = False 28 | 29 | with open(fpath, 'r') as f: 30 | lines = f.readlines() 31 | 32 | for line in lines: 33 | line = line.strip() 34 | 35 | if line == end_signal: 36 | good_to_go = True 37 | 38 | match_acc = regex_acc.search(line) 39 | if match_acc and good_to_go: 40 | acc = float(match_acc.group(1)) 41 | valid_accs.append(acc) 42 | valid_fpaths.append(fpath) 43 | 44 | match_err = regex_err.search(line) 45 | if match_err and good_to_go: 46 | err = float(match_err.group(1)) 47 | valid_errs.append(err) 48 | valid_fpaths = valid_fpaths[-5:] 49 | valid_accs = valid_accs[-5:] 50 | valid_errs = valid_errs[-5:] 51 | 52 | for fpath, acc, err in zip(valid_fpaths, valid_accs, valid_errs): 53 | print('file: {}. acc: {:.2f}%. err: {:.2f}%'.format(fpath, acc, err)) 54 | 55 | acc_mean = np.mean(valid_accs) 56 | acc_std = compute_ci95(valid_accs) if args.ci95 else np.std(valid_accs) 57 | 58 | err_mean = np.mean(valid_errs) 59 | err_std = compute_ci95(valid_errs) if args.ci95 else np.std(valid_errs) 60 | 61 | print('===') 62 | print('outcome of directory: {}'.format(directory)) 63 | if args.res_format in ['acc', 'acc_and_err']: 64 | print('* acc: {:.2f}% +- {:.2f}%'.format(acc_mean, acc_std)) 65 | if args.res_format in ['err', 'acc_and_err']: 66 | print('* err: {:.2f}% +- {:.2f}%'.format(err_mean, err_std)) 67 | print('===') 68 | 69 | return acc_mean, err_mean 70 | 71 | 72 | def main(args, end_signal): 73 | regex_acc = re.compile(r'\* accuracy: ([\.\deE+-]+)%') 74 | regex_err = re.compile(r'\* error: ([\.\deE+-]+)%') 75 | 76 | if args.multi_exp: 77 | accs, errs = [], [] 78 | for directory in listdir_nohidden(args.directory, sort=True): 79 | directory = osp.join(args.directory, directory) 80 | acc, err = parse_dir( 81 | directory, end_signal, regex_acc, regex_err, args 82 | ) 83 | accs.append(acc) 84 | errs.append(err) 85 | acc_mean = np.mean(accs) 86 | err_mean = np.mean(errs) 87 | print('overall average') 88 | if args.res_format in ['acc', 'acc_and_err']: 89 | print('* acc: {:.2f}%'.format(acc_mean)) 90 | if args.res_format in ['err', 'acc_and_err']: 91 | print('* err: {:.2f}%'.format(err_mean)) 92 | else: 93 | parse_dir(args.directory, end_signal, regex_acc, regex_err, args) 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('directory', type=str, help='Path to directory') 99 | parser.add_argument( 100 | '--ci95', 101 | action='store_true', 102 | help=r'Compute 95\% confidence interval' 103 | ) 104 | parser.add_argument( 105 | '--test-log', action='store_true', help='Process test log' 106 | ) 107 | parser.add_argument( 108 | '--multi-exp', action='store_true', help='Multiple experiments' 109 | ) 110 | parser.add_argument( 111 | '--res-format', 112 | type=str, 113 | default='acc', 114 | choices=['acc', 'err', 'acc_and_err'] 115 | ) 116 | args = parser.parse_args() 117 | end_signal = 'Finished training' 118 | if args.test_log: 119 | end_signal = '=> result' 120 | main(args, end_signal) 121 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .meters import * 3 | from .torchtools import * 4 | from .logger import * 5 | from .mixup import * 6 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import torch 4 | import torch.utils.data as data 5 | from PIL import Image, ImageFile 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | 8 | def get_image_dirs(root, dname, split): 9 | root = osp.abspath(root) 10 | image_dir = osp.join(root, 'images') 11 | split_dir = osp.join(root, 'splits') 12 | 13 | def read_split_pacs(split_file): 14 | items = [] 15 | 16 | with open(split_file, 'r') as f: 17 | lines = f.readlines() 18 | 19 | for line in lines: 20 | line = line.strip() 21 | impath, label = line.split(' ') 22 | impath = osp.join(image_dir, impath) 23 | if 'pacs' in image_dir: 24 | label = int(label) - 1 25 | else: 26 | label = int(label) 27 | items.append((impath, label)) 28 | 29 | return items 30 | if split == 'all': 31 | file_train = osp.join( 32 | split_dir, dname + '_train_kfold.txt' 33 | ) 34 | impath_label_list = read_split_pacs(file_train) 35 | file_val = osp.join( 36 | split_dir, dname + '_crossval_kfold.txt' 37 | ) 38 | impath_label_list += read_split_pacs(file_val) 39 | else: 40 | file = osp.join( 41 | split_dir, dname + '_' + split + '_kfold.txt' 42 | ) 43 | impath_label_list = read_split_pacs(file) 44 | 45 | return impath_label_list 46 | 47 | 48 | class base_dataset(data.Dataset): 49 | def __init__(self, impath_label_list, transform): 50 | self.impath_label_list = impath_label_list 51 | self.transform = transform 52 | 53 | def __len__(self): 54 | return len(self.impath_label_list) 55 | 56 | def __getitem__(self, index): 57 | impath, label = self.impath_label_list[index] 58 | img = Image.open(impath).convert('RGB') 59 | img = self.transform(img) 60 | output = { 61 | 'img': img, 62 | 'label': torch.tensor(label), 63 | 'impath': impath 64 | } 65 | return output 66 | 67 | def get_web_image_dirs(root): 68 | path = os.path.join(root, 'all_list.txt') 69 | file = open(path, 'r',encoding="gbk") 70 | items = [] 71 | for line in file.readlines(): 72 | impath = line.split()[0] 73 | impath = osp.join(root, impath) 74 | label = int(line.split()[1]) 75 | items.append((impath, label)) 76 | return items 77 | -------------------------------------------------------------------------------- /utils/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | from collections import OrderedDict, defaultdict 4 | import torch 5 | from sklearn.metrics import confusion_matrix 6 | 7 | 8 | class EvaluatorBase: 9 | """Base evaluator.""" 10 | 11 | def reset(self): 12 | raise NotImplementedError 13 | 14 | def process(self, mo, gt): 15 | raise NotImplementedError 16 | 17 | def evaluate(self): 18 | raise NotImplementedError 19 | 20 | 21 | class Classification(EvaluatorBase): 22 | """Evaluator for classification.""" 23 | 24 | def __init__(self, lab2cname=None, **kwargs): 25 | super().__init__() 26 | self._lab2cname = lab2cname 27 | self._correct = 0 28 | self._total = 0 29 | self._per_class_res = None 30 | self._y_true = [] 31 | self._y_pred = [] 32 | 33 | def reset(self): 34 | self._correct = 0 35 | self._total = 0 36 | if self._per_class_res is not None: 37 | self._per_class_res = defaultdict(list) 38 | 39 | def process(self, mo, gt): 40 | # mo (torch.Tensor): model output [batch, num_classes] 41 | # gt (torch.LongTensor): ground truth [batch] 42 | pred = mo.max(1)[1] 43 | matches = pred.eq(gt).float() 44 | self._correct += int(matches.sum().item()) 45 | self._total += gt.shape[0] 46 | 47 | self._y_true.extend(gt.data.cpu().numpy().tolist()) 48 | self._y_pred.extend(pred.data.cpu().numpy().tolist()) 49 | 50 | if self._per_class_res is not None: 51 | for i, label in enumerate(gt): 52 | label = label.item() 53 | matches_i = int(matches[i].item()) 54 | self._per_class_res[label].append(matches_i) 55 | 56 | def evaluate(self): 57 | results = OrderedDict() 58 | acc = 100. * self._correct / self._total 59 | err = 100. - acc 60 | results['accuracy'] = acc 61 | results['error_rate'] = err 62 | 63 | print( 64 | '=> result\n' 65 | '* total: {:,}\n' 66 | '* correct: {:,}\n' 67 | '* accuracy: {:.2f}%\n' 68 | '* error: {:.2f}%'.format(self._total, self._correct, acc, err) 69 | ) 70 | 71 | if self._per_class_res is not None: 72 | labels = list(self._per_class_res.keys()) 73 | labels.sort() 74 | 75 | print('=> per-class result') 76 | accs = [] 77 | 78 | for label in labels: 79 | classname = self._lab2cname[label] 80 | res = self._per_class_res[label] 81 | correct = sum(res) 82 | total = len(res) 83 | acc = 100. * correct / total 84 | accs.append(acc) 85 | print( 86 | '* class: {} ({})\t' 87 | 'total: {:,}\t' 88 | 'correct: {:,}\t' 89 | 'acc: {:.2f}%'.format( 90 | label, classname, total, correct, acc 91 | ) 92 | ) 93 | mean_acc = np.mean(accs) 94 | print('* average: {:.2f}%'.format(mean_acc)) 95 | 96 | results['perclass_accuracy'] = mean_acc 97 | 98 | return results 99 | -------------------------------------------------------------------------------- /utils/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | class GaussianBlur(object): 8 | """blur a single image on CPU""" 9 | 10 | def __init__(self, kernel_size): 11 | radias = kernel_size // 2 12 | kernel_size = radias * 2 + 1 13 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 14 | stride=1, padding=0, bias=False, groups=3) 15 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 16 | stride=1, padding=0, bias=False, groups=3) 17 | self.k = kernel_size 18 | self.r = radias 19 | 20 | self.blur = nn.Sequential( 21 | nn.ReflectionPad2d(radias), 22 | self.blur_h, 23 | self.blur_v 24 | ) 25 | 26 | self.pil_to_tensor = transforms.ToTensor() 27 | self.tensor_to_pil = transforms.ToPILImage() 28 | 29 | def __call__(self, img): 30 | img = self.pil_to_tensor(img).unsqueeze(0) 31 | 32 | sigma = np.random.uniform(0.1, 2.0) 33 | x = np.arange(-self.r, self.r + 1) 34 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 35 | x = x / x.sum() 36 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 37 | 38 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 39 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 40 | 41 | with torch.no_grad(): 42 | img = self.blur(img) 43 | img = img.squeeze() 44 | 45 | img = self.tensor_to_pil(img) 46 | 47 | return img 48 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import os.path as osp 5 | 6 | from .tools import mkdir_if_missing 7 | 8 | __all__ = ['Logger', 'setup_logger'] 9 | 10 | 11 | class Logger: 12 | """Write console output to external text file. 13 | 14 | Imported from ``_ 15 | 16 | Args: 17 | fpath (str): directory to save logging file. 18 | 19 | Examples:: 20 | >>> import sys 21 | >>> import os.path as osp 22 | >>> save_dir = 'output/experiment-1' 23 | >>> log_name = 'train.log' 24 | >>> sys.stdout = Logger(osp.join(save_dir, log_name)) 25 | """ 26 | 27 | def __init__(self, fpath=None): 28 | self.console = sys.stdout 29 | self.file = None 30 | if fpath is not None: 31 | mkdir_if_missing(osp.dirname(fpath)) 32 | self.file = open(fpath, 'w') 33 | 34 | def __del__(self): 35 | self.close() 36 | 37 | def __enter__(self): 38 | pass 39 | 40 | def __exit__(self, *args): 41 | self.close() 42 | 43 | def write(self, msg): 44 | self.console.write(msg) 45 | if self.file is not None: 46 | self.file.write(msg) 47 | 48 | def flush(self): 49 | self.console.flush() 50 | if self.file is not None: 51 | self.file.flush() 52 | os.fsync(self.file.fileno()) 53 | 54 | def close(self): 55 | self.console.close() 56 | if self.file is not None: 57 | self.file.close() 58 | 59 | 60 | def setup_logger(output=None): 61 | if output is None: 62 | return 63 | 64 | if output.endswith('.txt') or output.endswith('.log'): 65 | fpath = output 66 | else: 67 | fpath = osp.join(output, 'log.txt') 68 | 69 | if osp.exists(fpath): 70 | # make sure the existing log file is not over-written 71 | fpath = fpath[:-4] + time.strftime('-%Y-%m-%d-%H-%M-%S') + ".log" 72 | 73 | sys.stdout = Logger(fpath) 74 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | 4 | __all__ = ['AverageMeter', 'MetricMeter'] 5 | 6 | 7 | class AverageMeter: 8 | """Compute and store the average and current value. 9 | 10 | Examples:: 11 | >>> # 1. Initialize a meter to record loss 12 | >>> losses = AverageMeter() 13 | >>> # 2. Update meter after every mini-batch update 14 | >>> losses.update(loss_value, batch_size) 15 | """ 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | if isinstance(val, torch.Tensor): 28 | val = val.item() 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / self.count 33 | 34 | 35 | class MetricMeter: 36 | """Store the average and current value for a set of metrics. 37 | 38 | Examples:: 39 | >>> # 1. Create an instance of MetricMeter 40 | >>> metric = MetricMeter() 41 | >>> # 2. Update using a dictionary as input 42 | >>> input_dict = {'loss_1': value_1, 'loss_2': value_2} 43 | >>> metric.update(input_dict) 44 | >>> # 3. Convert to string and print 45 | >>> print(str(metric)) 46 | """ 47 | 48 | def __init__(self, delimiter='\t'): 49 | self.meters = defaultdict(AverageMeter) 50 | self.delimiter = delimiter 51 | 52 | def update(self, input_dict): 53 | if input_dict is None: 54 | return 55 | 56 | if not isinstance(input_dict, dict): 57 | raise TypeError( 58 | 'Input to MetricMeter.update() must be a dictionary' 59 | ) 60 | 61 | for k, v in input_dict.items(): 62 | if isinstance(v, torch.Tensor): 63 | v = v.item() 64 | self.meters[k].update(v) 65 | 66 | def __str__(self): 67 | output_str = [] 68 | for name, meter in self.meters.items(): 69 | output_str.append( 70 | '{} {:.4f} ({:.4f})'.format(name, meter.val, meter.avg) 71 | ) 72 | return self.delimiter.join(output_str) 73 | -------------------------------------------------------------------------------- /utils/mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def mixup_data(x, y, alpha=1.0): 6 | """Returns mixed inputs, pairs of targets, and lambda""" 7 | if alpha > 0: 8 | lam = np.random.beta(alpha, alpha) 9 | else: 10 | lam = 1 11 | 12 | batch_size = x.size()[0] 13 | index = torch.randperm(batch_size).to(x.device) 14 | 15 | mixed_x = lam * x + (1 - lam) * x[index, :] 16 | y_a, y_b = y, y[index] 17 | return mixed_x, y_a, y_b, lam 18 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.sampler import Sampler 3 | 4 | 5 | class UnifLabelSampler(Sampler): 6 | """Samples elements uniformely accross pseudolabels. 7 | Args: 8 | N (int): size of returned iterator. 9 | images_lists: dict of key (target), value (list of data with this target) 10 | """ 11 | 12 | def __init__(self, N, images_lists): 13 | self.N = N 14 | self.images_lists = images_lists 15 | self.indexes = self.generate_indexes_epoch() 16 | 17 | def generate_indexes_epoch(self): 18 | nmb_non_empty_clusters = 0 19 | for i in range(len(self.images_lists)): 20 | if len(self.images_lists[i]) != 0: 21 | nmb_non_empty_clusters += 1 22 | 23 | size_per_pseudolabel = int(self.N / nmb_non_empty_clusters) + 1 24 | res = np.array([]) 25 | 26 | for i in range(len(self.images_lists)): 27 | # skip empty clusters 28 | if len(self.images_lists[i]) == 0: 29 | continue 30 | indexes = np.random.choice( 31 | self.images_lists[i], 32 | size_per_pseudolabel, 33 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel) 34 | ) 35 | res = np.concatenate((res, indexes)) 36 | 37 | np.random.shuffle(res) 38 | res = list(res.astype('int')) 39 | if len(res) >= self.N: 40 | return res[:self.N] 41 | res += res[: (self.N - len(res))] 42 | return res 43 | 44 | def __iter__(self): 45 | return iter(self.indexes) 46 | 47 | def __len__(self): 48 | return len(self.indexes) 49 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/KaiyangZhou/deep-person-reid 3 | """ 4 | import os 5 | import sys 6 | import json 7 | import time 8 | import errno 9 | import numpy as np 10 | import random 11 | import os.path as osp 12 | import warnings 13 | from difflib import SequenceMatcher 14 | import PIL 15 | import torch 16 | from PIL import Image 17 | 18 | __all__ = [ 19 | 'mkdir_if_missing', 'check_isfile', 'read_json', 'write_json', 20 | 'set_random_seed', 'download_url', 'read_image', 'collect_env_info', 21 | 'listdir_nohidden', 'get_most_similar_str_to_a_from_b', 22 | 'check_availability', 'tolist_if_not' 23 | ] 24 | 25 | 26 | def mkdir_if_missing(dirname): 27 | """Create dirname if it is missing.""" 28 | if not osp.exists(dirname): 29 | try: 30 | os.makedirs(dirname) 31 | except OSError as e: 32 | if e.errno != errno.EEXIST: 33 | raise 34 | 35 | 36 | def check_isfile(fpath): 37 | """Check if the given path is a file. 38 | 39 | Args: 40 | fpath (str): file path. 41 | 42 | Returns: 43 | bool 44 | """ 45 | isfile = osp.isfile(fpath) 46 | if not isfile: 47 | warnings.warn('No file found at "{}"'.format(fpath)) 48 | return isfile 49 | 50 | 51 | def read_json(fpath): 52 | """Read json file from a path.""" 53 | with open(fpath, 'r') as f: 54 | obj = json.load(f) 55 | return obj 56 | 57 | 58 | def write_json(obj, fpath): 59 | """Writes to a json file.""" 60 | mkdir_if_missing(osp.dirname(fpath)) 61 | with open(fpath, 'w') as f: 62 | json.dump(obj, f, indent=4, separators=(',', ': ')) 63 | 64 | 65 | def set_random_seed(seed): 66 | random.seed(seed) 67 | np.random.seed(seed) 68 | torch.manual_seed(seed) 69 | torch.cuda.manual_seed_all(seed) 70 | 71 | 72 | def download_url(url, dst): 73 | """Download file from a url to a destination. 74 | 75 | Args: 76 | url (str): url to download file. 77 | dst (str): destination path. 78 | """ 79 | from six.moves import urllib 80 | print('* url="{}"'.format(url)) 81 | print('* destination="{}"'.format(dst)) 82 | 83 | def _reporthook(count, block_size, total_size): 84 | global start_time 85 | if count == 0: 86 | start_time = time.time() 87 | return 88 | duration = time.time() - start_time 89 | progress_size = int(count * block_size) 90 | speed = int(progress_size / (1024*duration)) 91 | percent = int(count * block_size * 100 / total_size) 92 | sys.stdout.write( 93 | '\r...%d%%, %d MB, %d KB/s, %d seconds passed' % 94 | (percent, progress_size / (1024*1024), speed, duration) 95 | ) 96 | sys.stdout.flush() 97 | 98 | urllib.request.urlretrieve(url, dst, _reporthook) 99 | sys.stdout.write('\n') 100 | 101 | 102 | def read_image(path): 103 | """Read image from path using ``PIL.Image``. 104 | 105 | Args: 106 | path (str): path to an image. 107 | 108 | Returns: 109 | PIL image 110 | """ 111 | if not osp.exists(path): 112 | raise IOError('No file exists at {}'.format(path)) 113 | 114 | while True: 115 | try: 116 | img = Image.open(path).convert('RGB') 117 | return img 118 | except IOError: 119 | print( 120 | 'Cannot read image from {}, ' 121 | 'probably due to heavy IO. Will re-try'.format(path) 122 | ) 123 | 124 | 125 | def collect_env_info(): 126 | """Return env info as a string. 127 | 128 | Code source: github.com/facebookresearch/maskrcnn-benchmark 129 | """ 130 | from torch.utils.collect_env import get_pretty_env_info 131 | env_str = get_pretty_env_info() 132 | env_str += '\n Pillow ({})'.format(PIL.__version__) 133 | return env_str 134 | 135 | 136 | def listdir_nohidden(path, sort=False): 137 | """List non-hidden items in a directory. 138 | 139 | Args: 140 | path (str): directory path. 141 | sort (bool): sort the items. 142 | """ 143 | items = [f for f in os.listdir(path) if not f.startswith('.')] 144 | if sort: 145 | items.sort() 146 | return items 147 | 148 | 149 | def get_most_similar_str_to_a_from_b(a, b): 150 | """Return the most similar string to a in b. 151 | 152 | Args: 153 | a (str): probe string. 154 | b (list): a list of candidate strings. 155 | """ 156 | highest_sim = 0 157 | chosen = None 158 | for candidate in b: 159 | sim = SequenceMatcher(None, a, candidate).ratio() 160 | if sim >= highest_sim: 161 | highest_sim = sim 162 | chosen = candidate 163 | return chosen 164 | 165 | 166 | def check_availability(requested, available): 167 | """Check if an element is available in a list. 168 | 169 | Args: 170 | requested (str): probe string. 171 | available (list): a list of available strings. 172 | """ 173 | if requested not in available: 174 | psb_ans = get_most_similar_str_to_a_from_b(requested, available) 175 | raise ValueError( 176 | 'The requested one is expected ' 177 | 'to belong to {}, but got [{}] ' 178 | '(do you mean [{}]?)'.format(available, requested, psb_ans) 179 | ) 180 | 181 | 182 | def tolist_if_not(x): 183 | """Convert to a list.""" 184 | if not isinstance(x, list): 185 | x = [x] 186 | return x 187 | -------------------------------------------------------------------------------- /utils/torchtools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/KaiyangZhou/deep-person-reid 3 | """ 4 | import pickle 5 | import shutil 6 | import os.path as osp 7 | import warnings 8 | from functools import partial 9 | from collections import OrderedDict 10 | import torch 11 | import torch.nn as nn 12 | 13 | from .tools import mkdir_if_missing 14 | 15 | __all__ = [ 16 | 'save_checkpoint', 'load_checkpoint', 'resume_from_checkpoint', 17 | 'open_all_layers', 'open_specified_layers', 'count_num_param', 18 | 'load_pretrained_weights', 'init_network_weights' 19 | ] 20 | 21 | 22 | def save_checkpoint( 23 | state, 24 | save_dir, 25 | is_best=False, 26 | remove_module_from_keys=False, 27 | model_name='' 28 | ): 29 | r"""Save checkpoint. 30 | 31 | Args: 32 | state (dict): dictionary. 33 | save_dir (str): directory to save checkpoint. 34 | is_best (bool, optional): if True, this checkpoint will be copied and named 35 | ``model-best.pth.tar``. Default is False. 36 | remove_module_from_keys (bool, optional): whether to remove "module." 37 | from layer names. Default is False. 38 | model_name (str, optional): model name to save. 39 | 40 | Examples:: 41 | >>> state = { 42 | >>> 'state_dict': model.state_dict(), 43 | >>> 'epoch': 10, 44 | >>> 'optimizer': optimizer.state_dict() 45 | >>> } 46 | >>> save_checkpoint(state, 'log/my_model') 47 | """ 48 | mkdir_if_missing(save_dir) 49 | 50 | if remove_module_from_keys: 51 | # remove 'module.' in state_dict's keys 52 | state_dict = state['state_dict'] 53 | new_state_dict = OrderedDict() 54 | for k, v in state_dict.items(): 55 | if k.startswith('module.'): 56 | k = k[7:] 57 | new_state_dict[k] = v 58 | state['state_dict'] = new_state_dict 59 | 60 | # save model 61 | epoch = state['epoch'] 62 | if not model_name: 63 | model_name = 'model.pth.tar-' + str(epoch) 64 | fpath = osp.join(save_dir, model_name) 65 | torch.save(state, fpath) 66 | print('Checkpoint saved to "{}"'.format(fpath)) 67 | 68 | # save current model name 69 | checkpoint_file = osp.join(save_dir, 'checkpoint') 70 | checkpoint = open(checkpoint_file, 'w+') 71 | checkpoint.write('{}\n'.format(osp.basename(fpath))) 72 | checkpoint.close() 73 | 74 | if is_best: 75 | best_fpath = osp.join(osp.dirname(fpath), 'model-best.pth.tar') 76 | shutil.copy(fpath, best_fpath) 77 | print('Best checkpoint saved to "{}"'.format(best_fpath)) 78 | 79 | 80 | def load_checkpoint(fpath): 81 | r"""Load checkpoint. 82 | 83 | ``UnicodeDecodeError`` can be well handled, which means 84 | python2-saved files can be read from python3. 85 | 86 | Args: 87 | fpath (str): path to checkpoint. 88 | 89 | Returns: 90 | dict 91 | 92 | Examples:: 93 | >>> fpath = 'log/my_model/model.pth.tar-10' 94 | >>> checkpoint = load_checkpoint(fpath) 95 | """ 96 | if fpath is None: 97 | raise ValueError('File path is None') 98 | 99 | if not osp.exists(fpath): 100 | raise FileNotFoundError('File is not found at "{}"'.format(fpath)) 101 | 102 | map_location = None if torch.cuda.is_available() else 'cpu' 103 | 104 | try: 105 | checkpoint = torch.load(fpath, map_location=map_location) 106 | 107 | except UnicodeDecodeError: 108 | pickle.load = partial(pickle.load, encoding="latin1") 109 | pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") 110 | checkpoint = torch.load( 111 | fpath, pickle_module=pickle, map_location=map_location 112 | ) 113 | 114 | except Exception: 115 | print('Unable to load checkpoint from "{}"'.format(fpath)) 116 | raise 117 | 118 | return checkpoint 119 | 120 | 121 | def resume_from_checkpoint(fdir, model, optimizer=None, scheduler=None): 122 | r"""Resume training from a checkpoint. 123 | 124 | This will load (1) model weights and (2) ``state_dict`` 125 | of optimizer if ``optimizer`` is not None. 126 | 127 | Args: 128 | fdir (str): directory where the model was saved. 129 | model (nn.Module): model. 130 | optimizer (Optimizer, optional): an Optimizer. 131 | scheduler (Scheduler, optional): an Scheduler. 132 | 133 | Returns: 134 | int: start_epoch. 135 | 136 | Examples:: 137 | >>> fdir = 'log/my_model' 138 | >>> start_epoch = resume_from_checkpoint(fdir, model, optimizer, scheduler) 139 | """ 140 | with open(osp.join(fdir, 'checkpoint'), 'r') as checkpoint: 141 | model_name = checkpoint.readlines()[0].strip('\n') 142 | fpath = osp.join(fdir, model_name) 143 | 144 | print('Loading checkpoint from "{}"'.format(fpath)) 145 | checkpoint = load_checkpoint(fpath) 146 | model.load_state_dict(checkpoint['state_dict']) 147 | print('Loaded model weights') 148 | 149 | if optimizer is not None and 'optimizer' in checkpoint.keys(): 150 | optimizer.load_state_dict(checkpoint['optimizer']) 151 | print('Loaded optimizer') 152 | 153 | if scheduler is not None and 'scheduler' in checkpoint.keys(): 154 | scheduler.load_state_dict(checkpoint['scheduler']) 155 | print('Loaded scheduler') 156 | 157 | start_epoch = checkpoint['epoch'] 158 | print('Previous epoch: {}'.format(start_epoch)) 159 | 160 | return start_epoch 161 | 162 | 163 | def adjust_learning_rate( 164 | optimizer, 165 | base_lr, 166 | epoch, 167 | stepsize=20, 168 | gamma=0.1, 169 | linear_decay=False, 170 | final_lr=0, 171 | max_epoch=100 172 | ): 173 | r"""Adjust learning rate. 174 | 175 | Deprecated. 176 | """ 177 | if linear_decay: 178 | # linearly decay learning rate from base_lr to final_lr 179 | frac_done = epoch / max_epoch 180 | lr = frac_done*final_lr + (1.-frac_done) * base_lr 181 | else: 182 | # decay learning rate by gamma for every stepsize 183 | lr = base_lr * (gamma**(epoch // stepsize)) 184 | 185 | for param_group in optimizer.param_groups: 186 | param_group['lr'] = lr 187 | 188 | 189 | def set_bn_to_eval(m): 190 | r"""Set BatchNorm layers to eval mode.""" 191 | # 1. no update for running mean and var 192 | # 2. scale and shift parameters are still trainable 193 | classname = m.__class__.__name__ 194 | if classname.find('BatchNorm') != -1: 195 | m.eval() 196 | 197 | 198 | def open_all_layers(model): 199 | r"""Open all layers in model for training. 200 | 201 | Examples:: 202 | >>> open_all_layers(model) 203 | """ 204 | model.train() 205 | for p in model.parameters(): 206 | p.requires_grad = True 207 | 208 | 209 | def open_specified_layers(model, open_layers): 210 | r"""Open specified layers in model for training while keeping 211 | other layers frozen. 212 | 213 | Args: 214 | model (nn.Module): neural net model. 215 | open_layers (str or list): layers open for training. 216 | 217 | Examples:: 218 | >>> # Only model.classifier will be updated. 219 | >>> open_layers = 'classifier' 220 | >>> open_specified_layers(model, open_layers) 221 | >>> # Only model.fc and model.classifier will be updated. 222 | >>> open_layers = ['fc', 'classifier'] 223 | >>> open_specified_layers(model, open_layers) 224 | """ 225 | if isinstance(model, nn.DataParallel): 226 | model = model.module 227 | 228 | if isinstance(open_layers, str): 229 | open_layers = [open_layers] 230 | 231 | for layer in open_layers: 232 | assert hasattr( 233 | model, layer 234 | ), '"{}" is not an attribute of the model, please provide the correct name'.format( 235 | layer 236 | ) 237 | 238 | for name, module in model.named_children(): 239 | if name in open_layers: 240 | module.train() 241 | for p in module.parameters(): 242 | p.requires_grad = True 243 | else: 244 | module.eval() 245 | for p in module.parameters(): 246 | p.requires_grad = False 247 | 248 | 249 | def count_num_param(model): 250 | r"""Count number of parameters in a model. 251 | 252 | Args: 253 | model (nn.Module): network model. 254 | 255 | Examples:: 256 | >>> model_size = count_num_param(model) 257 | """ 258 | return sum(p.numel() for p in model.parameters()) 259 | 260 | 261 | def load_pretrained_weights(model, weight_path): 262 | r"""Load pretrianed weights to model. 263 | 264 | Features:: 265 | - Incompatible layers (unmatched in name or size) will be ignored. 266 | - Can automatically deal with keys containing "module.". 267 | 268 | Args: 269 | model (nn.Module): network model. 270 | weight_path (str): path to pretrained weights. 271 | 272 | Examples:: 273 | >>> weight_path = 'log/my_model/model-best.pth.tar' 274 | >>> load_pretrained_weights(model, weight_path) 275 | """ 276 | checkpoint = load_checkpoint(weight_path) 277 | if 'state_dict' in checkpoint: 278 | state_dict = checkpoint['state_dict'] 279 | else: 280 | state_dict = checkpoint 281 | 282 | model_dict = model.state_dict() 283 | new_state_dict = OrderedDict() 284 | matched_layers, discarded_layers = [], [] 285 | 286 | for k, v in state_dict.items(): 287 | if k.startswith('module.'): 288 | k = k[7:] # discard module. 289 | 290 | if k in model_dict and model_dict[k].size() == v.size(): 291 | new_state_dict[k] = v 292 | matched_layers.append(k) 293 | else: 294 | discarded_layers.append(k) 295 | 296 | model_dict.update(new_state_dict) 297 | model.load_state_dict(model_dict) 298 | 299 | if len(matched_layers) == 0: 300 | warnings.warn( 301 | 'The pretrained weights "{}" cannot be loaded, ' 302 | 'please check the key names manually ' 303 | '(** ignored and continue **)'.format(weight_path) 304 | ) 305 | else: 306 | print( 307 | 'Successfully loaded pretrained weights from "{}"'. 308 | format(weight_path) 309 | ) 310 | if len(discarded_layers) > 0: 311 | print( 312 | '** The following layers are discarded ' 313 | 'due to unmatched keys or layer size: {}'. 314 | format(discarded_layers) 315 | ) 316 | 317 | 318 | def init_network_weights(model, init_type='normal', gain=0.02): 319 | 320 | def _init_func(m): 321 | classname = m.__class__.__name__ 322 | 323 | if hasattr(m, 'weight') and ( 324 | classname.find('Conv') != -1 or classname.find('Linear') != -1 325 | ): 326 | if init_type == 'normal': 327 | nn.init.normal_(m.weight.data, 0.0, gain) 328 | elif init_type == 'xavier': 329 | nn.init.xavier_normal_(m.weight.data, gain=gain) 330 | elif init_type == 'kaiming': 331 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 332 | elif init_type == 'orthogonal': 333 | nn.init.orthogonal_(m.weight.data, gain=gain) 334 | else: 335 | raise NotImplementedError( 336 | 'initialization method {} is not implemented'. 337 | format(init_type) 338 | ) 339 | if hasattr(m, 'bias') and m.bias is not None: 340 | nn.init.constant_(m.bias.data, 0.0) 341 | 342 | elif classname.find('BatchNorm') != -1: 343 | nn.init.constant_(m.weight.data, 1.0) 344 | nn.init.constant_(m.bias.data, 0.0) 345 | 346 | elif classname.find('InstanceNorm') != -1: 347 | if m.weight is not None and m.bias is not None: 348 | nn.init.constant_(m.weight.data, 1.0) 349 | nn.init.constant_(m.bias.data, 0.0) 350 | 351 | model.apply(_init_func) 352 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image 3 | from torchvision.transforms import transforms 4 | from utils.gaussian_blur import GaussianBlur 5 | 6 | 7 | class Random2DTranslation: 8 | """Randomly translates the input image with a probability. 9 | Specifically, given a predefined shape (height, width), the 10 | input is first resized with a factor of 1.125, leading to 11 | (height*1.125, width*1.125), then a random crop is performed. 12 | Such operation is done with a probability. 13 | 14 | Args: 15 | height (int): target image height. 16 | width (int): target image width. 17 | p (float, optional): probability that this operation takes place. 18 | Default is 0.5. 19 | interpolation (int, optional): desired interpolation. Default is 20 | ``PIL.Image.BILINEAR`` 21 | """ 22 | 23 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.p = p 27 | self.interpolation = interpolation 28 | 29 | def __call__(self, img): 30 | if random.uniform(0, 1) > self.p: 31 | return img.resize((self.width, self.height), self.interpolation) 32 | 33 | new_width = int(round(self.width * 1.125)) 34 | new_height = int(round(self.height * 1.125)) 35 | resized_img = img.resize((new_width, new_height), self.interpolation) 36 | 37 | x_maxrange = new_width - self.width 38 | y_maxrange = new_height - self.height 39 | x1 = int(round(random.uniform(0, x_maxrange))) 40 | y1 = int(round(random.uniform(0, y_maxrange))) 41 | croped_img = resized_img.crop( 42 | (x1, y1, x1 + self.width, y1 + self.height) 43 | ) 44 | 45 | return croped_img 46 | 47 | 48 | class MultiViewDataInjector(object): 49 | def __init__(self, *args): 50 | self.transforms = args[0] 51 | 52 | def __call__(self, sample): 53 | output = [transform(sample) for transform in self.transforms] 54 | return output 55 | 56 | 57 | def get_simclr_data_transforms(input_shape, s=1): 58 | # get a set of data augmentation transformations as described in the SimCLR paper. 59 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 60 | data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=eval(input_shape)[0]), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.RandomApply([color_jitter], p=0.8), 63 | transforms.RandomGrayscale(p=0.2), 64 | GaussianBlur(kernel_size=int(0.1 * eval(input_shape)[0])), 65 | transforms.ToTensor()]) 66 | return data_transforms 67 | 68 | 69 | def strong_transform_train(input_size, type='visda', s=1): 70 | rotation = transforms.RandomRotation(30) 71 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 72 | if type == 'digits': 73 | data_transforms = transforms.Compose([ 74 | rotation, 75 | transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0), ratio=(0.5, 2.0)), 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 78 | ]) 79 | elif type == 'signs': 80 | data_transforms = transforms.Compose([ 81 | transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0), ratio=(0.5, 2.0)), 82 | transforms.RandomApply([color_jitter], p=0.8), 83 | transforms.RandomGrayscale(p=0.2), 84 | GaussianBlur(kernel_size=int(0.1 * input_size)), 85 | transforms.ToTensor(), 86 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 87 | ]) 88 | else: # visda 89 | data_transforms = transforms.Compose([ 90 | transforms.RandomResizedCrop(input_size), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.RandomApply([color_jitter], p=0.8), 93 | transforms.RandomGrayscale(p=0.2), 94 | GaussianBlur(kernel_size=int(0.1 * input_size)), 95 | transforms.ToTensor(), 96 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255]) 97 | ]) 98 | return data_transforms 99 | 100 | 101 | def simple_transform_train(input_size, type=''): 102 | if type == 'digits': 103 | data_transforms = transforms.Compose([ 104 | transforms.Resize([input_size, input_size]), 105 | transforms.ToTensor(), 106 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 107 | ]) 108 | else: 109 | data_transforms = transforms.Compose([ 110 | transforms.RandomHorizontalFlip(), 111 | Random2DTranslation(input_size, input_size), 112 | transforms.ToTensor(), 113 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255]) 114 | ]) 115 | return data_transforms 116 | 117 | 118 | 119 | 120 | def simple_transform_test(input_size, type='visda'): 121 | if type == 'digits' or type == 'signs': 122 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 123 | else: 124 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255]) 125 | data_transforms = transforms.Compose([transforms.Resize([input_size, input_size]), 126 | transforms.ToTensor(), 127 | normalize]) 128 | return data_transforms 129 | 130 | 131 | def multiview_transform(input_size, padding, type='visda'): 132 | totensor = transforms.ToTensor() 133 | if type == 'digits' or type == 'signs': 134 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 135 | else: 136 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.255]) 137 | if padding >= 0: # up-scale 138 | data_transforms = transforms.Compose([ 139 | transforms.Resize([input_size + 2 * padding, input_size + 2 * padding]), 140 | transforms.CenterCrop([input_size, input_size]), 141 | totensor, 142 | normalize 143 | ]) 144 | else: # down-scale 145 | data_transforms = transforms.Compose([ 146 | transforms.Resize([input_size, input_size]), 147 | transforms.Pad(-padding, fill=0, padding_mode='constant'), 148 | transforms.Resize([input_size, input_size]), 149 | totensor, 150 | normalize 151 | ]) 152 | return data_transforms 153 | --------------------------------------------------------------------------------