├── models ├── __init__.py └── models.py ├── datasets ├── __init__.py ├── get_office-home.sh ├── download_domain_net.sh ├── get_office.sh ├── gdrive.sh ├── preprocess_office31.py ├── preprocess_office-home.py └── datasets.py ├── images ├── overview.gif └── results_office.png ├── train_test ├── __init__.py ├── test.py └── train.py ├── utils └── utils.py ├── LICENSE ├── requirements.txt ├── README.md ├── visualize.py ├── main.py └── active_learning.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.models import * 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.datasets import * 2 | -------------------------------------------------------------------------------- /images/overview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/s3vaada/HEAD/images/overview.gif -------------------------------------------------------------------------------- /images/results_office.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/val-iisc/s3vaada/HEAD/images/results_office.png -------------------------------------------------------------------------------- /train_test/__init__.py: -------------------------------------------------------------------------------- 1 | from train_test.train import * 2 | from train_test.test import * 3 | 4 | -------------------------------------------------------------------------------- /datasets/get_office-home.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "data" ]; then 4 | mkdir "data" 5 | fi 6 | 7 | if [ ! -d "data/Office-Home" ]; then 8 | mkdir "data/Office-Home" 9 | fi 10 | 11 | 12 | bash gdrive.sh "https://drive.google.com/uc?export=download&id=0B81rNlvomiwed0V1YUxQdC1uOTg" tmp.zip 13 | unzip tmp.zip -d data/Office-Home 14 | rm tmp.zip -------------------------------------------------------------------------------- /datasets/download_domain_net.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | mkdir data/multi 3 | cd data/multi 4 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip -O real.zip 5 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip -O sketch.zip 6 | wget http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip -O clipart.zip 7 | unzip real.zip 8 | unzip sketch.zip 9 | unzip clipart.zip -------------------------------------------------------------------------------- /datasets/get_office.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "data" ]; then 4 | mkdir "data" 5 | fi 6 | 7 | if [ ! -d "data/Office31" ]; then 8 | mkdir "data/Office31" 9 | fi 10 | 11 | if [ ! -d "office31" ]; then 12 | mkdir "office31" 13 | fi 14 | 15 | bash gdrive.sh "https://drive.google.com/uc?export=download&id=0B4IapRTv9pJ1WGZVd1VDMmhwdlE" tmp.tar.gz 16 | tar -xvzf tmp.tar.gz -C data/Office31 17 | rm tmp.tar.gz 18 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def optimizer_scheduler(optimizer, p): 4 | """ 5 | Adjust the learning rate of optimizer 6 | - optimizer: optimizer for updating parameters 7 | - p: a variable for adjusting learning rate 8 | return: optimizer 9 | """ 10 | for param_group in optimizer.param_groups: 11 | if "i_lr" not in param_group: 12 | param_group["i_lr"] = param_group["lr"] 13 | param_group['lr'] = param_group["i_lr"] / (1. + 10 * p) ** 0.75 14 | 15 | return optimizer 16 | 17 | def sigmoid(parameter, plasticity): 18 | return 1. / (1 + torch.exp((-plasticity * parameter))) 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Video Analytics Lab -- IISc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/gdrive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Download large files from Google Drive without having to manually confirm 3 | # Author: https://www.matthuisman.nz/2019/01/download-google-drive-files-wget-curl.html 4 | 5 | url=$1 6 | filename=$2 7 | 8 | [ -z "$url" ] && echo A URL or ID is required first argument && exit 1 9 | 10 | fileid="" 11 | declare -a patterns=("s/.*\/file\/d\/\(.*\)\/.*/\1/p" "s/.*id\=\(.*\)/\1/p" "s/\(.*\)/\1/p") 12 | for i in "${patterns[@]}" 13 | do 14 | fileid=$(echo $url | sed -n $i) 15 | [ ! -z "$fileid" ] && break 16 | done 17 | 18 | [ -z "$fileid" ] && echo Could not find Google ID && exit 1 19 | 20 | echo File ID: $fileid 21 | 22 | tmp_file="$filename.$$.file" 23 | tmp_cookies="$filename.$$.cookies" 24 | tmp_headers="$filename.$$.headers" 25 | 26 | url='https://docs.google.com/uc?export=download&id='$fileid 27 | echo Downloading: "$url > $tmp_file" 28 | wget --save-cookies "$tmp_cookies" -q -S -O - $url 2> "$tmp_headers" 1> "$tmp_file" 29 | 30 | if [[ ! $(find "$tmp_file" -type f -size +10000c 2>/dev/null) ]]; then 31 | confirm=$(cat "$tmp_file" | sed -En 's/.*confirm=([0-9A-Za-z_]+).*/\1/p') 32 | fi 33 | 34 | if [ ! -z "$confirm" ]; then 35 | url='https://docs.google.com/uc?export=download&id='$fileid'&confirm='$confirm 36 | echo Downloading: "$url > $tmp_file" 37 | wget --load-cookies "$tmp_cookies" -q -S -O - $url 2> "$tmp_headers" 1> "$tmp_file" 38 | fi 39 | 40 | [ -z "$filename" ] && filename=$(cat "$tmp_headers" | sed -En 's/.*filename=\"(.*)\".*/\1/p') 41 | [ -z "$filename" ] && filename="google_drive.file" 42 | 43 | echo Moving: "$tmp_file > $filename" 44 | 45 | mv "$tmp_file" "$filename" 46 | 47 | rm -f "$tmp_cookies" "$tmp_headers" 48 | 49 | echo Saved: "$filename" 50 | echo DONE! 51 | 52 | exit 0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | cachetools==4.2.0 3 | certifi==2020.6.20 4 | cffi==1.14.3 5 | chardet==3.0.4 6 | click==7.1.2 7 | cloudpickle==1.6.0 8 | cmake==3.18.2.post1 9 | configparser==5.0.0 10 | cycler==0.10.0 11 | Cython==0.29.21 12 | decorator==4.4.2 13 | docker-pycreds==0.4.0 14 | easydict==1.9 15 | future==0.18.2 16 | fvcore==0.1.2.post20210115 17 | gitdb==4.0.5 18 | GitPython==3.1.9 19 | google-auth==1.24.0 20 | google-auth-oauthlib==0.4.2 21 | grpcio==1.34.1 22 | idna==2.10 23 | imageio==2.9.0 24 | importlib-metadata==3.4.0 25 | iopath==0.1.2 26 | joblib==1.2.0 27 | kiwisolver==1.2.0 28 | kornia==0.4.0 29 | Markdown==3.3.3 30 | matplotlib==3.3.2 31 | MulticoreTSNE==0.1 32 | networkx==2.5 33 | numpy==1.22.0 34 | oauthlib==3.1.0 35 | opencv-contrib-python==4.2.0.34 36 | pandas==1.1.2 37 | pathtools==0.1.2 38 | Pillow==9.3.0 39 | plotly==4.11.0 40 | portalocker==2.0.0 41 | promise==2.3 42 | protobuf==3.18.3 43 | psutil==5.7.2 44 | pyasn1==0.4.8 45 | pyasn1-modules==0.2.8 46 | pycocotools==2.0.2 47 | pycparser==2.20 48 | pydot==1.4.1 49 | pyparsing==2.4.7 50 | python-dateutil==2.8.1 51 | pytz==2020.1 52 | PyWavelets==1.1.1 53 | PyYAML==5.4 54 | requests==2.24.0 55 | requests-oauthlib==1.3.0 56 | retrying==1.3.3 57 | rsa==4.7 58 | scikit-image==0.17.2 59 | scikit-learn==0.23.2 60 | scipy==1.5.2 61 | seaborn==0.11.0 62 | sentry-sdk==0.18.0 63 | shortuuid==1.0.1 64 | six==1.15.0 65 | smmap==3.0.4 66 | subprocess32==3.5.4 67 | tabulate==0.8.7 68 | tensorboard==2.4.1 69 | tensorboard-plugin-wit==1.7.0 70 | tensorboardX==2.1 71 | termcolor==1.1.0 72 | threadpoolctl==2.1.0 73 | tifffile==2020.9.3 74 | torch==1.6.0 75 | torchvision==0.7.0 76 | tqdm==4.50.0 77 | typing-extensions==3.7.4.3 78 | urllib3==1.26.5 79 | wandb==0.10.4 80 | watchdog==0.10.3 81 | Werkzeug==1.0.1 82 | yacs==0.1.8 83 | zipp==3.4.0 84 | -------------------------------------------------------------------------------- /datasets/preprocess_office31.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as numpy 3 | from skimage import io, transform 4 | import numpy as np 5 | 6 | weight = 256 7 | hight = 256 8 | data = None 9 | labels = None 10 | 11 | domains = ["amazon", "dslr", "webcam"] 12 | 13 | 14 | def process_image(image, weight, hight): 15 | img = io.imread(image) 16 | img = transform.resize(img, (weight, hight), mode="reflect") 17 | return img 18 | 19 | 20 | for d in domains: 21 | #path = "domain_adaptation_images/" + d + "/" "images/" 22 | path = "data/Office31/" + d + "/images/" 23 | print("processing " + path) 24 | 25 | for _, dirnames, _ in os.walk(path): 26 | dirnames.sort() 27 | for dirname in dirnames: 28 | index = dirnames.index(dirname) 29 | workdir = os.path.join(path, dirname) 30 | print(workdir) 31 | processed_images = io.ImageCollection( 32 | workdir + "/*.jpg", load_func=process_image, weight=weight, hight=hight) 33 | label = np.full(len(processed_images), 34 | fill_value=index, dtype=np.int64) 35 | # print(processed_images) 36 | images = io.concatenate_images(processed_images) 37 | 38 | if index == 0: 39 | data = images 40 | labels = label 41 | 42 | else: 43 | data = np.vstack((data, images)) 44 | labels = np.append(labels, label) 45 | 46 | print(np.shape(data)) 47 | print(np.shape(labels)) 48 | 49 | partial = [0, 1, 5, 10, 11, 12, 15, 16, 17, 22] 50 | idx = np.where(np.isin(labels, partial)) 51 | data_p = data[idx] 52 | label_p = labels[idx] 53 | 54 | print(np.shape(data_p)) 55 | print(np.shape(label_p)) 56 | 57 | np.savez("office31/"+d+"10.npz", 58 | data=data_p, label=label_p) 59 | print("Saved {}10.npz. It's length is {}".format(d, len(labels[idx]))) 60 | np.savez("office31/"+d+"31.npz", data=data, label=labels) 61 | print("Saved {}31.npz. It's length is {}".format(d, len(labels))) 62 | -------------------------------------------------------------------------------- /datasets/preprocess_office-home.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as numpy 3 | from skimage import io, transform 4 | import numpy as np 5 | 6 | weight = 256 7 | hight = 256 8 | data = None 9 | labels = None 10 | 11 | domains = ["Art", "Clipart", "Product","Real World"] 12 | 13 | 14 | def process_image(image, weight, hight): 15 | img = io.imread(image) 16 | img = transform.resize(img, (weight, hight), mode="reflect") 17 | return img 18 | 19 | 20 | for d in domains: 21 | path = "data/Office-Home/OfficeHomeDataset_10072016/" + d + "/" 22 | print("processing " + path) 23 | 24 | for _, dirnames, _ in os.walk(path): 25 | dirnames.sort() 26 | for dirname in dirnames: 27 | index = dirnames.index(dirname) 28 | workdir = os.path.join(path, dirname) 29 | print(workdir) 30 | processed_images = io.ImageCollection( 31 | workdir + "/*.jpg", load_func=process_image, weight=weight, hight=hight) 32 | label = np.full(len(processed_images), 33 | fill_value=index, dtype=np.int64) 34 | # print(processed_images) 35 | images = io.concatenate_images(processed_images) 36 | 37 | if index == 0: 38 | data = images 39 | labels = label 40 | 41 | else: 42 | data = np.vstack((data, images)) 43 | labels = np.append(labels, label) 44 | 45 | # print(np.shape(data)) 46 | # print(np.shape(labels)) 47 | 48 | # partial = [0, 1, 5, 10, 11, 12, 15, 16, 17, 22] 49 | # idx = np.where(np.isin(labels, partial)) 50 | # data_p = data[idx] 51 | # label_p = labels[idx] 52 | 53 | # print(np.shape(data_p)) 54 | # print(np.shape(label_p)) 55 | 56 | # np.savez("office-home/"+d+"10.npz", 57 | # data=data_p, label=label_p) 58 | # print("Saved {}10.npz. It's length is {}".format(d, len(labels[idx]))) 59 | np.savez("/home/sumukh/office-home/"+d+"65.npz", data=data, label=labels) 60 | print("Saved {}65.npz. It's length is {}".format(d, len(labels))) 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # S3VAADA: Submodular Subset Selection for Virtual Adversarial Active Domain Adaptation 2 | ## ICCV 2021 3 | Harsh Rangwani, Arihant Jain*, Sumukh K Aithal*, R. Venkatesh Babu\ 4 | Video Analytics Lab, Indian Institute of Science, Bengaluru 5 | ## [[Project Webpage](https://sites.google.com/iisc.ac.in/s3vaada-iccv2021/)] [[Paper](https://arxiv.org/pdf/2109.08901v1.pdf)] 6 | 7 | 8 | 9 | 10 | ## TLDR 11 | Obtain performance close to supervised learning using small amount of labelled data (~10%) in target domain, for adapting a model from source to target domain. 12 | 13 | ![Alt Text](images/overview.gif) 14 | 15 | ## Results on Office-Home 16 | ![Results](images/results_office.png) 17 | 18 | ## Setup the requirements 19 | Run the following commands to setup your environment 20 | ``` 21 | git clone https://github.com/val-iisc/s3vaada.git 22 | cd s3vaada/ 23 | pip install -r requirements.txt 24 | cd models/ 25 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth 26 | mv resnet50-19c8e357.pth resnet50.pth 27 | cd ../ 28 | ``` 29 | 30 | ## Dataset 31 | ### Office-31 32 | Run the following commands to download and preprocess Office-31 dataset 33 | ``` 34 | cd datasets/ 35 | sh get_office.sh 36 | python preprocess_office31.py 37 | cd ../ 38 | ``` 39 | ### Office-Home 40 | Run the following commands to download and preprocess Office-31 dataset 41 | ``` 42 | cd datasets/ 43 | sh get_office-home.sh 44 | python preprocess_office-home.py 45 | cd ../ 46 | ``` 47 | 48 | 49 | ## Training 50 | ``` 51 | python main.py --name w2a-s3vaada --source webcam --target amazon 52 | ``` 53 | 54 | ## Citation 55 | If you find our work useful cite our paper using the following BibTeX entry. 56 | ``` 57 | @InProceedings{Rangwani_2021_ICCV, 58 | author = {Rangwani, Harsh and Jain, Arihant and Aithal, Sumukh K and Babu, R. Venkatesh}, 59 | title = {S3VAADA: Submodular Subset Selection for Virtual Adversarial Active Domain Adaptation}, 60 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 61 | month = {October}, 62 | year = {2021}, 63 | pages = {7516-7525} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /train_test/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import wandb 4 | from torch.autograd import Variable 5 | import os 6 | import random 7 | import copy 8 | 9 | best_source_acc = 0 10 | best_target_acc = 0 11 | best_domain_acc = 0 12 | 13 | 14 | def test(net, source_dataloader, target_dataloader, epoch, cycle, args, device, final_epoch): 15 | # Setup model 16 | net.eval() 17 | 18 | source_label_correct = 0.0 19 | target_label_correct = 0.0 20 | 21 | source_domain_correct = 0.0 22 | target_domain_correct = 0.0 23 | domain_correct = 0.0 24 | 25 | lamda = 0 26 | len_src_dataset = len(source_dataloader.dataset) 27 | 28 | # Testing on small subset of Source dataset 29 | small_src_dataset = torch.utils.data.Subset(source_dataloader.dataset, random.sample( 30 | range(0, len_src_dataset), len_src_dataset//10)) 31 | source_dataloader = torch.utils.data.DataLoader( 32 | small_src_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 33 | # Test source data 34 | with torch.no_grad(): 35 | for batch_idx, source_data in enumerate(source_dataloader): 36 | 37 | source_input, source_label = source_data 38 | 39 | source_input, source_label = Variable( 40 | source_input.to(device)), Variable(source_label.to(device)) 41 | source_labels = Variable(torch.zeros( 42 | (source_input.size()[0])).type(torch.LongTensor).to(device)) 43 | 44 | source_label_pred, source_domain_pred, _ = net( 45 | source_input, 'source', lamda) 46 | 47 | source_label_pred = source_label_pred.data.max(1, keepdim=True)[1] 48 | source_label_correct += source_label_pred.eq( 49 | source_label.data.view_as(source_label_pred)).cpu().sum() 50 | 51 | source_domain_pred = source_domain_pred.data.max(1, keepdim=True)[ 52 | 1] 53 | source_domain_correct += source_domain_pred.eq( 54 | source_labels.data.view_as(source_domain_pred)).cpu().sum() 55 | 56 | # Test target data 57 | with torch.no_grad(): 58 | for batch_idx, target_data in enumerate(target_dataloader): 59 | 60 | target_input, target_label = target_data 61 | 62 | target_input, target_label = target_input.type(torch.FloatTensor).to( 63 | device), target_label.type(torch.LongTensor).to(device) 64 | target_labels = Variable(torch.ones( 65 | (target_input.size()[0])).type(torch.LongTensor).to(device)) 66 | # Compute target accuracy both for label and domain predictions 67 | target_label_pred_, target_domain_pred, _ = net( 68 | target_input, 'target', lamda) 69 | 70 | target_label_pred = target_label_pred_.data.max(1, keepdim=True)[1] 71 | target_label_correct += target_label_pred.eq( 72 | target_label.data.view_as(target_label_pred)).cpu().sum() 73 | 74 | target_domain_pred = target_domain_pred.data.max(1, keepdim=True)[ 75 | 1] 76 | target_domain_correct += target_domain_pred.eq( 77 | target_labels.data.view_as(target_domain_pred)).cpu().sum() 78 | 79 | # Compute domain correctness 80 | domain_correct = source_domain_correct + target_domain_correct 81 | 82 | global best_source_acc, best_target_acc, best_domain_acc 83 | 84 | target_acc = float(target_label_correct) / len(target_dataloader.dataset) 85 | source_acc = float(source_label_correct) / len(source_dataloader.dataset) 86 | domain_acc = float(domain_correct) / \ 87 | (len(source_dataloader.dataset) + len(target_dataloader.dataset)) 88 | 89 | if target_acc > best_target_acc: 90 | best_target_acc = target_acc 91 | best_source_acc = source_acc 92 | best_domain_acc = domain_acc 93 | 94 | CURRENT_DIR_PATH = os.path.abspath( 95 | os.path.join(os.path.dirname(__file__), os.pardir)) 96 | 97 | MODEL_CHECKPOINTS = CURRENT_DIR_PATH + '/models/models_checkpoints/' 98 | model_root = MODEL_CHECKPOINTS + args.source + '-' + \ 99 | args.target + "/" + args.sampling + "/" + args.time_stamp 100 | 101 | PATH = model_root + "/" + str(cycle) + ".pth" 102 | torch.save(net.state_dict(), PATH) 103 | 104 | # Print results 105 | if args.log_results and final_epoch: 106 | wandb.log({"Source Accuracy": 100. * best_source_acc, 107 | "Domain Accuracy": 100. * best_domain_acc, 108 | "Number of labeled images": cycle*args.budget, "Target Accuracy": 100. * best_target_acc}) 109 | # After every cycle: reset 110 | best_target_acc = 0 111 | best_source_acc = 0 112 | best_domain_acc = 0 113 | 114 | print('\nSource Accuracy: {}/{} ({:.4f}%)\nTarget Accuracy: {}/{} ({:.4f}%)\n' 115 | 'Domain Accuracy: {}/{} ({:.4f}%)\n'. 116 | format( 117 | source_label_correct, len(source_dataloader.dataset), 118 | 100. * source_acc, 119 | target_label_correct, len(target_dataloader.dataset), 120 | 100. * target_acc, 121 | domain_correct, len(source_dataloader.dataset) + 122 | len(target_dataloader.dataset), 123 | 100. * float(domain_correct) / ( 124 | len(source_dataloader.dataset) + len( 125 | target_dataloader.dataset)) 126 | )) 127 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import train_test 2 | import wandb 3 | import torch.nn as nn 4 | import torch 5 | import seaborn as sns 6 | import matplotlib.patheffects as PathEffects 7 | import matplotlib.pyplot as plt 8 | from MulticoreTSNE import MulticoreTSNE as TSNE 9 | import numpy as np 10 | import os 11 | from sklearn import manifold 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | 15 | 16 | def new_TSNE(net, source_dataloader, target_dataloader, new_data_loader, temp_dataloader, cycle, device, args): 17 | 18 | net.eval() 19 | source_embedding = torch.tensor([]).to(device) 20 | source_labels = torch.tensor([]).type(torch.LongTensor) # .to(device) 21 | 22 | f = plt.figure(figsize=(16, 16)) 23 | lamda = 0 24 | with torch.no_grad(): 25 | for batch_idx, source_data in enumerate(source_dataloader): 26 | source_input, source_label = source_data 27 | source_input = source_input.to(device) 28 | source_feature = net.feature_extractor( 29 | source_input, 'source', lamda) 30 | 31 | source_embedding = torch.cat((source_embedding, source_feature), 0) 32 | source_labels = torch.cat((source_labels, source_label), 0) 33 | 34 | target_embedding = torch.tensor([]).to(device) 35 | target_labels = torch.tensor([]).type(torch.LongTensor) # .to(device) 36 | 37 | with torch.no_grad(): 38 | for batch_idx, (inputs, labels) in enumerate(target_dataloader): 39 | p = float(batch_idx) / len(target_dataloader) 40 | inputs = inputs.to(device) 41 | 42 | target_feature = net.feature_extractor(inputs, 'target', lamda) 43 | target_embedding = torch.cat((target_embedding, target_feature), 0) 44 | target_labels = torch.cat((target_labels, labels), 0) 45 | 46 | newly_labeled_embedding = torch.tensor([]).to(device) 47 | new_labels = torch.tensor([]).type(torch.LongTensor) 48 | with torch.no_grad(): 49 | for batch_idx, (inputs, labels) in enumerate(temp_dataloader): 50 | inputs = inputs.to(device) 51 | target_feature = net.feature_extractor(inputs, 'target', lamda) 52 | newly_labeled_embedding = torch.cat( 53 | (newly_labeled_embedding, target_feature), 0) 54 | new_labels = torch.cat((new_labels, labels), 0) 55 | 56 | labeled_embedding = torch.tensor([]).to(device) 57 | old_labels = torch.tensor([]).type(torch.LongTensor) 58 | with torch.no_grad(): 59 | if new_data_loader is not None: 60 | for batch_idx, (inputs, labels) in enumerate(new_data_loader): 61 | inputs = inputs.to(device) 62 | target_feature = net.feature_extractor(inputs, 'target', lamda) 63 | labeled_embedding = torch.cat( 64 | (labeled_embedding, target_feature), 0) 65 | old_labels = torch.cat((old_labels, labels), 0) 66 | 67 | source_embedding = source_embedding.cpu().numpy() 68 | labeled_embedding = labeled_embedding.cpu().numpy() 69 | target_embedding = target_embedding.cpu().numpy() 70 | newly_labeled_embedding = newly_labeled_embedding.cpu().numpy() 71 | 72 | source_labels = source_labels.cpu().numpy() 73 | target_labels = target_labels.cpu().numpy() 74 | new_labels = new_labels.cpu().numpy() 75 | old_labels = old_labels.cpu().numpy() 76 | 77 | if new_data_loader is None: 78 | X = np.concatenate( 79 | (source_embedding, target_embedding, newly_labeled_embedding), axis=0) 80 | else: 81 | X = np.concatenate((source_embedding, target_embedding, 82 | newly_labeled_embedding, labeled_embedding), axis=0) 83 | tsne = TSNE(n_jobs=8) 84 | X_tsne = tsne.fit_transform(X) 85 | 86 | source_embedding = X_tsne[:len(source_embedding)] 87 | target_embedding = X_tsne[len(source_embedding):len( 88 | source_embedding)+len(target_embedding)] 89 | 90 | newly_labeled_embedding = X_tsne[len(source_embedding)+len(target_embedding):len( 91 | source_embedding)+len(target_embedding)+len(newly_labeled_embedding)] 92 | if new_data_loader is not None: 93 | labeled_embedding = X_tsne[len( 94 | source_embedding)+len(target_embedding)+len(newly_labeled_embedding):] 95 | 96 | n_class = args.num_classes 97 | palette = np.array(sns.color_palette('hls', n_class)) 98 | 99 | plt.scatter(source_embedding[:, 0], source_embedding[:, 1], lw=0, s=20, 100 | c=palette[source_labels.astype(np.int)], marker='o') # , alpha=0.3) 101 | plt.scatter(target_embedding[:, 0], target_embedding[:, 1], lw=0, s=20, 102 | c=palette[target_labels.astype(np.int)], marker='*') # , alpha=0.7) 103 | #plt.plot(newly_labeled_embedding[:, 0], newly_labeled_embedding[:, 1],linestyle='none',markersize=100, markeredgecolor="orange", markeredgewidth=10) 104 | plt.scatter(newly_labeled_embedding[:, 0], newly_labeled_embedding[:, 1], s=80, c=palette[new_labels.astype( 105 | np.int)], marker='s', edgecolor='red', linewidths=3) # , alpha=0.5) 106 | if new_data_loader is not None: 107 | plt.scatter(labeled_embedding[:, 0], labeled_embedding[:, 1], lw=0, s=90, 108 | c=palette[old_labels.astype(np.int)], marker='>') # , alpha=0.5) 109 | 110 | if args.log_results: 111 | wandb.log({f"Cycle{cycle+1}": plt}) 112 | 113 | plt.close() 114 | 115 | 116 | def analyze(idx, target_dataset, net, args, device): 117 | 118 | print("The queried samples belong to the following classes: ") 119 | for index in idx: 120 | print(target_dataset[index][1], end=' ') 121 | # classes_of_new_samples = torch.cat((classes_of_new_samples,torch.from_numpy(target_dataset[index][1])),0) 122 | print() 123 | 124 | 125 | if __name__ == '__main__': 126 | 127 | X = np.random.randn(1000, 50) 128 | tsne = TSNE(n_jobs=4) 129 | Y = tsne.fit_transform(X) 130 | print() 131 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import random 4 | import copy 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Function, Variable 10 | from torch.nn import init 11 | from torchvision.models.resnet import conv3x3, conv1x1, BasicBlock, Bottleneck 12 | 13 | 14 | class ResNet_FeatureClassifer(nn.Module): 15 | 16 | def __init__(self, num_classes): 17 | super(ResNet_FeatureClassifer, self).__init__() 18 | self.feature_classifier = nn.Sequential() 19 | self.feature_classifier.add_module('fc_n', nn.Linear(256, num_classes)) 20 | 21 | def forward(self, feature): 22 | return self.feature_classifier(feature) 23 | 24 | 25 | class ResNet_DomainClassifier(nn.Module): 26 | 27 | def __init__(self): 28 | super(ResNet_DomainClassifier, self).__init__() 29 | self.domain_classifier = nn.Sequential() 30 | self.domain_classifier.add_module('d_fc1', nn.Linear(256, 1024)) 31 | self.domain_classifier.add_module('d_relu1', nn.ReLU(True)) 32 | self.domain_classifier.add_module('d_fc2', nn.Linear(1024, 1024)) 33 | self.domain_classifier.add_module('d_relu2', nn.ReLU(True)) 34 | self.domain_classifier.add_module('d_fc3', nn.Linear(1024, 2)) 35 | 36 | def forward(self, feature): 37 | return self.domain_classifier(feature) 38 | 39 | 40 | class ResNet_FeatureExtractor(nn.Module): 41 | def __init__(self, layers=[3, 4, 6, 3], block=Bottleneck): 42 | 43 | super(ResNet_FeatureExtractor, self).__init__() 44 | 45 | self.inplanes = 64 46 | self.dilation = 1 47 | norm_layer = nn.BatchNorm2d 48 | self._norm_layer = nn.BatchNorm2d 49 | replace_stride_with_dilation = [False, False, False] 50 | self.groups = 1 51 | self.base_width = 64 52 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 53 | bias=False) 54 | self.bn1 = norm_layer(self.inplanes) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 57 | self.layer1 = self._make_layer(block, 64, layers[0]) 58 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 59 | dilate=replace_stride_with_dilation[0]) 60 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 61 | dilate=replace_stride_with_dilation[1]) 62 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 63 | dilate=replace_stride_with_dilation[2]) 64 | self.bottleneck = nn.Sequential() 65 | self.bottleneck.add_module('avgpool1', nn.AdaptiveAvgPool2d((1, 1))) 66 | self.linear_layer = nn.Linear(2048, 256) 67 | 68 | zero_init_residual = False 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | nn.init.kaiming_normal_( 72 | m.weight, mode='fan_out', nonlinearity='relu') 73 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 74 | nn.init.constant_(m.weight, 1) 75 | nn.init.constant_(m.bias, 0) 76 | 77 | # Zero-initialize the last BN in each residual branch, 78 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 79 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 80 | if zero_init_residual: 81 | for m in self.modules(): 82 | if isinstance(m, Bottleneck): 83 | nn.init.constant_(m.bn3.weight, 0) 84 | elif isinstance(m, BasicBlock): 85 | nn.init.constant_(m.bn2.weight, 0) 86 | 87 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 88 | norm_layer = self._norm_layer 89 | downsample = None 90 | previous_dilation = self.dilation 91 | if dilate: 92 | self.dilation *= stride 93 | stride = 1 94 | if stride != 1 or self.inplanes != planes * block.expansion: 95 | downsample = nn.Sequential( 96 | conv1x1(self.inplanes, planes * block.expansion, stride), 97 | norm_layer(planes * block.expansion), 98 | ) 99 | 100 | layers = [] 101 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 102 | self.base_width, previous_dilation, norm_layer)) 103 | self.inplanes = planes * block.expansion 104 | for _ in range(1, blocks): 105 | layers.append(block(self.inplanes, planes, groups=self.groups, 106 | base_width=self.base_width, dilation=self.dilation, 107 | norm_layer=norm_layer)) 108 | 109 | return nn.Sequential(*layers) 110 | 111 | def forward(self, x, domain='target', lamda=0): 112 | 113 | x = self.conv1(x) 114 | x = self.bn1(x) 115 | x = self.relu(x) 116 | x = self.maxpool(x) 117 | 118 | x = self.layer1(x) 119 | x = self.layer2(x) 120 | x = self.layer3(x) 121 | x = self.layer4(x) 122 | 123 | feature = self.bottleneck(x) 124 | feature = feature.view(-1, 2048) 125 | feature = self.linear_layer(feature) 126 | 127 | return feature 128 | 129 | class ResNet(nn.Module): 130 | 131 | def __init__(self, num_classes, device, args): 132 | 133 | super(ResNet, self).__init__() 134 | 135 | self.feature_extractor = ResNet_FeatureExtractor() 136 | self.feature_classifier = ResNet_FeatureClassifer(num_classes) 137 | self.domain_classifier = ResNet_DomainClassifier() 138 | self.method = args.method 139 | 140 | def forward(self, input, domain, lamda): 141 | feature = self.feature_extractor(input, domain, lamda) 142 | 143 | class_prediction = self.feature_classifier(feature) 144 | 145 | if self.method == "vaada": 146 | domain_prediction = self.domain_classifier(feature) 147 | else: 148 | reverse_feature = ReverseLayer.apply(feature, lamda) 149 | domain_prediction = self.domain_classifier(reverse_feature) 150 | 151 | return class_prediction, domain_prediction, feature 152 | 153 | 154 | class ReverseLayer(Function): 155 | 156 | @staticmethod 157 | def forward(ctx, x, lamda): 158 | ctx.lamda = lamda 159 | return x.view_as(x) 160 | 161 | @staticmethod 162 | def backward(ctx, grad_output): 163 | output = grad_output.neg() * ctx.lamda 164 | return output, None 165 | 166 | 167 | def load_single_state_dict(net, state_dict): 168 | 169 | own_state = net.state_dict() 170 | count = 0 171 | fe_param_count = 0 172 | for name, param in own_state.items(): 173 | if('feature_extractor' not in name): 174 | continue 175 | fe_param_count += 1 176 | # parsed = name.split('.') 177 | new_name = name.replace('feature_extractor.', '') 178 | if new_name in state_dict.keys(): 179 | # print(new_name) 180 | param_ = state_dict[new_name] 181 | if isinstance(param_, torch.nn.Parameter): 182 | # backwards compatibility for serialized parameters 183 | param_ = param_.data 184 | param_data = param_.data 185 | own_state[name].copy_(param_data) 186 | count += 1 187 | else: 188 | pass 189 | # print(name) 190 | print("Imagenet pre-trained weights loaded") 191 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data import Subset 10 | 11 | import torchvision 12 | from torchvision import datasets 13 | from torchvision import transforms 14 | 15 | import numpy as np 16 | from skimage import io, transform 17 | 18 | from tqdm import tqdm 19 | from sklearn.model_selection import train_test_split 20 | 21 | CURRENT_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 22 | 23 | 24 | class Office31(Dataset): 25 | 26 | def __init__(self, domain, root="office31/", train=True, partial=False, transform=None, target_transform=None): 27 | super(Office31, self).__init__() 28 | self.root = os.path.join(os.getcwd(), 'datasets', root) 29 | self.train = train 30 | self.partial = partial 31 | self.transform = transform 32 | self.target_transform = target_transform 33 | 34 | if self.partial: 35 | dataset_ = np.load(os.path.join(self.root, domain+"10.npz")) 36 | else: 37 | dataset_ = np.load(os.path.join(self.root, domain+"31.npz")) 38 | 39 | self.data, self.label = dataset_["data"], dataset_["label"] 40 | 41 | def __getitem__(self, index): 42 | 43 | data, label = self.data[index], self.label[index] 44 | data = Image.fromarray(np.uint8(data*255.0), mode="RGB") 45 | 46 | if self.transform is not None: 47 | data = self.transform(data) 48 | 49 | if self.target_transform is not None: 50 | label = self.target_transform(label) 51 | 52 | return data, label 53 | 54 | def __len__(self): 55 | return (len(self.label)) 56 | 57 | 58 | class OfficeHome(Dataset): 59 | 60 | def __init__(self, domain, root="office-home/", train=True, transform=None, target_transform=None): 61 | super(OfficeHome, self).__init__() 62 | self.root = os.path.join(os.getcwd(), 'datasets', root) 63 | self.train = train 64 | self.transform = transform 65 | self.target_transform = target_transform 66 | 67 | dataset_ = np.load(os.path.join(self.root, domain+"65.npz")) 68 | 69 | self.data, self.label = dataset_["data"], dataset_["label"] 70 | 71 | def __getitem__(self, index): 72 | 73 | data, label = self.data[index], self.label[index] 74 | data = Image.fromarray(np.uint8(data*255.0), mode="RGB") 75 | 76 | if self.transform is not None: 77 | data = self.transform(data) 78 | 79 | if self.target_transform is not None: 80 | label = self.target_transform(label) 81 | 82 | return data, label 83 | 84 | def __len__(self): 85 | return (len(self.label)) 86 | 87 | 88 | class DomainNet(Dataset): 89 | def __init__(self, domain, args, root="datasets/", 90 | transform=None, target_transform=None, test=False): 91 | if test: 92 | imgs, labels = make_dataset_fromlist(os.path.join( 93 | os.getcwd(), "datasets/txt", domain+"_test.txt")) 94 | args.num_classes = len(return_classlist(os.path.join( 95 | os.getcwd(), "datasets/txt", domain+"_test.txt"))) 96 | else: 97 | imgs, labels = make_dataset_fromlist(os.path.join( 98 | os.getcwd(), "datasets/txt", domain+"_train.txt")) 99 | args.num_classes = len(return_classlist(os.path.join( 100 | os.getcwd(), "datasets/txt", domain+"_train.txt"))) 101 | self.imgs = imgs 102 | self.label = labels 103 | self.transform = transform 104 | self.target_transform = target_transform 105 | self.loader = pil_loader 106 | self.root = root # READ FROM SSD 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index 112 | Returns: 113 | tuple: (image, target) where target is 114 | class_index of the target class. 115 | """ 116 | path = os.path.join(self.root, self.imgs[index]) 117 | target = self.label[index] 118 | img = self.loader(path) 119 | if self.transform is not None: 120 | img = self.transform(img) 121 | if self.target_transform is not None: 122 | target = self.target_transform(target) 123 | return img, target 124 | 125 | def __len__(self): 126 | return len(self.imgs) 127 | 128 | 129 | def get_dataset_visda(args, domain, train): 130 | # Data loading code 131 | dir_s = os.path.join(os.getcwd()+"/datasets/data/VisDA-18", args.source) 132 | dir_t = os.path.join(os.getcwd()+"/datasets/data/VisDA-18", args.target) 133 | 134 | if not os.path.isdir(dir_s): 135 | raise ValueError( 136 | 'The required data path is not exist, please download the dataset!') 137 | if not os.path.isdir(dir_t): 138 | raise ValueError( 139 | 'The required data path is not exist, please download the dataset!') 140 | 141 | # transformation 142 | data_transforms = transforms.Compose([ 143 | transforms.Resize((224, 224)), 144 | transforms.ToTensor(), 145 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 146 | ]) 147 | 148 | if domain == 'source': 149 | dataset = datasets.ImageFolder(root=dir_s, transform=data_transforms) 150 | else: 151 | dataset = datasets.ImageFolder(root=dir_t, transform=data_transforms) 152 | 153 | return dataset 154 | 155 | 156 | def get_source_domain(source_name, args, train=True): 157 | # Define root folder to store source dataset 158 | root = CURRENT_DIR_PATH + "/source" 159 | try: 160 | os.makedirs(root) 161 | except OSError as e: 162 | if e.errno == errno.EEXIST: 163 | pass 164 | else: 165 | raise 166 | 167 | # Define image source domain transformation 168 | source_transforms = transforms.Compose([ 169 | transforms.Resize((256, 256)), 170 | transforms.CenterCrop(args.image_size), 171 | transforms.ToTensor(), 172 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 173 | ]) 174 | # Define source dataset 175 | if source_name == 'amazon' or source_name == 'dslr' or source_name == 'webcam': 176 | source_dataset = Office31( 177 | source_name, transform=source_transforms, partial=False) 178 | elif source_name == 'Art' or source_name == 'Clipart' or source_name == 'Product' or source_name == 'Real World': 179 | source_dataset = OfficeHome(source_name, transform=source_transforms) 180 | elif source_name == 'real' or source_name == 'synthetic': # visda-18 181 | source_dataset = get_dataset_visda(args, 'source', train) 182 | elif source_name == "real-DN" or source_name == "clipart" or source_name == "sketch": 183 | source_dataset = DomainNet( 184 | source_name, args, transform=source_transforms) 185 | 186 | loader = torch.utils.data.DataLoader( 187 | source_dataset, batch_size=args.batch_size, shuffle=train, num_workers=args.workers) 188 | # Return source's dataset DataLoader object 189 | return loader, source_dataset 190 | 191 | 192 | def get_target_domain(target_name, args): 193 | # Define root folder to store target dataset 194 | root = CURRENT_DIR_PATH + "/target" 195 | try: 196 | os.makedirs(root) 197 | except OSError as e: 198 | if e.errno == errno.EEXIST: 199 | pass 200 | else: 201 | raise 202 | 203 | target_img_transforms = transforms.Compose([ 204 | transforms.Resize((256, 256)), 205 | transforms.CenterCrop(args.image_size), 206 | transforms.ToTensor(), 207 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 208 | ]) 209 | # Define image target domain transformation 210 | # Define target dataset 211 | if target_name == 'amazon' or target_name == 'dslr' or target_name == 'webcam': 212 | train_target_dataset = Office31(target_name, 213 | transform=target_img_transforms) 214 | test_target_dataset = Office31(target_name, 215 | transform=target_img_transforms) 216 | args.num_classes = 31 217 | 218 | elif target_name == 'Art' or target_name == 'Clipart' or target_name == 'Product' or target_name == 'Real World': 219 | train_target_dataset = OfficeHome( 220 | target_name, transform=target_img_transforms) 221 | test_target_dataset = OfficeHome( 222 | target_name, transform=target_img_transforms) 223 | args.num_classes = 65 224 | 225 | elif target_name == "real" or target_name == "synthetic": 226 | train_target_dataset = get_dataset_visda(args, 'train_target', None) 227 | test_target_dataset = get_dataset_visda(args, 'test_target', None) 228 | args.num_classes = 12 229 | 230 | elif target_name == "sketch" or target_name == "clipart": 231 | train_target_dataset = DomainNet( 232 | target_name, args, transform=target_img_transforms) 233 | test_target_dataset = DomainNet( 234 | target_name, args, transform=target_img_transforms, test=True) 235 | 236 | # Define target dataloader 237 | if target_name == "real" or target_name == "synthetic": 238 | train_idx, val_idx = train_test_split(list(range(len( 239 | train_target_dataset))), test_size=0.2, random_state=42, stratify=train_target_dataset.targets) 240 | elif target_name != "sketch" or target_name != "clipart": 241 | train_idx, val_idx = train_test_split(list(range(len( 242 | train_target_dataset))), test_size=0.2, random_state=42, stratify=train_target_dataset.label) 243 | if target_name != "sketch" and target_name != "clipart": 244 | train_dataset = Subset(train_target_dataset, train_idx) 245 | test_dataset = Subset(test_target_dataset, val_idx) 246 | else: 247 | train_dataset = train_target_dataset 248 | test_dataset = test_target_dataset 249 | print("Number of images in test dataset of domain ", 250 | target_name, ":", len(test_dataset)) 251 | 252 | train_loader = torch.utils.data.DataLoader( 253 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 254 | test_loader = torch.utils.data.DataLoader( 255 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 256 | 257 | # Return target's dataset DataLoader object 258 | return (train_dataset, train_loader), (test_dataset, test_loader) 259 | 260 | 261 | def pil_loader(path): 262 | with open(path, 'rb') as f: 263 | img = Image.open(f) 264 | return img.convert('RGB') 265 | 266 | # DomainNet 267 | def make_dataset_fromlist(image_list): 268 | with open(image_list) as f: 269 | image_index = [x.split(' ')[0] for x in f.readlines()] 270 | with open(image_list) as f: 271 | label_list = [] 272 | selected_list = [] 273 | for ind, x in enumerate(f.readlines()): 274 | label = x.split(' ')[1].strip() 275 | label_list.append(int(label)) 276 | selected_list.append(ind) 277 | image_index = np.array(image_index) 278 | label_list = np.array(label_list) 279 | image_index = image_index[selected_list] 280 | return image_index, label_list 281 | 282 | 283 | def return_classlist(image_list): 284 | with open(image_list) as f: 285 | label_list = [] 286 | for ind, x in enumerate(f.readlines()): 287 | label = x.split(' ')[0].split('/')[-2] 288 | if label not in label_list: 289 | label_list.append(str(label)) 290 | return label_list 291 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datasets 3 | import errno 4 | import models 5 | import torch 6 | 7 | import train_test 8 | import active_learning as al 9 | from visualize import new_TSNE, analyze 10 | 11 | import os 12 | from datetime import datetime 13 | import numpy as np 14 | import torch.optim as optim 15 | 16 | import wandb 17 | import matplotlib.pyplot as plt 18 | from torch.utils.data import DataLoader 19 | 20 | CURRENT_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 21 | MODEL_CHECKPOINTS = CURRENT_DIR_PATH + '/models/models_checkpoints/' 22 | 23 | # dd/mm/YY H:M:S 24 | time_stamp = datetime.now().strftime("%d%m%Y_%H%M%S") 25 | 26 | 27 | def make_args_parser(): 28 | # create an ArgumentParser object 29 | parser = argparse.ArgumentParser( 30 | description='Active Domain Adaptation via S3VAADA') 31 | # fill parser with information about program arguments 32 | parser.add_argument('-s', '--source', default='webcam', type=str, 33 | help='Define the source domain') 34 | parser.add_argument('-t', '--target', default='amazon', type=str, 35 | help='Define the target domain') 36 | parser.add_argument('-m', '--model', default='ResNet', type=str, 37 | help='Define the architecture') 38 | parser.add_argument('-bs', '--batch_size', default=36, type=int, 39 | help='Batch Size') 40 | parser.add_argument('-c', '--cycles', default=6, type=int, 41 | help='Number of Cycles') 42 | parser.add_argument('-e', '--epochs', default=100, type=int, 43 | help='Number of Epochs') 44 | parser.add_argument('-k', '--learning_rate', default=1e-2, type=float, 45 | help='Learning rate') 46 | parser.add_argument('-w', '--workers', default=4, type=int, 47 | help='Number of workers') 48 | parser.add_argument('-al', '--sampling', default='s3vaada', type=str, 49 | help='Sampling Strategy for active learning') 50 | parser.add_argument('-im', '--image_size', default=224, type=int, 51 | help='Image Size') 52 | parser.add_argument('-mo', '--momentum', default=0.9, type=float, 53 | help='Momentum') 54 | parser.add_argument('-wd', '--weight_decay', default=0.0005, type=float, 55 | help='weight decay for SGD') 56 | parser.add_argument('-se', '--seed', default=123, type=int, 57 | help='Seed for the run') 58 | parser.add_argument('-met', '--method', default="vaada", type=str, 59 | help='Method : dann or vaada') 60 | parser.add_argument('-clip', '--clip_value', default=1, type=float, 61 | help='Clip value for max norm') 62 | parser.add_argument('-g', '--gamma', default=10, type=float, 63 | help='Gamma value in the schedule (as defined in DANN)') 64 | parser.add_argument('-log', '--log_interval', default=50, type=int, 65 | help='Log interval for wandb') 66 | parser.add_argument('-na', '--name', default="test", type=str, 67 | help='Wandb name run') 68 | parser.add_argument('-amp', '--use_amp', default=True, type=bool, 69 | help='Mixed Precision Training') 70 | parser.add_argument('-logr', '--log_results', default=True, type=bool, 71 | help='To log results or not') 72 | parser.add_argument('-gid', '--gpu', default=1, type=int, 73 | help='GPU to use') 74 | parser.add_argument('-a', '--alpha', default=0.5, type=float, 75 | help="alpha value for submodular function") 76 | parser.add_argument('-b', '--beta', default=0.3, type=float, 77 | help="beta value for submodular function") 78 | parser.add_argument('-r', '--resume', default="", type=str, 79 | help="Resume from checkpoint") 80 | parser.add_argument('-bud', '--budget', default=None, type=int, 81 | help='Budget to use') 82 | return parser.parse_args() 83 | 84 | 85 | def print_args(args): 86 | print("Running with the following configuration") 87 | 88 | args_map = vars(args) 89 | for key in args_map: 90 | print('\t', key, '-->', args_map[key]) 91 | print() 92 | 93 | 94 | def main(): 95 | # parse and print arguments 96 | args = make_args_parser() 97 | print_args(args) 98 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 99 | 100 | # Check device available 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | print("Running on: {}".format(device)) 103 | 104 | # Seed Everything 105 | seed = args.seed 106 | torch.manual_seed(seed) 107 | torch.cuda.manual_seed_all(seed) 108 | torch.backends.cudnn.deterministic = True 109 | torch.backends.cudnn.benchmark = False 110 | np.random.seed(seed) 111 | 112 | # Timestamp 113 | args.time_stamp = time_stamp 114 | 115 | # Load both source and target domain datasets 116 | source_dataloader, source_dataset = datasets.get_source_domain( 117 | args.source, args) 118 | source_test_dataloader, _ = datasets.get_source_domain( 119 | args.source, args, train=False) 120 | 121 | (target_dataset, target_dataloader), (test_dataset, 122 | target_test_dataloader) = datasets.get_target_domain(args.target, args) 123 | 124 | # Set Budget as 2% of the number of samples in the target dataset 125 | if args.budget is None: 126 | args.budget = int(len(target_dataset)*0.02) 127 | 128 | print("Budget for every cycle : ", args.budget) 129 | # Create directory to save model's checkpoints 130 | try: 131 | model_root = MODEL_CHECKPOINTS + args.source + '-' + \ 132 | args.target + "/" + args.sampling + "/" + args.time_stamp + "/" 133 | print("Model saved at = ", model_root) 134 | os.makedirs(model_root) 135 | except OSError as e: 136 | if e.errno == errno.EEXIST: 137 | pass 138 | else: 139 | raise 140 | # Intialize Wandb 141 | if args.log_results: 142 | wandb.init(project="active-learning", 143 | entity="active-learning", name=args.name) 144 | wandb.config.update(args) 145 | wandb.config.update({"Optimizer": "SGD"}) 146 | 147 | # Initialize model 148 | 149 | net = models.ResNet(args.num_classes, device, args) 150 | param_dict = torch.load('models/resnet50.pth') 151 | models.load_single_state_dict(net, param_dict) 152 | 153 | net = net.to(device) 154 | 155 | domain_loss = torch.nn.CrossEntropyLoss() 156 | class_loss = torch.nn.CrossEntropyLoss() 157 | 158 | if args.log_results: 159 | wandb.watch(net) 160 | torch.save(net.state_dict(), model_root + "/" + args.name + ".pth") 161 | cycle_no = 0 162 | 163 | if args.resume: 164 | last_cycle_weight = sorted([x for x in os.listdir(args.resume) if x.endswith( 165 | ".pth") and len(x.strip(".pth")) < 3], key=lambda x: (len(x), x))[-2] 166 | print("Resuming from checkpoint:", last_cycle_weight) 167 | net.load_state_dict(torch.load( 168 | os.path.join(args.resume, last_cycle_weight))) 169 | cycle_no = last_cycle_weight.strip(".pth") # [0] 170 | all_idx = np.array([]) 171 | for i in range(int(cycle_no)+1): 172 | idx = np.load(os.path.join(args.resume, str(i)+".npy")) 173 | all_idx = np.concatenate((all_idx, idx)) 174 | all_idx = all_idx.astype(int) 175 | all_idx = torch.from_numpy(all_idx) 176 | all_indices = torch.arange(0, len(target_dataset)) 177 | new_data_set = torch.utils.data.Subset(target_dataset, all_idx) 178 | target_dataset = torch.utils.data.Subset(target_dataset, torch.from_numpy( 179 | np.setdiff1d(all_indices.numpy(), all_idx.numpy()))) 180 | target_dataloader = DataLoader( 181 | dataset=target_dataset, 182 | batch_size=args.batch_size, num_workers=args.workers, 183 | shuffle=True 184 | ) 185 | new_data_loader = DataLoader( 186 | dataset=new_data_set, 187 | batch_size=args.batch_size, num_workers=args.workers, 188 | shuffle=True 189 | ) 190 | print("Number of labeled target samples:", len(all_idx)) 191 | cycle_no = int(cycle_no)+1 192 | 193 | print("Number of classes: ", args.num_classes) 194 | print("Number of images in the target dataset : ", len(target_dataset)) 195 | print("Number of images in the source dataset : ", len(source_dataset)) 196 | 197 | new_data_loader = None 198 | 199 | for cycle in range(cycle_no, args.cycles): 200 | 201 | print('Cycle: ', cycle+1) 202 | if args.log_results: 203 | wandb.log({"Cycle": cycle+1}) 204 | 205 | # Load the original ResNet-50 weights 206 | net.load_state_dict(torch.load(model_root + "/" + args.name + ".pth")) 207 | 208 | dc_optimizer = optim.SGD(net.domain_classifier.parameters( 209 | ), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 210 | 211 | if args.method == "dann": 212 | fg_optimizer = optim.SGD(net.feature_extractor.parameters( 213 | ), lr=args.learning_rate/10, momentum=args.momentum, weight_decay=args.weight_decay) 214 | else: 215 | fg_optimizer = optim.SGD(net.feature_extractor.parameters( 216 | ), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 217 | 218 | fc_optimizer = optim.SGD(net.feature_classifier.parameters( 219 | ), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 220 | 221 | train_test.train(net, class_loss, domain_loss, source_dataloader, 222 | target_dataloader, new_data_loader, source_test_dataloader, target_test_dataloader, 223 | (fg_optimizer, fc_optimizer, dc_optimizer), 224 | cycle, model_root, args, device) 225 | 226 | # To sample the images from the unlabeled target dataset 227 | unshuffled_dataloader = DataLoader( 228 | dataset=target_dataset, 229 | batch_size=args.batch_size, 230 | num_workers=args.workers, 231 | shuffle=False 232 | ) 233 | 234 | len_data_loader = len(unshuffled_dataloader.dataset) 235 | all_indices = torch.arange(0, len_data_loader) 236 | 237 | idx = al.get_active_learning_method( 238 | net, unshuffled_dataloader, device, args, source_dataloader, cycle, new_data_loader) 239 | # Displays which classes the selected samples belong to 240 | analyze(idx, target_dataset, net, args, device) 241 | 242 | temp_dataset = torch.utils.data.Subset(target_dataset, idx) 243 | temp_dataloader = DataLoader( 244 | dataset=temp_dataset, 245 | batch_size=args.batch_size, num_workers=args.workers, 246 | shuffle=False 247 | ) 248 | # Visualize 249 | new_TSNE(net, source_dataloader, target_dataloader, 250 | new_data_loader, temp_dataloader, cycle, device, args) 251 | 252 | if new_data_loader is None: 253 | new_data_set = torch.utils.data.Subset(target_dataset, idx) 254 | else: 255 | new_data_set = torch.utils.data.ConcatDataset( 256 | [new_data_set, torch.utils.data.Subset(target_dataset, idx)]) 257 | 258 | # Remove the labeled images from the target dataset 259 | target_dataset = torch.utils.data.Subset(target_dataset, torch.from_numpy( 260 | np.setdiff1d(all_indices.numpy(), idx.numpy()))) 261 | target_dataloader = DataLoader( 262 | dataset=target_dataset, 263 | batch_size=args.batch_size, num_workers=args.workers, 264 | shuffle=True 265 | ) 266 | # new_data_loader contains the labeled target images 267 | new_data_loader = DataLoader( 268 | dataset=new_data_set, 269 | batch_size=args.batch_size, num_workers=args.workers, 270 | shuffle=True 271 | ) 272 | np.save(model_root+str(cycle)+".npy", idx) 273 | 274 | 275 | if __name__ == '__main__': 276 | main() 277 | -------------------------------------------------------------------------------- /train_test/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import time 4 | import itertools 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from utils.utils import optimizer_scheduler 11 | import train_test 12 | 13 | import wandb 14 | 15 | import copy 16 | from copy import deepcopy 17 | from collections import OrderedDict 18 | from sys import stderr 19 | 20 | from torch import Tensor 21 | 22 | 23 | class ConditionalEntropyLoss(nn.Module): 24 | def __init__(self): 25 | super(ConditionalEntropyLoss, self).__init__() 26 | 27 | def forward(self, x): 28 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 29 | b = b.sum(dim=1) 30 | return -1.0*b.mean(dim=0) 31 | 32 | 33 | class VAT(nn.Module): 34 | def __init__(self, model, reduction='mean'): 35 | super(VAT, self).__init__() 36 | self.n_power = 1 37 | self.XI = 1e-6 38 | self.model = model 39 | self.epsilon = 5 40 | self.reduction = reduction 41 | 42 | def forward(self, X, logit, domain, lamda): 43 | vat_loss, r_vadv = self.virtual_adversarial_loss( 44 | X, logit, domain, lamda) 45 | return vat_loss, r_vadv 46 | 47 | def generate_virtual_adversarial_perturbation(self, x, logit, domain, lamda): 48 | d = torch.randn_like(x, device='cuda') 49 | 50 | for _ in range(self.n_power): 51 | d = self.XI * self.get_normalized_vector(d).requires_grad_() 52 | logit_m, _, _ = self.model(x + d, domain, lamda) 53 | dist = self.kl_divergence_with_logit(logit, logit_m) 54 | if self.reduction == 'mean': 55 | grad = torch.autograd.grad(dist, [d])[0] 56 | d = grad.detach() 57 | 58 | return self.epsilon * self.get_normalized_vector(d) 59 | 60 | def kl_divergence_with_logit(self, q_logit, p_logit): 61 | q = F.softmax(q_logit, dim=1) 62 | if self.reduction == 'mean': 63 | qlogq = torch.mean( 64 | torch.sum(q * F.log_softmax(q_logit, dim=1), dim=1)) 65 | qlogp = torch.mean( 66 | torch.sum(q * F.log_softmax(p_logit, dim=1), dim=1)) 67 | else: 68 | qlogq = torch.sum(q*F.log_softmax(q_logit, dim=1), dim=1) 69 | qlogp = torch.sum(q*F.log_softmax(p_logit, dim=1), dim=1) 70 | return qlogq - qlogp 71 | 72 | def get_normalized_vector(self, d): 73 | return F.normalize(d.view(d.size(0), -1), p=2, dim=1).reshape(d.size()) 74 | 75 | def virtual_adversarial_loss(self, x, logit, domain, lamda): 76 | r_vadv = self.generate_virtual_adversarial_perturbation( 77 | x, logit, domain, lamda) 78 | logit_p = logit.detach() 79 | logit_m, _, _ = self.model(x + r_vadv, domain, lamda) 80 | loss = self.kl_divergence_with_logit(logit_p, logit_m) 81 | return loss, r_vadv 82 | 83 | 84 | def train(net, class_loss, domain_loss, source_dataloader, 85 | target_dataloader, new_data_loader, source_test_dataloader, target_test_dataloader, optimizer_, cycle, model_root, args, device): 86 | 87 | if args.use_amp: 88 | scaler = torch.cuda.amp.GradScaler() 89 | 90 | fg_optimizer, fc_optimizer, dc_optimizer = optimizer_ 91 | 92 | if new_data_loader is not None: 93 | new_source_dataloader = [new_data_loader, source_dataloader] 94 | else: 95 | new_source_dataloader = [source_dataloader] 96 | 97 | len_dataloader = min(len(source_dataloader), len(target_dataloader)) 98 | log_frequency = int(0.5*len_dataloader) 99 | 100 | cent = ConditionalEntropyLoss().to(device) 101 | vat_loss = VAT(net).to(device) 102 | 103 | device = torch.device(device) 104 | 105 | for epoch in range(args.epochs): 106 | 107 | start_time = time.time() 108 | net.train() 109 | 110 | if args.log_results: 111 | wandb.log({'Epoch': epoch+1}) 112 | 113 | print("Epoch :", epoch+1) 114 | 115 | for batch_idx, (source, target) in enumerate(zip(itertools.chain.from_iterable(new_source_dataloader), target_dataloader)): 116 | # Setup hyperparameters 117 | p = float(batch_idx + epoch * len_dataloader) / \ 118 | (args.epochs * len_dataloader) 119 | 120 | lamda = 2. / (1. + np.exp(-args.gamma * p)) - 1 121 | 122 | # Get data input along with corresponding label 123 | source_input, source_label = source 124 | target_input, target_label = target 125 | 126 | if args.method == 'dann': 127 | fc_optimizer = optimizer_scheduler(fc_optimizer, p) 128 | fg_optimizer = optimizer_scheduler(fg_optimizer, p) 129 | dc_optimizer = optimizer_scheduler(dc_optimizer, p) 130 | 131 | source_input, source_label = source_input.type(torch.FloatTensor).to( 132 | device), source_label.type(torch.LongTensor).to(device) 133 | target_input, target_label = target_input.type(torch.FloatTensor).to( 134 | device), target_label.type(torch.LongTensor).to(device) 135 | 136 | fg_optimizer.zero_grad() 137 | fc_optimizer.zero_grad() 138 | dc_optimizer.zero_grad() 139 | 140 | domain_source_labels = torch.ones( 141 | source_input.shape[0], device=device, dtype=torch.long) 142 | domain_target_labels = torch.zeros( 143 | target_input.size()[0], device=device, dtype=torch.long) 144 | 145 | domain_target_labels_new = torch.ones( 146 | target_input.size()[0], dtype=torch.long, device=device) 147 | domain_source_labels_new = torch.zeros( 148 | source_input.size()[0], device=device, dtype=torch.long) 149 | 150 | domain_of_batch = 'source' 151 | 152 | if new_data_loader is not None: 153 | if batch_idx < len(new_data_loader): 154 | domain_of_batch = 'target' 155 | else: 156 | domain_of_batch = 'source' 157 | 158 | if args.method == 'vaada': 159 | # Method is VAADA 160 | with torch.cuda.amp.autocast(args.use_amp): 161 | source_class_output, source_domain_output, source_features = net( 162 | source_input, domain_of_batch, lamda) 163 | source_class_loss = class_loss( 164 | source_class_output, source_label) 165 | 166 | target_class_output, target_domain_output, target_features = net( 167 | target_input, 'target', lamda) 168 | loss_target_cent = cent(target_class_output) 169 | 170 | source_domain_loss = domain_loss( 171 | source_domain_output, domain_source_labels) 172 | target_domain_loss = domain_loss( 173 | target_domain_output, domain_target_labels) 174 | 175 | source_domain_output_d = net.domain_classifier( 176 | source_features.detach()) 177 | target_domain_output_d = net.domain_classifier( 178 | target_features.detach()) 179 | 180 | domain_loss_total = (domain_loss(source_domain_output_d, domain_source_labels_new) + 181 | domain_loss(target_domain_output_d, domain_target_labels_new)) 182 | source_loss_vat, _ = vat_loss( 183 | source_input, source_class_output, domain_of_batch, lamda) 184 | target_loss_vat, _ = vat_loss( 185 | target_input, target_class_output, 'target', lamda) 186 | 187 | dc_optimizer.zero_grad() 188 | scaler.scale(domain_loss_total).backward() 189 | 190 | scaler.unscale_(dc_optimizer) 191 | torch.nn.utils.clip_grad_norm_( 192 | net.domain_classifier.parameters(), args.clip_value) 193 | scaler.step(dc_optimizer) 194 | 195 | total_domain_loss = 1*(source_domain_loss + target_domain_loss) 196 | 197 | loss = source_class_loss + 0.01*total_domain_loss + 0.01 * \ 198 | loss_target_cent + source_loss_vat + 0.01*target_loss_vat 199 | 200 | fg_optimizer.zero_grad() 201 | fc_optimizer.zero_grad() 202 | scaler.scale(loss).backward() 203 | 204 | scaler.unscale_(fg_optimizer) 205 | scaler.unscale_(fc_optimizer) 206 | 207 | torch.nn.utils.clip_grad_norm_( 208 | net.feature_extractor.parameters(), args.clip_value) 209 | torch.nn.utils.clip_grad_norm_( 210 | net.feature_classifier.parameters(), args.clip_value) 211 | 212 | scaler.step(fg_optimizer) 213 | scaler.step(fc_optimizer) 214 | 215 | else: 216 | #Method is DANN 217 | with torch.cuda.amp.autocast(args.use_amp): 218 | source_class_output, source_domain_output, _ = net( 219 | source_input, domain_of_batch, lamda) 220 | source_class_loss = class_loss( 221 | source_class_output, source_label) 222 | 223 | target_features, target_domain_output, _ = net( 224 | target_input, 'target', lamda) 225 | 226 | source_domain_loss = domain_loss( 227 | source_domain_output, domain_source_labels) 228 | target_domain_loss = domain_loss( 229 | target_domain_output, domain_target_labels) 230 | 231 | loss = source_class_loss + source_domain_loss + target_domain_loss 232 | 233 | fg_optimizer.zero_grad() 234 | fc_optimizer.zero_grad() 235 | dc_optimizer.zero_grad() 236 | 237 | scaler.scale(loss).backward() 238 | 239 | scaler.unscale_(fg_optimizer) 240 | scaler.unscale_(fc_optimizer) 241 | scaler.unscale_(dc_optimizer) 242 | 243 | torch.nn.utils.clip_grad_norm_( 244 | net.feature_extractor.parameters(), args.clip_value) 245 | torch.nn.utils.clip_grad_norm_( 246 | net.feature_classifier.parameters(), args.clip_value) 247 | torch.nn.utils.clip_grad_norm_( 248 | net.domain_classifier.parameters(), args.clip_value) 249 | 250 | scaler.step(fg_optimizer) 251 | scaler.step(fc_optimizer) 252 | scaler.step(dc_optimizer) 253 | 254 | if args.use_amp: 255 | scaler.update() 256 | 257 | if(batch_idx % args.log_interval == 0) and args.log_results: 258 | wandb.log({'Source Class Loss': source_class_loss, 259 | "Source Domain Loss": source_domain_loss, 260 | "Target Domain Loss": target_domain_loss}) 261 | 262 | if (batch_idx + 1) % log_frequency == 0: 263 | print('[{}/{} ({:.0f}%)]Source Class Loss: {:.6f}\tSource Domain Loss: {:.6f}\tTarget Domain Loss: {:.6f}'.format( 264 | batch_idx * 265 | args.batch_size, (len_dataloader*args.batch_size), 266 | 100. * batch_idx / len_dataloader, 267 | source_class_loss.item(), 268 | source_domain_loss.item(), target_domain_loss.item() 269 | )) 270 | 271 | avg_target_class_loss = 0 272 | if new_data_loader is not None: 273 | net.eval() 274 | len_new_data_loader = len(new_data_loader) 275 | with torch.no_grad(): 276 | for batch_idx, target in enumerate(new_data_loader): 277 | p = float(batch_idx) / (len_new_data_loader) 278 | lamda = 0 # 2. / (1. + np.exp(-args.gamma * p)) - 1 279 | 280 | target_input, target_label = target 281 | target_input, target_label = target_input.type(torch.FloatTensor).to( 282 | device), target_label.type(torch.LongTensor).to(device) 283 | target_class_output, _, _ = net( 284 | target_input, 'target', lamda) 285 | 286 | target_class_loss = class_loss( 287 | target_class_output, target_label) 288 | 289 | avg_target_class_loss += target_class_loss 290 | print("Target Class Loss : ", 291 | avg_target_class_loss.item()/len(new_data_loader)) 292 | 293 | end_time = time.time() 294 | print('Time taken :', end_time-start_time) 295 | if (epoch+1) % 10 == 0 or args.epochs-epoch <= 50: 296 | cur_state = torch.get_rng_state() 297 | train_test.test(net, source_test_dataloader, target_test_dataloader, 298 | epoch, cycle, args, device, epoch == (args.epochs-1)) 299 | torch.set_rng_state(cur_state) 300 | -------------------------------------------------------------------------------- /active_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import train_test 4 | 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | from copy import deepcopy 11 | 12 | from sklearn.cluster import KMeans 13 | from scipy.spatial import distance 14 | from sklearn.metrics.pairwise import euclidean_distances 15 | from sklearn.metrics import pairwise_distances 16 | from scipy import stats 17 | 18 | 19 | def get_active_learning_method(net, unlabeled_loader, device, args, source_dataloader, cycle, new_dataloader): 20 | 21 | if args.sampling == 'random': 22 | idx = random_sampling(net, unlabeled_loader, device, args) 23 | elif args.sampling == 'aada': 24 | idx = aada(net, unlabeled_loader, device, args) 25 | elif args.sampling == 'bvsb': 26 | idx = BvSB(net, unlabeled_loader, device, args) 27 | elif args.sampling == 'coreset': 28 | idx = coreset_sampling(net, unlabeled_loader, device, args) 29 | elif args.sampling == 'badge': 30 | idx = badge_sampling(net, unlabeled_loader, device, args) 31 | elif args.sampling == 's3vaada': 32 | idx = s3vaada(net, unlabeled_loader, device, args, 33 | cycle, source_dataloader, new_dataloader) 34 | else: 35 | raise NotImplementedError() 36 | return idx 37 | 38 | def H(x): 39 | return -1*torch.sum(torch.exp(x) * x, dim=1) 40 | 41 | 42 | def aada(models, unlabeled_loader, device, args): 43 | 44 | models.eval() 45 | uncertainty = torch.tensor([]).to(device) 46 | 47 | with torch.no_grad(): 48 | for batch_idx, (inputs, _) in enumerate(unlabeled_loader): 49 | p = float(batch_idx) / len(unlabeled_loader) 50 | lamda = 0 51 | inputs = inputs.to(device) 52 | scores, domain_scores, _ = models(inputs, 'target', lamda) 53 | 54 | pt = torch.exp(domain_scores[:, 1]) / \ 55 | torch.exp(domain_scores).sum(dim=1) 56 | ps = torch.exp(domain_scores[:, 0]) / \ 57 | torch.exp(domain_scores).sum(dim=1) 58 | weight = (1 - pt) / pt # Importance weight 59 | weight = pt / ps 60 | # Add the entropy term to the weight 61 | weight = weight * H(scores) 62 | weight = weight.view(weight.size(0)) 63 | uncertainty = torch.cat((uncertainty, weight), 0) 64 | 65 | uncertainty = torch.argsort(uncertainty, descending=True) 66 | idx = uncertainty.narrow(0, 0, args.budget) 67 | idx = idx.cpu() 68 | return idx 69 | 70 | 71 | def random_sampling(net, unlabeled_loader, device, args): 72 | number_of_unlabeled_samples = len(unlabeled_loader.dataset) 73 | idx = torch.from_numpy(np.random.choice( 74 | number_of_unlabeled_samples, args.budget, replace=False)) 75 | return idx 76 | 77 | 78 | def BvSB(net, unlabeled_loader, device, args): 79 | net.eval() 80 | diff_top2 = torch.tensor([]) 81 | lamda = 0 82 | with torch.no_grad(): 83 | for batch_idx, (inputs, _) in enumerate(unlabeled_loader): 84 | p = float(batch_idx) / len(unlabeled_loader) 85 | inputs = inputs.to(device) 86 | target_class_pred, _, _ = net(inputs, 'target', lamda) 87 | class_prob = F.softmax(target_class_pred, dim=1) 88 | top2_prob, top2_pred = torch.topk(class_prob, 2) 89 | for i in top2_prob: 90 | diff_top2 = torch.cat( 91 | (diff_top2, torch.tensor([(i[0]-i[1])])), 0) 92 | ranked = torch.argsort(diff_top2, descending=False) 93 | print(ranked) 94 | idx = ranked.narrow(0, 0, args.budget) 95 | return idx 96 | 97 | 98 | class VAT(nn.Module): 99 | def __init__(self, model, reduction='mean'): 100 | super(VAT, self).__init__() 101 | self.n_power = 1 102 | self.XI = 1e-6 103 | self.model = model 104 | self.epsilon = 5.0 105 | 106 | def forward(self, X, logit, domain, lamda): 107 | vat_loss, r_vadv = self.virtual_adversarial_loss( 108 | X, logit, domain, lamda) 109 | return vat_loss, r_vadv 110 | 111 | def generate_virtual_adversarial_perturbation(self, x, logit, domain, lamda, random=None): 112 | if random is None: 113 | d = torch.randn_like(x, device='cuda') 114 | else: 115 | d = random 116 | lamda = 0 117 | for _ in range(self.n_power): 118 | d = self.XI * self.get_normalized_vector(d).requires_grad_() 119 | logit_m, _, _ = self.model(x + d, domain, lamda) 120 | dist = self.kl_divergence_with_logit(logit, logit_m) 121 | grad = torch.autograd.grad(dist, [d])[0] 122 | d = grad.detach() 123 | 124 | return self.epsilon * self.get_normalized_vector(d) 125 | 126 | def kl_divergence_with_logit(self, q_logit, p_logit, reduction="mean"): 127 | q = F.softmax(q_logit, dim=1) 128 | if reduction == 'mean': 129 | qlogq = torch.mean( 130 | torch.sum(q * F.log_softmax(q_logit, dim=1), dim=1)) 131 | qlogp = torch.mean( 132 | torch.sum(q * F.log_softmax(p_logit, dim=1), dim=1)) 133 | else: 134 | qlogq = torch.sum(q*F.log_softmax(q_logit, dim=1), dim=1) 135 | qlogp = torch.sum(q*F.log_softmax(p_logit, dim=1), dim=1) 136 | return qlogq - qlogp 137 | 138 | def get_normalized_vector(self, d): 139 | return F.normalize(d.view(d.size(0), -1), p=2, dim=1).reshape(d.size()) 140 | 141 | def virtual_adversarial_loss(self, x, logit, domain, lamda): 142 | r_vadv = self.generate_virtual_adversarial_perturbation( 143 | x, logit, domain, lamda) 144 | logit_p = logit.detach() 145 | with torch.no_grad(): 146 | logit_m, _, _ = self.model(x + r_vadv, domain, lamda) 147 | loss = self.kl_divergence_with_logit( 148 | logit_p, logit_m, reduction="none") 149 | return loss, (r_vadv, logit_m) 150 | 151 | 152 | # K-Means++ utility function 153 | def init_centers(X, K): 154 | # take the maximum norm one vector as c0 155 | ind = np.argmax([np.linalg.norm(s, 2) for s in X]) 156 | mu = [X[ind]] 157 | indsAll = [ind] 158 | cent = 0 159 | print('#Samps\tTotal Distance') 160 | while len(mu) < K: 161 | if len(mu) == 1: 162 | D2 = pairwise_distances(X, mu).ravel().astype(float) 163 | else: 164 | newD = pairwise_distances(X, [mu[-1]]).ravel().astype(float) 165 | for i in range(len(X)): 166 | if D2[i] > newD[i]: 167 | # centInds[i] = cent 168 | D2[i] = newD[i] 169 | print(str(len(mu)) + '\t' + str(sum(D2)), flush=True) 170 | D2 = D2.ravel().astype(float) 171 | # SAMPLING WITH PMF = D2/sum(D2) 172 | Ddist = (D2 ** 2) / sum(D2 ** 2) 173 | customDist = stats.rv_discrete( 174 | name='custm', values=(np.arange(len(D2)), Ddist)) 175 | ind = customDist.rvs(size=1)[0] 176 | mu.append(X[ind]) 177 | indsAll.append(ind) 178 | return indsAll 179 | 180 | 181 | def get_grad_embedding(model, unlabeled_loader, args): 182 | embDim = 256 183 | model.eval() 184 | nLab = args.num_classes 185 | embedding = np.zeros([len(unlabeled_loader.dataset), embDim * nLab]) 186 | with torch.no_grad(): 187 | for batch_idx, (x, y) in enumerate(unlabeled_loader): 188 | x, y = Variable(x.cuda()), Variable(y.cuda()) 189 | idxs = np.arange(len(x)) + args.batch_size * batch_idx 190 | lamda = 0 191 | cout, _, out = model(x, 'target', lamda) 192 | out = out.data.cpu().numpy() 193 | batchProbs = F.softmax(cout, dim=1).data.cpu().numpy() 194 | maxInds = np.argmax(batchProbs, 1) 195 | for j in range(len(y)): 196 | for c in range(nLab): 197 | if c == maxInds[j]: 198 | embedding[idxs[j]][embDim * c: embDim * 199 | (c+1)] = deepcopy(out[j]) * (1 - batchProbs[j][c]) 200 | else: 201 | embedding[idxs[j]][embDim * c: embDim * 202 | (c+1)] = deepcopy(out[j]) * (-1 * batchProbs[j][c]) 203 | return torch.Tensor(embedding) 204 | 205 | 206 | def badge_sampling(net, unlabeled_loader, device, args): 207 | net.eval() 208 | idxs_unlabeled = np.arange(len(unlabeled_loader.dataset)) 209 | gradEmbedding = get_grad_embedding( 210 | net, unlabeled_loader, args).cpu().numpy() 211 | print("Grad embedding shape = ", gradEmbedding.shape) 212 | chosen = init_centers(gradEmbedding, args.budget) 213 | print("chosen = ", chosen) 214 | idxs = idxs_unlabeled[chosen] 215 | print("idxs = ", idxs) 216 | return torch.from_numpy(idxs) 217 | 218 | 219 | def coreset_sampling(net, unlabeled_loader, device, args): 220 | print('Core-Set Sampling') 221 | net.eval() 222 | lamda = 0 223 | embedding = torch.tensor([]).to(device) 224 | with torch.no_grad(): 225 | for batch_idx, (inputs, labels) in enumerate(unlabeled_loader): 226 | p = float(batch_idx) / len(unlabeled_loader) 227 | inputs = inputs.to(device) 228 | feature = net.feature_extractor(inputs, 'target', lamda) 229 | embedding = torch.cat((embedding, feature), 0) 230 | 231 | embedding = embedding.cpu().numpy() 232 | number_of_unlabeled_samples = len(unlabeled_loader.dataset) 233 | 234 | dist_mat = np.matmul(embedding, embedding.transpose()) 235 | 236 | sq = np.array(dist_mat.diagonal()).reshape(number_of_unlabeled_samples, 1) 237 | dist_mat *= -2 238 | dist_mat += sq 239 | dist_mat += sq.transpose() 240 | dist_mat = np.sqrt(dist_mat) 241 | 242 | NUM_INIT_LB = 10 243 | idxs_lb = np.zeros(number_of_unlabeled_samples, dtype=bool) 244 | idxs_tmp = np.arange(number_of_unlabeled_samples) 245 | np.random.shuffle(idxs_tmp) 246 | idxs_lb[idxs_tmp[:NUM_INIT_LB]] = True 247 | 248 | lb_flag = idxs_lb.copy() 249 | mat = dist_mat[~lb_flag, :][:, lb_flag] 250 | 251 | for i in range(args.budget): 252 | mat_min = mat.min(axis=1) 253 | q_idx_ = mat_min.argmax() 254 | q_idx = np.arange(number_of_unlabeled_samples)[~lb_flag][q_idx_] 255 | lb_flag[q_idx] = True 256 | mat = np.delete(mat, q_idx_, 0) 257 | mat = np.append(mat, dist_mat[~lb_flag, q_idx][:, None], axis=1) 258 | 259 | opt = mat.min(axis=1).max() 260 | 261 | bound_u = opt 262 | bound_l = opt/2.0 263 | delta = opt 264 | 265 | xx, yy = np.where(dist_mat <= opt) 266 | dd = dist_mat[xx, yy] 267 | 268 | lb_flag_ = idxs_lb.copy() 269 | subset = np.where(lb_flag_ == True)[0].tolist() 270 | 271 | SEED = 5 272 | 273 | import pickle 274 | #pickle.dump((xx.tolist(), yy.tolist(), dd.tolist(), subset, float(opt), args.budget, number_of_unlabeled_samples), open('mip{}.pkl'.format(SEED), 'wb'), 2) 275 | 276 | # import ipdb 277 | # ipdb.set_trace() 278 | # solving MIP 279 | # download Gurobi software from http://www.gurobi.com/ 280 | # sh {GUROBI_HOME}/linux64/bin/gurobi.sh < core_set_sovle_solve.py 281 | 282 | #import os 283 | #os.system('sh ./gurobi902/linux64/bin/gurobi.sh < core_set_sovle_solve.py') 284 | 285 | #sols = pickle.load(open('sols{}.pkl'.format(SEED), 'rb')) 286 | sols = None 287 | 288 | if sols is None: 289 | q_idxs = lb_flag 290 | else: 291 | lb_flag_[sols] = True 292 | q_idxs = lb_flag_ 293 | print('sum q_idxs = {}'.format(q_idxs.sum())) 294 | 295 | return torch.from_numpy(np.arange(number_of_unlabeled_samples)[(idxs_lb ^ q_idxs)]) 296 | 297 | 298 | def get_vat(net, unlabeled_loader, device, args): 299 | vat_loss = VAT(net, reduction='mean').to(device) 300 | net.eval() 301 | vat_loss_all = torch.tensor([]).to(device) 302 | restarts = 5 303 | lamda = 0 304 | 305 | for batch_idx, target in enumerate(unlabeled_loader): 306 | target_input, target_label = target 307 | target_input, target_label = target_input.type(torch.FloatTensor).to( 308 | device), target_label.type(torch.LongTensor).to(device) 309 | target_class_output, _, _ = net(target_input, 'target', lamda) 310 | logits = target_class_output 311 | 312 | vat_loss_restarts = None 313 | logit_batch = None 314 | for i in range(restarts): 315 | target_loss_vat, (r_vadv, logit) = vat_loss( 316 | target_input, logits, 'target', lamda) 317 | if vat_loss_restarts is None: 318 | vat_loss_restarts = target_loss_vat.unsqueeze(0) 319 | logit_batch = logit.unsqueeze(0) 320 | else: 321 | vat_loss_restarts = torch.cat( 322 | (vat_loss_restarts, target_loss_vat.unsqueeze(0)), 0) 323 | logit_batch = torch.cat((logit_batch, logit.unsqueeze(0)), 0) 324 | kl_avg = vat_loss_restarts 325 | for i in range(restarts): 326 | for j in range(restarts): 327 | if i != j: 328 | x = vat_loss.kl_divergence_with_logit( 329 | logit_batch[i], logit_batch[j], reduction="none") 330 | if kl_avg is None: 331 | kl_avg = x.unsqueeze(0) 332 | else: 333 | kl_avg = torch.cat((kl_avg, x.unsqueeze(0)), 0) 334 | output = torch.sum(kl_avg, dim=0) 335 | vat_loss_all = torch.cat((vat_loss_all, output), 0) 336 | return vat_loss_all.cpu().numpy()/(restarts**2) 337 | 338 | 339 | def get_embedding(net, unlabeled_loader, device, args): 340 | net.eval() 341 | lamda = 0 342 | embedding = torch.tensor([]).to(device) 343 | with torch.no_grad(): 344 | for _, (inputs, _) in enumerate(unlabeled_loader): 345 | inputs = inputs.to(device) 346 | feature = net.feature_extractor(inputs, 'target', lamda) 347 | embedding = torch.cat((embedding, feature), 0) 348 | return embedding.cpu().numpy() 349 | 350 | 351 | def get_softmax_output(net, unlabeled_loader, device, args): 352 | net.eval() 353 | softmax_output = torch.tensor([]).to(device) 354 | lamda = 0 355 | with torch.no_grad(): 356 | for batch_idx, (inputs, labels) in enumerate(unlabeled_loader): 357 | p = float(batch_idx) / len(unlabeled_loader) 358 | inputs = inputs.to(device) 359 | target_class_pred, _, feat = net(inputs, 'target', lamda) 360 | target_class_pred = F.softmax(target_class_pred, dim=1) 361 | softmax_output = torch.cat((softmax_output, target_class_pred), 0) 362 | print('softmax_output shape = ', softmax_output.shape) 363 | return softmax_output.cpu().numpy() 364 | 365 | 366 | def Gain(vat, kl, similarity, A, S, alpha=1., beta=1.): 367 | vat = vat[A] 368 | 369 | kl_score = kl[S, :][:, A].T # + kl[A,:][:,S] 370 | if kl_score.shape[1] == 0: 371 | kl_score = np.zeros(kl_score.shape[0]) 372 | else: 373 | kl_score = kl_score.min(axis=1) 374 | 375 | sim_score_of_all_with_selected = similarity[:, S] 376 | if sim_score_of_all_with_selected.shape[1] == 0: 377 | sim_score = np.zeros(sim_score_of_all_with_selected.shape[0]) 378 | else: 379 | sim_score_of_all_with_selected = similarity[:, S].max( 380 | axis=1).reshape(-1, 1) 381 | sim_score_of_all_with_not_selected = similarity[:, A] 382 | sim_score_of_all = sim_score_of_all_with_not_selected - \ 383 | sim_score_of_all_with_selected 384 | sim_score_of_all[sim_score_of_all < 0] = 0 385 | sim_score = sim_score_of_all.mean(axis=0) 386 | 387 | # Combining the three scores 388 | score = alpha*vat + beta*kl_score + (1-alpha-beta)*sim_score 389 | 390 | selected = score.argmax() 391 | print("Convex comb: VAP, KL = ", 392 | vat[selected], kl_score[selected], sim_score[selected]) 393 | 394 | return score 395 | 396 | 397 | def pairwise_kl_gpu(A, B): 398 | A1 = A[:, None, :] 399 | A2 = B[None, :, :] 400 | div = A1/A2 401 | log = torch.log(div) 402 | log = A1*log 403 | s = torch.sum(log, axis=-1) 404 | return s 405 | 406 | 407 | def pairwise_bc_similarity_gpu(A, B): 408 | A1 = A[:, None, :] 409 | A2 = B[None, :, :] 410 | mul = A1*A2 411 | mul = torch.sqrt(mul) 412 | s = torch.sum(mul, axis=-1) 413 | s = -torch.log(1 - s + 1e-6) 414 | return s 415 | 416 | 417 | def s3vaada(net, unlabeled_loader, device, args, cycle, source_dataloader, new_dataloader): 418 | print("S3VAADA Sampling") 419 | 420 | print("alpha = ", args.alpha) 421 | print("beta = ", args.beta) 422 | 423 | vat = get_vat(net, unlabeled_loader, device, args) 424 | vat = (vat - vat.min())/(vat.max() - vat.min()) 425 | softmax_output = get_softmax_output(net, unlabeled_loader, device, args) 426 | 427 | softmax_output = torch.Tensor(softmax_output).to(device) 428 | D = np.zeros((softmax_output.shape[0], softmax_output.shape[0])) 429 | b = 1000 430 | for i in range(0, softmax_output.shape[0], b): 431 | s1 = i 432 | e1 = min(i+b, softmax_output.shape[0]) 433 | for j in range(0, softmax_output.shape[0], b): 434 | s2 = j 435 | e2 = min(j+b, softmax_output.shape[0]) 436 | D[s1:e1, s2:e2] = pairwise_kl_gpu( 437 | A=softmax_output[s1:e1], B=softmax_output[s2:e2]).cpu().numpy() 438 | dists = D 439 | dists = (dists - dists.min())/(dists.max() - dists.min()) 440 | 441 | similarity = np.zeros((softmax_output.shape[0], softmax_output.shape[0])) 442 | b = 1000 443 | for i in range(0, softmax_output.shape[0], b): 444 | s1 = i 445 | e1 = min(i+b, softmax_output.shape[0]) 446 | for j in range(0, softmax_output.shape[0], b): 447 | s2 = j 448 | e2 = min(j+b, softmax_output.shape[0]) 449 | similarity[s1:e1, s2:e2] = pairwise_bc_similarity_gpu( 450 | A=softmax_output[s1:e1], B=softmax_output[s2:e2]).cpu().numpy() 451 | similarity = (similarity - similarity.min()) / \ 452 | (similarity.max() - similarity.min()) 453 | 454 | number_of_unlabeled_samples = len(unlabeled_loader.dataset) 455 | S = [] 456 | 457 | for i in range(args.budget): 458 | A = [j for j in range(number_of_unlabeled_samples) if j not in S] 459 | G = Gain(vat, dists, similarity, A, S, 460 | alpha=args.alpha, beta=args.beta) 461 | S.append(A[G.argmax()]) 462 | 463 | print(S) 464 | return torch.from_numpy(np.array(S)) 465 | --------------------------------------------------------------------------------