├── .DS_Store ├── .gitignore ├── .idea ├── MultiClassDA.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── LICENSE ├── README.md ├── config ├── __init__.py └── config.py ├── data ├── domainnet_data_prepare.py ├── folder_mec.py ├── folder_new.py ├── prepare_data.py └── vision.py ├── experiments ├── .DS_Store ├── Clarification_on_tasks.md ├── Open │ └── logOffice31_amazon_s_busto2webcam_t_busto_SymmNetsV2_None │ │ └── log.txt ├── Partial │ └── logOfficeHome_Art2Real_World_25_SymmNetsV2_None │ │ └── log.txt ├── SymmNets │ ├── logImageCLEF_c2i_SymmNetsV1_None │ │ └── log.txt │ ├── logImageCLEF_c2i_SymmNetsV2_None │ │ └── log.txt │ ├── logOffice31_amazon2dslr_SymmNetsV2_Noneclosed │ │ └── log.txt │ ├── logOffice31_amazon2dslr_SymmNetsV2_Noneclosedsc │ │ └── log.txt │ ├── logOffice31_webcam2amazon_SymmNetsV2_Noneclosed │ │ └── log.txt │ ├── logOffice31_webcam2amazon_SymmNetsV2_Noneclosedsc │ │ └── log.txt │ ├── logVisDA_train2validation_SymmNetsV2_None │ │ └── log.txt │ └── logres101VisDA_train2validation_SymmNetsV2_None │ │ └── log.txt ├── ckpt │ ├── logImageCLEF_c2i_McDalNet_CE │ │ └── log.txt │ ├── logImageCLEF_c2i_McDalNet_DANN │ │ └── log.txt │ ├── logImageCLEF_c2i_McDalNet_KL │ │ └── log.txt │ ├── logImageCLEF_c2i_McDalNet_L1 │ │ └── log.txt │ ├── logImageCLEF_c2i_McDalNet_MDD │ │ └── log.txt │ ├── logVisDA_train2validation_McDalNet_CE │ │ └── log.txt │ ├── logVisDA_train2validation_McDalNet_DANN │ │ └── log.txt │ ├── logVisDA_train2validation_McDalNet_KL │ │ └── log.txt │ ├── logVisDA_train2validation_McDalNet_L1 │ │ └── log.txt │ └── logVisDA_train2validation_McDalNet_MDD │ │ └── log.txt └── configs │ ├── .DS_Store │ ├── ImageCLEF │ ├── .DS_Store │ ├── McDalNet │ │ └── clef_train_c2i_cfg.yaml │ └── SymmNets │ │ └── clef_train_c2i_cfg.yaml │ ├── Office31 │ ├── .DS_Store │ ├── McDalNet │ │ └── office31_train_webcam2amazon_cfg.yaml │ └── SymmNets │ │ ├── .DS_Store │ │ ├── office31_train_amazon2dslr_cfg.yaml │ │ ├── office31_train_amazon2dslr_cfg_SC.yaml │ │ ├── office31_train_amazon2webcam_open_cfg.yaml │ │ ├── office31_train_webcam2amazon_cfg.yaml │ │ └── office31_train_webcam2amazon_cfg_SC.yaml │ ├── OfficeHome │ ├── .DS_Store │ └── SymmNets │ │ └── home_train_A2R_partial_cfg.yaml │ └── VisDA │ ├── McDalNet │ └── visda17_train_train2val_cfg.yaml │ └── SymmNets │ ├── visda17_train_train2val_cfg.yaml │ └── visda17_train_train2val_cfg_res101.yaml ├── models ├── loss_utils.py ├── resnet_McDalNet.py └── resnet_SymmNet.py ├── run.sh ├── run_symmnets.sh ├── run_temp.sh ├── solver ├── McDalNet_solver.py ├── SymmNetsV1_solver.py ├── SymmNetsV2Open_solver.py ├── SymmNetsV2Partial_solver.py ├── SymmNetsV2SC_solver.py ├── SymmNetsV2_solver.py └── base_solver.py ├── tools └── train.py └── utils ├── __init__.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.idea/MultiClassDA.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 13 | 14 | 19 | 20 | 21 | 22 | target_train_dataset 23 | DatasetFolder 24 | resume_dict 25 | self.opt 26 | nClass 27 | model_eval 28 | McDalNetLoss 29 | SAVE_DIR 30 | iters_per_epoch 31 | cfg 32 | PROCE 33 | BasicBlock 34 | net 35 | prob_s 36 | counter_all_auxi1 37 | counter_acc_auxi2 38 | expansion 39 | exp_name 40 | self.num_classes: 41 | source_gt_for_ft_in_fst 42 | 43 | 44 | counter_all_fs 45 | counter_acc_ft 46 | 47 | 48 | 49 | 51 | 52 | 53 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 1581836641713 80 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | file://$PROJECT_DIR$/data/prepare_data.py 91 | 122 92 | 94 | 95 | file://$PROJECT_DIR$/solver/SymmNetsV1_solver.py 96 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yabin Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiClassDA (TPAMI2020) 2 | Code release for ["Unsupervised Multi-Class Domain Adaptation: Theory, Algorithms, and Practice"](https://arxiv.org/pdf/2002.08681.pdf), which is 3 | an extension of our preliminary work of SymmNets [[Paper](https://zpascal.net/cvpr2019/Zhang_Domain-Symmetric_Networks_for_Adversarial_Domain_Adaptation_CVPR_2019_paper.pdf)] [[Code](https://github.com/YBZh/SymNets)] 4 | 5 | Please refer to the "run_temp.sh" for the usage. 6 | All expeimental results are logged in the file of "./experiments" 7 | 8 | 9 | 10 | 11 | ## Included codes: 12 | 1. Codes of McDalNets -->./solver/McDalNet_solver.py 13 | 2. Codes of SymNets-V2 14 | 1. For the Closed Set DA -->./solver/SymmNetsV2_solver.py 15 | 2. For the Strongthened Closed Set DA -->./solver/SymmNetsV2SC_solver.py 16 | 3. For the Partial DA -->./solver/SSymmNetsV2Partial_solver.py 17 | 4. For the Open Set DA -->./solver/SymmNetsV2Open_solver.py 18 | 19 | 20 | ## Dataset 21 | The structure of the dataset should be like 22 | 23 | ``` 24 | Office-31 25 | |_ amazon 26 | | |_ back_pack 27 | | |_ .jpg 28 | | |_ ... 29 | | |_ .jpg 30 | | |_ bike 31 | | |_ .jpg 32 | | |_ ... 33 | | |_ .jpg 34 | | |_ ... 35 | |_ dslr 36 | | |_ back_pack 37 | | |_ .jpg 38 | | |_ ... 39 | | |_ .jpg 40 | | |_ bike 41 | | |_ .jpg 42 | | |_ ... 43 | | |_ .jpg 44 | | |_ ... 45 | |_ ... 46 | ``` 47 | 48 | 49 | ## Citation 50 | 51 | @inproceedings{zhang2019domain, 52 | title={Domain-symmetric networks for adversarial domain adaptation}, 53 | author={Zhang, Yabin and Tang, Hui and Jia, Kui and Tan, Mingkui}, 54 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 55 | pages={5031--5040}, 56 | year={2019} 57 | } 58 | @article{zhang2020unsupervised, 59 | title={Unsupervised Multi-Class Domain Adaptation: Theory, Algorithms, and Practice}, 60 | author={Zhang, Yabin and Deng, Bin and Tang, Hui and Zhang, Lei and Jia, Kui}, 61 | journal=IEEE Transactions on Pattern Analysis and Machine Intelligence}, 62 | year={2020} 63 | publisher={IEEE} 64 | } 65 | 66 | ## Contact 67 | If you have any problem about our code, feel free to contact 68 | - zhang.yabin@mail.scut.edu.cn 69 | 70 | or describe your problem in Issues. 71 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/config/__init__.py -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | import numpy as np 4 | from easydict import EasyDict as edict 5 | 6 | __C = edict() 7 | cfg = __C 8 | 9 | # Dataset options 10 | # 11 | __C.DATASET = edict() 12 | __C.DATASET.NUM_CLASSES = 0 13 | __C.DATASET.DATASET = '' 14 | __C.DATASET.DATAROOT = '' 15 | __C.DATASET.SOURCE_NAME = '' 16 | __C.DATASET.TARGET_NAME = '' 17 | __C.DATASET.VAL_NAME = '' 18 | 19 | # Model options 20 | __C.MODEL = edict() 21 | __C.MODEL.FEATURE_EXTRACTOR = 'resnet101' 22 | __C.MODEL.PRETRAINED = True 23 | # __C.MODEL.FC_HIDDEN_DIMS = () ### options for multiple layers. 24 | 25 | # data pre-processing options 26 | # 27 | __C.DATA_TRANSFORM = edict() 28 | __C.DATA_TRANSFORM.TYPE = 'ours' ### trun to simple for the VisDA dataset. 29 | 30 | 31 | # Training options 32 | # 33 | __C.TRAIN = edict() 34 | # batch size setting 35 | __C.TRAIN.SOURCE_BATCH_SIZE = 128 36 | __C.TRAIN.TARGET_BATCH_SIZE = 128 37 | 38 | # learning rate schedule 39 | __C.TRAIN.BASE_LR = 0.01 40 | __C.TRAIN.MOMENTUM = 0.9 41 | __C.TRAIN.OPTIMIZER = 'SGD' 42 | __C.TRAIN.WEIGHT_DECAY = 0.0001 43 | __C.TRAIN.LR_SCHEDULE = 'inv' 44 | __C.TRAIN.MAX_EPOCH = 400 45 | __C.TRAIN.SAVING = False ## whether to save the intermediate status of model. 46 | __C.TRAIN.PROCESS_COUNTER = 'iteration' 47 | # __C.TRAIN.STOP_THRESHOLDS = (0.001, 0.001, 0.001) 48 | # __C.TRAIN.TEST_INTERVAL = 1.0 # percentage of total iterations each loop 49 | # __C.TRAIN.SAVE_CKPT_INTERVAL = 1.0 # percentage of total iterations in each loop 50 | 51 | 52 | __C.STRENGTHEN = edict() 53 | __C.STRENGTHEN.DATALOAD = 'normal' ## normal | hard | soft. The original class aware sampling adopt the hard mode. 54 | __C.STRENGTHEN.PERCATE = 10 55 | __C.STRENGTHEN.CLUSTER_FREQ = 6 56 | 57 | 58 | 59 | # optimizer options 60 | __C.MCDALNET = edict() 61 | __C.MCDALNET.DISTANCE_TYPE = '' ## choose in L1 | KL | CE | MDD | DANN | SourceOnly 62 | 63 | # optimizer options 64 | __C.ADAM = edict() ### adopted by the Digits dataset only 65 | __C.ADAM.BETA1 = 0.9 66 | __C.ADAM.BETA2 = 0.999 67 | 68 | __C.INV = edict() 69 | __C.INV.ALPHA = 10.0 70 | __C.INV.BETA = 0.75 71 | 72 | __C.OPEN = edict() 73 | __C.OPEN.WEIGHT_UNK = 6.0 74 | 75 | # Testing options 76 | # 77 | __C.TEST = edict() 78 | __C.TEST.BATCH_SIZE = 128 79 | 80 | 81 | # MISC 82 | __C.RESUME = '' 83 | __C.TASK = 'closed' ## closed | partial | open 84 | __C.EVAL_METRIC = "accu" # "mean_accu" as alternative 85 | __C.EXP_NAME = 'exp' 86 | __C.SAVE_DIR = '' 87 | __C.NUM_WORKERS = 6 88 | __C.PRINT_STEP = 3 89 | 90 | def _merge_a_into_b(a, b): 91 | """Merge config dictionary a into config dictionary b, clobbering the 92 | options in b whenever they are also specified in a. 93 | """ 94 | if type(a) is not edict: 95 | return 96 | 97 | for k in a: 98 | # a must specify keys that are in b 99 | v = a[k] 100 | if k not in b: 101 | raise KeyError('{} is not a valid config key'.format(k)) 102 | 103 | # the types must match, too 104 | old_type = type(b[k]) 105 | if old_type is not type(v): 106 | if isinstance(b[k], np.ndarray): 107 | v = np.array(v, dtype=b[k].dtype) 108 | else: 109 | raise ValueError(('Type mismatch ({} vs. {}) ' 110 | 'for config key: {}').format(type(b[k]), 111 | type(v), k)) 112 | 113 | # recursively merge dicts 114 | if type(v) is edict: 115 | try: 116 | _merge_a_into_b(a[k], b[k]) 117 | except: 118 | print('Error under config key: {}'.format(k)) 119 | raise 120 | else: 121 | b[k] = v 122 | 123 | def cfg_from_file(filename): 124 | """Load a config file and merge it into the default options.""" 125 | import yaml 126 | if filename[-1] == '\r': 127 | filename = filename[:-1] ## delete the '\r' at the end of the str 128 | with open(filename, 'r') as f: 129 | yaml_cfg = edict(yaml.load(f, Loader=yaml.FullLoader)) 130 | 131 | _merge_a_into_b(yaml_cfg, __C) 132 | 133 | def cfg_from_list(cfg_list): 134 | """Set config keys via list (e.g., from command line).""" 135 | from ast import literal_eval 136 | assert len(cfg_list) % 2 == 0 137 | for k, v in zip(cfg_list[0::2], cfg_list[1::2]): 138 | key_list = k.split('.') 139 | d = __C 140 | for subkey in key_list[:-1]: 141 | assert subkey in d 142 | d = d[subkey] 143 | subkey = key_list[-1] 144 | assert subkey in d 145 | try: 146 | value = literal_eval(v) 147 | except: 148 | # handle the case when v is a string literal 149 | value = v 150 | assert type(value) == type(d[subkey]), \ 151 | 'type {} does not match original type {}'.format( 152 | type(value), type(d[subkey])) 153 | d[subkey] = value -------------------------------------------------------------------------------- /data/domainnet_data_prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import ipdb 4 | 5 | data_file = ['clipart', 'painting', 'real', 'sketch', 'infograph', 'quickdraw'] 6 | current_path = os.getcwd() 7 | for task in data_file: 8 | source_path = current_path + '/' + task 9 | source_train_path = current_path + '/' + task + '_train' 10 | if not os.path.isdir(source_train_path): 11 | os.makedirs(source_train_path) 12 | txt_train = current_path + '/' + task + '_train.txt' 13 | txt_file = open(txt_train) 14 | line = txt_file.readline() 15 | while line: 16 | image_path = line.split(' ')[0] 17 | image_path_split_list = image_path.split('/') 18 | source_image_path = source_path + '/' + image_path_split_list[1] + '/' + image_path_split_list[2] 19 | source_train_category = source_train_path + '/' + image_path_split_list[1] 20 | if not os.path.isdir(source_train_category): 21 | os.makedirs(source_train_category) 22 | source_train_image_path = source_train_category + '/' + image_path_split_list[2] 23 | print('copy image from %s -> %s' % (source_image_path, source_train_image_path)) 24 | shutil.copyfile(source_image_path, source_train_image_path) 25 | line = txt_file.readline() 26 | 27 | 28 | 29 | source_test_path = current_path + '/' + task + '_test' 30 | if not os.path.isdir(source_test_path): 31 | os.makedirs(source_test_path) 32 | txt_test = current_path + '/' + task + '_test.txt' 33 | txt_file = open(txt_test) 34 | line = txt_file.readline() 35 | while line: 36 | image_path = line.split(' ')[0] 37 | image_path_split_list = image_path.split('/') 38 | source_image_path = source_path + '/' + image_path_split_list[1] + '/' + image_path_split_list[2] 39 | source_test_category = source_test_path + '/' + image_path_split_list[1] 40 | if not os.path.isdir(source_test_category): 41 | os.makedirs(source_test_category) 42 | source_test_image_path = source_test_category + '/' + image_path_split_list[2] 43 | print('copy image from %s -> %s' % (source_image_path, source_test_image_path)) 44 | shutil.copyfile(source_image_path, source_test_image_path) 45 | line = txt_file.readline() 46 | -------------------------------------------------------------------------------- /data/folder_mec.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that XXX 5 | ############################################################################### 6 | from data.vision import VisionDataset 7 | 8 | from PIL import Image 9 | 10 | import os 11 | import os.path 12 | import sys 13 | ########################### added part ### 14 | from PIL import ImageFile 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True ## used to handle some error when loading the special images. 16 | 17 | def has_file_allowed_extension(filename, extensions): 18 | """Checks if a file is an allowed extension. 19 | Args: 20 | filename (string): path to a file 21 | extensions (tuple of strings): extensions to consider (lowercase) 22 | Returns: 23 | bool: True if the filename ends with one of given extensions 24 | """ 25 | return filename.lower().endswith(extensions) 26 | 27 | 28 | def is_image_file(filename): 29 | """Checks if a file is an allowed image extension. 30 | Args: 31 | filename (string): path to a file 32 | Returns: 33 | bool: True if the filename ends with a known image extension 34 | """ 35 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 36 | 37 | 38 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): 39 | images = [] 40 | dir = os.path.expanduser(dir) 41 | if not ((extensions is None) ^ (is_valid_file is None)): 42 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 43 | if extensions is not None: 44 | def is_valid_file(x): 45 | return has_file_allowed_extension(x, extensions) 46 | for target in sorted(class_to_idx.keys()): 47 | d = os.path.join(dir, target) 48 | if not os.path.isdir(d): 49 | continue 50 | for root, _, fnames in sorted(os.walk(d, followlinks=True)): 51 | for fname in sorted(fnames): 52 | path = os.path.join(root, fname) 53 | if is_valid_file(path): 54 | item = (path, class_to_idx[target]) 55 | images.append(item) 56 | 57 | return images 58 | 59 | 60 | class DatasetFolder(VisionDataset): 61 | """A generic data loader where the samples are arranged in this way: :: 62 | root/class_x/xxx.ext 63 | root/class_x/xxy.ext 64 | root/class_x/xxz.ext 65 | root/class_y/123.ext 66 | root/class_y/nsdf3.ext 67 | root/class_y/asd932_.ext 68 | Args: 69 | root (string): Root directory path. 70 | loader (callable): A function to load a sample given its path. 71 | extensions (tuple[string]): A list of allowed extensions. 72 | both extensions and is_valid_file should not be passed. 73 | transform (callable, optional): A function/transform that takes in 74 | a sample and returns a transformed version. 75 | E.g, ``transforms.RandomCrop`` for images. 76 | target_transform (callable, optional): A function/transform that takes 77 | in the target and transforms it. 78 | is_valid_file (callable, optional): A function that takes path of a file 79 | and check if the file is a valid file (used to check of corrupt files) 80 | both extensions and is_valid_file should not be passed. 81 | Attributes: 82 | classes (list): List of the class names. 83 | class_to_idx (dict): Dict with items (class_name, class_index). 84 | samples (list): List of (sample path, class_index) tuples 85 | targets (list): The class_index value for each image in the dataset 86 | """ 87 | 88 | def __init__(self, root, loader, extensions=None, transform=None, transform_mec=None, 89 | target_transform=None, is_valid_file=None): 90 | super(DatasetFolder, self).__init__(root, transform=transform, 91 | target_transform=target_transform) 92 | classes, class_to_idx = self._find_classes(self.root) 93 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 94 | if len(samples) == 0: 95 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n" 96 | "Supported extensions are: " + ",".join(extensions))) 97 | 98 | self.loader = loader 99 | self.extensions = extensions 100 | 101 | self.classes = classes 102 | self.class_to_idx = class_to_idx 103 | self.samples = samples 104 | self.targets = [s[1] for s in samples] 105 | self.transform_mec = transform_mec 106 | 107 | def _find_classes(self, dir): 108 | """ 109 | Finds the class folders in a dataset. 110 | Args: 111 | dir (string): Root directory path. 112 | Returns: 113 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 114 | Ensures: 115 | No class is a subdirectory of another. 116 | """ 117 | if sys.version_info >= (3, 5): 118 | # Faster and available in Python 3.5 and above 119 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 120 | else: 121 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 122 | classes.sort() 123 | class_to_idx = {classes[i]: i for i in range(len(classes))} 124 | return classes, class_to_idx 125 | 126 | def __getitem__(self, index): 127 | """ 128 | Args: 129 | index (int): Index 130 | Returns: 131 | tuple: (sample, target) where target is class_index of the target class. 132 | """ 133 | path, target = self.samples[index] 134 | sample = self.loader(path) 135 | if self.transform is not None: 136 | sample_ori = self.transform(sample) 137 | if self.transform_mec is not None: 138 | sample_mec = self.transform(sample) 139 | if self.target_transform is not None: 140 | target = self.target_transform(target) 141 | 142 | return sample_ori, sample_mec, target, path 143 | 144 | def __len__(self): 145 | return len(self.samples) 146 | 147 | 148 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 149 | 150 | 151 | def pil_loader(path): 152 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 153 | with open(path, 'rb') as f: 154 | img = Image.open(f) 155 | return img.convert('RGB') 156 | 157 | 158 | def accimage_loader(path): 159 | import accimage 160 | try: 161 | return accimage.Image(path) 162 | except IOError: 163 | # Potentially a decoding problem, fall back to PIL.Image 164 | return pil_loader(path) 165 | 166 | 167 | def default_loader(path): 168 | from torchvision import get_image_backend 169 | if get_image_backend() == 'accimage': 170 | return accimage_loader(path) 171 | else: 172 | return pil_loader(path) 173 | 174 | 175 | class ImageFolder_MEC(DatasetFolder): 176 | """A generic data loader where the images are arranged in this way: :: 177 | root/dog/xxx.png 178 | root/dog/xxy.png 179 | root/dog/xxz.png 180 | root/cat/123.png 181 | root/cat/nsdf3.png 182 | root/cat/asd932_.png 183 | Args: 184 | root (string): Root directory path. 185 | transform (callable, optional): A function/transform that takes in an PIL image 186 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 187 | target_transform (callable, optional): A function/transform that takes in the 188 | target and transforms it. 189 | loader (callable, optional): A function to load an image given its path. 190 | is_valid_file (callable, optional): A function that takes path of an Image file 191 | and check if the file is a valid file (used to check of corrupt files) 192 | Attributes: 193 | classes (list): List of the class names. 194 | class_to_idx (dict): Dict with items (class_name, class_index). 195 | imgs (list): List of (image path, class_index) tuples 196 | """ 197 | 198 | def __init__(self, root, transform=None, transform_mec=None, target_transform=None, 199 | loader=default_loader, is_valid_file=None): 200 | super(ImageFolder_MEC, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 201 | transform=transform, 202 | transform_mec=transform_mec, 203 | target_transform=target_transform, 204 | is_valid_file=is_valid_file) 205 | self.imgs = self.samples -------------------------------------------------------------------------------- /data/folder_new.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that XXX 5 | ############################################################################### 6 | from data.vision import VisionDataset 7 | 8 | from PIL import Image 9 | 10 | import os 11 | import os.path 12 | import sys 13 | ########################### added part ### 14 | from PIL import ImageFile 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True ## used to handle some error when loading the special images. 16 | 17 | def has_file_allowed_extension(filename, extensions): 18 | """Checks if a file is an allowed extension. 19 | Args: 20 | filename (string): path to a file 21 | extensions (tuple of strings): extensions to consider (lowercase) 22 | Returns: 23 | bool: True if the filename ends with one of given extensions 24 | """ 25 | return filename.lower().endswith(extensions) 26 | 27 | 28 | def is_image_file(filename): 29 | """Checks if a file is an allowed image extension. 30 | Args: 31 | filename (string): path to a file 32 | Returns: 33 | bool: True if the filename ends with a known image extension 34 | """ 35 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 36 | 37 | 38 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): 39 | images = [] 40 | dir = os.path.expanduser(dir) 41 | if not ((extensions is None) ^ (is_valid_file is None)): 42 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 43 | if extensions is not None: 44 | def is_valid_file(x): 45 | return has_file_allowed_extension(x, extensions) 46 | for target in sorted(class_to_idx.keys()): 47 | d = os.path.join(dir, target) 48 | if not os.path.isdir(d): 49 | continue 50 | for root, _, fnames in sorted(os.walk(d, followlinks=True)): 51 | for fname in sorted(fnames): 52 | path = os.path.join(root, fname) 53 | if is_valid_file(path): 54 | item = (path, class_to_idx[target]) 55 | images.append(item) 56 | 57 | return images 58 | 59 | 60 | class DatasetFolder(VisionDataset): 61 | """A generic data loader where the samples are arranged in this way: :: 62 | root/class_x/xxx.ext 63 | root/class_x/xxy.ext 64 | root/class_x/xxz.ext 65 | root/class_y/123.ext 66 | root/class_y/nsdf3.ext 67 | root/class_y/asd932_.ext 68 | Args: 69 | root (string): Root directory path. 70 | loader (callable): A function to load a sample given its path. 71 | extensions (tuple[string]): A list of allowed extensions. 72 | both extensions and is_valid_file should not be passed. 73 | transform (callable, optional): A function/transform that takes in 74 | a sample and returns a transformed version. 75 | E.g, ``transforms.RandomCrop`` for images. 76 | target_transform (callable, optional): A function/transform that takes 77 | in the target and transforms it. 78 | is_valid_file (callable, optional): A function that takes path of a file 79 | and check if the file is a valid file (used to check of corrupt files) 80 | both extensions and is_valid_file should not be passed. 81 | Attributes: 82 | classes (list): List of the class names. 83 | class_to_idx (dict): Dict with items (class_name, class_index). 84 | samples (list): List of (sample path, class_index) tuples 85 | targets (list): The class_index value for each image in the dataset 86 | """ 87 | 88 | def __init__(self, root, loader, extensions=None, transform=None, 89 | target_transform=None, is_valid_file=None): 90 | super(DatasetFolder, self).__init__(root, transform=transform, 91 | target_transform=target_transform) 92 | classes, class_to_idx = self._find_classes(self.root) 93 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 94 | if len(samples) == 0: 95 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n" 96 | "Supported extensions are: " + ",".join(extensions))) 97 | 98 | self.loader = loader 99 | self.extensions = extensions 100 | 101 | self.classes = classes 102 | self.class_to_idx = class_to_idx 103 | self.samples = samples 104 | self.targets = [s[1] for s in samples] 105 | 106 | def _find_classes(self, dir): 107 | """ 108 | Finds the class folders in a dataset. 109 | Args: 110 | dir (string): Root directory path. 111 | Returns: 112 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 113 | Ensures: 114 | No class is a subdirectory of another. 115 | """ 116 | if sys.version_info >= (3, 5): 117 | # Faster and available in Python 3.5 and above 118 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 119 | else: 120 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 121 | classes.sort() 122 | class_to_idx = {classes[i]: i for i in range(len(classes))} 123 | return classes, class_to_idx 124 | 125 | def __getitem__(self, index): 126 | """ 127 | Args: 128 | index (int): Index 129 | Returns: 130 | tuple: (sample, target) where target is class_index of the target class. 131 | """ 132 | path, target = self.samples[index] 133 | sample = self.loader(path) 134 | if self.transform is not None: 135 | sample = self.transform(sample) 136 | if self.target_transform is not None: 137 | target = self.target_transform(target) 138 | 139 | return sample, target, path 140 | 141 | def __len__(self): 142 | return len(self.samples) 143 | 144 | 145 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 146 | 147 | 148 | def pil_loader(path): 149 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 150 | with open(path, 'rb') as f: 151 | img = Image.open(f) 152 | return img.convert('RGB') 153 | 154 | 155 | def accimage_loader(path): 156 | import accimage 157 | try: 158 | return accimage.Image(path) 159 | except IOError: 160 | # Potentially a decoding problem, fall back to PIL.Image 161 | return pil_loader(path) 162 | 163 | 164 | def default_loader(path): 165 | from torchvision import get_image_backend 166 | if get_image_backend() == 'accimage': 167 | return accimage_loader(path) 168 | else: 169 | return pil_loader(path) 170 | 171 | 172 | class ImageFolder_Withpath(DatasetFolder): 173 | """A generic data loader where the images are arranged in this way: :: 174 | root/dog/xxx.png 175 | root/dog/xxy.png 176 | root/dog/xxz.png 177 | root/cat/123.png 178 | root/cat/nsdf3.png 179 | root/cat/asd932_.png 180 | Args: 181 | root (string): Root directory path. 182 | transform (callable, optional): A function/transform that takes in an PIL image 183 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 184 | target_transform (callable, optional): A function/transform that takes in the 185 | target and transforms it. 186 | loader (callable, optional): A function to load an image given its path. 187 | is_valid_file (callable, optional): A function that takes path of an Image file 188 | and check if the file is a valid file (used to check of corrupt files) 189 | Attributes: 190 | classes (list): List of the class names. 191 | class_to_idx (dict): Dict with items (class_name, class_index). 192 | imgs (list): List of (image path, class_index) tuples 193 | """ 194 | 195 | def __init__(self, root, transform=None, target_transform=None, 196 | loader=default_loader, is_valid_file=None): 197 | super(ImageFolder_Withpath, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, 198 | transform=transform, 199 | target_transform=target_transform, 200 | is_valid_file=is_valid_file) 201 | self.imgs = self.samples -------------------------------------------------------------------------------- /data/vision.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py 4 | # Modified the original code so that XXX 5 | ############################################################################### 6 | import os 7 | import torch 8 | import torch.utils.data as data 9 | 10 | 11 | class VisionDataset(data.Dataset): 12 | _repr_indent = 4 13 | 14 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 15 | if isinstance(root, torch._six.string_classes): 16 | root = os.path.expanduser(root) 17 | self.root = root 18 | 19 | has_transforms = transforms is not None 20 | has_separate_transform = transform is not None or target_transform is not None 21 | if has_transforms and has_separate_transform: 22 | raise ValueError("Only transforms or transform/target_transform can " 23 | "be passed as argument") 24 | 25 | # for backwards-compatibility 26 | self.transform = transform 27 | self.target_transform = target_transform 28 | 29 | if has_separate_transform: 30 | transforms = StandardTransform(transform, target_transform) 31 | self.transforms = transforms 32 | 33 | def __getitem__(self, index): 34 | raise NotImplementedError 35 | 36 | def __len__(self): 37 | raise NotImplementedError 38 | 39 | def __repr__(self): 40 | head = "Dataset " + self.__class__.__name__ 41 | body = ["Number of datapoints: {}".format(self.__len__())] 42 | if self.root is not None: 43 | body.append("Root location: {}".format(self.root)) 44 | body += self.extra_repr().splitlines() 45 | if hasattr(self, "transforms") and self.transforms is not None: 46 | body += [repr(self.transforms)] 47 | lines = [head] + [" " * self._repr_indent + line for line in body] 48 | return '\n'.join(lines) 49 | 50 | def _format_transform_repr(self, transform, head): 51 | lines = transform.__repr__().splitlines() 52 | return (["{}{}".format(head, lines[0])] + 53 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 54 | 55 | def extra_repr(self): 56 | return "" 57 | 58 | 59 | class StandardTransform(object): 60 | def __init__(self, transform=None, target_transform=None): 61 | self.transform = transform 62 | self.target_transform = target_transform 63 | 64 | def __call__(self, input, target): 65 | if self.transform is not None: 66 | input = self.transform(input) 67 | if self.target_transform is not None: 68 | target = self.target_transform(target) 69 | return input, target 70 | 71 | def _format_transform_repr(self, transform, head): 72 | lines = transform.__repr__().splitlines() 73 | return (["{}{}".format(head, lines[0])] + 74 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 75 | 76 | def __repr__(self): 77 | body = [self.__class__.__name__] 78 | if self.transform is not None: 79 | body += self._format_transform_repr(self.transform, 80 | "Transform: ") 81 | if self.target_transform is not None: 82 | body += self._format_transform_repr(self.target_transform, 83 | "Target transform: ") 84 | 85 | return '\n'.join(body) -------------------------------------------------------------------------------- /experiments/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/.DS_Store -------------------------------------------------------------------------------- /experiments/Clarification_on_tasks.md: -------------------------------------------------------------------------------- 1 | # Clarification on tasks 2 | There are so many experiments associated with this project, so its quite easy to run experiments with the 3 | wrong code. So we make a clarification here. 4 | 5 | ## Examples 6 | We provide some script examples in the 'run_temp.sh', including scripts of: 7 | 1. cloded set uda 8 | 2. partial da 9 | 3. opent set da 10 | 4. SymmNets-V2 Strengthened for Closed Set UDA 11 | 12 | We also provide the corresponding log file in the ./experiments. More specifically, 13 | 1. The results of McDalNets based on the c2i of ImageCLEF and Visda are stored in ./experiments/ckpt 14 | 2. The results of partial da and open set da are stored in ./experiments/Partial and ./experiments/Open, respectively. 15 | 3. The reuslts of Symmnets and Symmnets-SC are strored in ./experiments/SymmNets 16 | 17 | ## Configs 18 | We provide configs for example experiments in the ./experiments/configs. 19 | The rule of name is: 20 | 1. 'dataset' _ 21 | 2. 'train' _ 22 | 3. 'source domain' 2 'target domain' 23 | 4. _ 'task settings (closed set for default | partial | open)' 24 | 5. _cfg 25 | 6. SC (only for Symmnets-SC) 26 | 27 | Note that the configs for different tasks settings may be different. 28 | 29 | ## Solver 30 | The solvers for all task settints are provided in the ./solver 31 | -------------------------------------------------------------------------------- /experiments/SymmNets/logVisDA_train2validation_SymmNetsV2_None/log.txt: -------------------------------------------------------------------------------- 1 | 2 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "None"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "TASK": "closed", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_SymmNetsV2_None", "SAVE_DIR": "./experiments/SymmNets/logVisDA_train2validation_SymmNetsV2_None", "NUM_WORKERS": 8, "PRINT_STEP": 3} 3 | 4 | Train:epoch: 0:[432/432], LossCla: 2.754657, LossFeat: 1.551413, AccFs: 85.244865, AccFt: 82.904732 5 | Test:epoch: 0, AccFs: 0.519134, AccFt: 0.588042 6 | Class-wise Acc of Ft:1st: 0.738069, 2nd: 0.423309, 3rd: 0.680597, 4th: 0.608307, 5th: 0.794074, 6th: 0.342169, 7th: 0.839890, 8th: 0.527250, 9th: 0.733568, 10th: 0.301622, 11th: 0.936969, 12th: 0.130678 Best Acc so far: 0.588042 7 | Train:epoch: 1:[432/432], LossCla: 1.635495, LossFeat: 1.372633, AccFs: 92.420792, AccFt: 91.610603 8 | Test:epoch: 1, AccFs: 0.599373, AccFt: 0.639031 9 | Class-wise Acc of Ft:1st: 0.831322, 2nd: 0.523453, 3rd: 0.690832, 4th: 0.563311, 5th: 0.797911, 6th: 0.400482, 7th: 0.841615, 8th: 0.746250, 9th: 0.762805, 10th: 0.271372, 11th: 0.929178, 12th: 0.309841 Best Acc so far: 0.639031 10 | Train:epoch: 2:[432/432], LossCla: 1.503049, LossFeat: 1.370504, AccFs: 93.858505, AccFt: 93.198425 11 | Test:epoch: 2, AccFs: 0.647094, AccFt: 0.670877 12 | Class-wise Acc of Ft:1st: 0.877126, 2nd: 0.494676, 3rd: 0.696375, 4th: 0.598308, 5th: 0.872522, 6th: 0.639036, 7th: 0.883023, 8th: 0.750500, 9th: 0.813585, 10th: 0.290224, 11th: 0.925165, 12th: 0.209986 Best Acc so far: 0.670877 13 | Train:epoch: 3:[432/432], LossCla: 1.420064, LossFeat: 1.361121, AccFs: 95.155167, AccFt: 94.270836 14 | Test:epoch: 3, AccFs: 0.649891, AccFt: 0.682259 15 | Class-wise Acc of Ft:1st: 0.862041, 2nd: 0.507050, 3rd: 0.713859, 4th: 0.601000, 5th: 0.872735, 6th: 0.683855, 7th: 0.874741, 8th: 0.759500, 9th: 0.799077, 10th: 0.289347, 11th: 0.924693, 12th: 0.299207 Best Acc so far: 0.682259 16 | Train:epoch: 4:[432/432], LossCla: 1.366017, LossFeat: 1.360252, AccFs: 95.775467, AccFt: 94.977936 17 | Test:epoch: 4, AccFs: 0.672196, AccFt: 0.688111 18 | Class-wise Acc of Ft:1st: 0.859572, 2nd: 0.545324, 3rd: 0.715139, 4th: 0.648880, 5th: 0.864848, 6th: 0.631807, 7th: 0.878019, 8th: 0.801500, 9th: 0.844142, 10th: 0.299868, 11th: 0.913362, 12th: 0.254867 Best Acc so far: 0.688111 19 | Train:epoch: 5:[432/432], LossCla: 1.349571, LossFeat: 1.359669, AccFs: 96.155235, AccFt: 95.668762 20 | Test:epoch: 5, AccFs: 0.678105, AccFt: 0.690267 21 | Class-wise Acc of Ft:1st: 0.846133, 2nd: 0.551655, 3rd: 0.736247, 4th: 0.672531, 5th: 0.856321, 6th: 0.645301, 7th: 0.887336, 8th: 0.794250, 9th: 0.846560, 10th: 0.296361, 11th: 0.903211, 12th: 0.247296 Best Acc so far: 0.690267 22 | Train:epoch: 6:[432/432], LossCla: 1.332122, LossFeat: 1.351513, AccFs: 96.715858, AccFt: 96.160660 23 | Test:epoch: 6, AccFs: 0.690127, AccFt: 0.698896 24 | Class-wise Acc of Ft:1st: 0.854635, 2nd: 0.584173, 3rd: 0.750320, 4th: 0.670609, 5th: 0.884886, 6th: 0.630843, 7th: 0.902346, 8th: 0.785250, 9th: 0.858211, 10th: 0.303376, 11th: 0.886686, 12th: 0.275415 Best Acc so far: 0.698896 25 | Train:epoch: 7:[432/432], LossCla: 1.328282, LossFeat: 1.341111, AccFs: 97.158928, AccFt: 96.784576 26 | Test:epoch: 7, AccFs: 0.692217, AccFt: 0.703444 27 | Class-wise Acc of Ft:1st: 0.862041, 2nd: 0.560576, 3rd: 0.725586, 4th: 0.638304, 5th: 0.882115, 6th: 0.688675, 7th: 0.886128, 8th: 0.796000, 9th: 0.876017, 10th: 0.337133, 11th: 0.901086, 12th: 0.287671 Best Acc so far: 0.703444 28 | Train:epoch: 8:[432/432], LossCla: 1.344697, LossFeat: 1.339598, AccFs: 97.298180, AccFt: 97.012444 29 | Test:epoch: 8, AccFs: 0.699641, AccFt: 0.711915 30 | Class-wise Acc of Ft:1st: 0.874931, 2nd: 0.679712, 3rd: 0.745842, 4th: 0.639169, 5th: 0.872735, 6th: 0.667952, 7th: 0.877674, 8th: 0.811750, 9th: 0.877555, 10th: 0.335818, 11th: 0.898489, 12th: 0.261355 Best Acc so far: 0.711915 31 | Train:epoch: 9:[432/432], LossCla: 1.335095, LossFeat: 1.324822, AccFs: 97.696037, AccFt: 97.350624 32 | Test:epoch: 9, AccFs: 0.703786, AccFt: 0.711389 33 | Class-wise Acc of Ft:1st: 0.874657, 2nd: 0.669353, 3rd: 0.749680, 4th: 0.667820, 5th: 0.859731, 6th: 0.617831, 7th: 0.885783, 8th: 0.802250, 9th: 0.880633, 10th: 0.372205, 11th: 0.888338, 12th: 0.268385 34 | Train:epoch: 10:[432/432], LossCla: 1.346777, LossFeat: 1.320226, AccFs: 97.685188, AccFt: 97.417534 35 | Test:epoch: 10, AccFs: 0.712391, AccFt: 0.717691 36 | Class-wise Acc of Ft:1st: 0.873012, 2nd: 0.709065, 3rd: 0.751599, 4th: 0.666282, 5th: 0.850352, 6th: 0.635663, 7th: 0.858696, 8th: 0.784000, 9th: 0.911629, 10th: 0.432267, 11th: 0.894240, 12th: 0.245494 Best Acc so far: 0.717691 37 | Train:epoch: 11:[432/432], LossCla: 1.343779, LossFeat: 1.302535, AccFs: 97.918472, AccFt: 97.652634 38 | Test:epoch: 11, AccFs: 0.699865, AccFt: 0.707937 39 | Class-wise Acc of Ft:1st: 0.864783, 2nd: 0.634532, 3rd: 0.753518, 4th: 0.699164, 5th: 0.855042, 6th: 0.616386, 7th: 0.891994, 8th: 0.741750, 9th: 0.872939, 10th: 0.463393, 11th: 0.885269, 12th: 0.216474 40 | Train:epoch: 12:[432/432], LossCla: 1.353641, LossFeat: 1.292348, AccFs: 97.931137, AccFt: 97.688805 41 | Test:epoch: 12, AccFs: 0.701588, AccFt: 0.705192 42 | Class-wise Acc of Ft:1st: 0.844213, 2nd: 0.669640, 3rd: 0.751386, 4th: 0.692145, 5th: 0.840119, 6th: 0.666024, 7th: 0.860593, 8th: 0.762250, 9th: 0.841284, 10th: 0.443665, 11th: 0.902266, 12th: 0.188717 43 | Train:epoch: 13:[432/432], LossCla: 1.342840, LossFeat: 1.280037, AccFs: 98.263893, AccFt: 97.963684 44 | Test:epoch: 13, AccFs: 0.709544, AccFt: 0.708189 45 | Class-wise Acc of Ft:1st: 0.845584, 2nd: 0.605755, 3rd: 0.749680, 4th: 0.673108, 5th: 0.844383, 6th: 0.728193, 7th: 0.861111, 8th: 0.763750, 9th: 0.903495, 10th: 0.419991, 11th: 0.903683, 12th: 0.199531 46 | Train:epoch: 14:[432/432], LossCla: 1.344371, LossFeat: 1.279572, AccFs: 98.209633, AccFt: 97.987198 47 | Test:epoch: 14, AccFs: 0.716340, AccFt: 0.712602 48 | Class-wise Acc of Ft:1st: 0.820351, 2nd: 0.650935, 3rd: 0.768657, 4th: 0.697241, 5th: 0.832658, 6th: 0.746506, 7th: 0.849034, 8th: 0.758000, 9th: 0.880633, 10th: 0.469969, 11th: 0.888338, 12th: 0.188897 49 | Train:epoch: 15:[432/432], LossCla: 1.346274, LossFeat: 1.272384, AccFs: 98.457397, AccFt: 98.126450 50 | Test:epoch: 15, AccFs: 0.712162, AccFt: 0.708201 51 | Class-wise Acc of Ft:1st: 0.843664, 2nd: 0.557410, 3rd: 0.772281, 4th: 0.678396, 5th: 0.826903, 6th: 0.747952, 7th: 0.873361, 8th: 0.739750, 9th: 0.873599, 10th: 0.483560, 11th: 0.893532, 12th: 0.208003 52 | Train:epoch: 16:[432/432], LossCla: 1.345689, LossFeat: 1.257215, AccFs: 98.564095, AccFt: 98.260269 53 | Test:epoch: 16, AccFs: 0.712009, AccFt: 0.705590 54 | Class-wise Acc of Ft:1st: 0.797038, 2nd: 0.629353, 3rd: 0.774627, 4th: 0.679262, 5th: 0.829247, 6th: 0.701205, 7th: 0.847654, 8th: 0.751500, 9th: 0.859530, 10th: 0.508549, 11th: 0.895892, 12th: 0.193223 55 | Train:epoch: 17:[432/432], LossCla: 1.348905, LossFeat: 1.249159, AccFs: 98.661751, AccFt: 98.365166 56 | Test:epoch: 17, AccFs: 0.716529, AccFt: 0.709276 57 | Class-wise Acc of Ft:1st: 0.794569, 2nd: 0.633669, 3rd: 0.785501, 4th: 0.690607, 5th: 0.792368, 6th: 0.676145, 7th: 0.860766, 8th: 0.732500, 9th: 0.898000, 10th: 0.525647, 11th: 0.886686, 12th: 0.234859 58 | Train:epoch: 18:[432/432], LossCla: 1.348393, LossFeat: 1.240162, AccFs: 98.710579, AccFt: 98.462822 59 | Test:epoch: 18, AccFs: 0.714328, AccFt: 0.709291 60 | Class-wise Acc of Ft:1st: 0.845310, 2nd: 0.524892, 3rd: 0.768657, 4th: 0.696952, 5th: 0.825197, 6th: 0.695904, 7th: 0.893375, 8th: 0.751250, 9th: 0.864805, 10th: 0.542744, 11th: 0.885033, 12th: 0.217376 61 | Train:epoch: 19:[432/432], LossCla: 1.341572, LossFeat: 1.234422, AccFs: 98.885994, AccFt: 98.594833 62 | Test:epoch: 19, AccFs: 0.707589, AccFt: 0.696340 63 | Class-wise Acc of Ft:1st: 0.782501, 2nd: 0.585036, 3rd: 0.765672, 4th: 0.680127, 5th: 0.797485, 6th: 0.673735, 7th: 0.880090, 8th: 0.731500, 9th: 0.856672, 10th: 0.508110, 11th: 0.910765, 12th: 0.184391 64 | Train:epoch: 20:[432/432], LossCla: 1.345093, LossFeat: 1.230574, AccFs: 98.878761, AccFt: 98.614731 65 | Test:epoch: 20, AccFs: 0.723617, AccFt: 0.713527 66 | Class-wise Acc of Ft:1st: 0.818431, 2nd: 0.611511, 3rd: 0.766524, 4th: 0.696375, 5th: 0.797485, 6th: 0.732530, 7th: 0.886473, 8th: 0.718750, 9th: 0.881073, 10th: 0.539676, 11th: 0.902974, 12th: 0.210526 Best Acc so far: 0.723617 67 | Train:epoch: 21:[432/432], LossCla: 1.345878, LossFeat: 1.223302, AccFs: 98.931206, AccFt: 98.656326 68 | Test:epoch: 21, AccFs: 0.717828, AccFt: 0.709226 69 | Class-wise Acc of Ft:1st: 0.781953, 2nd: 0.615540, 3rd: 0.799787, 4th: 0.647918, 5th: 0.831592, 6th: 0.708434, 7th: 0.834714, 8th: 0.722250, 9th: 0.870081, 10th: 0.607628, 11th: 0.904627, 12th: 0.186193 70 | Train:epoch: 22:[432/432], LossCla: 1.346019, LossFeat: 1.226002, AccFs: 99.005356, AccFt: 98.692490 71 | Test:epoch: 22, AccFs: 0.711813, AccFt: 0.699680 72 | Class-wise Acc of Ft:1st: 0.801700, 2nd: 0.569784, 3rd: 0.772068, 4th: 0.677435, 5th: 0.817310, 6th: 0.704578, 7th: 0.890959, 8th: 0.687500, 9th: 0.848098, 10th: 0.524770, 11th: 0.903683, 12th: 0.198270 73 | Train:epoch: 23:[432/432], LossCla: 1.345264, LossFeat: 1.223078, AccFs: 99.099396, AccFt: 98.853447 74 | Test:epoch: 23, AccFs: 0.705851, AccFt: 0.694408 75 | Class-wise Acc of Ft:1st: 0.756445, 2nd: 0.572086, 3rd: 0.785501, 4th: 0.678877, 5th: 0.791729, 6th: 0.684819, 7th: 0.864044, 8th: 0.663750, 9th: 0.872060, 10th: 0.539676, 11th: 0.898961, 12th: 0.224946 76 | Train:epoch: 24:[432/432], LossCla: 1.341958, LossFeat: 1.216860, AccFs: 99.148224, AccFt: 98.867912 77 | Test:epoch: 24, AccFs: 0.723894, AccFt: 0.714709 78 | Class-wise Acc of Ft:1st: 0.783873, 2nd: 0.568345, 3rd: 0.778678, 4th: 0.689068, 5th: 0.829247, 6th: 0.746506, 7th: 0.884748, 8th: 0.697000, 9th: 0.883931, 10th: 0.574748, 11th: 0.889282, 12th: 0.251081 Best Acc so far: 0.723894 79 | Train:epoch: 25:[432/432], LossCla: 1.332571, LossFeat: 1.214391, AccFs: 99.168114, AccFt: 98.923973 80 | Test:epoch: 25, AccFs: 0.702965, AccFt: 0.693480 81 | Class-wise Acc of Ft:1st: 0.745200, 2nd: 0.549928, 3rd: 0.769723, 4th: 0.616768, 5th: 0.819868, 6th: 0.718072, 7th: 0.893720, 8th: 0.664750, 9th: 0.897780, 10th: 0.505918, 11th: 0.910765, 12th: 0.229272 82 | Train:epoch: 26:[432/432], LossCla: 1.329841, LossFeat: 1.215475, AccFs: 99.144608, AccFt: 98.972801 83 | Test:epoch: 26, AccFs: 0.718568, AccFt: 0.707460 84 | Class-wise Acc of Ft:1st: 0.764948, 2nd: 0.593669, 3rd: 0.784648, 4th: 0.677146, 5th: 0.791303, 6th: 0.762892, 7th: 0.869220, 8th: 0.690750, 9th: 0.903495, 10th: 0.533538, 11th: 0.875118, 12th: 0.242790 85 | Train:epoch: 27:[432/432], LossCla: 1.329338, LossFeat: 1.212966, AccFs: 99.200668, AccFt: 98.956528 86 | Test:epoch: 27, AccFs: 0.716024, AccFt: 0.706511 87 | Class-wise Acc of Ft:1st: 0.758914, 2nd: 0.635108, 3rd: 0.792111, 4th: 0.683300, 5th: 0.802388, 6th: 0.738795, 7th: 0.871463, 8th: 0.692500, 9th: 0.877555, 10th: 0.497589, 11th: 0.882200, 12th: 0.246215 88 | Train:epoch: 28:[432/432], LossCla: 1.328192, LossFeat: 1.203063, AccFs: 99.339920, AccFt: 99.083115 89 | Test:epoch: 28, AccFs: 0.710096, AccFt: 0.700704 90 | Class-wise Acc of Ft:1st: 0.802249, 2nd: 0.567770, 3rd: 0.795096, 4th: 0.675416, 5th: 0.799616, 6th: 0.671807, 7th: 0.892685, 8th: 0.663250, 9th: 0.910310, 10th: 0.530469, 11th: 0.887630, 12th: 0.212149 91 | Train:epoch: 29:[432/432], LossCla: 1.329534, LossFeat: 1.206706, AccFs: 99.336304, AccFt: 99.115669 92 | Test:epoch: 29, AccFs: 0.721606, AccFt: 0.709754 93 | Class-wise Acc of Ft:1st: 0.792924, 2nd: 0.633669, 3rd: 0.791045, 4th: 0.668493, 5th: 0.822000, 6th: 0.735422, 7th: 0.881125, 8th: 0.666250, 9th: 0.864586, 10th: 0.555897, 11th: 0.899433, 12th: 0.206200 94 | Train:epoch: 30:[432/432], LossCla: 1.331311, LossFeat: 1.210111, AccFs: 99.381508, AccFt: 99.055992 95 | Test:epoch: 30, AccFs: 0.710322, AccFt: 0.697093 96 | Class-wise Acc of Ft:1st: 0.782776, 2nd: 0.529784, 3rd: 0.742217, 4th: 0.669647, 5th: 0.795992, 6th: 0.749880, 7th: 0.891994, 8th: 0.671250, 9th: 0.843042, 10th: 0.544498, 11th: 0.919263, 12th: 0.224766 97 | Train:epoch: 31:[432/432], LossCla: 1.326631, LossFeat: 1.207292, AccFs: 99.415871, AccFt: 99.198860 98 | Test:epoch: 31, AccFs: 0.711356, AccFt: 0.698533 99 | Class-wise Acc of Ft:1st: 0.820351, 2nd: 0.520000, 3rd: 0.770149, 4th: 0.675320, 5th: 0.760392, 6th: 0.761928, 7th: 0.868012, 8th: 0.669250, 9th: 0.819741, 10th: 0.562473, 11th: 0.909112, 12th: 0.245674 100 | -------------------------------------------------------------------------------- /experiments/SymmNets/logres101VisDA_train2validation_SymmNetsV2_None/log.txt: -------------------------------------------------------------------------------- 1 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/data1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet101", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 40, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "None"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "TASK": "closed", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logres101\rVisDA_train2validation_SymmNetsV2_None", "SAVE_DIR": "./experiments/SymmNets/logres101\rVisDA_train2validation_SymmNetsV2_None", "NUM_WORKERS": 8, "PRINT_STEP": 3} 2 | 3 | Train:epoch: 0:[432/432], LossCla: 2.581147, LossFeat: 1.475838, AccFs: 86.642799, AccFt: 84.563080 4 | Test:epoch: 0, AccFs: 0.603141, AccFt: 0.649918 5 | Class-wise Acc of Ft:1st: 0.873834, 2nd: 0.471367, 3rd: 0.728998, 4th: 0.609557, 5th: 0.770625, 6th: 0.629398, 7th: 0.898378, 8th: 0.671500, 9th: 0.848978, 10th: 0.278825, 11th: 0.919500, 12th: 0.098053 Best Acc so far: 0.649918 6 | Train:epoch: 1:[432/432], LossCla: 1.556984, LossFeat: 1.260458, AccFs: 93.570961, AccFt: 92.699295 7 | Test:epoch: 1, AccFs: 0.654738, AccFt: 0.687854 8 | Class-wise Acc of Ft:1st: 0.910861, 2nd: 0.515683, 3rd: 0.772281, 4th: 0.649168, 5th: 0.828182, 6th: 0.786988, 7th: 0.874396, 8th: 0.766500, 9th: 0.854254, 10th: 0.269619, 11th: 0.906988, 12th: 0.119322 Best Acc so far: 0.687854 9 | Train:epoch: 2:[432/432], LossCla: 1.446185, LossFeat: 1.245270, AccFs: 94.833260, AccFt: 94.019463 10 | Test:epoch: 2, AccFs: 0.692207, AccFt: 0.715441 11 | Class-wise Acc of Ft:1st: 0.930060, 2nd: 0.544460, 3rd: 0.795949, 4th: 0.701183, 5th: 0.855255, 6th: 0.883855, 7th: 0.879572, 8th: 0.800000, 9th: 0.834469, 10th: 0.282332, 11th: 0.870869, 12th: 0.207282 Best Acc so far: 0.715441 12 | Train:epoch: 3:[432/432], LossCla: 1.365695, LossFeat: 1.234548, AccFs: 96.044922, AccFt: 95.247398 13 | Test:epoch: 3, AccFs: 0.715807, AccFt: 0.732017 14 | Class-wise Acc of Ft:1st: 0.921009, 2nd: 0.606043, 3rd: 0.826013, 4th: 0.700317, 5th: 0.887657, 6th: 0.899759, 7th: 0.879917, 8th: 0.789750, 9th: 0.864146, 10th: 0.330995, 11th: 0.869688, 12th: 0.208904 Best Acc so far: 0.732017 15 | Train:epoch: 4:[432/432], LossCla: 1.332583, LossFeat: 1.235615, AccFs: 96.475334, AccFt: 95.826103 16 | Test:epoch: 4, AccFs: 0.734888, AccFt: 0.746103 17 | Class-wise Acc of Ft:1st: 0.935546, 2nd: 0.682014, 3rd: 0.808102, 4th: 0.687818, 5th: 0.879130, 6th: 0.924337, 7th: 0.884403, 8th: 0.802250, 9th: 0.858430, 10th: 0.370452, 11th: 0.894004, 12th: 0.226748 Best Acc so far: 0.746103 18 | Train:epoch: 5:[432/432], LossCla: 1.293805, LossFeat: 1.245844, AccFs: 97.010635, AccFt: 96.522354 19 | Test:epoch: 5, AccFs: 0.742592, AccFt: 0.754611 20 | Class-wise Acc of Ft:1st: 0.934174, 2nd: 0.717122, 3rd: 0.780384, 4th: 0.722431, 5th: 0.901087, 6th: 0.900723, 7th: 0.866287, 8th: 0.794500, 9th: 0.901297, 10th: 0.382727, 11th: 0.906043, 12th: 0.248558 Best Acc so far: 0.754611 21 | Train:epoch: 6:[432/432], LossCla: 1.282527, LossFeat: 1.256809, AccFs: 97.384987, AccFt: 96.965424 22 | Test:epoch: 6, AccFs: 0.740717, AccFt: 0.748055 23 | Class-wise Acc of Ft:1st: 0.923478, 2nd: 0.649784, 3rd: 0.789126, 4th: 0.731564, 5th: 0.884033, 6th: 0.933976, 7th: 0.860593, 8th: 0.815750, 9th: 0.868762, 10th: 0.382727, 11th: 0.899670, 12th: 0.237203 24 | Train:epoch: 7:[432/432], LossCla: 1.289282, LossFeat: 1.270966, AccFs: 97.674332, AccFt: 97.225838 25 | Test:epoch: 7, AccFs: 0.750884, AccFt: 0.749084 26 | Class-wise Acc of Ft:1st: 0.918267, 2nd: 0.710791, 3rd: 0.799787, 4th: 0.736948, 5th: 0.885739, 6th: 0.926747, 7th: 0.857660, 8th: 0.814500, 9th: 0.875797, 10th: 0.329680, 11th: 0.879131, 12th: 0.253965 27 | Train:epoch: 8:[432/432], LossCla: 1.315241, LossFeat: 1.275234, AccFs: 97.940178, AccFt: 97.616463 28 | Test:epoch: 8, AccFs: 0.754999, AccFt: 0.750322 29 | Class-wise Acc of Ft:1st: 0.903456, 2nd: 0.697266, 3rd: 0.808955, 4th: 0.719162, 5th: 0.897890, 6th: 0.918554, 7th: 0.833678, 8th: 0.825500, 9th: 0.881952, 10th: 0.376151, 11th: 0.874174, 12th: 0.267123 Best Acc so far: 0.754999 30 | Train:epoch: 9:[432/432], LossCla: 1.322757, LossFeat: 1.267201, AccFs: 98.102936, AccFt: 97.746674 31 | Test:epoch: 9, AccFs: 0.762605, AccFt: 0.755622 32 | Class-wise Acc of Ft:1st: 0.923478, 2nd: 0.715683, 3rd: 0.783795, 4th: 0.698106, 5th: 0.889789, 6th: 0.887711, 7th: 0.848344, 8th: 0.834750, 9th: 0.897560, 10th: 0.417361, 11th: 0.885741, 12th: 0.285148 Best Acc so far: 0.762605 33 | Train:epoch: 10:[432/432], LossCla: 1.325336, LossFeat: 1.256211, AccFs: 98.238571, AccFt: 97.976349 34 | Test:epoch: 10, AccFs: 0.757504, AccFt: 0.755875 35 | Class-wise Acc of Ft:1st: 0.925123, 2nd: 0.690647, 3rd: 0.812793, 4th: 0.663494, 5th: 0.879343, 6th: 0.888193, 7th: 0.857660, 8th: 0.834250, 9th: 0.885909, 10th: 0.481368, 11th: 0.886213, 12th: 0.265501 36 | Train:epoch: 11:[432/432], LossCla: 1.322062, LossFeat: 1.247246, AccFs: 98.379631, AccFt: 98.149956 37 | Test:epoch: 11, AccFs: 0.763305, AccFt: 0.760014 38 | Class-wise Acc of Ft:1st: 0.908393, 2nd: 0.668777, 3rd: 0.798507, 4th: 0.675993, 5th: 0.889789, 6th: 0.891084, 7th: 0.870945, 8th: 0.835500, 9th: 0.896461, 10th: 0.495397, 11th: 0.886686, 12th: 0.302632 Best Acc so far: 0.763305 39 | Train:epoch: 12:[432/432], LossCla: 1.315084, LossFeat: 1.236000, AccFs: 98.649086, AccFt: 98.348885 40 | Test:epoch: 12, AccFs: 0.764516, AccFt: 0.760719 41 | Class-wise Acc of Ft:1st: 0.910861, 2nd: 0.688921, 3rd: 0.807889, 4th: 0.666667, 5th: 0.888084, 6th: 0.846265, 7th: 0.879400, 8th: 0.843000, 9th: 0.904375, 10th: 0.531346, 11th: 0.882436, 12th: 0.279380 Best Acc so far: 0.764516 42 | Train:epoch: 13:[432/432], LossCla: 1.321137, LossFeat: 1.231841, AccFs: 98.696106, AccFt: 98.520691 43 | Test:epoch: 13, AccFs: 0.763517, AccFt: 0.759895 44 | Class-wise Acc of Ft:1st: 0.895502, 2nd: 0.685180, 3rd: 0.816844, 4th: 0.690607, 5th: 0.910680, 6th: 0.818313, 7th: 0.865942, 8th: 0.828500, 9th: 0.892944, 10th: 0.568610, 11th: 0.871105, 12th: 0.274513 45 | Train:epoch: 14:[432/432], LossCla: 1.325543, LossFeat: 1.226663, AccFs: 98.687065, AccFt: 98.423035 46 | Test:epoch: 14, AccFs: 0.755246, AccFt: 0.752819 47 | Class-wise Acc of Ft:1st: 0.872189, 2nd: 0.682302, 3rd: 0.788060, 4th: 0.644457, 5th: 0.883181, 6th: 0.749398, 7th: 0.862836, 8th: 0.834250, 9th: 0.895801, 10th: 0.644893, 11th: 0.906280, 12th: 0.270187 48 | Train:epoch: 15:[432/432], LossCla: 1.317506, LossFeat: 1.212464, AccFs: 98.858871, AccFt: 98.696106 49 | Test:epoch: 15, AccFs: 0.755130, AccFt: 0.753893 50 | Class-wise Acc of Ft:1st: 0.891114, 2nd: 0.638273, 3rd: 0.817697, 4th: 0.650611, 5th: 0.896824, 6th: 0.786506, 7th: 0.863527, 8th: 0.818500, 9th: 0.918663, 10th: 0.644454, 11th: 0.888574, 12th: 0.231975 51 | Train:epoch: 16:[432/432], LossCla: 1.315734, LossFeat: 1.204639, AccFs: 98.891418, AccFt: 98.810043 52 | Test:epoch: 16, AccFs: 0.761036, AccFt: 0.761904 53 | Class-wise Acc of Ft:1st: 0.900987, 2nd: 0.680000, 3rd: 0.811087, 4th: 0.622825, 5th: 0.908335, 6th: 0.781687, 7th: 0.849206, 8th: 0.805250, 9th: 0.913168, 10th: 0.705831, 11th: 0.902030, 12th: 0.262437 54 | Train:epoch: 17:[432/432], LossCla: 1.318342, LossFeat: 1.198416, AccFs: 98.999931, AccFt: 98.837166 55 | Test:epoch: 17, AccFs: 0.754808, AccFt: 0.756363 56 | Class-wise Acc of Ft:1st: 0.896325, 2nd: 0.672806, 3rd: 0.795309, 4th: 0.582829, 5th: 0.896397, 6th: 0.783614, 7th: 0.857488, 8th: 0.798000, 9th: 0.894922, 10th: 0.718106, 11th: 0.912181, 12th: 0.268385 57 | Train:epoch: 18:[432/432], LossCla: 1.305724, LossFeat: 1.187242, AccFs: 99.189819, AccFt: 99.005356 58 | Test:epoch: 18, AccFs: 0.759740, AccFt: 0.759087 59 | -------------------------------------------------------------------------------- /experiments/ckpt/logVisDA_train2validation_McDalNet_CE/log.txt: -------------------------------------------------------------------------------- 1 | 2 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "CE"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_McDalNet_CE", "SAVE_DIR": "./experiments/ckpt/logVisDA_train2validation_McDalNet_CE", "NUM_WORKERS": 8, "PRINT_STEP": 3} 3 | 4 | Train:epoch: 0:[432/432], LossCE: 0.690750, LossDA: -0.058507, LossAll: 2.029814, Auxi1: 82.930046, Auxi2: 82.752823, Task: 82.785378 5 | Test:epoch: 0, Top1_auxi1: 0.558361, Top1_auxi2: 0.552648, Top1: 0.556366 6 | Class-wise Acc:1st: 0.775919, 2nd: 0.379281, 3rd: 0.638806, 4th: 0.651572, 5th: 0.636751, 6th: 0.265542, 7th: 0.823844, 8th: 0.322250, 9th: 0.827874, 10th: 0.292416, 11th: 0.882436, 12th: 0.179704 Best Acc so far: 0.558361 7 | Train:epoch: 1:[432/432], LossCE: 0.244714, LossDA: -0.080817, LossAll: 0.664051, Auxi1: 92.214630, Auxi2: 92.138672, Task: 92.294197 8 | Test:epoch: 1, Top1_auxi1: 0.558922, Top1_auxi2: 0.526064, Top1: 0.563260 9 | Class-wise Acc:1st: 0.739166, 2nd: 0.392518, 3rd: 0.631557, 4th: 0.621575, 5th: 0.633340, 6th: 0.272289, 7th: 0.858351, 8th: 0.322000, 9th: 0.843482, 10th: 0.338886, 11th: 0.911473, 12th: 0.194485 Best Acc so far: 0.563260 10 | Train:epoch: 2:[432/432], LossCE: 0.183937, LossDA: -0.185786, LossAll: 0.430443, Auxi1: 93.030235, Auxi2: 93.485970, Task: 93.945312 11 | Test:epoch: 2, Top1_auxi1: 0.366009, Top1_auxi2: 0.437074, Top1: 0.582062 12 | Class-wise Acc:1st: 0.777839, 2nd: 0.392230, 3rd: 0.694883, 4th: 0.619556, 5th: 0.687913, 6th: 0.267952, 7th: 0.869565, 8th: 0.350500, 9th: 0.853814, 10th: 0.322665, 11th: 0.881964, 12th: 0.265862 Best Acc so far: 0.582062 13 | Train:epoch: 3:[432/432], LossCE: 0.145429, LossDA: -0.456710, LossAll: 0.132864, Auxi1: 92.097076, Auxi2: 94.017654, Task: 95.269096 14 | Test:epoch: 3, Top1_auxi1: 0.250198, Top1_auxi2: 0.457712, Top1: 0.612280 15 | Class-wise Acc:1st: 0.808283, 2nd: 0.411223, 3rd: 0.711301, 4th: 0.684453, 5th: 0.750160, 6th: 0.311325, 7th: 0.879227, 8th: 0.438500, 9th: 0.902176, 10th: 0.325296, 11th: 0.878659, 12th: 0.246756 Best Acc so far: 0.612280 16 | Train:epoch: 4:[432/432], LossCE: 0.143135, LossDA: -0.437714, LossAll: 0.156751, Auxi1: 92.183884, Auxi2: 94.439018, Task: 95.659721 17 | Test:epoch: 4, Top1_auxi1: 0.428402, Top1_auxi2: 0.617549, Top1: 0.664672 18 | Class-wise Acc:1st: 0.837630, 2nd: 0.542734, 3rd: 0.728145, 4th: 0.646284, 5th: 0.777873, 6th: 0.486747, 7th: 0.880090, 8th: 0.688250, 9th: 0.920202, 10th: 0.312582, 11th: 0.881020, 12th: 0.274513 Best Acc so far: 0.664672 19 | Train:epoch: 5:[432/432], LossCE: 0.132734, LossDA: -0.183213, LossAll: 0.290770, Auxi1: 94.561996, Auxi2: 95.509621, Task: 96.117264 20 | Test:epoch: 5, Top1_auxi1: 0.494570, Top1_auxi2: 0.659047, Top1: 0.679029 21 | Class-wise Acc:1st: 0.864235, 2nd: 0.500432, 3rd: 0.760128, 4th: 0.648207, 5th: 0.831806, 6th: 0.581205, 7th: 0.862491, 8th: 0.727000, 9th: 0.936250, 10th: 0.322665, 11th: 0.875826, 12th: 0.238104 Best Acc so far: 0.679029 22 | Train:epoch: 6:[432/432], LossCE: 0.113248, LossDA: -0.137117, LossAll: 0.251889, Auxi1: 95.701317, Auxi2: 96.173325, Task: 96.656181 23 | Test:epoch: 6, Top1_auxi1: 0.523407, Top1_auxi2: 0.667125, Top1: 0.678395 24 | Class-wise Acc:1st: 0.828579, 2nd: 0.551655, 3rd: 0.741578, 4th: 0.650034, 5th: 0.805159, 6th: 0.581687, 7th: 0.881988, 8th: 0.739750, 9th: 0.950758, 10th: 0.282332, 11th: 0.888574, 12th: 0.238645 25 | Train:epoch: 7:[432/432], LossCE: 0.100754, LossDA: -0.099633, LossAll: 0.238848, Auxi1: 96.310768, Auxi2: 96.567566, Task: 97.039566 26 | Test:epoch: 7, Top1_auxi1: 0.587965, Top1_auxi2: 0.690583, Top1: 0.692860 27 | Class-wise Acc:1st: 0.854361, 2nd: 0.586187, 3rd: 0.756077, 4th: 0.659360, 5th: 0.854402, 6th: 0.654458, 7th: 0.918910, 8th: 0.688750, 9th: 0.947681, 10th: 0.277510, 11th: 0.882672, 12th: 0.233958 Best Acc so far: 0.692860 28 | Train:epoch: 8:[432/432], LossCE: 0.084525, LossDA: -0.065558, LossAll: 0.211381, Auxi1: 97.055847, Auxi2: 97.249352, Task: 97.500725 29 | Test:epoch: 8, Top1_auxi1: 0.608555, Top1_auxi2: 0.693317, Top1: 0.690395 30 | Class-wise Acc:1st: 0.862863, 2nd: 0.578417, 3rd: 0.736034, 4th: 0.669839, 5th: 0.837135, 6th: 0.657831, 7th: 0.906142, 8th: 0.713000, 9th: 0.949659, 10th: 0.277071, 11th: 0.899197, 12th: 0.197549 Best Acc so far: 0.693317 31 | Train:epoch: 9:[432/432], LossCE: 0.071856, LossDA: -0.047084, LossAll: 0.183912, Auxi1: 97.545937, Auxi2: 97.703270, Task: 97.929329 32 | Test:epoch: 9, Top1_auxi1: 0.627759, Top1_auxi2: 0.701944, Top1: 0.696261 33 | Class-wise Acc:1st: 0.866703, 2nd: 0.556259, 3rd: 0.762260, 4th: 0.667820, 5th: 0.852057, 6th: 0.704096, 7th: 0.895790, 8th: 0.723250, 9th: 0.954276, 10th: 0.275756, 11th: 0.892823, 12th: 0.204037 Best Acc so far: 0.701944 34 | Train:epoch: 10:[432/432], LossCE: 0.061862, LossDA: -0.039253, LossAll: 0.158299, Auxi1: 97.900391, Auxi2: 98.066772, Task: 98.202400 35 | Test:epoch: 10, Top1_auxi1: 0.637854, Top1_auxi2: 0.704391, Top1: 0.699586 36 | Class-wise Acc:1st: 0.859846, 2nd: 0.563741, 3rd: 0.767804, 4th: 0.680031, 5th: 0.836069, 6th: 0.735904, 7th: 0.897170, 8th: 0.730750, 9th: 0.951418, 10th: 0.290662, 11th: 0.894004, 12th: 0.187635 Best Acc so far: 0.704391 37 | Train:epoch: 11:[432/432], LossCE: 0.056392, LossDA: -0.033153, LossAll: 0.145713, Auxi1: 98.146339, Auxi2: 98.247612, Task: 98.352501 38 | Test:epoch: 11, Top1_auxi1: 0.652609, Top1_auxi2: 0.707246, Top1: 0.702464 39 | Class-wise Acc:1st: 0.853538, 2nd: 0.575827, 3rd: 0.764606, 4th: 0.689645, 5th: 0.836709, 6th: 0.729639, 7th: 0.896653, 8th: 0.740750, 9th: 0.948120, 10th: 0.298992, 11th: 0.891407, 12th: 0.203677 Best Acc so far: 0.707246 40 | Train:epoch: 12:[432/432], LossCE: 0.050839, LossDA: -0.029478, LossAll: 0.130653, Auxi1: 98.339844, Auxi2: 98.442924, Task: 98.522499 41 | Test:epoch: 12, Top1_auxi1: 0.651128, Top1_auxi2: 0.697198, Top1: 0.692559 42 | Class-wise Acc:1st: 0.852167, 2nd: 0.523741, 3rd: 0.767804, 4th: 0.698491, 5th: 0.833724, 6th: 0.763855, 7th: 0.933057, 8th: 0.724250, 9th: 0.923500, 10th: 0.237177, 11th: 0.896128, 12th: 0.156813 43 | Train:epoch: 13:[432/432], LossCE: 0.044340, LossDA: -0.028602, LossAll: 0.111692, Auxi1: 98.611115, Auxi2: 98.710579, Task: 98.744934 44 | Test:epoch: 13, Top1_auxi1: 0.646009, Top1_auxi2: 0.697454, Top1: 0.690091 45 | Class-wise Acc:1st: 0.838179, 2nd: 0.527770, 3rd: 0.755437, 4th: 0.693106, 5th: 0.828821, 6th: 0.747470, 7th: 0.897170, 8th: 0.727500, 9th: 0.946362, 10th: 0.244630, 11th: 0.905571, 12th: 0.169070 46 | Train:epoch: 14:[432/432], LossCE: 0.041508, LossDA: -0.027529, LossAll: 0.104398, Auxi1: 98.697914, Auxi2: 98.743126, Task: 98.831741 47 | Test:epoch: 14, Top1_auxi1: 0.652566, Top1_auxi2: 0.701091, Top1: 0.692790 48 | Class-wise Acc:1st: 0.853538, 2nd: 0.501871, 3rd: 0.770576, 4th: 0.674166, 5th: 0.830313, 6th: 0.754699, 7th: 0.924776, 8th: 0.752250, 9th: 0.945922, 10th: 0.228409, 11th: 0.903919, 12th: 0.173035 49 | Train:epoch: 15:[432/432], LossCE: 0.037630, LossDA: -0.026932, LossAll: 0.092539, Auxi1: 98.786530, Auxi2: 98.885994, Task: 98.972801 50 | Test:epoch: 15, Top1_auxi1: 0.661947, Top1_auxi2: 0.712143, Top1: 0.702932 51 | Class-wise Acc:1st: 0.855458, 2nd: 0.541583, 3rd: 0.785714, 4th: 0.691857, 5th: 0.836709, 6th: 0.788434, 7th: 0.925121, 8th: 0.761500, 9th: 0.952957, 10th: 0.249890, 11th: 0.886449, 12th: 0.159517 Best Acc so far: 0.712143 52 | Train:epoch: 16:[432/432], LossCE: 0.032933, LossDA: -0.027620, LossAll: 0.078184, Auxi1: 98.927589, Auxi2: 99.034286, Task: 99.093971 53 | Test:epoch: 16, Top1_auxi1: 0.658064, Top1_auxi2: 0.708298, Top1: 0.697768 54 | Class-wise Acc:1st: 0.854087, 2nd: 0.524029, 3rd: 0.758209, 4th: 0.676858, 5th: 0.829034, 6th: 0.821205, 7th: 0.912871, 8th: 0.763500, 9th: 0.956034, 10th: 0.203858, 11th: 0.910765, 12th: 0.162761 55 | Train:epoch: 17:[432/432], LossCE: 0.031930, LossDA: -0.028328, LossAll: 0.074900, Auxi1: 98.996315, Auxi2: 99.063225, Task: 99.151840 56 | Test:epoch: 17, Top1_auxi1: 0.665599, Top1_auxi2: 0.713065, Top1: 0.704245 57 | Class-wise Acc:1st: 0.851892, 2nd: 0.554532, 3rd: 0.798934, 4th: 0.702913, 5th: 0.856747, 6th: 0.784096, 7th: 0.924948, 8th: 0.754750, 9th: 0.950099, 10th: 0.224463, 11th: 0.881020, 12th: 0.166547 Best Acc so far: 0.713065 58 | Train:epoch: 18:[432/432], LossCE: 0.028280, LossDA: -0.027890, LossAll: 0.064488, Auxi1: 99.133751, Auxi2: 99.159073, Task: 99.273003 59 | Test:epoch: 18, Top1_auxi1: 0.672141, Top1_auxi2: 0.719737, Top1: 0.709900 60 | Class-wise Acc:1st: 0.881788, 2nd: 0.595108, 3rd: 0.778038, 4th: 0.682627, 5th: 0.833511, 6th: 0.802410, 7th: 0.946342, 8th: 0.757250, 9th: 0.948780, 10th: 0.225778, 11th: 0.896837, 12th: 0.170332 Best Acc so far: 0.719737 61 | -------------------------------------------------------------------------------- /experiments/ckpt/logVisDA_train2validation_McDalNet_DANN/log.txt: -------------------------------------------------------------------------------- 1 | 2 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "DANN"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_McDalNet_DANN", "SAVE_DIR": "./experiments/ckpt/logVisDA_train2validation_McDalNet_DANN", "NUM_WORKERS": 8, "PRINT_STEP": 3} 3 | 4 | Train:epoch: 0:[432/432], LossCE: 0.689842, LossDA: 1.220195, LossAll: 1.910038, Auxi1: 5.635128, Auxi2: 4.967810, Task: 83.132599 5 | Test:epoch: 0, Top1_auxi1: 0.061966, Top1_auxi2: 0.089132, Top1: 0.497863 6 | Class-wise Acc:1st: 0.684586, 2nd: 0.348489, 3rd: 0.371642, 4th: 0.568888, 5th: 0.566830, 6th: 0.240482, 7th: 0.865942, 8th: 0.211750, 9th: 0.772697, 10th: 0.256028, 11th: 0.958687, 12th: 0.128335 Best Acc so far: 0.497863 7 | Train:epoch: 1:[432/432], LossCE: 0.249252, LossDA: 1.135072, LossAll: 1.384324, Auxi1: 5.407263, Auxi2: 4.445168, Task: 92.261650 8 | Test:epoch: 1, Top1_auxi1: 0.065403, Top1_auxi2: 0.092108, Top1: 0.556021 9 | Class-wise Acc:1st: 0.805815, 2nd: 0.368058, 3rd: 0.598934, 4th: 0.607442, 5th: 0.653805, 6th: 0.180723, 7th: 0.846618, 8th: 0.280750, 9th: 0.832711, 10th: 0.419991, 11th: 0.921860, 12th: 0.155552 Best Acc so far: 0.556021 10 | Train:epoch: 2:[432/432], LossCE: 0.187990, LossDA: 1.307358, LossAll: 1.495347, Auxi1: 5.485026, Auxi2: 4.302300, Task: 93.992332 11 | Test:epoch: 2, Top1_auxi1: 0.063872, Top1_auxi2: 0.096490, Top1: 0.567105 12 | Class-wise Acc:1st: 0.797586, 2nd: 0.421007, 3rd: 0.587846, 4th: 0.601961, 5th: 0.654231, 6th: 0.278072, 7th: 0.840235, 8th: 0.310250, 9th: 0.797978, 10th: 0.427883, 11th: 0.919500, 12th: 0.168709 Best Acc so far: 0.567105 13 | Train:epoch: 3:[432/432], LossCE: 0.152223, LossDA: 1.563822, LossAll: 1.716045, Auxi1: 5.629702, Auxi2: 4.356554, Task: 95.131653 14 | Test:epoch: 3, Top1_auxi1: 0.068933, Top1_auxi2: 0.087752, Top1: 0.565053 15 | Class-wise Acc:1st: 0.771530, 2nd: 0.403453, 3rd: 0.648188, 4th: 0.559850, 5th: 0.662332, 6th: 0.320482, 7th: 0.843513, 8th: 0.283750, 9th: 0.763245, 10th: 0.423498, 11th: 0.914778, 12th: 0.186013 16 | Train:epoch: 4:[432/432], LossCE: 0.130460, LossDA: 1.403738, LossAll: 1.534199, Auxi1: 4.700159, Auxi2: 4.197410, Task: 95.806206 17 | Test:epoch: 4, Top1_auxi1: 0.063369, Top1_auxi2: 0.082462, Top1: 0.567353 18 | Class-wise Acc:1st: 0.763028, 2nd: 0.410360, 3rd: 0.663753, 4th: 0.568215, 5th: 0.681731, 6th: 0.241446, 7th: 0.844203, 8th: 0.344250, 9th: 0.791822, 10th: 0.385796, 11th: 0.915722, 12th: 0.197909 Best Acc so far: 0.567353 19 | Train:epoch: 5:[432/432], LossCE: 0.109417, LossDA: 1.380119, LossAll: 1.489536, Auxi1: 4.079861, Auxi2: 4.309534, Task: 96.551285 20 | Test:epoch: 5, Top1_auxi1: 0.065876, Top1_auxi2: 0.080367, Top1: 0.556004 21 | Class-wise Acc:1st: 0.719967, 2nd: 0.349353, 3rd: 0.687846, 4th: 0.614268, 5th: 0.675549, 6th: 0.221687, 7th: 0.829365, 8th: 0.321750, 9th: 0.760607, 10th: 0.421306, 11th: 0.897309, 12th: 0.173035 22 | Train:epoch: 6:[432/432], LossCE: 0.095272, LossDA: 1.400877, LossAll: 1.496149, Auxi1: 4.456018, Auxi2: 4.295066, Task: 96.934677 23 | Test:epoch: 6, Top1_auxi1: 0.060089, Top1_auxi2: 0.087750, Top1: 0.578876 24 | Class-wise Acc:1st: 0.815963, 2nd: 0.366619, 3rd: 0.700426, 4th: 0.611768, 5th: 0.672351, 6th: 0.248675, 7th: 0.847999, 8th: 0.303500, 9th: 0.834249, 10th: 0.434897, 11th: 0.880076, 12th: 0.229993 Best Acc so far: 0.578876 25 | Train:epoch: 7:[432/432], LossCE: 0.081930, LossDA: 1.395533, LossAll: 1.477463, Auxi1: 4.390914, Auxi2: 3.891783, Task: 97.455513 26 | Test:epoch: 7, Top1_auxi1: 0.067571, Top1_auxi2: 0.082103, Top1: 0.571576 27 | Class-wise Acc:1st: 0.782501, 2nd: 0.386763, 3rd: 0.605970, 4th: 0.588309, 5th: 0.727350, 6th: 0.194217, 7th: 0.849379, 8th: 0.353500, 9th: 0.799956, 10th: 0.410346, 11th: 0.915486, 12th: 0.245133 28 | Train:epoch: 8:[432/432], LossCE: 0.071810, LossDA: 1.377522, LossAll: 1.449332, Auxi1: 4.537399, Auxi2: 4.146774, Task: 97.762947 29 | Test:epoch: 8, Top1_auxi1: 0.067663, Top1_auxi2: 0.079991, Top1: 0.570504 30 | Class-wise Acc:1st: 0.770433, 2nd: 0.358561, 3rd: 0.715778, 4th: 0.583117, 5th: 0.690258, 6th: 0.228916, 7th: 0.847999, 8th: 0.306750, 9th: 0.807210, 10th: 0.412977, 11th: 0.874410, 12th: 0.249640 31 | Train:epoch: 9:[432/432], LossCE: 0.063001, LossDA: 1.389740, LossAll: 1.452742, Auxi1: 3.976779, Auxi2: 4.201027, Task: 98.126450 32 | Test:epoch: 9, Top1_auxi1: 0.063785, Top1_auxi2: 0.080066, Top1: 0.582241 33 | Class-wise Acc:1st: 0.803895, 2nd: 0.416403, 3rd: 0.700853, 4th: 0.591193, 5th: 0.700490, 6th: 0.235663, 7th: 0.853175, 8th: 0.397250, 9th: 0.821719, 10th: 0.371767, 11th: 0.888102, 12th: 0.206381 Best Acc so far: 0.582241 34 | Train:epoch: 10:[432/432], LossCE: 0.057746, LossDA: 1.394590, LossAll: 1.452338, Auxi1: 3.803169, Auxi2: 4.465061, Task: 98.285591 35 | Test:epoch: 10, Top1_auxi1: 0.062671, Top1_auxi2: 0.078781, Top1: 0.574760 36 | Class-wise Acc:1st: 0.764125, 2nd: 0.391655, 3rd: 0.682729, 4th: 0.549947, 5th: 0.716905, 6th: 0.228434, 7th: 0.851622, 8th: 0.378750, 9th: 0.796219, 10th: 0.378781, 11th: 0.893532, 12th: 0.264420 37 | Train:epoch: 11:[432/432], LossCE: 0.051314, LossDA: 1.387914, LossAll: 1.439229, Auxi1: 4.229962, Auxi2: 4.823134, Task: 98.457397 38 | Test:epoch: 11, Top1_auxi1: 0.064863, Top1_auxi2: 0.084203, Top1: 0.584895 39 | Class-wise Acc:1st: 0.778113, 2nd: 0.423022, 3rd: 0.688486, 4th: 0.509182, 5th: 0.735451, 6th: 0.213494, 7th: 0.859041, 8th: 0.439750, 9th: 0.839305, 10th: 0.402017, 11th: 0.885741, 12th: 0.245133 Best Acc so far: 0.584895 40 | Train:epoch: 12:[432/432], LossCE: 0.046685, LossDA: 1.391144, LossAll: 1.437830, Auxi1: 4.184751, Auxi2: 5.063657, Task: 98.647278 41 | Test:epoch: 12, Top1_auxi1: 0.067518, Top1_auxi2: 0.085447, Top1: 0.576178 42 | Class-wise Acc:1st: 0.776193, 2nd: 0.397986, 3rd: 0.689765, 4th: 0.550812, 5th: 0.708165, 6th: 0.168193, 7th: 0.880262, 8th: 0.409500, 9th: 0.808969, 10th: 0.373520, 11th: 0.868508, 12th: 0.282264 43 | Train:epoch: 13:[432/432], LossCE: 0.043811, LossDA: 1.394163, LossAll: 1.437974, Auxi1: 4.036458, Auxi2: 5.079934, Task: 98.703346 44 | Test:epoch: 13, Top1_auxi1: 0.065180, Top1_auxi2: 0.081070, Top1: 0.560906 45 | Class-wise Acc:1st: 0.752606, 2nd: 0.378705, 3rd: 0.678038, 4th: 0.605134, 5th: 0.664677, 6th: 0.151325, 7th: 0.837992, 8th: 0.322000, 9th: 0.830073, 10th: 0.375712, 11th: 0.888574, 12th: 0.246035 46 | Train:epoch: 14:[432/432], LossCE: 0.039512, LossDA: 1.390498, LossAll: 1.430009, Auxi1: 4.132306, Auxi2: 4.503038, Task: 98.862488 47 | Test:epoch: 14, Top1_auxi1: 0.066672, Top1_auxi2: 0.084346, Top1: 0.555455 48 | Class-wise Acc:1st: 0.727647, 2nd: 0.380719, 3rd: 0.681876, 4th: 0.575137, 5th: 0.671499, 6th: 0.205301, 7th: 0.844893, 8th: 0.287000, 9th: 0.799736, 10th: 0.339325, 11th: 0.881964, 12th: 0.270368 49 | Train:epoch: 15:[432/432], LossCE: 0.036059, LossDA: 1.391966, LossAll: 1.428025, Auxi1: 3.616898, Auxi2: 4.685691, Task: 99.019821 50 | Test:epoch: 15, Top1_auxi1: 0.064970, Top1_auxi2: 0.083399, Top1: 0.557222 51 | Class-wise Acc:1st: 0.720516, 2nd: 0.358561, 3rd: 0.682729, 4th: 0.548120, 5th: 0.710083, 6th: 0.215904, 7th: 0.852657, 8th: 0.360500, 9th: 0.770059, 10th: 0.330995, 11th: 0.888338, 12th: 0.248198 52 | Train:epoch: 16:[432/432], LossCE: 0.032596, LossDA: 1.392626, LossAll: 1.425221, Auxi1: 3.915292, Auxi2: 4.548249, Task: 99.079498 53 | Test:epoch: 16, Top1_auxi1: 0.067874, Top1_auxi2: 0.082306, Top1: 0.577458 54 | Class-wise Acc:1st: 0.757268, 2nd: 0.414388, 3rd: 0.726226, 4th: 0.533026, 5th: 0.745896, 6th: 0.216386, 7th: 0.825224, 8th: 0.378500, 9th: 0.803913, 10th: 0.377904, 11th: 0.865439, 12th: 0.285328 55 | Train:epoch: 17:[432/432], LossCE: 0.030033, LossDA: 1.393536, LossAll: 1.423569, Auxi1: 3.618707, Auxi2: 5.087167, Task: 99.209709 56 | Test:epoch: 17, Top1_auxi1: 0.066593, Top1_auxi2: 0.082290, Top1: 0.571776 57 | Class-wise Acc:1st: 0.733681, 2nd: 0.443165, 3rd: 0.693177, 4th: 0.535718, 5th: 0.699851, 6th: 0.176867, 7th: 0.857660, 8th: 0.397500, 9th: 0.800835, 10th: 0.369575, 11th: 0.885977, 12th: 0.267304 58 | Train:epoch: 18:[432/432], LossCE: 0.027941, LossDA: 1.391720, LossAll: 1.419661, Auxi1: 3.432436, Auxi2: 5.251736, Task: 99.282043 59 | Test:epoch: 18, Top1_auxi1: 0.065550, Top1_auxi2: 0.083747, Top1: 0.574060 60 | Class-wise Acc:1st: 0.758914, 2nd: 0.398273, 3rd: 0.703838, 4th: 0.535141, 5th: 0.715626, 6th: 0.195181, 7th: 0.862146, 8th: 0.366250, 9th: 0.830073, 10th: 0.374397, 11th: 0.856704, 12th: 0.292177 61 | Train:epoch: 19:[432/432], LossCE: 0.025859, LossDA: 1.390437, LossAll: 1.416294, Auxi1: 3.837529, Auxi2: 5.121528, Task: 99.320023 62 | Test:epoch: 19, Top1_auxi1: 0.066620, Top1_auxi2: 0.082075, Top1: 0.558489 63 | Class-wise Acc:1st: 0.765222, 2nd: 0.355396, 3rd: 0.678891, 4th: 0.560427, 5th: 0.725432, 6th: 0.185060, 7th: 0.811594, 8th: 0.326000, 9th: 0.754671, 10th: 0.395441, 11th: 0.898961, 12th: 0.244773 64 | Train:epoch: 20:[432/432], LossCE: 0.023830, LossDA: 1.388485, LossAll: 1.412315, Auxi1: 3.848380, Auxi2: 5.255353, Task: 99.370659 65 | Test:epoch: 20, Top1_auxi1: 0.059989, Top1_auxi2: 0.081193, Top1: 0.579480 66 | Class-wise Acc:1st: 0.763302, 2nd: 0.403453, 3rd: 0.706397, 4th: 0.596385, 5th: 0.748241, 6th: 0.205301, 7th: 0.838854, 8th: 0.368000, 9th: 0.810947, 10th: 0.395441, 11th: 0.860954, 12th: 0.256489 67 | Train:epoch: 21:[432/432], LossCE: 0.022136, LossDA: 1.388649, LossAll: 1.410784, Auxi1: 3.792318, Auxi2: 5.060040, Task: 99.421295 68 | Test:epoch: 21, Top1_auxi1: 0.063822, Top1_auxi2: 0.078161, Top1: 0.571030 69 | Class-wise Acc:1st: 0.746572, 2nd: 0.384748, 3rd: 0.700640, 4th: 0.555139, 5th: 0.736090, 6th: 0.202892, 7th: 0.831263, 8th: 0.360500, 9th: 0.819301, 10th: 0.378781, 11th: 0.874174, 12th: 0.262257 70 | Train:epoch: 22:[432/432], LossCE: 0.019700, LossDA: 1.390863, LossAll: 1.410562, Auxi1: 3.669343, Auxi2: 4.622396, Task: 99.529800 71 | Test:epoch: 22, Top1_auxi1: 0.065152, Top1_auxi2: 0.079538, Top1: 0.572174 72 | Class-wise Acc:1st: 0.712013, 2nd: 0.401439, 3rd: 0.701493, 4th: 0.552062, 5th: 0.726498, 6th: 0.231807, 7th: 0.844720, 8th: 0.379000, 9th: 0.812926, 10th: 0.362999, 11th: 0.880312, 12th: 0.260815 73 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "DANN"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "TASK": "closed", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_McDalNet_DANN", "SAVE_DIR": "./experiments/ckpt/logVisDA_train2validation_McDalNet_DANN", "NUM_WORKERS": 8, "PRINT_STEP": 3} 74 | -------------------------------------------------------------------------------- /experiments/ckpt/logVisDA_train2validation_McDalNet_L1/log.txt: -------------------------------------------------------------------------------- 1 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "L1"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "TASK": "closed", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_McDalNet_L1", "SAVE_DIR": "./experiments/ckpt/logVisDA_train2validation_McDalNet_L1", "NUM_WORKERS": 8, "PRINT_STEP": 3} 2 | 3 | Train:epoch: 0:[432/432], LossCE: 0.693553, LossDA: -0.014865, LossAll: 2.081821, Auxi1: 82.879410, Auxi2: 83.246529, Task: 83.201317 4 | Test:epoch: 0, Top1_auxi1: 0.516962, Top1_auxi2: 0.535738, Top1: 0.541355 5 | Class-wise Acc:1st: 0.746846, 2nd: 0.426187, 3rd: 0.655437, 4th: 0.558985, 5th: 0.662972, 6th: 0.197590, 7th: 0.886128, 8th: 0.177000, 9th: 0.729171, 10th: 0.366506, 11th: 0.904863, 12th: 0.184571 Best Acc so far: 0.541355 6 | Train:epoch: 1:[432/432], LossCE: 0.246766, LossDA: -0.043768, LossAll: 0.711814, Auxi1: 92.135056, Auxi2: 92.173035, Task: 92.312286 7 | Test:epoch: 1, Top1_auxi1: 0.492681, Top1_auxi2: 0.529946, Top1: 0.542277 8 | Class-wise Acc:1st: 0.741909, 2nd: 0.447482, 3rd: 0.627079, 4th: 0.657245, 5th: 0.614794, 6th: 0.185542, 7th: 0.866460, 8th: 0.197750, 9th: 0.791822, 10th: 0.322665, 11th: 0.915250, 12th: 0.139329 Best Acc so far: 0.542277 9 | Train:epoch: 2:[432/432], LossCE: 0.183054, LossDA: -0.059813, LossAll: 0.505924, Auxi1: 93.845848, Auxi2: 93.782555, Task: 94.048393 10 | Test:epoch: 2, Top1_auxi1: 0.490163, Top1_auxi2: 0.522411, Top1: 0.551915 11 | Class-wise Acc:1st: 0.768513, 2nd: 0.391942, 3rd: 0.585927, 4th: 0.589078, 5th: 0.667022, 6th: 0.203373, 7th: 0.874051, 8th: 0.290500, 9th: 0.763245, 10th: 0.325296, 11th: 0.932956, 12th: 0.231074 Best Acc so far: 0.551915 12 | Train:epoch: 3:[432/432], LossCE: 0.144287, LossDA: -0.068417, LossAll: 0.381817, Auxi1: 95.093681, Auxi2: 95.081017, Task: 95.278137 13 | Test:epoch: 3, Top1_auxi1: 0.486351, Top1_auxi2: 0.520537, Top1: 0.563128 14 | Class-wise Acc:1st: 0.756171, 2nd: 0.396259, 3rd: 0.614073, 4th: 0.633689, 5th: 0.704541, 6th: 0.203373, 7th: 0.870600, 8th: 0.308000, 9th: 0.800616, 10th: 0.366068, 11th: 0.922096, 12th: 0.182048 Best Acc so far: 0.563128 15 | Train:epoch: 4:[432/432], LossCE: 0.119764, LossDA: -0.073706, LossAll: 0.302515, Auxi1: 95.912903, Auxi2: 95.921951, Task: 96.135345 16 | Test:epoch: 4, Top1_auxi1: 0.478332, Top1_auxi2: 0.531994, Top1: 0.566777 17 | Class-wise Acc:1st: 0.769062, 2nd: 0.445180, 3rd: 0.653092, 4th: 0.602923, 5th: 0.694308, 6th: 0.194699, 7th: 0.888199, 8th: 0.314000, 9th: 0.768960, 10th: 0.327050, 11th: 0.909348, 12th: 0.234499 Best Acc so far: 0.566777 18 | Train:epoch: 5:[432/432], LossCE: 0.102737, LossDA: -0.076498, LossAll: 0.248881, Auxi1: 96.408424, Auxi2: 96.482567, Task: 96.585648 19 | Test:epoch: 5, Top1_auxi1: 0.488338, Top1_auxi2: 0.531829, Top1: 0.582575 20 | Class-wise Acc:1st: 0.773725, 2nd: 0.455540, 3rd: 0.684222, 4th: 0.618114, 5th: 0.718610, 6th: 0.180241, 7th: 0.864734, 8th: 0.364750, 9th: 0.792262, 10th: 0.405085, 11th: 0.910293, 12th: 0.223324 Best Acc so far: 0.582575 21 | Train:epoch: 6:[432/432], LossCE: 0.085473, LossDA: -0.077619, LossAll: 0.195600, Auxi1: 97.139038, Auxi2: 97.082970, Task: 97.327110 22 | Test:epoch: 6, Top1_auxi1: 0.496820, Top1_auxi2: 0.537137, Top1: 0.594800 23 | Class-wise Acc:1st: 0.812671, 2nd: 0.442590, 3rd: 0.681237, 4th: 0.618883, 5th: 0.729695, 6th: 0.296867, 7th: 0.891822, 8th: 0.367000, 9th: 0.854693, 10th: 0.345024, 11th: 0.905335, 12th: 0.191781 Best Acc so far: 0.594800 24 | Train:epoch: 7:[432/432], LossCE: 0.075982, LossDA: -0.077116, LossAll: 0.167648, Auxi1: 97.437431, Auxi2: 97.289139, Task: 97.562210 25 | Test:epoch: 7, Top1_auxi1: 0.483066, Top1_auxi2: 0.540713, Top1: 0.590002 26 | Class-wise Acc:1st: 0.791278, 2nd: 0.413813, 3rd: 0.741578, 4th: 0.607634, 5th: 0.724579, 6th: 0.240000, 7th: 0.868875, 8th: 0.377000, 9th: 0.783029, 10th: 0.423060, 11th: 0.906043, 12th: 0.203136 27 | Train:epoch: 8:[432/432], LossCE: 0.066726, LossDA: -0.076585, LossAll: 0.139383, Auxi1: 97.706886, Auxi2: 97.652634, Task: 97.882309 28 | Test:epoch: 8, Top1_auxi1: 0.495741, Top1_auxi2: 0.555948, Top1: 0.606180 29 | Class-wise Acc:1st: 0.811300, 2nd: 0.391942, 3rd: 0.729638, 4th: 0.640611, 5th: 0.777659, 6th: 0.277108, 7th: 0.908040, 8th: 0.440000, 9th: 0.824577, 10th: 0.395002, 11th: 0.893532, 12th: 0.184751 Best Acc so far: 0.606180 30 | Train:epoch: 9:[432/432], LossCE: 0.060872, LossDA: -0.075088, LossAll: 0.123370, Auxi1: 97.862411, Auxi2: 97.838905, Task: 98.025177 31 | Test:epoch: 9, Top1_auxi1: 0.499938, Top1_auxi2: 0.551240, Top1: 0.599826 32 | Class-wise Acc:1st: 0.800329, 2nd: 0.367770, 3rd: 0.725586, 4th: 0.612345, 5th: 0.756342, 6th: 0.315181, 7th: 0.908213, 8th: 0.409000, 9th: 0.841064, 10th: 0.325296, 11th: 0.897781, 12th: 0.239005 33 | Train:epoch: 10:[432/432], LossCE: 0.054908, LossDA: -0.072794, LossAll: 0.107301, Auxi1: 98.144531, Auxi2: 98.090279, Task: 98.260269 34 | Test:epoch: 10, Top1_auxi1: 0.504898, Top1_auxi2: 0.558344, Top1: 0.609342 35 | Class-wise Acc:1st: 0.818431, 2nd: 0.417842, 3rd: 0.735181, 4th: 0.609461, 5th: 0.767853, 6th: 0.324819, 7th: 0.896998, 8th: 0.445000, 9th: 0.830292, 10th: 0.345463, 11th: 0.903211, 12th: 0.217556 Best Acc so far: 0.609342 36 | Train:epoch: 11:[432/432], LossCE: 0.049372, LossDA: -0.070811, LossAll: 0.092945, Auxi1: 98.386864, Auxi2: 98.327187, Task: 98.515266 37 | Test:epoch: 11, Top1_auxi1: 0.517144, Top1_auxi2: 0.574096, Top1: 0.623316 38 | Class-wise Acc:1st: 0.831596, 2nd: 0.427050, 3rd: 0.749040, 4th: 0.631766, 5th: 0.789810, 6th: 0.357590, 7th: 0.906832, 8th: 0.505000, 9th: 0.845680, 10th: 0.340202, 11th: 0.897309, 12th: 0.197909 Best Acc so far: 0.623316 39 | Train:epoch: 12:[432/432], LossCE: 0.045497, LossDA: -0.068416, LossAll: 0.082828, Auxi1: 98.457397, Auxi2: 98.441116, Task: 98.647278 40 | Test:epoch: 12, Top1_auxi1: 0.516069, Top1_auxi2: 0.578118, Top1: 0.624275 41 | Class-wise Acc:1st: 0.831322, 2nd: 0.396547, 3rd: 0.720682, 4th: 0.660225, 5th: 0.797058, 6th: 0.358554, 7th: 0.907867, 8th: 0.515500, 9th: 0.838646, 10th: 0.372644, 11th: 0.898489, 12th: 0.193764 Best Acc so far: 0.624275 42 | Train:epoch: 13:[432/432], LossCE: 0.041655, LossDA: -0.066161, LossAll: 0.073744, Auxi1: 98.509842, Auxi2: 98.511650, Task: 98.712387 43 | Test:epoch: 13, Top1_auxi1: 0.528750, Top1_auxi2: 0.580861, Top1: 0.628952 44 | Class-wise Acc:1st: 0.820077, 2nd: 0.452374, 3rd: 0.728998, 4th: 0.672243, 5th: 0.796206, 6th: 0.361928, 7th: 0.904589, 8th: 0.501250, 9th: 0.863486, 10th: 0.358615, 11th: 0.889754, 12th: 0.197909 Best Acc so far: 0.628952 45 | Train:epoch: 14:[432/432], LossCE: 0.037410, LossDA: -0.063568, LossAll: 0.063342, Auxi1: 98.708771, Auxi2: 98.694298, Task: 98.900467 46 | Test:epoch: 14, Top1_auxi1: 0.537866, Top1_auxi2: 0.599769, Top1: 0.643354 47 | Class-wise Acc:1st: 0.849698, 2nd: 0.498129, 3rd: 0.764392, 4th: 0.650611, 5th: 0.782989, 6th: 0.446747, 7th: 0.919082, 8th: 0.512000, 9th: 0.857111, 10th: 0.359053, 11th: 0.887394, 12th: 0.193043 Best Acc so far: 0.643354 48 | Train:epoch: 15:[432/432], LossCE: 0.035579, LossDA: -0.061481, LossAll: 0.059481, Auxi1: 98.770256, Auxi2: 98.753983, Task: 98.954720 49 | Test:epoch: 15, Top1_auxi1: 0.551721, Top1_auxi2: 0.614768, Top1: 0.655238 50 | Class-wise Acc:1st: 0.886177, 2nd: 0.496691, 3rd: 0.755864, 4th: 0.661090, 5th: 0.815604, 6th: 0.421687, 7th: 0.929607, 8th: 0.584750, 9th: 0.866564, 10th: 0.367383, 11th: 0.883853, 12th: 0.193583 Best Acc so far: 0.655238 51 | Train:epoch: 16:[432/432], LossCE: 0.032109, LossDA: -0.058913, LossAll: 0.051380, Auxi1: 98.909508, Auxi2: 98.858871, Task: 99.045143 52 | Test:epoch: 16, Top1_auxi1: 0.553002, Top1_auxi2: 0.608265, Top1: 0.646800 53 | Class-wise Acc:1st: 0.833242, 2nd: 0.502734, 3rd: 0.739659, 4th: 0.665224, 5th: 0.816670, 6th: 0.442410, 7th: 0.919255, 8th: 0.516500, 9th: 0.854474, 10th: 0.380973, 11th: 0.897781, 12th: 0.192682 54 | Train:epoch: 17:[432/432], LossCE: 0.030613, LossDA: -0.056364, LossAll: 0.049122, Auxi1: 99.028862, Auxi2: 98.958336, Task: 99.131943 55 | Test:epoch: 17, Top1_auxi1: 0.565271, Top1_auxi2: 0.612788, Top1: 0.655612 56 | Class-wise Acc:1st: 0.835985, 2nd: 0.496403, 3rd: 0.763539, 4th: 0.689549, 5th: 0.821147, 6th: 0.509880, 7th: 0.910973, 8th: 0.545750, 9th: 0.900418, 10th: 0.349847, 11th: 0.890463, 12th: 0.153389 Best Acc so far: 0.655612 57 | Train:epoch: 18:[432/432], LossCE: 0.027097, LossDA: -0.054544, LossAll: 0.039466, Auxi1: 99.121094, Auxi2: 99.061417, Task: 99.260345 58 | Test:epoch: 18, Top1_auxi1: 0.559192, Top1_auxi2: 0.617162, Top1: 0.655087 59 | Class-wise Acc:1st: 0.873834, 2nd: 0.437122, 3rd: 0.718763, 4th: 0.682434, 5th: 0.809209, 6th: 0.503133, 7th: 0.924431, 8th: 0.591250, 9th: 0.863047, 10th: 0.375274, 11th: 0.903919, 12th: 0.178623 60 | Train:epoch: 19:[432/432], LossCE: 0.025011, LossDA: -0.052258, LossAll: 0.035955, Auxi1: 99.153648, Auxi2: 99.115669, Task: 99.318214 61 | Test:epoch: 19, Top1_auxi1: 0.587179, Top1_auxi2: 0.623123, Top1: 0.671575 62 | Class-wise Acc:1st: 0.842567, 2nd: 0.544748, 3rd: 0.766524, 4th: 0.681761, 5th: 0.827116, 6th: 0.567229, 7th: 0.909593, 8th: 0.617250, 9th: 0.887008, 10th: 0.346778, 11th: 0.902502, 12th: 0.165826 Best Acc so far: 0.671575 63 | Train:epoch: 20:[432/432], LossCE: 0.024907, LossDA: -0.050518, LossAll: 0.036782, Auxi1: 99.164497, Auxi2: 99.151840, Task: 99.298325 64 | Test:epoch: 20, Top1_auxi1: 0.584258, Top1_auxi2: 0.621111, Top1: 0.667436 65 | Class-wise Acc:1st: 0.854361, 2nd: 0.468201, 3rd: 0.757783, 4th: 0.676473, 5th: 0.818802, 6th: 0.585542, 7th: 0.917702, 8th: 0.648250, 9th: 0.891185, 10th: 0.344586, 11th: 0.901794, 12th: 0.144557 66 | Train:epoch: 21:[432/432], LossCE: 0.024208, LossDA: -0.048013, LossAll: 0.037116, Auxi1: 99.184387, Auxi2: 99.184387, Task: 99.316406 67 | Test:epoch: 21, Top1_auxi1: 0.596875, Top1_auxi2: 0.643675, Top1: 0.678382 68 | Class-wise Acc:1st: 0.870817, 2nd: 0.480576, 3rd: 0.763326, 4th: 0.696952, 5th: 0.841398, 6th: 0.622651, 7th: 0.927536, 8th: 0.617750, 9th: 0.914047, 10th: 0.367821, 11th: 0.884325, 12th: 0.153389 Best Acc so far: 0.678382 69 | Train:epoch: 22:[432/432], LossCE: 0.022081, LossDA: -0.047065, LossAll: 0.030939, Auxi1: 99.244072, Auxi2: 99.233215, Task: 99.345345 70 | Test:epoch: 22, Top1_auxi1: 0.604596, Top1_auxi2: 0.639697, Top1: 0.681608 71 | Class-wise Acc:1st: 0.845036, 2nd: 0.511079, 3rd: 0.783795, 4th: 0.696760, 5th: 0.810701, 6th: 0.627952, 7th: 0.914079, 8th: 0.686500, 9th: 0.910969, 10th: 0.356861, 11th: 0.893532, 12th: 0.142033 Best Acc so far: 0.681608 72 | Train:epoch: 23:[432/432], LossCE: 0.021154, LossDA: -0.044631, LossAll: 0.030898, Auxi1: 99.345345, Auxi2: 99.229599, Task: 99.430336 73 | Test:epoch: 23, Top1_auxi1: 0.604646, Top1_auxi2: 0.653916, Top1: 0.684511 74 | Class-wise Acc:1st: 0.863412, 2nd: 0.514245, 3rd: 0.793817, 4th: 0.713681, 5th: 0.851418, 6th: 0.622651, 7th: 0.911146, 8th: 0.645500, 9th: 0.888327, 10th: 0.387111, 11th: 0.883853, 12th: 0.138969 Best Acc so far: 0.684511 75 | Train:epoch: 24:[432/432], LossCE: 0.019100, LossDA: -0.043429, LossAll: 0.025213, Auxi1: 99.361618, Auxi2: 99.339920, Task: 99.488213 76 | Test:epoch: 24, Top1_auxi1: 0.599118, Top1_auxi2: 0.656066, Top1: 0.682732 77 | Class-wise Acc:1st: 0.867526, 2nd: 0.522014, 3rd: 0.776333, 4th: 0.695510, 5th: 0.822426, 6th: 0.631807, 7th: 0.944790, 8th: 0.665500, 9th: 0.892064, 10th: 0.330118, 11th: 0.887158, 12th: 0.157534 78 | Train:epoch: 25:[432/432], LossCE: 0.018116, LossDA: -0.042065, LossAll: 0.023505, Auxi1: 99.414062, Auxi2: 99.394173, Task: 99.506294 79 | Test:epoch: 25, Top1_auxi1: 0.616083, Top1_auxi2: 0.658810, Top1: 0.691398 80 | Class-wise Acc:1st: 0.877674, 2nd: 0.548489, 3rd: 0.767804, 4th: 0.699356, 5th: 0.835856, 6th: 0.639518, 7th: 0.907350, 8th: 0.668000, 9th: 0.902176, 10th: 0.393249, 11th: 0.903919, 12th: 0.153389 Best Acc so far: 0.691398 81 | Train:epoch: 26:[432/432], LossCE: 0.016651, LossDA: -0.040484, LossAll: 0.020489, Auxi1: 99.452042, Auxi2: 99.417679, Task: 99.582253 82 | Test:epoch: 26, Top1_auxi1: 0.611797, Top1_auxi2: 0.662253, Top1: 0.687778 83 | Class-wise Acc:1st: 0.847778, 2nd: 0.561727, 3rd: 0.799787, 4th: 0.722815, 5th: 0.819655, 6th: 0.642892, 7th: 0.930297, 8th: 0.674500, 9th: 0.907232, 10th: 0.327488, 11th: 0.879839, 12th: 0.139329 84 | Train:epoch: 27:[432/432], LossCE: 0.015983, LossDA: -0.038920, LossAll: 0.020103, Auxi1: 99.509911, Auxi2: 99.435768, Task: 99.582253 85 | Test:epoch: 27, Top1_auxi1: 0.629375, Top1_auxi2: 0.668425, Top1: 0.694345 86 | Class-wise Acc:1st: 0.873560, 2nd: 0.544460, 3rd: 0.792964, 4th: 0.700221, 5th: 0.833298, 6th: 0.686265, 7th: 0.935300, 8th: 0.698500, 9th: 0.926357, 10th: 0.296361, 11th: 0.884797, 12th: 0.160058 Best Acc so far: 0.694345 87 | Train:epoch: 28:[432/432], LossCE: 0.015241, LossDA: -0.037351, LossAll: 0.018945, Auxi1: 99.490021, Auxi2: 99.462891, Task: 99.600334 88 | Test:epoch: 28, Top1_auxi1: 0.641193, Top1_auxi2: 0.675345, Top1: 0.703984 89 | Class-wise Acc:1st: 0.861492, 2nd: 0.596547, 3rd: 0.814499, 4th: 0.695702, 5th: 0.845449, 6th: 0.708434, 7th: 0.930469, 8th: 0.730250, 9th: 0.928995, 10th: 0.305129, 11th: 0.872049, 12th: 0.158796 Best Acc so far: 0.703984 90 | Train:epoch: 29:[432/432], LossCE: 0.014775, LossDA: -0.036507, LossAll: 0.017794, Auxi1: 99.513527, Auxi2: 99.477356, Task: 99.600334 91 | Test:epoch: 29, Top1_auxi1: 0.622866, Top1_auxi2: 0.670656, Top1: 0.695423 92 | Class-wise Acc:1st: 0.879594, 2nd: 0.567194, 3rd: 0.788913, 4th: 0.712720, 5th: 0.836069, 6th: 0.631325, 7th: 0.941339, 8th: 0.727250, 9th: 0.904375, 10th: 0.334941, 11th: 0.888338, 12th: 0.133021 93 | Train:epoch: 30:[432/432], LossCE: 0.012954, LossDA: -0.035282, LossAll: 0.013344, Auxi1: 99.594910, Auxi2: 99.584061, Task: 99.696182 94 | Test:epoch: 30, Top1_auxi1: 0.629052, Top1_auxi2: 0.672304, Top1: 0.694756 95 | Class-wise Acc:1st: 0.873012, 2nd: 0.532950, 3rd: 0.818763, 4th: 0.708586, 5th: 0.840546, 6th: 0.681928, 7th: 0.927536, 8th: 0.709750, 9th: 0.906353, 10th: 0.327488, 11th: 0.889754, 12th: 0.120404 96 | Train:epoch: 31:[432/432], LossCE: 0.012972, LossDA: -0.034260, LossAll: 0.014438, Auxi1: 99.573204, Auxi2: 99.533424, Task: 99.681717 97 | Test:epoch: 31, Top1_auxi1: 0.645591, Top1_auxi2: 0.677350, Top1: 0.703731 98 | Class-wise Acc:1st: 0.882063, 2nd: 0.583022, 3rd: 0.795949, 4th: 0.708009, 5th: 0.850991, 6th: 0.708434, 7th: 0.924431, 8th: 0.749000, 9th: 0.914707, 10th: 0.300745, 11th: 0.891879, 12th: 0.135544 99 | -------------------------------------------------------------------------------- /experiments/configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/.DS_Store -------------------------------------------------------------------------------- /experiments/configs/ImageCLEF/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/ImageCLEF/.DS_Store -------------------------------------------------------------------------------- /experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 12 3 | DATASET: 'ImageCLEF' 4 | DATAROOT: '/data1/domain_adaptation/image_CLEF' 5 | SOURCE_NAME: 'c' 6 | TARGET_NAME: 'i' 7 | VAL_NAME: 'i' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'ours' 14 | 15 | 16 | EVAL_METRIC: "accu" 17 | SAVE_DIR: "./experiments/ckpt" 18 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/ImageCLEF/SymmNets/clef_train_c2i_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 12 3 | DATASET: 'ImageCLEF' 4 | DATAROOT: '/data1/domain_adaptation/image_CLEF' 5 | SOURCE_NAME: 'c' 6 | TARGET_NAME: 'i' 7 | VAL_NAME: 'i' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'ours' 14 | 15 | 16 | EVAL_METRIC: "accu" 17 | SAVE_DIR: "./experiments/SymmNets" 18 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/Office31/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/Office31/.DS_Store -------------------------------------------------------------------------------- /experiments/configs/Office31/McDalNet/office31_train_webcam2amazon_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 31 3 | DATASET: 'Office31' 4 | DATAROOT: '/data1/domain_adaptation/Office31' 5 | SOURCE_NAME: 'webcam' 6 | TARGET_NAME: 'amazon' 7 | VAL_NAME: 'amazon' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'ours' 14 | 15 | 16 | EVAL_METRIC: "accu" 17 | SAVE_DIR: "./experiments/ckpt" 18 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/Office31/SymmNets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/Office31/SymmNets/.DS_Store -------------------------------------------------------------------------------- /experiments/configs/Office31/SymmNets/office31_train_amazon2dslr_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 31 3 | DATASET: 'Office31' 4 | DATAROOT: '/data1/domain_adaptation/Office31' 5 | SOURCE_NAME: 'amazon' 6 | TARGET_NAME: 'dslr' 7 | VAL_NAME: 'dslr' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'longs' 14 | 15 | TRAIN: 16 | MAX_EPOCH: 200 17 | # More training epoch is ok, but not necessary. 18 | 19 | EVAL_METRIC: "accu" 20 | SAVE_DIR: "./experiments/SymmNets" 21 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/Office31/SymmNets/office31_train_amazon2dslr_cfg_SC.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 31 3 | DATASET: 'Office31' 4 | DATAROOT: '/data1/domain_adaptation/Office31' 5 | SOURCE_NAME: 'amazon' 6 | TARGET_NAME: 'dslr' 7 | VAL_NAME: 'dslr' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'longs' 14 | 15 | STRENGTHEN: 16 | DATALOAD: 'soft' 17 | PERCATE: 4 18 | CLUSTER_FREQ: 20 19 | 20 | TRAIN: 21 | MAX_EPOCH: 200 22 | # More training epoch is ok, but not necessary. 23 | 24 | EVAL_METRIC: "accu" 25 | SAVE_DIR: "./experiments/SymmNets" 26 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/Office31/SymmNets/office31_train_amazon2webcam_open_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 11 3 | DATASET: 'Office31' 4 | DATAROOT: '/data1/domain_adaptation/openset_Office31/' 5 | SOURCE_NAME: 'amazon_s_busto' 6 | TARGET_NAME: 'webcam_t_busto' 7 | VAL_NAME: 'webcam_t_busto' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'ours' 14 | 15 | 16 | EVAL_METRIC: "accu_mean" 17 | SAVE_DIR: "./experiments/Open" 18 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/Office31/SymmNets/office31_train_webcam2amazon_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 31 3 | DATASET: 'Office31' 4 | DATAROOT: '/data1/domain_adaptation/Office31' 5 | SOURCE_NAME: 'webcam' 6 | TARGET_NAME: 'amazon' 7 | VAL_NAME: 'amazon' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'ours' 14 | 15 | TRAIN: 16 | MAX_EPOCH: 200 17 | # More training epoch is ok, but not necessary. 18 | 19 | EVAL_METRIC: "accu" 20 | SAVE_DIR: "./experiments/SymmNets" 21 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/Office31/SymmNets/office31_train_webcam2amazon_cfg_SC.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 31 3 | DATASET: 'Office31' 4 | DATAROOT: '/data1/domain_adaptation/Office31' 5 | SOURCE_NAME: 'webcam' 6 | TARGET_NAME: 'amazon' 7 | VAL_NAME: 'amazon' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'longs' 14 | 15 | STRENGTHEN: 16 | DATALOAD: 'soft' 17 | PERCATE: 4 18 | CLUSTER_FREQ: 20 19 | 20 | TRAIN: 21 | MAX_EPOCH: 200 22 | # More training epoch is ok, but not necessary. 23 | 24 | 25 | EVAL_METRIC: "accu" 26 | SAVE_DIR: "./experiments/SymmNets" 27 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/OfficeHome/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/OfficeHome/.DS_Store -------------------------------------------------------------------------------- /experiments/configs/OfficeHome/SymmNets/home_train_A2R_partial_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 65 3 | DATASET: 'OfficeHome' 4 | DATAROOT: '/data1/domain_adaptation/OfficeHome' 5 | SOURCE_NAME: 'Art' 6 | TARGET_NAME: 'Real_World_25' 7 | VAL_NAME: 'Real_World_25' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'ours' 14 | 15 | TRAIN: 16 | MAX_EPOCH: 80 17 | 18 | EVAL_METRIC: "accu" 19 | SAVE_DIR: "./experiments/Partial" 20 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 12 3 | DATASET: 'VisDA' 4 | DATAROOT: '/data1/domain_adaptation/visDA' 5 | SOURCE_NAME: 'train' 6 | TARGET_NAME: 'validation' 7 | VAL_NAME: 'validation' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'simple' 14 | 15 | TRAIN: 16 | BASE_LR: 0.001 17 | MAX_EPOCH: 30 18 | LR_SCHEDULE: 'fix' 19 | 20 | EVAL_METRIC: "accu_mean" 21 | SAVE_DIR: "./experiments/ckpt" 22 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 12 3 | DATASET: 'VisDA' 4 | DATAROOT: '/data1/domain_adaptation/visDA' 5 | SOURCE_NAME: 'train' 6 | TARGET_NAME: 'validation' 7 | VAL_NAME: 'validation' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet50' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'simple' 14 | 15 | TRAIN: 16 | BASE_LR: 0.001 17 | MAX_EPOCH: 30 18 | LR_SCHEDULE: 'fix' 19 | 20 | EVAL_METRIC: "accu_mean" 21 | SAVE_DIR: "./experiments/SymmNets" 22 | NUM_WORKERS: 8 -------------------------------------------------------------------------------- /experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg_res101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NUM_CLASSES: 12 3 | DATASET: 'VisDA' 4 | DATAROOT: '/data1/domain_adaptation/visDA' 5 | SOURCE_NAME: 'train' 6 | TARGET_NAME: 'validation' 7 | VAL_NAME: 'validation' 8 | 9 | MODEL: 10 | FEATURE_EXTRACTOR: 'resnet101' 11 | 12 | DATA_TRANSFORM: 13 | TYPE: 'simple' 14 | 15 | TRAIN: 16 | BASE_LR: 0.001 17 | MAX_EPOCH: 40 18 | LR_SCHEDULE: 'fix' 19 | 20 | EVAL_METRIC: "accu_mean" 21 | SAVE_DIR: "./experiments/SymmNets" 22 | NUM_WORKERS: 8 23 | -------------------------------------------------------------------------------- /models/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.utils import process_zero_values 5 | import ipdb 6 | 7 | 8 | def _assert_no_grad(variable): 9 | assert not variable.requires_grad, \ 10 | "nn criterions don't compute the gradient w.r.t. targets - please " \ 11 | "mark these variables as volatile or not requiring gradients" 12 | 13 | 14 | class _Loss(nn.Module): 15 | def __init__(self, size_average=True): 16 | super(_Loss, self).__init__() 17 | self.size_average = size_average 18 | 19 | 20 | class _WeightedLoss(_Loss): 21 | def __init__(self, weight=None, size_average=True): 22 | super(_WeightedLoss, self).__init__(size_average) 23 | self.register_buffer('weight', weight) 24 | 25 | 26 | class CrossEntropyClassWeighted(_Loss): 27 | 28 | def __init__(self, size_average=True, ignore_index=-100, reduce=None, reduction='elementwise_mean'): 29 | super(CrossEntropyClassWeighted, self).__init__(size_average) 30 | self.ignore_index = ignore_index 31 | self.reduction = reduction 32 | 33 | def forward(self, input, target, weight=None): 34 | return F.cross_entropy(input, target, weight, ignore_index=self.ignore_index, reduction=self.reduction) 35 | 36 | 37 | ### clone this function from: https://github.com/krumo/swd_pytorch/blob/master/swd_pytorch.py. [Unofficial] 38 | def discrepancy_slice_wasserstein(p1, p2): 39 | s = p1.shape 40 | if s[1] > 1: 41 | proj = torch.randn(s[1], 128).cuda() 42 | proj *= torch.rsqrt(torch.sum(torch.mul(proj, proj), 0, keepdim=True)) 43 | p1 = torch.matmul(p1, proj) 44 | p2 = torch.matmul(p2, proj) 45 | p1 = torch.topk(p1, s[0], dim=0)[0] 46 | p2 = torch.topk(p2, s[0], dim=0)[0] 47 | dist = p1 - p2 48 | wdist = torch.mean(torch.mul(dist, dist)) 49 | 50 | return wdist 51 | 52 | 53 | class McDalNetLoss(_WeightedLoss): 54 | 55 | def __init__(self, weight=None, size_average=True): 56 | super(McDalNetLoss, self).__init__(weight, size_average) 57 | 58 | def forward(self, input1, input2, dis_type='L1'): 59 | 60 | if dis_type == 'L1': 61 | prob_s = F.softmax(input1, dim=1) 62 | prob_t = F.softmax(input2, dim=1) 63 | loss = torch.mean(torch.abs(prob_s - prob_t)) ### element-wise 64 | elif dis_type == 'CE': ## Cross entropy 65 | loss = - ((F.log_softmax(input2, dim=1)).mul(F.softmax(input1, dim=1))).mean() - ( 66 | (F.log_softmax(input1, dim=1)).mul(F.softmax(input2, dim=1))).mean() 67 | loss = loss * 0.5 68 | elif dis_type == 'KL': ##### averaged over elements, not the real KL div (summed over elements of instance, and averaged over instance) 69 | ############# nn.KLDivLoss(size_average=False) Vs F.kl_div() 70 | loss = (F.kl_div(F.log_softmax(input1), F.softmax(input2))) + ( 71 | F.kl_div(F.log_softmax(input2), F.softmax(input1))) 72 | loss = loss * 0.5 73 | ############# the following two distances are not evaluated in our paper, and need further investigation 74 | elif dis_type == 'L2': 75 | nClass = input1.size()[1] 76 | prob_s = F.softmax(input1, dim=1) 77 | prob_t = F.softmax(input2, dim=1) 78 | loss = torch.norm(prob_s - prob_t, p=2, dim=1).mean() / nClass ### element-wise 79 | elif dis_type == 'Wasse': ## distance proposed in Sliced wasserstein discrepancy for unsupervised domain adaptation, 80 | prob_s = F.softmax(input1, dim=1) 81 | prob_t = F.softmax(input2, dim=1) 82 | loss = discrepancy_slice_wasserstein(prob_s, prob_t) 83 | 84 | return loss 85 | 86 | 87 | class TargetDiscrimLoss(_WeightedLoss): 88 | def __init__(self, weight=None, size_average=True, num_classes=31): 89 | super(TargetDiscrimLoss, self).__init__(weight, size_average) 90 | self.num_classes = num_classes 91 | 92 | def forward(self, input): 93 | batch_size = input.size(0) 94 | prob = F.softmax(input, dim=1) 95 | 96 | if (prob.data[:, self.num_classes:].sum(1) == 0).sum() != 0: ########### in case of log(0) 97 | soft_weight = torch.FloatTensor(batch_size).fill_(0) 98 | soft_weight[prob[:, self.num_classes:].sum(1).data.cpu() == 0] = 1e-6 99 | soft_weight_var = soft_weight.cuda() 100 | loss = -((prob[:, self.num_classes:].sum(1) + soft_weight_var).log().mean()) 101 | else: 102 | loss = -(prob[:, self.num_classes:].sum(1).log().mean()) 103 | return loss 104 | 105 | class SourceDiscrimLoss(_WeightedLoss): 106 | def __init__(self, weight=None, size_average=True, num_classes=31): 107 | super(SourceDiscrimLoss, self).__init__(weight, size_average) 108 | self.num_classes = num_classes 109 | 110 | def forward(self, input): 111 | batch_size = input.size(0) 112 | prob = F.softmax(input, dim=1) 113 | 114 | if (prob.data[:, :self.num_classes].sum(1) == 0).sum() != 0: ########### in case of log(0) 115 | soft_weight = torch.FloatTensor(batch_size).fill_(0) 116 | soft_weight[prob[:, :self.num_classes].sum(1).data.cpu() == 0] = 1e-6 117 | soft_weight_var = soft_weight.cuda() 118 | loss = -((prob[:, :self.num_classes].sum(1) + soft_weight_var).log().mean()) 119 | else: 120 | loss = -(prob[:, :self.num_classes].sum(1).log().mean()) 121 | return loss 122 | 123 | 124 | class ConcatenatedCELoss(_WeightedLoss): 125 | def __init__(self, weight=None, size_average=True, num_classes=31): 126 | super(ConcatenatedCELoss, self).__init__(weight, size_average) 127 | self.num_classes = num_classes 128 | 129 | def forward(self, input): 130 | prob = F.softmax(input, dim=1) 131 | prob_s = prob[:, :self.num_classes] 132 | prob_t = prob[:, self.num_classes:] 133 | 134 | prob_s = process_zero_values(prob_s) 135 | prob_t = process_zero_values(prob_t) 136 | loss = - (prob_s.log().mul(prob_t)).sum(1).mean() - (prob_t.log().mul(prob_s)).sum(1).mean() 137 | loss = loss * 0.5 138 | return loss 139 | 140 | 141 | 142 | class ConcatenatedEMLoss(_WeightedLoss): 143 | def __init__(self, weight=None, size_average=True, num_classes=31): 144 | super(ConcatenatedEMLoss, self).__init__(weight, size_average) 145 | self.num_classes = num_classes 146 | 147 | def forward(self, input): 148 | prob = F.softmax(input, dim=1) 149 | prob_s = prob[:, :self.num_classes] 150 | prob_t = prob[:, self.num_classes:] 151 | prob_sum = prob_s + prob_t 152 | prob_sum = process_zero_values(prob_sum) 153 | loss = - prob_sum.log().mul(prob_sum).sum(1).mean() 154 | 155 | return loss 156 | 157 | class MinEntropyConsensusLoss(nn.Module): 158 | def __init__(self, num_classes): 159 | super(MinEntropyConsensusLoss, self).__init__() 160 | self.num_classes = num_classes 161 | 162 | def forward(self, x, y): 163 | i = torch.eye(self.num_classes).unsqueeze(0).cuda() 164 | x = F.log_softmax(x, dim=1) 165 | y = F.log_softmax(y, dim=1) 166 | x = x.unsqueeze(-1) 167 | y = y.unsqueeze(-1) 168 | 169 | ce_x = (- 1.0 * i * x).sum(1) 170 | ce_y = (- 1.0 * i * y).sum(1) 171 | 172 | ce = 0.5 * (ce_x + ce_y).min(1)[0].mean() 173 | 174 | return ce -------------------------------------------------------------------------------- /models/resnet_McDalNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from torch.autograd import Function 5 | from config.config import cfg 6 | import torch 7 | import ipdb 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | 22 | class ReverseLayerF(Function): 23 | 24 | @staticmethod 25 | def forward(ctx, x, alpha): 26 | ctx.alpha = alpha 27 | 28 | return x.view_as(x) 29 | 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | output = grad_output.neg() * ctx.alpha 33 | 34 | return output, None 35 | 36 | class ZeroLayerF(Function): 37 | 38 | @staticmethod 39 | def forward(ctx, x, alpha): 40 | ctx.alpha = 0.0 41 | 42 | return x.view_as(x) 43 | 44 | @staticmethod 45 | def backward(ctx, grad_output): 46 | output = grad_output * 0.0 47 | 48 | return output, None 49 | 50 | def conv3x3(in_planes, out_planes, stride=1): 51 | "3x3 convolution with padding" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | 55 | 56 | class BasicBlock(nn.Module): 57 | expansion = 1 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(BasicBlock, self).__init__() 61 | self.conv1 = conv3x3(inplanes, planes, stride) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.conv2 = conv3x3(planes, planes) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | residual = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class Bottleneck(nn.Module): 89 | expansion = 4 90 | 91 | def __init__(self, inplanes, planes, stride=1, downsample=None): 92 | super(Bottleneck, self).__init__() 93 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 94 | self.bn1 = nn.BatchNorm2d(planes) 95 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 96 | padding=1, bias=False) 97 | self.bn2 = nn.BatchNorm2d(planes) 98 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 99 | self.bn3 = nn.BatchNorm2d(planes * 4) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | residual = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | residual = self.downsample(x) 120 | 121 | out += residual 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000): 130 | self.inplanes = 64 131 | super(ResNet, self).__init__() 132 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 133 | bias=False) 134 | self.bn1 = nn.BatchNorm2d(64) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0]) 138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 141 | self.avgpool = nn.AvgPool2d(7) 142 | self.fc = nn.Linear(512 * block.expansion, num_classes) ## for classification 143 | self.fc_aux1 = nn.Linear(512 * block.expansion, num_classes) ## auxiliary classifier one 144 | self.fc_aux2 = nn.Linear(512 * block.expansion, num_classes) ## auxiliary classifier two 145 | self.fcdc = nn.Linear(512 * block.expansion, 1) ## domain classifier 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 150 | m.weight.data.normal_(0, math.sqrt(2. / n)) 151 | elif isinstance(m, nn.BatchNorm2d): 152 | m.weight.data.fill_(1) 153 | m.bias.data.zero_() 154 | 155 | def _make_layer(self, block, planes, blocks, stride=1): 156 | downsample = None 157 | if stride != 1 or self.inplanes != planes * block.expansion: 158 | downsample = nn.Sequential( 159 | nn.Conv2d(self.inplanes, planes * block.expansion, 160 | kernel_size=1, stride=stride, bias=False), 161 | nn.BatchNorm2d(planes * block.expansion), 162 | ) 163 | 164 | layers = [] 165 | layers.append(block(self.inplanes, planes, stride, downsample)) 166 | self.inplanes = planes * block.expansion 167 | for i in range(1, blocks): 168 | layers.append(block(self.inplanes, planes)) 169 | 170 | return nn.Sequential(*layers) 171 | 172 | def forward(self, x, alpha): 173 | x = self.conv1(x) 174 | x = self.bn1(x) 175 | x = self.relu(x) 176 | x = self.maxpool(x) 177 | 178 | x = self.layer1(x) 179 | x = self.layer2(x) 180 | x = self.layer3(x) 181 | x = self.layer4(x) 182 | 183 | x = self.avgpool(x) 184 | x = x.view(x.size(0), -1) 185 | out = self.fc(x) 186 | rev_x = ReverseLayerF.apply(x, alpha) ## gradient reverse 187 | out1 = self.fc_aux1(rev_x) 188 | out2 = self.fc_aux2(rev_x) 189 | outdc = self.fcdc(rev_x) 190 | 191 | trunc_x = ZeroLayerF.apply(x, 0.0) ## zero gradient 192 | out1_trunc = self.fc_aux1(trunc_x) 193 | out2_trunc = self.fc_aux2(trunc_x) 194 | 195 | return x, out, out1, out2, outdc, out1_trunc, out2_trunc 196 | 197 | 198 | def resnet18(): 199 | """Constructs a ResNet-18 model. 200 | """ 201 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=cfg.DATASET.NUM_CLASSES) 202 | if cfg.MODEL.PRETRAINED: 203 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 204 | model_dict = model.state_dict() 205 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 206 | pretrained_dict_temp.pop('fc.weight') 207 | pretrained_dict_temp.pop('fc.bias') 208 | model_dict.update(pretrained_dict_temp) 209 | model.load_state_dict(model_dict) 210 | return model 211 | 212 | 213 | def resnet34(): 214 | """Constructs a ResNet-34 model. 215 | """ 216 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=cfg.DATASET.NUM_CLASSES) 217 | if cfg.MODEL.PRETRAINED: 218 | pretrained_dict = model_zoo.load_url(model_urls['resnet34']) 219 | model_dict = model.state_dict() 220 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 221 | pretrained_dict_temp.pop('fc.weight') 222 | pretrained_dict_temp.pop('fc.bias') 223 | model_dict.update(pretrained_dict_temp) 224 | model.load_state_dict(model_dict) 225 | 226 | return model 227 | 228 | 229 | def resnet50(): 230 | """Constructs a ResNet-50 model. 231 | """ 232 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=cfg.DATASET.NUM_CLASSES) 233 | if cfg.MODEL.PRETRAINED: 234 | print('load the ImageNet pretrained parameters') 235 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 236 | model_dict = model.state_dict() 237 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 238 | pretrained_dict_temp.pop('fc.weight') 239 | pretrained_dict_temp.pop('fc.bias') 240 | model_dict.update(pretrained_dict_temp) 241 | model.load_state_dict(model_dict) 242 | 243 | return model 244 | 245 | 246 | def resnet101(): 247 | """Constructs a ResNet-101 model. 248 | """ 249 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=cfg.DATASET.NUM_CLASSES) 250 | if cfg.MODEL.PRETRAINED: 251 | pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 252 | model_dict = model.state_dict() 253 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 254 | pretrained_dict_temp.pop('fc.weight') 255 | pretrained_dict_temp.pop('fc.bias') 256 | model_dict.update(pretrained_dict_temp) 257 | model.load_state_dict(model_dict) 258 | 259 | return model 260 | 261 | 262 | def resnet152(): 263 | """Constructs a ResNet-152 model. 264 | """ 265 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=cfg.DATASET.NUM_CLASSES) 266 | if cfg.MODEL.PRETRAINED: 267 | pretrained_dict = model_zoo.load_url(model_urls['resnet152']) 268 | model_dict = model.state_dict() 269 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 270 | pretrained_dict_temp.pop('fc.weight') 271 | pretrained_dict_temp.pop('fc.bias') 272 | model_dict.update(pretrained_dict_temp) 273 | model.load_state_dict(model_dict) 274 | 275 | return model 276 | 277 | 278 | def resnet(): 279 | print("==> creating model '{}' ".format(cfg.MODEL.FEATURE_EXTRACTOR)) 280 | if cfg.MODEL.FEATURE_EXTRACTOR == 'resnet18': 281 | return resnet18() 282 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet34': 283 | return resnet34() 284 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet50': 285 | return resnet50() 286 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet101': 287 | return resnet101() 288 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet152': 289 | return resnet152() 290 | else: 291 | raise ValueError('Unrecognized model architecture', cfg.MODEL.FEATURE_EXTRACTOR) -------------------------------------------------------------------------------- /models/resnet_SymmNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from torch.autograd import Function 5 | from config.config import cfg 6 | import torch 7 | import ipdb 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | 22 | class ReverseLayerF(Function): 23 | 24 | @staticmethod 25 | def forward(ctx, x, alpha): 26 | ctx.alpha = alpha 27 | 28 | return x.view_as(x) 29 | 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | output = grad_output.neg() * ctx.alpha 33 | 34 | return output, None 35 | 36 | class ZeroLayerF(Function): 37 | 38 | @staticmethod 39 | def forward(ctx, x, alpha): 40 | ctx.alpha = 0.0 41 | 42 | return x.view_as(x) 43 | 44 | @staticmethod 45 | def backward(ctx, grad_output): 46 | output = grad_output * 0.0 47 | 48 | return output, None 49 | 50 | def conv3x3(in_planes, out_planes, stride=1): 51 | "3x3 convolution with padding" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | 55 | 56 | class BasicBlock(nn.Module): 57 | expansion = 1 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(BasicBlock, self).__init__() 61 | self.conv1 = conv3x3(inplanes, planes, stride) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.conv2 = conv3x3(planes, planes) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | residual = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class Bottleneck(nn.Module): 89 | expansion = 4 90 | 91 | def __init__(self, inplanes, planes, stride=1, downsample=None): 92 | super(Bottleneck, self).__init__() 93 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 94 | self.bn1 = nn.BatchNorm2d(planes) 95 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 96 | padding=1, bias=False) 97 | self.bn2 = nn.BatchNorm2d(planes) 98 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 99 | self.bn3 = nn.BatchNorm2d(planes * 4) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | residual = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | residual = self.downsample(x) 120 | 121 | out += residual 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000): 130 | self.inplanes = 64 131 | super(ResNet, self).__init__() 132 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 133 | bias=False) 134 | self.bn1 = nn.BatchNorm2d(64) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0]) 138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 141 | self.avgpool = nn.AvgPool2d(7) 142 | # self.fc = nn.Linear(512 * block.expansion, num_classes*2) ## for classification 143 | 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 147 | m.weight.data.normal_(0, math.sqrt(2. / n)) 148 | elif isinstance(m, nn.BatchNorm2d): 149 | m.weight.data.fill_(1) 150 | m.bias.data.zero_() 151 | 152 | def _make_layer(self, block, planes, blocks, stride=1): 153 | downsample = None 154 | if stride != 1 or self.inplanes != planes * block.expansion: 155 | downsample = nn.Sequential( 156 | nn.Conv2d(self.inplanes, planes * block.expansion, 157 | kernel_size=1, stride=stride, bias=False), 158 | nn.BatchNorm2d(planes * block.expansion), 159 | ) 160 | 161 | layers = [] 162 | layers.append(block(self.inplanes, planes, stride, downsample)) 163 | self.inplanes = planes * block.expansion 164 | for i in range(1, blocks): 165 | layers.append(block(self.inplanes, planes)) 166 | 167 | return nn.Sequential(*layers) 168 | 169 | def forward(self, x): 170 | x = self.conv1(x) 171 | x = self.bn1(x) 172 | x = self.relu(x) 173 | x = self.maxpool(x) 174 | 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | 180 | x = self.avgpool(x) 181 | x = x.view(x.size(0), -1) 182 | # out = self.fc(x) 183 | 184 | return x 185 | 186 | 187 | def resnet18(): 188 | """Constructs a ResNet-18 model. 189 | """ 190 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=cfg.DATASET.NUM_CLASSES) 191 | if cfg.MODEL.PRETRAINED: 192 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 193 | model_dict = model.state_dict() 194 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 195 | model_dict.update(pretrained_dict_temp) 196 | model.load_state_dict(model_dict) 197 | classifier = nn.Linear(512 * BasicBlock.expansion, cfg.DATASET.NUM_CLASSES*2) 198 | return model, classifier 199 | 200 | 201 | def resnet34(): 202 | """Constructs a ResNet-34 model. 203 | """ 204 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=cfg.DATASET.NUM_CLASSES) 205 | if cfg.MODEL.PRETRAINED: 206 | pretrained_dict = model_zoo.load_url(model_urls['resnet34']) 207 | model_dict = model.state_dict() 208 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 209 | model_dict.update(pretrained_dict_temp) 210 | model.load_state_dict(model_dict) 211 | classifier = nn.Linear(512 * BasicBlock.expansion, cfg.DATASET.NUM_CLASSES * 2) 212 | return model, classifier 213 | 214 | 215 | def resnet50(): 216 | """Constructs a ResNet-50 model. 217 | """ 218 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=cfg.DATASET.NUM_CLASSES) 219 | if cfg.MODEL.PRETRAINED: 220 | print('load the ImageNet pretrained parameters') 221 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 222 | model_dict = model.state_dict() 223 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 224 | model_dict.update(pretrained_dict_temp) 225 | model.load_state_dict(model_dict) 226 | classifier = nn.Linear(512 * Bottleneck.expansion, cfg.DATASET.NUM_CLASSES * 2) ## the concatenation of two task classifiers 227 | return model, classifier 228 | 229 | 230 | def resnet101(): 231 | """Constructs a ResNet-101 model. 232 | """ 233 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=cfg.DATASET.NUM_CLASSES) 234 | if cfg.MODEL.PRETRAINED: 235 | pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 236 | model_dict = model.state_dict() 237 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 238 | model_dict.update(pretrained_dict_temp) 239 | model.load_state_dict(model_dict) 240 | # classifier = nn.Linear(512 * Bottleneck.expansion, cfg.DATASET.NUM_CLASSES * 2) 241 | classifier = nn.Sequential(nn.Linear(512 * Bottleneck.expansion, 512), 242 | nn.BatchNorm1d(512), 243 | nn.ReLU(inplace=True), 244 | nn.Linear(512, cfg.DATASET.NUM_CLASSES * 2)) 245 | return model, classifier 246 | 247 | 248 | def resnet152(): 249 | """Constructs a ResNet-152 model. 250 | """ 251 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=cfg.DATASET.NUM_CLASSES) 252 | if cfg.MODEL.PRETRAINED: 253 | pretrained_dict = model_zoo.load_url(model_urls['resnet152']) 254 | model_dict = model.state_dict() 255 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 256 | model_dict.update(pretrained_dict_temp) 257 | model.load_state_dict(model_dict) 258 | classifier = nn.Linear(512 * Bottleneck.expansion, cfg.DATASET.NUM_CLASSES * 2) 259 | return model, classifier 260 | 261 | 262 | def resnet(): 263 | print("==> creating model '{}' ".format(cfg.MODEL.FEATURE_EXTRACTOR)) 264 | if cfg.MODEL.FEATURE_EXTRACTOR == 'resnet18': 265 | return resnet18() 266 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet34': 267 | return resnet34() 268 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet50': 269 | return resnet50() 270 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet101': 271 | return resnet101() 272 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet152': 273 | return resnet152() 274 | else: 275 | raise ValueError('Unrecognized model architecture', cfg.MODEL.FEATURE_EXTRACTOR) -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python ./tools/train.py --distance_type None --method SymmNetsV2 --task closedsc --cfg ./experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg_res101SC.yaml --exp_name logCloseSC 4 | 5 | python ./tools/train.py --distance_type None --method SymmNetsV2 --task closedsc --cfg ./experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg_res101SC.yaml --exp_name logCloseSC 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /run_temp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## The example script. all the log file are present in the ./experiments file for verifying 3 | 4 | ### ##################################### Closed Set DA example ######################################### 5 | 6 | ## McDalNet script 7 | #python ./tools/train.py --distance_type L1 --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml 8 | # 9 | #python ./tools/train.py --distance_type KL --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml 10 | # 11 | #python ./tools/train.py --distance_type CE --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml 12 | # 13 | #python ./tools/train.py --distance_type MDD --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml 14 | # 15 | #python ./tools/train.py --distance_type DANN --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml 16 | 17 | #python ./tools/train.py --distance_type L1 --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml 18 | # 19 | #python ./tools/train.py --distance_type KL --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml 20 | # 21 | #python ./tools/train.py --distance_type CE --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml 22 | # 23 | #python ./tools/train.py --distance_type DANN --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml 24 | # 25 | #python ./tools/train.py --distance_type MDD --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml 26 | 27 | ## SymmNets script 28 | #python ./tools/train.py --distance_type None --method SymmNetsV1 --cfg ./experiments/configs/ImageCLEF/SymmNets/clef_train_c2i_cfg.yaml 29 | 30 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg_res101.yaml --exp_name logres101 31 | 32 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/ImageCLEF/SymmNets/clef_train_c2i_cfg.yaml 33 | 34 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg.yaml 35 | 36 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/Office31/SymmNets/office31_train_amazon2dslr_cfg.yaml 37 | 38 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/Office31/SymmNets/office31_train_webcam2amazon_cfg.yaml 39 | 40 | 41 | ## SymmNets-SC script 42 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python ./tools/train.py --distance_type None --task closedsc --method SymmNetsV2 --cfg ./experiments/configs/Office31/SymmNets/office31_train_webcam2amazon_cfg_SC.yaml 43 | 44 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python ./tools/train.py --distance_type None --task closedsc --method SymmNetsV2 --cfg ./experiments/configs/Office31/SymmNets/office31_train_amazon2dslr_cfg_SC.yaml 45 | 46 | 47 | 48 | 49 | ##################################### Partial DA example ####################### 50 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --task partial --cfg ./experiments/configs/OfficeHome/SymmNets/home_train_A2R_partial_cfg.yaml 51 | 52 | 53 | 54 | 55 | ###################################### Open set DA example ##################### 56 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --task open --cfg ./experiments/configs/Office31/SymmNets/office31_train_amazon2webcam_open_cfg.yaml 57 | # 58 | -------------------------------------------------------------------------------- /solver/McDalNet_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import math 5 | import time 6 | from utils.utils import to_cuda, accuracy_for_each_class, accuracy, AverageMeter, process_one_values 7 | from config.config import cfg 8 | import torch.nn.functional as F 9 | from models.loss_utils import McDalNetLoss 10 | from .base_solver import BaseSolver 11 | import ipdb 12 | 13 | class McDalNetSolver(BaseSolver): 14 | def __init__(self, net, dataloaders, **kwargs): 15 | super(McDalNetSolver, self).__init__(net, dataloaders, **kwargs) 16 | self.BCELoss = nn.BCEWithLogitsLoss().cuda() 17 | self.McDalNetLoss = McDalNetLoss().cuda() 18 | if cfg.RESUME != '': 19 | resume_dict = torch.load(cfg.RESUME) 20 | model_state_dict = resume_dict['model_state_dict'] 21 | self.net.load_state_dict(model_state_dict) 22 | self.best_prec1 = resume_dict['best_prec1'] 23 | self.epoch = resume_dict['epoch'] 24 | 25 | def solve(self): 26 | stop = False 27 | while not stop: 28 | stop = self.complete_training() 29 | self.update_network() 30 | acc = self.test() 31 | if acc > self.best_prec1: 32 | self.best_prec1 = acc 33 | self.save_ckpt() 34 | self.epoch += 1 35 | 36 | 37 | def update_network(self, **kwargs): 38 | stop = False 39 | self.train_data['source']['iterator'] = iter(self.train_data['source']['loader']) 40 | self.train_data['target']['iterator'] = iter(self.train_data['target']['loader']) 41 | self.iters_per_epoch = len(self.train_data['target']['loader']) 42 | iters_counter_within_epoch = 0 43 | data_time = AverageMeter() 44 | batch_time = AverageMeter() 45 | total_loss = AverageMeter() 46 | ce_loss = AverageMeter() 47 | da_loss = AverageMeter() 48 | prec1_task = AverageMeter() 49 | prec1_aux1 = AverageMeter() 50 | prec1_aux2 = AverageMeter() 51 | self.net.train() 52 | end = time.time() 53 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': 54 | lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1 55 | self.update_lr() 56 | print('value of lam is: %3f' % (lam)) 57 | while not stop: 58 | if self.opt.TRAIN.PROCESS_COUNTER == 'iteration': 59 | lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1 60 | print('value of lam is: %3f' % (lam)) 61 | self.update_lr() 62 | source_data, source_gt = self.get_samples('source') 63 | target_data, _ = self.get_samples('target') 64 | source_data = to_cuda(source_data) 65 | source_gt = to_cuda(source_gt) 66 | target_data = to_cuda(target_data) 67 | data_time.update(time.time() - end) 68 | 69 | feature_source, output_source, output_source1, output_source2, output_source_dc, output_source1_trunc, output_source2_trunc = self.net(source_data, lam) 70 | loss_task_auxiliary_1 = self.CELoss(output_source1_trunc, source_gt) 71 | loss_task_auxiliary_2 = self.CELoss(output_source2_trunc, source_gt) 72 | loss_task = self.CELoss(output_source, source_gt) 73 | if self.opt.MCDALNET.DISTANCE_TYPE != 'SourceOnly': 74 | feature_target, output_target, output_target1, output_target2, output_target_dc, output_target1_trunc, output_target2_trunc = self.net(target_data, lam) 75 | if self.opt.MCDALNET.DISTANCE_TYPE == 'DANN': 76 | num_source = source_data.size()[0] 77 | num_target = target_data.size()[0] 78 | dlabel_source = to_cuda(torch.zeros(num_source, 1)) 79 | dlabel_target = to_cuda(torch.ones(num_target, 1)) 80 | loss_domain_all = self.BCELoss(output_source_dc, dlabel_source) + self.BCELoss(output_target_dc, dlabel_target) 81 | loss_all = loss_task + loss_domain_all 82 | elif self.opt.MCDALNET.DISTANCE_TYPE == 'MDD': 83 | prob_target1 = F.softmax(output_target1, dim=1) 84 | _, target_pseudo_label = torch.topk(output_target2, 1) 85 | batch_index = torch.arange(output_target.size()[0]).long() 86 | pred_gt_prob = prob_target1[batch_index, target_pseudo_label] ## the prob values of the predicted gt 87 | pred_gt_prob = process_one_values(pred_gt_prob) 88 | loss_domain_target = (1 - pred_gt_prob).log().mean() 89 | 90 | _, source_pseudo_label = torch.topk(output_source2, 1) 91 | loss_domain_source = self.CELoss(output_source1, source_pseudo_label[:, 0]) 92 | loss_domain_all = loss_domain_source - loss_domain_target 93 | loss_all = loss_task + loss_domain_all + loss_task_auxiliary_1 + loss_task_auxiliary_2 94 | else: 95 | loss_domain_source = self.McDalNetLoss(output_source1, output_source2, self.opt.MCDALNET.DISTANCE_TYPE) 96 | loss_domain_target = self.McDalNetLoss(output_target1, output_target2, self.opt.MCDALNET.DISTANCE_TYPE) 97 | loss_domain_all = loss_domain_source - loss_domain_target 98 | loss_all = loss_task + loss_domain_all + loss_task_auxiliary_1 + loss_task_auxiliary_2 99 | da_loss.update(loss_domain_all, source_data.size()[0]) 100 | else: 101 | loss_all = loss_task 102 | ce_loss.update(loss_task, source_data.size()[0]) 103 | total_loss.update(loss_all, source_data.size()[0]) 104 | prec1_task.update(accuracy(output_source, source_gt), source_data.size()[0]) 105 | prec1_aux1.update(accuracy(output_source1, source_gt), source_data.size()[0]) 106 | prec1_aux2.update(accuracy(output_source2, source_gt), source_data.size()[0]) 107 | 108 | self.optimizer.zero_grad() 109 | loss_all.backward() 110 | self.optimizer.step() 111 | 112 | print(" Train:epoch: %d:[%d/%d], LossCE: %3f, LossDA: %3f, LossAll: %3f, Auxi1: %3f, Auxi2: %3f, Task: %3f" % \ 113 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, ce_loss.avg, da_loss.avg, total_loss.avg, prec1_aux1.avg, prec1_aux2.avg, prec1_task.avg)) 114 | 115 | batch_time.update(time.time() - end) 116 | end = time.time() 117 | self.iters += 1 118 | iters_counter_within_epoch += 1 119 | if iters_counter_within_epoch >= self.iters_per_epoch: 120 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 121 | log.write("\n") 122 | log.write(" Train:epoch: %d:[%d/%d], LossCE: %3f, LossDA: %3f, LossAll: %3f, Auxi1: %3f, Auxi2: %3f, Task: %3f" % \ 123 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, ce_loss.avg, da_loss.avg, total_loss.avg, prec1_aux1.avg, prec1_aux2.avg, prec1_task.avg)) 124 | log.close() 125 | stop = True 126 | 127 | 128 | def test(self): 129 | self.net.eval() 130 | prec1_task = AverageMeter() 131 | prec1_auxi1 = AverageMeter() 132 | prec1_auxi2 = AverageMeter() 133 | counter_all = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 134 | counter_all_auxi1 = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 135 | counter_all_auxi2 = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 136 | counter_acc = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 137 | counter_acc_auxi1 = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 138 | counter_acc_auxi2 = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 139 | 140 | for i, (input, target) in enumerate(self.test_data['loader']): 141 | input, target = to_cuda(input), to_cuda(target) 142 | with torch.no_grad(): 143 | _, output_test, output_test1, output_test2, _, _, _ = self.net(input, 1) ## the value of lam do not affect the test process 144 | 145 | if self.opt.EVAL_METRIC == 'accu': 146 | prec1_task_iter = accuracy(output_test, target) 147 | prec1_auxi1_iter = accuracy(output_test1, target) 148 | prec1_auxi2_iter = accuracy(output_test2, target) 149 | prec1_task.update(prec1_task_iter, input.size(0)) 150 | prec1_auxi1.update(prec1_auxi1_iter, input.size(0)) 151 | prec1_auxi2.update(prec1_auxi2_iter, input.size(0)) 152 | if i % self.opt.PRINT_STEP == 0: 153 | print(" Test:epoch: %d:[%d/%d], Auxi1: %3f, Auxi2: %3f, Task: %3f" % \ 154 | (self.epoch, i, len(self.test_data['loader']), prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg)) 155 | elif self.opt.EVAL_METRIC == 'accu_mean': 156 | prec1_task_iter = accuracy(output_test, target) 157 | prec1_task.update(prec1_task_iter, input.size(0)) 158 | counter_all, counter_acc = accuracy_for_each_class(output_test, target, counter_all, counter_acc) 159 | counter_all_auxi1, counter_acc_auxi1 = accuracy_for_each_class(output_test1, target, counter_all_auxi1, counter_acc_auxi1) 160 | counter_all_auxi2, counter_acc_auxi2 = accuracy_for_each_class(output_test2, target, counter_all_auxi2, counter_acc_auxi2) 161 | if i % self.opt.PRINT_STEP == 0: 162 | print(" Test:epoch: %d:[%d/%d], Task: %3f" % \ 163 | (self.epoch, i, len(self.test_data['loader']), prec1_task.avg)) 164 | else: 165 | raise NotImplementedError 166 | acc_for_each_class = counter_acc / counter_all 167 | acc_for_each_class_auxi1 = counter_acc_auxi1 / counter_all_auxi1 168 | acc_for_each_class_auxi2 = counter_acc_auxi2 / counter_all_auxi2 169 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 170 | log.write("\n") 171 | if self.opt.EVAL_METRIC == 'accu': 172 | log.write( 173 | " Test:epoch: %d, Top1_auxi1: %3f, Top1_auxi2: %3f, Top1: %3f" % \ 174 | (self.epoch, prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg)) 175 | log.close() 176 | return max(prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg) 177 | elif self.opt.EVAL_METRIC == 'accu_mean': 178 | log.write( 179 | " Test:epoch: %d, Top1_auxi1: %3f, Top1_auxi2: %3f, Top1: %3f" % \ 180 | (self.epoch, acc_for_each_class_auxi1.mean(), acc_for_each_class_auxi2.mean(), acc_for_each_class.mean())) 181 | log.write("\nClass-wise Acc:") ## based on the task classifier. 182 | for i in range(self.opt.DATASET.NUM_CLASSES): 183 | if i == 0: 184 | log.write("%dst: %3f" % (i + 1, acc_for_each_class[i])) 185 | elif i == 1: 186 | log.write(", %dnd: %3f" % (i + 1, acc_for_each_class[i])) 187 | elif i == 2: 188 | log.write(", %drd: %3f" % (i + 1, acc_for_each_class[i])) 189 | else: 190 | log.write(", %dth: %3f" % (i + 1, acc_for_each_class[i])) 191 | log.close() 192 | return max(acc_for_each_class_auxi1.mean(), acc_for_each_class_auxi2.mean(), acc_for_each_class.mean()) 193 | 194 | def build_optimizer(self): 195 | if self.opt.TRAIN.OPTIMIZER == 'SGD': ## some params may not contribute the loss_all, thus they are not updated in the training process. 196 | self.optimizer = torch.optim.SGD([ 197 | {'params': self.net.module.conv1.parameters(), 'name': 'pre-trained'}, 198 | {'params': self.net.module.bn1.parameters(), 'name': 'pre-trained'}, 199 | {'params': self.net.module.layer1.parameters(), 'name': 'pre-trained'}, 200 | {'params': self.net.module.layer2.parameters(), 'name': 'pre-trained'}, 201 | {'params': self.net.module.layer3.parameters(), 'name': 'pre-trained'}, 202 | {'params': self.net.module.layer4.parameters(), 'name': 'pre-trained'}, 203 | {'params': self.net.module.fc.parameters(), 'name': 'new-added'}, 204 | {'params': self.net.module.fc_aux1.parameters(), 'name': 'new-added'}, 205 | {'params': self.net.module.fc_aux2.parameters(), 'name': 'new-added'}, 206 | {'params': self.net.module.fcdc.parameters(), 'name': 'new-added'} 207 | ], 208 | lr=self.opt.TRAIN.BASE_LR, 209 | momentum=self.opt.TRAIN.MOMENTUM, 210 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY, 211 | nesterov=True) 212 | else: 213 | raise NotImplementedError 214 | print('Optimizer built') 215 | 216 | def update_lr(self): 217 | if self.opt.TRAIN.LR_SCHEDULE == 'inv': 218 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': 219 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.epoch / self.opt.TRAIN.MAX_EPOCH), self.opt.INV.BETA) 220 | elif self.opt.TRAIN.PROCESS_COUNTER == 'iteration': 221 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch)), self.opt.INV.BETA) 222 | else: 223 | raise NotImplementedError 224 | elif self.opt.TRAIN.LR_SCHEDULE == 'fix': 225 | lr = self.opt.TRAIN.BASE_LR 226 | else: 227 | raise NotImplementedError 228 | lr_pretrain = lr * 0.1 229 | print('the lr is: %3f' % (lr)) 230 | for param_group in self.optimizer.param_groups: 231 | if param_group['name'] == 'pre-trained': 232 | param_group['lr'] = lr_pretrain 233 | elif param_group['name'] == 'new-added': 234 | param_group['lr'] = lr 235 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer 236 | param_group['lr'] = 0 237 | 238 | def save_ckpt(self): 239 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 240 | log.write(" Best Acc so far: %3f" % (self.best_prec1)) 241 | log.close() 242 | if self.opt.TRAIN.SAVING: 243 | save_path = self.opt.SAVE_DIR 244 | ckpt_resume = os.path.join(save_path, 'ckpt_%d.resume' % (self.loop)) 245 | torch.save({'epoch': self.epoch, 246 | 'best_prec1': self.best_prec1, 247 | 'model_state_dict': self.net.state_dict() 248 | }, ckpt_resume) -------------------------------------------------------------------------------- /solver/SymmNetsV1_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import math 5 | import time 6 | from utils.utils import to_cuda, accuracy_for_each_class, accuracy, AverageMeter, process_one_values 7 | from config.config import cfg 8 | import torch.nn.functional as F 9 | from models.loss_utils import TargetDiscrimLoss, SourceDiscrimLoss, ConcatenatedEMLoss 10 | from .base_solver import BaseSolver 11 | import ipdb 12 | 13 | class SymmNetsV1Solver(BaseSolver): 14 | def __init__(self, net, dataloaders, **kwargs): 15 | super(SymmNetsV1Solver, self).__init__(net, dataloaders, **kwargs) 16 | self.num_classes = cfg.DATASET.NUM_CLASSES 17 | self.TargetDiscrimLoss = TargetDiscrimLoss(num_classes=self.num_classes).cuda() 18 | self.SourceDiscrimLoss = SourceDiscrimLoss(num_classes=self.num_classes).cuda() 19 | self.ConcatenatedEMLoss = ConcatenatedEMLoss(num_classes=self.num_classes).cuda() 20 | self.feature_extractor = self.net['feature_extractor'] 21 | self.classifier = self.net['classifier'] 22 | 23 | if cfg.RESUME != '': 24 | resume_dict = torch.load(cfg.RESUME) 25 | self.net['feature_extractor'].load_state_dict(resume_dict['feature_extractor_state_dict']) 26 | self.net['classifier'].load_state_dict(resume_dict['classifier_state_dict']) 27 | self.best_prec1 = resume_dict['best_prec1'] 28 | self.epoch = resume_dict['epoch'] 29 | 30 | def solve(self): 31 | stop = False 32 | while not stop: 33 | stop = self.complete_training() 34 | self.update_network() 35 | acc = self.test() 36 | if acc > self.best_prec1: 37 | self.best_prec1 = acc 38 | self.save_ckpt() 39 | self.epoch += 1 40 | 41 | 42 | def update_network(self, **kwargs): 43 | stop = False 44 | self.train_data['source']['iterator'] = iter(self.train_data['source']['loader']) 45 | self.train_data['target']['iterator'] = iter(self.train_data['target']['loader']) 46 | self.iters_per_epoch = len(self.train_data['target']['loader']) 47 | iters_counter_within_epoch = 0 48 | data_time = AverageMeter() 49 | batch_time = AverageMeter() 50 | classifier_loss = AverageMeter() 51 | feature_extractor_loss = AverageMeter() 52 | prec1_fs = AverageMeter() 53 | prec1_ft = AverageMeter() 54 | self.feature_extractor.train() 55 | self.classifier.train() 56 | end = time.time() 57 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': 58 | lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1 59 | self.update_lr() 60 | print('value of lam is: %3f' % (lam)) 61 | while not stop: 62 | if self.opt.TRAIN.PROCESS_COUNTER == 'iteration': 63 | lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1 64 | print('value of lam is: %3f' % (lam)) 65 | self.update_lr() 66 | source_data, source_gt = self.get_samples('source') 67 | target_data, _ = self.get_samples('target') 68 | source_data = to_cuda(source_data) 69 | source_gt = to_cuda(source_gt) 70 | target_data = to_cuda(target_data) 71 | data_time.update(time.time() - end) 72 | 73 | feature_source = self.feature_extractor(source_data) 74 | output_source = self.classifier(feature_source) 75 | feature_target = self.feature_extractor(target_data) 76 | output_target = self.classifier(feature_target) 77 | 78 | loss_task_fs = self.CELoss(output_source[:,:self.num_classes], source_gt) 79 | loss_task_ft = self.CELoss(output_source[:,self.num_classes:], source_gt) 80 | loss_discrim_source = self.SourceDiscrimLoss(output_source) 81 | loss_discrim_target = self.TargetDiscrimLoss(output_target) 82 | loss_summary_classifier = loss_task_fs + loss_task_ft + loss_discrim_source + loss_discrim_target 83 | 84 | source_gt_for_ft_in_fst = source_gt + self.num_classes 85 | loss_confusion_source = 0.5 * self.CELoss(output_source, source_gt) + 0.5 * self.CELoss(output_source, source_gt_for_ft_in_fst) 86 | loss_confusion_target = 0.5 * self.SourceDiscrimLoss(output_target) + 0.5 * self.TargetDiscrimLoss(output_target) 87 | loss_em = self.ConcatenatedEMLoss(output_target) 88 | loss_summary_feature_extractor = loss_confusion_source + lam * (loss_confusion_target + loss_em) 89 | 90 | self.optimizer_classifier.zero_grad() 91 | loss_summary_classifier.backward(retain_graph=True) 92 | self.optimizer_classifier.step() 93 | 94 | self.optimizer_feature_extractor.zero_grad() 95 | loss_summary_feature_extractor.backward() 96 | self.optimizer_feature_extractor.step() 97 | 98 | classifier_loss.update(loss_summary_classifier, source_data.size()[0]) 99 | feature_extractor_loss.update(loss_summary_feature_extractor, source_data.size()[0]) 100 | prec1_fs.update(accuracy(output_source[:, :self.num_classes], source_gt), source_data.size()[0]) 101 | prec1_ft.update(accuracy(output_source[:, self.num_classes:], source_gt), source_data.size()[0]) 102 | 103 | print(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \ 104 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg)) 105 | 106 | batch_time.update(time.time() - end) 107 | end = time.time() 108 | self.iters += 1 109 | iters_counter_within_epoch += 1 110 | if iters_counter_within_epoch >= self.iters_per_epoch: 111 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 112 | log.write("\n") 113 | log.write(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \ 114 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg)) 115 | log.close() 116 | stop = True 117 | 118 | 119 | def test(self): 120 | self.feature_extractor.eval() 121 | self.classifier.eval() 122 | prec1_fs = AverageMeter() 123 | prec1_ft = AverageMeter() 124 | counter_all_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 125 | counter_all_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 126 | counter_acc_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 127 | counter_acc_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 128 | 129 | for i, (input, target) in enumerate(self.test_data['loader']): 130 | input, target = to_cuda(input), to_cuda(target) 131 | with torch.no_grad(): 132 | feature_test = self.feature_extractor(input) 133 | output_test = self.classifier(feature_test) 134 | 135 | 136 | if self.opt.EVAL_METRIC == 'accu': 137 | prec1_fs_iter = accuracy(output_test[:, :self.num_classes], target) 138 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target) 139 | prec1_fs.update(prec1_fs_iter, input.size(0)) 140 | prec1_ft.update(prec1_ft_iter, input.size(0)) 141 | if i % self.opt.PRINT_STEP == 0: 142 | print(" Test:epoch: %d:[%d/%d], AccFs: %3f, AccFt: %3f" % \ 143 | (self.epoch, i, len(self.test_data['loader']), prec1_fs.avg, prec1_ft.avg)) 144 | elif self.opt.EVAL_METRIC == 'accu_mean': 145 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target) 146 | prec1_ft.update(prec1_ft_iter, input.size(0)) 147 | counter_all_fs, counter_acc_fs = accuracy_for_each_class(output_test[:, :self.num_classes], target, counter_all_fs, counter_acc_fs) 148 | counter_all_ft, counter_acc_ft = accuracy_for_each_class(output_test[:, self.num_classes:], target, counter_all_ft, counter_acc_ft) 149 | if i % self.opt.PRINT_STEP == 0: 150 | print(" Test:epoch: %d:[%d/%d], Task: %3f" % \ 151 | (self.epoch, i, len(self.test_data['loader']), prec1_ft.avg)) 152 | else: 153 | raise NotImplementedError 154 | acc_for_each_class_fs = counter_acc_fs / counter_all_fs 155 | acc_for_each_class_ft = counter_acc_ft / counter_all_ft 156 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 157 | log.write("\n") 158 | if self.opt.EVAL_METRIC == 'accu': 159 | log.write( 160 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \ 161 | (self.epoch, prec1_fs.avg, prec1_ft.avg)) 162 | log.close() 163 | return max(prec1_fs.avg, prec1_ft.avg) 164 | elif self.opt.EVAL_METRIC == 'accu_mean': 165 | log.write( 166 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \ 167 | (self.epoch,acc_for_each_class_fs.mean(), acc_for_each_class_ft.mean())) 168 | log.write("\nClass-wise Acc of Ft:") ## based on the task classifier. 169 | for i in range(self.opt.DATASET.NUM_CLASSES): 170 | if i == 0: 171 | log.write("%dst: %3f" % (i + 1, acc_for_each_class_ft[i])) 172 | elif i == 1: 173 | log.write(", %dnd: %3f" % (i + 1, acc_for_each_class_ft[i])) 174 | elif i == 2: 175 | log.write(", %drd: %3f" % (i + 1, acc_for_each_class_ft[i])) 176 | else: 177 | log.write(", %dth: %3f" % (i + 1, acc_for_each_class_ft[i])) 178 | log.close() 179 | return max(acc_for_each_class_ft.mean(), acc_for_each_class_fs.mean()) 180 | 181 | def build_optimizer(self): 182 | if self.opt.TRAIN.OPTIMIZER == 'SGD': ## some params may not contribute the loss_all, thus they are not updated in the training process. 183 | self.optimizer_feature_extractor = torch.optim.SGD([ 184 | {'params': self.net['feature_extractor'].module.conv1.parameters(), 'name': 'pre-trained'}, 185 | {'params': self.net['feature_extractor'].module.bn1.parameters(), 'name': 'pre-trained'}, 186 | {'params': self.net['feature_extractor'].module.layer1.parameters(), 'name': 'pre-trained'}, 187 | {'params': self.net['feature_extractor'].module.layer2.parameters(), 'name': 'pre-trained'}, 188 | {'params': self.net['feature_extractor'].module.layer3.parameters(), 'name': 'pre-trained'}, 189 | {'params': self.net['feature_extractor'].module.layer4.parameters(), 'name': 'pre-trained'}, 190 | ], 191 | lr=self.opt.TRAIN.BASE_LR, 192 | momentum=self.opt.TRAIN.MOMENTUM, 193 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY, 194 | nesterov=True) 195 | 196 | self.optimizer_classifier = torch.optim.SGD([ 197 | {'params': self.net['classifier'].parameters(), 'name': 'new-added'}, 198 | ], 199 | lr=self.opt.TRAIN.BASE_LR, 200 | momentum=self.opt.TRAIN.MOMENTUM, 201 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY, 202 | nesterov=True) 203 | else: 204 | raise NotImplementedError 205 | print('Optimizer built') 206 | 207 | def update_lr(self): 208 | if self.opt.TRAIN.LR_SCHEDULE == 'inv': 209 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': 210 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.epoch / self.opt.TRAIN.MAX_EPOCH), self.opt.INV.BETA) 211 | elif self.opt.TRAIN.PROCESS_COUNTER == 'iteration': 212 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch)), self.opt.INV.BETA) 213 | else: 214 | raise NotImplementedError 215 | elif self.opt.TRAIN.LR_SCHEDULE == 'fix': 216 | lr = self.opt.TRAIN.BASE_LR 217 | else: 218 | raise NotImplementedError 219 | lr_pretrain = lr * 0.1 220 | print('the lr is: %3f' % (lr)) 221 | for param_group in self.optimizer_feature_extractor.param_groups: 222 | if param_group['name'] == 'pre-trained': 223 | param_group['lr'] = lr_pretrain 224 | elif param_group['name'] == 'new-added': 225 | param_group['lr'] = lr 226 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer 227 | param_group['lr'] = 0 228 | 229 | for param_group in self.optimizer_classifier.param_groups: 230 | if param_group['name'] == 'pre-trained': 231 | param_group['lr'] = lr_pretrain 232 | elif param_group['name'] == 'new-added': 233 | param_group['lr'] = lr 234 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer 235 | param_group['lr'] = 0 236 | 237 | def save_ckpt(self): 238 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 239 | log.write(" Best Acc so far: %3f" % (self.best_prec1)) 240 | log.close() 241 | if self.opt.TRAIN.SAVING: 242 | save_path = self.opt.SAVE_DIR 243 | ckpt_resume = os.path.join(save_path, 'ckpt_%d.resume' % (self.loop)) 244 | torch.save({'epoch': self.epoch, 245 | 'best_prec1': self.best_prec1, 246 | 'feature_extractor_state_dict': self.net['feature_extractor'].state_dict(), 247 | 'classifier_state_dict': self.net['classifier'].state_dict() 248 | }, ckpt_resume) -------------------------------------------------------------------------------- /solver/SymmNetsV2Partial_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import math 5 | import time 6 | from utils.utils import to_cuda, accuracy_for_each_class, accuracy, AverageMeter, process_one_values 7 | from config.config import cfg 8 | import torch.nn.functional as F 9 | from models.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss, CrossEntropyClassWeighted 10 | from .base_solver import BaseSolver 11 | import ipdb 12 | 13 | class SymmNetsV2PartialSolver(BaseSolver): 14 | def __init__(self, net, dataloaders, **kwargs): 15 | super(SymmNetsV2PartialSolver, self).__init__(net, dataloaders, **kwargs) 16 | self.num_classes = cfg.DATASET.NUM_CLASSES 17 | self.TargetDiscrimLoss = TargetDiscrimLoss(num_classes=self.num_classes).cuda() 18 | self.ConcatenatedCELoss = ConcatenatedCELoss(num_classes=self.num_classes).cuda() 19 | self.feature_extractor = self.net['feature_extractor'] 20 | self.classifier = self.net['classifier'] 21 | self.lam = 0 22 | class_weight_initial = torch.ones(self.num_classes) ############################ class-level weight to filter out the outlier classes. 23 | self.class_weight_initial = class_weight_initial.cuda() 24 | class_weight = torch.ones(self.num_classes) ############################ class-level weight to filter out the outlier classes. 25 | self.class_weight = class_weight.cuda() 26 | self.softweight = True 27 | self.CELossWeight = CrossEntropyClassWeighted() 28 | 29 | if cfg.RESUME != '': 30 | resume_dict = torch.load(cfg.RESUME) 31 | self.net['feature_extractor'].load_state_dict(resume_dict['feature_extractor_state_dict']) 32 | self.net['classifier'].load_state_dict(resume_dict['classifier_state_dict']) 33 | self.best_prec1 = resume_dict['best_prec1'] 34 | self.epoch = resume_dict['epoch'] 35 | 36 | def solve(self): 37 | stop = False 38 | while not stop: 39 | stop = self.complete_training() 40 | self.update_network() 41 | prediction_weight, acc = self.test() 42 | prediction_weight = prediction_weight.cuda() 43 | if self.softweight: 44 | self.class_weight = prediction_weight * self.lam + self.class_weight_initial * (1 - self.lam) 45 | else: 46 | self.class_weight = prediction_weight 47 | print('the class weight adopted in partial DA') 48 | print(self.class_weight) 49 | if acc > self.best_prec1: 50 | self.best_prec1 = acc 51 | self.save_ckpt() 52 | self.epoch += 1 53 | 54 | 55 | def update_network(self, **kwargs): 56 | stop = False 57 | self.train_data['source']['iterator'] = iter(self.train_data['source']['loader']) 58 | self.train_data['target']['iterator'] = iter(self.train_data['target']['loader']) 59 | self.iters_per_epoch = len(self.train_data['target']['loader']) 60 | iters_counter_within_epoch = 0 61 | data_time = AverageMeter() 62 | batch_time = AverageMeter() 63 | classifier_loss = AverageMeter() 64 | feature_extractor_loss = AverageMeter() 65 | prec1_fs = AverageMeter() 66 | prec1_ft = AverageMeter() 67 | self.feature_extractor.train() 68 | self.classifier.train() 69 | end = time.time() 70 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': 71 | self.lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1 72 | self.update_lr() 73 | print('value of lam is: %3f' % (self.lam)) 74 | while not stop: 75 | if self.opt.TRAIN.PROCESS_COUNTER == 'iteration': 76 | self.lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1 77 | print('value of lam is: %3f' % (self.lam)) 78 | self.update_lr() 79 | source_data, source_gt = self.get_samples('source') 80 | target_data, _ = self.get_samples('target') 81 | source_data = to_cuda(source_data) 82 | source_gt = to_cuda(source_gt) 83 | target_data = to_cuda(target_data) 84 | data_time.update(time.time() - end) 85 | 86 | feature_source = self.feature_extractor(source_data) 87 | output_source = self.classifier(feature_source) 88 | feature_target = self.feature_extractor(target_data) 89 | output_target = self.classifier(feature_target) 90 | 91 | weight_concate = torch.cat((self.class_weight, self.class_weight)) 92 | loss_task_fs = self.CELossWeight(output_source[:,:self.num_classes], source_gt, self.class_weight) 93 | loss_task_ft = self.CELossWeight(output_source[:,self.num_classes:], source_gt, self.class_weight) 94 | loss_discrim_source = self.CELossWeight(output_source, source_gt, weight_concate) 95 | loss_discrim_target = self.TargetDiscrimLoss(output_target) 96 | loss_summary_classifier = loss_task_fs + loss_task_ft + loss_discrim_source + loss_discrim_target 97 | 98 | source_gt_for_ft_in_fst = source_gt + self.num_classes 99 | loss_confusion_source = 0.5 * self.CELossWeight(output_source, source_gt, weight_concate) + 0.5 * self.CELossWeight(output_source, source_gt_for_ft_in_fst, weight_concate) 100 | loss_confusion_target = self.ConcatenatedCELoss(output_target) 101 | loss_summary_feature_extractor = loss_confusion_source + self.lam * loss_confusion_target 102 | 103 | self.optimizer_classifier.zero_grad() 104 | loss_summary_classifier.backward(retain_graph=True) 105 | self.optimizer_classifier.step() 106 | 107 | self.optimizer_feature_extractor.zero_grad() 108 | loss_summary_feature_extractor.backward() 109 | self.optimizer_feature_extractor.step() 110 | 111 | classifier_loss.update(loss_summary_classifier, source_data.size()[0]) 112 | feature_extractor_loss.update(loss_summary_feature_extractor, source_data.size()[0]) 113 | prec1_fs.update(accuracy(output_source[:, :self.num_classes], source_gt), source_data.size()[0]) 114 | prec1_ft.update(accuracy(output_source[:, self.num_classes:], source_gt), source_data.size()[0]) 115 | 116 | print(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \ 117 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg)) 118 | 119 | batch_time.update(time.time() - end) 120 | end = time.time() 121 | self.iters += 1 122 | iters_counter_within_epoch += 1 123 | if iters_counter_within_epoch >= self.iters_per_epoch: 124 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 125 | log.write("\n") 126 | log.write(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \ 127 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg)) 128 | log.close() 129 | stop = True 130 | 131 | def test(self): 132 | self.feature_extractor.eval() 133 | self.classifier.eval() 134 | prec1_fs = AverageMeter() 135 | prec1_ft = AverageMeter() 136 | counter_all_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 137 | counter_all_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 138 | counter_acc_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 139 | counter_acc_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 140 | class_weight = torch.zeros(self.num_classes) 141 | class_weight = class_weight.cuda() 142 | count = 0 143 | 144 | 145 | for i, (input, target) in enumerate(self.test_data['loader']): 146 | input, target = to_cuda(input), to_cuda(target) 147 | with torch.no_grad(): 148 | feature_test = self.feature_extractor(input) 149 | output_test = self.classifier(feature_test) 150 | prob = F.softmax(output_test[:, self.num_classes:], dim=1) 151 | class_weight = class_weight + prob.data.sum(0) 152 | count = count + input.size(0) 153 | 154 | if self.opt.EVAL_METRIC == 'accu': 155 | prec1_fs_iter = accuracy(output_test[:, :self.num_classes], target) 156 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target) 157 | prec1_fs.update(prec1_fs_iter, input.size(0)) 158 | prec1_ft.update(prec1_ft_iter, input.size(0)) 159 | if i % self.opt.PRINT_STEP == 0: 160 | print(" Test:epoch: %d:[%d/%d], AccFs: %3f, AccFt: %3f" % \ 161 | (self.epoch, i, len(self.test_data['loader']), prec1_fs.avg, prec1_ft.avg)) 162 | elif self.opt.EVAL_METRIC == 'accu_mean': 163 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target) 164 | prec1_ft.update(prec1_ft_iter, input.size(0)) 165 | counter_all_fs, counter_acc_fs = accuracy_for_each_class(output_test[:, :self.num_classes], target, counter_all_fs, counter_acc_fs) 166 | counter_all_ft, counter_acc_ft = accuracy_for_each_class(output_test[:, self.num_classes:], target, counter_all_ft, counter_acc_ft) 167 | if i % self.opt.PRINT_STEP == 0: 168 | print(" Test:epoch: %d:[%d/%d], Task: %3f" % \ 169 | (self.epoch, i, len(self.test_data['loader']), prec1_ft.avg)) 170 | else: 171 | raise NotImplementedError 172 | acc_for_each_class_fs = counter_acc_fs / counter_all_fs 173 | acc_for_each_class_ft = counter_acc_ft / counter_all_ft 174 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 175 | log.write("\n") 176 | class_weight = class_weight / count 177 | class_weight = class_weight / max(class_weight) 178 | if self.opt.EVAL_METRIC == 'accu': 179 | log.write( 180 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \ 181 | (self.epoch, prec1_fs.avg, prec1_ft.avg)) 182 | log.close() 183 | return class_weight, max(prec1_fs.avg, prec1_ft.avg) 184 | elif self.opt.EVAL_METRIC == 'accu_mean': 185 | log.write( 186 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \ 187 | (self.epoch,acc_for_each_class_fs.mean(), acc_for_each_class_ft.mean())) 188 | log.write("\nClass-wise Acc of Ft:") ## based on the task classifier. 189 | for i in range(self.opt.DATASET.NUM_CLASSES): 190 | if i == 0: 191 | log.write("%dst: %3f" % (i + 1, acc_for_each_class_ft[i])) 192 | elif i == 1: 193 | log.write(", %dnd: %3f" % (i + 1, acc_for_each_class_ft[i])) 194 | elif i == 2: 195 | log.write(", %drd: %3f" % (i + 1, acc_for_each_class_ft[i])) 196 | else: 197 | log.write(", %dth: %3f" % (i + 1, acc_for_each_class_ft[i])) 198 | log.close() 199 | return class_weight, max(acc_for_each_class_ft.mean(), acc_for_each_class_fs.mean()) 200 | 201 | def build_optimizer(self): 202 | if self.opt.TRAIN.OPTIMIZER == 'SGD': ## some params may not contribute the loss_all, thus they are not updated in the training process. 203 | self.optimizer_feature_extractor = torch.optim.SGD([ 204 | {'params': self.net['feature_extractor'].module.conv1.parameters(), 'name': 'pre-trained'}, 205 | {'params': self.net['feature_extractor'].module.bn1.parameters(), 'name': 'pre-trained'}, 206 | {'params': self.net['feature_extractor'].module.layer1.parameters(), 'name': 'pre-trained'}, 207 | {'params': self.net['feature_extractor'].module.layer2.parameters(), 'name': 'pre-trained'}, 208 | {'params': self.net['feature_extractor'].module.layer3.parameters(), 'name': 'pre-trained'}, 209 | {'params': self.net['feature_extractor'].module.layer4.parameters(), 'name': 'pre-trained'}, 210 | ], 211 | lr=self.opt.TRAIN.BASE_LR, 212 | momentum=self.opt.TRAIN.MOMENTUM, 213 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY, 214 | nesterov=True) 215 | 216 | self.optimizer_classifier = torch.optim.SGD([ 217 | {'params': self.net['classifier'].parameters(), 'name': 'new-added'}, 218 | ], 219 | lr=self.opt.TRAIN.BASE_LR, 220 | momentum=self.opt.TRAIN.MOMENTUM, 221 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY, 222 | nesterov=True) 223 | else: 224 | raise NotImplementedError 225 | print('Optimizer built') 226 | 227 | def update_lr(self): 228 | if self.opt.TRAIN.LR_SCHEDULE == 'inv': 229 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': 230 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.epoch / self.opt.TRAIN.MAX_EPOCH), self.opt.INV.BETA) 231 | elif self.opt.TRAIN.PROCESS_COUNTER == 'iteration': 232 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch)), self.opt.INV.BETA) 233 | else: 234 | raise NotImplementedError 235 | elif self.opt.TRAIN.LR_SCHEDULE == 'fix': 236 | lr = self.opt.TRAIN.BASE_LR 237 | else: 238 | raise NotImplementedError 239 | lr_pretrain = lr * 0.1 240 | print('the lr is: %3f' % (lr)) 241 | for param_group in self.optimizer_feature_extractor.param_groups: 242 | if param_group['name'] == 'pre-trained': 243 | param_group['lr'] = lr_pretrain 244 | elif param_group['name'] == 'new-added': 245 | param_group['lr'] = lr 246 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer 247 | param_group['lr'] = 0 248 | 249 | for param_group in self.optimizer_classifier.param_groups: 250 | if param_group['name'] == 'pre-trained': 251 | param_group['lr'] = lr_pretrain 252 | elif param_group['name'] == 'new-added': 253 | param_group['lr'] = lr 254 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer 255 | param_group['lr'] = 0 256 | 257 | def save_ckpt(self): 258 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 259 | log.write(" Best Acc so far: %3f" % (self.best_prec1)) 260 | log.close() 261 | if self.opt.TRAIN.SAVING: 262 | save_path = self.opt.SAVE_DIR 263 | ckpt_resume = os.path.join(save_path, 'ckpt_%d.resume' % (self.loop)) 264 | torch.save({'epoch': self.epoch, 265 | 'best_prec1': self.best_prec1, 266 | 'feature_extractor_state_dict': self.net['feature_extractor'].state_dict(), 267 | 'classifier_state_dict': self.net['classifier'].state_dict() 268 | }, ckpt_resume) 269 | -------------------------------------------------------------------------------- /solver/SymmNetsV2_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import math 5 | import time 6 | from utils.utils import to_cuda, accuracy_for_each_class, accuracy, AverageMeter, process_one_values 7 | from config.config import cfg 8 | import torch.nn.functional as F 9 | from models.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss 10 | from .base_solver import BaseSolver 11 | import ipdb 12 | 13 | class SymmNetsV2Solver(BaseSolver): 14 | def __init__(self, net, dataloaders, **kwargs): 15 | super(SymmNetsV2Solver, self).__init__(net, dataloaders, **kwargs) 16 | self.num_classes = cfg.DATASET.NUM_CLASSES 17 | self.TargetDiscrimLoss = TargetDiscrimLoss(num_classes=self.num_classes).cuda() 18 | self.ConcatenatedCELoss = ConcatenatedCELoss(num_classes=self.num_classes).cuda() 19 | self.feature_extractor = self.net['feature_extractor'] 20 | self.classifier = self.net['classifier'] 21 | 22 | if cfg.RESUME != '': 23 | resume_dict = torch.load(cfg.RESUME) 24 | self.net['feature_extractor'].load_state_dict(resume_dict['feature_extractor_state_dict']) 25 | self.net['classifier'].load_state_dict(resume_dict['classifier_state_dict']) 26 | self.best_prec1 = resume_dict['best_prec1'] 27 | self.epoch = resume_dict['epoch'] 28 | 29 | def solve(self): 30 | stop = False 31 | while not stop: 32 | stop = self.complete_training() 33 | self.update_network() 34 | acc = self.test() 35 | if acc > self.best_prec1: 36 | self.best_prec1 = acc 37 | self.save_ckpt() 38 | self.epoch += 1 39 | 40 | 41 | def update_network(self, **kwargs): 42 | stop = False 43 | self.train_data['source']['iterator'] = iter(self.train_data['source']['loader']) 44 | self.train_data['target']['iterator'] = iter(self.train_data['target']['loader']) 45 | self.iters_per_epoch = max(len(self.train_data['target']['loader']), len(self.train_data['source']['loader'])) 46 | iters_counter_within_epoch = 0 47 | data_time = AverageMeter() 48 | batch_time = AverageMeter() 49 | classifier_loss = AverageMeter() 50 | feature_extractor_loss = AverageMeter() 51 | prec1_fs = AverageMeter() 52 | prec1_ft = AverageMeter() 53 | self.feature_extractor.train() 54 | self.classifier.train() 55 | end = time.time() 56 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': 57 | lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1 58 | self.update_lr() 59 | print('value of lam is: %3f' % (lam)) 60 | while not stop: 61 | if self.opt.TRAIN.PROCESS_COUNTER == 'iteration': 62 | lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1 63 | print('value of lam is: %3f' % (lam)) 64 | self.update_lr() 65 | source_data, source_gt = self.get_samples('source') 66 | target_data, _ = self.get_samples('target') 67 | source_data = to_cuda(source_data) 68 | source_gt = to_cuda(source_gt) 69 | target_data = to_cuda(target_data) 70 | data_time.update(time.time() - end) 71 | 72 | feature_source = self.feature_extractor(source_data) 73 | output_source = self.classifier(feature_source) 74 | feature_target = self.feature_extractor(target_data) 75 | output_target = self.classifier(feature_target) 76 | 77 | loss_task_fs = self.CELoss(output_source[:,:self.num_classes], source_gt) 78 | loss_task_ft = self.CELoss(output_source[:,self.num_classes:], source_gt) 79 | loss_discrim_source = self.CELoss(output_source, source_gt) 80 | loss_discrim_target = self.TargetDiscrimLoss(output_target) 81 | loss_summary_classifier = loss_task_fs + loss_task_ft + loss_discrim_source + loss_discrim_target 82 | 83 | source_gt_for_ft_in_fst = source_gt + self.num_classes 84 | loss_confusion_source = 0.5 * self.CELoss(output_source, source_gt) + 0.5 * self.CELoss(output_source, source_gt_for_ft_in_fst) 85 | loss_confusion_target = self.ConcatenatedCELoss(output_target) 86 | loss_summary_feature_extractor = loss_confusion_source + lam * loss_confusion_target 87 | 88 | self.optimizer_classifier.zero_grad() 89 | loss_summary_classifier.backward(retain_graph=True) 90 | self.optimizer_classifier.step() 91 | 92 | self.optimizer_feature_extractor.zero_grad() 93 | loss_summary_feature_extractor.backward() 94 | self.optimizer_feature_extractor.step() 95 | 96 | classifier_loss.update(loss_summary_classifier, source_data.size()[0]) 97 | feature_extractor_loss.update(loss_summary_feature_extractor, source_data.size()[0]) 98 | prec1_fs.update(accuracy(output_source[:, :self.num_classes], source_gt), source_data.size()[0]) 99 | prec1_ft.update(accuracy(output_source[:, self.num_classes:], source_gt), source_data.size()[0]) 100 | 101 | print(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \ 102 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg)) 103 | 104 | batch_time.update(time.time() - end) 105 | end = time.time() 106 | self.iters += 1 107 | iters_counter_within_epoch += 1 108 | if iters_counter_within_epoch >= self.iters_per_epoch: 109 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 110 | log.write("\n") 111 | log.write(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \ 112 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg)) 113 | log.close() 114 | stop = True 115 | 116 | 117 | def test(self): 118 | self.feature_extractor.eval() 119 | self.classifier.eval() 120 | prec1_fs = AverageMeter() 121 | prec1_ft = AverageMeter() 122 | counter_all_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 123 | counter_all_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 124 | counter_acc_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 125 | counter_acc_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0) 126 | 127 | for i, (input, target) in enumerate(self.test_data['loader']): 128 | input, target = to_cuda(input), to_cuda(target) 129 | with torch.no_grad(): 130 | feature_test = self.feature_extractor(input) 131 | output_test = self.classifier(feature_test) 132 | 133 | 134 | if self.opt.EVAL_METRIC == 'accu': 135 | prec1_fs_iter = accuracy(output_test[:, :self.num_classes], target) 136 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target) 137 | prec1_fs.update(prec1_fs_iter, input.size(0)) 138 | prec1_ft.update(prec1_ft_iter, input.size(0)) 139 | if i % self.opt.PRINT_STEP == 0: 140 | print(" Test:epoch: %d:[%d/%d], AccFs: %3f, AccFt: %3f" % \ 141 | (self.epoch, i, len(self.test_data['loader']), prec1_fs.avg, prec1_ft.avg)) 142 | elif self.opt.EVAL_METRIC == 'accu_mean': 143 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target) 144 | prec1_ft.update(prec1_ft_iter, input.size(0)) 145 | counter_all_fs, counter_acc_fs = accuracy_for_each_class(output_test[:, :self.num_classes], target, counter_all_fs, counter_acc_fs) 146 | counter_all_ft, counter_acc_ft = accuracy_for_each_class(output_test[:, self.num_classes:], target, counter_all_ft, counter_acc_ft) 147 | if i % self.opt.PRINT_STEP == 0: 148 | print(" Test:epoch: %d:[%d/%d], Task: %3f" % \ 149 | (self.epoch, i, len(self.test_data['loader']), prec1_ft.avg)) 150 | else: 151 | raise NotImplementedError 152 | acc_for_each_class_fs = counter_acc_fs / counter_all_fs 153 | acc_for_each_class_ft = counter_acc_ft / counter_all_ft 154 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 155 | log.write("\n") 156 | if self.opt.EVAL_METRIC == 'accu': 157 | log.write( 158 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \ 159 | (self.epoch, prec1_fs.avg, prec1_ft.avg)) 160 | log.close() 161 | return max(prec1_fs.avg, prec1_ft.avg) 162 | elif self.opt.EVAL_METRIC == 'accu_mean': 163 | log.write( 164 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \ 165 | (self.epoch,acc_for_each_class_fs.mean(), acc_for_each_class_ft.mean())) 166 | log.write("\nClass-wise Acc of Ft:") ## based on the task classifier. 167 | for i in range(self.opt.DATASET.NUM_CLASSES): 168 | if i == 0: 169 | log.write("%dst: %3f" % (i + 1, acc_for_each_class_ft[i])) 170 | elif i == 1: 171 | log.write(", %dnd: %3f" % (i + 1, acc_for_each_class_ft[i])) 172 | elif i == 2: 173 | log.write(", %drd: %3f" % (i + 1, acc_for_each_class_ft[i])) 174 | else: 175 | log.write(", %dth: %3f" % (i + 1, acc_for_each_class_ft[i])) 176 | log.close() 177 | return max(acc_for_each_class_ft.mean(), acc_for_each_class_fs.mean()) 178 | 179 | def build_optimizer(self): 180 | if self.opt.TRAIN.OPTIMIZER == 'SGD': ## some params may not contribute the loss_all, thus they are not updated in the training process. 181 | self.optimizer_feature_extractor = torch.optim.SGD([ 182 | {'params': self.net['feature_extractor'].module.conv1.parameters(), 'name': 'pre-trained'}, 183 | {'params': self.net['feature_extractor'].module.bn1.parameters(), 'name': 'pre-trained'}, 184 | {'params': self.net['feature_extractor'].module.layer1.parameters(), 'name': 'pre-trained'}, 185 | {'params': self.net['feature_extractor'].module.layer2.parameters(), 'name': 'pre-trained'}, 186 | {'params': self.net['feature_extractor'].module.layer3.parameters(), 'name': 'pre-trained'}, 187 | {'params': self.net['feature_extractor'].module.layer4.parameters(), 'name': 'pre-trained'}, 188 | ], 189 | lr=self.opt.TRAIN.BASE_LR, 190 | momentum=self.opt.TRAIN.MOMENTUM, 191 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY, 192 | nesterov=True) 193 | 194 | self.optimizer_classifier = torch.optim.SGD([ 195 | {'params': self.net['classifier'].parameters(), 'name': 'new-added'}, 196 | ], 197 | lr=self.opt.TRAIN.BASE_LR, 198 | momentum=self.opt.TRAIN.MOMENTUM, 199 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY, 200 | nesterov=True) 201 | else: 202 | raise NotImplementedError 203 | print('Optimizer built') 204 | 205 | def update_lr(self): 206 | if self.opt.TRAIN.LR_SCHEDULE == 'inv': 207 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch': 208 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.epoch / self.opt.TRAIN.MAX_EPOCH), self.opt.INV.BETA) 209 | elif self.opt.TRAIN.PROCESS_COUNTER == 'iteration': 210 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch)), self.opt.INV.BETA) 211 | else: 212 | raise NotImplementedError 213 | elif self.opt.TRAIN.LR_SCHEDULE == 'fix': 214 | lr = self.opt.TRAIN.BASE_LR 215 | else: 216 | raise NotImplementedError 217 | lr_pretrain = lr * 0.1 218 | print('the lr is: %3f' % (lr)) 219 | for param_group in self.optimizer_feature_extractor.param_groups: 220 | if param_group['name'] == 'pre-trained': 221 | param_group['lr'] = lr_pretrain 222 | elif param_group['name'] == 'new-added': 223 | param_group['lr'] = lr 224 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer 225 | param_group['lr'] = 0 226 | 227 | for param_group in self.optimizer_classifier.param_groups: 228 | if param_group['name'] == 'pre-trained': 229 | param_group['lr'] = lr_pretrain 230 | elif param_group['name'] == 'new-added': 231 | param_group['lr'] = lr 232 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer 233 | param_group['lr'] = 0 234 | 235 | def save_ckpt(self): 236 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a') 237 | log.write(" Best Acc so far: %3f" % (self.best_prec1)) 238 | log.close() 239 | if self.opt.TRAIN.SAVING: 240 | save_path = self.opt.SAVE_DIR 241 | ckpt_resume = os.path.join(save_path, 'ckpt_%d.resume' % (self.loop)) 242 | torch.save({'epoch': self.epoch, 243 | 'best_prec1': self.best_prec1, 244 | 'feature_extractor_state_dict': self.net['feature_extractor'].state_dict(), 245 | 'classifier_state_dict': self.net['classifier'].state_dict() 246 | }, ckpt_resume) 247 | -------------------------------------------------------------------------------- /solver/base_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from utils.utils import to_cuda, AverageMeter 5 | from config.config import cfg 6 | 7 | class BaseSolver: 8 | def __init__(self, net, dataloaders, **kwargs): 9 | self.opt = cfg 10 | self.net = net 11 | self.dataloaders = dataloaders 12 | self.CELoss = nn.CrossEntropyLoss() 13 | if torch.cuda.is_available(): 14 | self.CELoss.cuda() 15 | self.epoch = 0 16 | self.iters = 0 17 | self.best_prec1 = 0 18 | self.iters_per_epoch = None 19 | self.build_optimizer() 20 | self.init_data(self.dataloaders) 21 | 22 | def init_data(self, dataloaders): 23 | self.train_data = {key: dict() for key in dataloaders if key != 'test'} 24 | for key in self.train_data.keys(): 25 | if key not in dataloaders: 26 | continue 27 | cur_dataloader = dataloaders[key] 28 | self.train_data[key]['loader'] = cur_dataloader 29 | self.train_data[key]['iterator'] = None 30 | 31 | if 'test' in dataloaders: 32 | self.test_data = dict() 33 | self.test_data['loader'] = dataloaders['test'] 34 | 35 | def build_optimizer(self): 36 | print('Optimizer built') 37 | 38 | 39 | def complete_training(self): 40 | if self.epoch > self.opt.TRAIN.MAX_EPOCH: 41 | return True 42 | 43 | def solve(self): 44 | print('Training Done!') 45 | 46 | def get_samples(self, data_name): 47 | assert(data_name in self.train_data) 48 | assert('loader' in self.train_data[data_name] and \ 49 | 'iterator' in self.train_data[data_name]) 50 | 51 | data_loader = self.train_data[data_name]['loader'] 52 | data_iterator = self.train_data[data_name]['iterator'] 53 | assert data_loader is not None and data_iterator is not None, \ 54 | 'Check your dataloader of %s.' % data_name 55 | 56 | try: 57 | sample = next(data_iterator) 58 | except StopIteration: 59 | data_iterator = iter(data_loader) 60 | sample = next(data_iterator) 61 | self.train_data[data_name]['iterator'] = data_iterator 62 | return sample 63 | 64 | 65 | def update_network(self, **kwargs): 66 | pass -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ipdb 3 | import argparse 4 | import os 5 | import numpy as np 6 | from torch.backends import cudnn 7 | from config.config import cfg, cfg_from_file, cfg_from_list 8 | import sys 9 | import pprint 10 | import json 11 | 12 | 13 | def parse_args(): 14 | """ 15 | Parse input arguments 16 | """ 17 | parser = argparse.ArgumentParser(description='Train script.') 18 | parser.add_argument('--resume', dest='resume', 19 | help='initialize with saved solver status', 20 | default=None, type=str) 21 | parser.add_argument('--cfg', dest='cfg_file', 22 | help='optional config file', 23 | default=None, type=str) 24 | parser.add_argument('--set', dest='set_cfgs', 25 | help='set config keys', default=None, 26 | nargs=argparse.REMAINDER) 27 | parser.add_argument('--method', dest='method', 28 | help='set the method to use', 29 | default='McDalNet', type=str) 30 | parser.add_argument('--task', dest='task', 31 | help='closed | partial | open', 32 | default='closed', type=str) 33 | parser.add_argument('--distance_type', dest='distance_type', 34 | help='set distance type in McDalNet', 35 | default='L1', type=str) 36 | 37 | parser.add_argument('--exp_name', dest='exp_name', 38 | help='the experiment name', 39 | default='log', type=str) 40 | 41 | 42 | if len(sys.argv) == 1: 43 | parser.print_help() 44 | sys.exit(1) 45 | 46 | args = parser.parse_args() 47 | return args 48 | 49 | def train(args): 50 | 51 | # method-specific setting 52 | if args.method == 'McDalNet': 53 | if cfg.DATASET.DATASET == 'Digits': 54 | raise NotImplementedError 55 | else: 56 | from solver.McDalNet_solver import McDalNetSolver as Solver 57 | from models.resnet_McDalNet import resnet as Model 58 | from data.prepare_data import generate_dataloader as Dataloader 59 | net = Model() 60 | net = torch.nn.DataParallel(net) 61 | if torch.cuda.is_available(): 62 | net.cuda() 63 | elif args.method == 'SymmNetsV2': 64 | if args.task == 'closed': 65 | if cfg.DATASET.DATASET == 'Digits': 66 | raise NotImplementedError 67 | else: 68 | from solver.SymmNetsV2_solver import SymmNetsV2Solver as Solver 69 | from models.resnet_SymmNet import resnet as Model 70 | from data.prepare_data import generate_dataloader as Dataloader 71 | feature_extractor, classifier = Model() 72 | feature_extractor = torch.nn.DataParallel(feature_extractor) 73 | classifier = torch.nn.DataParallel(classifier) 74 | if torch.cuda.is_available(): 75 | feature_extractor.cuda() 76 | classifier.cuda() 77 | net = {'feature_extractor': feature_extractor, 'classifier': classifier} 78 | elif args.task == 'partial': 79 | from solver.SymmNetsV2Partial_solver import SymmNetsV2PartialSolver as Solver 80 | from models.resnet_SymmNet import resnet as Model 81 | from data.prepare_data import generate_dataloader as Dataloader 82 | feature_extractor, classifier = Model() 83 | feature_extractor = torch.nn.DataParallel(feature_extractor) 84 | classifier = torch.nn.DataParallel(classifier) 85 | if torch.cuda.is_available(): 86 | feature_extractor.cuda() 87 | classifier.cuda() 88 | net = {'feature_extractor': feature_extractor, 'classifier': classifier} 89 | elif args.task == 'open': 90 | from solver.SymmNetsV2Open_solver import SymmNetsV2OpenSolver as Solver 91 | from models.resnet_SymmNet import resnet as Model 92 | from data.prepare_data import generate_dataloader_open as Dataloader 93 | feature_extractor, classifier = Model() 94 | feature_extractor = torch.nn.DataParallel(feature_extractor) 95 | classifier = torch.nn.DataParallel(classifier) 96 | if torch.cuda.is_available(): 97 | feature_extractor.cuda() 98 | classifier.cuda() 99 | net = {'feature_extractor': feature_extractor, 'classifier': classifier} 100 | elif args.task == 'closedsc': 101 | if cfg.DATASET.DATASET == 'Digits': 102 | raise NotImplementedError 103 | else: 104 | from solver.SymmNetsV2SC_solver import SymmNetsV2SolverSC as Solver 105 | from models.resnet_SymmNet import resnet as Model 106 | from data.prepare_data import generate_dataloader_sc as Dataloader 107 | feature_extractor, classifier = Model() 108 | feature_extractor = torch.nn.DataParallel(feature_extractor) 109 | classifier = torch.nn.DataParallel(classifier) 110 | if torch.cuda.is_available(): 111 | feature_extractor.cuda() 112 | classifier.cuda() 113 | net = {'feature_extractor': feature_extractor, 'classifier': classifier} 114 | else: 115 | raise NotImplementedError("Currently don't support the specified method: %s." % (args.task)) 116 | 117 | ## Algorithm proposed in our CVPR19 paper: Domain-Symnetric Networks for Adversarial Domain Adaptation 118 | ## It is the same with our previous implementation of https://github.com/YBZh/SymNets 119 | elif args.method == 'SymmNetsV1': 120 | if cfg.DATASET.DATASET == 'Digits': 121 | raise NotImplementedError 122 | else: 123 | from solver.SymmNetsV1_solver import SymmNetsV1Solver as Solver 124 | from models.resnet_SymmNet import resnet as Model 125 | from data.prepare_data import generate_dataloader as Dataloader 126 | feature_extractor, classifier = Model() 127 | feature_extractor = torch.nn.DataParallel(feature_extractor) 128 | classifier = torch.nn.DataParallel(classifier) 129 | if torch.cuda.is_available(): 130 | feature_extractor.cuda() 131 | classifier.cuda() 132 | net = {'feature_extractor': feature_extractor, 'classifier': classifier} 133 | else: 134 | raise NotImplementedError("Currently don't support the specified method: %s." % (args.method)) 135 | 136 | dataloaders = Dataloader() 137 | 138 | # initialize solver 139 | train_solver = Solver(net, dataloaders) 140 | 141 | # train 142 | train_solver.solve() 143 | print('Finished!') 144 | 145 | if __name__ == '__main__': 146 | cudnn.benchmark = True 147 | args = parse_args() 148 | 149 | print('Called with args:') 150 | print(args) 151 | 152 | if args.cfg_file is not None: 153 | cfg_from_file(args.cfg_file) 154 | # if args.set_cfgs is not None: 155 | # cfg_from_list(args.set_cfgs) 156 | 157 | if args.resume is not None: 158 | cfg.RESUME = args.resume 159 | if args.exp_name is not None: 160 | cfg.EXP_NAME = args.exp_name + cfg.DATASET.DATASET + '_' + cfg.DATASET.SOURCE_NAME + '2' + cfg.DATASET.VAL_NAME + '_' + args.method + '_' +args.distance_type + args.task 161 | if args.distance_type is not None: 162 | cfg.MCDALNET.DISTANCE_TYPE = args.distance_type 163 | 164 | print('Using config:') 165 | pprint.pprint(cfg) 166 | 167 | cfg.SAVE_DIR = os.path.join(cfg.SAVE_DIR, cfg.EXP_NAME) 168 | print('Output will be saved to %s.' % cfg.SAVE_DIR) 169 | if not os.path.isdir(cfg.SAVE_DIR): 170 | os.makedirs(cfg.SAVE_DIR) 171 | log = open(os.path.join(cfg.SAVE_DIR, 'log.txt'), 'a') 172 | log.write("\n") 173 | log.write(json.dumps(cfg) + '\n') 174 | log.close() 175 | 176 | train(args) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/utils/__init__.py -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def to_cuda(x): 5 | if torch.cuda.is_available(): 6 | x = x.cuda() 7 | return x 8 | 9 | def to_cpu(x): 10 | return x.cpu() 11 | 12 | def to_numpy(x): 13 | if torch.cuda.is_available(): 14 | x = x.cpu() 15 | return x.data.numpy() 16 | 17 | def to_onehot(label, num_classes): 18 | identity = torch.eye(num_classes).to(label.device) 19 | onehot = torch.index_select(identity, 0, label) 20 | return onehot 21 | 22 | def accuracy(output, target): 23 | """Computes the precision""" 24 | batch_size = target.size(0) 25 | _, pred = output.topk(1, 1, True, True) 26 | pred = pred.t() 27 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 28 | 29 | correct = correct[:1].view(-1).float().sum(0, keepdim=True) 30 | res = correct.mul_(100.0 / batch_size) 31 | return res 32 | 33 | 34 | def accuracy_for_each_class(output, target, total_vector, correct_vector): 35 | """Computes the precision for each class""" 36 | batch_size = target.size(0) 37 | _, pred = output.topk(1, 1, True, True) 38 | pred = pred.t() 39 | correct = pred.eq(target.view(1, -1)).float().cpu().squeeze() 40 | for i in range(batch_size): 41 | total_vector[target[i]] += 1 42 | correct_vector[torch.LongTensor([target[i]])] += correct[i] 43 | 44 | return total_vector, correct_vector 45 | 46 | def recall_for_each_class(output, target, total_vector, correct_vector): 47 | """Computes the recall for each class""" 48 | batch_size = target.size(0) 49 | _, pred = output.topk(1, 1, True, True) 50 | pred = pred.t() 51 | correct = pred.eq(target.view(1, -1)).float().cpu().squeeze() 52 | for i in range(batch_size): 53 | total_vector[pred[0][i]] += 1 54 | correct_vector[torch.LongTensor([pred[0][i]])] += correct[i] 55 | 56 | return total_vector, correct_vector 57 | 58 | def process_one_values(tensor): 59 | if (tensor == 1).sum() != 0: 60 | eps = torch.FloatTensor(tensor.size()).fill_(0) 61 | eps[tensor.data.cpu() == 1] = 1e-6 62 | tensor = tensor - eps.cuda() 63 | return tensor 64 | 65 | def process_zero_values(tensor): 66 | if (tensor == 0).sum() != 0: 67 | eps = torch.FloatTensor(tensor.size()).fill_(0) 68 | eps[tensor.data.cpu() == 0] = 1e-6 69 | tensor = tensor + eps.cuda() 70 | return tensor 71 | 72 | 73 | class AverageMeter(object): 74 | """Computes and stores the average and current value""" 75 | def __init__(self): 76 | self.reset() 77 | 78 | def reset(self): 79 | self.val = 0 80 | self.avg = 0 81 | self.sum = 0 82 | self.count = 0 83 | 84 | def update(self, val, n=1): 85 | self.val = val 86 | self.sum += val * n 87 | self.count += n 88 | self.avg = self.sum / self.count --------------------------------------------------------------------------------