├── .DS_Store
├── .gitignore
├── .idea
├── MultiClassDA.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── config
├── __init__.py
└── config.py
├── data
├── domainnet_data_prepare.py
├── folder_mec.py
├── folder_new.py
├── prepare_data.py
└── vision.py
├── experiments
├── .DS_Store
├── Clarification_on_tasks.md
├── Open
│ └── logOffice31_amazon_s_busto2webcam_t_busto_SymmNetsV2_None
│ │ └── log.txt
├── Partial
│ └── logOfficeHome_Art2Real_World_25_SymmNetsV2_None
│ │ └── log.txt
├── SymmNets
│ ├── logImageCLEF_c2i_SymmNetsV1_None
│ │ └── log.txt
│ ├── logImageCLEF_c2i_SymmNetsV2_None
│ │ └── log.txt
│ ├── logOffice31_amazon2dslr_SymmNetsV2_Noneclosed
│ │ └── log.txt
│ ├── logOffice31_amazon2dslr_SymmNetsV2_Noneclosedsc
│ │ └── log.txt
│ ├── logOffice31_webcam2amazon_SymmNetsV2_Noneclosed
│ │ └── log.txt
│ ├── logOffice31_webcam2amazon_SymmNetsV2_Noneclosedsc
│ │ └── log.txt
│ ├── logVisDA_train2validation_SymmNetsV2_None
│ │ └── log.txt
│ └── logres101VisDA_train2validation_SymmNetsV2_None
│ │ └── log.txt
├── ckpt
│ ├── logImageCLEF_c2i_McDalNet_CE
│ │ └── log.txt
│ ├── logImageCLEF_c2i_McDalNet_DANN
│ │ └── log.txt
│ ├── logImageCLEF_c2i_McDalNet_KL
│ │ └── log.txt
│ ├── logImageCLEF_c2i_McDalNet_L1
│ │ └── log.txt
│ ├── logImageCLEF_c2i_McDalNet_MDD
│ │ └── log.txt
│ ├── logVisDA_train2validation_McDalNet_CE
│ │ └── log.txt
│ ├── logVisDA_train2validation_McDalNet_DANN
│ │ └── log.txt
│ ├── logVisDA_train2validation_McDalNet_KL
│ │ └── log.txt
│ ├── logVisDA_train2validation_McDalNet_L1
│ │ └── log.txt
│ └── logVisDA_train2validation_McDalNet_MDD
│ │ └── log.txt
└── configs
│ ├── .DS_Store
│ ├── ImageCLEF
│ ├── .DS_Store
│ ├── McDalNet
│ │ └── clef_train_c2i_cfg.yaml
│ └── SymmNets
│ │ └── clef_train_c2i_cfg.yaml
│ ├── Office31
│ ├── .DS_Store
│ ├── McDalNet
│ │ └── office31_train_webcam2amazon_cfg.yaml
│ └── SymmNets
│ │ ├── .DS_Store
│ │ ├── office31_train_amazon2dslr_cfg.yaml
│ │ ├── office31_train_amazon2dslr_cfg_SC.yaml
│ │ ├── office31_train_amazon2webcam_open_cfg.yaml
│ │ ├── office31_train_webcam2amazon_cfg.yaml
│ │ └── office31_train_webcam2amazon_cfg_SC.yaml
│ ├── OfficeHome
│ ├── .DS_Store
│ └── SymmNets
│ │ └── home_train_A2R_partial_cfg.yaml
│ └── VisDA
│ ├── McDalNet
│ └── visda17_train_train2val_cfg.yaml
│ └── SymmNets
│ ├── visda17_train_train2val_cfg.yaml
│ └── visda17_train_train2val_cfg_res101.yaml
├── models
├── loss_utils.py
├── resnet_McDalNet.py
└── resnet_SymmNet.py
├── run.sh
├── run_symmnets.sh
├── run_temp.sh
├── solver
├── McDalNet_solver.py
├── SymmNetsV1_solver.py
├── SymmNetsV2Open_solver.py
├── SymmNetsV2Partial_solver.py
├── SymmNetsV2SC_solver.py
├── SymmNetsV2_solver.py
└── base_solver.py
├── tools
└── train.py
└── utils
├── __init__.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.idea/MultiClassDA.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
19 |
20 |
21 |
22 | target_train_dataset
23 | DatasetFolder
24 | resume_dict
25 | self.opt
26 | nClass
27 | model_eval
28 | McDalNetLoss
29 | SAVE_DIR
30 | iters_per_epoch
31 | cfg
32 | PROCE
33 | BasicBlock
34 | net
35 | prob_s
36 | counter_all_auxi1
37 | counter_acc_auxi2
38 | expansion
39 | exp_name
40 | self.num_classes:
41 | source_gt_for_ft_in_fst
42 |
43 |
44 | counter_all_fs
45 | counter_acc_ft
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 | 1581836641713
80 |
81 |
82 | 1581836641713
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 | file://$PROJECT_DIR$/data/prepare_data.py
91 | 122
92 |
93 |
94 |
95 | file://$PROJECT_DIR$/solver/SymmNetsV1_solver.py
96 |
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yabin Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MultiClassDA (TPAMI2020)
2 | Code release for ["Unsupervised Multi-Class Domain Adaptation: Theory, Algorithms, and Practice"](https://arxiv.org/pdf/2002.08681.pdf), which is
3 | an extension of our preliminary work of SymmNets [[Paper](https://zpascal.net/cvpr2019/Zhang_Domain-Symmetric_Networks_for_Adversarial_Domain_Adaptation_CVPR_2019_paper.pdf)] [[Code](https://github.com/YBZh/SymNets)]
4 |
5 | Please refer to the "run_temp.sh" for the usage.
6 | All expeimental results are logged in the file of "./experiments"
7 |
8 |
9 |
10 |
11 | ## Included codes:
12 | 1. Codes of McDalNets -->./solver/McDalNet_solver.py
13 | 2. Codes of SymNets-V2
14 | 1. For the Closed Set DA -->./solver/SymmNetsV2_solver.py
15 | 2. For the Strongthened Closed Set DA -->./solver/SymmNetsV2SC_solver.py
16 | 3. For the Partial DA -->./solver/SSymmNetsV2Partial_solver.py
17 | 4. For the Open Set DA -->./solver/SymmNetsV2Open_solver.py
18 |
19 |
20 | ## Dataset
21 | The structure of the dataset should be like
22 |
23 | ```
24 | Office-31
25 | |_ amazon
26 | | |_ back_pack
27 | | |_ .jpg
28 | | |_ ...
29 | | |_ .jpg
30 | | |_ bike
31 | | |_ .jpg
32 | | |_ ...
33 | | |_ .jpg
34 | | |_ ...
35 | |_ dslr
36 | | |_ back_pack
37 | | |_ .jpg
38 | | |_ ...
39 | | |_ .jpg
40 | | |_ bike
41 | | |_ .jpg
42 | | |_ ...
43 | | |_ .jpg
44 | | |_ ...
45 | |_ ...
46 | ```
47 |
48 |
49 | ## Citation
50 |
51 | @inproceedings{zhang2019domain,
52 | title={Domain-symmetric networks for adversarial domain adaptation},
53 | author={Zhang, Yabin and Tang, Hui and Jia, Kui and Tan, Mingkui},
54 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
55 | pages={5031--5040},
56 | year={2019}
57 | }
58 | @article{zhang2020unsupervised,
59 | title={Unsupervised Multi-Class Domain Adaptation: Theory, Algorithms, and Practice},
60 | author={Zhang, Yabin and Deng, Bin and Tang, Hui and Zhang, Lei and Jia, Kui},
61 | journal=IEEE Transactions on Pattern Analysis and Machine Intelligence},
62 | year={2020}
63 | publisher={IEEE}
64 | }
65 |
66 | ## Contact
67 | If you have any problem about our code, feel free to contact
68 | - zhang.yabin@mail.scut.edu.cn
69 |
70 | or describe your problem in Issues.
71 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/config/__init__.py
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ipdb
3 | import numpy as np
4 | from easydict import EasyDict as edict
5 |
6 | __C = edict()
7 | cfg = __C
8 |
9 | # Dataset options
10 | #
11 | __C.DATASET = edict()
12 | __C.DATASET.NUM_CLASSES = 0
13 | __C.DATASET.DATASET = ''
14 | __C.DATASET.DATAROOT = ''
15 | __C.DATASET.SOURCE_NAME = ''
16 | __C.DATASET.TARGET_NAME = ''
17 | __C.DATASET.VAL_NAME = ''
18 |
19 | # Model options
20 | __C.MODEL = edict()
21 | __C.MODEL.FEATURE_EXTRACTOR = 'resnet101'
22 | __C.MODEL.PRETRAINED = True
23 | # __C.MODEL.FC_HIDDEN_DIMS = () ### options for multiple layers.
24 |
25 | # data pre-processing options
26 | #
27 | __C.DATA_TRANSFORM = edict()
28 | __C.DATA_TRANSFORM.TYPE = 'ours' ### trun to simple for the VisDA dataset.
29 |
30 |
31 | # Training options
32 | #
33 | __C.TRAIN = edict()
34 | # batch size setting
35 | __C.TRAIN.SOURCE_BATCH_SIZE = 128
36 | __C.TRAIN.TARGET_BATCH_SIZE = 128
37 |
38 | # learning rate schedule
39 | __C.TRAIN.BASE_LR = 0.01
40 | __C.TRAIN.MOMENTUM = 0.9
41 | __C.TRAIN.OPTIMIZER = 'SGD'
42 | __C.TRAIN.WEIGHT_DECAY = 0.0001
43 | __C.TRAIN.LR_SCHEDULE = 'inv'
44 | __C.TRAIN.MAX_EPOCH = 400
45 | __C.TRAIN.SAVING = False ## whether to save the intermediate status of model.
46 | __C.TRAIN.PROCESS_COUNTER = 'iteration'
47 | # __C.TRAIN.STOP_THRESHOLDS = (0.001, 0.001, 0.001)
48 | # __C.TRAIN.TEST_INTERVAL = 1.0 # percentage of total iterations each loop
49 | # __C.TRAIN.SAVE_CKPT_INTERVAL = 1.0 # percentage of total iterations in each loop
50 |
51 |
52 | __C.STRENGTHEN = edict()
53 | __C.STRENGTHEN.DATALOAD = 'normal' ## normal | hard | soft. The original class aware sampling adopt the hard mode.
54 | __C.STRENGTHEN.PERCATE = 10
55 | __C.STRENGTHEN.CLUSTER_FREQ = 6
56 |
57 |
58 |
59 | # optimizer options
60 | __C.MCDALNET = edict()
61 | __C.MCDALNET.DISTANCE_TYPE = '' ## choose in L1 | KL | CE | MDD | DANN | SourceOnly
62 |
63 | # optimizer options
64 | __C.ADAM = edict() ### adopted by the Digits dataset only
65 | __C.ADAM.BETA1 = 0.9
66 | __C.ADAM.BETA2 = 0.999
67 |
68 | __C.INV = edict()
69 | __C.INV.ALPHA = 10.0
70 | __C.INV.BETA = 0.75
71 |
72 | __C.OPEN = edict()
73 | __C.OPEN.WEIGHT_UNK = 6.0
74 |
75 | # Testing options
76 | #
77 | __C.TEST = edict()
78 | __C.TEST.BATCH_SIZE = 128
79 |
80 |
81 | # MISC
82 | __C.RESUME = ''
83 | __C.TASK = 'closed' ## closed | partial | open
84 | __C.EVAL_METRIC = "accu" # "mean_accu" as alternative
85 | __C.EXP_NAME = 'exp'
86 | __C.SAVE_DIR = ''
87 | __C.NUM_WORKERS = 6
88 | __C.PRINT_STEP = 3
89 |
90 | def _merge_a_into_b(a, b):
91 | """Merge config dictionary a into config dictionary b, clobbering the
92 | options in b whenever they are also specified in a.
93 | """
94 | if type(a) is not edict:
95 | return
96 |
97 | for k in a:
98 | # a must specify keys that are in b
99 | v = a[k]
100 | if k not in b:
101 | raise KeyError('{} is not a valid config key'.format(k))
102 |
103 | # the types must match, too
104 | old_type = type(b[k])
105 | if old_type is not type(v):
106 | if isinstance(b[k], np.ndarray):
107 | v = np.array(v, dtype=b[k].dtype)
108 | else:
109 | raise ValueError(('Type mismatch ({} vs. {}) '
110 | 'for config key: {}').format(type(b[k]),
111 | type(v), k))
112 |
113 | # recursively merge dicts
114 | if type(v) is edict:
115 | try:
116 | _merge_a_into_b(a[k], b[k])
117 | except:
118 | print('Error under config key: {}'.format(k))
119 | raise
120 | else:
121 | b[k] = v
122 |
123 | def cfg_from_file(filename):
124 | """Load a config file and merge it into the default options."""
125 | import yaml
126 | if filename[-1] == '\r':
127 | filename = filename[:-1] ## delete the '\r' at the end of the str
128 | with open(filename, 'r') as f:
129 | yaml_cfg = edict(yaml.load(f, Loader=yaml.FullLoader))
130 |
131 | _merge_a_into_b(yaml_cfg, __C)
132 |
133 | def cfg_from_list(cfg_list):
134 | """Set config keys via list (e.g., from command line)."""
135 | from ast import literal_eval
136 | assert len(cfg_list) % 2 == 0
137 | for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
138 | key_list = k.split('.')
139 | d = __C
140 | for subkey in key_list[:-1]:
141 | assert subkey in d
142 | d = d[subkey]
143 | subkey = key_list[-1]
144 | assert subkey in d
145 | try:
146 | value = literal_eval(v)
147 | except:
148 | # handle the case when v is a string literal
149 | value = v
150 | assert type(value) == type(d[subkey]), \
151 | 'type {} does not match original type {}'.format(
152 | type(value), type(d[subkey]))
153 | d[subkey] = value
--------------------------------------------------------------------------------
/data/domainnet_data_prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import ipdb
4 |
5 | data_file = ['clipart', 'painting', 'real', 'sketch', 'infograph', 'quickdraw']
6 | current_path = os.getcwd()
7 | for task in data_file:
8 | source_path = current_path + '/' + task
9 | source_train_path = current_path + '/' + task + '_train'
10 | if not os.path.isdir(source_train_path):
11 | os.makedirs(source_train_path)
12 | txt_train = current_path + '/' + task + '_train.txt'
13 | txt_file = open(txt_train)
14 | line = txt_file.readline()
15 | while line:
16 | image_path = line.split(' ')[0]
17 | image_path_split_list = image_path.split('/')
18 | source_image_path = source_path + '/' + image_path_split_list[1] + '/' + image_path_split_list[2]
19 | source_train_category = source_train_path + '/' + image_path_split_list[1]
20 | if not os.path.isdir(source_train_category):
21 | os.makedirs(source_train_category)
22 | source_train_image_path = source_train_category + '/' + image_path_split_list[2]
23 | print('copy image from %s -> %s' % (source_image_path, source_train_image_path))
24 | shutil.copyfile(source_image_path, source_train_image_path)
25 | line = txt_file.readline()
26 |
27 |
28 |
29 | source_test_path = current_path + '/' + task + '_test'
30 | if not os.path.isdir(source_test_path):
31 | os.makedirs(source_test_path)
32 | txt_test = current_path + '/' + task + '_test.txt'
33 | txt_file = open(txt_test)
34 | line = txt_file.readline()
35 | while line:
36 | image_path = line.split(' ')[0]
37 | image_path_split_list = image_path.split('/')
38 | source_image_path = source_path + '/' + image_path_split_list[1] + '/' + image_path_split_list[2]
39 | source_test_category = source_test_path + '/' + image_path_split_list[1]
40 | if not os.path.isdir(source_test_category):
41 | os.makedirs(source_test_category)
42 | source_test_image_path = source_test_category + '/' + image_path_split_list[2]
43 | print('copy image from %s -> %s' % (source_image_path, source_test_image_path))
44 | shutil.copyfile(source_image_path, source_test_image_path)
45 | line = txt_file.readline()
46 |
--------------------------------------------------------------------------------
/data/folder_mec.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that XXX
5 | ###############################################################################
6 | from data.vision import VisionDataset
7 |
8 | from PIL import Image
9 |
10 | import os
11 | import os.path
12 | import sys
13 | ########################### added part ###
14 | from PIL import ImageFile
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True ## used to handle some error when loading the special images.
16 |
17 | def has_file_allowed_extension(filename, extensions):
18 | """Checks if a file is an allowed extension.
19 | Args:
20 | filename (string): path to a file
21 | extensions (tuple of strings): extensions to consider (lowercase)
22 | Returns:
23 | bool: True if the filename ends with one of given extensions
24 | """
25 | return filename.lower().endswith(extensions)
26 |
27 |
28 | def is_image_file(filename):
29 | """Checks if a file is an allowed image extension.
30 | Args:
31 | filename (string): path to a file
32 | Returns:
33 | bool: True if the filename ends with a known image extension
34 | """
35 | return has_file_allowed_extension(filename, IMG_EXTENSIONS)
36 |
37 |
38 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
39 | images = []
40 | dir = os.path.expanduser(dir)
41 | if not ((extensions is None) ^ (is_valid_file is None)):
42 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
43 | if extensions is not None:
44 | def is_valid_file(x):
45 | return has_file_allowed_extension(x, extensions)
46 | for target in sorted(class_to_idx.keys()):
47 | d = os.path.join(dir, target)
48 | if not os.path.isdir(d):
49 | continue
50 | for root, _, fnames in sorted(os.walk(d, followlinks=True)):
51 | for fname in sorted(fnames):
52 | path = os.path.join(root, fname)
53 | if is_valid_file(path):
54 | item = (path, class_to_idx[target])
55 | images.append(item)
56 |
57 | return images
58 |
59 |
60 | class DatasetFolder(VisionDataset):
61 | """A generic data loader where the samples are arranged in this way: ::
62 | root/class_x/xxx.ext
63 | root/class_x/xxy.ext
64 | root/class_x/xxz.ext
65 | root/class_y/123.ext
66 | root/class_y/nsdf3.ext
67 | root/class_y/asd932_.ext
68 | Args:
69 | root (string): Root directory path.
70 | loader (callable): A function to load a sample given its path.
71 | extensions (tuple[string]): A list of allowed extensions.
72 | both extensions and is_valid_file should not be passed.
73 | transform (callable, optional): A function/transform that takes in
74 | a sample and returns a transformed version.
75 | E.g, ``transforms.RandomCrop`` for images.
76 | target_transform (callable, optional): A function/transform that takes
77 | in the target and transforms it.
78 | is_valid_file (callable, optional): A function that takes path of a file
79 | and check if the file is a valid file (used to check of corrupt files)
80 | both extensions and is_valid_file should not be passed.
81 | Attributes:
82 | classes (list): List of the class names.
83 | class_to_idx (dict): Dict with items (class_name, class_index).
84 | samples (list): List of (sample path, class_index) tuples
85 | targets (list): The class_index value for each image in the dataset
86 | """
87 |
88 | def __init__(self, root, loader, extensions=None, transform=None, transform_mec=None,
89 | target_transform=None, is_valid_file=None):
90 | super(DatasetFolder, self).__init__(root, transform=transform,
91 | target_transform=target_transform)
92 | classes, class_to_idx = self._find_classes(self.root)
93 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
94 | if len(samples) == 0:
95 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
96 | "Supported extensions are: " + ",".join(extensions)))
97 |
98 | self.loader = loader
99 | self.extensions = extensions
100 |
101 | self.classes = classes
102 | self.class_to_idx = class_to_idx
103 | self.samples = samples
104 | self.targets = [s[1] for s in samples]
105 | self.transform_mec = transform_mec
106 |
107 | def _find_classes(self, dir):
108 | """
109 | Finds the class folders in a dataset.
110 | Args:
111 | dir (string): Root directory path.
112 | Returns:
113 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
114 | Ensures:
115 | No class is a subdirectory of another.
116 | """
117 | if sys.version_info >= (3, 5):
118 | # Faster and available in Python 3.5 and above
119 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
120 | else:
121 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
122 | classes.sort()
123 | class_to_idx = {classes[i]: i for i in range(len(classes))}
124 | return classes, class_to_idx
125 |
126 | def __getitem__(self, index):
127 | """
128 | Args:
129 | index (int): Index
130 | Returns:
131 | tuple: (sample, target) where target is class_index of the target class.
132 | """
133 | path, target = self.samples[index]
134 | sample = self.loader(path)
135 | if self.transform is not None:
136 | sample_ori = self.transform(sample)
137 | if self.transform_mec is not None:
138 | sample_mec = self.transform(sample)
139 | if self.target_transform is not None:
140 | target = self.target_transform(target)
141 |
142 | return sample_ori, sample_mec, target, path
143 |
144 | def __len__(self):
145 | return len(self.samples)
146 |
147 |
148 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
149 |
150 |
151 | def pil_loader(path):
152 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
153 | with open(path, 'rb') as f:
154 | img = Image.open(f)
155 | return img.convert('RGB')
156 |
157 |
158 | def accimage_loader(path):
159 | import accimage
160 | try:
161 | return accimage.Image(path)
162 | except IOError:
163 | # Potentially a decoding problem, fall back to PIL.Image
164 | return pil_loader(path)
165 |
166 |
167 | def default_loader(path):
168 | from torchvision import get_image_backend
169 | if get_image_backend() == 'accimage':
170 | return accimage_loader(path)
171 | else:
172 | return pil_loader(path)
173 |
174 |
175 | class ImageFolder_MEC(DatasetFolder):
176 | """A generic data loader where the images are arranged in this way: ::
177 | root/dog/xxx.png
178 | root/dog/xxy.png
179 | root/dog/xxz.png
180 | root/cat/123.png
181 | root/cat/nsdf3.png
182 | root/cat/asd932_.png
183 | Args:
184 | root (string): Root directory path.
185 | transform (callable, optional): A function/transform that takes in an PIL image
186 | and returns a transformed version. E.g, ``transforms.RandomCrop``
187 | target_transform (callable, optional): A function/transform that takes in the
188 | target and transforms it.
189 | loader (callable, optional): A function to load an image given its path.
190 | is_valid_file (callable, optional): A function that takes path of an Image file
191 | and check if the file is a valid file (used to check of corrupt files)
192 | Attributes:
193 | classes (list): List of the class names.
194 | class_to_idx (dict): Dict with items (class_name, class_index).
195 | imgs (list): List of (image path, class_index) tuples
196 | """
197 |
198 | def __init__(self, root, transform=None, transform_mec=None, target_transform=None,
199 | loader=default_loader, is_valid_file=None):
200 | super(ImageFolder_MEC, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
201 | transform=transform,
202 | transform_mec=transform_mec,
203 | target_transform=target_transform,
204 | is_valid_file=is_valid_file)
205 | self.imgs = self.samples
--------------------------------------------------------------------------------
/data/folder_new.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that XXX
5 | ###############################################################################
6 | from data.vision import VisionDataset
7 |
8 | from PIL import Image
9 |
10 | import os
11 | import os.path
12 | import sys
13 | ########################### added part ###
14 | from PIL import ImageFile
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True ## used to handle some error when loading the special images.
16 |
17 | def has_file_allowed_extension(filename, extensions):
18 | """Checks if a file is an allowed extension.
19 | Args:
20 | filename (string): path to a file
21 | extensions (tuple of strings): extensions to consider (lowercase)
22 | Returns:
23 | bool: True if the filename ends with one of given extensions
24 | """
25 | return filename.lower().endswith(extensions)
26 |
27 |
28 | def is_image_file(filename):
29 | """Checks if a file is an allowed image extension.
30 | Args:
31 | filename (string): path to a file
32 | Returns:
33 | bool: True if the filename ends with a known image extension
34 | """
35 | return has_file_allowed_extension(filename, IMG_EXTENSIONS)
36 |
37 |
38 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
39 | images = []
40 | dir = os.path.expanduser(dir)
41 | if not ((extensions is None) ^ (is_valid_file is None)):
42 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
43 | if extensions is not None:
44 | def is_valid_file(x):
45 | return has_file_allowed_extension(x, extensions)
46 | for target in sorted(class_to_idx.keys()):
47 | d = os.path.join(dir, target)
48 | if not os.path.isdir(d):
49 | continue
50 | for root, _, fnames in sorted(os.walk(d, followlinks=True)):
51 | for fname in sorted(fnames):
52 | path = os.path.join(root, fname)
53 | if is_valid_file(path):
54 | item = (path, class_to_idx[target])
55 | images.append(item)
56 |
57 | return images
58 |
59 |
60 | class DatasetFolder(VisionDataset):
61 | """A generic data loader where the samples are arranged in this way: ::
62 | root/class_x/xxx.ext
63 | root/class_x/xxy.ext
64 | root/class_x/xxz.ext
65 | root/class_y/123.ext
66 | root/class_y/nsdf3.ext
67 | root/class_y/asd932_.ext
68 | Args:
69 | root (string): Root directory path.
70 | loader (callable): A function to load a sample given its path.
71 | extensions (tuple[string]): A list of allowed extensions.
72 | both extensions and is_valid_file should not be passed.
73 | transform (callable, optional): A function/transform that takes in
74 | a sample and returns a transformed version.
75 | E.g, ``transforms.RandomCrop`` for images.
76 | target_transform (callable, optional): A function/transform that takes
77 | in the target and transforms it.
78 | is_valid_file (callable, optional): A function that takes path of a file
79 | and check if the file is a valid file (used to check of corrupt files)
80 | both extensions and is_valid_file should not be passed.
81 | Attributes:
82 | classes (list): List of the class names.
83 | class_to_idx (dict): Dict with items (class_name, class_index).
84 | samples (list): List of (sample path, class_index) tuples
85 | targets (list): The class_index value for each image in the dataset
86 | """
87 |
88 | def __init__(self, root, loader, extensions=None, transform=None,
89 | target_transform=None, is_valid_file=None):
90 | super(DatasetFolder, self).__init__(root, transform=transform,
91 | target_transform=target_transform)
92 | classes, class_to_idx = self._find_classes(self.root)
93 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
94 | if len(samples) == 0:
95 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
96 | "Supported extensions are: " + ",".join(extensions)))
97 |
98 | self.loader = loader
99 | self.extensions = extensions
100 |
101 | self.classes = classes
102 | self.class_to_idx = class_to_idx
103 | self.samples = samples
104 | self.targets = [s[1] for s in samples]
105 |
106 | def _find_classes(self, dir):
107 | """
108 | Finds the class folders in a dataset.
109 | Args:
110 | dir (string): Root directory path.
111 | Returns:
112 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
113 | Ensures:
114 | No class is a subdirectory of another.
115 | """
116 | if sys.version_info >= (3, 5):
117 | # Faster and available in Python 3.5 and above
118 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
119 | else:
120 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
121 | classes.sort()
122 | class_to_idx = {classes[i]: i for i in range(len(classes))}
123 | return classes, class_to_idx
124 |
125 | def __getitem__(self, index):
126 | """
127 | Args:
128 | index (int): Index
129 | Returns:
130 | tuple: (sample, target) where target is class_index of the target class.
131 | """
132 | path, target = self.samples[index]
133 | sample = self.loader(path)
134 | if self.transform is not None:
135 | sample = self.transform(sample)
136 | if self.target_transform is not None:
137 | target = self.target_transform(target)
138 |
139 | return sample, target, path
140 |
141 | def __len__(self):
142 | return len(self.samples)
143 |
144 |
145 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
146 |
147 |
148 | def pil_loader(path):
149 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
150 | with open(path, 'rb') as f:
151 | img = Image.open(f)
152 | return img.convert('RGB')
153 |
154 |
155 | def accimage_loader(path):
156 | import accimage
157 | try:
158 | return accimage.Image(path)
159 | except IOError:
160 | # Potentially a decoding problem, fall back to PIL.Image
161 | return pil_loader(path)
162 |
163 |
164 | def default_loader(path):
165 | from torchvision import get_image_backend
166 | if get_image_backend() == 'accimage':
167 | return accimage_loader(path)
168 | else:
169 | return pil_loader(path)
170 |
171 |
172 | class ImageFolder_Withpath(DatasetFolder):
173 | """A generic data loader where the images are arranged in this way: ::
174 | root/dog/xxx.png
175 | root/dog/xxy.png
176 | root/dog/xxz.png
177 | root/cat/123.png
178 | root/cat/nsdf3.png
179 | root/cat/asd932_.png
180 | Args:
181 | root (string): Root directory path.
182 | transform (callable, optional): A function/transform that takes in an PIL image
183 | and returns a transformed version. E.g, ``transforms.RandomCrop``
184 | target_transform (callable, optional): A function/transform that takes in the
185 | target and transforms it.
186 | loader (callable, optional): A function to load an image given its path.
187 | is_valid_file (callable, optional): A function that takes path of an Image file
188 | and check if the file is a valid file (used to check of corrupt files)
189 | Attributes:
190 | classes (list): List of the class names.
191 | class_to_idx (dict): Dict with items (class_name, class_index).
192 | imgs (list): List of (image path, class_index) tuples
193 | """
194 |
195 | def __init__(self, root, transform=None, target_transform=None,
196 | loader=default_loader, is_valid_file=None):
197 | super(ImageFolder_Withpath, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
198 | transform=transform,
199 | target_transform=target_transform,
200 | is_valid_file=is_valid_file)
201 | self.imgs = self.samples
--------------------------------------------------------------------------------
/data/vision.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
4 | # Modified the original code so that XXX
5 | ###############################################################################
6 | import os
7 | import torch
8 | import torch.utils.data as data
9 |
10 |
11 | class VisionDataset(data.Dataset):
12 | _repr_indent = 4
13 |
14 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
15 | if isinstance(root, torch._six.string_classes):
16 | root = os.path.expanduser(root)
17 | self.root = root
18 |
19 | has_transforms = transforms is not None
20 | has_separate_transform = transform is not None or target_transform is not None
21 | if has_transforms and has_separate_transform:
22 | raise ValueError("Only transforms or transform/target_transform can "
23 | "be passed as argument")
24 |
25 | # for backwards-compatibility
26 | self.transform = transform
27 | self.target_transform = target_transform
28 |
29 | if has_separate_transform:
30 | transforms = StandardTransform(transform, target_transform)
31 | self.transforms = transforms
32 |
33 | def __getitem__(self, index):
34 | raise NotImplementedError
35 |
36 | def __len__(self):
37 | raise NotImplementedError
38 |
39 | def __repr__(self):
40 | head = "Dataset " + self.__class__.__name__
41 | body = ["Number of datapoints: {}".format(self.__len__())]
42 | if self.root is not None:
43 | body.append("Root location: {}".format(self.root))
44 | body += self.extra_repr().splitlines()
45 | if hasattr(self, "transforms") and self.transforms is not None:
46 | body += [repr(self.transforms)]
47 | lines = [head] + [" " * self._repr_indent + line for line in body]
48 | return '\n'.join(lines)
49 |
50 | def _format_transform_repr(self, transform, head):
51 | lines = transform.__repr__().splitlines()
52 | return (["{}{}".format(head, lines[0])] +
53 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
54 |
55 | def extra_repr(self):
56 | return ""
57 |
58 |
59 | class StandardTransform(object):
60 | def __init__(self, transform=None, target_transform=None):
61 | self.transform = transform
62 | self.target_transform = target_transform
63 |
64 | def __call__(self, input, target):
65 | if self.transform is not None:
66 | input = self.transform(input)
67 | if self.target_transform is not None:
68 | target = self.target_transform(target)
69 | return input, target
70 |
71 | def _format_transform_repr(self, transform, head):
72 | lines = transform.__repr__().splitlines()
73 | return (["{}{}".format(head, lines[0])] +
74 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
75 |
76 | def __repr__(self):
77 | body = [self.__class__.__name__]
78 | if self.transform is not None:
79 | body += self._format_transform_repr(self.transform,
80 | "Transform: ")
81 | if self.target_transform is not None:
82 | body += self._format_transform_repr(self.target_transform,
83 | "Target transform: ")
84 |
85 | return '\n'.join(body)
--------------------------------------------------------------------------------
/experiments/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/.DS_Store
--------------------------------------------------------------------------------
/experiments/Clarification_on_tasks.md:
--------------------------------------------------------------------------------
1 | # Clarification on tasks
2 | There are so many experiments associated with this project, so its quite easy to run experiments with the
3 | wrong code. So we make a clarification here.
4 |
5 | ## Examples
6 | We provide some script examples in the 'run_temp.sh', including scripts of:
7 | 1. cloded set uda
8 | 2. partial da
9 | 3. opent set da
10 | 4. SymmNets-V2 Strengthened for Closed Set UDA
11 |
12 | We also provide the corresponding log file in the ./experiments. More specifically,
13 | 1. The results of McDalNets based on the c2i of ImageCLEF and Visda are stored in ./experiments/ckpt
14 | 2. The results of partial da and open set da are stored in ./experiments/Partial and ./experiments/Open, respectively.
15 | 3. The reuslts of Symmnets and Symmnets-SC are strored in ./experiments/SymmNets
16 |
17 | ## Configs
18 | We provide configs for example experiments in the ./experiments/configs.
19 | The rule of name is:
20 | 1. 'dataset' _
21 | 2. 'train' _
22 | 3. 'source domain' 2 'target domain'
23 | 4. _ 'task settings (closed set for default | partial | open)'
24 | 5. _cfg
25 | 6. SC (only for Symmnets-SC)
26 |
27 | Note that the configs for different tasks settings may be different.
28 |
29 | ## Solver
30 | The solvers for all task settints are provided in the ./solver
31 |
--------------------------------------------------------------------------------
/experiments/SymmNets/logVisDA_train2validation_SymmNetsV2_None/log.txt:
--------------------------------------------------------------------------------
1 |
2 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "None"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "TASK": "closed", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_SymmNetsV2_None", "SAVE_DIR": "./experiments/SymmNets/logVisDA_train2validation_SymmNetsV2_None", "NUM_WORKERS": 8, "PRINT_STEP": 3}
3 |
4 | Train:epoch: 0:[432/432], LossCla: 2.754657, LossFeat: 1.551413, AccFs: 85.244865, AccFt: 82.904732
5 | Test:epoch: 0, AccFs: 0.519134, AccFt: 0.588042
6 | Class-wise Acc of Ft:1st: 0.738069, 2nd: 0.423309, 3rd: 0.680597, 4th: 0.608307, 5th: 0.794074, 6th: 0.342169, 7th: 0.839890, 8th: 0.527250, 9th: 0.733568, 10th: 0.301622, 11th: 0.936969, 12th: 0.130678 Best Acc so far: 0.588042
7 | Train:epoch: 1:[432/432], LossCla: 1.635495, LossFeat: 1.372633, AccFs: 92.420792, AccFt: 91.610603
8 | Test:epoch: 1, AccFs: 0.599373, AccFt: 0.639031
9 | Class-wise Acc of Ft:1st: 0.831322, 2nd: 0.523453, 3rd: 0.690832, 4th: 0.563311, 5th: 0.797911, 6th: 0.400482, 7th: 0.841615, 8th: 0.746250, 9th: 0.762805, 10th: 0.271372, 11th: 0.929178, 12th: 0.309841 Best Acc so far: 0.639031
10 | Train:epoch: 2:[432/432], LossCla: 1.503049, LossFeat: 1.370504, AccFs: 93.858505, AccFt: 93.198425
11 | Test:epoch: 2, AccFs: 0.647094, AccFt: 0.670877
12 | Class-wise Acc of Ft:1st: 0.877126, 2nd: 0.494676, 3rd: 0.696375, 4th: 0.598308, 5th: 0.872522, 6th: 0.639036, 7th: 0.883023, 8th: 0.750500, 9th: 0.813585, 10th: 0.290224, 11th: 0.925165, 12th: 0.209986 Best Acc so far: 0.670877
13 | Train:epoch: 3:[432/432], LossCla: 1.420064, LossFeat: 1.361121, AccFs: 95.155167, AccFt: 94.270836
14 | Test:epoch: 3, AccFs: 0.649891, AccFt: 0.682259
15 | Class-wise Acc of Ft:1st: 0.862041, 2nd: 0.507050, 3rd: 0.713859, 4th: 0.601000, 5th: 0.872735, 6th: 0.683855, 7th: 0.874741, 8th: 0.759500, 9th: 0.799077, 10th: 0.289347, 11th: 0.924693, 12th: 0.299207 Best Acc so far: 0.682259
16 | Train:epoch: 4:[432/432], LossCla: 1.366017, LossFeat: 1.360252, AccFs: 95.775467, AccFt: 94.977936
17 | Test:epoch: 4, AccFs: 0.672196, AccFt: 0.688111
18 | Class-wise Acc of Ft:1st: 0.859572, 2nd: 0.545324, 3rd: 0.715139, 4th: 0.648880, 5th: 0.864848, 6th: 0.631807, 7th: 0.878019, 8th: 0.801500, 9th: 0.844142, 10th: 0.299868, 11th: 0.913362, 12th: 0.254867 Best Acc so far: 0.688111
19 | Train:epoch: 5:[432/432], LossCla: 1.349571, LossFeat: 1.359669, AccFs: 96.155235, AccFt: 95.668762
20 | Test:epoch: 5, AccFs: 0.678105, AccFt: 0.690267
21 | Class-wise Acc of Ft:1st: 0.846133, 2nd: 0.551655, 3rd: 0.736247, 4th: 0.672531, 5th: 0.856321, 6th: 0.645301, 7th: 0.887336, 8th: 0.794250, 9th: 0.846560, 10th: 0.296361, 11th: 0.903211, 12th: 0.247296 Best Acc so far: 0.690267
22 | Train:epoch: 6:[432/432], LossCla: 1.332122, LossFeat: 1.351513, AccFs: 96.715858, AccFt: 96.160660
23 | Test:epoch: 6, AccFs: 0.690127, AccFt: 0.698896
24 | Class-wise Acc of Ft:1st: 0.854635, 2nd: 0.584173, 3rd: 0.750320, 4th: 0.670609, 5th: 0.884886, 6th: 0.630843, 7th: 0.902346, 8th: 0.785250, 9th: 0.858211, 10th: 0.303376, 11th: 0.886686, 12th: 0.275415 Best Acc so far: 0.698896
25 | Train:epoch: 7:[432/432], LossCla: 1.328282, LossFeat: 1.341111, AccFs: 97.158928, AccFt: 96.784576
26 | Test:epoch: 7, AccFs: 0.692217, AccFt: 0.703444
27 | Class-wise Acc of Ft:1st: 0.862041, 2nd: 0.560576, 3rd: 0.725586, 4th: 0.638304, 5th: 0.882115, 6th: 0.688675, 7th: 0.886128, 8th: 0.796000, 9th: 0.876017, 10th: 0.337133, 11th: 0.901086, 12th: 0.287671 Best Acc so far: 0.703444
28 | Train:epoch: 8:[432/432], LossCla: 1.344697, LossFeat: 1.339598, AccFs: 97.298180, AccFt: 97.012444
29 | Test:epoch: 8, AccFs: 0.699641, AccFt: 0.711915
30 | Class-wise Acc of Ft:1st: 0.874931, 2nd: 0.679712, 3rd: 0.745842, 4th: 0.639169, 5th: 0.872735, 6th: 0.667952, 7th: 0.877674, 8th: 0.811750, 9th: 0.877555, 10th: 0.335818, 11th: 0.898489, 12th: 0.261355 Best Acc so far: 0.711915
31 | Train:epoch: 9:[432/432], LossCla: 1.335095, LossFeat: 1.324822, AccFs: 97.696037, AccFt: 97.350624
32 | Test:epoch: 9, AccFs: 0.703786, AccFt: 0.711389
33 | Class-wise Acc of Ft:1st: 0.874657, 2nd: 0.669353, 3rd: 0.749680, 4th: 0.667820, 5th: 0.859731, 6th: 0.617831, 7th: 0.885783, 8th: 0.802250, 9th: 0.880633, 10th: 0.372205, 11th: 0.888338, 12th: 0.268385
34 | Train:epoch: 10:[432/432], LossCla: 1.346777, LossFeat: 1.320226, AccFs: 97.685188, AccFt: 97.417534
35 | Test:epoch: 10, AccFs: 0.712391, AccFt: 0.717691
36 | Class-wise Acc of Ft:1st: 0.873012, 2nd: 0.709065, 3rd: 0.751599, 4th: 0.666282, 5th: 0.850352, 6th: 0.635663, 7th: 0.858696, 8th: 0.784000, 9th: 0.911629, 10th: 0.432267, 11th: 0.894240, 12th: 0.245494 Best Acc so far: 0.717691
37 | Train:epoch: 11:[432/432], LossCla: 1.343779, LossFeat: 1.302535, AccFs: 97.918472, AccFt: 97.652634
38 | Test:epoch: 11, AccFs: 0.699865, AccFt: 0.707937
39 | Class-wise Acc of Ft:1st: 0.864783, 2nd: 0.634532, 3rd: 0.753518, 4th: 0.699164, 5th: 0.855042, 6th: 0.616386, 7th: 0.891994, 8th: 0.741750, 9th: 0.872939, 10th: 0.463393, 11th: 0.885269, 12th: 0.216474
40 | Train:epoch: 12:[432/432], LossCla: 1.353641, LossFeat: 1.292348, AccFs: 97.931137, AccFt: 97.688805
41 | Test:epoch: 12, AccFs: 0.701588, AccFt: 0.705192
42 | Class-wise Acc of Ft:1st: 0.844213, 2nd: 0.669640, 3rd: 0.751386, 4th: 0.692145, 5th: 0.840119, 6th: 0.666024, 7th: 0.860593, 8th: 0.762250, 9th: 0.841284, 10th: 0.443665, 11th: 0.902266, 12th: 0.188717
43 | Train:epoch: 13:[432/432], LossCla: 1.342840, LossFeat: 1.280037, AccFs: 98.263893, AccFt: 97.963684
44 | Test:epoch: 13, AccFs: 0.709544, AccFt: 0.708189
45 | Class-wise Acc of Ft:1st: 0.845584, 2nd: 0.605755, 3rd: 0.749680, 4th: 0.673108, 5th: 0.844383, 6th: 0.728193, 7th: 0.861111, 8th: 0.763750, 9th: 0.903495, 10th: 0.419991, 11th: 0.903683, 12th: 0.199531
46 | Train:epoch: 14:[432/432], LossCla: 1.344371, LossFeat: 1.279572, AccFs: 98.209633, AccFt: 97.987198
47 | Test:epoch: 14, AccFs: 0.716340, AccFt: 0.712602
48 | Class-wise Acc of Ft:1st: 0.820351, 2nd: 0.650935, 3rd: 0.768657, 4th: 0.697241, 5th: 0.832658, 6th: 0.746506, 7th: 0.849034, 8th: 0.758000, 9th: 0.880633, 10th: 0.469969, 11th: 0.888338, 12th: 0.188897
49 | Train:epoch: 15:[432/432], LossCla: 1.346274, LossFeat: 1.272384, AccFs: 98.457397, AccFt: 98.126450
50 | Test:epoch: 15, AccFs: 0.712162, AccFt: 0.708201
51 | Class-wise Acc of Ft:1st: 0.843664, 2nd: 0.557410, 3rd: 0.772281, 4th: 0.678396, 5th: 0.826903, 6th: 0.747952, 7th: 0.873361, 8th: 0.739750, 9th: 0.873599, 10th: 0.483560, 11th: 0.893532, 12th: 0.208003
52 | Train:epoch: 16:[432/432], LossCla: 1.345689, LossFeat: 1.257215, AccFs: 98.564095, AccFt: 98.260269
53 | Test:epoch: 16, AccFs: 0.712009, AccFt: 0.705590
54 | Class-wise Acc of Ft:1st: 0.797038, 2nd: 0.629353, 3rd: 0.774627, 4th: 0.679262, 5th: 0.829247, 6th: 0.701205, 7th: 0.847654, 8th: 0.751500, 9th: 0.859530, 10th: 0.508549, 11th: 0.895892, 12th: 0.193223
55 | Train:epoch: 17:[432/432], LossCla: 1.348905, LossFeat: 1.249159, AccFs: 98.661751, AccFt: 98.365166
56 | Test:epoch: 17, AccFs: 0.716529, AccFt: 0.709276
57 | Class-wise Acc of Ft:1st: 0.794569, 2nd: 0.633669, 3rd: 0.785501, 4th: 0.690607, 5th: 0.792368, 6th: 0.676145, 7th: 0.860766, 8th: 0.732500, 9th: 0.898000, 10th: 0.525647, 11th: 0.886686, 12th: 0.234859
58 | Train:epoch: 18:[432/432], LossCla: 1.348393, LossFeat: 1.240162, AccFs: 98.710579, AccFt: 98.462822
59 | Test:epoch: 18, AccFs: 0.714328, AccFt: 0.709291
60 | Class-wise Acc of Ft:1st: 0.845310, 2nd: 0.524892, 3rd: 0.768657, 4th: 0.696952, 5th: 0.825197, 6th: 0.695904, 7th: 0.893375, 8th: 0.751250, 9th: 0.864805, 10th: 0.542744, 11th: 0.885033, 12th: 0.217376
61 | Train:epoch: 19:[432/432], LossCla: 1.341572, LossFeat: 1.234422, AccFs: 98.885994, AccFt: 98.594833
62 | Test:epoch: 19, AccFs: 0.707589, AccFt: 0.696340
63 | Class-wise Acc of Ft:1st: 0.782501, 2nd: 0.585036, 3rd: 0.765672, 4th: 0.680127, 5th: 0.797485, 6th: 0.673735, 7th: 0.880090, 8th: 0.731500, 9th: 0.856672, 10th: 0.508110, 11th: 0.910765, 12th: 0.184391
64 | Train:epoch: 20:[432/432], LossCla: 1.345093, LossFeat: 1.230574, AccFs: 98.878761, AccFt: 98.614731
65 | Test:epoch: 20, AccFs: 0.723617, AccFt: 0.713527
66 | Class-wise Acc of Ft:1st: 0.818431, 2nd: 0.611511, 3rd: 0.766524, 4th: 0.696375, 5th: 0.797485, 6th: 0.732530, 7th: 0.886473, 8th: 0.718750, 9th: 0.881073, 10th: 0.539676, 11th: 0.902974, 12th: 0.210526 Best Acc so far: 0.723617
67 | Train:epoch: 21:[432/432], LossCla: 1.345878, LossFeat: 1.223302, AccFs: 98.931206, AccFt: 98.656326
68 | Test:epoch: 21, AccFs: 0.717828, AccFt: 0.709226
69 | Class-wise Acc of Ft:1st: 0.781953, 2nd: 0.615540, 3rd: 0.799787, 4th: 0.647918, 5th: 0.831592, 6th: 0.708434, 7th: 0.834714, 8th: 0.722250, 9th: 0.870081, 10th: 0.607628, 11th: 0.904627, 12th: 0.186193
70 | Train:epoch: 22:[432/432], LossCla: 1.346019, LossFeat: 1.226002, AccFs: 99.005356, AccFt: 98.692490
71 | Test:epoch: 22, AccFs: 0.711813, AccFt: 0.699680
72 | Class-wise Acc of Ft:1st: 0.801700, 2nd: 0.569784, 3rd: 0.772068, 4th: 0.677435, 5th: 0.817310, 6th: 0.704578, 7th: 0.890959, 8th: 0.687500, 9th: 0.848098, 10th: 0.524770, 11th: 0.903683, 12th: 0.198270
73 | Train:epoch: 23:[432/432], LossCla: 1.345264, LossFeat: 1.223078, AccFs: 99.099396, AccFt: 98.853447
74 | Test:epoch: 23, AccFs: 0.705851, AccFt: 0.694408
75 | Class-wise Acc of Ft:1st: 0.756445, 2nd: 0.572086, 3rd: 0.785501, 4th: 0.678877, 5th: 0.791729, 6th: 0.684819, 7th: 0.864044, 8th: 0.663750, 9th: 0.872060, 10th: 0.539676, 11th: 0.898961, 12th: 0.224946
76 | Train:epoch: 24:[432/432], LossCla: 1.341958, LossFeat: 1.216860, AccFs: 99.148224, AccFt: 98.867912
77 | Test:epoch: 24, AccFs: 0.723894, AccFt: 0.714709
78 | Class-wise Acc of Ft:1st: 0.783873, 2nd: 0.568345, 3rd: 0.778678, 4th: 0.689068, 5th: 0.829247, 6th: 0.746506, 7th: 0.884748, 8th: 0.697000, 9th: 0.883931, 10th: 0.574748, 11th: 0.889282, 12th: 0.251081 Best Acc so far: 0.723894
79 | Train:epoch: 25:[432/432], LossCla: 1.332571, LossFeat: 1.214391, AccFs: 99.168114, AccFt: 98.923973
80 | Test:epoch: 25, AccFs: 0.702965, AccFt: 0.693480
81 | Class-wise Acc of Ft:1st: 0.745200, 2nd: 0.549928, 3rd: 0.769723, 4th: 0.616768, 5th: 0.819868, 6th: 0.718072, 7th: 0.893720, 8th: 0.664750, 9th: 0.897780, 10th: 0.505918, 11th: 0.910765, 12th: 0.229272
82 | Train:epoch: 26:[432/432], LossCla: 1.329841, LossFeat: 1.215475, AccFs: 99.144608, AccFt: 98.972801
83 | Test:epoch: 26, AccFs: 0.718568, AccFt: 0.707460
84 | Class-wise Acc of Ft:1st: 0.764948, 2nd: 0.593669, 3rd: 0.784648, 4th: 0.677146, 5th: 0.791303, 6th: 0.762892, 7th: 0.869220, 8th: 0.690750, 9th: 0.903495, 10th: 0.533538, 11th: 0.875118, 12th: 0.242790
85 | Train:epoch: 27:[432/432], LossCla: 1.329338, LossFeat: 1.212966, AccFs: 99.200668, AccFt: 98.956528
86 | Test:epoch: 27, AccFs: 0.716024, AccFt: 0.706511
87 | Class-wise Acc of Ft:1st: 0.758914, 2nd: 0.635108, 3rd: 0.792111, 4th: 0.683300, 5th: 0.802388, 6th: 0.738795, 7th: 0.871463, 8th: 0.692500, 9th: 0.877555, 10th: 0.497589, 11th: 0.882200, 12th: 0.246215
88 | Train:epoch: 28:[432/432], LossCla: 1.328192, LossFeat: 1.203063, AccFs: 99.339920, AccFt: 99.083115
89 | Test:epoch: 28, AccFs: 0.710096, AccFt: 0.700704
90 | Class-wise Acc of Ft:1st: 0.802249, 2nd: 0.567770, 3rd: 0.795096, 4th: 0.675416, 5th: 0.799616, 6th: 0.671807, 7th: 0.892685, 8th: 0.663250, 9th: 0.910310, 10th: 0.530469, 11th: 0.887630, 12th: 0.212149
91 | Train:epoch: 29:[432/432], LossCla: 1.329534, LossFeat: 1.206706, AccFs: 99.336304, AccFt: 99.115669
92 | Test:epoch: 29, AccFs: 0.721606, AccFt: 0.709754
93 | Class-wise Acc of Ft:1st: 0.792924, 2nd: 0.633669, 3rd: 0.791045, 4th: 0.668493, 5th: 0.822000, 6th: 0.735422, 7th: 0.881125, 8th: 0.666250, 9th: 0.864586, 10th: 0.555897, 11th: 0.899433, 12th: 0.206200
94 | Train:epoch: 30:[432/432], LossCla: 1.331311, LossFeat: 1.210111, AccFs: 99.381508, AccFt: 99.055992
95 | Test:epoch: 30, AccFs: 0.710322, AccFt: 0.697093
96 | Class-wise Acc of Ft:1st: 0.782776, 2nd: 0.529784, 3rd: 0.742217, 4th: 0.669647, 5th: 0.795992, 6th: 0.749880, 7th: 0.891994, 8th: 0.671250, 9th: 0.843042, 10th: 0.544498, 11th: 0.919263, 12th: 0.224766
97 | Train:epoch: 31:[432/432], LossCla: 1.326631, LossFeat: 1.207292, AccFs: 99.415871, AccFt: 99.198860
98 | Test:epoch: 31, AccFs: 0.711356, AccFt: 0.698533
99 | Class-wise Acc of Ft:1st: 0.820351, 2nd: 0.520000, 3rd: 0.770149, 4th: 0.675320, 5th: 0.760392, 6th: 0.761928, 7th: 0.868012, 8th: 0.669250, 9th: 0.819741, 10th: 0.562473, 11th: 0.909112, 12th: 0.245674
100 |
--------------------------------------------------------------------------------
/experiments/SymmNets/logres101VisDA_train2validation_SymmNetsV2_None/log.txt:
--------------------------------------------------------------------------------
1 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/data1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet101", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 40, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "None"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "TASK": "closed", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logres101\rVisDA_train2validation_SymmNetsV2_None", "SAVE_DIR": "./experiments/SymmNets/logres101\rVisDA_train2validation_SymmNetsV2_None", "NUM_WORKERS": 8, "PRINT_STEP": 3}
2 |
3 | Train:epoch: 0:[432/432], LossCla: 2.581147, LossFeat: 1.475838, AccFs: 86.642799, AccFt: 84.563080
4 | Test:epoch: 0, AccFs: 0.603141, AccFt: 0.649918
5 | Class-wise Acc of Ft:1st: 0.873834, 2nd: 0.471367, 3rd: 0.728998, 4th: 0.609557, 5th: 0.770625, 6th: 0.629398, 7th: 0.898378, 8th: 0.671500, 9th: 0.848978, 10th: 0.278825, 11th: 0.919500, 12th: 0.098053 Best Acc so far: 0.649918
6 | Train:epoch: 1:[432/432], LossCla: 1.556984, LossFeat: 1.260458, AccFs: 93.570961, AccFt: 92.699295
7 | Test:epoch: 1, AccFs: 0.654738, AccFt: 0.687854
8 | Class-wise Acc of Ft:1st: 0.910861, 2nd: 0.515683, 3rd: 0.772281, 4th: 0.649168, 5th: 0.828182, 6th: 0.786988, 7th: 0.874396, 8th: 0.766500, 9th: 0.854254, 10th: 0.269619, 11th: 0.906988, 12th: 0.119322 Best Acc so far: 0.687854
9 | Train:epoch: 2:[432/432], LossCla: 1.446185, LossFeat: 1.245270, AccFs: 94.833260, AccFt: 94.019463
10 | Test:epoch: 2, AccFs: 0.692207, AccFt: 0.715441
11 | Class-wise Acc of Ft:1st: 0.930060, 2nd: 0.544460, 3rd: 0.795949, 4th: 0.701183, 5th: 0.855255, 6th: 0.883855, 7th: 0.879572, 8th: 0.800000, 9th: 0.834469, 10th: 0.282332, 11th: 0.870869, 12th: 0.207282 Best Acc so far: 0.715441
12 | Train:epoch: 3:[432/432], LossCla: 1.365695, LossFeat: 1.234548, AccFs: 96.044922, AccFt: 95.247398
13 | Test:epoch: 3, AccFs: 0.715807, AccFt: 0.732017
14 | Class-wise Acc of Ft:1st: 0.921009, 2nd: 0.606043, 3rd: 0.826013, 4th: 0.700317, 5th: 0.887657, 6th: 0.899759, 7th: 0.879917, 8th: 0.789750, 9th: 0.864146, 10th: 0.330995, 11th: 0.869688, 12th: 0.208904 Best Acc so far: 0.732017
15 | Train:epoch: 4:[432/432], LossCla: 1.332583, LossFeat: 1.235615, AccFs: 96.475334, AccFt: 95.826103
16 | Test:epoch: 4, AccFs: 0.734888, AccFt: 0.746103
17 | Class-wise Acc of Ft:1st: 0.935546, 2nd: 0.682014, 3rd: 0.808102, 4th: 0.687818, 5th: 0.879130, 6th: 0.924337, 7th: 0.884403, 8th: 0.802250, 9th: 0.858430, 10th: 0.370452, 11th: 0.894004, 12th: 0.226748 Best Acc so far: 0.746103
18 | Train:epoch: 5:[432/432], LossCla: 1.293805, LossFeat: 1.245844, AccFs: 97.010635, AccFt: 96.522354
19 | Test:epoch: 5, AccFs: 0.742592, AccFt: 0.754611
20 | Class-wise Acc of Ft:1st: 0.934174, 2nd: 0.717122, 3rd: 0.780384, 4th: 0.722431, 5th: 0.901087, 6th: 0.900723, 7th: 0.866287, 8th: 0.794500, 9th: 0.901297, 10th: 0.382727, 11th: 0.906043, 12th: 0.248558 Best Acc so far: 0.754611
21 | Train:epoch: 6:[432/432], LossCla: 1.282527, LossFeat: 1.256809, AccFs: 97.384987, AccFt: 96.965424
22 | Test:epoch: 6, AccFs: 0.740717, AccFt: 0.748055
23 | Class-wise Acc of Ft:1st: 0.923478, 2nd: 0.649784, 3rd: 0.789126, 4th: 0.731564, 5th: 0.884033, 6th: 0.933976, 7th: 0.860593, 8th: 0.815750, 9th: 0.868762, 10th: 0.382727, 11th: 0.899670, 12th: 0.237203
24 | Train:epoch: 7:[432/432], LossCla: 1.289282, LossFeat: 1.270966, AccFs: 97.674332, AccFt: 97.225838
25 | Test:epoch: 7, AccFs: 0.750884, AccFt: 0.749084
26 | Class-wise Acc of Ft:1st: 0.918267, 2nd: 0.710791, 3rd: 0.799787, 4th: 0.736948, 5th: 0.885739, 6th: 0.926747, 7th: 0.857660, 8th: 0.814500, 9th: 0.875797, 10th: 0.329680, 11th: 0.879131, 12th: 0.253965
27 | Train:epoch: 8:[432/432], LossCla: 1.315241, LossFeat: 1.275234, AccFs: 97.940178, AccFt: 97.616463
28 | Test:epoch: 8, AccFs: 0.754999, AccFt: 0.750322
29 | Class-wise Acc of Ft:1st: 0.903456, 2nd: 0.697266, 3rd: 0.808955, 4th: 0.719162, 5th: 0.897890, 6th: 0.918554, 7th: 0.833678, 8th: 0.825500, 9th: 0.881952, 10th: 0.376151, 11th: 0.874174, 12th: 0.267123 Best Acc so far: 0.754999
30 | Train:epoch: 9:[432/432], LossCla: 1.322757, LossFeat: 1.267201, AccFs: 98.102936, AccFt: 97.746674
31 | Test:epoch: 9, AccFs: 0.762605, AccFt: 0.755622
32 | Class-wise Acc of Ft:1st: 0.923478, 2nd: 0.715683, 3rd: 0.783795, 4th: 0.698106, 5th: 0.889789, 6th: 0.887711, 7th: 0.848344, 8th: 0.834750, 9th: 0.897560, 10th: 0.417361, 11th: 0.885741, 12th: 0.285148 Best Acc so far: 0.762605
33 | Train:epoch: 10:[432/432], LossCla: 1.325336, LossFeat: 1.256211, AccFs: 98.238571, AccFt: 97.976349
34 | Test:epoch: 10, AccFs: 0.757504, AccFt: 0.755875
35 | Class-wise Acc of Ft:1st: 0.925123, 2nd: 0.690647, 3rd: 0.812793, 4th: 0.663494, 5th: 0.879343, 6th: 0.888193, 7th: 0.857660, 8th: 0.834250, 9th: 0.885909, 10th: 0.481368, 11th: 0.886213, 12th: 0.265501
36 | Train:epoch: 11:[432/432], LossCla: 1.322062, LossFeat: 1.247246, AccFs: 98.379631, AccFt: 98.149956
37 | Test:epoch: 11, AccFs: 0.763305, AccFt: 0.760014
38 | Class-wise Acc of Ft:1st: 0.908393, 2nd: 0.668777, 3rd: 0.798507, 4th: 0.675993, 5th: 0.889789, 6th: 0.891084, 7th: 0.870945, 8th: 0.835500, 9th: 0.896461, 10th: 0.495397, 11th: 0.886686, 12th: 0.302632 Best Acc so far: 0.763305
39 | Train:epoch: 12:[432/432], LossCla: 1.315084, LossFeat: 1.236000, AccFs: 98.649086, AccFt: 98.348885
40 | Test:epoch: 12, AccFs: 0.764516, AccFt: 0.760719
41 | Class-wise Acc of Ft:1st: 0.910861, 2nd: 0.688921, 3rd: 0.807889, 4th: 0.666667, 5th: 0.888084, 6th: 0.846265, 7th: 0.879400, 8th: 0.843000, 9th: 0.904375, 10th: 0.531346, 11th: 0.882436, 12th: 0.279380 Best Acc so far: 0.764516
42 | Train:epoch: 13:[432/432], LossCla: 1.321137, LossFeat: 1.231841, AccFs: 98.696106, AccFt: 98.520691
43 | Test:epoch: 13, AccFs: 0.763517, AccFt: 0.759895
44 | Class-wise Acc of Ft:1st: 0.895502, 2nd: 0.685180, 3rd: 0.816844, 4th: 0.690607, 5th: 0.910680, 6th: 0.818313, 7th: 0.865942, 8th: 0.828500, 9th: 0.892944, 10th: 0.568610, 11th: 0.871105, 12th: 0.274513
45 | Train:epoch: 14:[432/432], LossCla: 1.325543, LossFeat: 1.226663, AccFs: 98.687065, AccFt: 98.423035
46 | Test:epoch: 14, AccFs: 0.755246, AccFt: 0.752819
47 | Class-wise Acc of Ft:1st: 0.872189, 2nd: 0.682302, 3rd: 0.788060, 4th: 0.644457, 5th: 0.883181, 6th: 0.749398, 7th: 0.862836, 8th: 0.834250, 9th: 0.895801, 10th: 0.644893, 11th: 0.906280, 12th: 0.270187
48 | Train:epoch: 15:[432/432], LossCla: 1.317506, LossFeat: 1.212464, AccFs: 98.858871, AccFt: 98.696106
49 | Test:epoch: 15, AccFs: 0.755130, AccFt: 0.753893
50 | Class-wise Acc of Ft:1st: 0.891114, 2nd: 0.638273, 3rd: 0.817697, 4th: 0.650611, 5th: 0.896824, 6th: 0.786506, 7th: 0.863527, 8th: 0.818500, 9th: 0.918663, 10th: 0.644454, 11th: 0.888574, 12th: 0.231975
51 | Train:epoch: 16:[432/432], LossCla: 1.315734, LossFeat: 1.204639, AccFs: 98.891418, AccFt: 98.810043
52 | Test:epoch: 16, AccFs: 0.761036, AccFt: 0.761904
53 | Class-wise Acc of Ft:1st: 0.900987, 2nd: 0.680000, 3rd: 0.811087, 4th: 0.622825, 5th: 0.908335, 6th: 0.781687, 7th: 0.849206, 8th: 0.805250, 9th: 0.913168, 10th: 0.705831, 11th: 0.902030, 12th: 0.262437
54 | Train:epoch: 17:[432/432], LossCla: 1.318342, LossFeat: 1.198416, AccFs: 98.999931, AccFt: 98.837166
55 | Test:epoch: 17, AccFs: 0.754808, AccFt: 0.756363
56 | Class-wise Acc of Ft:1st: 0.896325, 2nd: 0.672806, 3rd: 0.795309, 4th: 0.582829, 5th: 0.896397, 6th: 0.783614, 7th: 0.857488, 8th: 0.798000, 9th: 0.894922, 10th: 0.718106, 11th: 0.912181, 12th: 0.268385
57 | Train:epoch: 18:[432/432], LossCla: 1.305724, LossFeat: 1.187242, AccFs: 99.189819, AccFt: 99.005356
58 | Test:epoch: 18, AccFs: 0.759740, AccFt: 0.759087
59 |
--------------------------------------------------------------------------------
/experiments/ckpt/logVisDA_train2validation_McDalNet_CE/log.txt:
--------------------------------------------------------------------------------
1 |
2 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "CE"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_McDalNet_CE", "SAVE_DIR": "./experiments/ckpt/logVisDA_train2validation_McDalNet_CE", "NUM_WORKERS": 8, "PRINT_STEP": 3}
3 |
4 | Train:epoch: 0:[432/432], LossCE: 0.690750, LossDA: -0.058507, LossAll: 2.029814, Auxi1: 82.930046, Auxi2: 82.752823, Task: 82.785378
5 | Test:epoch: 0, Top1_auxi1: 0.558361, Top1_auxi2: 0.552648, Top1: 0.556366
6 | Class-wise Acc:1st: 0.775919, 2nd: 0.379281, 3rd: 0.638806, 4th: 0.651572, 5th: 0.636751, 6th: 0.265542, 7th: 0.823844, 8th: 0.322250, 9th: 0.827874, 10th: 0.292416, 11th: 0.882436, 12th: 0.179704 Best Acc so far: 0.558361
7 | Train:epoch: 1:[432/432], LossCE: 0.244714, LossDA: -0.080817, LossAll: 0.664051, Auxi1: 92.214630, Auxi2: 92.138672, Task: 92.294197
8 | Test:epoch: 1, Top1_auxi1: 0.558922, Top1_auxi2: 0.526064, Top1: 0.563260
9 | Class-wise Acc:1st: 0.739166, 2nd: 0.392518, 3rd: 0.631557, 4th: 0.621575, 5th: 0.633340, 6th: 0.272289, 7th: 0.858351, 8th: 0.322000, 9th: 0.843482, 10th: 0.338886, 11th: 0.911473, 12th: 0.194485 Best Acc so far: 0.563260
10 | Train:epoch: 2:[432/432], LossCE: 0.183937, LossDA: -0.185786, LossAll: 0.430443, Auxi1: 93.030235, Auxi2: 93.485970, Task: 93.945312
11 | Test:epoch: 2, Top1_auxi1: 0.366009, Top1_auxi2: 0.437074, Top1: 0.582062
12 | Class-wise Acc:1st: 0.777839, 2nd: 0.392230, 3rd: 0.694883, 4th: 0.619556, 5th: 0.687913, 6th: 0.267952, 7th: 0.869565, 8th: 0.350500, 9th: 0.853814, 10th: 0.322665, 11th: 0.881964, 12th: 0.265862 Best Acc so far: 0.582062
13 | Train:epoch: 3:[432/432], LossCE: 0.145429, LossDA: -0.456710, LossAll: 0.132864, Auxi1: 92.097076, Auxi2: 94.017654, Task: 95.269096
14 | Test:epoch: 3, Top1_auxi1: 0.250198, Top1_auxi2: 0.457712, Top1: 0.612280
15 | Class-wise Acc:1st: 0.808283, 2nd: 0.411223, 3rd: 0.711301, 4th: 0.684453, 5th: 0.750160, 6th: 0.311325, 7th: 0.879227, 8th: 0.438500, 9th: 0.902176, 10th: 0.325296, 11th: 0.878659, 12th: 0.246756 Best Acc so far: 0.612280
16 | Train:epoch: 4:[432/432], LossCE: 0.143135, LossDA: -0.437714, LossAll: 0.156751, Auxi1: 92.183884, Auxi2: 94.439018, Task: 95.659721
17 | Test:epoch: 4, Top1_auxi1: 0.428402, Top1_auxi2: 0.617549, Top1: 0.664672
18 | Class-wise Acc:1st: 0.837630, 2nd: 0.542734, 3rd: 0.728145, 4th: 0.646284, 5th: 0.777873, 6th: 0.486747, 7th: 0.880090, 8th: 0.688250, 9th: 0.920202, 10th: 0.312582, 11th: 0.881020, 12th: 0.274513 Best Acc so far: 0.664672
19 | Train:epoch: 5:[432/432], LossCE: 0.132734, LossDA: -0.183213, LossAll: 0.290770, Auxi1: 94.561996, Auxi2: 95.509621, Task: 96.117264
20 | Test:epoch: 5, Top1_auxi1: 0.494570, Top1_auxi2: 0.659047, Top1: 0.679029
21 | Class-wise Acc:1st: 0.864235, 2nd: 0.500432, 3rd: 0.760128, 4th: 0.648207, 5th: 0.831806, 6th: 0.581205, 7th: 0.862491, 8th: 0.727000, 9th: 0.936250, 10th: 0.322665, 11th: 0.875826, 12th: 0.238104 Best Acc so far: 0.679029
22 | Train:epoch: 6:[432/432], LossCE: 0.113248, LossDA: -0.137117, LossAll: 0.251889, Auxi1: 95.701317, Auxi2: 96.173325, Task: 96.656181
23 | Test:epoch: 6, Top1_auxi1: 0.523407, Top1_auxi2: 0.667125, Top1: 0.678395
24 | Class-wise Acc:1st: 0.828579, 2nd: 0.551655, 3rd: 0.741578, 4th: 0.650034, 5th: 0.805159, 6th: 0.581687, 7th: 0.881988, 8th: 0.739750, 9th: 0.950758, 10th: 0.282332, 11th: 0.888574, 12th: 0.238645
25 | Train:epoch: 7:[432/432], LossCE: 0.100754, LossDA: -0.099633, LossAll: 0.238848, Auxi1: 96.310768, Auxi2: 96.567566, Task: 97.039566
26 | Test:epoch: 7, Top1_auxi1: 0.587965, Top1_auxi2: 0.690583, Top1: 0.692860
27 | Class-wise Acc:1st: 0.854361, 2nd: 0.586187, 3rd: 0.756077, 4th: 0.659360, 5th: 0.854402, 6th: 0.654458, 7th: 0.918910, 8th: 0.688750, 9th: 0.947681, 10th: 0.277510, 11th: 0.882672, 12th: 0.233958 Best Acc so far: 0.692860
28 | Train:epoch: 8:[432/432], LossCE: 0.084525, LossDA: -0.065558, LossAll: 0.211381, Auxi1: 97.055847, Auxi2: 97.249352, Task: 97.500725
29 | Test:epoch: 8, Top1_auxi1: 0.608555, Top1_auxi2: 0.693317, Top1: 0.690395
30 | Class-wise Acc:1st: 0.862863, 2nd: 0.578417, 3rd: 0.736034, 4th: 0.669839, 5th: 0.837135, 6th: 0.657831, 7th: 0.906142, 8th: 0.713000, 9th: 0.949659, 10th: 0.277071, 11th: 0.899197, 12th: 0.197549 Best Acc so far: 0.693317
31 | Train:epoch: 9:[432/432], LossCE: 0.071856, LossDA: -0.047084, LossAll: 0.183912, Auxi1: 97.545937, Auxi2: 97.703270, Task: 97.929329
32 | Test:epoch: 9, Top1_auxi1: 0.627759, Top1_auxi2: 0.701944, Top1: 0.696261
33 | Class-wise Acc:1st: 0.866703, 2nd: 0.556259, 3rd: 0.762260, 4th: 0.667820, 5th: 0.852057, 6th: 0.704096, 7th: 0.895790, 8th: 0.723250, 9th: 0.954276, 10th: 0.275756, 11th: 0.892823, 12th: 0.204037 Best Acc so far: 0.701944
34 | Train:epoch: 10:[432/432], LossCE: 0.061862, LossDA: -0.039253, LossAll: 0.158299, Auxi1: 97.900391, Auxi2: 98.066772, Task: 98.202400
35 | Test:epoch: 10, Top1_auxi1: 0.637854, Top1_auxi2: 0.704391, Top1: 0.699586
36 | Class-wise Acc:1st: 0.859846, 2nd: 0.563741, 3rd: 0.767804, 4th: 0.680031, 5th: 0.836069, 6th: 0.735904, 7th: 0.897170, 8th: 0.730750, 9th: 0.951418, 10th: 0.290662, 11th: 0.894004, 12th: 0.187635 Best Acc so far: 0.704391
37 | Train:epoch: 11:[432/432], LossCE: 0.056392, LossDA: -0.033153, LossAll: 0.145713, Auxi1: 98.146339, Auxi2: 98.247612, Task: 98.352501
38 | Test:epoch: 11, Top1_auxi1: 0.652609, Top1_auxi2: 0.707246, Top1: 0.702464
39 | Class-wise Acc:1st: 0.853538, 2nd: 0.575827, 3rd: 0.764606, 4th: 0.689645, 5th: 0.836709, 6th: 0.729639, 7th: 0.896653, 8th: 0.740750, 9th: 0.948120, 10th: 0.298992, 11th: 0.891407, 12th: 0.203677 Best Acc so far: 0.707246
40 | Train:epoch: 12:[432/432], LossCE: 0.050839, LossDA: -0.029478, LossAll: 0.130653, Auxi1: 98.339844, Auxi2: 98.442924, Task: 98.522499
41 | Test:epoch: 12, Top1_auxi1: 0.651128, Top1_auxi2: 0.697198, Top1: 0.692559
42 | Class-wise Acc:1st: 0.852167, 2nd: 0.523741, 3rd: 0.767804, 4th: 0.698491, 5th: 0.833724, 6th: 0.763855, 7th: 0.933057, 8th: 0.724250, 9th: 0.923500, 10th: 0.237177, 11th: 0.896128, 12th: 0.156813
43 | Train:epoch: 13:[432/432], LossCE: 0.044340, LossDA: -0.028602, LossAll: 0.111692, Auxi1: 98.611115, Auxi2: 98.710579, Task: 98.744934
44 | Test:epoch: 13, Top1_auxi1: 0.646009, Top1_auxi2: 0.697454, Top1: 0.690091
45 | Class-wise Acc:1st: 0.838179, 2nd: 0.527770, 3rd: 0.755437, 4th: 0.693106, 5th: 0.828821, 6th: 0.747470, 7th: 0.897170, 8th: 0.727500, 9th: 0.946362, 10th: 0.244630, 11th: 0.905571, 12th: 0.169070
46 | Train:epoch: 14:[432/432], LossCE: 0.041508, LossDA: -0.027529, LossAll: 0.104398, Auxi1: 98.697914, Auxi2: 98.743126, Task: 98.831741
47 | Test:epoch: 14, Top1_auxi1: 0.652566, Top1_auxi2: 0.701091, Top1: 0.692790
48 | Class-wise Acc:1st: 0.853538, 2nd: 0.501871, 3rd: 0.770576, 4th: 0.674166, 5th: 0.830313, 6th: 0.754699, 7th: 0.924776, 8th: 0.752250, 9th: 0.945922, 10th: 0.228409, 11th: 0.903919, 12th: 0.173035
49 | Train:epoch: 15:[432/432], LossCE: 0.037630, LossDA: -0.026932, LossAll: 0.092539, Auxi1: 98.786530, Auxi2: 98.885994, Task: 98.972801
50 | Test:epoch: 15, Top1_auxi1: 0.661947, Top1_auxi2: 0.712143, Top1: 0.702932
51 | Class-wise Acc:1st: 0.855458, 2nd: 0.541583, 3rd: 0.785714, 4th: 0.691857, 5th: 0.836709, 6th: 0.788434, 7th: 0.925121, 8th: 0.761500, 9th: 0.952957, 10th: 0.249890, 11th: 0.886449, 12th: 0.159517 Best Acc so far: 0.712143
52 | Train:epoch: 16:[432/432], LossCE: 0.032933, LossDA: -0.027620, LossAll: 0.078184, Auxi1: 98.927589, Auxi2: 99.034286, Task: 99.093971
53 | Test:epoch: 16, Top1_auxi1: 0.658064, Top1_auxi2: 0.708298, Top1: 0.697768
54 | Class-wise Acc:1st: 0.854087, 2nd: 0.524029, 3rd: 0.758209, 4th: 0.676858, 5th: 0.829034, 6th: 0.821205, 7th: 0.912871, 8th: 0.763500, 9th: 0.956034, 10th: 0.203858, 11th: 0.910765, 12th: 0.162761
55 | Train:epoch: 17:[432/432], LossCE: 0.031930, LossDA: -0.028328, LossAll: 0.074900, Auxi1: 98.996315, Auxi2: 99.063225, Task: 99.151840
56 | Test:epoch: 17, Top1_auxi1: 0.665599, Top1_auxi2: 0.713065, Top1: 0.704245
57 | Class-wise Acc:1st: 0.851892, 2nd: 0.554532, 3rd: 0.798934, 4th: 0.702913, 5th: 0.856747, 6th: 0.784096, 7th: 0.924948, 8th: 0.754750, 9th: 0.950099, 10th: 0.224463, 11th: 0.881020, 12th: 0.166547 Best Acc so far: 0.713065
58 | Train:epoch: 18:[432/432], LossCE: 0.028280, LossDA: -0.027890, LossAll: 0.064488, Auxi1: 99.133751, Auxi2: 99.159073, Task: 99.273003
59 | Test:epoch: 18, Top1_auxi1: 0.672141, Top1_auxi2: 0.719737, Top1: 0.709900
60 | Class-wise Acc:1st: 0.881788, 2nd: 0.595108, 3rd: 0.778038, 4th: 0.682627, 5th: 0.833511, 6th: 0.802410, 7th: 0.946342, 8th: 0.757250, 9th: 0.948780, 10th: 0.225778, 11th: 0.896837, 12th: 0.170332 Best Acc so far: 0.719737
61 |
--------------------------------------------------------------------------------
/experiments/ckpt/logVisDA_train2validation_McDalNet_DANN/log.txt:
--------------------------------------------------------------------------------
1 |
2 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "DANN"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_McDalNet_DANN", "SAVE_DIR": "./experiments/ckpt/logVisDA_train2validation_McDalNet_DANN", "NUM_WORKERS": 8, "PRINT_STEP": 3}
3 |
4 | Train:epoch: 0:[432/432], LossCE: 0.689842, LossDA: 1.220195, LossAll: 1.910038, Auxi1: 5.635128, Auxi2: 4.967810, Task: 83.132599
5 | Test:epoch: 0, Top1_auxi1: 0.061966, Top1_auxi2: 0.089132, Top1: 0.497863
6 | Class-wise Acc:1st: 0.684586, 2nd: 0.348489, 3rd: 0.371642, 4th: 0.568888, 5th: 0.566830, 6th: 0.240482, 7th: 0.865942, 8th: 0.211750, 9th: 0.772697, 10th: 0.256028, 11th: 0.958687, 12th: 0.128335 Best Acc so far: 0.497863
7 | Train:epoch: 1:[432/432], LossCE: 0.249252, LossDA: 1.135072, LossAll: 1.384324, Auxi1: 5.407263, Auxi2: 4.445168, Task: 92.261650
8 | Test:epoch: 1, Top1_auxi1: 0.065403, Top1_auxi2: 0.092108, Top1: 0.556021
9 | Class-wise Acc:1st: 0.805815, 2nd: 0.368058, 3rd: 0.598934, 4th: 0.607442, 5th: 0.653805, 6th: 0.180723, 7th: 0.846618, 8th: 0.280750, 9th: 0.832711, 10th: 0.419991, 11th: 0.921860, 12th: 0.155552 Best Acc so far: 0.556021
10 | Train:epoch: 2:[432/432], LossCE: 0.187990, LossDA: 1.307358, LossAll: 1.495347, Auxi1: 5.485026, Auxi2: 4.302300, Task: 93.992332
11 | Test:epoch: 2, Top1_auxi1: 0.063872, Top1_auxi2: 0.096490, Top1: 0.567105
12 | Class-wise Acc:1st: 0.797586, 2nd: 0.421007, 3rd: 0.587846, 4th: 0.601961, 5th: 0.654231, 6th: 0.278072, 7th: 0.840235, 8th: 0.310250, 9th: 0.797978, 10th: 0.427883, 11th: 0.919500, 12th: 0.168709 Best Acc so far: 0.567105
13 | Train:epoch: 3:[432/432], LossCE: 0.152223, LossDA: 1.563822, LossAll: 1.716045, Auxi1: 5.629702, Auxi2: 4.356554, Task: 95.131653
14 | Test:epoch: 3, Top1_auxi1: 0.068933, Top1_auxi2: 0.087752, Top1: 0.565053
15 | Class-wise Acc:1st: 0.771530, 2nd: 0.403453, 3rd: 0.648188, 4th: 0.559850, 5th: 0.662332, 6th: 0.320482, 7th: 0.843513, 8th: 0.283750, 9th: 0.763245, 10th: 0.423498, 11th: 0.914778, 12th: 0.186013
16 | Train:epoch: 4:[432/432], LossCE: 0.130460, LossDA: 1.403738, LossAll: 1.534199, Auxi1: 4.700159, Auxi2: 4.197410, Task: 95.806206
17 | Test:epoch: 4, Top1_auxi1: 0.063369, Top1_auxi2: 0.082462, Top1: 0.567353
18 | Class-wise Acc:1st: 0.763028, 2nd: 0.410360, 3rd: 0.663753, 4th: 0.568215, 5th: 0.681731, 6th: 0.241446, 7th: 0.844203, 8th: 0.344250, 9th: 0.791822, 10th: 0.385796, 11th: 0.915722, 12th: 0.197909 Best Acc so far: 0.567353
19 | Train:epoch: 5:[432/432], LossCE: 0.109417, LossDA: 1.380119, LossAll: 1.489536, Auxi1: 4.079861, Auxi2: 4.309534, Task: 96.551285
20 | Test:epoch: 5, Top1_auxi1: 0.065876, Top1_auxi2: 0.080367, Top1: 0.556004
21 | Class-wise Acc:1st: 0.719967, 2nd: 0.349353, 3rd: 0.687846, 4th: 0.614268, 5th: 0.675549, 6th: 0.221687, 7th: 0.829365, 8th: 0.321750, 9th: 0.760607, 10th: 0.421306, 11th: 0.897309, 12th: 0.173035
22 | Train:epoch: 6:[432/432], LossCE: 0.095272, LossDA: 1.400877, LossAll: 1.496149, Auxi1: 4.456018, Auxi2: 4.295066, Task: 96.934677
23 | Test:epoch: 6, Top1_auxi1: 0.060089, Top1_auxi2: 0.087750, Top1: 0.578876
24 | Class-wise Acc:1st: 0.815963, 2nd: 0.366619, 3rd: 0.700426, 4th: 0.611768, 5th: 0.672351, 6th: 0.248675, 7th: 0.847999, 8th: 0.303500, 9th: 0.834249, 10th: 0.434897, 11th: 0.880076, 12th: 0.229993 Best Acc so far: 0.578876
25 | Train:epoch: 7:[432/432], LossCE: 0.081930, LossDA: 1.395533, LossAll: 1.477463, Auxi1: 4.390914, Auxi2: 3.891783, Task: 97.455513
26 | Test:epoch: 7, Top1_auxi1: 0.067571, Top1_auxi2: 0.082103, Top1: 0.571576
27 | Class-wise Acc:1st: 0.782501, 2nd: 0.386763, 3rd: 0.605970, 4th: 0.588309, 5th: 0.727350, 6th: 0.194217, 7th: 0.849379, 8th: 0.353500, 9th: 0.799956, 10th: 0.410346, 11th: 0.915486, 12th: 0.245133
28 | Train:epoch: 8:[432/432], LossCE: 0.071810, LossDA: 1.377522, LossAll: 1.449332, Auxi1: 4.537399, Auxi2: 4.146774, Task: 97.762947
29 | Test:epoch: 8, Top1_auxi1: 0.067663, Top1_auxi2: 0.079991, Top1: 0.570504
30 | Class-wise Acc:1st: 0.770433, 2nd: 0.358561, 3rd: 0.715778, 4th: 0.583117, 5th: 0.690258, 6th: 0.228916, 7th: 0.847999, 8th: 0.306750, 9th: 0.807210, 10th: 0.412977, 11th: 0.874410, 12th: 0.249640
31 | Train:epoch: 9:[432/432], LossCE: 0.063001, LossDA: 1.389740, LossAll: 1.452742, Auxi1: 3.976779, Auxi2: 4.201027, Task: 98.126450
32 | Test:epoch: 9, Top1_auxi1: 0.063785, Top1_auxi2: 0.080066, Top1: 0.582241
33 | Class-wise Acc:1st: 0.803895, 2nd: 0.416403, 3rd: 0.700853, 4th: 0.591193, 5th: 0.700490, 6th: 0.235663, 7th: 0.853175, 8th: 0.397250, 9th: 0.821719, 10th: 0.371767, 11th: 0.888102, 12th: 0.206381 Best Acc so far: 0.582241
34 | Train:epoch: 10:[432/432], LossCE: 0.057746, LossDA: 1.394590, LossAll: 1.452338, Auxi1: 3.803169, Auxi2: 4.465061, Task: 98.285591
35 | Test:epoch: 10, Top1_auxi1: 0.062671, Top1_auxi2: 0.078781, Top1: 0.574760
36 | Class-wise Acc:1st: 0.764125, 2nd: 0.391655, 3rd: 0.682729, 4th: 0.549947, 5th: 0.716905, 6th: 0.228434, 7th: 0.851622, 8th: 0.378750, 9th: 0.796219, 10th: 0.378781, 11th: 0.893532, 12th: 0.264420
37 | Train:epoch: 11:[432/432], LossCE: 0.051314, LossDA: 1.387914, LossAll: 1.439229, Auxi1: 4.229962, Auxi2: 4.823134, Task: 98.457397
38 | Test:epoch: 11, Top1_auxi1: 0.064863, Top1_auxi2: 0.084203, Top1: 0.584895
39 | Class-wise Acc:1st: 0.778113, 2nd: 0.423022, 3rd: 0.688486, 4th: 0.509182, 5th: 0.735451, 6th: 0.213494, 7th: 0.859041, 8th: 0.439750, 9th: 0.839305, 10th: 0.402017, 11th: 0.885741, 12th: 0.245133 Best Acc so far: 0.584895
40 | Train:epoch: 12:[432/432], LossCE: 0.046685, LossDA: 1.391144, LossAll: 1.437830, Auxi1: 4.184751, Auxi2: 5.063657, Task: 98.647278
41 | Test:epoch: 12, Top1_auxi1: 0.067518, Top1_auxi2: 0.085447, Top1: 0.576178
42 | Class-wise Acc:1st: 0.776193, 2nd: 0.397986, 3rd: 0.689765, 4th: 0.550812, 5th: 0.708165, 6th: 0.168193, 7th: 0.880262, 8th: 0.409500, 9th: 0.808969, 10th: 0.373520, 11th: 0.868508, 12th: 0.282264
43 | Train:epoch: 13:[432/432], LossCE: 0.043811, LossDA: 1.394163, LossAll: 1.437974, Auxi1: 4.036458, Auxi2: 5.079934, Task: 98.703346
44 | Test:epoch: 13, Top1_auxi1: 0.065180, Top1_auxi2: 0.081070, Top1: 0.560906
45 | Class-wise Acc:1st: 0.752606, 2nd: 0.378705, 3rd: 0.678038, 4th: 0.605134, 5th: 0.664677, 6th: 0.151325, 7th: 0.837992, 8th: 0.322000, 9th: 0.830073, 10th: 0.375712, 11th: 0.888574, 12th: 0.246035
46 | Train:epoch: 14:[432/432], LossCE: 0.039512, LossDA: 1.390498, LossAll: 1.430009, Auxi1: 4.132306, Auxi2: 4.503038, Task: 98.862488
47 | Test:epoch: 14, Top1_auxi1: 0.066672, Top1_auxi2: 0.084346, Top1: 0.555455
48 | Class-wise Acc:1st: 0.727647, 2nd: 0.380719, 3rd: 0.681876, 4th: 0.575137, 5th: 0.671499, 6th: 0.205301, 7th: 0.844893, 8th: 0.287000, 9th: 0.799736, 10th: 0.339325, 11th: 0.881964, 12th: 0.270368
49 | Train:epoch: 15:[432/432], LossCE: 0.036059, LossDA: 1.391966, LossAll: 1.428025, Auxi1: 3.616898, Auxi2: 4.685691, Task: 99.019821
50 | Test:epoch: 15, Top1_auxi1: 0.064970, Top1_auxi2: 0.083399, Top1: 0.557222
51 | Class-wise Acc:1st: 0.720516, 2nd: 0.358561, 3rd: 0.682729, 4th: 0.548120, 5th: 0.710083, 6th: 0.215904, 7th: 0.852657, 8th: 0.360500, 9th: 0.770059, 10th: 0.330995, 11th: 0.888338, 12th: 0.248198
52 | Train:epoch: 16:[432/432], LossCE: 0.032596, LossDA: 1.392626, LossAll: 1.425221, Auxi1: 3.915292, Auxi2: 4.548249, Task: 99.079498
53 | Test:epoch: 16, Top1_auxi1: 0.067874, Top1_auxi2: 0.082306, Top1: 0.577458
54 | Class-wise Acc:1st: 0.757268, 2nd: 0.414388, 3rd: 0.726226, 4th: 0.533026, 5th: 0.745896, 6th: 0.216386, 7th: 0.825224, 8th: 0.378500, 9th: 0.803913, 10th: 0.377904, 11th: 0.865439, 12th: 0.285328
55 | Train:epoch: 17:[432/432], LossCE: 0.030033, LossDA: 1.393536, LossAll: 1.423569, Auxi1: 3.618707, Auxi2: 5.087167, Task: 99.209709
56 | Test:epoch: 17, Top1_auxi1: 0.066593, Top1_auxi2: 0.082290, Top1: 0.571776
57 | Class-wise Acc:1st: 0.733681, 2nd: 0.443165, 3rd: 0.693177, 4th: 0.535718, 5th: 0.699851, 6th: 0.176867, 7th: 0.857660, 8th: 0.397500, 9th: 0.800835, 10th: 0.369575, 11th: 0.885977, 12th: 0.267304
58 | Train:epoch: 18:[432/432], LossCE: 0.027941, LossDA: 1.391720, LossAll: 1.419661, Auxi1: 3.432436, Auxi2: 5.251736, Task: 99.282043
59 | Test:epoch: 18, Top1_auxi1: 0.065550, Top1_auxi2: 0.083747, Top1: 0.574060
60 | Class-wise Acc:1st: 0.758914, 2nd: 0.398273, 3rd: 0.703838, 4th: 0.535141, 5th: 0.715626, 6th: 0.195181, 7th: 0.862146, 8th: 0.366250, 9th: 0.830073, 10th: 0.374397, 11th: 0.856704, 12th: 0.292177
61 | Train:epoch: 19:[432/432], LossCE: 0.025859, LossDA: 1.390437, LossAll: 1.416294, Auxi1: 3.837529, Auxi2: 5.121528, Task: 99.320023
62 | Test:epoch: 19, Top1_auxi1: 0.066620, Top1_auxi2: 0.082075, Top1: 0.558489
63 | Class-wise Acc:1st: 0.765222, 2nd: 0.355396, 3rd: 0.678891, 4th: 0.560427, 5th: 0.725432, 6th: 0.185060, 7th: 0.811594, 8th: 0.326000, 9th: 0.754671, 10th: 0.395441, 11th: 0.898961, 12th: 0.244773
64 | Train:epoch: 20:[432/432], LossCE: 0.023830, LossDA: 1.388485, LossAll: 1.412315, Auxi1: 3.848380, Auxi2: 5.255353, Task: 99.370659
65 | Test:epoch: 20, Top1_auxi1: 0.059989, Top1_auxi2: 0.081193, Top1: 0.579480
66 | Class-wise Acc:1st: 0.763302, 2nd: 0.403453, 3rd: 0.706397, 4th: 0.596385, 5th: 0.748241, 6th: 0.205301, 7th: 0.838854, 8th: 0.368000, 9th: 0.810947, 10th: 0.395441, 11th: 0.860954, 12th: 0.256489
67 | Train:epoch: 21:[432/432], LossCE: 0.022136, LossDA: 1.388649, LossAll: 1.410784, Auxi1: 3.792318, Auxi2: 5.060040, Task: 99.421295
68 | Test:epoch: 21, Top1_auxi1: 0.063822, Top1_auxi2: 0.078161, Top1: 0.571030
69 | Class-wise Acc:1st: 0.746572, 2nd: 0.384748, 3rd: 0.700640, 4th: 0.555139, 5th: 0.736090, 6th: 0.202892, 7th: 0.831263, 8th: 0.360500, 9th: 0.819301, 10th: 0.378781, 11th: 0.874174, 12th: 0.262257
70 | Train:epoch: 22:[432/432], LossCE: 0.019700, LossDA: 1.390863, LossAll: 1.410562, Auxi1: 3.669343, Auxi2: 4.622396, Task: 99.529800
71 | Test:epoch: 22, Top1_auxi1: 0.065152, Top1_auxi2: 0.079538, Top1: 0.572174
72 | Class-wise Acc:1st: 0.712013, 2nd: 0.401439, 3rd: 0.701493, 4th: 0.552062, 5th: 0.726498, 6th: 0.231807, 7th: 0.844720, 8th: 0.379000, 9th: 0.812926, 10th: 0.362999, 11th: 0.880312, 12th: 0.260815
73 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "DANN"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "TASK": "closed", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_McDalNet_DANN", "SAVE_DIR": "./experiments/ckpt/logVisDA_train2validation_McDalNet_DANN", "NUM_WORKERS": 8, "PRINT_STEP": 3}
74 |
--------------------------------------------------------------------------------
/experiments/ckpt/logVisDA_train2validation_McDalNet_L1/log.txt:
--------------------------------------------------------------------------------
1 | {"DATASET": {"NUM_CLASSES": 12, "DATASET": "VisDA", "DATAROOT": "/disk1/domain_adaptation/visDA", "SOURCE_NAME": "train", "TARGET_NAME": "validation", "VAL_NAME": "validation"}, "MODEL": {"FEATURE_EXTRACTOR": "resnet50", "PRETRAINED": true}, "DATA_TRANSFORM": {"TYPE": "simple"}, "TRAIN": {"SOURCE_BATCH_SIZE": 128, "TARGET_BATCH_SIZE": 128, "BASE_LR": 0.001, "MOMENTUM": 0.9, "OPTIMIZER": "SGD", "WEIGHT_DECAY": 0.0001, "LR_SCHEDULE": "fix", "MAX_EPOCH": 30, "SAVING": false, "PROCESS_COUNTER": "iteration"}, "MCDALNET": {"DISTANCE_TYPE": "L1"}, "ADAM": {"BETA1": 0.9, "BETA2": 0.999}, "INV": {"ALPHA": 10.0, "BETA": 0.75}, "TEST": {"BATCH_SIZE": 128}, "RESUME": "", "TASK": "closed", "EVAL_METRIC": "accu_mean", "EXP_NAME": "logVisDA_train2validation_McDalNet_L1", "SAVE_DIR": "./experiments/ckpt/logVisDA_train2validation_McDalNet_L1", "NUM_WORKERS": 8, "PRINT_STEP": 3}
2 |
3 | Train:epoch: 0:[432/432], LossCE: 0.693553, LossDA: -0.014865, LossAll: 2.081821, Auxi1: 82.879410, Auxi2: 83.246529, Task: 83.201317
4 | Test:epoch: 0, Top1_auxi1: 0.516962, Top1_auxi2: 0.535738, Top1: 0.541355
5 | Class-wise Acc:1st: 0.746846, 2nd: 0.426187, 3rd: 0.655437, 4th: 0.558985, 5th: 0.662972, 6th: 0.197590, 7th: 0.886128, 8th: 0.177000, 9th: 0.729171, 10th: 0.366506, 11th: 0.904863, 12th: 0.184571 Best Acc so far: 0.541355
6 | Train:epoch: 1:[432/432], LossCE: 0.246766, LossDA: -0.043768, LossAll: 0.711814, Auxi1: 92.135056, Auxi2: 92.173035, Task: 92.312286
7 | Test:epoch: 1, Top1_auxi1: 0.492681, Top1_auxi2: 0.529946, Top1: 0.542277
8 | Class-wise Acc:1st: 0.741909, 2nd: 0.447482, 3rd: 0.627079, 4th: 0.657245, 5th: 0.614794, 6th: 0.185542, 7th: 0.866460, 8th: 0.197750, 9th: 0.791822, 10th: 0.322665, 11th: 0.915250, 12th: 0.139329 Best Acc so far: 0.542277
9 | Train:epoch: 2:[432/432], LossCE: 0.183054, LossDA: -0.059813, LossAll: 0.505924, Auxi1: 93.845848, Auxi2: 93.782555, Task: 94.048393
10 | Test:epoch: 2, Top1_auxi1: 0.490163, Top1_auxi2: 0.522411, Top1: 0.551915
11 | Class-wise Acc:1st: 0.768513, 2nd: 0.391942, 3rd: 0.585927, 4th: 0.589078, 5th: 0.667022, 6th: 0.203373, 7th: 0.874051, 8th: 0.290500, 9th: 0.763245, 10th: 0.325296, 11th: 0.932956, 12th: 0.231074 Best Acc so far: 0.551915
12 | Train:epoch: 3:[432/432], LossCE: 0.144287, LossDA: -0.068417, LossAll: 0.381817, Auxi1: 95.093681, Auxi2: 95.081017, Task: 95.278137
13 | Test:epoch: 3, Top1_auxi1: 0.486351, Top1_auxi2: 0.520537, Top1: 0.563128
14 | Class-wise Acc:1st: 0.756171, 2nd: 0.396259, 3rd: 0.614073, 4th: 0.633689, 5th: 0.704541, 6th: 0.203373, 7th: 0.870600, 8th: 0.308000, 9th: 0.800616, 10th: 0.366068, 11th: 0.922096, 12th: 0.182048 Best Acc so far: 0.563128
15 | Train:epoch: 4:[432/432], LossCE: 0.119764, LossDA: -0.073706, LossAll: 0.302515, Auxi1: 95.912903, Auxi2: 95.921951, Task: 96.135345
16 | Test:epoch: 4, Top1_auxi1: 0.478332, Top1_auxi2: 0.531994, Top1: 0.566777
17 | Class-wise Acc:1st: 0.769062, 2nd: 0.445180, 3rd: 0.653092, 4th: 0.602923, 5th: 0.694308, 6th: 0.194699, 7th: 0.888199, 8th: 0.314000, 9th: 0.768960, 10th: 0.327050, 11th: 0.909348, 12th: 0.234499 Best Acc so far: 0.566777
18 | Train:epoch: 5:[432/432], LossCE: 0.102737, LossDA: -0.076498, LossAll: 0.248881, Auxi1: 96.408424, Auxi2: 96.482567, Task: 96.585648
19 | Test:epoch: 5, Top1_auxi1: 0.488338, Top1_auxi2: 0.531829, Top1: 0.582575
20 | Class-wise Acc:1st: 0.773725, 2nd: 0.455540, 3rd: 0.684222, 4th: 0.618114, 5th: 0.718610, 6th: 0.180241, 7th: 0.864734, 8th: 0.364750, 9th: 0.792262, 10th: 0.405085, 11th: 0.910293, 12th: 0.223324 Best Acc so far: 0.582575
21 | Train:epoch: 6:[432/432], LossCE: 0.085473, LossDA: -0.077619, LossAll: 0.195600, Auxi1: 97.139038, Auxi2: 97.082970, Task: 97.327110
22 | Test:epoch: 6, Top1_auxi1: 0.496820, Top1_auxi2: 0.537137, Top1: 0.594800
23 | Class-wise Acc:1st: 0.812671, 2nd: 0.442590, 3rd: 0.681237, 4th: 0.618883, 5th: 0.729695, 6th: 0.296867, 7th: 0.891822, 8th: 0.367000, 9th: 0.854693, 10th: 0.345024, 11th: 0.905335, 12th: 0.191781 Best Acc so far: 0.594800
24 | Train:epoch: 7:[432/432], LossCE: 0.075982, LossDA: -0.077116, LossAll: 0.167648, Auxi1: 97.437431, Auxi2: 97.289139, Task: 97.562210
25 | Test:epoch: 7, Top1_auxi1: 0.483066, Top1_auxi2: 0.540713, Top1: 0.590002
26 | Class-wise Acc:1st: 0.791278, 2nd: 0.413813, 3rd: 0.741578, 4th: 0.607634, 5th: 0.724579, 6th: 0.240000, 7th: 0.868875, 8th: 0.377000, 9th: 0.783029, 10th: 0.423060, 11th: 0.906043, 12th: 0.203136
27 | Train:epoch: 8:[432/432], LossCE: 0.066726, LossDA: -0.076585, LossAll: 0.139383, Auxi1: 97.706886, Auxi2: 97.652634, Task: 97.882309
28 | Test:epoch: 8, Top1_auxi1: 0.495741, Top1_auxi2: 0.555948, Top1: 0.606180
29 | Class-wise Acc:1st: 0.811300, 2nd: 0.391942, 3rd: 0.729638, 4th: 0.640611, 5th: 0.777659, 6th: 0.277108, 7th: 0.908040, 8th: 0.440000, 9th: 0.824577, 10th: 0.395002, 11th: 0.893532, 12th: 0.184751 Best Acc so far: 0.606180
30 | Train:epoch: 9:[432/432], LossCE: 0.060872, LossDA: -0.075088, LossAll: 0.123370, Auxi1: 97.862411, Auxi2: 97.838905, Task: 98.025177
31 | Test:epoch: 9, Top1_auxi1: 0.499938, Top1_auxi2: 0.551240, Top1: 0.599826
32 | Class-wise Acc:1st: 0.800329, 2nd: 0.367770, 3rd: 0.725586, 4th: 0.612345, 5th: 0.756342, 6th: 0.315181, 7th: 0.908213, 8th: 0.409000, 9th: 0.841064, 10th: 0.325296, 11th: 0.897781, 12th: 0.239005
33 | Train:epoch: 10:[432/432], LossCE: 0.054908, LossDA: -0.072794, LossAll: 0.107301, Auxi1: 98.144531, Auxi2: 98.090279, Task: 98.260269
34 | Test:epoch: 10, Top1_auxi1: 0.504898, Top1_auxi2: 0.558344, Top1: 0.609342
35 | Class-wise Acc:1st: 0.818431, 2nd: 0.417842, 3rd: 0.735181, 4th: 0.609461, 5th: 0.767853, 6th: 0.324819, 7th: 0.896998, 8th: 0.445000, 9th: 0.830292, 10th: 0.345463, 11th: 0.903211, 12th: 0.217556 Best Acc so far: 0.609342
36 | Train:epoch: 11:[432/432], LossCE: 0.049372, LossDA: -0.070811, LossAll: 0.092945, Auxi1: 98.386864, Auxi2: 98.327187, Task: 98.515266
37 | Test:epoch: 11, Top1_auxi1: 0.517144, Top1_auxi2: 0.574096, Top1: 0.623316
38 | Class-wise Acc:1st: 0.831596, 2nd: 0.427050, 3rd: 0.749040, 4th: 0.631766, 5th: 0.789810, 6th: 0.357590, 7th: 0.906832, 8th: 0.505000, 9th: 0.845680, 10th: 0.340202, 11th: 0.897309, 12th: 0.197909 Best Acc so far: 0.623316
39 | Train:epoch: 12:[432/432], LossCE: 0.045497, LossDA: -0.068416, LossAll: 0.082828, Auxi1: 98.457397, Auxi2: 98.441116, Task: 98.647278
40 | Test:epoch: 12, Top1_auxi1: 0.516069, Top1_auxi2: 0.578118, Top1: 0.624275
41 | Class-wise Acc:1st: 0.831322, 2nd: 0.396547, 3rd: 0.720682, 4th: 0.660225, 5th: 0.797058, 6th: 0.358554, 7th: 0.907867, 8th: 0.515500, 9th: 0.838646, 10th: 0.372644, 11th: 0.898489, 12th: 0.193764 Best Acc so far: 0.624275
42 | Train:epoch: 13:[432/432], LossCE: 0.041655, LossDA: -0.066161, LossAll: 0.073744, Auxi1: 98.509842, Auxi2: 98.511650, Task: 98.712387
43 | Test:epoch: 13, Top1_auxi1: 0.528750, Top1_auxi2: 0.580861, Top1: 0.628952
44 | Class-wise Acc:1st: 0.820077, 2nd: 0.452374, 3rd: 0.728998, 4th: 0.672243, 5th: 0.796206, 6th: 0.361928, 7th: 0.904589, 8th: 0.501250, 9th: 0.863486, 10th: 0.358615, 11th: 0.889754, 12th: 0.197909 Best Acc so far: 0.628952
45 | Train:epoch: 14:[432/432], LossCE: 0.037410, LossDA: -0.063568, LossAll: 0.063342, Auxi1: 98.708771, Auxi2: 98.694298, Task: 98.900467
46 | Test:epoch: 14, Top1_auxi1: 0.537866, Top1_auxi2: 0.599769, Top1: 0.643354
47 | Class-wise Acc:1st: 0.849698, 2nd: 0.498129, 3rd: 0.764392, 4th: 0.650611, 5th: 0.782989, 6th: 0.446747, 7th: 0.919082, 8th: 0.512000, 9th: 0.857111, 10th: 0.359053, 11th: 0.887394, 12th: 0.193043 Best Acc so far: 0.643354
48 | Train:epoch: 15:[432/432], LossCE: 0.035579, LossDA: -0.061481, LossAll: 0.059481, Auxi1: 98.770256, Auxi2: 98.753983, Task: 98.954720
49 | Test:epoch: 15, Top1_auxi1: 0.551721, Top1_auxi2: 0.614768, Top1: 0.655238
50 | Class-wise Acc:1st: 0.886177, 2nd: 0.496691, 3rd: 0.755864, 4th: 0.661090, 5th: 0.815604, 6th: 0.421687, 7th: 0.929607, 8th: 0.584750, 9th: 0.866564, 10th: 0.367383, 11th: 0.883853, 12th: 0.193583 Best Acc so far: 0.655238
51 | Train:epoch: 16:[432/432], LossCE: 0.032109, LossDA: -0.058913, LossAll: 0.051380, Auxi1: 98.909508, Auxi2: 98.858871, Task: 99.045143
52 | Test:epoch: 16, Top1_auxi1: 0.553002, Top1_auxi2: 0.608265, Top1: 0.646800
53 | Class-wise Acc:1st: 0.833242, 2nd: 0.502734, 3rd: 0.739659, 4th: 0.665224, 5th: 0.816670, 6th: 0.442410, 7th: 0.919255, 8th: 0.516500, 9th: 0.854474, 10th: 0.380973, 11th: 0.897781, 12th: 0.192682
54 | Train:epoch: 17:[432/432], LossCE: 0.030613, LossDA: -0.056364, LossAll: 0.049122, Auxi1: 99.028862, Auxi2: 98.958336, Task: 99.131943
55 | Test:epoch: 17, Top1_auxi1: 0.565271, Top1_auxi2: 0.612788, Top1: 0.655612
56 | Class-wise Acc:1st: 0.835985, 2nd: 0.496403, 3rd: 0.763539, 4th: 0.689549, 5th: 0.821147, 6th: 0.509880, 7th: 0.910973, 8th: 0.545750, 9th: 0.900418, 10th: 0.349847, 11th: 0.890463, 12th: 0.153389 Best Acc so far: 0.655612
57 | Train:epoch: 18:[432/432], LossCE: 0.027097, LossDA: -0.054544, LossAll: 0.039466, Auxi1: 99.121094, Auxi2: 99.061417, Task: 99.260345
58 | Test:epoch: 18, Top1_auxi1: 0.559192, Top1_auxi2: 0.617162, Top1: 0.655087
59 | Class-wise Acc:1st: 0.873834, 2nd: 0.437122, 3rd: 0.718763, 4th: 0.682434, 5th: 0.809209, 6th: 0.503133, 7th: 0.924431, 8th: 0.591250, 9th: 0.863047, 10th: 0.375274, 11th: 0.903919, 12th: 0.178623
60 | Train:epoch: 19:[432/432], LossCE: 0.025011, LossDA: -0.052258, LossAll: 0.035955, Auxi1: 99.153648, Auxi2: 99.115669, Task: 99.318214
61 | Test:epoch: 19, Top1_auxi1: 0.587179, Top1_auxi2: 0.623123, Top1: 0.671575
62 | Class-wise Acc:1st: 0.842567, 2nd: 0.544748, 3rd: 0.766524, 4th: 0.681761, 5th: 0.827116, 6th: 0.567229, 7th: 0.909593, 8th: 0.617250, 9th: 0.887008, 10th: 0.346778, 11th: 0.902502, 12th: 0.165826 Best Acc so far: 0.671575
63 | Train:epoch: 20:[432/432], LossCE: 0.024907, LossDA: -0.050518, LossAll: 0.036782, Auxi1: 99.164497, Auxi2: 99.151840, Task: 99.298325
64 | Test:epoch: 20, Top1_auxi1: 0.584258, Top1_auxi2: 0.621111, Top1: 0.667436
65 | Class-wise Acc:1st: 0.854361, 2nd: 0.468201, 3rd: 0.757783, 4th: 0.676473, 5th: 0.818802, 6th: 0.585542, 7th: 0.917702, 8th: 0.648250, 9th: 0.891185, 10th: 0.344586, 11th: 0.901794, 12th: 0.144557
66 | Train:epoch: 21:[432/432], LossCE: 0.024208, LossDA: -0.048013, LossAll: 0.037116, Auxi1: 99.184387, Auxi2: 99.184387, Task: 99.316406
67 | Test:epoch: 21, Top1_auxi1: 0.596875, Top1_auxi2: 0.643675, Top1: 0.678382
68 | Class-wise Acc:1st: 0.870817, 2nd: 0.480576, 3rd: 0.763326, 4th: 0.696952, 5th: 0.841398, 6th: 0.622651, 7th: 0.927536, 8th: 0.617750, 9th: 0.914047, 10th: 0.367821, 11th: 0.884325, 12th: 0.153389 Best Acc so far: 0.678382
69 | Train:epoch: 22:[432/432], LossCE: 0.022081, LossDA: -0.047065, LossAll: 0.030939, Auxi1: 99.244072, Auxi2: 99.233215, Task: 99.345345
70 | Test:epoch: 22, Top1_auxi1: 0.604596, Top1_auxi2: 0.639697, Top1: 0.681608
71 | Class-wise Acc:1st: 0.845036, 2nd: 0.511079, 3rd: 0.783795, 4th: 0.696760, 5th: 0.810701, 6th: 0.627952, 7th: 0.914079, 8th: 0.686500, 9th: 0.910969, 10th: 0.356861, 11th: 0.893532, 12th: 0.142033 Best Acc so far: 0.681608
72 | Train:epoch: 23:[432/432], LossCE: 0.021154, LossDA: -0.044631, LossAll: 0.030898, Auxi1: 99.345345, Auxi2: 99.229599, Task: 99.430336
73 | Test:epoch: 23, Top1_auxi1: 0.604646, Top1_auxi2: 0.653916, Top1: 0.684511
74 | Class-wise Acc:1st: 0.863412, 2nd: 0.514245, 3rd: 0.793817, 4th: 0.713681, 5th: 0.851418, 6th: 0.622651, 7th: 0.911146, 8th: 0.645500, 9th: 0.888327, 10th: 0.387111, 11th: 0.883853, 12th: 0.138969 Best Acc so far: 0.684511
75 | Train:epoch: 24:[432/432], LossCE: 0.019100, LossDA: -0.043429, LossAll: 0.025213, Auxi1: 99.361618, Auxi2: 99.339920, Task: 99.488213
76 | Test:epoch: 24, Top1_auxi1: 0.599118, Top1_auxi2: 0.656066, Top1: 0.682732
77 | Class-wise Acc:1st: 0.867526, 2nd: 0.522014, 3rd: 0.776333, 4th: 0.695510, 5th: 0.822426, 6th: 0.631807, 7th: 0.944790, 8th: 0.665500, 9th: 0.892064, 10th: 0.330118, 11th: 0.887158, 12th: 0.157534
78 | Train:epoch: 25:[432/432], LossCE: 0.018116, LossDA: -0.042065, LossAll: 0.023505, Auxi1: 99.414062, Auxi2: 99.394173, Task: 99.506294
79 | Test:epoch: 25, Top1_auxi1: 0.616083, Top1_auxi2: 0.658810, Top1: 0.691398
80 | Class-wise Acc:1st: 0.877674, 2nd: 0.548489, 3rd: 0.767804, 4th: 0.699356, 5th: 0.835856, 6th: 0.639518, 7th: 0.907350, 8th: 0.668000, 9th: 0.902176, 10th: 0.393249, 11th: 0.903919, 12th: 0.153389 Best Acc so far: 0.691398
81 | Train:epoch: 26:[432/432], LossCE: 0.016651, LossDA: -0.040484, LossAll: 0.020489, Auxi1: 99.452042, Auxi2: 99.417679, Task: 99.582253
82 | Test:epoch: 26, Top1_auxi1: 0.611797, Top1_auxi2: 0.662253, Top1: 0.687778
83 | Class-wise Acc:1st: 0.847778, 2nd: 0.561727, 3rd: 0.799787, 4th: 0.722815, 5th: 0.819655, 6th: 0.642892, 7th: 0.930297, 8th: 0.674500, 9th: 0.907232, 10th: 0.327488, 11th: 0.879839, 12th: 0.139329
84 | Train:epoch: 27:[432/432], LossCE: 0.015983, LossDA: -0.038920, LossAll: 0.020103, Auxi1: 99.509911, Auxi2: 99.435768, Task: 99.582253
85 | Test:epoch: 27, Top1_auxi1: 0.629375, Top1_auxi2: 0.668425, Top1: 0.694345
86 | Class-wise Acc:1st: 0.873560, 2nd: 0.544460, 3rd: 0.792964, 4th: 0.700221, 5th: 0.833298, 6th: 0.686265, 7th: 0.935300, 8th: 0.698500, 9th: 0.926357, 10th: 0.296361, 11th: 0.884797, 12th: 0.160058 Best Acc so far: 0.694345
87 | Train:epoch: 28:[432/432], LossCE: 0.015241, LossDA: -0.037351, LossAll: 0.018945, Auxi1: 99.490021, Auxi2: 99.462891, Task: 99.600334
88 | Test:epoch: 28, Top1_auxi1: 0.641193, Top1_auxi2: 0.675345, Top1: 0.703984
89 | Class-wise Acc:1st: 0.861492, 2nd: 0.596547, 3rd: 0.814499, 4th: 0.695702, 5th: 0.845449, 6th: 0.708434, 7th: 0.930469, 8th: 0.730250, 9th: 0.928995, 10th: 0.305129, 11th: 0.872049, 12th: 0.158796 Best Acc so far: 0.703984
90 | Train:epoch: 29:[432/432], LossCE: 0.014775, LossDA: -0.036507, LossAll: 0.017794, Auxi1: 99.513527, Auxi2: 99.477356, Task: 99.600334
91 | Test:epoch: 29, Top1_auxi1: 0.622866, Top1_auxi2: 0.670656, Top1: 0.695423
92 | Class-wise Acc:1st: 0.879594, 2nd: 0.567194, 3rd: 0.788913, 4th: 0.712720, 5th: 0.836069, 6th: 0.631325, 7th: 0.941339, 8th: 0.727250, 9th: 0.904375, 10th: 0.334941, 11th: 0.888338, 12th: 0.133021
93 | Train:epoch: 30:[432/432], LossCE: 0.012954, LossDA: -0.035282, LossAll: 0.013344, Auxi1: 99.594910, Auxi2: 99.584061, Task: 99.696182
94 | Test:epoch: 30, Top1_auxi1: 0.629052, Top1_auxi2: 0.672304, Top1: 0.694756
95 | Class-wise Acc:1st: 0.873012, 2nd: 0.532950, 3rd: 0.818763, 4th: 0.708586, 5th: 0.840546, 6th: 0.681928, 7th: 0.927536, 8th: 0.709750, 9th: 0.906353, 10th: 0.327488, 11th: 0.889754, 12th: 0.120404
96 | Train:epoch: 31:[432/432], LossCE: 0.012972, LossDA: -0.034260, LossAll: 0.014438, Auxi1: 99.573204, Auxi2: 99.533424, Task: 99.681717
97 | Test:epoch: 31, Top1_auxi1: 0.645591, Top1_auxi2: 0.677350, Top1: 0.703731
98 | Class-wise Acc:1st: 0.882063, 2nd: 0.583022, 3rd: 0.795949, 4th: 0.708009, 5th: 0.850991, 6th: 0.708434, 7th: 0.924431, 8th: 0.749000, 9th: 0.914707, 10th: 0.300745, 11th: 0.891879, 12th: 0.135544
99 |
--------------------------------------------------------------------------------
/experiments/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/.DS_Store
--------------------------------------------------------------------------------
/experiments/configs/ImageCLEF/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/ImageCLEF/.DS_Store
--------------------------------------------------------------------------------
/experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 12
3 | DATASET: 'ImageCLEF'
4 | DATAROOT: '/data1/domain_adaptation/image_CLEF'
5 | SOURCE_NAME: 'c'
6 | TARGET_NAME: 'i'
7 | VAL_NAME: 'i'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'ours'
14 |
15 |
16 | EVAL_METRIC: "accu"
17 | SAVE_DIR: "./experiments/ckpt"
18 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/ImageCLEF/SymmNets/clef_train_c2i_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 12
3 | DATASET: 'ImageCLEF'
4 | DATAROOT: '/data1/domain_adaptation/image_CLEF'
5 | SOURCE_NAME: 'c'
6 | TARGET_NAME: 'i'
7 | VAL_NAME: 'i'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'ours'
14 |
15 |
16 | EVAL_METRIC: "accu"
17 | SAVE_DIR: "./experiments/SymmNets"
18 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/Office31/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/Office31/.DS_Store
--------------------------------------------------------------------------------
/experiments/configs/Office31/McDalNet/office31_train_webcam2amazon_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 31
3 | DATASET: 'Office31'
4 | DATAROOT: '/data1/domain_adaptation/Office31'
5 | SOURCE_NAME: 'webcam'
6 | TARGET_NAME: 'amazon'
7 | VAL_NAME: 'amazon'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'ours'
14 |
15 |
16 | EVAL_METRIC: "accu"
17 | SAVE_DIR: "./experiments/ckpt"
18 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/Office31/SymmNets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/Office31/SymmNets/.DS_Store
--------------------------------------------------------------------------------
/experiments/configs/Office31/SymmNets/office31_train_amazon2dslr_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 31
3 | DATASET: 'Office31'
4 | DATAROOT: '/data1/domain_adaptation/Office31'
5 | SOURCE_NAME: 'amazon'
6 | TARGET_NAME: 'dslr'
7 | VAL_NAME: 'dslr'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'longs'
14 |
15 | TRAIN:
16 | MAX_EPOCH: 200
17 | # More training epoch is ok, but not necessary.
18 |
19 | EVAL_METRIC: "accu"
20 | SAVE_DIR: "./experiments/SymmNets"
21 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/Office31/SymmNets/office31_train_amazon2dslr_cfg_SC.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 31
3 | DATASET: 'Office31'
4 | DATAROOT: '/data1/domain_adaptation/Office31'
5 | SOURCE_NAME: 'amazon'
6 | TARGET_NAME: 'dslr'
7 | VAL_NAME: 'dslr'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'longs'
14 |
15 | STRENGTHEN:
16 | DATALOAD: 'soft'
17 | PERCATE: 4
18 | CLUSTER_FREQ: 20
19 |
20 | TRAIN:
21 | MAX_EPOCH: 200
22 | # More training epoch is ok, but not necessary.
23 |
24 | EVAL_METRIC: "accu"
25 | SAVE_DIR: "./experiments/SymmNets"
26 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/Office31/SymmNets/office31_train_amazon2webcam_open_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 11
3 | DATASET: 'Office31'
4 | DATAROOT: '/data1/domain_adaptation/openset_Office31/'
5 | SOURCE_NAME: 'amazon_s_busto'
6 | TARGET_NAME: 'webcam_t_busto'
7 | VAL_NAME: 'webcam_t_busto'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'ours'
14 |
15 |
16 | EVAL_METRIC: "accu_mean"
17 | SAVE_DIR: "./experiments/Open"
18 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/Office31/SymmNets/office31_train_webcam2amazon_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 31
3 | DATASET: 'Office31'
4 | DATAROOT: '/data1/domain_adaptation/Office31'
5 | SOURCE_NAME: 'webcam'
6 | TARGET_NAME: 'amazon'
7 | VAL_NAME: 'amazon'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'ours'
14 |
15 | TRAIN:
16 | MAX_EPOCH: 200
17 | # More training epoch is ok, but not necessary.
18 |
19 | EVAL_METRIC: "accu"
20 | SAVE_DIR: "./experiments/SymmNets"
21 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/Office31/SymmNets/office31_train_webcam2amazon_cfg_SC.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 31
3 | DATASET: 'Office31'
4 | DATAROOT: '/data1/domain_adaptation/Office31'
5 | SOURCE_NAME: 'webcam'
6 | TARGET_NAME: 'amazon'
7 | VAL_NAME: 'amazon'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'longs'
14 |
15 | STRENGTHEN:
16 | DATALOAD: 'soft'
17 | PERCATE: 4
18 | CLUSTER_FREQ: 20
19 |
20 | TRAIN:
21 | MAX_EPOCH: 200
22 | # More training epoch is ok, but not necessary.
23 |
24 |
25 | EVAL_METRIC: "accu"
26 | SAVE_DIR: "./experiments/SymmNets"
27 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/OfficeHome/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/experiments/configs/OfficeHome/.DS_Store
--------------------------------------------------------------------------------
/experiments/configs/OfficeHome/SymmNets/home_train_A2R_partial_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 65
3 | DATASET: 'OfficeHome'
4 | DATAROOT: '/data1/domain_adaptation/OfficeHome'
5 | SOURCE_NAME: 'Art'
6 | TARGET_NAME: 'Real_World_25'
7 | VAL_NAME: 'Real_World_25'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'ours'
14 |
15 | TRAIN:
16 | MAX_EPOCH: 80
17 |
18 | EVAL_METRIC: "accu"
19 | SAVE_DIR: "./experiments/Partial"
20 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 12
3 | DATASET: 'VisDA'
4 | DATAROOT: '/data1/domain_adaptation/visDA'
5 | SOURCE_NAME: 'train'
6 | TARGET_NAME: 'validation'
7 | VAL_NAME: 'validation'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'simple'
14 |
15 | TRAIN:
16 | BASE_LR: 0.001
17 | MAX_EPOCH: 30
18 | LR_SCHEDULE: 'fix'
19 |
20 | EVAL_METRIC: "accu_mean"
21 | SAVE_DIR: "./experiments/ckpt"
22 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 12
3 | DATASET: 'VisDA'
4 | DATAROOT: '/data1/domain_adaptation/visDA'
5 | SOURCE_NAME: 'train'
6 | TARGET_NAME: 'validation'
7 | VAL_NAME: 'validation'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet50'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'simple'
14 |
15 | TRAIN:
16 | BASE_LR: 0.001
17 | MAX_EPOCH: 30
18 | LR_SCHEDULE: 'fix'
19 |
20 | EVAL_METRIC: "accu_mean"
21 | SAVE_DIR: "./experiments/SymmNets"
22 | NUM_WORKERS: 8
--------------------------------------------------------------------------------
/experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg_res101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NUM_CLASSES: 12
3 | DATASET: 'VisDA'
4 | DATAROOT: '/data1/domain_adaptation/visDA'
5 | SOURCE_NAME: 'train'
6 | TARGET_NAME: 'validation'
7 | VAL_NAME: 'validation'
8 |
9 | MODEL:
10 | FEATURE_EXTRACTOR: 'resnet101'
11 |
12 | DATA_TRANSFORM:
13 | TYPE: 'simple'
14 |
15 | TRAIN:
16 | BASE_LR: 0.001
17 | MAX_EPOCH: 40
18 | LR_SCHEDULE: 'fix'
19 |
20 | EVAL_METRIC: "accu_mean"
21 | SAVE_DIR: "./experiments/SymmNets"
22 | NUM_WORKERS: 8
23 |
--------------------------------------------------------------------------------
/models/loss_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils.utils import process_zero_values
5 | import ipdb
6 |
7 |
8 | def _assert_no_grad(variable):
9 | assert not variable.requires_grad, \
10 | "nn criterions don't compute the gradient w.r.t. targets - please " \
11 | "mark these variables as volatile or not requiring gradients"
12 |
13 |
14 | class _Loss(nn.Module):
15 | def __init__(self, size_average=True):
16 | super(_Loss, self).__init__()
17 | self.size_average = size_average
18 |
19 |
20 | class _WeightedLoss(_Loss):
21 | def __init__(self, weight=None, size_average=True):
22 | super(_WeightedLoss, self).__init__(size_average)
23 | self.register_buffer('weight', weight)
24 |
25 |
26 | class CrossEntropyClassWeighted(_Loss):
27 |
28 | def __init__(self, size_average=True, ignore_index=-100, reduce=None, reduction='elementwise_mean'):
29 | super(CrossEntropyClassWeighted, self).__init__(size_average)
30 | self.ignore_index = ignore_index
31 | self.reduction = reduction
32 |
33 | def forward(self, input, target, weight=None):
34 | return F.cross_entropy(input, target, weight, ignore_index=self.ignore_index, reduction=self.reduction)
35 |
36 |
37 | ### clone this function from: https://github.com/krumo/swd_pytorch/blob/master/swd_pytorch.py. [Unofficial]
38 | def discrepancy_slice_wasserstein(p1, p2):
39 | s = p1.shape
40 | if s[1] > 1:
41 | proj = torch.randn(s[1], 128).cuda()
42 | proj *= torch.rsqrt(torch.sum(torch.mul(proj, proj), 0, keepdim=True))
43 | p1 = torch.matmul(p1, proj)
44 | p2 = torch.matmul(p2, proj)
45 | p1 = torch.topk(p1, s[0], dim=0)[0]
46 | p2 = torch.topk(p2, s[0], dim=0)[0]
47 | dist = p1 - p2
48 | wdist = torch.mean(torch.mul(dist, dist))
49 |
50 | return wdist
51 |
52 |
53 | class McDalNetLoss(_WeightedLoss):
54 |
55 | def __init__(self, weight=None, size_average=True):
56 | super(McDalNetLoss, self).__init__(weight, size_average)
57 |
58 | def forward(self, input1, input2, dis_type='L1'):
59 |
60 | if dis_type == 'L1':
61 | prob_s = F.softmax(input1, dim=1)
62 | prob_t = F.softmax(input2, dim=1)
63 | loss = torch.mean(torch.abs(prob_s - prob_t)) ### element-wise
64 | elif dis_type == 'CE': ## Cross entropy
65 | loss = - ((F.log_softmax(input2, dim=1)).mul(F.softmax(input1, dim=1))).mean() - (
66 | (F.log_softmax(input1, dim=1)).mul(F.softmax(input2, dim=1))).mean()
67 | loss = loss * 0.5
68 | elif dis_type == 'KL': ##### averaged over elements, not the real KL div (summed over elements of instance, and averaged over instance)
69 | ############# nn.KLDivLoss(size_average=False) Vs F.kl_div()
70 | loss = (F.kl_div(F.log_softmax(input1), F.softmax(input2))) + (
71 | F.kl_div(F.log_softmax(input2), F.softmax(input1)))
72 | loss = loss * 0.5
73 | ############# the following two distances are not evaluated in our paper, and need further investigation
74 | elif dis_type == 'L2':
75 | nClass = input1.size()[1]
76 | prob_s = F.softmax(input1, dim=1)
77 | prob_t = F.softmax(input2, dim=1)
78 | loss = torch.norm(prob_s - prob_t, p=2, dim=1).mean() / nClass ### element-wise
79 | elif dis_type == 'Wasse': ## distance proposed in Sliced wasserstein discrepancy for unsupervised domain adaptation,
80 | prob_s = F.softmax(input1, dim=1)
81 | prob_t = F.softmax(input2, dim=1)
82 | loss = discrepancy_slice_wasserstein(prob_s, prob_t)
83 |
84 | return loss
85 |
86 |
87 | class TargetDiscrimLoss(_WeightedLoss):
88 | def __init__(self, weight=None, size_average=True, num_classes=31):
89 | super(TargetDiscrimLoss, self).__init__(weight, size_average)
90 | self.num_classes = num_classes
91 |
92 | def forward(self, input):
93 | batch_size = input.size(0)
94 | prob = F.softmax(input, dim=1)
95 |
96 | if (prob.data[:, self.num_classes:].sum(1) == 0).sum() != 0: ########### in case of log(0)
97 | soft_weight = torch.FloatTensor(batch_size).fill_(0)
98 | soft_weight[prob[:, self.num_classes:].sum(1).data.cpu() == 0] = 1e-6
99 | soft_weight_var = soft_weight.cuda()
100 | loss = -((prob[:, self.num_classes:].sum(1) + soft_weight_var).log().mean())
101 | else:
102 | loss = -(prob[:, self.num_classes:].sum(1).log().mean())
103 | return loss
104 |
105 | class SourceDiscrimLoss(_WeightedLoss):
106 | def __init__(self, weight=None, size_average=True, num_classes=31):
107 | super(SourceDiscrimLoss, self).__init__(weight, size_average)
108 | self.num_classes = num_classes
109 |
110 | def forward(self, input):
111 | batch_size = input.size(0)
112 | prob = F.softmax(input, dim=1)
113 |
114 | if (prob.data[:, :self.num_classes].sum(1) == 0).sum() != 0: ########### in case of log(0)
115 | soft_weight = torch.FloatTensor(batch_size).fill_(0)
116 | soft_weight[prob[:, :self.num_classes].sum(1).data.cpu() == 0] = 1e-6
117 | soft_weight_var = soft_weight.cuda()
118 | loss = -((prob[:, :self.num_classes].sum(1) + soft_weight_var).log().mean())
119 | else:
120 | loss = -(prob[:, :self.num_classes].sum(1).log().mean())
121 | return loss
122 |
123 |
124 | class ConcatenatedCELoss(_WeightedLoss):
125 | def __init__(self, weight=None, size_average=True, num_classes=31):
126 | super(ConcatenatedCELoss, self).__init__(weight, size_average)
127 | self.num_classes = num_classes
128 |
129 | def forward(self, input):
130 | prob = F.softmax(input, dim=1)
131 | prob_s = prob[:, :self.num_classes]
132 | prob_t = prob[:, self.num_classes:]
133 |
134 | prob_s = process_zero_values(prob_s)
135 | prob_t = process_zero_values(prob_t)
136 | loss = - (prob_s.log().mul(prob_t)).sum(1).mean() - (prob_t.log().mul(prob_s)).sum(1).mean()
137 | loss = loss * 0.5
138 | return loss
139 |
140 |
141 |
142 | class ConcatenatedEMLoss(_WeightedLoss):
143 | def __init__(self, weight=None, size_average=True, num_classes=31):
144 | super(ConcatenatedEMLoss, self).__init__(weight, size_average)
145 | self.num_classes = num_classes
146 |
147 | def forward(self, input):
148 | prob = F.softmax(input, dim=1)
149 | prob_s = prob[:, :self.num_classes]
150 | prob_t = prob[:, self.num_classes:]
151 | prob_sum = prob_s + prob_t
152 | prob_sum = process_zero_values(prob_sum)
153 | loss = - prob_sum.log().mul(prob_sum).sum(1).mean()
154 |
155 | return loss
156 |
157 | class MinEntropyConsensusLoss(nn.Module):
158 | def __init__(self, num_classes):
159 | super(MinEntropyConsensusLoss, self).__init__()
160 | self.num_classes = num_classes
161 |
162 | def forward(self, x, y):
163 | i = torch.eye(self.num_classes).unsqueeze(0).cuda()
164 | x = F.log_softmax(x, dim=1)
165 | y = F.log_softmax(y, dim=1)
166 | x = x.unsqueeze(-1)
167 | y = y.unsqueeze(-1)
168 |
169 | ce_x = (- 1.0 * i * x).sum(1)
170 | ce_y = (- 1.0 * i * y).sum(1)
171 |
172 | ce = 0.5 * (ce_x + ce_y).min(1)[0].mean()
173 |
174 | return ce
--------------------------------------------------------------------------------
/models/resnet_McDalNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | from torch.autograd import Function
5 | from config.config import cfg
6 | import torch
7 | import ipdb
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152']
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 |
21 |
22 | class ReverseLayerF(Function):
23 |
24 | @staticmethod
25 | def forward(ctx, x, alpha):
26 | ctx.alpha = alpha
27 |
28 | return x.view_as(x)
29 |
30 | @staticmethod
31 | def backward(ctx, grad_output):
32 | output = grad_output.neg() * ctx.alpha
33 |
34 | return output, None
35 |
36 | class ZeroLayerF(Function):
37 |
38 | @staticmethod
39 | def forward(ctx, x, alpha):
40 | ctx.alpha = 0.0
41 |
42 | return x.view_as(x)
43 |
44 | @staticmethod
45 | def backward(ctx, grad_output):
46 | output = grad_output * 0.0
47 |
48 | return output, None
49 |
50 | def conv3x3(in_planes, out_planes, stride=1):
51 | "3x3 convolution with padding"
52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 |
55 |
56 | class BasicBlock(nn.Module):
57 | expansion = 1
58 |
59 | def __init__(self, inplanes, planes, stride=1, downsample=None):
60 | super(BasicBlock, self).__init__()
61 | self.conv1 = conv3x3(inplanes, planes, stride)
62 | self.bn1 = nn.BatchNorm2d(planes)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.conv2 = conv3x3(planes, planes)
65 | self.bn2 = nn.BatchNorm2d(planes)
66 | self.downsample = downsample
67 | self.stride = stride
68 |
69 | def forward(self, x):
70 | residual = x
71 |
72 | out = self.conv1(x)
73 | out = self.bn1(out)
74 | out = self.relu(out)
75 |
76 | out = self.conv2(out)
77 | out = self.bn2(out)
78 |
79 | if self.downsample is not None:
80 | residual = self.downsample(x)
81 |
82 | out += residual
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 |
88 | class Bottleneck(nn.Module):
89 | expansion = 4
90 |
91 | def __init__(self, inplanes, planes, stride=1, downsample=None):
92 | super(Bottleneck, self).__init__()
93 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
94 | self.bn1 = nn.BatchNorm2d(planes)
95 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
96 | padding=1, bias=False)
97 | self.bn2 = nn.BatchNorm2d(planes)
98 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
99 | self.bn3 = nn.BatchNorm2d(planes * 4)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.downsample = downsample
102 | self.stride = stride
103 |
104 | def forward(self, x):
105 | residual = x
106 |
107 | out = self.conv1(x)
108 | out = self.bn1(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv2(out)
112 | out = self.bn2(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv3(out)
116 | out = self.bn3(out)
117 |
118 | if self.downsample is not None:
119 | residual = self.downsample(x)
120 |
121 | out += residual
122 | out = self.relu(out)
123 |
124 | return out
125 |
126 |
127 | class ResNet(nn.Module):
128 |
129 | def __init__(self, block, layers, num_classes=1000):
130 | self.inplanes = 64
131 | super(ResNet, self).__init__()
132 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
133 | bias=False)
134 | self.bn1 = nn.BatchNorm2d(64)
135 | self.relu = nn.ReLU(inplace=True)
136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
137 | self.layer1 = self._make_layer(block, 64, layers[0])
138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
141 | self.avgpool = nn.AvgPool2d(7)
142 | self.fc = nn.Linear(512 * block.expansion, num_classes) ## for classification
143 | self.fc_aux1 = nn.Linear(512 * block.expansion, num_classes) ## auxiliary classifier one
144 | self.fc_aux2 = nn.Linear(512 * block.expansion, num_classes) ## auxiliary classifier two
145 | self.fcdc = nn.Linear(512 * block.expansion, 1) ## domain classifier
146 |
147 | for m in self.modules():
148 | if isinstance(m, nn.Conv2d):
149 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
150 | m.weight.data.normal_(0, math.sqrt(2. / n))
151 | elif isinstance(m, nn.BatchNorm2d):
152 | m.weight.data.fill_(1)
153 | m.bias.data.zero_()
154 |
155 | def _make_layer(self, block, planes, blocks, stride=1):
156 | downsample = None
157 | if stride != 1 or self.inplanes != planes * block.expansion:
158 | downsample = nn.Sequential(
159 | nn.Conv2d(self.inplanes, planes * block.expansion,
160 | kernel_size=1, stride=stride, bias=False),
161 | nn.BatchNorm2d(planes * block.expansion),
162 | )
163 |
164 | layers = []
165 | layers.append(block(self.inplanes, planes, stride, downsample))
166 | self.inplanes = planes * block.expansion
167 | for i in range(1, blocks):
168 | layers.append(block(self.inplanes, planes))
169 |
170 | return nn.Sequential(*layers)
171 |
172 | def forward(self, x, alpha):
173 | x = self.conv1(x)
174 | x = self.bn1(x)
175 | x = self.relu(x)
176 | x = self.maxpool(x)
177 |
178 | x = self.layer1(x)
179 | x = self.layer2(x)
180 | x = self.layer3(x)
181 | x = self.layer4(x)
182 |
183 | x = self.avgpool(x)
184 | x = x.view(x.size(0), -1)
185 | out = self.fc(x)
186 | rev_x = ReverseLayerF.apply(x, alpha) ## gradient reverse
187 | out1 = self.fc_aux1(rev_x)
188 | out2 = self.fc_aux2(rev_x)
189 | outdc = self.fcdc(rev_x)
190 |
191 | trunc_x = ZeroLayerF.apply(x, 0.0) ## zero gradient
192 | out1_trunc = self.fc_aux1(trunc_x)
193 | out2_trunc = self.fc_aux2(trunc_x)
194 |
195 | return x, out, out1, out2, outdc, out1_trunc, out2_trunc
196 |
197 |
198 | def resnet18():
199 | """Constructs a ResNet-18 model.
200 | """
201 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=cfg.DATASET.NUM_CLASSES)
202 | if cfg.MODEL.PRETRAINED:
203 | pretrained_dict = model_zoo.load_url(model_urls['resnet18'])
204 | model_dict = model.state_dict()
205 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
206 | pretrained_dict_temp.pop('fc.weight')
207 | pretrained_dict_temp.pop('fc.bias')
208 | model_dict.update(pretrained_dict_temp)
209 | model.load_state_dict(model_dict)
210 | return model
211 |
212 |
213 | def resnet34():
214 | """Constructs a ResNet-34 model.
215 | """
216 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=cfg.DATASET.NUM_CLASSES)
217 | if cfg.MODEL.PRETRAINED:
218 | pretrained_dict = model_zoo.load_url(model_urls['resnet34'])
219 | model_dict = model.state_dict()
220 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
221 | pretrained_dict_temp.pop('fc.weight')
222 | pretrained_dict_temp.pop('fc.bias')
223 | model_dict.update(pretrained_dict_temp)
224 | model.load_state_dict(model_dict)
225 |
226 | return model
227 |
228 |
229 | def resnet50():
230 | """Constructs a ResNet-50 model.
231 | """
232 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=cfg.DATASET.NUM_CLASSES)
233 | if cfg.MODEL.PRETRAINED:
234 | print('load the ImageNet pretrained parameters')
235 | pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
236 | model_dict = model.state_dict()
237 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
238 | pretrained_dict_temp.pop('fc.weight')
239 | pretrained_dict_temp.pop('fc.bias')
240 | model_dict.update(pretrained_dict_temp)
241 | model.load_state_dict(model_dict)
242 |
243 | return model
244 |
245 |
246 | def resnet101():
247 | """Constructs a ResNet-101 model.
248 | """
249 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=cfg.DATASET.NUM_CLASSES)
250 | if cfg.MODEL.PRETRAINED:
251 | pretrained_dict = model_zoo.load_url(model_urls['resnet101'])
252 | model_dict = model.state_dict()
253 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
254 | pretrained_dict_temp.pop('fc.weight')
255 | pretrained_dict_temp.pop('fc.bias')
256 | model_dict.update(pretrained_dict_temp)
257 | model.load_state_dict(model_dict)
258 |
259 | return model
260 |
261 |
262 | def resnet152():
263 | """Constructs a ResNet-152 model.
264 | """
265 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=cfg.DATASET.NUM_CLASSES)
266 | if cfg.MODEL.PRETRAINED:
267 | pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
268 | model_dict = model.state_dict()
269 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
270 | pretrained_dict_temp.pop('fc.weight')
271 | pretrained_dict_temp.pop('fc.bias')
272 | model_dict.update(pretrained_dict_temp)
273 | model.load_state_dict(model_dict)
274 |
275 | return model
276 |
277 |
278 | def resnet():
279 | print("==> creating model '{}' ".format(cfg.MODEL.FEATURE_EXTRACTOR))
280 | if cfg.MODEL.FEATURE_EXTRACTOR == 'resnet18':
281 | return resnet18()
282 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet34':
283 | return resnet34()
284 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet50':
285 | return resnet50()
286 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet101':
287 | return resnet101()
288 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet152':
289 | return resnet152()
290 | else:
291 | raise ValueError('Unrecognized model architecture', cfg.MODEL.FEATURE_EXTRACTOR)
--------------------------------------------------------------------------------
/models/resnet_SymmNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | from torch.autograd import Function
5 | from config.config import cfg
6 | import torch
7 | import ipdb
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152']
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 |
21 |
22 | class ReverseLayerF(Function):
23 |
24 | @staticmethod
25 | def forward(ctx, x, alpha):
26 | ctx.alpha = alpha
27 |
28 | return x.view_as(x)
29 |
30 | @staticmethod
31 | def backward(ctx, grad_output):
32 | output = grad_output.neg() * ctx.alpha
33 |
34 | return output, None
35 |
36 | class ZeroLayerF(Function):
37 |
38 | @staticmethod
39 | def forward(ctx, x, alpha):
40 | ctx.alpha = 0.0
41 |
42 | return x.view_as(x)
43 |
44 | @staticmethod
45 | def backward(ctx, grad_output):
46 | output = grad_output * 0.0
47 |
48 | return output, None
49 |
50 | def conv3x3(in_planes, out_planes, stride=1):
51 | "3x3 convolution with padding"
52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 |
55 |
56 | class BasicBlock(nn.Module):
57 | expansion = 1
58 |
59 | def __init__(self, inplanes, planes, stride=1, downsample=None):
60 | super(BasicBlock, self).__init__()
61 | self.conv1 = conv3x3(inplanes, planes, stride)
62 | self.bn1 = nn.BatchNorm2d(planes)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.conv2 = conv3x3(planes, planes)
65 | self.bn2 = nn.BatchNorm2d(planes)
66 | self.downsample = downsample
67 | self.stride = stride
68 |
69 | def forward(self, x):
70 | residual = x
71 |
72 | out = self.conv1(x)
73 | out = self.bn1(out)
74 | out = self.relu(out)
75 |
76 | out = self.conv2(out)
77 | out = self.bn2(out)
78 |
79 | if self.downsample is not None:
80 | residual = self.downsample(x)
81 |
82 | out += residual
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 |
88 | class Bottleneck(nn.Module):
89 | expansion = 4
90 |
91 | def __init__(self, inplanes, planes, stride=1, downsample=None):
92 | super(Bottleneck, self).__init__()
93 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
94 | self.bn1 = nn.BatchNorm2d(planes)
95 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
96 | padding=1, bias=False)
97 | self.bn2 = nn.BatchNorm2d(planes)
98 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
99 | self.bn3 = nn.BatchNorm2d(planes * 4)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.downsample = downsample
102 | self.stride = stride
103 |
104 | def forward(self, x):
105 | residual = x
106 |
107 | out = self.conv1(x)
108 | out = self.bn1(out)
109 | out = self.relu(out)
110 |
111 | out = self.conv2(out)
112 | out = self.bn2(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv3(out)
116 | out = self.bn3(out)
117 |
118 | if self.downsample is not None:
119 | residual = self.downsample(x)
120 |
121 | out += residual
122 | out = self.relu(out)
123 |
124 | return out
125 |
126 |
127 | class ResNet(nn.Module):
128 |
129 | def __init__(self, block, layers, num_classes=1000):
130 | self.inplanes = 64
131 | super(ResNet, self).__init__()
132 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
133 | bias=False)
134 | self.bn1 = nn.BatchNorm2d(64)
135 | self.relu = nn.ReLU(inplace=True)
136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
137 | self.layer1 = self._make_layer(block, 64, layers[0])
138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
141 | self.avgpool = nn.AvgPool2d(7)
142 | # self.fc = nn.Linear(512 * block.expansion, num_classes*2) ## for classification
143 |
144 | for m in self.modules():
145 | if isinstance(m, nn.Conv2d):
146 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
147 | m.weight.data.normal_(0, math.sqrt(2. / n))
148 | elif isinstance(m, nn.BatchNorm2d):
149 | m.weight.data.fill_(1)
150 | m.bias.data.zero_()
151 |
152 | def _make_layer(self, block, planes, blocks, stride=1):
153 | downsample = None
154 | if stride != 1 or self.inplanes != planes * block.expansion:
155 | downsample = nn.Sequential(
156 | nn.Conv2d(self.inplanes, planes * block.expansion,
157 | kernel_size=1, stride=stride, bias=False),
158 | nn.BatchNorm2d(planes * block.expansion),
159 | )
160 |
161 | layers = []
162 | layers.append(block(self.inplanes, planes, stride, downsample))
163 | self.inplanes = planes * block.expansion
164 | for i in range(1, blocks):
165 | layers.append(block(self.inplanes, planes))
166 |
167 | return nn.Sequential(*layers)
168 |
169 | def forward(self, x):
170 | x = self.conv1(x)
171 | x = self.bn1(x)
172 | x = self.relu(x)
173 | x = self.maxpool(x)
174 |
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 |
180 | x = self.avgpool(x)
181 | x = x.view(x.size(0), -1)
182 | # out = self.fc(x)
183 |
184 | return x
185 |
186 |
187 | def resnet18():
188 | """Constructs a ResNet-18 model.
189 | """
190 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=cfg.DATASET.NUM_CLASSES)
191 | if cfg.MODEL.PRETRAINED:
192 | pretrained_dict = model_zoo.load_url(model_urls['resnet18'])
193 | model_dict = model.state_dict()
194 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
195 | model_dict.update(pretrained_dict_temp)
196 | model.load_state_dict(model_dict)
197 | classifier = nn.Linear(512 * BasicBlock.expansion, cfg.DATASET.NUM_CLASSES*2)
198 | return model, classifier
199 |
200 |
201 | def resnet34():
202 | """Constructs a ResNet-34 model.
203 | """
204 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=cfg.DATASET.NUM_CLASSES)
205 | if cfg.MODEL.PRETRAINED:
206 | pretrained_dict = model_zoo.load_url(model_urls['resnet34'])
207 | model_dict = model.state_dict()
208 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
209 | model_dict.update(pretrained_dict_temp)
210 | model.load_state_dict(model_dict)
211 | classifier = nn.Linear(512 * BasicBlock.expansion, cfg.DATASET.NUM_CLASSES * 2)
212 | return model, classifier
213 |
214 |
215 | def resnet50():
216 | """Constructs a ResNet-50 model.
217 | """
218 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=cfg.DATASET.NUM_CLASSES)
219 | if cfg.MODEL.PRETRAINED:
220 | print('load the ImageNet pretrained parameters')
221 | pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
222 | model_dict = model.state_dict()
223 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
224 | model_dict.update(pretrained_dict_temp)
225 | model.load_state_dict(model_dict)
226 | classifier = nn.Linear(512 * Bottleneck.expansion, cfg.DATASET.NUM_CLASSES * 2) ## the concatenation of two task classifiers
227 | return model, classifier
228 |
229 |
230 | def resnet101():
231 | """Constructs a ResNet-101 model.
232 | """
233 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=cfg.DATASET.NUM_CLASSES)
234 | if cfg.MODEL.PRETRAINED:
235 | pretrained_dict = model_zoo.load_url(model_urls['resnet101'])
236 | model_dict = model.state_dict()
237 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
238 | model_dict.update(pretrained_dict_temp)
239 | model.load_state_dict(model_dict)
240 | # classifier = nn.Linear(512 * Bottleneck.expansion, cfg.DATASET.NUM_CLASSES * 2)
241 | classifier = nn.Sequential(nn.Linear(512 * Bottleneck.expansion, 512),
242 | nn.BatchNorm1d(512),
243 | nn.ReLU(inplace=True),
244 | nn.Linear(512, cfg.DATASET.NUM_CLASSES * 2))
245 | return model, classifier
246 |
247 |
248 | def resnet152():
249 | """Constructs a ResNet-152 model.
250 | """
251 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=cfg.DATASET.NUM_CLASSES)
252 | if cfg.MODEL.PRETRAINED:
253 | pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
254 | model_dict = model.state_dict()
255 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
256 | model_dict.update(pretrained_dict_temp)
257 | model.load_state_dict(model_dict)
258 | classifier = nn.Linear(512 * Bottleneck.expansion, cfg.DATASET.NUM_CLASSES * 2)
259 | return model, classifier
260 |
261 |
262 | def resnet():
263 | print("==> creating model '{}' ".format(cfg.MODEL.FEATURE_EXTRACTOR))
264 | if cfg.MODEL.FEATURE_EXTRACTOR == 'resnet18':
265 | return resnet18()
266 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet34':
267 | return resnet34()
268 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet50':
269 | return resnet50()
270 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet101':
271 | return resnet101()
272 | elif cfg.MODEL.FEATURE_EXTRACTOR == 'resnet152':
273 | return resnet152()
274 | else:
275 | raise ValueError('Unrecognized model architecture', cfg.MODEL.FEATURE_EXTRACTOR)
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python ./tools/train.py --distance_type None --method SymmNetsV2 --task closedsc --cfg ./experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg_res101SC.yaml --exp_name logCloseSC
4 |
5 | python ./tools/train.py --distance_type None --method SymmNetsV2 --task closedsc --cfg ./experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg_res101SC.yaml --exp_name logCloseSC
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/run_temp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ## The example script. all the log file are present in the ./experiments file for verifying
3 |
4 | ### ##################################### Closed Set DA example #########################################
5 |
6 | ## McDalNet script
7 | #python ./tools/train.py --distance_type L1 --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml
8 | #
9 | #python ./tools/train.py --distance_type KL --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml
10 | #
11 | #python ./tools/train.py --distance_type CE --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml
12 | #
13 | #python ./tools/train.py --distance_type MDD --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml
14 | #
15 | #python ./tools/train.py --distance_type DANN --method McDalNet --cfg ./experiments/configs/ImageCLEF/McDalNet/clef_train_c2i_cfg.yaml
16 |
17 | #python ./tools/train.py --distance_type L1 --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml
18 | #
19 | #python ./tools/train.py --distance_type KL --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml
20 | #
21 | #python ./tools/train.py --distance_type CE --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml
22 | #
23 | #python ./tools/train.py --distance_type DANN --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml
24 | #
25 | #python ./tools/train.py --distance_type MDD --method McDalNet --cfg ./experiments/configs/VisDA/McDalNet/visda17_train_train2val_cfg.yaml
26 |
27 | ## SymmNets script
28 | #python ./tools/train.py --distance_type None --method SymmNetsV1 --cfg ./experiments/configs/ImageCLEF/SymmNets/clef_train_c2i_cfg.yaml
29 |
30 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg_res101.yaml --exp_name logres101
31 |
32 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/ImageCLEF/SymmNets/clef_train_c2i_cfg.yaml
33 |
34 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/VisDA/SymmNets/visda17_train_train2val_cfg.yaml
35 |
36 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/Office31/SymmNets/office31_train_amazon2dslr_cfg.yaml
37 |
38 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python ./tools/train.py --distance_type None --method SymmNetsV2 --cfg ./experiments/configs/Office31/SymmNets/office31_train_webcam2amazon_cfg.yaml
39 |
40 |
41 | ## SymmNets-SC script
42 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python ./tools/train.py --distance_type None --task closedsc --method SymmNetsV2 --cfg ./experiments/configs/Office31/SymmNets/office31_train_webcam2amazon_cfg_SC.yaml
43 |
44 | #CUDA_VISIBLE_DEVICES=4,5,6,7 python ./tools/train.py --distance_type None --task closedsc --method SymmNetsV2 --cfg ./experiments/configs/Office31/SymmNets/office31_train_amazon2dslr_cfg_SC.yaml
45 |
46 |
47 |
48 |
49 | ##################################### Partial DA example #######################
50 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --task partial --cfg ./experiments/configs/OfficeHome/SymmNets/home_train_A2R_partial_cfg.yaml
51 |
52 |
53 |
54 |
55 | ###################################### Open set DA example #####################
56 | #python ./tools/train.py --distance_type None --method SymmNetsV2 --task open --cfg ./experiments/configs/Office31/SymmNets/office31_train_amazon2webcam_open_cfg.yaml
57 | #
58 |
--------------------------------------------------------------------------------
/solver/McDalNet_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import math
5 | import time
6 | from utils.utils import to_cuda, accuracy_for_each_class, accuracy, AverageMeter, process_one_values
7 | from config.config import cfg
8 | import torch.nn.functional as F
9 | from models.loss_utils import McDalNetLoss
10 | from .base_solver import BaseSolver
11 | import ipdb
12 |
13 | class McDalNetSolver(BaseSolver):
14 | def __init__(self, net, dataloaders, **kwargs):
15 | super(McDalNetSolver, self).__init__(net, dataloaders, **kwargs)
16 | self.BCELoss = nn.BCEWithLogitsLoss().cuda()
17 | self.McDalNetLoss = McDalNetLoss().cuda()
18 | if cfg.RESUME != '':
19 | resume_dict = torch.load(cfg.RESUME)
20 | model_state_dict = resume_dict['model_state_dict']
21 | self.net.load_state_dict(model_state_dict)
22 | self.best_prec1 = resume_dict['best_prec1']
23 | self.epoch = resume_dict['epoch']
24 |
25 | def solve(self):
26 | stop = False
27 | while not stop:
28 | stop = self.complete_training()
29 | self.update_network()
30 | acc = self.test()
31 | if acc > self.best_prec1:
32 | self.best_prec1 = acc
33 | self.save_ckpt()
34 | self.epoch += 1
35 |
36 |
37 | def update_network(self, **kwargs):
38 | stop = False
39 | self.train_data['source']['iterator'] = iter(self.train_data['source']['loader'])
40 | self.train_data['target']['iterator'] = iter(self.train_data['target']['loader'])
41 | self.iters_per_epoch = len(self.train_data['target']['loader'])
42 | iters_counter_within_epoch = 0
43 | data_time = AverageMeter()
44 | batch_time = AverageMeter()
45 | total_loss = AverageMeter()
46 | ce_loss = AverageMeter()
47 | da_loss = AverageMeter()
48 | prec1_task = AverageMeter()
49 | prec1_aux1 = AverageMeter()
50 | prec1_aux2 = AverageMeter()
51 | self.net.train()
52 | end = time.time()
53 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
54 | lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1
55 | self.update_lr()
56 | print('value of lam is: %3f' % (lam))
57 | while not stop:
58 | if self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
59 | lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1
60 | print('value of lam is: %3f' % (lam))
61 | self.update_lr()
62 | source_data, source_gt = self.get_samples('source')
63 | target_data, _ = self.get_samples('target')
64 | source_data = to_cuda(source_data)
65 | source_gt = to_cuda(source_gt)
66 | target_data = to_cuda(target_data)
67 | data_time.update(time.time() - end)
68 |
69 | feature_source, output_source, output_source1, output_source2, output_source_dc, output_source1_trunc, output_source2_trunc = self.net(source_data, lam)
70 | loss_task_auxiliary_1 = self.CELoss(output_source1_trunc, source_gt)
71 | loss_task_auxiliary_2 = self.CELoss(output_source2_trunc, source_gt)
72 | loss_task = self.CELoss(output_source, source_gt)
73 | if self.opt.MCDALNET.DISTANCE_TYPE != 'SourceOnly':
74 | feature_target, output_target, output_target1, output_target2, output_target_dc, output_target1_trunc, output_target2_trunc = self.net(target_data, lam)
75 | if self.opt.MCDALNET.DISTANCE_TYPE == 'DANN':
76 | num_source = source_data.size()[0]
77 | num_target = target_data.size()[0]
78 | dlabel_source = to_cuda(torch.zeros(num_source, 1))
79 | dlabel_target = to_cuda(torch.ones(num_target, 1))
80 | loss_domain_all = self.BCELoss(output_source_dc, dlabel_source) + self.BCELoss(output_target_dc, dlabel_target)
81 | loss_all = loss_task + loss_domain_all
82 | elif self.opt.MCDALNET.DISTANCE_TYPE == 'MDD':
83 | prob_target1 = F.softmax(output_target1, dim=1)
84 | _, target_pseudo_label = torch.topk(output_target2, 1)
85 | batch_index = torch.arange(output_target.size()[0]).long()
86 | pred_gt_prob = prob_target1[batch_index, target_pseudo_label] ## the prob values of the predicted gt
87 | pred_gt_prob = process_one_values(pred_gt_prob)
88 | loss_domain_target = (1 - pred_gt_prob).log().mean()
89 |
90 | _, source_pseudo_label = torch.topk(output_source2, 1)
91 | loss_domain_source = self.CELoss(output_source1, source_pseudo_label[:, 0])
92 | loss_domain_all = loss_domain_source - loss_domain_target
93 | loss_all = loss_task + loss_domain_all + loss_task_auxiliary_1 + loss_task_auxiliary_2
94 | else:
95 | loss_domain_source = self.McDalNetLoss(output_source1, output_source2, self.opt.MCDALNET.DISTANCE_TYPE)
96 | loss_domain_target = self.McDalNetLoss(output_target1, output_target2, self.opt.MCDALNET.DISTANCE_TYPE)
97 | loss_domain_all = loss_domain_source - loss_domain_target
98 | loss_all = loss_task + loss_domain_all + loss_task_auxiliary_1 + loss_task_auxiliary_2
99 | da_loss.update(loss_domain_all, source_data.size()[0])
100 | else:
101 | loss_all = loss_task
102 | ce_loss.update(loss_task, source_data.size()[0])
103 | total_loss.update(loss_all, source_data.size()[0])
104 | prec1_task.update(accuracy(output_source, source_gt), source_data.size()[0])
105 | prec1_aux1.update(accuracy(output_source1, source_gt), source_data.size()[0])
106 | prec1_aux2.update(accuracy(output_source2, source_gt), source_data.size()[0])
107 |
108 | self.optimizer.zero_grad()
109 | loss_all.backward()
110 | self.optimizer.step()
111 |
112 | print(" Train:epoch: %d:[%d/%d], LossCE: %3f, LossDA: %3f, LossAll: %3f, Auxi1: %3f, Auxi2: %3f, Task: %3f" % \
113 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, ce_loss.avg, da_loss.avg, total_loss.avg, prec1_aux1.avg, prec1_aux2.avg, prec1_task.avg))
114 |
115 | batch_time.update(time.time() - end)
116 | end = time.time()
117 | self.iters += 1
118 | iters_counter_within_epoch += 1
119 | if iters_counter_within_epoch >= self.iters_per_epoch:
120 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
121 | log.write("\n")
122 | log.write(" Train:epoch: %d:[%d/%d], LossCE: %3f, LossDA: %3f, LossAll: %3f, Auxi1: %3f, Auxi2: %3f, Task: %3f" % \
123 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, ce_loss.avg, da_loss.avg, total_loss.avg, prec1_aux1.avg, prec1_aux2.avg, prec1_task.avg))
124 | log.close()
125 | stop = True
126 |
127 |
128 | def test(self):
129 | self.net.eval()
130 | prec1_task = AverageMeter()
131 | prec1_auxi1 = AverageMeter()
132 | prec1_auxi2 = AverageMeter()
133 | counter_all = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
134 | counter_all_auxi1 = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
135 | counter_all_auxi2 = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
136 | counter_acc = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
137 | counter_acc_auxi1 = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
138 | counter_acc_auxi2 = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
139 |
140 | for i, (input, target) in enumerate(self.test_data['loader']):
141 | input, target = to_cuda(input), to_cuda(target)
142 | with torch.no_grad():
143 | _, output_test, output_test1, output_test2, _, _, _ = self.net(input, 1) ## the value of lam do not affect the test process
144 |
145 | if self.opt.EVAL_METRIC == 'accu':
146 | prec1_task_iter = accuracy(output_test, target)
147 | prec1_auxi1_iter = accuracy(output_test1, target)
148 | prec1_auxi2_iter = accuracy(output_test2, target)
149 | prec1_task.update(prec1_task_iter, input.size(0))
150 | prec1_auxi1.update(prec1_auxi1_iter, input.size(0))
151 | prec1_auxi2.update(prec1_auxi2_iter, input.size(0))
152 | if i % self.opt.PRINT_STEP == 0:
153 | print(" Test:epoch: %d:[%d/%d], Auxi1: %3f, Auxi2: %3f, Task: %3f" % \
154 | (self.epoch, i, len(self.test_data['loader']), prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg))
155 | elif self.opt.EVAL_METRIC == 'accu_mean':
156 | prec1_task_iter = accuracy(output_test, target)
157 | prec1_task.update(prec1_task_iter, input.size(0))
158 | counter_all, counter_acc = accuracy_for_each_class(output_test, target, counter_all, counter_acc)
159 | counter_all_auxi1, counter_acc_auxi1 = accuracy_for_each_class(output_test1, target, counter_all_auxi1, counter_acc_auxi1)
160 | counter_all_auxi2, counter_acc_auxi2 = accuracy_for_each_class(output_test2, target, counter_all_auxi2, counter_acc_auxi2)
161 | if i % self.opt.PRINT_STEP == 0:
162 | print(" Test:epoch: %d:[%d/%d], Task: %3f" % \
163 | (self.epoch, i, len(self.test_data['loader']), prec1_task.avg))
164 | else:
165 | raise NotImplementedError
166 | acc_for_each_class = counter_acc / counter_all
167 | acc_for_each_class_auxi1 = counter_acc_auxi1 / counter_all_auxi1
168 | acc_for_each_class_auxi2 = counter_acc_auxi2 / counter_all_auxi2
169 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
170 | log.write("\n")
171 | if self.opt.EVAL_METRIC == 'accu':
172 | log.write(
173 | " Test:epoch: %d, Top1_auxi1: %3f, Top1_auxi2: %3f, Top1: %3f" % \
174 | (self.epoch, prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg))
175 | log.close()
176 | return max(prec1_auxi1.avg, prec1_auxi2.avg, prec1_task.avg)
177 | elif self.opt.EVAL_METRIC == 'accu_mean':
178 | log.write(
179 | " Test:epoch: %d, Top1_auxi1: %3f, Top1_auxi2: %3f, Top1: %3f" % \
180 | (self.epoch, acc_for_each_class_auxi1.mean(), acc_for_each_class_auxi2.mean(), acc_for_each_class.mean()))
181 | log.write("\nClass-wise Acc:") ## based on the task classifier.
182 | for i in range(self.opt.DATASET.NUM_CLASSES):
183 | if i == 0:
184 | log.write("%dst: %3f" % (i + 1, acc_for_each_class[i]))
185 | elif i == 1:
186 | log.write(", %dnd: %3f" % (i + 1, acc_for_each_class[i]))
187 | elif i == 2:
188 | log.write(", %drd: %3f" % (i + 1, acc_for_each_class[i]))
189 | else:
190 | log.write(", %dth: %3f" % (i + 1, acc_for_each_class[i]))
191 | log.close()
192 | return max(acc_for_each_class_auxi1.mean(), acc_for_each_class_auxi2.mean(), acc_for_each_class.mean())
193 |
194 | def build_optimizer(self):
195 | if self.opt.TRAIN.OPTIMIZER == 'SGD': ## some params may not contribute the loss_all, thus they are not updated in the training process.
196 | self.optimizer = torch.optim.SGD([
197 | {'params': self.net.module.conv1.parameters(), 'name': 'pre-trained'},
198 | {'params': self.net.module.bn1.parameters(), 'name': 'pre-trained'},
199 | {'params': self.net.module.layer1.parameters(), 'name': 'pre-trained'},
200 | {'params': self.net.module.layer2.parameters(), 'name': 'pre-trained'},
201 | {'params': self.net.module.layer3.parameters(), 'name': 'pre-trained'},
202 | {'params': self.net.module.layer4.parameters(), 'name': 'pre-trained'},
203 | {'params': self.net.module.fc.parameters(), 'name': 'new-added'},
204 | {'params': self.net.module.fc_aux1.parameters(), 'name': 'new-added'},
205 | {'params': self.net.module.fc_aux2.parameters(), 'name': 'new-added'},
206 | {'params': self.net.module.fcdc.parameters(), 'name': 'new-added'}
207 | ],
208 | lr=self.opt.TRAIN.BASE_LR,
209 | momentum=self.opt.TRAIN.MOMENTUM,
210 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY,
211 | nesterov=True)
212 | else:
213 | raise NotImplementedError
214 | print('Optimizer built')
215 |
216 | def update_lr(self):
217 | if self.opt.TRAIN.LR_SCHEDULE == 'inv':
218 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
219 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.epoch / self.opt.TRAIN.MAX_EPOCH), self.opt.INV.BETA)
220 | elif self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
221 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch)), self.opt.INV.BETA)
222 | else:
223 | raise NotImplementedError
224 | elif self.opt.TRAIN.LR_SCHEDULE == 'fix':
225 | lr = self.opt.TRAIN.BASE_LR
226 | else:
227 | raise NotImplementedError
228 | lr_pretrain = lr * 0.1
229 | print('the lr is: %3f' % (lr))
230 | for param_group in self.optimizer.param_groups:
231 | if param_group['name'] == 'pre-trained':
232 | param_group['lr'] = lr_pretrain
233 | elif param_group['name'] == 'new-added':
234 | param_group['lr'] = lr
235 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer
236 | param_group['lr'] = 0
237 |
238 | def save_ckpt(self):
239 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
240 | log.write(" Best Acc so far: %3f" % (self.best_prec1))
241 | log.close()
242 | if self.opt.TRAIN.SAVING:
243 | save_path = self.opt.SAVE_DIR
244 | ckpt_resume = os.path.join(save_path, 'ckpt_%d.resume' % (self.loop))
245 | torch.save({'epoch': self.epoch,
246 | 'best_prec1': self.best_prec1,
247 | 'model_state_dict': self.net.state_dict()
248 | }, ckpt_resume)
--------------------------------------------------------------------------------
/solver/SymmNetsV1_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import math
5 | import time
6 | from utils.utils import to_cuda, accuracy_for_each_class, accuracy, AverageMeter, process_one_values
7 | from config.config import cfg
8 | import torch.nn.functional as F
9 | from models.loss_utils import TargetDiscrimLoss, SourceDiscrimLoss, ConcatenatedEMLoss
10 | from .base_solver import BaseSolver
11 | import ipdb
12 |
13 | class SymmNetsV1Solver(BaseSolver):
14 | def __init__(self, net, dataloaders, **kwargs):
15 | super(SymmNetsV1Solver, self).__init__(net, dataloaders, **kwargs)
16 | self.num_classes = cfg.DATASET.NUM_CLASSES
17 | self.TargetDiscrimLoss = TargetDiscrimLoss(num_classes=self.num_classes).cuda()
18 | self.SourceDiscrimLoss = SourceDiscrimLoss(num_classes=self.num_classes).cuda()
19 | self.ConcatenatedEMLoss = ConcatenatedEMLoss(num_classes=self.num_classes).cuda()
20 | self.feature_extractor = self.net['feature_extractor']
21 | self.classifier = self.net['classifier']
22 |
23 | if cfg.RESUME != '':
24 | resume_dict = torch.load(cfg.RESUME)
25 | self.net['feature_extractor'].load_state_dict(resume_dict['feature_extractor_state_dict'])
26 | self.net['classifier'].load_state_dict(resume_dict['classifier_state_dict'])
27 | self.best_prec1 = resume_dict['best_prec1']
28 | self.epoch = resume_dict['epoch']
29 |
30 | def solve(self):
31 | stop = False
32 | while not stop:
33 | stop = self.complete_training()
34 | self.update_network()
35 | acc = self.test()
36 | if acc > self.best_prec1:
37 | self.best_prec1 = acc
38 | self.save_ckpt()
39 | self.epoch += 1
40 |
41 |
42 | def update_network(self, **kwargs):
43 | stop = False
44 | self.train_data['source']['iterator'] = iter(self.train_data['source']['loader'])
45 | self.train_data['target']['iterator'] = iter(self.train_data['target']['loader'])
46 | self.iters_per_epoch = len(self.train_data['target']['loader'])
47 | iters_counter_within_epoch = 0
48 | data_time = AverageMeter()
49 | batch_time = AverageMeter()
50 | classifier_loss = AverageMeter()
51 | feature_extractor_loss = AverageMeter()
52 | prec1_fs = AverageMeter()
53 | prec1_ft = AverageMeter()
54 | self.feature_extractor.train()
55 | self.classifier.train()
56 | end = time.time()
57 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
58 | lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1
59 | self.update_lr()
60 | print('value of lam is: %3f' % (lam))
61 | while not stop:
62 | if self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
63 | lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1
64 | print('value of lam is: %3f' % (lam))
65 | self.update_lr()
66 | source_data, source_gt = self.get_samples('source')
67 | target_data, _ = self.get_samples('target')
68 | source_data = to_cuda(source_data)
69 | source_gt = to_cuda(source_gt)
70 | target_data = to_cuda(target_data)
71 | data_time.update(time.time() - end)
72 |
73 | feature_source = self.feature_extractor(source_data)
74 | output_source = self.classifier(feature_source)
75 | feature_target = self.feature_extractor(target_data)
76 | output_target = self.classifier(feature_target)
77 |
78 | loss_task_fs = self.CELoss(output_source[:,:self.num_classes], source_gt)
79 | loss_task_ft = self.CELoss(output_source[:,self.num_classes:], source_gt)
80 | loss_discrim_source = self.SourceDiscrimLoss(output_source)
81 | loss_discrim_target = self.TargetDiscrimLoss(output_target)
82 | loss_summary_classifier = loss_task_fs + loss_task_ft + loss_discrim_source + loss_discrim_target
83 |
84 | source_gt_for_ft_in_fst = source_gt + self.num_classes
85 | loss_confusion_source = 0.5 * self.CELoss(output_source, source_gt) + 0.5 * self.CELoss(output_source, source_gt_for_ft_in_fst)
86 | loss_confusion_target = 0.5 * self.SourceDiscrimLoss(output_target) + 0.5 * self.TargetDiscrimLoss(output_target)
87 | loss_em = self.ConcatenatedEMLoss(output_target)
88 | loss_summary_feature_extractor = loss_confusion_source + lam * (loss_confusion_target + loss_em)
89 |
90 | self.optimizer_classifier.zero_grad()
91 | loss_summary_classifier.backward(retain_graph=True)
92 | self.optimizer_classifier.step()
93 |
94 | self.optimizer_feature_extractor.zero_grad()
95 | loss_summary_feature_extractor.backward()
96 | self.optimizer_feature_extractor.step()
97 |
98 | classifier_loss.update(loss_summary_classifier, source_data.size()[0])
99 | feature_extractor_loss.update(loss_summary_feature_extractor, source_data.size()[0])
100 | prec1_fs.update(accuracy(output_source[:, :self.num_classes], source_gt), source_data.size()[0])
101 | prec1_ft.update(accuracy(output_source[:, self.num_classes:], source_gt), source_data.size()[0])
102 |
103 | print(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \
104 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg))
105 |
106 | batch_time.update(time.time() - end)
107 | end = time.time()
108 | self.iters += 1
109 | iters_counter_within_epoch += 1
110 | if iters_counter_within_epoch >= self.iters_per_epoch:
111 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
112 | log.write("\n")
113 | log.write(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \
114 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg))
115 | log.close()
116 | stop = True
117 |
118 |
119 | def test(self):
120 | self.feature_extractor.eval()
121 | self.classifier.eval()
122 | prec1_fs = AverageMeter()
123 | prec1_ft = AverageMeter()
124 | counter_all_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
125 | counter_all_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
126 | counter_acc_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
127 | counter_acc_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
128 |
129 | for i, (input, target) in enumerate(self.test_data['loader']):
130 | input, target = to_cuda(input), to_cuda(target)
131 | with torch.no_grad():
132 | feature_test = self.feature_extractor(input)
133 | output_test = self.classifier(feature_test)
134 |
135 |
136 | if self.opt.EVAL_METRIC == 'accu':
137 | prec1_fs_iter = accuracy(output_test[:, :self.num_classes], target)
138 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target)
139 | prec1_fs.update(prec1_fs_iter, input.size(0))
140 | prec1_ft.update(prec1_ft_iter, input.size(0))
141 | if i % self.opt.PRINT_STEP == 0:
142 | print(" Test:epoch: %d:[%d/%d], AccFs: %3f, AccFt: %3f" % \
143 | (self.epoch, i, len(self.test_data['loader']), prec1_fs.avg, prec1_ft.avg))
144 | elif self.opt.EVAL_METRIC == 'accu_mean':
145 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target)
146 | prec1_ft.update(prec1_ft_iter, input.size(0))
147 | counter_all_fs, counter_acc_fs = accuracy_for_each_class(output_test[:, :self.num_classes], target, counter_all_fs, counter_acc_fs)
148 | counter_all_ft, counter_acc_ft = accuracy_for_each_class(output_test[:, self.num_classes:], target, counter_all_ft, counter_acc_ft)
149 | if i % self.opt.PRINT_STEP == 0:
150 | print(" Test:epoch: %d:[%d/%d], Task: %3f" % \
151 | (self.epoch, i, len(self.test_data['loader']), prec1_ft.avg))
152 | else:
153 | raise NotImplementedError
154 | acc_for_each_class_fs = counter_acc_fs / counter_all_fs
155 | acc_for_each_class_ft = counter_acc_ft / counter_all_ft
156 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
157 | log.write("\n")
158 | if self.opt.EVAL_METRIC == 'accu':
159 | log.write(
160 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \
161 | (self.epoch, prec1_fs.avg, prec1_ft.avg))
162 | log.close()
163 | return max(prec1_fs.avg, prec1_ft.avg)
164 | elif self.opt.EVAL_METRIC == 'accu_mean':
165 | log.write(
166 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \
167 | (self.epoch,acc_for_each_class_fs.mean(), acc_for_each_class_ft.mean()))
168 | log.write("\nClass-wise Acc of Ft:") ## based on the task classifier.
169 | for i in range(self.opt.DATASET.NUM_CLASSES):
170 | if i == 0:
171 | log.write("%dst: %3f" % (i + 1, acc_for_each_class_ft[i]))
172 | elif i == 1:
173 | log.write(", %dnd: %3f" % (i + 1, acc_for_each_class_ft[i]))
174 | elif i == 2:
175 | log.write(", %drd: %3f" % (i + 1, acc_for_each_class_ft[i]))
176 | else:
177 | log.write(", %dth: %3f" % (i + 1, acc_for_each_class_ft[i]))
178 | log.close()
179 | return max(acc_for_each_class_ft.mean(), acc_for_each_class_fs.mean())
180 |
181 | def build_optimizer(self):
182 | if self.opt.TRAIN.OPTIMIZER == 'SGD': ## some params may not contribute the loss_all, thus they are not updated in the training process.
183 | self.optimizer_feature_extractor = torch.optim.SGD([
184 | {'params': self.net['feature_extractor'].module.conv1.parameters(), 'name': 'pre-trained'},
185 | {'params': self.net['feature_extractor'].module.bn1.parameters(), 'name': 'pre-trained'},
186 | {'params': self.net['feature_extractor'].module.layer1.parameters(), 'name': 'pre-trained'},
187 | {'params': self.net['feature_extractor'].module.layer2.parameters(), 'name': 'pre-trained'},
188 | {'params': self.net['feature_extractor'].module.layer3.parameters(), 'name': 'pre-trained'},
189 | {'params': self.net['feature_extractor'].module.layer4.parameters(), 'name': 'pre-trained'},
190 | ],
191 | lr=self.opt.TRAIN.BASE_LR,
192 | momentum=self.opt.TRAIN.MOMENTUM,
193 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY,
194 | nesterov=True)
195 |
196 | self.optimizer_classifier = torch.optim.SGD([
197 | {'params': self.net['classifier'].parameters(), 'name': 'new-added'},
198 | ],
199 | lr=self.opt.TRAIN.BASE_LR,
200 | momentum=self.opt.TRAIN.MOMENTUM,
201 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY,
202 | nesterov=True)
203 | else:
204 | raise NotImplementedError
205 | print('Optimizer built')
206 |
207 | def update_lr(self):
208 | if self.opt.TRAIN.LR_SCHEDULE == 'inv':
209 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
210 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.epoch / self.opt.TRAIN.MAX_EPOCH), self.opt.INV.BETA)
211 | elif self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
212 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch)), self.opt.INV.BETA)
213 | else:
214 | raise NotImplementedError
215 | elif self.opt.TRAIN.LR_SCHEDULE == 'fix':
216 | lr = self.opt.TRAIN.BASE_LR
217 | else:
218 | raise NotImplementedError
219 | lr_pretrain = lr * 0.1
220 | print('the lr is: %3f' % (lr))
221 | for param_group in self.optimizer_feature_extractor.param_groups:
222 | if param_group['name'] == 'pre-trained':
223 | param_group['lr'] = lr_pretrain
224 | elif param_group['name'] == 'new-added':
225 | param_group['lr'] = lr
226 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer
227 | param_group['lr'] = 0
228 |
229 | for param_group in self.optimizer_classifier.param_groups:
230 | if param_group['name'] == 'pre-trained':
231 | param_group['lr'] = lr_pretrain
232 | elif param_group['name'] == 'new-added':
233 | param_group['lr'] = lr
234 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer
235 | param_group['lr'] = 0
236 |
237 | def save_ckpt(self):
238 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
239 | log.write(" Best Acc so far: %3f" % (self.best_prec1))
240 | log.close()
241 | if self.opt.TRAIN.SAVING:
242 | save_path = self.opt.SAVE_DIR
243 | ckpt_resume = os.path.join(save_path, 'ckpt_%d.resume' % (self.loop))
244 | torch.save({'epoch': self.epoch,
245 | 'best_prec1': self.best_prec1,
246 | 'feature_extractor_state_dict': self.net['feature_extractor'].state_dict(),
247 | 'classifier_state_dict': self.net['classifier'].state_dict()
248 | }, ckpt_resume)
--------------------------------------------------------------------------------
/solver/SymmNetsV2Partial_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import math
5 | import time
6 | from utils.utils import to_cuda, accuracy_for_each_class, accuracy, AverageMeter, process_one_values
7 | from config.config import cfg
8 | import torch.nn.functional as F
9 | from models.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss, CrossEntropyClassWeighted
10 | from .base_solver import BaseSolver
11 | import ipdb
12 |
13 | class SymmNetsV2PartialSolver(BaseSolver):
14 | def __init__(self, net, dataloaders, **kwargs):
15 | super(SymmNetsV2PartialSolver, self).__init__(net, dataloaders, **kwargs)
16 | self.num_classes = cfg.DATASET.NUM_CLASSES
17 | self.TargetDiscrimLoss = TargetDiscrimLoss(num_classes=self.num_classes).cuda()
18 | self.ConcatenatedCELoss = ConcatenatedCELoss(num_classes=self.num_classes).cuda()
19 | self.feature_extractor = self.net['feature_extractor']
20 | self.classifier = self.net['classifier']
21 | self.lam = 0
22 | class_weight_initial = torch.ones(self.num_classes) ############################ class-level weight to filter out the outlier classes.
23 | self.class_weight_initial = class_weight_initial.cuda()
24 | class_weight = torch.ones(self.num_classes) ############################ class-level weight to filter out the outlier classes.
25 | self.class_weight = class_weight.cuda()
26 | self.softweight = True
27 | self.CELossWeight = CrossEntropyClassWeighted()
28 |
29 | if cfg.RESUME != '':
30 | resume_dict = torch.load(cfg.RESUME)
31 | self.net['feature_extractor'].load_state_dict(resume_dict['feature_extractor_state_dict'])
32 | self.net['classifier'].load_state_dict(resume_dict['classifier_state_dict'])
33 | self.best_prec1 = resume_dict['best_prec1']
34 | self.epoch = resume_dict['epoch']
35 |
36 | def solve(self):
37 | stop = False
38 | while not stop:
39 | stop = self.complete_training()
40 | self.update_network()
41 | prediction_weight, acc = self.test()
42 | prediction_weight = prediction_weight.cuda()
43 | if self.softweight:
44 | self.class_weight = prediction_weight * self.lam + self.class_weight_initial * (1 - self.lam)
45 | else:
46 | self.class_weight = prediction_weight
47 | print('the class weight adopted in partial DA')
48 | print(self.class_weight)
49 | if acc > self.best_prec1:
50 | self.best_prec1 = acc
51 | self.save_ckpt()
52 | self.epoch += 1
53 |
54 |
55 | def update_network(self, **kwargs):
56 | stop = False
57 | self.train_data['source']['iterator'] = iter(self.train_data['source']['loader'])
58 | self.train_data['target']['iterator'] = iter(self.train_data['target']['loader'])
59 | self.iters_per_epoch = len(self.train_data['target']['loader'])
60 | iters_counter_within_epoch = 0
61 | data_time = AverageMeter()
62 | batch_time = AverageMeter()
63 | classifier_loss = AverageMeter()
64 | feature_extractor_loss = AverageMeter()
65 | prec1_fs = AverageMeter()
66 | prec1_ft = AverageMeter()
67 | self.feature_extractor.train()
68 | self.classifier.train()
69 | end = time.time()
70 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
71 | self.lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1
72 | self.update_lr()
73 | print('value of lam is: %3f' % (self.lam))
74 | while not stop:
75 | if self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
76 | self.lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1
77 | print('value of lam is: %3f' % (self.lam))
78 | self.update_lr()
79 | source_data, source_gt = self.get_samples('source')
80 | target_data, _ = self.get_samples('target')
81 | source_data = to_cuda(source_data)
82 | source_gt = to_cuda(source_gt)
83 | target_data = to_cuda(target_data)
84 | data_time.update(time.time() - end)
85 |
86 | feature_source = self.feature_extractor(source_data)
87 | output_source = self.classifier(feature_source)
88 | feature_target = self.feature_extractor(target_data)
89 | output_target = self.classifier(feature_target)
90 |
91 | weight_concate = torch.cat((self.class_weight, self.class_weight))
92 | loss_task_fs = self.CELossWeight(output_source[:,:self.num_classes], source_gt, self.class_weight)
93 | loss_task_ft = self.CELossWeight(output_source[:,self.num_classes:], source_gt, self.class_weight)
94 | loss_discrim_source = self.CELossWeight(output_source, source_gt, weight_concate)
95 | loss_discrim_target = self.TargetDiscrimLoss(output_target)
96 | loss_summary_classifier = loss_task_fs + loss_task_ft + loss_discrim_source + loss_discrim_target
97 |
98 | source_gt_for_ft_in_fst = source_gt + self.num_classes
99 | loss_confusion_source = 0.5 * self.CELossWeight(output_source, source_gt, weight_concate) + 0.5 * self.CELossWeight(output_source, source_gt_for_ft_in_fst, weight_concate)
100 | loss_confusion_target = self.ConcatenatedCELoss(output_target)
101 | loss_summary_feature_extractor = loss_confusion_source + self.lam * loss_confusion_target
102 |
103 | self.optimizer_classifier.zero_grad()
104 | loss_summary_classifier.backward(retain_graph=True)
105 | self.optimizer_classifier.step()
106 |
107 | self.optimizer_feature_extractor.zero_grad()
108 | loss_summary_feature_extractor.backward()
109 | self.optimizer_feature_extractor.step()
110 |
111 | classifier_loss.update(loss_summary_classifier, source_data.size()[0])
112 | feature_extractor_loss.update(loss_summary_feature_extractor, source_data.size()[0])
113 | prec1_fs.update(accuracy(output_source[:, :self.num_classes], source_gt), source_data.size()[0])
114 | prec1_ft.update(accuracy(output_source[:, self.num_classes:], source_gt), source_data.size()[0])
115 |
116 | print(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \
117 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg))
118 |
119 | batch_time.update(time.time() - end)
120 | end = time.time()
121 | self.iters += 1
122 | iters_counter_within_epoch += 1
123 | if iters_counter_within_epoch >= self.iters_per_epoch:
124 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
125 | log.write("\n")
126 | log.write(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \
127 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg))
128 | log.close()
129 | stop = True
130 |
131 | def test(self):
132 | self.feature_extractor.eval()
133 | self.classifier.eval()
134 | prec1_fs = AverageMeter()
135 | prec1_ft = AverageMeter()
136 | counter_all_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
137 | counter_all_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
138 | counter_acc_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
139 | counter_acc_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
140 | class_weight = torch.zeros(self.num_classes)
141 | class_weight = class_weight.cuda()
142 | count = 0
143 |
144 |
145 | for i, (input, target) in enumerate(self.test_data['loader']):
146 | input, target = to_cuda(input), to_cuda(target)
147 | with torch.no_grad():
148 | feature_test = self.feature_extractor(input)
149 | output_test = self.classifier(feature_test)
150 | prob = F.softmax(output_test[:, self.num_classes:], dim=1)
151 | class_weight = class_weight + prob.data.sum(0)
152 | count = count + input.size(0)
153 |
154 | if self.opt.EVAL_METRIC == 'accu':
155 | prec1_fs_iter = accuracy(output_test[:, :self.num_classes], target)
156 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target)
157 | prec1_fs.update(prec1_fs_iter, input.size(0))
158 | prec1_ft.update(prec1_ft_iter, input.size(0))
159 | if i % self.opt.PRINT_STEP == 0:
160 | print(" Test:epoch: %d:[%d/%d], AccFs: %3f, AccFt: %3f" % \
161 | (self.epoch, i, len(self.test_data['loader']), prec1_fs.avg, prec1_ft.avg))
162 | elif self.opt.EVAL_METRIC == 'accu_mean':
163 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target)
164 | prec1_ft.update(prec1_ft_iter, input.size(0))
165 | counter_all_fs, counter_acc_fs = accuracy_for_each_class(output_test[:, :self.num_classes], target, counter_all_fs, counter_acc_fs)
166 | counter_all_ft, counter_acc_ft = accuracy_for_each_class(output_test[:, self.num_classes:], target, counter_all_ft, counter_acc_ft)
167 | if i % self.opt.PRINT_STEP == 0:
168 | print(" Test:epoch: %d:[%d/%d], Task: %3f" % \
169 | (self.epoch, i, len(self.test_data['loader']), prec1_ft.avg))
170 | else:
171 | raise NotImplementedError
172 | acc_for_each_class_fs = counter_acc_fs / counter_all_fs
173 | acc_for_each_class_ft = counter_acc_ft / counter_all_ft
174 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
175 | log.write("\n")
176 | class_weight = class_weight / count
177 | class_weight = class_weight / max(class_weight)
178 | if self.opt.EVAL_METRIC == 'accu':
179 | log.write(
180 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \
181 | (self.epoch, prec1_fs.avg, prec1_ft.avg))
182 | log.close()
183 | return class_weight, max(prec1_fs.avg, prec1_ft.avg)
184 | elif self.opt.EVAL_METRIC == 'accu_mean':
185 | log.write(
186 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \
187 | (self.epoch,acc_for_each_class_fs.mean(), acc_for_each_class_ft.mean()))
188 | log.write("\nClass-wise Acc of Ft:") ## based on the task classifier.
189 | for i in range(self.opt.DATASET.NUM_CLASSES):
190 | if i == 0:
191 | log.write("%dst: %3f" % (i + 1, acc_for_each_class_ft[i]))
192 | elif i == 1:
193 | log.write(", %dnd: %3f" % (i + 1, acc_for_each_class_ft[i]))
194 | elif i == 2:
195 | log.write(", %drd: %3f" % (i + 1, acc_for_each_class_ft[i]))
196 | else:
197 | log.write(", %dth: %3f" % (i + 1, acc_for_each_class_ft[i]))
198 | log.close()
199 | return class_weight, max(acc_for_each_class_ft.mean(), acc_for_each_class_fs.mean())
200 |
201 | def build_optimizer(self):
202 | if self.opt.TRAIN.OPTIMIZER == 'SGD': ## some params may not contribute the loss_all, thus they are not updated in the training process.
203 | self.optimizer_feature_extractor = torch.optim.SGD([
204 | {'params': self.net['feature_extractor'].module.conv1.parameters(), 'name': 'pre-trained'},
205 | {'params': self.net['feature_extractor'].module.bn1.parameters(), 'name': 'pre-trained'},
206 | {'params': self.net['feature_extractor'].module.layer1.parameters(), 'name': 'pre-trained'},
207 | {'params': self.net['feature_extractor'].module.layer2.parameters(), 'name': 'pre-trained'},
208 | {'params': self.net['feature_extractor'].module.layer3.parameters(), 'name': 'pre-trained'},
209 | {'params': self.net['feature_extractor'].module.layer4.parameters(), 'name': 'pre-trained'},
210 | ],
211 | lr=self.opt.TRAIN.BASE_LR,
212 | momentum=self.opt.TRAIN.MOMENTUM,
213 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY,
214 | nesterov=True)
215 |
216 | self.optimizer_classifier = torch.optim.SGD([
217 | {'params': self.net['classifier'].parameters(), 'name': 'new-added'},
218 | ],
219 | lr=self.opt.TRAIN.BASE_LR,
220 | momentum=self.opt.TRAIN.MOMENTUM,
221 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY,
222 | nesterov=True)
223 | else:
224 | raise NotImplementedError
225 | print('Optimizer built')
226 |
227 | def update_lr(self):
228 | if self.opt.TRAIN.LR_SCHEDULE == 'inv':
229 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
230 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.epoch / self.opt.TRAIN.MAX_EPOCH), self.opt.INV.BETA)
231 | elif self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
232 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch)), self.opt.INV.BETA)
233 | else:
234 | raise NotImplementedError
235 | elif self.opt.TRAIN.LR_SCHEDULE == 'fix':
236 | lr = self.opt.TRAIN.BASE_LR
237 | else:
238 | raise NotImplementedError
239 | lr_pretrain = lr * 0.1
240 | print('the lr is: %3f' % (lr))
241 | for param_group in self.optimizer_feature_extractor.param_groups:
242 | if param_group['name'] == 'pre-trained':
243 | param_group['lr'] = lr_pretrain
244 | elif param_group['name'] == 'new-added':
245 | param_group['lr'] = lr
246 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer
247 | param_group['lr'] = 0
248 |
249 | for param_group in self.optimizer_classifier.param_groups:
250 | if param_group['name'] == 'pre-trained':
251 | param_group['lr'] = lr_pretrain
252 | elif param_group['name'] == 'new-added':
253 | param_group['lr'] = lr
254 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer
255 | param_group['lr'] = 0
256 |
257 | def save_ckpt(self):
258 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
259 | log.write(" Best Acc so far: %3f" % (self.best_prec1))
260 | log.close()
261 | if self.opt.TRAIN.SAVING:
262 | save_path = self.opt.SAVE_DIR
263 | ckpt_resume = os.path.join(save_path, 'ckpt_%d.resume' % (self.loop))
264 | torch.save({'epoch': self.epoch,
265 | 'best_prec1': self.best_prec1,
266 | 'feature_extractor_state_dict': self.net['feature_extractor'].state_dict(),
267 | 'classifier_state_dict': self.net['classifier'].state_dict()
268 | }, ckpt_resume)
269 |
--------------------------------------------------------------------------------
/solver/SymmNetsV2_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import math
5 | import time
6 | from utils.utils import to_cuda, accuracy_for_each_class, accuracy, AverageMeter, process_one_values
7 | from config.config import cfg
8 | import torch.nn.functional as F
9 | from models.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
10 | from .base_solver import BaseSolver
11 | import ipdb
12 |
13 | class SymmNetsV2Solver(BaseSolver):
14 | def __init__(self, net, dataloaders, **kwargs):
15 | super(SymmNetsV2Solver, self).__init__(net, dataloaders, **kwargs)
16 | self.num_classes = cfg.DATASET.NUM_CLASSES
17 | self.TargetDiscrimLoss = TargetDiscrimLoss(num_classes=self.num_classes).cuda()
18 | self.ConcatenatedCELoss = ConcatenatedCELoss(num_classes=self.num_classes).cuda()
19 | self.feature_extractor = self.net['feature_extractor']
20 | self.classifier = self.net['classifier']
21 |
22 | if cfg.RESUME != '':
23 | resume_dict = torch.load(cfg.RESUME)
24 | self.net['feature_extractor'].load_state_dict(resume_dict['feature_extractor_state_dict'])
25 | self.net['classifier'].load_state_dict(resume_dict['classifier_state_dict'])
26 | self.best_prec1 = resume_dict['best_prec1']
27 | self.epoch = resume_dict['epoch']
28 |
29 | def solve(self):
30 | stop = False
31 | while not stop:
32 | stop = self.complete_training()
33 | self.update_network()
34 | acc = self.test()
35 | if acc > self.best_prec1:
36 | self.best_prec1 = acc
37 | self.save_ckpt()
38 | self.epoch += 1
39 |
40 |
41 | def update_network(self, **kwargs):
42 | stop = False
43 | self.train_data['source']['iterator'] = iter(self.train_data['source']['loader'])
44 | self.train_data['target']['iterator'] = iter(self.train_data['target']['loader'])
45 | self.iters_per_epoch = max(len(self.train_data['target']['loader']), len(self.train_data['source']['loader']))
46 | iters_counter_within_epoch = 0
47 | data_time = AverageMeter()
48 | batch_time = AverageMeter()
49 | classifier_loss = AverageMeter()
50 | feature_extractor_loss = AverageMeter()
51 | prec1_fs = AverageMeter()
52 | prec1_ft = AverageMeter()
53 | self.feature_extractor.train()
54 | self.classifier.train()
55 | end = time.time()
56 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
57 | lam = 2 / (1 + math.exp(-1 * 10 * self.epoch / self.opt.TRAIN.MAX_EPOCH)) - 1
58 | self.update_lr()
59 | print('value of lam is: %3f' % (lam))
60 | while not stop:
61 | if self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
62 | lam = 2 / (1 + math.exp(-1 * 10 * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch))) - 1
63 | print('value of lam is: %3f' % (lam))
64 | self.update_lr()
65 | source_data, source_gt = self.get_samples('source')
66 | target_data, _ = self.get_samples('target')
67 | source_data = to_cuda(source_data)
68 | source_gt = to_cuda(source_gt)
69 | target_data = to_cuda(target_data)
70 | data_time.update(time.time() - end)
71 |
72 | feature_source = self.feature_extractor(source_data)
73 | output_source = self.classifier(feature_source)
74 | feature_target = self.feature_extractor(target_data)
75 | output_target = self.classifier(feature_target)
76 |
77 | loss_task_fs = self.CELoss(output_source[:,:self.num_classes], source_gt)
78 | loss_task_ft = self.CELoss(output_source[:,self.num_classes:], source_gt)
79 | loss_discrim_source = self.CELoss(output_source, source_gt)
80 | loss_discrim_target = self.TargetDiscrimLoss(output_target)
81 | loss_summary_classifier = loss_task_fs + loss_task_ft + loss_discrim_source + loss_discrim_target
82 |
83 | source_gt_for_ft_in_fst = source_gt + self.num_classes
84 | loss_confusion_source = 0.5 * self.CELoss(output_source, source_gt) + 0.5 * self.CELoss(output_source, source_gt_for_ft_in_fst)
85 | loss_confusion_target = self.ConcatenatedCELoss(output_target)
86 | loss_summary_feature_extractor = loss_confusion_source + lam * loss_confusion_target
87 |
88 | self.optimizer_classifier.zero_grad()
89 | loss_summary_classifier.backward(retain_graph=True)
90 | self.optimizer_classifier.step()
91 |
92 | self.optimizer_feature_extractor.zero_grad()
93 | loss_summary_feature_extractor.backward()
94 | self.optimizer_feature_extractor.step()
95 |
96 | classifier_loss.update(loss_summary_classifier, source_data.size()[0])
97 | feature_extractor_loss.update(loss_summary_feature_extractor, source_data.size()[0])
98 | prec1_fs.update(accuracy(output_source[:, :self.num_classes], source_gt), source_data.size()[0])
99 | prec1_ft.update(accuracy(output_source[:, self.num_classes:], source_gt), source_data.size()[0])
100 |
101 | print(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \
102 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg))
103 |
104 | batch_time.update(time.time() - end)
105 | end = time.time()
106 | self.iters += 1
107 | iters_counter_within_epoch += 1
108 | if iters_counter_within_epoch >= self.iters_per_epoch:
109 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
110 | log.write("\n")
111 | log.write(" Train:epoch: %d:[%d/%d], LossCla: %3f, LossFeat: %3f, AccFs: %3f, AccFt: %3f" % \
112 | (self.epoch, iters_counter_within_epoch, self.iters_per_epoch, classifier_loss.avg, feature_extractor_loss.avg, prec1_fs.avg, prec1_ft.avg))
113 | log.close()
114 | stop = True
115 |
116 |
117 | def test(self):
118 | self.feature_extractor.eval()
119 | self.classifier.eval()
120 | prec1_fs = AverageMeter()
121 | prec1_ft = AverageMeter()
122 | counter_all_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
123 | counter_all_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
124 | counter_acc_fs = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
125 | counter_acc_ft = torch.FloatTensor(self.opt.DATASET.NUM_CLASSES).fill_(0)
126 |
127 | for i, (input, target) in enumerate(self.test_data['loader']):
128 | input, target = to_cuda(input), to_cuda(target)
129 | with torch.no_grad():
130 | feature_test = self.feature_extractor(input)
131 | output_test = self.classifier(feature_test)
132 |
133 |
134 | if self.opt.EVAL_METRIC == 'accu':
135 | prec1_fs_iter = accuracy(output_test[:, :self.num_classes], target)
136 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target)
137 | prec1_fs.update(prec1_fs_iter, input.size(0))
138 | prec1_ft.update(prec1_ft_iter, input.size(0))
139 | if i % self.opt.PRINT_STEP == 0:
140 | print(" Test:epoch: %d:[%d/%d], AccFs: %3f, AccFt: %3f" % \
141 | (self.epoch, i, len(self.test_data['loader']), prec1_fs.avg, prec1_ft.avg))
142 | elif self.opt.EVAL_METRIC == 'accu_mean':
143 | prec1_ft_iter = accuracy(output_test[:, self.num_classes:], target)
144 | prec1_ft.update(prec1_ft_iter, input.size(0))
145 | counter_all_fs, counter_acc_fs = accuracy_for_each_class(output_test[:, :self.num_classes], target, counter_all_fs, counter_acc_fs)
146 | counter_all_ft, counter_acc_ft = accuracy_for_each_class(output_test[:, self.num_classes:], target, counter_all_ft, counter_acc_ft)
147 | if i % self.opt.PRINT_STEP == 0:
148 | print(" Test:epoch: %d:[%d/%d], Task: %3f" % \
149 | (self.epoch, i, len(self.test_data['loader']), prec1_ft.avg))
150 | else:
151 | raise NotImplementedError
152 | acc_for_each_class_fs = counter_acc_fs / counter_all_fs
153 | acc_for_each_class_ft = counter_acc_ft / counter_all_ft
154 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
155 | log.write("\n")
156 | if self.opt.EVAL_METRIC == 'accu':
157 | log.write(
158 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \
159 | (self.epoch, prec1_fs.avg, prec1_ft.avg))
160 | log.close()
161 | return max(prec1_fs.avg, prec1_ft.avg)
162 | elif self.opt.EVAL_METRIC == 'accu_mean':
163 | log.write(
164 | " Test:epoch: %d, AccFs: %3f, AccFt: %3f" % \
165 | (self.epoch,acc_for_each_class_fs.mean(), acc_for_each_class_ft.mean()))
166 | log.write("\nClass-wise Acc of Ft:") ## based on the task classifier.
167 | for i in range(self.opt.DATASET.NUM_CLASSES):
168 | if i == 0:
169 | log.write("%dst: %3f" % (i + 1, acc_for_each_class_ft[i]))
170 | elif i == 1:
171 | log.write(", %dnd: %3f" % (i + 1, acc_for_each_class_ft[i]))
172 | elif i == 2:
173 | log.write(", %drd: %3f" % (i + 1, acc_for_each_class_ft[i]))
174 | else:
175 | log.write(", %dth: %3f" % (i + 1, acc_for_each_class_ft[i]))
176 | log.close()
177 | return max(acc_for_each_class_ft.mean(), acc_for_each_class_fs.mean())
178 |
179 | def build_optimizer(self):
180 | if self.opt.TRAIN.OPTIMIZER == 'SGD': ## some params may not contribute the loss_all, thus they are not updated in the training process.
181 | self.optimizer_feature_extractor = torch.optim.SGD([
182 | {'params': self.net['feature_extractor'].module.conv1.parameters(), 'name': 'pre-trained'},
183 | {'params': self.net['feature_extractor'].module.bn1.parameters(), 'name': 'pre-trained'},
184 | {'params': self.net['feature_extractor'].module.layer1.parameters(), 'name': 'pre-trained'},
185 | {'params': self.net['feature_extractor'].module.layer2.parameters(), 'name': 'pre-trained'},
186 | {'params': self.net['feature_extractor'].module.layer3.parameters(), 'name': 'pre-trained'},
187 | {'params': self.net['feature_extractor'].module.layer4.parameters(), 'name': 'pre-trained'},
188 | ],
189 | lr=self.opt.TRAIN.BASE_LR,
190 | momentum=self.opt.TRAIN.MOMENTUM,
191 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY,
192 | nesterov=True)
193 |
194 | self.optimizer_classifier = torch.optim.SGD([
195 | {'params': self.net['classifier'].parameters(), 'name': 'new-added'},
196 | ],
197 | lr=self.opt.TRAIN.BASE_LR,
198 | momentum=self.opt.TRAIN.MOMENTUM,
199 | weight_decay=self.opt.TRAIN.WEIGHT_DECAY,
200 | nesterov=True)
201 | else:
202 | raise NotImplementedError
203 | print('Optimizer built')
204 |
205 | def update_lr(self):
206 | if self.opt.TRAIN.LR_SCHEDULE == 'inv':
207 | if self.opt.TRAIN.PROCESS_COUNTER == 'epoch':
208 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.epoch / self.opt.TRAIN.MAX_EPOCH), self.opt.INV.BETA)
209 | elif self.opt.TRAIN.PROCESS_COUNTER == 'iteration':
210 | lr = self.opt.TRAIN.BASE_LR / pow((1 + self.opt.INV.ALPHA * self.iters / (self.opt.TRAIN.MAX_EPOCH * self.iters_per_epoch)), self.opt.INV.BETA)
211 | else:
212 | raise NotImplementedError
213 | elif self.opt.TRAIN.LR_SCHEDULE == 'fix':
214 | lr = self.opt.TRAIN.BASE_LR
215 | else:
216 | raise NotImplementedError
217 | lr_pretrain = lr * 0.1
218 | print('the lr is: %3f' % (lr))
219 | for param_group in self.optimizer_feature_extractor.param_groups:
220 | if param_group['name'] == 'pre-trained':
221 | param_group['lr'] = lr_pretrain
222 | elif param_group['name'] == 'new-added':
223 | param_group['lr'] = lr
224 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer
225 | param_group['lr'] = 0
226 |
227 | for param_group in self.optimizer_classifier.param_groups:
228 | if param_group['name'] == 'pre-trained':
229 | param_group['lr'] = lr_pretrain
230 | elif param_group['name'] == 'new-added':
231 | param_group['lr'] = lr
232 | elif param_group['name'] == 'fixed': ## Fix the lr as 0 can not fix the runing mean/var of the BN layer
233 | param_group['lr'] = 0
234 |
235 | def save_ckpt(self):
236 | log = open(os.path.join(self.opt.SAVE_DIR, 'log.txt'), 'a')
237 | log.write(" Best Acc so far: %3f" % (self.best_prec1))
238 | log.close()
239 | if self.opt.TRAIN.SAVING:
240 | save_path = self.opt.SAVE_DIR
241 | ckpt_resume = os.path.join(save_path, 'ckpt_%d.resume' % (self.loop))
242 | torch.save({'epoch': self.epoch,
243 | 'best_prec1': self.best_prec1,
244 | 'feature_extractor_state_dict': self.net['feature_extractor'].state_dict(),
245 | 'classifier_state_dict': self.net['classifier'].state_dict()
246 | }, ckpt_resume)
247 |
--------------------------------------------------------------------------------
/solver/base_solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | from utils.utils import to_cuda, AverageMeter
5 | from config.config import cfg
6 |
7 | class BaseSolver:
8 | def __init__(self, net, dataloaders, **kwargs):
9 | self.opt = cfg
10 | self.net = net
11 | self.dataloaders = dataloaders
12 | self.CELoss = nn.CrossEntropyLoss()
13 | if torch.cuda.is_available():
14 | self.CELoss.cuda()
15 | self.epoch = 0
16 | self.iters = 0
17 | self.best_prec1 = 0
18 | self.iters_per_epoch = None
19 | self.build_optimizer()
20 | self.init_data(self.dataloaders)
21 |
22 | def init_data(self, dataloaders):
23 | self.train_data = {key: dict() for key in dataloaders if key != 'test'}
24 | for key in self.train_data.keys():
25 | if key not in dataloaders:
26 | continue
27 | cur_dataloader = dataloaders[key]
28 | self.train_data[key]['loader'] = cur_dataloader
29 | self.train_data[key]['iterator'] = None
30 |
31 | if 'test' in dataloaders:
32 | self.test_data = dict()
33 | self.test_data['loader'] = dataloaders['test']
34 |
35 | def build_optimizer(self):
36 | print('Optimizer built')
37 |
38 |
39 | def complete_training(self):
40 | if self.epoch > self.opt.TRAIN.MAX_EPOCH:
41 | return True
42 |
43 | def solve(self):
44 | print('Training Done!')
45 |
46 | def get_samples(self, data_name):
47 | assert(data_name in self.train_data)
48 | assert('loader' in self.train_data[data_name] and \
49 | 'iterator' in self.train_data[data_name])
50 |
51 | data_loader = self.train_data[data_name]['loader']
52 | data_iterator = self.train_data[data_name]['iterator']
53 | assert data_loader is not None and data_iterator is not None, \
54 | 'Check your dataloader of %s.' % data_name
55 |
56 | try:
57 | sample = next(data_iterator)
58 | except StopIteration:
59 | data_iterator = iter(data_loader)
60 | sample = next(data_iterator)
61 | self.train_data[data_name]['iterator'] = data_iterator
62 | return sample
63 |
64 |
65 | def update_network(self, **kwargs):
66 | pass
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import ipdb
3 | import argparse
4 | import os
5 | import numpy as np
6 | from torch.backends import cudnn
7 | from config.config import cfg, cfg_from_file, cfg_from_list
8 | import sys
9 | import pprint
10 | import json
11 |
12 |
13 | def parse_args():
14 | """
15 | Parse input arguments
16 | """
17 | parser = argparse.ArgumentParser(description='Train script.')
18 | parser.add_argument('--resume', dest='resume',
19 | help='initialize with saved solver status',
20 | default=None, type=str)
21 | parser.add_argument('--cfg', dest='cfg_file',
22 | help='optional config file',
23 | default=None, type=str)
24 | parser.add_argument('--set', dest='set_cfgs',
25 | help='set config keys', default=None,
26 | nargs=argparse.REMAINDER)
27 | parser.add_argument('--method', dest='method',
28 | help='set the method to use',
29 | default='McDalNet', type=str)
30 | parser.add_argument('--task', dest='task',
31 | help='closed | partial | open',
32 | default='closed', type=str)
33 | parser.add_argument('--distance_type', dest='distance_type',
34 | help='set distance type in McDalNet',
35 | default='L1', type=str)
36 |
37 | parser.add_argument('--exp_name', dest='exp_name',
38 | help='the experiment name',
39 | default='log', type=str)
40 |
41 |
42 | if len(sys.argv) == 1:
43 | parser.print_help()
44 | sys.exit(1)
45 |
46 | args = parser.parse_args()
47 | return args
48 |
49 | def train(args):
50 |
51 | # method-specific setting
52 | if args.method == 'McDalNet':
53 | if cfg.DATASET.DATASET == 'Digits':
54 | raise NotImplementedError
55 | else:
56 | from solver.McDalNet_solver import McDalNetSolver as Solver
57 | from models.resnet_McDalNet import resnet as Model
58 | from data.prepare_data import generate_dataloader as Dataloader
59 | net = Model()
60 | net = torch.nn.DataParallel(net)
61 | if torch.cuda.is_available():
62 | net.cuda()
63 | elif args.method == 'SymmNetsV2':
64 | if args.task == 'closed':
65 | if cfg.DATASET.DATASET == 'Digits':
66 | raise NotImplementedError
67 | else:
68 | from solver.SymmNetsV2_solver import SymmNetsV2Solver as Solver
69 | from models.resnet_SymmNet import resnet as Model
70 | from data.prepare_data import generate_dataloader as Dataloader
71 | feature_extractor, classifier = Model()
72 | feature_extractor = torch.nn.DataParallel(feature_extractor)
73 | classifier = torch.nn.DataParallel(classifier)
74 | if torch.cuda.is_available():
75 | feature_extractor.cuda()
76 | classifier.cuda()
77 | net = {'feature_extractor': feature_extractor, 'classifier': classifier}
78 | elif args.task == 'partial':
79 | from solver.SymmNetsV2Partial_solver import SymmNetsV2PartialSolver as Solver
80 | from models.resnet_SymmNet import resnet as Model
81 | from data.prepare_data import generate_dataloader as Dataloader
82 | feature_extractor, classifier = Model()
83 | feature_extractor = torch.nn.DataParallel(feature_extractor)
84 | classifier = torch.nn.DataParallel(classifier)
85 | if torch.cuda.is_available():
86 | feature_extractor.cuda()
87 | classifier.cuda()
88 | net = {'feature_extractor': feature_extractor, 'classifier': classifier}
89 | elif args.task == 'open':
90 | from solver.SymmNetsV2Open_solver import SymmNetsV2OpenSolver as Solver
91 | from models.resnet_SymmNet import resnet as Model
92 | from data.prepare_data import generate_dataloader_open as Dataloader
93 | feature_extractor, classifier = Model()
94 | feature_extractor = torch.nn.DataParallel(feature_extractor)
95 | classifier = torch.nn.DataParallel(classifier)
96 | if torch.cuda.is_available():
97 | feature_extractor.cuda()
98 | classifier.cuda()
99 | net = {'feature_extractor': feature_extractor, 'classifier': classifier}
100 | elif args.task == 'closedsc':
101 | if cfg.DATASET.DATASET == 'Digits':
102 | raise NotImplementedError
103 | else:
104 | from solver.SymmNetsV2SC_solver import SymmNetsV2SolverSC as Solver
105 | from models.resnet_SymmNet import resnet as Model
106 | from data.prepare_data import generate_dataloader_sc as Dataloader
107 | feature_extractor, classifier = Model()
108 | feature_extractor = torch.nn.DataParallel(feature_extractor)
109 | classifier = torch.nn.DataParallel(classifier)
110 | if torch.cuda.is_available():
111 | feature_extractor.cuda()
112 | classifier.cuda()
113 | net = {'feature_extractor': feature_extractor, 'classifier': classifier}
114 | else:
115 | raise NotImplementedError("Currently don't support the specified method: %s." % (args.task))
116 |
117 | ## Algorithm proposed in our CVPR19 paper: Domain-Symnetric Networks for Adversarial Domain Adaptation
118 | ## It is the same with our previous implementation of https://github.com/YBZh/SymNets
119 | elif args.method == 'SymmNetsV1':
120 | if cfg.DATASET.DATASET == 'Digits':
121 | raise NotImplementedError
122 | else:
123 | from solver.SymmNetsV1_solver import SymmNetsV1Solver as Solver
124 | from models.resnet_SymmNet import resnet as Model
125 | from data.prepare_data import generate_dataloader as Dataloader
126 | feature_extractor, classifier = Model()
127 | feature_extractor = torch.nn.DataParallel(feature_extractor)
128 | classifier = torch.nn.DataParallel(classifier)
129 | if torch.cuda.is_available():
130 | feature_extractor.cuda()
131 | classifier.cuda()
132 | net = {'feature_extractor': feature_extractor, 'classifier': classifier}
133 | else:
134 | raise NotImplementedError("Currently don't support the specified method: %s." % (args.method))
135 |
136 | dataloaders = Dataloader()
137 |
138 | # initialize solver
139 | train_solver = Solver(net, dataloaders)
140 |
141 | # train
142 | train_solver.solve()
143 | print('Finished!')
144 |
145 | if __name__ == '__main__':
146 | cudnn.benchmark = True
147 | args = parse_args()
148 |
149 | print('Called with args:')
150 | print(args)
151 |
152 | if args.cfg_file is not None:
153 | cfg_from_file(args.cfg_file)
154 | # if args.set_cfgs is not None:
155 | # cfg_from_list(args.set_cfgs)
156 |
157 | if args.resume is not None:
158 | cfg.RESUME = args.resume
159 | if args.exp_name is not None:
160 | cfg.EXP_NAME = args.exp_name + cfg.DATASET.DATASET + '_' + cfg.DATASET.SOURCE_NAME + '2' + cfg.DATASET.VAL_NAME + '_' + args.method + '_' +args.distance_type + args.task
161 | if args.distance_type is not None:
162 | cfg.MCDALNET.DISTANCE_TYPE = args.distance_type
163 |
164 | print('Using config:')
165 | pprint.pprint(cfg)
166 |
167 | cfg.SAVE_DIR = os.path.join(cfg.SAVE_DIR, cfg.EXP_NAME)
168 | print('Output will be saved to %s.' % cfg.SAVE_DIR)
169 | if not os.path.isdir(cfg.SAVE_DIR):
170 | os.makedirs(cfg.SAVE_DIR)
171 | log = open(os.path.join(cfg.SAVE_DIR, 'log.txt'), 'a')
172 | log.write("\n")
173 | log.write(json.dumps(cfg) + '\n')
174 | log.close()
175 |
176 | train(args)
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/MultiClassDA/b0f61a5fe82f8b5414a14e8d77753fbf5d4bcb93/utils/__init__.py
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def to_cuda(x):
5 | if torch.cuda.is_available():
6 | x = x.cuda()
7 | return x
8 |
9 | def to_cpu(x):
10 | return x.cpu()
11 |
12 | def to_numpy(x):
13 | if torch.cuda.is_available():
14 | x = x.cpu()
15 | return x.data.numpy()
16 |
17 | def to_onehot(label, num_classes):
18 | identity = torch.eye(num_classes).to(label.device)
19 | onehot = torch.index_select(identity, 0, label)
20 | return onehot
21 |
22 | def accuracy(output, target):
23 | """Computes the precision"""
24 | batch_size = target.size(0)
25 | _, pred = output.topk(1, 1, True, True)
26 | pred = pred.t()
27 | correct = pred.eq(target.view(1, -1).expand_as(pred))
28 |
29 | correct = correct[:1].view(-1).float().sum(0, keepdim=True)
30 | res = correct.mul_(100.0 / batch_size)
31 | return res
32 |
33 |
34 | def accuracy_for_each_class(output, target, total_vector, correct_vector):
35 | """Computes the precision for each class"""
36 | batch_size = target.size(0)
37 | _, pred = output.topk(1, 1, True, True)
38 | pred = pred.t()
39 | correct = pred.eq(target.view(1, -1)).float().cpu().squeeze()
40 | for i in range(batch_size):
41 | total_vector[target[i]] += 1
42 | correct_vector[torch.LongTensor([target[i]])] += correct[i]
43 |
44 | return total_vector, correct_vector
45 |
46 | def recall_for_each_class(output, target, total_vector, correct_vector):
47 | """Computes the recall for each class"""
48 | batch_size = target.size(0)
49 | _, pred = output.topk(1, 1, True, True)
50 | pred = pred.t()
51 | correct = pred.eq(target.view(1, -1)).float().cpu().squeeze()
52 | for i in range(batch_size):
53 | total_vector[pred[0][i]] += 1
54 | correct_vector[torch.LongTensor([pred[0][i]])] += correct[i]
55 |
56 | return total_vector, correct_vector
57 |
58 | def process_one_values(tensor):
59 | if (tensor == 1).sum() != 0:
60 | eps = torch.FloatTensor(tensor.size()).fill_(0)
61 | eps[tensor.data.cpu() == 1] = 1e-6
62 | tensor = tensor - eps.cuda()
63 | return tensor
64 |
65 | def process_zero_values(tensor):
66 | if (tensor == 0).sum() != 0:
67 | eps = torch.FloatTensor(tensor.size()).fill_(0)
68 | eps[tensor.data.cpu() == 0] = 1e-6
69 | tensor = tensor + eps.cuda()
70 | return tensor
71 |
72 |
73 | class AverageMeter(object):
74 | """Computes and stores the average and current value"""
75 | def __init__(self):
76 | self.reset()
77 |
78 | def reset(self):
79 | self.val = 0
80 | self.avg = 0
81 | self.sum = 0
82 | self.count = 0
83 |
84 | def update(self, val, n=1):
85 | self.val = val
86 | self.sum += val * n
87 | self.count += n
88 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------