├── office-train-config.yaml ├── data.py ├── SFDA_test.py ├── config.py ├── lib.py ├── README.md ├── APM_update.py ├── net.py └── SFDA_train.py /office-train-config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataloader: {batch_size: 32, class_balance: true, data_workers: 2} 3 | dataset: {n_share: 31, n_total: 31, name: office, root_path: ../data/office, 4 | source: 0, target: 2 } 5 | log: {log_interval: 10, root_dir: public} 6 | misc: {gpus: 1} 7 | model: {base_model: resnet50, pretrained_model: False} 8 | test: {resume_file: Clipart_to_Art.pkl, test_interval: 1000, test_only: false} 9 | train: {update_freq: 100, lr: 0.001, min_step: 5000, momentum: 0.9, weight_decay: 0.0005} 10 | 11 | 12 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | from easydl import * 3 | from collections import Counter 4 | from torchvision.transforms.transforms import * 5 | from torch.utils.data import DataLoader, WeightedRandomSampler 6 | 7 | 8 | source_classes = [i for i in range(args.data.dataset.n_total)] 9 | target_classes = [i for i in range(args.data.dataset.n_share)] 10 | 11 | 12 | 13 | train_transform = Compose([ 14 | Resize(256), 15 | RandomCrop(224), 16 | RandomHorizontalFlip(), 17 | ToTensor(), 18 | ]) 19 | 20 | 21 | test_transform = Compose([ 22 | Resize(256), 23 | CenterCrop(224), 24 | ToTensor(), 25 | ]) 26 | 27 | 28 | source_train_ds = FileListDataset(list_path=source_file, path_prefix=dataset.prefixes[args.data.dataset.source], 29 | transform=train_transform, filter=(lambda x: x in source_classes)) 30 | source_test_ds = FileListDataset(list_path=source_file,path_prefix=dataset.prefixes[args.data.dataset.source], 31 | transform=test_transform, filter=(lambda x: x in source_classes)) 32 | target_train_ds = FileListDataset(list_path=target_file, path_prefix=dataset.prefixes[args.data.dataset.target], 33 | transform=train_transform, filter=(lambda x: x in target_classes)) 34 | target_test_ds = FileListDataset(list_path=target_file, path_prefix=dataset.prefixes[args.data.dataset.target], 35 | transform=test_transform, filter=(lambda x: x in target_classes)) 36 | 37 | 38 | classes = source_train_ds.labels 39 | freq = Counter(classes) 40 | class_weight = {x : 1.0 / freq[x] if args.data.dataloader.class_balance else 1.0 for x in freq} 41 | 42 | 43 | source_weights = [class_weight[x] for x in source_train_ds.labels] 44 | sampler = WeightedRandomSampler(source_weights, len(source_train_ds.labels)) 45 | 46 | source_train_dl = DataLoader(dataset=source_train_ds, batch_size=args.data.dataloader.batch_size, 47 | sampler=sampler, num_workers=args.data.dataloader.data_workers, drop_last=True) 48 | source_test_dl = DataLoader(dataset=source_test_ds, batch_size=args.data.dataloader.batch_size, shuffle=False, 49 | num_workers=1, drop_last=False) 50 | target_train_dl = DataLoader(dataset=target_train_ds, batch_size=args.data.dataloader.batch_size,shuffle=True, 51 | num_workers=args.data.dataloader.data_workers, drop_last=True) 52 | target_test_dl = DataLoader(dataset=target_test_ds, batch_size=args.data.dataloader.batch_size, shuffle=False, 53 | num_workers=1, drop_last=False) -------------------------------------------------------------------------------- /SFDA_test.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from net import * 3 | from lib import * 4 | from torch import optim 5 | from APM_update import * 6 | import torch.backends.cudnn as cudnn 7 | import time 8 | 9 | cudnn.benchmark = True 10 | cudnn.deterministic = True 11 | 12 | def seed_everything(seed=1234): 13 | import random 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | import os 19 | os.environ['PYTHONHASHSEED'] = str(seed) 20 | 21 | seed_everything() 22 | 23 | save_model_path = 'pretrained_weights/'+str(args.data.dataset.source)+str(args.data.dataset.target)+'/'+'domain'+ str(args.data.dataset.source)+str(args.data.dataset.target)+'accBEST_model_checkpoint.pth.tar' 24 | save_model_statedict = torch.load(save_model_path)['state_dict'] 25 | 26 | model_dict = { 27 | 'resnet50': ResNet50Fc, 28 | 'vgg16': VGG16Fc 29 | } 30 | 31 | # ======= network architecture ======= 32 | class Target_TrainableNet(nn.Module): 33 | def __init__(self): 34 | super(Target_TrainableNet, self).__init__() 35 | self.feature_extractor = model_dict[args.model.base_model](args.model.pretrained_model) 36 | classifier_output_dim = len(source_classes) 37 | self.classifier = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256) 38 | self.cls_multibranch = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256) 39 | 40 | 41 | # ======= target network ======= 42 | trainable_tragetNet = Target_TrainableNet() 43 | trainable_tragetNet.load_state_dict(save_model_statedict) 44 | 45 | feature_extractor_t =(trainable_tragetNet.feature_extractor).cuda() 46 | classifier_s2t = (trainable_tragetNet.classifier).cuda() 47 | classifier_t = (trainable_tragetNet.cls_multibranch).cuda() 48 | print ("Finish model loaded...") 49 | 50 | 51 | domains=['amazon', 'dslr', 'webcam'] 52 | print ('domain....'+domains[args.data.dataset.source]+'>>>>>>'+domains[args.data.dataset.target]) 53 | 54 | counter = AccuracyCounter() 55 | with TrainingModeManager([feature_extractor_t, classifier_t], train=False) as mgr, torch.no_grad(): 56 | 57 | for i, (img, label) in enumerate(target_test_dl): 58 | img = img.cuda() 59 | label = label.cuda() 60 | 61 | feature = feature_extractor_t.forward(img) 62 | ___, __, before_softmax, predict_prob = classifier_t.forward(feature) 63 | 64 | counter.addOneBatch(variable_to_numpy(predict_prob), variable_to_numpy(one_hot(label, args.data.dataset.n_total))) 65 | 66 | acc_test = counter.reportAccuracy() 67 | print('>>>>>>>Test Accuracy>>>>>>>>>>.') 68 | print(acc_test) 69 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>.') 70 | 71 | exit() -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import easydict 4 | from os.path import join 5 | 6 | 7 | class Dataset: 8 | def __init__(self, path, domains, files, prefix): 9 | self.path = path 10 | self.prefix = prefix 11 | self.domains = domains 12 | self.files = [(join(path, file)) for file in files] 13 | self.prefixes = [self.prefix] * len(self.domains) 14 | 15 | 16 | import argparse 17 | parser = argparse.ArgumentParser(description='Code for *Learning to Transfer Examples for Partial Domain Adaptation*', 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument('--config', type=str, default='config.yaml', help='/path/to/config/file') 20 | 21 | args = parser.parse_args() 22 | 23 | config_file = args.config 24 | 25 | args = yaml.load(open(config_file)) 26 | 27 | save_config = yaml.load(open(config_file)) 28 | 29 | args = easydict.EasyDict(args) 30 | 31 | dataset = None 32 | if args.data.dataset.name == 'office': 33 | dataset = Dataset( 34 | path=args.data.dataset.root_path, 35 | domains=['amazon', 'dslr', 'webcam'], 36 | files=[ 37 | 'amazon_31_list.txt', 38 | 'dslr_31_list.txt', 39 | 'webcam_31_list.txt' 40 | ], 41 | prefix=args.data.dataset.root_path) 42 | elif args.data.dataset.name == 'officehome': 43 | dataset = Dataset( 44 | path=args.data.dataset.root_path, 45 | domains=['Art', 'Clipart', 'Product', 'Real_World'], 46 | files=[ 47 | 'Art.txt', 48 | 'Clipart.txt', 49 | 'Product.txt', 50 | 'Real_World.txt' 51 | ], 52 | prefix=args.data.dataset.root_path) 53 | elif args.data.dataset.name == 'visda': 54 | dataset = Dataset( 55 | path=args.data.dataset.root_path, 56 | domains=['simulation', 'real'], 57 | files=[ 58 | 'simulation_image_list.txt', 59 | 'real_image_list.txt' 60 | ], 61 | prefix=args.data.dataset.root_path) 62 | elif args.data.dataset.name == 'C2I': 63 | dataset = Dataset( 64 | path=args.data.dataset.root_path, 65 | domains=['Caltech', 'ImageNet'], 66 | files=[ 67 | 'C2I_caltech256.txt', 68 | 'C2I_imagenet84.txt' 69 | ], 70 | prefix=args.data.dataset.root_path) 71 | elif args.data.dataset.name == 'I2C': 72 | dataset = Dataset( 73 | path=args.data.dataset.root_path, 74 | domains=['ImageNet', 'Caltech'], 75 | files=[ 76 | 'I2C_imagenet1000.txt', 77 | 'I2C_caltech84.txt' 78 | ], 79 | prefix=args.data.dataset.root_path) 80 | else: 81 | raise Exception(f'dataset {args.data.dataset.name} not supported!') 82 | # print (args.data.dataset.source) 83 | source_domain_name = dataset.domains[args.data.dataset.source] 84 | target_domain_name = dataset.domains[args.data.dataset.target] 85 | source_file = dataset.files[args.data.dataset.source] 86 | target_file = dataset.files[args.data.dataset.target] 87 | print ("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ") 88 | print ("source_domain_name :", source_domain_name, "target_domain_name :", target_domain_name) 89 | -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | from easydl import * 2 | import torch.nn.functional as F 3 | 4 | def reverse_sigmoid(y): 5 | return torch.log(y / (1.0 - y + 1e-10) + 1e-10) 6 | 7 | 8 | def get_source_share_weight(domain_out, before_softmax, domain_temperature=1.0, class_temperature=10.0): 9 | before_softmax = before_softmax / class_temperature 10 | after_softmax = nn.Softmax(-1)(before_softmax) 11 | domain_logit = reverse_sigmoid(domain_out) # why reverse layer?: do sigmoid()-1 12 | domain_logit = domain_logit / domain_temperature 13 | domain_out = nn.Sigmoid()(domain_logit) 14 | 15 | entropy = torch.sum(- after_softmax * torch.log(after_softmax + 1e-10), dim=1, keepdim=True) 16 | entropy_norm = entropy / np.log(after_softmax.size(1)) 17 | weight = entropy_norm - domain_out 18 | weight = weight.detach() 19 | return weight 20 | 21 | 22 | def get_source_share_weight_onlyentropy( before_softmax, class_temperature=10.0): 23 | before_softmax = before_softmax / class_temperature 24 | after_softmax = nn.Softmax(-1)(before_softmax) 25 | # print (after_softmax) 26 | 27 | entropy = torch.sum(- after_softmax * torch.log(after_softmax + 1e-5), dim=1, keepdim=True) 28 | entropy_norm = entropy / (np.log(after_softmax.size(1)) ) 29 | # print (entropy_norm) 30 | weight = entropy_norm 31 | weight = weight.detach() 32 | return weight 33 | 34 | def hellinger_distance(p, q): 35 | return torch.norm((torch.sqrt(p) - torch.sqrt(q)), p=2, dim=1) / np.sqrt(2) 36 | 37 | 38 | 39 | 40 | def get_commonness_weight(ps_s, pt_s, ps_t, pt_t, class_temperature=10.0): 41 | 42 | ps_s = F.softmax(ps_s / class_temperature) 43 | pt_s = F.softmax(pt_s / class_temperature) 44 | ps_t = F.softmax(ps_t) 45 | pt_t = F.softmax(pt_t) 46 | 47 | ws = hellinger_distance(ps_s, pt_s).detach() 48 | wt = hellinger_distance(ps_t, pt_t).detach() 49 | 50 | return ws, wt 51 | 52 | 53 | 54 | def get_entropy(domain_out, before_softmax, domain_temperature=1.0, class_temperature=10.0): 55 | before_softmax = before_softmax / class_temperature 56 | after_softmax = nn.Softmax(-1)(before_softmax) 57 | domain_logit = reverse_sigmoid(domain_out) # why reverse layer?: do sigmoid()-1 58 | domain_logit = domain_logit / domain_temperature 59 | domain_out = nn.Sigmoid()(domain_logit) 60 | 61 | entropy = torch.sum(- after_softmax * torch.log(after_softmax + 1e-10), dim=1, keepdim=True) 62 | entropy_norm = entropy / np.log(after_softmax.size(1)) 63 | weight = entropy_norm 64 | weight = weight.detach() 65 | return weight 66 | 67 | 68 | def get_target_share_weight(domain_out, before_softmax, domain_temperature=1.0, class_temperature=10.0): 69 | return - get_source_share_weight(domain_out, before_softmax, domain_temperature, class_temperature) 70 | 71 | 72 | def normalize_weight(x): 73 | min_val = x.min() 74 | max_val = x.max() 75 | x = (x - min_val) / (max_val - min_val +1e-5) 76 | x = x / (torch.mean(x)+1e-5) # why do this? 77 | return x.detach() 78 | 79 | 80 | def normalize_weight01(x): 81 | min_val = x.min() 82 | max_val = x.max() 83 | x = (x - min_val) / (max_val - min_val) 84 | return x.detach() 85 | 86 | def normalize_weight_11(x): 87 | min_val = x.min() 88 | max_val = x.max() 89 | x = (x - min_val) / (max_val - min_val) 90 | x = x*2 - 1 91 | return x.detach() 92 | 93 | def seed_everything(seed=1234): 94 | import random 95 | random.seed(seed) 96 | torch.manual_seed(seed) 97 | torch.cuda.manual_seed_all(seed) 98 | np.random.seed(seed) 99 | import os 100 | os.environ['PYTHONHASHSEED'] = str(seed) 101 | 102 | 103 | def tensor_l2normalization(q): 104 | qn = torch.norm(q, p=2, dim=1).detach().unsqueeze(1) 105 | q = q.div(qn.expand_as(q)) 106 | return q -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SFDA - Domain Adaptation without Source Data 2 | 3 | ## Prerequisites 4 | * Ubuntu 18.04 5 | * Python 3.6+ 6 | * PyTorch 1.5+ (recent version is recommended) 7 | * NVIDIA GPU (>= 12GB) 8 | * CUDA 10.0 (optional) 9 | * CUDNN 7.5 (optional) 10 | 11 | ## Getting Started 12 | 13 | ### Installation 14 | * Configure virtual (anaconda) environment 15 | ``` 16 | conda create -n env_name python=3.6 17 | source activate env_name 18 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 19 | ``` 20 | * Install python libraries 21 | ``` 22 | conda install -c conda-forge matplotlib 23 | conda install -c anaconda yaml 24 | conda install -c anaconda pyyaml 25 | conda install -c anaconda scipy 26 | conda install -c anaconda scikit-learn 27 | conda install -c conda-forge easydict 28 | pip install easydl 29 | ``` 30 | 31 | ### Download this repository 32 | * We provide two versions of the repository (with dataset / without dataset) for a flexible experiment 33 | 34 | * Full SFDA repository (with dataset): [download link][aa] 35 | * In this case, go to ```training and testing``` step directly 36 | 37 | [aa]: https://drive.google.com/drive/folders/11g8yOWxIG47G-5vImtX98qrg0Y4UxrGd?usp=sharing 38 | 39 | * SFDA repository (without dataset): [download link][a] 40 | 41 | [a]: https://drive.google.com/drive/folders/1ndxbQLAkDxxvlPs7E65_6fQ4dNbxXkHR?usp=sharing 42 | 43 | 44 | * Visualization of repository structure (Full SFDA repository) 45 | 46 | ``` 47 | |-- APM_update.py 48 | |-- SFDA_test.py 49 | |-- SFDA_train.py 50 | |-- config.py 51 | |-- data.py 52 | |-- lib.py 53 | |-- net.py 54 | |-- office-train-config.yaml 55 | |-- data 56 | | `-- office 57 | | |-- domain_adaptation_images 58 | | | |-- amazon 59 | | | | `-- images 60 | | | |-- dslr 61 | | | | `-- images 62 | | | `--- webcam 63 | | | `-- images 64 | | |-- amazon_31_list.txt 65 | | |-- dslr_31_list.txt 66 | | `-- webcam_31_list.txt 67 | |-- pretrained_weights 68 | | `-- 02 69 | | `-- domain02accBEST_model_checkpoint.pth.tar 70 | `-- source_pretrained_weights 71 | `-- 02 72 | `-- model_checkpoint.pth.tar 73 | ``` 74 | 75 | ### Download dataset 76 | * Download the Office31 dataset ([link][b]) and unzip in ```./data/office``` 77 | 78 | [b]: https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view 79 | 80 | * Download the text file ([link][c]), e.g., amazon_31_list.txt, dslr_31_list.txt, webcam_31_list.txt, in ```./data/office``` 81 | 82 | [c]: https://drive.google.com/drive/folders/11wFsBoG--cm7uD0L-7L5X5hprWDCMBpH?usp=sharing 83 | 84 | 85 | ### Download source-pretrained parameters (Fs and Cs of Figure 2 in our main paper) 86 | * Download source-pretrained parameters ([link][d]) in ```./source_pretrained_weights/[scenario_number]``` 87 | 88 | [d]: https://drive.google.com/drive/folders/1mkzEl8SHQ0mVFnYV0CvZIdeLstCm2shy?usp=sharing 89 | 90 | ex) Source-pretrained parameters of A[0] -> W[2] senario should be located in ```./source_pretrained_weights/02``` 91 | 92 | 93 | ## Training and testing 94 | 95 | * Arguments required for training and testing are contained in ```office-train-config.yaml ``` 96 | * Here is an example of running an experiment on Office31 (default: A -> W) 97 | * Scenario can be changed by editing ```source: 0, target: 2``` in ```office-train-config.yaml``` 98 | * We will update the full version of our framework including settings for ```OfficeHome``` and ```Visda-C``` 99 | 100 | ### Training 101 | 102 | * Run the following command 103 | 104 | ``` 105 | python SFDA_train.py --config office-train-config.yaml 106 | ``` 107 | 108 | ### Testing (on pretrained model) 109 | 110 | * As a first step, download SFDA pretrained parameters ([link][e]) in ```./pretrained_weights/[scenario_number]``` 111 | 112 | ex) SFDA pretrained parameters of A[0] -> W[2] senario should be located in ```./pretrained_weights/02``` 113 | 114 | [e]: https://drive.google.com/drive/folders/1XiWZXsES_oEAI2WMdOBxqjKieA7zOOwZ?usp=sharing 115 | 116 | * or run the training code to obtain pretrained weights 117 | 118 | * Run the following command 119 | 120 | ``` 121 | python SFDA_test.py --config office-train-config.yaml 122 | ``` 123 | 124 | 125 | 126 | ## Experimental results on Office31 127 | 128 | * Results using the provided code 129 | 130 | |
|
A→W
|
D→W
|
W→D
|
A→D
|
D→A
|
W→A
|
Avg
| 131 | |:--------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:| 132 | |Accuracy (%) |
91.06 |
97.35 |
98.99 |
91.96 |
71.60 |
68.62 |
86.60 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /APM_update.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from lib import * 3 | import time 4 | 5 | def APM_init_update(feature_extractor, classifier_t): 6 | start_time = time.time() 7 | available_cls = [] 8 | h_dict = {} 9 | feat_dict = {} 10 | missing_cls = [] 11 | after_softmax_numpy_for_emergency = [] 12 | feature_numpy_for_emergency = [] 13 | max_prototype_bound = 100 14 | 15 | for cls in range(len(source_classes)): 16 | h_dict[cls] = [] 17 | feat_dict[cls] = [] 18 | 19 | for (im_target_lbcorr, label_target_lbcorr) in target_train_dl: 20 | im_target_lbcorr = im_target_lbcorr.cuda() 21 | fc1_lbcorr = feature_extractor.forward(im_target_lbcorr) 22 | _, _, _, after_softmax = classifier_t.forward(fc1_lbcorr) 23 | after_softmax_numpy_for_emergency.append(after_softmax.data.cpu().numpy()) 24 | feature_numpy_for_emergency.append(fc1_lbcorr.data.cpu().numpy()) 25 | 26 | pseudo_label = torch.argmax(after_softmax, dim=1) 27 | pseudo_label = pseudo_label.cpu() 28 | 29 | entropy = torch.sum(- after_softmax * torch.log(after_softmax + 1e-10), dim=1, keepdim=True) 30 | entropy_norm = entropy / np.log(after_softmax.size(1)) 31 | entropy_norm = entropy_norm.squeeze(1) 32 | entropy_norm = entropy_norm.cpu() 33 | 34 | for cls in range(len(source_classes)): 35 | # stack H for each class 36 | cls_filter = (pseudo_label == cls) 37 | list_loc = (torch.where(cls_filter == 1))[0] 38 | num_element = list(list_loc.data.numpy()) 39 | if len(list_loc) == 0: 40 | missing_cls.append(cls) 41 | continue 42 | available_cls.append(cls) 43 | filtered_ent = torch.gather(entropy_norm, dim=0, index=list_loc) 44 | filtered_feat = torch.gather(fc1_lbcorr.cpu(), dim=0, index=list_loc.unsqueeze(1).repeat(1, 2048)) 45 | 46 | h_dict[cls].append(filtered_ent.cpu().data.numpy()) 47 | feat_dict[cls].append(filtered_feat.cpu().data.numpy()) 48 | 49 | available_cls = np.unique(available_cls) 50 | 51 | prototype_memory = [] 52 | prototype_memory_dict = {} 53 | after_softmax_numpy_for_emergency = np.concatenate(after_softmax_numpy_for_emergency, axis=0) 54 | feature_numpy_for_emergency = np.concatenate(feature_numpy_for_emergency, axis=0) 55 | 56 | max_top1_ent = 0 57 | for cls in available_cls: 58 | ents_np = np.concatenate(h_dict[cls], axis=0) 59 | ent_idxs = np.argsort(ents_np) 60 | top1_ent = ents_np[ent_idxs[0]] 61 | if max_top1_ent < top1_ent: 62 | max_top1_ent = top1_ent 63 | max_top1_class = cls 64 | 65 | class_protypeNum_dict = {} 66 | max_prototype = 0 67 | 68 | for cls in available_cls: 69 | ents_np = np.concatenate(h_dict[cls], axis=0) 70 | ents_np_filtered = (ents_np <= max_top1_ent) 71 | class_protypeNum_dict[cls] = ents_np_filtered.sum() 72 | 73 | if max_prototype < ents_np_filtered.sum(): 74 | max_prototype = ents_np_filtered.sum() 75 | 76 | if max_prototype > 100: 77 | max_prototype = max_prototype_bound 78 | 79 | for cls in range(len(source_classes)): 80 | 81 | if cls in available_cls: 82 | ents_np = np.concatenate(h_dict[cls], axis=0) 83 | feats_np = np.concatenate(feat_dict[cls], axis=0) 84 | ent_idxs = np.argsort(ents_np) 85 | 86 | truncated_feat = feats_np[ent_idxs[:class_protypeNum_dict[cls]]] 87 | fit_to_max_prototype = np.concatenate([truncated_feat] * (int(max_prototype / truncated_feat.shape[0]) + 1), 88 | axis=0) 89 | fit_to_max_prototype = fit_to_max_prototype[:max_prototype, :] 90 | 91 | prototype_memory.append(fit_to_max_prototype) 92 | prototype_memory_dict[cls] = fit_to_max_prototype 93 | else: 94 | after_softmax_torch_for_emergency = torch.Tensor(after_softmax_numpy_for_emergency) 95 | emergency_idx = torch.argsort(after_softmax_torch_for_emergency, descending=True, dim=1) 96 | cls_emergency_idx = emergency_idx[:, cls] 97 | cls_emergency_idx = cls_emergency_idx[0] 98 | cls_emergency_idx_numpy = cls_emergency_idx.data.numpy() 99 | 100 | copied_features_emergency = np.concatenate( 101 | [np.expand_dims(feature_numpy_for_emergency[cls_emergency_idx_numpy], axis=0)] * max_prototype, axis=0) 102 | 103 | prototype_memory.append(copied_features_emergency) 104 | prototype_memory_dict[cls] = copied_features_emergency 105 | 106 | print("** APM update... time:", time.time() - start_time) 107 | prototype_memory = np.concatenate(prototype_memory, axis=0) 108 | num_prototype_ = int(max_prototype) 109 | 110 | return prototype_memory, num_prototype_, prototype_memory_dict 111 | 112 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | from easydl import * 2 | from torchvision import models 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | class BaseFeatureExtractor(nn.Module): 7 | def forward(self, *input): 8 | pass 9 | 10 | def __init__(self): 11 | super(BaseFeatureExtractor, self).__init__() 12 | 13 | def output_num(self): 14 | pass 15 | 16 | def train(self, mode=True): 17 | # freeze BN mean and std 18 | for module in self.children(): 19 | if isinstance(module, nn.BatchNorm2d): 20 | module.train(False) 21 | else: 22 | module.train(mode) 23 | 24 | 25 | class ResNet50Fc(BaseFeatureExtractor): 26 | def __init__(self,model_path=None, normalize=True): 27 | super(ResNet50Fc, self).__init__() 28 | print (normalize) 29 | if model_path: 30 | if os.path.exists(model_path): 31 | self.model_resnet = models.resnet50(pretrained=False) 32 | self.model_resnet.load_state_dict(torch.load(model_path)) 33 | else: 34 | raise Exception('invalid model path!') 35 | else: 36 | self.model_resnet = models.resnet50(pretrained=False) 37 | 38 | if model_path or normalize: 39 | # pretrain model is used, use ImageNet normalization 40 | self.normalize = True 41 | self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 42 | self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 43 | else: 44 | self.normalize = False 45 | 46 | model_resnet = self.model_resnet 47 | self.conv1 = model_resnet.conv1 48 | self.bn1 = model_resnet.bn1 49 | self.relu = model_resnet.relu 50 | self.maxpool = model_resnet.maxpool 51 | self.layer1 = model_resnet.layer1 52 | self.layer2 = model_resnet.layer2 53 | self.layer3 = model_resnet.layer3 54 | self.layer4 = model_resnet.layer4 55 | self.avgpool = model_resnet.avgpool 56 | self.__in_features = model_resnet.fc.in_features 57 | 58 | # self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 59 | # self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 60 | 61 | def forward(self, x): 62 | if self.normalize: 63 | x = (x - self.mean) / self.std 64 | x = self.conv1(x) 65 | x = self.bn1(x) 66 | x = self.relu(x) 67 | x = self.maxpool(x) 68 | x = self.layer1(x) 69 | x = self.layer2(x) 70 | x = self.layer3(x) 71 | x = self.layer4(x) 72 | x = self.avgpool(x) 73 | x = x.view(x.size(0), -1) 74 | return x 75 | 76 | def output_num(self): 77 | return self.__in_features 78 | 79 | 80 | class VGG16Fc(BaseFeatureExtractor): 81 | def __init__(self,model_path=None, normalize=True): 82 | super(VGG16Fc, self).__init__() 83 | if model_path: 84 | if os.path.exists(model_path): 85 | self.model_vgg = models.vgg16(pretrained=False) 86 | self.model_vgg.load_state_dict(torch.load(model_path)) 87 | else: 88 | raise Exception('invalid model path!') 89 | else: 90 | self.model_vgg = models.vgg16(pretrained=True) 91 | 92 | if model_path or normalize: 93 | # pretrain model is used, use ImageNet normalization 94 | self.normalize = True 95 | self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 96 | self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 97 | else: 98 | self.normalize = False 99 | 100 | model_vgg = self.model_vgg 101 | self.features = model_vgg.features 102 | self.classifier = nn.Sequential() 103 | for i in range(6): 104 | self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i]) 105 | self.feature_layers = nn.Sequential(self.features, self.classifier) 106 | 107 | self.__in_features = 4096 108 | 109 | def forward(self, x): 110 | if self.normalize: 111 | x = (x - self.mean) / self.std 112 | x = self.features(x) 113 | x = x.view(x.size(0), 25088) 114 | x = self.classifier(x) 115 | return x 116 | 117 | def output_num(self): 118 | return self.__in_features 119 | 120 | 121 | class CLS(nn.Module): 122 | 123 | def __init__(self, in_dim, out_dim, bottle_neck_dim=256, pretrain=False): 124 | super(CLS, self).__init__() 125 | self.pretrain = pretrain 126 | if bottle_neck_dim: 127 | self.bottleneck = nn.Linear(in_dim, bottle_neck_dim) 128 | self.fc = nn.Linear(bottle_neck_dim, out_dim) 129 | self.main = nn.Sequential(self.bottleneck,self.fc,nn.Softmax(dim=-1)) 130 | else: 131 | self.fc = nn.Linear(in_dim, out_dim) 132 | self.main = nn.Sequential(self.fc,nn.Softmax(dim=-1)) 133 | 134 | def forward(self, x): 135 | out = [x] 136 | for module in self.main.children(): 137 | x = module(x) 138 | out.append(x) 139 | return out 140 | 141 | 142 | 143 | class CLS_copy(nn.Module): 144 | 145 | def __init__(self, in_dim, out_dim, bottle_neck_dim=256, pretrain=False): 146 | super(CLS_copy, self).__init__() 147 | self.pretrain = pretrain 148 | if bottle_neck_dim: 149 | self.bottleneck = nn.Linear(in_dim, bottle_neck_dim) 150 | self.fc = nn.Linear(bottle_neck_dim, out_dim) 151 | self.main = nn.Sequential(self.bottleneck,self.fc) 152 | else: 153 | self.fc = nn.Linear(in_dim, out_dim) 154 | self.main = nn.Sequential(self.fc) 155 | 156 | def forward(self, x): 157 | for module in self.main.children(): 158 | x = module(x) 159 | return x 160 | 161 | 162 | class AdversarialNetwork(nn.Module): 163 | 164 | def __init__(self, in_feature): 165 | super(AdversarialNetwork, self).__init__() 166 | self.main = nn.Sequential( 167 | nn.Linear(in_feature, 1024), 168 | nn.ReLU(inplace=True), 169 | nn.Dropout(0.5), 170 | nn.Linear(1024,1024), 171 | nn.ReLU(inplace=True), 172 | nn.Dropout(0.5), 173 | nn.Linear(1024, 1), 174 | nn.Sigmoid() 175 | ) 176 | self.grl = GradientReverseModule(lambda step: aToBSheduler(step, 0.0, 1.0, gamma=10, max_iter=10000)) 177 | 178 | def forward(self, x): 179 | x_ = self.grl(x) 180 | y = self.main(x_) 181 | return y 182 | 183 | -------------------------------------------------------------------------------- /SFDA_train.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from net import * 3 | from lib import * 4 | from torch import optim 5 | from APM_update import * 6 | import torch.backends.cudnn as cudnn 7 | import time 8 | 9 | cudnn.benchmark = True 10 | cudnn.deterministic = True 11 | 12 | def seed_everything(seed=1234): 13 | import random 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | import os 19 | os.environ['PYTHONHASHSEED'] = str(seed) 20 | 21 | seed_everything() 22 | 23 | save_model_path = 'source_pretrained_weights/'+ str(args.data.dataset.source)+str(args.data.dataset.target)+'/'+'model_checkpoint.pth.tar' 24 | save_model_statedict = torch.load(save_model_path)['state_dict'] 25 | 26 | model_dict = { 27 | 'resnet50': ResNet50Fc, 28 | 'vgg16': VGG16Fc 29 | } 30 | 31 | 32 | # ======= network architecture ======= 33 | class Source_FixedNet(nn.Module): 34 | def __init__(self): 35 | super(Source_FixedNet, self).__init__() 36 | self.feature_extractor = model_dict[args.model.base_model](args.model.pretrained_model) 37 | classifier_output_dim = len(source_classes) 38 | self.classifier = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256) 39 | 40 | class Target_TrainableNet(nn.Module): 41 | def __init__(self): 42 | super(Target_TrainableNet, self).__init__() 43 | self.feature_extractor = model_dict[args.model.base_model](args.model.pretrained_model) 44 | classifier_output_dim = len(source_classes) 45 | self.classifier = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256) 46 | self.cls_multibranch = CLS(self.feature_extractor.output_num(), classifier_output_dim, bottle_neck_dim=256) 47 | 48 | 49 | # ======= pre-trained source network ======= 50 | fixed_sourceNet = Source_FixedNet() 51 | fixed_sourceNet.load_state_dict(save_model_statedict) 52 | fixed_feature_extractor_s =(fixed_sourceNet.feature_extractor).cuda() 53 | fixed_classifier_s = (fixed_sourceNet.classifier).cuda() 54 | fixed_feature_extractor_s.eval() 55 | fixed_classifier_s.eval() 56 | 57 | # ======= trainable target network ======= 58 | trainable_tragetNet = Target_TrainableNet() 59 | feature_extractor_t =(trainable_tragetNet.feature_extractor).cuda() 60 | feature_extractor_t.load_state_dict(fixed_sourceNet.feature_extractor.state_dict()) 61 | classifier_s2t = (trainable_tragetNet.classifier).cuda() 62 | classifier_s2t.load_state_dict(fixed_sourceNet.classifier.state_dict()) 63 | classifier_t = (trainable_tragetNet.cls_multibranch).cuda() 64 | classifier_t.load_state_dict(fixed_sourceNet.classifier.state_dict()) 65 | 66 | 67 | model_dict = { 68 | 'global_step':0, 69 | 'state_dict': trainable_tragetNet.state_dict(), 70 | 'accuracy': 0} 71 | 72 | 73 | feature_extractor_t.train() 74 | classifier_s2t.train() 75 | classifier_t.train() 76 | print ("Finish model loaded...") 77 | 78 | domains=['amazon', 'dslr', 'webcam'] 79 | print ('domain....'+domains[args.data.dataset.source]+'>>>>>>'+domains[args.data.dataset.target]) 80 | 81 | scheduler = lambda step, initial_lr: inverseDecaySheduler(step, initial_lr, gamma=10, power=0.75, max_iter=(args.train.min_step)) 82 | 83 | optimizer_finetune = OptimWithSheduler( 84 | optim.SGD(feature_extractor_t.parameters(), lr=args.train.lr / 10.0, weight_decay=args.train.weight_decay, momentum=args.train.momentum, nesterov=True), 85 | scheduler) 86 | optimizer_classifier_s2t = OptimWithSheduler( 87 | optim.SGD(classifier_s2t.parameters(), lr=args.train.lr, weight_decay=args.train.weight_decay, momentum=args.train.momentum, nesterov=True), 88 | scheduler) 89 | optimizer_classifier_t= OptimWithSheduler( 90 | optim.SGD(classifier_t.parameters(), lr=args.train.lr, weight_decay=args.train.weight_decay, momentum=args.train.momentum, nesterov=True), 91 | scheduler) 92 | 93 | global_step = 0 94 | best_acc = 0 95 | epoch_id = 0 96 | class_num = args.data.dataset.n_total 97 | pt_memory_update_frequncy = args.train.update_freq 98 | 99 | 100 | while global_step < args.train.min_step: 101 | 102 | epoch_id += 1 103 | 104 | for i, (img_target, label_target) in enumerate(target_train_dl): 105 | 106 | # APM init/update 107 | if (global_step) % pt_memory_update_frequncy == 0: 108 | prototype_memory, num_prototype_,prototype_memory_dict = APM_init_update(feature_extractor_t, classifier_t) 109 | 110 | 111 | img_target = img_target.cuda() 112 | 113 | # forward pass: source-pretrained network 114 | fixed_fc1_s = fixed_feature_extractor_s.forward(img_target) 115 | _, _, _, logit_s = fixed_classifier_s.forward(fixed_fc1_s) 116 | pseudo_label_s = torch.argmax(logit_s, dim=1) 117 | 118 | # forward pass: target network 119 | fc1_t = feature_extractor_t.forward(img_target) 120 | _, _, logit_s2t, _ = classifier_s2t.forward(fc1_t) 121 | _, _, logit_t, _ = classifier_t(fc1_t) 122 | 123 | # compute pseudo labels 124 | proto_feat_tensor = torch.Tensor(prototype_memory) # (B * 2048) 125 | feature_embed_tensor = fc1_t.cpu() 126 | proto_feat_tensor = tensor_l2normalization(proto_feat_tensor) 127 | batch_feat_tensor = tensor_l2normalization(feature_embed_tensor) 128 | 129 | sim_mat = torch.mm(batch_feat_tensor, proto_feat_tensor.permute(1,0)) 130 | sim_mat = F.avg_pool1d(sim_mat.unsqueeze(0), kernel_size=num_prototype_, stride=num_prototype_).squeeze(0)# (B, #class) 131 | pseudo_label_t = torch.argmax(sim_mat, dim=1).cuda() 132 | 133 | # confidence-based filtering 134 | arg_idxs = torch.argsort(sim_mat, dim=1, descending=True) # (B, #class) 135 | 136 | first_group_idx = arg_idxs[:, 0] 137 | second_group_idx = arg_idxs[:, 1] 138 | 139 | first_group_feat = [prototype_memory_dict[int(x.data.numpy())] for x in first_group_idx] 140 | first_group_feat_tensor = torch.tensor(np.concatenate(first_group_feat, axis=0)) # (B*P, 2048) 141 | first_group_feat_tensor = tensor_l2normalization(first_group_feat_tensor) 142 | 143 | second_group_feat = [prototype_memory_dict[int(x.data.numpy())] for x in second_group_idx] 144 | second_group_feat_tensor = torch.tensor(np.concatenate(second_group_feat, axis=0)) # (B*P, 2048) 145 | second_group_feat_tensor = tensor_l2normalization(second_group_feat_tensor) 146 | 147 | feature_embed_tensor_repeat = torch.Tensor(np.repeat(feature_embed_tensor.cpu().data.numpy(), repeats=num_prototype_, axis=0)) 148 | feature_embed_tensor_repeat = tensor_l2normalization(feature_embed_tensor_repeat) 149 | 150 | first_dist_mat = 1 - torch.mm(first_group_feat_tensor, feature_embed_tensor_repeat.permute(1,0)) # distance = 1 - simialirty 151 | second_dist_mat = 1 - torch.mm(second_group_feat_tensor, feature_embed_tensor_repeat.permute(1,0)) 152 | 153 | first_dist_mat = F.max_pool2d(first_dist_mat.permute(1,0).unsqueeze(0).unsqueeze(0), kernel_size=num_prototype_, stride=num_prototype_).squeeze(0).squeeze(0)# (B, #class) 154 | second_dist_mat = -1*F.max_pool2d(-1* second_dist_mat.permute(1,0).unsqueeze(0).unsqueeze(0), kernel_size=num_prototype_, stride=num_prototype_).squeeze(0).squeeze(0)# (B, #class) 155 | 156 | first_dist_vec = torch.diag(first_dist_mat) #(B) 157 | second_dist_vec = torch.diag(second_dist_mat) # B 158 | 159 | confidence_mask = ((first_dist_vec- second_dist_vec) < 0).cuda() 160 | 161 | # optimize target network using two types of pseudo labels 162 | ce_from_s2t = nn.CrossEntropyLoss()(logit_s2t, pseudo_label_s) 163 | ce_from_t = nn.CrossEntropyLoss(reduction='none')(logit_t, pseudo_label_t).view(-1, 1).squeeze(1) 164 | ce_from_t = torch.mean(ce_from_t * confidence_mask, dim=0, keepdim=True) 165 | 166 | alpha = np.float(2.0 / (1.0 + np.exp(-10 * global_step / float(args.train.min_step//2))) - 1.0) 167 | ce_total = (1 - alpha) * ce_from_s2t + alpha * ce_from_t 168 | 169 | with OptimizerManager([optimizer_finetune, optimizer_classifier_s2t, optimizer_classifier_t]): 170 | loss = ce_total 171 | loss.backward() 172 | 173 | global_step += 1 174 | 175 | # evaluation during training 176 | if global_step % args.test.test_interval == 0: 177 | 178 | counter = AccuracyCounter() 179 | with TrainingModeManager([feature_extractor_t, classifier_t], train=False) as mgr, torch.no_grad(): 180 | 181 | for i, (img, label) in enumerate(target_test_dl): 182 | img = img.cuda() 183 | label = label.cuda() 184 | 185 | feature = feature_extractor_t.forward(img) 186 | _, _, _, predict_prob_t = classifier_t.forward(feature) 187 | 188 | counter.addOneBatch(variable_to_numpy(predict_prob_t), variable_to_numpy(one_hot(label, args.data.dataset.n_total))) 189 | 190 | acc_test = counter.reportAccuracy() 191 | print('>>>>>>>>>>>accuracy>>>>>>>>>>>>>>>>.') 192 | print(acc_test) 193 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>.') 194 | if best_acc < acc_test: 195 | best_acc = acc_test 196 | model_dict = { 197 | 'global_step': global_step + 1, 198 | 'state_dict': trainable_tragetNet.state_dict(), 199 | 'accuracy': acc_test} 200 | 201 | torch.save(model_dict, join('pretrained_weights/'+str(args.data.dataset.source) + str(args.data.dataset.target) +'/' + 'domain'+ str(args.data.dataset.source)+str(args.data.dataset.target)+'accBEST_model_checkpoint.pth.tar')) 202 | 203 | 204 | exit() 205 | --------------------------------------------------------------------------------