├── README.md ├── TransUNet ├── LICENSE ├── README.md ├── networks │ ├── vit_seg_configs.py │ ├── vit_seg_modeling.py │ └── vit_seg_modeling_resnet_skip.py ├── requirements.txt ├── test.py ├── train.py ├── trainer.py └── utils.py ├── data ├── .DS_Store ├── ._.DS_Store ├── VISDA-C │ ├── download_visda2017.sh │ ├── test_list.txt │ ├── train_list.txt │ └── validation_list.txt ├── generate_label.py ├── office-caltech │ ├── amazon_list.txt │ ├── caltech_list.txt │ ├── dslr_list.txt │ └── webcam_list.txt ├── office-home │ ├── Art.txt │ ├── Art_list.txt │ ├── Clipart.txt │ ├── Clipart_list.txt │ ├── Product.txt │ ├── Product_list.txt │ ├── RealWorld_list.txt │ └── Real_World.txt └── office │ ├── amazon_list.txt │ ├── dslr_list.txt │ └── webcam_list.txt ├── data_list.py ├── image ├── overview.png ├── result_office31.png └── result_officehome.png ├── image_pretrained.py ├── image_source.py ├── image_target.py ├── image_target_oda.py ├── image_test.py ├── loss.py ├── network.py ├── non_local_embedded_gaussian.py ├── run_office_home_more.sh ├── run_office_home_uda.sh ├── run_office_uda.sh ├── run_office_uda_ab.sh └── run_visda.sh /README.md: -------------------------------------------------------------------------------- 1 | # Official implementation for TransDA 2 | Official pytorch implement for [“Transformer-Based Source-Free Domain Adaptation”](https://arxiv.org/abs/2105.14138). 3 | Accepted by APIN 2022 4 | ## Overview: 5 | 6 | 7 | ## Result: 8 | 9 | 10 | 11 | 12 | ## Prerequisites: 13 | - python == 3.6.8 14 | - pytorch ==1.1.0 15 | - torchvision == 0.3.0 16 | - numpy, scipy, sklearn, PIL, argparse, tqdm 17 | 18 | ## Prepare pretrain model 19 | We choose R50-ViT-B_16 as our encoder. 20 | ```bash root transformerdepth 21 | wget https://storage.googleapis.com/vit_models/imagenet21k/R50+ViT-B_16.npz 22 | mkdir ./model/vit_checkpoint/imagenet21k 23 | mv R50+ViT-B_16.npz ./model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz 24 | ``` 25 | Our checkpoints could be find in [Dropbox](https://www.dropbox.com/sh/vost4yt3c2vuuec/AAAHEszAwM4ZTA-BxRe6_9p2a?dl=0) 26 | ## Dataset: 27 | - Please manually download the datasets [Office](https://www.dropbox.com/sh/vja4cdimm0k2um3/AACCKNKV8-HVbEZDPDCyAyf_a?dl=0), [Office-Home](https://www.dropbox.com/sh/vja4cdimm0k2um3/AACCKNKV8-HVbEZDPDCyAyf_a?dl=0), [VisDA](https://github.com/VisionLearningGroup/taskcv-2017-public/tree/master/classification), [Office-Caltech](https://www.dropbox.com/sh/vja4cdimm0k2um3/AACCKNKV8-HVbEZDPDCyAyf_a?dl=0) from the official websites, and modify the path of images in each '.txt' under the folder './data/'. 28 | - The script "download_visda2017.sh" in data fold also can use to download visda 29 | ## Training 30 | ### Office-31 31 | ```python 32 | sh run_office_uda.sh 33 | ``` 34 | ### Office-Home 35 | ```python 36 | sh run_office_home_uda.sh 37 | ``` 38 | ### Office-VisDA 39 | ```python 40 | sh run_visda.sh 41 | ``` 42 | # Reference 43 | 44 | [ViT](https://github.com/jeonsworld/ViT-pytorch) 45 | 46 | [TransUNet](https://github.com/Beckschen/TransUNet) 47 | 48 | [SHOT](https://github.com/tim-learn/SHOT) 49 | -------------------------------------------------------------------------------- /TransUNet/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /TransUNet/README.md: -------------------------------------------------------------------------------- 1 | # TransUNet 2 | This repo holds code for [TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation](https://arxiv.org/pdf/2102.04306.pdf) 3 | 4 | ## Usage 5 | 6 | ### 1. Download Google pre-trained ViT models 7 | * [Get models in this link](https://console.cloud.google.com/storage/vit_models/): R50-ViT-B_16, ViT-B_16, ViT-L_16... 8 | ```bash 9 | wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz && 10 | mkdir ../model/vit_checkpoint/imagenet21k && 11 | mv {MODEL_NAME}.npz ../model/vit_checkpoint/imagenet21k/{MODEL_NAME}.npz 12 | ``` 13 | 14 | ### 2. Prepare data 15 | 16 | Please go to ["./datasets/README.md"](datasets/README.md) for details, or please send an Email to jienengchen01 AT gmail.com to request the preprocessed data. If you would like to use the preprocessed data, please use it for research purposes and do not redistribute it. 17 | 18 | ### 3. Environment 19 | 20 | Please prepare an environment with python=3.7, and then use the command "pip install -r requirements.txt" for the dependencies. 21 | 22 | ### 4. Train/Test 23 | 24 | - Run the train script on synapse dataset. The batch size can be reduced to 12 or 6 to save memory (please also decrease the base_lr linearly), and both can reach similar performance. 25 | 26 | ```bash 27 | CUDA_VISIBLE_DEVICES=0 python train.py --dataset Synapse --vit_name R50-ViT-B_16 28 | ``` 29 | 30 | - Run the test script on synapse dataset. It supports testing for both 2D images and 3D volumes. 31 | 32 | ```bash 33 | python test.py --dataset Synapse --vit_name R50-ViT-B_16 34 | ``` 35 | 36 | ## Reference 37 | * [Google ViT](https://github.com/google-research/vision_transformer) 38 | * [ViT-pytorch](https://github.com/jeonsworld/ViT-pytorch) 39 | * [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) 40 | 41 | ## Citations 42 | 43 | ```bibtex 44 | @article{chen2021transunet, 45 | title={TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation}, 46 | author={Chen, Jieneng and Lu, Yongyi and Yu, Qihang and Luo, Xiangde and Adeli, Ehsan and Wang, Yan and Lu, Le and Yuille, Alan L., and Zhou, Yuyin}, 47 | journal={arXiv preprint arXiv:2102.04306}, 48 | year={2021} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /TransUNet/networks/vit_seg_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | def get_b16_config(): 4 | """Returns the ViT-B/16 configuration.""" 5 | config = ml_collections.ConfigDict() 6 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 7 | config.hidden_size = 768 8 | config.transformer = ml_collections.ConfigDict() 9 | config.transformer.mlp_dim = 3072 10 | config.transformer.num_heads = 12 11 | config.transformer.num_layers = 12 12 | config.transformer.attention_dropout_rate = 0.0 13 | config.transformer.dropout_rate = 0.1 14 | 15 | config.classifier = 'seg' 16 | config.representation_size = None 17 | config.resnet_pretrained_path = None 18 | config.pretrained_path = './model/vit_checkpoint/imagenet21k/ViT-B_16.npz' 19 | config.patch_size = 16 20 | 21 | config.decoder_channels = (256, 128, 64, 16) 22 | config.n_classes = 2 23 | config.activation = 'softmax' 24 | return config 25 | 26 | 27 | def get_testing(): 28 | """Returns a minimal configuration for testing.""" 29 | config = ml_collections.ConfigDict() 30 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 31 | config.hidden_size = 1 32 | config.transformer = ml_collections.ConfigDict() 33 | config.transformer.mlp_dim = 1 34 | config.transformer.num_heads = 1 35 | config.transformer.num_layers = 1 36 | config.transformer.attention_dropout_rate = 0.0 37 | config.transformer.dropout_rate = 0.1 38 | config.classifier = 'token' 39 | config.representation_size = None 40 | return config 41 | 42 | def get_r50_b16_config(): 43 | """Returns the Resnet50 + ViT-B/16 configuration.""" 44 | config = get_b16_config() 45 | config.patches.grid = (16, 16) 46 | config.resnet = ml_collections.ConfigDict() 47 | config.resnet.num_layers = (3, 4, 9) 48 | config.resnet.width_factor = 1 49 | 50 | config.classifier = 'seg' 51 | config.pretrained_path = './model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 52 | config.decoder_channels = (256, 128, 64, 16) 53 | config.skip_channels = [512, 256, 64, 16] 54 | config.n_classes = 2 55 | config.n_skip = 3 56 | config.activation = 'softmax' 57 | 58 | return config 59 | 60 | 61 | def get_b32_config(): 62 | """Returns the ViT-B/32 configuration.""" 63 | config = get_b16_config() 64 | config.patches.size = (32, 32) 65 | config.pretrained_path = './model/vit_checkpoint/imagenet21k/ViT-B_32.npz' 66 | return config 67 | 68 | 69 | def get_l16_config(): 70 | """Returns the ViT-L/16 configuration.""" 71 | config = ml_collections.ConfigDict() 72 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 73 | config.hidden_size = 1024 74 | config.transformer = ml_collections.ConfigDict() 75 | config.transformer.mlp_dim = 4096 76 | config.transformer.num_heads = 16 77 | config.transformer.num_layers = 24 78 | config.transformer.attention_dropout_rate = 0.0 79 | config.transformer.dropout_rate = 0.1 80 | config.representation_size = None 81 | 82 | # custom 83 | config.classifier = 'seg' 84 | config.resnet_pretrained_path = None 85 | config.pretrained_path = './model/vit_checkpoint/imagenet21k/ViT-L_16.npz' 86 | config.decoder_channels = (256, 128, 64, 16) 87 | config.n_classes = 2 88 | config.activation = 'softmax' 89 | return config 90 | 91 | 92 | def get_r50_l16_config(): 93 | """Returns the Resnet50 + ViT-L/16 configuration. customized """ 94 | config = get_l16_config() 95 | config.patches.grid = (16, 16) 96 | config.resnet = ml_collections.ConfigDict() 97 | config.resnet.num_layers = (3, 4, 9) 98 | config.resnet.width_factor = 1 99 | 100 | config.classifier = 'seg' 101 | config.resnet_pretrained_path = './model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' 102 | config.decoder_channels = (256, 128, 64, 16) 103 | config.skip_channels = [512, 256, 64, 16] 104 | config.n_classes = 2 105 | config.activation = 'softmax' 106 | return config 107 | 108 | 109 | def get_l32_config(): 110 | """Returns the ViT-L/32 configuration.""" 111 | config = get_l16_config() 112 | config.patches.size = (32, 32) 113 | return config 114 | 115 | 116 | def get_h14_config(): 117 | """Returns the ViT-L/16 configuration.""" 118 | config = ml_collections.ConfigDict() 119 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 120 | config.hidden_size = 1280 121 | config.transformer = ml_collections.ConfigDict() 122 | config.transformer.mlp_dim = 5120 123 | config.transformer.num_heads = 16 124 | config.transformer.num_layers = 32 125 | config.transformer.attention_dropout_rate = 0.0 126 | config.transformer.dropout_rate = 0.1 127 | config.classifier = 'token' 128 | config.representation_size = None 129 | 130 | return config 131 | -------------------------------------------------------------------------------- /TransUNet/networks/vit_seg_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | from . import vit_seg_configs as configs 20 | from .vit_seg_modeling_resnet_skip import ResNetV2 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 27 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 28 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 29 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 30 | FC_0 = "MlpBlock_3/Dense_0" 31 | FC_1 = "MlpBlock_3/Dense_1" 32 | ATTENTION_NORM = "LayerNorm_0" 33 | MLP_NORM = "LayerNorm_2" 34 | 35 | 36 | def np2th(weights, conv=False): 37 | """Possibly convert HWIO to OIHW.""" 38 | if conv: 39 | weights = weights.transpose([3, 2, 0, 1]) 40 | return torch.from_numpy(weights) 41 | 42 | 43 | def swish(x): 44 | return x * torch.sigmoid(x) 45 | 46 | 47 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 48 | 49 | 50 | class Attention(nn.Module): 51 | def __init__(self, config, vis): 52 | super(Attention, self).__init__() 53 | self.vis = vis 54 | self.num_attention_heads = config.transformer["num_heads"] 55 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 56 | self.all_head_size = self.num_attention_heads * self.attention_head_size 57 | 58 | self.query = Linear(config.hidden_size, self.all_head_size) 59 | self.key = Linear(config.hidden_size, self.all_head_size) 60 | self.value = Linear(config.hidden_size, self.all_head_size) 61 | 62 | self.out = Linear(config.hidden_size, config.hidden_size) 63 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 64 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 65 | 66 | self.softmax = Softmax(dim=-1) 67 | 68 | def transpose_for_scores(self, x): 69 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 70 | x = x.view(*new_x_shape) 71 | return x.permute(0, 2, 1, 3) 72 | 73 | def forward(self, hidden_states): 74 | mixed_query_layer = self.query(hidden_states) 75 | mixed_key_layer = self.key(hidden_states) 76 | mixed_value_layer = self.value(hidden_states) 77 | 78 | query_layer = self.transpose_for_scores(mixed_query_layer) 79 | key_layer = self.transpose_for_scores(mixed_key_layer) 80 | value_layer = self.transpose_for_scores(mixed_value_layer) 81 | 82 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 83 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 84 | attention_probs = self.softmax(attention_scores) 85 | weights = attention_probs if self.vis else None 86 | attention_probs = self.attn_dropout(attention_probs) 87 | 88 | context_layer = torch.matmul(attention_probs, value_layer) 89 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 90 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 91 | context_layer = context_layer.view(*new_context_layer_shape) 92 | attention_output = self.out(context_layer) 93 | attention_output = self.proj_dropout(attention_output) 94 | return attention_output, weights 95 | 96 | 97 | class Mlp(nn.Module): 98 | def __init__(self, config): 99 | super(Mlp, self).__init__() 100 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 101 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 102 | self.act_fn = ACT2FN["gelu"] 103 | self.dropout = Dropout(config.transformer["dropout_rate"]) 104 | 105 | self._init_weights() 106 | 107 | def _init_weights(self): 108 | nn.init.xavier_uniform_(self.fc1.weight) 109 | nn.init.xavier_uniform_(self.fc2.weight) 110 | nn.init.normal_(self.fc1.bias, std=1e-6) 111 | nn.init.normal_(self.fc2.bias, std=1e-6) 112 | 113 | def forward(self, x): 114 | x = self.fc1(x) 115 | x = self.act_fn(x) 116 | x = self.dropout(x) 117 | x = self.fc2(x) 118 | x = self.dropout(x) 119 | return x 120 | 121 | 122 | class Embeddings(nn.Module): 123 | """Construct the embeddings from patch, position embeddings. 124 | """ 125 | def __init__(self, config, img_size, in_channels=3): 126 | super(Embeddings, self).__init__() 127 | self.hybrid = None 128 | self.config = config 129 | # img_size = _pair(img_size) 130 | 131 | if config.patches.get("grid") is not None: # ResNet 132 | grid_size = config.patches["grid"] 133 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) 134 | patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) 135 | n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) 136 | self.hybrid = True 137 | else: 138 | patch_size = _pair(config.patches["size"]) 139 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 140 | self.hybrid = False 141 | 142 | if self.hybrid: 143 | self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) 144 | in_channels = self.hybrid_model.width * 16 145 | self.patch_embeddings = Conv2d(in_channels=in_channels, 146 | out_channels=config.hidden_size, 147 | kernel_size=patch_size, 148 | stride=patch_size) 149 | # self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) 150 | 151 | self.dropout = Dropout(config.transformer["dropout_rate"]) 152 | 153 | 154 | def forward(self, x): 155 | if self.hybrid: 156 | x, features = self.hybrid_model(x) 157 | else: 158 | features = None 159 | x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) 160 | x = x.flatten(2) 161 | x = x.transpose(-1, -2) # (B, n_patches, hidden) 162 | embeddings = x 163 | embeddings = self.dropout(embeddings) 164 | 165 | return embeddings, features 166 | 167 | 168 | class Block(nn.Module): 169 | def __init__(self, config, vis): 170 | super(Block, self).__init__() 171 | self.hidden_size = config.hidden_size 172 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 173 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 174 | self.ffn = Mlp(config) 175 | self.attn = Attention(config, vis) 176 | 177 | def forward(self, x): 178 | h = x 179 | x = self.attention_norm(x) 180 | x, weights = self.attn(x) 181 | x = x + h 182 | 183 | h = x 184 | x = self.ffn_norm(x) 185 | x = self.ffn(x) 186 | x = x + h 187 | return x, weights 188 | 189 | def load_from(self, weights, n_block): 190 | ROOT = f"Transformer/encoderblock_{n_block}" 191 | with torch.no_grad(): 192 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() 193 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 194 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() 195 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() 196 | 197 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 198 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 199 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 200 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 201 | 202 | self.attn.query.weight.copy_(query_weight) 203 | self.attn.key.weight.copy_(key_weight) 204 | self.attn.value.weight.copy_(value_weight) 205 | self.attn.out.weight.copy_(out_weight) 206 | self.attn.query.bias.copy_(query_bias) 207 | self.attn.key.bias.copy_(key_bias) 208 | self.attn.value.bias.copy_(value_bias) 209 | self.attn.out.bias.copy_(out_bias) 210 | 211 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 212 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 213 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 214 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 215 | 216 | self.ffn.fc1.weight.copy_(mlp_weight_0) 217 | self.ffn.fc2.weight.copy_(mlp_weight_1) 218 | self.ffn.fc1.bias.copy_(mlp_bias_0) 219 | self.ffn.fc2.bias.copy_(mlp_bias_1) 220 | 221 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 222 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 223 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 224 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 225 | 226 | 227 | class Encoder(nn.Module): 228 | def __init__(self, config, vis): 229 | super(Encoder, self).__init__() 230 | self.vis = vis 231 | self.layer = nn.ModuleList() 232 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 233 | for _ in range(config.transformer["num_layers"]): 234 | layer = Block(config, vis) 235 | self.layer.append(copy.deepcopy(layer)) 236 | 237 | def forward(self, hidden_states): 238 | attn_weights = [] 239 | for layer_block in self.layer: 240 | hidden_states, weights = layer_block(hidden_states) 241 | if self.vis: 242 | attn_weights.append(weights) 243 | encoded = self.encoder_norm(hidden_states) 244 | return encoded, attn_weights 245 | 246 | 247 | class Transformer(nn.Module): 248 | def __init__(self, config, img_size, vis): 249 | super(Transformer, self).__init__() 250 | self.embeddings = Embeddings(config, img_size=img_size) 251 | self.encoder = Encoder(config, vis) 252 | 253 | def forward(self, input_ids): 254 | embedding_output, features = self.embeddings(input_ids) 255 | encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) 256 | return encoded, attn_weights, features 257 | 258 | 259 | class Conv2dReLU(nn.Sequential): 260 | def __init__( 261 | self, 262 | in_channels, 263 | out_channels, 264 | kernel_size, 265 | padding=0, 266 | stride=1, 267 | use_batchnorm=True, 268 | ): 269 | conv = nn.Conv2d( 270 | in_channels, 271 | out_channels, 272 | kernel_size, 273 | stride=stride, 274 | padding=padding, 275 | bias=not (use_batchnorm), 276 | ) 277 | relu = nn.ReLU(inplace=True) 278 | 279 | bn = nn.BatchNorm2d(out_channels) 280 | 281 | super(Conv2dReLU, self).__init__(conv, bn, relu) 282 | 283 | 284 | class DecoderBlock(nn.Module): 285 | def __init__( 286 | self, 287 | in_channels, 288 | out_channels, 289 | skip_channels=0, 290 | use_batchnorm=True, 291 | ): 292 | super().__init__() 293 | self.conv1 = Conv2dReLU( 294 | in_channels + skip_channels, 295 | out_channels, 296 | kernel_size=3, 297 | padding=1, 298 | use_batchnorm=use_batchnorm, 299 | ) 300 | self.conv2 = Conv2dReLU( 301 | out_channels, 302 | out_channels, 303 | kernel_size=3, 304 | padding=1, 305 | use_batchnorm=use_batchnorm, 306 | ) 307 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 308 | 309 | def forward(self, x, skip=None): 310 | x = self.up(x) 311 | if skip is not None: 312 | x = torch.cat([x, skip], dim=1) 313 | x = self.conv1(x) 314 | x = self.conv2(x) 315 | return x 316 | 317 | 318 | class SegmentationHead(nn.Sequential): 319 | 320 | def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): 321 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 322 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 323 | super().__init__(conv2d, upsampling) 324 | 325 | 326 | class DecoderCup(nn.Module): 327 | def __init__(self, config): 328 | super().__init__() 329 | self.config = config 330 | head_channels = 512 331 | # self.conv_more = Conv2dReLU( 332 | # config.hidden_size, 333 | # 1024, 334 | # kernel_size=3, 335 | # padding=1, 336 | # use_batchnorm=True, 337 | # ) 338 | # self.conv_more_ = Conv2dReLU( 339 | # 1024, 340 | # 2048, 341 | # kernel_size=3, 342 | # stride=2, 343 | # padding=1, 344 | # use_batchnorm=True, 345 | # ) 346 | self.fc= nn.Linear(config.hidden_size,2048) 347 | decoder_channels = config.decoder_channels 348 | in_channels = [head_channels] + list(decoder_channels[:-1]) 349 | out_channels = decoder_channels 350 | 351 | if self.config.n_skip != 0: 352 | skip_channels = self.config.skip_channels 353 | for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip 354 | skip_channels[3-i]=0 355 | 356 | else: 357 | skip_channels=[0,0,0,0] 358 | 359 | blocks = [ 360 | DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) 361 | ] 362 | self.blocks = nn.ModuleList(blocks) 363 | self.avgpool= nn.AdaptiveAvgPool1d(1) 364 | # self.transfermer_f34=nn.Transformer(nhead=4, num_encoder_layers=3,d_model=768) 365 | def forward(self, hidden_states, features=None): 366 | B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 367 | h, w = int(features[0].shape[2]/2), int(features[0].shape[3]/2) 368 | ### unified multi-scale transformer 369 | 370 | x = hidden_states.permute(0, 2, 1) 371 | x = self.avgpool(x) 372 | ### for vis 373 | # vis = x.contiguous().view(B, hidden, h, w) 374 | # vis = functional.interpolate(vis, size=(224,224), mode="bilinear", align_corners=False) 375 | 376 | x = x.contiguous().view(B, hidden) 377 | x = self.fc(x) 378 | 379 | return x 380 | 381 | 382 | class VisionTransformer(nn.Module): 383 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 384 | super(VisionTransformer, self).__init__() 385 | self.num_classes = num_classes 386 | self.zero_head = zero_head 387 | self.classifier = config.classifier 388 | self.transformer = Transformer(config, img_size, vis) 389 | self.decoder = DecoderCup(config) 390 | 391 | self.config = config 392 | 393 | def forward(self, x): 394 | if x.size()[1] == 1: 395 | x = x.repeat(1,3,1,1) 396 | x0, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) 397 | x1 = self.decoder(x0, features) 398 | f=list(reversed(features)) 399 | f.append(x1) 400 | f.insert(0, x) 401 | return f,x1 402 | 403 | def load_from(self, weights): 404 | with torch.no_grad(): 405 | 406 | res_weight = weights 407 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 408 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 409 | 410 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 411 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 412 | 413 | # posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 414 | # 415 | # posemb_new = self.transformer.embeddings.position_embeddings 416 | # if posemb.size() == posemb_new.size(): 417 | # self.transformer.embeddings.position_embeddings.copy_(posemb) 418 | # elif posemb.size()[1]-1 == posemb_new.size()[1]: 419 | # posemb = posemb[:, 1:] 420 | # self.transformer.embeddings.position_embeddings.copy_(posemb) 421 | # else: 422 | # logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) 423 | # ntok_new = posemb_new.size(1) 424 | # if self.classifier == "seg": 425 | # _, posemb_grid = posemb[:, :1], posemb[0, 1:] 426 | # gs_old = int(np.sqrt(len(posemb_grid))) 427 | # gs_new = int(np.sqrt(ntok_new)) 428 | # print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 429 | # posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 430 | # zoom = (gs_new / gs_old, gs_new / gs_old, 1) 431 | # posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np 432 | # posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 433 | # posemb = posemb_grid 434 | # 435 | # # self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 436 | # self.transformer.embeddings.position_embeddings[:,0:posemb.shape[1],:]=np2th(posemb) 437 | # Encoder whole 438 | for bname, block in self.transformer.encoder.named_children(): 439 | for uname, unit in block.named_children(): 440 | unit.load_from(weights, n_block=uname) 441 | 442 | if self.transformer.embeddings.hybrid: 443 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) 444 | gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) 445 | gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) 446 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 447 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 448 | 449 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 450 | for uname, unit in block.named_children(): 451 | unit.load_from(res_weight, n_block=bname, n_unit=uname) 452 | 453 | CONFIGS = { 454 | 'ViT-B_16': configs.get_b16_config(), 455 | 'ViT-B_32': configs.get_b32_config(), 456 | 'ViT-L_16': configs.get_l16_config(), 457 | 'ViT-L_32': configs.get_l32_config(), 458 | 'ViT-H_14': configs.get_h14_config(), 459 | 'R50-ViT-B_16': configs.get_r50_b16_config(), 460 | 'R50-ViT-L_16': configs.get_r50_l16_config(), 461 | 'testing': configs.get_testing(), 462 | } 463 | 464 | 465 | -------------------------------------------------------------------------------- /TransUNet/networks/vit_seg_modeling_resnet_skip.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from os.path import join as pjoin 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def np2th(weights, conv=False): 12 | """Possibly convert HWIO to OIHW.""" 13 | if conv: 14 | weights = weights.transpose([3, 2, 0, 1]) 15 | return torch.from_numpy(weights) 16 | 17 | 18 | class StdConv2d(nn.Conv2d): 19 | 20 | def forward(self, x): 21 | w = self.weight 22 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 23 | w = (w - m) / torch.sqrt(v + 1e-5) 24 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 25 | self.dilation, self.groups) 26 | 27 | 28 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 29 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 30 | padding=1, bias=bias, groups=groups) 31 | 32 | 33 | def conv1x1(cin, cout, stride=1, bias=False): 34 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 35 | padding=0, bias=bias) 36 | 37 | 38 | class PreActBottleneck(nn.Module): 39 | """Pre-activation (v2) bottleneck block. 40 | """ 41 | 42 | def __init__(self, cin, cout=None, cmid=None, stride=1): 43 | super().__init__() 44 | cout = cout or cin 45 | cmid = cmid or cout//4 46 | 47 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) 48 | self.conv1 = conv1x1(cin, cmid, bias=False) 49 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) 50 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! 51 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) 52 | self.conv3 = conv1x1(cmid, cout, bias=False) 53 | self.relu = nn.ReLU(inplace=True) 54 | 55 | if (stride != 1 or cin != cout): 56 | # Projection also with pre-activation according to paper. 57 | self.downsample = conv1x1(cin, cout, stride, bias=False) 58 | self.gn_proj = nn.GroupNorm(cout, cout) 59 | 60 | def forward(self, x): 61 | 62 | # Residual branch 63 | residual = x 64 | if hasattr(self, 'downsample'): 65 | residual = self.downsample(x) 66 | residual = self.gn_proj(residual) 67 | 68 | # Unit's branch 69 | y = self.relu(self.gn1(self.conv1(x))) 70 | y = self.relu(self.gn2(self.conv2(y))) 71 | y = self.gn3(self.conv3(y)) 72 | 73 | y = self.relu(residual + y) 74 | return y 75 | 76 | def load_from(self, weights, n_block, n_unit): 77 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) 78 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) 79 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) 80 | 81 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) 82 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) 83 | 84 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) 85 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) 86 | 87 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) 88 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) 89 | 90 | self.conv1.weight.copy_(conv1_weight) 91 | self.conv2.weight.copy_(conv2_weight) 92 | self.conv3.weight.copy_(conv3_weight) 93 | 94 | self.gn1.weight.copy_(gn1_weight.view(-1)) 95 | self.gn1.bias.copy_(gn1_bias.view(-1)) 96 | 97 | self.gn2.weight.copy_(gn2_weight.view(-1)) 98 | self.gn2.bias.copy_(gn2_bias.view(-1)) 99 | 100 | self.gn3.weight.copy_(gn3_weight.view(-1)) 101 | self.gn3.bias.copy_(gn3_bias.view(-1)) 102 | 103 | if hasattr(self, 'downsample'): 104 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) 105 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) 106 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) 107 | 108 | self.downsample.weight.copy_(proj_conv_weight) 109 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) 110 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) 111 | 112 | class ResNetV2(nn.Module): 113 | """Implementation of Pre-activation (v2) ResNet mode.""" 114 | 115 | def __init__(self, block_units, width_factor): 116 | super().__init__() 117 | width = int(64 * width_factor) 118 | self.width = width 119 | 120 | self.root = nn.Sequential(OrderedDict([ 121 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), 122 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 123 | ('relu', nn.ReLU(inplace=True)), 124 | # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) 125 | ])) 126 | 127 | self.body = nn.Sequential(OrderedDict([ 128 | ('block1', nn.Sequential(OrderedDict( 129 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + 130 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], 131 | ))), 132 | ('block2', nn.Sequential(OrderedDict( 133 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + 134 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], 135 | ))), 136 | ('block3', nn.Sequential(OrderedDict( 137 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + 138 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], 139 | ))), 140 | ])) 141 | 142 | def forward(self, x): 143 | features = [] 144 | b, c, in_size, size_ = x.size() 145 | x = self.root(x) 146 | features.append(x) 147 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) 148 | for i in range(len(self.body)-1): 149 | x = self.body[i](x) 150 | left_size = int(in_size / 4 / (i+1)) 151 | right_size =int(size_ / 4 / (i+1)) 152 | if x.size()[2] != left_size: 153 | pad = left_size - x.size()[2] 154 | assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), left_size) 155 | feat = torch.zeros((b, x.size()[1], left_size,right_size ), device=x.device) 156 | feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] 157 | else: 158 | feat = x 159 | features.append(feat) 160 | x = self.body[-1](x) 161 | return x, features[::-1] 162 | -------------------------------------------------------------------------------- /TransUNet/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision==0.5.0 3 | numpy 4 | tqdm 5 | tensorboard 6 | tensorboardX 7 | ml-collections 8 | medpy 9 | SimpleITK 10 | scipy 11 | h5py 12 | -------------------------------------------------------------------------------- /TransUNet/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from datasets.dataset_synapse import Synapse_dataset 13 | from utils import test_single_volume 14 | from networks.vit_seg_modeling import VisionTransformer as ViT_seg 15 | from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--volume_path', type=str, 19 | default='../data/Synapse/test_vol_h5', help='root dir for validation volume data') # for acdc volume_path=root_dir 20 | parser.add_argument('--dataset', type=str, 21 | default='Synapse', help='experiment_name') 22 | parser.add_argument('--num_classes', type=int, 23 | default=4, help='output channel of network') 24 | parser.add_argument('--list_dir', type=str, 25 | default='./lists/lists_Synapse', help='list dir') 26 | 27 | parser.add_argument('--max_iterations', type=int,default=20000, help='maximum epoch number to train') 28 | parser.add_argument('--max_epochs', type=int, default=30, help='maximum epoch number to train') 29 | parser.add_argument('--batch_size', type=int, default=24, 30 | help='batch_size per gpu') 31 | parser.add_argument('--img_size', type=int, default=224, help='input patch size of network input') 32 | parser.add_argument('--is_savenii', action="store_true", help='whether to save results during inference') 33 | 34 | parser.add_argument('--n_skip', type=int, default=3, help='using number of skip-connect, default is num') 35 | parser.add_argument('--vit_name', type=str, default='ViT-B_16', help='select one vit model') 36 | 37 | parser.add_argument('--test_save_dir', type=str, default='../predictions', help='saving prediction as nii!') 38 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 39 | parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate') 40 | parser.add_argument('--seed', type=int, default=1234, help='random seed') 41 | parser.add_argument('--vit_patches_size', type=int, default=16, help='vit_patches_size, default is 16') 42 | args = parser.parse_args() 43 | 44 | 45 | def inference(args, model, test_save_path=None): 46 | db_test = args.Dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir) 47 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 48 | logging.info("{} test iterations per epoch".format(len(testloader))) 49 | model.eval() 50 | metric_list = 0.0 51 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 52 | h, w = sampled_batch["image"].size()[2:] 53 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] 54 | metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size], 55 | test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing) 56 | metric_list += np.array(metric_i) 57 | logging.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])) 58 | metric_list = metric_list / len(db_test) 59 | for i in range(1, args.num_classes): 60 | logging.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i-1][0], metric_list[i-1][1])) 61 | performance = np.mean(metric_list, axis=0)[0] 62 | mean_hd95 = np.mean(metric_list, axis=0)[1] 63 | logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f' % (performance, mean_hd95)) 64 | return "Testing Finished!" 65 | 66 | 67 | if __name__ == "__main__": 68 | 69 | if not args.deterministic: 70 | cudnn.benchmark = True 71 | cudnn.deterministic = False 72 | else: 73 | cudnn.benchmark = False 74 | cudnn.deterministic = True 75 | random.seed(args.seed) 76 | np.random.seed(args.seed) 77 | torch.manual_seed(args.seed) 78 | torch.cuda.manual_seed(args.seed) 79 | 80 | dataset_config = { 81 | 'Synapse': { 82 | 'Dataset': Synapse_dataset, 83 | 'volume_path': '../data/Synapse/test_vol_h5', 84 | 'list_dir': './lists/lists_Synapse', 85 | 'num_classes': 9, 86 | 'z_spacing': 1, 87 | }, 88 | } 89 | dataset_name = args.dataset 90 | args.num_classes = dataset_config[dataset_name]['num_classes'] 91 | args.volume_path = dataset_config[dataset_name]['volume_path'] 92 | args.Dataset = dataset_config[dataset_name]['Dataset'] 93 | args.list_dir = dataset_config[dataset_name]['list_dir'] 94 | args.z_spacing = dataset_config[dataset_name]['z_spacing'] 95 | args.is_pretrain = True 96 | 97 | # name the same snapshot defined in train script! 98 | args.exp = 'TU_' + dataset_name + str(args.img_size) 99 | snapshot_path = "../model/{}/{}".format(args.exp, 'TU') 100 | snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path 101 | snapshot_path += '_' + args.vit_name 102 | snapshot_path = snapshot_path + '_skip' + str(args.n_skip) 103 | snapshot_path = snapshot_path + '_vitpatch' + str(args.vit_patches_size) if args.vit_patches_size!=16 else snapshot_path 104 | snapshot_path = snapshot_path + '_epo' + str(args.max_epochs) if args.max_epochs != 30 else snapshot_path 105 | if dataset_name == 'ACDC': # using max_epoch instead of iteration to control training duration 106 | snapshot_path = snapshot_path + '_' + str(args.max_iterations)[0:2] + 'k' if args.max_iterations != 30000 else snapshot_path 107 | snapshot_path = snapshot_path+'_bs'+str(args.batch_size) 108 | snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path 109 | snapshot_path = snapshot_path + '_'+str(args.img_size) 110 | snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path 111 | 112 | config_vit = CONFIGS_ViT_seg[args.vit_name] 113 | config_vit.n_classes = args.num_classes 114 | config_vit.n_skip = args.n_skip 115 | config_vit.patches.size = (args.vit_patches_size, args.vit_patches_size) 116 | if args.vit_name.find('R50') !=-1: 117 | config_vit.patches.grid = (int(args.img_size/args.vit_patches_size), int(args.img_size/args.vit_patches_size)) 118 | net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda() 119 | 120 | snapshot = os.path.join(snapshot_path, 'best_model.pth') 121 | if not os.path.exists(snapshot): snapshot = snapshot.replace('best_model', 'epoch_'+str(args.max_epochs-1)) 122 | net.load_state_dict(torch.load(snapshot)) 123 | snapshot_name = snapshot_path.split('/')[-1] 124 | 125 | log_folder = './test_log/test_log_' + args.exp 126 | os.makedirs(log_folder, exist_ok=True) 127 | logging.basicConfig(filename=log_folder + '/'+snapshot_name+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 128 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 129 | logging.info(str(args)) 130 | logging.info(snapshot_name) 131 | 132 | if args.is_savenii: 133 | args.test_save_dir = '../predictions' 134 | test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name) 135 | os.makedirs(test_save_path, exist_ok=True) 136 | else: 137 | test_save_path = None 138 | inference(args, net, test_save_path) 139 | 140 | 141 | -------------------------------------------------------------------------------- /TransUNet/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from networks.vit_seg_modeling import VisionTransformer as ViT_seg 9 | from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg 10 | from trainer import trainer_synapse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--root_path', type=str, 14 | default='../data/Synapse/train_npz', help='root dir for data') 15 | parser.add_argument('--dataset', type=str, 16 | default='Synapse', help='experiment_name') 17 | parser.add_argument('--list_dir', type=str, 18 | default='./lists/lists_Synapse', help='list dir') 19 | parser.add_argument('--num_classes', type=int, 20 | default=9, help='output channel of network') 21 | parser.add_argument('--max_iterations', type=int, 22 | default=30000, help='maximum epoch number to train') 23 | parser.add_argument('--max_epochs', type=int, 24 | default=150, help='maximum epoch number to train') 25 | parser.add_argument('--batch_size', type=int, 26 | default=24, help='batch_size per gpu') 27 | parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') 28 | parser.add_argument('--deterministic', type=int, default=1, 29 | help='whether use deterministic training') 30 | parser.add_argument('--base_lr', type=float, default=0.01, 31 | help='segmentation network learning rate') 32 | parser.add_argument('--img_size', type=int, 33 | default=224, help='input patch size of network input') 34 | parser.add_argument('--seed', type=int, 35 | default=1234, help='random seed') 36 | parser.add_argument('--n_skip', type=int, 37 | default=3, help='using number of skip-connect, default is num') 38 | parser.add_argument('--vit_name', type=str, 39 | default='R50-ViT-B_16', help='select one vit model') 40 | parser.add_argument('--vit_patches_size', type=int, 41 | default=16, help='vit_patches_size, default is 16') 42 | args = parser.parse_args() 43 | 44 | 45 | if __name__ == "__main__": 46 | if not args.deterministic: 47 | cudnn.benchmark = True 48 | cudnn.deterministic = False 49 | else: 50 | cudnn.benchmark = False 51 | cudnn.deterministic = True 52 | 53 | random.seed(args.seed) 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | torch.cuda.manual_seed(args.seed) 57 | dataset_name = args.dataset 58 | dataset_config = { 59 | 'Synapse': { 60 | 'root_path': '../data/Synapse/train_npz', 61 | 'list_dir': './lists/lists_Synapse', 62 | 'num_classes': 9, 63 | }, 64 | } 65 | args.num_classes = dataset_config[dataset_name]['num_classes'] 66 | args.root_path = dataset_config[dataset_name]['root_path'] 67 | args.list_dir = dataset_config[dataset_name]['list_dir'] 68 | args.is_pretrain = True 69 | args.exp = 'TU_' + dataset_name + str(args.img_size) 70 | snapshot_path = "../model/{}/{}".format(args.exp, 'TU') 71 | snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path 72 | snapshot_path += '_' + args.vit_name 73 | snapshot_path = snapshot_path + '_skip' + str(args.n_skip) 74 | snapshot_path = snapshot_path + '_vitpatch' + str(args.vit_patches_size) if args.vit_patches_size!=16 else snapshot_path 75 | snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 30000 else snapshot_path 76 | snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path 77 | snapshot_path = snapshot_path+'_bs'+str(args.batch_size) 78 | snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path 79 | snapshot_path = snapshot_path + '_'+str(args.img_size) 80 | snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path 81 | 82 | if not os.path.exists(snapshot_path): 83 | os.makedirs(snapshot_path) 84 | config_vit = CONFIGS_ViT_seg[args.vit_name] 85 | config_vit.n_classes = args.num_classes 86 | config_vit.n_skip = args.n_skip 87 | if args.vit_name.find('R50') != -1: 88 | config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size)) 89 | net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda() 90 | net.load_from(weights=np.load(config_vit.pretrained_path)) 91 | 92 | trainer = {'Synapse': trainer_synapse,} 93 | trainer[dataset_name](args, net, snapshot_path) -------------------------------------------------------------------------------- /TransUNet/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import time 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from tensorboardX import SummaryWriter 12 | from torch.nn.modules.loss import CrossEntropyLoss 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | from utils import DiceLoss 16 | from torchvision import transforms 17 | 18 | def trainer_synapse(args, model, snapshot_path): 19 | from datasets.dataset_synapse import Synapse_dataset, RandomGenerator 20 | logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, 21 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 22 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 23 | logging.info(str(args)) 24 | base_lr = args.base_lr 25 | num_classes = args.num_classes 26 | batch_size = args.batch_size * args.n_gpu 27 | # max_iterations = args.max_iterations 28 | db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", 29 | transform=transforms.Compose( 30 | [RandomGenerator(output_size=[args.img_size, args.img_size])])) 31 | print("The length of train set is: {}".format(len(db_train))) 32 | 33 | def worker_init_fn(worker_id): 34 | random.seed(args.seed + worker_id) 35 | 36 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, 37 | worker_init_fn=worker_init_fn) 38 | if args.n_gpu > 1: 39 | model = nn.DataParallel(model) 40 | model.train() 41 | ce_loss = CrossEntropyLoss() 42 | dice_loss = DiceLoss(num_classes) 43 | optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 44 | writer = SummaryWriter(snapshot_path + '/log') 45 | iter_num = 0 46 | max_epoch = args.max_epochs 47 | max_iterations = args.max_epochs * len(trainloader) # max_epoch = max_iterations // len(trainloader) + 1 48 | logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations)) 49 | best_performance = 0.0 50 | iterator = tqdm(range(max_epoch), ncols=70) 51 | for epoch_num in iterator: 52 | for i_batch, sampled_batch in enumerate(trainloader): 53 | image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 54 | image_batch, label_batch = image_batch.cuda(), label_batch.cuda() 55 | outputs = model(image_batch) 56 | loss_ce = ce_loss(outputs, label_batch[:].long()) 57 | loss_dice = dice_loss(outputs, label_batch, softmax=True) 58 | loss = 0.5 * loss_ce + 0.5 * loss_dice 59 | optimizer.zero_grad() 60 | loss.backward() 61 | optimizer.step() 62 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 63 | for param_group in optimizer.param_groups: 64 | param_group['lr'] = lr_ 65 | 66 | iter_num = iter_num + 1 67 | writer.add_scalar('info/lr', lr_, iter_num) 68 | writer.add_scalar('info/total_loss', loss, iter_num) 69 | writer.add_scalar('info/loss_ce', loss_ce, iter_num) 70 | 71 | logging.info('iteration %d : loss : %f, loss_ce: %f' % (iter_num, loss.item(), loss_ce.item())) 72 | 73 | if iter_num % 20 == 0: 74 | image = image_batch[1, 0:1, :, :] 75 | image = (image - image.min()) / (image.max() - image.min()) 76 | writer.add_image('train/Image', image, iter_num) 77 | outputs = torch.argmax(torch.softmax(outputs, dim=1), dim=1, keepdim=True) 78 | writer.add_image('train/Prediction', outputs[1, ...] * 50, iter_num) 79 | labs = label_batch[1, ...].unsqueeze(0) * 50 80 | writer.add_image('train/GroundTruth', labs, iter_num) 81 | 82 | save_interval = 50 # int(max_epoch/6) 83 | if epoch_num > int(max_epoch / 2) and (epoch_num + 1) % save_interval == 0: 84 | save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') 85 | torch.save(model.state_dict(), save_mode_path) 86 | logging.info("save model to {}".format(save_mode_path)) 87 | 88 | if epoch_num >= max_epoch - 1: 89 | save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') 90 | torch.save(model.state_dict(), save_mode_path) 91 | logging.info("save model to {}".format(save_mode_path)) 92 | iterator.close() 93 | break 94 | 95 | writer.close() 96 | return "Training Finished!" -------------------------------------------------------------------------------- /TransUNet/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | import torch.nn as nn 6 | import SimpleITK as sitk 7 | 8 | 9 | class DiceLoss(nn.Module): 10 | def __init__(self, n_classes): 11 | super(DiceLoss, self).__init__() 12 | self.n_classes = n_classes 13 | 14 | def _one_hot_encoder(self, input_tensor): 15 | tensor_list = [] 16 | for i in range(self.n_classes): 17 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 18 | tensor_list.append(temp_prob.unsqueeze(1)) 19 | output_tensor = torch.cat(tensor_list, dim=1) 20 | return output_tensor.float() 21 | 22 | def _dice_loss(self, score, target): 23 | target = target.float() 24 | smooth = 1e-5 25 | intersect = torch.sum(score * target) 26 | y_sum = torch.sum(target * target) 27 | z_sum = torch.sum(score * score) 28 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 29 | loss = 1 - loss 30 | return loss 31 | 32 | def forward(self, inputs, target, weight=None, softmax=False): 33 | if softmax: 34 | inputs = torch.softmax(inputs, dim=1) 35 | target = self._one_hot_encoder(target) 36 | if weight is None: 37 | weight = [1] * self.n_classes 38 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 39 | class_wise_dice = [] 40 | loss = 0.0 41 | for i in range(0, self.n_classes): 42 | dice = self._dice_loss(inputs[:, i], target[:, i]) 43 | class_wise_dice.append(1.0 - dice.item()) 44 | loss += dice * weight[i] 45 | return loss / self.n_classes 46 | 47 | 48 | def calculate_metric_percase(pred, gt): 49 | pred[pred > 0] = 1 50 | gt[gt > 0] = 1 51 | if pred.sum() > 0 and gt.sum()>0: 52 | dice = metric.binary.dc(pred, gt) 53 | hd95 = metric.binary.hd95(pred, gt) 54 | return dice, hd95 55 | elif pred.sum() > 0 and gt.sum()==0: 56 | return 1, 0 57 | else: 58 | return 0, 0 59 | 60 | 61 | def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1): 62 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 63 | if len(image.shape) == 3: 64 | prediction = np.zeros_like(label) 65 | for ind in range(image.shape[0]): 66 | slice = image[ind, :, :] 67 | x, y = slice.shape[0], slice.shape[1] 68 | if x != patch_size[0] or y != patch_size[1]: 69 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0 70 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 71 | net.eval() 72 | with torch.no_grad(): 73 | outputs = net(input) 74 | out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) 75 | out = out.cpu().detach().numpy() 76 | if x != patch_size[0] or y != patch_size[1]: 77 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 78 | else: 79 | pred = out 80 | prediction[ind] = pred 81 | else: 82 | input = torch.from_numpy(image).unsqueeze( 83 | 0).unsqueeze(0).float().cuda() 84 | net.eval() 85 | with torch.no_grad(): 86 | out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) 87 | prediction = out.cpu().detach().numpy() 88 | metric_list = [] 89 | for i in range(1, classes): 90 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 91 | 92 | if test_save_path is not None: 93 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 94 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 95 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 96 | img_itk.SetSpacing((1, 1, z_spacing)) 97 | prd_itk.SetSpacing((1, 1, z_spacing)) 98 | lab_itk.SetSpacing((1, 1, z_spacing)) 99 | sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz") 100 | sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz") 101 | sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz") 102 | return metric_list -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygjwd12345/TransDA/76c76cc42a00ce465c353d51f084eb13d7f53620/data/.DS_Store -------------------------------------------------------------------------------- /data/._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygjwd12345/TransDA/76c76cc42a00ce465c353d51f084eb13d7f53620/data/._.DS_Store -------------------------------------------------------------------------------- /data/VISDA-C/download_visda2017.sh: -------------------------------------------------------------------------------- 1 | wget http://csr.bu.edu/ftp/visda17/clf/train.tar; 2 | tar xvf train.tar; 3 | wget http://csr.bu.edu/ftp/visda17/clf/validation.tar; 4 | tar xvf validation.tar; 5 | 6 | wget http://csr.bu.edu/ftp/visda17/clf/test.tar; 7 | tar xvf test.tar; -------------------------------------------------------------------------------- /data/generate_label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets as datasets 3 | import torch 4 | import glob 5 | 6 | phase=['amazon','caltech','dslr','webcam'] 7 | 8 | dict={'back_pack': 0, 9 | 'bike':1, 10 | 'calculator':2, 11 | 'headphones':3, 12 | 'keyboard':4, 13 | 'laptop_computer':5, 14 | 'monitor':6, 15 | 'mouse':7, 16 | 'mug':8, 17 | 'projector':9} 18 | for i in range(len(phase)): 19 | path = os.path.join('/data2/gyang/DA-transformer-other/object/data/office_caltech',phase[i]) 20 | text_path=phase[i]+'_list.txt' 21 | f=open(text_path,'w') 22 | 23 | for label in os.listdir(path): 24 | img_list = glob.glob(os.path.join(path, label, "*.jpg")) 25 | for img in img_list: 26 | f.write(img + " " + str(dict[label]) + "\n") 27 | print("create txt done...") -------------------------------------------------------------------------------- /data/office-caltech/dslr_list.txt: -------------------------------------------------------------------------------- 1 | ./data/office-caltech/dslr/headphones/frame_0007.jpg 3 2 | ./data/office-caltech/dslr/headphones/frame_0001.jpg 3 3 | ./data/office-caltech/dslr/headphones/frame_0004.jpg 3 4 | ./data/office-caltech/dslr/headphones/frame_0002.jpg 3 5 | ./data/office-caltech/dslr/headphones/frame_0012.jpg 3 6 | ./data/office-caltech/dslr/headphones/frame_0013.jpg 3 7 | ./data/office-caltech/dslr/headphones/frame_0008.jpg 3 8 | ./data/office-caltech/dslr/headphones/frame_0010.jpg 3 9 | ./data/office-caltech/dslr/headphones/frame_0005.jpg 3 10 | ./data/office-caltech/dslr/headphones/frame_0003.jpg 3 11 | ./data/office-caltech/dslr/headphones/frame_0011.jpg 3 12 | ./data/office-caltech/dslr/headphones/frame_0006.jpg 3 13 | ./data/office-caltech/dslr/headphones/frame_0009.jpg 3 14 | ./data/office-caltech/dslr/laptop_computer/frame_0014.jpg 5 15 | ./data/office-caltech/dslr/laptop_computer/frame_0017.jpg 5 16 | ./data/office-caltech/dslr/laptop_computer/frame_0007.jpg 5 17 | ./data/office-caltech/dslr/laptop_computer/frame_0001.jpg 5 18 | ./data/office-caltech/dslr/laptop_computer/frame_0004.jpg 5 19 | ./data/office-caltech/dslr/laptop_computer/frame_0024.jpg 5 20 | ./data/office-caltech/dslr/laptop_computer/frame_0002.jpg 5 21 | ./data/office-caltech/dslr/laptop_computer/frame_0023.jpg 5 22 | ./data/office-caltech/dslr/laptop_computer/frame_0022.jpg 5 23 | ./data/office-caltech/dslr/laptop_computer/frame_0018.jpg 5 24 | ./data/office-caltech/dslr/laptop_computer/frame_0012.jpg 5 25 | ./data/office-caltech/dslr/laptop_computer/frame_0021.jpg 5 26 | ./data/office-caltech/dslr/laptop_computer/frame_0013.jpg 5 27 | ./data/office-caltech/dslr/laptop_computer/frame_0020.jpg 5 28 | ./data/office-caltech/dslr/laptop_computer/frame_0008.jpg 5 29 | ./data/office-caltech/dslr/laptop_computer/frame_0010.jpg 5 30 | ./data/office-caltech/dslr/laptop_computer/frame_0005.jpg 5 31 | ./data/office-caltech/dslr/laptop_computer/frame_0003.jpg 5 32 | ./data/office-caltech/dslr/laptop_computer/frame_0011.jpg 5 33 | ./data/office-caltech/dslr/laptop_computer/frame_0015.jpg 5 34 | ./data/office-caltech/dslr/laptop_computer/frame_0016.jpg 5 35 | ./data/office-caltech/dslr/laptop_computer/frame_0006.jpg 5 36 | ./data/office-caltech/dslr/laptop_computer/frame_0009.jpg 5 37 | ./data/office-caltech/dslr/laptop_computer/frame_0019.jpg 5 38 | ./data/office-caltech/dslr/back_pack/frame_0007.jpg 0 39 | ./data/office-caltech/dslr/back_pack/frame_0001.jpg 0 40 | ./data/office-caltech/dslr/back_pack/frame_0004.jpg 0 41 | ./data/office-caltech/dslr/back_pack/frame_0002.jpg 0 42 | ./data/office-caltech/dslr/back_pack/frame_0012.jpg 0 43 | ./data/office-caltech/dslr/back_pack/frame_0008.jpg 0 44 | ./data/office-caltech/dslr/back_pack/frame_0010.jpg 0 45 | ./data/office-caltech/dslr/back_pack/frame_0005.jpg 0 46 | ./data/office-caltech/dslr/back_pack/frame_0003.jpg 0 47 | ./data/office-caltech/dslr/back_pack/frame_0011.jpg 0 48 | ./data/office-caltech/dslr/back_pack/frame_0006.jpg 0 49 | ./data/office-caltech/dslr/back_pack/frame_0009.jpg 0 50 | ./data/office-caltech/dslr/bike/frame_0014.jpg 1 51 | ./data/office-caltech/dslr/bike/frame_0017.jpg 1 52 | ./data/office-caltech/dslr/bike/frame_0007.jpg 1 53 | ./data/office-caltech/dslr/bike/frame_0001.jpg 1 54 | ./data/office-caltech/dslr/bike/frame_0004.jpg 1 55 | ./data/office-caltech/dslr/bike/frame_0002.jpg 1 56 | ./data/office-caltech/dslr/bike/frame_0018.jpg 1 57 | ./data/office-caltech/dslr/bike/frame_0012.jpg 1 58 | ./data/office-caltech/dslr/bike/frame_0021.jpg 1 59 | ./data/office-caltech/dslr/bike/frame_0013.jpg 1 60 | ./data/office-caltech/dslr/bike/frame_0020.jpg 1 61 | ./data/office-caltech/dslr/bike/frame_0008.jpg 1 62 | ./data/office-caltech/dslr/bike/frame_0010.jpg 1 63 | ./data/office-caltech/dslr/bike/frame_0005.jpg 1 64 | ./data/office-caltech/dslr/bike/frame_0003.jpg 1 65 | ./data/office-caltech/dslr/bike/frame_0011.jpg 1 66 | ./data/office-caltech/dslr/bike/frame_0015.jpg 1 67 | ./data/office-caltech/dslr/bike/frame_0016.jpg 1 68 | ./data/office-caltech/dslr/bike/frame_0006.jpg 1 69 | ./data/office-caltech/dslr/bike/frame_0009.jpg 1 70 | ./data/office-caltech/dslr/bike/frame_0019.jpg 1 71 | ./data/office-caltech/dslr/monitor/frame_0014.jpg 6 72 | ./data/office-caltech/dslr/monitor/frame_0017.jpg 6 73 | ./data/office-caltech/dslr/monitor/frame_0007.jpg 6 74 | ./data/office-caltech/dslr/monitor/frame_0001.jpg 6 75 | ./data/office-caltech/dslr/monitor/frame_0004.jpg 6 76 | ./data/office-caltech/dslr/monitor/frame_0002.jpg 6 77 | ./data/office-caltech/dslr/monitor/frame_0022.jpg 6 78 | ./data/office-caltech/dslr/monitor/frame_0018.jpg 6 79 | ./data/office-caltech/dslr/monitor/frame_0012.jpg 6 80 | ./data/office-caltech/dslr/monitor/frame_0021.jpg 6 81 | ./data/office-caltech/dslr/monitor/frame_0013.jpg 6 82 | ./data/office-caltech/dslr/monitor/frame_0020.jpg 6 83 | ./data/office-caltech/dslr/monitor/frame_0008.jpg 6 84 | ./data/office-caltech/dslr/monitor/frame_0010.jpg 6 85 | ./data/office-caltech/dslr/monitor/frame_0005.jpg 6 86 | ./data/office-caltech/dslr/monitor/frame_0003.jpg 6 87 | ./data/office-caltech/dslr/monitor/frame_0011.jpg 6 88 | ./data/office-caltech/dslr/monitor/frame_0015.jpg 6 89 | ./data/office-caltech/dslr/monitor/frame_0016.jpg 6 90 | ./data/office-caltech/dslr/monitor/frame_0006.jpg 6 91 | ./data/office-caltech/dslr/monitor/frame_0009.jpg 6 92 | ./data/office-caltech/dslr/monitor/frame_0019.jpg 6 93 | ./data/office-caltech/dslr/calculator/frame_0007.jpg 2 94 | ./data/office-caltech/dslr/calculator/frame_0001.jpg 2 95 | ./data/office-caltech/dslr/calculator/frame_0004.jpg 2 96 | ./data/office-caltech/dslr/calculator/frame_0002.jpg 2 97 | ./data/office-caltech/dslr/calculator/frame_0012.jpg 2 98 | ./data/office-caltech/dslr/calculator/frame_0008.jpg 2 99 | ./data/office-caltech/dslr/calculator/frame_0010.jpg 2 100 | ./data/office-caltech/dslr/calculator/frame_0005.jpg 2 101 | ./data/office-caltech/dslr/calculator/frame_0003.jpg 2 102 | ./data/office-caltech/dslr/calculator/frame_0011.jpg 2 103 | ./data/office-caltech/dslr/calculator/frame_0006.jpg 2 104 | ./data/office-caltech/dslr/calculator/frame_0009.jpg 2 105 | ./data/office-caltech/dslr/mug/frame_0007.jpg 8 106 | ./data/office-caltech/dslr/mug/frame_0001.jpg 8 107 | ./data/office-caltech/dslr/mug/frame_0004.jpg 8 108 | ./data/office-caltech/dslr/mug/frame_0002.jpg 8 109 | ./data/office-caltech/dslr/mug/frame_0008.jpg 8 110 | ./data/office-caltech/dslr/mug/frame_0005.jpg 8 111 | ./data/office-caltech/dslr/mug/frame_0003.jpg 8 112 | ./data/office-caltech/dslr/mug/frame_0006.jpg 8 113 | ./data/office-caltech/dslr/keyboard/frame_0007.jpg 4 114 | ./data/office-caltech/dslr/keyboard/frame_0001.jpg 4 115 | ./data/office-caltech/dslr/keyboard/frame_0004.jpg 4 116 | ./data/office-caltech/dslr/keyboard/frame_0002.jpg 4 117 | ./data/office-caltech/dslr/keyboard/frame_0008.jpg 4 118 | ./data/office-caltech/dslr/keyboard/frame_0010.jpg 4 119 | ./data/office-caltech/dslr/keyboard/frame_0005.jpg 4 120 | ./data/office-caltech/dslr/keyboard/frame_0003.jpg 4 121 | ./data/office-caltech/dslr/keyboard/frame_0006.jpg 4 122 | ./data/office-caltech/dslr/keyboard/frame_0009.jpg 4 123 | ./data/office-caltech/dslr/mouse/frame_0007.jpg 7 124 | ./data/office-caltech/dslr/mouse/frame_0001.jpg 7 125 | ./data/office-caltech/dslr/mouse/frame_0004.jpg 7 126 | ./data/office-caltech/dslr/mouse/frame_0002.jpg 7 127 | ./data/office-caltech/dslr/mouse/frame_0012.jpg 7 128 | ./data/office-caltech/dslr/mouse/frame_0008.jpg 7 129 | ./data/office-caltech/dslr/mouse/frame_0010.jpg 7 130 | ./data/office-caltech/dslr/mouse/frame_0005.jpg 7 131 | ./data/office-caltech/dslr/mouse/frame_0003.jpg 7 132 | ./data/office-caltech/dslr/mouse/frame_0011.jpg 7 133 | ./data/office-caltech/dslr/mouse/frame_0006.jpg 7 134 | ./data/office-caltech/dslr/mouse/frame_0009.jpg 7 135 | ./data/office-caltech/dslr/projector/frame_0014.jpg 9 136 | ./data/office-caltech/dslr/projector/frame_0017.jpg 9 137 | ./data/office-caltech/dslr/projector/frame_0007.jpg 9 138 | ./data/office-caltech/dslr/projector/frame_0001.jpg 9 139 | ./data/office-caltech/dslr/projector/frame_0004.jpg 9 140 | ./data/office-caltech/dslr/projector/frame_0002.jpg 9 141 | ./data/office-caltech/dslr/projector/frame_0023.jpg 9 142 | ./data/office-caltech/dslr/projector/frame_0022.jpg 9 143 | ./data/office-caltech/dslr/projector/frame_0018.jpg 9 144 | ./data/office-caltech/dslr/projector/frame_0012.jpg 9 145 | ./data/office-caltech/dslr/projector/frame_0021.jpg 9 146 | ./data/office-caltech/dslr/projector/frame_0013.jpg 9 147 | ./data/office-caltech/dslr/projector/frame_0020.jpg 9 148 | ./data/office-caltech/dslr/projector/frame_0008.jpg 9 149 | ./data/office-caltech/dslr/projector/frame_0010.jpg 9 150 | ./data/office-caltech/dslr/projector/frame_0005.jpg 9 151 | ./data/office-caltech/dslr/projector/frame_0003.jpg 9 152 | ./data/office-caltech/dslr/projector/frame_0011.jpg 9 153 | ./data/office-caltech/dslr/projector/frame_0015.jpg 9 154 | ./data/office-caltech/dslr/projector/frame_0016.jpg 9 155 | ./data/office-caltech/dslr/projector/frame_0006.jpg 9 156 | ./data/office-caltech/dslr/projector/frame_0009.jpg 9 157 | ./data/office-caltech/dslr/projector/frame_0019.jpg 9 158 | -------------------------------------------------------------------------------- /data/office-caltech/webcam_list.txt: -------------------------------------------------------------------------------- 1 | ./data/office-caltech/webcam/headphones/frame_0014.jpg 3 2 | ./data/office-caltech/webcam/headphones/frame_0017.jpg 3 3 | ./data/office-caltech/webcam/headphones/frame_0007.jpg 3 4 | ./data/office-caltech/webcam/headphones/frame_0001.jpg 3 5 | ./data/office-caltech/webcam/headphones/frame_0004.jpg 3 6 | ./data/office-caltech/webcam/headphones/frame_0025.jpg 3 7 | ./data/office-caltech/webcam/headphones/frame_0024.jpg 3 8 | ./data/office-caltech/webcam/headphones/frame_0002.jpg 3 9 | ./data/office-caltech/webcam/headphones/frame_0023.jpg 3 10 | ./data/office-caltech/webcam/headphones/frame_0022.jpg 3 11 | ./data/office-caltech/webcam/headphones/frame_0018.jpg 3 12 | ./data/office-caltech/webcam/headphones/frame_0012.jpg 3 13 | ./data/office-caltech/webcam/headphones/frame_0021.jpg 3 14 | ./data/office-caltech/webcam/headphones/frame_0013.jpg 3 15 | ./data/office-caltech/webcam/headphones/frame_0020.jpg 3 16 | ./data/office-caltech/webcam/headphones/frame_0026.jpg 3 17 | ./data/office-caltech/webcam/headphones/frame_0008.jpg 3 18 | ./data/office-caltech/webcam/headphones/frame_0010.jpg 3 19 | ./data/office-caltech/webcam/headphones/frame_0027.jpg 3 20 | ./data/office-caltech/webcam/headphones/frame_0005.jpg 3 21 | ./data/office-caltech/webcam/headphones/frame_0003.jpg 3 22 | ./data/office-caltech/webcam/headphones/frame_0011.jpg 3 23 | ./data/office-caltech/webcam/headphones/frame_0015.jpg 3 24 | ./data/office-caltech/webcam/headphones/frame_0016.jpg 3 25 | ./data/office-caltech/webcam/headphones/frame_0006.jpg 3 26 | ./data/office-caltech/webcam/headphones/frame_0009.jpg 3 27 | ./data/office-caltech/webcam/headphones/frame_0019.jpg 3 28 | ./data/office-caltech/webcam/laptop_computer/frame_0014.jpg 5 29 | ./data/office-caltech/webcam/laptop_computer/frame_0017.jpg 5 30 | ./data/office-caltech/webcam/laptop_computer/frame_0007.jpg 5 31 | ./data/office-caltech/webcam/laptop_computer/frame_0001.jpg 5 32 | ./data/office-caltech/webcam/laptop_computer/frame_0004.jpg 5 33 | ./data/office-caltech/webcam/laptop_computer/frame_0025.jpg 5 34 | ./data/office-caltech/webcam/laptop_computer/frame_0024.jpg 5 35 | ./data/office-caltech/webcam/laptop_computer/frame_0002.jpg 5 36 | ./data/office-caltech/webcam/laptop_computer/frame_0029.jpg 5 37 | ./data/office-caltech/webcam/laptop_computer/frame_0023.jpg 5 38 | ./data/office-caltech/webcam/laptop_computer/frame_0022.jpg 5 39 | ./data/office-caltech/webcam/laptop_computer/frame_0018.jpg 5 40 | ./data/office-caltech/webcam/laptop_computer/frame_0030.jpg 5 41 | ./data/office-caltech/webcam/laptop_computer/frame_0012.jpg 5 42 | ./data/office-caltech/webcam/laptop_computer/frame_0021.jpg 5 43 | ./data/office-caltech/webcam/laptop_computer/frame_0013.jpg 5 44 | ./data/office-caltech/webcam/laptop_computer/frame_0020.jpg 5 45 | ./data/office-caltech/webcam/laptop_computer/frame_0026.jpg 5 46 | ./data/office-caltech/webcam/laptop_computer/frame_0008.jpg 5 47 | ./data/office-caltech/webcam/laptop_computer/frame_0010.jpg 5 48 | ./data/office-caltech/webcam/laptop_computer/frame_0027.jpg 5 49 | ./data/office-caltech/webcam/laptop_computer/frame_0005.jpg 5 50 | ./data/office-caltech/webcam/laptop_computer/frame_0003.jpg 5 51 | ./data/office-caltech/webcam/laptop_computer/frame_0011.jpg 5 52 | ./data/office-caltech/webcam/laptop_computer/frame_0015.jpg 5 53 | ./data/office-caltech/webcam/laptop_computer/frame_0016.jpg 5 54 | ./data/office-caltech/webcam/laptop_computer/frame_0028.jpg 5 55 | ./data/office-caltech/webcam/laptop_computer/frame_0006.jpg 5 56 | ./data/office-caltech/webcam/laptop_computer/frame_0009.jpg 5 57 | ./data/office-caltech/webcam/laptop_computer/frame_0019.jpg 5 58 | ./data/office-caltech/webcam/back_pack/frame_0014.jpg 0 59 | ./data/office-caltech/webcam/back_pack/frame_0017.jpg 0 60 | ./data/office-caltech/webcam/back_pack/frame_0007.jpg 0 61 | ./data/office-caltech/webcam/back_pack/frame_0001.jpg 0 62 | ./data/office-caltech/webcam/back_pack/frame_0004.jpg 0 63 | ./data/office-caltech/webcam/back_pack/frame_0025.jpg 0 64 | ./data/office-caltech/webcam/back_pack/frame_0024.jpg 0 65 | ./data/office-caltech/webcam/back_pack/frame_0002.jpg 0 66 | ./data/office-caltech/webcam/back_pack/frame_0029.jpg 0 67 | ./data/office-caltech/webcam/back_pack/frame_0023.jpg 0 68 | ./data/office-caltech/webcam/back_pack/frame_0022.jpg 0 69 | ./data/office-caltech/webcam/back_pack/frame_0018.jpg 0 70 | ./data/office-caltech/webcam/back_pack/frame_0012.jpg 0 71 | ./data/office-caltech/webcam/back_pack/frame_0021.jpg 0 72 | ./data/office-caltech/webcam/back_pack/frame_0013.jpg 0 73 | ./data/office-caltech/webcam/back_pack/frame_0020.jpg 0 74 | ./data/office-caltech/webcam/back_pack/frame_0026.jpg 0 75 | ./data/office-caltech/webcam/back_pack/frame_0008.jpg 0 76 | ./data/office-caltech/webcam/back_pack/frame_0010.jpg 0 77 | ./data/office-caltech/webcam/back_pack/frame_0027.jpg 0 78 | ./data/office-caltech/webcam/back_pack/frame_0005.jpg 0 79 | ./data/office-caltech/webcam/back_pack/frame_0003.jpg 0 80 | ./data/office-caltech/webcam/back_pack/frame_0011.jpg 0 81 | ./data/office-caltech/webcam/back_pack/frame_0015.jpg 0 82 | ./data/office-caltech/webcam/back_pack/frame_0016.jpg 0 83 | ./data/office-caltech/webcam/back_pack/frame_0028.jpg 0 84 | ./data/office-caltech/webcam/back_pack/frame_0006.jpg 0 85 | ./data/office-caltech/webcam/back_pack/frame_0009.jpg 0 86 | ./data/office-caltech/webcam/back_pack/frame_0019.jpg 0 87 | ./data/office-caltech/webcam/bike/frame_0014.jpg 1 88 | ./data/office-caltech/webcam/bike/frame_0017.jpg 1 89 | ./data/office-caltech/webcam/bike/frame_0007.jpg 1 90 | ./data/office-caltech/webcam/bike/frame_0001.jpg 1 91 | ./data/office-caltech/webcam/bike/frame_0004.jpg 1 92 | ./data/office-caltech/webcam/bike/frame_0002.jpg 1 93 | ./data/office-caltech/webcam/bike/frame_0018.jpg 1 94 | ./data/office-caltech/webcam/bike/frame_0012.jpg 1 95 | ./data/office-caltech/webcam/bike/frame_0021.jpg 1 96 | ./data/office-caltech/webcam/bike/frame_0013.jpg 1 97 | ./data/office-caltech/webcam/bike/frame_0020.jpg 1 98 | ./data/office-caltech/webcam/bike/frame_0008.jpg 1 99 | ./data/office-caltech/webcam/bike/frame_0010.jpg 1 100 | ./data/office-caltech/webcam/bike/frame_0005.jpg 1 101 | ./data/office-caltech/webcam/bike/frame_0003.jpg 1 102 | ./data/office-caltech/webcam/bike/frame_0011.jpg 1 103 | ./data/office-caltech/webcam/bike/frame_0015.jpg 1 104 | ./data/office-caltech/webcam/bike/frame_0016.jpg 1 105 | ./data/office-caltech/webcam/bike/frame_0006.jpg 1 106 | ./data/office-caltech/webcam/bike/frame_0009.jpg 1 107 | ./data/office-caltech/webcam/bike/frame_0019.jpg 1 108 | ./data/office-caltech/webcam/monitor/frame_0014.jpg 6 109 | ./data/office-caltech/webcam/monitor/frame_0017.jpg 6 110 | ./data/office-caltech/webcam/monitor/frame_0035.jpg 6 111 | ./data/office-caltech/webcam/monitor/frame_0033.jpg 6 112 | ./data/office-caltech/webcam/monitor/frame_0007.jpg 6 113 | ./data/office-caltech/webcam/monitor/frame_0037.jpg 6 114 | ./data/office-caltech/webcam/monitor/frame_0001.jpg 6 115 | ./data/office-caltech/webcam/monitor/frame_0039.jpg 6 116 | ./data/office-caltech/webcam/monitor/frame_0004.jpg 6 117 | ./data/office-caltech/webcam/monitor/frame_0025.jpg 6 118 | ./data/office-caltech/webcam/monitor/frame_0024.jpg 6 119 | ./data/office-caltech/webcam/monitor/frame_0002.jpg 6 120 | ./data/office-caltech/webcam/monitor/frame_0029.jpg 6 121 | ./data/office-caltech/webcam/monitor/frame_0023.jpg 6 122 | ./data/office-caltech/webcam/monitor/frame_0022.jpg 6 123 | ./data/office-caltech/webcam/monitor/frame_0036.jpg 6 124 | ./data/office-caltech/webcam/monitor/frame_0038.jpg 6 125 | ./data/office-caltech/webcam/monitor/frame_0018.jpg 6 126 | ./data/office-caltech/webcam/monitor/frame_0030.jpg 6 127 | ./data/office-caltech/webcam/monitor/frame_0012.jpg 6 128 | ./data/office-caltech/webcam/monitor/frame_0021.jpg 6 129 | ./data/office-caltech/webcam/monitor/frame_0031.jpg 6 130 | ./data/office-caltech/webcam/monitor/frame_0013.jpg 6 131 | ./data/office-caltech/webcam/monitor/frame_0020.jpg 6 132 | ./data/office-caltech/webcam/monitor/frame_0026.jpg 6 133 | ./data/office-caltech/webcam/monitor/frame_0008.jpg 6 134 | ./data/office-caltech/webcam/monitor/frame_0032.jpg 6 135 | ./data/office-caltech/webcam/monitor/frame_0010.jpg 6 136 | ./data/office-caltech/webcam/monitor/frame_0027.jpg 6 137 | ./data/office-caltech/webcam/monitor/frame_0041.jpg 6 138 | ./data/office-caltech/webcam/monitor/frame_0040.jpg 6 139 | ./data/office-caltech/webcam/monitor/frame_0034.jpg 6 140 | ./data/office-caltech/webcam/monitor/frame_0043.jpg 6 141 | ./data/office-caltech/webcam/monitor/frame_0005.jpg 6 142 | ./data/office-caltech/webcam/monitor/frame_0042.jpg 6 143 | ./data/office-caltech/webcam/monitor/frame_0003.jpg 6 144 | ./data/office-caltech/webcam/monitor/frame_0011.jpg 6 145 | ./data/office-caltech/webcam/monitor/frame_0015.jpg 6 146 | ./data/office-caltech/webcam/monitor/frame_0016.jpg 6 147 | ./data/office-caltech/webcam/monitor/frame_0028.jpg 6 148 | ./data/office-caltech/webcam/monitor/frame_0006.jpg 6 149 | ./data/office-caltech/webcam/monitor/frame_0009.jpg 6 150 | ./data/office-caltech/webcam/monitor/frame_0019.jpg 6 151 | ./data/office-caltech/webcam/calculator/frame_0014.jpg 2 152 | ./data/office-caltech/webcam/calculator/frame_0017.jpg 2 153 | ./data/office-caltech/webcam/calculator/frame_0007.jpg 2 154 | ./data/office-caltech/webcam/calculator/frame_0001.jpg 2 155 | ./data/office-caltech/webcam/calculator/frame_0004.jpg 2 156 | ./data/office-caltech/webcam/calculator/frame_0025.jpg 2 157 | ./data/office-caltech/webcam/calculator/frame_0024.jpg 2 158 | ./data/office-caltech/webcam/calculator/frame_0002.jpg 2 159 | ./data/office-caltech/webcam/calculator/frame_0029.jpg 2 160 | ./data/office-caltech/webcam/calculator/frame_0023.jpg 2 161 | ./data/office-caltech/webcam/calculator/frame_0022.jpg 2 162 | ./data/office-caltech/webcam/calculator/frame_0018.jpg 2 163 | ./data/office-caltech/webcam/calculator/frame_0030.jpg 2 164 | ./data/office-caltech/webcam/calculator/frame_0012.jpg 2 165 | ./data/office-caltech/webcam/calculator/frame_0021.jpg 2 166 | ./data/office-caltech/webcam/calculator/frame_0031.jpg 2 167 | ./data/office-caltech/webcam/calculator/frame_0013.jpg 2 168 | ./data/office-caltech/webcam/calculator/frame_0020.jpg 2 169 | ./data/office-caltech/webcam/calculator/frame_0026.jpg 2 170 | ./data/office-caltech/webcam/calculator/frame_0008.jpg 2 171 | ./data/office-caltech/webcam/calculator/frame_0010.jpg 2 172 | ./data/office-caltech/webcam/calculator/frame_0027.jpg 2 173 | ./data/office-caltech/webcam/calculator/frame_0005.jpg 2 174 | ./data/office-caltech/webcam/calculator/frame_0003.jpg 2 175 | ./data/office-caltech/webcam/calculator/frame_0011.jpg 2 176 | ./data/office-caltech/webcam/calculator/frame_0015.jpg 2 177 | ./data/office-caltech/webcam/calculator/frame_0016.jpg 2 178 | ./data/office-caltech/webcam/calculator/frame_0028.jpg 2 179 | ./data/office-caltech/webcam/calculator/frame_0006.jpg 2 180 | ./data/office-caltech/webcam/calculator/frame_0009.jpg 2 181 | ./data/office-caltech/webcam/calculator/frame_0019.jpg 2 182 | ./data/office-caltech/webcam/mug/frame_0014.jpg 8 183 | ./data/office-caltech/webcam/mug/frame_0017.jpg 8 184 | ./data/office-caltech/webcam/mug/frame_0007.jpg 8 185 | ./data/office-caltech/webcam/mug/frame_0001.jpg 8 186 | ./data/office-caltech/webcam/mug/frame_0004.jpg 8 187 | ./data/office-caltech/webcam/mug/frame_0025.jpg 8 188 | ./data/office-caltech/webcam/mug/frame_0024.jpg 8 189 | ./data/office-caltech/webcam/mug/frame_0002.jpg 8 190 | ./data/office-caltech/webcam/mug/frame_0023.jpg 8 191 | ./data/office-caltech/webcam/mug/frame_0022.jpg 8 192 | ./data/office-caltech/webcam/mug/frame_0018.jpg 8 193 | ./data/office-caltech/webcam/mug/frame_0012.jpg 8 194 | ./data/office-caltech/webcam/mug/frame_0021.jpg 8 195 | ./data/office-caltech/webcam/mug/frame_0013.jpg 8 196 | ./data/office-caltech/webcam/mug/frame_0020.jpg 8 197 | ./data/office-caltech/webcam/mug/frame_0026.jpg 8 198 | ./data/office-caltech/webcam/mug/frame_0008.jpg 8 199 | ./data/office-caltech/webcam/mug/frame_0010.jpg 8 200 | ./data/office-caltech/webcam/mug/frame_0027.jpg 8 201 | ./data/office-caltech/webcam/mug/frame_0005.jpg 8 202 | ./data/office-caltech/webcam/mug/frame_0003.jpg 8 203 | ./data/office-caltech/webcam/mug/frame_0011.jpg 8 204 | ./data/office-caltech/webcam/mug/frame_0015.jpg 8 205 | ./data/office-caltech/webcam/mug/frame_0016.jpg 8 206 | ./data/office-caltech/webcam/mug/frame_0006.jpg 8 207 | ./data/office-caltech/webcam/mug/frame_0009.jpg 8 208 | ./data/office-caltech/webcam/mug/frame_0019.jpg 8 209 | ./data/office-caltech/webcam/keyboard/frame_0014.jpg 4 210 | ./data/office-caltech/webcam/keyboard/frame_0017.jpg 4 211 | ./data/office-caltech/webcam/keyboard/frame_0007.jpg 4 212 | ./data/office-caltech/webcam/keyboard/frame_0001.jpg 4 213 | ./data/office-caltech/webcam/keyboard/frame_0004.jpg 4 214 | ./data/office-caltech/webcam/keyboard/frame_0025.jpg 4 215 | ./data/office-caltech/webcam/keyboard/frame_0024.jpg 4 216 | ./data/office-caltech/webcam/keyboard/frame_0002.jpg 4 217 | ./data/office-caltech/webcam/keyboard/frame_0023.jpg 4 218 | ./data/office-caltech/webcam/keyboard/frame_0022.jpg 4 219 | ./data/office-caltech/webcam/keyboard/frame_0018.jpg 4 220 | ./data/office-caltech/webcam/keyboard/frame_0012.jpg 4 221 | ./data/office-caltech/webcam/keyboard/frame_0021.jpg 4 222 | ./data/office-caltech/webcam/keyboard/frame_0013.jpg 4 223 | ./data/office-caltech/webcam/keyboard/frame_0020.jpg 4 224 | ./data/office-caltech/webcam/keyboard/frame_0026.jpg 4 225 | ./data/office-caltech/webcam/keyboard/frame_0008.jpg 4 226 | ./data/office-caltech/webcam/keyboard/frame_0010.jpg 4 227 | ./data/office-caltech/webcam/keyboard/frame_0027.jpg 4 228 | ./data/office-caltech/webcam/keyboard/frame_0005.jpg 4 229 | ./data/office-caltech/webcam/keyboard/frame_0003.jpg 4 230 | ./data/office-caltech/webcam/keyboard/frame_0011.jpg 4 231 | ./data/office-caltech/webcam/keyboard/frame_0015.jpg 4 232 | ./data/office-caltech/webcam/keyboard/frame_0016.jpg 4 233 | ./data/office-caltech/webcam/keyboard/frame_0006.jpg 4 234 | ./data/office-caltech/webcam/keyboard/frame_0009.jpg 4 235 | ./data/office-caltech/webcam/keyboard/frame_0019.jpg 4 236 | ./data/office-caltech/webcam/mouse/frame_0014.jpg 7 237 | ./data/office-caltech/webcam/mouse/frame_0017.jpg 7 238 | ./data/office-caltech/webcam/mouse/frame_0007.jpg 7 239 | ./data/office-caltech/webcam/mouse/frame_0001.jpg 7 240 | ./data/office-caltech/webcam/mouse/frame_0004.jpg 7 241 | ./data/office-caltech/webcam/mouse/frame_0025.jpg 7 242 | ./data/office-caltech/webcam/mouse/frame_0024.jpg 7 243 | ./data/office-caltech/webcam/mouse/frame_0002.jpg 7 244 | ./data/office-caltech/webcam/mouse/frame_0029.jpg 7 245 | ./data/office-caltech/webcam/mouse/frame_0023.jpg 7 246 | ./data/office-caltech/webcam/mouse/frame_0022.jpg 7 247 | ./data/office-caltech/webcam/mouse/frame_0018.jpg 7 248 | ./data/office-caltech/webcam/mouse/frame_0030.jpg 7 249 | ./data/office-caltech/webcam/mouse/frame_0012.jpg 7 250 | ./data/office-caltech/webcam/mouse/frame_0021.jpg 7 251 | ./data/office-caltech/webcam/mouse/frame_0013.jpg 7 252 | ./data/office-caltech/webcam/mouse/frame_0020.jpg 7 253 | ./data/office-caltech/webcam/mouse/frame_0026.jpg 7 254 | ./data/office-caltech/webcam/mouse/frame_0008.jpg 7 255 | ./data/office-caltech/webcam/mouse/frame_0010.jpg 7 256 | ./data/office-caltech/webcam/mouse/frame_0027.jpg 7 257 | ./data/office-caltech/webcam/mouse/frame_0005.jpg 7 258 | ./data/office-caltech/webcam/mouse/frame_0003.jpg 7 259 | ./data/office-caltech/webcam/mouse/frame_0011.jpg 7 260 | ./data/office-caltech/webcam/mouse/frame_0015.jpg 7 261 | ./data/office-caltech/webcam/mouse/frame_0016.jpg 7 262 | ./data/office-caltech/webcam/mouse/frame_0028.jpg 7 263 | ./data/office-caltech/webcam/mouse/frame_0006.jpg 7 264 | ./data/office-caltech/webcam/mouse/frame_0009.jpg 7 265 | ./data/office-caltech/webcam/mouse/frame_0019.jpg 7 266 | ./data/office-caltech/webcam/projector/frame_0014.jpg 9 267 | ./data/office-caltech/webcam/projector/frame_0017.jpg 9 268 | ./data/office-caltech/webcam/projector/frame_0007.jpg 9 269 | ./data/office-caltech/webcam/projector/frame_0001.jpg 9 270 | ./data/office-caltech/webcam/projector/frame_0004.jpg 9 271 | ./data/office-caltech/webcam/projector/frame_0025.jpg 9 272 | ./data/office-caltech/webcam/projector/frame_0024.jpg 9 273 | ./data/office-caltech/webcam/projector/frame_0002.jpg 9 274 | ./data/office-caltech/webcam/projector/frame_0029.jpg 9 275 | ./data/office-caltech/webcam/projector/frame_0023.jpg 9 276 | ./data/office-caltech/webcam/projector/frame_0022.jpg 9 277 | ./data/office-caltech/webcam/projector/frame_0018.jpg 9 278 | ./data/office-caltech/webcam/projector/frame_0030.jpg 9 279 | ./data/office-caltech/webcam/projector/frame_0012.jpg 9 280 | ./data/office-caltech/webcam/projector/frame_0021.jpg 9 281 | ./data/office-caltech/webcam/projector/frame_0013.jpg 9 282 | ./data/office-caltech/webcam/projector/frame_0020.jpg 9 283 | ./data/office-caltech/webcam/projector/frame_0026.jpg 9 284 | ./data/office-caltech/webcam/projector/frame_0008.jpg 9 285 | ./data/office-caltech/webcam/projector/frame_0010.jpg 9 286 | ./data/office-caltech/webcam/projector/frame_0027.jpg 9 287 | ./data/office-caltech/webcam/projector/frame_0005.jpg 9 288 | ./data/office-caltech/webcam/projector/frame_0003.jpg 9 289 | ./data/office-caltech/webcam/projector/frame_0011.jpg 9 290 | ./data/office-caltech/webcam/projector/frame_0015.jpg 9 291 | ./data/office-caltech/webcam/projector/frame_0016.jpg 9 292 | ./data/office-caltech/webcam/projector/frame_0028.jpg 9 293 | ./data/office-caltech/webcam/projector/frame_0006.jpg 9 294 | ./data/office-caltech/webcam/projector/frame_0009.jpg 9 295 | ./data/office-caltech/webcam/projector/frame_0019.jpg 9 296 | -------------------------------------------------------------------------------- /data_list.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | import os 7 | import os.path 8 | import cv2 9 | import torchvision 10 | 11 | def make_dataset(image_list, labels): 12 | if labels: 13 | len_ = len(image_list) 14 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 15 | else: 16 | if len(image_list[0].split()) > 2: 17 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 18 | else: 19 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 20 | return images 21 | 22 | 23 | def rgb_loader(path): 24 | with open(path, 'rb') as f: 25 | with Image.open(f) as img: 26 | return img.convert('RGB') 27 | 28 | def l_loader(path): 29 | with open(path, 'rb') as f: 30 | with Image.open(f) as img: 31 | return img.convert('L') 32 | 33 | class ImageList(Dataset): 34 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 35 | imgs = make_dataset(image_list, labels) 36 | if len(imgs) == 0: 37 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 38 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 39 | 40 | self.imgs = imgs 41 | self.transform = transform 42 | self.target_transform = target_transform 43 | if mode == 'RGB': 44 | self.loader = rgb_loader 45 | elif mode == 'L': 46 | self.loader = l_loader 47 | 48 | def __getitem__(self, index): 49 | path, target = self.imgs[index] 50 | img = self.loader(path) 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | if self.target_transform is not None: 54 | target = self.target_transform(target) 55 | 56 | return img, target 57 | 58 | def __len__(self): 59 | return len(self.imgs) 60 | 61 | class ImageList_idx(Dataset): 62 | def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): 63 | imgs = make_dataset(image_list, labels) 64 | if len(imgs) == 0: 65 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 66 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 67 | 68 | self.imgs = imgs 69 | self.transform = transform 70 | self.target_transform = target_transform 71 | if mode == 'RGB': 72 | self.loader = rgb_loader 73 | elif mode == 'L': 74 | self.loader = l_loader 75 | 76 | def __getitem__(self, index): 77 | path, target = self.imgs[index] 78 | img = self.loader(path) 79 | if self.transform is not None: 80 | img = self.transform(img) 81 | if self.target_transform is not None: 82 | target = self.target_transform(target) 83 | 84 | return img, target, index 85 | 86 | def __len__(self): 87 | return len(self.imgs) -------------------------------------------------------------------------------- /image/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygjwd12345/TransDA/76c76cc42a00ce465c353d51f084eb13d7f53620/image/overview.png -------------------------------------------------------------------------------- /image/result_office31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygjwd12345/TransDA/76c76cc42a00ce465c353d51f084eb13d7f53620/image/result_office31.png -------------------------------------------------------------------------------- /image/result_officehome.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ygjwd12345/TransDA/76c76cc42a00ce465c353d51f084eb13d7f53620/image/result_officehome.png -------------------------------------------------------------------------------- /image_pretrained.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | from torchvision import transforms 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | 18 | def op_copy(optimizer): 19 | for param_group in optimizer.param_groups: 20 | param_group['lr0'] = param_group['lr'] 21 | return optimizer 22 | 23 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 24 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 25 | for param_group in optimizer.param_groups: 26 | param_group['lr'] = param_group['lr0'] * decay 27 | param_group['weight_decay'] = 1e-3 28 | param_group['momentum'] = 0.9 29 | param_group['nesterov'] = True 30 | return optimizer 31 | 32 | def image_train(resize_size=256, crop_size=224, alexnet=False): 33 | if not alexnet: 34 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 35 | std=[0.229, 0.224, 0.225]) 36 | else: 37 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 38 | return transforms.Compose([ 39 | transforms.Resize((resize_size, resize_size)), 40 | transforms.RandomCrop(crop_size), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | normalize 44 | ]) 45 | 46 | def image_test(resize_size=256, crop_size=224, alexnet=False): 47 | if not alexnet: 48 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225]) 50 | else: 51 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 52 | return transforms.Compose([ 53 | transforms.Resize((resize_size, resize_size)), 54 | transforms.CenterCrop(crop_size), 55 | transforms.ToTensor(), 56 | normalize 57 | ]) 58 | 59 | def data_load(args): 60 | dsets = {} 61 | dset_loaders = {} 62 | train_bs = args.batch_size 63 | txt_tar = open(args.t_dset_path).readlines() 64 | txt_test = open(args.test_dset_path).readlines() 65 | 66 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 67 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 68 | dsets["test"] = ImageList_idx(txt_test, transform=image_test()) 69 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*3, shuffle=False, num_workers=args.worker, drop_last=False) 70 | 71 | return dset_loaders 72 | 73 | def cal_acc(loader, net, flag=False): 74 | start_test = True 75 | with torch.no_grad(): 76 | iter_test = iter(loader) 77 | for i in range(len(loader)): 78 | data = iter_test.next() 79 | inputs = data[0] 80 | labels = data[1] 81 | inputs = inputs.cuda() 82 | _, outputs = net(inputs) 83 | if start_test: 84 | all_output = outputs.float().cpu() 85 | all_label = labels.float() 86 | start_test = False 87 | else: 88 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 89 | all_label = torch.cat((all_label, labels.float()), 0) 90 | _, predict = torch.max(all_output, 1) 91 | all_output = nn.Softmax(dim=1)(all_output) 92 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(all_output.size(1)) 93 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 94 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 95 | 96 | return accuracy, mean_ent 97 | 98 | def train_target(args): 99 | dset_loaders = data_load(args) 100 | netF = network.Res50().cuda() 101 | 102 | param_group = [] 103 | for k, v in netF.named_parameters(): 104 | if k.__contains__("fc"): 105 | v.requires_grad = False 106 | else: 107 | param_group += [{'params': v, 'lr': args.lr*args.lr_decay1}] 108 | 109 | optimizer = optim.SGD(param_group) 110 | optimizer = op_copy(optimizer) 111 | 112 | max_iter = args.max_epoch * len(dset_loaders["target"]) 113 | interval_iter = max_iter // args.interval 114 | iter_num = 0 115 | 116 | netF.train() 117 | while iter_num < max_iter: 118 | try: 119 | inputs_test, _, tar_idx = iter_test.next() 120 | except: 121 | iter_test = iter(dset_loaders["target"]) 122 | inputs_test, _, tar_idx = iter_test.next() 123 | 124 | if inputs_test.size(0) == 1: 125 | continue 126 | 127 | if iter_num % interval_iter == 0 and args.cls_par > 0: 128 | netF.eval() 129 | mem_label = obtain_label(dset_loaders['test'], netF, args) 130 | mem_label = torch.from_numpy(mem_label).cuda() 131 | netF.train() 132 | 133 | inputs_test = inputs_test.cuda() 134 | iter_num += 1 135 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 136 | 137 | features_test, outputs_test = netF(inputs_test) 138 | 139 | if args.cls_par > 0: 140 | pred = mem_label[tar_idx] 141 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) 142 | classifier_loss *= args.cls_par 143 | else: 144 | classifier_loss = torch.tensor(0.0).cuda() 145 | 146 | if args.ent: 147 | softmax_out = nn.Softmax(dim=1)(outputs_test) 148 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 149 | if args.gent: 150 | msoftmax = softmax_out.mean(dim=0) 151 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 152 | entropy_loss -= gentropy_loss 153 | classifier_loss += entropy_loss * args.ent_par 154 | 155 | optimizer.zero_grad() 156 | classifier_loss.backward() 157 | optimizer.step() 158 | 159 | if iter_num % interval_iter == 0 or iter_num == max_iter: 160 | netF.eval() 161 | acc, ment = cal_acc(dset_loaders['test'], netF) 162 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.dset, iter_num, max_iter, acc*100) 163 | args.out_file.write(log_str + '\n') 164 | args.out_file.flush() 165 | print(log_str+'\n') 166 | netF.train() 167 | 168 | if args.issave: 169 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target" + args.savename + ".pt")) 170 | 171 | return netF 172 | 173 | def print_args(args): 174 | s = "==========================================\n" 175 | for arg, content in args.__dict__.items(): 176 | s += "{}:{}\n".format(arg, content) 177 | return s 178 | 179 | def obtain_label(loader, net, args): 180 | start_test = True 181 | with torch.no_grad(): 182 | iter_test = iter(loader) 183 | for _ in range(len(loader)): 184 | data = iter_test.next() 185 | inputs = data[0] 186 | labels = data[1] 187 | inputs = inputs.cuda() 188 | feas, outputs = net(inputs) 189 | if start_test: 190 | all_fea = feas.float().cpu() 191 | all_output = outputs.float().cpu() 192 | all_label = labels.float() 193 | start_test = False 194 | else: 195 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 196 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 197 | all_label = torch.cat((all_label, labels.float()), 0) 198 | 199 | all_output = nn.Softmax(dim=1)(all_output) 200 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 201 | unknown_weight = 1 - ent / np.log(args.class_num) 202 | _, predict = torch.max(all_output, 1) 203 | 204 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 205 | if args.distance == 'cosine': 206 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 207 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 208 | 209 | all_fea = all_fea.float().cpu().numpy() 210 | K = all_output.size(1) 211 | aff = all_output.float().cpu().numpy() 212 | initc = aff.transpose().dot(all_fea) 213 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 214 | cls_count = np.eye(K)[predict].sum(axis=0) 215 | labelset = np.where(cls_count>args.threshold) 216 | labelset = labelset[0] 217 | # print(labelset) 218 | 219 | dd = cdist(all_fea, initc[labelset], args.distance) 220 | pred_label = dd.argmin(axis=1) 221 | pred_label = labelset[pred_label] 222 | 223 | for round in range(1): 224 | aff = np.eye(K)[pred_label] 225 | initc = aff.transpose().dot(all_fea) 226 | initc = initc / (1e-8 + aff.sum(axis=0)[:,None]) 227 | dd = cdist(all_fea, initc[labelset], args.distance) 228 | pred_label = dd.argmin(axis=1) 229 | pred_label = labelset[pred_label] 230 | 231 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 232 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy*100, acc*100) 233 | 234 | args.out_file.write(log_str + '\n') 235 | args.out_file.flush() 236 | print(log_str+'\n') 237 | 238 | return pred_label.astype('int') #, labelset 239 | 240 | 241 | if __name__ == "__main__": 242 | parser = argparse.ArgumentParser(description='SHOT') 243 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 244 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 245 | parser.add_argument('--interval', type=int, default=15, help="max iterations") 246 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 247 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 248 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 249 | parser.add_argument('--dset', type=str, default='imagenet_caltech', choices=['imagenet_caltech']) 250 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 251 | parser.add_argument('--net', type=str, default='resnet50', help="vgg16, resnet50, resnet101") 252 | parser.add_argument('--seed', type=int, default=2019, help="random seed") 253 | parser.add_argument('--epsilon', type=float, default=1e-5) 254 | parser.add_argument('--gent', type=bool, default=False) 255 | parser.add_argument('--ent', type=bool, default=True) 256 | parser.add_argument('--threshold', type=int, default=30) 257 | 258 | parser.add_argument('--cls_par', type=float, default=0.3) 259 | parser.add_argument('--ent_par', type=float, default=1.0) 260 | parser.add_argument('--output', type=str, default='seed') 261 | parser.add_argument('--da', type=str, default='pda', choices=['pda']) 262 | parser.add_argument('--issave', type=bool, default=True) 263 | parser.add_argument('--lr_decay1', type=float, default=0.1) 264 | 265 | args = parser.parse_args() 266 | 267 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 268 | SEED = args.seed 269 | torch.manual_seed(SEED) 270 | torch.cuda.manual_seed(SEED) 271 | np.random.seed(SEED) 272 | random.seed(SEED) 273 | # torch.backends.cudnn.deterministic = True 274 | 275 | args.class_num = 1000 276 | folder = './data/' 277 | if args.da == 'pda': 278 | args.t_dset_path = folder + args.dset + '/' + 'caltech_84' + '_list.txt' 279 | args.test_dset_path = args.t_dset_path 280 | 281 | args.output_dir = osp.join(args.output, args.da, args.dset) 282 | args.name = args.dset 283 | 284 | if not osp.exists(args.output_dir): 285 | os.system('mkdir -p ' + args.output_dir) 286 | if not osp.exists(args.output_dir): 287 | os.mkdir(args.output_dir) 288 | 289 | args.savename = 'par_' + str(args.cls_par) 290 | if args.da == 'pda': 291 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold) 292 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 293 | args.out_file.write(print_args(args)+'\n') 294 | args.out_file.flush() 295 | train_target(args) -------------------------------------------------------------------------------- /image_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from loss import CrossEntropyLabelSmooth 16 | from scipy.spatial.distance import cdist 17 | from sklearn.metrics import confusion_matrix 18 | from sklearn.cluster import KMeans 19 | 20 | def op_copy(optimizer): 21 | for param_group in optimizer.param_groups: 22 | param_group['lr0'] = param_group['lr'] 23 | return optimizer 24 | 25 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 26 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = param_group['lr0'] * decay 29 | param_group['weight_decay'] = 1e-3 30 | param_group['momentum'] = 0.9 31 | param_group['nesterov'] = True 32 | return optimizer 33 | 34 | def image_train(resize_size=256, crop_size=224, alexnet=False): 35 | if not alexnet: 36 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 37 | std=[0.229, 0.224, 0.225]) 38 | else: 39 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 40 | return transforms.Compose([ 41 | transforms.Resize((resize_size, resize_size)), 42 | transforms.RandomCrop(crop_size), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | normalize 46 | ]) 47 | 48 | def image_test(resize_size=256, crop_size=224, alexnet=False): 49 | if not alexnet: 50 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225]) 52 | else: 53 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 54 | return transforms.Compose([ 55 | transforms.Resize((resize_size, resize_size)), 56 | transforms.CenterCrop(crop_size), 57 | transforms.ToTensor(), 58 | normalize 59 | ]) 60 | 61 | def data_load(args): 62 | ## prepare data 63 | dsets = {} 64 | dset_loaders = {} 65 | train_bs = args.batch_size 66 | # print(args.s_dset_path) 67 | txt_src = open(args.s_dset_path).readlines() 68 | txt_test = open(args.test_dset_path).readlines() 69 | 70 | if not args.da == 'uda': 71 | label_map_s = {} 72 | for i in range(len(args.src_classes)): 73 | label_map_s[args.src_classes[i]] = i 74 | 75 | new_src = [] 76 | for i in range(len(txt_src)): 77 | rec = txt_src[i] 78 | reci = rec.strip().split(' ') 79 | if int(reci[1]) in args.src_classes: 80 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 81 | new_src.append(line) 82 | txt_src = new_src.copy() 83 | 84 | new_tar = [] 85 | for i in range(len(txt_test)): 86 | rec = txt_test[i] 87 | reci = rec.strip().split(' ') 88 | if int(reci[1]) in args.tar_classes: 89 | if int(reci[1]) in args.src_classes: 90 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 91 | new_tar.append(line) 92 | else: 93 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 94 | new_tar.append(line) 95 | txt_test = new_tar.copy() 96 | 97 | if args.trte == "val": 98 | dsize = len(txt_src) 99 | tr_size = int(0.9*dsize) 100 | # print(dsize, tr_size, dsize - tr_size) 101 | tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 102 | else: 103 | dsize = len(txt_src) 104 | tr_size = int(0.9*dsize) 105 | _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size]) 106 | tr_txt = txt_src 107 | 108 | dsets["source_tr"] = ImageList(tr_txt, transform=image_train()) 109 | dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 110 | dsets["source_te"] = ImageList(te_txt, transform=image_test()) 111 | dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True, num_workers=args.worker, drop_last=False) 112 | dsets["test"] = ImageList(txt_test, transform=image_test()) 113 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs*2, shuffle=True, num_workers=args.worker, drop_last=False) 114 | 115 | return dset_loaders 116 | 117 | def cal_acc(loader, netF, netB, netC, flag=False): 118 | start_test = True 119 | with torch.no_grad(): 120 | iter_test = iter(loader) 121 | for i in range(len(loader)): 122 | data = iter_test.next() 123 | inputs = data[0] 124 | labels = data[1] 125 | inputs = inputs.cuda() 126 | outputs = netC(netB(netF(inputs))) 127 | if start_test: 128 | all_output = outputs.float().cpu() 129 | all_label = labels.float() 130 | start_test = False 131 | else: 132 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 133 | all_label = torch.cat((all_label, labels.float()), 0) 134 | 135 | all_output = nn.Softmax(dim=1)(all_output) 136 | _, predict = torch.max(all_output, 1) 137 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 138 | mean_ent = torch.mean(loss.Entropy(all_output)).cpu().data.item() 139 | 140 | if flag: 141 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 142 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 143 | aacc = acc.mean() 144 | aa = [str(np.round(i, 2)) for i in acc] 145 | acc = ' '.join(aa) 146 | return aacc, acc 147 | else: 148 | return accuracy*100, mean_ent 149 | 150 | def cal_acc_oda(loader, netF, netB, netC): 151 | start_test = True 152 | with torch.no_grad(): 153 | iter_test = iter(loader) 154 | for i in range(len(loader)): 155 | data = iter_test.next() 156 | inputs = data[0] 157 | labels = data[1] 158 | inputs = inputs.cuda() 159 | outputs = netC(netB(netF(inputs))) 160 | if start_test: 161 | all_output = outputs.float().cpu() 162 | all_label = labels.float() 163 | start_test = False 164 | else: 165 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 166 | all_label = torch.cat((all_label, labels.float()), 0) 167 | 168 | all_output = nn.Softmax(dim=1)(all_output) 169 | _, predict = torch.max(all_output, 1) 170 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num) 171 | ent = ent.float().cpu() 172 | initc = np.array([[0], [1]]) 173 | kmeans = KMeans(n_clusters=2, random_state=0, init=initc, n_init=1).fit(ent.reshape(-1,1)) 174 | threshold = (kmeans.cluster_centers_).mean() 175 | 176 | predict[ent>threshold] = args.class_num 177 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 178 | matrix = matrix[np.unique(all_label).astype(int),:] 179 | 180 | acc = matrix.diagonal()/matrix.sum(axis=1) * 100 181 | unknown_acc = acc[-1:].item() 182 | 183 | return np.mean(acc[:-1]), np.mean(acc), unknown_acc 184 | # return np.mean(acc), np.mean(acc[:-1]) 185 | 186 | def train_source(args): 187 | dset_loaders = data_load(args) 188 | ## set base network 189 | if args.net[0:3] == 'res': 190 | netF = network.ResBase(res_name=args.net,se=args.se,nl=args.nl).cuda() 191 | elif args.net[0:3] == 'vgg': 192 | netF = network.VGGBase(vgg_name=args.net).cuda() 193 | elif args.net == 'vit': 194 | netF = network.ViT().cuda() 195 | 196 | ### test model paremet size 197 | # model=network.ResBase(res_name=args.net) 198 | # num_params = sum([np.prod(p.size()) for p in model.parameters()]) 199 | # print("Total number of parameters: {}".format(num_params)) 200 | # 201 | # num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad]) 202 | # print("Total number of learning parameters: {}".format(num_params_update)) 203 | 204 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 205 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 206 | 207 | param_group = [] 208 | learning_rate = args.lr 209 | for k, v in netF.named_parameters(): 210 | param_group += [{'params': v, 'lr': learning_rate*0.1}] 211 | for k, v in netB.named_parameters(): 212 | param_group += [{'params': v, 'lr': learning_rate}] 213 | for k, v in netC.named_parameters(): 214 | param_group += [{'params': v, 'lr': learning_rate}] 215 | optimizer = optim.SGD(param_group) 216 | optimizer = op_copy(optimizer) 217 | 218 | acc_init = 0 219 | max_iter = args.max_epoch * len(dset_loaders["source_tr"]) 220 | interval_iter = max_iter // 10 221 | iter_num = 0 222 | 223 | netF.train() 224 | netB.train() 225 | netC.train() 226 | 227 | while iter_num < max_iter: 228 | try: 229 | inputs_source, labels_source = iter_source.next() 230 | except: 231 | iter_source = iter(dset_loaders["source_tr"]) 232 | inputs_source, labels_source = iter_source.next() 233 | 234 | if inputs_source.size(0) == 1: 235 | continue 236 | 237 | iter_num += 1 238 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 239 | 240 | inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda() 241 | outputs_source = netC(netB(netF(inputs_source))) 242 | classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source, labels_source) 243 | 244 | 245 | optimizer.zero_grad() 246 | classifier_loss.backward() 247 | optimizer.step() 248 | 249 | if iter_num % interval_iter == 0 or iter_num == max_iter: 250 | netF.eval() 251 | netB.eval() 252 | netC.eval() 253 | if args.dset=='VISDA-C': 254 | acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF, netB, netC, True) 255 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) + '\n' + acc_list 256 | else: 257 | acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC, False) 258 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te) 259 | args.out_file.write(log_str + '\n') 260 | args.out_file.flush() 261 | print(log_str+'\n') 262 | 263 | if acc_s_te >= acc_init: 264 | acc_init = acc_s_te 265 | best_netF = netF.state_dict() 266 | best_netB = netB.state_dict() 267 | best_netC = netC.state_dict() 268 | 269 | netF.train() 270 | netB.train() 271 | netC.train() 272 | 273 | torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt")) 274 | torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt")) 275 | torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt")) 276 | 277 | return netF, netB, netC 278 | 279 | def test_target(args): 280 | dset_loaders = data_load(args) 281 | ## set base network 282 | if args.net[0:3] == 'res': 283 | netF = network.ResBase(res_name=args.net).cuda() 284 | elif args.net[0:3] == 'vgg': 285 | netF = network.VGGBase(vgg_name=args.net).cuda() 286 | else: 287 | netF = network.ViT().cuda() 288 | 289 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, bottleneck_dim=args.bottleneck).cuda() 290 | netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).cuda() 291 | 292 | args.modelpath = args.output_dir_src + '/source_F.pt' 293 | netF.load_state_dict(torch.load(args.modelpath)) 294 | args.modelpath = args.output_dir_src + '/source_B.pt' 295 | netB.load_state_dict(torch.load(args.modelpath)) 296 | args.modelpath = args.output_dir_src + '/source_C.pt' 297 | netC.load_state_dict(torch.load(args.modelpath)) 298 | netF.eval() 299 | netB.eval() 300 | netC.eval() 301 | 302 | if args.da == 'oda': 303 | acc_os1, acc_os2, acc_unknown = cal_acc_oda(dset_loaders['test'], netF, netB, netC) 304 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.trte, args.name, acc_os2, acc_os1, acc_unknown) 305 | else: 306 | if args.dset=='VISDA-C': 307 | acc, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) 308 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) + '\n' + acc_list 309 | else: 310 | acc, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 311 | log_str = '\nTraining: {}, Task: {}, Accuracy = {:.2f}%'.format(args.trte, args.name, acc) 312 | 313 | args.out_file.write(log_str) 314 | args.out_file.flush() 315 | print(log_str) 316 | 317 | def print_args(args): 318 | s = "==========================================\n" 319 | for arg, content in args.__dict__.items(): 320 | s += "{}:{}\n".format(arg, content) 321 | return s 322 | 323 | if __name__ == "__main__": 324 | parser = argparse.ArgumentParser(description='SHOT') 325 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 326 | parser.add_argument('--s', type=int, default=0, help="source") 327 | parser.add_argument('--t', type=int, default=1, help="target") 328 | parser.add_argument('--max_epoch', type=int, default=20, help="max iterations") 329 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 330 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 331 | parser.add_argument('--dset', type=str, default='office-home', choices=['VISDA-C', 'office', 'office-home', 'office-caltech']) 332 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 333 | parser.add_argument('--net', type=str, default='vit', help="vgg16, resnet50, resnet101") 334 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 335 | parser.add_argument('--bottleneck', type=int, default=256) 336 | parser.add_argument('--epsilon', type=float, default=1e-5) 337 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 338 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 339 | parser.add_argument('--smooth', type=float, default=0.1) 340 | parser.add_argument('--output', type=str, default='san') 341 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda', 'oda']) 342 | parser.add_argument('--trte', type=str, default='val', choices=['full', 'val']) 343 | parser.add_argument('--bsp', type=bool, default=False) 344 | parser.add_argument('--se', type=bool, default=False) 345 | parser.add_argument('--nl', type=bool, default=False) 346 | args = parser.parse_args() 347 | 348 | if args.dset == 'office-home': 349 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 350 | args.class_num = 65 351 | if args.dset == 'office': 352 | names = ['amazon', 'dslr', 'webcam'] 353 | args.class_num = 31 354 | if args.dset == 'VISDA-C': 355 | names = ['train', 'validation'] 356 | args.class_num = 12 357 | if args.dset == 'office-caltech': 358 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 359 | args.class_num = 10 360 | 361 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 362 | SEED = args.seed 363 | torch.manual_seed(SEED) 364 | torch.cuda.manual_seed(SEED) 365 | np.random.seed(SEED) 366 | random.seed(SEED) 367 | # torch.backends.cudnn.deterministic = True 368 | 369 | folder = './data/' 370 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 371 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 372 | 373 | if args.dset == 'office-home': 374 | if args.da == 'pda': 375 | args.class_num = 65 376 | args.src_classes = [i for i in range(65)] 377 | args.tar_classes = [i for i in range(25)] 378 | if args.da == 'oda': 379 | args.class_num = 25 380 | args.src_classes = [i for i in range(25)] 381 | args.tar_classes = [i for i in range(65)] 382 | 383 | args.output_dir_src = osp.join(args.output, args.da, args.dset, names[args.s][0].upper()) 384 | args.name_src = names[args.s][0].upper() 385 | if not osp.exists(args.output_dir_src): 386 | os.system('mkdir -p ' + args.output_dir_src) 387 | if not osp.exists(args.output_dir_src): 388 | os.mkdir(args.output_dir_src) 389 | 390 | args.out_file = open(osp.join(args.output_dir_src, 'log.txt'), 'w') 391 | args.out_file.write(print_args(args)+'\n') 392 | args.out_file.flush() 393 | train_source(args) 394 | 395 | args.out_file = open(osp.join(args.output_dir_src, 'log_test.txt'), 'w') 396 | for i in range(len(names)): 397 | if i == args.s: 398 | continue 399 | args.t = i 400 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 401 | 402 | folder = './data/' 403 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 404 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 405 | 406 | if args.dset == 'office-home': 407 | if args.da == 'pda': 408 | args.class_num = 65 409 | args.src_classes = [i for i in range(65)] 410 | args.tar_classes = [i for i in range(25)] 411 | if args.da == 'oda': 412 | args.class_num = 25 413 | args.src_classes = [i for i in range(25)] 414 | args.tar_classes = [i for i in range(65)] 415 | 416 | test_target(args) 417 | 418 | 419 | -------------------------------------------------------------------------------- /image_target.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | from loss import KnowledgeDistillationLoss 18 | 19 | 20 | def op_copy(optimizer): 21 | for param_group in optimizer.param_groups: 22 | param_group['lr0'] = param_group['lr'] 23 | return optimizer 24 | 25 | 26 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 27 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = param_group['lr0'] * decay 30 | param_group['weight_decay'] = 1e-3 31 | param_group['momentum'] = 0.9 32 | param_group['nesterov'] = True 33 | return optimizer 34 | 35 | 36 | def image_train(resize_size=256, crop_size=224, alexnet=False): 37 | if not alexnet: 38 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 39 | std=[0.229, 0.224, 0.225]) 40 | else: 41 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 42 | return transforms.Compose([ 43 | transforms.Resize((resize_size, resize_size)), 44 | transforms.RandomCrop(crop_size), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | normalize 48 | ]) 49 | 50 | 51 | def image_test(resize_size=256, crop_size=224, alexnet=False): 52 | if not alexnet: 53 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 54 | std=[0.229, 0.224, 0.225]) 55 | else: 56 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 57 | return transforms.Compose([ 58 | transforms.Resize((resize_size, resize_size)), 59 | transforms.CenterCrop(crop_size), 60 | transforms.ToTensor(), 61 | normalize 62 | ]) 63 | 64 | 65 | def data_load(args): 66 | ## prepare data 67 | dsets = {} 68 | dset_loaders = {} 69 | train_bs = args.batch_size 70 | txt_tar = open(args.t_dset_path).readlines() 71 | txt_test = open(args.test_dset_path).readlines() 72 | 73 | if not args.da == 'uda': 74 | label_map_s = {} 75 | for i in range(len(args.src_classes)): 76 | label_map_s[args.src_classes[i]] = i 77 | 78 | new_tar = [] 79 | for i in range(len(txt_tar)): 80 | rec = txt_tar[i] 81 | reci = rec.strip().split(' ') 82 | if int(reci[1]) in args.tar_classes: 83 | if int(reci[1]) in args.src_classes: 84 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 85 | new_tar.append(line) 86 | else: 87 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 88 | new_tar.append(line) 89 | txt_tar = new_tar.copy() 90 | txt_test = txt_tar.copy() 91 | 92 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 93 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, 94 | drop_last=False) 95 | dsets["test"] = ImageList_idx(txt_test, transform=image_test()) 96 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker, 97 | drop_last=False) 98 | 99 | return dset_loaders 100 | 101 | 102 | def cal_acc(loader, netF, netB, netC, flag=False): 103 | start_test = True 104 | with torch.no_grad(): 105 | iter_test = iter(loader) 106 | for i in range(len(loader)): 107 | data = iter_test.next() 108 | inputs = data[0] 109 | labels = data[1] 110 | inputs = inputs.cuda() 111 | outputs = netC(netB(netF(inputs))) 112 | if start_test: 113 | all_output = outputs.float().cpu() 114 | all_label = labels.float() 115 | start_test = False 116 | else: 117 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 118 | all_label = torch.cat((all_label, labels.float()), 0) 119 | _, predict = torch.max(all_output, 1) 120 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 121 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 122 | 123 | if flag: 124 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 125 | acc = matrix.diagonal() / matrix.sum(axis=1) * 100 126 | aacc = acc.mean() 127 | aa = [str(np.round(i, 2)) for i in acc] 128 | acc = ' '.join(aa) 129 | return aacc, acc 130 | else: 131 | return accuracy * 100, mean_ent 132 | 133 | 134 | def train_target(args): 135 | dset_loaders = data_load(args) 136 | ## set base network 137 | if args.net[0:3] == 'res': 138 | netF = network.ResBase(res_name=args.net, se=args.se, nl=args.nl).cuda() 139 | elif args.net[0:3] == 'vgg': 140 | netF = network.VGGBase(vgg_name=args.net).cuda() 141 | elif args.net == 'vit': 142 | netF = network.ViT().cuda() 143 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, 144 | bottleneck_dim=args.bottleneck).cuda() 145 | netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() 146 | 147 | modelpath = args.output_dir_src + '/source_F.pt' 148 | netF.load_state_dict(torch.load(modelpath), strict=False) 149 | modelpath = args.output_dir_src + '/source_B.pt' 150 | netB.load_state_dict(torch.load(modelpath)) 151 | modelpath = args.output_dir_src + '/source_C.pt' 152 | netC.load_state_dict(torch.load(modelpath)) 153 | netC.eval() 154 | for k, v in netC.named_parameters(): 155 | v.requires_grad = False 156 | ### add teacher module 157 | if args.net[0:3] == 'res': 158 | netF_t = network.ResBase(res_name=args.net, se=args.se, nl=args.nl).cuda() 159 | elif args.net[0:3] == 'vgg': 160 | netF_t = network.VGGBase(vgg_name=args.net).cuda() 161 | elif args.net == 'vit': 162 | netF_t = network.ViT().cuda() 163 | netB_t = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, 164 | bottleneck_dim=args.bottleneck).cuda() 165 | ### initial from student 166 | netF_t.load_state_dict(netF.state_dict()) 167 | netB_t.load_state_dict(netB.state_dict()) 168 | 169 | ### remove grad 170 | for k, v in netF_t.named_parameters(): 171 | v.requires_grad = False 172 | for k, v in netB_t.named_parameters(): 173 | v.requires_grad = False 174 | 175 | param_group = [] 176 | for k, v in netF.named_parameters(): 177 | if args.lr_decay1 > 0: 178 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 179 | else: 180 | v.requires_grad = False 181 | for k, v in netB.named_parameters(): 182 | if args.lr_decay2 > 0: 183 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] 184 | else: 185 | v.requires_grad = False 186 | 187 | optimizer = optim.SGD(param_group) 188 | optimizer = op_copy(optimizer) 189 | 190 | max_iter = args.max_epoch * len(dset_loaders["target"]) 191 | interval_iter = max_iter // args.interval 192 | iter_num = 0 193 | 194 | while iter_num < max_iter: 195 | try: 196 | inputs_test, _, tar_idx = iter_test.next() 197 | except: 198 | iter_test = iter(dset_loaders["target"]) 199 | inputs_test, _, tar_idx = iter_test.next() 200 | 201 | if inputs_test.size(0) == 1: 202 | continue 203 | 204 | if iter_num % interval_iter == 0 and args.cls_par > 0: 205 | netF.eval() 206 | netB.eval() 207 | netF_t.eval() 208 | netB_t.eval() 209 | mem_label, dd = obtain_label(dset_loaders['test'], netF_t, netB_t, netC, args) 210 | mem_label = torch.from_numpy(mem_label).cuda() 211 | dd = torch.from_numpy(dd).cuda() 212 | 213 | netF.train() 214 | netB.train() 215 | 216 | inputs_test = inputs_test.cuda() 217 | 218 | iter_num += 1 219 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 220 | 221 | features_test = netB(netF(inputs_test)) 222 | outputs_test = netC(features_test) 223 | 224 | if args.cls_par > 0: 225 | pred = mem_label[tar_idx] 226 | pred_soft = dd[tar_idx] 227 | classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred) 228 | classifier_loss *= args.cls_par 229 | if args.kd: 230 | kd_loss = KnowledgeDistillationLoss()(outputs_test, pred_soft) 231 | classifier_loss += kd_loss 232 | if iter_num < interval_iter and args.dset == "VISDA-C": 233 | classifier_loss *= 0 234 | else: 235 | classifier_loss = torch.tensor(0.0).cuda() 236 | 237 | if args.ent: 238 | softmax_out = nn.Softmax(dim=1)(outputs_test) 239 | entropy_loss = torch.mean(loss.Entropy(softmax_out)) 240 | if args.gent: 241 | msoftmax = softmax_out.mean(dim=0) 242 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 243 | entropy_loss -= gentropy_loss 244 | im_loss = entropy_loss * args.ent_par 245 | classifier_loss += im_loss 246 | 247 | optimizer.zero_grad() 248 | classifier_loss.backward() 249 | optimizer.step() 250 | # EMA update for the teacher 251 | with torch.no_grad(): 252 | m = 0.001 # momentum parameter 253 | for param_q, param_k in zip(netF.parameters(), netF_t.parameters()): 254 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 255 | for param_q, param_k in zip(netB.parameters(), netB_t.parameters()): 256 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 257 | if iter_num % interval_iter == 0 or iter_num == max_iter: 258 | netF.eval() 259 | netB.eval() 260 | if args.dset == 'VISDA-C': 261 | acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) 262 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, 263 | acc_s_te) + '\n' + acc_list 264 | else: 265 | acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 266 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name, iter_num, max_iter, acc_s_te) 267 | 268 | args.out_file.write(log_str + '\n') 269 | args.out_file.flush() 270 | print(log_str + '\n') 271 | netF.train() 272 | netB.train() 273 | 274 | if args.issave: 275 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) 276 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) 277 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) 278 | 279 | return netF, netB, netC 280 | 281 | 282 | def print_args(args): 283 | s = "==========================================\n" 284 | for arg, content in args.__dict__.items(): 285 | s += "{}:{}\n".format(arg, content) 286 | return s 287 | 288 | 289 | def obtain_label(loader, netF, netB, netC, args): 290 | start_test = True 291 | with torch.no_grad(): 292 | iter_test = iter(loader) 293 | for _ in range(len(loader)): 294 | data = iter_test.next() 295 | inputs = data[0] 296 | labels = data[1] 297 | inputs = inputs.cuda() 298 | feas = netB(netF(inputs)) 299 | outputs = netC(feas) 300 | if start_test: 301 | all_fea = feas.float().cpu() 302 | all_output = outputs.float().cpu() 303 | all_label = labels.float() 304 | start_test = False 305 | else: 306 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 307 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 308 | all_label = torch.cat((all_label, labels.float()), 0) 309 | 310 | all_output = nn.Softmax(dim=1)(all_output) 311 | # print(all_output.shape) 312 | # ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 313 | # unknown_weight = 1 - ent / np.log(args.class_num) 314 | _, predict = torch.max(all_output, 1) 315 | 316 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 317 | 318 | if args.distance == 'cosine': 319 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 320 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 321 | ### all_fea: extractor feature [bs,N] 322 | # print(all_fea.shape) 323 | all_fea = all_fea.float().cpu().numpy() 324 | K = all_output.size(1) 325 | aff = all_output.float().cpu().numpy() 326 | ### aff: softmax output [bs,c] 327 | # print(aff.shape) 328 | initc = aff.transpose().dot(all_fea) 329 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 330 | # print(initc.shape) 331 | cls_count = np.eye(K)[predict].sum(axis=0) 332 | labelset = np.where(cls_count > args.threshold) 333 | labelset = labelset[0] 334 | # print(labelset) 335 | 336 | dd = cdist(all_fea, initc[labelset], args.distance) 337 | pred_label = dd.argmin(axis=1) 338 | pred_label = labelset[pred_label] 339 | # acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 340 | # log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) 341 | # args.out_file.write(log_str + '\n') 342 | # args.out_file.flush() 343 | # print(log_str+'\n') 344 | 345 | for round in range(1): 346 | aff = np.eye(K)[pred_label] 347 | initc = aff.transpose().dot(all_fea) 348 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 349 | dd = cdist(all_fea, initc[labelset], args.distance) 350 | pred_label = dd.argmin(axis=1) 351 | pred_label = labelset[pred_label] 352 | 353 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 354 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) 355 | 356 | args.out_file.write(log_str + '\n') 357 | args.out_file.flush() 358 | print(log_str + '\n') 359 | 360 | return pred_label.astype('int'), dd 361 | 362 | 363 | if __name__ == "__main__": 364 | parser = argparse.ArgumentParser(description='SHOT') 365 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 366 | parser.add_argument('--s', type=int, default=0, help="source") 367 | parser.add_argument('--t', type=int, default=1, help="target") 368 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 369 | parser.add_argument('--interval', type=int, default=15) 370 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 371 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 372 | parser.add_argument('--dset', type=str, default='office-home', 373 | choices=['VISDA-C', 'office', 'office-home', 'office-caltech']) 374 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 375 | parser.add_argument('--net', type=str, default='vit', help="alexnet, vgg16, resnet50, res101") 376 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 377 | 378 | parser.add_argument('--gent', type=bool, default=True) 379 | parser.add_argument('--ent', type=bool, default=True) 380 | parser.add_argument('--kd', type=bool, default=False) 381 | parser.add_argument('--se', type=bool, default=False) 382 | parser.add_argument('--nl', type=bool, default=False) 383 | 384 | parser.add_argument('--threshold', type=int, default=0) 385 | parser.add_argument('--cls_par', type=float, default=0.3) 386 | parser.add_argument('--ent_par', type=float, default=1.0) 387 | parser.add_argument('--lr_decay1', type=float, default=0.1) 388 | parser.add_argument('--lr_decay2', type=float, default=1.0) 389 | 390 | parser.add_argument('--bottleneck', type=int, default=256) 391 | parser.add_argument('--epsilon', type=float, default=1e-5) 392 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 393 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 394 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 395 | parser.add_argument('--output', type=str, default='san') 396 | parser.add_argument('--output_src', type=str, default='san') 397 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 398 | parser.add_argument('--issave', type=bool, default=True) 399 | args = parser.parse_args() 400 | 401 | if args.dset == 'office-home': 402 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 403 | args.class_num = 65 404 | if args.dset == 'office': 405 | names = ['amazon', 'dslr', 'webcam'] 406 | args.class_num = 31 407 | if args.dset == 'VISDA-C': 408 | names = ['train', 'validation'] 409 | args.class_num = 12 410 | if args.dset == 'office-caltech': 411 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 412 | args.class_num = 10 413 | 414 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 415 | SEED = args.seed 416 | torch.manual_seed(SEED) 417 | torch.cuda.manual_seed(SEED) 418 | np.random.seed(SEED) 419 | random.seed(SEED) 420 | # torch.backends.cudnn.deterministic = True 421 | 422 | for i in range(len(names)): 423 | if i == args.s: 424 | continue 425 | args.t = i 426 | 427 | folder = './data/' 428 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 429 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 430 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 431 | 432 | if args.dset == 'office-home': 433 | if args.da == 'pda': 434 | args.class_num = 65 435 | args.src_classes = [i for i in range(65)] 436 | args.tar_classes = [i for i in range(25)] 437 | 438 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 439 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper() + names[args.t][0].upper()) 440 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 441 | 442 | if not osp.exists(args.output_dir): 443 | os.system('mkdir -p ' + args.output_dir) 444 | if not osp.exists(args.output_dir): 445 | os.mkdir(args.output_dir) 446 | 447 | args.savename = 'par_' + str(args.cls_par) 448 | if args.da == 'pda': 449 | args.gent = '' 450 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold) 451 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 452 | args.out_file.write(print_args(args) + '\n') 453 | args.out_file.flush() 454 | train_target(args) -------------------------------------------------------------------------------- /image_target_oda.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | from sklearn.cluster import KMeans 18 | from loss import KnowledgeDistillationLoss 19 | 20 | 21 | def op_copy(optimizer): 22 | for param_group in optimizer.param_groups: 23 | param_group['lr0'] = param_group['lr'] 24 | return optimizer 25 | 26 | 27 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 28 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 29 | for param_group in optimizer.param_groups: 30 | param_group['lr'] = param_group['lr0'] * decay 31 | param_group['weight_decay'] = 1e-3 32 | param_group['momentum'] = 0.9 33 | param_group['nesterov'] = True 34 | return optimizer 35 | 36 | 37 | def image_train(resize_size=256, crop_size=224, alexnet=False): 38 | if not alexnet: 39 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]) 41 | else: 42 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 43 | return transforms.Compose([ 44 | transforms.Resize((resize_size, resize_size)), 45 | transforms.RandomCrop(crop_size), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | normalize 49 | ]) 50 | 51 | 52 | def image_test(resize_size=256, crop_size=224, alexnet=False): 53 | if not alexnet: 54 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 55 | std=[0.229, 0.224, 0.225]) 56 | else: 57 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 58 | return transforms.Compose([ 59 | transforms.Resize((resize_size, resize_size)), 60 | transforms.CenterCrop(crop_size), 61 | transforms.ToTensor(), 62 | normalize 63 | ]) 64 | 65 | 66 | def data_load(args): 67 | ## prepare data 68 | dsets = {} 69 | dset_loaders = {} 70 | train_bs = args.batch_size 71 | txt_src = open(args.s_dset_path).readlines() 72 | txt_tar = open(args.t_dset_path).readlines() 73 | txt_test = open(args.test_dset_path).readlines() 74 | 75 | if not args.da == 'uda': 76 | label_map_s = {} 77 | for i in range(len(args.src_classes)): 78 | label_map_s[args.src_classes[i]] = i 79 | 80 | new_tar = [] 81 | for i in range(len(txt_tar)): 82 | rec = txt_tar[i] 83 | reci = rec.strip().split(' ') 84 | if int(reci[1]) in args.tar_classes: 85 | if int(reci[1]) in args.src_classes: 86 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 87 | new_tar.append(line) 88 | else: 89 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 90 | new_tar.append(line) 91 | txt_tar = new_tar.copy() 92 | txt_test = txt_tar.copy() 93 | 94 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 95 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, 96 | drop_last=False) 97 | dsets["test"] = ImageList(txt_test, transform=image_test()) 98 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker, 99 | drop_last=False) 100 | 101 | return dset_loaders 102 | 103 | 104 | def cal_acc(loader, netF, netB, netC, flag=False, threshold=0.1): 105 | start_test = True 106 | with torch.no_grad(): 107 | iter_test = iter(loader) 108 | for i in range(len(loader)): 109 | data = iter_test.next() 110 | inputs = data[0] 111 | labels = data[1] 112 | inputs = inputs.cuda() 113 | outputs = netC(netB(netF(inputs))) 114 | if start_test: 115 | all_output = outputs.float().cpu() 116 | all_label = labels.float() 117 | start_test = False 118 | else: 119 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 120 | all_label = torch.cat((all_label, labels.float()), 0) 121 | _, predict = torch.max(all_output, 1) 122 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 123 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 124 | 125 | if flag: 126 | all_output = nn.Softmax(dim=1)(all_output) 127 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num) 128 | 129 | from sklearn.cluster import KMeans 130 | kmeans = KMeans(2, random_state=0).fit(ent.reshape(-1, 1)) 131 | labels = kmeans.predict(ent.reshape(-1, 1)) 132 | 133 | idx = np.where(labels == 1)[0] 134 | iidx = 0 135 | if ent[idx].mean() > ent.mean(): 136 | iidx = 1 137 | predict[np.where(labels == iidx)[0]] = args.class_num 138 | 139 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 140 | matrix = matrix[np.unique(all_label).astype(int), :] 141 | 142 | acc = matrix.diagonal() / matrix.sum(axis=1) * 100 143 | unknown_acc = acc[-1:].item() 144 | return np.mean(acc[:-1]), np.mean(acc), unknown_acc 145 | else: 146 | return accuracy * 100, mean_ent 147 | 148 | 149 | def print_args(args): 150 | s = "==========================================\n" 151 | for arg, content in args.__dict__.items(): 152 | s += "{}:{}\n".format(arg, content) 153 | return s 154 | 155 | 156 | def train_target(args): 157 | dset_loaders = data_load(args) 158 | ## set base network 159 | if args.net[0:3] == 'res': 160 | netF = network.ResBase(res_name=args.net).cuda() 161 | elif args.net[0:3] == 'vgg': 162 | netF = network.VGGBase(vgg_name=args.net).cuda() 163 | else: 164 | netF = network.ViT().cuda() 165 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, 166 | bottleneck_dim=args.bottleneck).cuda() 167 | netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() 168 | 169 | args.modelpath = args.output_dir_src + '/source_F.pt' 170 | netF.load_state_dict(torch.load(args.modelpath)) 171 | args.modelpath = args.output_dir_src + '/source_B.pt' 172 | netB.load_state_dict(torch.load(args.modelpath)) 173 | args.modelpath = args.output_dir_src + '/source_C.pt' 174 | netC.load_state_dict(torch.load(args.modelpath)) 175 | netC.eval() 176 | for k, v in netC.named_parameters(): 177 | v.requires_grad = False 178 | ### add teacher module 179 | if args.net[0:3] == 'res': 180 | netF_t = network.ResBase(res_name=args.net, se=args.se, nl=args.nl).cuda() 181 | elif args.net[0:3] == 'vgg': 182 | netF_t = network.VGGBase(vgg_name=args.net).cuda() 183 | elif args.net == 'vit': 184 | netF_t = network.ViT().cuda() 185 | netB_t = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, 186 | bottleneck_dim=args.bottleneck).cuda() 187 | ### initial from student 188 | netF_t.load_state_dict(netF.state_dict()) 189 | netB_t.load_state_dict(netB.state_dict()) 190 | 191 | ### remove grad 192 | for k, v in netF_t.named_parameters(): 193 | v.requires_grad = False 194 | for k, v in netB_t.named_parameters(): 195 | v.requires_grad = False 196 | param_group = [] 197 | 198 | for k, v in netF.named_parameters(): 199 | if args.lr_decay1 > 0: 200 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 201 | else: 202 | v.requires_grad = False 203 | for k, v in netB.named_parameters(): 204 | if args.lr_decay2 > 0: 205 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] 206 | else: 207 | v.requires_grad = False 208 | 209 | optimizer = optim.SGD(param_group) 210 | optimizer = op_copy(optimizer) 211 | 212 | tt = 0 213 | iter_num = 0 214 | max_iter = args.max_epoch * len(dset_loaders["target"]) 215 | interval_iter = max_iter // args.interval 216 | 217 | while iter_num < max_iter: 218 | try: 219 | inputs_test, _, tar_idx = iter_test.next() 220 | except: 221 | iter_test = iter(dset_loaders["target"]) 222 | inputs_test, _, tar_idx = iter_test.next() 223 | 224 | if inputs_test.size(0) == 1: 225 | continue 226 | 227 | if iter_num % interval_iter == 0: 228 | netF.eval() 229 | netB.eval() 230 | mem_label, ENT_THRESHOLD,dd = obtain_label(dset_loaders['test'], netF_t, netB_t, netC, args) 231 | mem_label = torch.from_numpy(mem_label).cuda() 232 | netF.train() 233 | netB.train() 234 | 235 | inputs_test = inputs_test.cuda() 236 | 237 | iter_num += 1 238 | lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter) 239 | 240 | pred = mem_label[tar_idx] 241 | features_test = netB(netF(inputs_test)) 242 | outputs_test = netC(features_test) 243 | 244 | softmax_out = nn.Softmax(dim=1)(outputs_test) 245 | outputs_test_known = outputs_test[pred < args.class_num, :] 246 | pred = pred[pred < args.class_num] 247 | pred_soft = dd[tar_idx] 248 | pred_soft = torch.tensor(pred_soft).cuda() 249 | if len(pred) == 0: 250 | print(tt) 251 | del features_test 252 | del outputs_test 253 | tt += 1 254 | continue 255 | 256 | if args.cls_par > 0: 257 | classifier_loss = nn.CrossEntropyLoss()(outputs_test_known, pred) 258 | classifier_loss *= args.cls_par 259 | if args.kd: 260 | kd_loss = KnowledgeDistillationLoss()(outputs_test, pred_soft) 261 | classifier_loss += kd_loss 262 | else: 263 | classifier_loss = torch.tensor(0.0).cuda() 264 | 265 | if args.ent: 266 | softmax_out_known = nn.Softmax(dim=1)(outputs_test_known) 267 | entropy_loss = torch.mean(loss.Entropy(softmax_out_known)) 268 | if args.gent: 269 | msoftmax = softmax_out.mean(dim=0) 270 | gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon)) 271 | entropy_loss -= gentropy_loss 272 | classifier_loss += entropy_loss * args.ent_par 273 | # EMA update for the teacher 274 | with torch.no_grad(): 275 | m = 0.001 # momentum parameter 276 | for param_q, param_k in zip(netF.parameters(), netF_t.parameters()): 277 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 278 | for param_q, param_k in zip(netB.parameters(), netB_t.parameters()): 279 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 280 | optimizer.zero_grad() 281 | classifier_loss.backward() 282 | optimizer.step() 283 | 284 | if iter_num % interval_iter == 0 or iter_num == max_iter: 285 | netF.eval() 286 | netB.eval() 287 | acc_os1, acc_os2, acc_unknown = cal_acc(dset_loaders['test'], netF, netB, netC, True, ENT_THRESHOLD) 288 | log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}% / {:.2f}% / {:.2f}%'.format(args.name, iter_num, 289 | max_iter, acc_os2, acc_os1, 290 | acc_unknown) 291 | args.out_file.write(log_str + '\n') 292 | args.out_file.flush() 293 | print(log_str + '\n') 294 | netF.train() 295 | netB.train() 296 | 297 | if args.issave: 298 | torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F_" + args.savename + ".pt")) 299 | torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B_" + args.savename + ".pt")) 300 | torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C_" + args.savename + ".pt")) 301 | 302 | return netF, netB, netC 303 | 304 | 305 | def obtain_label(loader, netF, netB, netC, args): 306 | start_test = True 307 | with torch.no_grad(): 308 | iter_test = iter(loader) 309 | for _ in range(len(loader)): 310 | data = iter_test.next() 311 | inputs = data[0] 312 | labels = data[1] 313 | inputs = inputs.cuda() 314 | feas = netB(netF(inputs)) 315 | outputs = netC(feas) 316 | if start_test: 317 | all_fea = feas.float().cpu() 318 | all_output = outputs.float().cpu() 319 | all_label = labels.float() 320 | start_test = False 321 | else: 322 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 323 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 324 | all_label = torch.cat((all_label, labels.float()), 0) 325 | 326 | all_output = nn.Softmax(dim=1)(all_output) 327 | _, predict = torch.max(all_output, 1) 328 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 329 | if args.distance == 'cosine': 330 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 331 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 332 | 333 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) / np.log(args.class_num) 334 | ent = ent.float().cpu() 335 | 336 | from sklearn.cluster import KMeans 337 | kmeans = KMeans(2, random_state=0).fit(ent.reshape(-1, 1)) 338 | labels = kmeans.predict(ent.reshape(-1, 1)) 339 | 340 | idx = np.where(labels == 1)[0] 341 | iidx = 0 342 | if ent[idx].mean() > ent.mean(): 343 | iidx = 1 344 | known_idx = np.where(kmeans.labels_ != iidx)[0] 345 | 346 | all_fea = all_fea[known_idx, :] 347 | all_output = all_output[known_idx, :] 348 | predict = predict[known_idx] 349 | all_label_idx = all_label[known_idx] 350 | ENT_THRESHOLD = (kmeans.cluster_centers_).mean() 351 | 352 | all_fea = all_fea.float().cpu().numpy() 353 | K = all_output.size(1) 354 | aff = all_output.float().cpu().numpy() 355 | initc = aff.transpose().dot(all_fea) 356 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 357 | cls_count = np.eye(K)[predict].sum(axis=0) 358 | labelset = np.where(cls_count > args.threshold) 359 | labelset = labelset[0] 360 | 361 | dd = cdist(all_fea, initc[labelset], args.distance) 362 | pred_label = dd.argmin(axis=1) 363 | pred_label = labelset[pred_label] 364 | 365 | for round in range(1): 366 | aff = np.eye(K)[pred_label] 367 | initc = aff.transpose().dot(all_fea) 368 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 369 | dd = cdist(all_fea, initc[labelset], args.distance) 370 | pred_label = dd.argmin(axis=1) 371 | pred_label = labelset[pred_label] 372 | 373 | guess_label = args.class_num * np.ones(len(all_label), ) 374 | guess_label[known_idx] = pred_label 375 | D =np.ones((len(all_label),dd.shape[1] )) 376 | D[known_idx] = dd 377 | acc = np.sum(guess_label == all_label.float().numpy()) / len(all_label_idx) 378 | log_str = 'Threshold = {:.2f}, Accuracy = {:.2f}% -> {:.2f}%'.format(ENT_THRESHOLD, accuracy * 100, acc * 100) 379 | 380 | return guess_label.astype('int'), ENT_THRESHOLD,D 381 | 382 | 383 | if __name__ == "__main__": 384 | parser = argparse.ArgumentParser(description='SHOT') 385 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 386 | parser.add_argument('--s', type=int, default=0, help="source") 387 | parser.add_argument('--t', type=int, default=1, help="target") 388 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 389 | parser.add_argument('--interval', type=int, default=15) 390 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 391 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 392 | parser.add_argument('--dset', type=str, default='office-home', choices=['office-home']) 393 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 394 | parser.add_argument('--net', type=str, default='vit', help="vgg16, resnet50, resnet101") 395 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 396 | 397 | parser.add_argument('--gent', type=bool, default=True) 398 | parser.add_argument('--ent', type=bool, default=True) 399 | parser.add_argument('--kd', type=bool, default=False) 400 | 401 | parser.add_argument('--threshold', type=int, default=0) 402 | parser.add_argument('--cls_par', type=float, default=0.3) 403 | parser.add_argument('--ent_par', type=float, default=1.0) 404 | parser.add_argument('--lr_decay1', type=float, default=0.1) 405 | parser.add_argument('--lr_decay2', type=float, default=1.0) 406 | 407 | parser.add_argument('--bottleneck', type=int, default=256) 408 | parser.add_argument('--epsilon', type=float, default=1e-5) 409 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 410 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 411 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 412 | parser.add_argument('--output', type=str, default='san') 413 | parser.add_argument('--output_src', type=str, default='san') 414 | parser.add_argument('--da', type=str, default='oda', choices=['oda']) 415 | parser.add_argument('--issave', type=bool, default=True) 416 | args = parser.parse_args() 417 | 418 | if args.dset == 'office-home': 419 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 420 | args.class_num = 65 421 | 422 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 423 | SEED = args.seed 424 | torch.manual_seed(SEED) 425 | torch.cuda.manual_seed(SEED) 426 | np.random.seed(SEED) 427 | random.seed(SEED) 428 | # torch.backends.cudnn.deterministic = True 429 | 430 | for i in range(len(names)): 431 | if i == args.s: 432 | continue 433 | args.t = i 434 | 435 | folder = './data/' 436 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 437 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 438 | args.test_dset_path = args.t_dset_path 439 | 440 | if args.dset == 'office-home': 441 | if args.da == 'oda': 442 | args.class_num = 25 443 | args.src_classes = [i for i in range(25)] 444 | args.tar_classes = [i for i in range(65)] 445 | 446 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 447 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper() + names[args.t][0].upper()) 448 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 449 | 450 | if not osp.exists(args.output_dir): 451 | os.system('mkdir -p ' + args.output_dir) 452 | if not osp.exists(args.output_dir): 453 | os.mkdir(args.output_dir) 454 | 455 | args.savename = 'par_' + str(args.cls_par) 456 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 457 | args.out_file.write(print_args(args) + '\n') 458 | args.out_file.flush() 459 | train_target(args) -------------------------------------------------------------------------------- /image_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import torchvision 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import transforms 10 | import network, loss 11 | from torch.utils.data import DataLoader 12 | from data_list import ImageList, ImageList_idx 13 | import random, pdb, math, copy 14 | from tqdm import tqdm 15 | from scipy.spatial.distance import cdist 16 | from sklearn.metrics import confusion_matrix 17 | 18 | 19 | def op_copy(optimizer): 20 | for param_group in optimizer.param_groups: 21 | param_group['lr0'] = param_group['lr'] 22 | return optimizer 23 | 24 | 25 | def lr_scheduler(optimizer, iter_num, max_iter, gamma=10, power=0.75): 26 | decay = (1 + gamma * iter_num / max_iter) ** (-power) 27 | for param_group in optimizer.param_groups: 28 | param_group['lr'] = param_group['lr0'] * decay 29 | param_group['weight_decay'] = 1e-3 30 | param_group['momentum'] = 0.9 31 | param_group['nesterov'] = True 32 | return optimizer 33 | 34 | 35 | def image_train(resize_size=256, crop_size=224, alexnet=False): 36 | if not alexnet: 37 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 38 | std=[0.229, 0.224, 0.225]) 39 | else: 40 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 41 | return transforms.Compose([ 42 | transforms.Resize((resize_size, resize_size)), 43 | transforms.RandomCrop(crop_size), 44 | transforms.RandomHorizontalFlip(), 45 | transforms.ToTensor(), 46 | normalize 47 | ]) 48 | 49 | 50 | def image_test(resize_size=256, crop_size=224, alexnet=False): 51 | if not alexnet: 52 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225]) 54 | else: 55 | normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy') 56 | return transforms.Compose([ 57 | transforms.Resize((resize_size, resize_size)), 58 | transforms.CenterCrop(crop_size), 59 | transforms.ToTensor(), 60 | normalize 61 | ]) 62 | 63 | 64 | def data_load(args): 65 | ## prepare data 66 | dsets = {} 67 | dset_loaders = {} 68 | train_bs = args.batch_size 69 | txt_tar = open(args.t_dset_path).readlines() 70 | txt_test = open(args.test_dset_path).readlines() 71 | 72 | if not args.da == 'uda': 73 | label_map_s = {} 74 | for i in range(len(args.src_classes)): 75 | label_map_s[args.src_classes[i]] = i 76 | 77 | new_tar = [] 78 | for i in range(len(txt_tar)): 79 | rec = txt_tar[i] 80 | reci = rec.strip().split(' ') 81 | if int(reci[1]) in args.tar_classes: 82 | if int(reci[1]) in args.src_classes: 83 | line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n' 84 | new_tar.append(line) 85 | else: 86 | line = reci[0] + ' ' + str(len(label_map_s)) + '\n' 87 | new_tar.append(line) 88 | txt_tar = new_tar.copy() 89 | txt_test = txt_tar.copy() 90 | 91 | dsets["target"] = ImageList_idx(txt_tar, transform=image_train()) 92 | dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, shuffle=True, num_workers=args.worker, 93 | drop_last=False) 94 | dsets["test"] = ImageList_idx(txt_test, transform=image_test()) 95 | dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 3, shuffle=False, num_workers=args.worker, 96 | drop_last=False) 97 | 98 | return dset_loaders 99 | 100 | 101 | def cal_acc(loader, netF, netB, netC, flag=False): 102 | start_test = True 103 | with torch.no_grad(): 104 | iter_test = iter(loader) 105 | for i in range(len(loader)): 106 | data = iter_test.next() 107 | inputs = data[0] 108 | labels = data[1] 109 | inputs = inputs.cuda() 110 | outputs = netC(netB(netF(inputs))) 111 | outputs_tsne=netB(netF(inputs)) 112 | if start_test: 113 | all_output = outputs.float().cpu() 114 | all_output_tsne= outputs_tsne.float().cpu() 115 | all_label = labels.float() 116 | start_test = False 117 | else: 118 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 119 | all_output_tsne = torch.cat((all_output_tsne, outputs_tsne.float().cpu()), 0) 120 | all_label = torch.cat((all_label, labels.float()), 0) 121 | ### tsne 122 | b_out = all_output_tsne.squeeze().cpu() 123 | b_label = all_label.cpu() 124 | colormap_dir = './tsne' 125 | if not os.path.isdir(colormap_dir): 126 | os.mkdir(colormap_dir) 127 | from sklearn.manifold import TSNE 128 | from matplotlib import pyplot as plt 129 | import seaborn as sns 130 | from datetime import datetime 131 | now = datetime.now() 132 | timestamp = datetime.timestamp(now) 133 | dt_object = datetime.fromtimestamp(timestamp) 134 | 135 | colors = sns.color_palette('pastel').as_hex() + sns.color_palette('dark').as_hex() + sns.color_palette('deep').as_hex() + sns.color_palette('muted').as_hex() 136 | print('tsne start!!!') 137 | tsne = TSNE(n_components=2, random_state=0, n_jobs=16) 138 | out_2d = tsne.fit_transform(b_out) 139 | print('tsne done!!!') 140 | # plot the result 141 | vis_x = out_2d[:, 0] 142 | vis_y = out_2d[:, 1] 143 | fig, ax = plt.subplots() 144 | print(np.unique(b_label).shape) 145 | for j in np.unique(b_label).astype(np.int64): 146 | plt.scatter(vis_x[b_label == j], vis_y[b_label == j], c=colors[j]) 147 | # plt.colorbar(ticks=range(21)) 148 | fig.tight_layout() 149 | ## save confusion matrx 150 | fig.savefig(os.path.join(colormap_dir, str(dt_object)+'tsne.png')) 151 | ### 152 | 153 | _, predict = torch.max(all_output, 1) 154 | ### output record 155 | 156 | with open("Output.txt", "w") as text_file: 157 | for i in range(all_label.size()[0]): 158 | a=(torch.squeeze(predict).float() == all_label)[i] 159 | text_file.write("%s \n" % a) 160 | 161 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 162 | mean_ent = torch.mean(loss.Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item() 163 | 164 | if flag: 165 | matrix = confusion_matrix(all_label, torch.squeeze(predict).float()) 166 | acc = matrix.diagonal() / matrix.sum(axis=1) * 100 167 | aacc = acc.mean() 168 | aa = [str(np.round(i, 2)) for i in acc] 169 | acc = ' '.join(aa) 170 | return aacc, acc 171 | else: 172 | return accuracy * 100, mean_ent 173 | 174 | 175 | def train_target(args): 176 | dset_loaders = data_load(args) 177 | # netF = network.ResBase(res_name=args.net).cuda() 178 | 179 | netF = network.ViT().cuda() 180 | netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features, 181 | bottleneck_dim=args.bottleneck).cuda() 182 | netC = network.feat_classifier(type=args.layer, class_num=args.class_num, bottleneck_dim=args.bottleneck).cuda() 183 | ### target 184 | modelpath = args.checkpoint+ '/target_F_par_0.3.pt' 185 | netF.load_state_dict(torch.load(modelpath),strict=False) 186 | modelpath = args.checkpoint + '/target_B_par_0.3.pt' 187 | netB.load_state_dict(torch.load(modelpath)) 188 | modelpath = args.checkpoint + '/target_C_par_0.3.pt' 189 | netC.load_state_dict(torch.load(modelpath)) 190 | ### source 191 | # modelpath = args.checkpoint+ '/source_F.pt' 192 | # netF.load_state_dict(torch.load(modelpath)) 193 | # modelpath = args.checkpoint + '/source_B.pt' 194 | # netB.load_state_dict(torch.load(modelpath)) 195 | # modelpath = args.checkpoint + '/source_C.pt' 196 | # netC.load_state_dict(torch.load(modelpath)) 197 | netC.eval() 198 | for k, v in netC.named_parameters(): 199 | v.requires_grad = False 200 | 201 | param_group = [] 202 | for k, v in netF.named_parameters(): 203 | if args.lr_decay1 > 0: 204 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}] 205 | else: 206 | v.requires_grad = False 207 | for k, v in netB.named_parameters(): 208 | if args.lr_decay2 > 0: 209 | param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}] 210 | else: 211 | v.requires_grad = False 212 | 213 | optimizer = optim.SGD(param_group) 214 | optimizer = op_copy(optimizer) 215 | 216 | max_iter = args.max_epoch * len(dset_loaders["target"]) 217 | interval_iter = max_iter // args.interval 218 | 219 | 220 | 221 | netF.eval() 222 | netB.eval() 223 | if args.dset == 'VISDA-C': 224 | acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netB, netC, True) 225 | log_str = 'Task: {}; Accuracy = {:.2f}%'.format(args.name, acc_s_te) + '\n' + acc_list 226 | else: 227 | acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC, False) 228 | log_str = 'Task: {}; Accuracy = {:.2f}%'.format(args.name, acc_s_te) 229 | 230 | args.out_file.write(log_str + '\n') 231 | args.out_file.flush() 232 | print(log_str + '\n') 233 | 234 | 235 | 236 | return 1 237 | 238 | 239 | def print_args(args): 240 | s = "==========================================\n" 241 | for arg, content in args.__dict__.items(): 242 | s += "{}:{}\n".format(arg, content) 243 | return s 244 | 245 | 246 | def obtain_label(loader, netF, netB, netC, args): 247 | start_test = True 248 | with torch.no_grad(): 249 | iter_test = iter(loader) 250 | for _ in range(len(loader)): 251 | data = iter_test.next() 252 | inputs = data[0] 253 | labels = data[1] 254 | inputs = inputs.cuda() 255 | feas = netB(netF(inputs)) 256 | outputs = netC(feas) 257 | if start_test: 258 | all_fea = feas.float().cpu() 259 | all_output = outputs.float().cpu() 260 | all_label = labels.float() 261 | start_test = False 262 | else: 263 | all_fea = torch.cat((all_fea, feas.float().cpu()), 0) 264 | all_output = torch.cat((all_output, outputs.float().cpu()), 0) 265 | all_label = torch.cat((all_label, labels.float()), 0) 266 | 267 | all_output = nn.Softmax(dim=1)(all_output) 268 | ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1) 269 | unknown_weight = 1 - ent / np.log(args.class_num) 270 | _, predict = torch.max(all_output, 1) 271 | 272 | accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) 273 | if args.distance == 'cosine': 274 | all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) 275 | all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() 276 | 277 | all_fea = all_fea.float().cpu().numpy() 278 | K = all_output.size(1) 279 | aff = all_output.float().cpu().numpy() 280 | initc = aff.transpose().dot(all_fea) 281 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 282 | cls_count = np.eye(K)[predict].sum(axis=0) 283 | labelset = np.where(cls_count > args.threshold) 284 | labelset = labelset[0] 285 | # print(labelset) 286 | 287 | dd = cdist(all_fea, initc[labelset], args.distance) 288 | pred_label = dd.argmin(axis=1) 289 | pred_label = labelset[pred_label] 290 | 291 | for round in range(1): 292 | aff = np.eye(K)[pred_label] 293 | initc = aff.transpose().dot(all_fea) 294 | initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) 295 | dd = cdist(all_fea, initc[labelset], args.distance) 296 | pred_label = dd.argmin(axis=1) 297 | pred_label = labelset[pred_label] 298 | 299 | acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) 300 | log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) 301 | 302 | args.out_file.write(log_str + '\n') 303 | args.out_file.flush() 304 | print(log_str + '\n') 305 | 306 | return pred_label.astype('int') 307 | 308 | 309 | if __name__ == "__main__": 310 | parser = argparse.ArgumentParser(description='SHOT') 311 | parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run") 312 | parser.add_argument('--s', type=int, default=0, help="source") 313 | parser.add_argument('--t', type=int, default=1, help="target") 314 | parser.add_argument('--max_epoch', type=int, default=15, help="max iterations") 315 | parser.add_argument('--interval', type=int, default=15) 316 | parser.add_argument('--batch_size', type=int, default=64, help="batch_size") 317 | parser.add_argument('--worker', type=int, default=4, help="number of workers") 318 | parser.add_argument('--dset', type=str, default='office-home', 319 | choices=['VISDA-C', 'office', 'office-home', 'office-caltech']) 320 | parser.add_argument('--lr', type=float, default=1e-2, help="learning rate") 321 | parser.add_argument('--net', type=str, default='vit', help="alexnet, vgg16, resnet50, res101") 322 | parser.add_argument('--seed', type=int, default=2020, help="random seed") 323 | 324 | parser.add_argument('--gent', type=bool, default=True) 325 | parser.add_argument('--ent', type=bool, default=True) 326 | parser.add_argument('--bnm', type=bool, default=False) 327 | 328 | parser.add_argument('--threshold', type=int, default=0) 329 | parser.add_argument('--cls_par', type=float, default=0.3) 330 | parser.add_argument('--ent_par', type=float, default=1.0) 331 | parser.add_argument('--lr_decay1', type=float, default=0.1) 332 | parser.add_argument('--lr_decay2', type=float, default=1.0) 333 | 334 | parser.add_argument('--bottleneck', type=int, default=256) 335 | parser.add_argument('--epsilon', type=float, default=1e-5) 336 | parser.add_argument('--layer', type=str, default="wn", choices=["linear", "wn"]) 337 | parser.add_argument('--classifier', type=str, default="bn", choices=["ori", "bn"]) 338 | parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"]) 339 | parser.add_argument('--output', type=str, default='san') 340 | parser.add_argument('--output_src', type=str, default='san') 341 | parser.add_argument('--da', type=str, default='uda', choices=['uda', 'pda']) 342 | parser.add_argument('--issave', type=bool, default=True) 343 | parser.add_argument('--checkpoint', type=str, default='') 344 | args = parser.parse_args() 345 | 346 | if args.dset == 'office-home': 347 | names = ['Art', 'Clipart', 'Product', 'RealWorld'] 348 | args.class_num = 65 349 | if args.dset == 'office': 350 | names = ['amazon', 'dslr', 'webcam'] 351 | args.class_num = 31 352 | if args.dset == 'VISDA-C': 353 | names = ['train', 'validation'] 354 | args.class_num = 12 355 | if args.dset == 'office-caltech': 356 | names = ['amazon', 'caltech', 'dslr', 'webcam'] 357 | args.class_num = 10 358 | 359 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 360 | SEED = args.seed 361 | torch.manual_seed(SEED) 362 | torch.cuda.manual_seed(SEED) 363 | np.random.seed(SEED) 364 | random.seed(SEED) 365 | # torch.backends.cudnn.deterministic = True 366 | 367 | for i in range(len(names)): 368 | if i == args.s: 369 | continue 370 | args.t = i 371 | 372 | folder = './data/' 373 | args.s_dset_path = folder + args.dset + '/' + names[args.s] + '_list.txt' 374 | args.t_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 375 | args.test_dset_path = folder + args.dset + '/' + names[args.t] + '_list.txt' 376 | 377 | if args.dset == 'office-home': 378 | if args.da == 'pda': 379 | args.class_num = 65 380 | args.src_classes = [i for i in range(65)] 381 | args.tar_classes = [i for i in range(25)] 382 | 383 | args.output_dir_src = osp.join(args.output_src, args.da, args.dset, names[args.s][0].upper()) 384 | args.output_dir = osp.join(args.output, args.da, args.dset, names[args.s][0].upper() + names[args.t][0].upper()) 385 | args.name = names[args.s][0].upper() + names[args.t][0].upper() 386 | 387 | if not osp.exists(args.output_dir): 388 | os.system('mkdir -p ' + args.output_dir) 389 | if not osp.exists(args.output_dir): 390 | os.mkdir(args.output_dir) 391 | 392 | args.savename = 'par_' + str(args.cls_par) 393 | if args.da == 'pda': 394 | args.gent = '' 395 | args.savename = 'par_' + str(args.cls_par) + '_thr' + str(args.threshold) 396 | args.out_file = open(osp.join(args.output_dir, 'log_' + args.savename + '.txt'), 'w') 397 | args.out_file.write(print_args(args) + '\n') 398 | args.out_file.flush() 399 | train_target(args) -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | def grl_hook(coeff): 17 | def fun1(grad): 18 | return -coeff*grad.clone() 19 | return fun1 20 | 21 | def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None): 22 | softmax_output = input_list[1].detach() 23 | feature = input_list[0] 24 | if random_layer is None: 25 | op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1)) 26 | ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1))) 27 | else: 28 | random_out = random_layer.forward([feature, softmax_output]) 29 | ad_out = ad_net(random_out.view(-1, random_out.size(1))) 30 | batch_size = softmax_output.size(0) // 2 31 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 32 | if entropy is not None: 33 | entropy.register_hook(grl_hook(coeff)) 34 | entropy = 1.0+torch.exp(-entropy) 35 | source_mask = torch.ones_like(entropy) 36 | source_mask[feature.size(0)//2:] = 0 37 | source_weight = entropy*source_mask 38 | target_mask = torch.ones_like(entropy) 39 | target_mask[0:feature.size(0)//2] = 0 40 | target_weight = entropy*target_mask 41 | weight = source_weight / torch.sum(source_weight).detach().item() + \ 42 | target_weight / torch.sum(target_weight).detach().item() 43 | return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item() 44 | else: 45 | return nn.BCELoss()(ad_out, dc_target) 46 | 47 | def DANN(features, ad_net): 48 | ad_out = ad_net(features) 49 | batch_size = ad_out.size(0) // 2 50 | dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda() 51 | return nn.BCELoss()(ad_out, dc_target) 52 | 53 | 54 | class CrossEntropyLabelSmooth(nn.Module): 55 | """Cross entropy loss with label smoothing regularizer. 56 | Reference: 57 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 58 | Equation: y = (1 - epsilon) * y + epsilon / K. 59 | Args: 60 | num_classes (int): number of classes. 61 | epsilon (float): weight. 62 | """ 63 | 64 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 65 | super(CrossEntropyLabelSmooth, self).__init__() 66 | self.num_classes = num_classes 67 | self.epsilon = epsilon 68 | self.use_gpu = use_gpu 69 | self.reduction = reduction 70 | self.logsoftmax = nn.LogSoftmax(dim=1) 71 | 72 | def forward(self, inputs, targets): 73 | """ 74 | Args: 75 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 76 | targets: ground truth labels with shape (num_classes) 77 | """ 78 | log_probs = self.logsoftmax(inputs) 79 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 80 | if self.use_gpu: targets = targets.cuda() 81 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 82 | loss = (- targets * log_probs).sum(dim=1) 83 | if self.reduction: 84 | return loss.mean() 85 | else: 86 | return loss 87 | return loss 88 | 89 | class SupConLoss(nn.Module): 90 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 91 | It also supports the unsupervised contrastive loss in SimCLR""" 92 | def __init__(self, temperature=0.07, contrast_mode='all', 93 | base_temperature=0.07): 94 | super(SupConLoss, self).__init__() 95 | self.temperature = temperature 96 | self.contrast_mode = contrast_mode 97 | self.base_temperature = base_temperature 98 | 99 | def forward(self, features, labels=None, mask=None): 100 | """Compute loss for model. If both `labels` and `mask` are None, 101 | it degenerates to SimCLR unsupervised loss: 102 | https://arxiv.org/pdf/2002.05709.pdf 103 | Args: 104 | features: hidden vector of shape [bsz, n_views, ...]. 105 | labels: ground truth of shape [bsz]. 106 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 107 | has the same class as sample i. Can be asymmetric. 108 | Returns: 109 | A loss scalar. 110 | """ 111 | device = (torch.device('cuda') 112 | if features.is_cuda 113 | else torch.device('cpu')) 114 | 115 | if len(features.shape) < 3: 116 | features=features.unsqueeze(dim=1) 117 | # raise ValueError('`features` needs to be [bsz, n_views, ...],' 118 | # 'at least 3 dimensions are required') 119 | if len(features.shape) > 3: 120 | features = features.view(features.shape[0], features.shape[1], -1) 121 | 122 | batch_size = features.shape[0] 123 | if labels is not None and mask is not None: 124 | raise ValueError('Cannot define both `labels` and `mask`') 125 | elif labels is None and mask is None: 126 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 127 | elif labels is not None: 128 | labels = labels.contiguous().view(-1, 1) 129 | if labels.shape[0] != batch_size: 130 | raise ValueError('Num of labels does not match num of features') 131 | mask = torch.eq(labels, labels.T).float().to(device) 132 | else: 133 | mask = mask.float().to(device) 134 | 135 | contrast_count = features.shape[1] 136 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 137 | if self.contrast_mode == 'one': 138 | anchor_feature = features[:, 0] 139 | anchor_count = 1 140 | elif self.contrast_mode == 'all': 141 | anchor_feature = contrast_feature 142 | anchor_count = contrast_count 143 | else: 144 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 145 | 146 | # compute logits 147 | anchor_dot_contrast = torch.div( 148 | torch.matmul(anchor_feature, contrast_feature.T), 149 | self.temperature) 150 | # for numerical stability 151 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 152 | logits = anchor_dot_contrast - logits_max.detach() 153 | 154 | # tile mask 155 | mask = mask.repeat(anchor_count, contrast_count) 156 | # mask-out self-contrast cases 157 | logits_mask = torch.scatter( 158 | torch.ones_like(mask), 159 | 1, 160 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 161 | 0 162 | ) 163 | mask = mask * logits_mask 164 | 165 | # compute log_prob 166 | exp_logits = torch.exp(logits) * logits_mask 167 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 168 | 169 | # compute mean of log-likelihood over positive 170 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 171 | 172 | # loss 173 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 174 | loss = loss.view(anchor_count, batch_size).mean() 175 | 176 | return loss 177 | class SCELoss(torch.nn.Module): 178 | def __init__(self, alpha, beta, num_classes=10): 179 | super(SCELoss, self).__init__() 180 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 181 | self.alpha = alpha 182 | self.beta = beta 183 | self.num_classes = num_classes 184 | self.cross_entropy = torch.nn.CrossEntropyLoss() 185 | 186 | def forward(self, pred, labels): 187 | # CCE 188 | ce = self.cross_entropy(pred, labels) 189 | 190 | # RCE 191 | pred = F.softmax(pred, dim=1) 192 | pred = torch.clamp(pred, min=1e-7, max=1.0) 193 | label_one_hot = torch.nn.functional.one_hot(labels, self.num_classes).float().to(self.device) 194 | label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0) 195 | rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1)) 196 | 197 | # Loss 198 | loss = self.alpha * ce + self.beta * rce.mean() 199 | return loss 200 | 201 | 202 | class KnowledgeDistillationLoss(nn.Module): 203 | def __init__(self, reduction='mean', alpha=-1.): 204 | super().__init__() 205 | self.reduction = reduction 206 | self.alpha = alpha 207 | 208 | def forward(self, inputs, targets, mask=None): 209 | inputs = inputs.narrow(1, 0, targets.shape[1]) 210 | 211 | outputs = torch.log_softmax(inputs, dim=1) 212 | labels = torch.softmax(targets * self.alpha, dim=1) 213 | 214 | loss = (outputs * labels).mean(dim=1) 215 | 216 | if mask is not None: 217 | loss = loss * mask.float() 218 | 219 | if self.reduction == 'mean': 220 | outputs = -torch.mean(loss) 221 | elif self.reduction == 'sum': 222 | outputs = -torch.sum(loss) 223 | else: 224 | outputs = -loss 225 | 226 | return outputs -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import math 8 | import torch.nn.utils.weight_norm as weightNorm 9 | from collections import OrderedDict 10 | from TransUNet.networks.vit_seg_modeling import VisionTransformer as ViT_seg 11 | from TransUNet.networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg 12 | from non_local_embedded_gaussian import NONLocalBlock2D 13 | 14 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 15 | return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low) 16 | 17 | def init_weights(m): 18 | classname = m.__class__.__name__ 19 | if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1: 20 | nn.init.kaiming_uniform_(m.weight) 21 | nn.init.zeros_(m.bias) 22 | elif classname.find('BatchNorm') != -1: 23 | nn.init.normal_(m.weight, 1.0, 0.02) 24 | nn.init.zeros_(m.bias) 25 | elif classname.find('Linear') != -1: 26 | nn.init.xavier_normal_(m.weight) 27 | nn.init.zeros_(m.bias) 28 | 29 | vgg_dict = {"vgg11":models.vgg11, "vgg13":models.vgg13, "vgg16":models.vgg16, "vgg19":models.vgg19, 30 | "vgg11bn":models.vgg11_bn, "vgg13bn":models.vgg13_bn, "vgg16bn":models.vgg16_bn, "vgg19bn":models.vgg19_bn} 31 | class VGGBase(nn.Module): 32 | def __init__(self, vgg_name): 33 | super(VGGBase, self).__init__() 34 | model_vgg = vgg_dict[vgg_name](pretrained=True) 35 | self.features = model_vgg.features 36 | self.classifier = nn.Sequential() 37 | for i in range(6): 38 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i]) 39 | self.in_features = model_vgg.classifier[6].in_features 40 | 41 | def forward(self, x): 42 | x = self.features(x) 43 | x = x.view(x.size(0), -1) 44 | x = self.classifier(x) 45 | return x 46 | 47 | res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50, 48 | "resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d} 49 | 50 | class ViT(nn.Module): 51 | def __init__(self): 52 | super(ViT, self).__init__() 53 | config_vit = CONFIGS_ViT_seg['R50-ViT-B_16'] 54 | config_vit.n_classes = 100 55 | config_vit.n_skip = 3 56 | config_vit.patches.grid = (int(224 / 16), int(224 / 16)) 57 | self.feature_extractor = ViT_seg(config_vit, img_size=[224, 224], num_classes=config_vit.n_classes) 58 | self.feature_extractor.load_from(weights=np.load(config_vit.pretrained_path)) 59 | self.in_features = 2048 60 | 61 | def forward(self, x): 62 | _, feat = self.feature_extractor(x) 63 | return feat 64 | 65 | class ResBase(nn.Module): 66 | def __init__(self, res_name,se=False, nl=False): 67 | super(ResBase, self).__init__() 68 | model_resnet = res_dict[res_name](pretrained=True) 69 | self.conv1 = model_resnet.conv1 70 | self.bn1 = model_resnet.bn1 71 | self.relu = model_resnet.relu 72 | self.maxpool = model_resnet.maxpool 73 | self.layer1 = model_resnet.layer1 74 | self.layer2 = model_resnet.layer2 75 | self.layer3 = model_resnet.layer3 76 | self.layer4 = model_resnet.layer4 77 | self.avgpool = model_resnet.avgpool 78 | self.in_features = model_resnet.fc.in_features 79 | self.se=se 80 | self.nl=nl 81 | if self.se: 82 | self.SELayer=SELayer(self.in_features) 83 | if self.nl: 84 | self.nlLayer=NONLocalBlock2D(self.in_features) 85 | 86 | def forward(self, x): 87 | x = self.conv1(x) 88 | x = self.bn1(x) 89 | x = self.relu(x) 90 | x = self.maxpool(x) 91 | x = self.layer1(x) 92 | x = self.layer2(x) 93 | x = self.layer3(x) 94 | x = self.layer4(x) 95 | if self.se: 96 | x=self.SELayer(x) 97 | if self.nl: 98 | x=self.nlLayer(x) 99 | x = self.avgpool(x) 100 | x = x.view(x.size(0), -1) 101 | 102 | return x 103 | 104 | class feat_bootleneck(nn.Module): 105 | def __init__(self, feature_dim, bottleneck_dim=256, type="ori"): 106 | super(feat_bootleneck, self).__init__() 107 | self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.dropout = nn.Dropout(p=0.5) 110 | self.bottleneck = nn.Linear(feature_dim, bottleneck_dim) 111 | self.bottleneck.apply(init_weights) 112 | self.type = type 113 | 114 | def forward(self, x): 115 | x = self.bottleneck(x) 116 | if self.type == "bn": 117 | x = self.bn(x) 118 | return x 119 | 120 | class feat_classifier(nn.Module): 121 | def __init__(self, class_num, bottleneck_dim=256, type="linear"): 122 | super(feat_classifier, self).__init__() 123 | self.type = type 124 | if type == 'wn': 125 | self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight") 126 | self.fc.apply(init_weights) 127 | else: 128 | self.fc = nn.Linear(bottleneck_dim, class_num) 129 | self.fc.apply(init_weights) 130 | 131 | def forward(self, x): 132 | x = self.fc(x) 133 | return x 134 | 135 | class feat_classifier_two(nn.Module): 136 | def __init__(self, class_num, input_dim, bottleneck_dim=256): 137 | super(feat_classifier_two, self).__init__() 138 | self.type = type 139 | self.fc0 = nn.Linear(input_dim, bottleneck_dim) 140 | self.fc0.apply(init_weights) 141 | self.fc1 = nn.Linear(bottleneck_dim, class_num) 142 | self.fc1.apply(init_weights) 143 | 144 | def forward(self, x): 145 | x = self.fc0(x) 146 | x = self.fc1(x) 147 | return x 148 | 149 | class Res50(nn.Module): 150 | def __init__(self): 151 | super(Res50, self).__init__() 152 | model_resnet = models.resnet50(pretrained=True) 153 | self.conv1 = model_resnet.conv1 154 | self.bn1 = model_resnet.bn1 155 | self.relu = model_resnet.relu 156 | self.maxpool = model_resnet.maxpool 157 | self.layer1 = model_resnet.layer1 158 | self.layer2 = model_resnet.layer2 159 | self.layer3 = model_resnet.layer3 160 | self.layer4 = model_resnet.layer4 161 | self.avgpool = model_resnet.avgpool 162 | self.in_features = model_resnet.fc.in_features 163 | self.fc = model_resnet.fc 164 | 165 | def forward(self, x): 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | x = self.relu(x) 169 | x = self.maxpool(x) 170 | x = self.layer1(x) 171 | x = self.layer2(x) 172 | x = self.layer3(x) 173 | x = self.layer4(x) 174 | x = self.avgpool(x) 175 | x = x.view(x.size(0), -1) 176 | y = self.fc(x) 177 | return x, y 178 | class SELayer(nn.Module): 179 | def __init__(self, channel, reduction=16): 180 | super(SELayer, self).__init__() 181 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 182 | self.fc = nn.Sequential( 183 | nn.Linear(channel, channel // reduction, bias=False), 184 | nn.ReLU(inplace=True), 185 | nn.Linear(channel // reduction, channel, bias=False), 186 | nn.Sigmoid() 187 | ) 188 | 189 | def forward(self, x): 190 | b, c, _, _ = x.size() 191 | y = self.avg_pool(x).view(b, c) 192 | y = self.fc(y).view(b, c, 1, 1) 193 | return x * y.expand_as(x) -------------------------------------------------------------------------------- /non_local_embedded_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | super(_NonLocalBlockND, self).__init__() 9 | 10 | assert dimension in [1, 2, 3] 11 | 12 | self.dimension = dimension 13 | self.sub_sample = sub_sample 14 | 15 | self.in_channels = in_channels 16 | self.inter_channels = inter_channels 17 | 18 | if self.inter_channels is None: 19 | self.inter_channels = in_channels // 2 20 | if self.inter_channels == 0: 21 | self.inter_channels = 1 22 | 23 | if dimension == 3: 24 | conv_nd = nn.Conv3d 25 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 26 | bn = nn.BatchNorm3d 27 | elif dimension == 2: 28 | conv_nd = nn.Conv2d 29 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 30 | bn = nn.BatchNorm2d 31 | else: 32 | conv_nd = nn.Conv1d 33 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 34 | bn = nn.BatchNorm1d 35 | 36 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 37 | kernel_size=1, stride=1, padding=0) 38 | 39 | if bn_layer: 40 | self.W = nn.Sequential( 41 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 42 | kernel_size=1, stride=1, padding=0), 43 | bn(self.in_channels) 44 | ) 45 | nn.init.constant_(self.W[1].weight, 0) 46 | nn.init.constant_(self.W[1].bias, 0) 47 | else: 48 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0) 50 | nn.init.constant_(self.W.weight, 0) 51 | nn.init.constant_(self.W.bias, 0) 52 | 53 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 56 | kernel_size=1, stride=1, padding=0) 57 | 58 | if sub_sample: 59 | self.g = nn.Sequential(self.g, max_pool_layer) 60 | self.phi = nn.Sequential(self.phi, max_pool_layer) 61 | 62 | def forward(self, x): 63 | ''' 64 | :param x: (b, c, t, h, w) 65 | :return: 66 | ''' 67 | 68 | batch_size = x.size(0) 69 | 70 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 71 | g_x = g_x.permute(0, 2, 1) 72 | 73 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 74 | theta_x = theta_x.permute(0, 2, 1) 75 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 76 | f = torch.matmul(theta_x, phi_x) 77 | f_div_C = F.softmax(f, dim=-1) 78 | 79 | y = torch.matmul(f_div_C, g_x) 80 | y = y.permute(0, 2, 1).contiguous() 81 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 82 | W_y = self.W(y) 83 | z = W_y+x 84 | 85 | return z 86 | 87 | 88 | class NONLocalBlock1D(_NonLocalBlockND): 89 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 90 | super(NONLocalBlock1D, self).__init__(in_channels, 91 | inter_channels=inter_channels, 92 | dimension=1, sub_sample=sub_sample, 93 | bn_layer=bn_layer) 94 | 95 | 96 | class NONLocalBlock2D(_NonLocalBlockND): 97 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 98 | super(NONLocalBlock2D, self).__init__(in_channels, 99 | inter_channels=inter_channels, 100 | dimension=2, sub_sample=sub_sample, 101 | bn_layer=bn_layer) 102 | 103 | 104 | class NONLocalBlock3D(_NonLocalBlockND): 105 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 106 | super(NONLocalBlock3D, self).__init__(in_channels, 107 | inter_channels=inter_channels, 108 | dimension=3, sub_sample=sub_sample, 109 | bn_layer=bn_layer) 110 | 111 | 112 | if __name__ == '__main__': 113 | import torch 114 | 115 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 116 | img = torch.zeros(2, 3, 20) 117 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 118 | out = net(img) 119 | print(out.size()) 120 | 121 | img = torch.zeros(2, 3, 20, 20) 122 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 123 | out = net(img) 124 | print(out.size()) 125 | 126 | img = torch.randn(2, 3, 8, 20, 20) 127 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 128 | out = net(img) 129 | print(out.size()) -------------------------------------------------------------------------------- /run_office_home_more.sh: -------------------------------------------------------------------------------- 1 | ### pda 2 | python image_source.py --trte val --output ckps/source/ --da pda --gpu_id 0 --dset office-home --max_epoch 50 --s 0 --seed 2019; 3 | python image_target.py --cls_par 0.3 --threshold 10 --da pda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --seed 2019 --kd Ture; 4 | 5 | python image_source.py --trte val --output ckps/source/ --da pda --gpu_id 0 --dset office-home --max_epoch 50 --s 1 --seed 2019; 6 | python image_target.py --cls_par 0.3 --threshold 10 --da pda --dset office-home --gpu_id 0 --s 1 --output_src ckps/source/ --output ckps/target/ --seed 2019 --kd Ture; 7 | 8 | python image_source.py --trte val --output ckps/source/ --da pda --gpu_id 0 --dset office-home --max_epoch 50 --s 2 --seed 2019; 9 | python image_target.py --cls_par 0.3 --threshold 10 --da pda --dset office-home --gpu_id 0 --s 2 --output_src ckps/source/ --output ckps/target/ --seed 2019 --kd Ture; 10 | 11 | python image_source.py --trte val --output ckps/source/ --da pda --gpu_id 0 --dset office-home --max_epoch 50 --s 3 --seed 2019; 12 | python image_target.py --cls_par 0.3 --threshold 10 --da pda --dset office-home --gpu_id 0 --s 3 --output_src ckps/source/ --output ckps/target/ --seed 2019 --kd Ture; 13 | 14 | 15 | ### oda 16 | python image_source.py --trte val --output ckps/source/ --da oda --gpu_id 0 --dset office-home --max_epoch 50 --s 0 --seed 2019; 17 | python image_target_oda.py --cls_par 0.3 --da oda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --seed 2019 --kd Ture; 18 | 19 | python image_source.py --trte val --output ckps/source/ --da oda --gpu_id 0 --dset office-home --max_epoch 50 --s 1 --seed 2019; 20 | python image_target_oda.py --cls_par 0.3 --da oda --dset office-home --gpu_id 0 --s 1 --output_src ckps/source/ --output ckps/target/ --seed 2019 --kd Ture; 21 | 22 | python image_source.py --trte val --output ckps/source/ --da oda --gpu_id 0 --dset office-home --max_epoch 50 --s 2 --seed 2019; 23 | python image_target_oda.py --cls_par 0.3 --da oda --dset office-home --gpu_id 0 --s 2 --output_src ckps/source/ --output ckps/target/ --seed 2019 --kd Ture; 24 | 25 | python image_source.py --trte val --output ckps/source/ --da oda --gpu_id 0 --dset office-home --max_epoch 50 --s 3 --seed 2019; 26 | python image_target_oda.py --cls_par 0.3 --da oda --dset office-home --gpu_id 0 --s 3 --output_src ckps/source/ --output ckps/target/ --seed 2019 --kd Ture; 27 | -------------------------------------------------------------------------------- /run_office_home_uda.sh: -------------------------------------------------------------------------------- 1 | 2 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-home --max_epoch 100 --s 0; 3 | python image_target.py --cls_par 0.3 --da uda --dset office-home --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --kd Ture; 4 | 5 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-home --max_epoch 100 --s 1; 6 | python image_target.py --cls_par 0.3 --da uda --dset office-home --gpu_id 0 --s 1 --output_src ckps/source/ --output ckps/target/ --kd Ture; 7 | 8 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-home --max_epoch 100 --s 2; 9 | python image_target.py --cls_par 0.3 --da uda --dset office-home --gpu_id 0 --s 2 --output_src ckps/source/ --output ckps/target/ --kd Ture; 10 | 11 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office-home --max_epoch 100 --s 3; 12 | python image_target.py --cls_par 0.3 --da uda --dset office-home --gpu_id 0 --s 3 --output_src ckps/source/ --output ckps/target/ --kd Ture; 13 | 14 | 15 | -------------------------------------------------------------------------------- /run_office_uda.sh: -------------------------------------------------------------------------------- 1 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 0; 2 | python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --kd Ture; 3 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 1; 4 | python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 1 --output_src ckps/source/ --output ckps/target/ --kd Ture; 5 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 2; 6 | python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 2 --output_src ckps/source/ --output ckps/target/ --kd Ture; 7 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 1 --dset office --max_epoch 100 --s 0 --bs 2; 8 | -------------------------------------------------------------------------------- /run_office_uda_ab.sh: -------------------------------------------------------------------------------- 1 | ### se 2 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 0 --net resnet50 --se Ture ; 3 | python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --se Ture --net resnet50; 4 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 1 --net resnet50 --se Ture; 5 | python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 1 --output_src ckps/source/ --output ckps/target/ --net resnet50 --se Ture; 6 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 2 --net resnet50 --se Ture ; 7 | python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 2 --output_src ckps/source/ --output ckps/target/ --net resnet50 --se Ture; 8 | ### nonlocal 9 | #python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 0 --net resnet50 --nl Ture; 10 | #python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --net resnet50 --nl Ture; 11 | #python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 1 --net resnet50 --nl Ture; 12 | #python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 1 --output_src ckps/source/ --output ckps/target/ --net resnet50 --nl Ture; 13 | #python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset office --max_epoch 100 --s 2 --net resnet50 --nl Ture; 14 | #python image_target.py --cls_par 0.3 --da uda --dset office --gpu_id 0 --s 2 --output_src ckps/source/ --output ckps/target/ --net resnet50 --nl Ture; 15 | 16 | -------------------------------------------------------------------------------- /run_visda.sh: -------------------------------------------------------------------------------- 1 | python image_source.py --trte val --output ckps/source/ --da uda --gpu_id 0 --dset VISDA-C --net resnet101 --lr 1e-3 --max_epoch 10 --s 0; 2 | python image_target.py --cls_par 0.3 --da uda --dset VISDA-C --gpu_id 0 --s 0 --output_src ckps/source/ --output ckps/target/ --kd Ture --lr 1e-3; 3 | --------------------------------------------------------------------------------