├── office-train-config.yaml
├── data.py
├── SFDA_test.py
├── config.py
├── lib.py
├── README.md
├── APM_update.py
├── net.py
└── SFDA_train.py
/office-train-config.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataloader: {batch_size: 32, class_balance: true, data_workers: 2}
3 | dataset: {n_share: 31, n_total: 31, name: office, root_path: ../data/office,
4 | source: 0, target: 2 }
5 | log: {log_interval: 10, root_dir: public}
6 | misc: {gpus: 1}
7 | model: {base_model: resnet50, pretrained_model: False}
8 | test: {resume_file: Clipart_to_Art.pkl, test_interval: 1000, test_only: false}
9 | train: {update_freq: 100, lr: 0.001, min_step: 5000, momentum: 0.9, weight_decay: 0.0005}
10 |
11 |
12 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from config import *
2 | from easydl import *
3 | from collections import Counter
4 | from torchvision.transforms.transforms import *
5 | from torch.utils.data import DataLoader, WeightedRandomSampler
6 |
7 |
8 | source_classes = [i for i in range(args.data.dataset.n_total)]
9 | target_classes = [i for i in range(args.data.dataset.n_share)]
10 |
11 |
12 |
13 | train_transform = Compose([
14 | Resize(256),
15 | RandomCrop(224),
16 | RandomHorizontalFlip(),
17 | ToTensor(),
18 | ])
19 |
20 |
21 | test_transform = Compose([
22 | Resize(256),
23 | CenterCrop(224),
24 | ToTensor(),
25 | ])
26 |
27 |
28 | source_train_ds = FileListDataset(list_path=source_file, path_prefix=dataset.prefixes[args.data.dataset.source],
29 | transform=train_transform, filter=(lambda x: x in source_classes))
30 | source_test_ds = FileListDataset(list_path=source_file,path_prefix=dataset.prefixes[args.data.dataset.source],
31 | transform=test_transform, filter=(lambda x: x in source_classes))
32 | target_train_ds = FileListDataset(list_path=target_file, path_prefix=dataset.prefixes[args.data.dataset.target],
33 | transform=train_transform, filter=(lambda x: x in target_classes))
34 | target_test_ds = FileListDataset(list_path=target_file, path_prefix=dataset.prefixes[args.data.dataset.target],
35 | transform=test_transform, filter=(lambda x: x in target_classes))
36 |
37 |
38 | classes = source_train_ds.labels
39 | freq = Counter(classes)
40 | class_weight = {x : 1.0 / freq[x] if args.data.dataloader.class_balance else 1.0 for x in freq}
41 |
42 |
43 | source_weights = [class_weight[x] for x in source_train_ds.labels]
44 | sampler = WeightedRandomSampler(source_weights, len(source_train_ds.labels))
45 |
46 | source_train_dl = DataLoader(dataset=source_train_ds, batch_size=args.data.dataloader.batch_size,
47 | sampler=sampler, num_workers=args.data.dataloader.data_workers, drop_last=True)
48 | source_test_dl = DataLoader(dataset=source_test_ds, batch_size=args.data.dataloader.batch_size, shuffle=False,
49 | num_workers=1, drop_last=False)
50 | target_train_dl = DataLoader(dataset=target_train_ds, batch_size=args.data.dataloader.batch_size,shuffle=True,
51 | num_workers=args.data.dataloader.data_workers, drop_last=True)
52 | target_test_dl = DataLoader(dataset=target_test_ds, batch_size=args.data.dataloader.batch_size, shuffle=False,
53 | num_workers=1, drop_last=False)
--------------------------------------------------------------------------------
/SFDA_test.py:
--------------------------------------------------------------------------------
1 | from data import *
2 | from net import *
3 | from lib import *
4 | from torch import optim
5 | from APM_update import *
6 | import torch.backends.cudnn as cudnn
7 | import time
8 |
9 | cudnn.benchmark = True
10 | cudnn.deterministic = True
11 |
12 | def seed_everything(seed=1234):
13 | import random
14 | random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 | np.random.seed(seed)
18 | import os
19 | os.environ['PYTHONHASHSEED'] = str(seed)
20 |
21 | seed_everything()
22 |
23 | save_model_path = 'pretrained_weights/'+str(args.data.dataset.source)+str(args.data.dataset.target)+'/'+'domain'+ str(args.data.dataset.source)+str(args.data.dataset.target)+'accBEST_model_checkpoint.pth.tar'
24 | save_model_statedict = torch.load(save_model_path)['state_dict']
25 |
26 | model_dict = {
27 | 'resnet50': ResNet50Fc,
28 | 'vgg16': VGG16Fc
29 | }
30 |
31 | # ======= network architecture =======
32 | class Target_TrainableNet(nn.Module):
33 | def __init__(self):
34 | super(Target_TrainableNet, self).__init__()
35 | self.feature_extractor = model_dict[args.model.base_model](args.model.pretrained_model)
36 | classifier_output_dim = len(source_classes)
37 | self.classifier = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256)
38 | self.cls_multibranch = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256)
39 |
40 |
41 | # ======= target network =======
42 | trainable_tragetNet = Target_TrainableNet()
43 | trainable_tragetNet.load_state_dict(save_model_statedict)
44 |
45 | feature_extractor_t =(trainable_tragetNet.feature_extractor).cuda()
46 | classifier_s2t = (trainable_tragetNet.classifier).cuda()
47 | classifier_t = (trainable_tragetNet.cls_multibranch).cuda()
48 | print ("Finish model loaded...")
49 |
50 |
51 | domains=['amazon', 'dslr', 'webcam']
52 | print ('domain....'+domains[args.data.dataset.source]+'>>>>>>'+domains[args.data.dataset.target])
53 |
54 | counter = AccuracyCounter()
55 | with TrainingModeManager([feature_extractor_t, classifier_t], train=False) as mgr, torch.no_grad():
56 |
57 | for i, (img, label) in enumerate(target_test_dl):
58 | img = img.cuda()
59 | label = label.cuda()
60 |
61 | feature = feature_extractor_t.forward(img)
62 | ___, __, before_softmax, predict_prob = classifier_t.forward(feature)
63 |
64 | counter.addOneBatch(variable_to_numpy(predict_prob), variable_to_numpy(one_hot(label, args.data.dataset.n_total)))
65 |
66 | acc_test = counter.reportAccuracy()
67 | print('>>>>>>>Test Accuracy>>>>>>>>>>.')
68 | print(acc_test)
69 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>.')
70 |
71 | exit()
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import easydict
4 | from os.path import join
5 |
6 |
7 | class Dataset:
8 | def __init__(self, path, domains, files, prefix):
9 | self.path = path
10 | self.prefix = prefix
11 | self.domains = domains
12 | self.files = [(join(path, file)) for file in files]
13 | self.prefixes = [self.prefix] * len(self.domains)
14 |
15 |
16 | import argparse
17 | parser = argparse.ArgumentParser(description='Code for *Learning to Transfer Examples for Partial Domain Adaptation*',
18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19 | parser.add_argument('--config', type=str, default='config.yaml', help='/path/to/config/file')
20 |
21 | args = parser.parse_args()
22 |
23 | config_file = args.config
24 |
25 | args = yaml.load(open(config_file))
26 |
27 | save_config = yaml.load(open(config_file))
28 |
29 | args = easydict.EasyDict(args)
30 |
31 | dataset = None
32 | if args.data.dataset.name == 'office':
33 | dataset = Dataset(
34 | path=args.data.dataset.root_path,
35 | domains=['amazon', 'dslr', 'webcam'],
36 | files=[
37 | 'amazon_31_list.txt',
38 | 'dslr_31_list.txt',
39 | 'webcam_31_list.txt'
40 | ],
41 | prefix=args.data.dataset.root_path)
42 | elif args.data.dataset.name == 'officehome':
43 | dataset = Dataset(
44 | path=args.data.dataset.root_path,
45 | domains=['Art', 'Clipart', 'Product', 'Real_World'],
46 | files=[
47 | 'Art.txt',
48 | 'Clipart.txt',
49 | 'Product.txt',
50 | 'Real_World.txt'
51 | ],
52 | prefix=args.data.dataset.root_path)
53 | elif args.data.dataset.name == 'visda':
54 | dataset = Dataset(
55 | path=args.data.dataset.root_path,
56 | domains=['simulation', 'real'],
57 | files=[
58 | 'simulation_image_list.txt',
59 | 'real_image_list.txt'
60 | ],
61 | prefix=args.data.dataset.root_path)
62 | elif args.data.dataset.name == 'C2I':
63 | dataset = Dataset(
64 | path=args.data.dataset.root_path,
65 | domains=['Caltech', 'ImageNet'],
66 | files=[
67 | 'C2I_caltech256.txt',
68 | 'C2I_imagenet84.txt'
69 | ],
70 | prefix=args.data.dataset.root_path)
71 | elif args.data.dataset.name == 'I2C':
72 | dataset = Dataset(
73 | path=args.data.dataset.root_path,
74 | domains=['ImageNet', 'Caltech'],
75 | files=[
76 | 'I2C_imagenet1000.txt',
77 | 'I2C_caltech84.txt'
78 | ],
79 | prefix=args.data.dataset.root_path)
80 | else:
81 | raise Exception(f'dataset {args.data.dataset.name} not supported!')
82 | # print (args.data.dataset.source)
83 | source_domain_name = dataset.domains[args.data.dataset.source]
84 | target_domain_name = dataset.domains[args.data.dataset.target]
85 | source_file = dataset.files[args.data.dataset.source]
86 | target_file = dataset.files[args.data.dataset.target]
87 | print ("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ")
88 | print ("source_domain_name :", source_domain_name, "target_domain_name :", target_domain_name)
89 |
--------------------------------------------------------------------------------
/lib.py:
--------------------------------------------------------------------------------
1 | from easydl import *
2 | import torch.nn.functional as F
3 |
4 | def reverse_sigmoid(y):
5 | return torch.log(y / (1.0 - y + 1e-10) + 1e-10)
6 |
7 |
8 | def get_source_share_weight(domain_out, before_softmax, domain_temperature=1.0, class_temperature=10.0):
9 | before_softmax = before_softmax / class_temperature
10 | after_softmax = nn.Softmax(-1)(before_softmax)
11 | domain_logit = reverse_sigmoid(domain_out) # why reverse layer?: do sigmoid()-1
12 | domain_logit = domain_logit / domain_temperature
13 | domain_out = nn.Sigmoid()(domain_logit)
14 |
15 | entropy = torch.sum(- after_softmax * torch.log(after_softmax + 1e-10), dim=1, keepdim=True)
16 | entropy_norm = entropy / np.log(after_softmax.size(1))
17 | weight = entropy_norm - domain_out
18 | weight = weight.detach()
19 | return weight
20 |
21 |
22 | def get_source_share_weight_onlyentropy( before_softmax, class_temperature=10.0):
23 | before_softmax = before_softmax / class_temperature
24 | after_softmax = nn.Softmax(-1)(before_softmax)
25 | # print (after_softmax)
26 |
27 | entropy = torch.sum(- after_softmax * torch.log(after_softmax + 1e-5), dim=1, keepdim=True)
28 | entropy_norm = entropy / (np.log(after_softmax.size(1)) )
29 | # print (entropy_norm)
30 | weight = entropy_norm
31 | weight = weight.detach()
32 | return weight
33 |
34 | def hellinger_distance(p, q):
35 | return torch.norm((torch.sqrt(p) - torch.sqrt(q)), p=2, dim=1) / np.sqrt(2)
36 |
37 |
38 |
39 |
40 | def get_commonness_weight(ps_s, pt_s, ps_t, pt_t, class_temperature=10.0):
41 |
42 | ps_s = F.softmax(ps_s / class_temperature)
43 | pt_s = F.softmax(pt_s / class_temperature)
44 | ps_t = F.softmax(ps_t)
45 | pt_t = F.softmax(pt_t)
46 |
47 | ws = hellinger_distance(ps_s, pt_s).detach()
48 | wt = hellinger_distance(ps_t, pt_t).detach()
49 |
50 | return ws, wt
51 |
52 |
53 |
54 | def get_entropy(domain_out, before_softmax, domain_temperature=1.0, class_temperature=10.0):
55 | before_softmax = before_softmax / class_temperature
56 | after_softmax = nn.Softmax(-1)(before_softmax)
57 | domain_logit = reverse_sigmoid(domain_out) # why reverse layer?: do sigmoid()-1
58 | domain_logit = domain_logit / domain_temperature
59 | domain_out = nn.Sigmoid()(domain_logit)
60 |
61 | entropy = torch.sum(- after_softmax * torch.log(after_softmax + 1e-10), dim=1, keepdim=True)
62 | entropy_norm = entropy / np.log(after_softmax.size(1))
63 | weight = entropy_norm
64 | weight = weight.detach()
65 | return weight
66 |
67 |
68 | def get_target_share_weight(domain_out, before_softmax, domain_temperature=1.0, class_temperature=10.0):
69 | return - get_source_share_weight(domain_out, before_softmax, domain_temperature, class_temperature)
70 |
71 |
72 | def normalize_weight(x):
73 | min_val = x.min()
74 | max_val = x.max()
75 | x = (x - min_val) / (max_val - min_val +1e-5)
76 | x = x / (torch.mean(x)+1e-5) # why do this?
77 | return x.detach()
78 |
79 |
80 | def normalize_weight01(x):
81 | min_val = x.min()
82 | max_val = x.max()
83 | x = (x - min_val) / (max_val - min_val)
84 | return x.detach()
85 |
86 | def normalize_weight_11(x):
87 | min_val = x.min()
88 | max_val = x.max()
89 | x = (x - min_val) / (max_val - min_val)
90 | x = x*2 - 1
91 | return x.detach()
92 |
93 | def seed_everything(seed=1234):
94 | import random
95 | random.seed(seed)
96 | torch.manual_seed(seed)
97 | torch.cuda.manual_seed_all(seed)
98 | np.random.seed(seed)
99 | import os
100 | os.environ['PYTHONHASHSEED'] = str(seed)
101 |
102 |
103 | def tensor_l2normalization(q):
104 | qn = torch.norm(q, p=2, dim=1).detach().unsqueeze(1)
105 | q = q.div(qn.expand_as(q))
106 | return q
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SFDA - Domain Adaptation without Source Data
2 |
3 | ## Prerequisites
4 | * Ubuntu 18.04
5 | * Python 3.6+
6 | * PyTorch 1.5+ (recent version is recommended)
7 | * NVIDIA GPU (>= 12GB)
8 | * CUDA 10.0 (optional)
9 | * CUDNN 7.5 (optional)
10 |
11 | ## Getting Started
12 |
13 | ### Installation
14 | * Configure virtual (anaconda) environment
15 | ```
16 | conda create -n env_name python=3.6
17 | source activate env_name
18 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
19 | ```
20 | * Install python libraries
21 | ```
22 | conda install -c conda-forge matplotlib
23 | conda install -c anaconda yaml
24 | conda install -c anaconda pyyaml
25 | conda install -c anaconda scipy
26 | conda install -c anaconda scikit-learn
27 | conda install -c conda-forge easydict
28 | pip install easydl
29 | ```
30 |
31 | ### Download this repository
32 | * We provide two versions of the repository (with dataset / without dataset) for a flexible experiment
33 |
34 | * Full SFDA repository (with dataset): [download link][aa]
35 | * In this case, go to ```training and testing``` step directly
36 |
37 | [aa]: https://drive.google.com/drive/folders/11g8yOWxIG47G-5vImtX98qrg0Y4UxrGd?usp=sharing
38 |
39 | * SFDA repository (without dataset): [download link][a]
40 |
41 | [a]: https://drive.google.com/drive/folders/1ndxbQLAkDxxvlPs7E65_6fQ4dNbxXkHR?usp=sharing
42 |
43 |
44 | * Visualization of repository structure (Full SFDA repository)
45 |
46 | ```
47 | |-- APM_update.py
48 | |-- SFDA_test.py
49 | |-- SFDA_train.py
50 | |-- config.py
51 | |-- data.py
52 | |-- lib.py
53 | |-- net.py
54 | |-- office-train-config.yaml
55 | |-- data
56 | | `-- office
57 | | |-- domain_adaptation_images
58 | | | |-- amazon
59 | | | | `-- images
60 | | | |-- dslr
61 | | | | `-- images
62 | | | `--- webcam
63 | | | `-- images
64 | | |-- amazon_31_list.txt
65 | | |-- dslr_31_list.txt
66 | | `-- webcam_31_list.txt
67 | |-- pretrained_weights
68 | | `-- 02
69 | | `-- domain02accBEST_model_checkpoint.pth.tar
70 | `-- source_pretrained_weights
71 | `-- 02
72 | `-- model_checkpoint.pth.tar
73 | ```
74 |
75 | ### Download dataset
76 | * Download the Office31 dataset ([link][b]) and unzip in ```./data/office```
77 |
78 | [b]: https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view
79 |
80 | * Download the text file ([link][c]), e.g., amazon_31_list.txt, dslr_31_list.txt, webcam_31_list.txt, in ```./data/office```
81 |
82 | [c]: https://drive.google.com/drive/folders/11wFsBoG--cm7uD0L-7L5X5hprWDCMBpH?usp=sharing
83 |
84 |
85 | ### Download source-pretrained parameters (Fs and Cs of Figure 2 in our main paper)
86 | * Download source-pretrained parameters ([link][d]) in ```./source_pretrained_weights/[scenario_number]```
87 |
88 | [d]: https://drive.google.com/drive/folders/1mkzEl8SHQ0mVFnYV0CvZIdeLstCm2shy?usp=sharing
89 |
90 | ex) Source-pretrained parameters of A[0] -> W[2] senario should be located in ```./source_pretrained_weights/02```
91 |
92 |
93 | ## Training and testing
94 |
95 | * Arguments required for training and testing are contained in ```office-train-config.yaml ```
96 | * Here is an example of running an experiment on Office31 (default: A -> W)
97 | * Scenario can be changed by editing ```source: 0, target: 2``` in ```office-train-config.yaml```
98 | * We will update the full version of our framework including settings for ```OfficeHome``` and ```Visda-C```
99 |
100 | ### Training
101 |
102 | * Run the following command
103 |
104 | ```
105 | python SFDA_train.py --config office-train-config.yaml
106 | ```
107 |
108 | ### Testing (on pretrained model)
109 |
110 | * As a first step, download SFDA pretrained parameters ([link][e]) in ```./pretrained_weights/[scenario_number]```
111 |
112 | ex) SFDA pretrained parameters of A[0] -> W[2] senario should be located in ```./pretrained_weights/02```
113 |
114 | [e]: https://drive.google.com/drive/folders/1XiWZXsES_oEAI2WMdOBxqjKieA7zOOwZ?usp=sharing
115 |
116 | * or run the training code to obtain pretrained weights
117 |
118 | * Run the following command
119 |
120 | ```
121 | python SFDA_test.py --config office-train-config.yaml
122 | ```
123 |
124 |
125 |
126 | ## Experimental results on Office31
127 |
128 | * Results using the provided code
129 |
130 | |
| A→W | D→W | W→D | A→D | D→A | W→A | Avg |
131 | |:--------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
132 | |Accuracy (%) | 91.06 | 97.35 | 98.99 | 91.96 | 71.60 | 68.62 | 86.60 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/APM_update.py:
--------------------------------------------------------------------------------
1 | from data import *
2 | from lib import *
3 | import time
4 |
5 | def APM_init_update(feature_extractor, classifier_t):
6 | start_time = time.time()
7 | available_cls = []
8 | h_dict = {}
9 | feat_dict = {}
10 | missing_cls = []
11 | after_softmax_numpy_for_emergency = []
12 | feature_numpy_for_emergency = []
13 | max_prototype_bound = 100
14 |
15 | for cls in range(len(source_classes)):
16 | h_dict[cls] = []
17 | feat_dict[cls] = []
18 |
19 | for (im_target_lbcorr, label_target_lbcorr) in target_train_dl:
20 | im_target_lbcorr = im_target_lbcorr.cuda()
21 | fc1_lbcorr = feature_extractor.forward(im_target_lbcorr)
22 | _, _, _, after_softmax = classifier_t.forward(fc1_lbcorr)
23 | after_softmax_numpy_for_emergency.append(after_softmax.data.cpu().numpy())
24 | feature_numpy_for_emergency.append(fc1_lbcorr.data.cpu().numpy())
25 |
26 | pseudo_label = torch.argmax(after_softmax, dim=1)
27 | pseudo_label = pseudo_label.cpu()
28 |
29 | entropy = torch.sum(- after_softmax * torch.log(after_softmax + 1e-10), dim=1, keepdim=True)
30 | entropy_norm = entropy / np.log(after_softmax.size(1))
31 | entropy_norm = entropy_norm.squeeze(1)
32 | entropy_norm = entropy_norm.cpu()
33 |
34 | for cls in range(len(source_classes)):
35 | # stack H for each class
36 | cls_filter = (pseudo_label == cls)
37 | list_loc = (torch.where(cls_filter == 1))[0]
38 | num_element = list(list_loc.data.numpy())
39 | if len(list_loc) == 0:
40 | missing_cls.append(cls)
41 | continue
42 | available_cls.append(cls)
43 | filtered_ent = torch.gather(entropy_norm, dim=0, index=list_loc)
44 | filtered_feat = torch.gather(fc1_lbcorr.cpu(), dim=0, index=list_loc.unsqueeze(1).repeat(1, 2048))
45 |
46 | h_dict[cls].append(filtered_ent.cpu().data.numpy())
47 | feat_dict[cls].append(filtered_feat.cpu().data.numpy())
48 |
49 | available_cls = np.unique(available_cls)
50 |
51 | prototype_memory = []
52 | prototype_memory_dict = {}
53 | after_softmax_numpy_for_emergency = np.concatenate(after_softmax_numpy_for_emergency, axis=0)
54 | feature_numpy_for_emergency = np.concatenate(feature_numpy_for_emergency, axis=0)
55 |
56 | max_top1_ent = 0
57 | for cls in available_cls:
58 | ents_np = np.concatenate(h_dict[cls], axis=0)
59 | ent_idxs = np.argsort(ents_np)
60 | top1_ent = ents_np[ent_idxs[0]]
61 | if max_top1_ent < top1_ent:
62 | max_top1_ent = top1_ent
63 | max_top1_class = cls
64 |
65 | class_protypeNum_dict = {}
66 | max_prototype = 0
67 |
68 | for cls in available_cls:
69 | ents_np = np.concatenate(h_dict[cls], axis=0)
70 | ents_np_filtered = (ents_np <= max_top1_ent)
71 | class_protypeNum_dict[cls] = ents_np_filtered.sum()
72 |
73 | if max_prototype < ents_np_filtered.sum():
74 | max_prototype = ents_np_filtered.sum()
75 |
76 | if max_prototype > 100:
77 | max_prototype = max_prototype_bound
78 |
79 | for cls in range(len(source_classes)):
80 |
81 | if cls in available_cls:
82 | ents_np = np.concatenate(h_dict[cls], axis=0)
83 | feats_np = np.concatenate(feat_dict[cls], axis=0)
84 | ent_idxs = np.argsort(ents_np)
85 |
86 | truncated_feat = feats_np[ent_idxs[:class_protypeNum_dict[cls]]]
87 | fit_to_max_prototype = np.concatenate([truncated_feat] * (int(max_prototype / truncated_feat.shape[0]) + 1),
88 | axis=0)
89 | fit_to_max_prototype = fit_to_max_prototype[:max_prototype, :]
90 |
91 | prototype_memory.append(fit_to_max_prototype)
92 | prototype_memory_dict[cls] = fit_to_max_prototype
93 | else:
94 | after_softmax_torch_for_emergency = torch.Tensor(after_softmax_numpy_for_emergency)
95 | emergency_idx = torch.argsort(after_softmax_torch_for_emergency, descending=True, dim=1)
96 | cls_emergency_idx = emergency_idx[:, cls]
97 | cls_emergency_idx = cls_emergency_idx[0]
98 | cls_emergency_idx_numpy = cls_emergency_idx.data.numpy()
99 |
100 | copied_features_emergency = np.concatenate(
101 | [np.expand_dims(feature_numpy_for_emergency[cls_emergency_idx_numpy], axis=0)] * max_prototype, axis=0)
102 |
103 | prototype_memory.append(copied_features_emergency)
104 | prototype_memory_dict[cls] = copied_features_emergency
105 |
106 | print("** APM update... time:", time.time() - start_time)
107 | prototype_memory = np.concatenate(prototype_memory, axis=0)
108 | num_prototype_ = int(max_prototype)
109 |
110 | return prototype_memory, num_prototype_, prototype_memory_dict
111 |
112 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | from easydl import *
2 | from torchvision import models
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 | class BaseFeatureExtractor(nn.Module):
7 | def forward(self, *input):
8 | pass
9 |
10 | def __init__(self):
11 | super(BaseFeatureExtractor, self).__init__()
12 |
13 | def output_num(self):
14 | pass
15 |
16 | def train(self, mode=True):
17 | # freeze BN mean and std
18 | for module in self.children():
19 | if isinstance(module, nn.BatchNorm2d):
20 | module.train(False)
21 | else:
22 | module.train(mode)
23 |
24 |
25 | class ResNet50Fc(BaseFeatureExtractor):
26 | def __init__(self,model_path=None, normalize=True):
27 | super(ResNet50Fc, self).__init__()
28 | print (normalize)
29 | if model_path:
30 | if os.path.exists(model_path):
31 | self.model_resnet = models.resnet50(pretrained=False)
32 | self.model_resnet.load_state_dict(torch.load(model_path))
33 | else:
34 | raise Exception('invalid model path!')
35 | else:
36 | self.model_resnet = models.resnet50(pretrained=False)
37 |
38 | if model_path or normalize:
39 | # pretrain model is used, use ImageNet normalization
40 | self.normalize = True
41 | self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
42 | self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
43 | else:
44 | self.normalize = False
45 |
46 | model_resnet = self.model_resnet
47 | self.conv1 = model_resnet.conv1
48 | self.bn1 = model_resnet.bn1
49 | self.relu = model_resnet.relu
50 | self.maxpool = model_resnet.maxpool
51 | self.layer1 = model_resnet.layer1
52 | self.layer2 = model_resnet.layer2
53 | self.layer3 = model_resnet.layer3
54 | self.layer4 = model_resnet.layer4
55 | self.avgpool = model_resnet.avgpool
56 | self.__in_features = model_resnet.fc.in_features
57 |
58 | # self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
59 | # self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
60 |
61 | def forward(self, x):
62 | if self.normalize:
63 | x = (x - self.mean) / self.std
64 | x = self.conv1(x)
65 | x = self.bn1(x)
66 | x = self.relu(x)
67 | x = self.maxpool(x)
68 | x = self.layer1(x)
69 | x = self.layer2(x)
70 | x = self.layer3(x)
71 | x = self.layer4(x)
72 | x = self.avgpool(x)
73 | x = x.view(x.size(0), -1)
74 | return x
75 |
76 | def output_num(self):
77 | return self.__in_features
78 |
79 |
80 | class VGG16Fc(BaseFeatureExtractor):
81 | def __init__(self,model_path=None, normalize=True):
82 | super(VGG16Fc, self).__init__()
83 | if model_path:
84 | if os.path.exists(model_path):
85 | self.model_vgg = models.vgg16(pretrained=False)
86 | self.model_vgg.load_state_dict(torch.load(model_path))
87 | else:
88 | raise Exception('invalid model path!')
89 | else:
90 | self.model_vgg = models.vgg16(pretrained=True)
91 |
92 | if model_path or normalize:
93 | # pretrain model is used, use ImageNet normalization
94 | self.normalize = True
95 | self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
96 | self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
97 | else:
98 | self.normalize = False
99 |
100 | model_vgg = self.model_vgg
101 | self.features = model_vgg.features
102 | self.classifier = nn.Sequential()
103 | for i in range(6):
104 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i])
105 | self.feature_layers = nn.Sequential(self.features, self.classifier)
106 |
107 | self.__in_features = 4096
108 |
109 | def forward(self, x):
110 | if self.normalize:
111 | x = (x - self.mean) / self.std
112 | x = self.features(x)
113 | x = x.view(x.size(0), 25088)
114 | x = self.classifier(x)
115 | return x
116 |
117 | def output_num(self):
118 | return self.__in_features
119 |
120 |
121 | class CLS(nn.Module):
122 |
123 | def __init__(self, in_dim, out_dim, bottle_neck_dim=256, pretrain=False):
124 | super(CLS, self).__init__()
125 | self.pretrain = pretrain
126 | if bottle_neck_dim:
127 | self.bottleneck = nn.Linear(in_dim, bottle_neck_dim)
128 | self.fc = nn.Linear(bottle_neck_dim, out_dim)
129 | self.main = nn.Sequential(self.bottleneck,self.fc,nn.Softmax(dim=-1))
130 | else:
131 | self.fc = nn.Linear(in_dim, out_dim)
132 | self.main = nn.Sequential(self.fc,nn.Softmax(dim=-1))
133 |
134 | def forward(self, x):
135 | out = [x]
136 | for module in self.main.children():
137 | x = module(x)
138 | out.append(x)
139 | return out
140 |
141 |
142 |
143 | class CLS_copy(nn.Module):
144 |
145 | def __init__(self, in_dim, out_dim, bottle_neck_dim=256, pretrain=False):
146 | super(CLS_copy, self).__init__()
147 | self.pretrain = pretrain
148 | if bottle_neck_dim:
149 | self.bottleneck = nn.Linear(in_dim, bottle_neck_dim)
150 | self.fc = nn.Linear(bottle_neck_dim, out_dim)
151 | self.main = nn.Sequential(self.bottleneck,self.fc)
152 | else:
153 | self.fc = nn.Linear(in_dim, out_dim)
154 | self.main = nn.Sequential(self.fc)
155 |
156 | def forward(self, x):
157 | for module in self.main.children():
158 | x = module(x)
159 | return x
160 |
161 |
162 | class AdversarialNetwork(nn.Module):
163 |
164 | def __init__(self, in_feature):
165 | super(AdversarialNetwork, self).__init__()
166 | self.main = nn.Sequential(
167 | nn.Linear(in_feature, 1024),
168 | nn.ReLU(inplace=True),
169 | nn.Dropout(0.5),
170 | nn.Linear(1024,1024),
171 | nn.ReLU(inplace=True),
172 | nn.Dropout(0.5),
173 | nn.Linear(1024, 1),
174 | nn.Sigmoid()
175 | )
176 | self.grl = GradientReverseModule(lambda step: aToBSheduler(step, 0.0, 1.0, gamma=10, max_iter=10000))
177 |
178 | def forward(self, x):
179 | x_ = self.grl(x)
180 | y = self.main(x_)
181 | return y
182 |
183 |
--------------------------------------------------------------------------------
/SFDA_train.py:
--------------------------------------------------------------------------------
1 | from data import *
2 | from net import *
3 | from lib import *
4 | from torch import optim
5 | from APM_update import *
6 | import torch.backends.cudnn as cudnn
7 | import time
8 |
9 | cudnn.benchmark = True
10 | cudnn.deterministic = True
11 |
12 | def seed_everything(seed=1234):
13 | import random
14 | random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 | np.random.seed(seed)
18 | import os
19 | os.environ['PYTHONHASHSEED'] = str(seed)
20 |
21 | seed_everything()
22 |
23 | save_model_path = 'source_pretrained_weights/'+ str(args.data.dataset.source)+str(args.data.dataset.target)+'/'+'model_checkpoint.pth.tar'
24 | save_model_statedict = torch.load(save_model_path)['state_dict']
25 |
26 | model_dict = {
27 | 'resnet50': ResNet50Fc,
28 | 'vgg16': VGG16Fc
29 | }
30 |
31 |
32 | # ======= network architecture =======
33 | class Source_FixedNet(nn.Module):
34 | def __init__(self):
35 | super(Source_FixedNet, self).__init__()
36 | self.feature_extractor = model_dict[args.model.base_model](args.model.pretrained_model)
37 | classifier_output_dim = len(source_classes)
38 | self.classifier = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256)
39 |
40 | class Target_TrainableNet(nn.Module):
41 | def __init__(self):
42 | super(Target_TrainableNet, self).__init__()
43 | self.feature_extractor = model_dict[args.model.base_model](args.model.pretrained_model)
44 | classifier_output_dim = len(source_classes)
45 | self.classifier = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256)
46 | self.cls_multibranch = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256)
47 |
48 |
49 | # ======= pre-trained source network =======
50 | fixed_sourceNet = Source_FixedNet()
51 | fixed_sourceNet.load_state_dict(save_model_statedict)
52 | fixed_feature_extractor_s =(fixed_sourceNet.feature_extractor).cuda()
53 | fixed_classifier_s = (fixed_sourceNet.classifier).cuda()
54 | fixed_feature_extractor_s.eval()
55 | fixed_classifier_s.eval()
56 |
57 | # ======= trainable target network =======
58 | trainable_tragetNet = Target_TrainableNet()
59 | feature_extractor_t =(trainable_tragetNet.feature_extractor).cuda()
60 | feature_extractor_t.load_state_dict(fixed_sourceNet.feature_extractor.state_dict())
61 | classifier_s2t = (trainable_tragetNet.classifier).cuda()
62 | classifier_s2t.load_state_dict(fixed_sourceNet.classifier.state_dict())
63 | classifier_t = (trainable_tragetNet.cls_multibranch).cuda()
64 | classifier_t.load_state_dict(fixed_sourceNet.classifier.state_dict())
65 |
66 |
67 | model_dict = {
68 | 'global_step':0,
69 | 'state_dict': trainable_tragetNet.state_dict(),
70 | 'accuracy': 0}
71 |
72 |
73 | feature_extractor_t.train()
74 | classifier_s2t.train()
75 | classifier_t.train()
76 | print ("Finish model loaded...")
77 |
78 | domains=['amazon', 'dslr', 'webcam']
79 | print ('domain....'+domains[args.data.dataset.source]+'>>>>>>'+domains[args.data.dataset.target])
80 |
81 | scheduler = lambda step, initial_lr: inverseDecaySheduler(step, initial_lr, gamma=10, power=0.75, max_iter=(args.train.min_step))
82 |
83 | optimizer_finetune = OptimWithSheduler(
84 | optim.SGD(feature_extractor_t.parameters(), lr=args.train.lr / 10.0, weight_decay=args.train.weight_decay, momentum=args.train.momentum, nesterov=True),
85 | scheduler)
86 | optimizer_classifier_s2t = OptimWithSheduler(
87 | optim.SGD(classifier_s2t.parameters(), lr=args.train.lr, weight_decay=args.train.weight_decay, momentum=args.train.momentum, nesterov=True),
88 | scheduler)
89 | optimizer_classifier_t= OptimWithSheduler(
90 | optim.SGD(classifier_t.parameters(), lr=args.train.lr, weight_decay=args.train.weight_decay, momentum=args.train.momentum, nesterov=True),
91 | scheduler)
92 |
93 | global_step = 0
94 | best_acc = 0
95 | epoch_id = 0
96 | class_num = args.data.dataset.n_total
97 | pt_memory_update_frequncy = args.train.update_freq
98 |
99 |
100 | while global_step < args.train.min_step:
101 |
102 | epoch_id += 1
103 |
104 | for i, (img_target, label_target) in enumerate(target_train_dl):
105 |
106 | # APM init/update
107 | if (global_step) % pt_memory_update_frequncy == 0:
108 | prototype_memory, num_prototype_,prototype_memory_dict = APM_init_update(feature_extractor_t, classifier_t)
109 |
110 |
111 | img_target = img_target.cuda()
112 |
113 | # forward pass: source-pretrained network
114 | fixed_fc1_s = fixed_feature_extractor_s.forward(img_target)
115 | _, _, _, logit_s = fixed_classifier_s.forward(fixed_fc1_s)
116 | pseudo_label_s = torch.argmax(logit_s, dim=1)
117 |
118 | # forward pass: target network
119 | fc1_t = feature_extractor_t.forward(img_target)
120 | _, _, logit_s2t, _ = classifier_s2t.forward(fc1_t)
121 | _, _, logit_t, _ = classifier_t(fc1_t)
122 |
123 | # compute pseudo labels
124 | proto_feat_tensor = torch.Tensor(prototype_memory) # (B * 2048)
125 | feature_embed_tensor = fc1_t.cpu()
126 | proto_feat_tensor = tensor_l2normalization(proto_feat_tensor)
127 | batch_feat_tensor = tensor_l2normalization(feature_embed_tensor)
128 |
129 | sim_mat = torch.mm(batch_feat_tensor, proto_feat_tensor.permute(1,0))
130 | sim_mat = F.avg_pool1d(sim_mat.unsqueeze(0), kernel_size=num_prototype_, stride=num_prototype_).squeeze(0)# (B, #class)
131 | pseudo_label_t = torch.argmax(sim_mat, dim=1).cuda()
132 |
133 | # confidence-based filtering
134 | arg_idxs = torch.argsort(sim_mat, dim=1, descending=True) # (B, #class)
135 |
136 | first_group_idx = arg_idxs[:, 0]
137 | second_group_idx = arg_idxs[:, 1]
138 |
139 | first_group_feat = [prototype_memory_dict[int(x.data.numpy())] for x in first_group_idx]
140 | first_group_feat_tensor = torch.tensor(np.concatenate(first_group_feat, axis=0)) # (B*P, 2048)
141 | first_group_feat_tensor = tensor_l2normalization(first_group_feat_tensor)
142 |
143 | second_group_feat = [prototype_memory_dict[int(x.data.numpy())] for x in second_group_idx]
144 | second_group_feat_tensor = torch.tensor(np.concatenate(second_group_feat, axis=0)) # (B*P, 2048)
145 | second_group_feat_tensor = tensor_l2normalization(second_group_feat_tensor)
146 |
147 | feature_embed_tensor_repeat = torch.Tensor(np.repeat(feature_embed_tensor.cpu().data.numpy(), repeats=num_prototype_, axis=0))
148 | feature_embed_tensor_repeat = tensor_l2normalization(feature_embed_tensor_repeat)
149 |
150 | first_dist_mat = 1 - torch.mm(first_group_feat_tensor, feature_embed_tensor_repeat.permute(1,0)) # distance = 1 - simialirty
151 | second_dist_mat = 1 - torch.mm(second_group_feat_tensor, feature_embed_tensor_repeat.permute(1,0))
152 |
153 | first_dist_mat = F.max_pool2d(first_dist_mat.permute(1,0).unsqueeze(0).unsqueeze(0), kernel_size=num_prototype_, stride=num_prototype_).squeeze(0).squeeze(0)# (B, #class)
154 | second_dist_mat = -1*F.max_pool2d(-1* second_dist_mat.permute(1,0).unsqueeze(0).unsqueeze(0), kernel_size=num_prototype_, stride=num_prototype_).squeeze(0).squeeze(0)# (B, #class)
155 |
156 | first_dist_vec = torch.diag(first_dist_mat) #(B)
157 | second_dist_vec = torch.diag(second_dist_mat) # B
158 |
159 | confidence_mask = ((first_dist_vec- second_dist_vec) < 0).cuda()
160 |
161 | # optimize target network using two types of pseudo labels
162 | ce_from_s2t = nn.CrossEntropyLoss()(logit_s2t, pseudo_label_s)
163 | ce_from_t = nn.CrossEntropyLoss(reduction='none')(logit_t, pseudo_label_t).view(-1, 1).squeeze(1)
164 | ce_from_t = torch.mean(ce_from_t * confidence_mask, dim=0, keepdim=True)
165 |
166 | alpha = np.float(2.0 / (1.0 + np.exp(-10 * global_step / float(args.train.min_step//2))) - 1.0)
167 | ce_total = (1 - alpha) * ce_from_s2t + alpha * ce_from_t
168 |
169 | with OptimizerManager([optimizer_finetune, optimizer_classifier_s2t, optimizer_classifier_t]):
170 | loss = ce_total
171 | loss.backward()
172 |
173 | global_step += 1
174 |
175 | # evaluation during training
176 | if global_step % args.test.test_interval == 0:
177 |
178 | counter = AccuracyCounter()
179 | with TrainingModeManager([feature_extractor_t, classifier_t], train=False) as mgr, torch.no_grad():
180 |
181 | for i, (img, label) in enumerate(target_test_dl):
182 | img = img.cuda()
183 | label = label.cuda()
184 |
185 | feature = feature_extractor_t.forward(img)
186 | _, _, _, predict_prob_t = classifier_t.forward(feature)
187 |
188 | counter.addOneBatch(variable_to_numpy(predict_prob_t), variable_to_numpy(one_hot(label, args.data.dataset.n_total)))
189 |
190 | acc_test = counter.reportAccuracy()
191 | print('>>>>>>>>>>>accuracy>>>>>>>>>>>>>>>>.')
192 | print(acc_test)
193 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>.')
194 | if best_acc < acc_test:
195 | best_acc = acc_test
196 | model_dict = {
197 | 'global_step': global_step + 1,
198 | 'state_dict': trainable_tragetNet.state_dict(),
199 | 'accuracy': acc_test}
200 |
201 | torch.save(model_dict, join('pretrained_weights/'+str(args.data.dataset.source) + str(args.data.dataset.target) +'/' + 'domain'+ str(args.data.dataset.source)+str(args.data.dataset.target)+'accBEST_model_checkpoint.pth.tar'))
202 |
203 |
204 | exit()
205 |
--------------------------------------------------------------------------------